pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import copy
|
|
4
|
-
|
|
4
|
+
import traceback
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
|
|
5
7
|
|
|
6
8
|
import matplotlib.pyplot as plt
|
|
7
9
|
import numpy as np
|
|
@@ -9,17 +11,15 @@ import optuna
|
|
|
9
11
|
import torch
|
|
10
12
|
import torch.nn.functional as F
|
|
11
13
|
from sklearn.exceptions import NotFittedError
|
|
12
|
-
from sklearn.model_selection import train_test_split
|
|
13
14
|
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
14
15
|
from snpio.utils.logging import LoggerManager
|
|
15
16
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
16
17
|
|
|
17
18
|
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
18
19
|
from pgsui.data_processing.containers import VAEConfig
|
|
19
|
-
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
20
20
|
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
21
21
|
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
22
|
-
from pgsui.impute.unsupervised.loss_functions import compute_vae_loss
|
|
22
|
+
from pgsui.impute.unsupervised.loss_functions import FocalCELoss, compute_vae_loss
|
|
23
23
|
from pgsui.impute.unsupervised.models.vae_model import VAEModel
|
|
24
24
|
from pgsui.utils.logging_utils import configure_logger
|
|
25
25
|
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
@@ -29,14 +29,47 @@ if TYPE_CHECKING:
|
|
|
29
29
|
from snpio.read_input.genotype_data import GenotypeData
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def
|
|
33
|
-
|
|
32
|
+
def _make_warmup_cosine_scheduler(
|
|
33
|
+
optimizer: torch.optim.Optimizer,
|
|
34
|
+
*,
|
|
35
|
+
max_epochs: int,
|
|
36
|
+
warmup_epochs: int,
|
|
37
|
+
start_factor: float = 0.1,
|
|
38
|
+
) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
|
|
39
|
+
"""Create a warmup->cosine LR scheduler.
|
|
34
40
|
|
|
35
41
|
Args:
|
|
36
|
-
|
|
42
|
+
optimizer (torch.optim.Optimizer): The optimizer to schedule.
|
|
43
|
+
max_epochs (int): Total number of epochs for training.
|
|
44
|
+
warmup_epochs (int): Number of warmup epochs.
|
|
45
|
+
start_factor (float): Initial LR factor for warmup.
|
|
37
46
|
|
|
38
47
|
Returns:
|
|
39
|
-
|
|
48
|
+
torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: The learning rate scheduler.
|
|
49
|
+
"""
|
|
50
|
+
warmup_epochs = int(max(0, warmup_epochs))
|
|
51
|
+
|
|
52
|
+
if warmup_epochs == 0:
|
|
53
|
+
return CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
54
|
+
|
|
55
|
+
warmup = torch.optim.lr_scheduler.LinearLR(
|
|
56
|
+
optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
|
|
57
|
+
)
|
|
58
|
+
cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
|
|
59
|
+
|
|
60
|
+
return torch.optim.lr_scheduler.SequentialLR(
|
|
61
|
+
optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def ensure_vae_config(config: VAEConfig | dict | str | None) -> VAEConfig:
|
|
66
|
+
"""Ensure a VAEConfig instance from various input types.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
config (VAEConfig | dict | str | None): Configuration input.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
VAEConfig: The resulting VAEConfig instance.
|
|
40
73
|
"""
|
|
41
74
|
if config is None:
|
|
42
75
|
return VAEConfig()
|
|
@@ -45,13 +78,13 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
|
|
|
45
78
|
if isinstance(config, str):
|
|
46
79
|
return load_yaml_to_dataclass(config, VAEConfig)
|
|
47
80
|
if isinstance(config, dict):
|
|
81
|
+
cfg_in = copy.deepcopy(config)
|
|
48
82
|
base = VAEConfig()
|
|
49
|
-
|
|
50
|
-
|
|
83
|
+
preset = cfg_in.pop("preset", None)
|
|
84
|
+
if "io" in cfg_in and isinstance(cfg_in["io"], dict):
|
|
85
|
+
preset = preset or cfg_in["io"].pop("preset", None)
|
|
51
86
|
if preset:
|
|
52
87
|
base = VAEConfig.from_preset(preset)
|
|
53
|
-
# Flatten + apply
|
|
54
|
-
flat: Dict[str, object] = {}
|
|
55
88
|
|
|
56
89
|
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
57
90
|
for k, v in d.items():
|
|
@@ -62,15 +95,19 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
|
|
|
62
95
|
out[kk] = v
|
|
63
96
|
return out
|
|
64
97
|
|
|
65
|
-
flat = _flatten("",
|
|
98
|
+
flat = _flatten("", cfg_in, {})
|
|
66
99
|
return apply_dot_overrides(base, flat)
|
|
67
100
|
raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
|
|
68
101
|
|
|
69
102
|
|
|
70
103
|
class ImputeVAE(BaseNNImputer):
|
|
71
|
-
"""Variational Autoencoder imputer
|
|
104
|
+
"""Variational Autoencoder (VAE) imputer for 0/1/2 genotypes.
|
|
105
|
+
|
|
106
|
+
Trains a VAE on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. The workflow simulates missingness once on the full matrix, then creates train/val/test splits. It supports haploid and diploid data, focal-CE reconstruction loss with a KL term (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
|
|
72
107
|
|
|
73
|
-
|
|
108
|
+
Notes:
|
|
109
|
+
- Training includes early stopping based on validation loss.
|
|
110
|
+
- The imputer can handle both haploid and diploid genotype data.
|
|
74
111
|
"""
|
|
75
112
|
|
|
76
113
|
def __init__(
|
|
@@ -79,46 +116,37 @@ class ImputeVAE(BaseNNImputer):
|
|
|
79
116
|
*,
|
|
80
117
|
tree_parser: Optional["TreeParser"] = None,
|
|
81
118
|
config: Optional[Union["VAEConfig", dict, str]] = None,
|
|
82
|
-
overrides: dict
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
sim_prop: float | None = None,
|
|
95
|
-
sim_kwargs: dict | None = None,
|
|
96
|
-
):
|
|
97
|
-
"""Initialize the VAE imputer with a unified config interface.
|
|
98
|
-
|
|
99
|
-
This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
119
|
+
overrides: Optional[dict] = None,
|
|
120
|
+
sim_strategy: Literal[
|
|
121
|
+
"random",
|
|
122
|
+
"random_weighted",
|
|
123
|
+
"random_weighted_inv",
|
|
124
|
+
"nonrandom",
|
|
125
|
+
"nonrandom_weighted",
|
|
126
|
+
] = "random",
|
|
127
|
+
sim_prop: Optional[float] = None,
|
|
128
|
+
sim_kwargs: Optional[dict] = None,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Initialize the ImputeVAE imputer.
|
|
100
131
|
|
|
101
132
|
Args:
|
|
102
|
-
genotype_data (GenotypeData):
|
|
103
|
-
tree_parser (TreeParser
|
|
104
|
-
config (Union[VAEConfig, dict, str
|
|
105
|
-
overrides (dict
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
|
|
133
|
+
genotype_data (GenotypeData): Genotype data for imputation.
|
|
134
|
+
tree_parser (Optional[TreeParser]): Tree parser required for nonrandom strategies.
|
|
135
|
+
config (Optional[Union[VAEConfig, dict, str]]): Config dataclass, nested dict, YAML path, or None.
|
|
136
|
+
overrides (Optional[dict]): Dot-key overrides applied last with highest precedence.
|
|
137
|
+
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Missingness simulation strategy (overrides config).
|
|
138
|
+
sim_prop (Optional[float]): Proportion of entries to simulate as missing (overrides config). Default is None.
|
|
139
|
+
sim_kwargs (Optional[dict]): Extra missingness kwargs merged into config.
|
|
110
140
|
"""
|
|
111
141
|
self.model_name = "ImputeVAE"
|
|
112
142
|
self.genotype_data = genotype_data
|
|
113
143
|
self.tree_parser = tree_parser
|
|
114
144
|
|
|
115
|
-
# Normalize configuration and apply top-precedence overrides
|
|
116
145
|
cfg = ensure_vae_config(config)
|
|
117
146
|
if overrides:
|
|
118
147
|
cfg = apply_dot_overrides(cfg, overrides)
|
|
119
148
|
self.cfg = cfg
|
|
120
149
|
|
|
121
|
-
# Logger (align with AE/NLPCA)
|
|
122
150
|
logman = LoggerManager(
|
|
123
151
|
__name__,
|
|
124
152
|
prefix=self.cfg.io.prefix,
|
|
@@ -126,12 +154,10 @@ class ImputeVAE(BaseNNImputer):
|
|
|
126
154
|
verbose=self.cfg.io.verbose,
|
|
127
155
|
)
|
|
128
156
|
self.logger = configure_logger(
|
|
129
|
-
logman.get_logger(),
|
|
130
|
-
verbose=self.cfg.io.verbose,
|
|
131
|
-
debug=self.cfg.io.debug,
|
|
157
|
+
logman.get_logger(), verbose=self.cfg.io.verbose, debug=self.cfg.io.debug
|
|
132
158
|
)
|
|
159
|
+
self.logger.propagate = False
|
|
133
160
|
|
|
134
|
-
# BaseNNImputer bootstraps device/dirs/log formatting
|
|
135
161
|
super().__init__(
|
|
136
162
|
model_name=self.model_name,
|
|
137
163
|
genotype_data=self.genotype_data,
|
|
@@ -141,11 +167,10 @@ class ImputeVAE(BaseNNImputer):
|
|
|
141
167
|
debug=self.cfg.io.debug,
|
|
142
168
|
)
|
|
143
169
|
|
|
144
|
-
# Model hook & encoder
|
|
145
170
|
self.Model = VAEModel
|
|
146
171
|
self.pgenc = GenotypeEncoder(genotype_data)
|
|
147
172
|
|
|
148
|
-
#
|
|
173
|
+
# I/O and general parameters
|
|
149
174
|
self.seed = self.cfg.io.seed
|
|
150
175
|
self.n_jobs = self.cfg.io.n_jobs
|
|
151
176
|
self.prefix = self.cfg.io.prefix
|
|
@@ -153,43 +178,41 @@ class ImputeVAE(BaseNNImputer):
|
|
|
153
178
|
self.verbose = self.cfg.io.verbose
|
|
154
179
|
self.debug = self.cfg.io.debug
|
|
155
180
|
self.rng = np.random.default_rng(self.seed)
|
|
156
|
-
self.pos_weights_: torch.Tensor | None = None
|
|
157
181
|
|
|
158
|
-
#
|
|
182
|
+
# Simulation parameters
|
|
159
183
|
sim_cfg = getattr(self.cfg, "sim", None)
|
|
160
184
|
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
161
185
|
if sim_kwargs:
|
|
162
186
|
sim_cfg_kwargs.update(sim_kwargs)
|
|
163
187
|
if sim_cfg is None:
|
|
164
|
-
default_sim_flag = bool(simulate_missing)
|
|
165
188
|
default_strategy = "random"
|
|
166
|
-
default_prop = 0.
|
|
189
|
+
default_prop = 0.2
|
|
167
190
|
else:
|
|
168
|
-
default_sim_flag = sim_cfg.simulate_missing
|
|
169
191
|
default_strategy = sim_cfg.sim_strategy
|
|
170
192
|
default_prop = sim_cfg.sim_prop
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
)
|
|
193
|
+
|
|
194
|
+
self.simulate_missing = True
|
|
174
195
|
self.sim_strategy = sim_strategy or default_strategy
|
|
175
196
|
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
176
197
|
self.sim_kwargs = sim_cfg_kwargs
|
|
177
198
|
|
|
178
|
-
# Model
|
|
199
|
+
# Model architecture parameters
|
|
179
200
|
self.latent_dim = self.cfg.model.latent_dim
|
|
180
201
|
self.dropout_rate = self.cfg.model.dropout_rate
|
|
181
202
|
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
182
203
|
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
183
204
|
self.layer_schedule = self.cfg.model.layer_schedule
|
|
184
|
-
self.activation = self.cfg.model.
|
|
185
|
-
self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
|
|
205
|
+
self.activation = self.cfg.model.activation
|
|
186
206
|
|
|
187
|
-
# VAE-
|
|
188
|
-
self.
|
|
189
|
-
self.
|
|
190
|
-
self.kl_ramp = self.cfg.vae.kl_ramp
|
|
207
|
+
# VAE-specific parameters
|
|
208
|
+
self.kl_beta = self.cfg.vae.kl_beta
|
|
209
|
+
self.kl_beta_schedule = self.cfg.vae.kl_beta_schedule
|
|
191
210
|
|
|
192
|
-
#
|
|
211
|
+
# Training parameters
|
|
212
|
+
self.power: float = self.cfg.train.weights_power
|
|
213
|
+
self.max_ratio: Any | float | None = self.cfg.train.weights_max_ratio
|
|
214
|
+
self.normalize: bool = self.cfg.train.weights_normalize
|
|
215
|
+
self.inverse: bool = self.cfg.train.weights_inverse
|
|
193
216
|
self.batch_size = self.cfg.train.batch_size
|
|
194
217
|
self.learning_rate = self.cfg.train.learning_rate
|
|
195
218
|
self.l1_penalty: float = self.cfg.train.l1_penalty
|
|
@@ -197,32 +220,18 @@ class ImputeVAE(BaseNNImputer):
|
|
|
197
220
|
self.min_epochs = self.cfg.train.min_epochs
|
|
198
221
|
self.epochs = self.cfg.train.max_epochs
|
|
199
222
|
self.validation_split = self.cfg.train.validation_split
|
|
200
|
-
self.
|
|
201
|
-
self.
|
|
223
|
+
self.gamma = self.cfg.train.gamma
|
|
224
|
+
self.gamma_schedule = self.cfg.train.gamma_schedule
|
|
202
225
|
|
|
203
|
-
#
|
|
226
|
+
# Hyperparameter tuning
|
|
204
227
|
self.tune = self.cfg.tune.enabled
|
|
205
|
-
self.
|
|
206
|
-
self.tune_batch_size = self.cfg.tune.batch_size
|
|
207
|
-
self.tune_epochs = self.cfg.tune.epochs
|
|
208
|
-
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
209
|
-
self.tune_metric: Literal[
|
|
210
|
-
"pr_macro",
|
|
211
|
-
"f1",
|
|
212
|
-
"accuracy",
|
|
213
|
-
"average_precision",
|
|
214
|
-
"precision",
|
|
215
|
-
"recall",
|
|
216
|
-
"roc_auc",
|
|
217
|
-
] = self.cfg.tune.metric
|
|
228
|
+
self.tune_metric = self.cfg.tune.metric
|
|
218
229
|
self.n_trials = self.cfg.tune.n_trials
|
|
219
230
|
self.tune_save_db = self.cfg.tune.save_db
|
|
220
231
|
self.tune_resume = self.cfg.tune.resume
|
|
221
|
-
self.tune_max_samples = self.cfg.tune.max_samples
|
|
222
|
-
self.tune_max_loci = self.cfg.tune.max_loci
|
|
223
232
|
self.tune_patience = self.cfg.tune.patience
|
|
224
233
|
|
|
225
|
-
# Plotting
|
|
234
|
+
# Plotting parameters
|
|
226
235
|
self.plot_format = self.cfg.plot.fmt
|
|
227
236
|
self.plot_dpi = self.cfg.plot.dpi
|
|
228
237
|
self.plot_fontsize = self.cfg.plot.fontsize
|
|
@@ -230,163 +239,281 @@ class ImputeVAE(BaseNNImputer):
|
|
|
230
239
|
self.despine = self.cfg.plot.despine
|
|
231
240
|
self.show_plots = self.cfg.plot.show
|
|
232
241
|
|
|
233
|
-
#
|
|
234
|
-
self.
|
|
235
|
-
self.num_classes_: int = 3
|
|
242
|
+
# Internal attributes set during fitting
|
|
243
|
+
self.is_haploid_: bool = False
|
|
244
|
+
self.num_classes_: int = 3
|
|
236
245
|
self.model_params: Dict[str, Any] = {}
|
|
237
|
-
self.
|
|
238
|
-
self.sim_mask_train_: np.ndarray | None = None
|
|
239
|
-
self.sim_mask_test_: np.ndarray | None = None
|
|
246
|
+
self.sim_mask_test_: np.ndarray
|
|
240
247
|
|
|
241
248
|
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
242
|
-
msg = "tree_parser is required for nonrandom
|
|
249
|
+
msg = "tree_parser is required for nonrandom sim strategies."
|
|
243
250
|
self.logger.error(msg)
|
|
244
251
|
raise ValueError(msg)
|
|
245
252
|
|
|
246
|
-
# -------------------- Fit -------------------- #
|
|
247
253
|
def fit(self) -> "ImputeVAE":
|
|
248
|
-
"""Fit the VAE
|
|
249
|
-
|
|
250
|
-
This method
|
|
254
|
+
"""Fit the VAE imputer model to the genotype data.
|
|
255
|
+
|
|
256
|
+
This method performs the following steps:
|
|
257
|
+
1. Validates the presence of SNP data in the genotype data.
|
|
258
|
+
2. Determines the ploidy of the genotype data and sets up haploid/diploid handling.
|
|
259
|
+
3. Simulates missingness in the genotype data based on the specified strategy.
|
|
260
|
+
4. Splits the data into training, validation, and test sets.
|
|
261
|
+
5. One-hot encodes the genotype data for model input.
|
|
262
|
+
6. Initializes data loaders for training and validation.
|
|
263
|
+
7. If hyperparameter tuning is enabled, tunes the model hyperparameters.
|
|
264
|
+
8. Builds the VAE model with the best hyperparameters.
|
|
265
|
+
9. Trains the VAE model using the training data and validates on the validation set.
|
|
266
|
+
10. Evaluates the trained model on the test set and computes performance metrics.
|
|
267
|
+
11. Saves the trained model and best hyperparameters.
|
|
268
|
+
12. Generates plots of training history if enabled.
|
|
269
|
+
13. Returns the fitted ImputeVAE instance.
|
|
251
270
|
|
|
252
271
|
Returns:
|
|
253
|
-
ImputeVAE:
|
|
254
|
-
|
|
255
|
-
Raises:
|
|
256
|
-
RuntimeError: If training fails to produce a model.
|
|
272
|
+
ImputeVAE: The fitted ImputeVAE instance.
|
|
257
273
|
"""
|
|
258
274
|
self.logger.info(f"Fitting {self.model_name} model...")
|
|
259
275
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
276
|
+
if self.genotype_data.snp_data is None:
|
|
277
|
+
msg = f"SNP data is required for {self.model_name}."
|
|
278
|
+
self.logger.error(msg)
|
|
279
|
+
raise AttributeError(msg)
|
|
264
280
|
|
|
265
|
-
self.
|
|
266
|
-
|
|
267
|
-
if self.simulate_missing:
|
|
268
|
-
cached_mask = (
|
|
269
|
-
None if cache_key is None else self._sim_mask_cache.get(cache_key)
|
|
270
|
-
)
|
|
271
|
-
if cached_mask is not None:
|
|
272
|
-
self.sim_mask_global_ = cached_mask.copy()
|
|
273
|
-
else:
|
|
274
|
-
tr = SimMissingTransformer(
|
|
275
|
-
genotype_data=self.genotype_data,
|
|
276
|
-
tree_parser=self.tree_parser,
|
|
277
|
-
prop_missing=self.sim_prop,
|
|
278
|
-
strategy=self.sim_strategy,
|
|
279
|
-
missing_val=-9,
|
|
280
|
-
mask_missing=True,
|
|
281
|
-
verbose=self.verbose,
|
|
282
|
-
**self.sim_kwargs,
|
|
283
|
-
)
|
|
284
|
-
tr.fit(X012.copy())
|
|
285
|
-
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
286
|
-
if cache_key is not None:
|
|
287
|
-
self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
|
|
281
|
+
self.ploidy = self.cfg.io.ploidy
|
|
282
|
+
self.is_haploid_ = self.ploidy == 1
|
|
288
283
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
# Ploidy/classes
|
|
295
|
-
self.is_haploid = bool(
|
|
296
|
-
np.all(
|
|
297
|
-
np.isin(
|
|
298
|
-
self.genotype_data.snp_data,
|
|
299
|
-
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
300
|
-
)
|
|
301
|
-
)
|
|
302
|
-
)
|
|
303
|
-
self.ploidy = 1 if self.is_haploid else 2
|
|
304
|
-
self.num_classes_ = 2 if self.is_haploid else 3
|
|
305
|
-
self.output_classes_ = 2
|
|
306
|
-
self.logger.info(
|
|
307
|
-
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
308
|
-
f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
|
|
309
|
-
)
|
|
284
|
+
if self.ploidy > 2:
|
|
285
|
+
msg = f"{self.model_name} currently supports only haploid (1) or diploid (2) data; got ploidy={self.ploidy}."
|
|
286
|
+
self.logger.error(msg)
|
|
287
|
+
raise ValueError(msg)
|
|
310
288
|
|
|
311
|
-
if self.
|
|
312
|
-
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
313
|
-
X_for_model[X_for_model == 2] = 1
|
|
289
|
+
self.num_classes_ = 2 if self.is_haploid_ else 3
|
|
314
290
|
|
|
315
|
-
|
|
291
|
+
gt_full = self.pgenc.genotypes_012.copy()
|
|
292
|
+
gt_full[gt_full < 0] = -1
|
|
293
|
+
gt_full = np.nan_to_num(gt_full, nan=-1.0)
|
|
294
|
+
self.ground_truth_ = gt_full.astype(np.int8)
|
|
295
|
+
self.num_features_ = gt_full.shape[1]
|
|
316
296
|
|
|
317
|
-
# Model params (decoder outputs L*K logits)
|
|
318
297
|
self.model_params = {
|
|
319
298
|
"n_features": self.num_features_,
|
|
320
|
-
"num_classes": self.
|
|
299
|
+
"num_classes": self.num_classes_,
|
|
321
300
|
"latent_dim": self.latent_dim,
|
|
322
301
|
"dropout_rate": self.dropout_rate,
|
|
323
302
|
"activation": self.activation,
|
|
324
303
|
}
|
|
325
304
|
|
|
326
|
-
#
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
305
|
+
# Simulate missingness ONCE on the full matrix
|
|
306
|
+
sim_tup = self.sim_missing_transform(self.ground_truth_)
|
|
307
|
+
X_for_model_full = sim_tup[0]
|
|
308
|
+
self.sim_mask_ = sim_tup[1]
|
|
309
|
+
self.orig_mask_ = sim_tup[2]
|
|
310
|
+
|
|
311
|
+
# Split indices based on clean ground truth
|
|
312
|
+
self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
|
|
313
|
+
self.ground_truth_
|
|
330
314
|
)
|
|
331
|
-
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
332
|
-
self.X_train_ = X_for_model[train_idx]
|
|
333
|
-
self.X_val_ = X_for_model[val_idx]
|
|
334
|
-
self.GT_train_full_ = self.ground_truth_[train_idx]
|
|
335
|
-
self.GT_test_full_ = self.ground_truth_[val_idx]
|
|
336
|
-
|
|
337
|
-
if self.sim_mask_global_ is not None:
|
|
338
|
-
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
339
|
-
self.sim_mask_test_ = self.sim_mask_global_[val_idx]
|
|
340
|
-
else:
|
|
341
|
-
self.sim_mask_train_ = None
|
|
342
|
-
self.sim_mask_test_ = None
|
|
343
315
|
|
|
344
|
-
#
|
|
316
|
+
# --- Clean (targets) per split ---
|
|
317
|
+
X_train_clean = self.ground_truth_[self.train_idx_].copy()
|
|
318
|
+
X_val_clean = self.ground_truth_[self.val_idx_].copy()
|
|
319
|
+
X_test_clean = self.ground_truth_[self.test_idx_].copy()
|
|
320
|
+
|
|
321
|
+
# --- Corrupted (inputs) per split (from the single simulation) ---
|
|
322
|
+
X_train_corrupted = X_for_model_full[self.train_idx_].copy()
|
|
323
|
+
X_val_corrupted = X_for_model_full[self.val_idx_].copy()
|
|
324
|
+
X_test_corrupted = X_for_model_full[self.test_idx_].copy()
|
|
325
|
+
|
|
326
|
+
# --- Masks per split ---
|
|
327
|
+
self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
|
|
328
|
+
self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
|
|
329
|
+
self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
|
|
330
|
+
|
|
331
|
+
self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
|
|
332
|
+
self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
|
|
333
|
+
self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
|
|
334
|
+
|
|
335
|
+
# Persist clean/corrupted matrices if you want them accessible later
|
|
336
|
+
self.X_train_clean_ = X_train_clean
|
|
337
|
+
self.X_val_clean_ = X_val_clean
|
|
338
|
+
self.X_test_clean_ = X_test_clean
|
|
339
|
+
|
|
340
|
+
self.X_train_corrupted_ = X_train_corrupted
|
|
341
|
+
self.X_val_corrupted_ = X_val_corrupted
|
|
342
|
+
self.X_test_corrupted_ = X_test_corrupted
|
|
343
|
+
|
|
344
|
+
# --- Haploid harmonization (do NOT resimulate; just recode values) ---
|
|
345
|
+
if self.is_haploid_:
|
|
346
|
+
|
|
347
|
+
def _haploidize(arr: np.ndarray) -> np.ndarray:
|
|
348
|
+
out = arr.copy()
|
|
349
|
+
miss = out < 0
|
|
350
|
+
out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
|
|
351
|
+
out[miss] = -1
|
|
352
|
+
return out
|
|
353
|
+
|
|
354
|
+
X_train_clean = _haploidize(X_train_clean)
|
|
355
|
+
X_val_clean = _haploidize(X_val_clean)
|
|
356
|
+
X_test_clean = _haploidize(X_test_clean)
|
|
357
|
+
|
|
358
|
+
X_train_corrupted = _haploidize(X_train_corrupted)
|
|
359
|
+
X_val_corrupted = _haploidize(X_val_corrupted)
|
|
360
|
+
X_test_corrupted = _haploidize(X_test_corrupted)
|
|
361
|
+
|
|
362
|
+
# Write back the persisted versions too
|
|
363
|
+
self.X_train_clean_ = X_train_clean
|
|
364
|
+
self.X_val_clean_ = X_val_clean
|
|
365
|
+
self.X_test_clean_ = X_test_clean
|
|
366
|
+
|
|
367
|
+
self.X_train_corrupted_ = X_train_corrupted
|
|
368
|
+
self.X_val_corrupted_ = X_val_corrupted
|
|
369
|
+
self.X_test_corrupted_ = X_test_corrupted
|
|
370
|
+
|
|
371
|
+
# Final training tensors/matrices used by the pipeline
|
|
372
|
+
# Convention: X_* are corrupted inputs; y_* are clean targets
|
|
373
|
+
self.X_train_ = self.X_train_corrupted_
|
|
374
|
+
self.y_train_ = self.X_train_clean_
|
|
375
|
+
|
|
376
|
+
self.X_val_ = self.X_val_corrupted_
|
|
377
|
+
self.y_val_ = self.X_val_clean_
|
|
378
|
+
|
|
379
|
+
self.X_test_ = self.X_test_corrupted_
|
|
380
|
+
self.y_test_ = self.X_test_clean_
|
|
381
|
+
|
|
382
|
+
self.X_train_ = self._one_hot_encode_012(
|
|
383
|
+
self.X_train_, num_classes=self.num_classes_
|
|
384
|
+
)
|
|
385
|
+
self.X_val_ = self._one_hot_encode_012(
|
|
386
|
+
self.X_val_, num_classes=self.num_classes_
|
|
387
|
+
)
|
|
388
|
+
self.X_test_ = self._one_hot_encode_012(
|
|
389
|
+
self.X_test_, num_classes=self.num_classes_
|
|
390
|
+
)
|
|
391
|
+
self.X_test_ = self.X_test_.detach().cpu().numpy()
|
|
392
|
+
|
|
345
393
|
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
394
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
395
|
+
|
|
396
|
+
loci = getattr(self, "valid_class_mask_conflict_loci_", None)
|
|
397
|
+
if loci is not None and loci.size:
|
|
398
|
+
self._repair_ref_alt_from_iupac(loci)
|
|
399
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
400
|
+
|
|
401
|
+
train_loader = self._get_data_loaders(
|
|
402
|
+
self.X_train_.detach().cpu().numpy(),
|
|
403
|
+
self.y_train_,
|
|
404
|
+
~self.orig_mask_train_,
|
|
405
|
+
self.batch_size,
|
|
406
|
+
shuffle=True,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
val_loader = self._get_data_loaders(
|
|
410
|
+
self.X_val_.detach().cpu().numpy(),
|
|
411
|
+
self.y_val_,
|
|
412
|
+
~self.orig_mask_val_,
|
|
413
|
+
self.batch_size,
|
|
414
|
+
shuffle=False,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
self.train_loader_ = train_loader
|
|
418
|
+
self.val_loader_ = val_loader
|
|
346
419
|
|
|
347
|
-
# Optional tuning
|
|
348
420
|
if self.tune:
|
|
349
|
-
self.tune_hyperparameters()
|
|
421
|
+
self.tuned_params_ = self.tune_hyperparameters()
|
|
422
|
+
self.model_tuned_ = True
|
|
423
|
+
else:
|
|
424
|
+
self.model_tuned_ = False
|
|
425
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
426
|
+
self.y_train_,
|
|
427
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
428
|
+
inverse=self.inverse,
|
|
429
|
+
normalize=self.normalize,
|
|
430
|
+
max_ratio=self.max_ratio,
|
|
431
|
+
power=self.power,
|
|
432
|
+
)
|
|
433
|
+
self.tuned_params_ = {
|
|
434
|
+
"latent_dim": self.latent_dim,
|
|
435
|
+
"learning_rate": self.learning_rate,
|
|
436
|
+
"dropout_rate": self.dropout_rate,
|
|
437
|
+
"num_hidden_layers": self.num_hidden_layers,
|
|
438
|
+
"activation": self.activation,
|
|
439
|
+
"l1_penalty": self.l1_penalty,
|
|
440
|
+
"layer_scaling_factor": self.layer_scaling_factor,
|
|
441
|
+
"layer_schedule": self.layer_schedule,
|
|
442
|
+
}
|
|
443
|
+
self.tuned_params_.update(
|
|
444
|
+
{
|
|
445
|
+
"kl_beta": self.kl_beta,
|
|
446
|
+
"kl_beta_schedule": self.kl_beta_schedule,
|
|
447
|
+
"gamma": self.gamma,
|
|
448
|
+
"gamma_schedule": self.gamma_schedule,
|
|
449
|
+
"inverse": self.inverse,
|
|
450
|
+
"normalize": self.normalize,
|
|
451
|
+
"power": self.power,
|
|
452
|
+
}
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
self.tuned_params_["model_params"] = self.model_params
|
|
350
456
|
|
|
351
|
-
|
|
352
|
-
|
|
457
|
+
if self.class_weights_ is not None:
|
|
458
|
+
self.logger.info(
|
|
459
|
+
f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# Always start clean
|
|
463
|
+
self.best_params_ = copy.deepcopy(self.tuned_params_)
|
|
353
464
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
self.
|
|
465
|
+
model_params_final = {
|
|
466
|
+
"n_features": self.num_features_,
|
|
467
|
+
"num_classes": self.num_classes_,
|
|
468
|
+
"latent_dim": int(self.best_params_["latent_dim"]),
|
|
469
|
+
"dropout_rate": float(self.best_params_["dropout_rate"]),
|
|
470
|
+
"activation": str(self.best_params_["activation"]),
|
|
471
|
+
"kl_beta": float(
|
|
472
|
+
self.best_params_.get("kl_beta", getattr(self, "kl_beta", 1.0))
|
|
473
|
+
),
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
input_dim = self.num_features_ * self.num_classes_
|
|
477
|
+
model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
|
|
478
|
+
n_inputs=input_dim,
|
|
479
|
+
n_outputs=self.num_classes_,
|
|
480
|
+
n_samples=len(self.X_train_),
|
|
481
|
+
n_hidden=int(self.best_params_["num_hidden_layers"]),
|
|
482
|
+
latent_dim=int(self.best_params_["latent_dim"]),
|
|
483
|
+
alpha=float(self.best_params_["layer_scaling_factor"]),
|
|
484
|
+
schedule=str(self.best_params_["layer_schedule"]),
|
|
485
|
+
min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
|
|
357
486
|
)
|
|
358
|
-
if not self.is_haploid:
|
|
359
|
-
self.pos_weights_ = self._compute_pos_weights(self.X_train_)
|
|
360
|
-
else:
|
|
361
|
-
self.pos_weights_ = None
|
|
362
487
|
|
|
363
|
-
|
|
364
|
-
train_loader = self._get_data_loader(self.X_train_)
|
|
488
|
+
self.best_params_["model_params"] = model_params_final
|
|
365
489
|
|
|
366
|
-
#
|
|
367
|
-
model = self.build_model(self.Model, self.best_params_)
|
|
490
|
+
# Now build the model
|
|
491
|
+
model = self.build_model(self.Model, self.best_params_["model_params"])
|
|
368
492
|
model.apply(self.initialize_weights)
|
|
369
493
|
|
|
494
|
+
if self.verbose or self.debug:
|
|
495
|
+
self.logger.info("Using model hyperparameters:")
|
|
496
|
+
pm = PrettyMetrics(
|
|
497
|
+
self.best_params_, precision=3, title="Model Hyperparameters"
|
|
498
|
+
)
|
|
499
|
+
pm.render()
|
|
500
|
+
|
|
501
|
+
lr_final = float(self.best_params_["learning_rate"])
|
|
502
|
+
l1_final = float(self.best_params_["l1_penalty"])
|
|
503
|
+
|
|
370
504
|
loss, trained_model, history = self._train_and_validate_model(
|
|
371
505
|
model=model,
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
l1_penalty=self.l1_penalty,
|
|
375
|
-
return_history=True,
|
|
376
|
-
class_weights=self.class_weights_,
|
|
377
|
-
X_val=self.X_val_,
|
|
506
|
+
lr=lr_final,
|
|
507
|
+
l1_penalty=l1_final,
|
|
378
508
|
params=self.best_params_,
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
eval_latent_steps=0,
|
|
384
|
-
eval_latent_lr=0.0,
|
|
385
|
-
eval_latent_weight_decay=0.0,
|
|
509
|
+
trial=None,
|
|
510
|
+
class_weights=self.class_weights_,
|
|
511
|
+
kl_beta_schedule=self.best_params_["kl_beta_schedule"],
|
|
512
|
+
gamma_schedule=self.best_params_["gamma_schedule"],
|
|
386
513
|
)
|
|
387
514
|
|
|
388
515
|
if trained_model is None:
|
|
389
|
-
msg = "
|
|
516
|
+
msg = f"{self.model_name} training failed."
|
|
390
517
|
self.logger.error(msg)
|
|
391
518
|
raise RuntimeError(msg)
|
|
392
519
|
|
|
@@ -395,215 +522,209 @@ class ImputeVAE(BaseNNImputer):
|
|
|
395
522
|
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
396
523
|
)
|
|
397
524
|
|
|
398
|
-
|
|
399
|
-
|
|
525
|
+
if history is None:
|
|
526
|
+
hist = {"Train": []}
|
|
527
|
+
elif isinstance(history, dict):
|
|
528
|
+
hist = dict(history)
|
|
529
|
+
else:
|
|
530
|
+
hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
|
|
531
|
+
|
|
532
|
+
self.best_loss_ = loss
|
|
533
|
+
self.model_ = trained_model
|
|
534
|
+
self.history_ = hist
|
|
400
535
|
self.is_fit_ = True
|
|
401
536
|
|
|
402
|
-
# Evaluate (AE-parity reporting)
|
|
403
|
-
eval_mask = (
|
|
404
|
-
self.sim_mask_test_
|
|
405
|
-
if (self.simulate_missing and self.sim_mask_test_ is not None)
|
|
406
|
-
else None
|
|
407
|
-
)
|
|
408
537
|
self._evaluate_model(
|
|
409
|
-
self.X_val_,
|
|
410
538
|
self.model_,
|
|
411
|
-
self.
|
|
412
|
-
|
|
539
|
+
X=self.X_test_,
|
|
540
|
+
y=self.y_test_,
|
|
541
|
+
eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
|
|
542
|
+
objective_mode=False,
|
|
413
543
|
)
|
|
414
544
|
|
|
415
|
-
self.
|
|
545
|
+
if self.show_plots:
|
|
546
|
+
self.plotter_.plot_history(self.history_)
|
|
547
|
+
|
|
416
548
|
self._save_best_params(self.best_params_)
|
|
549
|
+
|
|
550
|
+
if self.model_tuned_:
|
|
551
|
+
title = f"{self.model_name} Optimized Parameters"
|
|
552
|
+
|
|
553
|
+
if self.verbose or self.debug:
|
|
554
|
+
pm = PrettyMetrics(self.best_params_, precision=2, title=title)
|
|
555
|
+
pm.render()
|
|
556
|
+
|
|
557
|
+
# Save best parameters to a JSON file.
|
|
558
|
+
self._save_best_params(self.best_params_, objective_mode=True)
|
|
417
559
|
return self
|
|
418
560
|
|
|
419
561
|
def transform(self) -> np.ndarray:
|
|
420
562
|
"""Impute missing genotypes and return IUPAC strings.
|
|
421
563
|
|
|
422
|
-
This method
|
|
564
|
+
This method performs the following steps:
|
|
565
|
+
1. Validates that the model has been fitted.
|
|
566
|
+
2. Uses the trained model to predict missing genotypes for the entire dataset.
|
|
567
|
+
3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
|
|
568
|
+
4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
|
|
569
|
+
5. Checks for any remaining missing values or decoding issues, raising errors if found.
|
|
570
|
+
6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
|
|
571
|
+
7. Returns the imputed IUPAC genotype matrix.
|
|
423
572
|
|
|
424
573
|
Returns:
|
|
425
|
-
np.ndarray: IUPAC
|
|
574
|
+
np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
|
|
426
575
|
|
|
427
|
-
|
|
428
|
-
|
|
576
|
+
Notes:
|
|
577
|
+
- ``transform()`` does not take any arguments; it operates on the data provided during initialization.
|
|
578
|
+
- Ensure that the model has been fitted before calling this method.
|
|
579
|
+
- For haploid data, genotypes encoded as '1' are treated as '2' during decoding.
|
|
580
|
+
- The method checks for decoding failures (i.e., resulting in 'N') and raises an error if any are found.
|
|
429
581
|
"""
|
|
430
582
|
if not getattr(self, "is_fit_", False):
|
|
431
|
-
|
|
583
|
+
msg = "Model is not fitted. Call fit() before transform()."
|
|
584
|
+
self.logger.error(msg)
|
|
585
|
+
raise NotFittedError(msg)
|
|
432
586
|
|
|
433
587
|
self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
|
|
434
588
|
X_to_impute = self.ground_truth_.copy()
|
|
435
589
|
|
|
436
|
-
|
|
590
|
+
# 1. Predict labels (0/1/2) for the entire matrix
|
|
591
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute)
|
|
437
592
|
|
|
438
|
-
# Fill
|
|
439
|
-
missing_mask = X_to_impute
|
|
593
|
+
# 2. Fill ONLY originally missing values
|
|
594
|
+
missing_mask = X_to_impute < 0
|
|
440
595
|
imputed_array = X_to_impute.copy()
|
|
441
596
|
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
442
597
|
|
|
443
|
-
#
|
|
444
|
-
|
|
445
|
-
|
|
598
|
+
# Sanity check: all -1s should be gone
|
|
599
|
+
if np.any(imputed_array < 0):
|
|
600
|
+
msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
|
|
601
|
+
self.logger.error(msg)
|
|
602
|
+
raise RuntimeError(msg)
|
|
446
603
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
self.
|
|
604
|
+
# 3. Handle Haploid mapping (2->1) before decoding if needed
|
|
605
|
+
decode_input = imputed_array
|
|
606
|
+
if self.is_haploid_:
|
|
607
|
+
decode_input = imputed_array.copy()
|
|
608
|
+
decode_input[decode_input == 1] = 2
|
|
450
609
|
|
|
451
|
-
|
|
610
|
+
# 4. Decode integers to IUPAC strings
|
|
611
|
+
imputed_genotypes = self.decode_012(decode_input)
|
|
452
612
|
|
|
453
|
-
|
|
613
|
+
# 5. Check for decoding failures (N)
|
|
614
|
+
# Hard error: downstream pipelines expect fully imputed data.
|
|
615
|
+
bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
|
|
616
|
+
if bad_loci.size > 0:
|
|
617
|
+
msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
|
|
618
|
+
self.logger.error(msg)
|
|
619
|
+
self.logger.debug(
|
|
620
|
+
"All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
|
|
621
|
+
)
|
|
622
|
+
raise RuntimeError(msg)
|
|
454
623
|
|
|
455
|
-
|
|
456
|
-
|
|
624
|
+
if self.show_plots:
|
|
625
|
+
original_input = X_to_impute
|
|
626
|
+
if self.is_haploid_:
|
|
627
|
+
original_input = X_to_impute.copy()
|
|
628
|
+
original_input[original_input == 1] = 2
|
|
457
629
|
|
|
458
|
-
|
|
630
|
+
original_genotypes = self.decode_012(original_input)
|
|
459
631
|
|
|
460
|
-
|
|
461
|
-
|
|
632
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
633
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
634
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
462
635
|
|
|
463
|
-
|
|
464
|
-
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
465
|
-
"""
|
|
466
|
-
y_tensor = torch.from_numpy(y).long()
|
|
467
|
-
indices = torch.arange(len(y), dtype=torch.long)
|
|
468
|
-
dataset = torch.utils.data.TensorDataset(indices, y_tensor)
|
|
469
|
-
pin_memory = self.device.type == "cuda"
|
|
470
|
-
return torch.utils.data.DataLoader(
|
|
471
|
-
dataset,
|
|
472
|
-
batch_size=self.batch_size,
|
|
473
|
-
shuffle=True,
|
|
474
|
-
pin_memory=pin_memory,
|
|
475
|
-
)
|
|
636
|
+
return imputed_genotypes
|
|
476
637
|
|
|
477
638
|
def _train_and_validate_model(
|
|
478
639
|
self,
|
|
479
640
|
model: torch.nn.Module,
|
|
480
|
-
|
|
641
|
+
*,
|
|
481
642
|
lr: float,
|
|
482
643
|
l1_penalty: float,
|
|
483
|
-
trial: optuna.Trial
|
|
484
|
-
|
|
485
|
-
class_weights: torch.Tensor
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
eval_requires_latents: bool = False, # VAE: no latent eval refinement
|
|
493
|
-
eval_latent_steps: int = 0,
|
|
494
|
-
eval_latent_lr: float = 0.0,
|
|
495
|
-
eval_latent_weight_decay: float = 0.0,
|
|
496
|
-
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
497
|
-
"""Wrap the VAE training loop with β-anneal & Optuna pruning.
|
|
498
|
-
|
|
499
|
-
This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
|
|
644
|
+
trial: Optional[optuna.Trial] = None,
|
|
645
|
+
params: Optional[dict[str, Any]] = None,
|
|
646
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
647
|
+
kl_beta_schedule: bool = False,
|
|
648
|
+
gamma_schedule: bool = False,
|
|
649
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
650
|
+
"""Train and validate the model.
|
|
651
|
+
|
|
652
|
+
This method orchestrates training with early stopping and optional Optuna pruning based on validation performance. It returns the best validation loss, the best model (with best weights loaded), and training history.
|
|
500
653
|
|
|
501
654
|
Args:
|
|
502
655
|
model (torch.nn.Module): VAE model.
|
|
503
|
-
loader (torch.utils.data.DataLoader): Training data loader.
|
|
504
656
|
lr (float): Learning rate.
|
|
505
657
|
l1_penalty (float): L1 regularization coefficient.
|
|
506
|
-
trial (optuna.Trial
|
|
507
|
-
|
|
508
|
-
class_weights (torch.Tensor
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
prune_metric (str | None): Metric for pruning decisions.
|
|
512
|
-
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
513
|
-
eval_interval (int): Epochs between validation evaluations.
|
|
514
|
-
eval_requires_latents (bool): If True, refine latents during eval.
|
|
515
|
-
eval_latent_steps (int): Latent refinement steps if needed.
|
|
516
|
-
eval_latent_lr (float): Latent refinement learning rate.
|
|
517
|
-
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
658
|
+
trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
|
|
659
|
+
params (Optional[dict[str, float | int | str | dict[str, Any]]]): Model params for evaluation.
|
|
660
|
+
class_weights (Optional[torch.Tensor]): Class weights for loss computation.
|
|
661
|
+
kl_beta_schedule (bool): Whether to use KL beta scheduling.
|
|
662
|
+
gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
|
|
518
663
|
|
|
519
664
|
Returns:
|
|
520
|
-
|
|
665
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
666
|
+
Best validation loss, best model, and training history.
|
|
521
667
|
"""
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
self.logger.error(msg)
|
|
525
|
-
raise TypeError(msg)
|
|
668
|
+
max_epochs = self.epochs
|
|
669
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
526
670
|
|
|
527
|
-
|
|
528
|
-
|
|
671
|
+
scheduler = _make_warmup_cosine_scheduler(
|
|
672
|
+
optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
|
|
529
673
|
)
|
|
530
674
|
|
|
531
|
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
532
|
-
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
533
|
-
|
|
534
675
|
best_loss, best_model, hist = self._execute_training_loop(
|
|
535
|
-
loader=loader,
|
|
536
676
|
optimizer=optimizer,
|
|
537
677
|
scheduler=scheduler,
|
|
538
678
|
model=model,
|
|
539
679
|
l1_penalty=l1_penalty,
|
|
540
680
|
trial=trial,
|
|
541
|
-
return_history=return_history,
|
|
542
|
-
class_weights=class_weights,
|
|
543
|
-
X_val=X_val,
|
|
544
681
|
params=params,
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
eval_requires_latents=eval_requires_latents,
|
|
549
|
-
eval_latent_steps=eval_latent_steps,
|
|
550
|
-
eval_latent_lr=eval_latent_lr,
|
|
551
|
-
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
682
|
+
class_weights=class_weights,
|
|
683
|
+
kl_beta_schedule=kl_beta_schedule,
|
|
684
|
+
gamma_schedule=gamma_schedule,
|
|
552
685
|
)
|
|
553
|
-
|
|
554
|
-
return best_loss, best_model, hist
|
|
555
|
-
|
|
556
|
-
return best_loss, best_model, None
|
|
686
|
+
return best_loss, best_model, hist
|
|
557
687
|
|
|
558
688
|
def _execute_training_loop(
|
|
559
689
|
self,
|
|
560
|
-
|
|
690
|
+
*,
|
|
561
691
|
optimizer: torch.optim.Optimizer,
|
|
562
|
-
scheduler:
|
|
692
|
+
scheduler: (
|
|
693
|
+
torch.optim.lr_scheduler.CosineAnnealingLR
|
|
694
|
+
| torch.optim.lr_scheduler.SequentialLR
|
|
695
|
+
),
|
|
563
696
|
model: torch.nn.Module,
|
|
564
697
|
l1_penalty: float,
|
|
565
|
-
trial: optuna.Trial
|
|
566
|
-
|
|
567
|
-
class_weights: torch.Tensor,
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
eval_requires_latents: bool = False,
|
|
575
|
-
eval_latent_steps: int = 0,
|
|
576
|
-
eval_latent_lr: float = 0.0,
|
|
577
|
-
eval_latent_weight_decay: float = 0.0,
|
|
578
|
-
) -> Tuple[float, torch.nn.Module, list]:
|
|
579
|
-
"""Train VAE with stable focal CE + KL(β) anneal and numeric guards.
|
|
580
|
-
|
|
581
|
-
This method implements the core training loop for the VAE model, incorporating a focal cross-entropy loss for reconstruction and a KL divergence term with an annealed weight (beta). It includes mechanisms for early stopping based on validation performance, learning rate scheduling, and optional pruning of unpromising trials when using Optuna for hyperparameter optimization. The method ensures numerical stability throughout the training process.
|
|
698
|
+
trial: Optional[optuna.Trial] = None,
|
|
699
|
+
params: Optional[dict[str, Any]] = None,
|
|
700
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
701
|
+
kl_beta_schedule: bool = False,
|
|
702
|
+
gamma_schedule: bool = False,
|
|
703
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
704
|
+
"""Train the model with focal CE reconstruction + KL divergence.
|
|
705
|
+
|
|
706
|
+
This method performs the training loop for the model using the provided optimizer and learning rate scheduler. It supports early stopping based on validation loss and integrates with Optuna for hyperparameter tuning. The method returns the best validation loss, the best model state, and the training history.
|
|
582
707
|
|
|
583
708
|
Args:
|
|
584
|
-
loader (torch.utils.data.DataLoader): Training data loader.
|
|
585
709
|
optimizer (torch.optim.Optimizer): Optimizer.
|
|
586
|
-
scheduler (torch.optim.lr_scheduler.
|
|
710
|
+
scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): Learning rate scheduler.
|
|
587
711
|
model (torch.nn.Module): VAE model.
|
|
588
712
|
l1_penalty (float): L1 regularization coefficient.
|
|
589
|
-
trial (optuna.Trial
|
|
590
|
-
|
|
591
|
-
class_weights (torch.Tensor):
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
prune_metric (str | None): Metric for pruning decisions.
|
|
595
|
-
prune_warmup_epochs (int): Epochs to skip before pruning.
|
|
596
|
-
eval_interval (int): Epochs between validation evaluations.
|
|
597
|
-
eval_requires_latents (bool): If True, refine latents during eval.
|
|
598
|
-
eval_latent_steps (int): Latent refinement steps if needed.
|
|
599
|
-
eval_latent_lr (float): Latent refinement learning rate.
|
|
600
|
-
eval_latent_weight_decay (float): Latent refinement L2 penalty.
|
|
713
|
+
trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
|
|
714
|
+
params (Optional[dict[str, Any]]): Model params for evaluation.
|
|
715
|
+
class_weights (Optional[torch.Tensor]): Class weights for loss computation.
|
|
716
|
+
kl_beta_schedule (bool): Whether to use KL beta scheduling.
|
|
717
|
+
gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
|
|
601
718
|
|
|
602
719
|
Returns:
|
|
603
|
-
|
|
720
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, training history.
|
|
721
|
+
|
|
722
|
+
Notes:
|
|
723
|
+
- Use CE with class weights during training/validation.
|
|
724
|
+
- Inference de-bias happens in _predict (separate).
|
|
725
|
+
- If `class_weights` is None, this will fall back to self.class_weights_ if present.
|
|
604
726
|
"""
|
|
605
|
-
|
|
606
|
-
history: list[float] = []
|
|
727
|
+
history: dict[str, list[float]] = defaultdict(list)
|
|
607
728
|
|
|
608
729
|
early_stopping = EarlyStopping(
|
|
609
730
|
patience=self.early_stop_gen,
|
|
@@ -613,204 +734,175 @@ class ImputeVAE(BaseNNImputer):
|
|
|
613
734
|
debug=self.debug,
|
|
614
735
|
)
|
|
615
736
|
|
|
616
|
-
#
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
gamma_val = gamma_val[0]
|
|
631
|
-
gamma_final = float(gamma_val)
|
|
632
|
-
|
|
633
|
-
gamma_warm, gamma_ramp = 50, 100
|
|
634
|
-
beta_warm, beta_ramp = int(self.kl_warmup), int(self.kl_ramp)
|
|
635
|
-
|
|
636
|
-
# Optional LR warmup
|
|
637
|
-
warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
|
|
638
|
-
base_lr = float(optimizer.param_groups[0]["lr"])
|
|
639
|
-
min_lr = base_lr * 0.1
|
|
640
|
-
|
|
641
|
-
max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
|
|
642
|
-
|
|
643
|
-
for epoch in range(max_epochs):
|
|
644
|
-
# focal γ schedule
|
|
645
|
-
if epoch < gamma_warm:
|
|
646
|
-
model.gamma = 0.0 # type: ignore[attr-defined]
|
|
647
|
-
elif epoch < gamma_warm + gamma_ramp:
|
|
648
|
-
model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore[attr-defined]
|
|
649
|
-
else:
|
|
650
|
-
model.gamma = gamma_final # type: ignore[attr-defined]
|
|
737
|
+
# KL schedule
|
|
738
|
+
kl_beta_target, kl_warm, kl_ramp = self._anneal_config(
|
|
739
|
+
params, "kl_beta", default=self.kl_beta, max_epochs=self.epochs
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
kl_beta_target = float(kl_beta_target)
|
|
743
|
+
|
|
744
|
+
gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
|
|
745
|
+
params, "gamma", default=self.gamma, max_epochs=self.epochs
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
cw = class_weights
|
|
749
|
+
if cw is not None and cw.device != self.device:
|
|
750
|
+
cw = cw.to(self.device)
|
|
651
751
|
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
752
|
+
ce_criterion = FocalCELoss(
|
|
753
|
+
alpha=cw, gamma=gamma_target, reduction="mean", ignore_index=-1
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
for epoch in range(int(self.epochs)):
|
|
757
|
+
if kl_beta_schedule:
|
|
758
|
+
kl_beta_current = self._update_anneal_schedule(
|
|
759
|
+
kl_beta_target,
|
|
760
|
+
warm=kl_warm,
|
|
761
|
+
ramp=kl_ramp,
|
|
762
|
+
epoch=epoch,
|
|
763
|
+
init_val=0.0,
|
|
764
|
+
)
|
|
657
765
|
else:
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
if
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
766
|
+
kl_beta_current = kl_beta_target
|
|
767
|
+
|
|
768
|
+
if gamma_schedule:
|
|
769
|
+
gamma_current = self._update_anneal_schedule(
|
|
770
|
+
gamma_target,
|
|
771
|
+
warm=gamma_warm,
|
|
772
|
+
ramp=gamma_ramp,
|
|
773
|
+
epoch=epoch,
|
|
774
|
+
init_val=0.0,
|
|
775
|
+
)
|
|
776
|
+
ce_criterion.gamma = gamma_current
|
|
664
777
|
|
|
665
778
|
train_loss = self._train_step(
|
|
666
|
-
loader=
|
|
779
|
+
loader=self.train_loader_,
|
|
667
780
|
optimizer=optimizer,
|
|
668
781
|
model=model,
|
|
782
|
+
ce_criterion=ce_criterion,
|
|
669
783
|
l1_penalty=l1_penalty,
|
|
670
|
-
|
|
784
|
+
kl_beta=kl_beta_current,
|
|
671
785
|
)
|
|
672
786
|
|
|
673
787
|
if not np.isfinite(train_loss):
|
|
674
|
-
if trial:
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
788
|
+
if trial is not None:
|
|
789
|
+
msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
|
|
790
|
+
self.logger.warning(msg)
|
|
791
|
+
raise optuna.exceptions.TrialPruned(msg)
|
|
792
|
+
|
|
793
|
+
msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
|
|
794
|
+
self.logger.error(msg)
|
|
795
|
+
raise RuntimeError(msg)
|
|
680
796
|
|
|
681
|
-
|
|
682
|
-
|
|
797
|
+
val_loss = self._val_step(
|
|
798
|
+
loader=self.val_loader_,
|
|
799
|
+
model=model,
|
|
800
|
+
ce_criterion=ce_criterion,
|
|
801
|
+
l1_penalty=l1_penalty,
|
|
802
|
+
kl_beta=kl_beta_current,
|
|
803
|
+
)
|
|
683
804
|
|
|
684
|
-
|
|
685
|
-
|
|
805
|
+
scheduler.step()
|
|
806
|
+
history["Train"].append(float(train_loss))
|
|
807
|
+
history["Val"].append(float(val_loss))
|
|
686
808
|
|
|
687
|
-
early_stopping(
|
|
809
|
+
early_stopping(val_loss, model)
|
|
688
810
|
if early_stopping.early_stop:
|
|
689
|
-
self.logger.
|
|
811
|
+
self.logger.debug(
|
|
812
|
+
f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
|
|
813
|
+
)
|
|
690
814
|
break
|
|
691
815
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
trial is not None
|
|
695
|
-
and X_val is not None
|
|
696
|
-
and ((epoch + 1) % eval_interval == 0)
|
|
697
|
-
):
|
|
698
|
-
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
699
|
-
mask_override = None
|
|
700
|
-
if (
|
|
701
|
-
self.simulate_missing
|
|
702
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
703
|
-
and getattr(self, "X_val_", None) is not None
|
|
704
|
-
and X_val.shape == self.X_val_.shape
|
|
705
|
-
):
|
|
706
|
-
mask_override = self.sim_mask_test_
|
|
707
|
-
metric_val = self._eval_for_pruning(
|
|
816
|
+
if trial is not None:
|
|
817
|
+
metric_vals = self._evaluate_model(
|
|
708
818
|
model=model,
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
819
|
+
X=self.X_val_corrupted_,
|
|
820
|
+
y=self.y_val_,
|
|
821
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
712
822
|
objective_mode=True,
|
|
713
|
-
do_latent_infer=False, # VAE uses encoder; no latent refine
|
|
714
|
-
latent_steps=0,
|
|
715
|
-
latent_lr=0.0,
|
|
716
|
-
latent_weight_decay=0.0,
|
|
717
|
-
latent_seed=self.seed, # type: ignore
|
|
718
|
-
_latent_cache=None,
|
|
719
|
-
_latent_cache_key=None,
|
|
720
|
-
eval_mask_override=mask_override,
|
|
721
823
|
)
|
|
722
|
-
trial.report(
|
|
723
|
-
if
|
|
824
|
+
trial.report(metric_vals[self.tune_metric], step=epoch + 1)
|
|
825
|
+
if trial.should_prune():
|
|
724
826
|
raise optuna.exceptions.TrialPruned(
|
|
725
|
-
f"
|
|
827
|
+
f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
|
|
726
828
|
)
|
|
727
829
|
|
|
728
|
-
best_loss = early_stopping.best_score
|
|
729
|
-
|
|
730
|
-
if
|
|
731
|
-
|
|
732
|
-
|
|
830
|
+
best_loss = float(early_stopping.best_score)
|
|
831
|
+
|
|
832
|
+
if early_stopping.best_state_dict is not None:
|
|
833
|
+
model.load_state_dict(early_stopping.best_state_dict)
|
|
834
|
+
|
|
835
|
+
return best_loss, model, dict(history)
|
|
733
836
|
|
|
734
837
|
def _train_step(
|
|
735
838
|
self,
|
|
736
839
|
loader: torch.utils.data.DataLoader,
|
|
737
840
|
optimizer: torch.optim.Optimizer,
|
|
738
841
|
model: torch.nn.Module,
|
|
842
|
+
ce_criterion: torch.nn.Module,
|
|
843
|
+
*,
|
|
739
844
|
l1_penalty: float,
|
|
740
|
-
|
|
845
|
+
kl_beta: torch.Tensor | float,
|
|
741
846
|
) -> float:
|
|
742
|
-
"""
|
|
743
|
-
|
|
744
|
-
This method performs a single training epoch for the VAE model. It processes batches of data, computes the reconstruction and KL divergence losses, applies L1 regularization if specified, and updates the model parameters. The method includes safeguards against non-finite values in the model outputs and gradients to ensure stable training.
|
|
847
|
+
"""Single epoch train step across batches (focal CE + KL + optional L1).
|
|
745
848
|
|
|
746
849
|
Args:
|
|
747
850
|
loader (torch.utils.data.DataLoader): Training data loader.
|
|
748
851
|
optimizer (torch.optim.Optimizer): Optimizer.
|
|
749
852
|
model (torch.nn.Module): VAE model.
|
|
853
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
750
854
|
l1_penalty (float): L1 regularization coefficient.
|
|
751
|
-
|
|
855
|
+
kl_beta (torch.Tensor | float): KL divergence weight.
|
|
752
856
|
|
|
753
857
|
Returns:
|
|
754
|
-
float: Average training loss
|
|
858
|
+
float: Average training loss.
|
|
859
|
+
|
|
755
860
|
"""
|
|
756
861
|
model.train()
|
|
757
|
-
running
|
|
862
|
+
running = 0.0
|
|
863
|
+
num_batches = 0
|
|
864
|
+
|
|
865
|
+
nF_model = self.num_features_
|
|
866
|
+
nC_model = self.num_classes_
|
|
758
867
|
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
759
|
-
if class_weights is not None and class_weights.device != self.device:
|
|
760
|
-
class_weights = class_weights.to(self.device)
|
|
761
868
|
|
|
762
|
-
for
|
|
869
|
+
for X_batch, y_batch, m_batch in loader:
|
|
763
870
|
optimizer.zero_grad(set_to_none=True)
|
|
871
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
872
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
873
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
764
874
|
|
|
765
|
-
|
|
875
|
+
raw = model(X_batch)
|
|
876
|
+
logits0 = raw[0]
|
|
766
877
|
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
878
|
+
expected = X_batch.shape[0] * nF_model * nC_model
|
|
879
|
+
if logits0.numel() != expected:
|
|
880
|
+
msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
|
|
881
|
+
self.logger.error(msg)
|
|
882
|
+
raise ValueError(msg)
|
|
771
883
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
recon_logits, mu, logvar = out[0], out[1], out[2]
|
|
775
|
-
else:
|
|
776
|
-
recon_logits, mu, logvar = out["recon_logits"], out["mu"], out["logvar"]
|
|
777
|
-
|
|
778
|
-
# Upstream guard
|
|
779
|
-
if (
|
|
780
|
-
not torch.isfinite(recon_logits).all()
|
|
781
|
-
or not torch.isfinite(mu).all()
|
|
782
|
-
or not torch.isfinite(logvar).all()
|
|
783
|
-
):
|
|
784
|
-
continue
|
|
884
|
+
logits_masked = logits0.view(-1, nC_model)
|
|
885
|
+
logits_masked = logits_masked[m_batch.view(-1)]
|
|
785
886
|
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
gamma = max(0.0, min(gamma, 10.0))
|
|
887
|
+
targets_masked = y_batch.view(-1)
|
|
888
|
+
targets_masked = targets_masked[m_batch.view(-1)]
|
|
789
889
|
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
mask = (y_int != -1).unsqueeze(-1).float()
|
|
807
|
-
recon_loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
808
|
-
kl = (
|
|
809
|
-
-0.5
|
|
810
|
-
* torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
811
|
-
/ (y_int.shape[0] + 1e-8)
|
|
812
|
-
)
|
|
813
|
-
loss = recon_loss + beta * kl
|
|
890
|
+
mask_flat = m_batch.view(-1)
|
|
891
|
+
if not bool(mask_flat.any()):
|
|
892
|
+
continue
|
|
893
|
+
|
|
894
|
+
# average number of masked loci per sample (scalar)
|
|
895
|
+
recon_scale = (mask_flat.sum().float() / float(X_batch.shape[0])).detach()
|
|
896
|
+
|
|
897
|
+
loss = compute_vae_loss(
|
|
898
|
+
ce_criterion,
|
|
899
|
+
logits_masked,
|
|
900
|
+
targets_masked,
|
|
901
|
+
mu=raw[1],
|
|
902
|
+
logvar=raw[2],
|
|
903
|
+
kl_beta=kl_beta,
|
|
904
|
+
recon_scale=recon_scale,
|
|
905
|
+
)
|
|
814
906
|
|
|
815
907
|
if l1_penalty > 0:
|
|
816
908
|
l1 = torch.zeros((), device=self.device)
|
|
@@ -818,185 +910,279 @@ class ImputeVAE(BaseNNImputer):
|
|
|
818
910
|
l1 = l1 + p.abs().sum()
|
|
819
911
|
loss = loss + l1_penalty * l1
|
|
820
912
|
|
|
821
|
-
if not torch.isfinite(loss):
|
|
822
|
-
continue
|
|
823
|
-
|
|
824
913
|
loss.backward()
|
|
825
914
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
826
|
-
|
|
827
|
-
# skip update if any grad is non-finite
|
|
828
|
-
bad = any(
|
|
829
|
-
p.grad is not None and not torch.isfinite(p.grad).all()
|
|
830
|
-
for p in model.parameters()
|
|
831
|
-
)
|
|
832
|
-
if bad:
|
|
833
|
-
optimizer.zero_grad(set_to_none=True)
|
|
834
|
-
continue
|
|
835
|
-
|
|
836
915
|
optimizer.step()
|
|
837
916
|
|
|
838
917
|
running += float(loss.detach().item())
|
|
839
|
-
|
|
918
|
+
num_batches += 1
|
|
919
|
+
|
|
920
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
921
|
+
|
|
922
|
+
def _val_step(
|
|
923
|
+
self,
|
|
924
|
+
loader: torch.utils.data.DataLoader,
|
|
925
|
+
model: torch.nn.Module,
|
|
926
|
+
ce_criterion: torch.nn.Module,
|
|
927
|
+
*,
|
|
928
|
+
l1_penalty: float,
|
|
929
|
+
kl_beta: torch.Tensor | float = 1.0,
|
|
930
|
+
) -> float:
|
|
931
|
+
"""Validation step for a single epoch (focal CE + KL + optional L1).
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
loader (torch.utils.data.DataLoader): Validation data loader.
|
|
935
|
+
model (torch.nn.Module): VAE model.
|
|
936
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
937
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
938
|
+
kl_beta (torch.Tensor | float): KL divergence weight.
|
|
939
|
+
|
|
940
|
+
Returns:
|
|
941
|
+
float: Average validation loss.
|
|
942
|
+
"""
|
|
943
|
+
model.eval()
|
|
944
|
+
running = 0.0
|
|
945
|
+
num_batches = 0
|
|
946
|
+
|
|
947
|
+
nF_model = self.num_features_
|
|
948
|
+
nC_model = self.num_classes_
|
|
949
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
950
|
+
|
|
951
|
+
with torch.no_grad():
|
|
952
|
+
for X_batch, y_batch, m_batch in loader:
|
|
953
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
954
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
955
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
956
|
+
|
|
957
|
+
raw = model(X_batch)
|
|
958
|
+
logits0 = raw[0]
|
|
959
|
+
|
|
960
|
+
expected = X_batch.shape[0] * nF_model * nC_model
|
|
961
|
+
if logits0.numel() != expected:
|
|
962
|
+
msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
|
|
963
|
+
self.logger.error(msg)
|
|
964
|
+
raise ValueError(msg)
|
|
965
|
+
|
|
966
|
+
logits_masked = logits0.view(-1, nC_model)
|
|
967
|
+
logits_masked = logits_masked[m_batch.view(-1)]
|
|
968
|
+
|
|
969
|
+
targets_masked = y_batch.view(-1)
|
|
970
|
+
targets_masked = targets_masked[m_batch.view(-1)]
|
|
971
|
+
|
|
972
|
+
mask_flat = m_batch.view(-1)
|
|
973
|
+
|
|
974
|
+
if not bool(mask_flat.any()):
|
|
975
|
+
continue
|
|
840
976
|
|
|
841
|
-
|
|
977
|
+
# average number of masked loci per sample (scalar)
|
|
978
|
+
recon_scale = (
|
|
979
|
+
mask_flat.sum().float() / float(X_batch.shape[0])
|
|
980
|
+
).detach()
|
|
981
|
+
|
|
982
|
+
loss = compute_vae_loss(
|
|
983
|
+
ce_criterion,
|
|
984
|
+
logits_masked,
|
|
985
|
+
targets_masked,
|
|
986
|
+
mu=raw[1],
|
|
987
|
+
logvar=raw[2],
|
|
988
|
+
kl_beta=kl_beta,
|
|
989
|
+
recon_scale=recon_scale,
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
if l1_penalty > 0:
|
|
993
|
+
l1 = torch.zeros((), device=self.device)
|
|
994
|
+
for p in l1_params:
|
|
995
|
+
l1 = l1 + p.abs().sum()
|
|
996
|
+
loss = loss + l1_penalty * l1
|
|
997
|
+
|
|
998
|
+
if not torch.isfinite(loss):
|
|
999
|
+
continue
|
|
1000
|
+
|
|
1001
|
+
running += float(loss.item())
|
|
1002
|
+
num_batches += 1
|
|
1003
|
+
|
|
1004
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
842
1005
|
|
|
843
1006
|
def _predict(
|
|
844
1007
|
self,
|
|
845
1008
|
model: torch.nn.Module,
|
|
846
1009
|
X: np.ndarray | torch.Tensor,
|
|
1010
|
+
*,
|
|
847
1011
|
return_proba: bool = False,
|
|
848
|
-
) ->
|
|
849
|
-
"""Predict
|
|
1012
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
1013
|
+
"""Predict categorical genotype labels from logits.
|
|
850
1014
|
|
|
851
|
-
This method uses the trained
|
|
1015
|
+
This method uses the trained model to predict genotype labels for the provided input data. It handles both 0/1/2 encoded matrices and one-hot encoded matrices, converting them as necessary for model input. The method returns the predicted labels and, optionally, the predicted probabilities.
|
|
852
1016
|
|
|
853
1017
|
Args:
|
|
854
1018
|
model (torch.nn.Module): Trained model.
|
|
855
|
-
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
|
|
856
|
-
return_proba (bool): If True,
|
|
1019
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
|
|
1020
|
+
return_proba (bool): If True, return probabilities.
|
|
857
1021
|
|
|
858
1022
|
Returns:
|
|
859
|
-
|
|
1023
|
+
tuple[np.ndarray, np.ndarray | None]: (labels, probas|None).
|
|
860
1024
|
"""
|
|
861
1025
|
if model is None:
|
|
862
|
-
msg =
|
|
1026
|
+
msg = (
|
|
1027
|
+
"Model passed to predict() is not trained. "
|
|
1028
|
+
"Call fit() before predict()."
|
|
1029
|
+
)
|
|
863
1030
|
self.logger.error(msg)
|
|
864
1031
|
raise NotFittedError(msg)
|
|
865
1032
|
|
|
866
1033
|
model.eval()
|
|
1034
|
+
|
|
1035
|
+
nF = self.num_features_
|
|
1036
|
+
nC = self.num_classes_
|
|
1037
|
+
|
|
1038
|
+
if isinstance(X, torch.Tensor):
|
|
1039
|
+
X_tensor = X
|
|
1040
|
+
else:
|
|
1041
|
+
X_tensor = torch.from_numpy(X)
|
|
1042
|
+
X_tensor = X_tensor.float()
|
|
1043
|
+
|
|
1044
|
+
if X_tensor.device != self.device:
|
|
1045
|
+
X_tensor = X_tensor.to(self.device)
|
|
1046
|
+
|
|
1047
|
+
if X_tensor.dim() == 2:
|
|
1048
|
+
# 0/1/2 matrix -> one-hot for model input
|
|
1049
|
+
X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
|
|
1050
|
+
X_tensor = X_tensor.float()
|
|
1051
|
+
|
|
1052
|
+
if X_tensor.device != self.device:
|
|
1053
|
+
X_tensor = X_tensor.to(self.device)
|
|
1054
|
+
|
|
1055
|
+
elif X_tensor.dim() != 3:
|
|
1056
|
+
msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
|
|
1057
|
+
self.logger.error(msg)
|
|
1058
|
+
raise ValueError(msg)
|
|
1059
|
+
|
|
1060
|
+
if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
|
|
1061
|
+
msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
|
|
1062
|
+
self.logger.error(msg)
|
|
1063
|
+
raise ValueError(msg)
|
|
1064
|
+
|
|
1065
|
+
X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
|
|
1066
|
+
|
|
867
1067
|
with torch.no_grad():
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
outputs = model(x_ohe)
|
|
873
|
-
logits = outputs[0].view(-1, self.num_features_, self.output_classes_)
|
|
874
|
-
probas = torch.softmax(logits, dim=-1)
|
|
875
|
-
labels = torch.argmax(probas, dim=-1)
|
|
876
|
-
else:
|
|
877
|
-
x_in = self._encode_multilabel_inputs(X_tensor)
|
|
878
|
-
outputs = model(x_in)
|
|
879
|
-
logits = outputs[0].view(-1, self.num_features_, self.output_classes_)
|
|
880
|
-
probas2 = torch.sigmoid(logits)
|
|
881
|
-
p_ref = probas2[..., 0]
|
|
882
|
-
p_alt = probas2[..., 1]
|
|
883
|
-
p_het = p_ref * p_alt
|
|
884
|
-
p_ref_only = p_ref * (1 - p_alt)
|
|
885
|
-
p_alt_only = p_alt * (1 - p_ref)
|
|
886
|
-
probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
|
|
887
|
-
probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
|
|
888
|
-
labels = torch.argmax(probas, dim=-1)
|
|
1068
|
+
raw = model(X_tensor)
|
|
1069
|
+
logits = raw[0].view(-1, nF, nC)
|
|
1070
|
+
probas = torch.softmax(logits, dim=-1)
|
|
1071
|
+
labels = torch.argmax(probas, dim=-1)
|
|
889
1072
|
|
|
890
1073
|
if return_proba:
|
|
891
1074
|
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
892
|
-
|
|
893
|
-
return labels.cpu().numpy()
|
|
1075
|
+
return labels.cpu().numpy(), None
|
|
894
1076
|
|
|
895
1077
|
def _evaluate_model(
|
|
896
1078
|
self,
|
|
897
|
-
X_val: np.ndarray,
|
|
898
1079
|
model: torch.nn.Module,
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
1080
|
+
X: np.ndarray | torch.Tensor,
|
|
1081
|
+
y: np.ndarray,
|
|
1082
|
+
eval_mask: np.ndarray,
|
|
902
1083
|
*,
|
|
903
|
-
|
|
1084
|
+
objective_mode: bool = False,
|
|
904
1085
|
) -> Dict[str, float]:
|
|
905
|
-
"""Evaluate
|
|
1086
|
+
"""Evaluate model performance on masked genotypes.
|
|
906
1087
|
|
|
907
|
-
This method evaluates the trained
|
|
1088
|
+
This method evaluates the performance of the trained model on a given dataset using a specified evaluation mask. It computes various classification metrics based on the predicted labels and probabilities, comparing them to the ground truth labels. The method returns a dictionary of evaluation metrics.
|
|
908
1089
|
|
|
909
1090
|
Args:
|
|
910
|
-
X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
|
|
911
1091
|
model (torch.nn.Module): Trained model.
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
1092
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
|
|
1093
|
+
y (np.ndarray): Ground truth 0/1/2 matrix with -1 for missing.
|
|
1094
|
+
eval_mask (np.ndarray): Boolean mask indicating which genotypes to evaluate.
|
|
1095
|
+
objective_mode (bool): If True, suppresses verbose output.
|
|
916
1096
|
|
|
917
1097
|
Returns:
|
|
918
|
-
Dict[str, float]:
|
|
919
|
-
|
|
920
|
-
Raises:
|
|
921
|
-
NotFittedError: If called before fit().
|
|
1098
|
+
Dict[str, float]: Evaluation metrics.
|
|
922
1099
|
"""
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
1100
|
+
if model is None:
|
|
1101
|
+
msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
|
|
1102
|
+
self.logger.error(msg)
|
|
1103
|
+
raise NotFittedError(msg)
|
|
926
1104
|
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
)
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
hasattr(self, "X_train_")
|
|
938
|
-
and getattr(self, "X_train_", None) is not None
|
|
939
|
-
and X_val.shape[0] == self.X_train_.shape[0]
|
|
940
|
-
):
|
|
941
|
-
GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
|
|
942
|
-
else:
|
|
943
|
-
GT_ref = self.ground_truth_
|
|
1105
|
+
pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
|
|
1106
|
+
|
|
1107
|
+
if pred_probas is None:
|
|
1108
|
+
msg = "Predicted probabilities are None in _evaluate_model()."
|
|
1109
|
+
self.logger.error(msg)
|
|
1110
|
+
raise ValueError(msg)
|
|
1111
|
+
|
|
1112
|
+
y_true_flat = y[eval_mask].astype(np.int8, copy=False)
|
|
1113
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
|
|
1114
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
|
|
944
1115
|
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
1116
|
+
valid = y_true_flat >= 0
|
|
1117
|
+
y_true_flat = y_true_flat[valid]
|
|
1118
|
+
y_pred_flat = y_pred_flat[valid]
|
|
1119
|
+
y_proba_flat = y_proba_flat[valid]
|
|
949
1120
|
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
GT_ref = X_val
|
|
1121
|
+
if y_true_flat.size == 0:
|
|
1122
|
+
return {self.tune_metric: 0.0}
|
|
953
1123
|
|
|
954
|
-
#
|
|
955
|
-
if
|
|
956
|
-
|
|
1124
|
+
# --- Hard assertions on probability shape ---
|
|
1125
|
+
if y_proba_flat.ndim != 2:
|
|
1126
|
+
msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got shape {y_proba_flat.shape}."
|
|
1127
|
+
self.logger.error(msg)
|
|
1128
|
+
raise ValueError(msg)
|
|
1129
|
+
|
|
1130
|
+
K = int(y_proba_flat.shape[1])
|
|
1131
|
+
if self.is_haploid_:
|
|
1132
|
+
if K not in (2, 3):
|
|
1133
|
+
msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
|
|
1134
|
+
self.logger.error(msg)
|
|
1135
|
+
raise ValueError(msg)
|
|
1136
|
+
else:
|
|
1137
|
+
if K != 3:
|
|
1138
|
+
msg = f"Diploid evaluation expects 3 classes; got {K}."
|
|
1139
|
+
self.logger.error(msg)
|
|
1140
|
+
raise ValueError(msg)
|
|
1141
|
+
|
|
1142
|
+
if not self.is_haploid_:
|
|
1143
|
+
if np.any((y_true_flat < 0) | (y_true_flat > 2)):
|
|
957
1144
|
msg = (
|
|
958
|
-
|
|
959
|
-
f"does not match X_val rows {X_val.shape[0]}"
|
|
1145
|
+
"Diploid y_true_flat contains values outside {0,1,2} after masking."
|
|
960
1146
|
)
|
|
961
1147
|
self.logger.error(msg)
|
|
962
1148
|
raise ValueError(msg)
|
|
963
1149
|
|
|
964
|
-
|
|
965
|
-
|
|
1150
|
+
# --- Harmonize for haploid vs diploid ---
|
|
1151
|
+
if self.is_haploid_:
|
|
1152
|
+
# Binary scoring: REF=0, ALT=1 (treat any non-zero as ALT)
|
|
1153
|
+
y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
|
|
1154
|
+
y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
|
|
1155
|
+
|
|
1156
|
+
K = y_proba_flat.shape[1]
|
|
1157
|
+
if K == 2:
|
|
1158
|
+
pass
|
|
1159
|
+
elif K == 3:
|
|
1160
|
+
proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
|
|
1161
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
1162
|
+
proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
|
|
1163
|
+
y_proba_flat = proba_2
|
|
966
1164
|
else:
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
eval_mask = eval_mask & finite_mask & (GT_ref != -1)
|
|
972
|
-
|
|
973
|
-
y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
|
|
974
|
-
y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
|
|
975
|
-
y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
|
|
1165
|
+
msg = f"Haploid evaluation expects 2 or 3 prob columns; got {K}"
|
|
1166
|
+
self.logger.error(msg)
|
|
1167
|
+
raise ValueError(msg)
|
|
976
1168
|
|
|
977
|
-
|
|
978
|
-
|
|
1169
|
+
labels_for_scoring = [0, 1]
|
|
1170
|
+
target_names = ["REF", "ALT"]
|
|
1171
|
+
else:
|
|
1172
|
+
if y_proba_flat.shape[1] != 3:
|
|
1173
|
+
msg = f"Diploid evaluation expects 3 prob columns; got {y_proba_flat.shape[1]}"
|
|
1174
|
+
self.logger.error(msg)
|
|
1175
|
+
raise ValueError(msg)
|
|
1176
|
+
labels_for_scoring = [0, 1, 2]
|
|
1177
|
+
target_names = ["REF", "HET", "ALT"]
|
|
979
1178
|
|
|
980
|
-
#
|
|
1179
|
+
# Ensure valid probability simplex after masking/collapsing
|
|
981
1180
|
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
982
1181
|
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
983
|
-
row_sums[row_sums == 0] = 1.0
|
|
1182
|
+
row_sums[row_sums == 0.0] = 1.0
|
|
984
1183
|
y_proba_flat = y_proba_flat / row_sums
|
|
985
1184
|
|
|
986
|
-
|
|
987
|
-
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
988
|
-
|
|
989
|
-
if self.is_haploid:
|
|
990
|
-
y_true_flat = y_true_flat.copy()
|
|
991
|
-
y_pred_flat = y_pred_flat.copy()
|
|
992
|
-
y_true_flat[y_true_flat == 2] = 1
|
|
993
|
-
y_pred_flat[y_pred_flat == 2] = 1
|
|
994
|
-
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
995
|
-
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
996
|
-
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
997
|
-
y_proba_flat = proba_2
|
|
998
|
-
|
|
999
|
-
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
1185
|
+
y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
|
|
1000
1186
|
|
|
1001
1187
|
metrics = self.scorers_.evaluate(
|
|
1002
1188
|
y_true_flat,
|
|
@@ -1004,16 +1190,29 @@ class ImputeVAE(BaseNNImputer):
|
|
|
1004
1190
|
y_true_ohe,
|
|
1005
1191
|
y_proba_flat,
|
|
1006
1192
|
objective_mode,
|
|
1007
|
-
|
|
1193
|
+
cast(
|
|
1194
|
+
Literal[
|
|
1195
|
+
"pr_macro",
|
|
1196
|
+
"roc_auc",
|
|
1197
|
+
"accuracy",
|
|
1198
|
+
"f1",
|
|
1199
|
+
"average_precision",
|
|
1200
|
+
"precision",
|
|
1201
|
+
"recall",
|
|
1202
|
+
"mcc",
|
|
1203
|
+
"jaccard",
|
|
1204
|
+
],
|
|
1205
|
+
self.tune_metric,
|
|
1206
|
+
),
|
|
1008
1207
|
)
|
|
1009
1208
|
|
|
1010
1209
|
if not objective_mode:
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1210
|
+
if self.verbose or self.debug:
|
|
1211
|
+
pm = PrettyMetrics(
|
|
1212
|
+
metrics, precision=2, title=f"{self.model_name} Validation Metrics"
|
|
1213
|
+
)
|
|
1214
|
+
pm.render()
|
|
1015
1215
|
|
|
1016
|
-
# Primary report
|
|
1017
1216
|
self._make_class_reports(
|
|
1018
1217
|
y_true=y_true_flat,
|
|
1019
1218
|
y_pred_proba=y_proba_flat,
|
|
@@ -1022,16 +1221,17 @@ class ImputeVAE(BaseNNImputer):
|
|
|
1022
1221
|
labels=target_names,
|
|
1023
1222
|
)
|
|
1024
1223
|
|
|
1025
|
-
# IUPAC decode
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
)
|
|
1224
|
+
# --- IUPAC decode and 10-base integer report ---
|
|
1225
|
+
y_true_matrix = np.array(y, copy=True)
|
|
1226
|
+
y_pred_matrix = np.array(pred_labels, copy=True)
|
|
1227
|
+
|
|
1228
|
+
if self.is_haploid_:
|
|
1229
|
+
# Map any ALT-coded >0 to 2 for decode_012, preserve missing -1
|
|
1230
|
+
y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
|
|
1231
|
+
y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
|
|
1232
|
+
|
|
1233
|
+
y_true_dec = self.decode_012(y_true_matrix)
|
|
1234
|
+
y_pred_dec = self.decode_012(y_pred_matrix)
|
|
1035
1235
|
|
|
1036
1236
|
encodings_dict = {
|
|
1037
1237
|
"A": 0,
|
|
@@ -1079,80 +1279,71 @@ class ImputeVAE(BaseNNImputer):
|
|
|
1079
1279
|
|
|
1080
1280
|
Returns:
|
|
1081
1281
|
float: Value of the tuning metric to be optimized.
|
|
1282
|
+
|
|
1283
|
+
Raises:
|
|
1284
|
+
RuntimeError: If model training returns None.
|
|
1285
|
+
optuna.exceptions.TrialPruned: If training fails unexpectedly or is unpromising.
|
|
1082
1286
|
"""
|
|
1083
1287
|
try:
|
|
1084
1288
|
params = self._sample_hyperparameters(trial)
|
|
1085
1289
|
|
|
1086
|
-
# Use tune subsets when available (tune_fast)
|
|
1087
|
-
X_train = getattr(self, "_tune_X_train", None)
|
|
1088
|
-
X_val = getattr(self, "_tune_X_test", None)
|
|
1089
|
-
if X_train is None or X_val is None:
|
|
1090
|
-
X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
|
|
1091
|
-
X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
|
|
1092
|
-
|
|
1093
|
-
class_weights = self._normalize_class_weights(
|
|
1094
|
-
self._class_weights_from_zygosity(X_train)
|
|
1095
|
-
)
|
|
1096
|
-
# Pos weights for diploid multilabel BCE during tuning
|
|
1097
|
-
if not self.is_haploid:
|
|
1098
|
-
self.pos_weights_ = self._compute_pos_weights(X_train)
|
|
1099
|
-
else:
|
|
1100
|
-
self.pos_weights_ = None
|
|
1101
|
-
train_loader = self._get_data_loader(X_train)
|
|
1102
|
-
|
|
1103
1290
|
model = self.build_model(self.Model, params["model_params"])
|
|
1104
1291
|
model.apply(self.initialize_weights)
|
|
1105
1292
|
|
|
1106
|
-
lr: float = params["
|
|
1293
|
+
lr: float = params["learning_rate"]
|
|
1107
1294
|
l1_penalty: float = params["l1_penalty"]
|
|
1108
1295
|
|
|
1109
|
-
|
|
1110
|
-
|
|
1296
|
+
class_weights = self._class_weights_from_zygosity(
|
|
1297
|
+
self.y_train_,
|
|
1298
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1299
|
+
inverse=params["inverse"],
|
|
1300
|
+
normalize=params["normalize"],
|
|
1301
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1302
|
+
power=params["power"],
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
res = self._train_and_validate_model(
|
|
1111
1306
|
model=model,
|
|
1112
|
-
loader=train_loader,
|
|
1113
1307
|
lr=lr,
|
|
1114
1308
|
l1_penalty=l1_penalty,
|
|
1309
|
+
params=params,
|
|
1115
1310
|
trial=trial,
|
|
1116
|
-
return_history=False,
|
|
1117
1311
|
class_weights=class_weights,
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
prune_metric=self.tune_metric,
|
|
1121
|
-
prune_warmup_epochs=10,
|
|
1122
|
-
eval_interval=self.tune_eval_interval,
|
|
1123
|
-
eval_requires_latents=False,
|
|
1124
|
-
eval_latent_steps=0,
|
|
1125
|
-
eval_latent_lr=0.0,
|
|
1126
|
-
eval_latent_weight_decay=0.0,
|
|
1127
|
-
)
|
|
1128
|
-
|
|
1129
|
-
eval_mask = (
|
|
1130
|
-
self.sim_mask_test_
|
|
1131
|
-
if (
|
|
1132
|
-
self.simulate_missing
|
|
1133
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
1134
|
-
)
|
|
1135
|
-
else None
|
|
1312
|
+
kl_beta_schedule=params["kl_beta_schedule"],
|
|
1313
|
+
gamma_schedule=params["gamma_schedule"],
|
|
1136
1314
|
)
|
|
1315
|
+
model = res[1]
|
|
1137
1316
|
|
|
1138
1317
|
if model is None:
|
|
1139
|
-
|
|
1318
|
+
msg = "Model training returned None in tuning objective."
|
|
1319
|
+
self.logger.error(msg)
|
|
1320
|
+
raise RuntimeError(msg)
|
|
1140
1321
|
|
|
1141
1322
|
metrics = self._evaluate_model(
|
|
1142
|
-
|
|
1323
|
+
model=model,
|
|
1324
|
+
X=self.X_val_corrupted_,
|
|
1325
|
+
y=self.y_val_,
|
|
1326
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
1327
|
+
objective_mode=True,
|
|
1143
1328
|
)
|
|
1144
|
-
|
|
1329
|
+
|
|
1330
|
+
self._clear_resources(model)
|
|
1145
1331
|
return metrics[self.tune_metric]
|
|
1146
1332
|
|
|
1147
1333
|
except Exception as e:
|
|
1148
|
-
#
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1334
|
+
# Unexpected failure: surface full details in logs while still
|
|
1335
|
+
# pruning the trial to keep sweeps moving.
|
|
1336
|
+
err_type = type(e).__name__
|
|
1337
|
+
self.logger.warning(
|
|
1338
|
+
f"Trial {trial.number} failed due to exception {err_type}: {e}"
|
|
1152
1339
|
)
|
|
1340
|
+
self.logger.debug(traceback.format_exc())
|
|
1341
|
+
raise optuna.exceptions.TrialPruned(
|
|
1342
|
+
f"Trial {trial.number} failed due to an exception. {err_type}: {e}. Enable debug logging for full traceback."
|
|
1343
|
+
) from e
|
|
1153
1344
|
|
|
1154
1345
|
def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
|
|
1155
|
-
"""Sample VAE hyperparameters; hidden sizes
|
|
1346
|
+
"""Sample VAE hyperparameters; hidden sizes use BaseNNImputer helper.
|
|
1156
1347
|
|
|
1157
1348
|
Args:
|
|
1158
1349
|
trial (optuna.Trial): Optuna trial object.
|
|
@@ -1161,156 +1352,109 @@ class ImputeVAE(BaseNNImputer):
|
|
|
1161
1352
|
Dict[str, int | float | str]: Sampled hyperparameters.
|
|
1162
1353
|
"""
|
|
1163
1354
|
params = {
|
|
1164
|
-
"latent_dim": trial.suggest_int("latent_dim",
|
|
1165
|
-
"
|
|
1166
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.
|
|
1167
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1,
|
|
1355
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
1356
|
+
"learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
|
|
1357
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
|
|
1358
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
|
|
1168
1359
|
"activation": trial.suggest_categorical(
|
|
1169
1360
|
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
1170
1361
|
),
|
|
1171
1362
|
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
1172
1363
|
"layer_scaling_factor": trial.suggest_float(
|
|
1173
|
-
"layer_scaling_factor", 2.0,
|
|
1364
|
+
"layer_scaling_factor", 2.0, 10.0, step=0.025
|
|
1174
1365
|
),
|
|
1175
1366
|
"layer_schedule": trial.suggest_categorical(
|
|
1176
1367
|
"layer_schedule", ["pyramid", "linear"]
|
|
1177
1368
|
),
|
|
1178
|
-
|
|
1179
|
-
"
|
|
1180
|
-
|
|
1181
|
-
"gamma": trial.suggest_float("gamma", 0.
|
|
1369
|
+
"power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
|
|
1370
|
+
"normalize": trial.suggest_categorical("normalize", [True, False]),
|
|
1371
|
+
"inverse": trial.suggest_categorical("inverse", [True, False]),
|
|
1372
|
+
"gamma": trial.suggest_float("gamma", 0.0, 3.0, step=0.1),
|
|
1373
|
+
"kl_beta": trial.suggest_float("kl_beta", 0.1, 5.0, step=0.1),
|
|
1374
|
+
"kl_beta_schedule": trial.suggest_categorical(
|
|
1375
|
+
"kl_beta_schedule", [True, False]
|
|
1376
|
+
),
|
|
1377
|
+
"gamma_schedule": trial.suggest_categorical(
|
|
1378
|
+
"gamma_schedule", [True, False]
|
|
1379
|
+
),
|
|
1182
1380
|
}
|
|
1183
1381
|
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
)
|
|
1189
|
-
input_dim = use_n_features * self.output_classes_
|
|
1382
|
+
nF: int = self.num_features_
|
|
1383
|
+
nC: int = self.num_classes_
|
|
1384
|
+
input_dim = nF * nC
|
|
1385
|
+
|
|
1190
1386
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1191
1387
|
n_inputs=input_dim,
|
|
1192
|
-
n_outputs=
|
|
1193
|
-
n_samples=len(self.
|
|
1388
|
+
n_outputs=nC,
|
|
1389
|
+
n_samples=len(self.X_train_),
|
|
1194
1390
|
n_hidden=params["num_hidden_layers"],
|
|
1391
|
+
latent_dim=params["latent_dim"],
|
|
1195
1392
|
alpha=params["layer_scaling_factor"],
|
|
1196
1393
|
schedule=params["layer_schedule"],
|
|
1197
1394
|
)
|
|
1198
1395
|
|
|
1199
|
-
# [latent_dim] + interior widths (exclude output width)
|
|
1200
|
-
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1201
|
-
|
|
1202
1396
|
params["model_params"] = {
|
|
1203
|
-
"n_features":
|
|
1204
|
-
"num_classes":
|
|
1205
|
-
"latent_dim": params["latent_dim"],
|
|
1397
|
+
"n_features": self.num_features_,
|
|
1398
|
+
"num_classes": nC, # categorical head: 2 or 3
|
|
1206
1399
|
"dropout_rate": params["dropout_rate"],
|
|
1207
|
-
"hidden_layer_sizes":
|
|
1400
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1208
1401
|
"activation": params["activation"],
|
|
1209
|
-
|
|
1210
|
-
"beta": params["beta"],
|
|
1211
|
-
"gamma": params["gamma"],
|
|
1402
|
+
"kl_beta": params["kl_beta"],
|
|
1212
1403
|
}
|
|
1404
|
+
|
|
1213
1405
|
return params
|
|
1214
1406
|
|
|
1215
|
-
def _set_best_params(self,
|
|
1216
|
-
"""
|
|
1407
|
+
def _set_best_params(self, params: dict) -> dict:
|
|
1408
|
+
"""Update instance fields from tuned params and return model_params dict.
|
|
1217
1409
|
|
|
1218
1410
|
Args:
|
|
1219
|
-
|
|
1411
|
+
params (dict): Best hyperparameters from tuning.
|
|
1220
1412
|
|
|
1221
1413
|
Returns:
|
|
1222
|
-
dict:
|
|
1414
|
+
dict: Model parameters for building the VAE.
|
|
1223
1415
|
"""
|
|
1224
|
-
self.latent_dim =
|
|
1225
|
-
self.dropout_rate =
|
|
1226
|
-
self.learning_rate =
|
|
1227
|
-
self.l1_penalty =
|
|
1228
|
-
self.activation =
|
|
1229
|
-
self.layer_scaling_factor =
|
|
1230
|
-
self.layer_schedule =
|
|
1231
|
-
self.
|
|
1232
|
-
self.
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1416
|
+
self.latent_dim = params["latent_dim"]
|
|
1417
|
+
self.dropout_rate = params["dropout_rate"]
|
|
1418
|
+
self.learning_rate = params["learning_rate"]
|
|
1419
|
+
self.l1_penalty = params["l1_penalty"]
|
|
1420
|
+
self.activation = params["activation"]
|
|
1421
|
+
self.layer_scaling_factor = params["layer_scaling_factor"]
|
|
1422
|
+
self.layer_schedule = params["layer_schedule"]
|
|
1423
|
+
self.power = params["power"]
|
|
1424
|
+
self.normalize = params["normalize"]
|
|
1425
|
+
self.inverse = params["inverse"]
|
|
1426
|
+
self.gamma = params["gamma"]
|
|
1427
|
+
self.gamma_schedule = params["gamma_schedule"]
|
|
1428
|
+
self.kl_beta = params["kl_beta"]
|
|
1429
|
+
self.kl_beta_schedule = params["kl_beta_schedule"]
|
|
1430
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
1431
|
+
self.y_train_,
|
|
1432
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1433
|
+
inverse=self.inverse,
|
|
1434
|
+
normalize=self.normalize,
|
|
1435
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1436
|
+
power=self.power,
|
|
1241
1437
|
)
|
|
1242
|
-
|
|
1438
|
+
nF = self.num_features_
|
|
1439
|
+
nC = self.num_classes_
|
|
1440
|
+
input_dim = nF * nC
|
|
1243
1441
|
|
|
1244
|
-
return {
|
|
1245
|
-
"n_features": self.num_features_,
|
|
1246
|
-
"latent_dim": self.latent_dim,
|
|
1247
|
-
"hidden_layer_sizes": hidden_only,
|
|
1248
|
-
"dropout_rate": self.dropout_rate,
|
|
1249
|
-
"activation": self.activation,
|
|
1250
|
-
"num_classes": self.output_classes_,
|
|
1251
|
-
"beta": self.kl_beta_final,
|
|
1252
|
-
"gamma": self.gamma,
|
|
1253
|
-
}
|
|
1254
|
-
|
|
1255
|
-
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
1256
|
-
"""Default VAE model params when tuning is disabled.
|
|
1257
|
-
|
|
1258
|
-
Returns:
|
|
1259
|
-
Dict[str, int | float | str | list]: VAE model parameters.
|
|
1260
|
-
"""
|
|
1261
1442
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1262
|
-
n_inputs=
|
|
1263
|
-
n_outputs=
|
|
1264
|
-
n_samples=len(self.
|
|
1265
|
-
n_hidden=
|
|
1266
|
-
|
|
1267
|
-
|
|
1443
|
+
n_inputs=input_dim,
|
|
1444
|
+
n_outputs=nC,
|
|
1445
|
+
n_samples=len(self.X_train_),
|
|
1446
|
+
n_hidden=params["num_hidden_layers"],
|
|
1447
|
+
latent_dim=params["latent_dim"],
|
|
1448
|
+
alpha=params["layer_scaling_factor"],
|
|
1449
|
+
schedule=params["layer_schedule"],
|
|
1268
1450
|
)
|
|
1451
|
+
|
|
1269
1452
|
return {
|
|
1270
|
-
"n_features":
|
|
1453
|
+
"n_features": nF,
|
|
1271
1454
|
"latent_dim": self.latent_dim,
|
|
1272
1455
|
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1273
1456
|
"dropout_rate": self.dropout_rate,
|
|
1274
1457
|
"activation": self.activation,
|
|
1275
|
-
"num_classes":
|
|
1276
|
-
"
|
|
1277
|
-
"gamma": self.gamma,
|
|
1458
|
+
"num_classes": nC,
|
|
1459
|
+
"kl_beta": params["kl_beta"],
|
|
1278
1460
|
}
|
|
1279
|
-
|
|
1280
|
-
def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
|
|
1281
|
-
"""Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
|
|
1282
|
-
if self.is_haploid:
|
|
1283
|
-
return self._one_hot_encode_012(y)
|
|
1284
|
-
y = y.to(self.device)
|
|
1285
|
-
shape = y.shape + (2,)
|
|
1286
|
-
out = torch.zeros(shape, device=self.device, dtype=torch.float32)
|
|
1287
|
-
valid = y != -1
|
|
1288
|
-
ref_mask = valid & (y != 2)
|
|
1289
|
-
alt_mask = valid & (y != 0)
|
|
1290
|
-
out[ref_mask, 0] = 1.0
|
|
1291
|
-
out[alt_mask, 1] = 1.0
|
|
1292
|
-
return out
|
|
1293
|
-
|
|
1294
|
-
def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
|
|
1295
|
-
"""Targets aligned with _encode_multilabel_inputs for diploid training."""
|
|
1296
|
-
if self.is_haploid:
|
|
1297
|
-
raise RuntimeError("_multi_hot_targets called for haploid data.")
|
|
1298
|
-
y = y.to(self.device)
|
|
1299
|
-
out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
|
|
1300
|
-
valid = y != -1
|
|
1301
|
-
ref_mask = valid & (y != 2)
|
|
1302
|
-
alt_mask = valid & (y != 0)
|
|
1303
|
-
out[ref_mask, 0] = 1.0
|
|
1304
|
-
out[alt_mask, 1] = 1.0
|
|
1305
|
-
return out
|
|
1306
|
-
|
|
1307
|
-
def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
|
|
1308
|
-
"""Balance REF/ALT channels for multilabel BCE."""
|
|
1309
|
-
ref_pos = np.count_nonzero((X == 0) | (X == 1))
|
|
1310
|
-
alt_pos = np.count_nonzero((X == 2) | (X == 1))
|
|
1311
|
-
total_valid = np.count_nonzero(X != -1)
|
|
1312
|
-
pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
|
|
1313
|
-
neg_counts = np.maximum(total_valid - pos_counts, 1.0)
|
|
1314
|
-
pos_counts = np.maximum(pos_counts, 1.0)
|
|
1315
|
-
weights = neg_counts / pos_counts
|
|
1316
|
-
return torch.tensor(weights, device=self.device, dtype=torch.float32)
|