pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -0,0 +1,957 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import optuna
|
|
9
|
+
import torch
|
|
10
|
+
from sklearn.exceptions import NotFittedError
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
13
|
+
from snpio.utils.logging import LoggerManager
|
|
14
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
15
|
+
|
|
16
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
17
|
+
from pgsui.data_processing.containers import VAEConfig
|
|
18
|
+
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
19
|
+
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
20
|
+
from pgsui.impute.unsupervised.models.vae_model import VAEModel
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
|
|
27
|
+
"""Normalize VAEConfig input from various sources.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
VAEConfig: Normalized configuration dataclass.
|
|
34
|
+
"""
|
|
35
|
+
if config is None:
|
|
36
|
+
return VAEConfig()
|
|
37
|
+
if isinstance(config, VAEConfig):
|
|
38
|
+
return config
|
|
39
|
+
if isinstance(config, str):
|
|
40
|
+
return load_yaml_to_dataclass(
|
|
41
|
+
config, VAEConfig, preset_builder=VAEConfig.from_preset
|
|
42
|
+
)
|
|
43
|
+
if isinstance(config, dict):
|
|
44
|
+
base = VAEConfig()
|
|
45
|
+
# Respect top-level preset
|
|
46
|
+
preset = config.pop("preset", None)
|
|
47
|
+
if preset:
|
|
48
|
+
base = VAEConfig.from_preset(preset)
|
|
49
|
+
# Flatten + apply
|
|
50
|
+
flat: Dict[str, object] = {}
|
|
51
|
+
|
|
52
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
53
|
+
for k, v in d.items():
|
|
54
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
55
|
+
if isinstance(v, dict):
|
|
56
|
+
_flatten(kk, v, out)
|
|
57
|
+
else:
|
|
58
|
+
out[kk] = v
|
|
59
|
+
return out
|
|
60
|
+
|
|
61
|
+
flat = _flatten("", config, {})
|
|
62
|
+
return apply_dot_overrides(base, flat)
|
|
63
|
+
raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ImputeVAE(BaseNNImputer):
|
|
67
|
+
"""Variational Autoencoder imputer on 0/1/2 encodings (missing=-1).
|
|
68
|
+
|
|
69
|
+
This imputer implements a VAE with a multinomial (categorical) latent space. It is designed to handle missing data by inferring the latent distribution and generating plausible predictions. The model is trained using a combination of reconstruction loss (cross-entropy) and a KL divergence term, with the KL weight (beta) annealed over time. The imputer supports both haploid and diploid genotype data.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
genotype_data: "GenotypeData",
|
|
75
|
+
*,
|
|
76
|
+
config: Optional[Union["VAEConfig", dict, str]] = None,
|
|
77
|
+
overrides: dict | None = None,
|
|
78
|
+
):
|
|
79
|
+
"""Initialize the VAE imputer with a unified config interface.
|
|
80
|
+
|
|
81
|
+
This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
85
|
+
config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
|
|
86
|
+
overrides (dict | None): Optional dot-key overrides with highest precedence.
|
|
87
|
+
"""
|
|
88
|
+
self.model_name = "ImputeVAE"
|
|
89
|
+
self.genotype_data = genotype_data
|
|
90
|
+
|
|
91
|
+
# Normalize configuration and apply top-precedence overrides
|
|
92
|
+
cfg = ensure_vae_config(config)
|
|
93
|
+
if overrides:
|
|
94
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
95
|
+
self.cfg = cfg
|
|
96
|
+
|
|
97
|
+
# Logger (align with AE/NLPCA)
|
|
98
|
+
logman = LoggerManager(
|
|
99
|
+
__name__,
|
|
100
|
+
prefix=self.cfg.io.prefix,
|
|
101
|
+
debug=self.cfg.io.debug,
|
|
102
|
+
verbose=self.cfg.io.verbose,
|
|
103
|
+
)
|
|
104
|
+
self.logger = logman.get_logger()
|
|
105
|
+
|
|
106
|
+
# BaseNNImputer bootstraps device/dirs/log formatting
|
|
107
|
+
super().__init__(
|
|
108
|
+
prefix=self.cfg.io.prefix,
|
|
109
|
+
device=self.cfg.train.device,
|
|
110
|
+
verbose=self.cfg.io.verbose,
|
|
111
|
+
debug=self.cfg.io.debug,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Model hook & encoder
|
|
115
|
+
self.Model = VAEModel
|
|
116
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
117
|
+
|
|
118
|
+
# IO/global
|
|
119
|
+
self.seed = self.cfg.io.seed
|
|
120
|
+
self.n_jobs = self.cfg.io.n_jobs
|
|
121
|
+
self.prefix = self.cfg.io.prefix
|
|
122
|
+
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
123
|
+
self.verbose = self.cfg.io.verbose
|
|
124
|
+
self.debug = self.cfg.io.debug
|
|
125
|
+
self.rng = np.random.default_rng(self.seed)
|
|
126
|
+
|
|
127
|
+
# Model hyperparams (AE-parity)
|
|
128
|
+
self.latent_dim = self.cfg.model.latent_dim
|
|
129
|
+
self.dropout_rate = self.cfg.model.dropout_rate
|
|
130
|
+
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
131
|
+
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
132
|
+
self.layer_schedule = self.cfg.model.layer_schedule
|
|
133
|
+
self.activation = self.cfg.model.hidden_activation
|
|
134
|
+
self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
|
|
135
|
+
|
|
136
|
+
# VAE-only KL controls
|
|
137
|
+
self.kl_beta_final = self.cfg.vae.kl_beta
|
|
138
|
+
self.kl_warmup = self.cfg.vae.kl_warmup
|
|
139
|
+
self.kl_ramp = self.cfg.vae.kl_ramp
|
|
140
|
+
|
|
141
|
+
# Train hyperparams (AE-parity)
|
|
142
|
+
self.batch_size = self.cfg.train.batch_size
|
|
143
|
+
self.learning_rate = self.cfg.train.learning_rate
|
|
144
|
+
self.l1_penalty = self.cfg.train.l1_penalty
|
|
145
|
+
self.early_stop_gen = self.cfg.train.early_stop_gen
|
|
146
|
+
self.min_epochs = self.cfg.train.min_epochs
|
|
147
|
+
self.epochs = self.cfg.train.max_epochs
|
|
148
|
+
self.validation_split = self.cfg.train.validation_split
|
|
149
|
+
self.beta = self.cfg.train.weights_beta
|
|
150
|
+
self.max_ratio = self.cfg.train.weights_max_ratio
|
|
151
|
+
|
|
152
|
+
# Tuning (AE-parity surface; VAE ignores latent refinement during eval)
|
|
153
|
+
self.tune = self.cfg.tune.enabled
|
|
154
|
+
self.tune_fast = self.cfg.tune.fast
|
|
155
|
+
self.tune_batch_size = self.cfg.tune.batch_size
|
|
156
|
+
self.tune_epochs = self.cfg.tune.epochs
|
|
157
|
+
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
158
|
+
self.tune_metric = self.cfg.tune.metric
|
|
159
|
+
self.n_trials = self.cfg.tune.n_trials
|
|
160
|
+
self.tune_save_db = self.cfg.tune.save_db
|
|
161
|
+
self.tune_resume = self.cfg.tune.resume
|
|
162
|
+
self.tune_max_samples = self.cfg.tune.max_samples
|
|
163
|
+
self.tune_max_loci = self.cfg.tune.max_loci
|
|
164
|
+
self.tune_patience = self.cfg.tune.patience
|
|
165
|
+
|
|
166
|
+
# Plotting (AE-parity)
|
|
167
|
+
self.plot_format = self.cfg.plot.fmt
|
|
168
|
+
self.plot_dpi = self.cfg.plot.dpi
|
|
169
|
+
self.plot_fontsize = self.cfg.plot.fontsize
|
|
170
|
+
self.title_fontsize = self.cfg.plot.fontsize
|
|
171
|
+
self.despine = self.cfg.plot.despine
|
|
172
|
+
self.show_plots = self.cfg.plot.show
|
|
173
|
+
|
|
174
|
+
# Derived at fit-time
|
|
175
|
+
self.is_haploid: bool | None = None
|
|
176
|
+
self.num_classes_: int | None = None
|
|
177
|
+
self.model_params: Dict[str, Any] = {}
|
|
178
|
+
|
|
179
|
+
# -------------------- Fit -------------------- #
|
|
180
|
+
def fit(self) -> "ImputeVAE":
|
|
181
|
+
"""Fit the VAE on 0/1/2 encoded genotypes (missing → -9).
|
|
182
|
+
|
|
183
|
+
This method prepares the genotype data, initializes model parameters, splits the data into training and validation sets, and trains the VAE model. It handles both haploid and diploid data, applies class weighting, and supports optional hyperparameter tuning. After training, it evaluates the model on the validation set and saves the trained model.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
ImputeVAE: Fitted instance.
|
|
187
|
+
|
|
188
|
+
Raises:
|
|
189
|
+
RuntimeError: If training fails to produce a model.
|
|
190
|
+
"""
|
|
191
|
+
self.logger.info(f"Fitting {self.model_name} (0/1/2 VAE) ...")
|
|
192
|
+
|
|
193
|
+
# Data prep aligns with AE/NLPCA
|
|
194
|
+
X = self.pgenc.genotypes_012.astype(np.float32)
|
|
195
|
+
X[X < 0] = np.nan
|
|
196
|
+
X[np.isnan(X)] = -1
|
|
197
|
+
self.ground_truth_ = X.astype(np.int64)
|
|
198
|
+
|
|
199
|
+
# Ploidy/classes
|
|
200
|
+
self.is_haploid = np.all(
|
|
201
|
+
np.isin(
|
|
202
|
+
self.genotype_data.snp_data,
|
|
203
|
+
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
self.ploidy = 1 if self.is_haploid else 2
|
|
207
|
+
self.num_classes_ = 2 if self.is_haploid else 3
|
|
208
|
+
self.logger.info(
|
|
209
|
+
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
210
|
+
f"using {self.num_classes_} classes."
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
n_samples, self.num_features_ = X.shape
|
|
214
|
+
|
|
215
|
+
# Model params (decoder outputs L*K logits)
|
|
216
|
+
self.model_params = {
|
|
217
|
+
"n_features": self.num_features_,
|
|
218
|
+
"num_classes": self.num_classes_,
|
|
219
|
+
"latent_dim": self.latent_dim,
|
|
220
|
+
"dropout_rate": self.dropout_rate,
|
|
221
|
+
"activation": self.activation,
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
# Train/Val split
|
|
225
|
+
indices = np.arange(n_samples)
|
|
226
|
+
train_idx, val_idx = train_test_split(
|
|
227
|
+
indices, test_size=self.validation_split, random_state=self.seed
|
|
228
|
+
)
|
|
229
|
+
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
230
|
+
self.X_train_ = self.ground_truth_[train_idx]
|
|
231
|
+
self.X_val_ = self.ground_truth_[val_idx]
|
|
232
|
+
|
|
233
|
+
# Plotters/scorers (shared utilities)
|
|
234
|
+
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
235
|
+
|
|
236
|
+
# Optional tuning
|
|
237
|
+
if self.tune:
|
|
238
|
+
self.tune_hyperparameters()
|
|
239
|
+
|
|
240
|
+
# Best params (tuned or default)
|
|
241
|
+
self.best_params_ = getattr(self, "best_params_", self._default_best_params())
|
|
242
|
+
|
|
243
|
+
# Class weights (device-aware)
|
|
244
|
+
self.class_weights_ = self._class_weights_from_zygosity(self.X_train_).to(
|
|
245
|
+
self.device
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# DataLoader
|
|
249
|
+
train_loader = self._get_data_loader(self.X_train_)
|
|
250
|
+
|
|
251
|
+
# Build & train
|
|
252
|
+
model = self.build_model(self.Model, self.best_params_)
|
|
253
|
+
model.apply(self.initialize_weights)
|
|
254
|
+
|
|
255
|
+
loss, trained_model, history = self._train_and_validate_model(
|
|
256
|
+
model=model,
|
|
257
|
+
loader=train_loader,
|
|
258
|
+
lr=self.learning_rate,
|
|
259
|
+
l1_penalty=self.l1_penalty,
|
|
260
|
+
return_history=True,
|
|
261
|
+
class_weights=self.class_weights_,
|
|
262
|
+
X_val=self.X_val_,
|
|
263
|
+
params=self.best_params_,
|
|
264
|
+
prune_metric=self.tune_metric,
|
|
265
|
+
prune_warmup_epochs=5,
|
|
266
|
+
eval_interval=1,
|
|
267
|
+
eval_requires_latents=False, # no latent refinement for eval
|
|
268
|
+
eval_latent_steps=0,
|
|
269
|
+
eval_latent_lr=0.0,
|
|
270
|
+
eval_latent_weight_decay=0.0,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if trained_model is None:
|
|
274
|
+
msg = "VAE training failed; no model was returned."
|
|
275
|
+
self.logger.error(msg)
|
|
276
|
+
raise RuntimeError(msg)
|
|
277
|
+
|
|
278
|
+
torch.save(
|
|
279
|
+
trained_model.state_dict(),
|
|
280
|
+
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
self.best_loss_, self.model_, self.history_ = (
|
|
284
|
+
loss,
|
|
285
|
+
trained_model,
|
|
286
|
+
{"Train": history},
|
|
287
|
+
)
|
|
288
|
+
self.is_fit_ = True
|
|
289
|
+
|
|
290
|
+
# Evaluate (AE-parity reporting)
|
|
291
|
+
self._evaluate_model(self.X_val_, self.model_, self.best_params_)
|
|
292
|
+
self.plotter_.plot_history(self.history_)
|
|
293
|
+
self._save_best_params(self.best_params_)
|
|
294
|
+
return self
|
|
295
|
+
|
|
296
|
+
def transform(self) -> np.ndarray:
|
|
297
|
+
"""Impute missing genotypes and return IUPAC strings.
|
|
298
|
+
|
|
299
|
+
This method uses the trained VAE model to impute missing genotypes in the dataset. It predicts the most likely genotype for each missing entry based on the learned latent representations and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easy interpretation.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
np.ndarray: IUPAC strings of shape (n_samples, n_loci).
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
NotFittedError: If called before fit().
|
|
306
|
+
"""
|
|
307
|
+
if not getattr(self, "is_fit_", False):
|
|
308
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
309
|
+
|
|
310
|
+
self.logger.info("Imputing entire dataset with VAE (0/1/2)...")
|
|
311
|
+
X_to_impute = self.ground_truth_.copy()
|
|
312
|
+
|
|
313
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
|
|
314
|
+
|
|
315
|
+
# Fill only missing
|
|
316
|
+
missing_mask = X_to_impute == -1
|
|
317
|
+
imputed_array = X_to_impute.copy()
|
|
318
|
+
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
319
|
+
|
|
320
|
+
# Decode to IUPAC & plot
|
|
321
|
+
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
322
|
+
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
323
|
+
|
|
324
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
325
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
326
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
327
|
+
|
|
328
|
+
return imputed_genotypes
|
|
329
|
+
|
|
330
|
+
# ---------- plumbing identical to AE, naming aligned ---------- #
|
|
331
|
+
|
|
332
|
+
def _get_data_loader(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
333
|
+
"""Create DataLoader over indices + integer targets (-1 for missing).
|
|
334
|
+
|
|
335
|
+
This method creates a PyTorch DataLoader for the training data. It converts the input genotype matrix into a tensor and constructs a dataset that includes both the indices and the genotype values. The DataLoader is configured to shuffle the data and use the specified batch size for training.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
y (np.ndarray): 0/1/2 matrix with -1 for missing.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
342
|
+
"""
|
|
343
|
+
y_tensor = torch.from_numpy(y).long().to(self.device)
|
|
344
|
+
dataset = torch.utils.data.TensorDataset(
|
|
345
|
+
torch.arange(len(y), device=self.device), y_tensor
|
|
346
|
+
)
|
|
347
|
+
return torch.utils.data.DataLoader(
|
|
348
|
+
dataset, batch_size=self.batch_size, shuffle=True
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def _train_and_validate_model(
|
|
352
|
+
self,
|
|
353
|
+
model: torch.nn.Module,
|
|
354
|
+
loader: torch.utils.data.DataLoader,
|
|
355
|
+
lr: float,
|
|
356
|
+
l1_penalty: float,
|
|
357
|
+
trial: optuna.Trial | None = None,
|
|
358
|
+
return_history: bool = False,
|
|
359
|
+
class_weights: torch.Tensor | None = None,
|
|
360
|
+
*,
|
|
361
|
+
X_val: np.ndarray | None = None,
|
|
362
|
+
params: dict | None = None,
|
|
363
|
+
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
364
|
+
prune_warmup_epochs: int = 3,
|
|
365
|
+
eval_interval: int = 1,
|
|
366
|
+
eval_requires_latents: bool = False, # VAE: no latent eval refinement
|
|
367
|
+
eval_latent_steps: int = 0,
|
|
368
|
+
eval_latent_lr: float = 0.0,
|
|
369
|
+
eval_latent_weight_decay: float = 0.0,
|
|
370
|
+
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
371
|
+
"""Wrap the VAE training loop with β-anneal & Optuna pruning.
|
|
372
|
+
|
|
373
|
+
This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
model (torch.nn.Module): VAE model.
|
|
377
|
+
loader (torch.utils.data.DataLoader): Training data loader.
|
|
378
|
+
lr (float): Learning rate.
|
|
379
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
380
|
+
trial (optuna.Trial | None): Optuna trial for pruning.
|
|
381
|
+
return_history (bool): If True, return training history.
|
|
382
|
+
class_weights (torch.Tensor | None): CE class weights on device.
|
|
383
|
+
X_val (np.ndarray | None): Validation data for pruning eval.
|
|
384
|
+
params (dict | None): Current hyperparameters (for logging).
|
|
385
|
+
prune_metric (str | None): Metric for pruning decisions.
|
|
386
|
+
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
387
|
+
eval_interval (int): Epochs between validation evaluations.
|
|
388
|
+
eval_requires_latents (bool): If True, refine latents during eval.
|
|
389
|
+
eval_latent_steps (int): Latent refinement steps if needed.
|
|
390
|
+
eval_latent_lr (float): Latent refinement learning rate.
|
|
391
|
+
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
Tuple[float, torch.nn.Module | None, list | None]: Best loss, best model, and training history (if requested).
|
|
395
|
+
"""
|
|
396
|
+
if class_weights is None:
|
|
397
|
+
msg = "Must provide class_weights."
|
|
398
|
+
self.logger.error(msg)
|
|
399
|
+
raise TypeError(msg)
|
|
400
|
+
|
|
401
|
+
max_epochs = (
|
|
402
|
+
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
406
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
407
|
+
|
|
408
|
+
best_loss, best_model, hist = self._execute_training_loop(
|
|
409
|
+
loader=loader,
|
|
410
|
+
optimizer=optimizer,
|
|
411
|
+
scheduler=scheduler,
|
|
412
|
+
model=model,
|
|
413
|
+
l1_penalty=l1_penalty,
|
|
414
|
+
trial=trial,
|
|
415
|
+
return_history=return_history,
|
|
416
|
+
class_weights=class_weights,
|
|
417
|
+
X_val=X_val,
|
|
418
|
+
params=params,
|
|
419
|
+
prune_metric=prune_metric,
|
|
420
|
+
prune_warmup_epochs=prune_warmup_epochs,
|
|
421
|
+
eval_interval=eval_interval,
|
|
422
|
+
eval_requires_latents=eval_requires_latents,
|
|
423
|
+
eval_latent_steps=eval_latent_steps,
|
|
424
|
+
eval_latent_lr=eval_latent_lr,
|
|
425
|
+
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
426
|
+
)
|
|
427
|
+
if return_history:
|
|
428
|
+
return best_loss, best_model, hist
|
|
429
|
+
|
|
430
|
+
return best_loss, best_model, None
|
|
431
|
+
|
|
432
|
+
def _execute_training_loop(
|
|
433
|
+
self,
|
|
434
|
+
loader: torch.utils.data.DataLoader,
|
|
435
|
+
optimizer: torch.optim.Optimizer,
|
|
436
|
+
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
|
437
|
+
model: torch.nn.Module,
|
|
438
|
+
l1_penalty: float,
|
|
439
|
+
trial: optuna.Trial | None,
|
|
440
|
+
return_history: bool,
|
|
441
|
+
class_weights: torch.Tensor,
|
|
442
|
+
*,
|
|
443
|
+
X_val: np.ndarray | None = None,
|
|
444
|
+
params: dict | None = None,
|
|
445
|
+
prune_metric: str | None = None,
|
|
446
|
+
prune_warmup_epochs: int = 3,
|
|
447
|
+
eval_interval: int = 1,
|
|
448
|
+
eval_requires_latents: bool = False,
|
|
449
|
+
eval_latent_steps: int = 0,
|
|
450
|
+
eval_latent_lr: float = 0.0,
|
|
451
|
+
eval_latent_weight_decay: float = 0.0,
|
|
452
|
+
) -> Tuple[float, torch.nn.Module, list]:
|
|
453
|
+
"""Train VAE with focal CE + KL(β) anneal, early stopping & pruning.
|
|
454
|
+
|
|
455
|
+
This method implements the core training loop for the VAE model, incorporating focal cross-entropy loss for reconstruction and KL divergence with an annealed beta weight. It includes mechanisms for early stopping based on validation performance and supports pruning of unpromising trials when used with Optuna. The training process is monitored, and the best model is retained.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
loader (torch.utils.data.DataLoader): Training data loader.
|
|
459
|
+
optimizer (torch.optim.Optimizer): Optimizer.
|
|
460
|
+
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
|
|
461
|
+
model (torch.nn.Module): VAE model.
|
|
462
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
463
|
+
trial (optuna.Trial | None): Optuna trial for pruning.
|
|
464
|
+
return_history (bool): If True, return training history.
|
|
465
|
+
class_weights (torch.Tensor): CE class weights on device.
|
|
466
|
+
X_val (np.ndarray | None): Validation data for pruning eval.
|
|
467
|
+
params (dict | None): Current hyperparameters (for logging).
|
|
468
|
+
prune_metric (str | None): Metric for pruning decisions.
|
|
469
|
+
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
470
|
+
eval_interval (int): Epochs between validation evaluations.
|
|
471
|
+
eval_requires_latents (bool): If True, refine latents during eval.
|
|
472
|
+
eval_latent_steps (int): Latent refinement steps if needed.
|
|
473
|
+
eval_latent_lr (float): Latent refinement learning rate.
|
|
474
|
+
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
Tuple[float, torch.nn.Module, list[float]]: Best loss, best model, history.
|
|
478
|
+
"""
|
|
479
|
+
best_model = None
|
|
480
|
+
history: list[float] = []
|
|
481
|
+
|
|
482
|
+
early_stopping = EarlyStopping(
|
|
483
|
+
patience=self.early_stop_gen,
|
|
484
|
+
min_epochs=self.min_epochs,
|
|
485
|
+
verbose=self.verbose,
|
|
486
|
+
prefix=self.prefix,
|
|
487
|
+
debug=self.debug,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
# AE-parity gamma schedule for focal CE (reconstruction)
|
|
491
|
+
gamma_warm, gamma_ramp, gamma_final = 50, 100, self.gamma
|
|
492
|
+
# VAE β schedule for KL term
|
|
493
|
+
beta_warm, beta_ramp, beta_final = (
|
|
494
|
+
self.kl_warmup,
|
|
495
|
+
self.kl_ramp,
|
|
496
|
+
self.kl_beta_final,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
for epoch in range(scheduler.T_max):
|
|
500
|
+
# schedules
|
|
501
|
+
# focal γ schedule (if your VAEModel uses it for recon CE)
|
|
502
|
+
if epoch < gamma_warm:
|
|
503
|
+
model.gamma = 0.0
|
|
504
|
+
elif epoch < gamma_warm + gamma_ramp:
|
|
505
|
+
model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp)
|
|
506
|
+
else:
|
|
507
|
+
model.gamma = gamma_final
|
|
508
|
+
|
|
509
|
+
# KL β schedule
|
|
510
|
+
if epoch < beta_warm:
|
|
511
|
+
model.beta = 0.0
|
|
512
|
+
elif epoch < beta_warm + beta_ramp:
|
|
513
|
+
model.beta = beta_final * ((epoch - beta_warm) / beta_ramp)
|
|
514
|
+
else:
|
|
515
|
+
model.beta = beta_final
|
|
516
|
+
|
|
517
|
+
# one epoch
|
|
518
|
+
train_loss = self._train_step(
|
|
519
|
+
loader=loader,
|
|
520
|
+
optimizer=optimizer,
|
|
521
|
+
model=model,
|
|
522
|
+
l1_penalty=l1_penalty,
|
|
523
|
+
class_weights=class_weights,
|
|
524
|
+
)
|
|
525
|
+
if trial and (np.isnan(train_loss) or np.isinf(train_loss)):
|
|
526
|
+
raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
|
|
527
|
+
|
|
528
|
+
scheduler.step()
|
|
529
|
+
if return_history:
|
|
530
|
+
history.append(train_loss)
|
|
531
|
+
|
|
532
|
+
early_stopping(train_loss, model)
|
|
533
|
+
if early_stopping.early_stop:
|
|
534
|
+
self.logger.info(f"Early stopping at epoch {epoch + 1}.")
|
|
535
|
+
break
|
|
536
|
+
|
|
537
|
+
# Optuna report/prune on validation metric
|
|
538
|
+
if (
|
|
539
|
+
trial is not None
|
|
540
|
+
and X_val is not None
|
|
541
|
+
and ((epoch + 1) % eval_interval == 0)
|
|
542
|
+
):
|
|
543
|
+
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
544
|
+
metric_val = self._eval_for_pruning(
|
|
545
|
+
model=model,
|
|
546
|
+
X_val=X_val,
|
|
547
|
+
params=params or getattr(self, "best_params_", {}),
|
|
548
|
+
metric=metric_key,
|
|
549
|
+
objective_mode=True,
|
|
550
|
+
do_latent_infer=False, # VAE: no latent refinement needed
|
|
551
|
+
latent_steps=0,
|
|
552
|
+
latent_lr=0.0,
|
|
553
|
+
latent_weight_decay=0.0,
|
|
554
|
+
latent_seed=(self.seed if self.seed is not None else 123),
|
|
555
|
+
_latent_cache=None,
|
|
556
|
+
_latent_cache_key=None,
|
|
557
|
+
)
|
|
558
|
+
trial.report(metric_val, step=epoch + 1)
|
|
559
|
+
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
560
|
+
raise optuna.exceptions.TrialPruned(
|
|
561
|
+
f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
best_loss = early_stopping.best_score
|
|
565
|
+
best_model = copy.deepcopy(early_stopping.best_model)
|
|
566
|
+
return best_loss, best_model, history
|
|
567
|
+
|
|
568
|
+
def _train_step(
|
|
569
|
+
self,
|
|
570
|
+
loader: torch.utils.data.DataLoader,
|
|
571
|
+
optimizer: torch.optim.Optimizer,
|
|
572
|
+
model: torch.nn.Module,
|
|
573
|
+
l1_penalty: float,
|
|
574
|
+
class_weights: torch.Tensor,
|
|
575
|
+
) -> float:
|
|
576
|
+
"""One epoch: one-hot inputs → VAE forward → recon (focal) + KL.
|
|
577
|
+
|
|
578
|
+
The VAEModel is expected to return (recon_logits, mu, logvar, ...) and expose a `compute_loss(outputs, y, mask, class_weights)` method that reads scheduled `model.beta` (and optionally `model.gamma`) attributes.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
loader (torch.utils.data.DataLoader): Yields (indices, y_int) where y_int is 0/1/2; -1 for missing.
|
|
582
|
+
optimizer (torch.optim.Optimizer): Optimizer.
|
|
583
|
+
model (torch.nn.Module): VAE model.
|
|
584
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
585
|
+
class_weights (torch.Tensor): CE class weights on device.
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
float: Mean training loss for the epoch.
|
|
589
|
+
"""
|
|
590
|
+
model.train()
|
|
591
|
+
running = 0.0
|
|
592
|
+
|
|
593
|
+
for _, y_batch in loader:
|
|
594
|
+
optimizer.zero_grad(set_to_none=True)
|
|
595
|
+
|
|
596
|
+
x_ohe = self._one_hot_encode_012(y_batch) # (B, L, K), zeros for -1
|
|
597
|
+
outputs = model(x_ohe) # (recon_logits, mu, logvar, ...)
|
|
598
|
+
|
|
599
|
+
# Targets for masked focal CE, same shapes as AE path
|
|
600
|
+
y_ohe = self._one_hot_encode_012(y_batch)
|
|
601
|
+
valid_mask = y_batch != -1
|
|
602
|
+
|
|
603
|
+
loss = model.compute_loss(
|
|
604
|
+
outputs=outputs,
|
|
605
|
+
y=y_ohe, # (B, L, K)
|
|
606
|
+
mask=valid_mask, # (B, L)
|
|
607
|
+
class_weights=class_weights,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
if l1_penalty > 0:
|
|
611
|
+
loss = loss + l1_penalty * sum(
|
|
612
|
+
p.abs().sum() for p in model.parameters()
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
loss.backward()
|
|
616
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
617
|
+
optimizer.step()
|
|
618
|
+
running += float(loss.item())
|
|
619
|
+
|
|
620
|
+
return running / len(loader)
|
|
621
|
+
|
|
622
|
+
def _predict(
|
|
623
|
+
self,
|
|
624
|
+
model: torch.nn.Module,
|
|
625
|
+
X: np.ndarray | torch.Tensor,
|
|
626
|
+
return_proba: bool = False,
|
|
627
|
+
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
|
|
628
|
+
"""Predict 0/1/2 labels (and probabilities) from masked inputs.
|
|
629
|
+
|
|
630
|
+
This method uses the trained VAE model to predict genotype labels for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted labels and probabilities. The method can return either just the predicted labels or both labels and probabilities based on the `return_proba` flag.
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
model (torch.nn.Module): Trained model.
|
|
634
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
|
|
635
|
+
return_proba (bool): If True, also return probabilities.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels, and probabilities if requested.
|
|
639
|
+
"""
|
|
640
|
+
if model is None:
|
|
641
|
+
msg = "Model is not trained. Call fit() before predict()."
|
|
642
|
+
self.logger.error(msg)
|
|
643
|
+
raise NotFittedError(msg)
|
|
644
|
+
|
|
645
|
+
model.eval()
|
|
646
|
+
with torch.no_grad():
|
|
647
|
+
X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
|
|
648
|
+
X_tensor = X_tensor.to(self.device).long()
|
|
649
|
+
x_ohe = self._one_hot_encode_012(X_tensor)
|
|
650
|
+
outputs = model(x_ohe) # first element must be recon logits
|
|
651
|
+
logits = outputs[0].view(-1, self.num_features_, self.num_classes_)
|
|
652
|
+
probas = torch.softmax(logits, dim=-1)
|
|
653
|
+
labels = torch.argmax(probas, dim=-1)
|
|
654
|
+
|
|
655
|
+
if return_proba:
|
|
656
|
+
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
657
|
+
|
|
658
|
+
return labels.cpu().numpy()
|
|
659
|
+
|
|
660
|
+
def _evaluate_model(
|
|
661
|
+
self,
|
|
662
|
+
X_val: np.ndarray,
|
|
663
|
+
model: torch.nn.Module,
|
|
664
|
+
params: dict,
|
|
665
|
+
objective_mode: bool = False,
|
|
666
|
+
latent_vectors_val: np.ndarray | None = None,
|
|
667
|
+
) -> Dict[str, float]:
|
|
668
|
+
"""Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
|
|
669
|
+
|
|
670
|
+
This method evaluates the trained VAE model on a validation dataset, computing various performance metrics. It handles missing data appropriately and generates detailed classification reports for both the original 0/1/2 encoding and the decoded IUPAC and integer formats. The evaluation metrics are logged for review.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
|
|
674
|
+
model (torch.nn.Module): Trained model.
|
|
675
|
+
params (dict): Current hyperparameters (for logging).
|
|
676
|
+
objective_mode (bool): If True, minimize logging for Optuna.
|
|
677
|
+
latent_vectors_val (np.ndarray | None): Not used by VAE.
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
Dict[str, float]: Computed metrics.
|
|
681
|
+
|
|
682
|
+
Raises:
|
|
683
|
+
NotFittedError: If called before fit().
|
|
684
|
+
"""
|
|
685
|
+
pred_labels, pred_probas = self._predict(
|
|
686
|
+
model=model, X=X_val, return_proba=True
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
# mask out true missing AND any non-finite prob rows
|
|
690
|
+
finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N,L)
|
|
691
|
+
eval_mask = (X_val != -1) & finite_mask
|
|
692
|
+
|
|
693
|
+
y_true_flat = X_val[eval_mask].astype(np.int64, copy=False)
|
|
694
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
|
|
695
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
|
|
696
|
+
|
|
697
|
+
if y_true_flat.size == 0:
|
|
698
|
+
return {self.tune_metric: 0.0}
|
|
699
|
+
|
|
700
|
+
# ensure valid probability simplex after masking
|
|
701
|
+
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
702
|
+
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
703
|
+
row_sums[row_sums == 0] = 1.0
|
|
704
|
+
y_proba_flat = y_proba_flat / row_sums
|
|
705
|
+
|
|
706
|
+
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
707
|
+
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
708
|
+
|
|
709
|
+
if self.is_haploid:
|
|
710
|
+
y_true_flat = y_true_flat.copy()
|
|
711
|
+
y_pred_flat = y_pred_flat.copy()
|
|
712
|
+
y_true_flat[y_true_flat == 2] = 1
|
|
713
|
+
y_pred_flat[y_pred_flat == 2] = 1
|
|
714
|
+
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
715
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
716
|
+
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
717
|
+
y_proba_flat = proba_2
|
|
718
|
+
|
|
719
|
+
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
720
|
+
|
|
721
|
+
metrics = self.scorers_.evaluate(
|
|
722
|
+
y_true_flat,
|
|
723
|
+
y_pred_flat,
|
|
724
|
+
y_true_ohe,
|
|
725
|
+
y_proba_flat,
|
|
726
|
+
objective_mode,
|
|
727
|
+
self.tune_metric,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
if not objective_mode:
|
|
731
|
+
self.logger.info(f"Validation Metrics: {metrics}")
|
|
732
|
+
|
|
733
|
+
# Primary report
|
|
734
|
+
self._make_class_reports(
|
|
735
|
+
y_true=y_true_flat,
|
|
736
|
+
y_pred_proba=y_proba_flat,
|
|
737
|
+
y_pred=y_pred_flat,
|
|
738
|
+
metrics=metrics,
|
|
739
|
+
labels=target_names,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# IUPAC decode & 10-base integer report
|
|
743
|
+
y_true_dec = self.pgenc.decode_012(X_val)
|
|
744
|
+
X_pred = X_val.copy()
|
|
745
|
+
X_pred[eval_mask] = y_pred_flat
|
|
746
|
+
y_pred_dec = self.pgenc.decode_012(
|
|
747
|
+
X_pred.reshape(X_val.shape[0], self.num_features_)
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
encodings_dict = {
|
|
751
|
+
"A": 0,
|
|
752
|
+
"C": 1,
|
|
753
|
+
"G": 2,
|
|
754
|
+
"T": 3,
|
|
755
|
+
"W": 4,
|
|
756
|
+
"R": 5,
|
|
757
|
+
"M": 6,
|
|
758
|
+
"K": 7,
|
|
759
|
+
"Y": 8,
|
|
760
|
+
"S": 9,
|
|
761
|
+
"N": -1,
|
|
762
|
+
}
|
|
763
|
+
y_true_int = self.pgenc.convert_int_iupac(
|
|
764
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
765
|
+
)
|
|
766
|
+
y_pred_int = self.pgenc.convert_int_iupac(
|
|
767
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
self._make_class_reports(
|
|
771
|
+
y_true=y_true_int[eval_mask],
|
|
772
|
+
y_pred=y_pred_int[eval_mask],
|
|
773
|
+
metrics=metrics,
|
|
774
|
+
y_pred_proba=None,
|
|
775
|
+
labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
return metrics
|
|
779
|
+
|
|
780
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
781
|
+
"""Optuna objective for VAE (no latent refinement during eval).
|
|
782
|
+
|
|
783
|
+
This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, trains the VAE model with these parameters, and evaluates its performance on a validation set. The evaluation metric specified by `self.tune_metric` is returned for optimization. If training fails, the trial is pruned to keep the tuning process efficient.
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
trial (optuna.Trial): Optuna trial object.
|
|
787
|
+
|
|
788
|
+
Returns:
|
|
789
|
+
float: Value of the tuning metric to be optimized.
|
|
790
|
+
"""
|
|
791
|
+
try:
|
|
792
|
+
params = self._sample_hyperparameters(trial)
|
|
793
|
+
|
|
794
|
+
X_train = self.ground_truth_[self.train_idx_]
|
|
795
|
+
X_val = self.ground_truth_[self.test_idx_]
|
|
796
|
+
|
|
797
|
+
class_weights = self._class_weights_from_zygosity(X_train).to(self.device)
|
|
798
|
+
train_loader = self._get_data_loader(X_train)
|
|
799
|
+
|
|
800
|
+
model = self.build_model(self.Model, params["model_params"])
|
|
801
|
+
model.apply(self.initialize_weights)
|
|
802
|
+
|
|
803
|
+
# Train + prune on metric
|
|
804
|
+
_, model, _ = self._train_and_validate_model(
|
|
805
|
+
model=model,
|
|
806
|
+
loader=train_loader,
|
|
807
|
+
lr=params["lr"],
|
|
808
|
+
l1_penalty=params["l1_penalty"],
|
|
809
|
+
trial=trial,
|
|
810
|
+
return_history=False,
|
|
811
|
+
class_weights=class_weights,
|
|
812
|
+
X_val=X_val,
|
|
813
|
+
params=params,
|
|
814
|
+
prune_metric=self.tune_metric,
|
|
815
|
+
prune_warmup_epochs=5,
|
|
816
|
+
eval_interval=self.tune_eval_interval,
|
|
817
|
+
eval_requires_latents=False,
|
|
818
|
+
eval_latent_steps=0,
|
|
819
|
+
eval_latent_lr=0.0,
|
|
820
|
+
eval_latent_weight_decay=0.0,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
metrics = self._evaluate_model(X_val, model, params, objective_mode=True)
|
|
824
|
+
self._clear_resources(model, train_loader)
|
|
825
|
+
return metrics[self.tune_metric]
|
|
826
|
+
|
|
827
|
+
except Exception as e:
|
|
828
|
+
# Keep sweeps moving
|
|
829
|
+
self.logger.debug(f"Trial failed with error: {e}")
|
|
830
|
+
raise optuna.exceptions.TrialPruned(
|
|
831
|
+
f"Trial failed with error. Enable debug logging for details."
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
def _sample_hyperparameters(
|
|
835
|
+
self, trial: optuna.Trial
|
|
836
|
+
) -> Dict[str, int | float | str]:
|
|
837
|
+
"""Sample VAE hyperparameters; hidden sizes mirror AE/NLPCA helper.
|
|
838
|
+
|
|
839
|
+
Args:
|
|
840
|
+
trial (optuna.Trial): Optuna trial object.
|
|
841
|
+
|
|
842
|
+
Returns:
|
|
843
|
+
Dict[str, int | float | str]: Sampled hyperparameters.
|
|
844
|
+
"""
|
|
845
|
+
params = {
|
|
846
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 64),
|
|
847
|
+
"lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
|
|
848
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
|
|
849
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
|
|
850
|
+
"activation": trial.suggest_categorical(
|
|
851
|
+
"activation", ["relu", "elu", "selu"]
|
|
852
|
+
),
|
|
853
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
|
|
854
|
+
"layer_scaling_factor": trial.suggest_float(
|
|
855
|
+
"layer_scaling_factor", 2.0, 10.0
|
|
856
|
+
),
|
|
857
|
+
"layer_schedule": trial.suggest_categorical(
|
|
858
|
+
"layer_schedule", ["pyramid", "constant", "linear"]
|
|
859
|
+
),
|
|
860
|
+
# VAE-specific β (final value after anneal)
|
|
861
|
+
"beta": trial.suggest_float("beta", 0.25, 4.0),
|
|
862
|
+
# focal gamma (if used in VAE recon CE)
|
|
863
|
+
"gamma": trial.suggest_float("gamma", 0.0, 5.0),
|
|
864
|
+
}
|
|
865
|
+
|
|
866
|
+
input_dim = self.num_features_ * self.num_classes_
|
|
867
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
868
|
+
n_inputs=input_dim,
|
|
869
|
+
n_outputs=input_dim,
|
|
870
|
+
n_samples=len(self.train_idx_),
|
|
871
|
+
n_hidden=params["num_hidden_layers"],
|
|
872
|
+
alpha=params["layer_scaling_factor"],
|
|
873
|
+
schedule=params["layer_schedule"],
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# [latent_dim] + interior widths (exclude output width)
|
|
877
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
878
|
+
|
|
879
|
+
params["model_params"] = {
|
|
880
|
+
"n_features": self.num_features_,
|
|
881
|
+
"num_classes": self.num_classes_,
|
|
882
|
+
"latent_dim": params["latent_dim"],
|
|
883
|
+
"dropout_rate": params["dropout_rate"],
|
|
884
|
+
"hidden_layer_sizes": hidden_only,
|
|
885
|
+
"activation": params["activation"],
|
|
886
|
+
# Pass through VAE recon/regularization coefficients
|
|
887
|
+
"beta": params["beta"],
|
|
888
|
+
"gamma": params["gamma"],
|
|
889
|
+
}
|
|
890
|
+
return params
|
|
891
|
+
|
|
892
|
+
def _set_best_params(
|
|
893
|
+
self, best_params: Dict[str, int | float | str | list]
|
|
894
|
+
) -> Dict[str, int | float | str | list]:
|
|
895
|
+
"""Adopt best params and return VAE model_params.
|
|
896
|
+
|
|
897
|
+
Args:
|
|
898
|
+
best_params (Dict[str, int | float | str | list]): Best hyperparameters from tuning.
|
|
899
|
+
|
|
900
|
+
Returns:
|
|
901
|
+
Dict[str, int | float | str | list]: VAE model parameters.
|
|
902
|
+
"""
|
|
903
|
+
self.latent_dim = best_params["latent_dim"]
|
|
904
|
+
self.dropout_rate = best_params["dropout_rate"]
|
|
905
|
+
self.learning_rate = best_params["learning_rate"]
|
|
906
|
+
self.l1_penalty = best_params["l1_penalty"]
|
|
907
|
+
self.activation = best_params["activation"]
|
|
908
|
+
self.layer_scaling_factor = best_params["layer_scaling_factor"]
|
|
909
|
+
self.layer_schedule = best_params["layer_schedule"]
|
|
910
|
+
self.kl_beta_final = best_params.get("beta", self.kl_beta_final)
|
|
911
|
+
self.gamma = best_params.get("gamma", self.gamma)
|
|
912
|
+
|
|
913
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
914
|
+
n_inputs=self.num_features_ * self.num_classes_,
|
|
915
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
916
|
+
n_samples=len(self.train_idx_),
|
|
917
|
+
n_hidden=best_params["num_hidden_layers"],
|
|
918
|
+
alpha=best_params["layer_scaling_factor"],
|
|
919
|
+
schedule=best_params["layer_schedule"],
|
|
920
|
+
)
|
|
921
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
922
|
+
|
|
923
|
+
return {
|
|
924
|
+
"n_features": self.num_features_,
|
|
925
|
+
"latent_dim": self.latent_dim,
|
|
926
|
+
"hidden_layer_sizes": hidden_only,
|
|
927
|
+
"dropout_rate": self.dropout_rate,
|
|
928
|
+
"activation": self.activation,
|
|
929
|
+
"num_classes": self.num_classes_,
|
|
930
|
+
"beta": self.kl_beta_final,
|
|
931
|
+
"gamma": self.gamma,
|
|
932
|
+
}
|
|
933
|
+
|
|
934
|
+
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
935
|
+
"""Default VAE model params when tuning is disabled.
|
|
936
|
+
|
|
937
|
+
Returns:
|
|
938
|
+
Dict[str, int | float | str | list]: VAE model parameters.
|
|
939
|
+
"""
|
|
940
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
941
|
+
n_inputs=self.num_features_ * self.num_classes_,
|
|
942
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
943
|
+
n_samples=len(self.ground_truth_),
|
|
944
|
+
n_hidden=self.num_hidden_layers,
|
|
945
|
+
alpha=self.layer_scaling_factor,
|
|
946
|
+
schedule=self.layer_schedule,
|
|
947
|
+
)
|
|
948
|
+
return {
|
|
949
|
+
"n_features": self.num_features_,
|
|
950
|
+
"latent_dim": self.latent_dim,
|
|
951
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
952
|
+
"dropout_rate": self.dropout_rate,
|
|
953
|
+
"activation": self.activation,
|
|
954
|
+
"num_classes": self.num_classes_,
|
|
955
|
+
"beta": self.kl_beta_final,
|
|
956
|
+
"gamma": self.gamma,
|
|
957
|
+
}
|