pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -0,0 +1,1288 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import optuna
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from sklearn.decomposition import PCA
|
|
10
|
+
from sklearn.exceptions import NotFittedError
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
13
|
+
from snpio.utils.logging import LoggerManager
|
|
14
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
15
|
+
|
|
16
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
17
|
+
from pgsui.data_processing.containers import UBPConfig
|
|
18
|
+
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
19
|
+
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
20
|
+
from pgsui.impute.unsupervised.models.ubp_model import UBPModel
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def ensure_ubp_config(config: UBPConfig | dict | str | None) -> UBPConfig:
|
|
27
|
+
"""Return a concrete UBPConfig from dataclass, dict, YAML path, or None.
|
|
28
|
+
|
|
29
|
+
This method normalizes the input configuration for the UBP imputer. It accepts a UBPConfig instance, a dictionary, a YAML file path, or None. If None is provided, it returns a default UBPConfig instance. If a YAML path is given, it loads the configuration from the file, supporting top-level presets. If a dictionary is provided, it flattens any nested structures and applies dot-key overrides to a base configuration, which can also be influenced by a preset if specified. The method ensures that the final output is a fully populated UBPConfig instance.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
config: UBPConfig | dict | YAML path | None.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
UBPConfig: Normalized configuration instance.
|
|
36
|
+
"""
|
|
37
|
+
if config is None:
|
|
38
|
+
return UBPConfig()
|
|
39
|
+
if isinstance(config, UBPConfig):
|
|
40
|
+
return config
|
|
41
|
+
if isinstance(config, str):
|
|
42
|
+
# YAML path — support top-level `preset`
|
|
43
|
+
return load_yaml_to_dataclass(
|
|
44
|
+
config,
|
|
45
|
+
UBPConfig,
|
|
46
|
+
preset_builder=UBPConfig.from_preset,
|
|
47
|
+
)
|
|
48
|
+
if isinstance(config, dict):
|
|
49
|
+
base = UBPConfig()
|
|
50
|
+
|
|
51
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
52
|
+
for k, v in d.items():
|
|
53
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
54
|
+
if isinstance(v, dict):
|
|
55
|
+
_flatten(kk, v, out)
|
|
56
|
+
else:
|
|
57
|
+
out[kk] = v
|
|
58
|
+
return out
|
|
59
|
+
|
|
60
|
+
preset_name = config.pop("preset", None)
|
|
61
|
+
if "io" in config and isinstance(config["io"], dict):
|
|
62
|
+
preset_name = preset_name or config["io"].pop("preset", None)
|
|
63
|
+
if preset_name:
|
|
64
|
+
base = UBPConfig.from_preset(preset_name)
|
|
65
|
+
|
|
66
|
+
flat = _flatten("", config, {})
|
|
67
|
+
return apply_dot_overrides(base, flat)
|
|
68
|
+
|
|
69
|
+
raise TypeError("config must be a UBPConfig, dict, YAML path, or None.")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ImputeUBP(BaseNNImputer):
|
|
73
|
+
"""UBP imputer for 0/1/2 genotypes with three-phase training.
|
|
74
|
+
|
|
75
|
+
This imputer uses a three-phase training schedule specific to the UBP model:
|
|
76
|
+
|
|
77
|
+
1. Pre-training: Train the model on the full dataset with a small learning rate.
|
|
78
|
+
2. Fine-tuning: Train the model on the full dataset with a larger learning rate.
|
|
79
|
+
3. Evaluation: Evaluate the model on the test set. Optimize latents for test set. Predict 0/1/2. Decode to IUPAC. Plot & report.
|
|
80
|
+
4. Post-processing: Apply any necessary post-processing steps to the imputed genotypes.
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
References:
|
|
84
|
+
- Gashler, Michael S., Smith, Michael R., Morris, R., and Martinez, T. (2016) Missing Value Imputation with Unsupervised Backpropagation. Computational Intelligence, 32: 196-215. doi: 10.1111/coin.12048.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
genotype_data: "GenotypeData",
|
|
90
|
+
*,
|
|
91
|
+
config: UBPConfig | dict | str | None = None,
|
|
92
|
+
overrides: dict | None = None,
|
|
93
|
+
):
|
|
94
|
+
"""Initialize the UBP imputer via dataclass/dict/YAML config with overrides.
|
|
95
|
+
|
|
96
|
+
This constructor allows for flexible initialization of the UBP imputer by accepting various forms of configuration input. It ensures that the configuration is properly normalized and any specified overrides are applied. The method also sets up logging and initializes various attributes related to the model, training, tuning, and evaluation based on the provided configuration.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
100
|
+
config (UBPConfig | dict | str | None): UBP configuration.
|
|
101
|
+
overrides (dict | None): Flat dot-key overrides applied after `config`.
|
|
102
|
+
"""
|
|
103
|
+
self.model_name = "ImputeUBP"
|
|
104
|
+
self.genotype_data = genotype_data
|
|
105
|
+
|
|
106
|
+
# ---- normalize config, then apply overrides ----
|
|
107
|
+
cfg = ensure_ubp_config(config)
|
|
108
|
+
if overrides:
|
|
109
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
110
|
+
self.cfg = cfg
|
|
111
|
+
|
|
112
|
+
# ---- logging ----
|
|
113
|
+
logman = LoggerManager(
|
|
114
|
+
__name__,
|
|
115
|
+
prefix=self.cfg.io.prefix,
|
|
116
|
+
debug=self.cfg.io.debug,
|
|
117
|
+
verbose=self.cfg.io.verbose,
|
|
118
|
+
)
|
|
119
|
+
self.logger = logman.get_logger()
|
|
120
|
+
|
|
121
|
+
# ---- Base init ----
|
|
122
|
+
super().__init__(
|
|
123
|
+
prefix=self.cfg.io.prefix,
|
|
124
|
+
device=self.cfg.train.device,
|
|
125
|
+
verbose=self.cfg.io.verbose,
|
|
126
|
+
debug=self.cfg.io.debug,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# ---- model/meta ----
|
|
130
|
+
self.Model = UBPModel
|
|
131
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
132
|
+
|
|
133
|
+
self.seed = self.cfg.io.seed
|
|
134
|
+
self.n_jobs = self.cfg.io.n_jobs
|
|
135
|
+
self.prefix = self.cfg.io.prefix
|
|
136
|
+
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
137
|
+
self.verbose = self.cfg.io.verbose
|
|
138
|
+
self.debug = self.cfg.io.debug
|
|
139
|
+
self.rng = np.random.default_rng(self.seed)
|
|
140
|
+
|
|
141
|
+
# ---- model hyperparams ----
|
|
142
|
+
self.latent_dim = self.cfg.model.latent_dim
|
|
143
|
+
self.dropout_rate = self.cfg.model.dropout_rate
|
|
144
|
+
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
145
|
+
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
146
|
+
self.layer_schedule = self.cfg.model.layer_schedule
|
|
147
|
+
self.latent_init = self.cfg.model.latent_init
|
|
148
|
+
self.activation = self.cfg.model.hidden_activation
|
|
149
|
+
self.gamma = self.cfg.model.gamma
|
|
150
|
+
|
|
151
|
+
# ---- training ----
|
|
152
|
+
self.batch_size = self.cfg.train.batch_size
|
|
153
|
+
self.learning_rate = self.cfg.train.learning_rate
|
|
154
|
+
self.lr_input_factor = self.cfg.train.lr_input_factor
|
|
155
|
+
self.l1_penalty = self.cfg.train.l1_penalty
|
|
156
|
+
self.early_stop_gen = self.cfg.train.early_stop_gen
|
|
157
|
+
self.min_epochs = self.cfg.train.min_epochs
|
|
158
|
+
self.epochs = self.cfg.train.max_epochs
|
|
159
|
+
self.validation_split = self.cfg.train.validation_split
|
|
160
|
+
self.beta = self.cfg.train.weights_beta
|
|
161
|
+
self.max_ratio = self.cfg.train.weights_max_ratio
|
|
162
|
+
|
|
163
|
+
# ---- tuning ----
|
|
164
|
+
self.tune = self.cfg.tune.enabled
|
|
165
|
+
self.tune_fast = self.cfg.tune.fast
|
|
166
|
+
self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
|
|
167
|
+
self.tune_batch_size = self.cfg.tune.batch_size
|
|
168
|
+
self.tune_epochs = self.cfg.tune.epochs
|
|
169
|
+
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
170
|
+
self.tune_metric = self.cfg.tune.metric
|
|
171
|
+
self.n_trials = self.cfg.tune.n_trials
|
|
172
|
+
self.tune_save_db = self.cfg.tune.save_db
|
|
173
|
+
self.tune_resume = self.cfg.tune.resume
|
|
174
|
+
self.tune_max_samples = self.cfg.tune.max_samples
|
|
175
|
+
self.tune_max_loci = self.cfg.tune.max_loci
|
|
176
|
+
self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
|
|
177
|
+
self.tune_patience = self.cfg.tune.patience
|
|
178
|
+
|
|
179
|
+
# ---- evaluation ----
|
|
180
|
+
self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
|
|
181
|
+
self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
|
|
182
|
+
self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
|
|
183
|
+
|
|
184
|
+
# ---- plotting ----
|
|
185
|
+
self.plot_format = self.cfg.plot.fmt
|
|
186
|
+
self.plot_dpi = self.cfg.plot.dpi
|
|
187
|
+
self.plot_fontsize = self.cfg.plot.fontsize
|
|
188
|
+
self.title_fontsize = self.cfg.plot.fontsize
|
|
189
|
+
self.despine = self.cfg.plot.despine
|
|
190
|
+
self.show_plots = self.cfg.plot.show
|
|
191
|
+
|
|
192
|
+
# ---- core runtime ----
|
|
193
|
+
self.is_haploid = None
|
|
194
|
+
self.num_classes_ = None
|
|
195
|
+
self.model_params: Dict[str, Any] = {}
|
|
196
|
+
|
|
197
|
+
def fit(self) -> "ImputeUBP":
|
|
198
|
+
"""Fit the UBP decoder on 0/1/2 encodings (missing = -1). Three phases.
|
|
199
|
+
|
|
200
|
+
1. Pre-training: Train the model on the full dataset with a small learning rate.
|
|
201
|
+
2. Fine-tuning: Train the model on the full dataset with a larger learning rate.
|
|
202
|
+
3. Evaluation: Evaluate the model on the test set. Optimize latents for test set. Predict 0/1/2. Decode to IUPAC. Plot & report.
|
|
203
|
+
4. Post-processing: Apply any necessary post-processing steps to the imputed genotypes.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
ImputeUBP: Fitted instance.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
NotFittedError: If training fails.
|
|
210
|
+
"""
|
|
211
|
+
self.logger.info(f"Fitting {self.model_name} model...")
|
|
212
|
+
|
|
213
|
+
# --- Use 0/1/2 with -1 for missing ---
|
|
214
|
+
X = self.pgenc.genotypes_012.astype(np.float32)
|
|
215
|
+
X[X < 0] = np.nan
|
|
216
|
+
X[np.isnan(X)] = -1
|
|
217
|
+
self.ground_truth_ = X.astype(np.int64)
|
|
218
|
+
|
|
219
|
+
# --- Determine ploidy (haploid vs diploid) and classes ---
|
|
220
|
+
self.is_haploid = np.all(
|
|
221
|
+
np.isin(
|
|
222
|
+
self.genotype_data.snp_data, ["A", "C", "G", "T", "N", "-", ".", "?"]
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
self.ploidy = 1 if self.is_haploid else 2
|
|
226
|
+
|
|
227
|
+
if self.is_haploid:
|
|
228
|
+
self.num_classes_ = 2
|
|
229
|
+
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
230
|
+
self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
|
|
231
|
+
else:
|
|
232
|
+
self.num_classes_ = 3
|
|
233
|
+
self.logger.info(
|
|
234
|
+
"Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
n_samples, self.num_features_ = X.shape
|
|
238
|
+
|
|
239
|
+
# --- model params (decoder: Z -> L * num_classes) ---
|
|
240
|
+
self.model_params = {
|
|
241
|
+
"n_features": self.num_features_,
|
|
242
|
+
"num_classes": self.num_classes_,
|
|
243
|
+
"latent_dim": self.latent_dim,
|
|
244
|
+
"dropout_rate": self.dropout_rate,
|
|
245
|
+
"activation": self.activation,
|
|
246
|
+
# hidden_layer_sizes injected later
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
# --- split ---
|
|
250
|
+
indices = np.arange(n_samples)
|
|
251
|
+
train_idx, test_idx = train_test_split(
|
|
252
|
+
indices, test_size=self.validation_split, random_state=self.seed
|
|
253
|
+
)
|
|
254
|
+
self.train_idx_, self.test_idx_ = train_idx, test_idx
|
|
255
|
+
self.X_train_ = self.ground_truth_[train_idx]
|
|
256
|
+
self.X_test_ = self.ground_truth_[test_idx]
|
|
257
|
+
|
|
258
|
+
# --- plotting/scorers & tuning ---
|
|
259
|
+
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
260
|
+
if self.tune:
|
|
261
|
+
self.tune_hyperparameters()
|
|
262
|
+
|
|
263
|
+
self.best_params_ = getattr(
|
|
264
|
+
self, "best_params_", self._set_best_params_default()
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# --- class weights for 0/1/2 ---
|
|
268
|
+
self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
|
|
269
|
+
|
|
270
|
+
# --- latent init & loader ---
|
|
271
|
+
train_latent_vectors = self._create_latent_space(
|
|
272
|
+
self.best_params_, len(self.X_train_), self.X_train_, self.latent_init
|
|
273
|
+
)
|
|
274
|
+
train_loader = self._get_data_loaders(self.X_train_)
|
|
275
|
+
|
|
276
|
+
# --- final training (three-phase under the hood) ---
|
|
277
|
+
(self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
|
|
278
|
+
self._train_final_model(
|
|
279
|
+
loader=train_loader,
|
|
280
|
+
best_params=self.best_params_,
|
|
281
|
+
initial_latent_vectors=train_latent_vectors,
|
|
282
|
+
)
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
self.is_fit_ = True
|
|
286
|
+
self.plotter_.plot_history(self.history_)
|
|
287
|
+
self._evaluate_model(self.X_test_, self.model_, self.best_params_)
|
|
288
|
+
self._save_best_params(self.best_params_)
|
|
289
|
+
return self
|
|
290
|
+
|
|
291
|
+
def transform(self) -> np.ndarray:
|
|
292
|
+
"""Impute missing genotypes (0/1/2) and return IUPAC strings.
|
|
293
|
+
|
|
294
|
+
This method first checks if the model has been fitted. It then imputes the entire dataset by optimizing latent vectors for the ground truth data and predicting the missing genotypes using the trained UBP model. The imputed genotypes are decoded to IUPAC format, and distributions of original and imputed genotypes are plotted.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
np.ndarray: IUPAC single-character array (n_samples x L).
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
NotFittedError: If called before fit().
|
|
301
|
+
"""
|
|
302
|
+
if not getattr(self, "is_fit_", False):
|
|
303
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
304
|
+
|
|
305
|
+
self.logger.info("Imputing entire dataset with UBP (0/1/2)...")
|
|
306
|
+
X_to_impute = self.ground_truth_.copy()
|
|
307
|
+
|
|
308
|
+
optimized_latents = self._optimize_latents_for_inference(
|
|
309
|
+
X_to_impute, self.model_, self.best_params_
|
|
310
|
+
)
|
|
311
|
+
pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
|
|
312
|
+
|
|
313
|
+
missing_mask = X_to_impute == -1
|
|
314
|
+
imputed_array = X_to_impute.copy()
|
|
315
|
+
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
316
|
+
|
|
317
|
+
# Decode to IUPAC for return & plots
|
|
318
|
+
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
319
|
+
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
320
|
+
|
|
321
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
322
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
323
|
+
return imputed_genotypes
|
|
324
|
+
|
|
325
|
+
def _train_step(
|
|
326
|
+
self,
|
|
327
|
+
loader: torch.utils.data.DataLoader,
|
|
328
|
+
optimizer: torch.optim.Optimizer,
|
|
329
|
+
latent_optimizer: torch.optim.Optimizer,
|
|
330
|
+
model: torch.nn.Module,
|
|
331
|
+
l1_penalty: float,
|
|
332
|
+
latent_vectors: torch.nn.Parameter,
|
|
333
|
+
class_weights: torch.Tensor,
|
|
334
|
+
phase: int,
|
|
335
|
+
) -> Tuple[float, torch.nn.Parameter]:
|
|
336
|
+
"""Single epoch over batches for UBP with 0/1/2 focal CE.
|
|
337
|
+
|
|
338
|
+
This method handles all three UBP phases:
|
|
339
|
+
|
|
340
|
+
1. Pre-training: Train the model on the full dataset with a small learning rate.
|
|
341
|
+
2. Fine-tuning: Train the model on the full dataset with a larger learning rate.
|
|
342
|
+
3. Joint training: Train both model and latents.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
loader (torch.utils.data.DataLoader): DataLoader (indices, y_batch).
|
|
346
|
+
optimizer (torch.optim.Optimizer): Decoder optimizer.
|
|
347
|
+
latent_optimizer (torch.optim.Optimizer): Latent optimizer.
|
|
348
|
+
model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
|
|
349
|
+
l1_penalty (float): L1 regularization weight.
|
|
350
|
+
latent_vectors (torch.nn.Parameter): Trainable Z.
|
|
351
|
+
class_weights (torch.Tensor): Class weights for 0/1/2.
|
|
352
|
+
phase (int): Phase id (1, 2, 3). Phase 1 = warm-up, phase 2 = decoder-only, phase 3 = joint.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
Tuple[float, torch.nn.Parameter]: Average loss and updated latents.
|
|
356
|
+
"""
|
|
357
|
+
model.train()
|
|
358
|
+
running = 0.0
|
|
359
|
+
|
|
360
|
+
for batch_indices, y_batch in loader:
|
|
361
|
+
optimizer.zero_grad(set_to_none=True)
|
|
362
|
+
latent_optimizer.zero_grad(set_to_none=True)
|
|
363
|
+
|
|
364
|
+
decoder = model.phase1_decoder if phase == 1 else model.phase23_decoder
|
|
365
|
+
logits = decoder(latent_vectors[batch_indices]).view(
|
|
366
|
+
len(batch_indices), self.num_features_, self.num_classes_
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
logits_flat = logits.view(-1, self.num_classes_)
|
|
370
|
+
targets_flat = y_batch.view(-1)
|
|
371
|
+
|
|
372
|
+
ce = F.cross_entropy(
|
|
373
|
+
logits_flat,
|
|
374
|
+
targets_flat,
|
|
375
|
+
weight=class_weights,
|
|
376
|
+
reduction="none",
|
|
377
|
+
ignore_index=-1,
|
|
378
|
+
)
|
|
379
|
+
pt = torch.exp(-ce)
|
|
380
|
+
gamma = getattr(model, "gamma", self.gamma)
|
|
381
|
+
focal = ((1 - pt) ** gamma) * ce
|
|
382
|
+
|
|
383
|
+
valid_mask = targets_flat != -1
|
|
384
|
+
loss = (
|
|
385
|
+
focal[valid_mask].mean()
|
|
386
|
+
if valid_mask.any()
|
|
387
|
+
else torch.tensor(0.0, device=logits.device)
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
if l1_penalty > 0:
|
|
391
|
+
loss = loss + l1_penalty * sum(
|
|
392
|
+
p.abs().sum() for p in model.parameters()
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
loss.backward()
|
|
396
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
397
|
+
torch.nn.utils.clip_grad_norm_([latent_vectors], 1.0)
|
|
398
|
+
|
|
399
|
+
optimizer.step()
|
|
400
|
+
|
|
401
|
+
if phase != 2:
|
|
402
|
+
latent_optimizer.step()
|
|
403
|
+
|
|
404
|
+
running += float(loss.item())
|
|
405
|
+
|
|
406
|
+
return running / len(loader), latent_vectors
|
|
407
|
+
|
|
408
|
+
def _predict(
|
|
409
|
+
self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter | None = None
|
|
410
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
411
|
+
"""Predict 0/1/2 labels & probabilities from latents via phase23 decoder. This method requires a trained model and latent vectors.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
model (torch.nn.Module): Trained model.
|
|
415
|
+
latent_vectors (torch.nn.Parameter | None): Latent vectors.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
|
|
419
|
+
"""
|
|
420
|
+
if model is None or latent_vectors is None:
|
|
421
|
+
msg = "Model and latent vectors must be provided for prediction. Fit the model first."
|
|
422
|
+
self.logger.error(msg)
|
|
423
|
+
raise NotFittedError(msg)
|
|
424
|
+
|
|
425
|
+
model.eval()
|
|
426
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
427
|
+
with torch.no_grad():
|
|
428
|
+
logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
|
|
429
|
+
len(latent_vectors), nF, self.num_classes_
|
|
430
|
+
)
|
|
431
|
+
probas = torch.softmax(logits, dim=-1)
|
|
432
|
+
labels = torch.argmax(probas, dim=-1)
|
|
433
|
+
|
|
434
|
+
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
435
|
+
|
|
436
|
+
def _evaluate_model(
|
|
437
|
+
self,
|
|
438
|
+
X_val: np.ndarray,
|
|
439
|
+
model: torch.nn.Module,
|
|
440
|
+
params: dict,
|
|
441
|
+
objective_mode: bool = False,
|
|
442
|
+
latent_vectors_val: torch.Tensor | None = None,
|
|
443
|
+
) -> Dict[str, float]:
|
|
444
|
+
"""Evaluate on held-out set with 0/1/2 classes; also IUPAC/10-base reports.
|
|
445
|
+
|
|
446
|
+
This method evaluates the trained UBP model on a held-out validation set. It optimizes latent vectors for the validation data if they are not provided, predicts 0/1/2 labels and probabilities, and computes various performance metrics. If not in objective mode, it generates detailed classification reports and confusion matrices for both 0/1/2 genotypes and their IUPAC/10-base representations. The method returns a dictionary of evaluation metrics.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
X_val (np.ndarray): 0/1/2 with -1 for missing.
|
|
450
|
+
model (torch.nn.Module): Trained model.
|
|
451
|
+
params (dict): Model params.
|
|
452
|
+
objective_mode (bool): If True, return only tuned metric.
|
|
453
|
+
latent_vectors_val (torch.Tensor | None): Optional pre-optimized latents.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Metrics dict.
|
|
457
|
+
"""
|
|
458
|
+
if latent_vectors_val is not None:
|
|
459
|
+
test_latent_vectors = latent_vectors_val
|
|
460
|
+
else:
|
|
461
|
+
test_latent_vectors = self._optimize_latents_for_inference(
|
|
462
|
+
X_val, model, params
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
pred_labels, pred_probas = self._predict(
|
|
466
|
+
model=model, latent_vectors=test_latent_vectors
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
eval_mask = X_val != -1
|
|
470
|
+
y_true_flat = X_val[eval_mask]
|
|
471
|
+
y_pred_flat = pred_labels[eval_mask]
|
|
472
|
+
y_proba_flat = pred_probas[eval_mask]
|
|
473
|
+
|
|
474
|
+
if y_true_flat.size == 0:
|
|
475
|
+
return {self.tune_metric: 0.0}
|
|
476
|
+
|
|
477
|
+
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
478
|
+
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
479
|
+
|
|
480
|
+
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
481
|
+
|
|
482
|
+
metrics = self.scorers_.evaluate(
|
|
483
|
+
y_true_flat,
|
|
484
|
+
y_pred_flat,
|
|
485
|
+
y_true_ohe,
|
|
486
|
+
y_proba_flat,
|
|
487
|
+
objective_mode,
|
|
488
|
+
self.tune_metric,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
if not objective_mode:
|
|
492
|
+
self.logger.info(f"Validation Metrics (0/1/2): {metrics}")
|
|
493
|
+
|
|
494
|
+
self._make_class_reports(
|
|
495
|
+
y_true=y_true_flat,
|
|
496
|
+
y_pred_proba=y_proba_flat,
|
|
497
|
+
y_pred=y_pred_flat,
|
|
498
|
+
metrics=metrics,
|
|
499
|
+
labels=target_names,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# IUPAC / 10-base auxiliary reports
|
|
503
|
+
y_true_dec = self.pgenc.decode_012(X_val)
|
|
504
|
+
X_pred = X_val.copy()
|
|
505
|
+
X_pred[eval_mask] = y_pred_flat
|
|
506
|
+
|
|
507
|
+
nF_eval = X_val.shape[1]
|
|
508
|
+
y_pred_dec = self.pgenc.decode_012(X_pred.reshape(X_val.shape[0], nF_eval))
|
|
509
|
+
|
|
510
|
+
encodings_dict = {
|
|
511
|
+
"A": 0,
|
|
512
|
+
"C": 1,
|
|
513
|
+
"G": 2,
|
|
514
|
+
"T": 3,
|
|
515
|
+
"W": 4,
|
|
516
|
+
"R": 5,
|
|
517
|
+
"M": 6,
|
|
518
|
+
"K": 7,
|
|
519
|
+
"Y": 8,
|
|
520
|
+
"S": 9,
|
|
521
|
+
"N": -1,
|
|
522
|
+
}
|
|
523
|
+
y_true_int = self.pgenc.convert_int_iupac(
|
|
524
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
525
|
+
)
|
|
526
|
+
y_pred_int = self.pgenc.convert_int_iupac(
|
|
527
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
self._make_class_reports(
|
|
531
|
+
y_true=y_true_int[eval_mask],
|
|
532
|
+
y_pred=y_pred_int[eval_mask],
|
|
533
|
+
metrics=metrics,
|
|
534
|
+
y_pred_proba=None,
|
|
535
|
+
labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
return metrics
|
|
539
|
+
|
|
540
|
+
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
541
|
+
"""Create DataLoader over indices + 0/1/2 target matrix.
|
|
542
|
+
|
|
543
|
+
This method creates a PyTorch DataLoader for the given genotype matrix, which contains 0/1/2 encodings with -1 for missing values. The DataLoader is constructed to yield batches of data during training, where each batch consists of indices and the corresponding genotype values. The genotype matrix is converted to a PyTorch tensor and moved to the appropriate device (CPU or GPU) before being wrapped in a TensorDataset. The DataLoader is configured to shuffle the data and use the specified batch size.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
y (np.ndarray): (n_samples x L) int matrix with -1 missing.
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
torch.utils.data.DataLoader: Shuffled mini-batches.
|
|
550
|
+
"""
|
|
551
|
+
y_tensor = torch.from_numpy(y).long().to(self.device)
|
|
552
|
+
dataset = torch.utils.data.TensorDataset(
|
|
553
|
+
torch.arange(len(y), device=self.device), y_tensor.to(self.device)
|
|
554
|
+
)
|
|
555
|
+
return torch.utils.data.DataLoader(
|
|
556
|
+
dataset, batch_size=self.batch_size, shuffle=True
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
560
|
+
"""Optuna objective using the UBP training loop."""
|
|
561
|
+
try:
|
|
562
|
+
params = self._sample_hyperparameters(trial)
|
|
563
|
+
|
|
564
|
+
X_train_trial = self.ground_truth_[self.train_idx_]
|
|
565
|
+
X_test_trial = self.ground_truth_[self.test_idx_]
|
|
566
|
+
|
|
567
|
+
class_weights = self._class_weights_from_zygosity(X_train_trial)
|
|
568
|
+
train_loader = self._get_data_loaders(X_train_trial)
|
|
569
|
+
|
|
570
|
+
train_latent_vectors = self._create_latent_space(
|
|
571
|
+
params, len(X_train_trial), X_train_trial, params["latent_init"]
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
model = self.build_model(self.Model, params["model_params"])
|
|
575
|
+
model.n_features = params["model_params"]["n_features"]
|
|
576
|
+
model.apply(self.initialize_weights)
|
|
577
|
+
|
|
578
|
+
_, model, _ = self._train_and_validate_model(
|
|
579
|
+
model=model,
|
|
580
|
+
loader=train_loader,
|
|
581
|
+
lr=params["lr"],
|
|
582
|
+
l1_penalty=params["l1_penalty"],
|
|
583
|
+
trial=trial,
|
|
584
|
+
return_history=False,
|
|
585
|
+
latent_vectors=train_latent_vectors,
|
|
586
|
+
lr_input_factor=params["lr_input_factor"],
|
|
587
|
+
class_weights=class_weights,
|
|
588
|
+
X_val=X_test_trial,
|
|
589
|
+
params=params,
|
|
590
|
+
prune_metric=self.tune_metric,
|
|
591
|
+
prune_warmup_epochs=5,
|
|
592
|
+
eval_interval=1,
|
|
593
|
+
eval_requires_latents=True,
|
|
594
|
+
eval_latent_steps=50,
|
|
595
|
+
eval_latent_lr=params["lr"] * params["lr_input_factor"],
|
|
596
|
+
eval_latent_weight_decay=0.0,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
metrics = self._evaluate_model(
|
|
600
|
+
X_test_trial, model, params, objective_mode=True
|
|
601
|
+
)
|
|
602
|
+
self._clear_resources(
|
|
603
|
+
model, train_loader, latent_vectors=train_latent_vectors
|
|
604
|
+
)
|
|
605
|
+
return metrics[self.tune_metric]
|
|
606
|
+
except Exception as e:
|
|
607
|
+
raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
|
|
608
|
+
|
|
609
|
+
def _sample_hyperparameters(
|
|
610
|
+
self, trial: optuna.Trial
|
|
611
|
+
) -> Dict[str, int | float | str | list]:
|
|
612
|
+
"""Sample UBP hyperparameters; compute hidden sizes for model_params.
|
|
613
|
+
|
|
614
|
+
This method samples a set of hyperparameters for the UBP model using the provided Optuna trial object. It defines a search space for various hyperparameters, including latent dimension, learning rate, dropout rate, number of hidden layers, activation function, and others. After sampling the hyperparameters, it computes the sizes of the hidden layers based on the sampled values and constructs the model parameters dictionary. The method returns a dictionary containing all sampled hyperparameters along with the computed model parameters.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
trial (optuna.Trial): Current trial.
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
Dict[str, int | float | str | list]: Sampled hyperparameters.
|
|
621
|
+
"""
|
|
622
|
+
params = {
|
|
623
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
624
|
+
"lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
|
|
625
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
|
|
626
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
|
|
627
|
+
"activation": trial.suggest_categorical(
|
|
628
|
+
"activation", ["relu", "elu", "selu"]
|
|
629
|
+
),
|
|
630
|
+
"gamma": trial.suggest_float("gamma", 0.0, 5.0),
|
|
631
|
+
"lr_input_factor": trial.suggest_float(
|
|
632
|
+
"lr_input_factor", 0.1, 10.0, log=True
|
|
633
|
+
),
|
|
634
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
|
|
635
|
+
"layer_scaling_factor": trial.suggest_float(
|
|
636
|
+
"layer_scaling_factor", 2.0, 10.0
|
|
637
|
+
),
|
|
638
|
+
"layer_schedule": trial.suggest_categorical(
|
|
639
|
+
"layer_schedule", ["pyramid", "constant", "linear"]
|
|
640
|
+
),
|
|
641
|
+
"latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
645
|
+
n_inputs=params["latent_dim"],
|
|
646
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
647
|
+
n_samples=len(self.train_idx_),
|
|
648
|
+
n_hidden=params["num_hidden_layers"],
|
|
649
|
+
alpha=params["layer_scaling_factor"],
|
|
650
|
+
schedule=params["layer_schedule"],
|
|
651
|
+
)
|
|
652
|
+
# Keep the latent_dim as the first element,
|
|
653
|
+
# then the interior hidden widths.
|
|
654
|
+
# If there are no interior widths (very small nets),
|
|
655
|
+
# this still leaves [latent_dim].
|
|
656
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
657
|
+
|
|
658
|
+
params["model_params"] = {
|
|
659
|
+
"n_features": self.num_features_,
|
|
660
|
+
"num_classes": self.num_classes_,
|
|
661
|
+
"latent_dim": params["latent_dim"],
|
|
662
|
+
"dropout_rate": params["dropout_rate"],
|
|
663
|
+
"hidden_layer_sizes": hidden_only,
|
|
664
|
+
"activation": params["activation"],
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
return params
|
|
668
|
+
|
|
669
|
+
def _set_best_params(
|
|
670
|
+
self, best_params: Dict[str, int | float | str | list]
|
|
671
|
+
) -> Dict[str, int | float | str | list]:
|
|
672
|
+
"""Set best params onto instance; return model_params payload.
|
|
673
|
+
|
|
674
|
+
This method sets the best hyperparameters found during tuning onto the instance attributes of the ImputeUBP class. It extracts the relevant hyperparameters from the provided dictionary and updates the corresponding instance variables. Additionally, it computes the sizes of the hidden layers based on the best hyperparameters and constructs the model parameters dictionary. The method returns a dictionary containing the model parameters that can be used to build the UBP model.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
best_params (Dict[str, int | float | str | list]): Best hyperparameters.
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
Dict[str, int | float | str | list]: model_params payload.
|
|
681
|
+
|
|
682
|
+
Raises:
|
|
683
|
+
ValueError: If best_params is missing required keys.
|
|
684
|
+
"""
|
|
685
|
+
self.latent_dim = best_params["latent_dim"]
|
|
686
|
+
self.dropout_rate = best_params["dropout_rate"]
|
|
687
|
+
self.learning_rate = best_params["learning_rate"]
|
|
688
|
+
self.gamma = best_params["gamma"]
|
|
689
|
+
self.lr_input_factor = best_params["lr_input_factor"]
|
|
690
|
+
self.l1_penalty = best_params["l1_penalty"]
|
|
691
|
+
self.activation = best_params["activation"]
|
|
692
|
+
self.latent_init = best_params["latent_init"]
|
|
693
|
+
|
|
694
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
695
|
+
n_inputs=self.latent_dim,
|
|
696
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
697
|
+
n_samples=len(self.train_idx_),
|
|
698
|
+
n_hidden=best_params["num_hidden_layers"],
|
|
699
|
+
alpha=best_params["layer_scaling_factor"],
|
|
700
|
+
schedule=best_params["layer_schedule"],
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
704
|
+
|
|
705
|
+
return {
|
|
706
|
+
"n_features": self.num_features_,
|
|
707
|
+
"latent_dim": self.latent_dim,
|
|
708
|
+
"hidden_layer_sizes": hidden_only,
|
|
709
|
+
"dropout_rate": self.dropout_rate,
|
|
710
|
+
"activation": self.activation,
|
|
711
|
+
"gamma": self.gamma,
|
|
712
|
+
"num_classes": self.num_classes_,
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
|
|
716
|
+
"""Default (no-tuning) model_params aligned with current attributes.
|
|
717
|
+
|
|
718
|
+
This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
Dict[str, int | float | str | list]: model_params payload.
|
|
722
|
+
"""
|
|
723
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
724
|
+
n_inputs=self.latent_dim,
|
|
725
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
726
|
+
n_samples=len(self.ground_truth_),
|
|
727
|
+
n_hidden=self.num_hidden_layers,
|
|
728
|
+
alpha=self.layer_scaling_factor,
|
|
729
|
+
schedule=self.layer_schedule,
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
733
|
+
|
|
734
|
+
return {
|
|
735
|
+
"n_features": self.num_features_,
|
|
736
|
+
"latent_dim": self.latent_dim,
|
|
737
|
+
"hidden_layer_sizes": hidden_only,
|
|
738
|
+
"dropout_rate": self.dropout_rate,
|
|
739
|
+
"activation": self.activation,
|
|
740
|
+
"gamma": self.gamma,
|
|
741
|
+
"num_classes": self.num_classes_,
|
|
742
|
+
}
|
|
743
|
+
|
|
744
|
+
def _train_and_validate_model(
|
|
745
|
+
self,
|
|
746
|
+
model: torch.nn.Module,
|
|
747
|
+
loader: torch.utils.data.DataLoader,
|
|
748
|
+
lr: float,
|
|
749
|
+
l1_penalty: float,
|
|
750
|
+
trial: optuna.Trial | None = None,
|
|
751
|
+
return_history: bool = False,
|
|
752
|
+
latent_vectors: torch.nn.Parameter | None = None,
|
|
753
|
+
lr_input_factor: float = 1.0,
|
|
754
|
+
class_weights: torch.Tensor | None = None,
|
|
755
|
+
*,
|
|
756
|
+
X_val: np.ndarray | None = None,
|
|
757
|
+
params: dict | None = None,
|
|
758
|
+
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
759
|
+
prune_warmup_epochs: int = 3,
|
|
760
|
+
eval_interval: int = 1,
|
|
761
|
+
eval_requires_latents: bool = True, # UBP needs latent eval
|
|
762
|
+
eval_latent_steps: int = 50,
|
|
763
|
+
eval_latent_lr: float = 1e-2,
|
|
764
|
+
eval_latent_weight_decay: float = 0.0,
|
|
765
|
+
) -> Tuple[float, torch.nn.Module | None, dict, torch.nn.Parameter | None]:
|
|
766
|
+
"""Train & validate UBP model with three-phase loop.
|
|
767
|
+
|
|
768
|
+
This method trains and validates the UBP model using a three-phase training loop. It sets up the latent optimizer and invokes the training loop, which includes pre-training, fine-tuning, and joint training phases. The method ensures that the necessary latent vectors and class weights are provided before proceeding with training. It also incorporates new parameters for evaluation and pruning during training. The final best loss, best model, training history, and optimized latent vectors are returned.
|
|
769
|
+
|
|
770
|
+
Args:
|
|
771
|
+
model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
|
|
772
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
773
|
+
lr (float): Learning rate for decoder.
|
|
774
|
+
l1_penalty (float): L1 regularization weight.
|
|
775
|
+
trial (optuna.Trial | None): Current trial or None.
|
|
776
|
+
return_history (bool): If True, return loss history.
|
|
777
|
+
latent_vectors (torch.nn.Parameter | None): Trainable Z.
|
|
778
|
+
lr_input_factor (float): LR factor for latents.
|
|
779
|
+
class_weights (torch.Tensor | None): Class weights for 0/1/2.
|
|
780
|
+
X_val (np.ndarray | None): Validation set for pruning/eval.
|
|
781
|
+
params (dict | None): Model params for eval.
|
|
782
|
+
prune_metric (str | None): Metric to monitor for pruning.
|
|
783
|
+
prune_warmup_epochs (int): Epochs before pruning starts.
|
|
784
|
+
eval_interval (int): Epochs between evaluations.
|
|
785
|
+
eval_requires_latents (bool): If True, optimize latents for eval.
|
|
786
|
+
eval_latent_steps (int): Latent optimization steps for eval.
|
|
787
|
+
eval_latent_lr (float): Latent optimization LR for eval.
|
|
788
|
+
eval_latent_weight_decay (float): Latent optimization weight decay for eval.
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
|
|
792
|
+
|
|
793
|
+
Raises:
|
|
794
|
+
TypeError: If latent_vectors or class_weights are
|
|
795
|
+
not provided.
|
|
796
|
+
ValueError: If X_val is not provided for evaluation.
|
|
797
|
+
RuntimeError: If eval_latent_steps is not positive.
|
|
798
|
+
"""
|
|
799
|
+
if latent_vectors is None or class_weights is None:
|
|
800
|
+
msg = "Must provide latent_vectors and class_weights."
|
|
801
|
+
self.logger.error(msg)
|
|
802
|
+
raise TypeError(msg)
|
|
803
|
+
|
|
804
|
+
latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
|
|
805
|
+
|
|
806
|
+
result = self._execute_training_loop(
|
|
807
|
+
loader=loader,
|
|
808
|
+
latent_optimizer=latent_optimizer,
|
|
809
|
+
lr=lr,
|
|
810
|
+
model=model,
|
|
811
|
+
l1_penalty=l1_penalty,
|
|
812
|
+
trial=trial,
|
|
813
|
+
return_history=return_history,
|
|
814
|
+
latent_vectors=latent_vectors,
|
|
815
|
+
class_weights=class_weights,
|
|
816
|
+
# NEW ↓↓↓
|
|
817
|
+
X_val=X_val,
|
|
818
|
+
params=params,
|
|
819
|
+
prune_metric=prune_metric,
|
|
820
|
+
prune_warmup_epochs=prune_warmup_epochs,
|
|
821
|
+
eval_interval=eval_interval,
|
|
822
|
+
eval_requires_latents=eval_requires_latents,
|
|
823
|
+
eval_latent_steps=eval_latent_steps,
|
|
824
|
+
eval_latent_lr=eval_latent_lr,
|
|
825
|
+
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
if return_history:
|
|
829
|
+
return result
|
|
830
|
+
|
|
831
|
+
return result[0], result[1], result[3]
|
|
832
|
+
|
|
833
|
+
def _train_final_model(
|
|
834
|
+
self,
|
|
835
|
+
loader: torch.utils.data.DataLoader,
|
|
836
|
+
best_params: Dict[str, int | float | str | list],
|
|
837
|
+
initial_latent_vectors: torch.nn.Parameter,
|
|
838
|
+
) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
|
|
839
|
+
"""Train final UBP model with best params; save weights to disk.
|
|
840
|
+
|
|
841
|
+
This method trains the final UBP model using the best hyperparameters found during tuning. It builds the model with the specified parameters, initializes the weights, and invokes the training and validation process. The method saves the trained model's state dictionary to disk and returns the final loss, trained model, training history, and optimized latent vectors.
|
|
842
|
+
|
|
843
|
+
Args:
|
|
844
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
845
|
+
best_params (Dict[str, int | float | str | list]): Best hyperparameters.
|
|
846
|
+
initial_latent_vectors (torch.nn.Parameter): Initialized latent vectors.
|
|
847
|
+
|
|
848
|
+
Returns:
|
|
849
|
+
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (loss, model, {"Train": history}, latents).
|
|
850
|
+
"""
|
|
851
|
+
self.logger.info("Training the final UBP (0/1/2) model...")
|
|
852
|
+
|
|
853
|
+
model = self.build_model(self.Model, best_params)
|
|
854
|
+
model.n_features = best_params["n_features"]
|
|
855
|
+
model.apply(self.initialize_weights)
|
|
856
|
+
|
|
857
|
+
loss, trained_model, history, latent_vectors = self._train_and_validate_model(
|
|
858
|
+
model=model,
|
|
859
|
+
loader=loader,
|
|
860
|
+
lr=self.learning_rate,
|
|
861
|
+
l1_penalty=self.l1_penalty,
|
|
862
|
+
return_history=True,
|
|
863
|
+
latent_vectors=initial_latent_vectors,
|
|
864
|
+
lr_input_factor=self.lr_input_factor,
|
|
865
|
+
class_weights=self.class_weights_,
|
|
866
|
+
X_val=self.X_test_,
|
|
867
|
+
params=best_params,
|
|
868
|
+
prune_metric=self.tune_metric,
|
|
869
|
+
prune_warmup_epochs=5,
|
|
870
|
+
eval_interval=1,
|
|
871
|
+
eval_requires_latents=True,
|
|
872
|
+
eval_latent_steps=50,
|
|
873
|
+
eval_latent_lr=self.learning_rate * self.lr_input_factor,
|
|
874
|
+
eval_latent_weight_decay=0.0,
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
if trained_model is None:
|
|
878
|
+
msg = "Final model training failed."
|
|
879
|
+
self.logger.error(msg)
|
|
880
|
+
raise RuntimeError(msg)
|
|
881
|
+
|
|
882
|
+
fout = self.models_dir / "final_model.pt"
|
|
883
|
+
torch.save(trained_model.state_dict(), fout)
|
|
884
|
+
return loss, trained_model, {"Train": history}, latent_vectors
|
|
885
|
+
|
|
886
|
+
def _execute_training_loop(
|
|
887
|
+
self,
|
|
888
|
+
loader: torch.utils.data.DataLoader,
|
|
889
|
+
latent_optimizer: torch.optim.Optimizer,
|
|
890
|
+
lr: float,
|
|
891
|
+
model: torch.nn.Module,
|
|
892
|
+
l1_penalty: float,
|
|
893
|
+
trial,
|
|
894
|
+
return_history: bool,
|
|
895
|
+
latent_vectors: torch.nn.Parameter,
|
|
896
|
+
class_weights: torch.Tensor,
|
|
897
|
+
*,
|
|
898
|
+
X_val: np.ndarray | None = None,
|
|
899
|
+
params: dict | None = None,
|
|
900
|
+
prune_metric: str | None = None,
|
|
901
|
+
prune_warmup_epochs: int = 3,
|
|
902
|
+
eval_interval: int = 1,
|
|
903
|
+
eval_requires_latents: bool = True,
|
|
904
|
+
eval_latent_steps: int = 50,
|
|
905
|
+
eval_latent_lr: float = 1e-2,
|
|
906
|
+
eval_latent_weight_decay: float = 0.0,
|
|
907
|
+
) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
|
|
908
|
+
"""Three-phase UBP loop with cosine LR, gamma warmup, and pruning hook.
|
|
909
|
+
|
|
910
|
+
This method executes the three-phase training loop for the UBP model, which includes pre-training, fine-tuning, and joint training phases. It incorporates a cosine annealing learning rate scheduler, focal loss gamma warmup, and an early stopping mechanism. The method also includes a pruning hook for Optuna trials, allowing for early termination of unpromising trials based on validation performance. The final best loss, best model, training history, and optimized latent vectors are returned.
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
914
|
+
latent_optimizer (torch.optim.Optimizer): Latent optimizer.
|
|
915
|
+
lr (float): Learning rate for decoder.
|
|
916
|
+
model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
|
|
917
|
+
l1_penalty (float): L1 regularization weight.
|
|
918
|
+
trial: Current trial or None.
|
|
919
|
+
return_history (bool): If True, return loss history.
|
|
920
|
+
latent_vectors (torch.nn.Parameter): Trainable Z.
|
|
921
|
+
class_weights (torch.Tensor): Class weights for 0/1/2.
|
|
922
|
+
X_val (np.ndarray | None): Validation set for pruning/eval.
|
|
923
|
+
params (dict | None): Model params for eval.
|
|
924
|
+
prune_metric (str | None): Metric to monitor for pruning.
|
|
925
|
+
prune_warmup_epochs (int): Epochs before pruning starts.
|
|
926
|
+
eval_interval (int): Epochs between evaluations.
|
|
927
|
+
eval_requires_latents (bool): If True, optimize latents for eval.
|
|
928
|
+
eval_latent_steps (int): Latent optimization steps for eval.
|
|
929
|
+
eval_latent_lr (float): Latent optimization LR for eval.
|
|
930
|
+
eval_latent_weight_decay (float): Latent optimization weight decay for eval.
|
|
931
|
+
|
|
932
|
+
Returns:
|
|
933
|
+
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
|
|
934
|
+
|
|
935
|
+
Raises:
|
|
936
|
+
TypeError: If X_val is not provided for evaluation.
|
|
937
|
+
ValueError: If eval_latent_steps is not positive.
|
|
938
|
+
"""
|
|
939
|
+
history: dict[str, list[float]] = {}
|
|
940
|
+
final_best_loss = float("inf")
|
|
941
|
+
final_best_model = None
|
|
942
|
+
|
|
943
|
+
# Schema-aware latent cache for eval
|
|
944
|
+
_latent_cache: dict = {}
|
|
945
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
946
|
+
cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.num_classes_}"
|
|
947
|
+
|
|
948
|
+
# Epoch budget; if you later add tune_fast behavior to UBP, wire it here
|
|
949
|
+
max_epochs = self.epochs
|
|
950
|
+
warm, ramp, gamma_final = 50, 100, self.gamma
|
|
951
|
+
|
|
952
|
+
for phase in (1, 2, 3):
|
|
953
|
+
early_stopping = EarlyStopping(
|
|
954
|
+
patience=self.early_stop_gen,
|
|
955
|
+
min_epochs=self.min_epochs,
|
|
956
|
+
verbose=self.verbose,
|
|
957
|
+
prefix=self.prefix,
|
|
958
|
+
debug=self.debug,
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
if phase == 2:
|
|
962
|
+
self._reset_weights(model)
|
|
963
|
+
|
|
964
|
+
decoder_params = (
|
|
965
|
+
model.phase1_decoder.parameters()
|
|
966
|
+
if phase == 1
|
|
967
|
+
else model.phase23_decoder.parameters()
|
|
968
|
+
)
|
|
969
|
+
optimizer = torch.optim.Adam(decoder_params, lr=lr)
|
|
970
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
971
|
+
|
|
972
|
+
phase_hist: list[float] = []
|
|
973
|
+
|
|
974
|
+
for epoch in range(max_epochs):
|
|
975
|
+
# Focal gamma warmup
|
|
976
|
+
if epoch < warm:
|
|
977
|
+
model.gamma = 0.0
|
|
978
|
+
elif epoch < warm + ramp:
|
|
979
|
+
model.gamma = gamma_final * ((epoch - warm) / ramp)
|
|
980
|
+
else:
|
|
981
|
+
model.gamma = gamma_final
|
|
982
|
+
|
|
983
|
+
train_loss, latent_vectors = self._train_step(
|
|
984
|
+
loader=loader,
|
|
985
|
+
optimizer=optimizer,
|
|
986
|
+
latent_optimizer=latent_optimizer,
|
|
987
|
+
model=model,
|
|
988
|
+
l1_penalty=l1_penalty,
|
|
989
|
+
latent_vectors=latent_vectors,
|
|
990
|
+
class_weights=class_weights,
|
|
991
|
+
phase=phase,
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
if trial and (np.isnan(train_loss) or np.isinf(train_loss)):
|
|
995
|
+
raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
|
|
996
|
+
|
|
997
|
+
scheduler.step()
|
|
998
|
+
if return_history:
|
|
999
|
+
phase_hist.append(train_loss)
|
|
1000
|
+
|
|
1001
|
+
early_stopping(train_loss, model)
|
|
1002
|
+
if early_stopping.early_stop:
|
|
1003
|
+
self.logger.info(
|
|
1004
|
+
f"Early stopping at epoch {epoch + 1} (phase {phase})."
|
|
1005
|
+
)
|
|
1006
|
+
break
|
|
1007
|
+
|
|
1008
|
+
# Validation pruning hook
|
|
1009
|
+
if (
|
|
1010
|
+
trial is not None
|
|
1011
|
+
and X_val is not None
|
|
1012
|
+
and ((epoch + 1) % eval_interval == 0)
|
|
1013
|
+
):
|
|
1014
|
+
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
1015
|
+
z = self._first_linear_in_features(model)
|
|
1016
|
+
schema_key = f"{cache_key_root}_z{z}"
|
|
1017
|
+
|
|
1018
|
+
metric_val = self._eval_for_pruning(
|
|
1019
|
+
model=model,
|
|
1020
|
+
X_val=X_val,
|
|
1021
|
+
params=params or getattr(self, "best_params_", {}),
|
|
1022
|
+
metric=metric_key,
|
|
1023
|
+
objective_mode=True,
|
|
1024
|
+
do_latent_infer=eval_requires_latents,
|
|
1025
|
+
latent_steps=eval_latent_steps,
|
|
1026
|
+
latent_lr=eval_latent_lr,
|
|
1027
|
+
latent_weight_decay=eval_latent_weight_decay,
|
|
1028
|
+
latent_seed=(self.seed if self.seed is not None else 123),
|
|
1029
|
+
_latent_cache=_latent_cache,
|
|
1030
|
+
_latent_cache_key=schema_key,
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
with warnings.catch_warnings():
|
|
1034
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
1035
|
+
trial.report(metric_val, step=epoch + 1)
|
|
1036
|
+
|
|
1037
|
+
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
1038
|
+
raise optuna.exceptions.TrialPruned(
|
|
1039
|
+
f"Pruned at epoch {epoch + 1} (phase {phase}): "
|
|
1040
|
+
f"{metric_key}={metric_val:.5f}"
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
history[f"Phase {phase}"] = phase_hist
|
|
1044
|
+
final_best_loss = early_stopping.best_score
|
|
1045
|
+
final_best_model = copy.deepcopy(early_stopping.best_model)
|
|
1046
|
+
|
|
1047
|
+
return final_best_loss, final_best_model, history, latent_vectors
|
|
1048
|
+
|
|
1049
|
+
def _optimize_latents_for_inference(
|
|
1050
|
+
self,
|
|
1051
|
+
X_new: np.ndarray,
|
|
1052
|
+
model: torch.nn.Module,
|
|
1053
|
+
params: dict,
|
|
1054
|
+
inference_epochs: int = 200,
|
|
1055
|
+
) -> torch.Tensor:
|
|
1056
|
+
"""Optimize latent vectors for new 0/1/2 data by minimizing masked CE.
|
|
1057
|
+
|
|
1058
|
+
This method optimizes latent vectors for a given genotype matrix using a trained UBP model. It initializes the latent vectors based on the specified strategy (random or PCA) and then refines them through gradient-based optimization to minimize the cross-entropy loss between the model's predictions and the provided genotype data. The optimization process is performed for a specified number of epochs, and the resulting optimized latent vectors are returned.
|
|
1059
|
+
|
|
1060
|
+
Args:
|
|
1061
|
+
X_new (np.ndarray): 0/1/2 with -1 for missing.
|
|
1062
|
+
model (torch.nn.Module): Trained model.
|
|
1063
|
+
params (dict): Should include 'latent_dim'.
|
|
1064
|
+
inference_epochs (int): Steps for optimization.
|
|
1065
|
+
|
|
1066
|
+
Returns:
|
|
1067
|
+
torch.Tensor: Optimized latent vectors.
|
|
1068
|
+
"""
|
|
1069
|
+
model.eval()
|
|
1070
|
+
|
|
1071
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
1072
|
+
|
|
1073
|
+
X_new = X_new.astype(np.int64, copy=False)
|
|
1074
|
+
X_new[X_new < 0] = -1
|
|
1075
|
+
|
|
1076
|
+
# Allow shorter inference when tune_fast is enabled, mirroring NLPCA
|
|
1077
|
+
if self.tune and self.tune_fast:
|
|
1078
|
+
inference_epochs = min(
|
|
1079
|
+
inference_epochs, getattr(self, "tune_infer_epochs", 20)
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
new_latent_vectors = self._create_latent_space(
|
|
1083
|
+
params, len(X_new), X_new, self.latent_init
|
|
1084
|
+
)
|
|
1085
|
+
opt = torch.optim.Adam(
|
|
1086
|
+
[new_latent_vectors], lr=self.learning_rate * self.lr_input_factor
|
|
1087
|
+
)
|
|
1088
|
+
y_target = torch.from_numpy(X_new).long().to(self.device)
|
|
1089
|
+
|
|
1090
|
+
for _ in range(inference_epochs):
|
|
1091
|
+
opt.zero_grad(set_to_none=True)
|
|
1092
|
+
logits = model.phase23_decoder(new_latent_vectors).view(
|
|
1093
|
+
len(X_new), nF, self.num_classes_
|
|
1094
|
+
)
|
|
1095
|
+
loss = F.cross_entropy(
|
|
1096
|
+
logits.view(-1, self.num_classes_), y_target.view(-1), ignore_index=-1
|
|
1097
|
+
)
|
|
1098
|
+
if torch.isnan(loss) or torch.isinf(loss):
|
|
1099
|
+
self.logger.warning(
|
|
1100
|
+
"Inference loss is NaN/Inf; stopping latent refinement."
|
|
1101
|
+
)
|
|
1102
|
+
break
|
|
1103
|
+
loss.backward()
|
|
1104
|
+
opt.step()
|
|
1105
|
+
|
|
1106
|
+
return new_latent_vectors.detach()
|
|
1107
|
+
|
|
1108
|
+
def _create_latent_space(
|
|
1109
|
+
self,
|
|
1110
|
+
params: dict,
|
|
1111
|
+
n_samples: int,
|
|
1112
|
+
X: np.ndarray,
|
|
1113
|
+
latent_init: Literal["random", "pca"],
|
|
1114
|
+
) -> torch.nn.Parameter:
|
|
1115
|
+
"""Initialize latent space via random Xavier or PCA on 0/1/2 matrix.
|
|
1116
|
+
|
|
1117
|
+
This method initializes the latent space for the UBP model using either random Xavier initialization or PCA-based initialization. The choice of initialization strategy is determined by the latent_init parameter. If PCA is selected, the method handles missing values by imputing them with column means before performing PCA. The resulting latent vectors are standardized and converted to a PyTorch parameter that can be optimized during training.
|
|
1118
|
+
|
|
1119
|
+
Args:
|
|
1120
|
+
params (dict): Contains 'latent_dim'.
|
|
1121
|
+
n_samples (int): Number of samples.
|
|
1122
|
+
X (np.ndarray): (n_samples x L) 0/1/2 with -1 missing.
|
|
1123
|
+
latent_init (Literal["random","pca"]): Init strategy.
|
|
1124
|
+
|
|
1125
|
+
Returns:
|
|
1126
|
+
torch.nn.Parameter: Trainable latent matrix.
|
|
1127
|
+
"""
|
|
1128
|
+
latent_dim = int(params["latent_dim"])
|
|
1129
|
+
|
|
1130
|
+
if latent_init == "pca":
|
|
1131
|
+
X_pca = X.astype(np.float32, copy=True)
|
|
1132
|
+
# mark missing
|
|
1133
|
+
X_pca[X_pca < 0] = np.nan
|
|
1134
|
+
|
|
1135
|
+
# ---- SAFE column means without warnings ----
|
|
1136
|
+
valid_counts = np.sum(~np.isnan(X_pca), axis=0)
|
|
1137
|
+
col_sums = np.nansum(X_pca, axis=0)
|
|
1138
|
+
col_means = np.divide(
|
|
1139
|
+
col_sums,
|
|
1140
|
+
valid_counts,
|
|
1141
|
+
out=np.zeros_like(col_sums, dtype=np.float32),
|
|
1142
|
+
where=valid_counts > 0,
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
# impute NaNs with per-column means
|
|
1146
|
+
# (all-NaN cols -> 0.0 by the divide above)
|
|
1147
|
+
nan_r, nan_c = np.where(np.isnan(X_pca))
|
|
1148
|
+
if nan_r.size:
|
|
1149
|
+
X_pca[nan_r, nan_c] = col_means[nan_c]
|
|
1150
|
+
|
|
1151
|
+
# center columns
|
|
1152
|
+
X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
|
|
1153
|
+
|
|
1154
|
+
# guard: degenerate / all-zero after centering ->
|
|
1155
|
+
# fall back to random
|
|
1156
|
+
if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
|
|
1157
|
+
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
1158
|
+
torch.nn.init.xavier_uniform_(latents)
|
|
1159
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
1160
|
+
|
|
1161
|
+
# rank-aware component count, at least 1
|
|
1162
|
+
try:
|
|
1163
|
+
est_rank = np.linalg.matrix_rank(X_pca)
|
|
1164
|
+
except Exception:
|
|
1165
|
+
est_rank = min(n_samples, X_pca.shape[1])
|
|
1166
|
+
|
|
1167
|
+
n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
|
|
1168
|
+
|
|
1169
|
+
# use deterministic SVD to avoid power-iteration warnings
|
|
1170
|
+
pca = PCA(
|
|
1171
|
+
n_components=n_components, svd_solver="full", random_state=self.seed
|
|
1172
|
+
)
|
|
1173
|
+
initial = pca.fit_transform(X_pca) # (n_samples, n_components)
|
|
1174
|
+
|
|
1175
|
+
# pad if latent_dim > n_components
|
|
1176
|
+
if n_components < latent_dim:
|
|
1177
|
+
pad = self.rng.standard_normal(
|
|
1178
|
+
size=(n_samples, latent_dim - n_components)
|
|
1179
|
+
)
|
|
1180
|
+
initial = np.hstack([initial, pad])
|
|
1181
|
+
|
|
1182
|
+
# standardize latent dims
|
|
1183
|
+
initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
|
|
1184
|
+
|
|
1185
|
+
latents = torch.from_numpy(initial).float().to(self.device)
|
|
1186
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
1187
|
+
|
|
1188
|
+
else:
|
|
1189
|
+
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
1190
|
+
torch.nn.init.xavier_uniform_(latents)
|
|
1191
|
+
|
|
1192
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
1193
|
+
|
|
1194
|
+
def _reset_weights(self, model: torch.nn.Module) -> None:
|
|
1195
|
+
"""Selectively resets only the weights of the phase 2/3 decoder.
|
|
1196
|
+
|
|
1197
|
+
This method targets only the `phase23_decoder` attribute of the UBPModel, leaving the `phase1_decoder` and other potential model components untouched. This allows the model to be re-initialized for the second phase of training without affecting other parts.
|
|
1198
|
+
|
|
1199
|
+
Args:
|
|
1200
|
+
model (torch.nn.Module): The PyTorch model whose parameters are to be reset.
|
|
1201
|
+
"""
|
|
1202
|
+
if hasattr(model, "phase23_decoder"):
|
|
1203
|
+
# Iterate through only the modules of the second decoder
|
|
1204
|
+
for layer in model.phase23_decoder.modules():
|
|
1205
|
+
if hasattr(layer, "reset_parameters"):
|
|
1206
|
+
layer.reset_parameters()
|
|
1207
|
+
else:
|
|
1208
|
+
self.logger.warning(
|
|
1209
|
+
"Model does not have a 'phase23_decoder' attribute; skipping weight reset."
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
def _latent_infer_for_eval(
|
|
1213
|
+
self,
|
|
1214
|
+
model: torch.nn.Module,
|
|
1215
|
+
X_val: np.ndarray,
|
|
1216
|
+
*,
|
|
1217
|
+
steps: int,
|
|
1218
|
+
lr: float,
|
|
1219
|
+
weight_decay: float,
|
|
1220
|
+
seed: int,
|
|
1221
|
+
cache: dict | None,
|
|
1222
|
+
cache_key: str | None,
|
|
1223
|
+
) -> None:
|
|
1224
|
+
"""Freeze weights; refine validation latents only.
|
|
1225
|
+
|
|
1226
|
+
This method optimizes latent vectors for the validation set using a trained UBP model. It refines the latent vectors by minimizing the cross-entropy loss between the model's predictions and the provided genotype data. The optimization process is performed for a specified number of steps, and the resulting optimized latent vectors are stored in a cache for potential reuse. The method ensures that the model's weights remain unchanged during this process by freezing them.
|
|
1227
|
+
|
|
1228
|
+
Args:
|
|
1229
|
+
model (torch.nn.Module): Trained UBP model.
|
|
1230
|
+
X_val (np.ndarray): Validation 0/1/2 with -1 for missing.
|
|
1231
|
+
steps (int): Number of optimization steps.
|
|
1232
|
+
lr (float): Learning rate for latent optimization.
|
|
1233
|
+
weight_decay (float): L2 weight decay on latents.
|
|
1234
|
+
seed (int): RNG seed for determinism across epochs.
|
|
1235
|
+
cache (dict | None): Optional dict to warm-start & persist val latents.
|
|
1236
|
+
cache_key (str | None): Ignored; we build a schema-aware key internally.
|
|
1237
|
+
"""
|
|
1238
|
+
if seed is None:
|
|
1239
|
+
seed = np.random.randint(0, 999999)
|
|
1240
|
+
torch.manual_seed(seed)
|
|
1241
|
+
np.random.seed(seed)
|
|
1242
|
+
|
|
1243
|
+
model.eval()
|
|
1244
|
+
for p in model.parameters():
|
|
1245
|
+
p.requires_grad_(False)
|
|
1246
|
+
|
|
1247
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
1248
|
+
|
|
1249
|
+
X_val = X_val.astype(np.int64, copy=False)
|
|
1250
|
+
X_val[X_val < 0] = -1
|
|
1251
|
+
y_target = torch.from_numpy(X_val).long().to(self.device)
|
|
1252
|
+
|
|
1253
|
+
# Infer current model latent size to avoid shape mismatch
|
|
1254
|
+
latent_dim_model = self._first_linear_in_features(model)
|
|
1255
|
+
schema_key = f"{self.prefix}_ubp_val_latents_z{latent_dim_model}_L{nF}_K{self.num_classes_}"
|
|
1256
|
+
|
|
1257
|
+
# Warm-start from cache if compatible
|
|
1258
|
+
if cache is not None and schema_key in cache:
|
|
1259
|
+
val_latents = cache[schema_key].detach().clone().requires_grad_(True)
|
|
1260
|
+
else:
|
|
1261
|
+
val_latents = self._create_latent_space(
|
|
1262
|
+
{"latent_dim": latent_dim_model},
|
|
1263
|
+
n_samples=X_val.shape[0],
|
|
1264
|
+
X=X_val,
|
|
1265
|
+
latent_init=self.latent_init,
|
|
1266
|
+
).requires_grad_(True)
|
|
1267
|
+
|
|
1268
|
+
opt = torch.optim.AdamW([val_latents], lr=lr, weight_decay=weight_decay)
|
|
1269
|
+
|
|
1270
|
+
for _ in range(max(int(steps), 0)):
|
|
1271
|
+
opt.zero_grad(set_to_none=True)
|
|
1272
|
+
logits = model.phase23_decoder(val_latents).view(
|
|
1273
|
+
X_val.shape[0], nF, self.num_classes_
|
|
1274
|
+
)
|
|
1275
|
+
loss = F.cross_entropy(
|
|
1276
|
+
logits.view(-1, self.num_classes_),
|
|
1277
|
+
y_target.view(-1),
|
|
1278
|
+
ignore_index=-1,
|
|
1279
|
+
reduction="mean",
|
|
1280
|
+
)
|
|
1281
|
+
loss.backward()
|
|
1282
|
+
opt.step()
|
|
1283
|
+
|
|
1284
|
+
if cache is not None:
|
|
1285
|
+
cache[schema_key] = val_latents.detach().clone()
|
|
1286
|
+
|
|
1287
|
+
for p in model.parameters():
|
|
1288
|
+
p.requires_grad_(True)
|