pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,1575 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import optuna
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from sklearn.decomposition import PCA
|
|
9
|
+
from sklearn.exceptions import NotFittedError
|
|
10
|
+
from sklearn.model_selection import train_test_split
|
|
11
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
12
|
+
from snpio.utils.logging import LoggerManager
|
|
13
|
+
|
|
14
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
15
|
+
from pgsui.data_processing.containers import UBPConfig
|
|
16
|
+
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
17
|
+
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
18
|
+
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
19
|
+
from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
|
|
20
|
+
from pgsui.impute.unsupervised.models.ubp_model import UBPModel
|
|
21
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
22
|
+
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from snpio import TreeParser
|
|
26
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def ensure_ubp_config(config: UBPConfig | dict | str | None) -> UBPConfig:
|
|
30
|
+
"""Return a concrete UBPConfig from dataclass, dict, YAML path, or None.
|
|
31
|
+
|
|
32
|
+
This method normalizes the input configuration for the UBP imputer. It accepts a UBPConfig instance, a dictionary, a YAML file path, or None. If None is provided, it returns a default UBPConfig instance. If a YAML path is given, it loads the configuration from the file, supporting top-level presets. If a dictionary is provided, it flattens any nested structures and applies dot-key overrides to a base configuration, which can also be influenced by a preset if specified. The method ensures that the final output is a fully populated UBPConfig instance.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config: UBPConfig | dict | YAML path | None.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
UBPConfig: Normalized configuration instance.
|
|
39
|
+
"""
|
|
40
|
+
if config is None:
|
|
41
|
+
return UBPConfig()
|
|
42
|
+
if isinstance(config, UBPConfig):
|
|
43
|
+
return config
|
|
44
|
+
if isinstance(config, str):
|
|
45
|
+
# YAML path — support top-level `preset`
|
|
46
|
+
return load_yaml_to_dataclass(config, UBPConfig)
|
|
47
|
+
if isinstance(config, dict):
|
|
48
|
+
base = UBPConfig()
|
|
49
|
+
|
|
50
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
51
|
+
for k, v in d.items():
|
|
52
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
53
|
+
if isinstance(v, dict):
|
|
54
|
+
_flatten(kk, v, out)
|
|
55
|
+
else:
|
|
56
|
+
out[kk] = v
|
|
57
|
+
return out
|
|
58
|
+
|
|
59
|
+
preset_name = config.pop("preset", None)
|
|
60
|
+
if "io" in config and isinstance(config["io"], dict):
|
|
61
|
+
preset_name = preset_name or config["io"].pop("preset", None)
|
|
62
|
+
if preset_name:
|
|
63
|
+
base = UBPConfig.from_preset(preset_name)
|
|
64
|
+
|
|
65
|
+
flat = _flatten("", config, {})
|
|
66
|
+
return apply_dot_overrides(base, flat)
|
|
67
|
+
|
|
68
|
+
raise TypeError("config must be a UBPConfig, dict, YAML path, or None.")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ImputeUBP(BaseNNImputer):
|
|
72
|
+
"""UBP imputer for 0/1/2 genotypes with a three-phase decoder schedule.
|
|
73
|
+
|
|
74
|
+
This imputer follows the training recipe from Unsupervised Backpropagation:
|
|
75
|
+
|
|
76
|
+
1. Phase 1 (joint warm start): Learn latent codes and the shallow linear decoder together.
|
|
77
|
+
2. Phase 2 (deep decoder reset): Reinitialize the deeper decoder, freeze the latent codes, and train only the decoder parameters.
|
|
78
|
+
3. Phase 3 (joint fine-tune): Unfreeze everything and jointly refine latent codes plus the deep decoder before evaluation/reporting.
|
|
79
|
+
|
|
80
|
+
References:
|
|
81
|
+
- Gashler, Michael S., Smith, Michael R., Morris, R., and Martinez, T. (2016) Missing Value Imputation with Unsupervised Backpropagation. Computational Intelligence, 32: 196-215. doi: 10.1111/coin.12048.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
genotype_data: "GenotypeData",
|
|
87
|
+
*,
|
|
88
|
+
tree_parser: Optional["TreeParser"] = None,
|
|
89
|
+
config: UBPConfig | dict | str | None = None,
|
|
90
|
+
overrides: dict | None = None,
|
|
91
|
+
simulate_missing: bool | None = None,
|
|
92
|
+
sim_strategy: (
|
|
93
|
+
Literal[
|
|
94
|
+
"random",
|
|
95
|
+
"random_weighted",
|
|
96
|
+
"random_weighted_inv",
|
|
97
|
+
"nonrandom",
|
|
98
|
+
"nonrandom_weighted",
|
|
99
|
+
]
|
|
100
|
+
| None
|
|
101
|
+
) = None,
|
|
102
|
+
sim_prop: float | None = None,
|
|
103
|
+
sim_kwargs: dict | None = None,
|
|
104
|
+
):
|
|
105
|
+
"""Initialize the UBP imputer via dataclass/dict/YAML config with overrides.
|
|
106
|
+
|
|
107
|
+
This constructor allows for flexible initialization of the UBP imputer by accepting various forms of configuration input. It ensures that the configuration is properly normalized and any specified overrides are applied. The method also sets up logging and initializes various attributes related to the model, training, tuning, and evaluation based on the provided configuration.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
111
|
+
tree_parser: "TreeParser" | None = None, Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
|
|
112
|
+
config (UBPConfig | dict | str | None): UBP configuration.
|
|
113
|
+
overrides (dict | None): Flat dot-key overrides applied after `config`.
|
|
114
|
+
simulate_missing (bool | None): Whether to simulate missing data during training.
|
|
115
|
+
sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
|
|
116
|
+
sim_prop (float | None): Proportion of data to simulate as missing if simulating.
|
|
117
|
+
sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
|
|
118
|
+
"""
|
|
119
|
+
self.model_name = "ImputeUBP"
|
|
120
|
+
self.genotype_data = genotype_data
|
|
121
|
+
self.tree_parser = tree_parser
|
|
122
|
+
|
|
123
|
+
# ---- normalize config, then apply overrides ----
|
|
124
|
+
cfg = ensure_ubp_config(config)
|
|
125
|
+
if overrides:
|
|
126
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
127
|
+
self.cfg = cfg
|
|
128
|
+
|
|
129
|
+
# ---- logging ----
|
|
130
|
+
logman = LoggerManager(
|
|
131
|
+
__name__,
|
|
132
|
+
prefix=self.cfg.io.prefix,
|
|
133
|
+
debug=self.cfg.io.debug,
|
|
134
|
+
verbose=self.cfg.io.verbose,
|
|
135
|
+
)
|
|
136
|
+
self.logger = configure_logger(
|
|
137
|
+
logman.get_logger(),
|
|
138
|
+
verbose=self.cfg.io.verbose,
|
|
139
|
+
debug=self.cfg.io.debug,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# ---- Base init ----
|
|
143
|
+
super().__init__(
|
|
144
|
+
model_name=self.model_name,
|
|
145
|
+
genotype_data=self.genotype_data,
|
|
146
|
+
prefix=self.cfg.io.prefix,
|
|
147
|
+
device=self.cfg.train.device,
|
|
148
|
+
verbose=self.cfg.io.verbose,
|
|
149
|
+
debug=self.cfg.io.debug,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# ---- model/meta ----
|
|
153
|
+
self.Model = UBPModel
|
|
154
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
155
|
+
|
|
156
|
+
self.seed = self.cfg.io.seed
|
|
157
|
+
self.n_jobs = self.cfg.io.n_jobs
|
|
158
|
+
self.prefix = self.cfg.io.prefix
|
|
159
|
+
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
160
|
+
self.verbose = self.cfg.io.verbose
|
|
161
|
+
self.debug = self.cfg.io.debug
|
|
162
|
+
self.rng = np.random.default_rng(self.seed)
|
|
163
|
+
|
|
164
|
+
# Simulated-missing controls (config defaults w/ overrides)
|
|
165
|
+
sim_cfg = getattr(self.cfg, "sim", None)
|
|
166
|
+
sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
|
|
167
|
+
if sim_kwargs:
|
|
168
|
+
sim_cfg_kwargs.update(sim_kwargs)
|
|
169
|
+
if sim_cfg is None:
|
|
170
|
+
default_sim_flag = bool(simulate_missing)
|
|
171
|
+
default_strategy = "random"
|
|
172
|
+
default_prop = 0.10
|
|
173
|
+
else:
|
|
174
|
+
default_sim_flag = sim_cfg.simulate_missing
|
|
175
|
+
default_strategy = sim_cfg.sim_strategy
|
|
176
|
+
default_prop = sim_cfg.sim_prop
|
|
177
|
+
self.simulate_missing = (
|
|
178
|
+
default_sim_flag if simulate_missing is None else bool(simulate_missing)
|
|
179
|
+
)
|
|
180
|
+
self.sim_strategy = sim_strategy or default_strategy
|
|
181
|
+
self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
|
|
182
|
+
self.sim_kwargs = sim_cfg_kwargs
|
|
183
|
+
|
|
184
|
+
# ---- model hyperparams ----
|
|
185
|
+
self.latent_dim = self.cfg.model.latent_dim
|
|
186
|
+
self.dropout_rate = self.cfg.model.dropout_rate
|
|
187
|
+
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
188
|
+
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
189
|
+
self.layer_schedule = self.cfg.model.layer_schedule
|
|
190
|
+
self.latent_init: Literal["pca", "random"] = self.cfg.model.latent_init
|
|
191
|
+
self.activation = self.cfg.model.hidden_activation
|
|
192
|
+
self.gamma = self.cfg.model.gamma
|
|
193
|
+
|
|
194
|
+
# ---- training ----
|
|
195
|
+
self.batch_size = self.cfg.train.batch_size
|
|
196
|
+
self.learning_rate = self.cfg.train.learning_rate
|
|
197
|
+
self.lr_input_factor = self.cfg.train.lr_input_factor
|
|
198
|
+
self.l1_penalty = self.cfg.train.l1_penalty
|
|
199
|
+
self.early_stop_gen = self.cfg.train.early_stop_gen
|
|
200
|
+
self.min_epochs = self.cfg.train.min_epochs
|
|
201
|
+
self.epochs = self.cfg.train.max_epochs
|
|
202
|
+
self.validation_split = self.cfg.train.validation_split
|
|
203
|
+
self.beta = self.cfg.train.weights_beta
|
|
204
|
+
self.max_ratio = self.cfg.train.weights_max_ratio
|
|
205
|
+
|
|
206
|
+
# ---- tuning ----
|
|
207
|
+
self.tune = self.cfg.tune.enabled
|
|
208
|
+
self.tune_fast = self.cfg.tune.fast
|
|
209
|
+
self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
|
|
210
|
+
self.tune_batch_size = self.cfg.tune.batch_size
|
|
211
|
+
self.tune_epochs = self.cfg.tune.epochs
|
|
212
|
+
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
213
|
+
self.tune_metric: Literal[
|
|
214
|
+
"pr_macro",
|
|
215
|
+
"f1",
|
|
216
|
+
"accuracy",
|
|
217
|
+
"average_precision",
|
|
218
|
+
"precision",
|
|
219
|
+
"recall",
|
|
220
|
+
"roc_auc",
|
|
221
|
+
] = self.cfg.tune.metric
|
|
222
|
+
self.n_trials = self.cfg.tune.n_trials
|
|
223
|
+
self.tune_save_db = self.cfg.tune.save_db
|
|
224
|
+
self.tune_resume = self.cfg.tune.resume
|
|
225
|
+
self.tune_max_samples = self.cfg.tune.max_samples
|
|
226
|
+
self.tune_max_loci = self.cfg.tune.max_loci
|
|
227
|
+
self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
|
|
228
|
+
self.tune_patience = self.cfg.tune.patience
|
|
229
|
+
|
|
230
|
+
# ---- evaluation ----
|
|
231
|
+
self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
|
|
232
|
+
self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
|
|
233
|
+
self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
|
|
234
|
+
|
|
235
|
+
# ---- plotting ----
|
|
236
|
+
self.plot_format = self.cfg.plot.fmt
|
|
237
|
+
self.plot_dpi = self.cfg.plot.dpi
|
|
238
|
+
self.plot_fontsize = self.cfg.plot.fontsize
|
|
239
|
+
self.title_fontsize = self.cfg.plot.fontsize
|
|
240
|
+
self.despine = self.cfg.plot.despine
|
|
241
|
+
self.show_plots = self.cfg.plot.show
|
|
242
|
+
|
|
243
|
+
# ---- core runtime ----
|
|
244
|
+
self.is_haploid = False
|
|
245
|
+
self.num_classes_ = False
|
|
246
|
+
self.model_params: Dict[str, Any] = {}
|
|
247
|
+
self.sim_mask_global_: np.ndarray | None = None
|
|
248
|
+
self.sim_mask_train_: np.ndarray | None = None
|
|
249
|
+
self.sim_mask_test_: np.ndarray | None = None
|
|
250
|
+
|
|
251
|
+
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
252
|
+
msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
|
|
253
|
+
self.logger.error(msg)
|
|
254
|
+
raise ValueError(msg)
|
|
255
|
+
|
|
256
|
+
def fit(self) -> "ImputeUBP":
|
|
257
|
+
"""Fit the UBP decoder on 0/1/2 encodings (missing = -1) via three phases.
|
|
258
|
+
|
|
259
|
+
1. Phase 1 initializes latent vectors alongside the linear decoder.
|
|
260
|
+
2. Phase 2 resets and trains the deeper decoder while latents remain fixed.
|
|
261
|
+
3. Phase 3 jointly fine-tunes latents plus the deep decoder before evaluation.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
ImputeUBP: Fitted instance.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
NotFittedError: If training fails.
|
|
268
|
+
"""
|
|
269
|
+
self.logger.info(f"Fitting {self.model_name} model...")
|
|
270
|
+
|
|
271
|
+
# --- Use 0/1/2 with -1 for missing ---
|
|
272
|
+
X012 = self._get_float_genotypes(copy=True)
|
|
273
|
+
GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
|
|
274
|
+
self.ground_truth_ = GT_full.astype(np.int64, copy=False)
|
|
275
|
+
|
|
276
|
+
cache_key = self._sim_mask_cache_key()
|
|
277
|
+
self.sim_mask_global_ = None
|
|
278
|
+
if self.simulate_missing:
|
|
279
|
+
cached_mask = (
|
|
280
|
+
None if cache_key is None else self._sim_mask_cache.get(cache_key)
|
|
281
|
+
)
|
|
282
|
+
if cached_mask is not None:
|
|
283
|
+
self.sim_mask_global_ = cached_mask.copy()
|
|
284
|
+
else:
|
|
285
|
+
tr = SimMissingTransformer(
|
|
286
|
+
genotype_data=self.genotype_data,
|
|
287
|
+
tree_parser=self.tree_parser,
|
|
288
|
+
prop_missing=self.sim_prop,
|
|
289
|
+
strategy=self.sim_strategy,
|
|
290
|
+
missing_val=-9,
|
|
291
|
+
mask_missing=True,
|
|
292
|
+
verbose=self.verbose,
|
|
293
|
+
**self.sim_kwargs,
|
|
294
|
+
)
|
|
295
|
+
tr.fit(X012.copy())
|
|
296
|
+
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
297
|
+
if cache_key is not None:
|
|
298
|
+
self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
|
|
299
|
+
|
|
300
|
+
X_for_model = self.ground_truth_.copy()
|
|
301
|
+
if self.sim_mask_global_ is not None:
|
|
302
|
+
X_for_model[self.sim_mask_global_] = -1
|
|
303
|
+
|
|
304
|
+
# --- Determine ploidy (haploid vs diploid) and classes ---
|
|
305
|
+
self.is_haploid = bool(
|
|
306
|
+
np.all(
|
|
307
|
+
np.isin(
|
|
308
|
+
self.genotype_data.snp_data,
|
|
309
|
+
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
self.ploidy = 1 if self.is_haploid else 2
|
|
314
|
+
|
|
315
|
+
if self.is_haploid:
|
|
316
|
+
self.num_classes_ = 2
|
|
317
|
+
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
318
|
+
X_for_model[X_for_model == 2] = 1
|
|
319
|
+
self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
|
|
320
|
+
else:
|
|
321
|
+
self.num_classes_ = 3
|
|
322
|
+
self.logger.info(
|
|
323
|
+
"Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
n_samples, self.num_features_ = X_for_model.shape
|
|
327
|
+
|
|
328
|
+
# --- model params (decoder: Z -> L * num_classes) ---
|
|
329
|
+
self.model_params = {
|
|
330
|
+
"n_features": self.num_features_,
|
|
331
|
+
"num_classes": self.num_classes_,
|
|
332
|
+
"latent_dim": self.latent_dim,
|
|
333
|
+
"dropout_rate": self.dropout_rate,
|
|
334
|
+
"activation": self.activation,
|
|
335
|
+
# hidden_layer_sizes injected later
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
# --- split ---
|
|
339
|
+
indices = np.arange(n_samples)
|
|
340
|
+
train_idx, test_idx = train_test_split(
|
|
341
|
+
indices, test_size=self.validation_split, random_state=self.seed
|
|
342
|
+
)
|
|
343
|
+
self.train_idx_, self.test_idx_ = train_idx, test_idx
|
|
344
|
+
self.X_train_ = X_for_model[train_idx]
|
|
345
|
+
self.X_test_ = X_for_model[test_idx]
|
|
346
|
+
self.GT_train_full_ = self.ground_truth_[train_idx]
|
|
347
|
+
self.GT_test_full_ = self.ground_truth_[test_idx]
|
|
348
|
+
|
|
349
|
+
if self.sim_mask_global_ is not None:
|
|
350
|
+
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
351
|
+
self.sim_mask_test_ = self.sim_mask_global_[test_idx]
|
|
352
|
+
else:
|
|
353
|
+
self.sim_mask_train_ = None
|
|
354
|
+
self.sim_mask_test_ = None
|
|
355
|
+
|
|
356
|
+
# --- plotting/scorers & tuning ---
|
|
357
|
+
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
358
|
+
if self.tune:
|
|
359
|
+
self.tune_hyperparameters()
|
|
360
|
+
|
|
361
|
+
# Fall back to default model params when none have been selected yet.
|
|
362
|
+
if not getattr(self, "best_params_", None):
|
|
363
|
+
self.best_params_ = self._set_best_params_default()
|
|
364
|
+
|
|
365
|
+
# --- class weights for 0/1/2 ---
|
|
366
|
+
self.class_weights_ = self._normalize_class_weights(
|
|
367
|
+
self._class_weights_from_zygosity(self.X_train_)
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# --- latent init & loader ---
|
|
371
|
+
train_latent_vectors = self._create_latent_space(
|
|
372
|
+
self.best_params_, len(self.X_train_), self.X_train_, self.latent_init
|
|
373
|
+
)
|
|
374
|
+
train_loader = self._get_data_loaders(self.X_train_)
|
|
375
|
+
|
|
376
|
+
# --- final training (three-phase under the hood) ---
|
|
377
|
+
(self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
|
|
378
|
+
self._train_final_model(
|
|
379
|
+
loader=train_loader,
|
|
380
|
+
best_params=self.best_params_,
|
|
381
|
+
initial_latent_vectors=train_latent_vectors,
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
self.is_fit_ = True
|
|
386
|
+
self.plotter_.plot_history(self.history_)
|
|
387
|
+
eval_mask = (
|
|
388
|
+
self.sim_mask_test_
|
|
389
|
+
if (self.simulate_missing and self.sim_mask_test_ is not None)
|
|
390
|
+
else None
|
|
391
|
+
)
|
|
392
|
+
self._evaluate_model(
|
|
393
|
+
self.X_test_,
|
|
394
|
+
self.model_,
|
|
395
|
+
self.best_params_,
|
|
396
|
+
eval_mask_override=eval_mask,
|
|
397
|
+
)
|
|
398
|
+
self._save_best_params(self.best_params_)
|
|
399
|
+
return self
|
|
400
|
+
|
|
401
|
+
def transform(self) -> np.ndarray:
|
|
402
|
+
"""Impute missing genotypes (0/1/2) and return IUPAC strings.
|
|
403
|
+
|
|
404
|
+
This method first checks if the model has been fitted. It then imputes the entire dataset by optimizing latent vectors for the ground truth data and predicting the missing genotypes using the trained UBP model. The imputed genotypes are decoded to IUPAC format, and genotype distributions are plotted only when ``self.show_plots`` is enabled.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
np.ndarray: IUPAC single-character array (n_samples x L).
|
|
408
|
+
|
|
409
|
+
Raises:
|
|
410
|
+
NotFittedError: If called before fit().
|
|
411
|
+
"""
|
|
412
|
+
if not getattr(self, "is_fit_", False):
|
|
413
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
414
|
+
|
|
415
|
+
self.logger.info(f"Imputing entire dataset with {self.model_name}...")
|
|
416
|
+
X_to_impute = self.ground_truth_.copy()
|
|
417
|
+
|
|
418
|
+
optimized_latents = self._optimize_latents_for_inference(
|
|
419
|
+
X_to_impute, self.model_, self.best_params_
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if not isinstance(optimized_latents, torch.nn.Parameter):
|
|
423
|
+
optimized_latents = torch.nn.Parameter(
|
|
424
|
+
optimized_latents, requires_grad=False
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
|
|
428
|
+
|
|
429
|
+
missing_mask = X_to_impute == -1
|
|
430
|
+
imputed_array = X_to_impute.copy()
|
|
431
|
+
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
432
|
+
|
|
433
|
+
# Decode to IUPAC for return & optional plots
|
|
434
|
+
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
435
|
+
if self.show_plots:
|
|
436
|
+
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
437
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
438
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
439
|
+
return imputed_genotypes
|
|
440
|
+
|
|
441
|
+
def _train_step(
|
|
442
|
+
self,
|
|
443
|
+
loader: torch.utils.data.DataLoader,
|
|
444
|
+
optimizer: torch.optim.Optimizer,
|
|
445
|
+
latent_optimizer: torch.optim.Optimizer,
|
|
446
|
+
model: torch.nn.Module,
|
|
447
|
+
l1_penalty: float,
|
|
448
|
+
latent_vectors: torch.nn.Parameter,
|
|
449
|
+
class_weights: torch.Tensor,
|
|
450
|
+
phase: int,
|
|
451
|
+
) -> Tuple[float, torch.nn.Parameter]:
|
|
452
|
+
"""One epoch with stable focal CE, grad clipping, and NaN guards.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
Tuple[float, torch.nn.Parameter]: Mean loss and updated latents.
|
|
456
|
+
"""
|
|
457
|
+
model.train()
|
|
458
|
+
running, used = 0.0, 0
|
|
459
|
+
|
|
460
|
+
if not isinstance(latent_vectors, torch.nn.Parameter):
|
|
461
|
+
latent_vectors = torch.nn.Parameter(latent_vectors, requires_grad=True)
|
|
462
|
+
|
|
463
|
+
gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
|
|
464
|
+
gamma = max(0.0, min(gamma, 10.0))
|
|
465
|
+
l1_params = tuple(p for p in model.parameters() if p.requires_grad)
|
|
466
|
+
if class_weights is not None and class_weights.device != self.device:
|
|
467
|
+
class_weights = class_weights.to(self.device)
|
|
468
|
+
|
|
469
|
+
criterion = SafeFocalCELoss(gamma=gamma, weight=class_weights, ignore_index=-1)
|
|
470
|
+
decoder: torch.Tensor | torch.nn.Module = (
|
|
471
|
+
model.phase1_decoder if phase == 1 else model.phase23_decoder
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
if not isinstance(decoder, torch.nn.Module):
|
|
475
|
+
msg = f"{self.model_name} Decoder is not a torch.nn.Module."
|
|
476
|
+
self.logger.error(msg)
|
|
477
|
+
raise TypeError(msg)
|
|
478
|
+
|
|
479
|
+
for batch_indices, y_batch in loader:
|
|
480
|
+
optimizer.zero_grad(set_to_none=True)
|
|
481
|
+
latent_optimizer.zero_grad(set_to_none=True)
|
|
482
|
+
|
|
483
|
+
batch_indices = batch_indices.to(latent_vectors.device, non_blocking=True)
|
|
484
|
+
z = latent_vectors[batch_indices]
|
|
485
|
+
y = y_batch.to(self.device, non_blocking=True).long()
|
|
486
|
+
|
|
487
|
+
logits = decoder(z).view(
|
|
488
|
+
len(batch_indices), self.num_features_, self.num_classes_
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Guard upstream explosions
|
|
492
|
+
if not torch.isfinite(logits).all():
|
|
493
|
+
continue
|
|
494
|
+
|
|
495
|
+
loss = criterion(logits.view(-1, self.num_classes_), y.view(-1))
|
|
496
|
+
|
|
497
|
+
if l1_penalty > 0:
|
|
498
|
+
l1 = torch.zeros((), device=self.device)
|
|
499
|
+
for p in l1_params:
|
|
500
|
+
l1 = l1 + p.abs().sum()
|
|
501
|
+
loss = loss + l1_penalty * l1
|
|
502
|
+
|
|
503
|
+
if not torch.isfinite(loss):
|
|
504
|
+
continue
|
|
505
|
+
|
|
506
|
+
loss.backward()
|
|
507
|
+
|
|
508
|
+
# Clip returns the Total Norm
|
|
509
|
+
model_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
510
|
+
latent_norm = torch.nn.utils.clip_grad_norm_([latent_vectors], 1.0)
|
|
511
|
+
|
|
512
|
+
# Skip update on non-finite grads
|
|
513
|
+
# Check norms instead of iterating all parameters
|
|
514
|
+
if torch.isfinite(model_norm) and torch.isfinite(latent_norm):
|
|
515
|
+
optimizer.step()
|
|
516
|
+
if phase != 2:
|
|
517
|
+
latent_optimizer.step()
|
|
518
|
+
else:
|
|
519
|
+
# Logic to handle bad grads (zero out, skip, etc)
|
|
520
|
+
optimizer.zero_grad(set_to_none=True)
|
|
521
|
+
latent_optimizer.zero_grad(set_to_none=True)
|
|
522
|
+
|
|
523
|
+
running += float(loss.detach().item())
|
|
524
|
+
used += 1
|
|
525
|
+
|
|
526
|
+
return (running / used if used > 0 else float("inf")), latent_vectors
|
|
527
|
+
|
|
528
|
+
def _predict(
|
|
529
|
+
self,
|
|
530
|
+
model: torch.nn.Module,
|
|
531
|
+
latent_vectors: Optional[torch.nn.Parameter | torch.Tensor] = None,
|
|
532
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
533
|
+
"""Predict 0/1/2 labels & probabilities from latents via phase23 decoder. This method requires a trained model and latent vectors.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
model (torch.nn.Module): Trained model.
|
|
537
|
+
latent_vectors (torch.nn.Parameter | None): Latent vectors.
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
|
|
541
|
+
"""
|
|
542
|
+
if model is None or latent_vectors is None:
|
|
543
|
+
msg = "Model and latent vectors must be provided for prediction. Fit the model first."
|
|
544
|
+
self.logger.error(msg)
|
|
545
|
+
raise NotFittedError(msg)
|
|
546
|
+
|
|
547
|
+
model.eval()
|
|
548
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
549
|
+
with torch.no_grad():
|
|
550
|
+
decoder = model.phase23_decoder
|
|
551
|
+
|
|
552
|
+
if not isinstance(decoder, torch.nn.Module):
|
|
553
|
+
msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
|
|
554
|
+
self.logger.error(msg)
|
|
555
|
+
raise TypeError(msg)
|
|
556
|
+
|
|
557
|
+
logits = decoder(latent_vectors.to(self.device)).view(
|
|
558
|
+
len(latent_vectors), nF, self.num_classes_
|
|
559
|
+
)
|
|
560
|
+
probas = torch.softmax(logits, dim=-1)
|
|
561
|
+
labels = torch.argmax(probas, dim=-1)
|
|
562
|
+
|
|
563
|
+
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
564
|
+
|
|
565
|
+
def _evaluate_model(
|
|
566
|
+
self,
|
|
567
|
+
X_val: np.ndarray,
|
|
568
|
+
model: torch.nn.Module,
|
|
569
|
+
params: dict,
|
|
570
|
+
objective_mode: bool = False,
|
|
571
|
+
latent_vectors_val: torch.Tensor | None = None,
|
|
572
|
+
*,
|
|
573
|
+
eval_mask_override: np.ndarray | None = None,
|
|
574
|
+
) -> Dict[str, float]:
|
|
575
|
+
"""Evaluates the model on a validation set.
|
|
576
|
+
|
|
577
|
+
This method evaluates the trained UBP model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
|
|
581
|
+
model (torch.nn.Module): Trained UBP model.
|
|
582
|
+
params (dict): Model parameters.
|
|
583
|
+
objective_mode (bool): If True, suppresses logging and reports only the metric.
|
|
584
|
+
latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
|
|
585
|
+
eval_mask_override (np.ndarray | None): Boolean mask to specify which entries to evaluate.
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
Dict[str, float]: Dictionary of evaluation metrics.
|
|
589
|
+
"""
|
|
590
|
+
if latent_vectors_val is not None:
|
|
591
|
+
test_latent_vectors = latent_vectors_val
|
|
592
|
+
else:
|
|
593
|
+
test_latent_vectors = self._optimize_latents_for_inference(
|
|
594
|
+
X_val, model, params
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
pred_labels, pred_probas = self._predict(
|
|
598
|
+
model=model, latent_vectors=test_latent_vectors
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
if eval_mask_override is not None:
|
|
602
|
+
# Validate row counts to allow feature subsetting during tuning
|
|
603
|
+
if eval_mask_override.shape[0] != X_val.shape[0]:
|
|
604
|
+
msg = (
|
|
605
|
+
f"eval_mask_override rows {eval_mask_override.shape[0]} "
|
|
606
|
+
f"does not match X_val rows {X_val.shape[0]}"
|
|
607
|
+
)
|
|
608
|
+
self.logger.error(msg)
|
|
609
|
+
raise ValueError(msg)
|
|
610
|
+
|
|
611
|
+
# FIX: Slice mask columns if override is wider than current X_val (tune_fast)
|
|
612
|
+
if eval_mask_override.shape[1] > X_val.shape[1]:
|
|
613
|
+
eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
|
|
614
|
+
else:
|
|
615
|
+
eval_mask = eval_mask_override.astype(bool)
|
|
616
|
+
else:
|
|
617
|
+
# Default: score only observed entries
|
|
618
|
+
eval_mask = X_val != -1
|
|
619
|
+
|
|
620
|
+
# y_true should be drawn from the pre-mask ground truth
|
|
621
|
+
# Map X_val back to the correct full ground truth slice
|
|
622
|
+
# FIX: Check shape[0] (n_samples) only.
|
|
623
|
+
if X_val.shape[0] == self.X_test_.shape[0]:
|
|
624
|
+
GT_ref = self.GT_test_full_
|
|
625
|
+
elif X_val.shape[0] == self.X_train_.shape[0]:
|
|
626
|
+
GT_ref = self.GT_train_full_
|
|
627
|
+
else:
|
|
628
|
+
GT_ref = self.ground_truth_
|
|
629
|
+
|
|
630
|
+
# FIX: Slice Ground Truth columns if it is wider than X_val (tune_fast)
|
|
631
|
+
if GT_ref.shape[1] > X_val.shape[1]:
|
|
632
|
+
GT_ref = GT_ref[:, : X_val.shape[1]]
|
|
633
|
+
|
|
634
|
+
# Fallback safeguard
|
|
635
|
+
if GT_ref.shape != X_val.shape:
|
|
636
|
+
GT_ref = X_val
|
|
637
|
+
|
|
638
|
+
y_true_flat = GT_ref[eval_mask]
|
|
639
|
+
pred_labels_flat = pred_labels[eval_mask]
|
|
640
|
+
pred_probas_flat = pred_probas[eval_mask]
|
|
641
|
+
|
|
642
|
+
if y_true_flat.size == 0:
|
|
643
|
+
return {self.tune_metric: 0.0}
|
|
644
|
+
|
|
645
|
+
# For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
|
|
646
|
+
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
647
|
+
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
648
|
+
|
|
649
|
+
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
650
|
+
|
|
651
|
+
metrics = self.scorers_.evaluate(
|
|
652
|
+
y_true_flat,
|
|
653
|
+
pred_labels_flat,
|
|
654
|
+
y_true_ohe,
|
|
655
|
+
pred_probas_flat,
|
|
656
|
+
objective_mode,
|
|
657
|
+
self.tune_metric,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
if not objective_mode:
|
|
661
|
+
pm = PrettyMetrics(
|
|
662
|
+
metrics, precision=3, title=f"{self.model_name} Validation Metrics"
|
|
663
|
+
)
|
|
664
|
+
pm.render() # prints a command-line table
|
|
665
|
+
|
|
666
|
+
self._make_class_reports(
|
|
667
|
+
y_true=y_true_flat,
|
|
668
|
+
y_pred_proba=pred_probas_flat,
|
|
669
|
+
y_pred=pred_labels_flat,
|
|
670
|
+
metrics=metrics,
|
|
671
|
+
labels=target_names,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# FIX: Use X_val dimensions for reshaping, not self.num_features_
|
|
675
|
+
y_true_dec = self.pgenc.decode_012(
|
|
676
|
+
GT_ref.reshape(X_val.shape[0], X_val.shape[1])
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
X_pred = X_val.copy()
|
|
680
|
+
X_pred[eval_mask] = pred_labels_flat
|
|
681
|
+
|
|
682
|
+
y_pred_dec = self.pgenc.decode_012(
|
|
683
|
+
X_pred.reshape(X_val.shape[0], X_val.shape[1])
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
encodings_dict = {
|
|
687
|
+
"A": 0,
|
|
688
|
+
"C": 1,
|
|
689
|
+
"G": 2,
|
|
690
|
+
"T": 3,
|
|
691
|
+
"W": 4,
|
|
692
|
+
"R": 5,
|
|
693
|
+
"M": 6,
|
|
694
|
+
"K": 7,
|
|
695
|
+
"Y": 8,
|
|
696
|
+
"S": 9,
|
|
697
|
+
"N": -1,
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
y_true_int = self.pgenc.convert_int_iupac(
|
|
701
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
702
|
+
)
|
|
703
|
+
y_pred_int = self.pgenc.convert_int_iupac(
|
|
704
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
# For IUPAC report
|
|
708
|
+
valid_true = y_true_int[eval_mask]
|
|
709
|
+
valid_true = valid_true[valid_true >= 0] # drop -1 (N)
|
|
710
|
+
iupac_label_set = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
|
|
711
|
+
|
|
712
|
+
# For numeric report
|
|
713
|
+
if (
|
|
714
|
+
np.intersect1d(np.unique(y_true_flat), labels_for_scoring).size == 0
|
|
715
|
+
or valid_true.size == 0
|
|
716
|
+
):
|
|
717
|
+
if not objective_mode:
|
|
718
|
+
self.logger.warning(
|
|
719
|
+
"Skipped numeric confusion matrix: no y_true labels present."
|
|
720
|
+
)
|
|
721
|
+
else:
|
|
722
|
+
self._make_class_reports(
|
|
723
|
+
y_true=valid_true,
|
|
724
|
+
y_pred=y_pred_int[eval_mask][y_true_int[eval_mask] >= 0],
|
|
725
|
+
metrics=metrics,
|
|
726
|
+
y_pred_proba=None,
|
|
727
|
+
labels=iupac_label_set,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
return metrics
|
|
731
|
+
|
|
732
|
+
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
733
|
+
"""Create DataLoader over indices + 0/1/2 target matrix.
|
|
734
|
+
|
|
735
|
+
This method creates a PyTorch DataLoader for the given genotype matrix, which contains 0/1/2 encodings with -1 for missing values. The DataLoader is constructed to yield batches of data during training, where each batch consists of indices and the corresponding genotype values. The genotype matrix is converted to a PyTorch tensor and moved to the appropriate device (CPU or GPU) before being wrapped in a TensorDataset. The DataLoader is configured to shuffle the data and use the specified batch size.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
y (np.ndarray): (n_samples x L) int matrix with -1 missing.
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
torch.utils.data.DataLoader: Shuffled mini-batches.
|
|
742
|
+
"""
|
|
743
|
+
y_tensor = torch.from_numpy(y).long()
|
|
744
|
+
indices = torch.arange(len(y), dtype=torch.long)
|
|
745
|
+
dataset = torch.utils.data.TensorDataset(indices, y_tensor)
|
|
746
|
+
pin_memory = self.device.type == "cuda"
|
|
747
|
+
return torch.utils.data.DataLoader(
|
|
748
|
+
dataset,
|
|
749
|
+
batch_size=self.batch_size,
|
|
750
|
+
shuffle=True,
|
|
751
|
+
pin_memory=pin_memory,
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
755
|
+
"""Optuna objective using the UBP training loop."""
|
|
756
|
+
try:
|
|
757
|
+
params = self._sample_hyperparameters(trial)
|
|
758
|
+
|
|
759
|
+
X_train_trial = getattr(
|
|
760
|
+
self, "X_train_", self.ground_truth_[self.train_idx_]
|
|
761
|
+
)
|
|
762
|
+
X_test_trial = getattr(self, "X_test_", self.ground_truth_[self.test_idx_])
|
|
763
|
+
|
|
764
|
+
class_weights = self._normalize_class_weights(
|
|
765
|
+
self._class_weights_from_zygosity(X_train_trial)
|
|
766
|
+
)
|
|
767
|
+
train_loader = self._get_data_loaders(X_train_trial)
|
|
768
|
+
|
|
769
|
+
train_latent_vectors = self._create_latent_space(
|
|
770
|
+
params, len(X_train_trial), X_train_trial, params["latent_init"]
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
model = self.build_model(self.Model, params["model_params"])
|
|
774
|
+
model.n_features = params["model_params"]["n_features"]
|
|
775
|
+
model.apply(self.initialize_weights)
|
|
776
|
+
|
|
777
|
+
_, model, __ = self._train_and_validate_model(
|
|
778
|
+
model=model,
|
|
779
|
+
loader=train_loader,
|
|
780
|
+
lr=params["lr"],
|
|
781
|
+
l1_penalty=params["l1_penalty"],
|
|
782
|
+
trial=trial,
|
|
783
|
+
return_history=False,
|
|
784
|
+
latent_vectors=train_latent_vectors,
|
|
785
|
+
lr_input_factor=params["lr_input_factor"],
|
|
786
|
+
class_weights=class_weights,
|
|
787
|
+
X_val=X_test_trial,
|
|
788
|
+
params=params,
|
|
789
|
+
prune_metric=self.tune_metric,
|
|
790
|
+
prune_warmup_epochs=5,
|
|
791
|
+
eval_interval=1,
|
|
792
|
+
eval_requires_latents=True,
|
|
793
|
+
eval_latent_steps=self.eval_latent_steps,
|
|
794
|
+
eval_latent_lr=self.eval_latent_lr,
|
|
795
|
+
eval_latent_weight_decay=self.eval_latent_weight_decay,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
eval_mask = (
|
|
799
|
+
self.sim_mask_test_
|
|
800
|
+
if (
|
|
801
|
+
self.simulate_missing
|
|
802
|
+
and getattr(self, "sim_mask_test_", None) is not None
|
|
803
|
+
)
|
|
804
|
+
else None
|
|
805
|
+
)
|
|
806
|
+
metrics = self._evaluate_model(
|
|
807
|
+
X_test_trial,
|
|
808
|
+
model,
|
|
809
|
+
params,
|
|
810
|
+
objective_mode=True,
|
|
811
|
+
eval_mask_override=eval_mask,
|
|
812
|
+
)
|
|
813
|
+
self._clear_resources(
|
|
814
|
+
model, train_loader, latent_vectors=train_latent_vectors
|
|
815
|
+
)
|
|
816
|
+
return metrics[self.tune_metric]
|
|
817
|
+
except Exception as e:
|
|
818
|
+
raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
|
|
819
|
+
|
|
820
|
+
def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
|
|
821
|
+
"""Sample UBP hyperparameters; compute hidden sizes for model_params.
|
|
822
|
+
|
|
823
|
+
This method samples a set of hyperparameters for the UBP model using the provided Optuna trial object. It defines a search space for various hyperparameters, including latent dimension, learning rate, dropout rate, number of hidden layers, activation function, and others. After sampling the hyperparameters, it computes the sizes of the hidden layers based on the sampled values and constructs the model parameters dictionary. The method returns a dictionary containing all sampled hyperparameters along with the computed model parameters.
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
trial (optuna.Trial): Current trial.
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
Dict[str, int | float | str | list]: Sampled hyperparameters.
|
|
830
|
+
"""
|
|
831
|
+
params = {
|
|
832
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
833
|
+
"lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
|
|
834
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
|
|
835
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
|
|
836
|
+
"activation": trial.suggest_categorical(
|
|
837
|
+
"activation", ["relu", "elu", "selu"]
|
|
838
|
+
),
|
|
839
|
+
"gamma": trial.suggest_float("gamma", 0.0, 5.0),
|
|
840
|
+
"lr_input_factor": trial.suggest_float(
|
|
841
|
+
"lr_input_factor", 0.1, 10.0, log=True
|
|
842
|
+
),
|
|
843
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
|
|
844
|
+
"layer_scaling_factor": trial.suggest_float(
|
|
845
|
+
"layer_scaling_factor", 2.0, 10.0
|
|
846
|
+
),
|
|
847
|
+
"layer_schedule": trial.suggest_categorical(
|
|
848
|
+
"layer_schedule", ["pyramid", "constant", "linear"]
|
|
849
|
+
),
|
|
850
|
+
"latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
854
|
+
n_inputs=params["latent_dim"],
|
|
855
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
856
|
+
n_samples=len(self.train_idx_),
|
|
857
|
+
n_hidden=params["num_hidden_layers"],
|
|
858
|
+
alpha=params["layer_scaling_factor"],
|
|
859
|
+
schedule=params["layer_schedule"],
|
|
860
|
+
)
|
|
861
|
+
# Keep the latent_dim as the first element,
|
|
862
|
+
# then the interior hidden widths.
|
|
863
|
+
# If there are no interior widths (very small nets),
|
|
864
|
+
# this still leaves [latent_dim].
|
|
865
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
866
|
+
|
|
867
|
+
params["model_params"] = {
|
|
868
|
+
"n_features": self.num_features_,
|
|
869
|
+
"num_classes": self.num_classes_,
|
|
870
|
+
"latent_dim": params["latent_dim"],
|
|
871
|
+
"dropout_rate": params["dropout_rate"],
|
|
872
|
+
"hidden_layer_sizes": hidden_only,
|
|
873
|
+
"activation": params["activation"],
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
return params
|
|
877
|
+
|
|
878
|
+
def _set_best_params(self, best_params: dict) -> dict:
|
|
879
|
+
"""Set best params onto instance; return model_params payload.
|
|
880
|
+
|
|
881
|
+
This method sets the best hyperparameters found during tuning onto the instance attributes of the ImputeUBP class. It extracts the relevant hyperparameters from the provided dictionary and updates the corresponding instance variables. Additionally, it computes the sizes of the hidden layers based on the best hyperparameters and constructs the model parameters dictionary. The method returns a dictionary containing the model parameters that can be used to build the UBP model.
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
best_params (dict): Best hyperparameters.
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
dict: model_params payload.
|
|
888
|
+
|
|
889
|
+
Raises:
|
|
890
|
+
ValueError: If best_params is missing required keys.
|
|
891
|
+
"""
|
|
892
|
+
self.latent_dim = best_params["latent_dim"]
|
|
893
|
+
self.dropout_rate = best_params["dropout_rate"]
|
|
894
|
+
self.learning_rate = best_params["learning_rate"]
|
|
895
|
+
self.gamma = best_params["gamma"]
|
|
896
|
+
self.lr_input_factor = best_params["lr_input_factor"]
|
|
897
|
+
self.l1_penalty = best_params["l1_penalty"]
|
|
898
|
+
self.activation = best_params["activation"]
|
|
899
|
+
self.latent_init = best_params["latent_init"]
|
|
900
|
+
|
|
901
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
902
|
+
n_inputs=self.latent_dim,
|
|
903
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
904
|
+
n_samples=len(self.train_idx_),
|
|
905
|
+
n_hidden=best_params["num_hidden_layers"],
|
|
906
|
+
alpha=best_params["layer_scaling_factor"],
|
|
907
|
+
schedule=best_params["layer_schedule"],
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
911
|
+
|
|
912
|
+
return {
|
|
913
|
+
"n_features": self.num_features_,
|
|
914
|
+
"latent_dim": self.latent_dim,
|
|
915
|
+
"hidden_layer_sizes": hidden_only,
|
|
916
|
+
"dropout_rate": self.dropout_rate,
|
|
917
|
+
"activation": self.activation,
|
|
918
|
+
"gamma": self.gamma,
|
|
919
|
+
"num_classes": self.num_classes_,
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
def _set_best_params_default(self) -> dict:
|
|
923
|
+
"""Default (no-tuning) model_params aligned with current attributes.
|
|
924
|
+
|
|
925
|
+
This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
|
|
926
|
+
|
|
927
|
+
Returns:
|
|
928
|
+
dict: model_params payload.
|
|
929
|
+
"""
|
|
930
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
931
|
+
n_inputs=self.latent_dim,
|
|
932
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
933
|
+
n_samples=len(self.ground_truth_),
|
|
934
|
+
n_hidden=self.num_hidden_layers,
|
|
935
|
+
alpha=self.layer_scaling_factor,
|
|
936
|
+
schedule=self.layer_schedule,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
|
|
940
|
+
|
|
941
|
+
return {
|
|
942
|
+
"n_features": self.num_features_,
|
|
943
|
+
"latent_dim": self.latent_dim,
|
|
944
|
+
"hidden_layer_sizes": hidden_only,
|
|
945
|
+
"dropout_rate": self.dropout_rate,
|
|
946
|
+
"activation": self.activation,
|
|
947
|
+
"gamma": self.gamma,
|
|
948
|
+
"num_classes": self.num_classes_,
|
|
949
|
+
}
|
|
950
|
+
|
|
951
|
+
def _train_and_validate_model(
|
|
952
|
+
self,
|
|
953
|
+
model: torch.nn.Module,
|
|
954
|
+
loader: torch.utils.data.DataLoader,
|
|
955
|
+
lr: float,
|
|
956
|
+
l1_penalty: float,
|
|
957
|
+
trial: optuna.Trial | None = None,
|
|
958
|
+
return_history: bool = False,
|
|
959
|
+
latent_vectors: torch.nn.Parameter | None = None,
|
|
960
|
+
lr_input_factor: float = 1.0,
|
|
961
|
+
class_weights: torch.Tensor | None = None,
|
|
962
|
+
*,
|
|
963
|
+
X_val: np.ndarray | None = None,
|
|
964
|
+
params: dict | None = None,
|
|
965
|
+
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
966
|
+
prune_warmup_epochs: int = 3,
|
|
967
|
+
eval_interval: int = 1,
|
|
968
|
+
eval_requires_latents: bool = True, # UBP needs latent eval
|
|
969
|
+
eval_latent_steps: int = 50,
|
|
970
|
+
eval_latent_lr: float = 1e-2,
|
|
971
|
+
eval_latent_weight_decay: float = 0.0,
|
|
972
|
+
) -> tuple:
|
|
973
|
+
"""Train & validate UBP model with three-phase loop.
|
|
974
|
+
|
|
975
|
+
This method trains and validates the UBP model using a three-phase training loop. It sets up the latent optimizer and invokes the training loop, which includes pre-training, fine-tuning, and joint training phases. The method ensures that the necessary latent vectors and class weights are provided before proceeding with training. It also incorporates new parameters for evaluation and pruning during training. The final best loss, best model, training history, and optimized latent vectors are returned.
|
|
976
|
+
|
|
977
|
+
Args:
|
|
978
|
+
model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
|
|
979
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
980
|
+
lr (float): Learning rate for decoder.
|
|
981
|
+
l1_penalty (float): L1 regularization weight.
|
|
982
|
+
trial (optuna.Trial | None): Current trial or None.
|
|
983
|
+
return_history (bool): If True, return loss history.
|
|
984
|
+
latent_vectors (torch.nn.Parameter | None): Trainable Z.
|
|
985
|
+
lr_input_factor (float): LR factor for latents.
|
|
986
|
+
class_weights (torch.Tensor | None): Class weights for 0/1/2.
|
|
987
|
+
X_val (np.ndarray | None): Validation set for pruning/eval.
|
|
988
|
+
params (dict | None): Model params for eval.
|
|
989
|
+
prune_metric (str | None): Metric to monitor for pruning.
|
|
990
|
+
prune_warmup_epochs (int): Epochs before pruning starts.
|
|
991
|
+
eval_interval (int): Epochs between evaluations.
|
|
992
|
+
eval_requires_latents (bool): If True, optimize latents for eval.
|
|
993
|
+
eval_latent_steps (int): Latent optimization steps for eval.
|
|
994
|
+
eval_latent_lr (float): Latent optimization LR for eval.
|
|
995
|
+
eval_latent_weight_decay (float): Latent optimization weight decay for eval.
|
|
996
|
+
|
|
997
|
+
Returns:
|
|
998
|
+
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
|
|
999
|
+
|
|
1000
|
+
Raises:
|
|
1001
|
+
TypeError: If latent_vectors or class_weights are
|
|
1002
|
+
not provided.
|
|
1003
|
+
ValueError: If X_val is not provided for evaluation.
|
|
1004
|
+
RuntimeError: If eval_latent_steps is not positive.
|
|
1005
|
+
"""
|
|
1006
|
+
if latent_vectors is None or class_weights is None:
|
|
1007
|
+
msg = "Must provide latent_vectors and class_weights."
|
|
1008
|
+
self.logger.error(msg)
|
|
1009
|
+
raise TypeError(msg)
|
|
1010
|
+
|
|
1011
|
+
latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
|
|
1012
|
+
|
|
1013
|
+
result = self._execute_training_loop(
|
|
1014
|
+
loader=loader,
|
|
1015
|
+
latent_optimizer=latent_optimizer,
|
|
1016
|
+
lr=lr,
|
|
1017
|
+
model=model,
|
|
1018
|
+
l1_penalty=l1_penalty,
|
|
1019
|
+
trial=trial,
|
|
1020
|
+
return_history=return_history,
|
|
1021
|
+
latent_vectors=latent_vectors,
|
|
1022
|
+
class_weights=class_weights,
|
|
1023
|
+
# NEW ↓↓↓
|
|
1024
|
+
X_val=X_val,
|
|
1025
|
+
params=params,
|
|
1026
|
+
prune_metric=prune_metric,
|
|
1027
|
+
prune_warmup_epochs=prune_warmup_epochs,
|
|
1028
|
+
eval_interval=eval_interval,
|
|
1029
|
+
eval_requires_latents=eval_requires_latents,
|
|
1030
|
+
eval_latent_steps=eval_latent_steps,
|
|
1031
|
+
eval_latent_lr=eval_latent_lr,
|
|
1032
|
+
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
if return_history:
|
|
1036
|
+
return result
|
|
1037
|
+
|
|
1038
|
+
return result[0], result[1], result[3]
|
|
1039
|
+
|
|
1040
|
+
def _train_final_model(
|
|
1041
|
+
self,
|
|
1042
|
+
loader: torch.utils.data.DataLoader,
|
|
1043
|
+
best_params: dict,
|
|
1044
|
+
initial_latent_vectors: torch.nn.Parameter,
|
|
1045
|
+
) -> tuple:
|
|
1046
|
+
"""Train final UBP model with best params; save weights to disk.
|
|
1047
|
+
|
|
1048
|
+
This method trains the final UBP model using the best hyperparameters found during tuning. It builds the model with the specified parameters, initializes the weights, and invokes the training and validation process. The method saves the trained model's state dictionary to disk and returns the final loss, trained model, training history, and optimized latent vectors.
|
|
1049
|
+
|
|
1050
|
+
Args:
|
|
1051
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
1052
|
+
best_params (Dict[str, int | float | str | list]): Best hyperparameters.
|
|
1053
|
+
initial_latent_vectors (torch.nn.Parameter): Initialized latent vectors.
|
|
1054
|
+
|
|
1055
|
+
Returns:
|
|
1056
|
+
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (loss, model, {"Train": history}, latents).
|
|
1057
|
+
"""
|
|
1058
|
+
self.logger.info(f"Training the final {self.model_name} model...")
|
|
1059
|
+
|
|
1060
|
+
model = self.build_model(self.Model, best_params)
|
|
1061
|
+
model.n_features = best_params["n_features"]
|
|
1062
|
+
model.apply(self.initialize_weights)
|
|
1063
|
+
|
|
1064
|
+
loss, trained_model, history, latent_vectors = self._train_and_validate_model(
|
|
1065
|
+
model=model,
|
|
1066
|
+
loader=loader,
|
|
1067
|
+
lr=self.learning_rate,
|
|
1068
|
+
l1_penalty=self.l1_penalty,
|
|
1069
|
+
return_history=True,
|
|
1070
|
+
latent_vectors=initial_latent_vectors,
|
|
1071
|
+
lr_input_factor=self.lr_input_factor,
|
|
1072
|
+
class_weights=self.class_weights_,
|
|
1073
|
+
X_val=self.X_test_,
|
|
1074
|
+
params=best_params,
|
|
1075
|
+
prune_metric=self.tune_metric,
|
|
1076
|
+
prune_warmup_epochs=5,
|
|
1077
|
+
eval_interval=1,
|
|
1078
|
+
eval_requires_latents=True,
|
|
1079
|
+
eval_latent_steps=self.eval_latent_steps,
|
|
1080
|
+
eval_latent_lr=self.eval_latent_lr,
|
|
1081
|
+
eval_latent_weight_decay=self.eval_latent_weight_decay,
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
if trained_model is None:
|
|
1085
|
+
msg = "Final model training failed."
|
|
1086
|
+
self.logger.error(msg)
|
|
1087
|
+
raise RuntimeError(msg)
|
|
1088
|
+
|
|
1089
|
+
fout = self.models_dir / "final_model.pt"
|
|
1090
|
+
torch.save(trained_model.state_dict(), fout)
|
|
1091
|
+
return loss, trained_model, {"Train": history}, latent_vectors
|
|
1092
|
+
|
|
1093
|
+
def _execute_training_loop(
|
|
1094
|
+
self,
|
|
1095
|
+
loader: torch.utils.data.DataLoader,
|
|
1096
|
+
latent_optimizer: torch.optim.Optimizer,
|
|
1097
|
+
lr: float,
|
|
1098
|
+
model: torch.nn.Module,
|
|
1099
|
+
l1_penalty: float,
|
|
1100
|
+
trial: optuna.Trial | None,
|
|
1101
|
+
return_history: bool,
|
|
1102
|
+
latent_vectors: torch.nn.Parameter,
|
|
1103
|
+
class_weights: torch.Tensor,
|
|
1104
|
+
*,
|
|
1105
|
+
X_val: np.ndarray | None = None,
|
|
1106
|
+
params: dict | None = None,
|
|
1107
|
+
prune_metric: str | None = None,
|
|
1108
|
+
prune_warmup_epochs: int = 3,
|
|
1109
|
+
eval_interval: int = 1,
|
|
1110
|
+
eval_requires_latents: bool = True,
|
|
1111
|
+
eval_latent_steps: int = 50,
|
|
1112
|
+
eval_latent_lr: float = 1e-2,
|
|
1113
|
+
eval_latent_weight_decay: float = 0.0,
|
|
1114
|
+
) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
|
|
1115
|
+
"""Three-phase UBP with numeric guards, LR warmup, and pruning.
|
|
1116
|
+
|
|
1117
|
+
This method executes the three-phase training loop for the UBP model, incorporating numeric stability guards, learning rate warmup, and Optuna pruning. It iterates through three training phases: pre-training the phase 1 decoder, fine-tuning the phase 2 and 3 decoders, and joint training of all components. The method monitors training loss, applies early stopping, and evaluates the model on a validation set for pruning purposes. The final best loss, best model, training history, and optimized latent vectors are returned.
|
|
1118
|
+
|
|
1119
|
+
Args:
|
|
1120
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
1121
|
+
latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
|
|
1122
|
+
lr (float): Learning rate for decoder.
|
|
1123
|
+
model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
|
|
1124
|
+
l1_penalty (float): L1 regularization weight.
|
|
1125
|
+
trial (optuna.Trial | None): Current trial or None.
|
|
1126
|
+
return_history (bool): If True, return loss history.
|
|
1127
|
+
latent_vectors (torch.nn.Parameter): Trainable Z.
|
|
1128
|
+
class_weights (torch.Tensor): Class weights for
|
|
1129
|
+
0/1/2.
|
|
1130
|
+
X_val (np.ndarray | None): Validation set for pruning/eval.
|
|
1131
|
+
params (dict | None): Model params for eval.
|
|
1132
|
+
prune_metric (str | None): Metric to monitor for pruning.
|
|
1133
|
+
prune_warmup_epochs (int): Epochs before pruning starts.
|
|
1134
|
+
eval_interval (int): Epochs between evaluations.
|
|
1135
|
+
eval_requires_latents (bool): If True, optimize latents for eval.
|
|
1136
|
+
eval_latent_steps (int): Latent optimization steps for eval.
|
|
1137
|
+
eval_latent_lr (float): Latent optimization LR for eval.
|
|
1138
|
+
eval_latent_weight_decay (float): Latent optimization weight decay for eval.
|
|
1139
|
+
|
|
1140
|
+
Returns:
|
|
1141
|
+
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
|
|
1142
|
+
|
|
1143
|
+
Raises:
|
|
1144
|
+
ValueError: If X_val is not provided for evaluation.
|
|
1145
|
+
RuntimeError: If eval_latent_steps is not positive.
|
|
1146
|
+
"""
|
|
1147
|
+
history: dict[str, list[float]] = {}
|
|
1148
|
+
final_best_loss, final_best_model = float("inf"), None
|
|
1149
|
+
|
|
1150
|
+
warm, ramp, gamma_final = 50, 100, torch.tensor(self.gamma, device=self.device)
|
|
1151
|
+
|
|
1152
|
+
# Schema-aware latent cache for eval
|
|
1153
|
+
_latent_cache: dict = {}
|
|
1154
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
1155
|
+
cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.num_classes_}"
|
|
1156
|
+
|
|
1157
|
+
E = int(self.epochs)
|
|
1158
|
+
phase_epochs = {
|
|
1159
|
+
1: max(1, int(0.15 * E)),
|
|
1160
|
+
2: max(1, int(0.35 * E)),
|
|
1161
|
+
3: max(1, E - int(0.15 * E) - int(0.35 * E)),
|
|
1162
|
+
}
|
|
1163
|
+
|
|
1164
|
+
for phase in (1, 2, 3):
|
|
1165
|
+
steps_this_phase = phase_epochs[phase]
|
|
1166
|
+
warmup_epochs = getattr(self, "lr_warmup_epochs", 5) if phase == 1 else 0
|
|
1167
|
+
|
|
1168
|
+
early_stopping = EarlyStopping(
|
|
1169
|
+
patience=self.early_stop_gen,
|
|
1170
|
+
min_epochs=self.min_epochs,
|
|
1171
|
+
verbose=self.verbose,
|
|
1172
|
+
prefix=self.prefix,
|
|
1173
|
+
debug=self.debug,
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
if phase == 2:
|
|
1177
|
+
self._reset_weights(model)
|
|
1178
|
+
|
|
1179
|
+
decoder: torch.Tensor | torch.nn.Module = (
|
|
1180
|
+
model.phase1_decoder if phase == 1 else model.phase23_decoder
|
|
1181
|
+
)
|
|
1182
|
+
|
|
1183
|
+
if not isinstance(decoder, torch.nn.Module):
|
|
1184
|
+
msg = f"{self.model_name} Decoder is not a torch.nn.Module."
|
|
1185
|
+
self.logger.error(msg)
|
|
1186
|
+
raise TypeError(msg)
|
|
1187
|
+
|
|
1188
|
+
decoder_params = decoder.parameters()
|
|
1189
|
+
optimizer = torch.optim.AdamW(decoder_params, lr=lr, eps=1e-7)
|
|
1190
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
1191
|
+
optimizer, T_max=steps_this_phase
|
|
1192
|
+
)
|
|
1193
|
+
|
|
1194
|
+
# Cache base LRs for warmup
|
|
1195
|
+
dec_lr0 = optimizer.param_groups[0]["lr"]
|
|
1196
|
+
lat_lr0 = latent_optimizer.param_groups[0]["lr"]
|
|
1197
|
+
dec_lr_min, lat_lr_min = dec_lr0 * 0.1, lat_lr0 * 0.1
|
|
1198
|
+
|
|
1199
|
+
phase_hist: list[float] = []
|
|
1200
|
+
gamma_init = torch.tensor(0.0, device=self.device)
|
|
1201
|
+
|
|
1202
|
+
for epoch in range(steps_this_phase):
|
|
1203
|
+
# Focal gamma warm/ramp
|
|
1204
|
+
if epoch < warm:
|
|
1205
|
+
model.gamma = gamma_init.cpu().numpy().item()
|
|
1206
|
+
elif epoch < warm + ramp:
|
|
1207
|
+
model.gamma = gamma_final * ((epoch - warm) / ramp)
|
|
1208
|
+
else:
|
|
1209
|
+
model.gamma = gamma_final
|
|
1210
|
+
|
|
1211
|
+
# Linear warmup for both optimizers
|
|
1212
|
+
if warmup_epochs and epoch < warmup_epochs:
|
|
1213
|
+
scale = float(epoch + 1) / warmup_epochs
|
|
1214
|
+
for g in optimizer.param_groups:
|
|
1215
|
+
g["lr"] = dec_lr_min + (dec_lr0 - dec_lr_min) * scale
|
|
1216
|
+
for g in latent_optimizer.param_groups:
|
|
1217
|
+
g["lr"] = lat_lr_min + (lat_lr0 - lat_lr_min) * scale
|
|
1218
|
+
|
|
1219
|
+
train_loss, latent_vectors = self._train_step(
|
|
1220
|
+
loader=loader,
|
|
1221
|
+
optimizer=optimizer,
|
|
1222
|
+
latent_optimizer=latent_optimizer,
|
|
1223
|
+
model=model,
|
|
1224
|
+
l1_penalty=l1_penalty,
|
|
1225
|
+
latent_vectors=latent_vectors,
|
|
1226
|
+
class_weights=class_weights,
|
|
1227
|
+
phase=phase,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
if not np.isfinite(train_loss):
|
|
1231
|
+
if trial:
|
|
1232
|
+
raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
|
|
1233
|
+
# reduce LRs and continue
|
|
1234
|
+
for g in optimizer.param_groups:
|
|
1235
|
+
g["lr"] *= 0.5
|
|
1236
|
+
for g in latent_optimizer.param_groups:
|
|
1237
|
+
g["lr"] *= 0.5
|
|
1238
|
+
continue
|
|
1239
|
+
|
|
1240
|
+
scheduler.step()
|
|
1241
|
+
if return_history:
|
|
1242
|
+
phase_hist.append(train_loss)
|
|
1243
|
+
|
|
1244
|
+
early_stopping(train_loss, model)
|
|
1245
|
+
if early_stopping.early_stop:
|
|
1246
|
+
self.logger.info(
|
|
1247
|
+
f"Early stopping at epoch {epoch + 1} (phase {phase})."
|
|
1248
|
+
)
|
|
1249
|
+
break
|
|
1250
|
+
|
|
1251
|
+
# Validation + pruning
|
|
1252
|
+
if (
|
|
1253
|
+
trial is not None
|
|
1254
|
+
and X_val is not None
|
|
1255
|
+
and ((epoch + 1) % eval_interval == 0)
|
|
1256
|
+
):
|
|
1257
|
+
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
1258
|
+
zdim = self._first_linear_in_features(model)
|
|
1259
|
+
schema_key = f"{cache_key_root}_z{zdim}"
|
|
1260
|
+
mask_override = None
|
|
1261
|
+
if (
|
|
1262
|
+
self.simulate_missing
|
|
1263
|
+
and getattr(self, "sim_mask_test_", None) is not None
|
|
1264
|
+
and getattr(self, "X_test_", None) is not None
|
|
1265
|
+
and X_val.shape == self.X_test_.shape
|
|
1266
|
+
):
|
|
1267
|
+
mask_override = self.sim_mask_test_
|
|
1268
|
+
|
|
1269
|
+
metric_val = self._eval_for_pruning(
|
|
1270
|
+
model=model,
|
|
1271
|
+
X_val=X_val,
|
|
1272
|
+
params=params or getattr(self, "best_params_", {}),
|
|
1273
|
+
metric=metric_key,
|
|
1274
|
+
objective_mode=True,
|
|
1275
|
+
do_latent_infer=eval_requires_latents,
|
|
1276
|
+
latent_steps=eval_latent_steps,
|
|
1277
|
+
latent_lr=eval_latent_lr,
|
|
1278
|
+
latent_weight_decay=eval_latent_weight_decay,
|
|
1279
|
+
latent_seed=self.seed, # type: ignore
|
|
1280
|
+
_latent_cache=_latent_cache,
|
|
1281
|
+
_latent_cache_key=schema_key,
|
|
1282
|
+
eval_mask_override=mask_override,
|
|
1283
|
+
)
|
|
1284
|
+
|
|
1285
|
+
if phase == 3:
|
|
1286
|
+
trial.report(metric_val, step=epoch + 1)
|
|
1287
|
+
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
1288
|
+
raise optuna.exceptions.TrialPruned(
|
|
1289
|
+
f"Pruned at epoch {epoch + 1} (phase {phase}): {metric_key}={metric_val:.5f}"
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
history[f"Phase {phase}"] = phase_hist
|
|
1293
|
+
final_best_loss = early_stopping.best_score
|
|
1294
|
+
if early_stopping.best_model is not None:
|
|
1295
|
+
final_best_model = copy.deepcopy(early_stopping.best_model)
|
|
1296
|
+
else:
|
|
1297
|
+
final_best_model = copy.deepcopy(model)
|
|
1298
|
+
|
|
1299
|
+
if final_best_model is None:
|
|
1300
|
+
final_best_model = copy.deepcopy(model)
|
|
1301
|
+
|
|
1302
|
+
return final_best_loss, final_best_model, history, latent_vectors
|
|
1303
|
+
|
|
1304
|
+
def _optimize_latents_for_inference(
|
|
1305
|
+
self,
|
|
1306
|
+
X_new: np.ndarray,
|
|
1307
|
+
model: torch.nn.Module,
|
|
1308
|
+
params: dict,
|
|
1309
|
+
inference_epochs: int = 200,
|
|
1310
|
+
) -> torch.Tensor:
|
|
1311
|
+
"""Optimize latents for new 0/1/2 data with guards.
|
|
1312
|
+
|
|
1313
|
+
This method optimizes the latent vectors for new genotype data using the trained UBP model. It initializes the latent space based on the provided data and iteratively updates the latent vectors to minimize the cross-entropy loss between the model's predictions and the true genotype values. The optimization process includes numeric stability guards to ensure that gradients and losses remain finite. The optimized latent vectors are returned as a PyTorch tensor.
|
|
1314
|
+
|
|
1315
|
+
Args:
|
|
1316
|
+
X_new (np.ndarray): New 0/1/2 data with -1 for missing.
|
|
1317
|
+
model (torch.nn.Module): Trained UBP model.
|
|
1318
|
+
params (dict): Model params.
|
|
1319
|
+
inference_epochs (int): Number of optimization epochs.
|
|
1320
|
+
|
|
1321
|
+
Returns:
|
|
1322
|
+
torch.Tensor: Optimized latent vectors.
|
|
1323
|
+
"""
|
|
1324
|
+
model.eval()
|
|
1325
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
1326
|
+
|
|
1327
|
+
if self.tune and self.tune_fast:
|
|
1328
|
+
inference_epochs = min(
|
|
1329
|
+
inference_epochs, getattr(self, "tune_infer_epochs", 20)
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
X_new = X_new.astype(np.int64, copy=False)
|
|
1333
|
+
X_new[X_new < 0] = -1
|
|
1334
|
+
y = torch.from_numpy(X_new).long().to(self.device)
|
|
1335
|
+
|
|
1336
|
+
z = self._create_latent_space(
|
|
1337
|
+
params, len(X_new), X_new, self.latent_init
|
|
1338
|
+
).requires_grad_(True)
|
|
1339
|
+
opt = torch.optim.AdamW(
|
|
1340
|
+
[z], lr=self.learning_rate * self.lr_input_factor, eps=1e-7
|
|
1341
|
+
)
|
|
1342
|
+
|
|
1343
|
+
for _ in range(inference_epochs):
|
|
1344
|
+
decoder = model.phase23_decoder
|
|
1345
|
+
|
|
1346
|
+
if not isinstance(decoder, torch.nn.Module):
|
|
1347
|
+
msg = f"{self.model_name} Decoder is not a torch.nn.Module."
|
|
1348
|
+
self.logger.error(msg)
|
|
1349
|
+
raise TypeError(msg)
|
|
1350
|
+
|
|
1351
|
+
opt.zero_grad(set_to_none=True)
|
|
1352
|
+
logits = decoder(z).view(len(X_new), nF, self.num_classes_)
|
|
1353
|
+
|
|
1354
|
+
if not torch.isfinite(logits).all():
|
|
1355
|
+
break
|
|
1356
|
+
|
|
1357
|
+
loss = F.cross_entropy(
|
|
1358
|
+
logits.view(-1, self.num_classes_), y.view(-1), ignore_index=-1
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
if not torch.isfinite(loss):
|
|
1362
|
+
break
|
|
1363
|
+
|
|
1364
|
+
loss.backward()
|
|
1365
|
+
|
|
1366
|
+
torch.nn.utils.clip_grad_norm_([z], 1.0)
|
|
1367
|
+
|
|
1368
|
+
if z.grad is None or not torch.isfinite(z.grad).all():
|
|
1369
|
+
break
|
|
1370
|
+
|
|
1371
|
+
opt.step()
|
|
1372
|
+
|
|
1373
|
+
return z.detach()
|
|
1374
|
+
|
|
1375
|
+
def _create_latent_space(
|
|
1376
|
+
self,
|
|
1377
|
+
params: dict,
|
|
1378
|
+
n_samples: int,
|
|
1379
|
+
X: np.ndarray,
|
|
1380
|
+
latent_init: Literal["random", "pca"],
|
|
1381
|
+
) -> torch.nn.Parameter:
|
|
1382
|
+
"""Initialize latent space via random Xavier or PCA on 0/1/2 matrix.
|
|
1383
|
+
|
|
1384
|
+
This method initializes the latent space for the UBP model using either random Xavier initialization or PCA-based initialization. The choice of initialization strategy is determined by the latent_init parameter. If PCA is selected, the method handles missing values by imputing them with column means before performing PCA. The resulting latent vectors are standardized and converted to a PyTorch parameter that can be optimized during training.
|
|
1385
|
+
|
|
1386
|
+
Args:
|
|
1387
|
+
params (dict): Contains 'latent_dim'.
|
|
1388
|
+
n_samples (int): Number of samples.
|
|
1389
|
+
X (np.ndarray): (n_samples x L) 0/1/2 with -1 missing.
|
|
1390
|
+
latent_init (Literal["random","pca"]): Init strategy.
|
|
1391
|
+
|
|
1392
|
+
Returns:
|
|
1393
|
+
torch.nn.Parameter: Trainable latent matrix.
|
|
1394
|
+
"""
|
|
1395
|
+
latent_dim = int(params["latent_dim"])
|
|
1396
|
+
|
|
1397
|
+
if latent_init == "pca":
|
|
1398
|
+
X_pca = X.astype(np.float32, copy=True)
|
|
1399
|
+
# mark missing
|
|
1400
|
+
X_pca[X_pca < 0] = np.nan
|
|
1401
|
+
|
|
1402
|
+
# ---- SAFE column means without warnings ----
|
|
1403
|
+
valid_counts = np.sum(~np.isnan(X_pca), axis=0)
|
|
1404
|
+
col_sums = np.nansum(X_pca, axis=0)
|
|
1405
|
+
col_means = np.divide(
|
|
1406
|
+
col_sums,
|
|
1407
|
+
valid_counts,
|
|
1408
|
+
out=np.zeros_like(col_sums, dtype=np.float32),
|
|
1409
|
+
where=valid_counts > 0,
|
|
1410
|
+
)
|
|
1411
|
+
|
|
1412
|
+
# impute NaNs with per-column means
|
|
1413
|
+
# (all-NaN cols -> 0.0 by the divide above)
|
|
1414
|
+
nan_r, nan_c = np.where(np.isnan(X_pca))
|
|
1415
|
+
if nan_r.size:
|
|
1416
|
+
X_pca[nan_r, nan_c] = col_means[nan_c]
|
|
1417
|
+
|
|
1418
|
+
# center columns
|
|
1419
|
+
X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
|
|
1420
|
+
|
|
1421
|
+
# guard: degenerate / all-zero after centering ->
|
|
1422
|
+
# fall back to random
|
|
1423
|
+
if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
|
|
1424
|
+
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
1425
|
+
torch.nn.init.xavier_uniform_(latents)
|
|
1426
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
1427
|
+
|
|
1428
|
+
# rank-aware component count, at least 1
|
|
1429
|
+
try:
|
|
1430
|
+
est_rank = np.linalg.matrix_rank(X_pca)
|
|
1431
|
+
except Exception:
|
|
1432
|
+
est_rank = min(n_samples, X_pca.shape[1])
|
|
1433
|
+
|
|
1434
|
+
n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
|
|
1435
|
+
|
|
1436
|
+
# use deterministic SVD to avoid power-iteration warnings
|
|
1437
|
+
pca = PCA(
|
|
1438
|
+
n_components=n_components,
|
|
1439
|
+
svd_solver="randomized",
|
|
1440
|
+
random_state=self.seed,
|
|
1441
|
+
)
|
|
1442
|
+
initial = pca.fit_transform(X_pca) # (n_samples, n_components)
|
|
1443
|
+
|
|
1444
|
+
# pad if latent_dim > n_components
|
|
1445
|
+
if n_components < latent_dim:
|
|
1446
|
+
pad = self.rng.standard_normal(
|
|
1447
|
+
size=(n_samples, latent_dim - n_components)
|
|
1448
|
+
)
|
|
1449
|
+
initial = np.hstack([initial, pad])
|
|
1450
|
+
|
|
1451
|
+
# standardize latent dims
|
|
1452
|
+
initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
|
|
1453
|
+
|
|
1454
|
+
latents = torch.from_numpy(initial).float().to(self.device)
|
|
1455
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
1456
|
+
|
|
1457
|
+
else:
|
|
1458
|
+
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
1459
|
+
torch.nn.init.xavier_uniform_(latents)
|
|
1460
|
+
|
|
1461
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
1462
|
+
|
|
1463
|
+
def _reset_weights(self, model: torch.nn.Module) -> None:
|
|
1464
|
+
"""Selectively resets only the weights of the phase 2/3 decoder.
|
|
1465
|
+
|
|
1466
|
+
This method targets only the `phase23_decoder` attribute of the UBPModel, leaving the `phase1_decoder` and other potential model components untouched. This allows the model to be re-initialized for the second phase of training without affecting other parts.
|
|
1467
|
+
|
|
1468
|
+
Args:
|
|
1469
|
+
model (torch.nn.Module): The PyTorch model whose parameters are to be reset.
|
|
1470
|
+
"""
|
|
1471
|
+
if hasattr(model, "phase23_decoder"):
|
|
1472
|
+
decoder = model.phase23_decoder
|
|
1473
|
+
if not isinstance(decoder, torch.nn.Module):
|
|
1474
|
+
msg = f"{self.model_name} phase23_decoder is not a torch.nn.Module."
|
|
1475
|
+
self.logger.error(msg)
|
|
1476
|
+
raise TypeError(msg)
|
|
1477
|
+
# Iterate through only the modules of the second decoder
|
|
1478
|
+
for layer in decoder.modules():
|
|
1479
|
+
if hasattr(layer, "reset_parameters") and isinstance(
|
|
1480
|
+
layer.reset_parameters, torch.nn.Module
|
|
1481
|
+
):
|
|
1482
|
+
layer.reset_parameters()
|
|
1483
|
+
else:
|
|
1484
|
+
self.logger.warning(
|
|
1485
|
+
"Model does not have a 'phase23_decoder' attribute; skipping weight reset."
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
def _latent_infer_for_eval(
|
|
1489
|
+
self,
|
|
1490
|
+
model: torch.nn.Module,
|
|
1491
|
+
X_val: np.ndarray,
|
|
1492
|
+
*,
|
|
1493
|
+
steps: int,
|
|
1494
|
+
lr: float,
|
|
1495
|
+
weight_decay: float,
|
|
1496
|
+
seed: int,
|
|
1497
|
+
cache: dict | None,
|
|
1498
|
+
cache_key: str | None,
|
|
1499
|
+
) -> None:
|
|
1500
|
+
"""Freeze network; refine validation latents only with guards.
|
|
1501
|
+
|
|
1502
|
+
This method refines the latent vectors for the validation dataset using the trained UBP model. It freezes the model parameters to prevent updates during this phase and optimizes the latent vectors to minimize the cross-entropy loss between the model's predictions and the true genotype values. The optimization process includes numeric stability checks to ensure that gradients and losses remain finite. If a cache is provided, it stores the optimized latent vectors for future use.
|
|
1503
|
+
|
|
1504
|
+
Args:
|
|
1505
|
+
model (torch.nn.Module): Trained UBP model.
|
|
1506
|
+
X_val (np.ndarray): Validation set 0/1/2 with -1 missing
|
|
1507
|
+
steps (int): Number of optimization steps.
|
|
1508
|
+
lr (float): Learning rate for latent optimization.
|
|
1509
|
+
weight_decay (float): Weight decay for latent optimization.
|
|
1510
|
+
seed (int): Random seed for reproducibility.
|
|
1511
|
+
cache (dict | None): Optional cache for latent vectors.
|
|
1512
|
+
cache_key (str | None): Key for storing/retrieving from cache.
|
|
1513
|
+
"""
|
|
1514
|
+
if seed is None:
|
|
1515
|
+
seed = np.random.randint(0, 999_999)
|
|
1516
|
+
|
|
1517
|
+
torch.manual_seed(seed)
|
|
1518
|
+
np.random.seed(seed)
|
|
1519
|
+
|
|
1520
|
+
model.eval()
|
|
1521
|
+
for p in model.parameters():
|
|
1522
|
+
p.requires_grad_(False)
|
|
1523
|
+
|
|
1524
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
1525
|
+
X_val = X_val.astype(np.int64, copy=False)
|
|
1526
|
+
X_val[X_val < 0] = -1
|
|
1527
|
+
y = torch.from_numpy(X_val).long().to(self.device)
|
|
1528
|
+
|
|
1529
|
+
zdim = self._first_linear_in_features(model)
|
|
1530
|
+
schema_key = f"{self.prefix}_ubp_val_latents_z{zdim}_L{nF}_K{self.num_classes_}"
|
|
1531
|
+
|
|
1532
|
+
if cache is not None and schema_key in cache:
|
|
1533
|
+
z = cache[schema_key].detach().clone().requires_grad_(True)
|
|
1534
|
+
else:
|
|
1535
|
+
z = self._create_latent_space(
|
|
1536
|
+
{"latent_dim": zdim}, X_val.shape[0], X_val, self.latent_init
|
|
1537
|
+
).requires_grad_(True)
|
|
1538
|
+
|
|
1539
|
+
opt = torch.optim.AdamW([z], lr=lr, weight_decay=weight_decay, eps=1e-7)
|
|
1540
|
+
|
|
1541
|
+
for _ in range(max(int(steps), 0)):
|
|
1542
|
+
opt.zero_grad(set_to_none=True)
|
|
1543
|
+
|
|
1544
|
+
decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
|
|
1545
|
+
|
|
1546
|
+
if not isinstance(decoder, torch.nn.Module):
|
|
1547
|
+
msg = f"{self.model_name} Decoder is not a torch.nn.Module."
|
|
1548
|
+
self.logger.error(msg)
|
|
1549
|
+
raise TypeError(msg)
|
|
1550
|
+
|
|
1551
|
+
logits = decoder(z).view(X_val.shape[0], nF, self.num_classes_)
|
|
1552
|
+
if not torch.isfinite(logits).all():
|
|
1553
|
+
break
|
|
1554
|
+
|
|
1555
|
+
loss = F.cross_entropy(
|
|
1556
|
+
logits.view(-1, self.num_classes_), y.view(-1), ignore_index=-1
|
|
1557
|
+
)
|
|
1558
|
+
|
|
1559
|
+
if not torch.isfinite(loss):
|
|
1560
|
+
break
|
|
1561
|
+
|
|
1562
|
+
loss.backward()
|
|
1563
|
+
|
|
1564
|
+
torch.nn.utils.clip_grad_norm_([z], 1.0)
|
|
1565
|
+
|
|
1566
|
+
if z.grad is None or not torch.isfinite(z.grad).all():
|
|
1567
|
+
break
|
|
1568
|
+
|
|
1569
|
+
opt.step()
|
|
1570
|
+
|
|
1571
|
+
if cache is not None:
|
|
1572
|
+
cache[schema_key] = z.detach().clone()
|
|
1573
|
+
|
|
1574
|
+
for p in model.parameters():
|
|
1575
|
+
p.requires_grad_(True)
|