pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -0,0 +1,972 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import optuna
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from sklearn.exceptions import NotFittedError
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
13
|
+
from snpio.utils.logging import LoggerManager
|
|
14
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
15
|
+
|
|
16
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
17
|
+
from pgsui.data_processing.containers import AutoencoderConfig
|
|
18
|
+
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
19
|
+
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
20
|
+
from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def ensure_autoencoder_config(
|
|
27
|
+
config: AutoencoderConfig | dict | str | None,
|
|
28
|
+
) -> AutoencoderConfig:
|
|
29
|
+
"""Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
|
|
30
|
+
|
|
31
|
+
This method normalizes the configuration input for the Autoencoder imputer. It accepts a structured configuration in various formats, including a dataclass instance, a nested dictionary, a YAML file path, or None. The method processes the input accordingly and returns a concrete instance of AutoencoderConfig with all necessary fields populated.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
config (AutoencoderConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
AutoencoderConfig: Concrete configuration instance.
|
|
38
|
+
"""
|
|
39
|
+
if config is None:
|
|
40
|
+
return AutoencoderConfig()
|
|
41
|
+
if isinstance(config, AutoencoderConfig):
|
|
42
|
+
return config
|
|
43
|
+
if isinstance(config, str):
|
|
44
|
+
# YAML path — top-level `preset` key is supported
|
|
45
|
+
return load_yaml_to_dataclass(
|
|
46
|
+
config, AutoencoderConfig, preset_builder=AutoencoderConfig.from_preset
|
|
47
|
+
)
|
|
48
|
+
if isinstance(config, dict):
|
|
49
|
+
# Flatten dict into dot-keys then overlay onto a fresh instance
|
|
50
|
+
base = AutoencoderConfig()
|
|
51
|
+
|
|
52
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
53
|
+
for k, v in d.items():
|
|
54
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
55
|
+
if isinstance(v, dict):
|
|
56
|
+
_flatten(kk, v, out)
|
|
57
|
+
else:
|
|
58
|
+
out[kk] = v
|
|
59
|
+
return out
|
|
60
|
+
|
|
61
|
+
# Lift any present preset first
|
|
62
|
+
preset_name = config.pop("preset", None)
|
|
63
|
+
if "io" in config and isinstance(config["io"], dict):
|
|
64
|
+
preset_name = preset_name or config["io"].pop("preset", None)
|
|
65
|
+
|
|
66
|
+
if preset_name:
|
|
67
|
+
base = AutoencoderConfig.from_preset(preset_name)
|
|
68
|
+
|
|
69
|
+
flat = _flatten("", config, {})
|
|
70
|
+
return apply_dot_overrides(base, flat)
|
|
71
|
+
|
|
72
|
+
raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ImputeAutoencoder(BaseNNImputer):
|
|
76
|
+
"""Impute missing genotypes with a standard Autoencoder on 0/1/2 encodings.
|
|
77
|
+
|
|
78
|
+
This imputer uses a feedforward autoencoder architecture to learn compressed and reconstructive representations of genotype data encoded as 0 (homozygous reference), 1 (heterozygous), and 2 (homozygous alternate). Missing genotypes are represented as -1 during training and imputation.
|
|
79
|
+
|
|
80
|
+
The model is trained to minimize a focal cross-entropy loss, which helps to address class imbalance by focusing more on hard-to-classify examples. The architecture includes configurable parameters such as the number of hidden layers, latent dimension size, dropout rate, and activation functions.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
genotype_data: "GenotypeData",
|
|
86
|
+
*,
|
|
87
|
+
config: Optional[Union["AutoencoderConfig", dict, str]] = None,
|
|
88
|
+
overrides: dict | None = None,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Initialize the Autoencoder imputer with a unified config interface.
|
|
91
|
+
|
|
92
|
+
This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
genotype_data: Backing genotype data object.
|
|
96
|
+
config: Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
97
|
+
overrides: Optional dot-key overrides with highest precedence (e.g., {'model.latent_dim': 32}).
|
|
98
|
+
"""
|
|
99
|
+
self.model_name = "ImputeAutoencoder"
|
|
100
|
+
self.genotype_data = genotype_data
|
|
101
|
+
|
|
102
|
+
# Normalize config then apply highest-precedence overrides
|
|
103
|
+
cfg = ensure_autoencoder_config(config)
|
|
104
|
+
if overrides:
|
|
105
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
106
|
+
self.cfg = cfg
|
|
107
|
+
|
|
108
|
+
# Logger consistent with NLPCA
|
|
109
|
+
logman = LoggerManager(
|
|
110
|
+
__name__,
|
|
111
|
+
prefix=self.cfg.io.prefix,
|
|
112
|
+
debug=self.cfg.io.debug,
|
|
113
|
+
verbose=self.cfg.io.verbose,
|
|
114
|
+
)
|
|
115
|
+
self.logger = logman.get_logger()
|
|
116
|
+
|
|
117
|
+
# BaseNNImputer bootstrapping (device/dirs/logging handled here)
|
|
118
|
+
super().__init__(
|
|
119
|
+
prefix=self.cfg.io.prefix,
|
|
120
|
+
device=self.cfg.train.device,
|
|
121
|
+
verbose=self.cfg.io.verbose,
|
|
122
|
+
debug=self.cfg.io.debug,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Model hook & encoder
|
|
126
|
+
self.Model = AutoencoderModel
|
|
127
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
128
|
+
|
|
129
|
+
# IO / global
|
|
130
|
+
self.seed = self.cfg.io.seed
|
|
131
|
+
self.n_jobs = self.cfg.io.n_jobs
|
|
132
|
+
self.prefix = self.cfg.io.prefix
|
|
133
|
+
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
134
|
+
self.verbose = self.cfg.io.verbose
|
|
135
|
+
self.debug = self.cfg.io.debug
|
|
136
|
+
self.rng = np.random.default_rng(self.seed)
|
|
137
|
+
|
|
138
|
+
# Model hyperparams
|
|
139
|
+
self.latent_dim = self.cfg.model.latent_dim
|
|
140
|
+
self.dropout_rate = self.cfg.model.dropout_rate
|
|
141
|
+
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
142
|
+
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
143
|
+
self.layer_schedule = self.cfg.model.layer_schedule
|
|
144
|
+
self.activation = self.cfg.model.hidden_activation
|
|
145
|
+
self.gamma = self.cfg.model.gamma
|
|
146
|
+
|
|
147
|
+
# Train hyperparams
|
|
148
|
+
self.batch_size = self.cfg.train.batch_size
|
|
149
|
+
self.learning_rate = self.cfg.train.learning_rate
|
|
150
|
+
self.l1_penalty = self.cfg.train.l1_penalty
|
|
151
|
+
self.early_stop_gen = self.cfg.train.early_stop_gen
|
|
152
|
+
self.min_epochs = self.cfg.train.min_epochs
|
|
153
|
+
self.epochs = self.cfg.train.max_epochs
|
|
154
|
+
self.validation_split = self.cfg.train.validation_split
|
|
155
|
+
self.beta = self.cfg.train.weights_beta
|
|
156
|
+
self.max_ratio = self.cfg.train.weights_max_ratio
|
|
157
|
+
|
|
158
|
+
# Tuning
|
|
159
|
+
self.tune = self.cfg.tune.enabled
|
|
160
|
+
self.tune_fast = self.cfg.tune.fast
|
|
161
|
+
self.tune_batch_size = self.cfg.tune.batch_size
|
|
162
|
+
self.tune_epochs = self.cfg.tune.epochs
|
|
163
|
+
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
164
|
+
self.tune_metric = self.cfg.tune.metric
|
|
165
|
+
self.n_trials = self.cfg.tune.n_trials
|
|
166
|
+
self.tune_save_db = self.cfg.tune.save_db
|
|
167
|
+
self.tune_resume = self.cfg.tune.resume
|
|
168
|
+
self.tune_max_samples = self.cfg.tune.max_samples
|
|
169
|
+
self.tune_max_loci = self.cfg.tune.max_loci
|
|
170
|
+
self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 0) # AE unused
|
|
171
|
+
self.tune_patience = self.cfg.tune.patience
|
|
172
|
+
|
|
173
|
+
# Evaluate
|
|
174
|
+
# AE does not optimize latents, so these are unused / fixed
|
|
175
|
+
self.eval_latent_steps = 0
|
|
176
|
+
self.eval_latent_lr = 0.0
|
|
177
|
+
self.eval_latent_weight_decay = 0.0
|
|
178
|
+
|
|
179
|
+
# Plotting (parity with NLPCA PlotConfig)
|
|
180
|
+
self.plot_format = self.cfg.plot.fmt
|
|
181
|
+
self.plot_dpi = self.cfg.plot.dpi
|
|
182
|
+
self.plot_fontsize = self.cfg.plot.fontsize
|
|
183
|
+
self.title_fontsize = self.cfg.plot.fontsize
|
|
184
|
+
self.despine = self.cfg.plot.despine
|
|
185
|
+
self.show_plots = self.cfg.plot.show
|
|
186
|
+
|
|
187
|
+
# Core derived at fit-time
|
|
188
|
+
self.is_haploid: bool | None = None
|
|
189
|
+
self.num_classes_: int | None = None
|
|
190
|
+
self.model_params: Dict[str, Any] = {}
|
|
191
|
+
|
|
192
|
+
def fit(self) -> "ImputeAutoencoder":
|
|
193
|
+
"""Fit the autoencoder on 0/1/2 encoded genotypes (missing → -9).
|
|
194
|
+
|
|
195
|
+
This method trains the autoencoder model using the provided genotype data. It prepares the data by encoding genotypes as 0, 1, and 2, with missing values represented as -9. The method splits the data into training and validation sets, initializes the model and training parameters, and performs training with optional hyperparameter tuning. After training, it evaluates the model on the validation set and stores the fitted model and training history.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
ImputeAutoencoder: Fitted instance.
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
NotFittedError: If training fails.
|
|
202
|
+
"""
|
|
203
|
+
self.logger.info(f"Fitting {self.model_name} (0/1/2 AE) ...")
|
|
204
|
+
|
|
205
|
+
# --- Data prep (mirror NLPCA) ---
|
|
206
|
+
X = self.pgenc.genotypes_012.astype(np.float32)
|
|
207
|
+
X[X < 0] = np.nan
|
|
208
|
+
X[np.isnan(X)] = -1
|
|
209
|
+
self.ground_truth_ = X.astype(np.int64)
|
|
210
|
+
|
|
211
|
+
# Ploidy & classes
|
|
212
|
+
self.is_haploid = np.all(
|
|
213
|
+
np.isin(
|
|
214
|
+
self.genotype_data.snp_data,
|
|
215
|
+
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
self.ploidy = 1 if self.is_haploid else 2
|
|
219
|
+
self.num_classes_ = 2 if self.is_haploid else 3
|
|
220
|
+
self.logger.info(
|
|
221
|
+
f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
|
|
222
|
+
f"using {self.num_classes_} classes."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
n_samples, self.num_features_ = X.shape
|
|
226
|
+
|
|
227
|
+
# Model params (decoder outputs L * K logits)
|
|
228
|
+
self.model_params = {
|
|
229
|
+
"n_features": self.num_features_,
|
|
230
|
+
"num_classes": self.num_classes_,
|
|
231
|
+
"latent_dim": self.latent_dim,
|
|
232
|
+
"dropout_rate": self.dropout_rate,
|
|
233
|
+
"activation": self.activation,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# Train/Val split
|
|
237
|
+
indices = np.arange(n_samples)
|
|
238
|
+
train_idx, val_idx = train_test_split(
|
|
239
|
+
indices, test_size=self.validation_split, random_state=self.seed
|
|
240
|
+
)
|
|
241
|
+
self.train_idx_, self.test_idx_ = train_idx, val_idx
|
|
242
|
+
self.X_train_ = self.ground_truth_[train_idx]
|
|
243
|
+
self.X_val_ = self.ground_truth_[val_idx]
|
|
244
|
+
|
|
245
|
+
# Plotters/scorers (shared utilities)
|
|
246
|
+
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
247
|
+
|
|
248
|
+
# Tuning (optional; AE never needs latent refinement)
|
|
249
|
+
if self.tune:
|
|
250
|
+
self.tune_hyperparameters()
|
|
251
|
+
|
|
252
|
+
# Best params (tuned or default)
|
|
253
|
+
self.best_params_ = getattr(self, "best_params_", self._default_best_params())
|
|
254
|
+
|
|
255
|
+
# Class weights (device-aware)
|
|
256
|
+
self.class_weights_ = self._class_weights_from_zygosity(self.X_train_).to(
|
|
257
|
+
self.device
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# DataLoader
|
|
261
|
+
train_loader = self._get_data_loaders(self.X_train_)
|
|
262
|
+
|
|
263
|
+
# Build & train
|
|
264
|
+
model = self.build_model(self.Model, self.best_params_)
|
|
265
|
+
model.apply(self.initialize_weights)
|
|
266
|
+
|
|
267
|
+
loss, trained_model, history = self._train_and_validate_model(
|
|
268
|
+
model=model,
|
|
269
|
+
loader=train_loader,
|
|
270
|
+
lr=self.learning_rate,
|
|
271
|
+
l1_penalty=self.l1_penalty,
|
|
272
|
+
return_history=True,
|
|
273
|
+
class_weights=self.class_weights_,
|
|
274
|
+
X_val=self.X_val_,
|
|
275
|
+
params=self.best_params_,
|
|
276
|
+
prune_metric=self.tune_metric,
|
|
277
|
+
prune_warmup_epochs=5,
|
|
278
|
+
eval_interval=1,
|
|
279
|
+
eval_requires_latents=False,
|
|
280
|
+
eval_latent_steps=0,
|
|
281
|
+
eval_latent_lr=0.0,
|
|
282
|
+
eval_latent_weight_decay=0.0,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
if trained_model is None:
|
|
286
|
+
msg = "Autoencoder training failed; no model was returned."
|
|
287
|
+
self.logger.error(msg)
|
|
288
|
+
raise RuntimeError(msg)
|
|
289
|
+
|
|
290
|
+
torch.save(
|
|
291
|
+
trained_model.state_dict(),
|
|
292
|
+
self.models_dir / f"final_model_{self.model_name}.pt",
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
self.best_loss_, self.model_, self.history_ = (
|
|
296
|
+
loss,
|
|
297
|
+
trained_model,
|
|
298
|
+
{"Train": history},
|
|
299
|
+
)
|
|
300
|
+
self.is_fit_ = True
|
|
301
|
+
|
|
302
|
+
# Evaluate on validation set (parity with NLPCA reporting)
|
|
303
|
+
self._evaluate_model(self.X_val_, self.model_, self.best_params_)
|
|
304
|
+
self.plotter_.plot_history(self.history_)
|
|
305
|
+
self._save_best_params(self.best_params_)
|
|
306
|
+
|
|
307
|
+
return self
|
|
308
|
+
|
|
309
|
+
def transform(self) -> np.ndarray:
|
|
310
|
+
"""Impute missing genotypes (0/1/2) and return IUPAC strings.
|
|
311
|
+
|
|
312
|
+
This method imputes missing genotypes in the dataset using the trained autoencoder model. It predicts the most likely genotype (0, 1, or 2) for each missing entry and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easier interpretation.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
np.ndarray: IUPAC strings of shape (n_samples, n_loci).
|
|
316
|
+
|
|
317
|
+
Raises:
|
|
318
|
+
NotFittedError: If called before fit().
|
|
319
|
+
"""
|
|
320
|
+
if not getattr(self, "is_fit_", False):
|
|
321
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
322
|
+
|
|
323
|
+
self.logger.info("Imputing entire dataset with AE (0/1/2)...")
|
|
324
|
+
X_to_impute = self.ground_truth_.copy()
|
|
325
|
+
|
|
326
|
+
# Predict with masked inputs (no latent optimization)
|
|
327
|
+
pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
|
|
328
|
+
|
|
329
|
+
# Fill only missing
|
|
330
|
+
missing_mask = X_to_impute == -1
|
|
331
|
+
imputed_array = X_to_impute.copy()
|
|
332
|
+
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
333
|
+
|
|
334
|
+
# Decode to IUPAC & plot
|
|
335
|
+
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
336
|
+
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
337
|
+
|
|
338
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
339
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
340
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
341
|
+
|
|
342
|
+
return imputed_genotypes
|
|
343
|
+
|
|
344
|
+
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
345
|
+
"""Create DataLoader over indices + integer targets (-1 for missing).
|
|
346
|
+
|
|
347
|
+
This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
y (np.ndarray): 0/1/2 matrix with -1 for missing.
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
torch.utils.data.DataLoader: Shuffled DataLoader.
|
|
354
|
+
"""
|
|
355
|
+
y_tensor = torch.from_numpy(y).long().to(self.device)
|
|
356
|
+
dataset = torch.utils.data.TensorDataset(
|
|
357
|
+
torch.arange(len(y), device=self.device), y_tensor
|
|
358
|
+
)
|
|
359
|
+
return torch.utils.data.DataLoader(
|
|
360
|
+
dataset, batch_size=self.batch_size, shuffle=True
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def _train_and_validate_model(
|
|
364
|
+
self,
|
|
365
|
+
model: torch.nn.Module,
|
|
366
|
+
loader: torch.utils.data.DataLoader,
|
|
367
|
+
lr: float,
|
|
368
|
+
l1_penalty: float,
|
|
369
|
+
trial: optuna.Trial | None = None,
|
|
370
|
+
return_history: bool = False,
|
|
371
|
+
class_weights: torch.Tensor | None = None,
|
|
372
|
+
*,
|
|
373
|
+
X_val: np.ndarray | None = None,
|
|
374
|
+
params: dict | None = None,
|
|
375
|
+
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
376
|
+
prune_warmup_epochs: int = 3,
|
|
377
|
+
eval_interval: int = 1,
|
|
378
|
+
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
379
|
+
eval_requires_latents: bool = False, # AE: always False
|
|
380
|
+
eval_latent_steps: int = 0,
|
|
381
|
+
eval_latent_lr: float = 0.0,
|
|
382
|
+
eval_latent_weight_decay: float = 0.0,
|
|
383
|
+
) -> Tuple[float, torch.nn.Module | None, list | None]:
|
|
384
|
+
"""Wrap the AE training loop (no latent optimizer), with Optuna pruning.
|
|
385
|
+
|
|
386
|
+
This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
model (torch.nn.Module): Autoencoder model.
|
|
390
|
+
loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
|
|
391
|
+
lr (float): Learning rate.
|
|
392
|
+
l1_penalty (float): L1 regularization coeff.
|
|
393
|
+
trial (optuna.Trial | None): Optuna trial for pruning (optional).
|
|
394
|
+
return_history (bool): If True, return train loss history.
|
|
395
|
+
class_weights (torch.Tensor | None): Class weights tensor (on device).
|
|
396
|
+
X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
|
|
397
|
+
params (dict | None): Model params for evaluation.
|
|
398
|
+
prune_metric (str | None): Metric for pruning reports.
|
|
399
|
+
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
400
|
+
eval_interval (int): Eval frequency (epochs).
|
|
401
|
+
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
402
|
+
eval_latent_steps (int): Unused for AE.
|
|
403
|
+
eval_latent_lr (float): Unused for AE.
|
|
404
|
+
eval_latent_weight_decay (float): Unused for AE.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Tuple[float, torch.nn.Module | None, list | None]: (best_loss, best_model, history or None).
|
|
408
|
+
"""
|
|
409
|
+
if class_weights is None:
|
|
410
|
+
msg = "Must provide class_weights."
|
|
411
|
+
self.logger.error(msg)
|
|
412
|
+
raise TypeError(msg)
|
|
413
|
+
|
|
414
|
+
# Epoch budget mirrors NLPCA config (tuning vs final)
|
|
415
|
+
max_epochs = (
|
|
416
|
+
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
420
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
421
|
+
|
|
422
|
+
best_loss, best_model, hist = self._execute_training_loop(
|
|
423
|
+
loader=loader,
|
|
424
|
+
optimizer=optimizer,
|
|
425
|
+
scheduler=scheduler,
|
|
426
|
+
model=model,
|
|
427
|
+
l1_penalty=l1_penalty,
|
|
428
|
+
trial=trial,
|
|
429
|
+
return_history=return_history,
|
|
430
|
+
class_weights=class_weights,
|
|
431
|
+
X_val=X_val,
|
|
432
|
+
params=params,
|
|
433
|
+
prune_metric=prune_metric,
|
|
434
|
+
prune_warmup_epochs=prune_warmup_epochs,
|
|
435
|
+
eval_interval=eval_interval,
|
|
436
|
+
eval_requires_latents=False, # AE: no latent inference
|
|
437
|
+
eval_latent_steps=0,
|
|
438
|
+
eval_latent_lr=0.0,
|
|
439
|
+
eval_latent_weight_decay=0.0,
|
|
440
|
+
)
|
|
441
|
+
if return_history:
|
|
442
|
+
return best_loss, best_model, hist
|
|
443
|
+
|
|
444
|
+
return best_loss, best_model, None
|
|
445
|
+
|
|
446
|
+
def _execute_training_loop(
|
|
447
|
+
self,
|
|
448
|
+
loader: torch.utils.data.DataLoader,
|
|
449
|
+
optimizer: torch.optim.Optimizer,
|
|
450
|
+
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
|
451
|
+
model: torch.nn.Module,
|
|
452
|
+
l1_penalty: float,
|
|
453
|
+
trial: optuna.Trial | None,
|
|
454
|
+
return_history: bool,
|
|
455
|
+
class_weights: torch.Tensor,
|
|
456
|
+
*,
|
|
457
|
+
X_val: np.ndarray | None = None,
|
|
458
|
+
params: dict | None = None,
|
|
459
|
+
prune_metric: str | None = None,
|
|
460
|
+
prune_warmup_epochs: int = 3,
|
|
461
|
+
eval_interval: int = 1,
|
|
462
|
+
# Evaluation parameters (AE ignores latent refinement knobs)
|
|
463
|
+
eval_requires_latents: bool = False, # AE: False
|
|
464
|
+
eval_latent_steps: int = 0,
|
|
465
|
+
eval_latent_lr: float = 0.0,
|
|
466
|
+
eval_latent_weight_decay: float = 0.0,
|
|
467
|
+
) -> Tuple[float, torch.nn.Module, list]:
|
|
468
|
+
"""Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
|
|
469
|
+
|
|
470
|
+
This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
|
|
474
|
+
optimizer (torch.optim.Optimizer): Optimizer.
|
|
475
|
+
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
|
|
476
|
+
model (torch.nn.Module): Autoencoder model.
|
|
477
|
+
l1_penalty (float): L1 regularization coeff.
|
|
478
|
+
trial (optuna.Trial | None): Optuna trial for pruning (optional).
|
|
479
|
+
return_history (bool): If True, return train loss history.
|
|
480
|
+
class_weights (torch.Tensor): Class weights tensor (on device).
|
|
481
|
+
X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
|
|
482
|
+
params (dict | None): Model params for evaluation.
|
|
483
|
+
prune_metric (str | None): Metric for pruning reports.
|
|
484
|
+
prune_warmup_epochs (int): Pruning warmup epochs.
|
|
485
|
+
eval_interval (int): Eval frequency (epochs).
|
|
486
|
+
eval_requires_latents (bool): Ignored for AE (no latent inference).
|
|
487
|
+
eval_latent_steps (int): Unused for AE.
|
|
488
|
+
eval_latent_lr (float): Unused for AE.
|
|
489
|
+
eval_latent_weight_decay (float): Unused for AE.
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
Tuple[float, torch.nn.Module, list]: Best validation loss, best model, and training history.
|
|
493
|
+
"""
|
|
494
|
+
best_loss = float("inf")
|
|
495
|
+
best_model = None
|
|
496
|
+
history: list[float] = []
|
|
497
|
+
|
|
498
|
+
early_stopping = EarlyStopping(
|
|
499
|
+
patience=self.early_stop_gen,
|
|
500
|
+
min_epochs=self.min_epochs,
|
|
501
|
+
verbose=self.verbose,
|
|
502
|
+
prefix=self.prefix,
|
|
503
|
+
debug=self.debug,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Parity with NLPCA (warm/ramp gamma schedule)
|
|
507
|
+
warm, ramp, gamma_final = 50, 100, self.gamma
|
|
508
|
+
|
|
509
|
+
# Epoch budget mirrors the caller's scheduler T_max
|
|
510
|
+
# (already set to tune_epochs or epochs).
|
|
511
|
+
for epoch in range(scheduler.T_max):
|
|
512
|
+
# Gamma schedule
|
|
513
|
+
if epoch < warm:
|
|
514
|
+
model.gamma = 0.0
|
|
515
|
+
elif epoch < warm + ramp:
|
|
516
|
+
model.gamma = gamma_final * ((epoch - warm) / ramp)
|
|
517
|
+
else:
|
|
518
|
+
model.gamma = gamma_final
|
|
519
|
+
|
|
520
|
+
# ---- one epoch ----
|
|
521
|
+
train_loss = self._train_step(
|
|
522
|
+
loader=loader,
|
|
523
|
+
optimizer=optimizer,
|
|
524
|
+
model=model,
|
|
525
|
+
l1_penalty=l1_penalty,
|
|
526
|
+
class_weights=class_weights,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
if trial and (np.isnan(train_loss) or np.isinf(train_loss)):
|
|
530
|
+
raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
|
|
531
|
+
|
|
532
|
+
scheduler.step()
|
|
533
|
+
if return_history:
|
|
534
|
+
history.append(train_loss)
|
|
535
|
+
|
|
536
|
+
early_stopping(train_loss, model)
|
|
537
|
+
if early_stopping.early_stop:
|
|
538
|
+
self.logger.info(f"Early stopping at epoch {epoch + 1}.")
|
|
539
|
+
break
|
|
540
|
+
|
|
541
|
+
# Optuna report/prune on validation metric
|
|
542
|
+
if (
|
|
543
|
+
trial is not None
|
|
544
|
+
and X_val is not None
|
|
545
|
+
and ((epoch + 1) % eval_interval == 0)
|
|
546
|
+
):
|
|
547
|
+
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
548
|
+
metric_val = self._eval_for_pruning(
|
|
549
|
+
model=model,
|
|
550
|
+
X_val=X_val,
|
|
551
|
+
params=params or getattr(self, "best_params_", {}),
|
|
552
|
+
metric=metric_key,
|
|
553
|
+
objective_mode=True,
|
|
554
|
+
do_latent_infer=False, # AE: False
|
|
555
|
+
latent_steps=0,
|
|
556
|
+
latent_lr=0.0,
|
|
557
|
+
latent_weight_decay=0.0,
|
|
558
|
+
latent_seed=(self.seed if self.seed is not None else 123),
|
|
559
|
+
_latent_cache=None, # AE: not used
|
|
560
|
+
_latent_cache_key=None,
|
|
561
|
+
)
|
|
562
|
+
trial.report(metric_val, step=epoch + 1)
|
|
563
|
+
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
564
|
+
raise optuna.exceptions.TrialPruned(
|
|
565
|
+
f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
best_loss = early_stopping.best_score
|
|
569
|
+
best_model = copy.deepcopy(early_stopping.best_model)
|
|
570
|
+
return best_loss, best_model, history
|
|
571
|
+
|
|
572
|
+
def _train_step(
|
|
573
|
+
self,
|
|
574
|
+
loader: torch.utils.data.DataLoader,
|
|
575
|
+
optimizer: torch.optim.Optimizer,
|
|
576
|
+
model: torch.nn.Module,
|
|
577
|
+
l1_penalty: float,
|
|
578
|
+
class_weights: torch.Tensor,
|
|
579
|
+
) -> float:
|
|
580
|
+
"""One epoch (indices, y_int) → one-hot inputs → logits → masked focal CE.
|
|
581
|
+
|
|
582
|
+
This method performs a single training epoch, processing batches of data from the DataLoader. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
loader (DataLoader): Yields (indices, y_int) where y_int is 0/1/2, -1 for missing.
|
|
586
|
+
optimizer (torch.optim.Optimizer): Optimizer.
|
|
587
|
+
model (torch.nn.Module): Autoencoder model.
|
|
588
|
+
l1_penalty (float): L1 regularization.
|
|
589
|
+
class_weights (torch.Tensor): Class weights for CE.
|
|
590
|
+
|
|
591
|
+
Returns:
|
|
592
|
+
float: Mean training loss for the epoch.
|
|
593
|
+
"""
|
|
594
|
+
model.train()
|
|
595
|
+
running = 0.0
|
|
596
|
+
|
|
597
|
+
for _, y_batch in loader:
|
|
598
|
+
optimizer.zero_grad(set_to_none=True)
|
|
599
|
+
|
|
600
|
+
# Inputs: one-hot with zeros for missing; Targets: ints with -1
|
|
601
|
+
x_ohe = self._one_hot_encode_012(y_batch) # (B, L, K)
|
|
602
|
+
logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
|
|
603
|
+
|
|
604
|
+
logits_flat = logits.view(-1, self.num_classes_)
|
|
605
|
+
targets_flat = y_batch.view(-1)
|
|
606
|
+
|
|
607
|
+
ce = F.cross_entropy(
|
|
608
|
+
logits_flat,
|
|
609
|
+
targets_flat,
|
|
610
|
+
weight=class_weights,
|
|
611
|
+
reduction="none",
|
|
612
|
+
ignore_index=-1,
|
|
613
|
+
)
|
|
614
|
+
pt = torch.exp(-ce)
|
|
615
|
+
gamma = getattr(model, "gamma", self.gamma)
|
|
616
|
+
focal = ((1 - pt) ** gamma) * ce
|
|
617
|
+
|
|
618
|
+
valid_mask = targets_flat != -1
|
|
619
|
+
loss = (
|
|
620
|
+
focal[valid_mask].mean()
|
|
621
|
+
if valid_mask.any()
|
|
622
|
+
else torch.tensor(0.0, device=logits.device)
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
if l1_penalty > 0:
|
|
626
|
+
loss = loss + l1_penalty * sum(
|
|
627
|
+
p.abs().sum() for p in model.parameters()
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
loss.backward()
|
|
631
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
632
|
+
optimizer.step()
|
|
633
|
+
|
|
634
|
+
running += float(loss.item())
|
|
635
|
+
|
|
636
|
+
return running / len(loader)
|
|
637
|
+
|
|
638
|
+
def _predict(
|
|
639
|
+
self,
|
|
640
|
+
model: torch.nn.Module,
|
|
641
|
+
X: np.ndarray | torch.Tensor,
|
|
642
|
+
return_proba: bool = False,
|
|
643
|
+
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
|
|
644
|
+
"""Predict 0/1/2 labels (and probabilities) from masked inputs.
|
|
645
|
+
|
|
646
|
+
This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
|
|
647
|
+
|
|
648
|
+
Args:
|
|
649
|
+
model (torch.nn.Module): Trained model.
|
|
650
|
+
X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
|
|
651
|
+
for missing.
|
|
652
|
+
return_proba (bool): If True, return probabilities.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels,
|
|
656
|
+
and probabilities if requested.
|
|
657
|
+
"""
|
|
658
|
+
if model is None:
|
|
659
|
+
msg = "Model is not trained. Call fit() before predict()."
|
|
660
|
+
self.logger.error(msg)
|
|
661
|
+
raise NotFittedError(msg)
|
|
662
|
+
|
|
663
|
+
model.eval()
|
|
664
|
+
with torch.no_grad():
|
|
665
|
+
X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
|
|
666
|
+
X_tensor = X_tensor.to(self.device).long()
|
|
667
|
+
x_ohe = self._one_hot_encode_012(X_tensor)
|
|
668
|
+
logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
|
|
669
|
+
probas = torch.softmax(logits, dim=-1)
|
|
670
|
+
labels = torch.argmax(probas, dim=-1)
|
|
671
|
+
|
|
672
|
+
if return_proba:
|
|
673
|
+
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
674
|
+
|
|
675
|
+
return labels.cpu().numpy()
|
|
676
|
+
|
|
677
|
+
def _evaluate_model(
|
|
678
|
+
self,
|
|
679
|
+
X_val: np.ndarray,
|
|
680
|
+
model: torch.nn.Module,
|
|
681
|
+
params: dict,
|
|
682
|
+
objective_mode: bool = False,
|
|
683
|
+
latent_vectors_val: Optional[np.ndarray] = None,
|
|
684
|
+
) -> Dict[str, float]:
|
|
685
|
+
"""Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
|
|
686
|
+
|
|
687
|
+
This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
X_val (np.ndarray): Validation set 0/1/2 matrix with -1
|
|
691
|
+
for missing.
|
|
692
|
+
model (torch.nn.Module): Trained model.
|
|
693
|
+
params (dict): Model parameters.
|
|
694
|
+
objective_mode (bool): If True, suppress logging and reports.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
Dict[str, float]: Dictionary of evaluation metrics.
|
|
698
|
+
"""
|
|
699
|
+
pred_labels, pred_probas = self._predict(
|
|
700
|
+
model=model, X=X_val, return_proba=True
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
# mask out true missing AND any non-finite prob rows
|
|
704
|
+
finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N,L)
|
|
705
|
+
eval_mask = (X_val != -1) & finite_mask
|
|
706
|
+
|
|
707
|
+
y_true_flat = X_val[eval_mask].astype(np.int64, copy=False)
|
|
708
|
+
y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
|
|
709
|
+
y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
|
|
710
|
+
|
|
711
|
+
if y_true_flat.size == 0:
|
|
712
|
+
return {self.tune_metric: 0.0}
|
|
713
|
+
|
|
714
|
+
# ensure valid probability simplex after masking (no NaNs/Infs, sums=1)
|
|
715
|
+
y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
|
|
716
|
+
row_sums = y_proba_flat.sum(axis=1, keepdims=True)
|
|
717
|
+
row_sums[row_sums == 0] = 1.0
|
|
718
|
+
y_proba_flat = y_proba_flat / row_sums
|
|
719
|
+
|
|
720
|
+
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
721
|
+
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
722
|
+
|
|
723
|
+
if self.is_haploid:
|
|
724
|
+
y_true_flat = y_true_flat.copy()
|
|
725
|
+
y_pred_flat = y_pred_flat.copy()
|
|
726
|
+
y_true_flat[y_true_flat == 2] = 1
|
|
727
|
+
y_pred_flat[y_pred_flat == 2] = 1
|
|
728
|
+
# collapse probs to 2-class
|
|
729
|
+
proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
|
|
730
|
+
proba_2[:, 0] = y_proba_flat[:, 0]
|
|
731
|
+
proba_2[:, 1] = y_proba_flat[:, 2]
|
|
732
|
+
y_proba_flat = proba_2
|
|
733
|
+
|
|
734
|
+
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
735
|
+
|
|
736
|
+
metrics = self.scorers_.evaluate(
|
|
737
|
+
y_true_flat,
|
|
738
|
+
y_pred_flat,
|
|
739
|
+
y_true_ohe,
|
|
740
|
+
y_proba_flat,
|
|
741
|
+
objective_mode,
|
|
742
|
+
self.tune_metric,
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
if not objective_mode:
|
|
746
|
+
self.logger.info(f"Validation Metrics: {metrics}")
|
|
747
|
+
|
|
748
|
+
# Primary report (REF/HET/ALT or REF/ALT)
|
|
749
|
+
self._make_class_reports(
|
|
750
|
+
y_true=y_true_flat,
|
|
751
|
+
y_pred_proba=y_proba_flat,
|
|
752
|
+
y_pred=y_pred_flat,
|
|
753
|
+
metrics=metrics,
|
|
754
|
+
labels=target_names,
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# IUPAC decode & 10-base integer report (parity with ImputeNLPCA)
|
|
758
|
+
y_true_dec = self.pgenc.decode_012(X_val)
|
|
759
|
+
X_pred = X_val.copy()
|
|
760
|
+
X_pred[eval_mask] = y_pred_flat
|
|
761
|
+
y_pred_dec = self.pgenc.decode_012(
|
|
762
|
+
X_pred.reshape(X_val.shape[0], self.num_features_)
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
encodings_dict = {
|
|
766
|
+
"A": 0,
|
|
767
|
+
"C": 1,
|
|
768
|
+
"G": 2,
|
|
769
|
+
"T": 3,
|
|
770
|
+
"W": 4,
|
|
771
|
+
"R": 5,
|
|
772
|
+
"M": 6,
|
|
773
|
+
"K": 7,
|
|
774
|
+
"Y": 8,
|
|
775
|
+
"S": 9,
|
|
776
|
+
"N": -1,
|
|
777
|
+
}
|
|
778
|
+
y_true_int = self.pgenc.convert_int_iupac(
|
|
779
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
780
|
+
)
|
|
781
|
+
y_pred_int = self.pgenc.convert_int_iupac(
|
|
782
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
self._make_class_reports(
|
|
786
|
+
y_true=y_true_int[eval_mask],
|
|
787
|
+
y_pred=y_pred_int[eval_mask],
|
|
788
|
+
metrics=metrics,
|
|
789
|
+
y_pred_proba=None,
|
|
790
|
+
labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
return metrics
|
|
794
|
+
|
|
795
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
796
|
+
"""Optuna objective for AE; mirrors NLPCA study driver without latents.
|
|
797
|
+
|
|
798
|
+
This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
|
|
799
|
+
|
|
800
|
+
Args:
|
|
801
|
+
trial (optuna.Trial): Optuna trial.
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
float: Value of the tuning metric (maximize).
|
|
805
|
+
"""
|
|
806
|
+
try:
|
|
807
|
+
# Sample hyperparameters (existing helper; unchanged signature)
|
|
808
|
+
params = self._sample_hyperparameters(trial)
|
|
809
|
+
|
|
810
|
+
# Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
|
|
811
|
+
X_train = self.ground_truth_[self.train_idx_]
|
|
812
|
+
X_val = self.ground_truth_[self.test_idx_]
|
|
813
|
+
|
|
814
|
+
class_weights = self._class_weights_from_zygosity(X_train).to(self.device)
|
|
815
|
+
train_loader = self._get_data_loaders(X_train)
|
|
816
|
+
|
|
817
|
+
model = self.build_model(self.Model, params["model_params"])
|
|
818
|
+
model.apply(self.initialize_weights)
|
|
819
|
+
|
|
820
|
+
# Train + prune on metric
|
|
821
|
+
_, model, _ = self._train_and_validate_model(
|
|
822
|
+
model=model,
|
|
823
|
+
loader=train_loader,
|
|
824
|
+
lr=params["lr"],
|
|
825
|
+
l1_penalty=params["l1_penalty"],
|
|
826
|
+
trial=trial,
|
|
827
|
+
return_history=False,
|
|
828
|
+
class_weights=class_weights,
|
|
829
|
+
X_val=X_val,
|
|
830
|
+
params=params,
|
|
831
|
+
prune_metric=self.tune_metric,
|
|
832
|
+
prune_warmup_epochs=5,
|
|
833
|
+
eval_interval=self.tune_eval_interval,
|
|
834
|
+
eval_requires_latents=False,
|
|
835
|
+
eval_latent_steps=0,
|
|
836
|
+
eval_latent_lr=0.0,
|
|
837
|
+
eval_latent_weight_decay=0.0,
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
metrics = self._evaluate_model(X_val, model, params, objective_mode=True)
|
|
841
|
+
self._clear_resources(model, train_loader)
|
|
842
|
+
return metrics[self.tune_metric]
|
|
843
|
+
|
|
844
|
+
except Exception as e:
|
|
845
|
+
# Keep sweeps moving if a trial fails
|
|
846
|
+
raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
|
|
847
|
+
|
|
848
|
+
def _sample_hyperparameters(
|
|
849
|
+
self, trial: optuna.Trial
|
|
850
|
+
) -> Dict[str, int | float | str]:
|
|
851
|
+
"""Sample AE hyperparameters and compute hidden sizes for model params.
|
|
852
|
+
|
|
853
|
+
This method samples hyperparameters for the autoencoder model using Optuna's trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
|
|
854
|
+
|
|
855
|
+
Args:
|
|
856
|
+
trial (optuna.Trial): Optuna trial object.
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
Dict[str, int | float | str]: Sampled hyperparameters and model_params.
|
|
860
|
+
"""
|
|
861
|
+
params = {
|
|
862
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 64),
|
|
863
|
+
"lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
|
|
864
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
|
|
865
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
|
|
866
|
+
"activation": trial.suggest_categorical(
|
|
867
|
+
"activation", ["relu", "elu", "selu"]
|
|
868
|
+
),
|
|
869
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
|
|
870
|
+
"layer_scaling_factor": trial.suggest_float(
|
|
871
|
+
"layer_scaling_factor", 2.0, 10.0
|
|
872
|
+
),
|
|
873
|
+
"layer_schedule": trial.suggest_categorical(
|
|
874
|
+
"layer_schedule", ["pyramid", "constant", "linear"]
|
|
875
|
+
),
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
input_dim = self.num_features_ * self.num_classes_
|
|
879
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
880
|
+
n_inputs=input_dim,
|
|
881
|
+
n_outputs=input_dim,
|
|
882
|
+
n_samples=len(self.train_idx_),
|
|
883
|
+
n_hidden=params["num_hidden_layers"],
|
|
884
|
+
alpha=params["layer_scaling_factor"],
|
|
885
|
+
schedule=params["layer_schedule"],
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
# Keep the latent_dim as the first element,
|
|
889
|
+
# then the interior hidden widths.
|
|
890
|
+
# If there are no interior widths (very small nets),
|
|
891
|
+
# this still leaves [latent_dim].
|
|
892
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
893
|
+
|
|
894
|
+
params["model_params"] = {
|
|
895
|
+
"n_features": self.num_features_,
|
|
896
|
+
"num_classes": self.num_classes_,
|
|
897
|
+
"latent_dim": params["latent_dim"],
|
|
898
|
+
"dropout_rate": params["dropout_rate"],
|
|
899
|
+
"hidden_layer_sizes": hidden_only,
|
|
900
|
+
"activation": params["activation"],
|
|
901
|
+
}
|
|
902
|
+
return params
|
|
903
|
+
|
|
904
|
+
def _set_best_params(
|
|
905
|
+
self, best_params: Dict[str, int | float | str | list]
|
|
906
|
+
) -> Dict[str, int | float | str | list]:
|
|
907
|
+
"""Adopt best params (ImputeNLPCA parity) and return model_params.
|
|
908
|
+
|
|
909
|
+
This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
|
|
910
|
+
|
|
911
|
+
Args:
|
|
912
|
+
best_params (Dict[str, int | float | str | list]): Best hyperparameters from tuning.
|
|
913
|
+
|
|
914
|
+
Returns:
|
|
915
|
+
Dict[str, int | float | str | list]: Model parameters for building the model.
|
|
916
|
+
"""
|
|
917
|
+
self.latent_dim = best_params["latent_dim"]
|
|
918
|
+
self.dropout_rate = best_params["dropout_rate"]
|
|
919
|
+
self.learning_rate = best_params["learning_rate"]
|
|
920
|
+
self.l1_penalty = best_params["l1_penalty"]
|
|
921
|
+
self.activation = best_params["activation"]
|
|
922
|
+
self.layer_scaling_factor = best_params["layer_scaling_factor"]
|
|
923
|
+
self.layer_schedule = best_params["layer_schedule"]
|
|
924
|
+
|
|
925
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
926
|
+
n_inputs=self.num_features_ * self.num_classes_,
|
|
927
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
928
|
+
n_samples=len(self.train_idx_),
|
|
929
|
+
n_hidden=best_params["num_hidden_layers"],
|
|
930
|
+
alpha=best_params["layer_scaling_factor"],
|
|
931
|
+
schedule=best_params["layer_schedule"],
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
# Keep the latent_dim as the first element,
|
|
935
|
+
# then the interior hidden widths.
|
|
936
|
+
# If there are no interior widths (very small nets),
|
|
937
|
+
# this still leaves [latent_dim].
|
|
938
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
939
|
+
|
|
940
|
+
return {
|
|
941
|
+
"n_features": self.num_features_,
|
|
942
|
+
"latent_dim": self.latent_dim,
|
|
943
|
+
"hidden_layer_sizes": hidden_only,
|
|
944
|
+
"dropout_rate": self.dropout_rate,
|
|
945
|
+
"activation": self.activation,
|
|
946
|
+
"num_classes": self.num_classes_,
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
def _default_best_params(self) -> Dict[str, int | float | str | list]:
|
|
950
|
+
"""Default model params when tuning is disabled.
|
|
951
|
+
|
|
952
|
+
This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
|
|
953
|
+
|
|
954
|
+
Returns:
|
|
955
|
+
Dict[str, int | float | str | list]: Default model parameters.
|
|
956
|
+
"""
|
|
957
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
958
|
+
n_inputs=self.num_features_ * self.num_classes_,
|
|
959
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
960
|
+
n_samples=len(self.ground_truth_),
|
|
961
|
+
n_hidden=self.num_hidden_layers,
|
|
962
|
+
alpha=self.layer_scaling_factor,
|
|
963
|
+
schedule=self.layer_schedule,
|
|
964
|
+
)
|
|
965
|
+
return {
|
|
966
|
+
"n_features": self.num_features_,
|
|
967
|
+
"latent_dim": self.latent_dim,
|
|
968
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
969
|
+
"dropout_rate": self.dropout_rate,
|
|
970
|
+
"activation": self.activation,
|
|
971
|
+
"num_classes": self.num_classes_,
|
|
972
|
+
}
|