pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,25 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
1
4
|
import copy
|
|
2
|
-
|
|
5
|
+
import traceback
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
|
|
3
8
|
|
|
4
9
|
import matplotlib.pyplot as plt
|
|
5
10
|
import numpy as np
|
|
6
11
|
import optuna
|
|
7
12
|
import torch
|
|
8
13
|
from sklearn.exceptions import NotFittedError
|
|
9
|
-
from sklearn.model_selection import train_test_split
|
|
10
14
|
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
11
15
|
from snpio.utils.logging import LoggerManager
|
|
12
16
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
13
17
|
|
|
14
18
|
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
15
19
|
from pgsui.data_processing.containers import AutoencoderConfig
|
|
16
|
-
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
17
20
|
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
18
21
|
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
19
|
-
from pgsui.impute.unsupervised.loss_functions import
|
|
22
|
+
from pgsui.impute.unsupervised.loss_functions import FocalCELoss
|
|
20
23
|
from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
|
|
21
24
|
from pgsui.utils.logging_utils import configure_logger
|
|
22
25
|
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
@@ -26,30 +29,72 @@ if TYPE_CHECKING:
|
|
|
26
29
|
from snpio.read_input.genotype_data import GenotypeData
|
|
27
30
|
|
|
28
31
|
|
|
32
|
+
def _make_warmup_cosine_scheduler(
|
|
33
|
+
optimizer: torch.optim.Optimizer,
|
|
34
|
+
*,
|
|
35
|
+
max_epochs: int,
|
|
36
|
+
warmup_epochs: int,
|
|
37
|
+
start_factor: float = 0.1,
|
|
38
|
+
) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
|
|
39
|
+
"""Create a warmup->cosine LR scheduler.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
optimizer: Optimizer to schedule.
|
|
43
|
+
max_epochs: Total number of epochs.
|
|
44
|
+
warmup_epochs: Number of warmup epochs.
|
|
45
|
+
start_factor: Starting LR factor for warmup.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: LR scheduler (SequentialLR if warmup_epochs > 0 else CosineAnnealingLR).
|
|
49
|
+
"""
|
|
50
|
+
warmup_epochs = int(max(0, warmup_epochs))
|
|
51
|
+
|
|
52
|
+
if warmup_epochs == 0:
|
|
53
|
+
return CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
54
|
+
|
|
55
|
+
warmup = torch.optim.lr_scheduler.LinearLR(
|
|
56
|
+
optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
|
|
57
|
+
)
|
|
58
|
+
cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
|
|
59
|
+
|
|
60
|
+
return torch.optim.lr_scheduler.SequentialLR(
|
|
61
|
+
optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
29
65
|
def ensure_autoencoder_config(
|
|
30
66
|
config: AutoencoderConfig | dict | str | None,
|
|
31
67
|
) -> AutoencoderConfig:
|
|
32
68
|
"""Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
|
|
33
69
|
|
|
34
|
-
|
|
70
|
+
Notes:
|
|
71
|
+
- Supports top-level preset, or io.preset inside dict/YAML.
|
|
72
|
+
- Does not mutate user-provided dict (deep-copies before processing).
|
|
73
|
+
- Flattens nested dicts into dot-keys and applies them as overrides.
|
|
35
74
|
|
|
36
75
|
Args:
|
|
37
|
-
config
|
|
76
|
+
config: AutoencoderConfig instance, dict, YAML path, or None.
|
|
38
77
|
|
|
39
78
|
Returns:
|
|
40
|
-
|
|
79
|
+
Concrete AutoencoderConfig.
|
|
41
80
|
"""
|
|
42
81
|
if config is None:
|
|
43
82
|
return AutoencoderConfig()
|
|
44
83
|
if isinstance(config, AutoencoderConfig):
|
|
45
84
|
return config
|
|
46
85
|
if isinstance(config, str):
|
|
47
|
-
# YAML path — top-level `preset` key is supported
|
|
48
86
|
return load_yaml_to_dataclass(config, AutoencoderConfig)
|
|
49
87
|
if isinstance(config, dict):
|
|
50
|
-
|
|
88
|
+
cfg_in = copy.deepcopy(config)
|
|
51
89
|
base = AutoencoderConfig()
|
|
52
90
|
|
|
91
|
+
preset = cfg_in.pop("preset", None)
|
|
92
|
+
if "io" in cfg_in and isinstance(cfg_in["io"], dict):
|
|
93
|
+
preset = preset or cfg_in["io"].pop("preset", None)
|
|
94
|
+
|
|
95
|
+
if preset:
|
|
96
|
+
base = AutoencoderConfig.from_preset(preset)
|
|
97
|
+
|
|
53
98
|
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
54
99
|
for k, v in d.items():
|
|
55
100
|
kk = f"{prefix}.{k}" if prefix else k
|
|
@@ -59,26 +104,24 @@ def ensure_autoencoder_config(
|
|
|
59
104
|
out[kk] = v
|
|
60
105
|
return out
|
|
61
106
|
|
|
62
|
-
|
|
63
|
-
preset_name = config.pop("preset", None)
|
|
64
|
-
if "io" in config and isinstance(config["io"], dict):
|
|
65
|
-
preset_name = preset_name or config["io"].pop("preset", None)
|
|
66
|
-
|
|
67
|
-
if preset_name:
|
|
68
|
-
base = AutoencoderConfig.from_preset(preset_name)
|
|
69
|
-
|
|
70
|
-
flat = _flatten("", config, {})
|
|
107
|
+
flat = _flatten("", cfg_in, {})
|
|
71
108
|
return apply_dot_overrides(base, flat)
|
|
72
109
|
|
|
73
110
|
raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
|
|
74
111
|
|
|
75
112
|
|
|
76
113
|
class ImputeAutoencoder(BaseNNImputer):
|
|
77
|
-
"""
|
|
114
|
+
"""Autoencoder imputer for 0/1/2 genotypes.
|
|
78
115
|
|
|
79
|
-
|
|
116
|
+
Trains a feedforward autoencoder on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. Missingness is simulated once on the full matrix, then train/val/test splits reuse those masks. It supports haploid and diploid data, focal-CE reconstruction loss (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
|
|
80
117
|
|
|
81
|
-
|
|
118
|
+
Notes:
|
|
119
|
+
- Simulates missingness once on the full 0/1/2 matrix, then splits indices on clean ground truth.
|
|
120
|
+
- Maintains clean targets and corrupted inputs per train/val/test, plus per-split masks.
|
|
121
|
+
- Haploid harmonization happens after the single simulation (no re-simulation).
|
|
122
|
+
- Training/validation loss is computed only where targets are known (~orig_mask_*).
|
|
123
|
+
- Evaluation is computed only on simulated-missing sites (sim_mask_*).
|
|
124
|
+
- ``transform()`` fills only originally missing sites and hard-errors if decoding yields "N".
|
|
82
125
|
"""
|
|
83
126
|
|
|
84
127
|
def __init__(
|
|
@@ -87,8 +130,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
87
130
|
*,
|
|
88
131
|
tree_parser: Optional["TreeParser"] = None,
|
|
89
132
|
config: Optional[Union["AutoencoderConfig", dict, str]] = None,
|
|
90
|
-
overrides: dict
|
|
91
|
-
simulate_missing: bool | None = None,
|
|
133
|
+
overrides: Optional[dict] = None,
|
|
92
134
|
sim_strategy: (
|
|
93
135
|
Literal[
|
|
94
136
|
"random",
|
|
@@ -99,34 +141,29 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
99
141
|
]
|
|
100
142
|
| None
|
|
101
143
|
) = None,
|
|
102
|
-
sim_prop: float
|
|
103
|
-
sim_kwargs: dict
|
|
144
|
+
sim_prop: Optional[float] = None,
|
|
145
|
+
sim_kwargs: Optional[dict] = None,
|
|
104
146
|
) -> None:
|
|
105
147
|
"""Initialize the Autoencoder imputer with a unified config interface.
|
|
106
148
|
|
|
107
|
-
This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
108
|
-
|
|
109
149
|
Args:
|
|
110
|
-
genotype_data (
|
|
111
|
-
tree_parser (Optional[
|
|
112
|
-
config (Union[
|
|
113
|
-
overrides (dict
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
sim_kwargs (dict | None): Additional keyword arguments for simulating missing data. If None, uses config default.
|
|
150
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
151
|
+
tree_parser (Optional[TreeParser]): Optional SNPio tree parser for nonrandom simulated-missing modes.
|
|
152
|
+
config (Optional[Union[AutoencoderConfig, dict, str]]): AutoencoderConfig, nested dict, YAML path, or None.
|
|
153
|
+
overrides (Optional[dict]): Optional dot-key overrides with highest precedence.
|
|
154
|
+
sim_strategy (Literal["random", "random_weighted" "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Override sim strategy; if None, uses config default.
|
|
155
|
+
sim_prop (Optional[float]): Override simulated missing proportion; if None, uses config default. Default is None.
|
|
156
|
+
sim_kwargs (Optional[dict]): Override/extend simulated missing kwargs; if None, uses config default.
|
|
118
157
|
"""
|
|
119
158
|
self.model_name = "ImputeAutoencoder"
|
|
120
159
|
self.genotype_data = genotype_data
|
|
121
160
|
self.tree_parser = tree_parser
|
|
122
161
|
|
|
123
|
-
# Normalize config then apply highest-precedence overrides
|
|
124
162
|
cfg = ensure_autoencoder_config(config)
|
|
125
163
|
if overrides:
|
|
126
164
|
cfg = apply_dot_overrides(cfg, overrides)
|
|
127
165
|
self.cfg = cfg
|
|
128
166
|
|
|
129
|
-
# Logger consistent with NLPCA
|
|
130
167
|
logman = LoggerManager(
|
|
131
168
|
__name__,
|
|
132
169
|
prefix=self.cfg.io.prefix,
|
|
@@ -138,8 +175,8 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
138
175
|
verbose=self.cfg.io.verbose,
|
|
139
176
|
debug=self.cfg.io.debug,
|
|
140
177
|
)
|
|
178
|
+
self.logger.propagate = False
|
|
141
179
|
|
|
142
|
-
# BaseNNImputer bootstrapping (device/dirs/logging handled here)
|
|
143
180
|
super().__init__(
|
|
144
181
|
model_name=self.model_name,
|
|
145
182
|
genotype_data=self.genotype_data,
|
|
@@ -150,11 +187,9 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
150
187
|
)
|
|
151
188
|
|
|
152
189
|
self.Model = AutoencoderModel
|
|
153
|
-
|
|
154
|
-
# Model hook & encoder
|
|
155
190
|
self.pgenc = GenotypeEncoder(genotype_data)
|
|
156
191
|
|
|
157
|
-
#
|
|
192
|
+
# I/O and global
|
|
158
193
|
self.seed = self.cfg.io.seed
|
|
159
194
|
self.n_jobs = self.cfg.io.n_jobs
|
|
160
195
|
self.prefix = self.cfg.io.prefix
|
|
@@ -163,186 +198,150 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
163
198
|
self.debug = self.cfg.io.debug
|
|
164
199
|
self.rng = np.random.default_rng(self.seed)
|
|
165
200
|
|
|
166
|
-
#
|
|
201
|
+
# Simulation controls (match VAE pattern)
|
|
167
202
|
sim_cfg = getattr(self.cfg, "sim", None)
|
|
168
203
|
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
169
204
|
if sim_kwargs:
|
|
170
205
|
sim_cfg_kwargs.update(sim_kwargs)
|
|
171
|
-
|
|
172
|
-
(
|
|
173
|
-
sim_cfg.simulate_missing
|
|
174
|
-
if simulate_missing is None
|
|
175
|
-
else bool(simulate_missing)
|
|
176
|
-
)
|
|
177
|
-
if sim_cfg is not None
|
|
178
|
-
else bool(simulate_missing)
|
|
179
|
-
)
|
|
206
|
+
|
|
180
207
|
if sim_cfg is None:
|
|
181
208
|
default_strategy = "random"
|
|
182
|
-
default_prop = 0.
|
|
209
|
+
default_prop = 0.2
|
|
183
210
|
else:
|
|
184
211
|
default_strategy = sim_cfg.sim_strategy
|
|
185
212
|
default_prop = sim_cfg.sim_prop
|
|
213
|
+
|
|
214
|
+
self.simulate_missing = True
|
|
186
215
|
self.sim_strategy = sim_strategy or default_strategy
|
|
187
216
|
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
188
217
|
self.sim_kwargs = sim_cfg_kwargs
|
|
189
218
|
|
|
190
219
|
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
191
|
-
msg = "tree_parser is required for nonrandom
|
|
220
|
+
msg = "tree_parser is required for nonrandom sim strategies."
|
|
192
221
|
self.logger.error(msg)
|
|
193
222
|
raise ValueError(msg)
|
|
194
223
|
|
|
195
|
-
# Model
|
|
224
|
+
# Model architecture
|
|
196
225
|
self.latent_dim = int(self.cfg.model.latent_dim)
|
|
197
226
|
self.dropout_rate = float(self.cfg.model.dropout_rate)
|
|
198
227
|
self.num_hidden_layers = int(self.cfg.model.num_hidden_layers)
|
|
199
228
|
self.layer_scaling_factor = float(self.cfg.model.layer_scaling_factor)
|
|
200
|
-
self.layer_schedule
|
|
201
|
-
self.activation = str(self.cfg.model.
|
|
202
|
-
|
|
229
|
+
self.layer_schedule = str(self.cfg.model.layer_schedule)
|
|
230
|
+
self.activation = str(self.cfg.model.activation)
|
|
231
|
+
|
|
232
|
+
# Training / loss controls (align with VAE fields where present)
|
|
233
|
+
self.power = float(getattr(self.cfg.train, "weights_power", 1.0))
|
|
234
|
+
self.max_ratio = getattr(self.cfg.train, "weights_max_ratio", None)
|
|
235
|
+
self.normalize = bool(getattr(self.cfg.train, "weights_normalize", True))
|
|
236
|
+
self.inverse = bool(getattr(self.cfg.train, "weights_inverse", False))
|
|
203
237
|
|
|
204
|
-
# Train hyperparams
|
|
205
238
|
self.batch_size = int(self.cfg.train.batch_size)
|
|
206
239
|
self.learning_rate = float(self.cfg.train.learning_rate)
|
|
207
|
-
self.l1_penalty
|
|
240
|
+
self.l1_penalty = float(self.cfg.train.l1_penalty)
|
|
208
241
|
self.early_stop_gen = int(self.cfg.train.early_stop_gen)
|
|
209
242
|
self.min_epochs = int(self.cfg.train.min_epochs)
|
|
210
243
|
self.epochs = int(self.cfg.train.max_epochs)
|
|
211
244
|
self.validation_split = float(self.cfg.train.validation_split)
|
|
212
|
-
self.beta = float(self.cfg.train.weights_beta)
|
|
213
|
-
self.max_ratio = float(self.cfg.train.weights_max_ratio)
|
|
214
245
|
|
|
215
|
-
#
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
Literal[
|
|
226
|
-
"pr_macro",
|
|
227
|
-
"f1",
|
|
228
|
-
"accuracy",
|
|
229
|
-
"precision",
|
|
230
|
-
"recall",
|
|
231
|
-
"roc_auc",
|
|
232
|
-
"average_precision",
|
|
233
|
-
]
|
|
234
|
-
| None
|
|
235
|
-
) = self.cfg.tune.metric
|
|
246
|
+
# Gamma can live in cfg.model or cfg.train depending on your dataclasses
|
|
247
|
+
gamma_raw = getattr(
|
|
248
|
+
self.cfg.train, "gamma", getattr(self.cfg.model, "gamma", 0.0)
|
|
249
|
+
)
|
|
250
|
+
if not isinstance(gamma_raw, (float, int)):
|
|
251
|
+
msg = f"Gamma must be float|int; got {type(gamma_raw)}."
|
|
252
|
+
self.logger.error(msg)
|
|
253
|
+
raise TypeError(msg)
|
|
254
|
+
self.gamma = float(gamma_raw)
|
|
255
|
+
self.gamma_schedule = bool(getattr(self.cfg.train, "gamma_schedule", True))
|
|
236
256
|
|
|
257
|
+
# Hyperparameter tuning
|
|
258
|
+
self.tune = bool(self.cfg.tune.enabled)
|
|
259
|
+
self.tune_metric = cast(
|
|
260
|
+
Literal[
|
|
261
|
+
"pr_macro",
|
|
262
|
+
"f1",
|
|
263
|
+
"accuracy",
|
|
264
|
+
"precision",
|
|
265
|
+
"recall",
|
|
266
|
+
"roc_auc",
|
|
267
|
+
"average_precision",
|
|
268
|
+
"mcc",
|
|
269
|
+
"jaccard",
|
|
270
|
+
],
|
|
271
|
+
self.cfg.tune.metric or "f1",
|
|
272
|
+
)
|
|
237
273
|
self.n_trials = int(self.cfg.tune.n_trials)
|
|
238
274
|
self.tune_save_db = bool(self.cfg.tune.save_db)
|
|
239
275
|
self.tune_resume = bool(self.cfg.tune.resume)
|
|
240
|
-
self.tune_max_samples = int(self.cfg.tune.max_samples)
|
|
241
|
-
self.tune_max_loci = int(self.cfg.tune.max_loci)
|
|
242
|
-
self.tune_infer_epochs = int(
|
|
243
|
-
getattr(self.cfg.tune, "infer_epochs", 0)
|
|
244
|
-
) # AE unused
|
|
245
276
|
self.tune_patience = int(self.cfg.tune.patience)
|
|
246
277
|
|
|
247
|
-
#
|
|
248
|
-
|
|
249
|
-
self.eval_latent_steps: int = 0
|
|
250
|
-
self.eval_latent_lr: float = 0.0
|
|
251
|
-
self.eval_latent_weight_decay: float = 0.0
|
|
252
|
-
|
|
253
|
-
# Plotting (parity with NLPCA PlotConfig)
|
|
254
|
-
self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = (
|
|
255
|
-
self.cfg.plot.fmt
|
|
256
|
-
)
|
|
278
|
+
# Plotting
|
|
279
|
+
self.plot_format = self.cfg.plot.fmt
|
|
257
280
|
self.plot_dpi = int(self.cfg.plot.dpi)
|
|
258
281
|
self.plot_fontsize = int(self.cfg.plot.fontsize)
|
|
259
282
|
self.title_fontsize = int(self.cfg.plot.fontsize)
|
|
260
283
|
self.despine = bool(self.cfg.plot.despine)
|
|
261
284
|
self.show_plots = bool(self.cfg.plot.show)
|
|
262
285
|
|
|
263
|
-
#
|
|
264
|
-
self.
|
|
265
|
-
self.num_classes_: int
|
|
286
|
+
# Fit-time attributes
|
|
287
|
+
self.is_haploid_: bool = False
|
|
288
|
+
self.num_classes_: int = 3
|
|
266
289
|
self.model_params: Dict[str, Any] = {}
|
|
267
|
-
self.sim_mask_global_: np.ndarray | None = None
|
|
268
|
-
self.sim_mask_train_: np.ndarray | None = None
|
|
269
|
-
self.sim_mask_test_: np.ndarray | None = None
|
|
270
290
|
|
|
271
|
-
|
|
272
|
-
|
|
291
|
+
self.sim_mask_train_: np.ndarray
|
|
292
|
+
self.sim_mask_val_: np.ndarray
|
|
293
|
+
self.sim_mask_test_: np.ndarray
|
|
273
294
|
|
|
274
|
-
|
|
295
|
+
self.orig_mask_train_: np.ndarray
|
|
296
|
+
self.orig_mask_val_: np.ndarray
|
|
297
|
+
self.orig_mask_test_: np.ndarray
|
|
275
298
|
|
|
276
|
-
|
|
277
|
-
|
|
299
|
+
def fit(self) -> "ImputeAutoencoder":
|
|
300
|
+
"""Fit the Autoencoder imputer model to the genotype data.
|
|
301
|
+
|
|
302
|
+
This method performs the following steps:
|
|
303
|
+
1. Validates the presence of SNP data in the genotype data.
|
|
304
|
+
2. Determines ploidy and sets up the number of classes accordingly.
|
|
305
|
+
3. Cleans the ground truth genotype matrix and simulates missingness.
|
|
306
|
+
4. Splits the data into training, validation, and test sets.
|
|
307
|
+
5. Prepares one-hot encoded inputs for the model.
|
|
308
|
+
6. Initializes plotting utilities and valid-class masks.
|
|
309
|
+
7. Sets up data loaders for training and validation.
|
|
310
|
+
8. Performs hyperparameter tuning if enabled, otherwise uses fixed hyperparameters.
|
|
311
|
+
9. Builds and trains the Autoencoder model.
|
|
312
|
+
10. Evaluates the trained model on the test set.
|
|
313
|
+
11. Returns the fitted ImputeAutoencoder instance.
|
|
278
314
|
|
|
279
|
-
|
|
280
|
-
|
|
315
|
+
Returns:
|
|
316
|
+
ImputeAutoencoder: The fitted ImputeAutoencoder instance.
|
|
281
317
|
"""
|
|
282
318
|
self.logger.info(f"Fitting {self.model_name} model...")
|
|
283
319
|
|
|
284
|
-
# --- Data prep (mirror NLPCA) ---
|
|
285
|
-
X012 = self._get_float_genotypes(copy=True)
|
|
286
|
-
GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
|
|
287
|
-
self.ground_truth_ = GT_full.astype(np.int64, copy=False)
|
|
288
|
-
|
|
289
|
-
self.sim_mask_global_ = None
|
|
290
|
-
cache_key = self._sim_mask_cache_key()
|
|
291
|
-
if self.simulate_missing:
|
|
292
|
-
cached_mask = (
|
|
293
|
-
None if cache_key is None else self._sim_mask_cache.get(cache_key)
|
|
294
|
-
)
|
|
295
|
-
if cached_mask is not None:
|
|
296
|
-
self.sim_mask_global_ = cached_mask.copy()
|
|
297
|
-
else:
|
|
298
|
-
tr = SimMissingTransformer(
|
|
299
|
-
genotype_data=self.genotype_data,
|
|
300
|
-
tree_parser=self.tree_parser,
|
|
301
|
-
prop_missing=self.sim_prop,
|
|
302
|
-
strategy=self.sim_strategy,
|
|
303
|
-
missing_val=-9,
|
|
304
|
-
mask_missing=True,
|
|
305
|
-
verbose=self.verbose,
|
|
306
|
-
**self.sim_kwargs,
|
|
307
|
-
)
|
|
308
|
-
tr.fit(X012.copy())
|
|
309
|
-
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
310
|
-
if cache_key is not None:
|
|
311
|
-
self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
|
|
312
|
-
|
|
313
|
-
X_for_model = self.ground_truth_.copy()
|
|
314
|
-
X_for_model[self.sim_mask_global_] = -1
|
|
315
|
-
else:
|
|
316
|
-
X_for_model = self.ground_truth_.copy()
|
|
317
|
-
|
|
318
320
|
if self.genotype_data.snp_data is None:
|
|
319
|
-
msg = "SNP data is required for
|
|
321
|
+
msg = f"SNP data is required for {self.model_name}."
|
|
320
322
|
self.logger.error(msg)
|
|
321
|
-
raise
|
|
323
|
+
raise AttributeError(msg)
|
|
322
324
|
|
|
323
|
-
|
|
324
|
-
self.
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
325
|
+
self.ploidy = self.cfg.io.ploidy
|
|
326
|
+
self.is_haploid_ = self.ploidy == 1
|
|
327
|
+
|
|
328
|
+
if self.ploidy > 2:
|
|
329
|
+
msg = (
|
|
330
|
+
f"{self.model_name} currently supports only haploid (1) or diploid (2) "
|
|
331
|
+
f"data; got ploidy={self.ploidy}."
|
|
330
332
|
)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
self.num_classes_ = 2 if self.is_haploid else 3
|
|
334
|
-
self.logger.info(
|
|
335
|
-
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
336
|
-
f"using {self.num_classes_} classes."
|
|
337
|
-
)
|
|
333
|
+
self.logger.error(msg)
|
|
334
|
+
raise ValueError(msg)
|
|
338
335
|
|
|
339
|
-
if self.
|
|
340
|
-
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
341
|
-
X_for_model[X_for_model == 2] = 1
|
|
336
|
+
self.num_classes_ = 2 if self.is_haploid_ else 3
|
|
342
337
|
|
|
343
|
-
|
|
338
|
+
# Clean 0/1/2 ground truth (missing=-1)
|
|
339
|
+
gt_full = self.pgenc.genotypes_012.copy()
|
|
340
|
+
gt_full[gt_full < 0] = -1
|
|
341
|
+
gt_full = np.nan_to_num(gt_full, nan=-1.0)
|
|
342
|
+
self.ground_truth_ = gt_full.astype(np.int8)
|
|
343
|
+
self.num_features_ = int(self.ground_truth_.shape[1])
|
|
344
344
|
|
|
345
|
-
# Model params (decoder outputs L * K logits)
|
|
346
345
|
self.model_params = {
|
|
347
346
|
"n_features": self.num_features_,
|
|
348
347
|
"num_classes": self.num_classes_,
|
|
@@ -351,66 +350,194 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
351
350
|
"activation": self.activation,
|
|
352
351
|
}
|
|
353
352
|
|
|
354
|
-
#
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
indices, test_size=self.validation_split, random_state=self.seed
|
|
353
|
+
# Simulate missingness ONCE on the full matrix
|
|
354
|
+
X_for_model_full, self.sim_mask_, self.orig_mask_ = self.sim_missing_transform(
|
|
355
|
+
self.ground_truth_
|
|
358
356
|
)
|
|
359
|
-
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
360
|
-
self.X_train_ = X_for_model[train_idx]
|
|
361
|
-
self.X_val_ = X_for_model[val_idx]
|
|
362
|
-
self.GT_train_full_ = self.ground_truth_[train_idx]
|
|
363
|
-
self.GT_test_full_ = self.ground_truth_[val_idx]
|
|
364
|
-
|
|
365
|
-
if self.sim_mask_global_ is not None:
|
|
366
|
-
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
367
|
-
self.sim_mask_test_ = self.sim_mask_global_[val_idx]
|
|
368
|
-
else:
|
|
369
|
-
self.sim_mask_train_ = None
|
|
370
|
-
self.sim_mask_test_ = None
|
|
371
357
|
|
|
372
|
-
#
|
|
358
|
+
# Split indices based on clean ground truth
|
|
359
|
+
self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
|
|
360
|
+
self.ground_truth_
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# --- Clean targets per split ---
|
|
364
|
+
X_train_clean = self.ground_truth_[self.train_idx_].copy()
|
|
365
|
+
X_val_clean = self.ground_truth_[self.val_idx_].copy()
|
|
366
|
+
X_test_clean = self.ground_truth_[self.test_idx_].copy()
|
|
367
|
+
|
|
368
|
+
# --- Corrupted inputs per split (from the single simulation) ---
|
|
369
|
+
X_train_corrupted = X_for_model_full[self.train_idx_].copy()
|
|
370
|
+
X_val_corrupted = X_for_model_full[self.val_idx_].copy()
|
|
371
|
+
X_test_corrupted = X_for_model_full[self.test_idx_].copy()
|
|
372
|
+
|
|
373
|
+
# --- Masks per split ---
|
|
374
|
+
self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
|
|
375
|
+
self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
|
|
376
|
+
self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
|
|
377
|
+
|
|
378
|
+
self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
|
|
379
|
+
self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
|
|
380
|
+
self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
|
|
381
|
+
|
|
382
|
+
# Persist per-split matrices
|
|
383
|
+
self.X_train_clean_ = X_train_clean
|
|
384
|
+
self.X_val_clean_ = X_val_clean
|
|
385
|
+
self.X_test_clean_ = X_test_clean
|
|
386
|
+
|
|
387
|
+
self.X_train_corrupted_ = X_train_corrupted
|
|
388
|
+
self.X_val_corrupted_ = X_val_corrupted
|
|
389
|
+
self.X_test_corrupted_ = X_test_corrupted
|
|
390
|
+
|
|
391
|
+
# Haploid harmonization (do NOT resimulate; just recode values)
|
|
392
|
+
if self.is_haploid_:
|
|
393
|
+
|
|
394
|
+
def _haploidize(arr: np.ndarray) -> np.ndarray:
|
|
395
|
+
out = arr.copy()
|
|
396
|
+
miss = out < 0
|
|
397
|
+
out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
|
|
398
|
+
out[miss] = -1
|
|
399
|
+
return out
|
|
400
|
+
|
|
401
|
+
self.X_train_clean_ = _haploidize(self.X_train_clean_)
|
|
402
|
+
self.X_val_clean_ = _haploidize(self.X_val_clean_)
|
|
403
|
+
self.X_test_clean_ = _haploidize(self.X_test_clean_)
|
|
404
|
+
|
|
405
|
+
self.X_train_corrupted_ = _haploidize(self.X_train_corrupted_)
|
|
406
|
+
self.X_val_corrupted_ = _haploidize(self.X_val_corrupted_)
|
|
407
|
+
self.X_test_corrupted_ = _haploidize(self.X_test_corrupted_)
|
|
408
|
+
|
|
409
|
+
# Convention: X_* are corrupted inputs; y_* are clean targets
|
|
410
|
+
self.X_train_ = self.X_train_corrupted_
|
|
411
|
+
self.y_train_ = self.X_train_clean_
|
|
412
|
+
|
|
413
|
+
self.X_val_ = self.X_val_corrupted_
|
|
414
|
+
self.y_val_ = self.X_val_clean_
|
|
415
|
+
|
|
416
|
+
self.X_test_ = self.X_test_corrupted_
|
|
417
|
+
self.y_test_ = self.X_test_clean_
|
|
418
|
+
|
|
419
|
+
# One-hot for loaders/model input
|
|
420
|
+
X_train_ohe = self._one_hot_encode_012(
|
|
421
|
+
self.X_train_, num_classes=self.num_classes_
|
|
422
|
+
)
|
|
423
|
+
X_val_ohe = self._one_hot_encode_012(self.X_val_, num_classes=self.num_classes_)
|
|
424
|
+
|
|
425
|
+
# Plotters/scorers + valid-class mask repairs (copied from VAE flow)
|
|
373
426
|
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
427
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
428
|
+
|
|
429
|
+
loci = getattr(self, "valid_class_mask_conflict_loci_", None)
|
|
430
|
+
if loci is not None and loci.size:
|
|
431
|
+
self._repair_ref_alt_from_iupac(loci)
|
|
432
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
433
|
+
|
|
434
|
+
train_loader = self._get_data_loaders(
|
|
435
|
+
X_train_ohe.detach().cpu().numpy(),
|
|
436
|
+
self.y_train_,
|
|
437
|
+
~self.orig_mask_train_,
|
|
438
|
+
self.batch_size,
|
|
439
|
+
shuffle=True,
|
|
440
|
+
)
|
|
441
|
+
val_loader = self._get_data_loaders(
|
|
442
|
+
X_val_ohe.detach().cpu().numpy(),
|
|
443
|
+
self.y_val_,
|
|
444
|
+
~self.orig_mask_val_,
|
|
445
|
+
self.batch_size,
|
|
446
|
+
shuffle=False,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
self.train_loader_ = train_loader
|
|
450
|
+
self.val_loader_ = val_loader
|
|
374
451
|
|
|
375
|
-
#
|
|
452
|
+
# Hyperparameter tuning or fixed run
|
|
376
453
|
if self.tune:
|
|
377
|
-
self.tune_hyperparameters()
|
|
454
|
+
self.tuned_params_ = self.tune_hyperparameters()
|
|
455
|
+
self.model_tuned_ = True
|
|
456
|
+
else:
|
|
457
|
+
self.model_tuned_ = False
|
|
458
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
459
|
+
self.y_train_,
|
|
460
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
461
|
+
inverse=self.inverse,
|
|
462
|
+
normalize=self.normalize,
|
|
463
|
+
max_ratio=self.max_ratio,
|
|
464
|
+
power=self.power,
|
|
465
|
+
)
|
|
466
|
+
self.tuned_params_ = {
|
|
467
|
+
"latent_dim": self.latent_dim,
|
|
468
|
+
"learning_rate": self.learning_rate,
|
|
469
|
+
"dropout_rate": self.dropout_rate,
|
|
470
|
+
"num_hidden_layers": self.num_hidden_layers,
|
|
471
|
+
"activation": self.activation,
|
|
472
|
+
"l1_penalty": self.l1_penalty,
|
|
473
|
+
"layer_scaling_factor": self.layer_scaling_factor,
|
|
474
|
+
"layer_schedule": self.layer_schedule,
|
|
475
|
+
"gamma": self.gamma,
|
|
476
|
+
"gamma_schedule": self.gamma_schedule,
|
|
477
|
+
"inverse": self.inverse,
|
|
478
|
+
"normalize": self.normalize,
|
|
479
|
+
"power": self.power,
|
|
480
|
+
}
|
|
481
|
+
self.tuned_params_["model_params"] = self.model_params
|
|
378
482
|
|
|
379
|
-
|
|
380
|
-
|
|
483
|
+
if self.class_weights_ is not None:
|
|
484
|
+
self.logger.info(
|
|
485
|
+
f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
|
|
486
|
+
)
|
|
381
487
|
|
|
382
|
-
#
|
|
383
|
-
self.
|
|
384
|
-
self._class_weights_from_zygosity(self.X_train_)
|
|
385
|
-
)
|
|
488
|
+
# Always start clean
|
|
489
|
+
self.best_params_ = copy.deepcopy(self.tuned_params_)
|
|
386
490
|
|
|
387
|
-
#
|
|
388
|
-
|
|
491
|
+
# Final model params (compute hidden sizes using n_inputs=L*K, mirroring VAE)
|
|
492
|
+
input_dim = int(self.num_features_ * self.num_classes_)
|
|
493
|
+
model_params_final = {
|
|
494
|
+
"n_features": int(self.num_features_),
|
|
495
|
+
"num_classes": int(self.num_classes_),
|
|
496
|
+
"latent_dim": int(self.best_params_["latent_dim"]),
|
|
497
|
+
"dropout_rate": float(self.best_params_["dropout_rate"]),
|
|
498
|
+
"activation": str(self.best_params_["activation"]),
|
|
499
|
+
}
|
|
500
|
+
model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
|
|
501
|
+
n_inputs=input_dim,
|
|
502
|
+
n_outputs=int(self.num_classes_),
|
|
503
|
+
n_samples=len(self.train_idx_),
|
|
504
|
+
n_hidden=int(self.best_params_["num_hidden_layers"]),
|
|
505
|
+
latent_dim=int(self.best_params_["latent_dim"]),
|
|
506
|
+
alpha=float(self.best_params_["layer_scaling_factor"]),
|
|
507
|
+
schedule=str(self.best_params_["layer_schedule"]),
|
|
508
|
+
min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
|
|
509
|
+
)
|
|
510
|
+
self.best_params_["model_params"] = model_params_final
|
|
389
511
|
|
|
390
|
-
# Build
|
|
391
|
-
model = self.build_model(self.Model, self.best_params_)
|
|
512
|
+
# Build and train
|
|
513
|
+
model = self.build_model(self.Model, self.best_params_["model_params"])
|
|
392
514
|
model.apply(self.initialize_weights)
|
|
393
515
|
|
|
516
|
+
if self.verbose or self.debug:
|
|
517
|
+
self.logger.info("Using model hyperparameters:")
|
|
518
|
+
pm = PrettyMetrics(
|
|
519
|
+
self.best_params_, precision=3, title="Model Hyperparameters"
|
|
520
|
+
)
|
|
521
|
+
pm.render()
|
|
522
|
+
|
|
523
|
+
lr_final = float(self.best_params_["learning_rate"])
|
|
524
|
+
l1_final = float(self.best_params_["l1_penalty"])
|
|
525
|
+
gamma_schedule = bool(
|
|
526
|
+
self.best_params_.get("gamma_schedule", self.gamma_schedule)
|
|
527
|
+
)
|
|
528
|
+
|
|
394
529
|
loss, trained_model, history = self._train_and_validate_model(
|
|
395
530
|
model=model,
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
l1_penalty=self.l1_penalty,
|
|
399
|
-
return_history=True,
|
|
400
|
-
class_weights=self.class_weights_,
|
|
401
|
-
X_val=self.X_val_,
|
|
531
|
+
lr=lr_final,
|
|
532
|
+
l1_penalty=l1_final,
|
|
402
533
|
params=self.best_params_,
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
eval_requires_latents=False,
|
|
407
|
-
eval_latent_steps=0,
|
|
408
|
-
eval_latent_lr=0.0,
|
|
409
|
-
eval_latent_weight_decay=0.0,
|
|
534
|
+
trial=None,
|
|
535
|
+
class_weights=getattr(self, "class_weights_", None),
|
|
536
|
+
gamma_schedule=gamma_schedule,
|
|
410
537
|
)
|
|
411
538
|
|
|
412
539
|
if trained_model is None:
|
|
413
|
-
msg = "
|
|
540
|
+
msg = f"{self.model_name} training failed."
|
|
414
541
|
self.logger.error(msg)
|
|
415
542
|
raise RuntimeError(msg)
|
|
416
543
|
|
|
@@ -419,217 +546,194 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
419
546
|
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
420
547
|
)
|
|
421
548
|
|
|
422
|
-
|
|
423
|
-
"Train":
|
|
424
|
-
|
|
425
|
-
|
|
549
|
+
if history is None:
|
|
550
|
+
hist = {"Train": []}
|
|
551
|
+
elif isinstance(history, dict):
|
|
552
|
+
hist = dict(history)
|
|
553
|
+
else:
|
|
554
|
+
hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
|
|
555
|
+
|
|
556
|
+
self.best_loss_ = float(loss)
|
|
557
|
+
self.model_ = trained_model
|
|
558
|
+
self.history_ = hist
|
|
426
559
|
self.is_fit_ = True
|
|
427
560
|
|
|
428
|
-
# Evaluate on
|
|
429
|
-
eval_mask = (
|
|
430
|
-
self.sim_mask_test_
|
|
431
|
-
if (self.simulate_missing and self.sim_mask_test_ is not None)
|
|
432
|
-
else None
|
|
433
|
-
)
|
|
561
|
+
# Evaluate on simulated-missing sites only
|
|
434
562
|
self._evaluate_model(
|
|
435
|
-
self.
|
|
563
|
+
self.model_,
|
|
564
|
+
X=self.X_test_,
|
|
565
|
+
y=self.y_test_,
|
|
566
|
+
eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
|
|
567
|
+
objective_mode=False,
|
|
436
568
|
)
|
|
437
|
-
|
|
569
|
+
|
|
570
|
+
if self.show_plots:
|
|
571
|
+
self.plotter_.plot_history(self.history_)
|
|
572
|
+
|
|
438
573
|
self._save_best_params(self.best_params_)
|
|
439
574
|
|
|
575
|
+
if self.model_tuned_:
|
|
576
|
+
title = f"{self.model_name} Optimized Parameters"
|
|
577
|
+
|
|
578
|
+
if self.verbose or self.debug:
|
|
579
|
+
pm = PrettyMetrics(self.best_params_, precision=2, title=title)
|
|
580
|
+
pm.render()
|
|
581
|
+
|
|
582
|
+
# Save best parameters to a JSON file.
|
|
583
|
+
self._save_best_params(self.best_params_, objective_mode=True)
|
|
584
|
+
|
|
440
585
|
return self
|
|
441
586
|
|
|
442
587
|
def transform(self) -> np.ndarray:
|
|
443
|
-
"""Impute missing genotypes
|
|
588
|
+
"""Impute missing genotypes and return IUPAC strings.
|
|
444
589
|
|
|
445
|
-
This method
|
|
590
|
+
This method performs the following steps:
|
|
591
|
+
1. Validates that the model has been fitted.
|
|
592
|
+
2. Uses the trained model to predict missing genotypes for the entire dataset.
|
|
593
|
+
3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
|
|
594
|
+
4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
|
|
595
|
+
5. Checks for any remaining missing values or decoding issues, raising errors if found.
|
|
596
|
+
6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
|
|
597
|
+
7. Returns the imputed IUPAC genotype matrix.
|
|
446
598
|
|
|
447
599
|
Returns:
|
|
448
|
-
np.ndarray: IUPAC
|
|
600
|
+
np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
|
|
449
601
|
|
|
450
602
|
Raises:
|
|
451
603
|
NotFittedError: If called before fit().
|
|
604
|
+
RuntimeError: If any missing values remain or decoding yields "N".
|
|
605
|
+
RuntimeError: If loci contain 'N' after imputation due to missing REF/ALT metadata.
|
|
452
606
|
"""
|
|
453
607
|
if not getattr(self, "is_fit_", False):
|
|
454
|
-
|
|
608
|
+
msg = "Model is not fitted. Call fit() before transform()."
|
|
609
|
+
self.logger.error(msg)
|
|
610
|
+
raise NotFittedError(msg)
|
|
455
611
|
|
|
456
|
-
self.logger.info(f"Imputing entire dataset with {self.model_name}...")
|
|
612
|
+
self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
|
|
457
613
|
X_to_impute = self.ground_truth_.copy()
|
|
458
614
|
|
|
459
|
-
|
|
460
|
-
pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
|
|
615
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute)
|
|
461
616
|
|
|
462
|
-
|
|
463
|
-
missing_mask = X_to_impute == -1
|
|
617
|
+
missing_mask = X_to_impute < 0
|
|
464
618
|
imputed_array = X_to_impute.copy()
|
|
465
619
|
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
466
620
|
|
|
467
|
-
|
|
468
|
-
|
|
621
|
+
if np.any(imputed_array < 0):
|
|
622
|
+
msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
|
|
623
|
+
self.logger.error(msg)
|
|
624
|
+
raise RuntimeError(msg)
|
|
625
|
+
|
|
626
|
+
decode_input = imputed_array
|
|
627
|
+
if self.is_haploid_:
|
|
628
|
+
decode_input = imputed_array.copy()
|
|
629
|
+
decode_input[decode_input == 1] = 2
|
|
630
|
+
|
|
631
|
+
imputed_genotypes = self.decode_012(decode_input)
|
|
632
|
+
|
|
633
|
+
bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
|
|
634
|
+
if bad_loci.size > 0:
|
|
635
|
+
msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
|
|
636
|
+
self.logger.error(msg)
|
|
637
|
+
self.logger.debug(
|
|
638
|
+
"All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
|
|
639
|
+
)
|
|
640
|
+
raise RuntimeError(msg)
|
|
641
|
+
|
|
469
642
|
if self.show_plots:
|
|
470
|
-
|
|
643
|
+
original_input = X_to_impute
|
|
644
|
+
if self.is_haploid_:
|
|
645
|
+
original_input = X_to_impute.copy()
|
|
646
|
+
original_input[original_input == 1] = 2
|
|
647
|
+
|
|
648
|
+
original_genotypes = self.decode_012(original_input)
|
|
649
|
+
|
|
471
650
|
plt.rcParams.update(self.plotter_.param_dict)
|
|
472
651
|
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
473
652
|
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
474
653
|
|
|
475
654
|
return imputed_genotypes
|
|
476
655
|
|
|
477
|
-
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
478
|
-
"""Create DataLoader over indices + integer targets (-1 for missing).
|
|
479
|
-
|
|
480
|
-
This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
|
|
481
|
-
|
|
482
|
-
Args:
|
|
483
|
-
y (np.ndarray): 0/1/2 matrix with -1 for missing.
|
|
484
|
-
|
|
485
|
-
Returns:
|
|
486
|
-
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
487
|
-
"""
|
|
488
|
-
y_tensor = torch.from_numpy(y).long()
|
|
489
|
-
indices = torch.arange(len(y), dtype=torch.long)
|
|
490
|
-
dataset = torch.utils.data.TensorDataset(indices, y_tensor)
|
|
491
|
-
pin_memory = self.device.type == "cuda"
|
|
492
|
-
return torch.utils.data.DataLoader(
|
|
493
|
-
dataset,
|
|
494
|
-
batch_size=self.batch_size,
|
|
495
|
-
shuffle=True,
|
|
496
|
-
pin_memory=pin_memory,
|
|
497
|
-
)
|
|
498
|
-
|
|
499
656
|
def _train_and_validate_model(
|
|
500
657
|
self,
|
|
501
658
|
model: torch.nn.Module,
|
|
502
|
-
|
|
659
|
+
*,
|
|
503
660
|
lr: float,
|
|
504
661
|
l1_penalty: float,
|
|
505
|
-
trial: optuna.Trial
|
|
506
|
-
|
|
507
|
-
class_weights: torch.Tensor
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
eval_interval: int = 1,
|
|
514
|
-
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
515
|
-
eval_requires_latents: bool = False, # AE: always False
|
|
516
|
-
eval_latent_steps: int = 0,
|
|
517
|
-
eval_latent_lr: float = 0.0,
|
|
518
|
-
eval_latent_weight_decay: float = 0.0,
|
|
519
|
-
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
520
|
-
"""Wrap the AE training loop (no latent optimizer), with Optuna pruning.
|
|
521
|
-
|
|
522
|
-
This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
|
|
662
|
+
trial: Optional[optuna.Trial] = None,
|
|
663
|
+
params: Optional[dict[str, Any]] = None,
|
|
664
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
665
|
+
gamma_schedule: bool = False,
|
|
666
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
667
|
+
"""Train and validate the model.
|
|
668
|
+
|
|
669
|
+
This method sets up the optimizer and learning rate scheduler, then executes the training loop with early stopping and optional hyperparameter tuning via Optuna. It returns the best validation loss, the best model, and the training history.
|
|
523
670
|
|
|
524
671
|
Args:
|
|
525
672
|
model (torch.nn.Module): Autoencoder model.
|
|
526
|
-
loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
|
|
527
673
|
lr (float): Learning rate.
|
|
528
|
-
l1_penalty (float): L1 regularization
|
|
529
|
-
trial (optuna.Trial
|
|
530
|
-
|
|
531
|
-
class_weights (torch.Tensor
|
|
532
|
-
|
|
533
|
-
params (dict | None): Model params for evaluation.
|
|
534
|
-
prune_metric (str): Metric for pruning reports.
|
|
535
|
-
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
536
|
-
eval_interval (int): Eval frequency (epochs).
|
|
537
|
-
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
538
|
-
eval_latent_steps (int): Unused for AE.
|
|
539
|
-
eval_latent_lr (float): Unused for AE.
|
|
540
|
-
eval_latent_weight_decay (float): Unused for AE.
|
|
674
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
675
|
+
trial (Optional[optuna.Trial]): Optuna trial (optional).
|
|
676
|
+
params (Optional[dict[str, Any]]): Hyperparams dict (optional).
|
|
677
|
+
class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
|
|
678
|
+
gamma_schedule (bool): Whether to schedule gamma.
|
|
541
679
|
|
|
542
680
|
Returns:
|
|
543
|
-
|
|
681
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, history.
|
|
544
682
|
"""
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
self.logger.error(msg)
|
|
548
|
-
raise TypeError(msg)
|
|
683
|
+
max_epochs = int(self.epochs)
|
|
684
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
549
685
|
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
686
|
+
scheduler = _make_warmup_cosine_scheduler(
|
|
687
|
+
optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
|
|
553
688
|
)
|
|
554
689
|
|
|
555
|
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
556
|
-
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
557
|
-
|
|
558
690
|
best_loss, best_model, hist = self._execute_training_loop(
|
|
559
|
-
loader=loader,
|
|
560
691
|
optimizer=optimizer,
|
|
561
692
|
scheduler=scheduler,
|
|
562
693
|
model=model,
|
|
563
694
|
l1_penalty=l1_penalty,
|
|
564
695
|
trial=trial,
|
|
565
|
-
return_history=return_history,
|
|
566
|
-
class_weights=class_weights,
|
|
567
|
-
X_val=X_val,
|
|
568
696
|
params=params,
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
eval_interval=eval_interval,
|
|
572
|
-
eval_requires_latents=False, # AE: no latent inference
|
|
573
|
-
eval_latent_steps=0,
|
|
574
|
-
eval_latent_lr=0.0,
|
|
575
|
-
eval_latent_weight_decay=0.0,
|
|
697
|
+
class_weights=class_weights,
|
|
698
|
+
gamma_schedule=gamma_schedule,
|
|
576
699
|
)
|
|
577
|
-
|
|
578
|
-
return best_loss, best_model, hist
|
|
579
|
-
|
|
580
|
-
return best_loss, best_model, None
|
|
700
|
+
return best_loss, best_model, hist
|
|
581
701
|
|
|
582
702
|
def _execute_training_loop(
|
|
583
703
|
self,
|
|
584
|
-
|
|
704
|
+
*,
|
|
585
705
|
optimizer: torch.optim.Optimizer,
|
|
586
|
-
scheduler:
|
|
706
|
+
scheduler: (
|
|
707
|
+
torch.optim.lr_scheduler.CosineAnnealingLR
|
|
708
|
+
| torch.optim.lr_scheduler.SequentialLR
|
|
709
|
+
),
|
|
587
710
|
model: torch.nn.Module,
|
|
588
711
|
l1_penalty: float,
|
|
589
|
-
trial: optuna.Trial
|
|
590
|
-
|
|
591
|
-
class_weights: torch.Tensor,
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
prune_metric: str = "f1",
|
|
596
|
-
prune_warmup_epochs: int = 3,
|
|
597
|
-
eval_interval: int = 1,
|
|
598
|
-
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
599
|
-
eval_requires_latents: bool = False, # AE: False
|
|
600
|
-
eval_latent_steps: int = 0,
|
|
601
|
-
eval_latent_lr: float = 0.0,
|
|
602
|
-
eval_latent_weight_decay: float = 0.0,
|
|
603
|
-
) -> Tuple[float, torch.nn.Module, list]:
|
|
604
|
-
"""Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
|
|
605
|
-
|
|
606
|
-
This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
|
|
712
|
+
trial: Optional[optuna.Trial] = None,
|
|
713
|
+
params: Optional[dict[str, Any]] = None,
|
|
714
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
715
|
+
gamma_schedule: bool = False,
|
|
716
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
717
|
+
"""Train AE (masked focal CE) with EarlyStopping + Optuna pruning.
|
|
607
718
|
|
|
608
719
|
Args:
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
|
|
720
|
+
optimizer (torch.optim.Optimizer): Optimizer for training.
|
|
721
|
+
scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): LR scheduler.
|
|
612
722
|
model (torch.nn.Module): Autoencoder model.
|
|
613
|
-
l1_penalty (float): L1 regularization
|
|
614
|
-
trial (optuna.Trial
|
|
615
|
-
|
|
616
|
-
class_weights (torch.Tensor): Class weights
|
|
617
|
-
|
|
618
|
-
params (dict | None): Model params for evaluation.
|
|
619
|
-
prune_metric (str): Metric for pruning reports.
|
|
620
|
-
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
621
|
-
eval_interval (int): Eval frequency (epochs).
|
|
622
|
-
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
623
|
-
eval_latent_steps (int): Unused for AE.
|
|
624
|
-
eval_latent_lr (float): Unused for AE.
|
|
625
|
-
eval_latent_weight_decay (float): Unused for AE.
|
|
723
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
724
|
+
trial (Optional[optuna.Trial]): Optuna trial (optional).
|
|
725
|
+
params (Optional[dict[str, Any]]): Hyperparams dict (optional).
|
|
726
|
+
class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
|
|
727
|
+
gamma_schedule (bool): Whether to schedule gamma.
|
|
626
728
|
|
|
627
729
|
Returns:
|
|
628
|
-
|
|
730
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]: Best loss, best model, and training history.
|
|
731
|
+
|
|
732
|
+
Notes:
|
|
733
|
+
- Computes loss only where targets are known (~orig_mask_*).
|
|
734
|
+
- Evaluates metrics only on simulated-missing sites (sim_mask_*).
|
|
629
735
|
"""
|
|
630
|
-
|
|
631
|
-
best_model = None
|
|
632
|
-
history: list[float] = []
|
|
736
|
+
history: dict[str, list[float]] = defaultdict(list)
|
|
633
737
|
|
|
634
738
|
early_stopping = EarlyStopping(
|
|
635
739
|
patience=self.early_stop_gen,
|
|
@@ -639,146 +743,157 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
639
743
|
debug=self.debug,
|
|
640
744
|
)
|
|
641
745
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
model.gamma = 0.0 # type: ignore
|
|
662
|
-
elif epoch < gamma_warm + gamma_ramp:
|
|
663
|
-
model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore
|
|
746
|
+
gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
|
|
747
|
+
params, "gamma", default=self.gamma, max_epochs=self.epochs
|
|
748
|
+
)
|
|
749
|
+
gamma_target = float(gamma_target)
|
|
750
|
+
|
|
751
|
+
cw = class_weights
|
|
752
|
+
if cw is not None and cw.device != self.device:
|
|
753
|
+
cw = cw.to(self.device)
|
|
754
|
+
|
|
755
|
+
for epoch in range(int(self.epochs)):
|
|
756
|
+
if gamma_schedule:
|
|
757
|
+
gamma_current = self._update_anneal_schedule(
|
|
758
|
+
gamma_target,
|
|
759
|
+
warm=gamma_warm,
|
|
760
|
+
ramp=gamma_ramp,
|
|
761
|
+
epoch=epoch,
|
|
762
|
+
init_val=0.0,
|
|
763
|
+
)
|
|
764
|
+
gamma_val = float(gamma_current)
|
|
664
765
|
else:
|
|
665
|
-
|
|
766
|
+
gamma_val = gamma_target
|
|
666
767
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
for g in optimizer.param_groups:
|
|
671
|
-
g["lr"] = min_lr + (base_lr - min_lr) * scale
|
|
768
|
+
ce_criterion = FocalCELoss(
|
|
769
|
+
alpha=cw, gamma=gamma_val, ignore_index=-1, reduction="mean"
|
|
770
|
+
)
|
|
672
771
|
|
|
673
772
|
train_loss = self._train_step(
|
|
674
|
-
loader=
|
|
773
|
+
loader=self.train_loader_,
|
|
675
774
|
optimizer=optimizer,
|
|
676
775
|
model=model,
|
|
776
|
+
ce_criterion=ce_criterion,
|
|
677
777
|
l1_penalty=l1_penalty,
|
|
678
|
-
class_weights=class_weights,
|
|
679
778
|
)
|
|
680
779
|
|
|
681
|
-
# Abort or prune on non-finite epoch loss
|
|
682
780
|
if not np.isfinite(train_loss):
|
|
683
781
|
if trial is not None:
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
782
|
+
msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
|
|
783
|
+
self.logger.warning(msg)
|
|
784
|
+
raise optuna.exceptions.TrialPruned(msg)
|
|
785
|
+
msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
|
|
786
|
+
self.logger.error(msg)
|
|
787
|
+
raise RuntimeError(msg)
|
|
788
|
+
|
|
789
|
+
val_loss = self._val_step(
|
|
790
|
+
loader=self.val_loader_,
|
|
791
|
+
model=model,
|
|
792
|
+
ce_criterion=ce_criterion,
|
|
793
|
+
l1_penalty=l1_penalty,
|
|
794
|
+
)
|
|
692
795
|
|
|
693
796
|
scheduler.step()
|
|
694
|
-
|
|
695
|
-
|
|
797
|
+
history["Train"].append(float(train_loss))
|
|
798
|
+
history["Val"].append(float(val_loss))
|
|
696
799
|
|
|
697
|
-
early_stopping(
|
|
800
|
+
early_stopping(val_loss, model)
|
|
698
801
|
if early_stopping.early_stop:
|
|
699
|
-
self.logger.
|
|
802
|
+
self.logger.debug(
|
|
803
|
+
f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
|
|
804
|
+
)
|
|
700
805
|
break
|
|
701
806
|
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
trial is not None
|
|
705
|
-
and X_val is not None
|
|
706
|
-
and ((epoch + 1) % eval_interval == 0)
|
|
707
|
-
):
|
|
708
|
-
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
709
|
-
mask_override = None
|
|
710
|
-
if (
|
|
711
|
-
self.simulate_missing
|
|
712
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
713
|
-
and getattr(self, "X_val_", None) is not None
|
|
714
|
-
and X_val.shape == self.X_val_.shape
|
|
715
|
-
):
|
|
716
|
-
mask_override = self.sim_mask_test_
|
|
717
|
-
metric_val = self._eval_for_pruning(
|
|
807
|
+
if trial is not None:
|
|
808
|
+
metric_vals = self._evaluate_model(
|
|
718
809
|
model=model,
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
810
|
+
X=self.X_val_,
|
|
811
|
+
y=self.y_val_,
|
|
812
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
722
813
|
objective_mode=True,
|
|
723
|
-
do_latent_infer=False, # AE: False
|
|
724
|
-
latent_steps=0,
|
|
725
|
-
latent_lr=0.0,
|
|
726
|
-
latent_weight_decay=0.0,
|
|
727
|
-
latent_seed=self.seed, # type: ignore
|
|
728
|
-
_latent_cache=None, # AE: not used
|
|
729
|
-
_latent_cache_key=None,
|
|
730
|
-
eval_mask_override=mask_override,
|
|
731
814
|
)
|
|
732
|
-
trial.report(
|
|
733
|
-
if
|
|
815
|
+
trial.report(metric_vals[self.tune_metric], step=epoch + 1)
|
|
816
|
+
if trial.should_prune():
|
|
734
817
|
raise optuna.exceptions.TrialPruned(
|
|
735
|
-
f"
|
|
818
|
+
f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
|
|
736
819
|
)
|
|
737
820
|
|
|
738
|
-
best_loss = early_stopping.best_score
|
|
739
|
-
if early_stopping.
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
return best_loss, best_model, history
|
|
821
|
+
best_loss = float(early_stopping.best_score)
|
|
822
|
+
if early_stopping.best_state_dict is not None:
|
|
823
|
+
model.load_state_dict(early_stopping.best_state_dict)
|
|
824
|
+
|
|
825
|
+
return best_loss, model, dict(history)
|
|
744
826
|
|
|
745
827
|
def _train_step(
|
|
746
828
|
self,
|
|
747
829
|
loader: torch.utils.data.DataLoader,
|
|
748
830
|
optimizer: torch.optim.Optimizer,
|
|
749
831
|
model: torch.nn.Module,
|
|
832
|
+
ce_criterion: torch.nn.Module,
|
|
833
|
+
*,
|
|
750
834
|
l1_penalty: float,
|
|
751
|
-
class_weights: torch.Tensor,
|
|
752
835
|
) -> float:
|
|
753
|
-
"""
|
|
836
|
+
"""Single epoch train step (masked focal CE + optional L1).
|
|
837
|
+
|
|
838
|
+
Args:
|
|
839
|
+
loader (torch.utils.data.DataLoader): Training data loader.
|
|
840
|
+
optimizer (torch.optim.Optimizer): Optimizer for training.
|
|
841
|
+
model (torch.nn.Module): Autoencoder model.
|
|
842
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
843
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
844
|
+
|
|
845
|
+
Returns:
|
|
846
|
+
float: Average training loss over the epoch.
|
|
847
|
+
|
|
848
|
+
Notes:
|
|
849
|
+
Expects loader batches as (X_ohe, y_int, mask_bool) where:
|
|
850
|
+
- X_ohe: (B, L, C) float/compatible
|
|
851
|
+
- y_int: (B, L) int, with -1 for unknown targets
|
|
852
|
+
- mask_bool: (B, L) bool selecting which positions contribute to loss
|
|
853
|
+
"""
|
|
754
854
|
model.train()
|
|
755
855
|
running = 0.0
|
|
756
856
|
num_batches = 0
|
|
757
|
-
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
758
|
-
if class_weights is not None and class_weights.device != self.device:
|
|
759
|
-
class_weights = class_weights.to(self.device)
|
|
760
857
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
criterion = SafeFocalCELoss(gamma=gamma, weight=class_weights, ignore_index=-1)
|
|
858
|
+
nF_model = int(getattr(model, "n_features", self.num_features_))
|
|
859
|
+
nC_model = int(getattr(model, "num_classes", self.num_classes_))
|
|
860
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
765
861
|
|
|
766
|
-
for
|
|
862
|
+
for X_batch, y_batch, m_batch in loader:
|
|
767
863
|
optimizer.zero_grad(set_to_none=True)
|
|
768
|
-
|
|
864
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
865
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
866
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
769
867
|
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
868
|
+
if (
|
|
869
|
+
X_batch.dim() != 3
|
|
870
|
+
or X_batch.shape[1] != nF_model
|
|
871
|
+
or X_batch.shape[2] != nC_model
|
|
872
|
+
):
|
|
873
|
+
msg = (
|
|
874
|
+
f"Train batch X shape mismatch: expected (B,{nF_model},{nC_model}), "
|
|
875
|
+
f"got {tuple(X_batch.shape)}."
|
|
876
|
+
)
|
|
877
|
+
self.logger.error(msg)
|
|
878
|
+
raise ValueError(msg)
|
|
775
879
|
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
880
|
+
logits_flat = model(X_batch)
|
|
881
|
+
expected = (X_batch.shape[0], nF_model * nC_model)
|
|
882
|
+
if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
|
|
883
|
+
try:
|
|
884
|
+
logits_flat = logits_flat.view(*expected)
|
|
885
|
+
except Exception as e:
|
|
886
|
+
msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
|
|
887
|
+
self.logger.error(msg)
|
|
888
|
+
raise ValueError(msg) from e
|
|
780
889
|
|
|
781
|
-
|
|
890
|
+
logits = logits_flat.view(-1, nF_model, nC_model)
|
|
891
|
+
logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
|
|
892
|
+
|
|
893
|
+
targets_masked = y_batch.view(-1)
|
|
894
|
+
targets_masked = targets_masked[m_batch.view(-1)]
|
|
895
|
+
|
|
896
|
+
loss = ce_criterion(logits_masked, targets_masked)
|
|
782
897
|
|
|
783
898
|
if l1_penalty > 0:
|
|
784
899
|
l1 = torch.zeros((), device=self.device)
|
|
@@ -786,194 +901,234 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
786
901
|
l1 = l1 + p.abs().sum()
|
|
787
902
|
loss = loss + l1_penalty * l1
|
|
788
903
|
|
|
789
|
-
# Final guard
|
|
790
904
|
if not torch.isfinite(loss):
|
|
791
905
|
continue
|
|
792
906
|
|
|
793
907
|
loss.backward()
|
|
794
|
-
|
|
795
|
-
# Clip to prevent exploding grads
|
|
796
908
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
797
|
-
|
|
798
|
-
# If grads blew up to non-finite, skip update
|
|
799
|
-
if any(
|
|
800
|
-
(not torch.isfinite(p.grad).all())
|
|
801
|
-
for p in model.parameters()
|
|
802
|
-
if p.grad is not None
|
|
803
|
-
):
|
|
804
|
-
optimizer.zero_grad(set_to_none=True)
|
|
805
|
-
continue
|
|
806
|
-
|
|
807
909
|
optimizer.step()
|
|
808
910
|
|
|
809
911
|
running += float(loss.detach().item())
|
|
810
912
|
num_batches += 1
|
|
811
913
|
|
|
812
|
-
if num_batches == 0
|
|
813
|
-
|
|
814
|
-
|
|
914
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
915
|
+
|
|
916
|
+
def _val_step(
|
|
917
|
+
self,
|
|
918
|
+
loader: torch.utils.data.DataLoader,
|
|
919
|
+
model: torch.nn.Module,
|
|
920
|
+
ce_criterion: torch.nn.Module,
|
|
921
|
+
*,
|
|
922
|
+
l1_penalty: float,
|
|
923
|
+
) -> float:
|
|
924
|
+
"""Validation step (masked focal CE + optional L1).
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
loader (torch.utils.data.DataLoader): Validation data loader.
|
|
928
|
+
model (torch.nn.Module): Autoencoder model.
|
|
929
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
930
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
931
|
+
|
|
932
|
+
Returns:
|
|
933
|
+
float: Average validation loss over the epoch.
|
|
934
|
+
"""
|
|
935
|
+
model.eval()
|
|
936
|
+
running = 0.0
|
|
937
|
+
num_batches = 0
|
|
938
|
+
|
|
939
|
+
nF_model = self.num_features_
|
|
940
|
+
nC_model = self.num_classes_
|
|
941
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
942
|
+
|
|
943
|
+
with torch.no_grad():
|
|
944
|
+
for X_batch, y_batch, m_batch in loader:
|
|
945
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
946
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
947
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
948
|
+
|
|
949
|
+
logits_flat = model(X_batch)
|
|
950
|
+
expected = (X_batch.shape[0], nF_model * nC_model)
|
|
951
|
+
|
|
952
|
+
if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
|
|
953
|
+
try:
|
|
954
|
+
logits_flat = logits_flat.view(*expected)
|
|
955
|
+
except Exception as e:
|
|
956
|
+
msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
|
|
957
|
+
self.logger.error(msg)
|
|
958
|
+
raise ValueError(msg) from e
|
|
959
|
+
|
|
960
|
+
logits = logits_flat.view(-1, nF_model, nC_model)
|
|
961
|
+
logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
|
|
962
|
+
targets_masked = y_batch.view(-1)[m_batch.view(-1)]
|
|
963
|
+
|
|
964
|
+
if targets_masked.numel() == 0:
|
|
965
|
+
continue
|
|
966
|
+
|
|
967
|
+
loss = ce_criterion(logits_masked, targets_masked)
|
|
968
|
+
|
|
969
|
+
if l1_penalty > 0:
|
|
970
|
+
l1 = torch.zeros((), device=self.device)
|
|
971
|
+
for p in l1_params:
|
|
972
|
+
l1 = l1 + p.abs().sum()
|
|
973
|
+
loss = loss + l1_penalty * l1
|
|
974
|
+
|
|
975
|
+
if not torch.isfinite(loss):
|
|
976
|
+
continue
|
|
977
|
+
|
|
978
|
+
running += float(loss.item())
|
|
979
|
+
num_batches += 1
|
|
980
|
+
|
|
981
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
815
982
|
|
|
816
983
|
def _predict(
|
|
817
984
|
self,
|
|
818
985
|
model: torch.nn.Module,
|
|
819
986
|
X: np.ndarray | torch.Tensor,
|
|
987
|
+
*,
|
|
820
988
|
return_proba: bool = False,
|
|
821
|
-
) ->
|
|
822
|
-
"""Predict
|
|
823
|
-
|
|
824
|
-
This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
|
|
989
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
990
|
+
"""Predict categorical genotype labels from logits.
|
|
825
991
|
|
|
826
992
|
Args:
|
|
827
993
|
model (torch.nn.Module): Trained model.
|
|
828
|
-
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
|
|
829
|
-
|
|
830
|
-
return_proba (bool): If True, return probabilities.
|
|
994
|
+
X (np.ndarray | torch.Tensor): 2D 0/1/2 matrix with -1 for missing, or 3D one-hot (B, L, K).
|
|
995
|
+
return_proba (bool): If True, return probabilities (B, L, K).
|
|
831
996
|
|
|
832
997
|
Returns:
|
|
833
|
-
|
|
834
|
-
and probabilities if requested.
|
|
998
|
+
tuple[np.ndarray, np.ndarray | None]: Predicted labels and optionally probabilities.
|
|
835
999
|
"""
|
|
836
1000
|
if model is None:
|
|
837
1001
|
msg = "Model is not trained. Call fit() before predict()."
|
|
838
1002
|
self.logger.error(msg)
|
|
839
1003
|
raise NotFittedError(msg)
|
|
840
1004
|
|
|
1005
|
+
nF = self.num_features_
|
|
1006
|
+
nC = self.num_classes_
|
|
1007
|
+
|
|
1008
|
+
if isinstance(X, torch.Tensor):
|
|
1009
|
+
X_tensor = X
|
|
1010
|
+
else:
|
|
1011
|
+
X_tensor = torch.from_numpy(X)
|
|
1012
|
+
X_tensor = X_tensor.float()
|
|
1013
|
+
|
|
1014
|
+
if X_tensor.device != self.device:
|
|
1015
|
+
X_tensor = X_tensor.to(self.device)
|
|
1016
|
+
|
|
1017
|
+
if X_tensor.dim() == 2:
|
|
1018
|
+
# 0/1/2 matrix -> one-hot for model input
|
|
1019
|
+
X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
|
|
1020
|
+
X_tensor = X_tensor.float()
|
|
1021
|
+
if X_tensor.device != self.device:
|
|
1022
|
+
X_tensor = X_tensor.to(self.device)
|
|
1023
|
+
|
|
1024
|
+
elif X_tensor.dim() != 3:
|
|
1025
|
+
msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
|
|
1026
|
+
self.logger.error(msg)
|
|
1027
|
+
raise ValueError(msg)
|
|
1028
|
+
|
|
1029
|
+
if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
|
|
1030
|
+
msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
|
|
1031
|
+
self.logger.error(msg)
|
|
1032
|
+
raise ValueError(msg)
|
|
1033
|
+
|
|
1034
|
+
X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
|
|
1035
|
+
|
|
841
1036
|
model.eval()
|
|
842
1037
|
with torch.no_grad():
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
|
|
1038
|
+
logits_flat = model(X_tensor)
|
|
1039
|
+
logits = logits_flat.view(-1, nF, nC)
|
|
1040
|
+
|
|
847
1041
|
probas = torch.softmax(logits, dim=-1)
|
|
848
1042
|
labels = torch.argmax(probas, dim=-1)
|
|
849
1043
|
|
|
850
1044
|
if return_proba:
|
|
851
1045
|
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
852
|
-
|
|
853
|
-
return labels.cpu().numpy()
|
|
1046
|
+
return labels.cpu().numpy(), None
|
|
854
1047
|
|
|
855
1048
|
def _evaluate_model(
|
|
856
1049
|
self,
|
|
857
|
-
X_val: np.ndarray,
|
|
858
1050
|
model: torch.nn.Module,
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
1051
|
+
X: np.ndarray,
|
|
1052
|
+
y: np.ndarray,
|
|
1053
|
+
eval_mask: np.ndarray,
|
|
862
1054
|
*,
|
|
863
|
-
|
|
1055
|
+
objective_mode: bool = False,
|
|
864
1056
|
) -> Dict[str, float]:
|
|
865
1057
|
"""Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
|
|
866
1058
|
|
|
867
|
-
This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
|
|
868
|
-
|
|
869
1059
|
Args:
|
|
870
|
-
X_val (np.ndarray): Validation set 0/1/2 matrix with -1
|
|
871
|
-
for missing.
|
|
872
1060
|
model (torch.nn.Module): Trained model.
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
1061
|
+
X (np.ndarray): 2D 0/1/2 matrix with -1 for missing.
|
|
1062
|
+
y (np.ndarray): 2D 0/1/2 ground truth matrix with -1 for missing.
|
|
1063
|
+
eval_mask (np.ndarray): 2D boolean mask selecting sites to evaluate.
|
|
1064
|
+
objective_mode (bool): If True, suppress detailed reports and plots.
|
|
877
1065
|
|
|
878
1066
|
Returns:
|
|
879
1067
|
Dict[str, float]: Dictionary of evaluation metrics.
|
|
880
1068
|
"""
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
1069
|
+
if model is None:
|
|
1070
|
+
msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
|
|
1071
|
+
self.logger.error(msg)
|
|
1072
|
+
raise NotFittedError(msg)
|
|
884
1073
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
# FIX 1: Check ROWS (shape[0]) only. X_val might be a feature subset.
|
|
888
|
-
if (
|
|
889
|
-
hasattr(self, "X_val_")
|
|
890
|
-
and getattr(self, "X_val_", None) is not None
|
|
891
|
-
and X_val.shape[0] == self.X_val_.shape[0]
|
|
892
|
-
):
|
|
893
|
-
GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
|
|
894
|
-
elif (
|
|
895
|
-
hasattr(self, "X_train_")
|
|
896
|
-
and getattr(self, "X_train_", None) is not None
|
|
897
|
-
and X_val.shape[0] == self.X_train_.shape[0]
|
|
898
|
-
):
|
|
899
|
-
GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
|
|
900
|
-
else:
|
|
901
|
-
GT_ref = self.ground_truth_
|
|
902
|
-
|
|
903
|
-
# FIX 2: Handle Feature Mismatch (e.g., tune_fast feature subsetting)
|
|
904
|
-
# If the GT source has more columns than X_val, slice it to match.
|
|
905
|
-
if GT_ref.shape[1] > X_val.shape[1]:
|
|
906
|
-
GT_ref = GT_ref[:, : X_val.shape[1]]
|
|
907
|
-
|
|
908
|
-
# Fallback if rows mismatch (unlikely after Fix 1, but safe to keep)
|
|
909
|
-
if GT_ref.shape != X_val.shape:
|
|
910
|
-
# If completely different, we can't use the ground truth object.
|
|
911
|
-
# Fall back to X_val (this implies only observed values are scored)
|
|
912
|
-
GT_ref = X_val
|
|
913
|
-
|
|
914
|
-
if eval_mask_override is not None:
|
|
915
|
-
# FIX 3: Allow override mask to be sliced if it's too wide
|
|
916
|
-
if eval_mask_override.shape[0] != X_val.shape[0]:
|
|
917
|
-
msg = (
|
|
918
|
-
f"eval_mask_override rows {eval_mask_override.shape[0]} "
|
|
919
|
-
f"does not match X_val rows {X_val.shape[0]}"
|
|
920
|
-
)
|
|
921
|
-
self.logger.error(msg)
|
|
922
|
-
raise ValueError(msg)
|
|
1074
|
+
pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
|
|
923
1075
|
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
else:
|
|
929
|
-
eval_mask = X_val != -1
|
|
1076
|
+
if pred_probas is None:
|
|
1077
|
+
msg = "Predicted probabilities are None in _evaluate_model()."
|
|
1078
|
+
self.logger.error(msg)
|
|
1079
|
+
raise ValueError(msg)
|
|
930
1080
|
|
|
931
|
-
|
|
932
|
-
|
|
1081
|
+
y_true_flat = y[eval_mask].astype(np.int8, copy=False)
|
|
1082
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
|
|
1083
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
|
|
933
1084
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
1085
|
+
valid = y_true_flat >= 0
|
|
1086
|
+
y_true_flat = y_true_flat[valid]
|
|
1087
|
+
y_pred_flat = y_pred_flat[valid]
|
|
1088
|
+
y_proba_flat = y_proba_flat[valid]
|
|
937
1089
|
|
|
938
1090
|
if y_true_flat.size == 0:
|
|
939
|
-
self.tune_metric = "f1" if self.tune_metric is None else self.tune_metric
|
|
940
1091
|
return {self.tune_metric: 0.0}
|
|
941
1092
|
|
|
942
|
-
|
|
1093
|
+
if y_proba_flat.ndim != 2:
|
|
1094
|
+
msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got {y_proba_flat.shape}."
|
|
1095
|
+
self.logger.error(msg)
|
|
1096
|
+
raise ValueError(msg)
|
|
1097
|
+
|
|
1098
|
+
K = int(y_proba_flat.shape[1])
|
|
1099
|
+
if self.is_haploid_:
|
|
1100
|
+
if K not in (2, 3):
|
|
1101
|
+
msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
|
|
1102
|
+
self.logger.error(msg)
|
|
1103
|
+
raise ValueError(msg)
|
|
1104
|
+
else:
|
|
1105
|
+
if K != 3:
|
|
1106
|
+
msg = f"Diploid evaluation expects 3 classes; got {K}."
|
|
1107
|
+
self.logger.error(msg)
|
|
1108
|
+
raise ValueError(msg)
|
|
1109
|
+
|
|
1110
|
+
if self.is_haploid_:
|
|
1111
|
+
y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
|
|
1112
|
+
y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
|
|
1113
|
+
|
|
1114
|
+
if K == 3:
|
|
1115
|
+
proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
|
|
1116
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
1117
|
+
proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
|
|
1118
|
+
y_proba_flat = proba_2
|
|
1119
|
+
|
|
1120
|
+
labels_for_scoring = [0, 1]
|
|
1121
|
+
target_names = ["REF", "ALT"]
|
|
1122
|
+
else:
|
|
1123
|
+
labels_for_scoring = [0, 1, 2]
|
|
1124
|
+
target_names = ["REF", "HET", "ALT"]
|
|
1125
|
+
|
|
943
1126
|
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
944
1127
|
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
945
|
-
row_sums[row_sums == 0] = 1.0
|
|
1128
|
+
row_sums[row_sums == 0.0] = 1.0
|
|
946
1129
|
y_proba_flat = y_proba_flat / row_sums
|
|
947
1130
|
|
|
948
|
-
|
|
949
|
-
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
950
|
-
|
|
951
|
-
if self.is_haploid:
|
|
952
|
-
y_true_flat = y_true_flat.copy()
|
|
953
|
-
y_pred_flat = y_pred_flat.copy()
|
|
954
|
-
y_true_flat[y_true_flat == 2] = 1
|
|
955
|
-
y_pred_flat[y_pred_flat == 2] = 1
|
|
956
|
-
# collapse probs to 2-class
|
|
957
|
-
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
958
|
-
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
959
|
-
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
960
|
-
y_proba_flat = proba_2
|
|
961
|
-
|
|
962
|
-
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
963
|
-
|
|
964
|
-
tune_metric_tmp: Literal[
|
|
965
|
-
"pr_macro",
|
|
966
|
-
"roc_auc",
|
|
967
|
-
"average_precision",
|
|
968
|
-
"accuracy",
|
|
969
|
-
"f1",
|
|
970
|
-
"precision",
|
|
971
|
-
"recall",
|
|
972
|
-
]
|
|
973
|
-
if self.tune_metric_ is not None:
|
|
974
|
-
tune_metric_tmp = self.tune_metric_
|
|
975
|
-
else:
|
|
976
|
-
tune_metric_tmp = "f1" # Default if not tuning
|
|
1131
|
+
y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
|
|
977
1132
|
|
|
978
1133
|
metrics = self.scorers_.evaluate(
|
|
979
1134
|
y_true_flat,
|
|
@@ -981,16 +1136,29 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
981
1136
|
y_true_ohe,
|
|
982
1137
|
y_proba_flat,
|
|
983
1138
|
objective_mode,
|
|
984
|
-
|
|
1139
|
+
cast(
|
|
1140
|
+
Literal[
|
|
1141
|
+
"pr_macro",
|
|
1142
|
+
"roc_auc",
|
|
1143
|
+
"accuracy",
|
|
1144
|
+
"f1",
|
|
1145
|
+
"average_precision",
|
|
1146
|
+
"precision",
|
|
1147
|
+
"recall",
|
|
1148
|
+
"mcc",
|
|
1149
|
+
"jaccard",
|
|
1150
|
+
],
|
|
1151
|
+
self.tune_metric,
|
|
1152
|
+
),
|
|
985
1153
|
)
|
|
986
1154
|
|
|
987
1155
|
if not objective_mode:
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
1156
|
+
if self.verbose or self.debug:
|
|
1157
|
+
pm = PrettyMetrics(
|
|
1158
|
+
metrics, precision=2, title=f"{self.model_name} Validation Metrics"
|
|
1159
|
+
)
|
|
1160
|
+
pm.render()
|
|
992
1161
|
|
|
993
|
-
# Primary report (REF/HET/ALT or REF/ALT)
|
|
994
1162
|
self._make_class_reports(
|
|
995
1163
|
y_true=y_true_flat,
|
|
996
1164
|
y_pred_proba=y_proba_flat,
|
|
@@ -999,18 +1167,15 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
999
1167
|
labels=target_names,
|
|
1000
1168
|
)
|
|
1001
1169
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
y_true_dec = self.pgenc.decode_012(
|
|
1005
|
-
GT_ref.reshape(X_val.shape[0], X_val.shape[1])
|
|
1006
|
-
)
|
|
1007
|
-
X_pred = X_val.copy()
|
|
1008
|
-
X_pred[eval_mask] = y_pred_flat
|
|
1170
|
+
y_true_matrix = np.array(y, copy=True)
|
|
1171
|
+
y_pred_matrix = np.array(pred_labels, copy=True)
|
|
1009
1172
|
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1173
|
+
if self.is_haploid_:
|
|
1174
|
+
y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
|
|
1175
|
+
y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
|
|
1176
|
+
|
|
1177
|
+
y_true_dec = self.decode_012(y_true_matrix)
|
|
1178
|
+
y_pred_dec = self.decode_012(y_pred_matrix)
|
|
1014
1179
|
|
|
1015
1180
|
encodings_dict = {
|
|
1016
1181
|
"A": 0,
|
|
@@ -1049,237 +1214,177 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1049
1214
|
return metrics
|
|
1050
1215
|
|
|
1051
1216
|
def _objective(self, trial: optuna.Trial) -> float:
|
|
1052
|
-
"""Optuna objective for AE
|
|
1053
|
-
|
|
1054
|
-
This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
|
|
1217
|
+
"""Optuna objective for AE (mirrors VAE flow, excluding KL-specific parts).
|
|
1055
1218
|
|
|
1056
1219
|
Args:
|
|
1057
|
-
trial (optuna.Trial): Optuna trial.
|
|
1220
|
+
trial (optuna.Trial): Optuna trial object.
|
|
1058
1221
|
|
|
1059
1222
|
Returns:
|
|
1060
|
-
float: Value of the tuning metric
|
|
1223
|
+
float: Value of the tuning metric to optimize.
|
|
1061
1224
|
"""
|
|
1062
1225
|
try:
|
|
1063
|
-
# Sample hyperparameters (existing helper; unchanged signature)
|
|
1064
1226
|
params = self._sample_hyperparameters(trial)
|
|
1065
1227
|
|
|
1066
|
-
# Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
|
|
1067
|
-
X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
|
|
1068
|
-
X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
|
|
1069
|
-
|
|
1070
|
-
class_weights = self._normalize_class_weights(
|
|
1071
|
-
self._class_weights_from_zygosity(X_train)
|
|
1072
|
-
)
|
|
1073
|
-
train_loader = self._get_data_loaders(X_train)
|
|
1074
|
-
|
|
1075
1228
|
model = self.build_model(self.Model, params["model_params"])
|
|
1076
1229
|
model.apply(self.initialize_weights)
|
|
1077
1230
|
|
|
1078
|
-
lr
|
|
1079
|
-
l1_penalty
|
|
1231
|
+
lr = float(params["learning_rate"])
|
|
1232
|
+
l1_penalty = float(params["l1_penalty"])
|
|
1233
|
+
|
|
1234
|
+
class_weights = self._class_weights_from_zygosity(
|
|
1235
|
+
self.y_train_,
|
|
1236
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1237
|
+
inverse=params["inverse"],
|
|
1238
|
+
normalize=params["normalize"],
|
|
1239
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1240
|
+
power=params["power"],
|
|
1241
|
+
)
|
|
1080
1242
|
|
|
1081
|
-
|
|
1082
|
-
_, model, __ = self._train_and_validate_model(
|
|
1243
|
+
loss, model, _hist = self._train_and_validate_model(
|
|
1083
1244
|
model=model,
|
|
1084
|
-
loader=train_loader,
|
|
1085
1245
|
lr=lr,
|
|
1086
1246
|
l1_penalty=l1_penalty,
|
|
1247
|
+
params=params,
|
|
1087
1248
|
trial=trial,
|
|
1088
|
-
return_history=False,
|
|
1089
1249
|
class_weights=class_weights,
|
|
1090
|
-
|
|
1091
|
-
params=params,
|
|
1092
|
-
prune_metric=self.tune_metric,
|
|
1093
|
-
prune_warmup_epochs=5,
|
|
1094
|
-
eval_interval=self.tune_eval_interval,
|
|
1095
|
-
eval_requires_latents=False,
|
|
1096
|
-
eval_latent_steps=0,
|
|
1097
|
-
eval_latent_lr=0.0,
|
|
1098
|
-
eval_latent_weight_decay=0.0,
|
|
1250
|
+
gamma_schedule=params["gamma_schedule"],
|
|
1099
1251
|
)
|
|
1100
1252
|
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
1106
|
-
)
|
|
1107
|
-
else None
|
|
1108
|
-
)
|
|
1253
|
+
if model is None or not np.isfinite(loss):
|
|
1254
|
+
msg = "Model training returned None or non-finite loss in tuning objective."
|
|
1255
|
+
self.logger.error(msg)
|
|
1256
|
+
raise RuntimeError(msg)
|
|
1109
1257
|
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
)
|
|
1118
|
-
self._clear_resources(model, train_loader)
|
|
1119
|
-
else:
|
|
1120
|
-
raise TypeError("Model training failed; no model was returned.")
|
|
1258
|
+
metrics = self._evaluate_model(
|
|
1259
|
+
model=model,
|
|
1260
|
+
X=self.X_val_,
|
|
1261
|
+
y=self.y_val_,
|
|
1262
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
1263
|
+
objective_mode=True,
|
|
1264
|
+
)
|
|
1121
1265
|
|
|
1122
|
-
|
|
1266
|
+
self._clear_resources(model)
|
|
1267
|
+
return float(metrics[self.tune_metric])
|
|
1123
1268
|
|
|
1124
1269
|
except Exception as e:
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1270
|
+
err_type = type(e).__name__
|
|
1271
|
+
self.logger.warning(
|
|
1272
|
+
f"Trial {trial.number} failed due to exception {err_type}: {e}"
|
|
1273
|
+
)
|
|
1274
|
+
self.logger.debug(traceback.format_exc())
|
|
1275
|
+
raise optuna.exceptions.TrialPruned(
|
|
1276
|
+
f"Trial {trial.number} failed due to an exception. {err_type}: {e}. "
|
|
1277
|
+
"Enable debug logging for full traceback."
|
|
1278
|
+
) from e
|
|
1130
1279
|
|
|
1131
|
-
|
|
1280
|
+
def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
|
|
1281
|
+
"""Sample AE hyperparameters; hidden sizes mirror VAE helper (excluding KL).
|
|
1132
1282
|
|
|
1133
1283
|
Args:
|
|
1134
1284
|
trial (optuna.Trial): Optuna trial object.
|
|
1135
1285
|
|
|
1136
1286
|
Returns:
|
|
1137
|
-
|
|
1287
|
+
dict: Sampled hyperparameters.
|
|
1138
1288
|
"""
|
|
1139
1289
|
params = {
|
|
1140
|
-
"latent_dim": trial.suggest_int("latent_dim", 2,
|
|
1141
|
-
"
|
|
1142
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.
|
|
1143
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1,
|
|
1290
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
1291
|
+
"learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
|
|
1292
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
|
|
1293
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
|
|
1144
1294
|
"activation": trial.suggest_categorical(
|
|
1145
|
-
"activation", ["relu", "elu", "selu"]
|
|
1295
|
+
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
1146
1296
|
),
|
|
1147
|
-
"l1_penalty": trial.suggest_float("l1_penalty", 1e-
|
|
1297
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
1148
1298
|
"layer_scaling_factor": trial.suggest_float(
|
|
1149
|
-
"layer_scaling_factor", 2.0, 10.0
|
|
1299
|
+
"layer_scaling_factor", 2.0, 10.0, step=0.025
|
|
1150
1300
|
),
|
|
1151
1301
|
"layer_schedule": trial.suggest_categorical(
|
|
1152
|
-
"layer_schedule", ["pyramid", "
|
|
1302
|
+
"layer_schedule", ["pyramid", "linear"]
|
|
1303
|
+
),
|
|
1304
|
+
"power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
|
|
1305
|
+
"normalize": trial.suggest_categorical("normalize", [True, False]),
|
|
1306
|
+
"inverse": trial.suggest_categorical("inverse", [True, False]),
|
|
1307
|
+
"gamma": trial.suggest_float("gamma", 0.0, 10.0, step=0.1),
|
|
1308
|
+
"gamma_schedule": trial.suggest_categorical(
|
|
1309
|
+
"gamma_schedule", [True, False]
|
|
1153
1310
|
),
|
|
1154
1311
|
}
|
|
1155
1312
|
|
|
1156
|
-
nF
|
|
1157
|
-
nC
|
|
1313
|
+
nF = int(self.num_features_)
|
|
1314
|
+
nC = int(self.num_classes_)
|
|
1158
1315
|
input_dim = nF * nC
|
|
1316
|
+
|
|
1159
1317
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1160
1318
|
n_inputs=input_dim,
|
|
1161
|
-
n_outputs=
|
|
1319
|
+
n_outputs=nC,
|
|
1162
1320
|
n_samples=len(self.train_idx_),
|
|
1163
|
-
n_hidden=params["num_hidden_layers"],
|
|
1164
|
-
|
|
1165
|
-
|
|
1321
|
+
n_hidden=int(params["num_hidden_layers"]),
|
|
1322
|
+
latent_dim=int(params["latent_dim"]),
|
|
1323
|
+
alpha=float(params["layer_scaling_factor"]),
|
|
1324
|
+
schedule=str(params["layer_schedule"]),
|
|
1166
1325
|
)
|
|
1167
1326
|
|
|
1168
|
-
# Keep the latent_dim as the first element,
|
|
1169
|
-
# then the interior hidden widths.
|
|
1170
|
-
# If there are no interior widths (very small nets),
|
|
1171
|
-
# this still leaves [latent_dim].
|
|
1172
|
-
hidden_only: list[int] = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1173
|
-
|
|
1174
1327
|
params["model_params"] = {
|
|
1175
|
-
"n_features":
|
|
1176
|
-
"num_classes":
|
|
1177
|
-
int(self.num_classes_) if self.num_classes_ is not None else 3
|
|
1178
|
-
),
|
|
1328
|
+
"n_features": nF,
|
|
1329
|
+
"num_classes": nC,
|
|
1179
1330
|
"latent_dim": int(params["latent_dim"]),
|
|
1180
1331
|
"dropout_rate": float(params["dropout_rate"]),
|
|
1181
|
-
"hidden_layer_sizes":
|
|
1332
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1182
1333
|
"activation": str(params["activation"]),
|
|
1183
1334
|
}
|
|
1184
1335
|
return params
|
|
1185
1336
|
|
|
1186
|
-
def _set_best_params(
|
|
1187
|
-
|
|
1188
|
-
) -> Dict[str, int | float | str | List[int]]:
|
|
1189
|
-
"""Adopt best params (ImputeNLPCA parity) and return model_params.
|
|
1190
|
-
|
|
1191
|
-
This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
|
|
1337
|
+
def _set_best_params(self, params: dict) -> dict:
|
|
1338
|
+
"""Update instance fields from tuned params and return model_params dict.
|
|
1192
1339
|
|
|
1193
1340
|
Args:
|
|
1194
|
-
|
|
1341
|
+
params (dict): Best hyperparameters from tuning.
|
|
1195
1342
|
|
|
1196
1343
|
Returns:
|
|
1197
|
-
|
|
1344
|
+
dict: Model parameters for building the final model.
|
|
1198
1345
|
"""
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
else
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
self.latent_dim: int = bp["latent_dim"]
|
|
1222
|
-
self.dropout_rate: float = bp["dropout_rate"]
|
|
1223
|
-
self.learning_rate: float = bp["learning_rate"]
|
|
1224
|
-
self.l1_penalty: float = bp["l1_penalty"]
|
|
1225
|
-
self.activation: str = bp["activation"]
|
|
1226
|
-
self.layer_scaling_factor: float = bp["layer_scaling_factor"]
|
|
1227
|
-
self.layer_schedule: str = bp["layer_schedule"]
|
|
1228
|
-
|
|
1229
|
-
nF: int = self.num_features_
|
|
1230
|
-
nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
|
|
1231
|
-
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1232
|
-
n_inputs=nF * nC,
|
|
1233
|
-
n_outputs=nF * nC,
|
|
1234
|
-
n_samples=len(self.train_idx_),
|
|
1235
|
-
n_hidden=bp["num_hidden_layers"],
|
|
1236
|
-
alpha=bp["layer_scaling_factor"],
|
|
1237
|
-
schedule=bp["layer_schedule"],
|
|
1346
|
+
self.latent_dim = int(params["latent_dim"])
|
|
1347
|
+
self.dropout_rate = float(params["dropout_rate"])
|
|
1348
|
+
self.learning_rate = float(params["learning_rate"])
|
|
1349
|
+
self.l1_penalty = float(params["l1_penalty"])
|
|
1350
|
+
self.activation = str(params["activation"])
|
|
1351
|
+
self.layer_scaling_factor = float(params["layer_scaling_factor"])
|
|
1352
|
+
self.layer_schedule = str(params["layer_schedule"])
|
|
1353
|
+
|
|
1354
|
+
self.power = float(params["power"])
|
|
1355
|
+
self.normalize = bool(params["normalize"])
|
|
1356
|
+
self.inverse = bool(params["inverse"])
|
|
1357
|
+
self.gamma = float(params["gamma"])
|
|
1358
|
+
self.gamma_schedule = bool(params["gamma_schedule"])
|
|
1359
|
+
|
|
1360
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
1361
|
+
self.y_train_,
|
|
1362
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1363
|
+
inverse=self.inverse,
|
|
1364
|
+
normalize=self.normalize,
|
|
1365
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1366
|
+
power=self.power,
|
|
1238
1367
|
)
|
|
1239
1368
|
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
# this still leaves [latent_dim].
|
|
1244
|
-
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1245
|
-
|
|
1246
|
-
return {
|
|
1247
|
-
"n_features": self.num_features_,
|
|
1248
|
-
"latent_dim": self.latent_dim,
|
|
1249
|
-
"hidden_layer_sizes": hidden_only,
|
|
1250
|
-
"dropout_rate": self.dropout_rate,
|
|
1251
|
-
"activation": self.activation,
|
|
1252
|
-
"num_classes": nC,
|
|
1253
|
-
}
|
|
1254
|
-
|
|
1255
|
-
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
1256
|
-
"""Default model params when tuning is disabled.
|
|
1257
|
-
|
|
1258
|
-
This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
|
|
1259
|
-
|
|
1260
|
-
Returns:
|
|
1261
|
-
Dict[str, int | float | str | list]: Default model parameters.
|
|
1262
|
-
"""
|
|
1263
|
-
nF: int = self.num_features_
|
|
1264
|
-
nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
|
|
1265
|
-
ls = self.layer_schedule
|
|
1266
|
-
|
|
1267
|
-
if ls not in {"pyramid", "constant", "linear"}:
|
|
1268
|
-
raise ValueError(f"Invalid layer_schedule: {ls}")
|
|
1369
|
+
nF = int(self.num_features_)
|
|
1370
|
+
nC = int(self.num_classes_)
|
|
1371
|
+
input_dim = nF * nC
|
|
1269
1372
|
|
|
1270
1373
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1271
|
-
n_inputs=
|
|
1272
|
-
n_outputs=
|
|
1273
|
-
n_samples=len(self.
|
|
1274
|
-
n_hidden=
|
|
1275
|
-
|
|
1276
|
-
|
|
1374
|
+
n_inputs=input_dim,
|
|
1375
|
+
n_outputs=nC,
|
|
1376
|
+
n_samples=len(self.train_idx_),
|
|
1377
|
+
n_hidden=int(params["num_hidden_layers"]),
|
|
1378
|
+
latent_dim=int(params["latent_dim"]),
|
|
1379
|
+
alpha=float(params["layer_scaling_factor"]),
|
|
1380
|
+
schedule=str(params["layer_schedule"]),
|
|
1277
1381
|
)
|
|
1382
|
+
|
|
1278
1383
|
return {
|
|
1279
|
-
"n_features":
|
|
1384
|
+
"n_features": nF,
|
|
1385
|
+
"num_classes": nC,
|
|
1280
1386
|
"latent_dim": self.latent_dim,
|
|
1281
1387
|
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1282
1388
|
"dropout_rate": self.dropout_rate,
|
|
1283
1389
|
"activation": self.activation,
|
|
1284
|
-
"num_classes": nC,
|
|
1285
1390
|
}
|