pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,1288 @@
1
+ import copy
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Tuple
4
+
5
+ import numpy as np
6
+ import optuna
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from sklearn.decomposition import PCA
10
+ from sklearn.exceptions import NotFittedError
11
+ from sklearn.model_selection import train_test_split
12
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
13
+ from snpio.utils.logging import LoggerManager
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR
15
+
16
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
17
+ from pgsui.data_processing.containers import UBPConfig
18
+ from pgsui.impute.unsupervised.base import BaseNNImputer
19
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
20
+ from pgsui.impute.unsupervised.models.ubp_model import UBPModel
21
+
22
+ if TYPE_CHECKING:
23
+ from snpio.read_input.genotype_data import GenotypeData
24
+
25
+
26
+ def ensure_ubp_config(config: UBPConfig | dict | str | None) -> UBPConfig:
27
+ """Return a concrete UBPConfig from dataclass, dict, YAML path, or None.
28
+
29
+ This method normalizes the input configuration for the UBP imputer. It accepts a UBPConfig instance, a dictionary, a YAML file path, or None. If None is provided, it returns a default UBPConfig instance. If a YAML path is given, it loads the configuration from the file, supporting top-level presets. If a dictionary is provided, it flattens any nested structures and applies dot-key overrides to a base configuration, which can also be influenced by a preset if specified. The method ensures that the final output is a fully populated UBPConfig instance.
30
+
31
+ Args:
32
+ config: UBPConfig | dict | YAML path | None.
33
+
34
+ Returns:
35
+ UBPConfig: Normalized configuration instance.
36
+ """
37
+ if config is None:
38
+ return UBPConfig()
39
+ if isinstance(config, UBPConfig):
40
+ return config
41
+ if isinstance(config, str):
42
+ # YAML path — support top-level `preset`
43
+ return load_yaml_to_dataclass(
44
+ config,
45
+ UBPConfig,
46
+ preset_builder=UBPConfig.from_preset,
47
+ )
48
+ if isinstance(config, dict):
49
+ base = UBPConfig()
50
+
51
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
52
+ for k, v in d.items():
53
+ kk = f"{prefix}.{k}" if prefix else k
54
+ if isinstance(v, dict):
55
+ _flatten(kk, v, out)
56
+ else:
57
+ out[kk] = v
58
+ return out
59
+
60
+ preset_name = config.pop("preset", None)
61
+ if "io" in config and isinstance(config["io"], dict):
62
+ preset_name = preset_name or config["io"].pop("preset", None)
63
+ if preset_name:
64
+ base = UBPConfig.from_preset(preset_name)
65
+
66
+ flat = _flatten("", config, {})
67
+ return apply_dot_overrides(base, flat)
68
+
69
+ raise TypeError("config must be a UBPConfig, dict, YAML path, or None.")
70
+
71
+
72
+ class ImputeUBP(BaseNNImputer):
73
+ """UBP imputer for 0/1/2 genotypes with three-phase training.
74
+
75
+ This imputer uses a three-phase training schedule specific to the UBP model:
76
+
77
+ 1. Pre-training: Train the model on the full dataset with a small learning rate.
78
+ 2. Fine-tuning: Train the model on the full dataset with a larger learning rate.
79
+ 3. Evaluation: Evaluate the model on the test set. Optimize latents for test set. Predict 0/1/2. Decode to IUPAC. Plot & report.
80
+ 4. Post-processing: Apply any necessary post-processing steps to the imputed genotypes.
81
+
82
+
83
+ References:
84
+ - Gashler, Michael S., Smith, Michael R., Morris, R., and Martinez, T. (2016) Missing Value Imputation with Unsupervised Backpropagation. Computational Intelligence, 32: 196-215. doi: 10.1111/coin.12048.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ genotype_data: "GenotypeData",
90
+ *,
91
+ config: UBPConfig | dict | str | None = None,
92
+ overrides: dict | None = None,
93
+ ):
94
+ """Initialize the UBP imputer via dataclass/dict/YAML config with overrides.
95
+
96
+ This constructor allows for flexible initialization of the UBP imputer by accepting various forms of configuration input. It ensures that the configuration is properly normalized and any specified overrides are applied. The method also sets up logging and initializes various attributes related to the model, training, tuning, and evaluation based on the provided configuration.
97
+
98
+ Args:
99
+ genotype_data (GenotypeData): Backing genotype data object.
100
+ config (UBPConfig | dict | str | None): UBP configuration.
101
+ overrides (dict | None): Flat dot-key overrides applied after `config`.
102
+ """
103
+ self.model_name = "ImputeUBP"
104
+ self.genotype_data = genotype_data
105
+
106
+ # ---- normalize config, then apply overrides ----
107
+ cfg = ensure_ubp_config(config)
108
+ if overrides:
109
+ cfg = apply_dot_overrides(cfg, overrides)
110
+ self.cfg = cfg
111
+
112
+ # ---- logging ----
113
+ logman = LoggerManager(
114
+ __name__,
115
+ prefix=self.cfg.io.prefix,
116
+ debug=self.cfg.io.debug,
117
+ verbose=self.cfg.io.verbose,
118
+ )
119
+ self.logger = logman.get_logger()
120
+
121
+ # ---- Base init ----
122
+ super().__init__(
123
+ prefix=self.cfg.io.prefix,
124
+ device=self.cfg.train.device,
125
+ verbose=self.cfg.io.verbose,
126
+ debug=self.cfg.io.debug,
127
+ )
128
+
129
+ # ---- model/meta ----
130
+ self.Model = UBPModel
131
+ self.pgenc = GenotypeEncoder(genotype_data)
132
+
133
+ self.seed = self.cfg.io.seed
134
+ self.n_jobs = self.cfg.io.n_jobs
135
+ self.prefix = self.cfg.io.prefix
136
+ self.scoring_averaging = self.cfg.io.scoring_averaging
137
+ self.verbose = self.cfg.io.verbose
138
+ self.debug = self.cfg.io.debug
139
+ self.rng = np.random.default_rng(self.seed)
140
+
141
+ # ---- model hyperparams ----
142
+ self.latent_dim = self.cfg.model.latent_dim
143
+ self.dropout_rate = self.cfg.model.dropout_rate
144
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
145
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
146
+ self.layer_schedule = self.cfg.model.layer_schedule
147
+ self.latent_init = self.cfg.model.latent_init
148
+ self.activation = self.cfg.model.hidden_activation
149
+ self.gamma = self.cfg.model.gamma
150
+
151
+ # ---- training ----
152
+ self.batch_size = self.cfg.train.batch_size
153
+ self.learning_rate = self.cfg.train.learning_rate
154
+ self.lr_input_factor = self.cfg.train.lr_input_factor
155
+ self.l1_penalty = self.cfg.train.l1_penalty
156
+ self.early_stop_gen = self.cfg.train.early_stop_gen
157
+ self.min_epochs = self.cfg.train.min_epochs
158
+ self.epochs = self.cfg.train.max_epochs
159
+ self.validation_split = self.cfg.train.validation_split
160
+ self.beta = self.cfg.train.weights_beta
161
+ self.max_ratio = self.cfg.train.weights_max_ratio
162
+
163
+ # ---- tuning ----
164
+ self.tune = self.cfg.tune.enabled
165
+ self.tune_fast = self.cfg.tune.fast
166
+ self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
167
+ self.tune_batch_size = self.cfg.tune.batch_size
168
+ self.tune_epochs = self.cfg.tune.epochs
169
+ self.tune_eval_interval = self.cfg.tune.eval_interval
170
+ self.tune_metric = self.cfg.tune.metric
171
+ self.n_trials = self.cfg.tune.n_trials
172
+ self.tune_save_db = self.cfg.tune.save_db
173
+ self.tune_resume = self.cfg.tune.resume
174
+ self.tune_max_samples = self.cfg.tune.max_samples
175
+ self.tune_max_loci = self.cfg.tune.max_loci
176
+ self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
177
+ self.tune_patience = self.cfg.tune.patience
178
+
179
+ # ---- evaluation ----
180
+ self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
181
+ self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
182
+ self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
183
+
184
+ # ---- plotting ----
185
+ self.plot_format = self.cfg.plot.fmt
186
+ self.plot_dpi = self.cfg.plot.dpi
187
+ self.plot_fontsize = self.cfg.plot.fontsize
188
+ self.title_fontsize = self.cfg.plot.fontsize
189
+ self.despine = self.cfg.plot.despine
190
+ self.show_plots = self.cfg.plot.show
191
+
192
+ # ---- core runtime ----
193
+ self.is_haploid = None
194
+ self.num_classes_ = None
195
+ self.model_params: Dict[str, Any] = {}
196
+
197
+ def fit(self) -> "ImputeUBP":
198
+ """Fit the UBP decoder on 0/1/2 encodings (missing = -1). Three phases.
199
+
200
+ 1. Pre-training: Train the model on the full dataset with a small learning rate.
201
+ 2. Fine-tuning: Train the model on the full dataset with a larger learning rate.
202
+ 3. Evaluation: Evaluate the model on the test set. Optimize latents for test set. Predict 0/1/2. Decode to IUPAC. Plot & report.
203
+ 4. Post-processing: Apply any necessary post-processing steps to the imputed genotypes.
204
+
205
+ Returns:
206
+ ImputeUBP: Fitted instance.
207
+
208
+ Raises:
209
+ NotFittedError: If training fails.
210
+ """
211
+ self.logger.info(f"Fitting {self.model_name} model...")
212
+
213
+ # --- Use 0/1/2 with -1 for missing ---
214
+ X = self.pgenc.genotypes_012.astype(np.float32)
215
+ X[X < 0] = np.nan
216
+ X[np.isnan(X)] = -1
217
+ self.ground_truth_ = X.astype(np.int64)
218
+
219
+ # --- Determine ploidy (haploid vs diploid) and classes ---
220
+ self.is_haploid = np.all(
221
+ np.isin(
222
+ self.genotype_data.snp_data, ["A", "C", "G", "T", "N", "-", ".", "?"]
223
+ )
224
+ )
225
+ self.ploidy = 1 if self.is_haploid else 2
226
+
227
+ if self.is_haploid:
228
+ self.num_classes_ = 2
229
+ self.ground_truth_[self.ground_truth_ == 2] = 1
230
+ self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
231
+ else:
232
+ self.num_classes_ = 3
233
+ self.logger.info(
234
+ "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
235
+ )
236
+
237
+ n_samples, self.num_features_ = X.shape
238
+
239
+ # --- model params (decoder: Z -> L * num_classes) ---
240
+ self.model_params = {
241
+ "n_features": self.num_features_,
242
+ "num_classes": self.num_classes_,
243
+ "latent_dim": self.latent_dim,
244
+ "dropout_rate": self.dropout_rate,
245
+ "activation": self.activation,
246
+ # hidden_layer_sizes injected later
247
+ }
248
+
249
+ # --- split ---
250
+ indices = np.arange(n_samples)
251
+ train_idx, test_idx = train_test_split(
252
+ indices, test_size=self.validation_split, random_state=self.seed
253
+ )
254
+ self.train_idx_, self.test_idx_ = train_idx, test_idx
255
+ self.X_train_ = self.ground_truth_[train_idx]
256
+ self.X_test_ = self.ground_truth_[test_idx]
257
+
258
+ # --- plotting/scorers & tuning ---
259
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
260
+ if self.tune:
261
+ self.tune_hyperparameters()
262
+
263
+ self.best_params_ = getattr(
264
+ self, "best_params_", self._set_best_params_default()
265
+ )
266
+
267
+ # --- class weights for 0/1/2 ---
268
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
269
+
270
+ # --- latent init & loader ---
271
+ train_latent_vectors = self._create_latent_space(
272
+ self.best_params_, len(self.X_train_), self.X_train_, self.latent_init
273
+ )
274
+ train_loader = self._get_data_loaders(self.X_train_)
275
+
276
+ # --- final training (three-phase under the hood) ---
277
+ (self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
278
+ self._train_final_model(
279
+ loader=train_loader,
280
+ best_params=self.best_params_,
281
+ initial_latent_vectors=train_latent_vectors,
282
+ )
283
+ )
284
+
285
+ self.is_fit_ = True
286
+ self.plotter_.plot_history(self.history_)
287
+ self._evaluate_model(self.X_test_, self.model_, self.best_params_)
288
+ self._save_best_params(self.best_params_)
289
+ return self
290
+
291
+ def transform(self) -> np.ndarray:
292
+ """Impute missing genotypes (0/1/2) and return IUPAC strings.
293
+
294
+ This method first checks if the model has been fitted. It then imputes the entire dataset by optimizing latent vectors for the ground truth data and predicting the missing genotypes using the trained UBP model. The imputed genotypes are decoded to IUPAC format, and distributions of original and imputed genotypes are plotted.
295
+
296
+ Returns:
297
+ np.ndarray: IUPAC single-character array (n_samples x L).
298
+
299
+ Raises:
300
+ NotFittedError: If called before fit().
301
+ """
302
+ if not getattr(self, "is_fit_", False):
303
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
304
+
305
+ self.logger.info("Imputing entire dataset with UBP (0/1/2)...")
306
+ X_to_impute = self.ground_truth_.copy()
307
+
308
+ optimized_latents = self._optimize_latents_for_inference(
309
+ X_to_impute, self.model_, self.best_params_
310
+ )
311
+ pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
312
+
313
+ missing_mask = X_to_impute == -1
314
+ imputed_array = X_to_impute.copy()
315
+ imputed_array[missing_mask] = pred_labels[missing_mask]
316
+
317
+ # Decode to IUPAC for return & plots
318
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
319
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
320
+
321
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
322
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
323
+ return imputed_genotypes
324
+
325
+ def _train_step(
326
+ self,
327
+ loader: torch.utils.data.DataLoader,
328
+ optimizer: torch.optim.Optimizer,
329
+ latent_optimizer: torch.optim.Optimizer,
330
+ model: torch.nn.Module,
331
+ l1_penalty: float,
332
+ latent_vectors: torch.nn.Parameter,
333
+ class_weights: torch.Tensor,
334
+ phase: int,
335
+ ) -> Tuple[float, torch.nn.Parameter]:
336
+ """Single epoch over batches for UBP with 0/1/2 focal CE.
337
+
338
+ This method handles all three UBP phases:
339
+
340
+ 1. Pre-training: Train the model on the full dataset with a small learning rate.
341
+ 2. Fine-tuning: Train the model on the full dataset with a larger learning rate.
342
+ 3. Joint training: Train both model and latents.
343
+
344
+ Args:
345
+ loader (torch.utils.data.DataLoader): DataLoader (indices, y_batch).
346
+ optimizer (torch.optim.Optimizer): Decoder optimizer.
347
+ latent_optimizer (torch.optim.Optimizer): Latent optimizer.
348
+ model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
349
+ l1_penalty (float): L1 regularization weight.
350
+ latent_vectors (torch.nn.Parameter): Trainable Z.
351
+ class_weights (torch.Tensor): Class weights for 0/1/2.
352
+ phase (int): Phase id (1, 2, 3). Phase 1 = warm-up, phase 2 = decoder-only, phase 3 = joint.
353
+
354
+ Returns:
355
+ Tuple[float, torch.nn.Parameter]: Average loss and updated latents.
356
+ """
357
+ model.train()
358
+ running = 0.0
359
+
360
+ for batch_indices, y_batch in loader:
361
+ optimizer.zero_grad(set_to_none=True)
362
+ latent_optimizer.zero_grad(set_to_none=True)
363
+
364
+ decoder = model.phase1_decoder if phase == 1 else model.phase23_decoder
365
+ logits = decoder(latent_vectors[batch_indices]).view(
366
+ len(batch_indices), self.num_features_, self.num_classes_
367
+ )
368
+
369
+ logits_flat = logits.view(-1, self.num_classes_)
370
+ targets_flat = y_batch.view(-1)
371
+
372
+ ce = F.cross_entropy(
373
+ logits_flat,
374
+ targets_flat,
375
+ weight=class_weights,
376
+ reduction="none",
377
+ ignore_index=-1,
378
+ )
379
+ pt = torch.exp(-ce)
380
+ gamma = getattr(model, "gamma", self.gamma)
381
+ focal = ((1 - pt) ** gamma) * ce
382
+
383
+ valid_mask = targets_flat != -1
384
+ loss = (
385
+ focal[valid_mask].mean()
386
+ if valid_mask.any()
387
+ else torch.tensor(0.0, device=logits.device)
388
+ )
389
+
390
+ if l1_penalty > 0:
391
+ loss = loss + l1_penalty * sum(
392
+ p.abs().sum() for p in model.parameters()
393
+ )
394
+
395
+ loss.backward()
396
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
397
+ torch.nn.utils.clip_grad_norm_([latent_vectors], 1.0)
398
+
399
+ optimizer.step()
400
+
401
+ if phase != 2:
402
+ latent_optimizer.step()
403
+
404
+ running += float(loss.item())
405
+
406
+ return running / len(loader), latent_vectors
407
+
408
+ def _predict(
409
+ self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter | None = None
410
+ ) -> Tuple[np.ndarray, np.ndarray]:
411
+ """Predict 0/1/2 labels & probabilities from latents via phase23 decoder. This method requires a trained model and latent vectors.
412
+
413
+ Args:
414
+ model (torch.nn.Module): Trained model.
415
+ latent_vectors (torch.nn.Parameter | None): Latent vectors.
416
+
417
+ Returns:
418
+ Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
419
+ """
420
+ if model is None or latent_vectors is None:
421
+ msg = "Model and latent vectors must be provided for prediction. Fit the model first."
422
+ self.logger.error(msg)
423
+ raise NotFittedError(msg)
424
+
425
+ model.eval()
426
+ nF = getattr(model, "n_features", self.num_features_)
427
+ with torch.no_grad():
428
+ logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
429
+ len(latent_vectors), nF, self.num_classes_
430
+ )
431
+ probas = torch.softmax(logits, dim=-1)
432
+ labels = torch.argmax(probas, dim=-1)
433
+
434
+ return labels.cpu().numpy(), probas.cpu().numpy()
435
+
436
+ def _evaluate_model(
437
+ self,
438
+ X_val: np.ndarray,
439
+ model: torch.nn.Module,
440
+ params: dict,
441
+ objective_mode: bool = False,
442
+ latent_vectors_val: torch.Tensor | None = None,
443
+ ) -> Dict[str, float]:
444
+ """Evaluate on held-out set with 0/1/2 classes; also IUPAC/10-base reports.
445
+
446
+ This method evaluates the trained UBP model on a held-out validation set. It optimizes latent vectors for the validation data if they are not provided, predicts 0/1/2 labels and probabilities, and computes various performance metrics. If not in objective mode, it generates detailed classification reports and confusion matrices for both 0/1/2 genotypes and their IUPAC/10-base representations. The method returns a dictionary of evaluation metrics.
447
+
448
+ Args:
449
+ X_val (np.ndarray): 0/1/2 with -1 for missing.
450
+ model (torch.nn.Module): Trained model.
451
+ params (dict): Model params.
452
+ objective_mode (bool): If True, return only tuned metric.
453
+ latent_vectors_val (torch.Tensor | None): Optional pre-optimized latents.
454
+
455
+ Returns:
456
+ Metrics dict.
457
+ """
458
+ if latent_vectors_val is not None:
459
+ test_latent_vectors = latent_vectors_val
460
+ else:
461
+ test_latent_vectors = self._optimize_latents_for_inference(
462
+ X_val, model, params
463
+ )
464
+
465
+ pred_labels, pred_probas = self._predict(
466
+ model=model, latent_vectors=test_latent_vectors
467
+ )
468
+
469
+ eval_mask = X_val != -1
470
+ y_true_flat = X_val[eval_mask]
471
+ y_pred_flat = pred_labels[eval_mask]
472
+ y_proba_flat = pred_probas[eval_mask]
473
+
474
+ if y_true_flat.size == 0:
475
+ return {self.tune_metric: 0.0}
476
+
477
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
478
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
479
+
480
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
481
+
482
+ metrics = self.scorers_.evaluate(
483
+ y_true_flat,
484
+ y_pred_flat,
485
+ y_true_ohe,
486
+ y_proba_flat,
487
+ objective_mode,
488
+ self.tune_metric,
489
+ )
490
+
491
+ if not objective_mode:
492
+ self.logger.info(f"Validation Metrics (0/1/2): {metrics}")
493
+
494
+ self._make_class_reports(
495
+ y_true=y_true_flat,
496
+ y_pred_proba=y_proba_flat,
497
+ y_pred=y_pred_flat,
498
+ metrics=metrics,
499
+ labels=target_names,
500
+ )
501
+
502
+ # IUPAC / 10-base auxiliary reports
503
+ y_true_dec = self.pgenc.decode_012(X_val)
504
+ X_pred = X_val.copy()
505
+ X_pred[eval_mask] = y_pred_flat
506
+
507
+ nF_eval = X_val.shape[1]
508
+ y_pred_dec = self.pgenc.decode_012(X_pred.reshape(X_val.shape[0], nF_eval))
509
+
510
+ encodings_dict = {
511
+ "A": 0,
512
+ "C": 1,
513
+ "G": 2,
514
+ "T": 3,
515
+ "W": 4,
516
+ "R": 5,
517
+ "M": 6,
518
+ "K": 7,
519
+ "Y": 8,
520
+ "S": 9,
521
+ "N": -1,
522
+ }
523
+ y_true_int = self.pgenc.convert_int_iupac(
524
+ y_true_dec, encodings_dict=encodings_dict
525
+ )
526
+ y_pred_int = self.pgenc.convert_int_iupac(
527
+ y_pred_dec, encodings_dict=encodings_dict
528
+ )
529
+
530
+ self._make_class_reports(
531
+ y_true=y_true_int[eval_mask],
532
+ y_pred=y_pred_int[eval_mask],
533
+ metrics=metrics,
534
+ y_pred_proba=None,
535
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
536
+ )
537
+
538
+ return metrics
539
+
540
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
541
+ """Create DataLoader over indices + 0/1/2 target matrix.
542
+
543
+ This method creates a PyTorch DataLoader for the given genotype matrix, which contains 0/1/2 encodings with -1 for missing values. The DataLoader is constructed to yield batches of data during training, where each batch consists of indices and the corresponding genotype values. The genotype matrix is converted to a PyTorch tensor and moved to the appropriate device (CPU or GPU) before being wrapped in a TensorDataset. The DataLoader is configured to shuffle the data and use the specified batch size.
544
+
545
+ Args:
546
+ y (np.ndarray): (n_samples x L) int matrix with -1 missing.
547
+
548
+ Returns:
549
+ torch.utils.data.DataLoader: Shuffled mini-batches.
550
+ """
551
+ y_tensor = torch.from_numpy(y).long().to(self.device)
552
+ dataset = torch.utils.data.TensorDataset(
553
+ torch.arange(len(y), device=self.device), y_tensor.to(self.device)
554
+ )
555
+ return torch.utils.data.DataLoader(
556
+ dataset, batch_size=self.batch_size, shuffle=True
557
+ )
558
+
559
+ def _objective(self, trial: optuna.Trial) -> float:
560
+ """Optuna objective using the UBP training loop."""
561
+ try:
562
+ params = self._sample_hyperparameters(trial)
563
+
564
+ X_train_trial = self.ground_truth_[self.train_idx_]
565
+ X_test_trial = self.ground_truth_[self.test_idx_]
566
+
567
+ class_weights = self._class_weights_from_zygosity(X_train_trial)
568
+ train_loader = self._get_data_loaders(X_train_trial)
569
+
570
+ train_latent_vectors = self._create_latent_space(
571
+ params, len(X_train_trial), X_train_trial, params["latent_init"]
572
+ )
573
+
574
+ model = self.build_model(self.Model, params["model_params"])
575
+ model.n_features = params["model_params"]["n_features"]
576
+ model.apply(self.initialize_weights)
577
+
578
+ _, model, _ = self._train_and_validate_model(
579
+ model=model,
580
+ loader=train_loader,
581
+ lr=params["lr"],
582
+ l1_penalty=params["l1_penalty"],
583
+ trial=trial,
584
+ return_history=False,
585
+ latent_vectors=train_latent_vectors,
586
+ lr_input_factor=params["lr_input_factor"],
587
+ class_weights=class_weights,
588
+ X_val=X_test_trial,
589
+ params=params,
590
+ prune_metric=self.tune_metric,
591
+ prune_warmup_epochs=5,
592
+ eval_interval=1,
593
+ eval_requires_latents=True,
594
+ eval_latent_steps=50,
595
+ eval_latent_lr=params["lr"] * params["lr_input_factor"],
596
+ eval_latent_weight_decay=0.0,
597
+ )
598
+
599
+ metrics = self._evaluate_model(
600
+ X_test_trial, model, params, objective_mode=True
601
+ )
602
+ self._clear_resources(
603
+ model, train_loader, latent_vectors=train_latent_vectors
604
+ )
605
+ return metrics[self.tune_metric]
606
+ except Exception as e:
607
+ raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
608
+
609
+ def _sample_hyperparameters(
610
+ self, trial: optuna.Trial
611
+ ) -> Dict[str, int | float | str | list]:
612
+ """Sample UBP hyperparameters; compute hidden sizes for model_params.
613
+
614
+ This method samples a set of hyperparameters for the UBP model using the provided Optuna trial object. It defines a search space for various hyperparameters, including latent dimension, learning rate, dropout rate, number of hidden layers, activation function, and others. After sampling the hyperparameters, it computes the sizes of the hidden layers based on the sampled values and constructs the model parameters dictionary. The method returns a dictionary containing all sampled hyperparameters along with the computed model parameters.
615
+
616
+ Args:
617
+ trial (optuna.Trial): Current trial.
618
+
619
+ Returns:
620
+ Dict[str, int | float | str | list]: Sampled hyperparameters.
621
+ """
622
+ params = {
623
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
624
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
625
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
626
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
627
+ "activation": trial.suggest_categorical(
628
+ "activation", ["relu", "elu", "selu"]
629
+ ),
630
+ "gamma": trial.suggest_float("gamma", 0.0, 5.0),
631
+ "lr_input_factor": trial.suggest_float(
632
+ "lr_input_factor", 0.1, 10.0, log=True
633
+ ),
634
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
635
+ "layer_scaling_factor": trial.suggest_float(
636
+ "layer_scaling_factor", 2.0, 10.0
637
+ ),
638
+ "layer_schedule": trial.suggest_categorical(
639
+ "layer_schedule", ["pyramid", "constant", "linear"]
640
+ ),
641
+ "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
642
+ }
643
+
644
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
645
+ n_inputs=params["latent_dim"],
646
+ n_outputs=self.num_features_ * self.num_classes_,
647
+ n_samples=len(self.train_idx_),
648
+ n_hidden=params["num_hidden_layers"],
649
+ alpha=params["layer_scaling_factor"],
650
+ schedule=params["layer_schedule"],
651
+ )
652
+ # Keep the latent_dim as the first element,
653
+ # then the interior hidden widths.
654
+ # If there are no interior widths (very small nets),
655
+ # this still leaves [latent_dim].
656
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
657
+
658
+ params["model_params"] = {
659
+ "n_features": self.num_features_,
660
+ "num_classes": self.num_classes_,
661
+ "latent_dim": params["latent_dim"],
662
+ "dropout_rate": params["dropout_rate"],
663
+ "hidden_layer_sizes": hidden_only,
664
+ "activation": params["activation"],
665
+ }
666
+
667
+ return params
668
+
669
+ def _set_best_params(
670
+ self, best_params: Dict[str, int | float | str | list]
671
+ ) -> Dict[str, int | float | str | list]:
672
+ """Set best params onto instance; return model_params payload.
673
+
674
+ This method sets the best hyperparameters found during tuning onto the instance attributes of the ImputeUBP class. It extracts the relevant hyperparameters from the provided dictionary and updates the corresponding instance variables. Additionally, it computes the sizes of the hidden layers based on the best hyperparameters and constructs the model parameters dictionary. The method returns a dictionary containing the model parameters that can be used to build the UBP model.
675
+
676
+ Args:
677
+ best_params (Dict[str, int | float | str | list]): Best hyperparameters.
678
+
679
+ Returns:
680
+ Dict[str, int | float | str | list]: model_params payload.
681
+
682
+ Raises:
683
+ ValueError: If best_params is missing required keys.
684
+ """
685
+ self.latent_dim = best_params["latent_dim"]
686
+ self.dropout_rate = best_params["dropout_rate"]
687
+ self.learning_rate = best_params["learning_rate"]
688
+ self.gamma = best_params["gamma"]
689
+ self.lr_input_factor = best_params["lr_input_factor"]
690
+ self.l1_penalty = best_params["l1_penalty"]
691
+ self.activation = best_params["activation"]
692
+ self.latent_init = best_params["latent_init"]
693
+
694
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
695
+ n_inputs=self.latent_dim,
696
+ n_outputs=self.num_features_ * self.num_classes_,
697
+ n_samples=len(self.train_idx_),
698
+ n_hidden=best_params["num_hidden_layers"],
699
+ alpha=best_params["layer_scaling_factor"],
700
+ schedule=best_params["layer_schedule"],
701
+ )
702
+
703
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
704
+
705
+ return {
706
+ "n_features": self.num_features_,
707
+ "latent_dim": self.latent_dim,
708
+ "hidden_layer_sizes": hidden_only,
709
+ "dropout_rate": self.dropout_rate,
710
+ "activation": self.activation,
711
+ "gamma": self.gamma,
712
+ "num_classes": self.num_classes_,
713
+ }
714
+
715
+ def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
716
+ """Default (no-tuning) model_params aligned with current attributes.
717
+
718
+ This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
719
+
720
+ Returns:
721
+ Dict[str, int | float | str | list]: model_params payload.
722
+ """
723
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
724
+ n_inputs=self.latent_dim,
725
+ n_outputs=self.num_features_ * self.num_classes_,
726
+ n_samples=len(self.ground_truth_),
727
+ n_hidden=self.num_hidden_layers,
728
+ alpha=self.layer_scaling_factor,
729
+ schedule=self.layer_schedule,
730
+ )
731
+
732
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
733
+
734
+ return {
735
+ "n_features": self.num_features_,
736
+ "latent_dim": self.latent_dim,
737
+ "hidden_layer_sizes": hidden_only,
738
+ "dropout_rate": self.dropout_rate,
739
+ "activation": self.activation,
740
+ "gamma": self.gamma,
741
+ "num_classes": self.num_classes_,
742
+ }
743
+
744
+ def _train_and_validate_model(
745
+ self,
746
+ model: torch.nn.Module,
747
+ loader: torch.utils.data.DataLoader,
748
+ lr: float,
749
+ l1_penalty: float,
750
+ trial: optuna.Trial | None = None,
751
+ return_history: bool = False,
752
+ latent_vectors: torch.nn.Parameter | None = None,
753
+ lr_input_factor: float = 1.0,
754
+ class_weights: torch.Tensor | None = None,
755
+ *,
756
+ X_val: np.ndarray | None = None,
757
+ params: dict | None = None,
758
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
759
+ prune_warmup_epochs: int = 3,
760
+ eval_interval: int = 1,
761
+ eval_requires_latents: bool = True, # UBP needs latent eval
762
+ eval_latent_steps: int = 50,
763
+ eval_latent_lr: float = 1e-2,
764
+ eval_latent_weight_decay: float = 0.0,
765
+ ) -> Tuple[float, torch.nn.Module | None, dict, torch.nn.Parameter | None]:
766
+ """Train & validate UBP model with three-phase loop.
767
+
768
+ This method trains and validates the UBP model using a three-phase training loop. It sets up the latent optimizer and invokes the training loop, which includes pre-training, fine-tuning, and joint training phases. The method ensures that the necessary latent vectors and class weights are provided before proceeding with training. It also incorporates new parameters for evaluation and pruning during training. The final best loss, best model, training history, and optimized latent vectors are returned.
769
+
770
+ Args:
771
+ model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
772
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
773
+ lr (float): Learning rate for decoder.
774
+ l1_penalty (float): L1 regularization weight.
775
+ trial (optuna.Trial | None): Current trial or None.
776
+ return_history (bool): If True, return loss history.
777
+ latent_vectors (torch.nn.Parameter | None): Trainable Z.
778
+ lr_input_factor (float): LR factor for latents.
779
+ class_weights (torch.Tensor | None): Class weights for 0/1/2.
780
+ X_val (np.ndarray | None): Validation set for pruning/eval.
781
+ params (dict | None): Model params for eval.
782
+ prune_metric (str | None): Metric to monitor for pruning.
783
+ prune_warmup_epochs (int): Epochs before pruning starts.
784
+ eval_interval (int): Epochs between evaluations.
785
+ eval_requires_latents (bool): If True, optimize latents for eval.
786
+ eval_latent_steps (int): Latent optimization steps for eval.
787
+ eval_latent_lr (float): Latent optimization LR for eval.
788
+ eval_latent_weight_decay (float): Latent optimization weight decay for eval.
789
+
790
+ Returns:
791
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
792
+
793
+ Raises:
794
+ TypeError: If latent_vectors or class_weights are
795
+ not provided.
796
+ ValueError: If X_val is not provided for evaluation.
797
+ RuntimeError: If eval_latent_steps is not positive.
798
+ """
799
+ if latent_vectors is None or class_weights is None:
800
+ msg = "Must provide latent_vectors and class_weights."
801
+ self.logger.error(msg)
802
+ raise TypeError(msg)
803
+
804
+ latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
805
+
806
+ result = self._execute_training_loop(
807
+ loader=loader,
808
+ latent_optimizer=latent_optimizer,
809
+ lr=lr,
810
+ model=model,
811
+ l1_penalty=l1_penalty,
812
+ trial=trial,
813
+ return_history=return_history,
814
+ latent_vectors=latent_vectors,
815
+ class_weights=class_weights,
816
+ # NEW ↓↓↓
817
+ X_val=X_val,
818
+ params=params,
819
+ prune_metric=prune_metric,
820
+ prune_warmup_epochs=prune_warmup_epochs,
821
+ eval_interval=eval_interval,
822
+ eval_requires_latents=eval_requires_latents,
823
+ eval_latent_steps=eval_latent_steps,
824
+ eval_latent_lr=eval_latent_lr,
825
+ eval_latent_weight_decay=eval_latent_weight_decay,
826
+ )
827
+
828
+ if return_history:
829
+ return result
830
+
831
+ return result[0], result[1], result[3]
832
+
833
+ def _train_final_model(
834
+ self,
835
+ loader: torch.utils.data.DataLoader,
836
+ best_params: Dict[str, int | float | str | list],
837
+ initial_latent_vectors: torch.nn.Parameter,
838
+ ) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
839
+ """Train final UBP model with best params; save weights to disk.
840
+
841
+ This method trains the final UBP model using the best hyperparameters found during tuning. It builds the model with the specified parameters, initializes the weights, and invokes the training and validation process. The method saves the trained model's state dictionary to disk and returns the final loss, trained model, training history, and optimized latent vectors.
842
+
843
+ Args:
844
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
845
+ best_params (Dict[str, int | float | str | list]): Best hyperparameters.
846
+ initial_latent_vectors (torch.nn.Parameter): Initialized latent vectors.
847
+
848
+ Returns:
849
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (loss, model, {"Train": history}, latents).
850
+ """
851
+ self.logger.info("Training the final UBP (0/1/2) model...")
852
+
853
+ model = self.build_model(self.Model, best_params)
854
+ model.n_features = best_params["n_features"]
855
+ model.apply(self.initialize_weights)
856
+
857
+ loss, trained_model, history, latent_vectors = self._train_and_validate_model(
858
+ model=model,
859
+ loader=loader,
860
+ lr=self.learning_rate,
861
+ l1_penalty=self.l1_penalty,
862
+ return_history=True,
863
+ latent_vectors=initial_latent_vectors,
864
+ lr_input_factor=self.lr_input_factor,
865
+ class_weights=self.class_weights_,
866
+ X_val=self.X_test_,
867
+ params=best_params,
868
+ prune_metric=self.tune_metric,
869
+ prune_warmup_epochs=5,
870
+ eval_interval=1,
871
+ eval_requires_latents=True,
872
+ eval_latent_steps=50,
873
+ eval_latent_lr=self.learning_rate * self.lr_input_factor,
874
+ eval_latent_weight_decay=0.0,
875
+ )
876
+
877
+ if trained_model is None:
878
+ msg = "Final model training failed."
879
+ self.logger.error(msg)
880
+ raise RuntimeError(msg)
881
+
882
+ fout = self.models_dir / "final_model.pt"
883
+ torch.save(trained_model.state_dict(), fout)
884
+ return loss, trained_model, {"Train": history}, latent_vectors
885
+
886
+ def _execute_training_loop(
887
+ self,
888
+ loader: torch.utils.data.DataLoader,
889
+ latent_optimizer: torch.optim.Optimizer,
890
+ lr: float,
891
+ model: torch.nn.Module,
892
+ l1_penalty: float,
893
+ trial,
894
+ return_history: bool,
895
+ latent_vectors: torch.nn.Parameter,
896
+ class_weights: torch.Tensor,
897
+ *,
898
+ X_val: np.ndarray | None = None,
899
+ params: dict | None = None,
900
+ prune_metric: str | None = None,
901
+ prune_warmup_epochs: int = 3,
902
+ eval_interval: int = 1,
903
+ eval_requires_latents: bool = True,
904
+ eval_latent_steps: int = 50,
905
+ eval_latent_lr: float = 1e-2,
906
+ eval_latent_weight_decay: float = 0.0,
907
+ ) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
908
+ """Three-phase UBP loop with cosine LR, gamma warmup, and pruning hook.
909
+
910
+ This method executes the three-phase training loop for the UBP model, which includes pre-training, fine-tuning, and joint training phases. It incorporates a cosine annealing learning rate scheduler, focal loss gamma warmup, and an early stopping mechanism. The method also includes a pruning hook for Optuna trials, allowing for early termination of unpromising trials based on validation performance. The final best loss, best model, training history, and optimized latent vectors are returned.
911
+
912
+ Args:
913
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
914
+ latent_optimizer (torch.optim.Optimizer): Latent optimizer.
915
+ lr (float): Learning rate for decoder.
916
+ model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
917
+ l1_penalty (float): L1 regularization weight.
918
+ trial: Current trial or None.
919
+ return_history (bool): If True, return loss history.
920
+ latent_vectors (torch.nn.Parameter): Trainable Z.
921
+ class_weights (torch.Tensor): Class weights for 0/1/2.
922
+ X_val (np.ndarray | None): Validation set for pruning/eval.
923
+ params (dict | None): Model params for eval.
924
+ prune_metric (str | None): Metric to monitor for pruning.
925
+ prune_warmup_epochs (int): Epochs before pruning starts.
926
+ eval_interval (int): Epochs between evaluations.
927
+ eval_requires_latents (bool): If True, optimize latents for eval.
928
+ eval_latent_steps (int): Latent optimization steps for eval.
929
+ eval_latent_lr (float): Latent optimization LR for eval.
930
+ eval_latent_weight_decay (float): Latent optimization weight decay for eval.
931
+
932
+ Returns:
933
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
934
+
935
+ Raises:
936
+ TypeError: If X_val is not provided for evaluation.
937
+ ValueError: If eval_latent_steps is not positive.
938
+ """
939
+ history: dict[str, list[float]] = {}
940
+ final_best_loss = float("inf")
941
+ final_best_model = None
942
+
943
+ # Schema-aware latent cache for eval
944
+ _latent_cache: dict = {}
945
+ nF = getattr(model, "n_features", self.num_features_)
946
+ cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.num_classes_}"
947
+
948
+ # Epoch budget; if you later add tune_fast behavior to UBP, wire it here
949
+ max_epochs = self.epochs
950
+ warm, ramp, gamma_final = 50, 100, self.gamma
951
+
952
+ for phase in (1, 2, 3):
953
+ early_stopping = EarlyStopping(
954
+ patience=self.early_stop_gen,
955
+ min_epochs=self.min_epochs,
956
+ verbose=self.verbose,
957
+ prefix=self.prefix,
958
+ debug=self.debug,
959
+ )
960
+
961
+ if phase == 2:
962
+ self._reset_weights(model)
963
+
964
+ decoder_params = (
965
+ model.phase1_decoder.parameters()
966
+ if phase == 1
967
+ else model.phase23_decoder.parameters()
968
+ )
969
+ optimizer = torch.optim.Adam(decoder_params, lr=lr)
970
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
971
+
972
+ phase_hist: list[float] = []
973
+
974
+ for epoch in range(max_epochs):
975
+ # Focal gamma warmup
976
+ if epoch < warm:
977
+ model.gamma = 0.0
978
+ elif epoch < warm + ramp:
979
+ model.gamma = gamma_final * ((epoch - warm) / ramp)
980
+ else:
981
+ model.gamma = gamma_final
982
+
983
+ train_loss, latent_vectors = self._train_step(
984
+ loader=loader,
985
+ optimizer=optimizer,
986
+ latent_optimizer=latent_optimizer,
987
+ model=model,
988
+ l1_penalty=l1_penalty,
989
+ latent_vectors=latent_vectors,
990
+ class_weights=class_weights,
991
+ phase=phase,
992
+ )
993
+
994
+ if trial and (np.isnan(train_loss) or np.isinf(train_loss)):
995
+ raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
996
+
997
+ scheduler.step()
998
+ if return_history:
999
+ phase_hist.append(train_loss)
1000
+
1001
+ early_stopping(train_loss, model)
1002
+ if early_stopping.early_stop:
1003
+ self.logger.info(
1004
+ f"Early stopping at epoch {epoch + 1} (phase {phase})."
1005
+ )
1006
+ break
1007
+
1008
+ # Validation pruning hook
1009
+ if (
1010
+ trial is not None
1011
+ and X_val is not None
1012
+ and ((epoch + 1) % eval_interval == 0)
1013
+ ):
1014
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
1015
+ z = self._first_linear_in_features(model)
1016
+ schema_key = f"{cache_key_root}_z{z}"
1017
+
1018
+ metric_val = self._eval_for_pruning(
1019
+ model=model,
1020
+ X_val=X_val,
1021
+ params=params or getattr(self, "best_params_", {}),
1022
+ metric=metric_key,
1023
+ objective_mode=True,
1024
+ do_latent_infer=eval_requires_latents,
1025
+ latent_steps=eval_latent_steps,
1026
+ latent_lr=eval_latent_lr,
1027
+ latent_weight_decay=eval_latent_weight_decay,
1028
+ latent_seed=(self.seed if self.seed is not None else 123),
1029
+ _latent_cache=_latent_cache,
1030
+ _latent_cache_key=schema_key,
1031
+ )
1032
+
1033
+ with warnings.catch_warnings():
1034
+ warnings.simplefilter("ignore", category=UserWarning)
1035
+ trial.report(metric_val, step=epoch + 1)
1036
+
1037
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
1038
+ raise optuna.exceptions.TrialPruned(
1039
+ f"Pruned at epoch {epoch + 1} (phase {phase}): "
1040
+ f"{metric_key}={metric_val:.5f}"
1041
+ )
1042
+
1043
+ history[f"Phase {phase}"] = phase_hist
1044
+ final_best_loss = early_stopping.best_score
1045
+ final_best_model = copy.deepcopy(early_stopping.best_model)
1046
+
1047
+ return final_best_loss, final_best_model, history, latent_vectors
1048
+
1049
+ def _optimize_latents_for_inference(
1050
+ self,
1051
+ X_new: np.ndarray,
1052
+ model: torch.nn.Module,
1053
+ params: dict,
1054
+ inference_epochs: int = 200,
1055
+ ) -> torch.Tensor:
1056
+ """Optimize latent vectors for new 0/1/2 data by minimizing masked CE.
1057
+
1058
+ This method optimizes latent vectors for a given genotype matrix using a trained UBP model. It initializes the latent vectors based on the specified strategy (random or PCA) and then refines them through gradient-based optimization to minimize the cross-entropy loss between the model's predictions and the provided genotype data. The optimization process is performed for a specified number of epochs, and the resulting optimized latent vectors are returned.
1059
+
1060
+ Args:
1061
+ X_new (np.ndarray): 0/1/2 with -1 for missing.
1062
+ model (torch.nn.Module): Trained model.
1063
+ params (dict): Should include 'latent_dim'.
1064
+ inference_epochs (int): Steps for optimization.
1065
+
1066
+ Returns:
1067
+ torch.Tensor: Optimized latent vectors.
1068
+ """
1069
+ model.eval()
1070
+
1071
+ nF = getattr(model, "n_features", self.num_features_)
1072
+
1073
+ X_new = X_new.astype(np.int64, copy=False)
1074
+ X_new[X_new < 0] = -1
1075
+
1076
+ # Allow shorter inference when tune_fast is enabled, mirroring NLPCA
1077
+ if self.tune and self.tune_fast:
1078
+ inference_epochs = min(
1079
+ inference_epochs, getattr(self, "tune_infer_epochs", 20)
1080
+ )
1081
+
1082
+ new_latent_vectors = self._create_latent_space(
1083
+ params, len(X_new), X_new, self.latent_init
1084
+ )
1085
+ opt = torch.optim.Adam(
1086
+ [new_latent_vectors], lr=self.learning_rate * self.lr_input_factor
1087
+ )
1088
+ y_target = torch.from_numpy(X_new).long().to(self.device)
1089
+
1090
+ for _ in range(inference_epochs):
1091
+ opt.zero_grad(set_to_none=True)
1092
+ logits = model.phase23_decoder(new_latent_vectors).view(
1093
+ len(X_new), nF, self.num_classes_
1094
+ )
1095
+ loss = F.cross_entropy(
1096
+ logits.view(-1, self.num_classes_), y_target.view(-1), ignore_index=-1
1097
+ )
1098
+ if torch.isnan(loss) or torch.isinf(loss):
1099
+ self.logger.warning(
1100
+ "Inference loss is NaN/Inf; stopping latent refinement."
1101
+ )
1102
+ break
1103
+ loss.backward()
1104
+ opt.step()
1105
+
1106
+ return new_latent_vectors.detach()
1107
+
1108
+ def _create_latent_space(
1109
+ self,
1110
+ params: dict,
1111
+ n_samples: int,
1112
+ X: np.ndarray,
1113
+ latent_init: Literal["random", "pca"],
1114
+ ) -> torch.nn.Parameter:
1115
+ """Initialize latent space via random Xavier or PCA on 0/1/2 matrix.
1116
+
1117
+ This method initializes the latent space for the UBP model using either random Xavier initialization or PCA-based initialization. The choice of initialization strategy is determined by the latent_init parameter. If PCA is selected, the method handles missing values by imputing them with column means before performing PCA. The resulting latent vectors are standardized and converted to a PyTorch parameter that can be optimized during training.
1118
+
1119
+ Args:
1120
+ params (dict): Contains 'latent_dim'.
1121
+ n_samples (int): Number of samples.
1122
+ X (np.ndarray): (n_samples x L) 0/1/2 with -1 missing.
1123
+ latent_init (Literal["random","pca"]): Init strategy.
1124
+
1125
+ Returns:
1126
+ torch.nn.Parameter: Trainable latent matrix.
1127
+ """
1128
+ latent_dim = int(params["latent_dim"])
1129
+
1130
+ if latent_init == "pca":
1131
+ X_pca = X.astype(np.float32, copy=True)
1132
+ # mark missing
1133
+ X_pca[X_pca < 0] = np.nan
1134
+
1135
+ # ---- SAFE column means without warnings ----
1136
+ valid_counts = np.sum(~np.isnan(X_pca), axis=0)
1137
+ col_sums = np.nansum(X_pca, axis=0)
1138
+ col_means = np.divide(
1139
+ col_sums,
1140
+ valid_counts,
1141
+ out=np.zeros_like(col_sums, dtype=np.float32),
1142
+ where=valid_counts > 0,
1143
+ )
1144
+
1145
+ # impute NaNs with per-column means
1146
+ # (all-NaN cols -> 0.0 by the divide above)
1147
+ nan_r, nan_c = np.where(np.isnan(X_pca))
1148
+ if nan_r.size:
1149
+ X_pca[nan_r, nan_c] = col_means[nan_c]
1150
+
1151
+ # center columns
1152
+ X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
1153
+
1154
+ # guard: degenerate / all-zero after centering ->
1155
+ # fall back to random
1156
+ if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
1157
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
1158
+ torch.nn.init.xavier_uniform_(latents)
1159
+ return torch.nn.Parameter(latents, requires_grad=True)
1160
+
1161
+ # rank-aware component count, at least 1
1162
+ try:
1163
+ est_rank = np.linalg.matrix_rank(X_pca)
1164
+ except Exception:
1165
+ est_rank = min(n_samples, X_pca.shape[1])
1166
+
1167
+ n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
1168
+
1169
+ # use deterministic SVD to avoid power-iteration warnings
1170
+ pca = PCA(
1171
+ n_components=n_components, svd_solver="full", random_state=self.seed
1172
+ )
1173
+ initial = pca.fit_transform(X_pca) # (n_samples, n_components)
1174
+
1175
+ # pad if latent_dim > n_components
1176
+ if n_components < latent_dim:
1177
+ pad = self.rng.standard_normal(
1178
+ size=(n_samples, latent_dim - n_components)
1179
+ )
1180
+ initial = np.hstack([initial, pad])
1181
+
1182
+ # standardize latent dims
1183
+ initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
1184
+
1185
+ latents = torch.from_numpy(initial).float().to(self.device)
1186
+ return torch.nn.Parameter(latents, requires_grad=True)
1187
+
1188
+ else:
1189
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
1190
+ torch.nn.init.xavier_uniform_(latents)
1191
+
1192
+ return torch.nn.Parameter(latents, requires_grad=True)
1193
+
1194
+ def _reset_weights(self, model: torch.nn.Module) -> None:
1195
+ """Selectively resets only the weights of the phase 2/3 decoder.
1196
+
1197
+ This method targets only the `phase23_decoder` attribute of the UBPModel, leaving the `phase1_decoder` and other potential model components untouched. This allows the model to be re-initialized for the second phase of training without affecting other parts.
1198
+
1199
+ Args:
1200
+ model (torch.nn.Module): The PyTorch model whose parameters are to be reset.
1201
+ """
1202
+ if hasattr(model, "phase23_decoder"):
1203
+ # Iterate through only the modules of the second decoder
1204
+ for layer in model.phase23_decoder.modules():
1205
+ if hasattr(layer, "reset_parameters"):
1206
+ layer.reset_parameters()
1207
+ else:
1208
+ self.logger.warning(
1209
+ "Model does not have a 'phase23_decoder' attribute; skipping weight reset."
1210
+ )
1211
+
1212
+ def _latent_infer_for_eval(
1213
+ self,
1214
+ model: torch.nn.Module,
1215
+ X_val: np.ndarray,
1216
+ *,
1217
+ steps: int,
1218
+ lr: float,
1219
+ weight_decay: float,
1220
+ seed: int,
1221
+ cache: dict | None,
1222
+ cache_key: str | None,
1223
+ ) -> None:
1224
+ """Freeze weights; refine validation latents only.
1225
+
1226
+ This method optimizes latent vectors for the validation set using a trained UBP model. It refines the latent vectors by minimizing the cross-entropy loss between the model's predictions and the provided genotype data. The optimization process is performed for a specified number of steps, and the resulting optimized latent vectors are stored in a cache for potential reuse. The method ensures that the model's weights remain unchanged during this process by freezing them.
1227
+
1228
+ Args:
1229
+ model (torch.nn.Module): Trained UBP model.
1230
+ X_val (np.ndarray): Validation 0/1/2 with -1 for missing.
1231
+ steps (int): Number of optimization steps.
1232
+ lr (float): Learning rate for latent optimization.
1233
+ weight_decay (float): L2 weight decay on latents.
1234
+ seed (int): RNG seed for determinism across epochs.
1235
+ cache (dict | None): Optional dict to warm-start & persist val latents.
1236
+ cache_key (str | None): Ignored; we build a schema-aware key internally.
1237
+ """
1238
+ if seed is None:
1239
+ seed = np.random.randint(0, 999999)
1240
+ torch.manual_seed(seed)
1241
+ np.random.seed(seed)
1242
+
1243
+ model.eval()
1244
+ for p in model.parameters():
1245
+ p.requires_grad_(False)
1246
+
1247
+ nF = getattr(model, "n_features", self.num_features_)
1248
+
1249
+ X_val = X_val.astype(np.int64, copy=False)
1250
+ X_val[X_val < 0] = -1
1251
+ y_target = torch.from_numpy(X_val).long().to(self.device)
1252
+
1253
+ # Infer current model latent size to avoid shape mismatch
1254
+ latent_dim_model = self._first_linear_in_features(model)
1255
+ schema_key = f"{self.prefix}_ubp_val_latents_z{latent_dim_model}_L{nF}_K{self.num_classes_}"
1256
+
1257
+ # Warm-start from cache if compatible
1258
+ if cache is not None and schema_key in cache:
1259
+ val_latents = cache[schema_key].detach().clone().requires_grad_(True)
1260
+ else:
1261
+ val_latents = self._create_latent_space(
1262
+ {"latent_dim": latent_dim_model},
1263
+ n_samples=X_val.shape[0],
1264
+ X=X_val,
1265
+ latent_init=self.latent_init,
1266
+ ).requires_grad_(True)
1267
+
1268
+ opt = torch.optim.AdamW([val_latents], lr=lr, weight_decay=weight_decay)
1269
+
1270
+ for _ in range(max(int(steps), 0)):
1271
+ opt.zero_grad(set_to_none=True)
1272
+ logits = model.phase23_decoder(val_latents).view(
1273
+ X_val.shape[0], nF, self.num_classes_
1274
+ )
1275
+ loss = F.cross_entropy(
1276
+ logits.view(-1, self.num_classes_),
1277
+ y_target.view(-1),
1278
+ ignore_index=-1,
1279
+ reduction="mean",
1280
+ )
1281
+ loss.backward()
1282
+ opt.step()
1283
+
1284
+ if cache is not None:
1285
+ cache[schema_key] = val_latents.detach().clone()
1286
+
1287
+ for p in model.parameters():
1288
+ p.requires_grad_(True)