pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -0,0 +1,1361 @@
1
+ import copy
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import optuna
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from sklearn.exceptions import NotFittedError
10
+ from sklearn.model_selection import train_test_split
11
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
12
+ from snpio.utils.logging import LoggerManager
13
+ from torch.optim.lr_scheduler import CosineAnnealingLR
14
+
15
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
16
+ from pgsui.data_processing.containers import AutoencoderConfig
17
+ from pgsui.data_processing.transformers import SimMissingTransformer
18
+ from pgsui.impute.unsupervised.base import BaseNNImputer
19
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
20
+ from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
21
+ from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
22
+ from pgsui.utils.logging_utils import configure_logger
23
+ from pgsui.utils.pretty_metrics import PrettyMetrics
24
+
25
+ if TYPE_CHECKING:
26
+ from snpio import TreeParser
27
+ from snpio.read_input.genotype_data import GenotypeData
28
+
29
+
30
+ def ensure_autoencoder_config(
31
+ config: AutoencoderConfig | dict | str | None,
32
+ ) -> AutoencoderConfig:
33
+ """Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
34
+
35
+ This method normalizes the configuration input for the Autoencoder imputer. It accepts a structured configuration in various formats, including a dataclass instance, a nested dictionary, a YAML file path, or None. The method processes the input accordingly and returns a concrete instance of AutoencoderConfig with all necessary fields populated.
36
+
37
+ Args:
38
+ config (AutoencoderConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
39
+
40
+ Returns:
41
+ AutoencoderConfig: Concrete configuration instance.
42
+ """
43
+ if config is None:
44
+ return AutoencoderConfig()
45
+ if isinstance(config, AutoencoderConfig):
46
+ return config
47
+ if isinstance(config, str):
48
+ # YAML path — top-level `preset` key is supported
49
+ return load_yaml_to_dataclass(config, AutoencoderConfig)
50
+ if isinstance(config, dict):
51
+ # Flatten dict into dot-keys then overlay onto a fresh instance
52
+ base = AutoencoderConfig()
53
+
54
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
55
+ for k, v in d.items():
56
+ kk = f"{prefix}.{k}" if prefix else k
57
+ if isinstance(v, dict):
58
+ _flatten(kk, v, out)
59
+ else:
60
+ out[kk] = v
61
+ return out
62
+
63
+ # Lift any present preset first
64
+ preset_name = config.pop("preset", None)
65
+ if "io" in config and isinstance(config["io"], dict):
66
+ preset_name = preset_name or config["io"].pop("preset", None)
67
+
68
+ if preset_name:
69
+ base = AutoencoderConfig.from_preset(preset_name)
70
+
71
+ flat = _flatten("", config, {})
72
+ return apply_dot_overrides(base, flat)
73
+
74
+ raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
75
+
76
+
77
+ class ImputeAutoencoder(BaseNNImputer):
78
+ """Impute missing genotypes with a standard Autoencoder on 0/1/2 encodings.
79
+
80
+ This imputer uses a feedforward autoencoder architecture to learn compressed and reconstructive representations of genotype data encoded as 0 (homozygous reference), 1 (heterozygous), and 2 (homozygous alternate). Missing genotypes are represented as -1 during training and imputation.
81
+
82
+ The model is trained to minimize a focal cross-entropy loss, which helps to address class imbalance by focusing more on hard-to-classify examples. The architecture includes configurable parameters such as the number of hidden layers, latent dimension size, dropout rate, and activation functions.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ genotype_data: "GenotypeData",
88
+ *,
89
+ tree_parser: Optional["TreeParser"] = None,
90
+ config: Optional[Union["AutoencoderConfig", dict, str]] = None,
91
+ overrides: dict | None = None,
92
+ simulate_missing: bool | None = None,
93
+ sim_strategy: (
94
+ Literal[
95
+ "random",
96
+ "random_weighted",
97
+ "random_weighted_inv",
98
+ "nonrandom",
99
+ "nonrandom_weighted",
100
+ ]
101
+ | None
102
+ ) = None,
103
+ sim_prop: float | None = None,
104
+ sim_kwargs: dict | None = None,
105
+ ) -> None:
106
+ """Initialize the Autoencoder imputer with a unified config interface.
107
+
108
+ This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
109
+
110
+ Args:
111
+ genotype_data ("GenotypeData"): Backing genotype data object.
112
+ tree_parser (Optional["TreeParser"]): Optional SNPio phylogenetic tree parser for population-specific modes.
113
+ config (Union["AutoencoderConfig", dict, str] | None): Structured configuration as dataclass, nested dict, YAML path, or None.
114
+ overrides (dict | None): Optional dot-key overrides with highest precedence (e.g., {'model.latent_dim': 32}).
115
+ simulate_missing (bool | None): Whether to simulate missing data during evaluation. If None, uses config default.
116
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
117
+ sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
118
+ sim_kwargs (dict | None): Additional keyword arguments for simulating missing data. If None, uses config default.
119
+ """
120
+ self.model_name = "ImputeAutoencoder"
121
+ self.genotype_data = genotype_data
122
+ self.tree_parser = tree_parser
123
+
124
+ # Normalize config then apply highest-precedence overrides
125
+ cfg = ensure_autoencoder_config(config)
126
+ if overrides:
127
+ cfg = apply_dot_overrides(cfg, overrides)
128
+ self.cfg = cfg
129
+
130
+ # Logger consistent with NLPCA
131
+ logman = LoggerManager(
132
+ __name__,
133
+ prefix=self.cfg.io.prefix,
134
+ debug=self.cfg.io.debug,
135
+ verbose=self.cfg.io.verbose,
136
+ )
137
+ self.logger = configure_logger(
138
+ logman.get_logger(),
139
+ verbose=self.cfg.io.verbose,
140
+ debug=self.cfg.io.debug,
141
+ )
142
+
143
+ # BaseNNImputer bootstrapping (device/dirs/logging handled here)
144
+ super().__init__(
145
+ model_name=self.model_name,
146
+ genotype_data=self.genotype_data,
147
+ prefix=self.cfg.io.prefix,
148
+ device=self.cfg.train.device,
149
+ verbose=self.cfg.io.verbose,
150
+ debug=self.cfg.io.debug,
151
+ )
152
+
153
+ self.Model = AutoencoderModel
154
+
155
+ # Model hook & encoder
156
+ self.pgenc = GenotypeEncoder(genotype_data)
157
+
158
+ # IO / global
159
+ self.seed = self.cfg.io.seed
160
+ self.n_jobs = self.cfg.io.n_jobs
161
+ self.prefix = self.cfg.io.prefix
162
+ self.scoring_averaging = self.cfg.io.scoring_averaging
163
+ self.verbose = self.cfg.io.verbose
164
+ self.debug = self.cfg.io.debug
165
+ self.rng = np.random.default_rng(self.seed)
166
+ self.pos_weights_: torch.Tensor | None = None
167
+
168
+ # Simulated-missing controls (config defaults with ctor overrides)
169
+ sim_cfg = getattr(self.cfg, "sim", None)
170
+ sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
171
+ if sim_kwargs:
172
+ sim_cfg_kwargs.update(sim_kwargs)
173
+ self.simulate_missing = (
174
+ (
175
+ sim_cfg.simulate_missing
176
+ if simulate_missing is None
177
+ else bool(simulate_missing)
178
+ )
179
+ if sim_cfg is not None
180
+ else bool(simulate_missing)
181
+ )
182
+ if sim_cfg is None:
183
+ default_strategy = "random"
184
+ default_prop = 0.10
185
+ else:
186
+ default_strategy = sim_cfg.sim_strategy
187
+ default_prop = sim_cfg.sim_prop
188
+ self.sim_strategy = sim_strategy or default_strategy
189
+ self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
190
+ self.sim_kwargs = sim_cfg_kwargs
191
+
192
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
193
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
194
+ self.logger.error(msg)
195
+ raise ValueError(msg)
196
+
197
+ # Model hyperparams
198
+ self.latent_dim = int(self.cfg.model.latent_dim)
199
+ self.dropout_rate = float(self.cfg.model.dropout_rate)
200
+ self.num_hidden_layers = int(self.cfg.model.num_hidden_layers)
201
+ self.layer_scaling_factor = float(self.cfg.model.layer_scaling_factor)
202
+ self.layer_schedule: str = str(self.cfg.model.layer_schedule)
203
+ self.activation = str(self.cfg.model.hidden_activation)
204
+ self.gamma = float(self.cfg.model.gamma)
205
+
206
+ # Train hyperparams
207
+ self.batch_size = int(self.cfg.train.batch_size)
208
+ self.learning_rate = float(self.cfg.train.learning_rate)
209
+ self.l1_penalty: float = float(self.cfg.train.l1_penalty)
210
+ self.early_stop_gen = int(self.cfg.train.early_stop_gen)
211
+ self.min_epochs = int(self.cfg.train.min_epochs)
212
+ self.epochs = int(self.cfg.train.max_epochs)
213
+ self.validation_split = float(self.cfg.train.validation_split)
214
+ self.beta = float(self.cfg.train.weights_beta)
215
+ self.max_ratio = float(self.cfg.train.weights_max_ratio)
216
+
217
+ # Tuning
218
+ self.tune = bool(self.cfg.tune.enabled)
219
+ self.tune_fast = bool(self.cfg.tune.fast)
220
+ self.tune_batch_size = int(self.cfg.tune.batch_size)
221
+ self.tune_epochs = int(self.cfg.tune.epochs)
222
+ self.tune_eval_interval = int(self.cfg.tune.eval_interval)
223
+ self.tune_metric: str = self.cfg.tune.metric
224
+
225
+ if self.tune_metric is not None:
226
+ self.tune_metric_: (
227
+ Literal[
228
+ "pr_macro",
229
+ "f1",
230
+ "accuracy",
231
+ "precision",
232
+ "recall",
233
+ "roc_auc",
234
+ "average_precision",
235
+ ]
236
+ | None
237
+ ) = self.cfg.tune.metric
238
+
239
+ self.n_trials = int(self.cfg.tune.n_trials)
240
+ self.tune_save_db = bool(self.cfg.tune.save_db)
241
+ self.tune_resume = bool(self.cfg.tune.resume)
242
+ self.tune_max_samples = int(self.cfg.tune.max_samples)
243
+ self.tune_max_loci = int(self.cfg.tune.max_loci)
244
+ self.tune_infer_epochs = int(
245
+ getattr(self.cfg.tune, "infer_epochs", 0)
246
+ ) # AE unused
247
+ self.tune_patience = int(self.cfg.tune.patience)
248
+
249
+ # Evaluate
250
+ # AE does not optimize latents, so these are unused / fixed
251
+ self.eval_latent_steps: int = 0
252
+ self.eval_latent_lr: float = 0.0
253
+ self.eval_latent_weight_decay: float = 0.0
254
+
255
+ # Plotting (parity with NLPCA PlotConfig)
256
+ self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = (
257
+ self.cfg.plot.fmt
258
+ )
259
+ self.plot_dpi = int(self.cfg.plot.dpi)
260
+ self.plot_fontsize = int(self.cfg.plot.fontsize)
261
+ self.title_fontsize = int(self.cfg.plot.fontsize)
262
+ self.despine = bool(self.cfg.plot.despine)
263
+ self.show_plots = bool(self.cfg.plot.show)
264
+
265
+ # Core derived at fit-time
266
+ self.is_haploid: bool = False
267
+ self.num_classes_: int | None = None
268
+ self.model_params: Dict[str, Any] = {}
269
+ self.sim_mask_global_: np.ndarray | None = None
270
+ self.sim_mask_train_: np.ndarray | None = None
271
+ self.sim_mask_test_: np.ndarray | None = None
272
+
273
+ def fit(self) -> "ImputeAutoencoder":
274
+ """Fit the autoencoder on 0/1/2 encoded genotypes (missing -> -1).
275
+
276
+ This method trains the autoencoder model using the provided genotype data. It prepares the data by encoding genotypes as 0, 1, and 2, with missing values represented internally as -1. (When simulated-missing loci are generated via ``SimMissingTransformer`` they are first marked with -9 but are immediately re-encoded as -1 prior to training.) The method splits the data into training and validation sets, initializes the model and training parameters, and performs training with optional hyperparameter tuning. After training, it evaluates the model on the validation set and stores the fitted model and training history.
277
+
278
+ Returns:
279
+ ImputeAutoencoder: Fitted instance.
280
+
281
+ Raises:
282
+ NotFittedError: If training fails.
283
+ """
284
+ self.logger.info(f"Fitting {self.model_name} model...")
285
+
286
+ # --- Data prep (mirror NLPCA) ---
287
+ X012 = self._get_float_genotypes(copy=True)
288
+ GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
289
+ self.ground_truth_ = GT_full.astype(np.int64, copy=False)
290
+
291
+ self.sim_mask_global_ = None
292
+ cache_key = self._sim_mask_cache_key()
293
+ if self.simulate_missing:
294
+ cached_mask = (
295
+ None if cache_key is None else self._sim_mask_cache.get(cache_key)
296
+ )
297
+ if cached_mask is not None:
298
+ self.sim_mask_global_ = cached_mask.copy()
299
+ else:
300
+ tr = SimMissingTransformer(
301
+ genotype_data=self.genotype_data,
302
+ tree_parser=self.tree_parser,
303
+ prop_missing=self.sim_prop,
304
+ strategy=self.sim_strategy,
305
+ missing_val=-9,
306
+ mask_missing=True,
307
+ verbose=self.verbose,
308
+ **self.sim_kwargs,
309
+ )
310
+ tr.fit(X012.copy())
311
+ self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
312
+ if cache_key is not None:
313
+ self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
314
+
315
+ X_for_model = self.ground_truth_.copy()
316
+ X_for_model[self.sim_mask_global_] = -1
317
+ else:
318
+ X_for_model = self.ground_truth_.copy()
319
+
320
+ if self.genotype_data.snp_data is None:
321
+ msg = "SNP data is required for Autoencoder imputer."
322
+ self.logger.error(msg)
323
+ raise TypeError(msg)
324
+
325
+ # Ploidy & classes
326
+ self.is_haploid = bool(
327
+ np.all(
328
+ np.isin(
329
+ self.genotype_data.snp_data,
330
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
331
+ )
332
+ )
333
+ )
334
+ self.ploidy = 1 if self.is_haploid else 2
335
+ # Scoring still uses 3 labels for diploid (REF/HET/ALT); model head uses 2 logits
336
+ self.num_classes_ = 2 if self.is_haploid else 3
337
+ self.output_classes_ = 2
338
+ self.logger.info(
339
+ f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
340
+ f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
341
+ )
342
+
343
+ if self.is_haploid:
344
+ self.ground_truth_[self.ground_truth_ == 2] = 1
345
+ X_for_model[X_for_model == 2] = 1
346
+
347
+ n_samples, self.num_features_ = X_for_model.shape
348
+
349
+ # Model params (decoder outputs L * K logits)
350
+ self.model_params = {
351
+ "n_features": self.num_features_,
352
+ "num_classes": self.output_classes_,
353
+ "latent_dim": self.latent_dim,
354
+ "dropout_rate": self.dropout_rate,
355
+ "activation": self.activation,
356
+ }
357
+
358
+ # Train/Val split
359
+ indices = np.arange(n_samples)
360
+ train_idx, val_idx = train_test_split(
361
+ indices, test_size=self.validation_split, random_state=self.seed
362
+ )
363
+ self.train_idx_, self.test_idx_ = train_idx, val_idx
364
+ self.X_train_ = X_for_model[train_idx]
365
+ self.X_val_ = X_for_model[val_idx]
366
+ self.GT_train_full_ = self.ground_truth_[train_idx]
367
+ self.GT_test_full_ = self.ground_truth_[val_idx]
368
+
369
+ if self.sim_mask_global_ is not None:
370
+ self.sim_mask_train_ = self.sim_mask_global_[train_idx]
371
+ self.sim_mask_test_ = self.sim_mask_global_[val_idx]
372
+ else:
373
+ self.sim_mask_train_ = None
374
+ self.sim_mask_test_ = None
375
+
376
+ # Pos weights for diploid multilabel path (must exist before tuning)
377
+ if not self.is_haploid:
378
+ self.pos_weights_ = self._compute_pos_weights(self.X_train_)
379
+ else:
380
+ self.pos_weights_ = None
381
+
382
+ # Plotters/scorers (shared utilities)
383
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
384
+
385
+ # Tuning (optional; AE never needs latent refinement)
386
+ if self.tune:
387
+ self.tune_hyperparameters()
388
+
389
+ # Best params (tuned or default)
390
+ self.best_params_ = getattr(self, "best_params_", self._default_best_params())
391
+
392
+ # Class weights (device-aware)
393
+ self.class_weights_ = self._normalize_class_weights(
394
+ self._class_weights_from_zygosity(self.X_train_)
395
+ )
396
+
397
+ # DataLoader
398
+ train_loader = self._get_data_loaders(self.X_train_)
399
+
400
+ # Build & train
401
+ model = self.build_model(self.Model, self.best_params_)
402
+ model.apply(self.initialize_weights)
403
+
404
+ loss, trained_model, history = self._train_and_validate_model(
405
+ model=model,
406
+ loader=train_loader,
407
+ lr=self.learning_rate,
408
+ l1_penalty=self.l1_penalty,
409
+ return_history=True,
410
+ class_weights=self.class_weights_,
411
+ X_val=self.X_val_,
412
+ params=self.best_params_,
413
+ prune_metric=self.tune_metric,
414
+ prune_warmup_epochs=10,
415
+ eval_interval=1,
416
+ eval_requires_latents=False,
417
+ eval_latent_steps=0,
418
+ eval_latent_lr=0.0,
419
+ eval_latent_weight_decay=0.0,
420
+ )
421
+
422
+ if trained_model is None:
423
+ msg = "Autoencoder training failed; no model was returned."
424
+ self.logger.error(msg)
425
+ raise RuntimeError(msg)
426
+
427
+ torch.save(
428
+ trained_model.state_dict(),
429
+ self.models_dir / f"final_model_{self.model_name}.pt",
430
+ )
431
+
432
+ hist: Dict[str, List[float] | Dict[str, List[float]] | None] | None = {
433
+ "Train": history
434
+ }
435
+ self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
436
+ self.is_fit_ = True
437
+
438
+ # Evaluate on validation set (parity with NLPCA reporting)
439
+ eval_mask = (
440
+ self.sim_mask_test_
441
+ if (self.simulate_missing and self.sim_mask_test_ is not None)
442
+ else None
443
+ )
444
+ self._evaluate_model(
445
+ self.X_val_, self.model_, self.best_params_, eval_mask_override=eval_mask
446
+ )
447
+ self.plotter_.plot_history(self.history_)
448
+ self._save_best_params(self.best_params_)
449
+
450
+ return self
451
+
452
+ def transform(self) -> np.ndarray:
453
+ """Impute missing genotypes (0/1/2) and return IUPAC strings.
454
+
455
+ This method imputes missing genotypes in the dataset using the trained autoencoder model. It predicts the most likely genotype (0, 1, or 2) for each missing entry and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easier interpretation.
456
+
457
+ Returns:
458
+ np.ndarray: IUPAC strings of shape (n_samples, n_loci).
459
+
460
+ Raises:
461
+ NotFittedError: If called before fit().
462
+ """
463
+ if not getattr(self, "is_fit_", False):
464
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
465
+
466
+ self.logger.info(f"Imputing entire dataset with {self.model_name}...")
467
+ X_to_impute = self.ground_truth_.copy()
468
+
469
+ # Predict with masked inputs (no latent optimization)
470
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
471
+
472
+ # Fill only missing
473
+ missing_mask = X_to_impute == -1
474
+ imputed_array = X_to_impute.copy()
475
+ imputed_array[missing_mask] = pred_labels[missing_mask]
476
+
477
+ # Decode to IUPAC & optionally plot
478
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
479
+ if self.show_plots:
480
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
481
+ plt.rcParams.update(self.plotter_.param_dict)
482
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
483
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
484
+
485
+ return imputed_genotypes
486
+
487
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
488
+ """Create DataLoader over indices + integer targets (-1 for missing).
489
+
490
+ This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
491
+
492
+ Args:
493
+ y (np.ndarray): 0/1/2 matrix with -1 for missing.
494
+
495
+ Returns:
496
+ torch.utils.data.DataLoader: Shuffled DataLoader.
497
+ """
498
+ y_tensor = torch.from_numpy(y).long()
499
+ indices = torch.arange(len(y), dtype=torch.long)
500
+ dataset = torch.utils.data.TensorDataset(indices, y_tensor)
501
+ pin_memory = self.device.type == "cuda"
502
+ return torch.utils.data.DataLoader(
503
+ dataset,
504
+ batch_size=self.batch_size,
505
+ shuffle=True,
506
+ pin_memory=pin_memory,
507
+ )
508
+
509
+ def _train_and_validate_model(
510
+ self,
511
+ model: torch.nn.Module,
512
+ loader: torch.utils.data.DataLoader,
513
+ lr: float,
514
+ l1_penalty: float,
515
+ trial: optuna.Trial | None = None,
516
+ return_history: bool = False,
517
+ class_weights: torch.Tensor | None = None,
518
+ *,
519
+ X_val: np.ndarray | None = None,
520
+ params: dict | None = None,
521
+ prune_metric: str = "f1", # "f1" | "accuracy" | "pr_macro"
522
+ prune_warmup_epochs: int = 10,
523
+ eval_interval: int = 1,
524
+ # Evaluation parameters (AE ignores latent refinement knobs)
525
+ eval_requires_latents: bool = False, # AE: always False
526
+ eval_latent_steps: int = 0,
527
+ eval_latent_lr: float = 0.0,
528
+ eval_latent_weight_decay: float = 0.0,
529
+ ) -> Tuple[float, torch.nn.Module | None, list | None]:
530
+ """Wrap the AE training loop (no latent optimizer), with Optuna pruning.
531
+
532
+ This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
533
+
534
+ Args:
535
+ model (torch.nn.Module): Autoencoder model.
536
+ loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
537
+ lr (float): Learning rate.
538
+ l1_penalty (float): L1 regularization coeff.
539
+ trial (optuna.Trial | None): Optuna trial for pruning (optional).
540
+ return_history (bool): If True, return train loss history.
541
+ class_weights (torch.Tensor | None): Class weights tensor (on device).
542
+ X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
543
+ params (dict | None): Model params for evaluation.
544
+ prune_metric (str): Metric for pruning reports.
545
+ prune_warmup_epochs (int): Pruning warmup epochs.
546
+ eval_interval (int): Eval frequency (epochs).
547
+ eval_requires_latents (bool): Ignored for AE (no latent inference).
548
+ eval_latent_steps (int): Unused for AE.
549
+ eval_latent_lr (float): Unused for AE.
550
+ eval_latent_weight_decay (float): Unused for AE.
551
+
552
+ Returns:
553
+ Tuple[float, torch.nn.Module | None, list | None]: (best_loss, best_model, history or None).
554
+ """
555
+ if class_weights is None:
556
+ msg = "Must provide class_weights."
557
+ self.logger.error(msg)
558
+ raise TypeError(msg)
559
+
560
+ # Epoch budget mirrors NLPCA config (tuning vs final)
561
+ max_epochs = (
562
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
563
+ )
564
+
565
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
566
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
567
+
568
+ best_loss, best_model, hist = self._execute_training_loop(
569
+ loader=loader,
570
+ optimizer=optimizer,
571
+ scheduler=scheduler,
572
+ model=model,
573
+ l1_penalty=l1_penalty,
574
+ trial=trial,
575
+ return_history=return_history,
576
+ class_weights=class_weights,
577
+ X_val=X_val,
578
+ params=params,
579
+ prune_metric=prune_metric,
580
+ prune_warmup_epochs=prune_warmup_epochs,
581
+ eval_interval=eval_interval,
582
+ eval_requires_latents=False, # AE: no latent inference
583
+ eval_latent_steps=0,
584
+ eval_latent_lr=0.0,
585
+ eval_latent_weight_decay=0.0,
586
+ )
587
+ if return_history:
588
+ return best_loss, best_model, hist
589
+
590
+ return best_loss, best_model, None
591
+
592
+ def _execute_training_loop(
593
+ self,
594
+ loader: torch.utils.data.DataLoader,
595
+ optimizer: torch.optim.Optimizer,
596
+ scheduler: CosineAnnealingLR,
597
+ model: torch.nn.Module,
598
+ l1_penalty: float,
599
+ trial: optuna.Trial | None,
600
+ return_history: bool,
601
+ class_weights: torch.Tensor,
602
+ *,
603
+ X_val: np.ndarray | None = None,
604
+ params: dict | None = None,
605
+ prune_metric: str = "f1",
606
+ prune_warmup_epochs: int = 10,
607
+ eval_interval: int = 1,
608
+ # Evaluation parameters (AE ignores latent refinement knobs)
609
+ eval_requires_latents: bool = False, # AE: False
610
+ eval_latent_steps: int = 0,
611
+ eval_latent_lr: float = 0.0,
612
+ eval_latent_weight_decay: float = 0.0,
613
+ ) -> Tuple[float, torch.nn.Module, list]:
614
+ """Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
615
+
616
+ This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
617
+
618
+ Args:
619
+ loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
620
+ optimizer (torch.optim.Optimizer): Optimizer.
621
+ scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
622
+ model (torch.nn.Module): Autoencoder model.
623
+ l1_penalty (float): L1 regularization coeff.
624
+ trial (optuna.Trial | None): Optuna trial for pruning (optional).
625
+ return_history (bool): If True, return train loss history.
626
+ class_weights (torch.Tensor): Class weights tensor (on device).
627
+ X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
628
+ params (dict | None): Model params for evaluation.
629
+ prune_metric (str): Metric for pruning reports.
630
+ prune_warmup_epochs (int): Pruning warmup epochs.
631
+ eval_interval (int): Eval frequency (epochs).
632
+ eval_requires_latents (bool): Ignored for AE (no latent inference).
633
+ eval_latent_steps (int): Unused for AE.
634
+ eval_latent_lr (float): Unused for AE.
635
+ eval_latent_weight_decay (float): Unused for AE.
636
+
637
+ Returns:
638
+ Tuple[float, torch.nn.Module, list]: Best validation loss, best model, and training history.
639
+ """
640
+ best_loss = float("inf")
641
+ best_model = None
642
+ history: list[float] = []
643
+
644
+ early_stopping = EarlyStopping(
645
+ patience=self.early_stop_gen,
646
+ min_epochs=self.min_epochs,
647
+ verbose=self.verbose,
648
+ prefix=self.prefix,
649
+ debug=self.debug,
650
+ )
651
+
652
+ gamma_val = self.gamma
653
+ if isinstance(gamma_val, (list, tuple)):
654
+ if len(gamma_val) == 0:
655
+ raise ValueError("gamma list is empty.")
656
+ gamma_val = gamma_val[0]
657
+
658
+ gamma_final = float(gamma_val)
659
+ gamma_warm, gamma_ramp = 50, 100
660
+
661
+ # Optional LR warmup
662
+ warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
663
+ base_lr = float(optimizer.param_groups[0]["lr"])
664
+ min_lr = base_lr * 0.1
665
+
666
+ max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
667
+
668
+ for epoch in range(max_epochs):
669
+ # focal γ schedule (for stable training)
670
+ if epoch < gamma_warm:
671
+ model.gamma = 0.0 # type: ignore
672
+ elif epoch < gamma_warm + gamma_ramp:
673
+ model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore
674
+ else:
675
+ model.gamma = gamma_final # type: ignore
676
+
677
+ # LR warmup
678
+ if epoch < warmup_epochs:
679
+ scale = float(epoch + 1) / warmup_epochs
680
+ for g in optimizer.param_groups:
681
+ g["lr"] = min_lr + (base_lr - min_lr) * scale
682
+
683
+ train_loss = self._train_step(
684
+ loader=loader,
685
+ optimizer=optimizer,
686
+ model=model,
687
+ l1_penalty=l1_penalty,
688
+ class_weights=class_weights,
689
+ )
690
+
691
+ # Abort or prune on non-finite epoch loss
692
+ if not np.isfinite(train_loss):
693
+ if trial is not None:
694
+ raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
695
+ # Soft reset suggestion: reduce LR and continue, or break
696
+ self.logger.warning(
697
+ "Non-finite epoch loss. Reducing LR by 10 percent and continuing."
698
+ )
699
+ for g in optimizer.param_groups:
700
+ g["lr"] *= 0.9
701
+ continue
702
+
703
+ scheduler.step()
704
+ if return_history:
705
+ history.append(train_loss)
706
+
707
+ early_stopping(train_loss, model)
708
+ if early_stopping.early_stop:
709
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
710
+ break
711
+
712
+ # Optuna report/prune on validation metric
713
+ if (
714
+ trial is not None
715
+ and X_val is not None
716
+ and ((epoch + 1) % eval_interval == 0)
717
+ ):
718
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
719
+ mask_override = None
720
+ if (
721
+ self.simulate_missing
722
+ and getattr(self, "sim_mask_test_", None) is not None
723
+ and getattr(self, "X_val_", None) is not None
724
+ and X_val.shape == self.X_val_.shape
725
+ ):
726
+ mask_override = self.sim_mask_test_
727
+ metric_val = self._eval_for_pruning(
728
+ model=model,
729
+ X_val=X_val,
730
+ params=params or getattr(self, "best_params_", {}),
731
+ metric=metric_key,
732
+ objective_mode=True,
733
+ do_latent_infer=False, # AE: False
734
+ latent_steps=0,
735
+ latent_lr=0.0,
736
+ latent_weight_decay=0.0,
737
+ latent_seed=self.seed, # type: ignore
738
+ _latent_cache=None, # AE: not used
739
+ _latent_cache_key=None,
740
+ eval_mask_override=mask_override,
741
+ )
742
+ trial.report(metric_val, step=epoch + 1)
743
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
744
+ raise optuna.exceptions.TrialPruned(
745
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
746
+ )
747
+
748
+ best_loss = early_stopping.best_score
749
+ if early_stopping.best_model is not None:
750
+ best_model = copy.deepcopy(early_stopping.best_model)
751
+ else:
752
+ best_model = copy.deepcopy(model)
753
+ return best_loss, best_model, history
754
+
755
+ def _train_step(
756
+ self,
757
+ loader: torch.utils.data.DataLoader,
758
+ optimizer: torch.optim.Optimizer,
759
+ model: torch.nn.Module,
760
+ l1_penalty: float,
761
+ class_weights: torch.Tensor,
762
+ ) -> float:
763
+ """One epoch with stable focal CE and NaN/Inf guards."""
764
+ model.train()
765
+ running = 0.0
766
+ num_batches = 0
767
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
768
+ if class_weights is not None and class_weights.device != self.device:
769
+ class_weights = class_weights.to(self.device)
770
+
771
+ # Use model.gamma if present, else self.gamma
772
+ gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
773
+ gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
774
+ ce_criterion = SafeFocalCELoss(
775
+ gamma=gamma, weight=class_weights, ignore_index=-1
776
+ )
777
+
778
+ for _, y_batch in loader:
779
+ optimizer.zero_grad(set_to_none=True)
780
+ y_batch = y_batch.to(self.device, non_blocking=True)
781
+
782
+ # Inputs: one-hot with zeros for missing; Targets: long ints with -1 for missing
783
+ if self.is_haploid:
784
+ x_in = self._one_hot_encode_012(y_batch) # (B, L, 2)
785
+ logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
786
+ logits_flat = logits.view(-1, self.output_classes_)
787
+ targets_flat = y_batch.view(-1).long()
788
+ if not torch.isfinite(logits_flat).all():
789
+ continue
790
+ loss = ce_criterion(logits_flat, targets_flat)
791
+ else:
792
+ x_in = self._encode_multilabel_inputs(y_batch) # (B, L, 2)
793
+ logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
794
+ if not torch.isfinite(logits).all():
795
+ continue
796
+ pos_w = getattr(self, "pos_weights_", None)
797
+ targets = self._multi_hot_targets(y_batch) # float, same shape
798
+ bce = F.binary_cross_entropy_with_logits(
799
+ logits, targets, pos_weight=pos_w, reduction="none"
800
+ )
801
+ mask = (y_batch != -1).unsqueeze(-1).float()
802
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
803
+
804
+ if l1_penalty > 0:
805
+ l1 = torch.zeros((), device=self.device)
806
+ for p in l1_params:
807
+ l1 = l1 + p.abs().sum()
808
+ loss = loss + l1_penalty * l1
809
+
810
+ # Final guard
811
+ if not torch.isfinite(loss):
812
+ continue
813
+
814
+ loss.backward()
815
+
816
+ # Clip to prevent exploding grads
817
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
818
+
819
+ # If grads blew up to non-finite, skip update
820
+ if any(
821
+ (not torch.isfinite(p.grad).all())
822
+ for p in model.parameters()
823
+ if p.grad is not None
824
+ ):
825
+ optimizer.zero_grad(set_to_none=True)
826
+ continue
827
+
828
+ optimizer.step()
829
+
830
+ running += float(loss.detach().item())
831
+ num_batches += 1
832
+
833
+ if num_batches == 0:
834
+ return float("inf") # signal upstream that epoch had no usable batches
835
+ return running / num_batches
836
+
837
+ def _predict(
838
+ self,
839
+ model: torch.nn.Module,
840
+ X: np.ndarray | torch.Tensor,
841
+ return_proba: bool = False,
842
+ ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
843
+ """Predict 0/1/2 labels (and probabilities) from masked inputs.
844
+
845
+ This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
846
+
847
+ Args:
848
+ model (torch.nn.Module): Trained model.
849
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
850
+ for missing.
851
+ return_proba (bool): If True, return probabilities.
852
+
853
+ Returns:
854
+ Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels,
855
+ and probabilities if requested.
856
+ """
857
+ if model is None:
858
+ msg = "Model is not trained. Call fit() before predict()."
859
+ self.logger.error(msg)
860
+ raise NotFittedError(msg)
861
+
862
+ model.eval()
863
+ with torch.no_grad():
864
+ X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
865
+ X_tensor = X_tensor.to(self.device).long()
866
+ if self.is_haploid:
867
+ x_ohe = self._one_hot_encode_012(X_tensor)
868
+ logits = model(x_ohe).view(-1, self.num_features_, self.output_classes_)
869
+ probas = torch.softmax(logits, dim=-1)
870
+ labels = torch.argmax(probas, dim=-1)
871
+ else:
872
+ x_in = self._encode_multilabel_inputs(X_tensor)
873
+ logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
874
+ probas_2 = torch.sigmoid(logits)
875
+ p_ref = probas_2[..., 0]
876
+ p_alt = probas_2[..., 1]
877
+ p_het = p_ref * p_alt
878
+ p_ref_only = p_ref * (1 - p_alt)
879
+ p_alt_only = p_alt * (1 - p_ref)
880
+ stacked = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
881
+ stacked = stacked / stacked.sum(dim=-1, keepdim=True).clamp_min(1e-8)
882
+ probas = stacked
883
+ labels = torch.argmax(stacked, dim=-1)
884
+
885
+ if return_proba:
886
+ return labels.cpu().numpy(), probas.cpu().numpy()
887
+
888
+ return labels.cpu().numpy()
889
+
890
+ def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
891
+ """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
892
+ if self.is_haploid:
893
+ return self._one_hot_encode_012(y)
894
+ y = y.to(self.device)
895
+ shape = y.shape + (2,)
896
+ out = torch.zeros(shape, device=self.device, dtype=torch.float32)
897
+ valid = y != -1
898
+ ref_mask = valid & (y != 2)
899
+ alt_mask = valid & (y != 0)
900
+ out[ref_mask, 0] = 1.0
901
+ out[alt_mask, 1] = 1.0
902
+ return out
903
+
904
+ def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
905
+ """Targets aligned with _encode_multilabel_inputs for diploid training."""
906
+ if self.is_haploid:
907
+ # One-hot CE path expects integer targets; handled upstream.
908
+ raise RuntimeError("_multi_hot_targets called for haploid data.")
909
+ y = y.to(self.device)
910
+ out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
911
+ valid = y != -1
912
+ ref_mask = valid & (y != 2)
913
+ alt_mask = valid & (y != 0)
914
+ out[ref_mask, 0] = 1.0
915
+ out[alt_mask, 1] = 1.0
916
+ return out
917
+
918
+ def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
919
+ """Balance REF/ALT channels for multilabel BCE."""
920
+ ref_pos = np.count_nonzero((X == 0) | (X == 1))
921
+ alt_pos = np.count_nonzero((X == 2) | (X == 1))
922
+ total_valid = np.count_nonzero(X != -1)
923
+ pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
924
+ neg_counts = np.maximum(total_valid - pos_counts, 1.0)
925
+ pos_counts = np.maximum(pos_counts, 1.0)
926
+ weights = neg_counts / pos_counts
927
+ return torch.tensor(weights, device=self.device, dtype=torch.float32)
928
+
929
+ def _evaluate_model(
930
+ self,
931
+ X_val: np.ndarray,
932
+ model: torch.nn.Module,
933
+ params: dict,
934
+ objective_mode: bool = False,
935
+ latent_vectors_val: Optional[np.ndarray] = None,
936
+ *,
937
+ eval_mask_override: np.ndarray | None = None,
938
+ ) -> Dict[str, float]:
939
+ """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
940
+
941
+ This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
942
+
943
+ Args:
944
+ X_val (np.ndarray): Validation set 0/1/2 matrix with -1
945
+ for missing.
946
+ model (torch.nn.Module): Trained model.
947
+ params (dict): Model parameters.
948
+ objective_mode (bool): If True, suppress logging and reports.
949
+ latent_vectors_val (Optional[np.ndarray]): Unused for AE.
950
+ eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
951
+
952
+ Returns:
953
+ Dict[str, float]: Dictionary of evaluation metrics.
954
+ """
955
+ pred_labels, pred_probas = self._predict(
956
+ model=model, X=X_val, return_proba=True
957
+ )
958
+
959
+ finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
960
+
961
+ # FIX 1: Check ROWS (shape[0]) only. X_val might be a feature subset.
962
+ if (
963
+ hasattr(self, "X_val_")
964
+ and getattr(self, "X_val_", None) is not None
965
+ and X_val.shape[0] == self.X_val_.shape[0]
966
+ ):
967
+ GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
968
+ elif (
969
+ hasattr(self, "X_train_")
970
+ and getattr(self, "X_train_", None) is not None
971
+ and X_val.shape[0] == self.X_train_.shape[0]
972
+ ):
973
+ GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
974
+ else:
975
+ GT_ref = self.ground_truth_
976
+
977
+ # FIX 2: Handle Feature Mismatch (e.g., tune_fast feature subsetting)
978
+ # If the GT source has more columns than X_val, slice it to match.
979
+ if GT_ref.shape[1] > X_val.shape[1]:
980
+ GT_ref = GT_ref[:, : X_val.shape[1]]
981
+
982
+ # Fallback if rows mismatch (unlikely after Fix 1, but safe to keep)
983
+ if GT_ref.shape != X_val.shape:
984
+ # If completely different, we can't use the ground truth object.
985
+ # Fall back to X_val (this implies only observed values are scored)
986
+ GT_ref = X_val
987
+
988
+ if eval_mask_override is not None:
989
+ # FIX 3: Allow override mask to be sliced if it's too wide
990
+ if eval_mask_override.shape[0] != X_val.shape[0]:
991
+ msg = (
992
+ f"eval_mask_override rows {eval_mask_override.shape[0]} "
993
+ f"does not match X_val rows {X_val.shape[0]}"
994
+ )
995
+ self.logger.error(msg)
996
+ raise ValueError(msg)
997
+
998
+ if eval_mask_override.shape[1] > X_val.shape[1]:
999
+ eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
1000
+ else:
1001
+ eval_mask = eval_mask_override.astype(bool)
1002
+ else:
1003
+ eval_mask = X_val != -1
1004
+
1005
+ # Combine masks
1006
+ eval_mask = eval_mask & finite_mask & (GT_ref != -1)
1007
+
1008
+ y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
1009
+ y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
1010
+ y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
1011
+
1012
+ if y_true_flat.size == 0:
1013
+ self.tune_metric = "f1" if self.tune_metric is None else self.tune_metric
1014
+ return {self.tune_metric: 0.0}
1015
+
1016
+ # ensure valid probability simplex after masking (no NaNs/Infs, sums=1)
1017
+ y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
1018
+ row_sums = y_proba_flat.sum(axis=1, keepdims=True)
1019
+ row_sums[row_sums == 0] = 1.0
1020
+ y_proba_flat = y_proba_flat / row_sums
1021
+
1022
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
1023
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
1024
+
1025
+ if self.is_haploid:
1026
+ y_true_flat = y_true_flat.copy()
1027
+ y_pred_flat = y_pred_flat.copy()
1028
+ y_true_flat[y_true_flat == 2] = 1
1029
+ y_pred_flat[y_pred_flat == 2] = 1
1030
+ # collapse probs to 2-class
1031
+ proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
1032
+ proba_2[:, 0] = y_proba_flat[:, 0]
1033
+ proba_2[:, 1] = y_proba_flat[:, 2]
1034
+ y_proba_flat = proba_2
1035
+
1036
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
1037
+
1038
+ tune_metric_tmp: Literal[
1039
+ "pr_macro",
1040
+ "roc_auc",
1041
+ "average_precision",
1042
+ "accuracy",
1043
+ "f1",
1044
+ "precision",
1045
+ "recall",
1046
+ ]
1047
+ if self.tune_metric_ is not None:
1048
+ tune_metric_tmp = self.tune_metric_
1049
+ else:
1050
+ tune_metric_tmp = "f1" # Default if not tuning
1051
+
1052
+ metrics = self.scorers_.evaluate(
1053
+ y_true_flat,
1054
+ y_pred_flat,
1055
+ y_true_ohe,
1056
+ y_proba_flat,
1057
+ objective_mode,
1058
+ tune_metric_tmp,
1059
+ )
1060
+
1061
+ if not objective_mode:
1062
+ pm = PrettyMetrics(
1063
+ metrics, precision=3, title=f"{self.model_name} Validation Metrics"
1064
+ )
1065
+ pm.render() # prints a command-line table
1066
+
1067
+ # Primary report (REF/HET/ALT or REF/ALT)
1068
+ self._make_class_reports(
1069
+ y_true=y_true_flat,
1070
+ y_pred_proba=y_proba_flat,
1071
+ y_pred=y_pred_flat,
1072
+ metrics=metrics,
1073
+ labels=target_names,
1074
+ )
1075
+
1076
+ # IUPAC decode & 10-base integer reports
1077
+ # Now safe because GT_ref has been sliced to match X_val dimensions
1078
+ y_true_dec = self.pgenc.decode_012(
1079
+ GT_ref.reshape(X_val.shape[0], X_val.shape[1])
1080
+ )
1081
+ X_pred = X_val.copy()
1082
+ X_pred[eval_mask] = y_pred_flat
1083
+
1084
+ # Use X_val.shape[1] (current features) not self.num_features_ (original features)
1085
+ y_pred_dec = self.pgenc.decode_012(
1086
+ X_pred.reshape(X_val.shape[0], X_val.shape[1])
1087
+ )
1088
+
1089
+ encodings_dict = {
1090
+ "A": 0,
1091
+ "C": 1,
1092
+ "G": 2,
1093
+ "T": 3,
1094
+ "W": 4,
1095
+ "R": 5,
1096
+ "M": 6,
1097
+ "K": 7,
1098
+ "Y": 8,
1099
+ "S": 9,
1100
+ "N": -1,
1101
+ }
1102
+ y_true_int = self.pgenc.convert_int_iupac(
1103
+ y_true_dec, encodings_dict=encodings_dict
1104
+ )
1105
+ y_pred_int = self.pgenc.convert_int_iupac(
1106
+ y_pred_dec, encodings_dict=encodings_dict
1107
+ )
1108
+
1109
+ valid_iupac_mask = y_true_int[eval_mask] >= 0
1110
+ if valid_iupac_mask.any():
1111
+ self._make_class_reports(
1112
+ y_true=y_true_int[eval_mask][valid_iupac_mask],
1113
+ y_pred=y_pred_int[eval_mask][valid_iupac_mask],
1114
+ metrics=metrics,
1115
+ y_pred_proba=None,
1116
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
1117
+ )
1118
+ else:
1119
+ self.logger.warning(
1120
+ "Skipped IUPAC confusion matrix: No valid ground truths."
1121
+ )
1122
+
1123
+ return metrics
1124
+
1125
+ def _objective(self, trial: optuna.Trial) -> float:
1126
+ """Optuna objective for AE; mirrors NLPCA study driver without latents.
1127
+
1128
+ This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
1129
+
1130
+ Args:
1131
+ trial (optuna.Trial): Optuna trial.
1132
+
1133
+ Returns:
1134
+ float: Value of the tuning metric (maximize).
1135
+ """
1136
+ try:
1137
+ # Sample hyperparameters (existing helper; unchanged signature)
1138
+ params = self._sample_hyperparameters(trial)
1139
+
1140
+ # Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
1141
+ X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1142
+ X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1143
+
1144
+ class_weights = self._normalize_class_weights(
1145
+ self._class_weights_from_zygosity(X_train)
1146
+ )
1147
+ train_loader = self._get_data_loaders(X_train)
1148
+
1149
+ model = self.build_model(self.Model, params["model_params"])
1150
+ model.apply(self.initialize_weights)
1151
+
1152
+ lr: float = float(params["lr"])
1153
+ l1_penalty: float = float(params["l1_penalty"])
1154
+
1155
+ # Train + prune on metric
1156
+ _, model, __ = self._train_and_validate_model(
1157
+ model=model,
1158
+ loader=train_loader,
1159
+ lr=lr,
1160
+ l1_penalty=l1_penalty,
1161
+ trial=trial,
1162
+ return_history=False,
1163
+ class_weights=class_weights,
1164
+ X_val=X_val,
1165
+ params=params,
1166
+ prune_metric=self.tune_metric,
1167
+ prune_warmup_epochs=10,
1168
+ eval_interval=self.tune_eval_interval,
1169
+ eval_requires_latents=False,
1170
+ eval_latent_steps=0,
1171
+ eval_latent_lr=0.0,
1172
+ eval_latent_weight_decay=0.0,
1173
+ )
1174
+
1175
+ eval_mask = (
1176
+ self.sim_mask_test_
1177
+ if (
1178
+ self.simulate_missing
1179
+ and getattr(self, "sim_mask_test_", None) is not None
1180
+ )
1181
+ else None
1182
+ )
1183
+
1184
+ if model is not None:
1185
+ metrics = self._evaluate_model(
1186
+ X_val,
1187
+ model,
1188
+ params,
1189
+ objective_mode=True,
1190
+ eval_mask_override=eval_mask,
1191
+ )
1192
+ self._clear_resources(model, train_loader)
1193
+ else:
1194
+ raise TypeError("Model training failed; no model was returned.")
1195
+
1196
+ return metrics[self.tune_metric]
1197
+
1198
+ except Exception as e:
1199
+ # Keep sweeps moving if a trial fails
1200
+ raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
1201
+
1202
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
1203
+ """Sample AE hyperparameters and compute hidden sizes for model params.
1204
+
1205
+ This method samples hyperparameters for the autoencoder model using Optuna's trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
1206
+
1207
+ Args:
1208
+ trial (optuna.Trial): Optuna trial object.
1209
+
1210
+ Returns:
1211
+ Dict[str, int | float | str | bool]: Sampled hyperparameters and model_params.
1212
+ """
1213
+ params = {
1214
+ "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
1215
+ "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
1216
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
1217
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
1218
+ "activation": trial.suggest_categorical(
1219
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
1220
+ ),
1221
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1222
+ "layer_scaling_factor": trial.suggest_float(
1223
+ "layer_scaling_factor", 2.0, 4.0, step=0.5
1224
+ ),
1225
+ "layer_schedule": trial.suggest_categorical(
1226
+ "layer_schedule", ["pyramid", "linear"]
1227
+ ),
1228
+ }
1229
+
1230
+ nF: int = self.num_features_
1231
+ nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1232
+ input_dim = nF * nC
1233
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1234
+ n_inputs=input_dim,
1235
+ n_outputs=input_dim,
1236
+ n_samples=len(self.train_idx_),
1237
+ n_hidden=params["num_hidden_layers"],
1238
+ alpha=params["layer_scaling_factor"],
1239
+ schedule=params["layer_schedule"],
1240
+ )
1241
+
1242
+ # Keep the latent_dim as the first element,
1243
+ # then the interior hidden widths.
1244
+ # If there are no interior widths (very small nets),
1245
+ # this still leaves [latent_dim].
1246
+ hidden_only: list[int] = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1247
+
1248
+ params["model_params"] = {
1249
+ "n_features": int(self.num_features_),
1250
+ "num_classes": int(
1251
+ getattr(self, "output_classes_", self.num_classes_ or 3)
1252
+ ),
1253
+ "latent_dim": int(params["latent_dim"]),
1254
+ "dropout_rate": float(params["dropout_rate"]),
1255
+ "hidden_layer_sizes": hidden_only,
1256
+ "activation": str(params["activation"]),
1257
+ }
1258
+ return params
1259
+
1260
+ def _set_best_params(
1261
+ self, best_params: Dict[str, int | float | str | List[int]]
1262
+ ) -> Dict[str, int | float | str | List[int]]:
1263
+ """Adopt best params (ImputeNLPCA parity) and return model_params.
1264
+
1265
+ This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
1266
+
1267
+ Args:
1268
+ best_params (Dict[str, int | float | str | List[int]]): Best hyperparameters from tuning.
1269
+
1270
+ Returns:
1271
+ Dict[str, int | float | str | List[int]]: Model parameters for building the model.
1272
+ """
1273
+ bp = {}
1274
+ for k, v in best_params.items():
1275
+ if not isinstance(v, list):
1276
+ if k in {"latent_dim", "num_hidden_layers"}:
1277
+ bp[k] = int(v)
1278
+ elif k in {
1279
+ "dropout_rate",
1280
+ "learning_rate",
1281
+ "l1_penalty",
1282
+ "layer_scaling_factor",
1283
+ }:
1284
+ bp[k] = float(v)
1285
+ elif k in {"activation", "layer_schedule"}:
1286
+ if k == "layer_schedule":
1287
+ if v not in {"pyramid", "constant", "linear"}:
1288
+ raise ValueError(f"Invalid layer_schedule: {v}")
1289
+ bp[k] = v
1290
+ else:
1291
+ bp[k] = str(v)
1292
+ else:
1293
+ bp[k] = v # keep lists as-is
1294
+
1295
+ self.latent_dim: int = bp["latent_dim"]
1296
+ self.dropout_rate: float = bp["dropout_rate"]
1297
+ self.learning_rate: float = bp["learning_rate"]
1298
+ self.l1_penalty: float = bp["l1_penalty"]
1299
+ self.activation: str = bp["activation"]
1300
+ self.layer_scaling_factor: float = bp["layer_scaling_factor"]
1301
+ self.layer_schedule: str = bp["layer_schedule"]
1302
+
1303
+ nF: int = self.num_features_
1304
+ nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1305
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1306
+ n_inputs=nF * nC,
1307
+ n_outputs=nF * nC,
1308
+ n_samples=len(self.train_idx_),
1309
+ n_hidden=bp["num_hidden_layers"],
1310
+ alpha=bp["layer_scaling_factor"],
1311
+ schedule=bp["layer_schedule"],
1312
+ )
1313
+
1314
+ # Keep the latent_dim as the first element,
1315
+ # then the interior hidden widths.
1316
+ # If there are no interior widths (very small nets),
1317
+ # this still leaves [latent_dim].
1318
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1319
+
1320
+ return {
1321
+ "n_features": self.num_features_,
1322
+ "latent_dim": self.latent_dim,
1323
+ "hidden_layer_sizes": hidden_only,
1324
+ "dropout_rate": self.dropout_rate,
1325
+ "activation": self.activation,
1326
+ "num_classes": nC,
1327
+ }
1328
+
1329
+ def _default_best_params(self) -> Dict[str, int | float | str | list]:
1330
+ """Default model params when tuning is disabled.
1331
+
1332
+ This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
1333
+
1334
+ Returns:
1335
+ Dict[str, int | float | str | list]: Default model parameters.
1336
+ """
1337
+ nF: int = self.num_features_
1338
+ # Use the number of output channels passed to the model (2 for diploid multilabel)
1339
+ # instead of the scoring classes (3) to keep layer shapes aligned.
1340
+ nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1341
+ ls = self.layer_schedule
1342
+
1343
+ if ls not in {"pyramid", "constant", "linear"}:
1344
+ raise ValueError(f"Invalid layer_schedule: {ls}")
1345
+
1346
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1347
+ n_inputs=nF * nC,
1348
+ n_outputs=nF * nC,
1349
+ n_samples=len(self.ground_truth_),
1350
+ n_hidden=self.num_hidden_layers,
1351
+ alpha=self.layer_scaling_factor,
1352
+ schedule=ls,
1353
+ )
1354
+ return {
1355
+ "n_features": self.num_features_,
1356
+ "latent_dim": self.latent_dim,
1357
+ "hidden_layer_sizes": hidden_layer_sizes,
1358
+ "dropout_rate": self.dropout_rate,
1359
+ "activation": self.activation,
1360
+ "num_classes": nC,
1361
+ }