pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,1285 @@
1
+ import copy
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import optuna
7
+ import torch
8
+ from sklearn.exceptions import NotFittedError
9
+ from sklearn.model_selection import train_test_split
10
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
11
+ from snpio.utils.logging import LoggerManager
12
+ from torch.optim.lr_scheduler import CosineAnnealingLR
13
+
14
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
15
+ from pgsui.data_processing.containers import AutoencoderConfig
16
+ from pgsui.data_processing.transformers import SimMissingTransformer
17
+ from pgsui.impute.unsupervised.base import BaseNNImputer
18
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
19
+ from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
20
+ from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
21
+ from pgsui.utils.logging_utils import configure_logger
22
+ from pgsui.utils.pretty_metrics import PrettyMetrics
23
+
24
+ if TYPE_CHECKING:
25
+ from snpio import TreeParser
26
+ from snpio.read_input.genotype_data import GenotypeData
27
+
28
+
29
+ def ensure_autoencoder_config(
30
+ config: AutoencoderConfig | dict | str | None,
31
+ ) -> AutoencoderConfig:
32
+ """Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
33
+
34
+ This method normalizes the configuration input for the Autoencoder imputer. It accepts a structured configuration in various formats, including a dataclass instance, a nested dictionary, a YAML file path, or None. The method processes the input accordingly and returns a concrete instance of AutoencoderConfig with all necessary fields populated.
35
+
36
+ Args:
37
+ config (AutoencoderConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
38
+
39
+ Returns:
40
+ AutoencoderConfig: Concrete configuration instance.
41
+ """
42
+ if config is None:
43
+ return AutoencoderConfig()
44
+ if isinstance(config, AutoencoderConfig):
45
+ return config
46
+ if isinstance(config, str):
47
+ # YAML path — top-level `preset` key is supported
48
+ return load_yaml_to_dataclass(config, AutoencoderConfig)
49
+ if isinstance(config, dict):
50
+ # Flatten dict into dot-keys then overlay onto a fresh instance
51
+ base = AutoencoderConfig()
52
+
53
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
54
+ for k, v in d.items():
55
+ kk = f"{prefix}.{k}" if prefix else k
56
+ if isinstance(v, dict):
57
+ _flatten(kk, v, out)
58
+ else:
59
+ out[kk] = v
60
+ return out
61
+
62
+ # Lift any present preset first
63
+ preset_name = config.pop("preset", None)
64
+ if "io" in config and isinstance(config["io"], dict):
65
+ preset_name = preset_name or config["io"].pop("preset", None)
66
+
67
+ if preset_name:
68
+ base = AutoencoderConfig.from_preset(preset_name)
69
+
70
+ flat = _flatten("", config, {})
71
+ return apply_dot_overrides(base, flat)
72
+
73
+ raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
74
+
75
+
76
+ class ImputeAutoencoder(BaseNNImputer):
77
+ """Impute missing genotypes with a standard Autoencoder on 0/1/2 encodings.
78
+
79
+ This imputer uses a feedforward autoencoder architecture to learn compressed and reconstructive representations of genotype data encoded as 0 (homozygous reference), 1 (heterozygous), and 2 (homozygous alternate). Missing genotypes are represented as -1 during training and imputation.
80
+
81
+ The model is trained to minimize a focal cross-entropy loss, which helps to address class imbalance by focusing more on hard-to-classify examples. The architecture includes configurable parameters such as the number of hidden layers, latent dimension size, dropout rate, and activation functions.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ genotype_data: "GenotypeData",
87
+ *,
88
+ tree_parser: Optional["TreeParser"] = None,
89
+ config: Optional[Union["AutoencoderConfig", dict, str]] = None,
90
+ overrides: dict | None = None,
91
+ simulate_missing: bool | None = None,
92
+ sim_strategy: (
93
+ Literal[
94
+ "random",
95
+ "random_weighted",
96
+ "random_weighted_inv",
97
+ "nonrandom",
98
+ "nonrandom_weighted",
99
+ ]
100
+ | None
101
+ ) = None,
102
+ sim_prop: float | None = None,
103
+ sim_kwargs: dict | None = None,
104
+ ) -> None:
105
+ """Initialize the Autoencoder imputer with a unified config interface.
106
+
107
+ This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
108
+
109
+ Args:
110
+ genotype_data ("GenotypeData"): Backing genotype data object.
111
+ tree_parser (Optional["TreeParser"]): Optional SNPio phylogenetic tree parser for population-specific modes.
112
+ config (Union["AutoencoderConfig", dict, str] | None): Structured configuration as dataclass, nested dict, YAML path, or None.
113
+ overrides (dict | None): Optional dot-key overrides with highest precedence (e.g., {'model.latent_dim': 32}).
114
+ simulate_missing (bool | None): Whether to simulate missing data during evaluation. If None, uses config default.
115
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
116
+ sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
117
+ sim_kwargs (dict | None): Additional keyword arguments for simulating missing data. If None, uses config default.
118
+ """
119
+ self.model_name = "ImputeAutoencoder"
120
+ self.genotype_data = genotype_data
121
+ self.tree_parser = tree_parser
122
+
123
+ # Normalize config then apply highest-precedence overrides
124
+ cfg = ensure_autoencoder_config(config)
125
+ if overrides:
126
+ cfg = apply_dot_overrides(cfg, overrides)
127
+ self.cfg = cfg
128
+
129
+ # Logger consistent with NLPCA
130
+ logman = LoggerManager(
131
+ __name__,
132
+ prefix=self.cfg.io.prefix,
133
+ debug=self.cfg.io.debug,
134
+ verbose=self.cfg.io.verbose,
135
+ )
136
+ self.logger = configure_logger(
137
+ logman.get_logger(),
138
+ verbose=self.cfg.io.verbose,
139
+ debug=self.cfg.io.debug,
140
+ )
141
+
142
+ # BaseNNImputer bootstrapping (device/dirs/logging handled here)
143
+ super().__init__(
144
+ model_name=self.model_name,
145
+ genotype_data=self.genotype_data,
146
+ prefix=self.cfg.io.prefix,
147
+ device=self.cfg.train.device,
148
+ verbose=self.cfg.io.verbose,
149
+ debug=self.cfg.io.debug,
150
+ )
151
+
152
+ self.Model = AutoencoderModel
153
+
154
+ # Model hook & encoder
155
+ self.pgenc = GenotypeEncoder(genotype_data)
156
+
157
+ # IO / global
158
+ self.seed = self.cfg.io.seed
159
+ self.n_jobs = self.cfg.io.n_jobs
160
+ self.prefix = self.cfg.io.prefix
161
+ self.scoring_averaging = self.cfg.io.scoring_averaging
162
+ self.verbose = self.cfg.io.verbose
163
+ self.debug = self.cfg.io.debug
164
+ self.rng = np.random.default_rng(self.seed)
165
+
166
+ # Simulated-missing controls (config defaults with ctor overrides)
167
+ sim_cfg = getattr(self.cfg, "sim", None)
168
+ sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
169
+ if sim_kwargs:
170
+ sim_cfg_kwargs.update(sim_kwargs)
171
+ self.simulate_missing = (
172
+ (
173
+ sim_cfg.simulate_missing
174
+ if simulate_missing is None
175
+ else bool(simulate_missing)
176
+ )
177
+ if sim_cfg is not None
178
+ else bool(simulate_missing)
179
+ )
180
+ if sim_cfg is None:
181
+ default_strategy = "random"
182
+ default_prop = 0.10
183
+ else:
184
+ default_strategy = sim_cfg.sim_strategy
185
+ default_prop = sim_cfg.sim_prop
186
+ self.sim_strategy = sim_strategy or default_strategy
187
+ self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
188
+ self.sim_kwargs = sim_cfg_kwargs
189
+
190
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
191
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
192
+ self.logger.error(msg)
193
+ raise ValueError(msg)
194
+
195
+ # Model hyperparams
196
+ self.latent_dim = int(self.cfg.model.latent_dim)
197
+ self.dropout_rate = float(self.cfg.model.dropout_rate)
198
+ self.num_hidden_layers = int(self.cfg.model.num_hidden_layers)
199
+ self.layer_scaling_factor = float(self.cfg.model.layer_scaling_factor)
200
+ self.layer_schedule: str = str(self.cfg.model.layer_schedule)
201
+ self.activation = str(self.cfg.model.hidden_activation)
202
+ self.gamma = float(self.cfg.model.gamma)
203
+
204
+ # Train hyperparams
205
+ self.batch_size = int(self.cfg.train.batch_size)
206
+ self.learning_rate = float(self.cfg.train.learning_rate)
207
+ self.l1_penalty: float = float(self.cfg.train.l1_penalty)
208
+ self.early_stop_gen = int(self.cfg.train.early_stop_gen)
209
+ self.min_epochs = int(self.cfg.train.min_epochs)
210
+ self.epochs = int(self.cfg.train.max_epochs)
211
+ self.validation_split = float(self.cfg.train.validation_split)
212
+ self.beta = float(self.cfg.train.weights_beta)
213
+ self.max_ratio = float(self.cfg.train.weights_max_ratio)
214
+
215
+ # Tuning
216
+ self.tune = bool(self.cfg.tune.enabled)
217
+ self.tune_fast = bool(self.cfg.tune.fast)
218
+ self.tune_batch_size = int(self.cfg.tune.batch_size)
219
+ self.tune_epochs = int(self.cfg.tune.epochs)
220
+ self.tune_eval_interval = int(self.cfg.tune.eval_interval)
221
+ self.tune_metric: str = self.cfg.tune.metric
222
+
223
+ if self.tune_metric is not None:
224
+ self.tune_metric_: (
225
+ Literal[
226
+ "pr_macro",
227
+ "f1",
228
+ "accuracy",
229
+ "precision",
230
+ "recall",
231
+ "roc_auc",
232
+ "average_precision",
233
+ ]
234
+ | None
235
+ ) = self.cfg.tune.metric
236
+
237
+ self.n_trials = int(self.cfg.tune.n_trials)
238
+ self.tune_save_db = bool(self.cfg.tune.save_db)
239
+ self.tune_resume = bool(self.cfg.tune.resume)
240
+ self.tune_max_samples = int(self.cfg.tune.max_samples)
241
+ self.tune_max_loci = int(self.cfg.tune.max_loci)
242
+ self.tune_infer_epochs = int(
243
+ getattr(self.cfg.tune, "infer_epochs", 0)
244
+ ) # AE unused
245
+ self.tune_patience = int(self.cfg.tune.patience)
246
+
247
+ # Evaluate
248
+ # AE does not optimize latents, so these are unused / fixed
249
+ self.eval_latent_steps: int = 0
250
+ self.eval_latent_lr: float = 0.0
251
+ self.eval_latent_weight_decay: float = 0.0
252
+
253
+ # Plotting (parity with NLPCA PlotConfig)
254
+ self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = (
255
+ self.cfg.plot.fmt
256
+ )
257
+ self.plot_dpi = int(self.cfg.plot.dpi)
258
+ self.plot_fontsize = int(self.cfg.plot.fontsize)
259
+ self.title_fontsize = int(self.cfg.plot.fontsize)
260
+ self.despine = bool(self.cfg.plot.despine)
261
+ self.show_plots = bool(self.cfg.plot.show)
262
+
263
+ # Core derived at fit-time
264
+ self.is_haploid: bool = False
265
+ self.num_classes_: int | None = None
266
+ self.model_params: Dict[str, Any] = {}
267
+ self.sim_mask_global_: np.ndarray | None = None
268
+ self.sim_mask_train_: np.ndarray | None = None
269
+ self.sim_mask_test_: np.ndarray | None = None
270
+
271
+ def fit(self) -> "ImputeAutoencoder":
272
+ """Fit the autoencoder on 0/1/2 encoded genotypes (missing -> -1).
273
+
274
+ This method trains the autoencoder model using the provided genotype data. It prepares the data by encoding genotypes as 0, 1, and 2, with missing values represented internally as -1. (When simulated-missing loci are generated via ``SimMissingTransformer`` they are first marked with -9 but are immediately re-encoded as -1 prior to training.) The method splits the data into training and validation sets, initializes the model and training parameters, and performs training with optional hyperparameter tuning. After training, it evaluates the model on the validation set and stores the fitted model and training history.
275
+
276
+ Returns:
277
+ ImputeAutoencoder: Fitted instance.
278
+
279
+ Raises:
280
+ NotFittedError: If training fails.
281
+ """
282
+ self.logger.info(f"Fitting {self.model_name} model...")
283
+
284
+ # --- Data prep (mirror NLPCA) ---
285
+ X012 = self._get_float_genotypes(copy=True)
286
+ GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
287
+ self.ground_truth_ = GT_full.astype(np.int64, copy=False)
288
+
289
+ self.sim_mask_global_ = None
290
+ cache_key = self._sim_mask_cache_key()
291
+ if self.simulate_missing:
292
+ cached_mask = (
293
+ None if cache_key is None else self._sim_mask_cache.get(cache_key)
294
+ )
295
+ if cached_mask is not None:
296
+ self.sim_mask_global_ = cached_mask.copy()
297
+ else:
298
+ tr = SimMissingTransformer(
299
+ genotype_data=self.genotype_data,
300
+ tree_parser=self.tree_parser,
301
+ prop_missing=self.sim_prop,
302
+ strategy=self.sim_strategy,
303
+ missing_val=-9,
304
+ mask_missing=True,
305
+ verbose=self.verbose,
306
+ **self.sim_kwargs,
307
+ )
308
+ tr.fit(X012.copy())
309
+ self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
310
+ if cache_key is not None:
311
+ self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
312
+
313
+ X_for_model = self.ground_truth_.copy()
314
+ X_for_model[self.sim_mask_global_] = -1
315
+ else:
316
+ X_for_model = self.ground_truth_.copy()
317
+
318
+ if self.genotype_data.snp_data is None:
319
+ msg = "SNP data is required for Autoencoder imputer."
320
+ self.logger.error(msg)
321
+ raise TypeError(msg)
322
+
323
+ # Ploidy & classes
324
+ self.is_haploid = bool(
325
+ np.all(
326
+ np.isin(
327
+ self.genotype_data.snp_data,
328
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
329
+ )
330
+ )
331
+ )
332
+ self.ploidy = 1 if self.is_haploid else 2
333
+ self.num_classes_ = 2 if self.is_haploid else 3
334
+ self.logger.info(
335
+ f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
336
+ f"using {self.num_classes_} classes."
337
+ )
338
+
339
+ if self.is_haploid:
340
+ self.ground_truth_[self.ground_truth_ == 2] = 1
341
+ X_for_model[X_for_model == 2] = 1
342
+
343
+ n_samples, self.num_features_ = X_for_model.shape
344
+
345
+ # Model params (decoder outputs L * K logits)
346
+ self.model_params = {
347
+ "n_features": self.num_features_,
348
+ "num_classes": self.num_classes_,
349
+ "latent_dim": self.latent_dim,
350
+ "dropout_rate": self.dropout_rate,
351
+ "activation": self.activation,
352
+ }
353
+
354
+ # Train/Val split
355
+ indices = np.arange(n_samples)
356
+ train_idx, val_idx = train_test_split(
357
+ indices, test_size=self.validation_split, random_state=self.seed
358
+ )
359
+ self.train_idx_, self.test_idx_ = train_idx, val_idx
360
+ self.X_train_ = X_for_model[train_idx]
361
+ self.X_val_ = X_for_model[val_idx]
362
+ self.GT_train_full_ = self.ground_truth_[train_idx]
363
+ self.GT_test_full_ = self.ground_truth_[val_idx]
364
+
365
+ if self.sim_mask_global_ is not None:
366
+ self.sim_mask_train_ = self.sim_mask_global_[train_idx]
367
+ self.sim_mask_test_ = self.sim_mask_global_[val_idx]
368
+ else:
369
+ self.sim_mask_train_ = None
370
+ self.sim_mask_test_ = None
371
+
372
+ # Plotters/scorers (shared utilities)
373
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
374
+
375
+ # Tuning (optional; AE never needs latent refinement)
376
+ if self.tune:
377
+ self.tune_hyperparameters()
378
+
379
+ # Best params (tuned or default)
380
+ self.best_params_ = getattr(self, "best_params_", self._default_best_params())
381
+
382
+ # Class weights (device-aware)
383
+ self.class_weights_ = self._normalize_class_weights(
384
+ self._class_weights_from_zygosity(self.X_train_)
385
+ )
386
+
387
+ # DataLoader
388
+ train_loader = self._get_data_loaders(self.X_train_)
389
+
390
+ # Build & train
391
+ model = self.build_model(self.Model, self.best_params_)
392
+ model.apply(self.initialize_weights)
393
+
394
+ loss, trained_model, history = self._train_and_validate_model(
395
+ model=model,
396
+ loader=train_loader,
397
+ lr=self.learning_rate,
398
+ l1_penalty=self.l1_penalty,
399
+ return_history=True,
400
+ class_weights=self.class_weights_,
401
+ X_val=self.X_val_,
402
+ params=self.best_params_,
403
+ prune_metric=self.tune_metric,
404
+ prune_warmup_epochs=5,
405
+ eval_interval=1,
406
+ eval_requires_latents=False,
407
+ eval_latent_steps=0,
408
+ eval_latent_lr=0.0,
409
+ eval_latent_weight_decay=0.0,
410
+ )
411
+
412
+ if trained_model is None:
413
+ msg = "Autoencoder training failed; no model was returned."
414
+ self.logger.error(msg)
415
+ raise RuntimeError(msg)
416
+
417
+ torch.save(
418
+ trained_model.state_dict(),
419
+ self.models_dir / f"final_model_{self.model_name}.pt",
420
+ )
421
+
422
+ hist: Dict[str, List[float] | Dict[str, List[float]] | None] | None = {
423
+ "Train": history
424
+ }
425
+ self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
426
+ self.is_fit_ = True
427
+
428
+ # Evaluate on validation set (parity with NLPCA reporting)
429
+ eval_mask = (
430
+ self.sim_mask_test_
431
+ if (self.simulate_missing and self.sim_mask_test_ is not None)
432
+ else None
433
+ )
434
+ self._evaluate_model(
435
+ self.X_val_, self.model_, self.best_params_, eval_mask_override=eval_mask
436
+ )
437
+ self.plotter_.plot_history(self.history_)
438
+ self._save_best_params(self.best_params_)
439
+
440
+ return self
441
+
442
+ def transform(self) -> np.ndarray:
443
+ """Impute missing genotypes (0/1/2) and return IUPAC strings.
444
+
445
+ This method imputes missing genotypes in the dataset using the trained autoencoder model. It predicts the most likely genotype (0, 1, or 2) for each missing entry and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easier interpretation.
446
+
447
+ Returns:
448
+ np.ndarray: IUPAC strings of shape (n_samples, n_loci).
449
+
450
+ Raises:
451
+ NotFittedError: If called before fit().
452
+ """
453
+ if not getattr(self, "is_fit_", False):
454
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
455
+
456
+ self.logger.info(f"Imputing entire dataset with {self.model_name}...")
457
+ X_to_impute = self.ground_truth_.copy()
458
+
459
+ # Predict with masked inputs (no latent optimization)
460
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
461
+
462
+ # Fill only missing
463
+ missing_mask = X_to_impute == -1
464
+ imputed_array = X_to_impute.copy()
465
+ imputed_array[missing_mask] = pred_labels[missing_mask]
466
+
467
+ # Decode to IUPAC & optionally plot
468
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
469
+ if self.show_plots:
470
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
471
+ plt.rcParams.update(self.plotter_.param_dict)
472
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
473
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
474
+
475
+ return imputed_genotypes
476
+
477
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
478
+ """Create DataLoader over indices + integer targets (-1 for missing).
479
+
480
+ This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
481
+
482
+ Args:
483
+ y (np.ndarray): 0/1/2 matrix with -1 for missing.
484
+
485
+ Returns:
486
+ torch.utils.data.DataLoader: Shuffled DataLoader.
487
+ """
488
+ y_tensor = torch.from_numpy(y).long()
489
+ indices = torch.arange(len(y), dtype=torch.long)
490
+ dataset = torch.utils.data.TensorDataset(indices, y_tensor)
491
+ pin_memory = self.device.type == "cuda"
492
+ return torch.utils.data.DataLoader(
493
+ dataset,
494
+ batch_size=self.batch_size,
495
+ shuffle=True,
496
+ pin_memory=pin_memory,
497
+ )
498
+
499
+ def _train_and_validate_model(
500
+ self,
501
+ model: torch.nn.Module,
502
+ loader: torch.utils.data.DataLoader,
503
+ lr: float,
504
+ l1_penalty: float,
505
+ trial: optuna.Trial | None = None,
506
+ return_history: bool = False,
507
+ class_weights: torch.Tensor | None = None,
508
+ *,
509
+ X_val: np.ndarray | None = None,
510
+ params: dict | None = None,
511
+ prune_metric: str = "f1", # "f1" | "accuracy" | "pr_macro"
512
+ prune_warmup_epochs: int = 3,
513
+ eval_interval: int = 1,
514
+ # Evaluation parameters (AE ignores latent refinement knobs)
515
+ eval_requires_latents: bool = False, # AE: always False
516
+ eval_latent_steps: int = 0,
517
+ eval_latent_lr: float = 0.0,
518
+ eval_latent_weight_decay: float = 0.0,
519
+ ) -> Tuple[float, torch.nn.Module | None, list | None]:
520
+ """Wrap the AE training loop (no latent optimizer), with Optuna pruning.
521
+
522
+ This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
523
+
524
+ Args:
525
+ model (torch.nn.Module): Autoencoder model.
526
+ loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
527
+ lr (float): Learning rate.
528
+ l1_penalty (float): L1 regularization coeff.
529
+ trial (optuna.Trial | None): Optuna trial for pruning (optional).
530
+ return_history (bool): If True, return train loss history.
531
+ class_weights (torch.Tensor | None): Class weights tensor (on device).
532
+ X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
533
+ params (dict | None): Model params for evaluation.
534
+ prune_metric (str): Metric for pruning reports.
535
+ prune_warmup_epochs (int): Pruning warmup epochs.
536
+ eval_interval (int): Eval frequency (epochs).
537
+ eval_requires_latents (bool): Ignored for AE (no latent inference).
538
+ eval_latent_steps (int): Unused for AE.
539
+ eval_latent_lr (float): Unused for AE.
540
+ eval_latent_weight_decay (float): Unused for AE.
541
+
542
+ Returns:
543
+ Tuple[float, torch.nn.Module | None, list | None]: (best_loss, best_model, history or None).
544
+ """
545
+ if class_weights is None:
546
+ msg = "Must provide class_weights."
547
+ self.logger.error(msg)
548
+ raise TypeError(msg)
549
+
550
+ # Epoch budget mirrors NLPCA config (tuning vs final)
551
+ max_epochs = (
552
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
553
+ )
554
+
555
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
556
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
557
+
558
+ best_loss, best_model, hist = self._execute_training_loop(
559
+ loader=loader,
560
+ optimizer=optimizer,
561
+ scheduler=scheduler,
562
+ model=model,
563
+ l1_penalty=l1_penalty,
564
+ trial=trial,
565
+ return_history=return_history,
566
+ class_weights=class_weights,
567
+ X_val=X_val,
568
+ params=params,
569
+ prune_metric=prune_metric,
570
+ prune_warmup_epochs=prune_warmup_epochs,
571
+ eval_interval=eval_interval,
572
+ eval_requires_latents=False, # AE: no latent inference
573
+ eval_latent_steps=0,
574
+ eval_latent_lr=0.0,
575
+ eval_latent_weight_decay=0.0,
576
+ )
577
+ if return_history:
578
+ return best_loss, best_model, hist
579
+
580
+ return best_loss, best_model, None
581
+
582
+ def _execute_training_loop(
583
+ self,
584
+ loader: torch.utils.data.DataLoader,
585
+ optimizer: torch.optim.Optimizer,
586
+ scheduler: CosineAnnealingLR,
587
+ model: torch.nn.Module,
588
+ l1_penalty: float,
589
+ trial: optuna.Trial | None,
590
+ return_history: bool,
591
+ class_weights: torch.Tensor,
592
+ *,
593
+ X_val: np.ndarray | None = None,
594
+ params: dict | None = None,
595
+ prune_metric: str = "f1",
596
+ prune_warmup_epochs: int = 3,
597
+ eval_interval: int = 1,
598
+ # Evaluation parameters (AE ignores latent refinement knobs)
599
+ eval_requires_latents: bool = False, # AE: False
600
+ eval_latent_steps: int = 0,
601
+ eval_latent_lr: float = 0.0,
602
+ eval_latent_weight_decay: float = 0.0,
603
+ ) -> Tuple[float, torch.nn.Module, list]:
604
+ """Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
605
+
606
+ This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
607
+
608
+ Args:
609
+ loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
610
+ optimizer (torch.optim.Optimizer): Optimizer.
611
+ scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
612
+ model (torch.nn.Module): Autoencoder model.
613
+ l1_penalty (float): L1 regularization coeff.
614
+ trial (optuna.Trial | None): Optuna trial for pruning (optional).
615
+ return_history (bool): If True, return train loss history.
616
+ class_weights (torch.Tensor): Class weights tensor (on device).
617
+ X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
618
+ params (dict | None): Model params for evaluation.
619
+ prune_metric (str): Metric for pruning reports.
620
+ prune_warmup_epochs (int): Pruning warmup epochs.
621
+ eval_interval (int): Eval frequency (epochs).
622
+ eval_requires_latents (bool): Ignored for AE (no latent inference).
623
+ eval_latent_steps (int): Unused for AE.
624
+ eval_latent_lr (float): Unused for AE.
625
+ eval_latent_weight_decay (float): Unused for AE.
626
+
627
+ Returns:
628
+ Tuple[float, torch.nn.Module, list]: Best validation loss, best model, and training history.
629
+ """
630
+ best_loss = float("inf")
631
+ best_model = None
632
+ history: list[float] = []
633
+
634
+ early_stopping = EarlyStopping(
635
+ patience=self.early_stop_gen,
636
+ min_epochs=self.min_epochs,
637
+ verbose=self.verbose,
638
+ prefix=self.prefix,
639
+ debug=self.debug,
640
+ )
641
+
642
+ gamma_val = self.gamma
643
+ if isinstance(gamma_val, (list, tuple)):
644
+ if len(gamma_val) == 0:
645
+ raise ValueError("gamma list is empty.")
646
+ gamma_val = gamma_val[0]
647
+
648
+ gamma_final = float(gamma_val)
649
+ gamma_warm, gamma_ramp = 50, 100
650
+
651
+ # Optional LR warmup
652
+ warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
653
+ base_lr = float(optimizer.param_groups[0]["lr"])
654
+ min_lr = base_lr * 0.1
655
+
656
+ max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
657
+
658
+ for epoch in range(max_epochs):
659
+ # focal γ schedule (for stable training)
660
+ if epoch < gamma_warm:
661
+ model.gamma = 0.0 # type: ignore
662
+ elif epoch < gamma_warm + gamma_ramp:
663
+ model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore
664
+ else:
665
+ model.gamma = gamma_final # type: ignore
666
+
667
+ # LR warmup
668
+ if epoch < warmup_epochs:
669
+ scale = float(epoch + 1) / warmup_epochs
670
+ for g in optimizer.param_groups:
671
+ g["lr"] = min_lr + (base_lr - min_lr) * scale
672
+
673
+ train_loss = self._train_step(
674
+ loader=loader,
675
+ optimizer=optimizer,
676
+ model=model,
677
+ l1_penalty=l1_penalty,
678
+ class_weights=class_weights,
679
+ )
680
+
681
+ # Abort or prune on non-finite epoch loss
682
+ if not np.isfinite(train_loss):
683
+ if trial is not None:
684
+ raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
685
+ # Soft reset suggestion: reduce LR and continue, or break
686
+ self.logger.warning(
687
+ "Non-finite epoch loss. Reducing LR by 10 percent and continuing."
688
+ )
689
+ for g in optimizer.param_groups:
690
+ g["lr"] *= 0.9
691
+ continue
692
+
693
+ scheduler.step()
694
+ if return_history:
695
+ history.append(train_loss)
696
+
697
+ early_stopping(train_loss, model)
698
+ if early_stopping.early_stop:
699
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
700
+ break
701
+
702
+ # Optuna report/prune on validation metric
703
+ if (
704
+ trial is not None
705
+ and X_val is not None
706
+ and ((epoch + 1) % eval_interval == 0)
707
+ ):
708
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
709
+ mask_override = None
710
+ if (
711
+ self.simulate_missing
712
+ and getattr(self, "sim_mask_test_", None) is not None
713
+ and getattr(self, "X_val_", None) is not None
714
+ and X_val.shape == self.X_val_.shape
715
+ ):
716
+ mask_override = self.sim_mask_test_
717
+ metric_val = self._eval_for_pruning(
718
+ model=model,
719
+ X_val=X_val,
720
+ params=params or getattr(self, "best_params_", {}),
721
+ metric=metric_key,
722
+ objective_mode=True,
723
+ do_latent_infer=False, # AE: False
724
+ latent_steps=0,
725
+ latent_lr=0.0,
726
+ latent_weight_decay=0.0,
727
+ latent_seed=self.seed, # type: ignore
728
+ _latent_cache=None, # AE: not used
729
+ _latent_cache_key=None,
730
+ eval_mask_override=mask_override,
731
+ )
732
+ trial.report(metric_val, step=epoch + 1)
733
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
734
+ raise optuna.exceptions.TrialPruned(
735
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
736
+ )
737
+
738
+ best_loss = early_stopping.best_score
739
+ if early_stopping.best_model is not None:
740
+ best_model = copy.deepcopy(early_stopping.best_model)
741
+ else:
742
+ best_model = copy.deepcopy(model)
743
+ return best_loss, best_model, history
744
+
745
+ def _train_step(
746
+ self,
747
+ loader: torch.utils.data.DataLoader,
748
+ optimizer: torch.optim.Optimizer,
749
+ model: torch.nn.Module,
750
+ l1_penalty: float,
751
+ class_weights: torch.Tensor,
752
+ ) -> float:
753
+ """One epoch with stable focal CE and NaN/Inf guards."""
754
+ model.train()
755
+ running = 0.0
756
+ num_batches = 0
757
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
758
+ if class_weights is not None and class_weights.device != self.device:
759
+ class_weights = class_weights.to(self.device)
760
+
761
+ # Use model.gamma if present, else self.gamma
762
+ gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
763
+ gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
764
+ criterion = SafeFocalCELoss(gamma=gamma, weight=class_weights, ignore_index=-1)
765
+
766
+ for _, y_batch in loader:
767
+ optimizer.zero_grad(set_to_none=True)
768
+ y_batch = y_batch.to(self.device, non_blocking=True)
769
+
770
+ # Inputs: one-hot with zeros for missing; Targets: long ints with -1 for missing
771
+ x_ohe = self._one_hot_encode_012(y_batch) # (B, L, K)
772
+ logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
773
+ logits_flat = logits.view(-1, self.num_classes_)
774
+ targets_flat = y_batch.view(-1).long()
775
+
776
+ # Upfront guards on inputs
777
+ if not torch.isfinite(logits_flat).all():
778
+ # Skip this batch if model already produced non-finite
779
+ continue
780
+
781
+ loss = criterion(logits_flat, targets_flat)
782
+
783
+ if l1_penalty > 0:
784
+ l1 = torch.zeros((), device=self.device)
785
+ for p in l1_params:
786
+ l1 = l1 + p.abs().sum()
787
+ loss = loss + l1_penalty * l1
788
+
789
+ # Final guard
790
+ if not torch.isfinite(loss):
791
+ continue
792
+
793
+ loss.backward()
794
+
795
+ # Clip to prevent exploding grads
796
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
797
+
798
+ # If grads blew up to non-finite, skip update
799
+ if any(
800
+ (not torch.isfinite(p.grad).all())
801
+ for p in model.parameters()
802
+ if p.grad is not None
803
+ ):
804
+ optimizer.zero_grad(set_to_none=True)
805
+ continue
806
+
807
+ optimizer.step()
808
+
809
+ running += float(loss.detach().item())
810
+ num_batches += 1
811
+
812
+ if num_batches == 0:
813
+ return float("inf") # signal upstream that epoch had no usable batches
814
+ return running / num_batches
815
+
816
+ def _predict(
817
+ self,
818
+ model: torch.nn.Module,
819
+ X: np.ndarray | torch.Tensor,
820
+ return_proba: bool = False,
821
+ ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
822
+ """Predict 0/1/2 labels (and probabilities) from masked inputs.
823
+
824
+ This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
825
+
826
+ Args:
827
+ model (torch.nn.Module): Trained model.
828
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
829
+ for missing.
830
+ return_proba (bool): If True, return probabilities.
831
+
832
+ Returns:
833
+ Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels,
834
+ and probabilities if requested.
835
+ """
836
+ if model is None:
837
+ msg = "Model is not trained. Call fit() before predict()."
838
+ self.logger.error(msg)
839
+ raise NotFittedError(msg)
840
+
841
+ model.eval()
842
+ with torch.no_grad():
843
+ X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
844
+ X_tensor = X_tensor.to(self.device).long()
845
+ x_ohe = self._one_hot_encode_012(X_tensor)
846
+ logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
847
+ probas = torch.softmax(logits, dim=-1)
848
+ labels = torch.argmax(probas, dim=-1)
849
+
850
+ if return_proba:
851
+ return labels.cpu().numpy(), probas.cpu().numpy()
852
+
853
+ return labels.cpu().numpy()
854
+
855
+ def _evaluate_model(
856
+ self,
857
+ X_val: np.ndarray,
858
+ model: torch.nn.Module,
859
+ params: dict,
860
+ objective_mode: bool = False,
861
+ latent_vectors_val: Optional[np.ndarray] = None,
862
+ *,
863
+ eval_mask_override: np.ndarray | None = None,
864
+ ) -> Dict[str, float]:
865
+ """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
866
+
867
+ This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
868
+
869
+ Args:
870
+ X_val (np.ndarray): Validation set 0/1/2 matrix with -1
871
+ for missing.
872
+ model (torch.nn.Module): Trained model.
873
+ params (dict): Model parameters.
874
+ objective_mode (bool): If True, suppress logging and reports.
875
+ latent_vectors_val (Optional[np.ndarray]): Unused for AE.
876
+ eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
877
+
878
+ Returns:
879
+ Dict[str, float]: Dictionary of evaluation metrics.
880
+ """
881
+ pred_labels, pred_probas = self._predict(
882
+ model=model, X=X_val, return_proba=True
883
+ )
884
+
885
+ finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
886
+
887
+ # FIX 1: Check ROWS (shape[0]) only. X_val might be a feature subset.
888
+ if (
889
+ hasattr(self, "X_val_")
890
+ and getattr(self, "X_val_", None) is not None
891
+ and X_val.shape[0] == self.X_val_.shape[0]
892
+ ):
893
+ GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
894
+ elif (
895
+ hasattr(self, "X_train_")
896
+ and getattr(self, "X_train_", None) is not None
897
+ and X_val.shape[0] == self.X_train_.shape[0]
898
+ ):
899
+ GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
900
+ else:
901
+ GT_ref = self.ground_truth_
902
+
903
+ # FIX 2: Handle Feature Mismatch (e.g., tune_fast feature subsetting)
904
+ # If the GT source has more columns than X_val, slice it to match.
905
+ if GT_ref.shape[1] > X_val.shape[1]:
906
+ GT_ref = GT_ref[:, : X_val.shape[1]]
907
+
908
+ # Fallback if rows mismatch (unlikely after Fix 1, but safe to keep)
909
+ if GT_ref.shape != X_val.shape:
910
+ # If completely different, we can't use the ground truth object.
911
+ # Fall back to X_val (this implies only observed values are scored)
912
+ GT_ref = X_val
913
+
914
+ if eval_mask_override is not None:
915
+ # FIX 3: Allow override mask to be sliced if it's too wide
916
+ if eval_mask_override.shape[0] != X_val.shape[0]:
917
+ msg = (
918
+ f"eval_mask_override rows {eval_mask_override.shape[0]} "
919
+ f"does not match X_val rows {X_val.shape[0]}"
920
+ )
921
+ self.logger.error(msg)
922
+ raise ValueError(msg)
923
+
924
+ if eval_mask_override.shape[1] > X_val.shape[1]:
925
+ eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
926
+ else:
927
+ eval_mask = eval_mask_override.astype(bool)
928
+ else:
929
+ eval_mask = X_val != -1
930
+
931
+ # Combine masks
932
+ eval_mask = eval_mask & finite_mask & (GT_ref != -1)
933
+
934
+ y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
935
+ y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
936
+ y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
937
+
938
+ if y_true_flat.size == 0:
939
+ self.tune_metric = "f1" if self.tune_metric is None else self.tune_metric
940
+ return {self.tune_metric: 0.0}
941
+
942
+ # ensure valid probability simplex after masking (no NaNs/Infs, sums=1)
943
+ y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
944
+ row_sums = y_proba_flat.sum(axis=1, keepdims=True)
945
+ row_sums[row_sums == 0] = 1.0
946
+ y_proba_flat = y_proba_flat / row_sums
947
+
948
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
949
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
950
+
951
+ if self.is_haploid:
952
+ y_true_flat = y_true_flat.copy()
953
+ y_pred_flat = y_pred_flat.copy()
954
+ y_true_flat[y_true_flat == 2] = 1
955
+ y_pred_flat[y_pred_flat == 2] = 1
956
+ # collapse probs to 2-class
957
+ proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
958
+ proba_2[:, 0] = y_proba_flat[:, 0]
959
+ proba_2[:, 1] = y_proba_flat[:, 2]
960
+ y_proba_flat = proba_2
961
+
962
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
963
+
964
+ tune_metric_tmp: Literal[
965
+ "pr_macro",
966
+ "roc_auc",
967
+ "average_precision",
968
+ "accuracy",
969
+ "f1",
970
+ "precision",
971
+ "recall",
972
+ ]
973
+ if self.tune_metric_ is not None:
974
+ tune_metric_tmp = self.tune_metric_
975
+ else:
976
+ tune_metric_tmp = "f1" # Default if not tuning
977
+
978
+ metrics = self.scorers_.evaluate(
979
+ y_true_flat,
980
+ y_pred_flat,
981
+ y_true_ohe,
982
+ y_proba_flat,
983
+ objective_mode,
984
+ tune_metric_tmp,
985
+ )
986
+
987
+ if not objective_mode:
988
+ pm = PrettyMetrics(
989
+ metrics, precision=3, title=f"{self.model_name} Validation Metrics"
990
+ )
991
+ pm.render() # prints a command-line table
992
+
993
+ # Primary report (REF/HET/ALT or REF/ALT)
994
+ self._make_class_reports(
995
+ y_true=y_true_flat,
996
+ y_pred_proba=y_proba_flat,
997
+ y_pred=y_pred_flat,
998
+ metrics=metrics,
999
+ labels=target_names,
1000
+ )
1001
+
1002
+ # IUPAC decode & 10-base integer reports
1003
+ # Now safe because GT_ref has been sliced to match X_val dimensions
1004
+ y_true_dec = self.pgenc.decode_012(
1005
+ GT_ref.reshape(X_val.shape[0], X_val.shape[1])
1006
+ )
1007
+ X_pred = X_val.copy()
1008
+ X_pred[eval_mask] = y_pred_flat
1009
+
1010
+ # Use X_val.shape[1] (current features) not self.num_features_ (original features)
1011
+ y_pred_dec = self.pgenc.decode_012(
1012
+ X_pred.reshape(X_val.shape[0], X_val.shape[1])
1013
+ )
1014
+
1015
+ encodings_dict = {
1016
+ "A": 0,
1017
+ "C": 1,
1018
+ "G": 2,
1019
+ "T": 3,
1020
+ "W": 4,
1021
+ "R": 5,
1022
+ "M": 6,
1023
+ "K": 7,
1024
+ "Y": 8,
1025
+ "S": 9,
1026
+ "N": -1,
1027
+ }
1028
+ y_true_int = self.pgenc.convert_int_iupac(
1029
+ y_true_dec, encodings_dict=encodings_dict
1030
+ )
1031
+ y_pred_int = self.pgenc.convert_int_iupac(
1032
+ y_pred_dec, encodings_dict=encodings_dict
1033
+ )
1034
+
1035
+ valid_iupac_mask = y_true_int[eval_mask] >= 0
1036
+ if valid_iupac_mask.any():
1037
+ self._make_class_reports(
1038
+ y_true=y_true_int[eval_mask][valid_iupac_mask],
1039
+ y_pred=y_pred_int[eval_mask][valid_iupac_mask],
1040
+ metrics=metrics,
1041
+ y_pred_proba=None,
1042
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
1043
+ )
1044
+ else:
1045
+ self.logger.warning(
1046
+ "Skipped IUPAC confusion matrix: No valid ground truths."
1047
+ )
1048
+
1049
+ return metrics
1050
+
1051
+ def _objective(self, trial: optuna.Trial) -> float:
1052
+ """Optuna objective for AE; mirrors NLPCA study driver without latents.
1053
+
1054
+ This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
1055
+
1056
+ Args:
1057
+ trial (optuna.Trial): Optuna trial.
1058
+
1059
+ Returns:
1060
+ float: Value of the tuning metric (maximize).
1061
+ """
1062
+ try:
1063
+ # Sample hyperparameters (existing helper; unchanged signature)
1064
+ params = self._sample_hyperparameters(trial)
1065
+
1066
+ # Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
1067
+ X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1068
+ X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1069
+
1070
+ class_weights = self._normalize_class_weights(
1071
+ self._class_weights_from_zygosity(X_train)
1072
+ )
1073
+ train_loader = self._get_data_loaders(X_train)
1074
+
1075
+ model = self.build_model(self.Model, params["model_params"])
1076
+ model.apply(self.initialize_weights)
1077
+
1078
+ lr: float = float(params["lr"])
1079
+ l1_penalty: float = float(params["l1_penalty"])
1080
+
1081
+ # Train + prune on metric
1082
+ _, model, __ = self._train_and_validate_model(
1083
+ model=model,
1084
+ loader=train_loader,
1085
+ lr=lr,
1086
+ l1_penalty=l1_penalty,
1087
+ trial=trial,
1088
+ return_history=False,
1089
+ class_weights=class_weights,
1090
+ X_val=X_val,
1091
+ params=params,
1092
+ prune_metric=self.tune_metric,
1093
+ prune_warmup_epochs=5,
1094
+ eval_interval=self.tune_eval_interval,
1095
+ eval_requires_latents=False,
1096
+ eval_latent_steps=0,
1097
+ eval_latent_lr=0.0,
1098
+ eval_latent_weight_decay=0.0,
1099
+ )
1100
+
1101
+ eval_mask = (
1102
+ self.sim_mask_test_
1103
+ if (
1104
+ self.simulate_missing
1105
+ and getattr(self, "sim_mask_test_", None) is not None
1106
+ )
1107
+ else None
1108
+ )
1109
+
1110
+ if model is not None:
1111
+ metrics = self._evaluate_model(
1112
+ X_val,
1113
+ model,
1114
+ params,
1115
+ objective_mode=True,
1116
+ eval_mask_override=eval_mask,
1117
+ )
1118
+ self._clear_resources(model, train_loader)
1119
+ else:
1120
+ raise TypeError("Model training failed; no model was returned.")
1121
+
1122
+ return metrics[self.tune_metric]
1123
+
1124
+ except Exception as e:
1125
+ # Keep sweeps moving if a trial fails
1126
+ raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
1127
+
1128
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
1129
+ """Sample AE hyperparameters and compute hidden sizes for model params.
1130
+
1131
+ This method samples hyperparameters for the autoencoder model using Optuna's trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
1132
+
1133
+ Args:
1134
+ trial (optuna.Trial): Optuna trial object.
1135
+
1136
+ Returns:
1137
+ Dict[str, int | float | str | bool]: Sampled hyperparameters and model_params.
1138
+ """
1139
+ params = {
1140
+ "latent_dim": trial.suggest_int("latent_dim", 2, 64),
1141
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
1142
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
1143
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
1144
+ "activation": trial.suggest_categorical(
1145
+ "activation", ["relu", "elu", "selu"]
1146
+ ),
1147
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
1148
+ "layer_scaling_factor": trial.suggest_float(
1149
+ "layer_scaling_factor", 2.0, 10.0
1150
+ ),
1151
+ "layer_schedule": trial.suggest_categorical(
1152
+ "layer_schedule", ["pyramid", "constant", "linear"]
1153
+ ),
1154
+ }
1155
+
1156
+ nF: int = self.num_features_
1157
+ nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1158
+ input_dim = nF * nC
1159
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1160
+ n_inputs=input_dim,
1161
+ n_outputs=input_dim,
1162
+ n_samples=len(self.train_idx_),
1163
+ n_hidden=params["num_hidden_layers"],
1164
+ alpha=params["layer_scaling_factor"],
1165
+ schedule=params["layer_schedule"],
1166
+ )
1167
+
1168
+ # Keep the latent_dim as the first element,
1169
+ # then the interior hidden widths.
1170
+ # If there are no interior widths (very small nets),
1171
+ # this still leaves [latent_dim].
1172
+ hidden_only: list[int] = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1173
+
1174
+ params["model_params"] = {
1175
+ "n_features": int(self.num_features_),
1176
+ "num_classes": (
1177
+ int(self.num_classes_) if self.num_classes_ is not None else 3
1178
+ ),
1179
+ "latent_dim": int(params["latent_dim"]),
1180
+ "dropout_rate": float(params["dropout_rate"]),
1181
+ "hidden_layer_sizes": hidden_only,
1182
+ "activation": str(params["activation"]),
1183
+ }
1184
+ return params
1185
+
1186
+ def _set_best_params(
1187
+ self, best_params: Dict[str, int | float | str | List[int]]
1188
+ ) -> Dict[str, int | float | str | List[int]]:
1189
+ """Adopt best params (ImputeNLPCA parity) and return model_params.
1190
+
1191
+ This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
1192
+
1193
+ Args:
1194
+ best_params (Dict[str, int | float | str | List[int]]): Best hyperparameters from tuning.
1195
+
1196
+ Returns:
1197
+ Dict[str, int | float | str | List[int]]: Model parameters for building the model.
1198
+ """
1199
+ bp = {}
1200
+ for k, v in best_params.items():
1201
+ if not isinstance(v, list):
1202
+ if k in {"latent_dim", "num_hidden_layers"}:
1203
+ bp[k] = int(v)
1204
+ elif k in {
1205
+ "dropout_rate",
1206
+ "learning_rate",
1207
+ "l1_penalty",
1208
+ "layer_scaling_factor",
1209
+ }:
1210
+ bp[k] = float(v)
1211
+ elif k in {"activation", "layer_schedule"}:
1212
+ if k == "layer_schedule":
1213
+ if v not in {"pyramid", "constant", "linear"}:
1214
+ raise ValueError(f"Invalid layer_schedule: {v}")
1215
+ bp[k] = v
1216
+ else:
1217
+ bp[k] = str(v)
1218
+ else:
1219
+ bp[k] = v # keep lists as-is
1220
+
1221
+ self.latent_dim: int = bp["latent_dim"]
1222
+ self.dropout_rate: float = bp["dropout_rate"]
1223
+ self.learning_rate: float = bp["learning_rate"]
1224
+ self.l1_penalty: float = bp["l1_penalty"]
1225
+ self.activation: str = bp["activation"]
1226
+ self.layer_scaling_factor: float = bp["layer_scaling_factor"]
1227
+ self.layer_schedule: str = bp["layer_schedule"]
1228
+
1229
+ nF: int = self.num_features_
1230
+ nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1231
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1232
+ n_inputs=nF * nC,
1233
+ n_outputs=nF * nC,
1234
+ n_samples=len(self.train_idx_),
1235
+ n_hidden=bp["num_hidden_layers"],
1236
+ alpha=bp["layer_scaling_factor"],
1237
+ schedule=bp["layer_schedule"],
1238
+ )
1239
+
1240
+ # Keep the latent_dim as the first element,
1241
+ # then the interior hidden widths.
1242
+ # If there are no interior widths (very small nets),
1243
+ # this still leaves [latent_dim].
1244
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1245
+
1246
+ return {
1247
+ "n_features": self.num_features_,
1248
+ "latent_dim": self.latent_dim,
1249
+ "hidden_layer_sizes": hidden_only,
1250
+ "dropout_rate": self.dropout_rate,
1251
+ "activation": self.activation,
1252
+ "num_classes": nC,
1253
+ }
1254
+
1255
+ def _default_best_params(self) -> Dict[str, int | float | str | list]:
1256
+ """Default model params when tuning is disabled.
1257
+
1258
+ This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
1259
+
1260
+ Returns:
1261
+ Dict[str, int | float | str | list]: Default model parameters.
1262
+ """
1263
+ nF: int = self.num_features_
1264
+ nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1265
+ ls = self.layer_schedule
1266
+
1267
+ if ls not in {"pyramid", "constant", "linear"}:
1268
+ raise ValueError(f"Invalid layer_schedule: {ls}")
1269
+
1270
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1271
+ n_inputs=nF * nC,
1272
+ n_outputs=nF * nC,
1273
+ n_samples=len(self.ground_truth_),
1274
+ n_hidden=self.num_hidden_layers,
1275
+ alpha=self.layer_scaling_factor,
1276
+ schedule=ls,
1277
+ )
1278
+ return {
1279
+ "n_features": self.num_features_,
1280
+ "latent_dim": self.latent_dim,
1281
+ "hidden_layer_sizes": hidden_layer_sizes,
1282
+ "dropout_rate": self.dropout_rate,
1283
+ "activation": self.activation,
1284
+ "num_classes": nC,
1285
+ }