pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,23 +1,25 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
1
4
|
import copy
|
|
2
|
-
|
|
5
|
+
import traceback
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
|
|
3
8
|
|
|
4
9
|
import matplotlib.pyplot as plt
|
|
5
10
|
import numpy as np
|
|
6
11
|
import optuna
|
|
7
12
|
import torch
|
|
8
|
-
import torch.nn.functional as F
|
|
9
13
|
from sklearn.exceptions import NotFittedError
|
|
10
|
-
from sklearn.model_selection import train_test_split
|
|
11
14
|
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
12
15
|
from snpio.utils.logging import LoggerManager
|
|
13
16
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
14
17
|
|
|
15
18
|
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
16
19
|
from pgsui.data_processing.containers import AutoencoderConfig
|
|
17
|
-
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
18
20
|
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
19
21
|
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
20
|
-
from pgsui.impute.unsupervised.loss_functions import
|
|
22
|
+
from pgsui.impute.unsupervised.loss_functions import FocalCELoss
|
|
21
23
|
from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
|
|
22
24
|
from pgsui.utils.logging_utils import configure_logger
|
|
23
25
|
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
@@ -27,30 +29,72 @@ if TYPE_CHECKING:
|
|
|
27
29
|
from snpio.read_input.genotype_data import GenotypeData
|
|
28
30
|
|
|
29
31
|
|
|
32
|
+
def _make_warmup_cosine_scheduler(
|
|
33
|
+
optimizer: torch.optim.Optimizer,
|
|
34
|
+
*,
|
|
35
|
+
max_epochs: int,
|
|
36
|
+
warmup_epochs: int,
|
|
37
|
+
start_factor: float = 0.1,
|
|
38
|
+
) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
|
|
39
|
+
"""Create a warmup->cosine LR scheduler.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
optimizer: Optimizer to schedule.
|
|
43
|
+
max_epochs: Total number of epochs.
|
|
44
|
+
warmup_epochs: Number of warmup epochs.
|
|
45
|
+
start_factor: Starting LR factor for warmup.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: LR scheduler (SequentialLR if warmup_epochs > 0 else CosineAnnealingLR).
|
|
49
|
+
"""
|
|
50
|
+
warmup_epochs = int(max(0, warmup_epochs))
|
|
51
|
+
|
|
52
|
+
if warmup_epochs == 0:
|
|
53
|
+
return CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
54
|
+
|
|
55
|
+
warmup = torch.optim.lr_scheduler.LinearLR(
|
|
56
|
+
optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
|
|
57
|
+
)
|
|
58
|
+
cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
|
|
59
|
+
|
|
60
|
+
return torch.optim.lr_scheduler.SequentialLR(
|
|
61
|
+
optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
30
65
|
def ensure_autoencoder_config(
|
|
31
66
|
config: AutoencoderConfig | dict | str | None,
|
|
32
67
|
) -> AutoencoderConfig:
|
|
33
68
|
"""Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
|
|
34
69
|
|
|
35
|
-
|
|
70
|
+
Notes:
|
|
71
|
+
- Supports top-level preset, or io.preset inside dict/YAML.
|
|
72
|
+
- Does not mutate user-provided dict (deep-copies before processing).
|
|
73
|
+
- Flattens nested dicts into dot-keys and applies them as overrides.
|
|
36
74
|
|
|
37
75
|
Args:
|
|
38
|
-
config
|
|
76
|
+
config: AutoencoderConfig instance, dict, YAML path, or None.
|
|
39
77
|
|
|
40
78
|
Returns:
|
|
41
|
-
|
|
79
|
+
Concrete AutoencoderConfig.
|
|
42
80
|
"""
|
|
43
81
|
if config is None:
|
|
44
82
|
return AutoencoderConfig()
|
|
45
83
|
if isinstance(config, AutoencoderConfig):
|
|
46
84
|
return config
|
|
47
85
|
if isinstance(config, str):
|
|
48
|
-
# YAML path — top-level `preset` key is supported
|
|
49
86
|
return load_yaml_to_dataclass(config, AutoencoderConfig)
|
|
50
87
|
if isinstance(config, dict):
|
|
51
|
-
|
|
88
|
+
cfg_in = copy.deepcopy(config)
|
|
52
89
|
base = AutoencoderConfig()
|
|
53
90
|
|
|
91
|
+
preset = cfg_in.pop("preset", None)
|
|
92
|
+
if "io" in cfg_in and isinstance(cfg_in["io"], dict):
|
|
93
|
+
preset = preset or cfg_in["io"].pop("preset", None)
|
|
94
|
+
|
|
95
|
+
if preset:
|
|
96
|
+
base = AutoencoderConfig.from_preset(preset)
|
|
97
|
+
|
|
54
98
|
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
55
99
|
for k, v in d.items():
|
|
56
100
|
kk = f"{prefix}.{k}" if prefix else k
|
|
@@ -60,26 +104,24 @@ def ensure_autoencoder_config(
|
|
|
60
104
|
out[kk] = v
|
|
61
105
|
return out
|
|
62
106
|
|
|
63
|
-
|
|
64
|
-
preset_name = config.pop("preset", None)
|
|
65
|
-
if "io" in config and isinstance(config["io"], dict):
|
|
66
|
-
preset_name = preset_name or config["io"].pop("preset", None)
|
|
67
|
-
|
|
68
|
-
if preset_name:
|
|
69
|
-
base = AutoencoderConfig.from_preset(preset_name)
|
|
70
|
-
|
|
71
|
-
flat = _flatten("", config, {})
|
|
107
|
+
flat = _flatten("", cfg_in, {})
|
|
72
108
|
return apply_dot_overrides(base, flat)
|
|
73
109
|
|
|
74
110
|
raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
|
|
75
111
|
|
|
76
112
|
|
|
77
113
|
class ImputeAutoencoder(BaseNNImputer):
|
|
78
|
-
"""
|
|
114
|
+
"""Autoencoder imputer for 0/1/2 genotypes.
|
|
79
115
|
|
|
80
|
-
|
|
116
|
+
Trains a feedforward autoencoder on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. Missingness is simulated once on the full matrix, then train/val/test splits reuse those masks. It supports haploid and diploid data, focal-CE reconstruction loss (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
|
|
81
117
|
|
|
82
|
-
|
|
118
|
+
Notes:
|
|
119
|
+
- Simulates missingness once on the full 0/1/2 matrix, then splits indices on clean ground truth.
|
|
120
|
+
- Maintains clean targets and corrupted inputs per train/val/test, plus per-split masks.
|
|
121
|
+
- Haploid harmonization happens after the single simulation (no re-simulation).
|
|
122
|
+
- Training/validation loss is computed only where targets are known (~orig_mask_*).
|
|
123
|
+
- Evaluation is computed only on simulated-missing sites (sim_mask_*).
|
|
124
|
+
- ``transform()`` fills only originally missing sites and hard-errors if decoding yields "N".
|
|
83
125
|
"""
|
|
84
126
|
|
|
85
127
|
def __init__(
|
|
@@ -88,8 +130,7 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
88
130
|
*,
|
|
89
131
|
tree_parser: Optional["TreeParser"] = None,
|
|
90
132
|
config: Optional[Union["AutoencoderConfig", dict, str]] = None,
|
|
91
|
-
overrides: dict
|
|
92
|
-
simulate_missing: bool | None = None,
|
|
133
|
+
overrides: Optional[dict] = None,
|
|
93
134
|
sim_strategy: (
|
|
94
135
|
Literal[
|
|
95
136
|
"random",
|
|
@@ -100,34 +141,29 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
100
141
|
]
|
|
101
142
|
| None
|
|
102
143
|
) = None,
|
|
103
|
-
sim_prop: float
|
|
104
|
-
sim_kwargs: dict
|
|
144
|
+
sim_prop: Optional[float] = None,
|
|
145
|
+
sim_kwargs: Optional[dict] = None,
|
|
105
146
|
) -> None:
|
|
106
147
|
"""Initialize the Autoencoder imputer with a unified config interface.
|
|
107
148
|
|
|
108
|
-
This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
109
|
-
|
|
110
149
|
Args:
|
|
111
|
-
genotype_data (
|
|
112
|
-
tree_parser (Optional[
|
|
113
|
-
config (Union[
|
|
114
|
-
overrides (dict
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
sim_kwargs (dict | None): Additional keyword arguments for simulating missing data. If None, uses config default.
|
|
150
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
151
|
+
tree_parser (Optional[TreeParser]): Optional SNPio tree parser for nonrandom simulated-missing modes.
|
|
152
|
+
config (Optional[Union[AutoencoderConfig, dict, str]]): AutoencoderConfig, nested dict, YAML path, or None.
|
|
153
|
+
overrides (Optional[dict]): Optional dot-key overrides with highest precedence.
|
|
154
|
+
sim_strategy (Literal["random", "random_weighted" "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Override sim strategy; if None, uses config default.
|
|
155
|
+
sim_prop (Optional[float]): Override simulated missing proportion; if None, uses config default. Default is None.
|
|
156
|
+
sim_kwargs (Optional[dict]): Override/extend simulated missing kwargs; if None, uses config default.
|
|
119
157
|
"""
|
|
120
158
|
self.model_name = "ImputeAutoencoder"
|
|
121
159
|
self.genotype_data = genotype_data
|
|
122
160
|
self.tree_parser = tree_parser
|
|
123
161
|
|
|
124
|
-
# Normalize config then apply highest-precedence overrides
|
|
125
162
|
cfg = ensure_autoencoder_config(config)
|
|
126
163
|
if overrides:
|
|
127
164
|
cfg = apply_dot_overrides(cfg, overrides)
|
|
128
165
|
self.cfg = cfg
|
|
129
166
|
|
|
130
|
-
# Logger consistent with NLPCA
|
|
131
167
|
logman = LoggerManager(
|
|
132
168
|
__name__,
|
|
133
169
|
prefix=self.cfg.io.prefix,
|
|
@@ -139,8 +175,8 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
139
175
|
verbose=self.cfg.io.verbose,
|
|
140
176
|
debug=self.cfg.io.debug,
|
|
141
177
|
)
|
|
178
|
+
self.logger.propagate = False
|
|
142
179
|
|
|
143
|
-
# BaseNNImputer bootstrapping (device/dirs/logging handled here)
|
|
144
180
|
super().__init__(
|
|
145
181
|
model_name=self.model_name,
|
|
146
182
|
genotype_data=self.genotype_data,
|
|
@@ -151,11 +187,9 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
151
187
|
)
|
|
152
188
|
|
|
153
189
|
self.Model = AutoencoderModel
|
|
154
|
-
|
|
155
|
-
# Model hook & encoder
|
|
156
190
|
self.pgenc = GenotypeEncoder(genotype_data)
|
|
157
191
|
|
|
158
|
-
#
|
|
192
|
+
# I/O and global
|
|
159
193
|
self.seed = self.cfg.io.seed
|
|
160
194
|
self.n_jobs = self.cfg.io.n_jobs
|
|
161
195
|
self.prefix = self.cfg.io.prefix
|
|
@@ -163,264 +197,347 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
163
197
|
self.verbose = self.cfg.io.verbose
|
|
164
198
|
self.debug = self.cfg.io.debug
|
|
165
199
|
self.rng = np.random.default_rng(self.seed)
|
|
166
|
-
self.pos_weights_: torch.Tensor | None = None
|
|
167
200
|
|
|
168
|
-
#
|
|
201
|
+
# Simulation controls (match VAE pattern)
|
|
169
202
|
sim_cfg = getattr(self.cfg, "sim", None)
|
|
170
203
|
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
171
204
|
if sim_kwargs:
|
|
172
205
|
sim_cfg_kwargs.update(sim_kwargs)
|
|
173
|
-
|
|
174
|
-
(
|
|
175
|
-
sim_cfg.simulate_missing
|
|
176
|
-
if simulate_missing is None
|
|
177
|
-
else bool(simulate_missing)
|
|
178
|
-
)
|
|
179
|
-
if sim_cfg is not None
|
|
180
|
-
else bool(simulate_missing)
|
|
181
|
-
)
|
|
206
|
+
|
|
182
207
|
if sim_cfg is None:
|
|
183
208
|
default_strategy = "random"
|
|
184
|
-
default_prop = 0.
|
|
209
|
+
default_prop = 0.2
|
|
185
210
|
else:
|
|
186
211
|
default_strategy = sim_cfg.sim_strategy
|
|
187
212
|
default_prop = sim_cfg.sim_prop
|
|
213
|
+
|
|
214
|
+
self.simulate_missing = True
|
|
188
215
|
self.sim_strategy = sim_strategy or default_strategy
|
|
189
216
|
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
190
217
|
self.sim_kwargs = sim_cfg_kwargs
|
|
191
218
|
|
|
192
219
|
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
193
|
-
msg = "tree_parser is required for nonrandom
|
|
220
|
+
msg = "tree_parser is required for nonrandom sim strategies."
|
|
194
221
|
self.logger.error(msg)
|
|
195
222
|
raise ValueError(msg)
|
|
196
223
|
|
|
197
|
-
# Model
|
|
224
|
+
# Model architecture
|
|
198
225
|
self.latent_dim = int(self.cfg.model.latent_dim)
|
|
199
226
|
self.dropout_rate = float(self.cfg.model.dropout_rate)
|
|
200
227
|
self.num_hidden_layers = int(self.cfg.model.num_hidden_layers)
|
|
201
228
|
self.layer_scaling_factor = float(self.cfg.model.layer_scaling_factor)
|
|
202
|
-
self.layer_schedule
|
|
203
|
-
self.activation = str(self.cfg.model.
|
|
204
|
-
|
|
229
|
+
self.layer_schedule = str(self.cfg.model.layer_schedule)
|
|
230
|
+
self.activation = str(self.cfg.model.activation)
|
|
231
|
+
|
|
232
|
+
# Training / loss controls (align with VAE fields where present)
|
|
233
|
+
self.power = float(getattr(self.cfg.train, "weights_power", 1.0))
|
|
234
|
+
self.max_ratio = getattr(self.cfg.train, "weights_max_ratio", None)
|
|
235
|
+
self.normalize = bool(getattr(self.cfg.train, "weights_normalize", True))
|
|
236
|
+
self.inverse = bool(getattr(self.cfg.train, "weights_inverse", False))
|
|
205
237
|
|
|
206
|
-
# Train hyperparams
|
|
207
238
|
self.batch_size = int(self.cfg.train.batch_size)
|
|
208
239
|
self.learning_rate = float(self.cfg.train.learning_rate)
|
|
209
|
-
self.l1_penalty
|
|
240
|
+
self.l1_penalty = float(self.cfg.train.l1_penalty)
|
|
210
241
|
self.early_stop_gen = int(self.cfg.train.early_stop_gen)
|
|
211
242
|
self.min_epochs = int(self.cfg.train.min_epochs)
|
|
212
243
|
self.epochs = int(self.cfg.train.max_epochs)
|
|
213
244
|
self.validation_split = float(self.cfg.train.validation_split)
|
|
214
|
-
self.beta = float(self.cfg.train.weights_beta)
|
|
215
|
-
self.max_ratio = float(self.cfg.train.weights_max_ratio)
|
|
216
245
|
|
|
217
|
-
#
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
Literal[
|
|
228
|
-
"pr_macro",
|
|
229
|
-
"f1",
|
|
230
|
-
"accuracy",
|
|
231
|
-
"precision",
|
|
232
|
-
"recall",
|
|
233
|
-
"roc_auc",
|
|
234
|
-
"average_precision",
|
|
235
|
-
]
|
|
236
|
-
| None
|
|
237
|
-
) = self.cfg.tune.metric
|
|
246
|
+
# Gamma can live in cfg.model or cfg.train depending on your dataclasses
|
|
247
|
+
gamma_raw = getattr(
|
|
248
|
+
self.cfg.train, "gamma", getattr(self.cfg.model, "gamma", 0.0)
|
|
249
|
+
)
|
|
250
|
+
if not isinstance(gamma_raw, (float, int)):
|
|
251
|
+
msg = f"Gamma must be float|int; got {type(gamma_raw)}."
|
|
252
|
+
self.logger.error(msg)
|
|
253
|
+
raise TypeError(msg)
|
|
254
|
+
self.gamma = float(gamma_raw)
|
|
255
|
+
self.gamma_schedule = bool(getattr(self.cfg.train, "gamma_schedule", True))
|
|
238
256
|
|
|
257
|
+
# Hyperparameter tuning
|
|
258
|
+
self.tune = bool(self.cfg.tune.enabled)
|
|
259
|
+
self.tune_metric = cast(
|
|
260
|
+
Literal[
|
|
261
|
+
"pr_macro",
|
|
262
|
+
"f1",
|
|
263
|
+
"accuracy",
|
|
264
|
+
"precision",
|
|
265
|
+
"recall",
|
|
266
|
+
"roc_auc",
|
|
267
|
+
"average_precision",
|
|
268
|
+
"mcc",
|
|
269
|
+
"jaccard",
|
|
270
|
+
],
|
|
271
|
+
self.cfg.tune.metric or "f1",
|
|
272
|
+
)
|
|
239
273
|
self.n_trials = int(self.cfg.tune.n_trials)
|
|
240
274
|
self.tune_save_db = bool(self.cfg.tune.save_db)
|
|
241
275
|
self.tune_resume = bool(self.cfg.tune.resume)
|
|
242
|
-
self.tune_max_samples = int(self.cfg.tune.max_samples)
|
|
243
|
-
self.tune_max_loci = int(self.cfg.tune.max_loci)
|
|
244
|
-
self.tune_infer_epochs = int(
|
|
245
|
-
getattr(self.cfg.tune, "infer_epochs", 0)
|
|
246
|
-
) # AE unused
|
|
247
276
|
self.tune_patience = int(self.cfg.tune.patience)
|
|
248
277
|
|
|
249
|
-
#
|
|
250
|
-
|
|
251
|
-
self.eval_latent_steps: int = 0
|
|
252
|
-
self.eval_latent_lr: float = 0.0
|
|
253
|
-
self.eval_latent_weight_decay: float = 0.0
|
|
254
|
-
|
|
255
|
-
# Plotting (parity with NLPCA PlotConfig)
|
|
256
|
-
self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = (
|
|
257
|
-
self.cfg.plot.fmt
|
|
258
|
-
)
|
|
278
|
+
# Plotting
|
|
279
|
+
self.plot_format = self.cfg.plot.fmt
|
|
259
280
|
self.plot_dpi = int(self.cfg.plot.dpi)
|
|
260
281
|
self.plot_fontsize = int(self.cfg.plot.fontsize)
|
|
261
282
|
self.title_fontsize = int(self.cfg.plot.fontsize)
|
|
262
283
|
self.despine = bool(self.cfg.plot.despine)
|
|
263
284
|
self.show_plots = bool(self.cfg.plot.show)
|
|
264
285
|
|
|
265
|
-
#
|
|
266
|
-
self.
|
|
267
|
-
self.num_classes_: int
|
|
286
|
+
# Fit-time attributes
|
|
287
|
+
self.is_haploid_: bool = False
|
|
288
|
+
self.num_classes_: int = 3
|
|
268
289
|
self.model_params: Dict[str, Any] = {}
|
|
269
|
-
self.sim_mask_global_: np.ndarray | None = None
|
|
270
|
-
self.sim_mask_train_: np.ndarray | None = None
|
|
271
|
-
self.sim_mask_test_: np.ndarray | None = None
|
|
272
290
|
|
|
273
|
-
|
|
274
|
-
|
|
291
|
+
self.sim_mask_train_: np.ndarray
|
|
292
|
+
self.sim_mask_val_: np.ndarray
|
|
293
|
+
self.sim_mask_test_: np.ndarray
|
|
275
294
|
|
|
276
|
-
|
|
295
|
+
self.orig_mask_train_: np.ndarray
|
|
296
|
+
self.orig_mask_val_: np.ndarray
|
|
297
|
+
self.orig_mask_test_: np.ndarray
|
|
277
298
|
|
|
278
|
-
|
|
279
|
-
|
|
299
|
+
def fit(self) -> "ImputeAutoencoder":
|
|
300
|
+
"""Fit the Autoencoder imputer model to the genotype data.
|
|
301
|
+
|
|
302
|
+
This method performs the following steps:
|
|
303
|
+
1. Validates the presence of SNP data in the genotype data.
|
|
304
|
+
2. Determines ploidy and sets up the number of classes accordingly.
|
|
305
|
+
3. Cleans the ground truth genotype matrix and simulates missingness.
|
|
306
|
+
4. Splits the data into training, validation, and test sets.
|
|
307
|
+
5. Prepares one-hot encoded inputs for the model.
|
|
308
|
+
6. Initializes plotting utilities and valid-class masks.
|
|
309
|
+
7. Sets up data loaders for training and validation.
|
|
310
|
+
8. Performs hyperparameter tuning if enabled, otherwise uses fixed hyperparameters.
|
|
311
|
+
9. Builds and trains the Autoencoder model.
|
|
312
|
+
10. Evaluates the trained model on the test set.
|
|
313
|
+
11. Returns the fitted ImputeAutoencoder instance.
|
|
280
314
|
|
|
281
|
-
|
|
282
|
-
|
|
315
|
+
Returns:
|
|
316
|
+
ImputeAutoencoder: The fitted ImputeAutoencoder instance.
|
|
283
317
|
"""
|
|
284
318
|
self.logger.info(f"Fitting {self.model_name} model...")
|
|
285
319
|
|
|
286
|
-
# --- Data prep (mirror NLPCA) ---
|
|
287
|
-
X012 = self._get_float_genotypes(copy=True)
|
|
288
|
-
GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
|
|
289
|
-
self.ground_truth_ = GT_full.astype(np.int64, copy=False)
|
|
290
|
-
|
|
291
|
-
self.sim_mask_global_ = None
|
|
292
|
-
cache_key = self._sim_mask_cache_key()
|
|
293
|
-
if self.simulate_missing:
|
|
294
|
-
cached_mask = (
|
|
295
|
-
None if cache_key is None else self._sim_mask_cache.get(cache_key)
|
|
296
|
-
)
|
|
297
|
-
if cached_mask is not None:
|
|
298
|
-
self.sim_mask_global_ = cached_mask.copy()
|
|
299
|
-
else:
|
|
300
|
-
tr = SimMissingTransformer(
|
|
301
|
-
genotype_data=self.genotype_data,
|
|
302
|
-
tree_parser=self.tree_parser,
|
|
303
|
-
prop_missing=self.sim_prop,
|
|
304
|
-
strategy=self.sim_strategy,
|
|
305
|
-
missing_val=-9,
|
|
306
|
-
mask_missing=True,
|
|
307
|
-
verbose=self.verbose,
|
|
308
|
-
**self.sim_kwargs,
|
|
309
|
-
)
|
|
310
|
-
tr.fit(X012.copy())
|
|
311
|
-
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
312
|
-
if cache_key is not None:
|
|
313
|
-
self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
|
|
314
|
-
|
|
315
|
-
X_for_model = self.ground_truth_.copy()
|
|
316
|
-
X_for_model[self.sim_mask_global_] = -1
|
|
317
|
-
else:
|
|
318
|
-
X_for_model = self.ground_truth_.copy()
|
|
319
|
-
|
|
320
320
|
if self.genotype_data.snp_data is None:
|
|
321
|
-
msg = "SNP data is required for
|
|
321
|
+
msg = f"SNP data is required for {self.model_name}."
|
|
322
322
|
self.logger.error(msg)
|
|
323
|
-
raise
|
|
323
|
+
raise AttributeError(msg)
|
|
324
324
|
|
|
325
|
-
|
|
326
|
-
self.
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
325
|
+
self.ploidy = self.cfg.io.ploidy
|
|
326
|
+
self.is_haploid_ = self.ploidy == 1
|
|
327
|
+
|
|
328
|
+
if self.ploidy > 2:
|
|
329
|
+
msg = (
|
|
330
|
+
f"{self.model_name} currently supports only haploid (1) or diploid (2) "
|
|
331
|
+
f"data; got ploidy={self.ploidy}."
|
|
332
332
|
)
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
# Scoring still uses 3 labels for diploid (REF/HET/ALT); model head uses 2 logits
|
|
336
|
-
self.num_classes_ = 2 if self.is_haploid else 3
|
|
337
|
-
self.output_classes_ = 2
|
|
338
|
-
self.logger.info(
|
|
339
|
-
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
340
|
-
f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
|
|
341
|
-
)
|
|
333
|
+
self.logger.error(msg)
|
|
334
|
+
raise ValueError(msg)
|
|
342
335
|
|
|
343
|
-
if self.
|
|
344
|
-
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
345
|
-
X_for_model[X_for_model == 2] = 1
|
|
336
|
+
self.num_classes_ = 2 if self.is_haploid_ else 3
|
|
346
337
|
|
|
347
|
-
|
|
338
|
+
# Clean 0/1/2 ground truth (missing=-1)
|
|
339
|
+
gt_full = self.pgenc.genotypes_012.copy()
|
|
340
|
+
gt_full[gt_full < 0] = -1
|
|
341
|
+
gt_full = np.nan_to_num(gt_full, nan=-1.0)
|
|
342
|
+
self.ground_truth_ = gt_full.astype(np.int8)
|
|
343
|
+
self.num_features_ = int(self.ground_truth_.shape[1])
|
|
348
344
|
|
|
349
|
-
# Model params (decoder outputs L * K logits)
|
|
350
345
|
self.model_params = {
|
|
351
346
|
"n_features": self.num_features_,
|
|
352
|
-
"num_classes": self.
|
|
347
|
+
"num_classes": self.num_classes_,
|
|
353
348
|
"latent_dim": self.latent_dim,
|
|
354
349
|
"dropout_rate": self.dropout_rate,
|
|
355
350
|
"activation": self.activation,
|
|
356
351
|
}
|
|
357
352
|
|
|
358
|
-
#
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
indices, test_size=self.validation_split, random_state=self.seed
|
|
353
|
+
# Simulate missingness ONCE on the full matrix
|
|
354
|
+
X_for_model_full, self.sim_mask_, self.orig_mask_ = self.sim_missing_transform(
|
|
355
|
+
self.ground_truth_
|
|
362
356
|
)
|
|
363
|
-
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
364
|
-
self.X_train_ = X_for_model[train_idx]
|
|
365
|
-
self.X_val_ = X_for_model[val_idx]
|
|
366
|
-
self.GT_train_full_ = self.ground_truth_[train_idx]
|
|
367
|
-
self.GT_test_full_ = self.ground_truth_[val_idx]
|
|
368
|
-
|
|
369
|
-
if self.sim_mask_global_ is not None:
|
|
370
|
-
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
371
|
-
self.sim_mask_test_ = self.sim_mask_global_[val_idx]
|
|
372
|
-
else:
|
|
373
|
-
self.sim_mask_train_ = None
|
|
374
|
-
self.sim_mask_test_ = None
|
|
375
357
|
|
|
376
|
-
#
|
|
377
|
-
|
|
378
|
-
self.
|
|
379
|
-
|
|
380
|
-
|
|
358
|
+
# Split indices based on clean ground truth
|
|
359
|
+
self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
|
|
360
|
+
self.ground_truth_
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# --- Clean targets per split ---
|
|
364
|
+
X_train_clean = self.ground_truth_[self.train_idx_].copy()
|
|
365
|
+
X_val_clean = self.ground_truth_[self.val_idx_].copy()
|
|
366
|
+
X_test_clean = self.ground_truth_[self.test_idx_].copy()
|
|
367
|
+
|
|
368
|
+
# --- Corrupted inputs per split (from the single simulation) ---
|
|
369
|
+
X_train_corrupted = X_for_model_full[self.train_idx_].copy()
|
|
370
|
+
X_val_corrupted = X_for_model_full[self.val_idx_].copy()
|
|
371
|
+
X_test_corrupted = X_for_model_full[self.test_idx_].copy()
|
|
372
|
+
|
|
373
|
+
# --- Masks per split ---
|
|
374
|
+
self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
|
|
375
|
+
self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
|
|
376
|
+
self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
|
|
377
|
+
|
|
378
|
+
self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
|
|
379
|
+
self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
|
|
380
|
+
self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
|
|
381
|
+
|
|
382
|
+
# Persist per-split matrices
|
|
383
|
+
self.X_train_clean_ = X_train_clean
|
|
384
|
+
self.X_val_clean_ = X_val_clean
|
|
385
|
+
self.X_test_clean_ = X_test_clean
|
|
386
|
+
|
|
387
|
+
self.X_train_corrupted_ = X_train_corrupted
|
|
388
|
+
self.X_val_corrupted_ = X_val_corrupted
|
|
389
|
+
self.X_test_corrupted_ = X_test_corrupted
|
|
390
|
+
|
|
391
|
+
# Haploid harmonization (do NOT resimulate; just recode values)
|
|
392
|
+
if self.is_haploid_:
|
|
393
|
+
|
|
394
|
+
def _haploidize(arr: np.ndarray) -> np.ndarray:
|
|
395
|
+
out = arr.copy()
|
|
396
|
+
miss = out < 0
|
|
397
|
+
out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
|
|
398
|
+
out[miss] = -1
|
|
399
|
+
return out
|
|
400
|
+
|
|
401
|
+
self.X_train_clean_ = _haploidize(self.X_train_clean_)
|
|
402
|
+
self.X_val_clean_ = _haploidize(self.X_val_clean_)
|
|
403
|
+
self.X_test_clean_ = _haploidize(self.X_test_clean_)
|
|
404
|
+
|
|
405
|
+
self.X_train_corrupted_ = _haploidize(self.X_train_corrupted_)
|
|
406
|
+
self.X_val_corrupted_ = _haploidize(self.X_val_corrupted_)
|
|
407
|
+
self.X_test_corrupted_ = _haploidize(self.X_test_corrupted_)
|
|
408
|
+
|
|
409
|
+
# Convention: X_* are corrupted inputs; y_* are clean targets
|
|
410
|
+
self.X_train_ = self.X_train_corrupted_
|
|
411
|
+
self.y_train_ = self.X_train_clean_
|
|
412
|
+
|
|
413
|
+
self.X_val_ = self.X_val_corrupted_
|
|
414
|
+
self.y_val_ = self.X_val_clean_
|
|
415
|
+
|
|
416
|
+
self.X_test_ = self.X_test_corrupted_
|
|
417
|
+
self.y_test_ = self.X_test_clean_
|
|
418
|
+
|
|
419
|
+
# One-hot for loaders/model input
|
|
420
|
+
X_train_ohe = self._one_hot_encode_012(
|
|
421
|
+
self.X_train_, num_classes=self.num_classes_
|
|
422
|
+
)
|
|
423
|
+
X_val_ohe = self._one_hot_encode_012(self.X_val_, num_classes=self.num_classes_)
|
|
381
424
|
|
|
382
|
-
# Plotters/scorers (
|
|
425
|
+
# Plotters/scorers + valid-class mask repairs (copied from VAE flow)
|
|
383
426
|
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
427
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
428
|
+
|
|
429
|
+
loci = getattr(self, "valid_class_mask_conflict_loci_", None)
|
|
430
|
+
if loci is not None and loci.size:
|
|
431
|
+
self._repair_ref_alt_from_iupac(loci)
|
|
432
|
+
self.valid_class_mask_ = self._build_valid_class_mask()
|
|
433
|
+
|
|
434
|
+
train_loader = self._get_data_loaders(
|
|
435
|
+
X_train_ohe.detach().cpu().numpy(),
|
|
436
|
+
self.y_train_,
|
|
437
|
+
~self.orig_mask_train_,
|
|
438
|
+
self.batch_size,
|
|
439
|
+
shuffle=True,
|
|
440
|
+
)
|
|
441
|
+
val_loader = self._get_data_loaders(
|
|
442
|
+
X_val_ohe.detach().cpu().numpy(),
|
|
443
|
+
self.y_val_,
|
|
444
|
+
~self.orig_mask_val_,
|
|
445
|
+
self.batch_size,
|
|
446
|
+
shuffle=False,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
self.train_loader_ = train_loader
|
|
450
|
+
self.val_loader_ = val_loader
|
|
384
451
|
|
|
385
|
-
#
|
|
452
|
+
# Hyperparameter tuning or fixed run
|
|
386
453
|
if self.tune:
|
|
387
|
-
self.tune_hyperparameters()
|
|
454
|
+
self.tuned_params_ = self.tune_hyperparameters()
|
|
455
|
+
self.model_tuned_ = True
|
|
456
|
+
else:
|
|
457
|
+
self.model_tuned_ = False
|
|
458
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
459
|
+
self.y_train_,
|
|
460
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
461
|
+
inverse=self.inverse,
|
|
462
|
+
normalize=self.normalize,
|
|
463
|
+
max_ratio=self.max_ratio,
|
|
464
|
+
power=self.power,
|
|
465
|
+
)
|
|
466
|
+
self.tuned_params_ = {
|
|
467
|
+
"latent_dim": self.latent_dim,
|
|
468
|
+
"learning_rate": self.learning_rate,
|
|
469
|
+
"dropout_rate": self.dropout_rate,
|
|
470
|
+
"num_hidden_layers": self.num_hidden_layers,
|
|
471
|
+
"activation": self.activation,
|
|
472
|
+
"l1_penalty": self.l1_penalty,
|
|
473
|
+
"layer_scaling_factor": self.layer_scaling_factor,
|
|
474
|
+
"layer_schedule": self.layer_schedule,
|
|
475
|
+
"gamma": self.gamma,
|
|
476
|
+
"gamma_schedule": self.gamma_schedule,
|
|
477
|
+
"inverse": self.inverse,
|
|
478
|
+
"normalize": self.normalize,
|
|
479
|
+
"power": self.power,
|
|
480
|
+
}
|
|
481
|
+
self.tuned_params_["model_params"] = self.model_params
|
|
388
482
|
|
|
389
|
-
|
|
390
|
-
|
|
483
|
+
if self.class_weights_ is not None:
|
|
484
|
+
self.logger.info(
|
|
485
|
+
f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
|
|
486
|
+
)
|
|
391
487
|
|
|
392
|
-
#
|
|
393
|
-
self.
|
|
394
|
-
self._class_weights_from_zygosity(self.X_train_)
|
|
395
|
-
)
|
|
488
|
+
# Always start clean
|
|
489
|
+
self.best_params_ = copy.deepcopy(self.tuned_params_)
|
|
396
490
|
|
|
397
|
-
#
|
|
398
|
-
|
|
491
|
+
# Final model params (compute hidden sizes using n_inputs=L*K, mirroring VAE)
|
|
492
|
+
input_dim = int(self.num_features_ * self.num_classes_)
|
|
493
|
+
model_params_final = {
|
|
494
|
+
"n_features": int(self.num_features_),
|
|
495
|
+
"num_classes": int(self.num_classes_),
|
|
496
|
+
"latent_dim": int(self.best_params_["latent_dim"]),
|
|
497
|
+
"dropout_rate": float(self.best_params_["dropout_rate"]),
|
|
498
|
+
"activation": str(self.best_params_["activation"]),
|
|
499
|
+
}
|
|
500
|
+
model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
|
|
501
|
+
n_inputs=input_dim,
|
|
502
|
+
n_outputs=int(self.num_classes_),
|
|
503
|
+
n_samples=len(self.train_idx_),
|
|
504
|
+
n_hidden=int(self.best_params_["num_hidden_layers"]),
|
|
505
|
+
latent_dim=int(self.best_params_["latent_dim"]),
|
|
506
|
+
alpha=float(self.best_params_["layer_scaling_factor"]),
|
|
507
|
+
schedule=str(self.best_params_["layer_schedule"]),
|
|
508
|
+
min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
|
|
509
|
+
)
|
|
510
|
+
self.best_params_["model_params"] = model_params_final
|
|
399
511
|
|
|
400
|
-
# Build
|
|
401
|
-
model = self.build_model(self.Model, self.best_params_)
|
|
512
|
+
# Build and train
|
|
513
|
+
model = self.build_model(self.Model, self.best_params_["model_params"])
|
|
402
514
|
model.apply(self.initialize_weights)
|
|
403
515
|
|
|
516
|
+
if self.verbose or self.debug:
|
|
517
|
+
self.logger.info("Using model hyperparameters:")
|
|
518
|
+
pm = PrettyMetrics(
|
|
519
|
+
self.best_params_, precision=3, title="Model Hyperparameters"
|
|
520
|
+
)
|
|
521
|
+
pm.render()
|
|
522
|
+
|
|
523
|
+
lr_final = float(self.best_params_["learning_rate"])
|
|
524
|
+
l1_final = float(self.best_params_["l1_penalty"])
|
|
525
|
+
gamma_schedule = bool(
|
|
526
|
+
self.best_params_.get("gamma_schedule", self.gamma_schedule)
|
|
527
|
+
)
|
|
528
|
+
|
|
404
529
|
loss, trained_model, history = self._train_and_validate_model(
|
|
405
530
|
model=model,
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
l1_penalty=self.l1_penalty,
|
|
409
|
-
return_history=True,
|
|
410
|
-
class_weights=self.class_weights_,
|
|
411
|
-
X_val=self.X_val_,
|
|
531
|
+
lr=lr_final,
|
|
532
|
+
l1_penalty=l1_final,
|
|
412
533
|
params=self.best_params_,
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
eval_requires_latents=False,
|
|
417
|
-
eval_latent_steps=0,
|
|
418
|
-
eval_latent_lr=0.0,
|
|
419
|
-
eval_latent_weight_decay=0.0,
|
|
534
|
+
trial=None,
|
|
535
|
+
class_weights=getattr(self, "class_weights_", None),
|
|
536
|
+
gamma_schedule=gamma_schedule,
|
|
420
537
|
)
|
|
421
538
|
|
|
422
539
|
if trained_model is None:
|
|
423
|
-
msg = "
|
|
540
|
+
msg = f"{self.model_name} training failed."
|
|
424
541
|
self.logger.error(msg)
|
|
425
542
|
raise RuntimeError(msg)
|
|
426
543
|
|
|
@@ -429,217 +546,194 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
429
546
|
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
430
547
|
)
|
|
431
548
|
|
|
432
|
-
|
|
433
|
-
"Train":
|
|
434
|
-
|
|
435
|
-
|
|
549
|
+
if history is None:
|
|
550
|
+
hist = {"Train": []}
|
|
551
|
+
elif isinstance(history, dict):
|
|
552
|
+
hist = dict(history)
|
|
553
|
+
else:
|
|
554
|
+
hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
|
|
555
|
+
|
|
556
|
+
self.best_loss_ = float(loss)
|
|
557
|
+
self.model_ = trained_model
|
|
558
|
+
self.history_ = hist
|
|
436
559
|
self.is_fit_ = True
|
|
437
560
|
|
|
438
|
-
# Evaluate on
|
|
439
|
-
eval_mask = (
|
|
440
|
-
self.sim_mask_test_
|
|
441
|
-
if (self.simulate_missing and self.sim_mask_test_ is not None)
|
|
442
|
-
else None
|
|
443
|
-
)
|
|
561
|
+
# Evaluate on simulated-missing sites only
|
|
444
562
|
self._evaluate_model(
|
|
445
|
-
self.
|
|
563
|
+
self.model_,
|
|
564
|
+
X=self.X_test_,
|
|
565
|
+
y=self.y_test_,
|
|
566
|
+
eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
|
|
567
|
+
objective_mode=False,
|
|
446
568
|
)
|
|
447
|
-
|
|
569
|
+
|
|
570
|
+
if self.show_plots:
|
|
571
|
+
self.plotter_.plot_history(self.history_)
|
|
572
|
+
|
|
448
573
|
self._save_best_params(self.best_params_)
|
|
449
574
|
|
|
575
|
+
if self.model_tuned_:
|
|
576
|
+
title = f"{self.model_name} Optimized Parameters"
|
|
577
|
+
|
|
578
|
+
if self.verbose or self.debug:
|
|
579
|
+
pm = PrettyMetrics(self.best_params_, precision=2, title=title)
|
|
580
|
+
pm.render()
|
|
581
|
+
|
|
582
|
+
# Save best parameters to a JSON file.
|
|
583
|
+
self._save_best_params(self.best_params_, objective_mode=True)
|
|
584
|
+
|
|
450
585
|
return self
|
|
451
586
|
|
|
452
587
|
def transform(self) -> np.ndarray:
|
|
453
|
-
"""Impute missing genotypes
|
|
588
|
+
"""Impute missing genotypes and return IUPAC strings.
|
|
454
589
|
|
|
455
|
-
This method
|
|
590
|
+
This method performs the following steps:
|
|
591
|
+
1. Validates that the model has been fitted.
|
|
592
|
+
2. Uses the trained model to predict missing genotypes for the entire dataset.
|
|
593
|
+
3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
|
|
594
|
+
4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
|
|
595
|
+
5. Checks for any remaining missing values or decoding issues, raising errors if found.
|
|
596
|
+
6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
|
|
597
|
+
7. Returns the imputed IUPAC genotype matrix.
|
|
456
598
|
|
|
457
599
|
Returns:
|
|
458
|
-
np.ndarray: IUPAC
|
|
600
|
+
np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
|
|
459
601
|
|
|
460
602
|
Raises:
|
|
461
603
|
NotFittedError: If called before fit().
|
|
604
|
+
RuntimeError: If any missing values remain or decoding yields "N".
|
|
605
|
+
RuntimeError: If loci contain 'N' after imputation due to missing REF/ALT metadata.
|
|
462
606
|
"""
|
|
463
607
|
if not getattr(self, "is_fit_", False):
|
|
464
|
-
|
|
608
|
+
msg = "Model is not fitted. Call fit() before transform()."
|
|
609
|
+
self.logger.error(msg)
|
|
610
|
+
raise NotFittedError(msg)
|
|
465
611
|
|
|
466
|
-
self.logger.info(f"Imputing entire dataset with {self.model_name}...")
|
|
612
|
+
self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
|
|
467
613
|
X_to_impute = self.ground_truth_.copy()
|
|
468
614
|
|
|
469
|
-
|
|
470
|
-
pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
|
|
615
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute)
|
|
471
616
|
|
|
472
|
-
|
|
473
|
-
missing_mask = X_to_impute == -1
|
|
617
|
+
missing_mask = X_to_impute < 0
|
|
474
618
|
imputed_array = X_to_impute.copy()
|
|
475
619
|
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
476
620
|
|
|
477
|
-
|
|
478
|
-
|
|
621
|
+
if np.any(imputed_array < 0):
|
|
622
|
+
msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
|
|
623
|
+
self.logger.error(msg)
|
|
624
|
+
raise RuntimeError(msg)
|
|
625
|
+
|
|
626
|
+
decode_input = imputed_array
|
|
627
|
+
if self.is_haploid_:
|
|
628
|
+
decode_input = imputed_array.copy()
|
|
629
|
+
decode_input[decode_input == 1] = 2
|
|
630
|
+
|
|
631
|
+
imputed_genotypes = self.decode_012(decode_input)
|
|
632
|
+
|
|
633
|
+
bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
|
|
634
|
+
if bad_loci.size > 0:
|
|
635
|
+
msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
|
|
636
|
+
self.logger.error(msg)
|
|
637
|
+
self.logger.debug(
|
|
638
|
+
"All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
|
|
639
|
+
)
|
|
640
|
+
raise RuntimeError(msg)
|
|
641
|
+
|
|
479
642
|
if self.show_plots:
|
|
480
|
-
|
|
643
|
+
original_input = X_to_impute
|
|
644
|
+
if self.is_haploid_:
|
|
645
|
+
original_input = X_to_impute.copy()
|
|
646
|
+
original_input[original_input == 1] = 2
|
|
647
|
+
|
|
648
|
+
original_genotypes = self.decode_012(original_input)
|
|
649
|
+
|
|
481
650
|
plt.rcParams.update(self.plotter_.param_dict)
|
|
482
651
|
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
483
652
|
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
484
653
|
|
|
485
654
|
return imputed_genotypes
|
|
486
655
|
|
|
487
|
-
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
488
|
-
"""Create DataLoader over indices + integer targets (-1 for missing).
|
|
489
|
-
|
|
490
|
-
This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
|
|
491
|
-
|
|
492
|
-
Args:
|
|
493
|
-
y (np.ndarray): 0/1/2 matrix with -1 for missing.
|
|
494
|
-
|
|
495
|
-
Returns:
|
|
496
|
-
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
497
|
-
"""
|
|
498
|
-
y_tensor = torch.from_numpy(y).long()
|
|
499
|
-
indices = torch.arange(len(y), dtype=torch.long)
|
|
500
|
-
dataset = torch.utils.data.TensorDataset(indices, y_tensor)
|
|
501
|
-
pin_memory = self.device.type == "cuda"
|
|
502
|
-
return torch.utils.data.DataLoader(
|
|
503
|
-
dataset,
|
|
504
|
-
batch_size=self.batch_size,
|
|
505
|
-
shuffle=True,
|
|
506
|
-
pin_memory=pin_memory,
|
|
507
|
-
)
|
|
508
|
-
|
|
509
656
|
def _train_and_validate_model(
|
|
510
657
|
self,
|
|
511
658
|
model: torch.nn.Module,
|
|
512
|
-
|
|
659
|
+
*,
|
|
513
660
|
lr: float,
|
|
514
661
|
l1_penalty: float,
|
|
515
|
-
trial: optuna.Trial
|
|
516
|
-
|
|
517
|
-
class_weights: torch.Tensor
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
eval_interval: int = 1,
|
|
524
|
-
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
525
|
-
eval_requires_latents: bool = False, # AE: always False
|
|
526
|
-
eval_latent_steps: int = 0,
|
|
527
|
-
eval_latent_lr: float = 0.0,
|
|
528
|
-
eval_latent_weight_decay: float = 0.0,
|
|
529
|
-
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
530
|
-
"""Wrap the AE training loop (no latent optimizer), with Optuna pruning.
|
|
531
|
-
|
|
532
|
-
This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
|
|
662
|
+
trial: Optional[optuna.Trial] = None,
|
|
663
|
+
params: Optional[dict[str, Any]] = None,
|
|
664
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
665
|
+
gamma_schedule: bool = False,
|
|
666
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
667
|
+
"""Train and validate the model.
|
|
668
|
+
|
|
669
|
+
This method sets up the optimizer and learning rate scheduler, then executes the training loop with early stopping and optional hyperparameter tuning via Optuna. It returns the best validation loss, the best model, and the training history.
|
|
533
670
|
|
|
534
671
|
Args:
|
|
535
672
|
model (torch.nn.Module): Autoencoder model.
|
|
536
|
-
loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
|
|
537
673
|
lr (float): Learning rate.
|
|
538
|
-
l1_penalty (float): L1 regularization
|
|
539
|
-
trial (optuna.Trial
|
|
540
|
-
|
|
541
|
-
class_weights (torch.Tensor
|
|
542
|
-
|
|
543
|
-
params (dict | None): Model params for evaluation.
|
|
544
|
-
prune_metric (str): Metric for pruning reports.
|
|
545
|
-
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
546
|
-
eval_interval (int): Eval frequency (epochs).
|
|
547
|
-
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
548
|
-
eval_latent_steps (int): Unused for AE.
|
|
549
|
-
eval_latent_lr (float): Unused for AE.
|
|
550
|
-
eval_latent_weight_decay (float): Unused for AE.
|
|
674
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
675
|
+
trial (Optional[optuna.Trial]): Optuna trial (optional).
|
|
676
|
+
params (Optional[dict[str, Any]]): Hyperparams dict (optional).
|
|
677
|
+
class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
|
|
678
|
+
gamma_schedule (bool): Whether to schedule gamma.
|
|
551
679
|
|
|
552
680
|
Returns:
|
|
553
|
-
|
|
681
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, history.
|
|
554
682
|
"""
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
self.logger.error(msg)
|
|
558
|
-
raise TypeError(msg)
|
|
683
|
+
max_epochs = int(self.epochs)
|
|
684
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
559
685
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
686
|
+
scheduler = _make_warmup_cosine_scheduler(
|
|
687
|
+
optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
|
|
563
688
|
)
|
|
564
689
|
|
|
565
|
-
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
566
|
-
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
567
|
-
|
|
568
690
|
best_loss, best_model, hist = self._execute_training_loop(
|
|
569
|
-
loader=loader,
|
|
570
691
|
optimizer=optimizer,
|
|
571
692
|
scheduler=scheduler,
|
|
572
693
|
model=model,
|
|
573
694
|
l1_penalty=l1_penalty,
|
|
574
695
|
trial=trial,
|
|
575
|
-
return_history=return_history,
|
|
576
|
-
class_weights=class_weights,
|
|
577
|
-
X_val=X_val,
|
|
578
696
|
params=params,
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
eval_interval=eval_interval,
|
|
582
|
-
eval_requires_latents=False, # AE: no latent inference
|
|
583
|
-
eval_latent_steps=0,
|
|
584
|
-
eval_latent_lr=0.0,
|
|
585
|
-
eval_latent_weight_decay=0.0,
|
|
697
|
+
class_weights=class_weights,
|
|
698
|
+
gamma_schedule=gamma_schedule,
|
|
586
699
|
)
|
|
587
|
-
|
|
588
|
-
return best_loss, best_model, hist
|
|
589
|
-
|
|
590
|
-
return best_loss, best_model, None
|
|
700
|
+
return best_loss, best_model, hist
|
|
591
701
|
|
|
592
702
|
def _execute_training_loop(
|
|
593
703
|
self,
|
|
594
|
-
|
|
704
|
+
*,
|
|
595
705
|
optimizer: torch.optim.Optimizer,
|
|
596
|
-
scheduler:
|
|
706
|
+
scheduler: (
|
|
707
|
+
torch.optim.lr_scheduler.CosineAnnealingLR
|
|
708
|
+
| torch.optim.lr_scheduler.SequentialLR
|
|
709
|
+
),
|
|
597
710
|
model: torch.nn.Module,
|
|
598
711
|
l1_penalty: float,
|
|
599
|
-
trial: optuna.Trial
|
|
600
|
-
|
|
601
|
-
class_weights: torch.Tensor,
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
prune_metric: str = "f1",
|
|
606
|
-
prune_warmup_epochs: int = 10,
|
|
607
|
-
eval_interval: int = 1,
|
|
608
|
-
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
609
|
-
eval_requires_latents: bool = False, # AE: False
|
|
610
|
-
eval_latent_steps: int = 0,
|
|
611
|
-
eval_latent_lr: float = 0.0,
|
|
612
|
-
eval_latent_weight_decay: float = 0.0,
|
|
613
|
-
) -> Tuple[float, torch.nn.Module, list]:
|
|
614
|
-
"""Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
|
|
615
|
-
|
|
616
|
-
This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
|
|
712
|
+
trial: Optional[optuna.Trial] = None,
|
|
713
|
+
params: Optional[dict[str, Any]] = None,
|
|
714
|
+
class_weights: Optional[torch.Tensor] = None,
|
|
715
|
+
gamma_schedule: bool = False,
|
|
716
|
+
) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
|
|
717
|
+
"""Train AE (masked focal CE) with EarlyStopping + Optuna pruning.
|
|
617
718
|
|
|
618
719
|
Args:
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
|
|
720
|
+
optimizer (torch.optim.Optimizer): Optimizer for training.
|
|
721
|
+
scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): LR scheduler.
|
|
622
722
|
model (torch.nn.Module): Autoencoder model.
|
|
623
|
-
l1_penalty (float): L1 regularization
|
|
624
|
-
trial (optuna.Trial
|
|
625
|
-
|
|
626
|
-
class_weights (torch.Tensor): Class weights
|
|
627
|
-
|
|
628
|
-
params (dict | None): Model params for evaluation.
|
|
629
|
-
prune_metric (str): Metric for pruning reports.
|
|
630
|
-
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
631
|
-
eval_interval (int): Eval frequency (epochs).
|
|
632
|
-
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
633
|
-
eval_latent_steps (int): Unused for AE.
|
|
634
|
-
eval_latent_lr (float): Unused for AE.
|
|
635
|
-
eval_latent_weight_decay (float): Unused for AE.
|
|
723
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
724
|
+
trial (Optional[optuna.Trial]): Optuna trial (optional).
|
|
725
|
+
params (Optional[dict[str, Any]]): Hyperparams dict (optional).
|
|
726
|
+
class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
|
|
727
|
+
gamma_schedule (bool): Whether to schedule gamma.
|
|
636
728
|
|
|
637
729
|
Returns:
|
|
638
|
-
|
|
730
|
+
tuple[float, torch.nn.Module, dict[str, list[float]]]: Best loss, best model, and training history.
|
|
731
|
+
|
|
732
|
+
Notes:
|
|
733
|
+
- Computes loss only where targets are known (~orig_mask_*).
|
|
734
|
+
- Evaluates metrics only on simulated-missing sites (sim_mask_*).
|
|
639
735
|
"""
|
|
640
|
-
|
|
641
|
-
best_model = None
|
|
642
|
-
history: list[float] = []
|
|
736
|
+
history: dict[str, list[float]] = defaultdict(list)
|
|
643
737
|
|
|
644
738
|
early_stopping = EarlyStopping(
|
|
645
739
|
patience=self.early_stop_gen,
|
|
@@ -649,157 +743,157 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
649
743
|
debug=self.debug,
|
|
650
744
|
)
|
|
651
745
|
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
model.gamma = 0.0 # type: ignore
|
|
672
|
-
elif epoch < gamma_warm + gamma_ramp:
|
|
673
|
-
model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore
|
|
746
|
+
gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
|
|
747
|
+
params, "gamma", default=self.gamma, max_epochs=self.epochs
|
|
748
|
+
)
|
|
749
|
+
gamma_target = float(gamma_target)
|
|
750
|
+
|
|
751
|
+
cw = class_weights
|
|
752
|
+
if cw is not None and cw.device != self.device:
|
|
753
|
+
cw = cw.to(self.device)
|
|
754
|
+
|
|
755
|
+
for epoch in range(int(self.epochs)):
|
|
756
|
+
if gamma_schedule:
|
|
757
|
+
gamma_current = self._update_anneal_schedule(
|
|
758
|
+
gamma_target,
|
|
759
|
+
warm=gamma_warm,
|
|
760
|
+
ramp=gamma_ramp,
|
|
761
|
+
epoch=epoch,
|
|
762
|
+
init_val=0.0,
|
|
763
|
+
)
|
|
764
|
+
gamma_val = float(gamma_current)
|
|
674
765
|
else:
|
|
675
|
-
|
|
766
|
+
gamma_val = gamma_target
|
|
676
767
|
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
for g in optimizer.param_groups:
|
|
681
|
-
g["lr"] = min_lr + (base_lr - min_lr) * scale
|
|
768
|
+
ce_criterion = FocalCELoss(
|
|
769
|
+
alpha=cw, gamma=gamma_val, ignore_index=-1, reduction="mean"
|
|
770
|
+
)
|
|
682
771
|
|
|
683
772
|
train_loss = self._train_step(
|
|
684
|
-
loader=
|
|
773
|
+
loader=self.train_loader_,
|
|
685
774
|
optimizer=optimizer,
|
|
686
775
|
model=model,
|
|
776
|
+
ce_criterion=ce_criterion,
|
|
687
777
|
l1_penalty=l1_penalty,
|
|
688
|
-
class_weights=class_weights,
|
|
689
778
|
)
|
|
690
779
|
|
|
691
|
-
# Abort or prune on non-finite epoch loss
|
|
692
780
|
if not np.isfinite(train_loss):
|
|
693
781
|
if trial is not None:
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
)
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
782
|
+
msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
|
|
783
|
+
self.logger.warning(msg)
|
|
784
|
+
raise optuna.exceptions.TrialPruned(msg)
|
|
785
|
+
msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
|
|
786
|
+
self.logger.error(msg)
|
|
787
|
+
raise RuntimeError(msg)
|
|
788
|
+
|
|
789
|
+
val_loss = self._val_step(
|
|
790
|
+
loader=self.val_loader_,
|
|
791
|
+
model=model,
|
|
792
|
+
ce_criterion=ce_criterion,
|
|
793
|
+
l1_penalty=l1_penalty,
|
|
794
|
+
)
|
|
702
795
|
|
|
703
796
|
scheduler.step()
|
|
704
|
-
|
|
705
|
-
|
|
797
|
+
history["Train"].append(float(train_loss))
|
|
798
|
+
history["Val"].append(float(val_loss))
|
|
706
799
|
|
|
707
|
-
early_stopping(
|
|
800
|
+
early_stopping(val_loss, model)
|
|
708
801
|
if early_stopping.early_stop:
|
|
709
|
-
self.logger.
|
|
802
|
+
self.logger.debug(
|
|
803
|
+
f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
|
|
804
|
+
)
|
|
710
805
|
break
|
|
711
806
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
trial is not None
|
|
715
|
-
and X_val is not None
|
|
716
|
-
and ((epoch + 1) % eval_interval == 0)
|
|
717
|
-
):
|
|
718
|
-
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
719
|
-
mask_override = None
|
|
720
|
-
if (
|
|
721
|
-
self.simulate_missing
|
|
722
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
723
|
-
and getattr(self, "X_val_", None) is not None
|
|
724
|
-
and X_val.shape == self.X_val_.shape
|
|
725
|
-
):
|
|
726
|
-
mask_override = self.sim_mask_test_
|
|
727
|
-
metric_val = self._eval_for_pruning(
|
|
807
|
+
if trial is not None:
|
|
808
|
+
metric_vals = self._evaluate_model(
|
|
728
809
|
model=model,
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
810
|
+
X=self.X_val_,
|
|
811
|
+
y=self.y_val_,
|
|
812
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
732
813
|
objective_mode=True,
|
|
733
|
-
do_latent_infer=False, # AE: False
|
|
734
|
-
latent_steps=0,
|
|
735
|
-
latent_lr=0.0,
|
|
736
|
-
latent_weight_decay=0.0,
|
|
737
|
-
latent_seed=self.seed, # type: ignore
|
|
738
|
-
_latent_cache=None, # AE: not used
|
|
739
|
-
_latent_cache_key=None,
|
|
740
|
-
eval_mask_override=mask_override,
|
|
741
814
|
)
|
|
742
|
-
trial.report(
|
|
743
|
-
if
|
|
815
|
+
trial.report(metric_vals[self.tune_metric], step=epoch + 1)
|
|
816
|
+
if trial.should_prune():
|
|
744
817
|
raise optuna.exceptions.TrialPruned(
|
|
745
|
-
f"
|
|
818
|
+
f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
|
|
746
819
|
)
|
|
747
820
|
|
|
748
|
-
best_loss = early_stopping.best_score
|
|
749
|
-
if early_stopping.
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
return best_loss, best_model, history
|
|
821
|
+
best_loss = float(early_stopping.best_score)
|
|
822
|
+
if early_stopping.best_state_dict is not None:
|
|
823
|
+
model.load_state_dict(early_stopping.best_state_dict)
|
|
824
|
+
|
|
825
|
+
return best_loss, model, dict(history)
|
|
754
826
|
|
|
755
827
|
def _train_step(
|
|
756
828
|
self,
|
|
757
829
|
loader: torch.utils.data.DataLoader,
|
|
758
830
|
optimizer: torch.optim.Optimizer,
|
|
759
831
|
model: torch.nn.Module,
|
|
832
|
+
ce_criterion: torch.nn.Module,
|
|
833
|
+
*,
|
|
760
834
|
l1_penalty: float,
|
|
761
|
-
class_weights: torch.Tensor,
|
|
762
835
|
) -> float:
|
|
763
|
-
"""
|
|
836
|
+
"""Single epoch train step (masked focal CE + optional L1).
|
|
837
|
+
|
|
838
|
+
Args:
|
|
839
|
+
loader (torch.utils.data.DataLoader): Training data loader.
|
|
840
|
+
optimizer (torch.optim.Optimizer): Optimizer for training.
|
|
841
|
+
model (torch.nn.Module): Autoencoder model.
|
|
842
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
843
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
844
|
+
|
|
845
|
+
Returns:
|
|
846
|
+
float: Average training loss over the epoch.
|
|
847
|
+
|
|
848
|
+
Notes:
|
|
849
|
+
Expects loader batches as (X_ohe, y_int, mask_bool) where:
|
|
850
|
+
- X_ohe: (B, L, C) float/compatible
|
|
851
|
+
- y_int: (B, L) int, with -1 for unknown targets
|
|
852
|
+
- mask_bool: (B, L) bool selecting which positions contribute to loss
|
|
853
|
+
"""
|
|
764
854
|
model.train()
|
|
765
855
|
running = 0.0
|
|
766
856
|
num_batches = 0
|
|
857
|
+
|
|
858
|
+
nF_model = int(getattr(model, "n_features", self.num_features_))
|
|
859
|
+
nC_model = int(getattr(model, "num_classes", self.num_classes_))
|
|
767
860
|
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
768
|
-
if class_weights is not None and class_weights.device != self.device:
|
|
769
|
-
class_weights = class_weights.to(self.device)
|
|
770
|
-
|
|
771
|
-
# Use model.gamma if present, else self.gamma
|
|
772
|
-
gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
|
|
773
|
-
gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
|
|
774
|
-
ce_criterion = SafeFocalCELoss(
|
|
775
|
-
gamma=gamma, weight=class_weights, ignore_index=-1
|
|
776
|
-
)
|
|
777
861
|
|
|
778
|
-
for
|
|
862
|
+
for X_batch, y_batch, m_batch in loader:
|
|
779
863
|
optimizer.zero_grad(set_to_none=True)
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
x_in = self._encode_multilabel_inputs(y_batch) # (B, L, 2)
|
|
793
|
-
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
794
|
-
if not torch.isfinite(logits).all():
|
|
795
|
-
continue
|
|
796
|
-
pos_w = getattr(self, "pos_weights_", None)
|
|
797
|
-
targets = self._multi_hot_targets(y_batch) # float, same shape
|
|
798
|
-
bce = F.binary_cross_entropy_with_logits(
|
|
799
|
-
logits, targets, pos_weight=pos_w, reduction="none"
|
|
864
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
865
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
866
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
867
|
+
|
|
868
|
+
if (
|
|
869
|
+
X_batch.dim() != 3
|
|
870
|
+
or X_batch.shape[1] != nF_model
|
|
871
|
+
or X_batch.shape[2] != nC_model
|
|
872
|
+
):
|
|
873
|
+
msg = (
|
|
874
|
+
f"Train batch X shape mismatch: expected (B,{nF_model},{nC_model}), "
|
|
875
|
+
f"got {tuple(X_batch.shape)}."
|
|
800
876
|
)
|
|
801
|
-
|
|
802
|
-
|
|
877
|
+
self.logger.error(msg)
|
|
878
|
+
raise ValueError(msg)
|
|
879
|
+
|
|
880
|
+
logits_flat = model(X_batch)
|
|
881
|
+
expected = (X_batch.shape[0], nF_model * nC_model)
|
|
882
|
+
if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
|
|
883
|
+
try:
|
|
884
|
+
logits_flat = logits_flat.view(*expected)
|
|
885
|
+
except Exception as e:
|
|
886
|
+
msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
|
|
887
|
+
self.logger.error(msg)
|
|
888
|
+
raise ValueError(msg) from e
|
|
889
|
+
|
|
890
|
+
logits = logits_flat.view(-1, nF_model, nC_model)
|
|
891
|
+
logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
|
|
892
|
+
|
|
893
|
+
targets_masked = y_batch.view(-1)
|
|
894
|
+
targets_masked = targets_masked[m_batch.view(-1)]
|
|
895
|
+
|
|
896
|
+
loss = ce_criterion(logits_masked, targets_masked)
|
|
803
897
|
|
|
804
898
|
if l1_penalty > 0:
|
|
805
899
|
l1 = torch.zeros((), device=self.device)
|
|
@@ -807,247 +901,234 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
807
901
|
l1 = l1 + p.abs().sum()
|
|
808
902
|
loss = loss + l1_penalty * l1
|
|
809
903
|
|
|
810
|
-
# Final guard
|
|
811
904
|
if not torch.isfinite(loss):
|
|
812
905
|
continue
|
|
813
906
|
|
|
814
907
|
loss.backward()
|
|
815
|
-
|
|
816
|
-
# Clip to prevent exploding grads
|
|
817
908
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
818
|
-
|
|
819
|
-
# If grads blew up to non-finite, skip update
|
|
820
|
-
if any(
|
|
821
|
-
(not torch.isfinite(p.grad).all())
|
|
822
|
-
for p in model.parameters()
|
|
823
|
-
if p.grad is not None
|
|
824
|
-
):
|
|
825
|
-
optimizer.zero_grad(set_to_none=True)
|
|
826
|
-
continue
|
|
827
|
-
|
|
828
909
|
optimizer.step()
|
|
829
910
|
|
|
830
911
|
running += float(loss.detach().item())
|
|
831
912
|
num_batches += 1
|
|
832
913
|
|
|
833
|
-
if num_batches == 0
|
|
834
|
-
|
|
835
|
-
|
|
914
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
915
|
+
|
|
916
|
+
def _val_step(
|
|
917
|
+
self,
|
|
918
|
+
loader: torch.utils.data.DataLoader,
|
|
919
|
+
model: torch.nn.Module,
|
|
920
|
+
ce_criterion: torch.nn.Module,
|
|
921
|
+
*,
|
|
922
|
+
l1_penalty: float,
|
|
923
|
+
) -> float:
|
|
924
|
+
"""Validation step (masked focal CE + optional L1).
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
loader (torch.utils.data.DataLoader): Validation data loader.
|
|
928
|
+
model (torch.nn.Module): Autoencoder model.
|
|
929
|
+
ce_criterion (torch.nn.Module): Cross-entropy loss function.
|
|
930
|
+
l1_penalty (float): L1 regularization coefficient.
|
|
931
|
+
|
|
932
|
+
Returns:
|
|
933
|
+
float: Average validation loss over the epoch.
|
|
934
|
+
"""
|
|
935
|
+
model.eval()
|
|
936
|
+
running = 0.0
|
|
937
|
+
num_batches = 0
|
|
938
|
+
|
|
939
|
+
nF_model = self.num_features_
|
|
940
|
+
nC_model = self.num_classes_
|
|
941
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
942
|
+
|
|
943
|
+
with torch.no_grad():
|
|
944
|
+
for X_batch, y_batch, m_batch in loader:
|
|
945
|
+
X_batch = X_batch.to(self.device, non_blocking=True).float()
|
|
946
|
+
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
947
|
+
m_batch = m_batch.to(self.device, non_blocking=True).bool()
|
|
948
|
+
|
|
949
|
+
logits_flat = model(X_batch)
|
|
950
|
+
expected = (X_batch.shape[0], nF_model * nC_model)
|
|
951
|
+
|
|
952
|
+
if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
|
|
953
|
+
try:
|
|
954
|
+
logits_flat = logits_flat.view(*expected)
|
|
955
|
+
except Exception as e:
|
|
956
|
+
msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
|
|
957
|
+
self.logger.error(msg)
|
|
958
|
+
raise ValueError(msg) from e
|
|
959
|
+
|
|
960
|
+
logits = logits_flat.view(-1, nF_model, nC_model)
|
|
961
|
+
logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
|
|
962
|
+
targets_masked = y_batch.view(-1)[m_batch.view(-1)]
|
|
963
|
+
|
|
964
|
+
if targets_masked.numel() == 0:
|
|
965
|
+
continue
|
|
966
|
+
|
|
967
|
+
loss = ce_criterion(logits_masked, targets_masked)
|
|
968
|
+
|
|
969
|
+
if l1_penalty > 0:
|
|
970
|
+
l1 = torch.zeros((), device=self.device)
|
|
971
|
+
for p in l1_params:
|
|
972
|
+
l1 = l1 + p.abs().sum()
|
|
973
|
+
loss = loss + l1_penalty * l1
|
|
974
|
+
|
|
975
|
+
if not torch.isfinite(loss):
|
|
976
|
+
continue
|
|
977
|
+
|
|
978
|
+
running += float(loss.item())
|
|
979
|
+
num_batches += 1
|
|
980
|
+
|
|
981
|
+
return float("inf") if num_batches == 0 else running / num_batches
|
|
836
982
|
|
|
837
983
|
def _predict(
|
|
838
984
|
self,
|
|
839
985
|
model: torch.nn.Module,
|
|
840
986
|
X: np.ndarray | torch.Tensor,
|
|
987
|
+
*,
|
|
841
988
|
return_proba: bool = False,
|
|
842
|
-
) ->
|
|
843
|
-
"""Predict
|
|
844
|
-
|
|
845
|
-
This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
|
|
989
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
990
|
+
"""Predict categorical genotype labels from logits.
|
|
846
991
|
|
|
847
992
|
Args:
|
|
848
993
|
model (torch.nn.Module): Trained model.
|
|
849
|
-
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
|
|
850
|
-
|
|
851
|
-
return_proba (bool): If True, return probabilities.
|
|
994
|
+
X (np.ndarray | torch.Tensor): 2D 0/1/2 matrix with -1 for missing, or 3D one-hot (B, L, K).
|
|
995
|
+
return_proba (bool): If True, return probabilities (B, L, K).
|
|
852
996
|
|
|
853
997
|
Returns:
|
|
854
|
-
|
|
855
|
-
and probabilities if requested.
|
|
998
|
+
tuple[np.ndarray, np.ndarray | None]: Predicted labels and optionally probabilities.
|
|
856
999
|
"""
|
|
857
1000
|
if model is None:
|
|
858
1001
|
msg = "Model is not trained. Call fit() before predict()."
|
|
859
1002
|
self.logger.error(msg)
|
|
860
1003
|
raise NotFittedError(msg)
|
|
861
1004
|
|
|
1005
|
+
nF = self.num_features_
|
|
1006
|
+
nC = self.num_classes_
|
|
1007
|
+
|
|
1008
|
+
if isinstance(X, torch.Tensor):
|
|
1009
|
+
X_tensor = X
|
|
1010
|
+
else:
|
|
1011
|
+
X_tensor = torch.from_numpy(X)
|
|
1012
|
+
X_tensor = X_tensor.float()
|
|
1013
|
+
|
|
1014
|
+
if X_tensor.device != self.device:
|
|
1015
|
+
X_tensor = X_tensor.to(self.device)
|
|
1016
|
+
|
|
1017
|
+
if X_tensor.dim() == 2:
|
|
1018
|
+
# 0/1/2 matrix -> one-hot for model input
|
|
1019
|
+
X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
|
|
1020
|
+
X_tensor = X_tensor.float()
|
|
1021
|
+
if X_tensor.device != self.device:
|
|
1022
|
+
X_tensor = X_tensor.to(self.device)
|
|
1023
|
+
|
|
1024
|
+
elif X_tensor.dim() != 3:
|
|
1025
|
+
msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
|
|
1026
|
+
self.logger.error(msg)
|
|
1027
|
+
raise ValueError(msg)
|
|
1028
|
+
|
|
1029
|
+
if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
|
|
1030
|
+
msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
|
|
1031
|
+
self.logger.error(msg)
|
|
1032
|
+
raise ValueError(msg)
|
|
1033
|
+
|
|
1034
|
+
X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
|
|
1035
|
+
|
|
862
1036
|
model.eval()
|
|
863
1037
|
with torch.no_grad():
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
probas = torch.softmax(logits, dim=-1)
|
|
870
|
-
labels = torch.argmax(probas, dim=-1)
|
|
871
|
-
else:
|
|
872
|
-
x_in = self._encode_multilabel_inputs(X_tensor)
|
|
873
|
-
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
874
|
-
probas_2 = torch.sigmoid(logits)
|
|
875
|
-
p_ref = probas_2[..., 0]
|
|
876
|
-
p_alt = probas_2[..., 1]
|
|
877
|
-
p_het = p_ref * p_alt
|
|
878
|
-
p_ref_only = p_ref * (1 - p_alt)
|
|
879
|
-
p_alt_only = p_alt * (1 - p_ref)
|
|
880
|
-
stacked = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
|
|
881
|
-
stacked = stacked / stacked.sum(dim=-1, keepdim=True).clamp_min(1e-8)
|
|
882
|
-
probas = stacked
|
|
883
|
-
labels = torch.argmax(stacked, dim=-1)
|
|
1038
|
+
logits_flat = model(X_tensor)
|
|
1039
|
+
logits = logits_flat.view(-1, nF, nC)
|
|
1040
|
+
|
|
1041
|
+
probas = torch.softmax(logits, dim=-1)
|
|
1042
|
+
labels = torch.argmax(probas, dim=-1)
|
|
884
1043
|
|
|
885
1044
|
if return_proba:
|
|
886
1045
|
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
887
|
-
|
|
888
|
-
return labels.cpu().numpy()
|
|
889
|
-
|
|
890
|
-
def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
|
|
891
|
-
"""Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
|
|
892
|
-
if self.is_haploid:
|
|
893
|
-
return self._one_hot_encode_012(y)
|
|
894
|
-
y = y.to(self.device)
|
|
895
|
-
shape = y.shape + (2,)
|
|
896
|
-
out = torch.zeros(shape, device=self.device, dtype=torch.float32)
|
|
897
|
-
valid = y != -1
|
|
898
|
-
ref_mask = valid & (y != 2)
|
|
899
|
-
alt_mask = valid & (y != 0)
|
|
900
|
-
out[ref_mask, 0] = 1.0
|
|
901
|
-
out[alt_mask, 1] = 1.0
|
|
902
|
-
return out
|
|
903
|
-
|
|
904
|
-
def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
|
|
905
|
-
"""Targets aligned with _encode_multilabel_inputs for diploid training."""
|
|
906
|
-
if self.is_haploid:
|
|
907
|
-
# One-hot CE path expects integer targets; handled upstream.
|
|
908
|
-
raise RuntimeError("_multi_hot_targets called for haploid data.")
|
|
909
|
-
y = y.to(self.device)
|
|
910
|
-
out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
|
|
911
|
-
valid = y != -1
|
|
912
|
-
ref_mask = valid & (y != 2)
|
|
913
|
-
alt_mask = valid & (y != 0)
|
|
914
|
-
out[ref_mask, 0] = 1.0
|
|
915
|
-
out[alt_mask, 1] = 1.0
|
|
916
|
-
return out
|
|
917
|
-
|
|
918
|
-
def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
|
|
919
|
-
"""Balance REF/ALT channels for multilabel BCE."""
|
|
920
|
-
ref_pos = np.count_nonzero((X == 0) | (X == 1))
|
|
921
|
-
alt_pos = np.count_nonzero((X == 2) | (X == 1))
|
|
922
|
-
total_valid = np.count_nonzero(X != -1)
|
|
923
|
-
pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
|
|
924
|
-
neg_counts = np.maximum(total_valid - pos_counts, 1.0)
|
|
925
|
-
pos_counts = np.maximum(pos_counts, 1.0)
|
|
926
|
-
weights = neg_counts / pos_counts
|
|
927
|
-
return torch.tensor(weights, device=self.device, dtype=torch.float32)
|
|
1046
|
+
return labels.cpu().numpy(), None
|
|
928
1047
|
|
|
929
1048
|
def _evaluate_model(
|
|
930
1049
|
self,
|
|
931
|
-
X_val: np.ndarray,
|
|
932
1050
|
model: torch.nn.Module,
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
1051
|
+
X: np.ndarray,
|
|
1052
|
+
y: np.ndarray,
|
|
1053
|
+
eval_mask: np.ndarray,
|
|
936
1054
|
*,
|
|
937
|
-
|
|
1055
|
+
objective_mode: bool = False,
|
|
938
1056
|
) -> Dict[str, float]:
|
|
939
1057
|
"""Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
|
|
940
1058
|
|
|
941
|
-
This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
|
|
942
|
-
|
|
943
1059
|
Args:
|
|
944
|
-
X_val (np.ndarray): Validation set 0/1/2 matrix with -1
|
|
945
|
-
for missing.
|
|
946
1060
|
model (torch.nn.Module): Trained model.
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
1061
|
+
X (np.ndarray): 2D 0/1/2 matrix with -1 for missing.
|
|
1062
|
+
y (np.ndarray): 2D 0/1/2 ground truth matrix with -1 for missing.
|
|
1063
|
+
eval_mask (np.ndarray): 2D boolean mask selecting sites to evaluate.
|
|
1064
|
+
objective_mode (bool): If True, suppress detailed reports and plots.
|
|
951
1065
|
|
|
952
1066
|
Returns:
|
|
953
1067
|
Dict[str, float]: Dictionary of evaluation metrics.
|
|
954
1068
|
"""
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
1069
|
+
if model is None:
|
|
1070
|
+
msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
|
|
1071
|
+
self.logger.error(msg)
|
|
1072
|
+
raise NotFittedError(msg)
|
|
958
1073
|
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
# FIX 1: Check ROWS (shape[0]) only. X_val might be a feature subset.
|
|
962
|
-
if (
|
|
963
|
-
hasattr(self, "X_val_")
|
|
964
|
-
and getattr(self, "X_val_", None) is not None
|
|
965
|
-
and X_val.shape[0] == self.X_val_.shape[0]
|
|
966
|
-
):
|
|
967
|
-
GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
|
|
968
|
-
elif (
|
|
969
|
-
hasattr(self, "X_train_")
|
|
970
|
-
and getattr(self, "X_train_", None) is not None
|
|
971
|
-
and X_val.shape[0] == self.X_train_.shape[0]
|
|
972
|
-
):
|
|
973
|
-
GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
|
|
974
|
-
else:
|
|
975
|
-
GT_ref = self.ground_truth_
|
|
976
|
-
|
|
977
|
-
# FIX 2: Handle Feature Mismatch (e.g., tune_fast feature subsetting)
|
|
978
|
-
# If the GT source has more columns than X_val, slice it to match.
|
|
979
|
-
if GT_ref.shape[1] > X_val.shape[1]:
|
|
980
|
-
GT_ref = GT_ref[:, : X_val.shape[1]]
|
|
981
|
-
|
|
982
|
-
# Fallback if rows mismatch (unlikely after Fix 1, but safe to keep)
|
|
983
|
-
if GT_ref.shape != X_val.shape:
|
|
984
|
-
# If completely different, we can't use the ground truth object.
|
|
985
|
-
# Fall back to X_val (this implies only observed values are scored)
|
|
986
|
-
GT_ref = X_val
|
|
987
|
-
|
|
988
|
-
if eval_mask_override is not None:
|
|
989
|
-
# FIX 3: Allow override mask to be sliced if it's too wide
|
|
990
|
-
if eval_mask_override.shape[0] != X_val.shape[0]:
|
|
991
|
-
msg = (
|
|
992
|
-
f"eval_mask_override rows {eval_mask_override.shape[0]} "
|
|
993
|
-
f"does not match X_val rows {X_val.shape[0]}"
|
|
994
|
-
)
|
|
995
|
-
self.logger.error(msg)
|
|
996
|
-
raise ValueError(msg)
|
|
1074
|
+
pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
|
|
997
1075
|
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
else:
|
|
1003
|
-
eval_mask = X_val != -1
|
|
1076
|
+
if pred_probas is None:
|
|
1077
|
+
msg = "Predicted probabilities are None in _evaluate_model()."
|
|
1078
|
+
self.logger.error(msg)
|
|
1079
|
+
raise ValueError(msg)
|
|
1004
1080
|
|
|
1005
|
-
|
|
1006
|
-
|
|
1081
|
+
y_true_flat = y[eval_mask].astype(np.int8, copy=False)
|
|
1082
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
|
|
1083
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
|
|
1007
1084
|
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1085
|
+
valid = y_true_flat >= 0
|
|
1086
|
+
y_true_flat = y_true_flat[valid]
|
|
1087
|
+
y_pred_flat = y_pred_flat[valid]
|
|
1088
|
+
y_proba_flat = y_proba_flat[valid]
|
|
1011
1089
|
|
|
1012
1090
|
if y_true_flat.size == 0:
|
|
1013
|
-
self.tune_metric = "f1" if self.tune_metric is None else self.tune_metric
|
|
1014
1091
|
return {self.tune_metric: 0.0}
|
|
1015
1092
|
|
|
1016
|
-
|
|
1093
|
+
if y_proba_flat.ndim != 2:
|
|
1094
|
+
msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got {y_proba_flat.shape}."
|
|
1095
|
+
self.logger.error(msg)
|
|
1096
|
+
raise ValueError(msg)
|
|
1097
|
+
|
|
1098
|
+
K = int(y_proba_flat.shape[1])
|
|
1099
|
+
if self.is_haploid_:
|
|
1100
|
+
if K not in (2, 3):
|
|
1101
|
+
msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
|
|
1102
|
+
self.logger.error(msg)
|
|
1103
|
+
raise ValueError(msg)
|
|
1104
|
+
else:
|
|
1105
|
+
if K != 3:
|
|
1106
|
+
msg = f"Diploid evaluation expects 3 classes; got {K}."
|
|
1107
|
+
self.logger.error(msg)
|
|
1108
|
+
raise ValueError(msg)
|
|
1109
|
+
|
|
1110
|
+
if self.is_haploid_:
|
|
1111
|
+
y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
|
|
1112
|
+
y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
|
|
1113
|
+
|
|
1114
|
+
if K == 3:
|
|
1115
|
+
proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
|
|
1116
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
1117
|
+
proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
|
|
1118
|
+
y_proba_flat = proba_2
|
|
1119
|
+
|
|
1120
|
+
labels_for_scoring = [0, 1]
|
|
1121
|
+
target_names = ["REF", "ALT"]
|
|
1122
|
+
else:
|
|
1123
|
+
labels_for_scoring = [0, 1, 2]
|
|
1124
|
+
target_names = ["REF", "HET", "ALT"]
|
|
1125
|
+
|
|
1017
1126
|
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
1018
1127
|
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
1019
|
-
row_sums[row_sums == 0] = 1.0
|
|
1128
|
+
row_sums[row_sums == 0.0] = 1.0
|
|
1020
1129
|
y_proba_flat = y_proba_flat / row_sums
|
|
1021
1130
|
|
|
1022
|
-
|
|
1023
|
-
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
1024
|
-
|
|
1025
|
-
if self.is_haploid:
|
|
1026
|
-
y_true_flat = y_true_flat.copy()
|
|
1027
|
-
y_pred_flat = y_pred_flat.copy()
|
|
1028
|
-
y_true_flat[y_true_flat == 2] = 1
|
|
1029
|
-
y_pred_flat[y_pred_flat == 2] = 1
|
|
1030
|
-
# collapse probs to 2-class
|
|
1031
|
-
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
1032
|
-
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
1033
|
-
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
1034
|
-
y_proba_flat = proba_2
|
|
1035
|
-
|
|
1036
|
-
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
1037
|
-
|
|
1038
|
-
tune_metric_tmp: Literal[
|
|
1039
|
-
"pr_macro",
|
|
1040
|
-
"roc_auc",
|
|
1041
|
-
"average_precision",
|
|
1042
|
-
"accuracy",
|
|
1043
|
-
"f1",
|
|
1044
|
-
"precision",
|
|
1045
|
-
"recall",
|
|
1046
|
-
]
|
|
1047
|
-
if self.tune_metric_ is not None:
|
|
1048
|
-
tune_metric_tmp = self.tune_metric_
|
|
1049
|
-
else:
|
|
1050
|
-
tune_metric_tmp = "f1" # Default if not tuning
|
|
1131
|
+
y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
|
|
1051
1132
|
|
|
1052
1133
|
metrics = self.scorers_.evaluate(
|
|
1053
1134
|
y_true_flat,
|
|
@@ -1055,16 +1136,29 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1055
1136
|
y_true_ohe,
|
|
1056
1137
|
y_proba_flat,
|
|
1057
1138
|
objective_mode,
|
|
1058
|
-
|
|
1139
|
+
cast(
|
|
1140
|
+
Literal[
|
|
1141
|
+
"pr_macro",
|
|
1142
|
+
"roc_auc",
|
|
1143
|
+
"accuracy",
|
|
1144
|
+
"f1",
|
|
1145
|
+
"average_precision",
|
|
1146
|
+
"precision",
|
|
1147
|
+
"recall",
|
|
1148
|
+
"mcc",
|
|
1149
|
+
"jaccard",
|
|
1150
|
+
],
|
|
1151
|
+
self.tune_metric,
|
|
1152
|
+
),
|
|
1059
1153
|
)
|
|
1060
1154
|
|
|
1061
1155
|
if not objective_mode:
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1156
|
+
if self.verbose or self.debug:
|
|
1157
|
+
pm = PrettyMetrics(
|
|
1158
|
+
metrics, precision=2, title=f"{self.model_name} Validation Metrics"
|
|
1159
|
+
)
|
|
1160
|
+
pm.render()
|
|
1066
1161
|
|
|
1067
|
-
# Primary report (REF/HET/ALT or REF/ALT)
|
|
1068
1162
|
self._make_class_reports(
|
|
1069
1163
|
y_true=y_true_flat,
|
|
1070
1164
|
y_pred_proba=y_proba_flat,
|
|
@@ -1073,18 +1167,15 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1073
1167
|
labels=target_names,
|
|
1074
1168
|
)
|
|
1075
1169
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
y_true_dec = self.pgenc.decode_012(
|
|
1079
|
-
GT_ref.reshape(X_val.shape[0], X_val.shape[1])
|
|
1080
|
-
)
|
|
1081
|
-
X_pred = X_val.copy()
|
|
1082
|
-
X_pred[eval_mask] = y_pred_flat
|
|
1170
|
+
y_true_matrix = np.array(y, copy=True)
|
|
1171
|
+
y_pred_matrix = np.array(pred_labels, copy=True)
|
|
1083
1172
|
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1173
|
+
if self.is_haploid_:
|
|
1174
|
+
y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
|
|
1175
|
+
y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
|
|
1176
|
+
|
|
1177
|
+
y_true_dec = self.decode_012(y_true_matrix)
|
|
1178
|
+
y_pred_dec = self.decode_012(y_pred_matrix)
|
|
1088
1179
|
|
|
1089
1180
|
encodings_dict = {
|
|
1090
1181
|
"A": 0,
|
|
@@ -1123,239 +1214,177 @@ class ImputeAutoencoder(BaseNNImputer):
|
|
|
1123
1214
|
return metrics
|
|
1124
1215
|
|
|
1125
1216
|
def _objective(self, trial: optuna.Trial) -> float:
|
|
1126
|
-
"""Optuna objective for AE
|
|
1127
|
-
|
|
1128
|
-
This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
|
|
1217
|
+
"""Optuna objective for AE (mirrors VAE flow, excluding KL-specific parts).
|
|
1129
1218
|
|
|
1130
1219
|
Args:
|
|
1131
|
-
trial (optuna.Trial): Optuna trial.
|
|
1220
|
+
trial (optuna.Trial): Optuna trial object.
|
|
1132
1221
|
|
|
1133
1222
|
Returns:
|
|
1134
|
-
float: Value of the tuning metric
|
|
1223
|
+
float: Value of the tuning metric to optimize.
|
|
1135
1224
|
"""
|
|
1136
1225
|
try:
|
|
1137
|
-
# Sample hyperparameters (existing helper; unchanged signature)
|
|
1138
1226
|
params = self._sample_hyperparameters(trial)
|
|
1139
1227
|
|
|
1140
|
-
# Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
|
|
1141
|
-
X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
|
|
1142
|
-
X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
|
|
1143
|
-
|
|
1144
|
-
class_weights = self._normalize_class_weights(
|
|
1145
|
-
self._class_weights_from_zygosity(X_train)
|
|
1146
|
-
)
|
|
1147
|
-
train_loader = self._get_data_loaders(X_train)
|
|
1148
|
-
|
|
1149
1228
|
model = self.build_model(self.Model, params["model_params"])
|
|
1150
1229
|
model.apply(self.initialize_weights)
|
|
1151
1230
|
|
|
1152
|
-
lr
|
|
1153
|
-
l1_penalty
|
|
1231
|
+
lr = float(params["learning_rate"])
|
|
1232
|
+
l1_penalty = float(params["l1_penalty"])
|
|
1233
|
+
|
|
1234
|
+
class_weights = self._class_weights_from_zygosity(
|
|
1235
|
+
self.y_train_,
|
|
1236
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1237
|
+
inverse=params["inverse"],
|
|
1238
|
+
normalize=params["normalize"],
|
|
1239
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1240
|
+
power=params["power"],
|
|
1241
|
+
)
|
|
1154
1242
|
|
|
1155
|
-
|
|
1156
|
-
_, model, __ = self._train_and_validate_model(
|
|
1243
|
+
loss, model, _hist = self._train_and_validate_model(
|
|
1157
1244
|
model=model,
|
|
1158
|
-
loader=train_loader,
|
|
1159
1245
|
lr=lr,
|
|
1160
1246
|
l1_penalty=l1_penalty,
|
|
1247
|
+
params=params,
|
|
1161
1248
|
trial=trial,
|
|
1162
|
-
return_history=False,
|
|
1163
1249
|
class_weights=class_weights,
|
|
1164
|
-
|
|
1165
|
-
params=params,
|
|
1166
|
-
prune_metric=self.tune_metric,
|
|
1167
|
-
prune_warmup_epochs=10,
|
|
1168
|
-
eval_interval=self.tune_eval_interval,
|
|
1169
|
-
eval_requires_latents=False,
|
|
1170
|
-
eval_latent_steps=0,
|
|
1171
|
-
eval_latent_lr=0.0,
|
|
1172
|
-
eval_latent_weight_decay=0.0,
|
|
1250
|
+
gamma_schedule=params["gamma_schedule"],
|
|
1173
1251
|
)
|
|
1174
1252
|
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
1180
|
-
)
|
|
1181
|
-
else None
|
|
1182
|
-
)
|
|
1253
|
+
if model is None or not np.isfinite(loss):
|
|
1254
|
+
msg = "Model training returned None or non-finite loss in tuning objective."
|
|
1255
|
+
self.logger.error(msg)
|
|
1256
|
+
raise RuntimeError(msg)
|
|
1183
1257
|
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
)
|
|
1192
|
-
self._clear_resources(model, train_loader)
|
|
1193
|
-
else:
|
|
1194
|
-
raise TypeError("Model training failed; no model was returned.")
|
|
1258
|
+
metrics = self._evaluate_model(
|
|
1259
|
+
model=model,
|
|
1260
|
+
X=self.X_val_,
|
|
1261
|
+
y=self.y_val_,
|
|
1262
|
+
eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
|
|
1263
|
+
objective_mode=True,
|
|
1264
|
+
)
|
|
1195
1265
|
|
|
1196
|
-
|
|
1266
|
+
self._clear_resources(model)
|
|
1267
|
+
return float(metrics[self.tune_metric])
|
|
1197
1268
|
|
|
1198
1269
|
except Exception as e:
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1270
|
+
err_type = type(e).__name__
|
|
1271
|
+
self.logger.warning(
|
|
1272
|
+
f"Trial {trial.number} failed due to exception {err_type}: {e}"
|
|
1273
|
+
)
|
|
1274
|
+
self.logger.debug(traceback.format_exc())
|
|
1275
|
+
raise optuna.exceptions.TrialPruned(
|
|
1276
|
+
f"Trial {trial.number} failed due to an exception. {err_type}: {e}. "
|
|
1277
|
+
"Enable debug logging for full traceback."
|
|
1278
|
+
) from e
|
|
1204
1279
|
|
|
1205
|
-
|
|
1280
|
+
def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
|
|
1281
|
+
"""Sample AE hyperparameters; hidden sizes mirror VAE helper (excluding KL).
|
|
1206
1282
|
|
|
1207
1283
|
Args:
|
|
1208
1284
|
trial (optuna.Trial): Optuna trial object.
|
|
1209
1285
|
|
|
1210
1286
|
Returns:
|
|
1211
|
-
|
|
1287
|
+
dict: Sampled hyperparameters.
|
|
1212
1288
|
"""
|
|
1213
1289
|
params = {
|
|
1214
|
-
"latent_dim": trial.suggest_int("latent_dim",
|
|
1215
|
-
"
|
|
1216
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.
|
|
1217
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1,
|
|
1290
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
1291
|
+
"learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
|
|
1292
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
|
|
1293
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
|
|
1218
1294
|
"activation": trial.suggest_categorical(
|
|
1219
1295
|
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
1220
1296
|
),
|
|
1221
1297
|
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
1222
1298
|
"layer_scaling_factor": trial.suggest_float(
|
|
1223
|
-
"layer_scaling_factor", 2.0,
|
|
1299
|
+
"layer_scaling_factor", 2.0, 10.0, step=0.025
|
|
1224
1300
|
),
|
|
1225
1301
|
"layer_schedule": trial.suggest_categorical(
|
|
1226
1302
|
"layer_schedule", ["pyramid", "linear"]
|
|
1227
1303
|
),
|
|
1304
|
+
"power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
|
|
1305
|
+
"normalize": trial.suggest_categorical("normalize", [True, False]),
|
|
1306
|
+
"inverse": trial.suggest_categorical("inverse", [True, False]),
|
|
1307
|
+
"gamma": trial.suggest_float("gamma", 0.0, 10.0, step=0.1),
|
|
1308
|
+
"gamma_schedule": trial.suggest_categorical(
|
|
1309
|
+
"gamma_schedule", [True, False]
|
|
1310
|
+
),
|
|
1228
1311
|
}
|
|
1229
1312
|
|
|
1230
|
-
nF
|
|
1231
|
-
nC
|
|
1313
|
+
nF = int(self.num_features_)
|
|
1314
|
+
nC = int(self.num_classes_)
|
|
1232
1315
|
input_dim = nF * nC
|
|
1316
|
+
|
|
1233
1317
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1234
1318
|
n_inputs=input_dim,
|
|
1235
|
-
n_outputs=
|
|
1319
|
+
n_outputs=nC,
|
|
1236
1320
|
n_samples=len(self.train_idx_),
|
|
1237
|
-
n_hidden=params["num_hidden_layers"],
|
|
1238
|
-
|
|
1239
|
-
|
|
1321
|
+
n_hidden=int(params["num_hidden_layers"]),
|
|
1322
|
+
latent_dim=int(params["latent_dim"]),
|
|
1323
|
+
alpha=float(params["layer_scaling_factor"]),
|
|
1324
|
+
schedule=str(params["layer_schedule"]),
|
|
1240
1325
|
)
|
|
1241
1326
|
|
|
1242
|
-
# Keep the latent_dim as the first element,
|
|
1243
|
-
# then the interior hidden widths.
|
|
1244
|
-
# If there are no interior widths (very small nets),
|
|
1245
|
-
# this still leaves [latent_dim].
|
|
1246
|
-
hidden_only: list[int] = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1247
|
-
|
|
1248
1327
|
params["model_params"] = {
|
|
1249
|
-
"n_features":
|
|
1250
|
-
"num_classes":
|
|
1251
|
-
getattr(self, "output_classes_", self.num_classes_ or 3)
|
|
1252
|
-
),
|
|
1328
|
+
"n_features": nF,
|
|
1329
|
+
"num_classes": nC,
|
|
1253
1330
|
"latent_dim": int(params["latent_dim"]),
|
|
1254
1331
|
"dropout_rate": float(params["dropout_rate"]),
|
|
1255
|
-
"hidden_layer_sizes":
|
|
1332
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1256
1333
|
"activation": str(params["activation"]),
|
|
1257
1334
|
}
|
|
1258
1335
|
return params
|
|
1259
1336
|
|
|
1260
|
-
def _set_best_params(
|
|
1261
|
-
|
|
1262
|
-
) -> Dict[str, int | float | str | List[int]]:
|
|
1263
|
-
"""Adopt best params (ImputeNLPCA parity) and return model_params.
|
|
1264
|
-
|
|
1265
|
-
This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
|
|
1337
|
+
def _set_best_params(self, params: dict) -> dict:
|
|
1338
|
+
"""Update instance fields from tuned params and return model_params dict.
|
|
1266
1339
|
|
|
1267
1340
|
Args:
|
|
1268
|
-
|
|
1341
|
+
params (dict): Best hyperparameters from tuning.
|
|
1269
1342
|
|
|
1270
1343
|
Returns:
|
|
1271
|
-
|
|
1344
|
+
dict: Model parameters for building the final model.
|
|
1272
1345
|
"""
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
else
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
self.latent_dim: int = bp["latent_dim"]
|
|
1296
|
-
self.dropout_rate: float = bp["dropout_rate"]
|
|
1297
|
-
self.learning_rate: float = bp["learning_rate"]
|
|
1298
|
-
self.l1_penalty: float = bp["l1_penalty"]
|
|
1299
|
-
self.activation: str = bp["activation"]
|
|
1300
|
-
self.layer_scaling_factor: float = bp["layer_scaling_factor"]
|
|
1301
|
-
self.layer_schedule: str = bp["layer_schedule"]
|
|
1302
|
-
|
|
1303
|
-
nF: int = self.num_features_
|
|
1304
|
-
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1305
|
-
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1306
|
-
n_inputs=nF * nC,
|
|
1307
|
-
n_outputs=nF * nC,
|
|
1308
|
-
n_samples=len(self.train_idx_),
|
|
1309
|
-
n_hidden=bp["num_hidden_layers"],
|
|
1310
|
-
alpha=bp["layer_scaling_factor"],
|
|
1311
|
-
schedule=bp["layer_schedule"],
|
|
1346
|
+
self.latent_dim = int(params["latent_dim"])
|
|
1347
|
+
self.dropout_rate = float(params["dropout_rate"])
|
|
1348
|
+
self.learning_rate = float(params["learning_rate"])
|
|
1349
|
+
self.l1_penalty = float(params["l1_penalty"])
|
|
1350
|
+
self.activation = str(params["activation"])
|
|
1351
|
+
self.layer_scaling_factor = float(params["layer_scaling_factor"])
|
|
1352
|
+
self.layer_schedule = str(params["layer_schedule"])
|
|
1353
|
+
|
|
1354
|
+
self.power = float(params["power"])
|
|
1355
|
+
self.normalize = bool(params["normalize"])
|
|
1356
|
+
self.inverse = bool(params["inverse"])
|
|
1357
|
+
self.gamma = float(params["gamma"])
|
|
1358
|
+
self.gamma_schedule = bool(params["gamma_schedule"])
|
|
1359
|
+
|
|
1360
|
+
self.class_weights_ = self._class_weights_from_zygosity(
|
|
1361
|
+
self.y_train_,
|
|
1362
|
+
train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
|
|
1363
|
+
inverse=self.inverse,
|
|
1364
|
+
normalize=self.normalize,
|
|
1365
|
+
max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
|
|
1366
|
+
power=self.power,
|
|
1312
1367
|
)
|
|
1313
1368
|
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
# this still leaves [latent_dim].
|
|
1318
|
-
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1319
|
-
|
|
1320
|
-
return {
|
|
1321
|
-
"n_features": self.num_features_,
|
|
1322
|
-
"latent_dim": self.latent_dim,
|
|
1323
|
-
"hidden_layer_sizes": hidden_only,
|
|
1324
|
-
"dropout_rate": self.dropout_rate,
|
|
1325
|
-
"activation": self.activation,
|
|
1326
|
-
"num_classes": nC,
|
|
1327
|
-
}
|
|
1328
|
-
|
|
1329
|
-
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
1330
|
-
"""Default model params when tuning is disabled.
|
|
1331
|
-
|
|
1332
|
-
This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
|
|
1333
|
-
|
|
1334
|
-
Returns:
|
|
1335
|
-
Dict[str, int | float | str | list]: Default model parameters.
|
|
1336
|
-
"""
|
|
1337
|
-
nF: int = self.num_features_
|
|
1338
|
-
# Use the number of output channels passed to the model (2 for diploid multilabel)
|
|
1339
|
-
# instead of the scoring classes (3) to keep layer shapes aligned.
|
|
1340
|
-
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1341
|
-
ls = self.layer_schedule
|
|
1342
|
-
|
|
1343
|
-
if ls not in {"pyramid", "constant", "linear"}:
|
|
1344
|
-
raise ValueError(f"Invalid layer_schedule: {ls}")
|
|
1369
|
+
nF = int(self.num_features_)
|
|
1370
|
+
nC = int(self.num_classes_)
|
|
1371
|
+
input_dim = nF * nC
|
|
1345
1372
|
|
|
1346
1373
|
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1347
|
-
n_inputs=
|
|
1348
|
-
n_outputs=
|
|
1349
|
-
n_samples=len(self.
|
|
1350
|
-
n_hidden=
|
|
1351
|
-
|
|
1352
|
-
|
|
1374
|
+
n_inputs=input_dim,
|
|
1375
|
+
n_outputs=nC,
|
|
1376
|
+
n_samples=len(self.train_idx_),
|
|
1377
|
+
n_hidden=int(params["num_hidden_layers"]),
|
|
1378
|
+
latent_dim=int(params["latent_dim"]),
|
|
1379
|
+
alpha=float(params["layer_scaling_factor"]),
|
|
1380
|
+
schedule=str(params["layer_schedule"]),
|
|
1353
1381
|
)
|
|
1382
|
+
|
|
1354
1383
|
return {
|
|
1355
|
-
"n_features":
|
|
1384
|
+
"n_features": nF,
|
|
1385
|
+
"num_classes": nC,
|
|
1356
1386
|
"latent_dim": self.latent_dim,
|
|
1357
1387
|
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1358
1388
|
"dropout_rate": self.dropout_rate,
|
|
1359
1389
|
"activation": self.activation,
|
|
1360
|
-
"num_classes": nC,
|
|
1361
1390
|
}
|