pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.6.16a3.dist-info/METADATA +292 -0
- pg_sui-1.6.16a3.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
- pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +922 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1436 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1121 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
- pgsui/impute/unsupervised/imputers/vae.py +1316 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/METADATA +0 -322
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,1361 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
import optuna
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from sklearn.exceptions import NotFittedError
|
|
10
|
+
from sklearn.model_selection import train_test_split
|
|
11
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
12
|
+
from snpio.utils.logging import LoggerManager
|
|
13
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
14
|
+
|
|
15
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
16
|
+
from pgsui.data_processing.containers import AutoencoderConfig
|
|
17
|
+
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
18
|
+
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
19
|
+
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
20
|
+
from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
|
|
21
|
+
from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
|
|
22
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
23
|
+
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from snpio import TreeParser
|
|
27
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def ensure_autoencoder_config(
|
|
31
|
+
config: AutoencoderConfig | dict | str | None,
|
|
32
|
+
) -> AutoencoderConfig:
|
|
33
|
+
"""Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
|
|
34
|
+
|
|
35
|
+
This method normalizes the configuration input for the Autoencoder imputer. It accepts a structured configuration in various formats, including a dataclass instance, a nested dictionary, a YAML file path, or None. The method processes the input accordingly and returns a concrete instance of AutoencoderConfig with all necessary fields populated.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config (AutoencoderConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
AutoencoderConfig: Concrete configuration instance.
|
|
42
|
+
"""
|
|
43
|
+
if config is None:
|
|
44
|
+
return AutoencoderConfig()
|
|
45
|
+
if isinstance(config, AutoencoderConfig):
|
|
46
|
+
return config
|
|
47
|
+
if isinstance(config, str):
|
|
48
|
+
# YAML path — top-level `preset` key is supported
|
|
49
|
+
return load_yaml_to_dataclass(config, AutoencoderConfig)
|
|
50
|
+
if isinstance(config, dict):
|
|
51
|
+
# Flatten dict into dot-keys then overlay onto a fresh instance
|
|
52
|
+
base = AutoencoderConfig()
|
|
53
|
+
|
|
54
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
55
|
+
for k, v in d.items():
|
|
56
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
57
|
+
if isinstance(v, dict):
|
|
58
|
+
_flatten(kk, v, out)
|
|
59
|
+
else:
|
|
60
|
+
out[kk] = v
|
|
61
|
+
return out
|
|
62
|
+
|
|
63
|
+
# Lift any present preset first
|
|
64
|
+
preset_name = config.pop("preset", None)
|
|
65
|
+
if "io" in config and isinstance(config["io"], dict):
|
|
66
|
+
preset_name = preset_name or config["io"].pop("preset", None)
|
|
67
|
+
|
|
68
|
+
if preset_name:
|
|
69
|
+
base = AutoencoderConfig.from_preset(preset_name)
|
|
70
|
+
|
|
71
|
+
flat = _flatten("", config, {})
|
|
72
|
+
return apply_dot_overrides(base, flat)
|
|
73
|
+
|
|
74
|
+
raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ImputeAutoencoder(BaseNNImputer):
|
|
78
|
+
"""Impute missing genotypes with a standard Autoencoder on 0/1/2 encodings.
|
|
79
|
+
|
|
80
|
+
This imputer uses a feedforward autoencoder architecture to learn compressed and reconstructive representations of genotype data encoded as 0 (homozygous reference), 1 (heterozygous), and 2 (homozygous alternate). Missing genotypes are represented as -1 during training and imputation.
|
|
81
|
+
|
|
82
|
+
The model is trained to minimize a focal cross-entropy loss, which helps to address class imbalance by focusing more on hard-to-classify examples. The architecture includes configurable parameters such as the number of hidden layers, latent dimension size, dropout rate, and activation functions.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
genotype_data: "GenotypeData",
|
|
88
|
+
*,
|
|
89
|
+
tree_parser: Optional["TreeParser"] = None,
|
|
90
|
+
config: Optional[Union["AutoencoderConfig", dict, str]] = None,
|
|
91
|
+
overrides: dict | None = None,
|
|
92
|
+
simulate_missing: bool | None = None,
|
|
93
|
+
sim_strategy: (
|
|
94
|
+
Literal[
|
|
95
|
+
"random",
|
|
96
|
+
"random_weighted",
|
|
97
|
+
"random_weighted_inv",
|
|
98
|
+
"nonrandom",
|
|
99
|
+
"nonrandom_weighted",
|
|
100
|
+
]
|
|
101
|
+
| None
|
|
102
|
+
) = None,
|
|
103
|
+
sim_prop: float | None = None,
|
|
104
|
+
sim_kwargs: dict | None = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Initialize the Autoencoder imputer with a unified config interface.
|
|
107
|
+
|
|
108
|
+
This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
genotype_data ("GenotypeData"): Backing genotype data object.
|
|
112
|
+
tree_parser (Optional["TreeParser"]): Optional SNPio phylogenetic tree parser for population-specific modes.
|
|
113
|
+
config (Union["AutoencoderConfig", dict, str] | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
114
|
+
overrides (dict | None): Optional dot-key overrides with highest precedence (e.g., {'model.latent_dim': 32}).
|
|
115
|
+
simulate_missing (bool | None): Whether to simulate missing data during evaluation. If None, uses config default.
|
|
116
|
+
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
|
|
117
|
+
sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
|
|
118
|
+
sim_kwargs (dict | None): Additional keyword arguments for simulating missing data. If None, uses config default.
|
|
119
|
+
"""
|
|
120
|
+
self.model_name = "ImputeAutoencoder"
|
|
121
|
+
self.genotype_data = genotype_data
|
|
122
|
+
self.tree_parser = tree_parser
|
|
123
|
+
|
|
124
|
+
# Normalize config then apply highest-precedence overrides
|
|
125
|
+
cfg = ensure_autoencoder_config(config)
|
|
126
|
+
if overrides:
|
|
127
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
128
|
+
self.cfg = cfg
|
|
129
|
+
|
|
130
|
+
# Logger consistent with NLPCA
|
|
131
|
+
logman = LoggerManager(
|
|
132
|
+
__name__,
|
|
133
|
+
prefix=self.cfg.io.prefix,
|
|
134
|
+
debug=self.cfg.io.debug,
|
|
135
|
+
verbose=self.cfg.io.verbose,
|
|
136
|
+
)
|
|
137
|
+
self.logger = configure_logger(
|
|
138
|
+
logman.get_logger(),
|
|
139
|
+
verbose=self.cfg.io.verbose,
|
|
140
|
+
debug=self.cfg.io.debug,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# BaseNNImputer bootstrapping (device/dirs/logging handled here)
|
|
144
|
+
super().__init__(
|
|
145
|
+
model_name=self.model_name,
|
|
146
|
+
genotype_data=self.genotype_data,
|
|
147
|
+
prefix=self.cfg.io.prefix,
|
|
148
|
+
device=self.cfg.train.device,
|
|
149
|
+
verbose=self.cfg.io.verbose,
|
|
150
|
+
debug=self.cfg.io.debug,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self.Model = AutoencoderModel
|
|
154
|
+
|
|
155
|
+
# Model hook & encoder
|
|
156
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
157
|
+
|
|
158
|
+
# IO / global
|
|
159
|
+
self.seed = self.cfg.io.seed
|
|
160
|
+
self.n_jobs = self.cfg.io.n_jobs
|
|
161
|
+
self.prefix = self.cfg.io.prefix
|
|
162
|
+
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
163
|
+
self.verbose = self.cfg.io.verbose
|
|
164
|
+
self.debug = self.cfg.io.debug
|
|
165
|
+
self.rng = np.random.default_rng(self.seed)
|
|
166
|
+
self.pos_weights_: torch.Tensor | None = None
|
|
167
|
+
|
|
168
|
+
# Simulated-missing controls (config defaults with ctor overrides)
|
|
169
|
+
sim_cfg = getattr(self.cfg, "sim", None)
|
|
170
|
+
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
171
|
+
if sim_kwargs:
|
|
172
|
+
sim_cfg_kwargs.update(sim_kwargs)
|
|
173
|
+
self.simulate_missing = (
|
|
174
|
+
(
|
|
175
|
+
sim_cfg.simulate_missing
|
|
176
|
+
if simulate_missing is None
|
|
177
|
+
else bool(simulate_missing)
|
|
178
|
+
)
|
|
179
|
+
if sim_cfg is not None
|
|
180
|
+
else bool(simulate_missing)
|
|
181
|
+
)
|
|
182
|
+
if sim_cfg is None:
|
|
183
|
+
default_strategy = "random"
|
|
184
|
+
default_prop = 0.10
|
|
185
|
+
else:
|
|
186
|
+
default_strategy = sim_cfg.sim_strategy
|
|
187
|
+
default_prop = sim_cfg.sim_prop
|
|
188
|
+
self.sim_strategy = sim_strategy or default_strategy
|
|
189
|
+
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
190
|
+
self.sim_kwargs = sim_cfg_kwargs
|
|
191
|
+
|
|
192
|
+
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
193
|
+
msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
|
|
194
|
+
self.logger.error(msg)
|
|
195
|
+
raise ValueError(msg)
|
|
196
|
+
|
|
197
|
+
# Model hyperparams
|
|
198
|
+
self.latent_dim = int(self.cfg.model.latent_dim)
|
|
199
|
+
self.dropout_rate = float(self.cfg.model.dropout_rate)
|
|
200
|
+
self.num_hidden_layers = int(self.cfg.model.num_hidden_layers)
|
|
201
|
+
self.layer_scaling_factor = float(self.cfg.model.layer_scaling_factor)
|
|
202
|
+
self.layer_schedule: str = str(self.cfg.model.layer_schedule)
|
|
203
|
+
self.activation = str(self.cfg.model.hidden_activation)
|
|
204
|
+
self.gamma = float(self.cfg.model.gamma)
|
|
205
|
+
|
|
206
|
+
# Train hyperparams
|
|
207
|
+
self.batch_size = int(self.cfg.train.batch_size)
|
|
208
|
+
self.learning_rate = float(self.cfg.train.learning_rate)
|
|
209
|
+
self.l1_penalty: float = float(self.cfg.train.l1_penalty)
|
|
210
|
+
self.early_stop_gen = int(self.cfg.train.early_stop_gen)
|
|
211
|
+
self.min_epochs = int(self.cfg.train.min_epochs)
|
|
212
|
+
self.epochs = int(self.cfg.train.max_epochs)
|
|
213
|
+
self.validation_split = float(self.cfg.train.validation_split)
|
|
214
|
+
self.beta = float(self.cfg.train.weights_beta)
|
|
215
|
+
self.max_ratio = float(self.cfg.train.weights_max_ratio)
|
|
216
|
+
|
|
217
|
+
# Tuning
|
|
218
|
+
self.tune = bool(self.cfg.tune.enabled)
|
|
219
|
+
self.tune_fast = bool(self.cfg.tune.fast)
|
|
220
|
+
self.tune_batch_size = int(self.cfg.tune.batch_size)
|
|
221
|
+
self.tune_epochs = int(self.cfg.tune.epochs)
|
|
222
|
+
self.tune_eval_interval = int(self.cfg.tune.eval_interval)
|
|
223
|
+
self.tune_metric: str = self.cfg.tune.metric
|
|
224
|
+
|
|
225
|
+
if self.tune_metric is not None:
|
|
226
|
+
self.tune_metric_: (
|
|
227
|
+
Literal[
|
|
228
|
+
"pr_macro",
|
|
229
|
+
"f1",
|
|
230
|
+
"accuracy",
|
|
231
|
+
"precision",
|
|
232
|
+
"recall",
|
|
233
|
+
"roc_auc",
|
|
234
|
+
"average_precision",
|
|
235
|
+
]
|
|
236
|
+
| None
|
|
237
|
+
) = self.cfg.tune.metric
|
|
238
|
+
|
|
239
|
+
self.n_trials = int(self.cfg.tune.n_trials)
|
|
240
|
+
self.tune_save_db = bool(self.cfg.tune.save_db)
|
|
241
|
+
self.tune_resume = bool(self.cfg.tune.resume)
|
|
242
|
+
self.tune_max_samples = int(self.cfg.tune.max_samples)
|
|
243
|
+
self.tune_max_loci = int(self.cfg.tune.max_loci)
|
|
244
|
+
self.tune_infer_epochs = int(
|
|
245
|
+
getattr(self.cfg.tune, "infer_epochs", 0)
|
|
246
|
+
) # AE unused
|
|
247
|
+
self.tune_patience = int(self.cfg.tune.patience)
|
|
248
|
+
|
|
249
|
+
# Evaluate
|
|
250
|
+
# AE does not optimize latents, so these are unused / fixed
|
|
251
|
+
self.eval_latent_steps: int = 0
|
|
252
|
+
self.eval_latent_lr: float = 0.0
|
|
253
|
+
self.eval_latent_weight_decay: float = 0.0
|
|
254
|
+
|
|
255
|
+
# Plotting (parity with NLPCA PlotConfig)
|
|
256
|
+
self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = (
|
|
257
|
+
self.cfg.plot.fmt
|
|
258
|
+
)
|
|
259
|
+
self.plot_dpi = int(self.cfg.plot.dpi)
|
|
260
|
+
self.plot_fontsize = int(self.cfg.plot.fontsize)
|
|
261
|
+
self.title_fontsize = int(self.cfg.plot.fontsize)
|
|
262
|
+
self.despine = bool(self.cfg.plot.despine)
|
|
263
|
+
self.show_plots = bool(self.cfg.plot.show)
|
|
264
|
+
|
|
265
|
+
# Core derived at fit-time
|
|
266
|
+
self.is_haploid: bool = False
|
|
267
|
+
self.num_classes_: int | None = None
|
|
268
|
+
self.model_params: Dict[str, Any] = {}
|
|
269
|
+
self.sim_mask_global_: np.ndarray | None = None
|
|
270
|
+
self.sim_mask_train_: np.ndarray | None = None
|
|
271
|
+
self.sim_mask_test_: np.ndarray | None = None
|
|
272
|
+
|
|
273
|
+
def fit(self) -> "ImputeAutoencoder":
|
|
274
|
+
"""Fit the autoencoder on 0/1/2 encoded genotypes (missing -> -1).
|
|
275
|
+
|
|
276
|
+
This method trains the autoencoder model using the provided genotype data. It prepares the data by encoding genotypes as 0, 1, and 2, with missing values represented internally as -1. (When simulated-missing loci are generated via ``SimMissingTransformer`` they are first marked with -9 but are immediately re-encoded as -1 prior to training.) The method splits the data into training and validation sets, initializes the model and training parameters, and performs training with optional hyperparameter tuning. After training, it evaluates the model on the validation set and stores the fitted model and training history.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
ImputeAutoencoder: Fitted instance.
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
NotFittedError: If training fails.
|
|
283
|
+
"""
|
|
284
|
+
self.logger.info(f"Fitting {self.model_name} model...")
|
|
285
|
+
|
|
286
|
+
# --- Data prep (mirror NLPCA) ---
|
|
287
|
+
X012 = self._get_float_genotypes(copy=True)
|
|
288
|
+
GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
|
|
289
|
+
self.ground_truth_ = GT_full.astype(np.int64, copy=False)
|
|
290
|
+
|
|
291
|
+
self.sim_mask_global_ = None
|
|
292
|
+
cache_key = self._sim_mask_cache_key()
|
|
293
|
+
if self.simulate_missing:
|
|
294
|
+
cached_mask = (
|
|
295
|
+
None if cache_key is None else self._sim_mask_cache.get(cache_key)
|
|
296
|
+
)
|
|
297
|
+
if cached_mask is not None:
|
|
298
|
+
self.sim_mask_global_ = cached_mask.copy()
|
|
299
|
+
else:
|
|
300
|
+
tr = SimMissingTransformer(
|
|
301
|
+
genotype_data=self.genotype_data,
|
|
302
|
+
tree_parser=self.tree_parser,
|
|
303
|
+
prop_missing=self.sim_prop,
|
|
304
|
+
strategy=self.sim_strategy,
|
|
305
|
+
missing_val=-9,
|
|
306
|
+
mask_missing=True,
|
|
307
|
+
verbose=self.verbose,
|
|
308
|
+
**self.sim_kwargs,
|
|
309
|
+
)
|
|
310
|
+
tr.fit(X012.copy())
|
|
311
|
+
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
312
|
+
if cache_key is not None:
|
|
313
|
+
self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
|
|
314
|
+
|
|
315
|
+
X_for_model = self.ground_truth_.copy()
|
|
316
|
+
X_for_model[self.sim_mask_global_] = -1
|
|
317
|
+
else:
|
|
318
|
+
X_for_model = self.ground_truth_.copy()
|
|
319
|
+
|
|
320
|
+
if self.genotype_data.snp_data is None:
|
|
321
|
+
msg = "SNP data is required for Autoencoder imputer."
|
|
322
|
+
self.logger.error(msg)
|
|
323
|
+
raise TypeError(msg)
|
|
324
|
+
|
|
325
|
+
# Ploidy & classes
|
|
326
|
+
self.is_haploid = bool(
|
|
327
|
+
np.all(
|
|
328
|
+
np.isin(
|
|
329
|
+
self.genotype_data.snp_data,
|
|
330
|
+
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
)
|
|
334
|
+
self.ploidy = 1 if self.is_haploid else 2
|
|
335
|
+
# Scoring still uses 3 labels for diploid (REF/HET/ALT); model head uses 2 logits
|
|
336
|
+
self.num_classes_ = 2 if self.is_haploid else 3
|
|
337
|
+
self.output_classes_ = 2
|
|
338
|
+
self.logger.info(
|
|
339
|
+
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
340
|
+
f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if self.is_haploid:
|
|
344
|
+
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
345
|
+
X_for_model[X_for_model == 2] = 1
|
|
346
|
+
|
|
347
|
+
n_samples, self.num_features_ = X_for_model.shape
|
|
348
|
+
|
|
349
|
+
# Model params (decoder outputs L * K logits)
|
|
350
|
+
self.model_params = {
|
|
351
|
+
"n_features": self.num_features_,
|
|
352
|
+
"num_classes": self.output_classes_,
|
|
353
|
+
"latent_dim": self.latent_dim,
|
|
354
|
+
"dropout_rate": self.dropout_rate,
|
|
355
|
+
"activation": self.activation,
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
# Train/Val split
|
|
359
|
+
indices = np.arange(n_samples)
|
|
360
|
+
train_idx, val_idx = train_test_split(
|
|
361
|
+
indices, test_size=self.validation_split, random_state=self.seed
|
|
362
|
+
)
|
|
363
|
+
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
364
|
+
self.X_train_ = X_for_model[train_idx]
|
|
365
|
+
self.X_val_ = X_for_model[val_idx]
|
|
366
|
+
self.GT_train_full_ = self.ground_truth_[train_idx]
|
|
367
|
+
self.GT_test_full_ = self.ground_truth_[val_idx]
|
|
368
|
+
|
|
369
|
+
if self.sim_mask_global_ is not None:
|
|
370
|
+
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
371
|
+
self.sim_mask_test_ = self.sim_mask_global_[val_idx]
|
|
372
|
+
else:
|
|
373
|
+
self.sim_mask_train_ = None
|
|
374
|
+
self.sim_mask_test_ = None
|
|
375
|
+
|
|
376
|
+
# Pos weights for diploid multilabel path (must exist before tuning)
|
|
377
|
+
if not self.is_haploid:
|
|
378
|
+
self.pos_weights_ = self._compute_pos_weights(self.X_train_)
|
|
379
|
+
else:
|
|
380
|
+
self.pos_weights_ = None
|
|
381
|
+
|
|
382
|
+
# Plotters/scorers (shared utilities)
|
|
383
|
+
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
384
|
+
|
|
385
|
+
# Tuning (optional; AE never needs latent refinement)
|
|
386
|
+
if self.tune:
|
|
387
|
+
self.tune_hyperparameters()
|
|
388
|
+
|
|
389
|
+
# Best params (tuned or default)
|
|
390
|
+
self.best_params_ = getattr(self, "best_params_", self._default_best_params())
|
|
391
|
+
|
|
392
|
+
# Class weights (device-aware)
|
|
393
|
+
self.class_weights_ = self._normalize_class_weights(
|
|
394
|
+
self._class_weights_from_zygosity(self.X_train_)
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# DataLoader
|
|
398
|
+
train_loader = self._get_data_loaders(self.X_train_)
|
|
399
|
+
|
|
400
|
+
# Build & train
|
|
401
|
+
model = self.build_model(self.Model, self.best_params_)
|
|
402
|
+
model.apply(self.initialize_weights)
|
|
403
|
+
|
|
404
|
+
loss, trained_model, history = self._train_and_validate_model(
|
|
405
|
+
model=model,
|
|
406
|
+
loader=train_loader,
|
|
407
|
+
lr=self.learning_rate,
|
|
408
|
+
l1_penalty=self.l1_penalty,
|
|
409
|
+
return_history=True,
|
|
410
|
+
class_weights=self.class_weights_,
|
|
411
|
+
X_val=self.X_val_,
|
|
412
|
+
params=self.best_params_,
|
|
413
|
+
prune_metric=self.tune_metric,
|
|
414
|
+
prune_warmup_epochs=10,
|
|
415
|
+
eval_interval=1,
|
|
416
|
+
eval_requires_latents=False,
|
|
417
|
+
eval_latent_steps=0,
|
|
418
|
+
eval_latent_lr=0.0,
|
|
419
|
+
eval_latent_weight_decay=0.0,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if trained_model is None:
|
|
423
|
+
msg = "Autoencoder training failed; no model was returned."
|
|
424
|
+
self.logger.error(msg)
|
|
425
|
+
raise RuntimeError(msg)
|
|
426
|
+
|
|
427
|
+
torch.save(
|
|
428
|
+
trained_model.state_dict(),
|
|
429
|
+
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
hist: Dict[str, List[float] | Dict[str, List[float]] | None] | None = {
|
|
433
|
+
"Train": history
|
|
434
|
+
}
|
|
435
|
+
self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
|
|
436
|
+
self.is_fit_ = True
|
|
437
|
+
|
|
438
|
+
# Evaluate on validation set (parity with NLPCA reporting)
|
|
439
|
+
eval_mask = (
|
|
440
|
+
self.sim_mask_test_
|
|
441
|
+
if (self.simulate_missing and self.sim_mask_test_ is not None)
|
|
442
|
+
else None
|
|
443
|
+
)
|
|
444
|
+
self._evaluate_model(
|
|
445
|
+
self.X_val_, self.model_, self.best_params_, eval_mask_override=eval_mask
|
|
446
|
+
)
|
|
447
|
+
self.plotter_.plot_history(self.history_)
|
|
448
|
+
self._save_best_params(self.best_params_)
|
|
449
|
+
|
|
450
|
+
return self
|
|
451
|
+
|
|
452
|
+
def transform(self) -> np.ndarray:
|
|
453
|
+
"""Impute missing genotypes (0/1/2) and return IUPAC strings.
|
|
454
|
+
|
|
455
|
+
This method imputes missing genotypes in the dataset using the trained autoencoder model. It predicts the most likely genotype (0, 1, or 2) for each missing entry and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easier interpretation.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
np.ndarray: IUPAC strings of shape (n_samples, n_loci).
|
|
459
|
+
|
|
460
|
+
Raises:
|
|
461
|
+
NotFittedError: If called before fit().
|
|
462
|
+
"""
|
|
463
|
+
if not getattr(self, "is_fit_", False):
|
|
464
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
465
|
+
|
|
466
|
+
self.logger.info(f"Imputing entire dataset with {self.model_name}...")
|
|
467
|
+
X_to_impute = self.ground_truth_.copy()
|
|
468
|
+
|
|
469
|
+
# Predict with masked inputs (no latent optimization)
|
|
470
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
|
|
471
|
+
|
|
472
|
+
# Fill only missing
|
|
473
|
+
missing_mask = X_to_impute == -1
|
|
474
|
+
imputed_array = X_to_impute.copy()
|
|
475
|
+
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
476
|
+
|
|
477
|
+
# Decode to IUPAC & optionally plot
|
|
478
|
+
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
479
|
+
if self.show_plots:
|
|
480
|
+
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
481
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
482
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
483
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
484
|
+
|
|
485
|
+
return imputed_genotypes
|
|
486
|
+
|
|
487
|
+
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
488
|
+
"""Create DataLoader over indices + integer targets (-1 for missing).
|
|
489
|
+
|
|
490
|
+
This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
y (np.ndarray): 0/1/2 matrix with -1 for missing.
|
|
494
|
+
|
|
495
|
+
Returns:
|
|
496
|
+
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
497
|
+
"""
|
|
498
|
+
y_tensor = torch.from_numpy(y).long()
|
|
499
|
+
indices = torch.arange(len(y), dtype=torch.long)
|
|
500
|
+
dataset = torch.utils.data.TensorDataset(indices, y_tensor)
|
|
501
|
+
pin_memory = self.device.type == "cuda"
|
|
502
|
+
return torch.utils.data.DataLoader(
|
|
503
|
+
dataset,
|
|
504
|
+
batch_size=self.batch_size,
|
|
505
|
+
shuffle=True,
|
|
506
|
+
pin_memory=pin_memory,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
def _train_and_validate_model(
|
|
510
|
+
self,
|
|
511
|
+
model: torch.nn.Module,
|
|
512
|
+
loader: torch.utils.data.DataLoader,
|
|
513
|
+
lr: float,
|
|
514
|
+
l1_penalty: float,
|
|
515
|
+
trial: optuna.Trial | None = None,
|
|
516
|
+
return_history: bool = False,
|
|
517
|
+
class_weights: torch.Tensor | None = None,
|
|
518
|
+
*,
|
|
519
|
+
X_val: np.ndarray | None = None,
|
|
520
|
+
params: dict | None = None,
|
|
521
|
+
prune_metric: str = "f1", # "f1" | "accuracy" | "pr_macro"
|
|
522
|
+
prune_warmup_epochs: int = 10,
|
|
523
|
+
eval_interval: int = 1,
|
|
524
|
+
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
525
|
+
eval_requires_latents: bool = False, # AE: always False
|
|
526
|
+
eval_latent_steps: int = 0,
|
|
527
|
+
eval_latent_lr: float = 0.0,
|
|
528
|
+
eval_latent_weight_decay: float = 0.0,
|
|
529
|
+
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
530
|
+
"""Wrap the AE training loop (no latent optimizer), with Optuna pruning.
|
|
531
|
+
|
|
532
|
+
This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
model (torch.nn.Module): Autoencoder model.
|
|
536
|
+
loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
|
|
537
|
+
lr (float): Learning rate.
|
|
538
|
+
l1_penalty (float): L1 regularization coeff.
|
|
539
|
+
trial (optuna.Trial | None): Optuna trial for pruning (optional).
|
|
540
|
+
return_history (bool): If True, return train loss history.
|
|
541
|
+
class_weights (torch.Tensor | None): Class weights tensor (on device).
|
|
542
|
+
X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
|
|
543
|
+
params (dict | None): Model params for evaluation.
|
|
544
|
+
prune_metric (str): Metric for pruning reports.
|
|
545
|
+
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
546
|
+
eval_interval (int): Eval frequency (epochs).
|
|
547
|
+
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
548
|
+
eval_latent_steps (int): Unused for AE.
|
|
549
|
+
eval_latent_lr (float): Unused for AE.
|
|
550
|
+
eval_latent_weight_decay (float): Unused for AE.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Tuple[float, torch.nn.Module | None, list | None]: (best_loss, best_model, history or None).
|
|
554
|
+
"""
|
|
555
|
+
if class_weights is None:
|
|
556
|
+
msg = "Must provide class_weights."
|
|
557
|
+
self.logger.error(msg)
|
|
558
|
+
raise TypeError(msg)
|
|
559
|
+
|
|
560
|
+
# Epoch budget mirrors NLPCA config (tuning vs final)
|
|
561
|
+
max_epochs = (
|
|
562
|
+
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
566
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
567
|
+
|
|
568
|
+
best_loss, best_model, hist = self._execute_training_loop(
|
|
569
|
+
loader=loader,
|
|
570
|
+
optimizer=optimizer,
|
|
571
|
+
scheduler=scheduler,
|
|
572
|
+
model=model,
|
|
573
|
+
l1_penalty=l1_penalty,
|
|
574
|
+
trial=trial,
|
|
575
|
+
return_history=return_history,
|
|
576
|
+
class_weights=class_weights,
|
|
577
|
+
X_val=X_val,
|
|
578
|
+
params=params,
|
|
579
|
+
prune_metric=prune_metric,
|
|
580
|
+
prune_warmup_epochs=prune_warmup_epochs,
|
|
581
|
+
eval_interval=eval_interval,
|
|
582
|
+
eval_requires_latents=False, # AE: no latent inference
|
|
583
|
+
eval_latent_steps=0,
|
|
584
|
+
eval_latent_lr=0.0,
|
|
585
|
+
eval_latent_weight_decay=0.0,
|
|
586
|
+
)
|
|
587
|
+
if return_history:
|
|
588
|
+
return best_loss, best_model, hist
|
|
589
|
+
|
|
590
|
+
return best_loss, best_model, None
|
|
591
|
+
|
|
592
|
+
def _execute_training_loop(
|
|
593
|
+
self,
|
|
594
|
+
loader: torch.utils.data.DataLoader,
|
|
595
|
+
optimizer: torch.optim.Optimizer,
|
|
596
|
+
scheduler: CosineAnnealingLR,
|
|
597
|
+
model: torch.nn.Module,
|
|
598
|
+
l1_penalty: float,
|
|
599
|
+
trial: optuna.Trial | None,
|
|
600
|
+
return_history: bool,
|
|
601
|
+
class_weights: torch.Tensor,
|
|
602
|
+
*,
|
|
603
|
+
X_val: np.ndarray | None = None,
|
|
604
|
+
params: dict | None = None,
|
|
605
|
+
prune_metric: str = "f1",
|
|
606
|
+
prune_warmup_epochs: int = 10,
|
|
607
|
+
eval_interval: int = 1,
|
|
608
|
+
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
609
|
+
eval_requires_latents: bool = False, # AE: False
|
|
610
|
+
eval_latent_steps: int = 0,
|
|
611
|
+
eval_latent_lr: float = 0.0,
|
|
612
|
+
eval_latent_weight_decay: float = 0.0,
|
|
613
|
+
) -> Tuple[float, torch.nn.Module, list]:
|
|
614
|
+
"""Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
|
|
615
|
+
|
|
616
|
+
This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
|
|
620
|
+
optimizer (torch.optim.Optimizer): Optimizer.
|
|
621
|
+
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
|
|
622
|
+
model (torch.nn.Module): Autoencoder model.
|
|
623
|
+
l1_penalty (float): L1 regularization coeff.
|
|
624
|
+
trial (optuna.Trial | None): Optuna trial for pruning (optional).
|
|
625
|
+
return_history (bool): If True, return train loss history.
|
|
626
|
+
class_weights (torch.Tensor): Class weights tensor (on device).
|
|
627
|
+
X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
|
|
628
|
+
params (dict | None): Model params for evaluation.
|
|
629
|
+
prune_metric (str): Metric for pruning reports.
|
|
630
|
+
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
631
|
+
eval_interval (int): Eval frequency (epochs).
|
|
632
|
+
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
633
|
+
eval_latent_steps (int): Unused for AE.
|
|
634
|
+
eval_latent_lr (float): Unused for AE.
|
|
635
|
+
eval_latent_weight_decay (float): Unused for AE.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
Tuple[float, torch.nn.Module, list]: Best validation loss, best model, and training history.
|
|
639
|
+
"""
|
|
640
|
+
best_loss = float("inf")
|
|
641
|
+
best_model = None
|
|
642
|
+
history: list[float] = []
|
|
643
|
+
|
|
644
|
+
early_stopping = EarlyStopping(
|
|
645
|
+
patience=self.early_stop_gen,
|
|
646
|
+
min_epochs=self.min_epochs,
|
|
647
|
+
verbose=self.verbose,
|
|
648
|
+
prefix=self.prefix,
|
|
649
|
+
debug=self.debug,
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
gamma_val = self.gamma
|
|
653
|
+
if isinstance(gamma_val, (list, tuple)):
|
|
654
|
+
if len(gamma_val) == 0:
|
|
655
|
+
raise ValueError("gamma list is empty.")
|
|
656
|
+
gamma_val = gamma_val[0]
|
|
657
|
+
|
|
658
|
+
gamma_final = float(gamma_val)
|
|
659
|
+
gamma_warm, gamma_ramp = 50, 100
|
|
660
|
+
|
|
661
|
+
# Optional LR warmup
|
|
662
|
+
warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
|
|
663
|
+
base_lr = float(optimizer.param_groups[0]["lr"])
|
|
664
|
+
min_lr = base_lr * 0.1
|
|
665
|
+
|
|
666
|
+
max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
|
|
667
|
+
|
|
668
|
+
for epoch in range(max_epochs):
|
|
669
|
+
# focal γ schedule (for stable training)
|
|
670
|
+
if epoch < gamma_warm:
|
|
671
|
+
model.gamma = 0.0 # type: ignore
|
|
672
|
+
elif epoch < gamma_warm + gamma_ramp:
|
|
673
|
+
model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore
|
|
674
|
+
else:
|
|
675
|
+
model.gamma = gamma_final # type: ignore
|
|
676
|
+
|
|
677
|
+
# LR warmup
|
|
678
|
+
if epoch < warmup_epochs:
|
|
679
|
+
scale = float(epoch + 1) / warmup_epochs
|
|
680
|
+
for g in optimizer.param_groups:
|
|
681
|
+
g["lr"] = min_lr + (base_lr - min_lr) * scale
|
|
682
|
+
|
|
683
|
+
train_loss = self._train_step(
|
|
684
|
+
loader=loader,
|
|
685
|
+
optimizer=optimizer,
|
|
686
|
+
model=model,
|
|
687
|
+
l1_penalty=l1_penalty,
|
|
688
|
+
class_weights=class_weights,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
# Abort or prune on non-finite epoch loss
|
|
692
|
+
if not np.isfinite(train_loss):
|
|
693
|
+
if trial is not None:
|
|
694
|
+
raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
|
|
695
|
+
# Soft reset suggestion: reduce LR and continue, or break
|
|
696
|
+
self.logger.warning(
|
|
697
|
+
"Non-finite epoch loss. Reducing LR by 10 percent and continuing."
|
|
698
|
+
)
|
|
699
|
+
for g in optimizer.param_groups:
|
|
700
|
+
g["lr"] *= 0.9
|
|
701
|
+
continue
|
|
702
|
+
|
|
703
|
+
scheduler.step()
|
|
704
|
+
if return_history:
|
|
705
|
+
history.append(train_loss)
|
|
706
|
+
|
|
707
|
+
early_stopping(train_loss, model)
|
|
708
|
+
if early_stopping.early_stop:
|
|
709
|
+
self.logger.info(f"Early stopping at epoch {epoch + 1}.")
|
|
710
|
+
break
|
|
711
|
+
|
|
712
|
+
# Optuna report/prune on validation metric
|
|
713
|
+
if (
|
|
714
|
+
trial is not None
|
|
715
|
+
and X_val is not None
|
|
716
|
+
and ((epoch + 1) % eval_interval == 0)
|
|
717
|
+
):
|
|
718
|
+
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
719
|
+
mask_override = None
|
|
720
|
+
if (
|
|
721
|
+
self.simulate_missing
|
|
722
|
+
and getattr(self, "sim_mask_test_", None) is not None
|
|
723
|
+
and getattr(self, "X_val_", None) is not None
|
|
724
|
+
and X_val.shape == self.X_val_.shape
|
|
725
|
+
):
|
|
726
|
+
mask_override = self.sim_mask_test_
|
|
727
|
+
metric_val = self._eval_for_pruning(
|
|
728
|
+
model=model,
|
|
729
|
+
X_val=X_val,
|
|
730
|
+
params=params or getattr(self, "best_params_", {}),
|
|
731
|
+
metric=metric_key,
|
|
732
|
+
objective_mode=True,
|
|
733
|
+
do_latent_infer=False, # AE: False
|
|
734
|
+
latent_steps=0,
|
|
735
|
+
latent_lr=0.0,
|
|
736
|
+
latent_weight_decay=0.0,
|
|
737
|
+
latent_seed=self.seed, # type: ignore
|
|
738
|
+
_latent_cache=None, # AE: not used
|
|
739
|
+
_latent_cache_key=None,
|
|
740
|
+
eval_mask_override=mask_override,
|
|
741
|
+
)
|
|
742
|
+
trial.report(metric_val, step=epoch + 1)
|
|
743
|
+
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
744
|
+
raise optuna.exceptions.TrialPruned(
|
|
745
|
+
f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
best_loss = early_stopping.best_score
|
|
749
|
+
if early_stopping.best_model is not None:
|
|
750
|
+
best_model = copy.deepcopy(early_stopping.best_model)
|
|
751
|
+
else:
|
|
752
|
+
best_model = copy.deepcopy(model)
|
|
753
|
+
return best_loss, best_model, history
|
|
754
|
+
|
|
755
|
+
def _train_step(
|
|
756
|
+
self,
|
|
757
|
+
loader: torch.utils.data.DataLoader,
|
|
758
|
+
optimizer: torch.optim.Optimizer,
|
|
759
|
+
model: torch.nn.Module,
|
|
760
|
+
l1_penalty: float,
|
|
761
|
+
class_weights: torch.Tensor,
|
|
762
|
+
) -> float:
|
|
763
|
+
"""One epoch with stable focal CE and NaN/Inf guards."""
|
|
764
|
+
model.train()
|
|
765
|
+
running = 0.0
|
|
766
|
+
num_batches = 0
|
|
767
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
768
|
+
if class_weights is not None and class_weights.device != self.device:
|
|
769
|
+
class_weights = class_weights.to(self.device)
|
|
770
|
+
|
|
771
|
+
# Use model.gamma if present, else self.gamma
|
|
772
|
+
gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
|
|
773
|
+
gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
|
|
774
|
+
ce_criterion = SafeFocalCELoss(
|
|
775
|
+
gamma=gamma, weight=class_weights, ignore_index=-1
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
for _, y_batch in loader:
|
|
779
|
+
optimizer.zero_grad(set_to_none=True)
|
|
780
|
+
y_batch = y_batch.to(self.device, non_blocking=True)
|
|
781
|
+
|
|
782
|
+
# Inputs: one-hot with zeros for missing; Targets: long ints with -1 for missing
|
|
783
|
+
if self.is_haploid:
|
|
784
|
+
x_in = self._one_hot_encode_012(y_batch) # (B, L, 2)
|
|
785
|
+
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
786
|
+
logits_flat = logits.view(-1, self.output_classes_)
|
|
787
|
+
targets_flat = y_batch.view(-1).long()
|
|
788
|
+
if not torch.isfinite(logits_flat).all():
|
|
789
|
+
continue
|
|
790
|
+
loss = ce_criterion(logits_flat, targets_flat)
|
|
791
|
+
else:
|
|
792
|
+
x_in = self._encode_multilabel_inputs(y_batch) # (B, L, 2)
|
|
793
|
+
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
794
|
+
if not torch.isfinite(logits).all():
|
|
795
|
+
continue
|
|
796
|
+
pos_w = getattr(self, "pos_weights_", None)
|
|
797
|
+
targets = self._multi_hot_targets(y_batch) # float, same shape
|
|
798
|
+
bce = F.binary_cross_entropy_with_logits(
|
|
799
|
+
logits, targets, pos_weight=pos_w, reduction="none"
|
|
800
|
+
)
|
|
801
|
+
mask = (y_batch != -1).unsqueeze(-1).float()
|
|
802
|
+
loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
|
|
803
|
+
|
|
804
|
+
if l1_penalty > 0:
|
|
805
|
+
l1 = torch.zeros((), device=self.device)
|
|
806
|
+
for p in l1_params:
|
|
807
|
+
l1 = l1 + p.abs().sum()
|
|
808
|
+
loss = loss + l1_penalty * l1
|
|
809
|
+
|
|
810
|
+
# Final guard
|
|
811
|
+
if not torch.isfinite(loss):
|
|
812
|
+
continue
|
|
813
|
+
|
|
814
|
+
loss.backward()
|
|
815
|
+
|
|
816
|
+
# Clip to prevent exploding grads
|
|
817
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
818
|
+
|
|
819
|
+
# If grads blew up to non-finite, skip update
|
|
820
|
+
if any(
|
|
821
|
+
(not torch.isfinite(p.grad).all())
|
|
822
|
+
for p in model.parameters()
|
|
823
|
+
if p.grad is not None
|
|
824
|
+
):
|
|
825
|
+
optimizer.zero_grad(set_to_none=True)
|
|
826
|
+
continue
|
|
827
|
+
|
|
828
|
+
optimizer.step()
|
|
829
|
+
|
|
830
|
+
running += float(loss.detach().item())
|
|
831
|
+
num_batches += 1
|
|
832
|
+
|
|
833
|
+
if num_batches == 0:
|
|
834
|
+
return float("inf") # signal upstream that epoch had no usable batches
|
|
835
|
+
return running / num_batches
|
|
836
|
+
|
|
837
|
+
def _predict(
|
|
838
|
+
self,
|
|
839
|
+
model: torch.nn.Module,
|
|
840
|
+
X: np.ndarray | torch.Tensor,
|
|
841
|
+
return_proba: bool = False,
|
|
842
|
+
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
|
|
843
|
+
"""Predict 0/1/2 labels (and probabilities) from masked inputs.
|
|
844
|
+
|
|
845
|
+
This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
model (torch.nn.Module): Trained model.
|
|
849
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
|
|
850
|
+
for missing.
|
|
851
|
+
return_proba (bool): If True, return probabilities.
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels,
|
|
855
|
+
and probabilities if requested.
|
|
856
|
+
"""
|
|
857
|
+
if model is None:
|
|
858
|
+
msg = "Model is not trained. Call fit() before predict()."
|
|
859
|
+
self.logger.error(msg)
|
|
860
|
+
raise NotFittedError(msg)
|
|
861
|
+
|
|
862
|
+
model.eval()
|
|
863
|
+
with torch.no_grad():
|
|
864
|
+
X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
|
|
865
|
+
X_tensor = X_tensor.to(self.device).long()
|
|
866
|
+
if self.is_haploid:
|
|
867
|
+
x_ohe = self._one_hot_encode_012(X_tensor)
|
|
868
|
+
logits = model(x_ohe).view(-1, self.num_features_, self.output_classes_)
|
|
869
|
+
probas = torch.softmax(logits, dim=-1)
|
|
870
|
+
labels = torch.argmax(probas, dim=-1)
|
|
871
|
+
else:
|
|
872
|
+
x_in = self._encode_multilabel_inputs(X_tensor)
|
|
873
|
+
logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
|
|
874
|
+
probas_2 = torch.sigmoid(logits)
|
|
875
|
+
p_ref = probas_2[..., 0]
|
|
876
|
+
p_alt = probas_2[..., 1]
|
|
877
|
+
p_het = p_ref * p_alt
|
|
878
|
+
p_ref_only = p_ref * (1 - p_alt)
|
|
879
|
+
p_alt_only = p_alt * (1 - p_ref)
|
|
880
|
+
stacked = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
|
|
881
|
+
stacked = stacked / stacked.sum(dim=-1, keepdim=True).clamp_min(1e-8)
|
|
882
|
+
probas = stacked
|
|
883
|
+
labels = torch.argmax(stacked, dim=-1)
|
|
884
|
+
|
|
885
|
+
if return_proba:
|
|
886
|
+
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
887
|
+
|
|
888
|
+
return labels.cpu().numpy()
|
|
889
|
+
|
|
890
|
+
def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
|
|
891
|
+
"""Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
|
|
892
|
+
if self.is_haploid:
|
|
893
|
+
return self._one_hot_encode_012(y)
|
|
894
|
+
y = y.to(self.device)
|
|
895
|
+
shape = y.shape + (2,)
|
|
896
|
+
out = torch.zeros(shape, device=self.device, dtype=torch.float32)
|
|
897
|
+
valid = y != -1
|
|
898
|
+
ref_mask = valid & (y != 2)
|
|
899
|
+
alt_mask = valid & (y != 0)
|
|
900
|
+
out[ref_mask, 0] = 1.0
|
|
901
|
+
out[alt_mask, 1] = 1.0
|
|
902
|
+
return out
|
|
903
|
+
|
|
904
|
+
def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
|
|
905
|
+
"""Targets aligned with _encode_multilabel_inputs for diploid training."""
|
|
906
|
+
if self.is_haploid:
|
|
907
|
+
# One-hot CE path expects integer targets; handled upstream.
|
|
908
|
+
raise RuntimeError("_multi_hot_targets called for haploid data.")
|
|
909
|
+
y = y.to(self.device)
|
|
910
|
+
out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
|
|
911
|
+
valid = y != -1
|
|
912
|
+
ref_mask = valid & (y != 2)
|
|
913
|
+
alt_mask = valid & (y != 0)
|
|
914
|
+
out[ref_mask, 0] = 1.0
|
|
915
|
+
out[alt_mask, 1] = 1.0
|
|
916
|
+
return out
|
|
917
|
+
|
|
918
|
+
def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
|
|
919
|
+
"""Balance REF/ALT channels for multilabel BCE."""
|
|
920
|
+
ref_pos = np.count_nonzero((X == 0) | (X == 1))
|
|
921
|
+
alt_pos = np.count_nonzero((X == 2) | (X == 1))
|
|
922
|
+
total_valid = np.count_nonzero(X != -1)
|
|
923
|
+
pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
|
|
924
|
+
neg_counts = np.maximum(total_valid - pos_counts, 1.0)
|
|
925
|
+
pos_counts = np.maximum(pos_counts, 1.0)
|
|
926
|
+
weights = neg_counts / pos_counts
|
|
927
|
+
return torch.tensor(weights, device=self.device, dtype=torch.float32)
|
|
928
|
+
|
|
929
|
+
def _evaluate_model(
|
|
930
|
+
self,
|
|
931
|
+
X_val: np.ndarray,
|
|
932
|
+
model: torch.nn.Module,
|
|
933
|
+
params: dict,
|
|
934
|
+
objective_mode: bool = False,
|
|
935
|
+
latent_vectors_val: Optional[np.ndarray] = None,
|
|
936
|
+
*,
|
|
937
|
+
eval_mask_override: np.ndarray | None = None,
|
|
938
|
+
) -> Dict[str, float]:
|
|
939
|
+
"""Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
|
|
940
|
+
|
|
941
|
+
This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
|
|
942
|
+
|
|
943
|
+
Args:
|
|
944
|
+
X_val (np.ndarray): Validation set 0/1/2 matrix with -1
|
|
945
|
+
for missing.
|
|
946
|
+
model (torch.nn.Module): Trained model.
|
|
947
|
+
params (dict): Model parameters.
|
|
948
|
+
objective_mode (bool): If True, suppress logging and reports.
|
|
949
|
+
latent_vectors_val (Optional[np.ndarray]): Unused for AE.
|
|
950
|
+
eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
|
|
951
|
+
|
|
952
|
+
Returns:
|
|
953
|
+
Dict[str, float]: Dictionary of evaluation metrics.
|
|
954
|
+
"""
|
|
955
|
+
pred_labels, pred_probas = self._predict(
|
|
956
|
+
model=model, X=X_val, return_proba=True
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
|
|
960
|
+
|
|
961
|
+
# FIX 1: Check ROWS (shape[0]) only. X_val might be a feature subset.
|
|
962
|
+
if (
|
|
963
|
+
hasattr(self, "X_val_")
|
|
964
|
+
and getattr(self, "X_val_", None) is not None
|
|
965
|
+
and X_val.shape[0] == self.X_val_.shape[0]
|
|
966
|
+
):
|
|
967
|
+
GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
|
|
968
|
+
elif (
|
|
969
|
+
hasattr(self, "X_train_")
|
|
970
|
+
and getattr(self, "X_train_", None) is not None
|
|
971
|
+
and X_val.shape[0] == self.X_train_.shape[0]
|
|
972
|
+
):
|
|
973
|
+
GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
|
|
974
|
+
else:
|
|
975
|
+
GT_ref = self.ground_truth_
|
|
976
|
+
|
|
977
|
+
# FIX 2: Handle Feature Mismatch (e.g., tune_fast feature subsetting)
|
|
978
|
+
# If the GT source has more columns than X_val, slice it to match.
|
|
979
|
+
if GT_ref.shape[1] > X_val.shape[1]:
|
|
980
|
+
GT_ref = GT_ref[:, : X_val.shape[1]]
|
|
981
|
+
|
|
982
|
+
# Fallback if rows mismatch (unlikely after Fix 1, but safe to keep)
|
|
983
|
+
if GT_ref.shape != X_val.shape:
|
|
984
|
+
# If completely different, we can't use the ground truth object.
|
|
985
|
+
# Fall back to X_val (this implies only observed values are scored)
|
|
986
|
+
GT_ref = X_val
|
|
987
|
+
|
|
988
|
+
if eval_mask_override is not None:
|
|
989
|
+
# FIX 3: Allow override mask to be sliced if it's too wide
|
|
990
|
+
if eval_mask_override.shape[0] != X_val.shape[0]:
|
|
991
|
+
msg = (
|
|
992
|
+
f"eval_mask_override rows {eval_mask_override.shape[0]} "
|
|
993
|
+
f"does not match X_val rows {X_val.shape[0]}"
|
|
994
|
+
)
|
|
995
|
+
self.logger.error(msg)
|
|
996
|
+
raise ValueError(msg)
|
|
997
|
+
|
|
998
|
+
if eval_mask_override.shape[1] > X_val.shape[1]:
|
|
999
|
+
eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
|
|
1000
|
+
else:
|
|
1001
|
+
eval_mask = eval_mask_override.astype(bool)
|
|
1002
|
+
else:
|
|
1003
|
+
eval_mask = X_val != -1
|
|
1004
|
+
|
|
1005
|
+
# Combine masks
|
|
1006
|
+
eval_mask = eval_mask & finite_mask & (GT_ref != -1)
|
|
1007
|
+
|
|
1008
|
+
y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
|
|
1009
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
|
|
1010
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
|
|
1011
|
+
|
|
1012
|
+
if y_true_flat.size == 0:
|
|
1013
|
+
self.tune_metric = "f1" if self.tune_metric is None else self.tune_metric
|
|
1014
|
+
return {self.tune_metric: 0.0}
|
|
1015
|
+
|
|
1016
|
+
# ensure valid probability simplex after masking (no NaNs/Infs, sums=1)
|
|
1017
|
+
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
1018
|
+
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
1019
|
+
row_sums[row_sums == 0] = 1.0
|
|
1020
|
+
y_proba_flat = y_proba_flat / row_sums
|
|
1021
|
+
|
|
1022
|
+
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
1023
|
+
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
1024
|
+
|
|
1025
|
+
if self.is_haploid:
|
|
1026
|
+
y_true_flat = y_true_flat.copy()
|
|
1027
|
+
y_pred_flat = y_pred_flat.copy()
|
|
1028
|
+
y_true_flat[y_true_flat == 2] = 1
|
|
1029
|
+
y_pred_flat[y_pred_flat == 2] = 1
|
|
1030
|
+
# collapse probs to 2-class
|
|
1031
|
+
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
1032
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
1033
|
+
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
1034
|
+
y_proba_flat = proba_2
|
|
1035
|
+
|
|
1036
|
+
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
1037
|
+
|
|
1038
|
+
tune_metric_tmp: Literal[
|
|
1039
|
+
"pr_macro",
|
|
1040
|
+
"roc_auc",
|
|
1041
|
+
"average_precision",
|
|
1042
|
+
"accuracy",
|
|
1043
|
+
"f1",
|
|
1044
|
+
"precision",
|
|
1045
|
+
"recall",
|
|
1046
|
+
]
|
|
1047
|
+
if self.tune_metric_ is not None:
|
|
1048
|
+
tune_metric_tmp = self.tune_metric_
|
|
1049
|
+
else:
|
|
1050
|
+
tune_metric_tmp = "f1" # Default if not tuning
|
|
1051
|
+
|
|
1052
|
+
metrics = self.scorers_.evaluate(
|
|
1053
|
+
y_true_flat,
|
|
1054
|
+
y_pred_flat,
|
|
1055
|
+
y_true_ohe,
|
|
1056
|
+
y_proba_flat,
|
|
1057
|
+
objective_mode,
|
|
1058
|
+
tune_metric_tmp,
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
if not objective_mode:
|
|
1062
|
+
pm = PrettyMetrics(
|
|
1063
|
+
metrics, precision=3, title=f"{self.model_name} Validation Metrics"
|
|
1064
|
+
)
|
|
1065
|
+
pm.render() # prints a command-line table
|
|
1066
|
+
|
|
1067
|
+
# Primary report (REF/HET/ALT or REF/ALT)
|
|
1068
|
+
self._make_class_reports(
|
|
1069
|
+
y_true=y_true_flat,
|
|
1070
|
+
y_pred_proba=y_proba_flat,
|
|
1071
|
+
y_pred=y_pred_flat,
|
|
1072
|
+
metrics=metrics,
|
|
1073
|
+
labels=target_names,
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
# IUPAC decode & 10-base integer reports
|
|
1077
|
+
# Now safe because GT_ref has been sliced to match X_val dimensions
|
|
1078
|
+
y_true_dec = self.pgenc.decode_012(
|
|
1079
|
+
GT_ref.reshape(X_val.shape[0], X_val.shape[1])
|
|
1080
|
+
)
|
|
1081
|
+
X_pred = X_val.copy()
|
|
1082
|
+
X_pred[eval_mask] = y_pred_flat
|
|
1083
|
+
|
|
1084
|
+
# Use X_val.shape[1] (current features) not self.num_features_ (original features)
|
|
1085
|
+
y_pred_dec = self.pgenc.decode_012(
|
|
1086
|
+
X_pred.reshape(X_val.shape[0], X_val.shape[1])
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
encodings_dict = {
|
|
1090
|
+
"A": 0,
|
|
1091
|
+
"C": 1,
|
|
1092
|
+
"G": 2,
|
|
1093
|
+
"T": 3,
|
|
1094
|
+
"W": 4,
|
|
1095
|
+
"R": 5,
|
|
1096
|
+
"M": 6,
|
|
1097
|
+
"K": 7,
|
|
1098
|
+
"Y": 8,
|
|
1099
|
+
"S": 9,
|
|
1100
|
+
"N": -1,
|
|
1101
|
+
}
|
|
1102
|
+
y_true_int = self.pgenc.convert_int_iupac(
|
|
1103
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
1104
|
+
)
|
|
1105
|
+
y_pred_int = self.pgenc.convert_int_iupac(
|
|
1106
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
valid_iupac_mask = y_true_int[eval_mask] >= 0
|
|
1110
|
+
if valid_iupac_mask.any():
|
|
1111
|
+
self._make_class_reports(
|
|
1112
|
+
y_true=y_true_int[eval_mask][valid_iupac_mask],
|
|
1113
|
+
y_pred=y_pred_int[eval_mask][valid_iupac_mask],
|
|
1114
|
+
metrics=metrics,
|
|
1115
|
+
y_pred_proba=None,
|
|
1116
|
+
labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
|
|
1117
|
+
)
|
|
1118
|
+
else:
|
|
1119
|
+
self.logger.warning(
|
|
1120
|
+
"Skipped IUPAC confusion matrix: No valid ground truths."
|
|
1121
|
+
)
|
|
1122
|
+
|
|
1123
|
+
return metrics
|
|
1124
|
+
|
|
1125
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
1126
|
+
"""Optuna objective for AE; mirrors NLPCA study driver without latents.
|
|
1127
|
+
|
|
1128
|
+
This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
|
|
1129
|
+
|
|
1130
|
+
Args:
|
|
1131
|
+
trial (optuna.Trial): Optuna trial.
|
|
1132
|
+
|
|
1133
|
+
Returns:
|
|
1134
|
+
float: Value of the tuning metric (maximize).
|
|
1135
|
+
"""
|
|
1136
|
+
try:
|
|
1137
|
+
# Sample hyperparameters (existing helper; unchanged signature)
|
|
1138
|
+
params = self._sample_hyperparameters(trial)
|
|
1139
|
+
|
|
1140
|
+
# Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
|
|
1141
|
+
X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
|
|
1142
|
+
X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
|
|
1143
|
+
|
|
1144
|
+
class_weights = self._normalize_class_weights(
|
|
1145
|
+
self._class_weights_from_zygosity(X_train)
|
|
1146
|
+
)
|
|
1147
|
+
train_loader = self._get_data_loaders(X_train)
|
|
1148
|
+
|
|
1149
|
+
model = self.build_model(self.Model, params["model_params"])
|
|
1150
|
+
model.apply(self.initialize_weights)
|
|
1151
|
+
|
|
1152
|
+
lr: float = float(params["lr"])
|
|
1153
|
+
l1_penalty: float = float(params["l1_penalty"])
|
|
1154
|
+
|
|
1155
|
+
# Train + prune on metric
|
|
1156
|
+
_, model, __ = self._train_and_validate_model(
|
|
1157
|
+
model=model,
|
|
1158
|
+
loader=train_loader,
|
|
1159
|
+
lr=lr,
|
|
1160
|
+
l1_penalty=l1_penalty,
|
|
1161
|
+
trial=trial,
|
|
1162
|
+
return_history=False,
|
|
1163
|
+
class_weights=class_weights,
|
|
1164
|
+
X_val=X_val,
|
|
1165
|
+
params=params,
|
|
1166
|
+
prune_metric=self.tune_metric,
|
|
1167
|
+
prune_warmup_epochs=10,
|
|
1168
|
+
eval_interval=self.tune_eval_interval,
|
|
1169
|
+
eval_requires_latents=False,
|
|
1170
|
+
eval_latent_steps=0,
|
|
1171
|
+
eval_latent_lr=0.0,
|
|
1172
|
+
eval_latent_weight_decay=0.0,
|
|
1173
|
+
)
|
|
1174
|
+
|
|
1175
|
+
eval_mask = (
|
|
1176
|
+
self.sim_mask_test_
|
|
1177
|
+
if (
|
|
1178
|
+
self.simulate_missing
|
|
1179
|
+
and getattr(self, "sim_mask_test_", None) is not None
|
|
1180
|
+
)
|
|
1181
|
+
else None
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
if model is not None:
|
|
1185
|
+
metrics = self._evaluate_model(
|
|
1186
|
+
X_val,
|
|
1187
|
+
model,
|
|
1188
|
+
params,
|
|
1189
|
+
objective_mode=True,
|
|
1190
|
+
eval_mask_override=eval_mask,
|
|
1191
|
+
)
|
|
1192
|
+
self._clear_resources(model, train_loader)
|
|
1193
|
+
else:
|
|
1194
|
+
raise TypeError("Model training failed; no model was returned.")
|
|
1195
|
+
|
|
1196
|
+
return metrics[self.tune_metric]
|
|
1197
|
+
|
|
1198
|
+
except Exception as e:
|
|
1199
|
+
# Keep sweeps moving if a trial fails
|
|
1200
|
+
raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
|
|
1201
|
+
|
|
1202
|
+
def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
|
|
1203
|
+
"""Sample AE hyperparameters and compute hidden sizes for model params.
|
|
1204
|
+
|
|
1205
|
+
This method samples hyperparameters for the autoencoder model using Optuna's trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
|
|
1206
|
+
|
|
1207
|
+
Args:
|
|
1208
|
+
trial (optuna.Trial): Optuna trial object.
|
|
1209
|
+
|
|
1210
|
+
Returns:
|
|
1211
|
+
Dict[str, int | float | str | bool]: Sampled hyperparameters and model_params.
|
|
1212
|
+
"""
|
|
1213
|
+
params = {
|
|
1214
|
+
"latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
|
|
1215
|
+
"lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
|
|
1216
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
|
|
1217
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
|
|
1218
|
+
"activation": trial.suggest_categorical(
|
|
1219
|
+
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
1220
|
+
),
|
|
1221
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
1222
|
+
"layer_scaling_factor": trial.suggest_float(
|
|
1223
|
+
"layer_scaling_factor", 2.0, 4.0, step=0.5
|
|
1224
|
+
),
|
|
1225
|
+
"layer_schedule": trial.suggest_categorical(
|
|
1226
|
+
"layer_schedule", ["pyramid", "linear"]
|
|
1227
|
+
),
|
|
1228
|
+
}
|
|
1229
|
+
|
|
1230
|
+
nF: int = self.num_features_
|
|
1231
|
+
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1232
|
+
input_dim = nF * nC
|
|
1233
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1234
|
+
n_inputs=input_dim,
|
|
1235
|
+
n_outputs=input_dim,
|
|
1236
|
+
n_samples=len(self.train_idx_),
|
|
1237
|
+
n_hidden=params["num_hidden_layers"],
|
|
1238
|
+
alpha=params["layer_scaling_factor"],
|
|
1239
|
+
schedule=params["layer_schedule"],
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1242
|
+
# Keep the latent_dim as the first element,
|
|
1243
|
+
# then the interior hidden widths.
|
|
1244
|
+
# If there are no interior widths (very small nets),
|
|
1245
|
+
# this still leaves [latent_dim].
|
|
1246
|
+
hidden_only: list[int] = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1247
|
+
|
|
1248
|
+
params["model_params"] = {
|
|
1249
|
+
"n_features": int(self.num_features_),
|
|
1250
|
+
"num_classes": int(
|
|
1251
|
+
getattr(self, "output_classes_", self.num_classes_ or 3)
|
|
1252
|
+
),
|
|
1253
|
+
"latent_dim": int(params["latent_dim"]),
|
|
1254
|
+
"dropout_rate": float(params["dropout_rate"]),
|
|
1255
|
+
"hidden_layer_sizes": hidden_only,
|
|
1256
|
+
"activation": str(params["activation"]),
|
|
1257
|
+
}
|
|
1258
|
+
return params
|
|
1259
|
+
|
|
1260
|
+
def _set_best_params(
|
|
1261
|
+
self, best_params: Dict[str, int | float | str | List[int]]
|
|
1262
|
+
) -> Dict[str, int | float | str | List[int]]:
|
|
1263
|
+
"""Adopt best params (ImputeNLPCA parity) and return model_params.
|
|
1264
|
+
|
|
1265
|
+
This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
|
|
1266
|
+
|
|
1267
|
+
Args:
|
|
1268
|
+
best_params (Dict[str, int | float | str | List[int]]): Best hyperparameters from tuning.
|
|
1269
|
+
|
|
1270
|
+
Returns:
|
|
1271
|
+
Dict[str, int | float | str | List[int]]: Model parameters for building the model.
|
|
1272
|
+
"""
|
|
1273
|
+
bp = {}
|
|
1274
|
+
for k, v in best_params.items():
|
|
1275
|
+
if not isinstance(v, list):
|
|
1276
|
+
if k in {"latent_dim", "num_hidden_layers"}:
|
|
1277
|
+
bp[k] = int(v)
|
|
1278
|
+
elif k in {
|
|
1279
|
+
"dropout_rate",
|
|
1280
|
+
"learning_rate",
|
|
1281
|
+
"l1_penalty",
|
|
1282
|
+
"layer_scaling_factor",
|
|
1283
|
+
}:
|
|
1284
|
+
bp[k] = float(v)
|
|
1285
|
+
elif k in {"activation", "layer_schedule"}:
|
|
1286
|
+
if k == "layer_schedule":
|
|
1287
|
+
if v not in {"pyramid", "constant", "linear"}:
|
|
1288
|
+
raise ValueError(f"Invalid layer_schedule: {v}")
|
|
1289
|
+
bp[k] = v
|
|
1290
|
+
else:
|
|
1291
|
+
bp[k] = str(v)
|
|
1292
|
+
else:
|
|
1293
|
+
bp[k] = v # keep lists as-is
|
|
1294
|
+
|
|
1295
|
+
self.latent_dim: int = bp["latent_dim"]
|
|
1296
|
+
self.dropout_rate: float = bp["dropout_rate"]
|
|
1297
|
+
self.learning_rate: float = bp["learning_rate"]
|
|
1298
|
+
self.l1_penalty: float = bp["l1_penalty"]
|
|
1299
|
+
self.activation: str = bp["activation"]
|
|
1300
|
+
self.layer_scaling_factor: float = bp["layer_scaling_factor"]
|
|
1301
|
+
self.layer_schedule: str = bp["layer_schedule"]
|
|
1302
|
+
|
|
1303
|
+
nF: int = self.num_features_
|
|
1304
|
+
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1305
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1306
|
+
n_inputs=nF * nC,
|
|
1307
|
+
n_outputs=nF * nC,
|
|
1308
|
+
n_samples=len(self.train_idx_),
|
|
1309
|
+
n_hidden=bp["num_hidden_layers"],
|
|
1310
|
+
alpha=bp["layer_scaling_factor"],
|
|
1311
|
+
schedule=bp["layer_schedule"],
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
# Keep the latent_dim as the first element,
|
|
1315
|
+
# then the interior hidden widths.
|
|
1316
|
+
# If there are no interior widths (very small nets),
|
|
1317
|
+
# this still leaves [latent_dim].
|
|
1318
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
1319
|
+
|
|
1320
|
+
return {
|
|
1321
|
+
"n_features": self.num_features_,
|
|
1322
|
+
"latent_dim": self.latent_dim,
|
|
1323
|
+
"hidden_layer_sizes": hidden_only,
|
|
1324
|
+
"dropout_rate": self.dropout_rate,
|
|
1325
|
+
"activation": self.activation,
|
|
1326
|
+
"num_classes": nC,
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
1330
|
+
"""Default model params when tuning is disabled.
|
|
1331
|
+
|
|
1332
|
+
This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
|
|
1333
|
+
|
|
1334
|
+
Returns:
|
|
1335
|
+
Dict[str, int | float | str | list]: Default model parameters.
|
|
1336
|
+
"""
|
|
1337
|
+
nF: int = self.num_features_
|
|
1338
|
+
# Use the number of output channels passed to the model (2 for diploid multilabel)
|
|
1339
|
+
# instead of the scoring classes (3) to keep layer shapes aligned.
|
|
1340
|
+
nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
|
|
1341
|
+
ls = self.layer_schedule
|
|
1342
|
+
|
|
1343
|
+
if ls not in {"pyramid", "constant", "linear"}:
|
|
1344
|
+
raise ValueError(f"Invalid layer_schedule: {ls}")
|
|
1345
|
+
|
|
1346
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1347
|
+
n_inputs=nF * nC,
|
|
1348
|
+
n_outputs=nF * nC,
|
|
1349
|
+
n_samples=len(self.ground_truth_),
|
|
1350
|
+
n_hidden=self.num_hidden_layers,
|
|
1351
|
+
alpha=self.layer_scaling_factor,
|
|
1352
|
+
schedule=ls,
|
|
1353
|
+
)
|
|
1354
|
+
return {
|
|
1355
|
+
"n_features": self.num_features_,
|
|
1356
|
+
"latent_dim": self.latent_dim,
|
|
1357
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1358
|
+
"dropout_rate": self.dropout_rate,
|
|
1359
|
+
"activation": self.activation,
|
|
1360
|
+
"num_classes": nC,
|
|
1361
|
+
}
|