pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,24 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import copy
|
|
4
|
-
|
|
4
|
+
import traceback
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
|
|
5
7
|
|
|
6
8
|
import matplotlib.pyplot as plt
|
|
7
9
|
import numpy as np
|
|
8
10
|
import optuna
|
|
9
11
|
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
10
13
|
from sklearn.exceptions import NotFittedError
|
|
11
|
-
from sklearn.model_selection import train_test_split
|
|
12
14
|
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
13
15
|
from snpio.utils.logging import LoggerManager
|
|
14
16
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
15
17
|
|
|
16
18
|
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
17
19
|
from pgsui.data_processing.containers import VAEConfig
|
|
18
|
-
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
19
20
|
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
20
21
|
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
21
|
-
from pgsui.impute.unsupervised.loss_functions import compute_vae_loss
|
|
22
|
+
from pgsui.impute.unsupervised.loss_functions import FocalCELoss, compute_vae_loss
|
|
22
23
|
from pgsui.impute.unsupervised.models.vae_model import VAEModel
|
|
23
24
|
from pgsui.utils.logging_utils import configure_logger
|
|
24
25
|
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
@@ -28,14 +29,47 @@ if TYPE_CHECKING:
|
|
|
28
29
|
from snpio.read_input.genotype_data import GenotypeData
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
def
|
|
32
|
-
|
|
32
|
+
def _make_warmup_cosine_scheduler(
|
|
33
|
+
optimizer: torch.optim.Optimizer,
|
|
34
|
+
*,
|
|
35
|
+
max_epochs: int,
|
|
36
|
+
warmup_epochs: int,
|
|
37
|
+
start_factor: float = 0.1,
|
|
38
|
+
) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
|
|
39
|
+
"""Create a warmup->cosine LR scheduler.
|
|
33
40
|
|
|
34
41
|
Args:
|
|
35
|
-
|
|
42
|
+
optimizer (torch.optim.Optimizer): The optimizer to schedule.
|
|
43
|
+
max_epochs (int): Total number of epochs for training.
|
|
44
|
+
warmup_epochs (int): Number of warmup epochs.
|
|
45
|
+
start_factor (float): Initial LR factor for warmup.
|
|
36
46
|
|
|
37
47
|
Returns:
|
|
38
|
-
|
|
48
|
+
torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: The learning rate scheduler.
|
|
49
|
+
"""
|
|
50
|
+
warmup_epochs = int(max(0, warmup_epochs))
|
|
51
|
+
|
|
52
|
+
if warmup_epochs == 0:
|
|
53
|
+
return CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
54
|
+
|
|
55
|
+
warmup = torch.optim.lr_scheduler.LinearLR(
|
|
56
|
+
optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
|
|
57
|
+
)
|
|
58
|
+
cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
|
|
59
|
+
|
|
60
|
+
return torch.optim.lr_scheduler.SequentialLR(
|
|
61
|
+
optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def ensure_vae_config(config: VAEConfig | dict | str | None) -> VAEConfig:
|
|
66
|
+
"""Ensure a VAEConfig instance from various input types.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
config (VAEConfig | dict | str | None): Configuration input.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
VAEConfig: The resulting VAEConfig instance.
|
|
39
73
|
"""
|
|
40
74
|
if config is None:
|
|
41
75
|
return VAEConfig()
|
|
@@ -44,13 +78,13 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
|
|
|
44
78
|
if isinstance(config, str):
|
|
45
79
|
return load_yaml_to_dataclass(config, VAEConfig)
|
|
46
80
|
if isinstance(config, dict):
|
|
81
|
+
cfg_in = copy.deepcopy(config)
|
|
47
82
|
base = VAEConfig()
|
|
48
|
-
|
|
49
|
-
|
|
83
|
+
preset = cfg_in.pop("preset", None)
|
|
84
|
+
if "io" in cfg_in and isinstance(cfg_in["io"], dict):
|
|
85
|
+
preset = preset or cfg_in["io"].pop("preset", None)
|
|
50
86
|
if preset:
|
|
51
87
|
base = VAEConfig.from_preset(preset)
|
|
52
|
-
# Flatten + apply
|
|
53
|
-
flat: Dict[str, object] = {}
|
|
54
88
|
|
|
55
89
|
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
56
90
|
for k, v in d.items():
|
|
@@ -61,15 +95,19 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
|
|
|
61
95
|
out[kk] = v
|
|
62
96
|
return out
|
|
63
97
|
|
|
64
|
-
flat = _flatten("",
|
|
98
|
+
flat = _flatten("", cfg_in, {})
|
|
65
99
|
return apply_dot_overrides(base, flat)
|
|
66
100
|
raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
|
|
67
101
|
|
|
68
102
|
|
|
69
103
|
class ImputeVAE(BaseNNImputer):
|
|
70
|
-
"""Variational Autoencoder imputer
|
|
104
|
+
"""Variational Autoencoder (VAE) imputer for 0/1/2 genotypes.
|
|
105
|
+
|
|
106
|
+
Trains a VAE on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. The workflow simulates missingness once on the full matrix, then creates train/val/test splits. It supports haploid and diploid data, focal-CE reconstruction loss with a KL term (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
|
|
71
107
|
|
|
72
|
-
|
|
108
|
+
Notes:
|
|
109
|
+
- Training includes early stopping based on validation loss.
|
|
110
|
+
- The imputer can handle both haploid and diploid genotype data.
|
|
73
111
|
"""
|
|
74
112
|
|
|
75
113
|
def __init__(
|
|
@@ -78,46 +116,37 @@ class ImputeVAE(BaseNNImputer):
|
|
|
78
116
|
*,
|
|
79
117
|
tree_parser: Optional["TreeParser"] = None,
|
|
80
118
|
config: Optional[Union["VAEConfig", dict, str]] = None,
|
|
81
|
-
overrides: dict
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
sim_prop: float | None = None,
|
|
94
|
-
sim_kwargs: dict | None = None,
|
|
95
|
-
):
|
|
96
|
-
"""Initialize the VAE imputer with a unified config interface.
|
|
97
|
-
|
|
98
|
-
This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
119
|
+
overrides: Optional[dict] = None,
|
|
120
|
+
sim_strategy: Literal[
|
|
121
|
+
"random",
|
|
122
|
+
"random_weighted",
|
|
123
|
+
"random_weighted_inv",
|
|
124
|
+
"nonrandom",
|
|
125
|
+
"nonrandom_weighted",
|
|
126
|
+
] = "random",
|
|
127
|
+
sim_prop: Optional[float] = None,
|
|
128
|
+
sim_kwargs: Optional[dict] = None,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Initialize the ImputeVAE imputer.
|
|
99
131
|
|
|
100
132
|
Args:
|
|
101
|
-
genotype_data (GenotypeData):
|
|
102
|
-
tree_parser (TreeParser
|
|
103
|
-
config (Union[VAEConfig, dict, str
|
|
104
|
-
overrides (dict
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
|
|
133
|
+
genotype_data (GenotypeData): Genotype data for imputation.
|
|
134
|
+
tree_parser (Optional[TreeParser]): Tree parser required for nonrandom strategies.
|
|
135
|
+
config (Optional[Union[VAEConfig, dict, str]]): Config dataclass, nested dict, YAML path, or None.
|
|
136
|
+
overrides (Optional[dict]): Dot-key overrides applied last with highest precedence.
|
|
137
|
+
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Missingness simulation strategy (overrides config).
|
|
138
|
+
sim_prop (Optional[float]): Proportion of entries to simulate as missing (overrides config). Default is None.
|
|
139
|
+
sim_kwargs (Optional[dict]): Extra missingness kwargs merged into config.
|
|
109
140
|
"""
|
|
110
141
|
self.model_name = "ImputeVAE"
|
|
111
142
|
self.genotype_data = genotype_data
|
|
112
143
|
self.tree_parser = tree_parser
|
|
113
144
|
|
|
114
|
-
# Normalize configuration and apply top-precedence overrides
|
|
115
145
|
cfg = ensure_vae_config(config)
|
|
116
146
|
if overrides:
|
|
117
147
|
cfg = apply_dot_overrides(cfg, overrides)
|
|
118
148
|
self.cfg = cfg
|
|
119
149
|
|
|
120
|
-
# Logger (align with AE/NLPCA)
|
|
121
150
|
logman = LoggerManager(
|
|
122
151
|
__name__,
|
|
123
152
|
prefix=self.cfg.io.prefix,
|
|
@@ -125,12 +154,10 @@ class ImputeVAE(BaseNNImputer):
|
|
|
125
154
|
verbose=self.cfg.io.verbose,
|
|
126
155
|
)
|
|
127
156
|
self.logger = configure_logger(
|
|
128
|
-
logman.get_logger(),
|
|
129
|
-
verbose=self.cfg.io.verbose,
|
|
130
|
-
debug=self.cfg.io.debug,
|
|
157
|
+
logman.get_logger(), verbose=self.cfg.io.verbose, debug=self.cfg.io.debug
|
|
131
158
|
)
|
|
159
|
+
self.logger.propagate = False
|
|
132
160
|
|
|
133
|
-
# BaseNNImputer bootstraps device/dirs/log formatting
|
|
134
161
|
super().__init__(
|
|
135
162
|
model_name=self.model_name,
|
|
136
163
|
genotype_data=self.genotype_data,
|
|
@@ -140,11 +167,10 @@ class ImputeVAE(BaseNNImputer):
|
|
|
140
167
|
debug=self.cfg.io.debug,
|
|
141
168
|
)
|
|
142
169
|
|
|
143
|
-
# Model hook & encoder
|
|
144
170
|
self.Model = VAEModel
|
|
145
171
|
self.pgenc = GenotypeEncoder(genotype_data)
|
|
146
172
|
|
|
147
|
-
#
|
|
173
|
+
# I/O and general parameters
|
|
148
174
|
self.seed = self.cfg.io.seed
|
|
149
175
|
self.n_jobs = self.cfg.io.n_jobs
|
|
150
176
|
self.prefix = self.cfg.io.prefix
|
|
@@ -153,41 +179,40 @@ class ImputeVAE(BaseNNImputer):
|
|
|
153
179
|
self.debug = self.cfg.io.debug
|
|
154
180
|
self.rng = np.random.default_rng(self.seed)
|
|
155
181
|
|
|
156
|
-
#
|
|
182
|
+
# Simulation parameters
|
|
157
183
|
sim_cfg = getattr(self.cfg, "sim", None)
|
|
158
184
|
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
159
185
|
if sim_kwargs:
|
|
160
186
|
sim_cfg_kwargs.update(sim_kwargs)
|
|
161
187
|
if sim_cfg is None:
|
|
162
|
-
default_sim_flag = bool(simulate_missing)
|
|
163
188
|
default_strategy = "random"
|
|
164
|
-
default_prop = 0.
|
|
189
|
+
default_prop = 0.2
|
|
165
190
|
else:
|
|
166
|
-
default_sim_flag = sim_cfg.simulate_missing
|
|
167
191
|
default_strategy = sim_cfg.sim_strategy
|
|
168
192
|
default_prop = sim_cfg.sim_prop
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
)
|
|
193
|
+
|
|
194
|
+
self.simulate_missing = True
|
|
172
195
|
self.sim_strategy = sim_strategy or default_strategy
|
|
173
196
|
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
174
197
|
self.sim_kwargs = sim_cfg_kwargs
|
|
175
198
|
|
|
176
|
-
# Model
|
|
199
|
+
# Model architecture parameters
|
|
177
200
|
self.latent_dim = self.cfg.model.latent_dim
|
|
178
201
|
self.dropout_rate = self.cfg.model.dropout_rate
|
|
179
202
|
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
180
203
|
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
181
204
|
self.layer_schedule = self.cfg.model.layer_schedule
|
|
182
|
-
self.activation = self.cfg.model.
|
|
183
|
-
self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
|
|
205
|
+
self.activation = self.cfg.model.activation
|
|
184
206
|
|
|
185
|
-
# VAE-
|
|
186
|
-
self.
|
|
187
|
-
self.
|
|
188
|
-
self.kl_ramp = self.cfg.vae.kl_ramp
|
|
207
|
+
# VAE-specific parameters
|
|
208
|
+
self.kl_beta = self.cfg.vae.kl_beta
|
|
209
|
+
self.kl_beta_schedule = self.cfg.vae.kl_beta_schedule
|
|
189
210
|
|
|
190
|
-
#
|
|
211
|
+
# Training parameters
|
|
212
|
+
self.power: float = self.cfg.train.weights_power
|
|
213
|
+
self.max_ratio: Any | float | None = self.cfg.train.weights_max_ratio
|
|
214
|
+
self.normalize: bool = self.cfg.train.weights_normalize
|
|
215
|
+
self.inverse: bool = self.cfg.train.weights_inverse
|
|
191
216
|
self.batch_size = self.cfg.train.batch_size
|
|
192
217
|
self.learning_rate = self.cfg.train.learning_rate
|
|
193
218
|
self.l1_penalty: float = self.cfg.train.l1_penalty
|
|
@@ -195,32 +220,18 @@ class ImputeVAE(BaseNNImputer):
|
|
|
195
220
|
self.min_epochs = self.cfg.train.min_epochs
|
|
196
221
|
self.epochs = self.cfg.train.max_epochs
|
|
197
222
|
self.validation_split = self.cfg.train.validation_split
|
|
198
|
-
self.
|
|
199
|
-
self.
|
|
223
|
+
self.gamma = self.cfg.train.gamma
|
|
224
|
+
self.gamma_schedule = self.cfg.train.gamma_schedule
|
|
200
225
|
|
|
201
|
-
#
|
|
226
|
+
# Hyperparameter tuning
|
|
202
227
|
self.tune = self.cfg.tune.enabled
|
|
203
|
-
self.
|
|
204
|
-
self.tune_batch_size = self.cfg.tune.batch_size
|
|
205
|
-
self.tune_epochs = self.cfg.tune.epochs
|
|
206
|
-
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
207
|
-
self.tune_metric: Literal[
|
|
208
|
-
"pr_macro",
|
|
209
|
-
"f1",
|
|
210
|
-
"accuracy",
|
|
211
|
-
"average_precision",
|
|
212
|
-
"precision",
|
|
213
|
-
"recall",
|
|
214
|
-
"roc_auc",
|
|
215
|
-
] = self.cfg.tune.metric
|
|
228
|
+
self.tune_metric = self.cfg.tune.metric
|
|
216
229
|
self.n_trials = self.cfg.tune.n_trials
|
|
217
230
|
self.tune_save_db = self.cfg.tune.save_db
|
|
218
231
|
self.tune_resume = self.cfg.tune.resume
|
|
219
|
-
self.tune_max_samples = self.cfg.tune.max_samples
|
|
220
|
-
self.tune_max_loci = self.cfg.tune.max_loci
|
|
221
232
|
self.tune_patience = self.cfg.tune.patience
|
|
222
233
|
|
|
223
|
-
# Plotting
|
|
234
|
+
# Plotting parameters
|
|
224
235
|
self.plot_format = self.cfg.plot.fmt
|
|
225
236
|
self.plot_dpi = self.cfg.plot.dpi
|
|
226
237
|
self.plot_fontsize = self.cfg.plot.fontsize
|
|
@@ -228,90 +239,61 @@ class ImputeVAE(BaseNNImputer):
|
|
|
228
239
|
self.despine = self.cfg.plot.despine
|
|
229
240
|
self.show_plots = self.cfg.plot.show
|
|
230
241
|
|
|
231
|
-
#
|
|
232
|
-
self.
|
|
233
|
-
self.num_classes_: int = 3
|
|
242
|
+
# Internal attributes set during fitting
|
|
243
|
+
self.is_haploid_: bool = False
|
|
244
|
+
self.num_classes_: int = 3
|
|
234
245
|
self.model_params: Dict[str, Any] = {}
|
|
235
|
-
self.
|
|
236
|
-
self.sim_mask_train_: np.ndarray | None = None
|
|
237
|
-
self.sim_mask_test_: np.ndarray | None = None
|
|
246
|
+
self.sim_mask_test_: np.ndarray
|
|
238
247
|
|
|
239
248
|
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
240
|
-
msg = "tree_parser is required for nonrandom
|
|
249
|
+
msg = "tree_parser is required for nonrandom sim strategies."
|
|
241
250
|
self.logger.error(msg)
|
|
242
251
|
raise ValueError(msg)
|
|
243
252
|
|
|
244
|
-
# -------------------- Fit -------------------- #
|
|
245
253
|
def fit(self) -> "ImputeVAE":
|
|
246
|
-
"""Fit the VAE
|
|
247
|
-
|
|
248
|
-
This method
|
|
254
|
+
"""Fit the VAE imputer model to the genotype data.
|
|
255
|
+
|
|
256
|
+
This method performs the following steps:
|
|
257
|
+
1. Validates the presence of SNP data in the genotype data.
|
|
258
|
+
2. Determines the ploidy of the genotype data and sets up haploid/diploid handling.
|
|
259
|
+
3. Simulates missingness in the genotype data based on the specified strategy.
|
|
260
|
+
4. Splits the data into training, validation, and test sets.
|
|
261
|
+
5. One-hot encodes the genotype data for model input.
|
|
262
|
+
6. Initializes data loaders for training and validation.
|
|
263
|
+
7. If hyperparameter tuning is enabled, tunes the model hyperparameters.
|
|
264
|
+
8. Builds the VAE model with the best hyperparameters.
|
|
265
|
+
9. Trains the VAE model using the training data and validates on the validation set.
|
|
266
|
+
10. Evaluates the trained model on the test set and computes performance metrics.
|
|
267
|
+
11. Saves the trained model and best hyperparameters.
|
|
268
|
+
12. Generates plots of training history if enabled.
|
|
269
|
+
13. Returns the fitted ImputeVAE instance.
|
|
249
270
|
|
|
250
271
|
Returns:
|
|
251
|
-
ImputeVAE:
|
|
252
|
-
|
|
253
|
-
Raises:
|
|
254
|
-
RuntimeError: If training fails to produce a model.
|
|
272
|
+
ImputeVAE: The fitted ImputeVAE instance.
|
|
255
273
|
"""
|
|
256
274
|
self.logger.info(f"Fitting {self.model_name} model...")
|
|
257
275
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
276
|
+
if self.genotype_data.snp_data is None:
|
|
277
|
+
msg = f"SNP data is required for {self.model_name}."
|
|
278
|
+
self.logger.error(msg)
|
|
279
|
+
raise AttributeError(msg)
|
|
262
280
|
|
|
263
|
-
self.
|
|
264
|
-
|
|
265
|
-
if self.simulate_missing:
|
|
266
|
-
cached_mask = (
|
|
267
|
-
None if cache_key is None else self._sim_mask_cache.get(cache_key)
|
|
268
|
-
)
|
|
269
|
-
if cached_mask is not None:
|
|
270
|
-
self.sim_mask_global_ = cached_mask.copy()
|
|
271
|
-
else:
|
|
272
|
-
tr = SimMissingTransformer(
|
|
273
|
-
genotype_data=self.genotype_data,
|
|
274
|
-
tree_parser=self.tree_parser,
|
|
275
|
-
prop_missing=self.sim_prop,
|
|
276
|
-
strategy=self.sim_strategy,
|
|
277
|
-
missing_val=-9,
|
|
278
|
-
mask_missing=True,
|
|
279
|
-
verbose=self.verbose,
|
|
280
|
-
**self.sim_kwargs,
|
|
281
|
-
)
|
|
282
|
-
tr.fit(X012.copy())
|
|
283
|
-
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
284
|
-
if cache_key is not None:
|
|
285
|
-
self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
|
|
281
|
+
self.ploidy = self.cfg.io.ploidy
|
|
282
|
+
self.is_haploid_ = self.ploidy == 1
|
|
286
283
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
# Ploidy/classes
|
|
293
|
-
self.is_haploid = bool(
|
|
294
|
-
np.all(
|
|
295
|
-
np.isin(
|
|
296
|
-
self.genotype_data.snp_data,
|
|
297
|
-
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
298
|
-
)
|
|
299
|
-
)
|
|
300
|
-
)
|
|
301
|
-
self.ploidy = 1 if self.is_haploid else 2
|
|
302
|
-
self.num_classes_ = 2 if self.is_haploid else 3
|
|
303
|
-
self.logger.info(
|
|
304
|
-
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
305
|
-
f"using {self.num_classes_} classes."
|
|
306
|
-
)
|
|
284
|
+
if self.ploidy > 2:
|
|
285
|
+
msg = f"{self.model_name} currently supports only haploid (1) or diploid (2) data; got ploidy={self.ploidy}."
|
|
286
|
+
self.logger.error(msg)
|
|
287
|
+
raise ValueError(msg)
|
|
307
288
|
|
|
308
|
-
if self.
|
|
309
|
-
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
310
|
-
X_for_model[X_for_model == 2] = 1
|
|
289
|
+
self.num_classes_ = 2 if self.is_haploid_ else 3
|
|
311
290
|
|
|
312
|
-
|
|
291
|
+
gt_full = self.pgenc.genotypes_012.copy()
|
|
292
|
+
gt_full[gt_full < 0] = -1
|
|
293
|
+
gt_full = np.nan_to_num(gt_full, nan=-1.0)
|
|
294
|
+
self.ground_truth_ = gt_full.astype(np.int8)
|
|
295
|
+
self.num_features_ = gt_full.shape[1]
|
|
313
296
|
|
|
314
|
-
# Model params (decoder outputs L*K logits)
|
|
315
297
|
self.model_params = {
|
|
316
298
|
"n_features": self.num_features_,
|
|
317
299
|
"num_classes": self.num_classes_,
|
|
@@ -320,66 +302,218 @@ class ImputeVAE(BaseNNImputer):
|
|
|
320
302
|
"activation": self.activation,
|
|
321
303
|
}
|
|
322
304
|
|
|
323
|
-
#
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
305
|
+
# Simulate missingness ONCE on the full matrix
|
|
306
|
+
sim_tup = self.sim_missing_transform(self.ground_truth_)
|
|
307
|
+
X_for_model_full = sim_tup[0]
|
|
308
|
+
self.sim_mask_ = sim_tup[1]
|
|
309
|
+
self.orig_mask_ = sim_tup[2]
|
|
310
|
+
|
|
311
|
+
# Split indices based on clean ground truth
|
|
312
|
+
self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
|
|
313
|
+
self.ground_truth_
|
|
327
314
|
)
|
|
328
|
-
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
329
|
-
self.X_train_ = X_for_model[train_idx]
|
|
330
|
-
self.X_val_ = X_for_model[val_idx]
|
|
331
|
-
self.GT_train_full_ = self.ground_truth_[train_idx]
|
|
332
|
-
self.GT_test_full_ = self.ground_truth_[val_idx]
|
|
333
|
-
|
|
334
|
-
if self.sim_mask_global_ is not None:
|
|
335
|
-
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
336
|
-
self.sim_mask_test_ = self.sim_mask_global_[val_idx]
|
|
337
|
-
else:
|
|
338
|
-
self.sim_mask_train_ = None
|
|
339
|
-
self.sim_mask_test_ = None
|
|
340
315
|
|
|
341
|
-
#
|
|
316
|
+
# --- Clean (targets) per split ---
|
|
317
|
+
X_train_clean = self.ground_truth_[self.train_idx_].copy()
|
|
318
|
+
X_val_clean = self.ground_truth_[self.val_idx_].copy()
|
|
319
|
+
X_test_clean = self.ground_truth_[self.test_idx_].copy()
|
|
320
|
+
|
|
321
|
+
# --- Corrupted (inputs) per split (from the single simulation) ---
|
|
322
|
+
X_train_corrupted = X_for_model_full[self.train_idx_].copy()
|
|
323
|
+
X_val_corrupted = X_for_model_full[self.val_idx_].copy()
|
|
324
|
+
X_test_corrupted = X_for_model_full[self.test_idx_].copy()
|
|
325
|
+
|
|
326
|
+
# --- Masks per split ---
|
|
327
|
+
self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
|
|
328
|
+
self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
|
|
329
|
+
self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
|
|
330
|
+
|
|
331
|
+
self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
|
|
332
|
+
self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
|
|
333
|
+
self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
|
|
334
|
+
|
|
335
|
+
# Persist clean/corrupted matrices if you want them accessible later
|
|
336
|
+
self.X_train_clean_ = X_train_clean
|
|
337
|
+
self.X_val_clean_ = X_val_clean
|
|
338
|
+
self.X_test_clean_ = X_test_clean
|
|
339
|
+
|
|
340
|
+
self.X_train_corrupted_ = X_train_corrupted
|
|
341
|
+
self.X_val_corrupted_ = X_val_corrupted
|
|
342
|
+
self.X_test_corrupted_ = X_test_corrupted
|
|
343
|
+
|
|
344
|
+
# --- Haploid harmonization (do NOT resimulate; just recode values) ---
|
|
345
|
+
if self.is_haploid_:
|
|
346
|
+
|
|
347
|
+
def _haploidize(arr: np.ndarray) -> np.ndarray:
|
|
348
|
+
out = arr.copy()
|
|
349
|
+
miss = out < 0
|
|
350
|
+
out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
|
|
351
|
+
out[miss] = -1
|
|
352
|
+
return out
|
|
353
|
+
|
|
354
|
+
X_train_clean = _haploidize(X_train_clean)
|
|
355
|
+
X_val_clean = _haploidize(X_val_clean)
|
|
356
|
+
X_test_clean = _haploidize(X_test_clean)
|
|
357
|
+
|
|
358
|
+
X_train_corrupted = _haploidize(X_train_corrupted)
|
|
359
|
+
X_val_corrupted = _haploidize(X_val_corrupted)
|
|
360
|
+
X_test_corrupted = _haploidize(X_test_corrupted)
|
|
361
|
+
|
|
362
|
+
# Write back the persisted versions too
|
|
363
|
+
self.X_train_clean_ = X_train_clean
|
|
364
|
+
self.X_val_clean_ = X_val_clean
|
|
365
|
+
self.X_test_clean_ = X_test_clean
|
|
366
|
+
|
|
367
|
+
self.X_train_corrupted_ = X_train_corrupted
|
|
368
|
+
self.X_val_corrupted_ = X_val_corrupted
|
|
369
|
+
self.X_test_corrupted_ = X_test_corrupted
|
|
370
|
+
|
|
371
|
+
# Final training tensors/matrices used by the pipeline
|
|
372
|
+
# Convention: X_* are corrupted inputs; y_* are clean targets
|
|
373
|
+
self.X_train_ = self.X_train_corrupted_
|
|
374
|
+
self.y_train_ = self.X_train_clean_
|
|
375
|
+
|
|
376
|
+
self.X_val_ = self.X_val_corrupted_
|
|
377
|
+
self.y_val_ = self.X_val_clean_
|
|
378
|
+
|
|
379
|
+
self.X_test_ = self.X_test_corrupted_
|
|
380
|
+
self.y_test_ = self.X_test_clean_
|
|
381
|
+
|
|
382
|
+
self.X_train_ = self._one_hot_encode_012(
|
|
383
|
+
self.X_train_, num_classes=self.num_classes_
|
|
384
|
+
)
|
|
385
|
+
self.X_val_ = self._one_hot_encode_012(
|
|
386
|
+
self.X_val_, num_classes=self.num_classes_
|
|
387
|
+
)
|
|
388
|
+
self.X_test_ = self._one_hot_encode_012(
|
|
389
|
+
self.X_test_, num_classes=self.num_classes_
|
|
390
|
+
)
|
|
391
|
+
self.X_test_ = self.X_test_.detach().cpu().numpy()
|
|
392
|
+
|
|
342
393
|
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
394
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
395
|
+
|
|
396
|
+
loci = getattr(self, "valid_class_mask_conflict_loci_", None)
|
|
397
|
+
if loci is not None and loci.size:
|
|
398
|
+
self._repair_ref_alt_from_iupac(loci)
|
|
399
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
400
|
+
|
|
401
|
+
train_loader = self._get_data_loaders(
|
|
402
|
+
self.X_train_.detach().cpu().numpy(),
|
|
403
|
+
self.y_train_,
|
|
404
|
+
~self.orig_mask_train_,
|
|
405
|
+
self.batch_size,
|
|
406
|
+
shuffle=True,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
val_loader = self._get_data_loaders(
|
|
410
|
+
self.X_val_.detach().cpu().numpy(),
|
|
411
|
+
self.y_val_,
|
|
412
|
+
~self.orig_mask_val_,
|
|
413
|
+
self.batch_size,
|
|
414
|
+
shuffle=False,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
self.train_loader_ = train_loader
|
|
418
|
+
self.val_loader_ = val_loader
|
|
343
419
|
|
|
344
|
-
# Optional tuning
|
|
345
420
|
if self.tune:
|
|
346
|
-
self.tune_hyperparameters()
|
|
421
|
+
self.tuned_params_ = self.tune_hyperparameters()
|
|
422
|
+
self.model_tuned_ = True
|
|
423
|
+
else:
|
|
424
|
+
self.model_tuned_ = False
|
|
425
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
426
|
+
self.y_train_,
|
|
427
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
428
|
+
inverse=self.inverse,
|
|
429
|
+
normalize=self.normalize,
|
|
430
|
+
max_ratio=self.max_ratio,
|
|
431
|
+
power=self.power,
|
|
432
|
+
)
|
|
433
|
+
self.tuned_params_ = {
|
|
434
|
+
"latent_dim": self.latent_dim,
|
|
435
|
+
"learning_rate": self.learning_rate,
|
|
436
|
+
"dropout_rate": self.dropout_rate,
|
|
437
|
+
"num_hidden_layers": self.num_hidden_layers,
|
|
438
|
+
"activation": self.activation,
|
|
439
|
+
"l1_penalty": self.l1_penalty,
|
|
440
|
+
"layer_scaling_factor": self.layer_scaling_factor,
|
|
441
|
+
"layer_schedule": self.layer_schedule,
|
|
442
|
+
}
|
|
443
|
+
self.tuned_params_.update(
|
|
444
|
+
{
|
|
445
|
+
"kl_beta": self.kl_beta,
|
|
446
|
+
"kl_beta_schedule": self.kl_beta_schedule,
|
|
447
|
+
"gamma": self.gamma,
|
|
448
|
+
"gamma_schedule": self.gamma_schedule,
|
|
449
|
+
"inverse": self.inverse,
|
|
450
|
+
"normalize": self.normalize,
|
|
451
|
+
"power": self.power,
|
|
452
|
+
}
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
self.tuned_params_["model_params"] = self.model_params
|
|
456
|
+
|
|
457
|
+
if self.class_weights_ is not None:
|
|
458
|
+
self.logger.info(
|
|
459
|
+
f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
|
|
460
|
+
)
|
|
347
461
|
|
|
348
|
-
#
|
|
349
|
-
self.best_params_ =
|
|
462
|
+
# Always start clean
|
|
463
|
+
self.best_params_ = copy.deepcopy(self.tuned_params_)
|
|
464
|
+
|
|
465
|
+
model_params_final = {
|
|
466
|
+
"n_features": self.num_features_,
|
|
467
|
+
"num_classes": self.num_classes_,
|
|
468
|
+
"latent_dim": int(self.best_params_["latent_dim"]),
|
|
469
|
+
"dropout_rate": float(self.best_params_["dropout_rate"]),
|
|
470
|
+
"activation": str(self.best_params_["activation"]),
|
|
471
|
+
"kl_beta": float(
|
|
472
|
+
self.best_params_.get("kl_beta", getattr(self, "kl_beta", 1.0))
|
|
473
|
+
),
|
|
474
|
+
}
|
|
350
475
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
476
|
+
input_dim = self.num_features_ * self.num_classes_
|
|
477
|
+
model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
|
|
478
|
+
n_inputs=input_dim,
|
|
479
|
+
n_outputs=self.num_classes_,
|
|
480
|
+
n_samples=len(self.X_train_),
|
|
481
|
+
n_hidden=int(self.best_params_["num_hidden_layers"]),
|
|
482
|
+
latent_dim=int(self.best_params_["latent_dim"]),
|
|
483
|
+
alpha=float(self.best_params_["layer_scaling_factor"]),
|
|
484
|
+
schedule=str(self.best_params_["layer_schedule"]),
|
|
485
|
+
min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
|
|
354
486
|
)
|
|
355
487
|
|
|
356
|
-
|
|
357
|
-
train_loader = self._get_data_loader(self.X_train_)
|
|
488
|
+
self.best_params_["model_params"] = model_params_final
|
|
358
489
|
|
|
359
|
-
#
|
|
360
|
-
model = self.build_model(self.Model, self.best_params_)
|
|
490
|
+
# Now build the model
|
|
491
|
+
model = self.build_model(self.Model, self.best_params_["model_params"])
|
|
361
492
|
model.apply(self.initialize_weights)
|
|
362
493
|
|
|
494
|
+
if self.verbose or self.debug:
|
|
495
|
+
self.logger.info("Using model hyperparameters:")
|
|
496
|
+
pm = PrettyMetrics(
|
|
497
|
+
self.best_params_, precision=3, title="Model Hyperparameters"
|
|
498
|
+
)
|
|
499
|
+
pm.render()
|
|
500
|
+
|
|
501
|
+
lr_final = float(self.best_params_["learning_rate"])
|
|
502
|
+
l1_final = float(self.best_params_["l1_penalty"])
|
|
503
|
+
|
|
363
504
|
loss, trained_model, history = self._train_and_validate_model(
|
|
364
505
|
model=model,
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
l1_penalty=self.l1_penalty,
|
|
368
|
-
return_history=True,
|
|
369
|
-
class_weights=self.class_weights_,
|
|
370
|
-
X_val=self.X_val_,
|
|
506
|
+
lr=lr_final,
|
|
507
|
+
l1_penalty=l1_final,
|
|
371
508
|
params=self.best_params_,
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
eval_latent_steps=0,
|
|
377
|
-
eval_latent_lr=0.0,
|
|
378
|
-
eval_latent_weight_decay=0.0,
|
|
509
|
+
trial=None,
|
|
510
|
+
class_weights=self.class_weights_,
|
|
511
|
+
kl_beta_schedule=self.best_params_["kl_beta_schedule"],
|
|
512
|
+
gamma_schedule=self.best_params_["gamma_schedule"],
|
|
379
513
|
)
|
|
380
514
|
|
|
381
515
|
if trained_model is None:
|
|
382
|
-
msg = "
|
|
516
|
+
msg = f"{self.model_name} training failed."
|
|
383
517
|
self.logger.error(msg)
|
|
384
518
|
raise RuntimeError(msg)
|
|
385
519
|
|
|
@@ -388,215 +522,209 @@ class ImputeVAE(BaseNNImputer):
|
|
|
388
522
|
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
389
523
|
)
|
|
390
524
|
|
|
391
|
-
|
|
392
|
-
|
|
525
|
+
if history is None:
|
|
526
|
+
hist = {"Train": []}
|
|
527
|
+
elif isinstance(history, dict):
|
|
528
|
+
hist = dict(history)
|
|
529
|
+
else:
|
|
530
|
+
hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
|
|
531
|
+
|
|
532
|
+
self.best_loss_ = loss
|
|
533
|
+
self.model_ = trained_model
|
|
534
|
+
self.history_ = hist
|
|
393
535
|
self.is_fit_ = True
|
|
394
536
|
|
|
395
|
-
# Evaluate (AE-parity reporting)
|
|
396
|
-
eval_mask = (
|
|
397
|
-
self.sim_mask_test_
|
|
398
|
-
if (self.simulate_missing and self.sim_mask_test_ is not None)
|
|
399
|
-
else None
|
|
400
|
-
)
|
|
401
537
|
self._evaluate_model(
|
|
402
|
-
self.X_val_,
|
|
403
538
|
self.model_,
|
|
404
|
-
self.
|
|
405
|
-
|
|
539
|
+
X=self.X_test_,
|
|
540
|
+
y=self.y_test_,
|
|
541
|
+
eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
|
|
542
|
+
objective_mode=False,
|
|
406
543
|
)
|
|
407
544
|
|
|
408
|
-
self.
|
|
545
|
+
if self.show_plots:
|
|
546
|
+
self.plotter_.plot_history(self.history_)
|
|
547
|
+
|
|
409
548
|
self._save_best_params(self.best_params_)
|
|
549
|
+
|
|
550
|
+
if self.model_tuned_:
|
|
551
|
+
title = f"{self.model_name} Optimized Parameters"
|
|
552
|
+
|
|
553
|
+
if self.verbose or self.debug:
|
|
554
|
+
pm = PrettyMetrics(self.best_params_, precision=2, title=title)
|
|
555
|
+
pm.render()
|
|
556
|
+
|
|
557
|
+
# Save best parameters to a JSON file.
|
|
558
|
+
self._save_best_params(self.best_params_, objective_mode=True)
|
|
410
559
|
return self
|
|
411
560
|
|
|
412
561
|
def transform(self) -> np.ndarray:
|
|
413
562
|
"""Impute missing genotypes and return IUPAC strings.
|
|
414
563
|
|
|
415
|
-
This method
|
|
564
|
+
This method performs the following steps:
|
|
565
|
+
1. Validates that the model has been fitted.
|
|
566
|
+
2. Uses the trained model to predict missing genotypes for the entire dataset.
|
|
567
|
+
3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
|
|
568
|
+
4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
|
|
569
|
+
5. Checks for any remaining missing values or decoding issues, raising errors if found.
|
|
570
|
+
6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
|
|
571
|
+
7. Returns the imputed IUPAC genotype matrix.
|
|
416
572
|
|
|
417
573
|
Returns:
|
|
418
|
-
np.ndarray: IUPAC
|
|
574
|
+
np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
|
|
419
575
|
|
|
420
|
-
|
|
421
|
-
|
|
576
|
+
Notes:
|
|
577
|
+
- ``transform()`` does not take any arguments; it operates on the data provided during initialization.
|
|
578
|
+
- Ensure that the model has been fitted before calling this method.
|
|
579
|
+
- For haploid data, genotypes encoded as '1' are treated as '2' during decoding.
|
|
580
|
+
- The method checks for decoding failures (i.e., resulting in 'N') and raises an error if any are found.
|
|
422
581
|
"""
|
|
423
582
|
if not getattr(self, "is_fit_", False):
|
|
424
|
-
|
|
583
|
+
msg = "Model is not fitted. Call fit() before transform()."
|
|
584
|
+
self.logger.error(msg)
|
|
585
|
+
raise NotFittedError(msg)
|
|
425
586
|
|
|
426
587
|
self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
|
|
427
588
|
X_to_impute = self.ground_truth_.copy()
|
|
428
589
|
|
|
429
|
-
|
|
590
|
+
# 1. Predict labels (0/1/2) for the entire matrix
|
|
591
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute)
|
|
430
592
|
|
|
431
|
-
# Fill
|
|
432
|
-
missing_mask = X_to_impute
|
|
593
|
+
# 2. Fill ONLY originally missing values
|
|
594
|
+
missing_mask = X_to_impute < 0
|
|
433
595
|
imputed_array = X_to_impute.copy()
|
|
434
596
|
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
435
597
|
|
|
436
|
-
#
|
|
437
|
-
|
|
438
|
-
|
|
598
|
+
# Sanity check: all -1s should be gone
|
|
599
|
+
if np.any(imputed_array < 0):
|
|
600
|
+
msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
|
|
601
|
+
self.logger.error(msg)
|
|
602
|
+
raise RuntimeError(msg)
|
|
439
603
|
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
self.
|
|
604
|
+
# 3. Handle Haploid mapping (2->1) before decoding if needed
|
|
605
|
+
decode_input = imputed_array
|
|
606
|
+
if self.is_haploid_:
|
|
607
|
+
decode_input = imputed_array.copy()
|
|
608
|
+
decode_input[decode_input == 1] = 2
|
|
443
609
|
|
|
444
|
-
|
|
610
|
+
# 4. Decode integers to IUPAC strings
|
|
611
|
+
imputed_genotypes = self.decode_012(decode_input)
|
|
445
612
|
|
|
446
|
-
|
|
613
|
+
# 5. Check for decoding failures (N)
|
|
614
|
+
# Hard error: downstream pipelines expect fully imputed data.
|
|
615
|
+
bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
|
|
616
|
+
if bad_loci.size > 0:
|
|
617
|
+
msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
|
|
618
|
+
self.logger.error(msg)
|
|
619
|
+
self.logger.debug(
|
|
620
|
+
"All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
|
|
621
|
+
)
|
|
622
|
+
raise RuntimeError(msg)
|
|
447
623
|
|
|
448
|
-
|
|
449
|
-
|
|
624
|
+
if self.show_plots:
|
|
625
|
+
original_input = X_to_impute
|
|
626
|
+
if self.is_haploid_:
|
|
627
|
+
original_input = X_to_impute.copy()
|
|
628
|
+
original_input[original_input == 1] = 2
|
|
450
629
|
|
|
451
|
-
|
|
630
|
+
original_genotypes = self.decode_012(original_input)
|
|
452
631
|
|
|
453
|
-
|
|
454
|
-
|
|
632
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
633
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
634
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
455
635
|
|
|
456
|
-
|
|
457
|
-
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
458
|
-
"""
|
|
459
|
-
y_tensor = torch.from_numpy(y).long()
|
|
460
|
-
indices = torch.arange(len(y), dtype=torch.long)
|
|
461
|
-
dataset = torch.utils.data.TensorDataset(indices, y_tensor)
|
|
462
|
-
pin_memory = self.device.type == "cuda"
|
|
463
|
-
return torch.utils.data.DataLoader(
|
|
464
|
-
dataset,
|
|
465
|
-
batch_size=self.batch_size,
|
|
466
|
-
shuffle=True,
|
|
467
|
-
pin_memory=pin_memory,
|
|
468
|
-
)
|
|
636
|
+
return imputed_genotypes
|
|
469
637
|
|
|
470
638
|
def _train_and_validate_model(
|
|
471
639
|
self,
|
|
472
640
|
model: torch.nn.Module,
|
|
473
|
-
|
|
641
|
+
*,
|
|
474
642
|
lr: float,
|
|
475
643
|
l1_penalty: float,
|
|
476
|
-
trial: optuna.Trial
|
|
477
|
-
|
|
478
|
-
class_weights: torch.Tensor
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
eval_requires_latents: bool = False, # VAE: no latent eval refinement
|
|
486
|
-
eval_latent_steps: int = 0,
|
|
487
|
-
eval_latent_lr: float = 0.0,
|
|
488
|
-
eval_latent_weight_decay: float = 0.0,
|
|
489
|
-
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
490
|
-
"""Wrap the VAE training loop with β-anneal & Optuna pruning.
|
|
491
|
-
|
|
492
|
-
This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
|
|
644
|
+
trial: Optional[optuna.Trial] = None,
|
|
645
|
+
params: Optional[dict[str, Any]] = None,
|
|
646
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
647
|
+
kl_beta_schedule: bool = False,
|
|
648
|
+
gamma_schedule: bool = False,
|
|
649
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
650
|
+
"""Train and validate the model.
|
|
651
|
+
|
|
652
|
+
This method orchestrates training with early stopping and optional Optuna pruning based on validation performance. It returns the best validation loss, the best model (with best weights loaded), and training history.
|
|
493
653
|
|
|
494
654
|
Args:
|
|
495
655
|
model (torch.nn.Module): VAE model.
|
|
496
|
-
loader (torch.utils.data.DataLoader): Training data loader.
|
|
497
656
|
lr (float): Learning rate.
|
|
498
657
|
l1_penalty (float): L1 regularization coefficient.
|
|
499
|
-
trial (optuna.Trial
|
|
500
|
-
|
|
501
|
-
class_weights (torch.Tensor
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
prune_metric (str | None): Metric for pruning decisions.
|
|
505
|
-
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
506
|
-
eval_interval (int): Epochs between validation evaluations.
|
|
507
|
-
eval_requires_latents (bool): If True, refine latents during eval.
|
|
508
|
-
eval_latent_steps (int): Latent refinement steps if needed.
|
|
509
|
-
eval_latent_lr (float): Latent refinement learning rate.
|
|
510
|
-
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
658
|
+
trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
|
|
659
|
+
params (Optional[dict[str, float | int | str | dict[str, Any]]]): Model params for evaluation.
|
|
660
|
+
class_weights (Optional[torch.Tensor]): Class weights for loss computation.
|
|
661
|
+
kl_beta_schedule (bool): Whether to use KL beta scheduling.
|
|
662
|
+
gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
|
|
511
663
|
|
|
512
664
|
Returns:
|
|
513
|
-
|
|
665
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
666
|
+
Best validation loss, best model, and training history.
|
|
514
667
|
"""
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
self.logger.error(msg)
|
|
518
|
-
raise TypeError(msg)
|
|
668
|
+
max_epochs = self.epochs
|
|
669
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
519
670
|
|
|
520
|
-
|
|
521
|
-
|
|
671
|
+
scheduler = _make_warmup_cosine_scheduler(
|
|
672
|
+
optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
|
|
522
673
|
)
|
|
523
674
|
|
|
524
|
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
525
|
-
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
526
|
-
|
|
527
675
|
best_loss, best_model, hist = self._execute_training_loop(
|
|
528
|
-
loader=loader,
|
|
529
676
|
optimizer=optimizer,
|
|
530
677
|
scheduler=scheduler,
|
|
531
678
|
model=model,
|
|
532
679
|
l1_penalty=l1_penalty,
|
|
533
680
|
trial=trial,
|
|
534
|
-
return_history=return_history,
|
|
535
|
-
class_weights=class_weights,
|
|
536
|
-
X_val=X_val,
|
|
537
681
|
params=params,
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
eval_requires_latents=eval_requires_latents,
|
|
542
|
-
eval_latent_steps=eval_latent_steps,
|
|
543
|
-
eval_latent_lr=eval_latent_lr,
|
|
544
|
-
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
682
|
+
class_weights=class_weights,
|
|
683
|
+
kl_beta_schedule=kl_beta_schedule,
|
|
684
|
+
gamma_schedule=gamma_schedule,
|
|
545
685
|
)
|
|
546
|
-
|
|
547
|
-
return best_loss, best_model, hist
|
|
548
|
-
|
|
549
|
-
return best_loss, best_model, None
|
|
686
|
+
return best_loss, best_model, hist
|
|
550
687
|
|
|
551
688
|
def _execute_training_loop(
|
|
552
689
|
self,
|
|
553
|
-
|
|
690
|
+
*,
|
|
554
691
|
optimizer: torch.optim.Optimizer,
|
|
555
|
-
scheduler:
|
|
692
|
+
scheduler: (
|
|
693
|
+
torch.optim.lr_scheduler.CosineAnnealingLR
|
|
694
|
+
| torch.optim.lr_scheduler.SequentialLR
|
|
695
|
+
),
|
|
556
696
|
model: torch.nn.Module,
|
|
557
697
|
l1_penalty: float,
|
|
558
|
-
trial: optuna.Trial
|
|
559
|
-
|
|
560
|
-
class_weights: torch.Tensor,
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
eval_requires_latents: bool = False,
|
|
568
|
-
eval_latent_steps: int = 0,
|
|
569
|
-
eval_latent_lr: float = 0.0,
|
|
570
|
-
eval_latent_weight_decay: float = 0.0,
|
|
571
|
-
) -> Tuple[float, torch.nn.Module, list]:
|
|
572
|
-
"""Train VAE with stable focal CE + KL(β) anneal and numeric guards.
|
|
573
|
-
|
|
574
|
-
This method implements the core training loop for the VAE model, incorporating a focal cross-entropy loss for reconstruction and a KL divergence term with an annealed weight (beta). It includes mechanisms for early stopping based on validation performance, learning rate scheduling, and optional pruning of unpromising trials when using Optuna for hyperparameter optimization. The method ensures numerical stability throughout the training process.
|
|
698
|
+
trial: Optional[optuna.Trial] = None,
|
|
699
|
+
params: Optional[dict[str, Any]] = None,
|
|
700
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
701
|
+
kl_beta_schedule: bool = False,
|
|
702
|
+
gamma_schedule: bool = False,
|
|
703
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
704
|
+
"""Train the model with focal CE reconstruction + KL divergence.
|
|
705
|
+
|
|
706
|
+
This method performs the training loop for the model using the provided optimizer and learning rate scheduler. It supports early stopping based on validation loss and integrates with Optuna for hyperparameter tuning. The method returns the best validation loss, the best model state, and the training history.
|
|
575
707
|
|
|
576
708
|
Args:
|
|
577
|
-
loader (torch.utils.data.DataLoader): Training data loader.
|
|
578
709
|
optimizer (torch.optim.Optimizer): Optimizer.
|
|
579
|
-
scheduler (torch.optim.lr_scheduler.
|
|
710
|
+
scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): Learning rate scheduler.
|
|
580
711
|
model (torch.nn.Module): VAE model.
|
|
581
712
|
l1_penalty (float): L1 regularization coefficient.
|
|
582
|
-
trial (optuna.Trial
|
|
583
|
-
|
|
584
|
-
class_weights (torch.Tensor):
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
prune_metric (str | None): Metric for pruning decisions.
|
|
588
|
-
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
589
|
-
eval_interval (int): Epochs between validation evaluations.
|
|
590
|
-
eval_requires_latents (bool): If True, refine latents during eval.
|
|
591
|
-
eval_latent_steps (int): Latent refinement steps if needed.
|
|
592
|
-
eval_latent_lr (float): Latent refinement learning rate.
|
|
593
|
-
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
713
|
+
trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
|
|
714
|
+
params (Optional[dict[str, Any]]): Model params for evaluation.
|
|
715
|
+
class_weights (Optional[torch.Tensor]): Class weights for loss computation.
|
|
716
|
+
kl_beta_schedule (bool): Whether to use KL beta scheduling.
|
|
717
|
+
gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
|
|
594
718
|
|
|
595
719
|
Returns:
|
|
596
|
-
|
|
720
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, training history.
|
|
721
|
+
|
|
722
|
+
Notes:
|
|
723
|
+
- Use CE with class weights during training/validation.
|
|
724
|
+
- Inference de-bias happens in _predict (separate).
|
|
725
|
+
- If `class_weights` is None, this will fall back to self.class_weights_ if present.
|
|
597
726
|
"""
|
|
598
|
-
|
|
599
|
-
history: list[float] = []
|
|
727
|
+
history: dict[str, list[float]] = defaultdict(list)
|
|
600
728
|
|
|
601
729
|
early_stopping = EarlyStopping(
|
|
602
730
|
patience=self.early_stop_gen,
|
|
@@ -606,188 +734,174 @@ class ImputeVAE(BaseNNImputer):
|
|
|
606
734
|
debug=self.debug,
|
|
607
735
|
)
|
|
608
736
|
|
|
609
|
-
#
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
beta_warm, beta_ramp = int(self.kl_warmup), int(self.kl_ramp)
|
|
628
|
-
|
|
629
|
-
# Optional LR warmup
|
|
630
|
-
warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
|
|
631
|
-
base_lr = float(optimizer.param_groups[0]["lr"])
|
|
632
|
-
min_lr = base_lr * 0.1
|
|
633
|
-
|
|
634
|
-
max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
|
|
635
|
-
|
|
636
|
-
for epoch in range(max_epochs):
|
|
637
|
-
# focal γ schedule
|
|
638
|
-
if epoch < gamma_warm:
|
|
639
|
-
model.gamma = 0.0 # type: ignore[attr-defined]
|
|
640
|
-
elif epoch < gamma_warm + gamma_ramp:
|
|
641
|
-
model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore[attr-defined]
|
|
642
|
-
else:
|
|
643
|
-
model.gamma = gamma_final # type: ignore[attr-defined]
|
|
737
|
+
# KL schedule
|
|
738
|
+
kl_beta_target, kl_warm, kl_ramp = self._anneal_config(
|
|
739
|
+
params, "kl_beta", default=self.kl_beta, max_epochs=self.epochs
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
kl_beta_target = float(kl_beta_target)
|
|
743
|
+
|
|
744
|
+
gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
|
|
745
|
+
params, "gamma", default=self.gamma, max_epochs=self.epochs
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
cw = class_weights
|
|
749
|
+
if cw is not None and cw.device != self.device:
|
|
750
|
+
cw = cw.to(self.device)
|
|
751
|
+
|
|
752
|
+
ce_criterion = FocalCELoss(
|
|
753
|
+
alpha=cw, gamma=gamma_target, reduction="mean", ignore_index=-1
|
|
754
|
+
)
|
|
644
755
|
|
|
645
|
-
|
|
646
|
-
if
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
756
|
+
for epoch in range(int(self.epochs)):
|
|
757
|
+
if kl_beta_schedule:
|
|
758
|
+
kl_beta_current = self._update_anneal_schedule(
|
|
759
|
+
kl_beta_target,
|
|
760
|
+
warm=kl_warm,
|
|
761
|
+
ramp=kl_ramp,
|
|
762
|
+
epoch=epoch,
|
|
763
|
+
init_val=0.0,
|
|
764
|
+
)
|
|
650
765
|
else:
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
if
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
766
|
+
kl_beta_current = kl_beta_target
|
|
767
|
+
|
|
768
|
+
if gamma_schedule:
|
|
769
|
+
gamma_current = self._update_anneal_schedule(
|
|
770
|
+
gamma_target,
|
|
771
|
+
warm=gamma_warm,
|
|
772
|
+
ramp=gamma_ramp,
|
|
773
|
+
epoch=epoch,
|
|
774
|
+
init_val=0.0,
|
|
775
|
+
)
|
|
776
|
+
ce_criterion.gamma = gamma_current
|
|
657
777
|
|
|
658
778
|
train_loss = self._train_step(
|
|
659
|
-
loader=
|
|
779
|
+
loader=self.train_loader_,
|
|
660
780
|
optimizer=optimizer,
|
|
661
781
|
model=model,
|
|
782
|
+
ce_criterion=ce_criterion,
|
|
662
783
|
l1_penalty=l1_penalty,
|
|
663
|
-
|
|
784
|
+
kl_beta=kl_beta_current,
|
|
664
785
|
)
|
|
665
786
|
|
|
666
787
|
if not np.isfinite(train_loss):
|
|
667
|
-
if trial:
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
g["lr"] *= 0.5
|
|
672
|
-
continue
|
|
788
|
+
if trial is not None:
|
|
789
|
+
msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
|
|
790
|
+
self.logger.warning(msg)
|
|
791
|
+
raise optuna.exceptions.TrialPruned(msg)
|
|
673
792
|
|
|
674
|
-
|
|
675
|
-
|
|
793
|
+
msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
|
|
794
|
+
self.logger.error(msg)
|
|
795
|
+
raise RuntimeError(msg)
|
|
796
|
+
|
|
797
|
+
val_loss = self._val_step(
|
|
798
|
+
loader=self.val_loader_,
|
|
799
|
+
model=model,
|
|
800
|
+
ce_criterion=ce_criterion,
|
|
801
|
+
l1_penalty=l1_penalty,
|
|
802
|
+
kl_beta=kl_beta_current,
|
|
803
|
+
)
|
|
676
804
|
|
|
677
|
-
|
|
678
|
-
|
|
805
|
+
scheduler.step()
|
|
806
|
+
history["Train"].append(float(train_loss))
|
|
807
|
+
history["Val"].append(float(val_loss))
|
|
679
808
|
|
|
680
|
-
early_stopping(
|
|
809
|
+
early_stopping(val_loss, model)
|
|
681
810
|
if early_stopping.early_stop:
|
|
682
|
-
self.logger.
|
|
811
|
+
self.logger.debug(
|
|
812
|
+
f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
|
|
813
|
+
)
|
|
683
814
|
break
|
|
684
815
|
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
trial is not None
|
|
688
|
-
and X_val is not None
|
|
689
|
-
and ((epoch + 1) % eval_interval == 0)
|
|
690
|
-
):
|
|
691
|
-
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
692
|
-
mask_override = None
|
|
693
|
-
if (
|
|
694
|
-
self.simulate_missing
|
|
695
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
696
|
-
and getattr(self, "X_val_", None) is not None
|
|
697
|
-
and X_val.shape == self.X_val_.shape
|
|
698
|
-
):
|
|
699
|
-
mask_override = self.sim_mask_test_
|
|
700
|
-
metric_val = self._eval_for_pruning(
|
|
816
|
+
if trial is not None:
|
|
817
|
+
metric_vals = self._evaluate_model(
|
|
701
818
|
model=model,
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
819
|
+
X=self.X_val_corrupted_,
|
|
820
|
+
y=self.y_val_,
|
|
821
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
705
822
|
objective_mode=True,
|
|
706
|
-
do_latent_infer=False, # VAE uses encoder; no latent refine
|
|
707
|
-
latent_steps=0,
|
|
708
|
-
latent_lr=0.0,
|
|
709
|
-
latent_weight_decay=0.0,
|
|
710
|
-
latent_seed=self.seed, # type: ignore
|
|
711
|
-
_latent_cache=None,
|
|
712
|
-
_latent_cache_key=None,
|
|
713
|
-
eval_mask_override=mask_override,
|
|
714
823
|
)
|
|
715
|
-
trial.report(
|
|
716
|
-
if
|
|
824
|
+
trial.report(metric_vals[self.tune_metric], step=epoch + 1)
|
|
825
|
+
if trial.should_prune():
|
|
717
826
|
raise optuna.exceptions.TrialPruned(
|
|
718
|
-
f"
|
|
827
|
+
f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
|
|
719
828
|
)
|
|
720
829
|
|
|
721
|
-
best_loss = early_stopping.best_score
|
|
722
|
-
|
|
723
|
-
if
|
|
724
|
-
|
|
725
|
-
|
|
830
|
+
best_loss = float(early_stopping.best_score)
|
|
831
|
+
|
|
832
|
+
if early_stopping.best_state_dict is not None:
|
|
833
|
+
model.load_state_dict(early_stopping.best_state_dict)
|
|
834
|
+
|
|
835
|
+
return best_loss, model, dict(history)
|
|
726
836
|
|
|
727
837
|
def _train_step(
|
|
728
838
|
self,
|
|
729
839
|
loader: torch.utils.data.DataLoader,
|
|
730
840
|
optimizer: torch.optim.Optimizer,
|
|
731
841
|
model: torch.nn.Module,
|
|
842
|
+
ce_criterion: torch.nn.Module,
|
|
843
|
+
*,
|
|
732
844
|
l1_penalty: float,
|
|
733
|
-
|
|
845
|
+
kl_beta: torch.Tensor | float,
|
|
734
846
|
) -> float:
|
|
735
|
-
"""
|
|
736
|
-
|
|
737
|
-
This method performs a single training epoch for the VAE model. It processes batches of data, computes the reconstruction and KL divergence losses, applies L1 regularization if specified, and updates the model parameters. The method includes safeguards against non-finite values in the model outputs and gradients to ensure stable training.
|
|
847
|
+
"""Single epoch train step across batches (focal CE + KL + optional L1).
|
|
738
848
|
|
|
739
849
|
Args:
|
|
740
850
|
loader (torch.utils.data.DataLoader): Training data loader.
|
|
741
851
|
optimizer (torch.optim.Optimizer): Optimizer.
|
|
742
852
|
model (torch.nn.Module): VAE model.
|
|
853
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
743
854
|
l1_penalty (float): L1 regularization coefficient.
|
|
744
|
-
|
|
855
|
+
kl_beta (torch.Tensor | float): KL divergence weight.
|
|
745
856
|
|
|
746
857
|
Returns:
|
|
747
|
-
float: Average training loss
|
|
858
|
+
float: Average training loss.
|
|
859
|
+
|
|
748
860
|
"""
|
|
749
861
|
model.train()
|
|
750
|
-
running
|
|
862
|
+
running = 0.0
|
|
863
|
+
num_batches = 0
|
|
864
|
+
|
|
865
|
+
nF_model = self.num_features_
|
|
866
|
+
nC_model = self.num_classes_
|
|
751
867
|
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
752
|
-
if class_weights is not None and class_weights.device != self.device:
|
|
753
|
-
class_weights = class_weights.to(self.device)
|
|
754
868
|
|
|
755
|
-
for
|
|
869
|
+
for X_batch, y_batch, m_batch in loader:
|
|
756
870
|
optimizer.zero_grad(set_to_none=True)
|
|
871
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
872
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
873
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
757
874
|
|
|
758
|
-
|
|
759
|
-
|
|
875
|
+
raw = model(X_batch)
|
|
876
|
+
logits0 = raw[0]
|
|
760
877
|
|
|
761
|
-
|
|
762
|
-
|
|
878
|
+
expected = X_batch.shape[0] * nF_model * nC_model
|
|
879
|
+
if logits0.numel() != expected:
|
|
880
|
+
msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
|
|
881
|
+
self.logger.error(msg)
|
|
882
|
+
raise ValueError(msg)
|
|
763
883
|
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
if (
|
|
773
|
-
not torch.isfinite(recon_logits).all()
|
|
774
|
-
or not torch.isfinite(mu).all()
|
|
775
|
-
or not torch.isfinite(logvar).all()
|
|
776
|
-
):
|
|
884
|
+
logits_masked = logits0.view(-1, nC_model)
|
|
885
|
+
logits_masked = logits_masked[m_batch.view(-1)]
|
|
886
|
+
|
|
887
|
+
targets_masked = y_batch.view(-1)
|
|
888
|
+
targets_masked = targets_masked[m_batch.view(-1)]
|
|
889
|
+
|
|
890
|
+
mask_flat = m_batch.view(-1)
|
|
891
|
+
if not bool(mask_flat.any()):
|
|
777
892
|
continue
|
|
778
893
|
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
gamma = max(0.0, min(gamma, 10.0))
|
|
894
|
+
# average number of masked loci per sample (scalar)
|
|
895
|
+
recon_scale = (mask_flat.sum().float() / float(X_batch.shape[0])).detach()
|
|
782
896
|
|
|
783
897
|
loss = compute_vae_loss(
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
898
|
+
ce_criterion,
|
|
899
|
+
logits_masked,
|
|
900
|
+
targets_masked,
|
|
901
|
+
mu=raw[1],
|
|
902
|
+
logvar=raw[2],
|
|
903
|
+
kl_beta=kl_beta,
|
|
904
|
+
recon_scale=recon_scale,
|
|
791
905
|
)
|
|
792
906
|
|
|
793
907
|
if l1_penalty > 0:
|
|
@@ -796,171 +910,279 @@ class ImputeVAE(BaseNNImputer):
|
|
|
796
910
|
l1 = l1 + p.abs().sum()
|
|
797
911
|
loss = loss + l1_penalty * l1
|
|
798
912
|
|
|
799
|
-
if not torch.isfinite(loss):
|
|
800
|
-
continue
|
|
801
|
-
|
|
802
913
|
loss.backward()
|
|
803
914
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
804
|
-
|
|
805
|
-
# skip update if any grad is non-finite
|
|
806
|
-
bad = any(
|
|
807
|
-
p.grad is not None and not torch.isfinite(p.grad).all()
|
|
808
|
-
for p in model.parameters()
|
|
809
|
-
)
|
|
810
|
-
if bad:
|
|
811
|
-
optimizer.zero_grad(set_to_none=True)
|
|
812
|
-
continue
|
|
813
|
-
|
|
814
915
|
optimizer.step()
|
|
815
916
|
|
|
816
917
|
running += float(loss.detach().item())
|
|
817
|
-
|
|
918
|
+
num_batches += 1
|
|
919
|
+
|
|
920
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
921
|
+
|
|
922
|
+
def _val_step(
|
|
923
|
+
self,
|
|
924
|
+
loader: torch.utils.data.DataLoader,
|
|
925
|
+
model: torch.nn.Module,
|
|
926
|
+
ce_criterion: torch.nn.Module,
|
|
927
|
+
*,
|
|
928
|
+
l1_penalty: float,
|
|
929
|
+
kl_beta: torch.Tensor | float = 1.0,
|
|
930
|
+
) -> float:
|
|
931
|
+
"""Validation step for a single epoch (focal CE + KL + optional L1).
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
loader (torch.utils.data.DataLoader): Validation data loader.
|
|
935
|
+
model (torch.nn.Module): VAE model.
|
|
936
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
937
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
938
|
+
kl_beta (torch.Tensor | float): KL divergence weight.
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
float: Average validation loss.
|
|
942
|
+
"""
|
|
943
|
+
model.eval()
|
|
944
|
+
running = 0.0
|
|
945
|
+
num_batches = 0
|
|
946
|
+
|
|
947
|
+
nF_model = self.num_features_
|
|
948
|
+
nC_model = self.num_classes_
|
|
949
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
950
|
+
|
|
951
|
+
with torch.no_grad():
|
|
952
|
+
for X_batch, y_batch, m_batch in loader:
|
|
953
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
954
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
955
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
956
|
+
|
|
957
|
+
raw = model(X_batch)
|
|
958
|
+
logits0 = raw[0]
|
|
959
|
+
|
|
960
|
+
expected = X_batch.shape[0] * nF_model * nC_model
|
|
961
|
+
if logits0.numel() != expected:
|
|
962
|
+
msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
|
|
963
|
+
self.logger.error(msg)
|
|
964
|
+
raise ValueError(msg)
|
|
965
|
+
|
|
966
|
+
logits_masked = logits0.view(-1, nC_model)
|
|
967
|
+
logits_masked = logits_masked[m_batch.view(-1)]
|
|
968
|
+
|
|
969
|
+
targets_masked = y_batch.view(-1)
|
|
970
|
+
targets_masked = targets_masked[m_batch.view(-1)]
|
|
971
|
+
|
|
972
|
+
mask_flat = m_batch.view(-1)
|
|
973
|
+
|
|
974
|
+
if not bool(mask_flat.any()):
|
|
975
|
+
continue
|
|
976
|
+
|
|
977
|
+
# average number of masked loci per sample (scalar)
|
|
978
|
+
recon_scale = (
|
|
979
|
+
mask_flat.sum().float() / float(X_batch.shape[0])
|
|
980
|
+
).detach()
|
|
981
|
+
|
|
982
|
+
loss = compute_vae_loss(
|
|
983
|
+
ce_criterion,
|
|
984
|
+
logits_masked,
|
|
985
|
+
targets_masked,
|
|
986
|
+
mu=raw[1],
|
|
987
|
+
logvar=raw[2],
|
|
988
|
+
kl_beta=kl_beta,
|
|
989
|
+
recon_scale=recon_scale,
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
if l1_penalty > 0:
|
|
993
|
+
l1 = torch.zeros((), device=self.device)
|
|
994
|
+
for p in l1_params:
|
|
995
|
+
l1 = l1 + p.abs().sum()
|
|
996
|
+
loss = loss + l1_penalty * l1
|
|
997
|
+
|
|
998
|
+
if not torch.isfinite(loss):
|
|
999
|
+
continue
|
|
818
1000
|
|
|
819
|
-
|
|
1001
|
+
running += float(loss.item())
|
|
1002
|
+
num_batches += 1
|
|
1003
|
+
|
|
1004
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
820
1005
|
|
|
821
1006
|
def _predict(
|
|
822
1007
|
self,
|
|
823
1008
|
model: torch.nn.Module,
|
|
824
1009
|
X: np.ndarray | torch.Tensor,
|
|
1010
|
+
*,
|
|
825
1011
|
return_proba: bool = False,
|
|
826
|
-
) ->
|
|
827
|
-
"""Predict
|
|
1012
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
1013
|
+
"""Predict categorical genotype labels from logits.
|
|
828
1014
|
|
|
829
|
-
This method uses the trained
|
|
1015
|
+
This method uses the trained model to predict genotype labels for the provided input data. It handles both 0/1/2 encoded matrices and one-hot encoded matrices, converting them as necessary for model input. The method returns the predicted labels and, optionally, the predicted probabilities.
|
|
830
1016
|
|
|
831
1017
|
Args:
|
|
832
1018
|
model (torch.nn.Module): Trained model.
|
|
833
|
-
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
|
|
834
|
-
return_proba (bool): If True,
|
|
1019
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
|
|
1020
|
+
return_proba (bool): If True, return probabilities.
|
|
835
1021
|
|
|
836
1022
|
Returns:
|
|
837
|
-
|
|
1023
|
+
tuple[np.ndarray, np.ndarray | None]: (labels, probas|None).
|
|
838
1024
|
"""
|
|
839
1025
|
if model is None:
|
|
840
|
-
msg =
|
|
1026
|
+
msg = (
|
|
1027
|
+
"Model passed to predict() is not trained. "
|
|
1028
|
+
"Call fit() before predict()."
|
|
1029
|
+
)
|
|
841
1030
|
self.logger.error(msg)
|
|
842
1031
|
raise NotFittedError(msg)
|
|
843
1032
|
|
|
844
1033
|
model.eval()
|
|
1034
|
+
|
|
1035
|
+
nF = self.num_features_
|
|
1036
|
+
nC = self.num_classes_
|
|
1037
|
+
|
|
1038
|
+
if isinstance(X, torch.Tensor):
|
|
1039
|
+
X_tensor = X
|
|
1040
|
+
else:
|
|
1041
|
+
X_tensor = torch.from_numpy(X)
|
|
1042
|
+
X_tensor = X_tensor.float()
|
|
1043
|
+
|
|
1044
|
+
if X_tensor.device != self.device:
|
|
1045
|
+
X_tensor = X_tensor.to(self.device)
|
|
1046
|
+
|
|
1047
|
+
if X_tensor.dim() == 2:
|
|
1048
|
+
# 0/1/2 matrix -> one-hot for model input
|
|
1049
|
+
X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
|
|
1050
|
+
X_tensor = X_tensor.float()
|
|
1051
|
+
|
|
1052
|
+
if X_tensor.device != self.device:
|
|
1053
|
+
X_tensor = X_tensor.to(self.device)
|
|
1054
|
+
|
|
1055
|
+
elif X_tensor.dim() != 3:
|
|
1056
|
+
msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
|
|
1057
|
+
self.logger.error(msg)
|
|
1058
|
+
raise ValueError(msg)
|
|
1059
|
+
|
|
1060
|
+
if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
|
|
1061
|
+
msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
|
|
1062
|
+
self.logger.error(msg)
|
|
1063
|
+
raise ValueError(msg)
|
|
1064
|
+
|
|
1065
|
+
X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
|
|
1066
|
+
|
|
845
1067
|
with torch.no_grad():
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
x_ohe = self._one_hot_encode_012(X_tensor)
|
|
849
|
-
outputs = model(x_ohe) # first element must be recon logits
|
|
850
|
-
logits = outputs[0].view(-1, self.num_features_, self.num_classes_)
|
|
1068
|
+
raw = model(X_tensor)
|
|
1069
|
+
logits = raw[0].view(-1, nF, nC)
|
|
851
1070
|
probas = torch.softmax(logits, dim=-1)
|
|
852
1071
|
labels = torch.argmax(probas, dim=-1)
|
|
853
1072
|
|
|
854
1073
|
if return_proba:
|
|
855
1074
|
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
856
|
-
|
|
857
|
-
return labels.cpu().numpy()
|
|
1075
|
+
return labels.cpu().numpy(), None
|
|
858
1076
|
|
|
859
1077
|
def _evaluate_model(
|
|
860
1078
|
self,
|
|
861
|
-
X_val: np.ndarray,
|
|
862
1079
|
model: torch.nn.Module,
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
1080
|
+
X: np.ndarray | torch.Tensor,
|
|
1081
|
+
y: np.ndarray,
|
|
1082
|
+
eval_mask: np.ndarray,
|
|
866
1083
|
*,
|
|
867
|
-
|
|
1084
|
+
objective_mode: bool = False,
|
|
868
1085
|
) -> Dict[str, float]:
|
|
869
|
-
"""Evaluate
|
|
1086
|
+
"""Evaluate model performance on masked genotypes.
|
|
870
1087
|
|
|
871
|
-
This method evaluates the trained
|
|
1088
|
+
This method evaluates the performance of the trained model on a given dataset using a specified evaluation mask. It computes various classification metrics based on the predicted labels and probabilities, comparing them to the ground truth labels. The method returns a dictionary of evaluation metrics.
|
|
872
1089
|
|
|
873
1090
|
Args:
|
|
874
|
-
X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
|
|
875
1091
|
model (torch.nn.Module): Trained model.
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
1092
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
|
|
1093
|
+
y (np.ndarray): Ground truth 0/1/2 matrix with -1 for missing.
|
|
1094
|
+
eval_mask (np.ndarray): Boolean mask indicating which genotypes to evaluate.
|
|
1095
|
+
objective_mode (bool): If True, suppresses verbose output.
|
|
880
1096
|
|
|
881
1097
|
Returns:
|
|
882
|
-
Dict[str, float]:
|
|
883
|
-
|
|
884
|
-
Raises:
|
|
885
|
-
NotFittedError: If called before fit().
|
|
1098
|
+
Dict[str, float]: Evaluation metrics.
|
|
886
1099
|
"""
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
1100
|
+
if model is None:
|
|
1101
|
+
msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
|
|
1102
|
+
self.logger.error(msg)
|
|
1103
|
+
raise NotFittedError(msg)
|
|
890
1104
|
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
# FIX 1: Match rows (shape[0]) only to allow feature subsets (tune_fast)
|
|
894
|
-
if (
|
|
895
|
-
hasattr(self, "X_val_")
|
|
896
|
-
and getattr(self, "X_val_", None) is not None
|
|
897
|
-
and X_val.shape[0] == self.X_val_.shape[0]
|
|
898
|
-
):
|
|
899
|
-
GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
|
|
900
|
-
elif (
|
|
901
|
-
hasattr(self, "X_train_")
|
|
902
|
-
and getattr(self, "X_train_", None) is not None
|
|
903
|
-
and X_val.shape[0] == self.X_train_.shape[0]
|
|
904
|
-
):
|
|
905
|
-
GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
|
|
906
|
-
else:
|
|
907
|
-
GT_ref = self.ground_truth_
|
|
1105
|
+
pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
|
|
908
1106
|
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
1107
|
+
if pred_probas is None:
|
|
1108
|
+
msg = "Predicted probabilities are None in _evaluate_model()."
|
|
1109
|
+
self.logger.error(msg)
|
|
1110
|
+
raise ValueError(msg)
|
|
913
1111
|
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
1112
|
+
y_true_flat = y[eval_mask].astype(np.int8, copy=False)
|
|
1113
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
|
|
1114
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
|
|
1115
|
+
|
|
1116
|
+
valid = y_true_flat >= 0
|
|
1117
|
+
y_true_flat = y_true_flat[valid]
|
|
1118
|
+
y_pred_flat = y_pred_flat[valid]
|
|
1119
|
+
y_proba_flat = y_proba_flat[valid]
|
|
1120
|
+
|
|
1121
|
+
if y_true_flat.size == 0:
|
|
1122
|
+
return {self.tune_metric: 0.0}
|
|
917
1123
|
|
|
918
|
-
#
|
|
919
|
-
if
|
|
920
|
-
|
|
1124
|
+
# --- Hard assertions on probability shape ---
|
|
1125
|
+
if y_proba_flat.ndim != 2:
|
|
1126
|
+
msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got shape {y_proba_flat.shape}."
|
|
1127
|
+
self.logger.error(msg)
|
|
1128
|
+
raise ValueError(msg)
|
|
1129
|
+
|
|
1130
|
+
K = int(y_proba_flat.shape[1])
|
|
1131
|
+
if self.is_haploid_:
|
|
1132
|
+
if K not in (2, 3):
|
|
1133
|
+
msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
|
|
1134
|
+
self.logger.error(msg)
|
|
1135
|
+
raise ValueError(msg)
|
|
1136
|
+
else:
|
|
1137
|
+
if K != 3:
|
|
1138
|
+
msg = f"Diploid evaluation expects 3 classes; got {K}."
|
|
1139
|
+
self.logger.error(msg)
|
|
1140
|
+
raise ValueError(msg)
|
|
1141
|
+
|
|
1142
|
+
if not self.is_haploid_:
|
|
1143
|
+
if np.any((y_true_flat < 0) | (y_true_flat > 2)):
|
|
921
1144
|
msg = (
|
|
922
|
-
|
|
923
|
-
f"does not match X_val rows {X_val.shape[0]}"
|
|
1145
|
+
"Diploid y_true_flat contains values outside {0,1,2} after masking."
|
|
924
1146
|
)
|
|
925
1147
|
self.logger.error(msg)
|
|
926
1148
|
raise ValueError(msg)
|
|
927
1149
|
|
|
928
|
-
|
|
929
|
-
|
|
1150
|
+
# --- Harmonize for haploid vs diploid ---
|
|
1151
|
+
if self.is_haploid_:
|
|
1152
|
+
# Binary scoring: REF=0, ALT=1 (treat any non-zero as ALT)
|
|
1153
|
+
y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
|
|
1154
|
+
y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
|
|
1155
|
+
|
|
1156
|
+
K = y_proba_flat.shape[1]
|
|
1157
|
+
if K == 2:
|
|
1158
|
+
pass
|
|
1159
|
+
elif K == 3:
|
|
1160
|
+
proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
|
|
1161
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
1162
|
+
proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
|
|
1163
|
+
y_proba_flat = proba_2
|
|
930
1164
|
else:
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
eval_mask = eval_mask & finite_mask & (GT_ref != -1)
|
|
936
|
-
|
|
937
|
-
y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
|
|
938
|
-
y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
|
|
939
|
-
y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
|
|
1165
|
+
msg = f"Haploid evaluation expects 2 or 3 prob columns; got {K}"
|
|
1166
|
+
self.logger.error(msg)
|
|
1167
|
+
raise ValueError(msg)
|
|
940
1168
|
|
|
941
|
-
|
|
942
|
-
|
|
1169
|
+
labels_for_scoring = [0, 1]
|
|
1170
|
+
target_names = ["REF", "ALT"]
|
|
1171
|
+
else:
|
|
1172
|
+
if y_proba_flat.shape[1] != 3:
|
|
1173
|
+
msg = f"Diploid evaluation expects 3 prob columns; got {y_proba_flat.shape[1]}"
|
|
1174
|
+
self.logger.error(msg)
|
|
1175
|
+
raise ValueError(msg)
|
|
1176
|
+
labels_for_scoring = [0, 1, 2]
|
|
1177
|
+
target_names = ["REF", "HET", "ALT"]
|
|
943
1178
|
|
|
944
|
-
#
|
|
1179
|
+
# Ensure valid probability simplex after masking/collapsing
|
|
945
1180
|
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
946
1181
|
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
947
|
-
row_sums[row_sums == 0] = 1.0
|
|
1182
|
+
row_sums[row_sums == 0.0] = 1.0
|
|
948
1183
|
y_proba_flat = y_proba_flat / row_sums
|
|
949
1184
|
|
|
950
|
-
|
|
951
|
-
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
952
|
-
|
|
953
|
-
if self.is_haploid:
|
|
954
|
-
y_true_flat = y_true_flat.copy()
|
|
955
|
-
y_pred_flat = y_pred_flat.copy()
|
|
956
|
-
y_true_flat[y_true_flat == 2] = 1
|
|
957
|
-
y_pred_flat[y_pred_flat == 2] = 1
|
|
958
|
-
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
959
|
-
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
960
|
-
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
961
|
-
y_proba_flat = proba_2
|
|
962
|
-
|
|
963
|
-
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
1185
|
+
y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
|
|
964
1186
|
|
|
965
1187
|
metrics = self.scorers_.evaluate(
|
|
966
1188
|
y_true_flat,
|
|
@@ -968,16 +1190,29 @@ class ImputeVAE(BaseNNImputer):
|
|
|
968
1190
|
y_true_ohe,
|
|
969
1191
|
y_proba_flat,
|
|
970
1192
|
objective_mode,
|
|
971
|
-
|
|
1193
|
+
cast(
|
|
1194
|
+
Literal[
|
|
1195
|
+
"pr_macro",
|
|
1196
|
+
"roc_auc",
|
|
1197
|
+
"accuracy",
|
|
1198
|
+
"f1",
|
|
1199
|
+
"average_precision",
|
|
1200
|
+
"precision",
|
|
1201
|
+
"recall",
|
|
1202
|
+
"mcc",
|
|
1203
|
+
"jaccard",
|
|
1204
|
+
],
|
|
1205
|
+
self.tune_metric,
|
|
1206
|
+
),
|
|
972
1207
|
)
|
|
973
1208
|
|
|
974
1209
|
if not objective_mode:
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
1210
|
+
if self.verbose or self.debug:
|
|
1211
|
+
pm = PrettyMetrics(
|
|
1212
|
+
metrics, precision=2, title=f"{self.model_name} Validation Metrics"
|
|
1213
|
+
)
|
|
1214
|
+
pm.render()
|
|
979
1215
|
|
|
980
|
-
# Primary report
|
|
981
1216
|
self._make_class_reports(
|
|
982
1217
|
y_true=y_true_flat,
|
|
983
1218
|
y_pred_proba=y_proba_flat,
|
|
@@ -986,16 +1221,17 @@ class ImputeVAE(BaseNNImputer):
|
|
|
986
1221
|
labels=target_names,
|
|
987
1222
|
)
|
|
988
1223
|
|
|
989
|
-
# IUPAC decode
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
)
|
|
1224
|
+
# --- IUPAC decode and 10-base integer report ---
|
|
1225
|
+
y_true_matrix = np.array(y, copy=True)
|
|
1226
|
+
y_pred_matrix = np.array(pred_labels, copy=True)
|
|
1227
|
+
|
|
1228
|
+
if self.is_haploid_:
|
|
1229
|
+
# Map any ALT-coded >0 to 2 for decode_012, preserve missing -1
|
|
1230
|
+
y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
|
|
1231
|
+
y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
|
|
1232
|
+
|
|
1233
|
+
y_true_dec = self.decode_012(y_true_matrix)
|
|
1234
|
+
y_pred_dec = self.decode_012(y_pred_matrix)
|
|
999
1235
|
|
|
1000
1236
|
encodings_dict = {
|
|
1001
1237
|
"A": 0,
|
|
@@ -1043,71 +1279,71 @@ class ImputeVAE(BaseNNImputer):
|
|
|
1043
1279
|
|
|
1044
1280
|
Returns:
|
|
1045
1281
|
float: Value of the tuning metric to be optimized.
|
|
1282
|
+
|
|
1283
|
+
Raises:
|
|
1284
|
+
RuntimeError: If model training returns None.
|
|
1285
|
+
optuna.exceptions.TrialPruned: If training fails unexpectedly or is unpromising.
|
|
1046
1286
|
"""
|
|
1047
1287
|
try:
|
|
1048
1288
|
params = self._sample_hyperparameters(trial)
|
|
1049
1289
|
|
|
1050
|
-
X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
|
|
1051
|
-
X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
|
|
1052
|
-
|
|
1053
|
-
class_weights = self._normalize_class_weights(
|
|
1054
|
-
self._class_weights_from_zygosity(X_train)
|
|
1055
|
-
)
|
|
1056
|
-
train_loader = self._get_data_loader(X_train)
|
|
1057
|
-
|
|
1058
1290
|
model = self.build_model(self.Model, params["model_params"])
|
|
1059
1291
|
model.apply(self.initialize_weights)
|
|
1060
1292
|
|
|
1061
|
-
lr: float = params["
|
|
1293
|
+
lr: float = params["learning_rate"]
|
|
1062
1294
|
l1_penalty: float = params["l1_penalty"]
|
|
1063
1295
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1296
|
+
class_weights = self._class_weights_from_zygosity(
|
|
1297
|
+
self.y_train_,
|
|
1298
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1299
|
+
inverse=params["inverse"],
|
|
1300
|
+
normalize=params["normalize"],
|
|
1301
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1302
|
+
power=params["power"],
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
res = self._train_and_validate_model(
|
|
1066
1306
|
model=model,
|
|
1067
|
-
loader=train_loader,
|
|
1068
1307
|
lr=lr,
|
|
1069
1308
|
l1_penalty=l1_penalty,
|
|
1309
|
+
params=params,
|
|
1070
1310
|
trial=trial,
|
|
1071
|
-
return_history=False,
|
|
1072
1311
|
class_weights=class_weights,
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
prune_metric=self.tune_metric,
|
|
1076
|
-
prune_warmup_epochs=5,
|
|
1077
|
-
eval_interval=self.tune_eval_interval,
|
|
1078
|
-
eval_requires_latents=False,
|
|
1079
|
-
eval_latent_steps=0,
|
|
1080
|
-
eval_latent_lr=0.0,
|
|
1081
|
-
eval_latent_weight_decay=0.0,
|
|
1082
|
-
)
|
|
1083
|
-
|
|
1084
|
-
eval_mask = (
|
|
1085
|
-
self.sim_mask_test_
|
|
1086
|
-
if (
|
|
1087
|
-
self.simulate_missing
|
|
1088
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
1089
|
-
)
|
|
1090
|
-
else None
|
|
1312
|
+
kl_beta_schedule=params["kl_beta_schedule"],
|
|
1313
|
+
gamma_schedule=params["gamma_schedule"],
|
|
1091
1314
|
)
|
|
1315
|
+
model = res[1]
|
|
1092
1316
|
|
|
1093
1317
|
if model is None:
|
|
1094
|
-
|
|
1318
|
+
msg = "Model training returned None in tuning objective."
|
|
1319
|
+
self.logger.error(msg)
|
|
1320
|
+
raise RuntimeError(msg)
|
|
1095
1321
|
|
|
1096
1322
|
metrics = self._evaluate_model(
|
|
1097
|
-
|
|
1323
|
+
model=model,
|
|
1324
|
+
X=self.X_val_corrupted_,
|
|
1325
|
+
y=self.y_val_,
|
|
1326
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
1327
|
+
objective_mode=True,
|
|
1098
1328
|
)
|
|
1099
|
-
|
|
1329
|
+
|
|
1330
|
+
self._clear_resources(model)
|
|
1100
1331
|
return metrics[self.tune_metric]
|
|
1101
1332
|
|
|
1102
1333
|
except Exception as e:
|
|
1103
|
-
#
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1334
|
+
# Unexpected failure: surface full details in logs while still
|
|
1335
|
+
# pruning the trial to keep sweeps moving.
|
|
1336
|
+
err_type = type(e).__name__
|
|
1337
|
+
self.logger.warning(
|
|
1338
|
+
f"Trial {trial.number} failed due to exception {err_type}: {e}"
|
|
1107
1339
|
)
|
|
1340
|
+
self.logger.debug(traceback.format_exc())
|
|
1341
|
+
raise optuna.exceptions.TrialPruned(
|
|
1342
|
+
f"Trial {trial.number} failed due to an exception. {err_type}: {e}. Enable debug logging for full traceback."
|
|
1343
|
+
) from e
|
|
1108
1344
|
|
|
1109
1345
|
def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
|
|
1110
|
-
"""Sample VAE hyperparameters; hidden sizes
|
|
1346
|
+
"""Sample VAE hyperparameters; hidden sizes use BaseNNImputer helper.
|
|
1111
1347
|
|
|
1112
1348
|
Args:
|
|
1113
1349
|
trial (optuna.Trial): Optuna trial object.
|
|
@@ -1116,113 +1352,109 @@ class ImputeVAE(BaseNNImputer):
|
|
|
1116
1352
|
Dict[str, int | float | str]: Sampled hyperparameters.
|
|
1117
1353
|
"""
|
|
1118
1354
|
params = {
|
|
1119
|
-
"latent_dim": trial.suggest_int("latent_dim", 2,
|
|
1120
|
-
"
|
|
1121
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.
|
|
1122
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1,
|
|
1355
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
1356
|
+
"learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
|
|
1357
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
|
|
1358
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
|
|
1123
1359
|
"activation": trial.suggest_categorical(
|
|
1124
|
-
"activation", ["relu", "elu", "selu"]
|
|
1360
|
+
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
1125
1361
|
),
|
|
1126
|
-
"l1_penalty": trial.suggest_float("l1_penalty", 1e-
|
|
1362
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
1127
1363
|
"layer_scaling_factor": trial.suggest_float(
|
|
1128
|
-
"layer_scaling_factor", 2.0, 10.0
|
|
1364
|
+
"layer_scaling_factor", 2.0, 10.0, step=0.025
|
|
1129
1365
|
),
|
|
1130
1366
|
"layer_schedule": trial.suggest_categorical(
|
|
1131
|
-
"layer_schedule", ["pyramid", "
|
|
1367
|
+
"layer_schedule", ["pyramid", "linear"]
|
|
1368
|
+
),
|
|
1369
|
+
"power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
|
|
1370
|
+
"normalize": trial.suggest_categorical("normalize", [True, False]),
|
|
1371
|
+
"inverse": trial.suggest_categorical("inverse", [True, False]),
|
|
1372
|
+
"gamma": trial.suggest_float("gamma", 0.0, 3.0, step=0.1),
|
|
1373
|
+
"kl_beta": trial.suggest_float("kl_beta", 0.1, 5.0, step=0.1),
|
|
1374
|
+
"kl_beta_schedule": trial.suggest_categorical(
|
|
1375
|
+
"kl_beta_schedule", [True, False]
|
|
1376
|
+
),
|
|
1377
|
+
"gamma_schedule": trial.suggest_categorical(
|
|
1378
|
+
"gamma_schedule", [True, False]
|
|
1132
1379
|
),
|
|
1133
|
-
# VAE-specific β (final value after anneal)
|
|
1134
|
-
"beta": trial.suggest_float("beta", 0.25, 4.0),
|
|
1135
|
-
# focal gamma (if used in VAE recon CE)
|
|
1136
|
-
"gamma": trial.suggest_float("gamma", 0.0, 5.0),
|
|
1137
1380
|
}
|
|
1138
1381
|
|
|
1139
|
-
|
|
1382
|
+
nF: int = self.num_features_
|
|
1383
|
+
nC: int = self.num_classes_
|
|
1384
|
+
input_dim = nF * nC
|
|
1385
|
+
|
|
1140
1386
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1141
1387
|
n_inputs=input_dim,
|
|
1142
|
-
n_outputs=
|
|
1143
|
-
n_samples=len(self.
|
|
1388
|
+
n_outputs=nC,
|
|
1389
|
+
n_samples=len(self.X_train_),
|
|
1144
1390
|
n_hidden=params["num_hidden_layers"],
|
|
1391
|
+
latent_dim=params["latent_dim"],
|
|
1145
1392
|
alpha=params["layer_scaling_factor"],
|
|
1146
1393
|
schedule=params["layer_schedule"],
|
|
1147
1394
|
)
|
|
1148
1395
|
|
|
1149
|
-
# [latent_dim] + interior widths (exclude output width)
|
|
1150
|
-
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1151
|
-
|
|
1152
1396
|
params["model_params"] = {
|
|
1153
1397
|
"n_features": self.num_features_,
|
|
1154
|
-
"num_classes":
|
|
1155
|
-
"latent_dim": params["latent_dim"],
|
|
1398
|
+
"num_classes": nC, # categorical head: 2 or 3
|
|
1156
1399
|
"dropout_rate": params["dropout_rate"],
|
|
1157
|
-
"hidden_layer_sizes":
|
|
1400
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1158
1401
|
"activation": params["activation"],
|
|
1159
|
-
|
|
1160
|
-
"beta": params["beta"],
|
|
1161
|
-
"gamma": params["gamma"],
|
|
1402
|
+
"kl_beta": params["kl_beta"],
|
|
1162
1403
|
}
|
|
1404
|
+
|
|
1163
1405
|
return params
|
|
1164
1406
|
|
|
1165
|
-
def _set_best_params(self,
|
|
1166
|
-
"""
|
|
1407
|
+
def _set_best_params(self, params: dict) -> dict:
|
|
1408
|
+
"""Update instance fields from tuned params and return model_params dict.
|
|
1167
1409
|
|
|
1168
1410
|
Args:
|
|
1169
|
-
|
|
1411
|
+
params (dict): Best hyperparameters from tuning.
|
|
1170
1412
|
|
|
1171
1413
|
Returns:
|
|
1172
|
-
dict:
|
|
1414
|
+
dict: Model parameters for building the VAE.
|
|
1173
1415
|
"""
|
|
1174
|
-
self.latent_dim =
|
|
1175
|
-
self.dropout_rate =
|
|
1176
|
-
self.learning_rate =
|
|
1177
|
-
self.l1_penalty =
|
|
1178
|
-
self.activation =
|
|
1179
|
-
self.layer_scaling_factor =
|
|
1180
|
-
self.layer_schedule =
|
|
1181
|
-
self.
|
|
1182
|
-
self.
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1416
|
+
self.latent_dim = params["latent_dim"]
|
|
1417
|
+
self.dropout_rate = params["dropout_rate"]
|
|
1418
|
+
self.learning_rate = params["learning_rate"]
|
|
1419
|
+
self.l1_penalty = params["l1_penalty"]
|
|
1420
|
+
self.activation = params["activation"]
|
|
1421
|
+
self.layer_scaling_factor = params["layer_scaling_factor"]
|
|
1422
|
+
self.layer_schedule = params["layer_schedule"]
|
|
1423
|
+
self.power = params["power"]
|
|
1424
|
+
self.normalize = params["normalize"]
|
|
1425
|
+
self.inverse = params["inverse"]
|
|
1426
|
+
self.gamma = params["gamma"]
|
|
1427
|
+
self.gamma_schedule = params["gamma_schedule"]
|
|
1428
|
+
self.kl_beta = params["kl_beta"]
|
|
1429
|
+
self.kl_beta_schedule = params["kl_beta_schedule"]
|
|
1430
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
1431
|
+
self.y_train_,
|
|
1432
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1433
|
+
inverse=self.inverse,
|
|
1434
|
+
normalize=self.normalize,
|
|
1435
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1436
|
+
power=self.power,
|
|
1191
1437
|
)
|
|
1192
|
-
|
|
1438
|
+
nF = self.num_features_
|
|
1439
|
+
nC = self.num_classes_
|
|
1440
|
+
input_dim = nF * nC
|
|
1193
1441
|
|
|
1194
|
-
return {
|
|
1195
|
-
"n_features": self.num_features_,
|
|
1196
|
-
"latent_dim": self.latent_dim,
|
|
1197
|
-
"hidden_layer_sizes": hidden_only,
|
|
1198
|
-
"dropout_rate": self.dropout_rate,
|
|
1199
|
-
"activation": self.activation,
|
|
1200
|
-
"num_classes": self.num_classes_,
|
|
1201
|
-
"beta": self.kl_beta_final,
|
|
1202
|
-
"gamma": self.gamma,
|
|
1203
|
-
}
|
|
1204
|
-
|
|
1205
|
-
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
1206
|
-
"""Default VAE model params when tuning is disabled.
|
|
1207
|
-
|
|
1208
|
-
Returns:
|
|
1209
|
-
Dict[str, int | float | str | list]: VAE model parameters.
|
|
1210
|
-
"""
|
|
1211
1442
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1212
|
-
n_inputs=
|
|
1213
|
-
n_outputs=
|
|
1214
|
-
n_samples=len(self.
|
|
1215
|
-
n_hidden=
|
|
1216
|
-
|
|
1217
|
-
|
|
1443
|
+
n_inputs=input_dim,
|
|
1444
|
+
n_outputs=nC,
|
|
1445
|
+
n_samples=len(self.X_train_),
|
|
1446
|
+
n_hidden=params["num_hidden_layers"],
|
|
1447
|
+
latent_dim=params["latent_dim"],
|
|
1448
|
+
alpha=params["layer_scaling_factor"],
|
|
1449
|
+
schedule=params["layer_schedule"],
|
|
1218
1450
|
)
|
|
1451
|
+
|
|
1219
1452
|
return {
|
|
1220
|
-
"n_features":
|
|
1453
|
+
"n_features": nF,
|
|
1221
1454
|
"latent_dim": self.latent_dim,
|
|
1222
1455
|
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1223
1456
|
"dropout_rate": self.dropout_rate,
|
|
1224
1457
|
"activation": self.activation,
|
|
1225
|
-
"num_classes":
|
|
1226
|
-
"
|
|
1227
|
-
"gamma": self.gamma,
|
|
1458
|
+
"num_classes": nC,
|
|
1459
|
+
"kl_beta": params["kl_beta"],
|
|
1228
1460
|
}
|