pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,1575 @@
1
+ import copy
2
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import optuna
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from sklearn.decomposition import PCA
9
+ from sklearn.exceptions import NotFittedError
10
+ from sklearn.model_selection import train_test_split
11
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
12
+ from snpio.utils.logging import LoggerManager
13
+
14
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
15
+ from pgsui.data_processing.containers import UBPConfig
16
+ from pgsui.data_processing.transformers import SimMissingTransformer
17
+ from pgsui.impute.unsupervised.base import BaseNNImputer
18
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
19
+ from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
20
+ from pgsui.impute.unsupervised.models.ubp_model import UBPModel
21
+ from pgsui.utils.logging_utils import configure_logger
22
+ from pgsui.utils.pretty_metrics import PrettyMetrics
23
+
24
+ if TYPE_CHECKING:
25
+ from snpio import TreeParser
26
+ from snpio.read_input.genotype_data import GenotypeData
27
+
28
+
29
+ def ensure_ubp_config(config: UBPConfig | dict | str | None) -> UBPConfig:
30
+ """Return a concrete UBPConfig from dataclass, dict, YAML path, or None.
31
+
32
+ This method normalizes the input configuration for the UBP imputer. It accepts a UBPConfig instance, a dictionary, a YAML file path, or None. If None is provided, it returns a default UBPConfig instance. If a YAML path is given, it loads the configuration from the file, supporting top-level presets. If a dictionary is provided, it flattens any nested structures and applies dot-key overrides to a base configuration, which can also be influenced by a preset if specified. The method ensures that the final output is a fully populated UBPConfig instance.
33
+
34
+ Args:
35
+ config: UBPConfig | dict | YAML path | None.
36
+
37
+ Returns:
38
+ UBPConfig: Normalized configuration instance.
39
+ """
40
+ if config is None:
41
+ return UBPConfig()
42
+ if isinstance(config, UBPConfig):
43
+ return config
44
+ if isinstance(config, str):
45
+ # YAML path — support top-level `preset`
46
+ return load_yaml_to_dataclass(config, UBPConfig)
47
+ if isinstance(config, dict):
48
+ base = UBPConfig()
49
+
50
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
51
+ for k, v in d.items():
52
+ kk = f"{prefix}.{k}" if prefix else k
53
+ if isinstance(v, dict):
54
+ _flatten(kk, v, out)
55
+ else:
56
+ out[kk] = v
57
+ return out
58
+
59
+ preset_name = config.pop("preset", None)
60
+ if "io" in config and isinstance(config["io"], dict):
61
+ preset_name = preset_name or config["io"].pop("preset", None)
62
+ if preset_name:
63
+ base = UBPConfig.from_preset(preset_name)
64
+
65
+ flat = _flatten("", config, {})
66
+ return apply_dot_overrides(base, flat)
67
+
68
+ raise TypeError("config must be a UBPConfig, dict, YAML path, or None.")
69
+
70
+
71
+ class ImputeUBP(BaseNNImputer):
72
+ """UBP imputer for 0/1/2 genotypes with a three-phase decoder schedule.
73
+
74
+ This imputer follows the training recipe from Unsupervised Backpropagation:
75
+
76
+ 1. Phase 1 (joint warm start): Learn latent codes and the shallow linear decoder together.
77
+ 2. Phase 2 (deep decoder reset): Reinitialize the deeper decoder, freeze the latent codes, and train only the decoder parameters.
78
+ 3. Phase 3 (joint fine-tune): Unfreeze everything and jointly refine latent codes plus the deep decoder before evaluation/reporting.
79
+
80
+ References:
81
+ - Gashler, Michael S., Smith, Michael R., Morris, R., and Martinez, T. (2016) Missing Value Imputation with Unsupervised Backpropagation. Computational Intelligence, 32: 196-215. doi: 10.1111/coin.12048.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ genotype_data: "GenotypeData",
87
+ *,
88
+ tree_parser: Optional["TreeParser"] = None,
89
+ config: UBPConfig | dict | str | None = None,
90
+ overrides: dict | None = None,
91
+ simulate_missing: bool | None = None,
92
+ sim_strategy: (
93
+ Literal[
94
+ "random",
95
+ "random_weighted",
96
+ "random_weighted_inv",
97
+ "nonrandom",
98
+ "nonrandom_weighted",
99
+ ]
100
+ | None
101
+ ) = None,
102
+ sim_prop: float | None = None,
103
+ sim_kwargs: dict | None = None,
104
+ ):
105
+ """Initialize the UBP imputer via dataclass/dict/YAML config with overrides.
106
+
107
+ This constructor allows for flexible initialization of the UBP imputer by accepting various forms of configuration input. It ensures that the configuration is properly normalized and any specified overrides are applied. The method also sets up logging and initializes various attributes related to the model, training, tuning, and evaluation based on the provided configuration.
108
+
109
+ Args:
110
+ genotype_data (GenotypeData): Backing genotype data object.
111
+ tree_parser: "TreeParser" | None = None, Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
112
+ config (UBPConfig | dict | str | None): UBP configuration.
113
+ overrides (dict | None): Flat dot-key overrides applied after `config`.
114
+ simulate_missing (bool | None): Whether to simulate missing data during training.
115
+ sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
116
+ sim_prop (float | None): Proportion of data to simulate as missing if simulating.
117
+ sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
118
+ """
119
+ self.model_name = "ImputeUBP"
120
+ self.genotype_data = genotype_data
121
+ self.tree_parser = tree_parser
122
+
123
+ # ---- normalize config, then apply overrides ----
124
+ cfg = ensure_ubp_config(config)
125
+ if overrides:
126
+ cfg = apply_dot_overrides(cfg, overrides)
127
+ self.cfg = cfg
128
+
129
+ # ---- logging ----
130
+ logman = LoggerManager(
131
+ __name__,
132
+ prefix=self.cfg.io.prefix,
133
+ debug=self.cfg.io.debug,
134
+ verbose=self.cfg.io.verbose,
135
+ )
136
+ self.logger = configure_logger(
137
+ logman.get_logger(),
138
+ verbose=self.cfg.io.verbose,
139
+ debug=self.cfg.io.debug,
140
+ )
141
+
142
+ # ---- Base init ----
143
+ super().__init__(
144
+ model_name=self.model_name,
145
+ genotype_data=self.genotype_data,
146
+ prefix=self.cfg.io.prefix,
147
+ device=self.cfg.train.device,
148
+ verbose=self.cfg.io.verbose,
149
+ debug=self.cfg.io.debug,
150
+ )
151
+
152
+ # ---- model/meta ----
153
+ self.Model = UBPModel
154
+ self.pgenc = GenotypeEncoder(genotype_data)
155
+
156
+ self.seed = self.cfg.io.seed
157
+ self.n_jobs = self.cfg.io.n_jobs
158
+ self.prefix = self.cfg.io.prefix
159
+ self.scoring_averaging = self.cfg.io.scoring_averaging
160
+ self.verbose = self.cfg.io.verbose
161
+ self.debug = self.cfg.io.debug
162
+ self.rng = np.random.default_rng(self.seed)
163
+
164
+ # Simulated-missing controls (config defaults w/ overrides)
165
+ sim_cfg = getattr(self.cfg, "sim", None)
166
+ sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
167
+ if sim_kwargs:
168
+ sim_cfg_kwargs.update(sim_kwargs)
169
+ if sim_cfg is None:
170
+ default_sim_flag = bool(simulate_missing)
171
+ default_strategy = "random"
172
+ default_prop = 0.10
173
+ else:
174
+ default_sim_flag = sim_cfg.simulate_missing
175
+ default_strategy = sim_cfg.sim_strategy
176
+ default_prop = sim_cfg.sim_prop
177
+ self.simulate_missing = (
178
+ default_sim_flag if simulate_missing is None else bool(simulate_missing)
179
+ )
180
+ self.sim_strategy = sim_strategy or default_strategy
181
+ self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
182
+ self.sim_kwargs = sim_cfg_kwargs
183
+
184
+ # ---- model hyperparams ----
185
+ self.latent_dim = self.cfg.model.latent_dim
186
+ self.dropout_rate = self.cfg.model.dropout_rate
187
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
188
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
189
+ self.layer_schedule = self.cfg.model.layer_schedule
190
+ self.latent_init: Literal["pca", "random"] = self.cfg.model.latent_init
191
+ self.activation = self.cfg.model.hidden_activation
192
+ self.gamma = self.cfg.model.gamma
193
+
194
+ # ---- training ----
195
+ self.batch_size = self.cfg.train.batch_size
196
+ self.learning_rate = self.cfg.train.learning_rate
197
+ self.lr_input_factor = self.cfg.train.lr_input_factor
198
+ self.l1_penalty = self.cfg.train.l1_penalty
199
+ self.early_stop_gen = self.cfg.train.early_stop_gen
200
+ self.min_epochs = self.cfg.train.min_epochs
201
+ self.epochs = self.cfg.train.max_epochs
202
+ self.validation_split = self.cfg.train.validation_split
203
+ self.beta = self.cfg.train.weights_beta
204
+ self.max_ratio = self.cfg.train.weights_max_ratio
205
+
206
+ # ---- tuning ----
207
+ self.tune = self.cfg.tune.enabled
208
+ self.tune_fast = self.cfg.tune.fast
209
+ self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
210
+ self.tune_batch_size = self.cfg.tune.batch_size
211
+ self.tune_epochs = self.cfg.tune.epochs
212
+ self.tune_eval_interval = self.cfg.tune.eval_interval
213
+ self.tune_metric: Literal[
214
+ "pr_macro",
215
+ "f1",
216
+ "accuracy",
217
+ "average_precision",
218
+ "precision",
219
+ "recall",
220
+ "roc_auc",
221
+ ] = self.cfg.tune.metric
222
+ self.n_trials = self.cfg.tune.n_trials
223
+ self.tune_save_db = self.cfg.tune.save_db
224
+ self.tune_resume = self.cfg.tune.resume
225
+ self.tune_max_samples = self.cfg.tune.max_samples
226
+ self.tune_max_loci = self.cfg.tune.max_loci
227
+ self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
228
+ self.tune_patience = self.cfg.tune.patience
229
+
230
+ # ---- evaluation ----
231
+ self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
232
+ self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
233
+ self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
234
+
235
+ # ---- plotting ----
236
+ self.plot_format = self.cfg.plot.fmt
237
+ self.plot_dpi = self.cfg.plot.dpi
238
+ self.plot_fontsize = self.cfg.plot.fontsize
239
+ self.title_fontsize = self.cfg.plot.fontsize
240
+ self.despine = self.cfg.plot.despine
241
+ self.show_plots = self.cfg.plot.show
242
+
243
+ # ---- core runtime ----
244
+ self.is_haploid = False
245
+ self.num_classes_ = False
246
+ self.model_params: Dict[str, Any] = {}
247
+ self.sim_mask_global_: np.ndarray | None = None
248
+ self.sim_mask_train_: np.ndarray | None = None
249
+ self.sim_mask_test_: np.ndarray | None = None
250
+
251
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
252
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
253
+ self.logger.error(msg)
254
+ raise ValueError(msg)
255
+
256
+ def fit(self) -> "ImputeUBP":
257
+ """Fit the UBP decoder on 0/1/2 encodings (missing = -1) via three phases.
258
+
259
+ 1. Phase 1 initializes latent vectors alongside the linear decoder.
260
+ 2. Phase 2 resets and trains the deeper decoder while latents remain fixed.
261
+ 3. Phase 3 jointly fine-tunes latents plus the deep decoder before evaluation.
262
+
263
+ Returns:
264
+ ImputeUBP: Fitted instance.
265
+
266
+ Raises:
267
+ NotFittedError: If training fails.
268
+ """
269
+ self.logger.info(f"Fitting {self.model_name} model...")
270
+
271
+ # --- Use 0/1/2 with -1 for missing ---
272
+ X012 = self._get_float_genotypes(copy=True)
273
+ GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
274
+ self.ground_truth_ = GT_full.astype(np.int64, copy=False)
275
+
276
+ cache_key = self._sim_mask_cache_key()
277
+ self.sim_mask_global_ = None
278
+ if self.simulate_missing:
279
+ cached_mask = (
280
+ None if cache_key is None else self._sim_mask_cache.get(cache_key)
281
+ )
282
+ if cached_mask is not None:
283
+ self.sim_mask_global_ = cached_mask.copy()
284
+ else:
285
+ tr = SimMissingTransformer(
286
+ genotype_data=self.genotype_data,
287
+ tree_parser=self.tree_parser,
288
+ prop_missing=self.sim_prop,
289
+ strategy=self.sim_strategy,
290
+ missing_val=-9,
291
+ mask_missing=True,
292
+ verbose=self.verbose,
293
+ **self.sim_kwargs,
294
+ )
295
+ tr.fit(X012.copy())
296
+ self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
297
+ if cache_key is not None:
298
+ self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
299
+
300
+ X_for_model = self.ground_truth_.copy()
301
+ if self.sim_mask_global_ is not None:
302
+ X_for_model[self.sim_mask_global_] = -1
303
+
304
+ # --- Determine ploidy (haploid vs diploid) and classes ---
305
+ self.is_haploid = bool(
306
+ np.all(
307
+ np.isin(
308
+ self.genotype_data.snp_data,
309
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
310
+ )
311
+ )
312
+ )
313
+ self.ploidy = 1 if self.is_haploid else 2
314
+
315
+ if self.is_haploid:
316
+ self.num_classes_ = 2
317
+ self.ground_truth_[self.ground_truth_ == 2] = 1
318
+ X_for_model[X_for_model == 2] = 1
319
+ self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
320
+ else:
321
+ self.num_classes_ = 3
322
+ self.logger.info(
323
+ "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
324
+ )
325
+
326
+ n_samples, self.num_features_ = X_for_model.shape
327
+
328
+ # --- model params (decoder: Z -> L * num_classes) ---
329
+ self.model_params = {
330
+ "n_features": self.num_features_,
331
+ "num_classes": self.num_classes_,
332
+ "latent_dim": self.latent_dim,
333
+ "dropout_rate": self.dropout_rate,
334
+ "activation": self.activation,
335
+ # hidden_layer_sizes injected later
336
+ }
337
+
338
+ # --- split ---
339
+ indices = np.arange(n_samples)
340
+ train_idx, test_idx = train_test_split(
341
+ indices, test_size=self.validation_split, random_state=self.seed
342
+ )
343
+ self.train_idx_, self.test_idx_ = train_idx, test_idx
344
+ self.X_train_ = X_for_model[train_idx]
345
+ self.X_test_ = X_for_model[test_idx]
346
+ self.GT_train_full_ = self.ground_truth_[train_idx]
347
+ self.GT_test_full_ = self.ground_truth_[test_idx]
348
+
349
+ if self.sim_mask_global_ is not None:
350
+ self.sim_mask_train_ = self.sim_mask_global_[train_idx]
351
+ self.sim_mask_test_ = self.sim_mask_global_[test_idx]
352
+ else:
353
+ self.sim_mask_train_ = None
354
+ self.sim_mask_test_ = None
355
+
356
+ # --- plotting/scorers & tuning ---
357
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
358
+ if self.tune:
359
+ self.tune_hyperparameters()
360
+
361
+ # Fall back to default model params when none have been selected yet.
362
+ if not getattr(self, "best_params_", None):
363
+ self.best_params_ = self._set_best_params_default()
364
+
365
+ # --- class weights for 0/1/2 ---
366
+ self.class_weights_ = self._normalize_class_weights(
367
+ self._class_weights_from_zygosity(self.X_train_)
368
+ )
369
+
370
+ # --- latent init & loader ---
371
+ train_latent_vectors = self._create_latent_space(
372
+ self.best_params_, len(self.X_train_), self.X_train_, self.latent_init
373
+ )
374
+ train_loader = self._get_data_loaders(self.X_train_)
375
+
376
+ # --- final training (three-phase under the hood) ---
377
+ (self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
378
+ self._train_final_model(
379
+ loader=train_loader,
380
+ best_params=self.best_params_,
381
+ initial_latent_vectors=train_latent_vectors,
382
+ )
383
+ )
384
+
385
+ self.is_fit_ = True
386
+ self.plotter_.plot_history(self.history_)
387
+ eval_mask = (
388
+ self.sim_mask_test_
389
+ if (self.simulate_missing and self.sim_mask_test_ is not None)
390
+ else None
391
+ )
392
+ self._evaluate_model(
393
+ self.X_test_,
394
+ self.model_,
395
+ self.best_params_,
396
+ eval_mask_override=eval_mask,
397
+ )
398
+ self._save_best_params(self.best_params_)
399
+ return self
400
+
401
+ def transform(self) -> np.ndarray:
402
+ """Impute missing genotypes (0/1/2) and return IUPAC strings.
403
+
404
+ This method first checks if the model has been fitted. It then imputes the entire dataset by optimizing latent vectors for the ground truth data and predicting the missing genotypes using the trained UBP model. The imputed genotypes are decoded to IUPAC format, and genotype distributions are plotted only when ``self.show_plots`` is enabled.
405
+
406
+ Returns:
407
+ np.ndarray: IUPAC single-character array (n_samples x L).
408
+
409
+ Raises:
410
+ NotFittedError: If called before fit().
411
+ """
412
+ if not getattr(self, "is_fit_", False):
413
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
414
+
415
+ self.logger.info(f"Imputing entire dataset with {self.model_name}...")
416
+ X_to_impute = self.ground_truth_.copy()
417
+
418
+ optimized_latents = self._optimize_latents_for_inference(
419
+ X_to_impute, self.model_, self.best_params_
420
+ )
421
+
422
+ if not isinstance(optimized_latents, torch.nn.Parameter):
423
+ optimized_latents = torch.nn.Parameter(
424
+ optimized_latents, requires_grad=False
425
+ )
426
+
427
+ pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
428
+
429
+ missing_mask = X_to_impute == -1
430
+ imputed_array = X_to_impute.copy()
431
+ imputed_array[missing_mask] = pred_labels[missing_mask]
432
+
433
+ # Decode to IUPAC for return & optional plots
434
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
435
+ if self.show_plots:
436
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
437
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
438
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
439
+ return imputed_genotypes
440
+
441
+ def _train_step(
442
+ self,
443
+ loader: torch.utils.data.DataLoader,
444
+ optimizer: torch.optim.Optimizer,
445
+ latent_optimizer: torch.optim.Optimizer,
446
+ model: torch.nn.Module,
447
+ l1_penalty: float,
448
+ latent_vectors: torch.nn.Parameter,
449
+ class_weights: torch.Tensor,
450
+ phase: int,
451
+ ) -> Tuple[float, torch.nn.Parameter]:
452
+ """One epoch with stable focal CE, grad clipping, and NaN guards.
453
+
454
+ Returns:
455
+ Tuple[float, torch.nn.Parameter]: Mean loss and updated latents.
456
+ """
457
+ model.train()
458
+ running, used = 0.0, 0
459
+
460
+ if not isinstance(latent_vectors, torch.nn.Parameter):
461
+ latent_vectors = torch.nn.Parameter(latent_vectors, requires_grad=True)
462
+
463
+ gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
464
+ gamma = max(0.0, min(gamma, 10.0))
465
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
466
+ if class_weights is not None and class_weights.device != self.device:
467
+ class_weights = class_weights.to(self.device)
468
+
469
+ criterion = SafeFocalCELoss(gamma=gamma, weight=class_weights, ignore_index=-1)
470
+ decoder: torch.Tensor | torch.nn.Module = (
471
+ model.phase1_decoder if phase == 1 else model.phase23_decoder
472
+ )
473
+
474
+ if not isinstance(decoder, torch.nn.Module):
475
+ msg = f"{self.model_name} Decoder is not a torch.nn.Module."
476
+ self.logger.error(msg)
477
+ raise TypeError(msg)
478
+
479
+ for batch_indices, y_batch in loader:
480
+ optimizer.zero_grad(set_to_none=True)
481
+ latent_optimizer.zero_grad(set_to_none=True)
482
+
483
+ batch_indices = batch_indices.to(latent_vectors.device, non_blocking=True)
484
+ z = latent_vectors[batch_indices]
485
+ y = y_batch.to(self.device, non_blocking=True).long()
486
+
487
+ logits = decoder(z).view(
488
+ len(batch_indices), self.num_features_, self.num_classes_
489
+ )
490
+
491
+ # Guard upstream explosions
492
+ if not torch.isfinite(logits).all():
493
+ continue
494
+
495
+ loss = criterion(logits.view(-1, self.num_classes_), y.view(-1))
496
+
497
+ if l1_penalty > 0:
498
+ l1 = torch.zeros((), device=self.device)
499
+ for p in l1_params:
500
+ l1 = l1 + p.abs().sum()
501
+ loss = loss + l1_penalty * l1
502
+
503
+ if not torch.isfinite(loss):
504
+ continue
505
+
506
+ loss.backward()
507
+
508
+ # Clip returns the Total Norm
509
+ model_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
510
+ latent_norm = torch.nn.utils.clip_grad_norm_([latent_vectors], 1.0)
511
+
512
+ # Skip update on non-finite grads
513
+ # Check norms instead of iterating all parameters
514
+ if torch.isfinite(model_norm) and torch.isfinite(latent_norm):
515
+ optimizer.step()
516
+ if phase != 2:
517
+ latent_optimizer.step()
518
+ else:
519
+ # Logic to handle bad grads (zero out, skip, etc)
520
+ optimizer.zero_grad(set_to_none=True)
521
+ latent_optimizer.zero_grad(set_to_none=True)
522
+
523
+ running += float(loss.detach().item())
524
+ used += 1
525
+
526
+ return (running / used if used > 0 else float("inf")), latent_vectors
527
+
528
+ def _predict(
529
+ self,
530
+ model: torch.nn.Module,
531
+ latent_vectors: Optional[torch.nn.Parameter | torch.Tensor] = None,
532
+ ) -> Tuple[np.ndarray, np.ndarray]:
533
+ """Predict 0/1/2 labels & probabilities from latents via phase23 decoder. This method requires a trained model and latent vectors.
534
+
535
+ Args:
536
+ model (torch.nn.Module): Trained model.
537
+ latent_vectors (torch.nn.Parameter | None): Latent vectors.
538
+
539
+ Returns:
540
+ Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
541
+ """
542
+ if model is None or latent_vectors is None:
543
+ msg = "Model and latent vectors must be provided for prediction. Fit the model first."
544
+ self.logger.error(msg)
545
+ raise NotFittedError(msg)
546
+
547
+ model.eval()
548
+ nF = getattr(model, "n_features", self.num_features_)
549
+ with torch.no_grad():
550
+ decoder = model.phase23_decoder
551
+
552
+ if not isinstance(decoder, torch.nn.Module):
553
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
554
+ self.logger.error(msg)
555
+ raise TypeError(msg)
556
+
557
+ logits = decoder(latent_vectors.to(self.device)).view(
558
+ len(latent_vectors), nF, self.num_classes_
559
+ )
560
+ probas = torch.softmax(logits, dim=-1)
561
+ labels = torch.argmax(probas, dim=-1)
562
+
563
+ return labels.cpu().numpy(), probas.cpu().numpy()
564
+
565
+ def _evaluate_model(
566
+ self,
567
+ X_val: np.ndarray,
568
+ model: torch.nn.Module,
569
+ params: dict,
570
+ objective_mode: bool = False,
571
+ latent_vectors_val: torch.Tensor | None = None,
572
+ *,
573
+ eval_mask_override: np.ndarray | None = None,
574
+ ) -> Dict[str, float]:
575
+ """Evaluates the model on a validation set.
576
+
577
+ This method evaluates the trained UBP model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
578
+
579
+ Args:
580
+ X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
581
+ model (torch.nn.Module): Trained UBP model.
582
+ params (dict): Model parameters.
583
+ objective_mode (bool): If True, suppresses logging and reports only the metric.
584
+ latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
585
+ eval_mask_override (np.ndarray | None): Boolean mask to specify which entries to evaluate.
586
+
587
+ Returns:
588
+ Dict[str, float]: Dictionary of evaluation metrics.
589
+ """
590
+ if latent_vectors_val is not None:
591
+ test_latent_vectors = latent_vectors_val
592
+ else:
593
+ test_latent_vectors = self._optimize_latents_for_inference(
594
+ X_val, model, params
595
+ )
596
+
597
+ pred_labels, pred_probas = self._predict(
598
+ model=model, latent_vectors=test_latent_vectors
599
+ )
600
+
601
+ if eval_mask_override is not None:
602
+ # Validate row counts to allow feature subsetting during tuning
603
+ if eval_mask_override.shape[0] != X_val.shape[0]:
604
+ msg = (
605
+ f"eval_mask_override rows {eval_mask_override.shape[0]} "
606
+ f"does not match X_val rows {X_val.shape[0]}"
607
+ )
608
+ self.logger.error(msg)
609
+ raise ValueError(msg)
610
+
611
+ # FIX: Slice mask columns if override is wider than current X_val (tune_fast)
612
+ if eval_mask_override.shape[1] > X_val.shape[1]:
613
+ eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
614
+ else:
615
+ eval_mask = eval_mask_override.astype(bool)
616
+ else:
617
+ # Default: score only observed entries
618
+ eval_mask = X_val != -1
619
+
620
+ # y_true should be drawn from the pre-mask ground truth
621
+ # Map X_val back to the correct full ground truth slice
622
+ # FIX: Check shape[0] (n_samples) only.
623
+ if X_val.shape[0] == self.X_test_.shape[0]:
624
+ GT_ref = self.GT_test_full_
625
+ elif X_val.shape[0] == self.X_train_.shape[0]:
626
+ GT_ref = self.GT_train_full_
627
+ else:
628
+ GT_ref = self.ground_truth_
629
+
630
+ # FIX: Slice Ground Truth columns if it is wider than X_val (tune_fast)
631
+ if GT_ref.shape[1] > X_val.shape[1]:
632
+ GT_ref = GT_ref[:, : X_val.shape[1]]
633
+
634
+ # Fallback safeguard
635
+ if GT_ref.shape != X_val.shape:
636
+ GT_ref = X_val
637
+
638
+ y_true_flat = GT_ref[eval_mask]
639
+ pred_labels_flat = pred_labels[eval_mask]
640
+ pred_probas_flat = pred_probas[eval_mask]
641
+
642
+ if y_true_flat.size == 0:
643
+ return {self.tune_metric: 0.0}
644
+
645
+ # For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
646
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
647
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
648
+
649
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
650
+
651
+ metrics = self.scorers_.evaluate(
652
+ y_true_flat,
653
+ pred_labels_flat,
654
+ y_true_ohe,
655
+ pred_probas_flat,
656
+ objective_mode,
657
+ self.tune_metric,
658
+ )
659
+
660
+ if not objective_mode:
661
+ pm = PrettyMetrics(
662
+ metrics, precision=3, title=f"{self.model_name} Validation Metrics"
663
+ )
664
+ pm.render() # prints a command-line table
665
+
666
+ self._make_class_reports(
667
+ y_true=y_true_flat,
668
+ y_pred_proba=pred_probas_flat,
669
+ y_pred=pred_labels_flat,
670
+ metrics=metrics,
671
+ labels=target_names,
672
+ )
673
+
674
+ # FIX: Use X_val dimensions for reshaping, not self.num_features_
675
+ y_true_dec = self.pgenc.decode_012(
676
+ GT_ref.reshape(X_val.shape[0], X_val.shape[1])
677
+ )
678
+
679
+ X_pred = X_val.copy()
680
+ X_pred[eval_mask] = pred_labels_flat
681
+
682
+ y_pred_dec = self.pgenc.decode_012(
683
+ X_pred.reshape(X_val.shape[0], X_val.shape[1])
684
+ )
685
+
686
+ encodings_dict = {
687
+ "A": 0,
688
+ "C": 1,
689
+ "G": 2,
690
+ "T": 3,
691
+ "W": 4,
692
+ "R": 5,
693
+ "M": 6,
694
+ "K": 7,
695
+ "Y": 8,
696
+ "S": 9,
697
+ "N": -1,
698
+ }
699
+
700
+ y_true_int = self.pgenc.convert_int_iupac(
701
+ y_true_dec, encodings_dict=encodings_dict
702
+ )
703
+ y_pred_int = self.pgenc.convert_int_iupac(
704
+ y_pred_dec, encodings_dict=encodings_dict
705
+ )
706
+
707
+ # For IUPAC report
708
+ valid_true = y_true_int[eval_mask]
709
+ valid_true = valid_true[valid_true >= 0] # drop -1 (N)
710
+ iupac_label_set = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
711
+
712
+ # For numeric report
713
+ if (
714
+ np.intersect1d(np.unique(y_true_flat), labels_for_scoring).size == 0
715
+ or valid_true.size == 0
716
+ ):
717
+ if not objective_mode:
718
+ self.logger.warning(
719
+ "Skipped numeric confusion matrix: no y_true labels present."
720
+ )
721
+ else:
722
+ self._make_class_reports(
723
+ y_true=valid_true,
724
+ y_pred=y_pred_int[eval_mask][y_true_int[eval_mask] >= 0],
725
+ metrics=metrics,
726
+ y_pred_proba=None,
727
+ labels=iupac_label_set,
728
+ )
729
+
730
+ return metrics
731
+
732
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
733
+ """Create DataLoader over indices + 0/1/2 target matrix.
734
+
735
+ This method creates a PyTorch DataLoader for the given genotype matrix, which contains 0/1/2 encodings with -1 for missing values. The DataLoader is constructed to yield batches of data during training, where each batch consists of indices and the corresponding genotype values. The genotype matrix is converted to a PyTorch tensor and moved to the appropriate device (CPU or GPU) before being wrapped in a TensorDataset. The DataLoader is configured to shuffle the data and use the specified batch size.
736
+
737
+ Args:
738
+ y (np.ndarray): (n_samples x L) int matrix with -1 missing.
739
+
740
+ Returns:
741
+ torch.utils.data.DataLoader: Shuffled mini-batches.
742
+ """
743
+ y_tensor = torch.from_numpy(y).long()
744
+ indices = torch.arange(len(y), dtype=torch.long)
745
+ dataset = torch.utils.data.TensorDataset(indices, y_tensor)
746
+ pin_memory = self.device.type == "cuda"
747
+ return torch.utils.data.DataLoader(
748
+ dataset,
749
+ batch_size=self.batch_size,
750
+ shuffle=True,
751
+ pin_memory=pin_memory,
752
+ )
753
+
754
+ def _objective(self, trial: optuna.Trial) -> float:
755
+ """Optuna objective using the UBP training loop."""
756
+ try:
757
+ params = self._sample_hyperparameters(trial)
758
+
759
+ X_train_trial = getattr(
760
+ self, "X_train_", self.ground_truth_[self.train_idx_]
761
+ )
762
+ X_test_trial = getattr(self, "X_test_", self.ground_truth_[self.test_idx_])
763
+
764
+ class_weights = self._normalize_class_weights(
765
+ self._class_weights_from_zygosity(X_train_trial)
766
+ )
767
+ train_loader = self._get_data_loaders(X_train_trial)
768
+
769
+ train_latent_vectors = self._create_latent_space(
770
+ params, len(X_train_trial), X_train_trial, params["latent_init"]
771
+ )
772
+
773
+ model = self.build_model(self.Model, params["model_params"])
774
+ model.n_features = params["model_params"]["n_features"]
775
+ model.apply(self.initialize_weights)
776
+
777
+ _, model, __ = self._train_and_validate_model(
778
+ model=model,
779
+ loader=train_loader,
780
+ lr=params["lr"],
781
+ l1_penalty=params["l1_penalty"],
782
+ trial=trial,
783
+ return_history=False,
784
+ latent_vectors=train_latent_vectors,
785
+ lr_input_factor=params["lr_input_factor"],
786
+ class_weights=class_weights,
787
+ X_val=X_test_trial,
788
+ params=params,
789
+ prune_metric=self.tune_metric,
790
+ prune_warmup_epochs=5,
791
+ eval_interval=1,
792
+ eval_requires_latents=True,
793
+ eval_latent_steps=self.eval_latent_steps,
794
+ eval_latent_lr=self.eval_latent_lr,
795
+ eval_latent_weight_decay=self.eval_latent_weight_decay,
796
+ )
797
+
798
+ eval_mask = (
799
+ self.sim_mask_test_
800
+ if (
801
+ self.simulate_missing
802
+ and getattr(self, "sim_mask_test_", None) is not None
803
+ )
804
+ else None
805
+ )
806
+ metrics = self._evaluate_model(
807
+ X_test_trial,
808
+ model,
809
+ params,
810
+ objective_mode=True,
811
+ eval_mask_override=eval_mask,
812
+ )
813
+ self._clear_resources(
814
+ model, train_loader, latent_vectors=train_latent_vectors
815
+ )
816
+ return metrics[self.tune_metric]
817
+ except Exception as e:
818
+ raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
819
+
820
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
821
+ """Sample UBP hyperparameters; compute hidden sizes for model_params.
822
+
823
+ This method samples a set of hyperparameters for the UBP model using the provided Optuna trial object. It defines a search space for various hyperparameters, including latent dimension, learning rate, dropout rate, number of hidden layers, activation function, and others. After sampling the hyperparameters, it computes the sizes of the hidden layers based on the sampled values and constructs the model parameters dictionary. The method returns a dictionary containing all sampled hyperparameters along with the computed model parameters.
824
+
825
+ Args:
826
+ trial (optuna.Trial): Current trial.
827
+
828
+ Returns:
829
+ Dict[str, int | float | str | list]: Sampled hyperparameters.
830
+ """
831
+ params = {
832
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
833
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
834
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
835
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
836
+ "activation": trial.suggest_categorical(
837
+ "activation", ["relu", "elu", "selu"]
838
+ ),
839
+ "gamma": trial.suggest_float("gamma", 0.0, 5.0),
840
+ "lr_input_factor": trial.suggest_float(
841
+ "lr_input_factor", 0.1, 10.0, log=True
842
+ ),
843
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
844
+ "layer_scaling_factor": trial.suggest_float(
845
+ "layer_scaling_factor", 2.0, 10.0
846
+ ),
847
+ "layer_schedule": trial.suggest_categorical(
848
+ "layer_schedule", ["pyramid", "constant", "linear"]
849
+ ),
850
+ "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
851
+ }
852
+
853
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
854
+ n_inputs=params["latent_dim"],
855
+ n_outputs=self.num_features_ * self.num_classes_,
856
+ n_samples=len(self.train_idx_),
857
+ n_hidden=params["num_hidden_layers"],
858
+ alpha=params["layer_scaling_factor"],
859
+ schedule=params["layer_schedule"],
860
+ )
861
+ # Keep the latent_dim as the first element,
862
+ # then the interior hidden widths.
863
+ # If there are no interior widths (very small nets),
864
+ # this still leaves [latent_dim].
865
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
866
+
867
+ params["model_params"] = {
868
+ "n_features": self.num_features_,
869
+ "num_classes": self.num_classes_,
870
+ "latent_dim": params["latent_dim"],
871
+ "dropout_rate": params["dropout_rate"],
872
+ "hidden_layer_sizes": hidden_only,
873
+ "activation": params["activation"],
874
+ }
875
+
876
+ return params
877
+
878
+ def _set_best_params(self, best_params: dict) -> dict:
879
+ """Set best params onto instance; return model_params payload.
880
+
881
+ This method sets the best hyperparameters found during tuning onto the instance attributes of the ImputeUBP class. It extracts the relevant hyperparameters from the provided dictionary and updates the corresponding instance variables. Additionally, it computes the sizes of the hidden layers based on the best hyperparameters and constructs the model parameters dictionary. The method returns a dictionary containing the model parameters that can be used to build the UBP model.
882
+
883
+ Args:
884
+ best_params (dict): Best hyperparameters.
885
+
886
+ Returns:
887
+ dict: model_params payload.
888
+
889
+ Raises:
890
+ ValueError: If best_params is missing required keys.
891
+ """
892
+ self.latent_dim = best_params["latent_dim"]
893
+ self.dropout_rate = best_params["dropout_rate"]
894
+ self.learning_rate = best_params["learning_rate"]
895
+ self.gamma = best_params["gamma"]
896
+ self.lr_input_factor = best_params["lr_input_factor"]
897
+ self.l1_penalty = best_params["l1_penalty"]
898
+ self.activation = best_params["activation"]
899
+ self.latent_init = best_params["latent_init"]
900
+
901
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
902
+ n_inputs=self.latent_dim,
903
+ n_outputs=self.num_features_ * self.num_classes_,
904
+ n_samples=len(self.train_idx_),
905
+ n_hidden=best_params["num_hidden_layers"],
906
+ alpha=best_params["layer_scaling_factor"],
907
+ schedule=best_params["layer_schedule"],
908
+ )
909
+
910
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
911
+
912
+ return {
913
+ "n_features": self.num_features_,
914
+ "latent_dim": self.latent_dim,
915
+ "hidden_layer_sizes": hidden_only,
916
+ "dropout_rate": self.dropout_rate,
917
+ "activation": self.activation,
918
+ "gamma": self.gamma,
919
+ "num_classes": self.num_classes_,
920
+ }
921
+
922
+ def _set_best_params_default(self) -> dict:
923
+ """Default (no-tuning) model_params aligned with current attributes.
924
+
925
+ This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
926
+
927
+ Returns:
928
+ dict: model_params payload.
929
+ """
930
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
931
+ n_inputs=self.latent_dim,
932
+ n_outputs=self.num_features_ * self.num_classes_,
933
+ n_samples=len(self.ground_truth_),
934
+ n_hidden=self.num_hidden_layers,
935
+ alpha=self.layer_scaling_factor,
936
+ schedule=self.layer_schedule,
937
+ )
938
+
939
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
940
+
941
+ return {
942
+ "n_features": self.num_features_,
943
+ "latent_dim": self.latent_dim,
944
+ "hidden_layer_sizes": hidden_only,
945
+ "dropout_rate": self.dropout_rate,
946
+ "activation": self.activation,
947
+ "gamma": self.gamma,
948
+ "num_classes": self.num_classes_,
949
+ }
950
+
951
+ def _train_and_validate_model(
952
+ self,
953
+ model: torch.nn.Module,
954
+ loader: torch.utils.data.DataLoader,
955
+ lr: float,
956
+ l1_penalty: float,
957
+ trial: optuna.Trial | None = None,
958
+ return_history: bool = False,
959
+ latent_vectors: torch.nn.Parameter | None = None,
960
+ lr_input_factor: float = 1.0,
961
+ class_weights: torch.Tensor | None = None,
962
+ *,
963
+ X_val: np.ndarray | None = None,
964
+ params: dict | None = None,
965
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
966
+ prune_warmup_epochs: int = 3,
967
+ eval_interval: int = 1,
968
+ eval_requires_latents: bool = True, # UBP needs latent eval
969
+ eval_latent_steps: int = 50,
970
+ eval_latent_lr: float = 1e-2,
971
+ eval_latent_weight_decay: float = 0.0,
972
+ ) -> tuple:
973
+ """Train & validate UBP model with three-phase loop.
974
+
975
+ This method trains and validates the UBP model using a three-phase training loop. It sets up the latent optimizer and invokes the training loop, which includes pre-training, fine-tuning, and joint training phases. The method ensures that the necessary latent vectors and class weights are provided before proceeding with training. It also incorporates new parameters for evaluation and pruning during training. The final best loss, best model, training history, and optimized latent vectors are returned.
976
+
977
+ Args:
978
+ model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
979
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
980
+ lr (float): Learning rate for decoder.
981
+ l1_penalty (float): L1 regularization weight.
982
+ trial (optuna.Trial | None): Current trial or None.
983
+ return_history (bool): If True, return loss history.
984
+ latent_vectors (torch.nn.Parameter | None): Trainable Z.
985
+ lr_input_factor (float): LR factor for latents.
986
+ class_weights (torch.Tensor | None): Class weights for 0/1/2.
987
+ X_val (np.ndarray | None): Validation set for pruning/eval.
988
+ params (dict | None): Model params for eval.
989
+ prune_metric (str | None): Metric to monitor for pruning.
990
+ prune_warmup_epochs (int): Epochs before pruning starts.
991
+ eval_interval (int): Epochs between evaluations.
992
+ eval_requires_latents (bool): If True, optimize latents for eval.
993
+ eval_latent_steps (int): Latent optimization steps for eval.
994
+ eval_latent_lr (float): Latent optimization LR for eval.
995
+ eval_latent_weight_decay (float): Latent optimization weight decay for eval.
996
+
997
+ Returns:
998
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
999
+
1000
+ Raises:
1001
+ TypeError: If latent_vectors or class_weights are
1002
+ not provided.
1003
+ ValueError: If X_val is not provided for evaluation.
1004
+ RuntimeError: If eval_latent_steps is not positive.
1005
+ """
1006
+ if latent_vectors is None or class_weights is None:
1007
+ msg = "Must provide latent_vectors and class_weights."
1008
+ self.logger.error(msg)
1009
+ raise TypeError(msg)
1010
+
1011
+ latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
1012
+
1013
+ result = self._execute_training_loop(
1014
+ loader=loader,
1015
+ latent_optimizer=latent_optimizer,
1016
+ lr=lr,
1017
+ model=model,
1018
+ l1_penalty=l1_penalty,
1019
+ trial=trial,
1020
+ return_history=return_history,
1021
+ latent_vectors=latent_vectors,
1022
+ class_weights=class_weights,
1023
+ # NEW ↓↓↓
1024
+ X_val=X_val,
1025
+ params=params,
1026
+ prune_metric=prune_metric,
1027
+ prune_warmup_epochs=prune_warmup_epochs,
1028
+ eval_interval=eval_interval,
1029
+ eval_requires_latents=eval_requires_latents,
1030
+ eval_latent_steps=eval_latent_steps,
1031
+ eval_latent_lr=eval_latent_lr,
1032
+ eval_latent_weight_decay=eval_latent_weight_decay,
1033
+ )
1034
+
1035
+ if return_history:
1036
+ return result
1037
+
1038
+ return result[0], result[1], result[3]
1039
+
1040
+ def _train_final_model(
1041
+ self,
1042
+ loader: torch.utils.data.DataLoader,
1043
+ best_params: dict,
1044
+ initial_latent_vectors: torch.nn.Parameter,
1045
+ ) -> tuple:
1046
+ """Train final UBP model with best params; save weights to disk.
1047
+
1048
+ This method trains the final UBP model using the best hyperparameters found during tuning. It builds the model with the specified parameters, initializes the weights, and invokes the training and validation process. The method saves the trained model's state dictionary to disk and returns the final loss, trained model, training history, and optimized latent vectors.
1049
+
1050
+ Args:
1051
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
1052
+ best_params (Dict[str, int | float | str | list]): Best hyperparameters.
1053
+ initial_latent_vectors (torch.nn.Parameter): Initialized latent vectors.
1054
+
1055
+ Returns:
1056
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (loss, model, {"Train": history}, latents).
1057
+ """
1058
+ self.logger.info(f"Training the final {self.model_name} model...")
1059
+
1060
+ model = self.build_model(self.Model, best_params)
1061
+ model.n_features = best_params["n_features"]
1062
+ model.apply(self.initialize_weights)
1063
+
1064
+ loss, trained_model, history, latent_vectors = self._train_and_validate_model(
1065
+ model=model,
1066
+ loader=loader,
1067
+ lr=self.learning_rate,
1068
+ l1_penalty=self.l1_penalty,
1069
+ return_history=True,
1070
+ latent_vectors=initial_latent_vectors,
1071
+ lr_input_factor=self.lr_input_factor,
1072
+ class_weights=self.class_weights_,
1073
+ X_val=self.X_test_,
1074
+ params=best_params,
1075
+ prune_metric=self.tune_metric,
1076
+ prune_warmup_epochs=5,
1077
+ eval_interval=1,
1078
+ eval_requires_latents=True,
1079
+ eval_latent_steps=self.eval_latent_steps,
1080
+ eval_latent_lr=self.eval_latent_lr,
1081
+ eval_latent_weight_decay=self.eval_latent_weight_decay,
1082
+ )
1083
+
1084
+ if trained_model is None:
1085
+ msg = "Final model training failed."
1086
+ self.logger.error(msg)
1087
+ raise RuntimeError(msg)
1088
+
1089
+ fout = self.models_dir / "final_model.pt"
1090
+ torch.save(trained_model.state_dict(), fout)
1091
+ return loss, trained_model, {"Train": history}, latent_vectors
1092
+
1093
+ def _execute_training_loop(
1094
+ self,
1095
+ loader: torch.utils.data.DataLoader,
1096
+ latent_optimizer: torch.optim.Optimizer,
1097
+ lr: float,
1098
+ model: torch.nn.Module,
1099
+ l1_penalty: float,
1100
+ trial: optuna.Trial | None,
1101
+ return_history: bool,
1102
+ latent_vectors: torch.nn.Parameter,
1103
+ class_weights: torch.Tensor,
1104
+ *,
1105
+ X_val: np.ndarray | None = None,
1106
+ params: dict | None = None,
1107
+ prune_metric: str | None = None,
1108
+ prune_warmup_epochs: int = 3,
1109
+ eval_interval: int = 1,
1110
+ eval_requires_latents: bool = True,
1111
+ eval_latent_steps: int = 50,
1112
+ eval_latent_lr: float = 1e-2,
1113
+ eval_latent_weight_decay: float = 0.0,
1114
+ ) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
1115
+ """Three-phase UBP with numeric guards, LR warmup, and pruning.
1116
+
1117
+ This method executes the three-phase training loop for the UBP model, incorporating numeric stability guards, learning rate warmup, and Optuna pruning. It iterates through three training phases: pre-training the phase 1 decoder, fine-tuning the phase 2 and 3 decoders, and joint training of all components. The method monitors training loss, applies early stopping, and evaluates the model on a validation set for pruning purposes. The final best loss, best model, training history, and optimized latent vectors are returned.
1118
+
1119
+ Args:
1120
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
1121
+ latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
1122
+ lr (float): Learning rate for decoder.
1123
+ model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
1124
+ l1_penalty (float): L1 regularization weight.
1125
+ trial (optuna.Trial | None): Current trial or None.
1126
+ return_history (bool): If True, return loss history.
1127
+ latent_vectors (torch.nn.Parameter): Trainable Z.
1128
+ class_weights (torch.Tensor): Class weights for
1129
+ 0/1/2.
1130
+ X_val (np.ndarray | None): Validation set for pruning/eval.
1131
+ params (dict | None): Model params for eval.
1132
+ prune_metric (str | None): Metric to monitor for pruning.
1133
+ prune_warmup_epochs (int): Epochs before pruning starts.
1134
+ eval_interval (int): Epochs between evaluations.
1135
+ eval_requires_latents (bool): If True, optimize latents for eval.
1136
+ eval_latent_steps (int): Latent optimization steps for eval.
1137
+ eval_latent_lr (float): Latent optimization LR for eval.
1138
+ eval_latent_weight_decay (float): Latent optimization weight decay for eval.
1139
+
1140
+ Returns:
1141
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
1142
+
1143
+ Raises:
1144
+ ValueError: If X_val is not provided for evaluation.
1145
+ RuntimeError: If eval_latent_steps is not positive.
1146
+ """
1147
+ history: dict[str, list[float]] = {}
1148
+ final_best_loss, final_best_model = float("inf"), None
1149
+
1150
+ warm, ramp, gamma_final = 50, 100, torch.tensor(self.gamma, device=self.device)
1151
+
1152
+ # Schema-aware latent cache for eval
1153
+ _latent_cache: dict = {}
1154
+ nF = getattr(model, "n_features", self.num_features_)
1155
+ cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.num_classes_}"
1156
+
1157
+ E = int(self.epochs)
1158
+ phase_epochs = {
1159
+ 1: max(1, int(0.15 * E)),
1160
+ 2: max(1, int(0.35 * E)),
1161
+ 3: max(1, E - int(0.15 * E) - int(0.35 * E)),
1162
+ }
1163
+
1164
+ for phase in (1, 2, 3):
1165
+ steps_this_phase = phase_epochs[phase]
1166
+ warmup_epochs = getattr(self, "lr_warmup_epochs", 5) if phase == 1 else 0
1167
+
1168
+ early_stopping = EarlyStopping(
1169
+ patience=self.early_stop_gen,
1170
+ min_epochs=self.min_epochs,
1171
+ verbose=self.verbose,
1172
+ prefix=self.prefix,
1173
+ debug=self.debug,
1174
+ )
1175
+
1176
+ if phase == 2:
1177
+ self._reset_weights(model)
1178
+
1179
+ decoder: torch.Tensor | torch.nn.Module = (
1180
+ model.phase1_decoder if phase == 1 else model.phase23_decoder
1181
+ )
1182
+
1183
+ if not isinstance(decoder, torch.nn.Module):
1184
+ msg = f"{self.model_name} Decoder is not a torch.nn.Module."
1185
+ self.logger.error(msg)
1186
+ raise TypeError(msg)
1187
+
1188
+ decoder_params = decoder.parameters()
1189
+ optimizer = torch.optim.AdamW(decoder_params, lr=lr, eps=1e-7)
1190
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1191
+ optimizer, T_max=steps_this_phase
1192
+ )
1193
+
1194
+ # Cache base LRs for warmup
1195
+ dec_lr0 = optimizer.param_groups[0]["lr"]
1196
+ lat_lr0 = latent_optimizer.param_groups[0]["lr"]
1197
+ dec_lr_min, lat_lr_min = dec_lr0 * 0.1, lat_lr0 * 0.1
1198
+
1199
+ phase_hist: list[float] = []
1200
+ gamma_init = torch.tensor(0.0, device=self.device)
1201
+
1202
+ for epoch in range(steps_this_phase):
1203
+ # Focal gamma warm/ramp
1204
+ if epoch < warm:
1205
+ model.gamma = gamma_init.cpu().numpy().item()
1206
+ elif epoch < warm + ramp:
1207
+ model.gamma = gamma_final * ((epoch - warm) / ramp)
1208
+ else:
1209
+ model.gamma = gamma_final
1210
+
1211
+ # Linear warmup for both optimizers
1212
+ if warmup_epochs and epoch < warmup_epochs:
1213
+ scale = float(epoch + 1) / warmup_epochs
1214
+ for g in optimizer.param_groups:
1215
+ g["lr"] = dec_lr_min + (dec_lr0 - dec_lr_min) * scale
1216
+ for g in latent_optimizer.param_groups:
1217
+ g["lr"] = lat_lr_min + (lat_lr0 - lat_lr_min) * scale
1218
+
1219
+ train_loss, latent_vectors = self._train_step(
1220
+ loader=loader,
1221
+ optimizer=optimizer,
1222
+ latent_optimizer=latent_optimizer,
1223
+ model=model,
1224
+ l1_penalty=l1_penalty,
1225
+ latent_vectors=latent_vectors,
1226
+ class_weights=class_weights,
1227
+ phase=phase,
1228
+ )
1229
+
1230
+ if not np.isfinite(train_loss):
1231
+ if trial:
1232
+ raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
1233
+ # reduce LRs and continue
1234
+ for g in optimizer.param_groups:
1235
+ g["lr"] *= 0.5
1236
+ for g in latent_optimizer.param_groups:
1237
+ g["lr"] *= 0.5
1238
+ continue
1239
+
1240
+ scheduler.step()
1241
+ if return_history:
1242
+ phase_hist.append(train_loss)
1243
+
1244
+ early_stopping(train_loss, model)
1245
+ if early_stopping.early_stop:
1246
+ self.logger.info(
1247
+ f"Early stopping at epoch {epoch + 1} (phase {phase})."
1248
+ )
1249
+ break
1250
+
1251
+ # Validation + pruning
1252
+ if (
1253
+ trial is not None
1254
+ and X_val is not None
1255
+ and ((epoch + 1) % eval_interval == 0)
1256
+ ):
1257
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
1258
+ zdim = self._first_linear_in_features(model)
1259
+ schema_key = f"{cache_key_root}_z{zdim}"
1260
+ mask_override = None
1261
+ if (
1262
+ self.simulate_missing
1263
+ and getattr(self, "sim_mask_test_", None) is not None
1264
+ and getattr(self, "X_test_", None) is not None
1265
+ and X_val.shape == self.X_test_.shape
1266
+ ):
1267
+ mask_override = self.sim_mask_test_
1268
+
1269
+ metric_val = self._eval_for_pruning(
1270
+ model=model,
1271
+ X_val=X_val,
1272
+ params=params or getattr(self, "best_params_", {}),
1273
+ metric=metric_key,
1274
+ objective_mode=True,
1275
+ do_latent_infer=eval_requires_latents,
1276
+ latent_steps=eval_latent_steps,
1277
+ latent_lr=eval_latent_lr,
1278
+ latent_weight_decay=eval_latent_weight_decay,
1279
+ latent_seed=self.seed, # type: ignore
1280
+ _latent_cache=_latent_cache,
1281
+ _latent_cache_key=schema_key,
1282
+ eval_mask_override=mask_override,
1283
+ )
1284
+
1285
+ if phase == 3:
1286
+ trial.report(metric_val, step=epoch + 1)
1287
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
1288
+ raise optuna.exceptions.TrialPruned(
1289
+ f"Pruned at epoch {epoch + 1} (phase {phase}): {metric_key}={metric_val:.5f}"
1290
+ )
1291
+
1292
+ history[f"Phase {phase}"] = phase_hist
1293
+ final_best_loss = early_stopping.best_score
1294
+ if early_stopping.best_model is not None:
1295
+ final_best_model = copy.deepcopy(early_stopping.best_model)
1296
+ else:
1297
+ final_best_model = copy.deepcopy(model)
1298
+
1299
+ if final_best_model is None:
1300
+ final_best_model = copy.deepcopy(model)
1301
+
1302
+ return final_best_loss, final_best_model, history, latent_vectors
1303
+
1304
+ def _optimize_latents_for_inference(
1305
+ self,
1306
+ X_new: np.ndarray,
1307
+ model: torch.nn.Module,
1308
+ params: dict,
1309
+ inference_epochs: int = 200,
1310
+ ) -> torch.Tensor:
1311
+ """Optimize latents for new 0/1/2 data with guards.
1312
+
1313
+ This method optimizes the latent vectors for new genotype data using the trained UBP model. It initializes the latent space based on the provided data and iteratively updates the latent vectors to minimize the cross-entropy loss between the model's predictions and the true genotype values. The optimization process includes numeric stability guards to ensure that gradients and losses remain finite. The optimized latent vectors are returned as a PyTorch tensor.
1314
+
1315
+ Args:
1316
+ X_new (np.ndarray): New 0/1/2 data with -1 for missing.
1317
+ model (torch.nn.Module): Trained UBP model.
1318
+ params (dict): Model params.
1319
+ inference_epochs (int): Number of optimization epochs.
1320
+
1321
+ Returns:
1322
+ torch.Tensor: Optimized latent vectors.
1323
+ """
1324
+ model.eval()
1325
+ nF = getattr(model, "n_features", self.num_features_)
1326
+
1327
+ if self.tune and self.tune_fast:
1328
+ inference_epochs = min(
1329
+ inference_epochs, getattr(self, "tune_infer_epochs", 20)
1330
+ )
1331
+
1332
+ X_new = X_new.astype(np.int64, copy=False)
1333
+ X_new[X_new < 0] = -1
1334
+ y = torch.from_numpy(X_new).long().to(self.device)
1335
+
1336
+ z = self._create_latent_space(
1337
+ params, len(X_new), X_new, self.latent_init
1338
+ ).requires_grad_(True)
1339
+ opt = torch.optim.AdamW(
1340
+ [z], lr=self.learning_rate * self.lr_input_factor, eps=1e-7
1341
+ )
1342
+
1343
+ for _ in range(inference_epochs):
1344
+ decoder = model.phase23_decoder
1345
+
1346
+ if not isinstance(decoder, torch.nn.Module):
1347
+ msg = f"{self.model_name} Decoder is not a torch.nn.Module."
1348
+ self.logger.error(msg)
1349
+ raise TypeError(msg)
1350
+
1351
+ opt.zero_grad(set_to_none=True)
1352
+ logits = decoder(z).view(len(X_new), nF, self.num_classes_)
1353
+
1354
+ if not torch.isfinite(logits).all():
1355
+ break
1356
+
1357
+ loss = F.cross_entropy(
1358
+ logits.view(-1, self.num_classes_), y.view(-1), ignore_index=-1
1359
+ )
1360
+
1361
+ if not torch.isfinite(loss):
1362
+ break
1363
+
1364
+ loss.backward()
1365
+
1366
+ torch.nn.utils.clip_grad_norm_([z], 1.0)
1367
+
1368
+ if z.grad is None or not torch.isfinite(z.grad).all():
1369
+ break
1370
+
1371
+ opt.step()
1372
+
1373
+ return z.detach()
1374
+
1375
+ def _create_latent_space(
1376
+ self,
1377
+ params: dict,
1378
+ n_samples: int,
1379
+ X: np.ndarray,
1380
+ latent_init: Literal["random", "pca"],
1381
+ ) -> torch.nn.Parameter:
1382
+ """Initialize latent space via random Xavier or PCA on 0/1/2 matrix.
1383
+
1384
+ This method initializes the latent space for the UBP model using either random Xavier initialization or PCA-based initialization. The choice of initialization strategy is determined by the latent_init parameter. If PCA is selected, the method handles missing values by imputing them with column means before performing PCA. The resulting latent vectors are standardized and converted to a PyTorch parameter that can be optimized during training.
1385
+
1386
+ Args:
1387
+ params (dict): Contains 'latent_dim'.
1388
+ n_samples (int): Number of samples.
1389
+ X (np.ndarray): (n_samples x L) 0/1/2 with -1 missing.
1390
+ latent_init (Literal["random","pca"]): Init strategy.
1391
+
1392
+ Returns:
1393
+ torch.nn.Parameter: Trainable latent matrix.
1394
+ """
1395
+ latent_dim = int(params["latent_dim"])
1396
+
1397
+ if latent_init == "pca":
1398
+ X_pca = X.astype(np.float32, copy=True)
1399
+ # mark missing
1400
+ X_pca[X_pca < 0] = np.nan
1401
+
1402
+ # ---- SAFE column means without warnings ----
1403
+ valid_counts = np.sum(~np.isnan(X_pca), axis=0)
1404
+ col_sums = np.nansum(X_pca, axis=0)
1405
+ col_means = np.divide(
1406
+ col_sums,
1407
+ valid_counts,
1408
+ out=np.zeros_like(col_sums, dtype=np.float32),
1409
+ where=valid_counts > 0,
1410
+ )
1411
+
1412
+ # impute NaNs with per-column means
1413
+ # (all-NaN cols -> 0.0 by the divide above)
1414
+ nan_r, nan_c = np.where(np.isnan(X_pca))
1415
+ if nan_r.size:
1416
+ X_pca[nan_r, nan_c] = col_means[nan_c]
1417
+
1418
+ # center columns
1419
+ X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
1420
+
1421
+ # guard: degenerate / all-zero after centering ->
1422
+ # fall back to random
1423
+ if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
1424
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
1425
+ torch.nn.init.xavier_uniform_(latents)
1426
+ return torch.nn.Parameter(latents, requires_grad=True)
1427
+
1428
+ # rank-aware component count, at least 1
1429
+ try:
1430
+ est_rank = np.linalg.matrix_rank(X_pca)
1431
+ except Exception:
1432
+ est_rank = min(n_samples, X_pca.shape[1])
1433
+
1434
+ n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
1435
+
1436
+ # use deterministic SVD to avoid power-iteration warnings
1437
+ pca = PCA(
1438
+ n_components=n_components,
1439
+ svd_solver="randomized",
1440
+ random_state=self.seed,
1441
+ )
1442
+ initial = pca.fit_transform(X_pca) # (n_samples, n_components)
1443
+
1444
+ # pad if latent_dim > n_components
1445
+ if n_components < latent_dim:
1446
+ pad = self.rng.standard_normal(
1447
+ size=(n_samples, latent_dim - n_components)
1448
+ )
1449
+ initial = np.hstack([initial, pad])
1450
+
1451
+ # standardize latent dims
1452
+ initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
1453
+
1454
+ latents = torch.from_numpy(initial).float().to(self.device)
1455
+ return torch.nn.Parameter(latents, requires_grad=True)
1456
+
1457
+ else:
1458
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
1459
+ torch.nn.init.xavier_uniform_(latents)
1460
+
1461
+ return torch.nn.Parameter(latents, requires_grad=True)
1462
+
1463
+ def _reset_weights(self, model: torch.nn.Module) -> None:
1464
+ """Selectively resets only the weights of the phase 2/3 decoder.
1465
+
1466
+ This method targets only the `phase23_decoder` attribute of the UBPModel, leaving the `phase1_decoder` and other potential model components untouched. This allows the model to be re-initialized for the second phase of training without affecting other parts.
1467
+
1468
+ Args:
1469
+ model (torch.nn.Module): The PyTorch model whose parameters are to be reset.
1470
+ """
1471
+ if hasattr(model, "phase23_decoder"):
1472
+ decoder = model.phase23_decoder
1473
+ if not isinstance(decoder, torch.nn.Module):
1474
+ msg = f"{self.model_name} phase23_decoder is not a torch.nn.Module."
1475
+ self.logger.error(msg)
1476
+ raise TypeError(msg)
1477
+ # Iterate through only the modules of the second decoder
1478
+ for layer in decoder.modules():
1479
+ if hasattr(layer, "reset_parameters") and isinstance(
1480
+ layer.reset_parameters, torch.nn.Module
1481
+ ):
1482
+ layer.reset_parameters()
1483
+ else:
1484
+ self.logger.warning(
1485
+ "Model does not have a 'phase23_decoder' attribute; skipping weight reset."
1486
+ )
1487
+
1488
+ def _latent_infer_for_eval(
1489
+ self,
1490
+ model: torch.nn.Module,
1491
+ X_val: np.ndarray,
1492
+ *,
1493
+ steps: int,
1494
+ lr: float,
1495
+ weight_decay: float,
1496
+ seed: int,
1497
+ cache: dict | None,
1498
+ cache_key: str | None,
1499
+ ) -> None:
1500
+ """Freeze network; refine validation latents only with guards.
1501
+
1502
+ This method refines the latent vectors for the validation dataset using the trained UBP model. It freezes the model parameters to prevent updates during this phase and optimizes the latent vectors to minimize the cross-entropy loss between the model's predictions and the true genotype values. The optimization process includes numeric stability checks to ensure that gradients and losses remain finite. If a cache is provided, it stores the optimized latent vectors for future use.
1503
+
1504
+ Args:
1505
+ model (torch.nn.Module): Trained UBP model.
1506
+ X_val (np.ndarray): Validation set 0/1/2 with -1 missing
1507
+ steps (int): Number of optimization steps.
1508
+ lr (float): Learning rate for latent optimization.
1509
+ weight_decay (float): Weight decay for latent optimization.
1510
+ seed (int): Random seed for reproducibility.
1511
+ cache (dict | None): Optional cache for latent vectors.
1512
+ cache_key (str | None): Key for storing/retrieving from cache.
1513
+ """
1514
+ if seed is None:
1515
+ seed = np.random.randint(0, 999_999)
1516
+
1517
+ torch.manual_seed(seed)
1518
+ np.random.seed(seed)
1519
+
1520
+ model.eval()
1521
+ for p in model.parameters():
1522
+ p.requires_grad_(False)
1523
+
1524
+ nF = getattr(model, "n_features", self.num_features_)
1525
+ X_val = X_val.astype(np.int64, copy=False)
1526
+ X_val[X_val < 0] = -1
1527
+ y = torch.from_numpy(X_val).long().to(self.device)
1528
+
1529
+ zdim = self._first_linear_in_features(model)
1530
+ schema_key = f"{self.prefix}_ubp_val_latents_z{zdim}_L{nF}_K{self.num_classes_}"
1531
+
1532
+ if cache is not None and schema_key in cache:
1533
+ z = cache[schema_key].detach().clone().requires_grad_(True)
1534
+ else:
1535
+ z = self._create_latent_space(
1536
+ {"latent_dim": zdim}, X_val.shape[0], X_val, self.latent_init
1537
+ ).requires_grad_(True)
1538
+
1539
+ opt = torch.optim.AdamW([z], lr=lr, weight_decay=weight_decay, eps=1e-7)
1540
+
1541
+ for _ in range(max(int(steps), 0)):
1542
+ opt.zero_grad(set_to_none=True)
1543
+
1544
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1545
+
1546
+ if not isinstance(decoder, torch.nn.Module):
1547
+ msg = f"{self.model_name} Decoder is not a torch.nn.Module."
1548
+ self.logger.error(msg)
1549
+ raise TypeError(msg)
1550
+
1551
+ logits = decoder(z).view(X_val.shape[0], nF, self.num_classes_)
1552
+ if not torch.isfinite(logits).all():
1553
+ break
1554
+
1555
+ loss = F.cross_entropy(
1556
+ logits.view(-1, self.num_classes_), y.view(-1), ignore_index=-1
1557
+ )
1558
+
1559
+ if not torch.isfinite(loss):
1560
+ break
1561
+
1562
+ loss.backward()
1563
+
1564
+ torch.nn.utils.clip_grad_norm_([z], 1.0)
1565
+
1566
+ if z.grad is None or not torch.isfinite(z.grad).all():
1567
+ break
1568
+
1569
+ opt.step()
1570
+
1571
+ if cache is not None:
1572
+ cache[schema_key] = z.detach().clone()
1573
+
1574
+ for p in model.parameters():
1575
+ p.requires_grad_(True)