pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,1228 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import optuna
|
|
9
|
+
import torch
|
|
10
|
+
from sklearn.exceptions import NotFittedError
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
13
|
+
from snpio.utils.logging import LoggerManager
|
|
14
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
15
|
+
|
|
16
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
17
|
+
from pgsui.data_processing.containers import VAEConfig
|
|
18
|
+
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
19
|
+
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
20
|
+
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
21
|
+
from pgsui.impute.unsupervised.loss_functions import compute_vae_loss
|
|
22
|
+
from pgsui.impute.unsupervised.models.vae_model import VAEModel
|
|
23
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
24
|
+
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from snpio import TreeParser
|
|
28
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
|
|
32
|
+
"""Normalize VAEConfig input from various sources.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
VAEConfig: Normalized configuration dataclass.
|
|
39
|
+
"""
|
|
40
|
+
if config is None:
|
|
41
|
+
return VAEConfig()
|
|
42
|
+
if isinstance(config, VAEConfig):
|
|
43
|
+
return config
|
|
44
|
+
if isinstance(config, str):
|
|
45
|
+
return load_yaml_to_dataclass(config, VAEConfig)
|
|
46
|
+
if isinstance(config, dict):
|
|
47
|
+
base = VAEConfig()
|
|
48
|
+
# Respect top-level preset
|
|
49
|
+
preset = config.pop("preset", None)
|
|
50
|
+
if preset:
|
|
51
|
+
base = VAEConfig.from_preset(preset)
|
|
52
|
+
# Flatten + apply
|
|
53
|
+
flat: Dict[str, object] = {}
|
|
54
|
+
|
|
55
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
56
|
+
for k, v in d.items():
|
|
57
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
58
|
+
if isinstance(v, dict):
|
|
59
|
+
_flatten(kk, v, out)
|
|
60
|
+
else:
|
|
61
|
+
out[kk] = v
|
|
62
|
+
return out
|
|
63
|
+
|
|
64
|
+
flat = _flatten("", config, {})
|
|
65
|
+
return apply_dot_overrides(base, flat)
|
|
66
|
+
raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ImputeVAE(BaseNNImputer):
|
|
70
|
+
"""Variational Autoencoder imputer on 0/1/2 encodings (missing=-1).
|
|
71
|
+
|
|
72
|
+
This imputer implements a VAE with a multinomial (categorical) latent space. It is designed to handle missing data by inferring the latent distribution and generating plausible predictions. The model is trained using a combination of reconstruction loss (cross-entropy) and a KL divergence term, with the KL weight (beta) annealed over time. The imputer supports both haploid and diploid genotype data.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
genotype_data: "GenotypeData",
|
|
78
|
+
*,
|
|
79
|
+
tree_parser: Optional["TreeParser"] = None,
|
|
80
|
+
config: Optional[Union["VAEConfig", dict, str]] = None,
|
|
81
|
+
overrides: dict | None = None,
|
|
82
|
+
simulate_missing: bool | None = None,
|
|
83
|
+
sim_strategy: (
|
|
84
|
+
Literal[
|
|
85
|
+
"random",
|
|
86
|
+
"random_weighted",
|
|
87
|
+
"random_weighted_inv",
|
|
88
|
+
"nonrandom",
|
|
89
|
+
"nonrandom_weighted",
|
|
90
|
+
]
|
|
91
|
+
| None
|
|
92
|
+
) = None,
|
|
93
|
+
sim_prop: float | None = None,
|
|
94
|
+
sim_kwargs: dict | None = None,
|
|
95
|
+
):
|
|
96
|
+
"""Initialize the VAE imputer with a unified config interface.
|
|
97
|
+
|
|
98
|
+
This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
102
|
+
tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
|
|
103
|
+
config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
|
|
104
|
+
overrides (dict | None): Optional dot-key overrides with highest precedence.
|
|
105
|
+
simulate_missing (bool | None): Whether to simulate missing data during training.
|
|
106
|
+
sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
|
|
107
|
+
sim_prop (float | None): Proportion of data to simulate as missing if simulating.
|
|
108
|
+
sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
|
|
109
|
+
"""
|
|
110
|
+
self.model_name = "ImputeVAE"
|
|
111
|
+
self.genotype_data = genotype_data
|
|
112
|
+
self.tree_parser = tree_parser
|
|
113
|
+
|
|
114
|
+
# Normalize configuration and apply top-precedence overrides
|
|
115
|
+
cfg = ensure_vae_config(config)
|
|
116
|
+
if overrides:
|
|
117
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
118
|
+
self.cfg = cfg
|
|
119
|
+
|
|
120
|
+
# Logger (align with AE/NLPCA)
|
|
121
|
+
logman = LoggerManager(
|
|
122
|
+
__name__,
|
|
123
|
+
prefix=self.cfg.io.prefix,
|
|
124
|
+
debug=self.cfg.io.debug,
|
|
125
|
+
verbose=self.cfg.io.verbose,
|
|
126
|
+
)
|
|
127
|
+
self.logger = configure_logger(
|
|
128
|
+
logman.get_logger(),
|
|
129
|
+
verbose=self.cfg.io.verbose,
|
|
130
|
+
debug=self.cfg.io.debug,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# BaseNNImputer bootstraps device/dirs/log formatting
|
|
134
|
+
super().__init__(
|
|
135
|
+
model_name=self.model_name,
|
|
136
|
+
genotype_data=self.genotype_data,
|
|
137
|
+
prefix=self.cfg.io.prefix,
|
|
138
|
+
device=self.cfg.train.device,
|
|
139
|
+
verbose=self.cfg.io.verbose,
|
|
140
|
+
debug=self.cfg.io.debug,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Model hook & encoder
|
|
144
|
+
self.Model = VAEModel
|
|
145
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
146
|
+
|
|
147
|
+
# IO/global
|
|
148
|
+
self.seed = self.cfg.io.seed
|
|
149
|
+
self.n_jobs = self.cfg.io.n_jobs
|
|
150
|
+
self.prefix = self.cfg.io.prefix
|
|
151
|
+
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
152
|
+
self.verbose = self.cfg.io.verbose
|
|
153
|
+
self.debug = self.cfg.io.debug
|
|
154
|
+
self.rng = np.random.default_rng(self.seed)
|
|
155
|
+
|
|
156
|
+
# Simulated-missing controls (config defaults + ctor overrides)
|
|
157
|
+
sim_cfg = getattr(self.cfg, "sim", None)
|
|
158
|
+
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
159
|
+
if sim_kwargs:
|
|
160
|
+
sim_cfg_kwargs.update(sim_kwargs)
|
|
161
|
+
if sim_cfg is None:
|
|
162
|
+
default_sim_flag = bool(simulate_missing)
|
|
163
|
+
default_strategy = "random"
|
|
164
|
+
default_prop = 0.10
|
|
165
|
+
else:
|
|
166
|
+
default_sim_flag = sim_cfg.simulate_missing
|
|
167
|
+
default_strategy = sim_cfg.sim_strategy
|
|
168
|
+
default_prop = sim_cfg.sim_prop
|
|
169
|
+
self.simulate_missing = (
|
|
170
|
+
default_sim_flag if simulate_missing is None else bool(simulate_missing)
|
|
171
|
+
)
|
|
172
|
+
self.sim_strategy = sim_strategy or default_strategy
|
|
173
|
+
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
174
|
+
self.sim_kwargs = sim_cfg_kwargs
|
|
175
|
+
|
|
176
|
+
# Model hyperparams (AE-parity)
|
|
177
|
+
self.latent_dim = self.cfg.model.latent_dim
|
|
178
|
+
self.dropout_rate = self.cfg.model.dropout_rate
|
|
179
|
+
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
180
|
+
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
181
|
+
self.layer_schedule = self.cfg.model.layer_schedule
|
|
182
|
+
self.activation = self.cfg.model.hidden_activation
|
|
183
|
+
self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
|
|
184
|
+
|
|
185
|
+
# VAE-only KL controls
|
|
186
|
+
self.kl_beta_final = self.cfg.vae.kl_beta
|
|
187
|
+
self.kl_warmup = self.cfg.vae.kl_warmup
|
|
188
|
+
self.kl_ramp = self.cfg.vae.kl_ramp
|
|
189
|
+
|
|
190
|
+
# Train hyperparams (AE-parity)
|
|
191
|
+
self.batch_size = self.cfg.train.batch_size
|
|
192
|
+
self.learning_rate = self.cfg.train.learning_rate
|
|
193
|
+
self.l1_penalty: float = self.cfg.train.l1_penalty
|
|
194
|
+
self.early_stop_gen = self.cfg.train.early_stop_gen
|
|
195
|
+
self.min_epochs = self.cfg.train.min_epochs
|
|
196
|
+
self.epochs = self.cfg.train.max_epochs
|
|
197
|
+
self.validation_split = self.cfg.train.validation_split
|
|
198
|
+
self.beta = self.cfg.train.weights_beta
|
|
199
|
+
self.max_ratio = self.cfg.train.weights_max_ratio
|
|
200
|
+
|
|
201
|
+
# Tuning (AE-parity surface; VAE ignores latent refinement during eval)
|
|
202
|
+
self.tune = self.cfg.tune.enabled
|
|
203
|
+
self.tune_fast = self.cfg.tune.fast
|
|
204
|
+
self.tune_batch_size = self.cfg.tune.batch_size
|
|
205
|
+
self.tune_epochs = self.cfg.tune.epochs
|
|
206
|
+
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
207
|
+
self.tune_metric: Literal[
|
|
208
|
+
"pr_macro",
|
|
209
|
+
"f1",
|
|
210
|
+
"accuracy",
|
|
211
|
+
"average_precision",
|
|
212
|
+
"precision",
|
|
213
|
+
"recall",
|
|
214
|
+
"roc_auc",
|
|
215
|
+
] = self.cfg.tune.metric
|
|
216
|
+
self.n_trials = self.cfg.tune.n_trials
|
|
217
|
+
self.tune_save_db = self.cfg.tune.save_db
|
|
218
|
+
self.tune_resume = self.cfg.tune.resume
|
|
219
|
+
self.tune_max_samples = self.cfg.tune.max_samples
|
|
220
|
+
self.tune_max_loci = self.cfg.tune.max_loci
|
|
221
|
+
self.tune_patience = self.cfg.tune.patience
|
|
222
|
+
|
|
223
|
+
# Plotting (AE-parity)
|
|
224
|
+
self.plot_format = self.cfg.plot.fmt
|
|
225
|
+
self.plot_dpi = self.cfg.plot.dpi
|
|
226
|
+
self.plot_fontsize = self.cfg.plot.fontsize
|
|
227
|
+
self.title_fontsize = self.cfg.plot.fontsize
|
|
228
|
+
self.despine = self.cfg.plot.despine
|
|
229
|
+
self.show_plots = self.cfg.plot.show
|
|
230
|
+
|
|
231
|
+
# Derived at fit-time
|
|
232
|
+
self.is_haploid: bool = False
|
|
233
|
+
self.num_classes_: int = 3 # diploid default
|
|
234
|
+
self.model_params: Dict[str, Any] = {}
|
|
235
|
+
self.sim_mask_global_: np.ndarray | None = None
|
|
236
|
+
self.sim_mask_train_: np.ndarray | None = None
|
|
237
|
+
self.sim_mask_test_: np.ndarray | None = None
|
|
238
|
+
|
|
239
|
+
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
240
|
+
msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
|
|
241
|
+
self.logger.error(msg)
|
|
242
|
+
raise ValueError(msg)
|
|
243
|
+
|
|
244
|
+
# -------------------- Fit -------------------- #
|
|
245
|
+
def fit(self) -> "ImputeVAE":
|
|
246
|
+
"""Fit the VAE on 0/1/2 encoded genotypes (missing -> -1).
|
|
247
|
+
|
|
248
|
+
This method prepares the genotype data, initializes model parameters, splits the data into training and validation sets, and trains the VAE model. Missing positions are encoded as -1 for loss masking (any simulated-missing loci are temporarily tagged with -9 by ``SimMissingTransformer`` before being re-encoded as -1). It handles both haploid and diploid data, applies class weighting, and supports optional hyperparameter tuning. After training, it evaluates the model on the validation set and saves the trained model.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
ImputeVAE: Fitted instance.
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
RuntimeError: If training fails to produce a model.
|
|
255
|
+
"""
|
|
256
|
+
self.logger.info(f"Fitting {self.model_name} model...")
|
|
257
|
+
|
|
258
|
+
# Data prep aligns with AE/NLPCA
|
|
259
|
+
X012 = self._get_float_genotypes(copy=True)
|
|
260
|
+
GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
|
|
261
|
+
self.ground_truth_ = GT_full.astype(np.int64, copy=False)
|
|
262
|
+
|
|
263
|
+
self.sim_mask_global_ = None
|
|
264
|
+
cache_key = self._sim_mask_cache_key()
|
|
265
|
+
if self.simulate_missing:
|
|
266
|
+
cached_mask = (
|
|
267
|
+
None if cache_key is None else self._sim_mask_cache.get(cache_key)
|
|
268
|
+
)
|
|
269
|
+
if cached_mask is not None:
|
|
270
|
+
self.sim_mask_global_ = cached_mask.copy()
|
|
271
|
+
else:
|
|
272
|
+
tr = SimMissingTransformer(
|
|
273
|
+
genotype_data=self.genotype_data,
|
|
274
|
+
tree_parser=self.tree_parser,
|
|
275
|
+
prop_missing=self.sim_prop,
|
|
276
|
+
strategy=self.sim_strategy,
|
|
277
|
+
missing_val=-9,
|
|
278
|
+
mask_missing=True,
|
|
279
|
+
verbose=self.verbose,
|
|
280
|
+
**self.sim_kwargs,
|
|
281
|
+
)
|
|
282
|
+
tr.fit(X012.copy())
|
|
283
|
+
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
284
|
+
if cache_key is not None:
|
|
285
|
+
self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
|
|
286
|
+
|
|
287
|
+
X_for_model = self.ground_truth_.copy()
|
|
288
|
+
X_for_model[self.sim_mask_global_] = -1
|
|
289
|
+
else:
|
|
290
|
+
X_for_model = self.ground_truth_.copy()
|
|
291
|
+
|
|
292
|
+
# Ploidy/classes
|
|
293
|
+
self.is_haploid = bool(
|
|
294
|
+
np.all(
|
|
295
|
+
np.isin(
|
|
296
|
+
self.genotype_data.snp_data,
|
|
297
|
+
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
self.ploidy = 1 if self.is_haploid else 2
|
|
302
|
+
self.num_classes_ = 2 if self.is_haploid else 3
|
|
303
|
+
self.logger.info(
|
|
304
|
+
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
305
|
+
f"using {self.num_classes_} classes."
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if self.is_haploid:
|
|
309
|
+
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
310
|
+
X_for_model[X_for_model == 2] = 1
|
|
311
|
+
|
|
312
|
+
n_samples, self.num_features_ = X_for_model.shape
|
|
313
|
+
|
|
314
|
+
# Model params (decoder outputs L*K logits)
|
|
315
|
+
self.model_params = {
|
|
316
|
+
"n_features": self.num_features_,
|
|
317
|
+
"num_classes": self.num_classes_,
|
|
318
|
+
"latent_dim": self.latent_dim,
|
|
319
|
+
"dropout_rate": self.dropout_rate,
|
|
320
|
+
"activation": self.activation,
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
# Train/Val split
|
|
324
|
+
indices = np.arange(n_samples)
|
|
325
|
+
train_idx, val_idx = train_test_split(
|
|
326
|
+
indices, test_size=self.validation_split, random_state=self.seed
|
|
327
|
+
)
|
|
328
|
+
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
329
|
+
self.X_train_ = X_for_model[train_idx]
|
|
330
|
+
self.X_val_ = X_for_model[val_idx]
|
|
331
|
+
self.GT_train_full_ = self.ground_truth_[train_idx]
|
|
332
|
+
self.GT_test_full_ = self.ground_truth_[val_idx]
|
|
333
|
+
|
|
334
|
+
if self.sim_mask_global_ is not None:
|
|
335
|
+
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
336
|
+
self.sim_mask_test_ = self.sim_mask_global_[val_idx]
|
|
337
|
+
else:
|
|
338
|
+
self.sim_mask_train_ = None
|
|
339
|
+
self.sim_mask_test_ = None
|
|
340
|
+
|
|
341
|
+
# Plotters/scorers (shared utilities)
|
|
342
|
+
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
343
|
+
|
|
344
|
+
# Optional tuning
|
|
345
|
+
if self.tune:
|
|
346
|
+
self.tune_hyperparameters()
|
|
347
|
+
|
|
348
|
+
# Best params (tuned or default)
|
|
349
|
+
self.best_params_ = getattr(self, "best_params_", self._default_best_params())
|
|
350
|
+
|
|
351
|
+
# Class weights (device-aware)
|
|
352
|
+
self.class_weights_ = self._normalize_class_weights(
|
|
353
|
+
self._class_weights_from_zygosity(self.X_train_)
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# DataLoader
|
|
357
|
+
train_loader = self._get_data_loader(self.X_train_)
|
|
358
|
+
|
|
359
|
+
# Build & train
|
|
360
|
+
model = self.build_model(self.Model, self.best_params_)
|
|
361
|
+
model.apply(self.initialize_weights)
|
|
362
|
+
|
|
363
|
+
loss, trained_model, history = self._train_and_validate_model(
|
|
364
|
+
model=model,
|
|
365
|
+
loader=train_loader,
|
|
366
|
+
lr=self.learning_rate,
|
|
367
|
+
l1_penalty=self.l1_penalty,
|
|
368
|
+
return_history=True,
|
|
369
|
+
class_weights=self.class_weights_,
|
|
370
|
+
X_val=self.X_val_,
|
|
371
|
+
params=self.best_params_,
|
|
372
|
+
prune_metric=self.tune_metric,
|
|
373
|
+
prune_warmup_epochs=5,
|
|
374
|
+
eval_interval=1,
|
|
375
|
+
eval_requires_latents=False, # no latent refinement for eval
|
|
376
|
+
eval_latent_steps=0,
|
|
377
|
+
eval_latent_lr=0.0,
|
|
378
|
+
eval_latent_weight_decay=0.0,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
if trained_model is None:
|
|
382
|
+
msg = "VAE training failed; no model was returned."
|
|
383
|
+
self.logger.error(msg)
|
|
384
|
+
raise RuntimeError(msg)
|
|
385
|
+
|
|
386
|
+
torch.save(
|
|
387
|
+
trained_model.state_dict(),
|
|
388
|
+
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
hist: dict = {"Train": history}
|
|
392
|
+
self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
|
|
393
|
+
self.is_fit_ = True
|
|
394
|
+
|
|
395
|
+
# Evaluate (AE-parity reporting)
|
|
396
|
+
eval_mask = (
|
|
397
|
+
self.sim_mask_test_
|
|
398
|
+
if (self.simulate_missing and self.sim_mask_test_ is not None)
|
|
399
|
+
else None
|
|
400
|
+
)
|
|
401
|
+
self._evaluate_model(
|
|
402
|
+
self.X_val_,
|
|
403
|
+
self.model_,
|
|
404
|
+
self.best_params_,
|
|
405
|
+
eval_mask_override=eval_mask,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
self.plotter_.plot_history(self.history_)
|
|
409
|
+
self._save_best_params(self.best_params_)
|
|
410
|
+
return self
|
|
411
|
+
|
|
412
|
+
def transform(self) -> np.ndarray:
|
|
413
|
+
"""Impute missing genotypes and return IUPAC strings.
|
|
414
|
+
|
|
415
|
+
This method uses the trained VAE model to impute missing genotypes in the dataset. It predicts the most likely genotype for each missing entry based on the learned latent representations and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easy interpretation.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
np.ndarray: IUPAC strings of shape (n_samples, n_loci).
|
|
419
|
+
|
|
420
|
+
Raises:
|
|
421
|
+
NotFittedError: If called before fit().
|
|
422
|
+
"""
|
|
423
|
+
if not getattr(self, "is_fit_", False):
|
|
424
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
425
|
+
|
|
426
|
+
self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
|
|
427
|
+
X_to_impute = self.ground_truth_.copy()
|
|
428
|
+
|
|
429
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
|
|
430
|
+
|
|
431
|
+
# Fill only missing
|
|
432
|
+
missing_mask = X_to_impute == -1
|
|
433
|
+
imputed_array = X_to_impute.copy()
|
|
434
|
+
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
435
|
+
|
|
436
|
+
# Decode to IUPAC & optionally plot
|
|
437
|
+
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
438
|
+
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
439
|
+
|
|
440
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
441
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
442
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
443
|
+
|
|
444
|
+
return imputed_genotypes
|
|
445
|
+
|
|
446
|
+
# ---------- plumbing identical to AE, naming aligned ---------- #
|
|
447
|
+
|
|
448
|
+
def _get_data_loader(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
449
|
+
"""Create DataLoader over indices + integer targets (-1 for missing).
|
|
450
|
+
|
|
451
|
+
This method creates a PyTorch DataLoader for the training data. It converts the input genotype matrix into a tensor and constructs a dataset that includes both the indices and the genotype values. The DataLoader is configured to shuffle the data and use the specified batch size for training.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
y (np.ndarray): 0/1/2 matrix with -1 for missing.
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
458
|
+
"""
|
|
459
|
+
y_tensor = torch.from_numpy(y).long()
|
|
460
|
+
indices = torch.arange(len(y), dtype=torch.long)
|
|
461
|
+
dataset = torch.utils.data.TensorDataset(indices, y_tensor)
|
|
462
|
+
pin_memory = self.device.type == "cuda"
|
|
463
|
+
return torch.utils.data.DataLoader(
|
|
464
|
+
dataset,
|
|
465
|
+
batch_size=self.batch_size,
|
|
466
|
+
shuffle=True,
|
|
467
|
+
pin_memory=pin_memory,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def _train_and_validate_model(
|
|
471
|
+
self,
|
|
472
|
+
model: torch.nn.Module,
|
|
473
|
+
loader: torch.utils.data.DataLoader,
|
|
474
|
+
lr: float,
|
|
475
|
+
l1_penalty: float,
|
|
476
|
+
trial: optuna.Trial | None = None,
|
|
477
|
+
return_history: bool = False,
|
|
478
|
+
class_weights: torch.Tensor | None = None,
|
|
479
|
+
*,
|
|
480
|
+
X_val: np.ndarray | None = None,
|
|
481
|
+
params: dict | None = None,
|
|
482
|
+
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
483
|
+
prune_warmup_epochs: int = 3,
|
|
484
|
+
eval_interval: int = 1,
|
|
485
|
+
eval_requires_latents: bool = False, # VAE: no latent eval refinement
|
|
486
|
+
eval_latent_steps: int = 0,
|
|
487
|
+
eval_latent_lr: float = 0.0,
|
|
488
|
+
eval_latent_weight_decay: float = 0.0,
|
|
489
|
+
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
490
|
+
"""Wrap the VAE training loop with β-anneal & Optuna pruning.
|
|
491
|
+
|
|
492
|
+
This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
model (torch.nn.Module): VAE model.
|
|
496
|
+
loader (torch.utils.data.DataLoader): Training data loader.
|
|
497
|
+
lr (float): Learning rate.
|
|
498
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
499
|
+
trial (optuna.Trial | None): Optuna trial for pruning.
|
|
500
|
+
return_history (bool): If True, return training history.
|
|
501
|
+
class_weights (torch.Tensor | None): CE class weights on device.
|
|
502
|
+
X_val (np.ndarray | None): Validation data for pruning eval.
|
|
503
|
+
params (dict | None): Current hyperparameters (for logging).
|
|
504
|
+
prune_metric (str | None): Metric for pruning decisions.
|
|
505
|
+
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
506
|
+
eval_interval (int): Epochs between validation evaluations.
|
|
507
|
+
eval_requires_latents (bool): If True, refine latents during eval.
|
|
508
|
+
eval_latent_steps (int): Latent refinement steps if needed.
|
|
509
|
+
eval_latent_lr (float): Latent refinement learning rate.
|
|
510
|
+
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
Tuple[float, torch.nn.Module | None, list | None]: Best loss, best model, and training history (if requested).
|
|
514
|
+
"""
|
|
515
|
+
if class_weights is None:
|
|
516
|
+
msg = "Must provide class_weights."
|
|
517
|
+
self.logger.error(msg)
|
|
518
|
+
raise TypeError(msg)
|
|
519
|
+
|
|
520
|
+
max_epochs = (
|
|
521
|
+
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
525
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
526
|
+
|
|
527
|
+
best_loss, best_model, hist = self._execute_training_loop(
|
|
528
|
+
loader=loader,
|
|
529
|
+
optimizer=optimizer,
|
|
530
|
+
scheduler=scheduler,
|
|
531
|
+
model=model,
|
|
532
|
+
l1_penalty=l1_penalty,
|
|
533
|
+
trial=trial,
|
|
534
|
+
return_history=return_history,
|
|
535
|
+
class_weights=class_weights,
|
|
536
|
+
X_val=X_val,
|
|
537
|
+
params=params,
|
|
538
|
+
prune_metric=prune_metric,
|
|
539
|
+
prune_warmup_epochs=prune_warmup_epochs,
|
|
540
|
+
eval_interval=eval_interval,
|
|
541
|
+
eval_requires_latents=eval_requires_latents,
|
|
542
|
+
eval_latent_steps=eval_latent_steps,
|
|
543
|
+
eval_latent_lr=eval_latent_lr,
|
|
544
|
+
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
545
|
+
)
|
|
546
|
+
if return_history:
|
|
547
|
+
return best_loss, best_model, hist
|
|
548
|
+
|
|
549
|
+
return best_loss, best_model, None
|
|
550
|
+
|
|
551
|
+
def _execute_training_loop(
|
|
552
|
+
self,
|
|
553
|
+
loader: torch.utils.data.DataLoader,
|
|
554
|
+
optimizer: torch.optim.Optimizer,
|
|
555
|
+
scheduler: CosineAnnealingLR,
|
|
556
|
+
model: torch.nn.Module,
|
|
557
|
+
l1_penalty: float,
|
|
558
|
+
trial: optuna.Trial | None,
|
|
559
|
+
return_history: bool,
|
|
560
|
+
class_weights: torch.Tensor,
|
|
561
|
+
*,
|
|
562
|
+
X_val: np.ndarray | None = None,
|
|
563
|
+
params: dict | None = None,
|
|
564
|
+
prune_metric: str | None = None,
|
|
565
|
+
prune_warmup_epochs: int = 3,
|
|
566
|
+
eval_interval: int = 1,
|
|
567
|
+
eval_requires_latents: bool = False,
|
|
568
|
+
eval_latent_steps: int = 0,
|
|
569
|
+
eval_latent_lr: float = 0.0,
|
|
570
|
+
eval_latent_weight_decay: float = 0.0,
|
|
571
|
+
) -> Tuple[float, torch.nn.Module, list]:
|
|
572
|
+
"""Train VAE with stable focal CE + KL(β) anneal and numeric guards.
|
|
573
|
+
|
|
574
|
+
This method implements the core training loop for the VAE model, incorporating a focal cross-entropy loss for reconstruction and a KL divergence term with an annealed weight (beta). It includes mechanisms for early stopping based on validation performance, learning rate scheduling, and optional pruning of unpromising trials when using Optuna for hyperparameter optimization. The method ensures numerical stability throughout the training process.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
loader (torch.utils.data.DataLoader): Training data loader.
|
|
578
|
+
optimizer (torch.optim.Optimizer): Optimizer.
|
|
579
|
+
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
|
580
|
+
model (torch.nn.Module): VAE model.
|
|
581
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
582
|
+
trial (optuna.Trial | None): Optuna trial for pruning.
|
|
583
|
+
return_history (bool): If True, return training history.
|
|
584
|
+
class_weights (torch.Tensor): CE class weights on device.
|
|
585
|
+
X_val (np.ndarray | None): Validation data for pruning eval.
|
|
586
|
+
params (dict | None): Current hyperparameters (for logging).
|
|
587
|
+
prune_metric (str | None): Metric for pruning decisions.
|
|
588
|
+
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
589
|
+
eval_interval (int): Epochs between validation evaluations.
|
|
590
|
+
eval_requires_latents (bool): If True, refine latents during eval.
|
|
591
|
+
eval_latent_steps (int): Latent refinement steps if needed.
|
|
592
|
+
eval_latent_lr (float): Latent refinement learning rate.
|
|
593
|
+
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
Tuple[float, torch.nn.Module, list]: Best loss, best model, and training history.
|
|
597
|
+
"""
|
|
598
|
+
best_model = None
|
|
599
|
+
history: list[float] = []
|
|
600
|
+
|
|
601
|
+
early_stopping = EarlyStopping(
|
|
602
|
+
patience=self.early_stop_gen,
|
|
603
|
+
min_epochs=self.min_epochs,
|
|
604
|
+
verbose=self.verbose,
|
|
605
|
+
prefix=self.prefix,
|
|
606
|
+
debug=self.debug,
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
# ---- scalarize schedule endpoints up front ----
|
|
610
|
+
kl = self.kl_beta_final
|
|
611
|
+
if isinstance(kl, (list, tuple)):
|
|
612
|
+
if len(kl) == 0:
|
|
613
|
+
msg = "kl_beta_final list is empty."
|
|
614
|
+
self.logger.error(msg)
|
|
615
|
+
raise ValueError(msg)
|
|
616
|
+
kl = kl[0]
|
|
617
|
+
beta_final = float(kl)
|
|
618
|
+
|
|
619
|
+
gamma_val = self.gamma
|
|
620
|
+
if isinstance(gamma_val, (list, tuple)):
|
|
621
|
+
if len(gamma_val) == 0:
|
|
622
|
+
raise ValueError("gamma list is empty.")
|
|
623
|
+
gamma_val = gamma_val[0]
|
|
624
|
+
gamma_final = float(gamma_val)
|
|
625
|
+
|
|
626
|
+
gamma_warm, gamma_ramp = 50, 100
|
|
627
|
+
beta_warm, beta_ramp = int(self.kl_warmup), int(self.kl_ramp)
|
|
628
|
+
|
|
629
|
+
# Optional LR warmup
|
|
630
|
+
warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
|
|
631
|
+
base_lr = float(optimizer.param_groups[0]["lr"])
|
|
632
|
+
min_lr = base_lr * 0.1
|
|
633
|
+
|
|
634
|
+
max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
|
|
635
|
+
|
|
636
|
+
for epoch in range(max_epochs):
|
|
637
|
+
# focal γ schedule
|
|
638
|
+
if epoch < gamma_warm:
|
|
639
|
+
model.gamma = 0.0 # type: ignore[attr-defined]
|
|
640
|
+
elif epoch < gamma_warm + gamma_ramp:
|
|
641
|
+
model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore[attr-defined]
|
|
642
|
+
else:
|
|
643
|
+
model.gamma = gamma_final # type: ignore[attr-defined]
|
|
644
|
+
|
|
645
|
+
# KL β schedule (float throughout + ramp guard)
|
|
646
|
+
if epoch < beta_warm:
|
|
647
|
+
model.beta = 0.0 # type: ignore[attr-defined]
|
|
648
|
+
elif beta_ramp > 0 and epoch < beta_warm + beta_ramp:
|
|
649
|
+
model.beta = beta_final * ((epoch - beta_warm) / beta_ramp) # type: ignore[attr-defined]
|
|
650
|
+
else:
|
|
651
|
+
model.beta = beta_final # type: ignore[attr-defined]
|
|
652
|
+
# LR warmup
|
|
653
|
+
if epoch < warmup_epochs:
|
|
654
|
+
scale = float(epoch + 1) / warmup_epochs
|
|
655
|
+
for g in optimizer.param_groups:
|
|
656
|
+
g["lr"] = min_lr + (base_lr - min_lr) * scale
|
|
657
|
+
|
|
658
|
+
train_loss = self._train_step(
|
|
659
|
+
loader=loader,
|
|
660
|
+
optimizer=optimizer,
|
|
661
|
+
model=model,
|
|
662
|
+
l1_penalty=l1_penalty,
|
|
663
|
+
class_weights=class_weights,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
if not np.isfinite(train_loss):
|
|
667
|
+
if trial:
|
|
668
|
+
raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
|
|
669
|
+
# shrink LR and continue
|
|
670
|
+
for g in optimizer.param_groups:
|
|
671
|
+
g["lr"] *= 0.5
|
|
672
|
+
continue
|
|
673
|
+
|
|
674
|
+
if scheduler is not None:
|
|
675
|
+
scheduler.step()
|
|
676
|
+
|
|
677
|
+
if return_history:
|
|
678
|
+
history.append(train_loss)
|
|
679
|
+
|
|
680
|
+
early_stopping(train_loss, model)
|
|
681
|
+
if early_stopping.early_stop:
|
|
682
|
+
self.logger.info(f"Early stopping at epoch {epoch + 1}.")
|
|
683
|
+
break
|
|
684
|
+
|
|
685
|
+
# Optional Optuna pruning on a validation metric
|
|
686
|
+
if (
|
|
687
|
+
trial is not None
|
|
688
|
+
and X_val is not None
|
|
689
|
+
and ((epoch + 1) % eval_interval == 0)
|
|
690
|
+
):
|
|
691
|
+
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
692
|
+
mask_override = None
|
|
693
|
+
if (
|
|
694
|
+
self.simulate_missing
|
|
695
|
+
and getattr(self, "sim_mask_test_", None) is not None
|
|
696
|
+
and getattr(self, "X_val_", None) is not None
|
|
697
|
+
and X_val.shape == self.X_val_.shape
|
|
698
|
+
):
|
|
699
|
+
mask_override = self.sim_mask_test_
|
|
700
|
+
metric_val = self._eval_for_pruning(
|
|
701
|
+
model=model,
|
|
702
|
+
X_val=X_val,
|
|
703
|
+
params=params or getattr(self, "best_params_", {}),
|
|
704
|
+
metric=metric_key,
|
|
705
|
+
objective_mode=True,
|
|
706
|
+
do_latent_infer=False, # VAE uses encoder; no latent refine
|
|
707
|
+
latent_steps=0,
|
|
708
|
+
latent_lr=0.0,
|
|
709
|
+
latent_weight_decay=0.0,
|
|
710
|
+
latent_seed=self.seed, # type: ignore
|
|
711
|
+
_latent_cache=None,
|
|
712
|
+
_latent_cache_key=None,
|
|
713
|
+
eval_mask_override=mask_override,
|
|
714
|
+
)
|
|
715
|
+
trial.report(metric_val, step=epoch + 1)
|
|
716
|
+
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
717
|
+
raise optuna.exceptions.TrialPruned(
|
|
718
|
+
f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
best_loss = early_stopping.best_score
|
|
722
|
+
best_model = copy.deepcopy(early_stopping.best_model)
|
|
723
|
+
if best_model is None:
|
|
724
|
+
best_model = copy.deepcopy(model)
|
|
725
|
+
return best_loss, best_model, history
|
|
726
|
+
|
|
727
|
+
def _train_step(
|
|
728
|
+
self,
|
|
729
|
+
loader: torch.utils.data.DataLoader,
|
|
730
|
+
optimizer: torch.optim.Optimizer,
|
|
731
|
+
model: torch.nn.Module,
|
|
732
|
+
l1_penalty: float,
|
|
733
|
+
class_weights: torch.Tensor,
|
|
734
|
+
) -> float:
|
|
735
|
+
"""One epoch: inputs → VAE forward → focal CE + β·KL with guards.
|
|
736
|
+
|
|
737
|
+
This method performs a single training epoch for the VAE model. It processes batches of data, computes the reconstruction and KL divergence losses, applies L1 regularization if specified, and updates the model parameters. The method includes safeguards against non-finite values in the model outputs and gradients to ensure stable training.
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
loader (torch.utils.data.DataLoader): Training data loader.
|
|
741
|
+
optimizer (torch.optim.Optimizer): Optimizer.
|
|
742
|
+
model (torch.nn.Module): VAE model.
|
|
743
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
744
|
+
class_weights (torch.Tensor): CE class weights on device.
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
float: Average training loss for the epoch.
|
|
748
|
+
"""
|
|
749
|
+
model.train()
|
|
750
|
+
running, used = 0.0, 0
|
|
751
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
752
|
+
if class_weights is not None and class_weights.device != self.device:
|
|
753
|
+
class_weights = class_weights.to(self.device)
|
|
754
|
+
|
|
755
|
+
for _, y_batch in loader:
|
|
756
|
+
optimizer.zero_grad(set_to_none=True)
|
|
757
|
+
|
|
758
|
+
# targets: (B, L) int in {0,1,2,-1}
|
|
759
|
+
y_int = y_batch.to(self.device, non_blocking=True).long()
|
|
760
|
+
|
|
761
|
+
# inputs: one-hot with zeros for missing
|
|
762
|
+
x_ohe = self._one_hot_encode_012(y_int) # (B, L, K)
|
|
763
|
+
|
|
764
|
+
# Forward. Expect model to return recon_logits, mu, logvar, ...
|
|
765
|
+
out = model(x_ohe)
|
|
766
|
+
if isinstance(out, (list, tuple)):
|
|
767
|
+
recon_logits, mu, logvar = out[0], out[1], out[2]
|
|
768
|
+
else:
|
|
769
|
+
recon_logits, mu, logvar = out["recon_logits"], out["mu"], out["logvar"]
|
|
770
|
+
|
|
771
|
+
# Upstream guard
|
|
772
|
+
if (
|
|
773
|
+
not torch.isfinite(recon_logits).all()
|
|
774
|
+
or not torch.isfinite(mu).all()
|
|
775
|
+
or not torch.isfinite(logvar).all()
|
|
776
|
+
):
|
|
777
|
+
continue
|
|
778
|
+
|
|
779
|
+
gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
|
|
780
|
+
beta = float(getattr(model, "beta", getattr(self, "kl_beta_final", 0.0)))
|
|
781
|
+
gamma = max(0.0, min(gamma, 10.0))
|
|
782
|
+
|
|
783
|
+
loss = compute_vae_loss(
|
|
784
|
+
recon_logits=recon_logits,
|
|
785
|
+
targets=y_int,
|
|
786
|
+
mu=mu,
|
|
787
|
+
logvar=logvar,
|
|
788
|
+
class_weights=class_weights,
|
|
789
|
+
gamma=gamma,
|
|
790
|
+
beta=beta,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
if l1_penalty > 0:
|
|
794
|
+
l1 = torch.zeros((), device=self.device)
|
|
795
|
+
for p in l1_params:
|
|
796
|
+
l1 = l1 + p.abs().sum()
|
|
797
|
+
loss = loss + l1_penalty * l1
|
|
798
|
+
|
|
799
|
+
if not torch.isfinite(loss):
|
|
800
|
+
continue
|
|
801
|
+
|
|
802
|
+
loss.backward()
|
|
803
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
804
|
+
|
|
805
|
+
# skip update if any grad is non-finite
|
|
806
|
+
bad = any(
|
|
807
|
+
p.grad is not None and not torch.isfinite(p.grad).all()
|
|
808
|
+
for p in model.parameters()
|
|
809
|
+
)
|
|
810
|
+
if bad:
|
|
811
|
+
optimizer.zero_grad(set_to_none=True)
|
|
812
|
+
continue
|
|
813
|
+
|
|
814
|
+
optimizer.step()
|
|
815
|
+
|
|
816
|
+
running += float(loss.detach().item())
|
|
817
|
+
used += 1
|
|
818
|
+
|
|
819
|
+
return (running / used) if used > 0 else float("inf")
|
|
820
|
+
|
|
821
|
+
def _predict(
|
|
822
|
+
self,
|
|
823
|
+
model: torch.nn.Module,
|
|
824
|
+
X: np.ndarray | torch.Tensor,
|
|
825
|
+
return_proba: bool = False,
|
|
826
|
+
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
|
|
827
|
+
"""Predict 0/1/2 labels (and probabilities) from masked inputs.
|
|
828
|
+
|
|
829
|
+
This method uses the trained VAE model to predict genotype labels for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted labels and probabilities. The method can return either just the predicted labels or both labels and probabilities based on the `return_proba` flag.
|
|
830
|
+
|
|
831
|
+
Args:
|
|
832
|
+
model (torch.nn.Module): Trained model.
|
|
833
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
|
|
834
|
+
return_proba (bool): If True, also return probabilities.
|
|
835
|
+
|
|
836
|
+
Returns:
|
|
837
|
+
Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels, and probabilities if requested.
|
|
838
|
+
"""
|
|
839
|
+
if model is None:
|
|
840
|
+
msg = "Model is not trained. Call fit() before predict()."
|
|
841
|
+
self.logger.error(msg)
|
|
842
|
+
raise NotFittedError(msg)
|
|
843
|
+
|
|
844
|
+
model.eval()
|
|
845
|
+
with torch.no_grad():
|
|
846
|
+
X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
|
|
847
|
+
X_tensor = X_tensor.to(self.device).long()
|
|
848
|
+
x_ohe = self._one_hot_encode_012(X_tensor)
|
|
849
|
+
outputs = model(x_ohe) # first element must be recon logits
|
|
850
|
+
logits = outputs[0].view(-1, self.num_features_, self.num_classes_)
|
|
851
|
+
probas = torch.softmax(logits, dim=-1)
|
|
852
|
+
labels = torch.argmax(probas, dim=-1)
|
|
853
|
+
|
|
854
|
+
if return_proba:
|
|
855
|
+
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
856
|
+
|
|
857
|
+
return labels.cpu().numpy()
|
|
858
|
+
|
|
859
|
+
def _evaluate_model(
|
|
860
|
+
self,
|
|
861
|
+
X_val: np.ndarray,
|
|
862
|
+
model: torch.nn.Module,
|
|
863
|
+
params: dict,
|
|
864
|
+
objective_mode: bool = False,
|
|
865
|
+
latent_vectors_val: np.ndarray | None = None,
|
|
866
|
+
*,
|
|
867
|
+
eval_mask_override: np.ndarray | None = None,
|
|
868
|
+
) -> Dict[str, float]:
|
|
869
|
+
"""Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
|
|
870
|
+
|
|
871
|
+
This method evaluates the trained VAE model on a validation dataset, computing various performance metrics. It handles missing data appropriately and generates detailed classification reports for both the original 0/1/2 encoding and the decoded IUPAC and integer formats. The evaluation metrics are logged for review.
|
|
872
|
+
|
|
873
|
+
Args:
|
|
874
|
+
X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
|
|
875
|
+
model (torch.nn.Module): Trained model.
|
|
876
|
+
params (dict): Current hyperparameters (for logging).
|
|
877
|
+
objective_mode (bool): If True, minimize logging for Optuna.
|
|
878
|
+
latent_vectors_val (np.ndarray | None): Not used by VAE.
|
|
879
|
+
eval_mask_override (np.ndarray | None): Optional mask to override default eval mask.
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
Dict[str, float]: Computed metrics.
|
|
883
|
+
|
|
884
|
+
Raises:
|
|
885
|
+
NotFittedError: If called before fit().
|
|
886
|
+
"""
|
|
887
|
+
pred_labels, pred_probas = self._predict(
|
|
888
|
+
model=model, X=X_val, return_proba=True
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
|
|
892
|
+
|
|
893
|
+
# FIX 1: Match rows (shape[0]) only to allow feature subsets (tune_fast)
|
|
894
|
+
if (
|
|
895
|
+
hasattr(self, "X_val_")
|
|
896
|
+
and getattr(self, "X_val_", None) is not None
|
|
897
|
+
and X_val.shape[0] == self.X_val_.shape[0]
|
|
898
|
+
):
|
|
899
|
+
GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
|
|
900
|
+
elif (
|
|
901
|
+
hasattr(self, "X_train_")
|
|
902
|
+
and getattr(self, "X_train_", None) is not None
|
|
903
|
+
and X_val.shape[0] == self.X_train_.shape[0]
|
|
904
|
+
):
|
|
905
|
+
GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
|
|
906
|
+
else:
|
|
907
|
+
GT_ref = self.ground_truth_
|
|
908
|
+
|
|
909
|
+
# FIX 2: Handle Feature Mismatch
|
|
910
|
+
# If GT has more columns than X_val, slice it to match.
|
|
911
|
+
if GT_ref.shape[1] > X_val.shape[1]:
|
|
912
|
+
GT_ref = GT_ref[:, : X_val.shape[1]]
|
|
913
|
+
|
|
914
|
+
# Fallback safeguard
|
|
915
|
+
if GT_ref.shape != X_val.shape:
|
|
916
|
+
GT_ref = X_val
|
|
917
|
+
|
|
918
|
+
# FIX 3: Allow override mask to be sliced if it's too wide
|
|
919
|
+
if eval_mask_override is not None:
|
|
920
|
+
if eval_mask_override.shape[0] != X_val.shape[0]:
|
|
921
|
+
msg = (
|
|
922
|
+
f"eval_mask_override rows {eval_mask_override.shape[0]} "
|
|
923
|
+
f"does not match X_val rows {X_val.shape[0]}"
|
|
924
|
+
)
|
|
925
|
+
self.logger.error(msg)
|
|
926
|
+
raise ValueError(msg)
|
|
927
|
+
|
|
928
|
+
if eval_mask_override.shape[1] > X_val.shape[1]:
|
|
929
|
+
eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
|
|
930
|
+
else:
|
|
931
|
+
eval_mask = eval_mask_override.astype(bool)
|
|
932
|
+
else:
|
|
933
|
+
eval_mask = X_val != -1
|
|
934
|
+
|
|
935
|
+
eval_mask = eval_mask & finite_mask & (GT_ref != -1)
|
|
936
|
+
|
|
937
|
+
y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
|
|
938
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
|
|
939
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
|
|
940
|
+
|
|
941
|
+
if y_true_flat.size == 0:
|
|
942
|
+
return {self.tune_metric: 0.0}
|
|
943
|
+
|
|
944
|
+
# ensure valid probability simplex after masking
|
|
945
|
+
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
946
|
+
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
947
|
+
row_sums[row_sums == 0] = 1.0
|
|
948
|
+
y_proba_flat = y_proba_flat / row_sums
|
|
949
|
+
|
|
950
|
+
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
951
|
+
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
952
|
+
|
|
953
|
+
if self.is_haploid:
|
|
954
|
+
y_true_flat = y_true_flat.copy()
|
|
955
|
+
y_pred_flat = y_pred_flat.copy()
|
|
956
|
+
y_true_flat[y_true_flat == 2] = 1
|
|
957
|
+
y_pred_flat[y_pred_flat == 2] = 1
|
|
958
|
+
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
959
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
960
|
+
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
961
|
+
y_proba_flat = proba_2
|
|
962
|
+
|
|
963
|
+
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
964
|
+
|
|
965
|
+
metrics = self.scorers_.evaluate(
|
|
966
|
+
y_true_flat,
|
|
967
|
+
y_pred_flat,
|
|
968
|
+
y_true_ohe,
|
|
969
|
+
y_proba_flat,
|
|
970
|
+
objective_mode,
|
|
971
|
+
self.tune_metric,
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
if not objective_mode:
|
|
975
|
+
pm = PrettyMetrics(
|
|
976
|
+
metrics, precision=3, title=f"{self.model_name} Validation Metrics"
|
|
977
|
+
)
|
|
978
|
+
pm.render()
|
|
979
|
+
|
|
980
|
+
# Primary report
|
|
981
|
+
self._make_class_reports(
|
|
982
|
+
y_true=y_true_flat,
|
|
983
|
+
y_pred_proba=y_proba_flat,
|
|
984
|
+
y_pred=y_pred_flat,
|
|
985
|
+
metrics=metrics,
|
|
986
|
+
labels=target_names,
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
# IUPAC decode & 10-base integer report
|
|
990
|
+
# FIX 4: Use current shape (X_val.shape) not self.num_features_
|
|
991
|
+
y_true_dec = self.pgenc.decode_012(
|
|
992
|
+
GT_ref.reshape(X_val.shape[0], X_val.shape[1])
|
|
993
|
+
)
|
|
994
|
+
X_pred = X_val.copy()
|
|
995
|
+
X_pred[eval_mask] = y_pred_flat
|
|
996
|
+
y_pred_dec = self.pgenc.decode_012(
|
|
997
|
+
X_pred.reshape(X_val.shape[0], X_val.shape[1])
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
encodings_dict = {
|
|
1001
|
+
"A": 0,
|
|
1002
|
+
"C": 1,
|
|
1003
|
+
"G": 2,
|
|
1004
|
+
"T": 3,
|
|
1005
|
+
"W": 4,
|
|
1006
|
+
"R": 5,
|
|
1007
|
+
"M": 6,
|
|
1008
|
+
"K": 7,
|
|
1009
|
+
"Y": 8,
|
|
1010
|
+
"S": 9,
|
|
1011
|
+
"N": -1,
|
|
1012
|
+
}
|
|
1013
|
+
y_true_int = self.pgenc.convert_int_iupac(
|
|
1014
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
1015
|
+
)
|
|
1016
|
+
y_pred_int = self.pgenc.convert_int_iupac(
|
|
1017
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
valid_iupac_mask = y_true_int[eval_mask] >= 0
|
|
1021
|
+
if valid_iupac_mask.any():
|
|
1022
|
+
self._make_class_reports(
|
|
1023
|
+
y_true=y_true_int[eval_mask][valid_iupac_mask],
|
|
1024
|
+
y_pred=y_pred_int[eval_mask][valid_iupac_mask],
|
|
1025
|
+
metrics=metrics,
|
|
1026
|
+
y_pred_proba=None,
|
|
1027
|
+
labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
|
|
1028
|
+
)
|
|
1029
|
+
else:
|
|
1030
|
+
self.logger.warning(
|
|
1031
|
+
"Skipped IUPAC confusion matrix: No valid ground truths."
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
return metrics
|
|
1035
|
+
|
|
1036
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
1037
|
+
"""Optuna objective for VAE (no latent refinement during eval).
|
|
1038
|
+
|
|
1039
|
+
This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, trains the VAE model with these parameters, and evaluates its performance on a validation set. The evaluation metric specified by `self.tune_metric` is returned for optimization. If training fails, the trial is pruned to keep the tuning process efficient.
|
|
1040
|
+
|
|
1041
|
+
Args:
|
|
1042
|
+
trial (optuna.Trial): Optuna trial object.
|
|
1043
|
+
|
|
1044
|
+
Returns:
|
|
1045
|
+
float: Value of the tuning metric to be optimized.
|
|
1046
|
+
"""
|
|
1047
|
+
try:
|
|
1048
|
+
params = self._sample_hyperparameters(trial)
|
|
1049
|
+
|
|
1050
|
+
X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
|
|
1051
|
+
X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
|
|
1052
|
+
|
|
1053
|
+
class_weights = self._normalize_class_weights(
|
|
1054
|
+
self._class_weights_from_zygosity(X_train)
|
|
1055
|
+
)
|
|
1056
|
+
train_loader = self._get_data_loader(X_train)
|
|
1057
|
+
|
|
1058
|
+
model = self.build_model(self.Model, params["model_params"])
|
|
1059
|
+
model.apply(self.initialize_weights)
|
|
1060
|
+
|
|
1061
|
+
lr: float = params["lr"]
|
|
1062
|
+
l1_penalty: float = params["l1_penalty"]
|
|
1063
|
+
|
|
1064
|
+
# Train + prune on metric
|
|
1065
|
+
_, model, __ = self._train_and_validate_model(
|
|
1066
|
+
model=model,
|
|
1067
|
+
loader=train_loader,
|
|
1068
|
+
lr=lr,
|
|
1069
|
+
l1_penalty=l1_penalty,
|
|
1070
|
+
trial=trial,
|
|
1071
|
+
return_history=False,
|
|
1072
|
+
class_weights=class_weights,
|
|
1073
|
+
X_val=X_val,
|
|
1074
|
+
params=params,
|
|
1075
|
+
prune_metric=self.tune_metric,
|
|
1076
|
+
prune_warmup_epochs=5,
|
|
1077
|
+
eval_interval=self.tune_eval_interval,
|
|
1078
|
+
eval_requires_latents=False,
|
|
1079
|
+
eval_latent_steps=0,
|
|
1080
|
+
eval_latent_lr=0.0,
|
|
1081
|
+
eval_latent_weight_decay=0.0,
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
eval_mask = (
|
|
1085
|
+
self.sim_mask_test_
|
|
1086
|
+
if (
|
|
1087
|
+
self.simulate_missing
|
|
1088
|
+
and getattr(self, "sim_mask_test_", None) is not None
|
|
1089
|
+
)
|
|
1090
|
+
else None
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
if model is None:
|
|
1094
|
+
raise RuntimeError("Model training failed; no model was returned.")
|
|
1095
|
+
|
|
1096
|
+
metrics = self._evaluate_model(
|
|
1097
|
+
X_val, model, params, objective_mode=True, eval_mask_override=eval_mask
|
|
1098
|
+
)
|
|
1099
|
+
self._clear_resources(model, train_loader)
|
|
1100
|
+
return metrics[self.tune_metric]
|
|
1101
|
+
|
|
1102
|
+
except Exception as e:
|
|
1103
|
+
# Keep sweeps moving
|
|
1104
|
+
self.logger.debug(f"Trial failed with error: {e}")
|
|
1105
|
+
raise optuna.exceptions.TrialPruned(
|
|
1106
|
+
f"Trial failed with error. Enable debug logging for details."
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
|
|
1110
|
+
"""Sample VAE hyperparameters; hidden sizes mirror AE/NLPCA helper.
|
|
1111
|
+
|
|
1112
|
+
Args:
|
|
1113
|
+
trial (optuna.Trial): Optuna trial object.
|
|
1114
|
+
|
|
1115
|
+
Returns:
|
|
1116
|
+
Dict[str, int | float | str]: Sampled hyperparameters.
|
|
1117
|
+
"""
|
|
1118
|
+
params = {
|
|
1119
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 64),
|
|
1120
|
+
"lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
|
|
1121
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
|
|
1122
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
|
|
1123
|
+
"activation": trial.suggest_categorical(
|
|
1124
|
+
"activation", ["relu", "elu", "selu"]
|
|
1125
|
+
),
|
|
1126
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
|
|
1127
|
+
"layer_scaling_factor": trial.suggest_float(
|
|
1128
|
+
"layer_scaling_factor", 2.0, 10.0
|
|
1129
|
+
),
|
|
1130
|
+
"layer_schedule": trial.suggest_categorical(
|
|
1131
|
+
"layer_schedule", ["pyramid", "constant", "linear"]
|
|
1132
|
+
),
|
|
1133
|
+
# VAE-specific β (final value after anneal)
|
|
1134
|
+
"beta": trial.suggest_float("beta", 0.25, 4.0),
|
|
1135
|
+
# focal gamma (if used in VAE recon CE)
|
|
1136
|
+
"gamma": trial.suggest_float("gamma", 0.0, 5.0),
|
|
1137
|
+
}
|
|
1138
|
+
|
|
1139
|
+
input_dim = self.num_features_ * self.num_classes_
|
|
1140
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1141
|
+
n_inputs=input_dim,
|
|
1142
|
+
n_outputs=input_dim,
|
|
1143
|
+
n_samples=len(self.train_idx_),
|
|
1144
|
+
n_hidden=params["num_hidden_layers"],
|
|
1145
|
+
alpha=params["layer_scaling_factor"],
|
|
1146
|
+
schedule=params["layer_schedule"],
|
|
1147
|
+
)
|
|
1148
|
+
|
|
1149
|
+
# [latent_dim] + interior widths (exclude output width)
|
|
1150
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1151
|
+
|
|
1152
|
+
params["model_params"] = {
|
|
1153
|
+
"n_features": self.num_features_,
|
|
1154
|
+
"num_classes": self.num_classes_,
|
|
1155
|
+
"latent_dim": params["latent_dim"],
|
|
1156
|
+
"dropout_rate": params["dropout_rate"],
|
|
1157
|
+
"hidden_layer_sizes": hidden_only,
|
|
1158
|
+
"activation": params["activation"],
|
|
1159
|
+
# Pass through VAE recon/regularization coefficients
|
|
1160
|
+
"beta": params["beta"],
|
|
1161
|
+
"gamma": params["gamma"],
|
|
1162
|
+
}
|
|
1163
|
+
return params
|
|
1164
|
+
|
|
1165
|
+
def _set_best_params(self, best_params: dict) -> dict:
|
|
1166
|
+
"""Adopt best params and return VAE model_params.
|
|
1167
|
+
|
|
1168
|
+
Args:
|
|
1169
|
+
best_params (dict): Best hyperparameters from tuning.
|
|
1170
|
+
|
|
1171
|
+
Returns:
|
|
1172
|
+
dict: VAE model parameters.
|
|
1173
|
+
"""
|
|
1174
|
+
self.latent_dim = best_params["latent_dim"]
|
|
1175
|
+
self.dropout_rate = best_params["dropout_rate"]
|
|
1176
|
+
self.learning_rate = best_params["learning_rate"]
|
|
1177
|
+
self.l1_penalty = best_params["l1_penalty"]
|
|
1178
|
+
self.activation = best_params["activation"]
|
|
1179
|
+
self.layer_scaling_factor = best_params["layer_scaling_factor"]
|
|
1180
|
+
self.layer_schedule = best_params["layer_schedule"]
|
|
1181
|
+
self.kl_beta_final = best_params.get("beta", self.kl_beta_final)
|
|
1182
|
+
self.gamma = best_params.get("gamma", self.gamma)
|
|
1183
|
+
|
|
1184
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1185
|
+
n_inputs=self.num_features_ * self.num_classes_,
|
|
1186
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
1187
|
+
n_samples=len(self.train_idx_),
|
|
1188
|
+
n_hidden=best_params["num_hidden_layers"],
|
|
1189
|
+
alpha=best_params["layer_scaling_factor"],
|
|
1190
|
+
schedule=best_params["layer_schedule"],
|
|
1191
|
+
)
|
|
1192
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1193
|
+
|
|
1194
|
+
return {
|
|
1195
|
+
"n_features": self.num_features_,
|
|
1196
|
+
"latent_dim": self.latent_dim,
|
|
1197
|
+
"hidden_layer_sizes": hidden_only,
|
|
1198
|
+
"dropout_rate": self.dropout_rate,
|
|
1199
|
+
"activation": self.activation,
|
|
1200
|
+
"num_classes": self.num_classes_,
|
|
1201
|
+
"beta": self.kl_beta_final,
|
|
1202
|
+
"gamma": self.gamma,
|
|
1203
|
+
}
|
|
1204
|
+
|
|
1205
|
+
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
1206
|
+
"""Default VAE model params when tuning is disabled.
|
|
1207
|
+
|
|
1208
|
+
Returns:
|
|
1209
|
+
Dict[str, int | float | str | list]: VAE model parameters.
|
|
1210
|
+
"""
|
|
1211
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1212
|
+
n_inputs=self.num_features_ * self.num_classes_,
|
|
1213
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
1214
|
+
n_samples=len(self.ground_truth_),
|
|
1215
|
+
n_hidden=self.num_hidden_layers,
|
|
1216
|
+
alpha=self.layer_scaling_factor,
|
|
1217
|
+
schedule=self.layer_schedule,
|
|
1218
|
+
)
|
|
1219
|
+
return {
|
|
1220
|
+
"n_features": self.num_features_,
|
|
1221
|
+
"latent_dim": self.latent_dim,
|
|
1222
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1223
|
+
"dropout_rate": self.dropout_rate,
|
|
1224
|
+
"activation": self.activation,
|
|
1225
|
+
"num_classes": self.num_classes_,
|
|
1226
|
+
"beta": self.kl_beta_final,
|
|
1227
|
+
"gamma": self.gamma,
|
|
1228
|
+
}
|