pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -0,0 +1,1264 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, Literal, Tuple
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import optuna
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from sklearn.decomposition import PCA
|
|
11
|
+
from sklearn.exceptions import NotFittedError
|
|
12
|
+
from sklearn.model_selection import train_test_split
|
|
13
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
14
|
+
from snpio.utils.logging import LoggerManager
|
|
15
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
16
|
+
|
|
17
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
18
|
+
from pgsui.data_processing.containers import NLPCAConfig
|
|
19
|
+
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
20
|
+
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
21
|
+
from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def ensure_nlpca_config(config: NLPCAConfig | dict | str | None) -> NLPCAConfig:
|
|
28
|
+
"""Return a concrete NLPCAConfig from dataclass, dict, YAML path, or None.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
NLPCAConfig: Concrete configuration instance.
|
|
35
|
+
"""
|
|
36
|
+
if config is None:
|
|
37
|
+
return NLPCAConfig()
|
|
38
|
+
if isinstance(config, NLPCAConfig):
|
|
39
|
+
return config
|
|
40
|
+
if isinstance(config, str):
|
|
41
|
+
# YAML path — top-level `preset` key is supported
|
|
42
|
+
return load_yaml_to_dataclass(
|
|
43
|
+
config,
|
|
44
|
+
NLPCAConfig,
|
|
45
|
+
preset_builder=NLPCAConfig.from_preset,
|
|
46
|
+
)
|
|
47
|
+
if isinstance(config, dict):
|
|
48
|
+
# Flatten dict into dot-keys then overlay onto a fresh instance
|
|
49
|
+
base = NLPCAConfig()
|
|
50
|
+
|
|
51
|
+
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
52
|
+
for k, v in d.items():
|
|
53
|
+
kk = f"{prefix}.{k}" if prefix else k
|
|
54
|
+
if isinstance(v, dict):
|
|
55
|
+
_flatten(kk, v, out)
|
|
56
|
+
else:
|
|
57
|
+
out[kk] = v
|
|
58
|
+
return out
|
|
59
|
+
|
|
60
|
+
# Lift any present preset first
|
|
61
|
+
preset_name = config.pop("preset", None)
|
|
62
|
+
if "io" in config and isinstance(config["io"], dict):
|
|
63
|
+
preset_name = preset_name or config["io"].pop("preset", None)
|
|
64
|
+
|
|
65
|
+
if preset_name:
|
|
66
|
+
base = NLPCAConfig.from_preset(preset_name)
|
|
67
|
+
|
|
68
|
+
flat = _flatten("", config, {})
|
|
69
|
+
return apply_dot_overrides(base, flat)
|
|
70
|
+
|
|
71
|
+
raise TypeError("config must be an NLPCAConfig, dict, YAML path, or None.")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ImputeNLPCA(BaseNNImputer):
|
|
75
|
+
"""Imputes missing genotypes using a Non-linear Principal Component Analysis (NLPCA) model.
|
|
76
|
+
|
|
77
|
+
This class implements an imputer based on Non-linear Principal Component Analysis (NLPCA) using a neural network architecture. It is designed to handle genotype data encoded in 0/1/2 format, where 0 represents the reference allele, 1 represents the heterozygous genotype, and 2 represents the alternate allele. Missing genotypes should be represented as -9 or -1.
|
|
78
|
+
|
|
79
|
+
The NLPCA model consists of an encoder-decoder architecture that learns a low-dimensional latent representation of the genotype data. The model is trained using a focal loss function to address class imbalance, and it can incorporate L1 regularization to promote sparsity in the learned representations.
|
|
80
|
+
|
|
81
|
+
Notes:
|
|
82
|
+
- Supports both haploid and diploid genotype data.
|
|
83
|
+
- Configurable model architecture with options for latent dimension, dropout rate, number of hidden layers, and activation functions.
|
|
84
|
+
- Hyperparameter tuning using Optuna for optimal model performance.
|
|
85
|
+
- Evaluation metrics including accuracy, F1-score, precision, recall, and ROC-AUC.
|
|
86
|
+
- Visualization of training history and genotype distributions.
|
|
87
|
+
- Flexible configuration via dataclass, dictionary, or YAML file.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
>>> from snpio import VCFReader
|
|
91
|
+
>>> from pgsui import ImputeNLPCA
|
|
92
|
+
>>> gdata = VCFReader("genotypes.vcf.gz")
|
|
93
|
+
>>> imputer = ImputeNLPCA(gdata, config="nlpca_config.yaml")
|
|
94
|
+
>>> imputer.fit()
|
|
95
|
+
>>> imputed_genotypes = imputer.transform()
|
|
96
|
+
>>> print(imputed_genotypes)
|
|
97
|
+
[['A' 'G' 'C' ...],
|
|
98
|
+
['G' 'G' 'C' ...],
|
|
99
|
+
...
|
|
100
|
+
['T' 'C' 'A' ...],
|
|
101
|
+
['C' 'C' 'C' ...]]
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
genotype_data: "GenotypeData",
|
|
107
|
+
*,
|
|
108
|
+
config: NLPCAConfig | dict | str | None = None,
|
|
109
|
+
overrides: dict | None = None,
|
|
110
|
+
):
|
|
111
|
+
"""Initializes the ImputeNLPCA imputer with genotype data and configuration.
|
|
112
|
+
|
|
113
|
+
This constructor sets up the ImputeNLPCA imputer by accepting genotype data and a configuration that can be provided in various formats. It initializes logging, device settings, and model parameters based on the provided configuration.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
genotype_data (GenotypeData): Backing genotype data.
|
|
117
|
+
config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
118
|
+
overrides (dict | None): Dot-key overrides (e.g. {'model.latent_dim': 4}).
|
|
119
|
+
"""
|
|
120
|
+
self.model_name = "ImputeNLPCA"
|
|
121
|
+
self.genotype_data = genotype_data
|
|
122
|
+
|
|
123
|
+
# Normalize config first, then apply overrides (highest precedence)
|
|
124
|
+
cfg = ensure_nlpca_config(config)
|
|
125
|
+
if overrides:
|
|
126
|
+
cfg = apply_dot_overrides(cfg, overrides)
|
|
127
|
+
|
|
128
|
+
self.cfg = cfg
|
|
129
|
+
|
|
130
|
+
logman = LoggerManager(
|
|
131
|
+
__name__,
|
|
132
|
+
prefix=self.cfg.io.prefix,
|
|
133
|
+
debug=self.cfg.io.debug,
|
|
134
|
+
verbose=self.cfg.io.verbose,
|
|
135
|
+
)
|
|
136
|
+
self.logger = logman.get_logger()
|
|
137
|
+
|
|
138
|
+
# Initialize BaseNNImputer with device/dirs/logging from config
|
|
139
|
+
super().__init__(
|
|
140
|
+
prefix=self.cfg.io.prefix,
|
|
141
|
+
device=self.cfg.train.device,
|
|
142
|
+
verbose=self.cfg.io.verbose,
|
|
143
|
+
debug=self.cfg.io.debug,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.Model = NLPCAModel
|
|
147
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
148
|
+
self.seed = self.cfg.io.seed
|
|
149
|
+
self.n_jobs = self.cfg.io.n_jobs
|
|
150
|
+
self.prefix = self.cfg.io.prefix
|
|
151
|
+
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
152
|
+
self.verbose = self.cfg.io.verbose
|
|
153
|
+
self.debug = self.cfg.io.debug
|
|
154
|
+
|
|
155
|
+
self.rng = np.random.default_rng(self.seed)
|
|
156
|
+
|
|
157
|
+
# Model/train hyperparams
|
|
158
|
+
self.latent_dim = self.cfg.model.latent_dim
|
|
159
|
+
self.dropout_rate = self.cfg.model.dropout_rate
|
|
160
|
+
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
161
|
+
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
162
|
+
self.layer_schedule = self.cfg.model.layer_schedule
|
|
163
|
+
self.latent_init = self.cfg.model.latent_init
|
|
164
|
+
self.activation = self.cfg.model.hidden_activation
|
|
165
|
+
self.gamma = self.cfg.model.gamma
|
|
166
|
+
|
|
167
|
+
self.batch_size = self.cfg.train.batch_size
|
|
168
|
+
self.learning_rate = self.cfg.train.learning_rate
|
|
169
|
+
self.lr_input_factor = self.cfg.train.lr_input_factor
|
|
170
|
+
self.l1_penalty = self.cfg.train.l1_penalty
|
|
171
|
+
self.early_stop_gen = self.cfg.train.early_stop_gen
|
|
172
|
+
self.min_epochs = self.cfg.train.min_epochs
|
|
173
|
+
self.epochs = self.cfg.train.max_epochs
|
|
174
|
+
self.validation_split = self.cfg.train.validation_split
|
|
175
|
+
self.beta = self.cfg.train.weights_beta
|
|
176
|
+
self.max_ratio = self.cfg.train.weights_max_ratio
|
|
177
|
+
|
|
178
|
+
# Tuning
|
|
179
|
+
self.tune = self.cfg.tune.enabled
|
|
180
|
+
self.tune_fast = self.cfg.tune.fast
|
|
181
|
+
self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
|
|
182
|
+
self.tune_batch_size = self.cfg.tune.batch_size
|
|
183
|
+
self.tune_epochs = self.cfg.tune.epochs
|
|
184
|
+
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
185
|
+
self.tune_metric = self.cfg.tune.metric
|
|
186
|
+
self.n_trials = self.cfg.tune.n_trials
|
|
187
|
+
self.tune_save_db = self.cfg.tune.save_db
|
|
188
|
+
self.tune_resume = self.cfg.tune.resume
|
|
189
|
+
self.tune_max_samples = self.cfg.tune.max_samples
|
|
190
|
+
self.tune_max_loci = self.cfg.tune.max_loci
|
|
191
|
+
self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
|
|
192
|
+
self.tune_patience = self.cfg.tune.patience
|
|
193
|
+
|
|
194
|
+
# Eval
|
|
195
|
+
self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
|
|
196
|
+
self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
|
|
197
|
+
self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
|
|
198
|
+
|
|
199
|
+
# Plotting (note: PlotConfig has 'show', not 'show_plots')
|
|
200
|
+
self.plot_format = self.cfg.plot.fmt
|
|
201
|
+
self.plot_dpi = self.cfg.plot.dpi
|
|
202
|
+
self.plot_fontsize = self.cfg.plot.fontsize
|
|
203
|
+
self.title_fontsize = self.cfg.plot.fontsize
|
|
204
|
+
self.despine = self.cfg.plot.despine
|
|
205
|
+
self.show_plots = self.cfg.plot.show
|
|
206
|
+
|
|
207
|
+
# Core model config
|
|
208
|
+
self.is_haploid = None
|
|
209
|
+
self.num_classes_ = None
|
|
210
|
+
self.model_params: Dict[str, Any] = {}
|
|
211
|
+
|
|
212
|
+
def fit(self) -> "ImputeNLPCA":
|
|
213
|
+
"""Fits the NLPCA model to the 0/1/2 encoded genotype data.
|
|
214
|
+
|
|
215
|
+
This method prepares the data, splits it into training and validation sets, initializes the model, and trains it. If hyperparameter tuning is enabled, it will perform tuning before final training. After training, it evaluates the model on a test set and generates relevant plots.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
ImputeNLPCA: The fitted imputer instance.
|
|
219
|
+
"""
|
|
220
|
+
self.logger.info(f"Fitting {self.model_name} model...")
|
|
221
|
+
|
|
222
|
+
# --- DATA PREPARATION ---
|
|
223
|
+
X = self.pgenc.genotypes_012.astype(np.float32)
|
|
224
|
+
X[X < 0] = np.nan # Ensure missing are NaN
|
|
225
|
+
X[np.isnan(X)] = -1 # Use -1 for missing, required by loss function
|
|
226
|
+
self.ground_truth_ = X.astype(np.int64)
|
|
227
|
+
|
|
228
|
+
# --- Determine Ploidy and Number of Classes ---
|
|
229
|
+
self.is_haploid = np.all(
|
|
230
|
+
np.isin(
|
|
231
|
+
self.genotype_data.snp_data, ["A", "C", "G", "T", "N", "-", ".", "?"]
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self.ploidy = 1 if self.is_haploid else 2
|
|
236
|
+
|
|
237
|
+
if self.is_haploid:
|
|
238
|
+
self.num_classes_ = 2
|
|
239
|
+
|
|
240
|
+
# Remap labels from {0, 2} to {0, 1}
|
|
241
|
+
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
242
|
+
self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
|
|
243
|
+
else:
|
|
244
|
+
self.num_classes_ = 3
|
|
245
|
+
|
|
246
|
+
self.logger.info(
|
|
247
|
+
"Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
n_samples, self.num_features_ = X.shape
|
|
251
|
+
|
|
252
|
+
self.model_params = {
|
|
253
|
+
"n_features": self.num_features_,
|
|
254
|
+
"latent_dim": self.latent_dim,
|
|
255
|
+
"dropout_rate": self.dropout_rate,
|
|
256
|
+
"activation": self.activation,
|
|
257
|
+
"gamma": self.gamma,
|
|
258
|
+
"num_classes": self.num_classes_,
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
# --- Train/Test Split ---
|
|
262
|
+
indices = np.arange(n_samples)
|
|
263
|
+
train_idx, test_idx = train_test_split(
|
|
264
|
+
indices, test_size=self.validation_split, random_state=self.seed
|
|
265
|
+
)
|
|
266
|
+
self.train_idx_, self.test_idx_ = train_idx, test_idx
|
|
267
|
+
self.X_train_ = self.ground_truth_[train_idx]
|
|
268
|
+
self.X_test_ = self.ground_truth_[test_idx]
|
|
269
|
+
|
|
270
|
+
# --- Tuning & Model Setup ---
|
|
271
|
+
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
272
|
+
|
|
273
|
+
if self.tune:
|
|
274
|
+
self.tune_hyperparameters()
|
|
275
|
+
self.best_params_ = getattr(self, "best_params_", self.model_params.copy())
|
|
276
|
+
else:
|
|
277
|
+
self.best_params_ = self._set_best_params_default()
|
|
278
|
+
|
|
279
|
+
# Class weights from 0/1/2 training data
|
|
280
|
+
self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
|
|
281
|
+
|
|
282
|
+
# Latent vectors for training set
|
|
283
|
+
train_latent_vectors = self._create_latent_space(
|
|
284
|
+
self.best_params_, len(self.X_train_), self.X_train_, self.latent_init
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
train_loader = self._get_data_loaders(self.X_train_)
|
|
288
|
+
|
|
289
|
+
# Train the final model
|
|
290
|
+
(self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
|
|
291
|
+
self._train_final_model(
|
|
292
|
+
train_loader, self.best_params_, train_latent_vectors
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
self.is_fit_ = True
|
|
297
|
+
self.plotter_.plot_history(self.history_)
|
|
298
|
+
self._evaluate_model(self.X_test_, self.model_, self.best_params_)
|
|
299
|
+
self._save_best_params(self.best_params_)
|
|
300
|
+
return self
|
|
301
|
+
|
|
302
|
+
def transform(self) -> np.ndarray:
|
|
303
|
+
"""Imputes missing genotypes using the trained model.
|
|
304
|
+
|
|
305
|
+
This method uses the trained NLPCA model to impute missing genotypes in the entire dataset. It optimizes latent vectors for all samples, predicts missing values, and fills them in. The imputed genotypes are returned in IUPAC string format.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
np.ndarray: Imputed genotypes in IUPAC string format.
|
|
309
|
+
|
|
310
|
+
Raises:
|
|
311
|
+
NotFittedError: If the model has not been fitted.
|
|
312
|
+
"""
|
|
313
|
+
if not getattr(self, "is_fit_", False):
|
|
314
|
+
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
315
|
+
|
|
316
|
+
self.logger.info("Imputing entire dataset...")
|
|
317
|
+
X_to_impute = self.ground_truth_.copy()
|
|
318
|
+
|
|
319
|
+
# Optimize latents for the full dataset
|
|
320
|
+
optimized_latents = self._optimize_latents_for_inference(
|
|
321
|
+
X_to_impute, self.model_, self.best_params_
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Predict missing values
|
|
325
|
+
pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
|
|
326
|
+
|
|
327
|
+
# Fill in missing values
|
|
328
|
+
missing_mask = X_to_impute == -1
|
|
329
|
+
imputed_array = X_to_impute.copy()
|
|
330
|
+
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
331
|
+
|
|
332
|
+
# Decode back to IUPAC strings
|
|
333
|
+
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
334
|
+
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
335
|
+
|
|
336
|
+
# Plot distributions
|
|
337
|
+
plt.rcParams.update(self.plotter_.param_dict) # Ensure consistent style
|
|
338
|
+
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
339
|
+
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
340
|
+
|
|
341
|
+
return imputed_genotypes
|
|
342
|
+
|
|
343
|
+
def _train_step(
|
|
344
|
+
self,
|
|
345
|
+
loader: torch.utils.data.DataLoader,
|
|
346
|
+
optimizer: torch.optim.Optimizer,
|
|
347
|
+
latent_optimizer: torch.optim.Optimizer,
|
|
348
|
+
model: torch.nn.Module,
|
|
349
|
+
l1_penalty: float,
|
|
350
|
+
latent_vectors: torch.nn.Parameter,
|
|
351
|
+
class_weights: torch.Tensor,
|
|
352
|
+
) -> Tuple[float, torch.nn.Parameter]:
|
|
353
|
+
"""Performs one epoch of training.
|
|
354
|
+
|
|
355
|
+
This method executes a single training epoch for the NLPCA model. It processes batches of data, computes the focal loss while handling missing values, applies L1 regularization if specified, and updates both the model parameters and latent vectors using their respective optimizers.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
359
|
+
optimizer (torch.optim.Optimizer): Optimizer for model parameters.
|
|
360
|
+
latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
|
|
361
|
+
model (torch.nn.Module): The NLPCA model.
|
|
362
|
+
l1_penalty (float): L1 regularization penalty.
|
|
363
|
+
latent_vectors (torch.nn.Parameter): Latent vectors for samples.
|
|
364
|
+
class_weights (torch.Tensor): Class weights for handling class imbalance.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Tuple[float, torch.nn.Parameter]: Average training loss and updated latent vectors.
|
|
368
|
+
"""
|
|
369
|
+
model.train()
|
|
370
|
+
running_loss = 0.0
|
|
371
|
+
|
|
372
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
373
|
+
|
|
374
|
+
for batch_indices, y_batch in loader:
|
|
375
|
+
optimizer.zero_grad(set_to_none=True)
|
|
376
|
+
latent_optimizer.zero_grad(set_to_none=True)
|
|
377
|
+
|
|
378
|
+
logits = model.phase23_decoder(latent_vectors[batch_indices]).view(
|
|
379
|
+
len(batch_indices), nF, self.num_classes_
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# --- Simplified Focal Loss on 0/1/2 Classes ---
|
|
383
|
+
logits_flat = logits.view(-1, self.num_classes_)
|
|
384
|
+
targets_flat = y_batch.view(-1)
|
|
385
|
+
|
|
386
|
+
ce_loss = F.cross_entropy(
|
|
387
|
+
logits_flat,
|
|
388
|
+
targets_flat,
|
|
389
|
+
weight=class_weights,
|
|
390
|
+
reduction="none",
|
|
391
|
+
ignore_index=-1,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
pt = torch.exp(-ce_loss)
|
|
395
|
+
gamma = getattr(model, "gamma", self.gamma)
|
|
396
|
+
focal_loss = ((1 - pt) ** gamma) * ce_loss
|
|
397
|
+
|
|
398
|
+
valid_mask = targets_flat != -1
|
|
399
|
+
loss = focal_loss[valid_mask].mean() if valid_mask.any() else 0.0
|
|
400
|
+
|
|
401
|
+
if l1_penalty > 0:
|
|
402
|
+
loss += l1_penalty * sum(p.abs().sum() for p in model.parameters())
|
|
403
|
+
|
|
404
|
+
loss.backward()
|
|
405
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
406
|
+
torch.nn.utils.clip_grad_norm_([latent_vectors], max_norm=1.0)
|
|
407
|
+
optimizer.step()
|
|
408
|
+
latent_optimizer.step()
|
|
409
|
+
|
|
410
|
+
running_loss += loss.item()
|
|
411
|
+
|
|
412
|
+
return running_loss / len(loader), latent_vectors
|
|
413
|
+
|
|
414
|
+
def _predict(
|
|
415
|
+
self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter | None = None
|
|
416
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
417
|
+
"""Generates 0/1/2 predictions from latent vectors.
|
|
418
|
+
|
|
419
|
+
This method uses the trained NLPCA model to generate predictions from the latent vectors by passing them through the decoder. It returns both the predicted labels and their associated probabilities.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
model (torch.nn.Module): Trained NLPCA model.
|
|
423
|
+
latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
|
|
427
|
+
"""
|
|
428
|
+
if model is None or latent_vectors is None:
|
|
429
|
+
raise NotFittedError("Model or latent vectors not available.")
|
|
430
|
+
|
|
431
|
+
model.eval()
|
|
432
|
+
|
|
433
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
434
|
+
|
|
435
|
+
with torch.no_grad():
|
|
436
|
+
logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
|
|
437
|
+
len(latent_vectors), nF, self.num_classes_
|
|
438
|
+
)
|
|
439
|
+
probas = torch.softmax(logits, dim=-1)
|
|
440
|
+
labels = torch.argmax(probas, dim=-1)
|
|
441
|
+
|
|
442
|
+
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
443
|
+
|
|
444
|
+
def _evaluate_model(
|
|
445
|
+
self,
|
|
446
|
+
X_val: np.ndarray,
|
|
447
|
+
model: torch.nn.Module,
|
|
448
|
+
params: dict,
|
|
449
|
+
objective_mode: bool = False,
|
|
450
|
+
latent_vectors_val: torch.Tensor | None = None,
|
|
451
|
+
) -> Dict[str, float]:
|
|
452
|
+
"""Evaluates the model on a validation set.
|
|
453
|
+
|
|
454
|
+
This method evaluates the trained NLPCA model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
|
|
458
|
+
model (torch.nn.Module): Trained NLPCA model.
|
|
459
|
+
params (dict): Model parameters.
|
|
460
|
+
objective_mode (bool): If True, suppresses logging and reports only the metric.
|
|
461
|
+
latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Dict[str, float]: Dictionary of evaluation metrics.
|
|
465
|
+
"""
|
|
466
|
+
if latent_vectors_val is not None:
|
|
467
|
+
test_latent_vectors = latent_vectors_val
|
|
468
|
+
else:
|
|
469
|
+
test_latent_vectors = self._optimize_latents_for_inference(
|
|
470
|
+
X_val, model, params
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# The rest of the function remains the same...
|
|
474
|
+
pred_labels, pred_probas = self._predict(
|
|
475
|
+
model=model, latent_vectors=test_latent_vectors
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
eval_mask = X_val != -1
|
|
479
|
+
y_true_flat = X_val[eval_mask]
|
|
480
|
+
pred_labels_flat = pred_labels[eval_mask]
|
|
481
|
+
pred_probas_flat = pred_probas[eval_mask]
|
|
482
|
+
|
|
483
|
+
if y_true_flat.size == 0:
|
|
484
|
+
return {self.tune_metric: 0.0}
|
|
485
|
+
|
|
486
|
+
# For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
|
|
487
|
+
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
488
|
+
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
489
|
+
|
|
490
|
+
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
491
|
+
|
|
492
|
+
metrics = self.scorers_.evaluate(
|
|
493
|
+
y_true_flat,
|
|
494
|
+
pred_labels_flat,
|
|
495
|
+
y_true_ohe,
|
|
496
|
+
pred_probas_flat,
|
|
497
|
+
objective_mode,
|
|
498
|
+
self.tune_metric,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
if not objective_mode:
|
|
502
|
+
self.logger.info(f"Validation Metrics: {metrics}")
|
|
503
|
+
|
|
504
|
+
self._make_class_reports(
|
|
505
|
+
y_true=y_true_flat,
|
|
506
|
+
y_pred_proba=pred_probas_flat,
|
|
507
|
+
y_pred=pred_labels_flat,
|
|
508
|
+
metrics=metrics,
|
|
509
|
+
labels=target_names,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
y_true_dec = self.pgenc.decode_012(X_val)
|
|
513
|
+
X_pred = X_val.copy()
|
|
514
|
+
X_pred[eval_mask] = pred_labels_flat
|
|
515
|
+
|
|
516
|
+
nF_eval = X_val.shape[1]
|
|
517
|
+
y_pred_dec = self.pgenc.decode_012(X_pred.reshape(X_val.shape[0], nF_eval))
|
|
518
|
+
|
|
519
|
+
encodings_dict = {
|
|
520
|
+
"A": 0,
|
|
521
|
+
"C": 1,
|
|
522
|
+
"G": 2,
|
|
523
|
+
"T": 3,
|
|
524
|
+
"W": 4,
|
|
525
|
+
"R": 5,
|
|
526
|
+
"M": 6,
|
|
527
|
+
"K": 7,
|
|
528
|
+
"Y": 8,
|
|
529
|
+
"S": 9,
|
|
530
|
+
"N": -1,
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
y_true_int = self.pgenc.convert_int_iupac(
|
|
534
|
+
y_true_dec, encodings_dict=encodings_dict
|
|
535
|
+
)
|
|
536
|
+
y_pred_int = self.pgenc.convert_int_iupac(
|
|
537
|
+
y_pred_dec, encodings_dict=encodings_dict
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
self._make_class_reports(
|
|
541
|
+
y_true=y_true_int[eval_mask],
|
|
542
|
+
y_pred=y_pred_int[eval_mask],
|
|
543
|
+
metrics=metrics,
|
|
544
|
+
y_pred_proba=None,
|
|
545
|
+
labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
return metrics
|
|
549
|
+
|
|
550
|
+
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
551
|
+
"""Creates a PyTorch DataLoader for the 0/1/2 encoded data.
|
|
552
|
+
|
|
553
|
+
This method constructs a DataLoader from the provided genotype data, which is expected to be in 0/1/2 encoding with -1 for missing values. The DataLoader is used for batching and shuffling the data during model training. It converts the numpy array to a PyTorch tensor and creates a TensorDataset. The DataLoader is configured with the specified batch size and shuffling enabled.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
y (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
|
|
557
|
+
|
|
558
|
+
Returns:
|
|
559
|
+
torch.utils.data.DataLoader: DataLoader for the dataset.
|
|
560
|
+
"""
|
|
561
|
+
y_tensor = torch.from_numpy(y).long().to(self.device)
|
|
562
|
+
dataset = torch.utils.data.TensorDataset(
|
|
563
|
+
torch.arange(len(y), device=self.device), y_tensor.to(self.device)
|
|
564
|
+
)
|
|
565
|
+
return torch.utils.data.DataLoader(
|
|
566
|
+
dataset, batch_size=self.batch_size, shuffle=True
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
def _create_latent_space(
|
|
570
|
+
self,
|
|
571
|
+
params: dict,
|
|
572
|
+
n_samples: int,
|
|
573
|
+
X: np.ndarray,
|
|
574
|
+
latent_init: Literal["random", "pca"],
|
|
575
|
+
) -> torch.nn.Parameter:
|
|
576
|
+
"""Initializes the latent space for the NLPCA model.
|
|
577
|
+
|
|
578
|
+
This method initializes the latent space for the NLPCA model based on the specified initialization method. It supports two methods: 'random' initialization using Xavier uniform distribution, and 'pca' initialization which uses PCA to derive initial latent vectors from the data. The latent vectors are returned as a PyTorch Parameter, allowing them to be optimized during training.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
params (dict): Model parameters including 'latent_dim'.
|
|
582
|
+
n_samples (int): Number of samples in the dataset.
|
|
583
|
+
X (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
|
|
584
|
+
latent_init (str): Method to initialize latent space ('random' or 'pca').
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
torch.nn.Parameter: Initialized latent vectors as a PyTorch Parameter.
|
|
588
|
+
"""
|
|
589
|
+
latent_dim = int(params["latent_dim"])
|
|
590
|
+
|
|
591
|
+
if latent_init == "pca":
|
|
592
|
+
X_pca = X.astype(np.float32, copy=True)
|
|
593
|
+
# mark missing
|
|
594
|
+
X_pca[X_pca < 0] = np.nan
|
|
595
|
+
|
|
596
|
+
# ---- SAFE column means without warnings ----
|
|
597
|
+
valid_counts = np.sum(~np.isnan(X_pca), axis=0)
|
|
598
|
+
col_sums = np.nansum(X_pca, axis=0)
|
|
599
|
+
col_means = np.divide(
|
|
600
|
+
col_sums,
|
|
601
|
+
valid_counts,
|
|
602
|
+
out=np.zeros_like(col_sums, dtype=np.float32),
|
|
603
|
+
where=valid_counts > 0,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# impute NaNs with per-column means
|
|
607
|
+
# (all-NaN cols -> 0.0 by the divide above)
|
|
608
|
+
nan_r, nan_c = np.where(np.isnan(X_pca))
|
|
609
|
+
if nan_r.size:
|
|
610
|
+
X_pca[nan_r, nan_c] = col_means[nan_c]
|
|
611
|
+
|
|
612
|
+
# center columns
|
|
613
|
+
X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
|
|
614
|
+
|
|
615
|
+
# guard: degenerate / all-zero after centering ->
|
|
616
|
+
# fall back to random
|
|
617
|
+
if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
|
|
618
|
+
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
619
|
+
torch.nn.init.xavier_uniform_(latents)
|
|
620
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
621
|
+
|
|
622
|
+
# rank-aware component count, at least 1
|
|
623
|
+
try:
|
|
624
|
+
est_rank = np.linalg.matrix_rank(X_pca)
|
|
625
|
+
except Exception:
|
|
626
|
+
est_rank = min(n_samples, X_pca.shape[1])
|
|
627
|
+
|
|
628
|
+
n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
|
|
629
|
+
|
|
630
|
+
# use deterministic SVD to avoid power-iteration warnings
|
|
631
|
+
pca = PCA(
|
|
632
|
+
n_components=n_components, svd_solver="full", random_state=self.seed
|
|
633
|
+
)
|
|
634
|
+
initial = pca.fit_transform(X_pca) # (n_samples, n_components)
|
|
635
|
+
|
|
636
|
+
# pad if latent_dim > n_components
|
|
637
|
+
if n_components < latent_dim:
|
|
638
|
+
pad = self.rng.standard_normal(
|
|
639
|
+
size=(n_samples, latent_dim - n_components)
|
|
640
|
+
)
|
|
641
|
+
initial = np.hstack([initial, pad])
|
|
642
|
+
|
|
643
|
+
# standardize latent dims
|
|
644
|
+
initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
|
|
645
|
+
|
|
646
|
+
latents = torch.from_numpy(initial).float().to(self.device)
|
|
647
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
648
|
+
|
|
649
|
+
# --- Random init path (unchanged) ---
|
|
650
|
+
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
651
|
+
torch.nn.init.xavier_uniform_(latents)
|
|
652
|
+
return torch.nn.Parameter(latents, requires_grad=True)
|
|
653
|
+
|
|
654
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
655
|
+
"""Objective function for hyperparameter tuning with Optuna.
|
|
656
|
+
|
|
657
|
+
This method defines the objective function used by Optuna for hyperparameter tuning of the NLPCA model. It samples a set of hyperparameters, prepares the training and validation data, initializes the model and latent vectors, and trains the model. After training, it evaluates the model on a validation set and returns the value of the specified tuning metric.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
float: The value of the tuning metric to be minimized or maximized.
|
|
664
|
+
"""
|
|
665
|
+
self._prepare_tuning_artifacts()
|
|
666
|
+
trial_params = self._sample_hyperparameters(trial)
|
|
667
|
+
model_params = dict(trial_params["model_params"])
|
|
668
|
+
|
|
669
|
+
if self.tune and self.tune_fast:
|
|
670
|
+
model_params["n_features"] = self._tune_num_features
|
|
671
|
+
|
|
672
|
+
lr = trial_params["lr"]
|
|
673
|
+
l1_penalty = trial_params["l1_penalty"]
|
|
674
|
+
lr_input_fac = trial_params["lr_input_factor"]
|
|
675
|
+
|
|
676
|
+
X_train_trial = self._tune_X_train
|
|
677
|
+
X_test_trial = self._tune_X_test
|
|
678
|
+
class_weights = self._tune_class_weights
|
|
679
|
+
train_loader = self._tune_loader
|
|
680
|
+
|
|
681
|
+
train_latents = self._create_latent_space(
|
|
682
|
+
model_params,
|
|
683
|
+
len(X_train_trial),
|
|
684
|
+
X_train_trial,
|
|
685
|
+
trial_params["latent_init"],
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
model = self.build_model(self.Model, model_params)
|
|
689
|
+
model.n_features = model_params["n_features"]
|
|
690
|
+
model.apply(self.initialize_weights)
|
|
691
|
+
|
|
692
|
+
# train; pass an explicit flag that we are in tuning + whether to fix latents
|
|
693
|
+
_, model, _ = self._train_and_validate_model(
|
|
694
|
+
model=model,
|
|
695
|
+
loader=train_loader,
|
|
696
|
+
lr=lr,
|
|
697
|
+
l1_penalty=l1_penalty,
|
|
698
|
+
trial=trial,
|
|
699
|
+
latent_vectors=train_latents,
|
|
700
|
+
lr_input_factor=lr_input_fac,
|
|
701
|
+
class_weights=class_weights,
|
|
702
|
+
X_val=X_test_trial,
|
|
703
|
+
params=model_params,
|
|
704
|
+
prune_metric=self.tune_metric,
|
|
705
|
+
prune_warmup_epochs=5,
|
|
706
|
+
eval_interval=self.tune_eval_interval,
|
|
707
|
+
eval_latent_steps=0,
|
|
708
|
+
eval_latent_lr=0.0,
|
|
709
|
+
eval_latent_weight_decay=0.0,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
metrics = self._evaluate_model(
|
|
713
|
+
X_test_trial, model, model_params, objective_mode=True
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
self._clear_resources(model, train_loader, latent_vectors=train_latents)
|
|
717
|
+
return metrics[self.tune_metric]
|
|
718
|
+
|
|
719
|
+
def _sample_hyperparameters(
|
|
720
|
+
self, trial: optuna.Trial
|
|
721
|
+
) -> Dict[str, int | float | str | list]:
|
|
722
|
+
"""Samples hyperparameters for the simplified NLPCA model.
|
|
723
|
+
|
|
724
|
+
This method defines the hyperparameter search space for the NLPCA model and samples a set of hyperparameters using the provided Optuna trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
|
|
725
|
+
|
|
726
|
+
Args:
|
|
727
|
+
trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
|
|
728
|
+
|
|
729
|
+
Returns:
|
|
730
|
+
Dict[str, int | float | str | list]: A dictionary of sampled hyperparameters.
|
|
731
|
+
"""
|
|
732
|
+
params = {
|
|
733
|
+
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
734
|
+
"lr": trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
|
|
735
|
+
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.05),
|
|
736
|
+
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 16),
|
|
737
|
+
"activation": trial.suggest_categorical(
|
|
738
|
+
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
739
|
+
),
|
|
740
|
+
"gamma": trial.suggest_float("gamma", 0.1, 5.0, step=0.1),
|
|
741
|
+
"lr_input_factor": trial.suggest_float(
|
|
742
|
+
"lr_input_factor", 0.1, 10.0, log=True
|
|
743
|
+
),
|
|
744
|
+
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
745
|
+
"layer_scaling_factor": trial.suggest_float(
|
|
746
|
+
"layer_scaling_factor", 2.0, 10.0
|
|
747
|
+
),
|
|
748
|
+
"layer_schedule": trial.suggest_categorical(
|
|
749
|
+
"layer_schedule", ["pyramid", "constant", "linear"]
|
|
750
|
+
),
|
|
751
|
+
"latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
use_n_features = (
|
|
755
|
+
self._tune_num_features
|
|
756
|
+
if (self.tune and self.tune_fast and hasattr(self, "_tune_num_features"))
|
|
757
|
+
else self.num_features_
|
|
758
|
+
)
|
|
759
|
+
use_n_samples = (
|
|
760
|
+
len(self._tune_train_idx)
|
|
761
|
+
if (self.tune and self.tune_fast and hasattr(self, "_tune_train_idx"))
|
|
762
|
+
else len(self.train_idx_)
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
766
|
+
n_inputs=params["latent_dim"],
|
|
767
|
+
n_outputs=use_n_features * self.num_classes_,
|
|
768
|
+
n_samples=use_n_samples,
|
|
769
|
+
n_hidden=params["num_hidden_layers"],
|
|
770
|
+
alpha=params["layer_scaling_factor"],
|
|
771
|
+
schedule=params["layer_schedule"],
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
params["model_params"] = {
|
|
775
|
+
"n_features": use_n_features,
|
|
776
|
+
"num_classes": self.num_classes_,
|
|
777
|
+
"latent_dim": params["latent_dim"],
|
|
778
|
+
"dropout_rate": params["dropout_rate"],
|
|
779
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
780
|
+
"activation": params["activation"],
|
|
781
|
+
"gamma": params["gamma"],
|
|
782
|
+
}
|
|
783
|
+
|
|
784
|
+
return params
|
|
785
|
+
|
|
786
|
+
def _set_best_params(
|
|
787
|
+
self, best_params: Dict[str, int | float | str | list]
|
|
788
|
+
) -> Dict[str, int | float | str | list]:
|
|
789
|
+
"""Sets the best hyperparameters found during tuning.
|
|
790
|
+
|
|
791
|
+
This method updates the model's attributes with the best hyperparameters obtained from tuning. It also computes the hidden layer sizes based on these parameters and prepares the final model parameters dictionary.
|
|
792
|
+
|
|
793
|
+
Args:
|
|
794
|
+
best_params (dict): Best hyperparameters from tuning.
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
Dict[str, int | float | str | list]: Model parameters configured with the best hyperparameters.
|
|
798
|
+
"""
|
|
799
|
+
self.latent_dim = best_params["latent_dim"]
|
|
800
|
+
self.dropout_rate = best_params["dropout_rate"]
|
|
801
|
+
self.learning_rate = best_params["learning_rate"]
|
|
802
|
+
self.gamma = best_params["gamma"]
|
|
803
|
+
self.lr_input_factor = best_params["lr_input_factor"]
|
|
804
|
+
self.l1_penalty = best_params["l1_penalty"]
|
|
805
|
+
self.activation = best_params["activation"]
|
|
806
|
+
|
|
807
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
808
|
+
n_inputs=self.latent_dim,
|
|
809
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
810
|
+
n_samples=len(self.train_idx_),
|
|
811
|
+
n_hidden=best_params["num_hidden_layers"],
|
|
812
|
+
alpha=best_params["layer_scaling_factor"],
|
|
813
|
+
schedule=best_params["layer_schedule"],
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
return {
|
|
817
|
+
"n_features": self.num_features_,
|
|
818
|
+
"latent_dim": self.latent_dim,
|
|
819
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
820
|
+
"dropout_rate": self.dropout_rate,
|
|
821
|
+
"activation": self.activation,
|
|
822
|
+
"gamma": self.gamma,
|
|
823
|
+
"num_classes": self.num_classes_,
|
|
824
|
+
}
|
|
825
|
+
|
|
826
|
+
def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
|
|
827
|
+
"""Default (no-tuning) model_params aligned with current attributes.
|
|
828
|
+
|
|
829
|
+
This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
|
|
830
|
+
|
|
831
|
+
Returns:
|
|
832
|
+
Dict[str, int | float | str | list]: model_params payload.
|
|
833
|
+
"""
|
|
834
|
+
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
835
|
+
n_inputs=self.latent_dim,
|
|
836
|
+
n_outputs=self.num_features_ * self.num_classes_,
|
|
837
|
+
n_samples=len(self.ground_truth_),
|
|
838
|
+
n_hidden=self.num_hidden_layers,
|
|
839
|
+
alpha=self.layer_scaling_factor,
|
|
840
|
+
schedule=self.layer_schedule,
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
return {
|
|
844
|
+
"n_features": self.num_features_,
|
|
845
|
+
"latent_dim": self.latent_dim,
|
|
846
|
+
"hidden_layer_sizes": hidden_layer_sizes,
|
|
847
|
+
"dropout_rate": self.dropout_rate,
|
|
848
|
+
"activation": self.activation,
|
|
849
|
+
"gamma": self.gamma,
|
|
850
|
+
"num_classes": self.num_classes_,
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
def _train_and_validate_model(
|
|
854
|
+
self,
|
|
855
|
+
model: torch.nn.Module,
|
|
856
|
+
loader: torch.utils.data.DataLoader,
|
|
857
|
+
lr: float,
|
|
858
|
+
l1_penalty: float,
|
|
859
|
+
trial: optuna.Trial | None = None,
|
|
860
|
+
return_history: bool = False,
|
|
861
|
+
latent_vectors: torch.nn.Parameter | None = None,
|
|
862
|
+
lr_input_factor: float = 1.0,
|
|
863
|
+
class_weights: torch.Tensor | None = None,
|
|
864
|
+
*,
|
|
865
|
+
X_val: np.ndarray | None = None,
|
|
866
|
+
params: dict | None = None,
|
|
867
|
+
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
868
|
+
prune_warmup_epochs: int = 3,
|
|
869
|
+
eval_interval: int = 1,
|
|
870
|
+
eval_latent_steps: int = 50,
|
|
871
|
+
eval_latent_lr: float = 1e-2,
|
|
872
|
+
eval_latent_weight_decay: float = 0.0,
|
|
873
|
+
) -> Tuple:
|
|
874
|
+
"""Trains and validates the NLPCA model.
|
|
875
|
+
|
|
876
|
+
This method trains the provided NLPCA model using the specified training data and hyperparameters. It supports optional integration with Optuna for hyperparameter tuning and pruning based on validation performance. The method initializes optimizers for both the model parameters and latent vectors, sets up a learning rate scheduler, and executes the training loop. It can return the training history if requested.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
model (torch.nn.Module): The NLPCA model to be trained.
|
|
880
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
881
|
+
lr (float): Learning rate for the model optimizer.
|
|
882
|
+
l1_penalty (float): L1 regularization penalty.
|
|
883
|
+
trial (optuna.Trial | None): Optuna trial for hyperparameter tuning.
|
|
884
|
+
return_history (bool): Whether to return training history.
|
|
885
|
+
latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
|
|
886
|
+
lr_input_factor (float): Learning rate factor for latent vectors.
|
|
887
|
+
class_weights (torch.Tensor | None): Class weights for handling class imbalance.
|
|
888
|
+
X_val (np.ndarray | None): Validation data for pruning.
|
|
889
|
+
params (dict | None): Model parameters.
|
|
890
|
+
prune_metric (str | None): Metric for pruning decisions.
|
|
891
|
+
prune_warmup_epochs (int): Number of epochs before pruning starts.
|
|
892
|
+
eval_interval (int): Interval (in epochs) for evaluation during training.
|
|
893
|
+
eval_latent_steps (int): Steps for latent optimization during evaluation.
|
|
894
|
+
eval_latent_lr (float): Learning rate for latent optimization during evaluation.
|
|
895
|
+
eval_latent_weight_decay (float): Weight decay for latent optimization during evaluation.
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
Tuple[float, torch.nn.Module, Dict[str, float], torch.nn.Parameter] | Tuple[float, torch.nn.Module, torch.nn.Parameter]: Training loss, trained model, training history (if requested), and optimized latent vectors.
|
|
899
|
+
|
|
900
|
+
Raises:
|
|
901
|
+
TypeError: If latent_vectors or class_weights are not provided.
|
|
902
|
+
"""
|
|
903
|
+
|
|
904
|
+
if latent_vectors is None or class_weights is None:
|
|
905
|
+
msg = "latent_vectors and class_weights must be provided."
|
|
906
|
+
self.logger.error(msg)
|
|
907
|
+
raise TypeError("Must provide latent_vectors and class_weights.")
|
|
908
|
+
|
|
909
|
+
latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
|
|
910
|
+
|
|
911
|
+
optimizer = torch.optim.Adam(model.phase23_decoder.parameters(), lr=lr)
|
|
912
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=self.epochs)
|
|
913
|
+
|
|
914
|
+
result = self._execute_training_loop(
|
|
915
|
+
loader=loader,
|
|
916
|
+
optimizer=optimizer,
|
|
917
|
+
latent_optimizer=latent_optimizer,
|
|
918
|
+
scheduler=scheduler,
|
|
919
|
+
model=model,
|
|
920
|
+
l1_penalty=l1_penalty,
|
|
921
|
+
return_history=return_history,
|
|
922
|
+
latent_vectors=latent_vectors,
|
|
923
|
+
class_weights=class_weights,
|
|
924
|
+
trial=trial,
|
|
925
|
+
X_val=X_val,
|
|
926
|
+
params=params,
|
|
927
|
+
prune_metric=prune_metric,
|
|
928
|
+
prune_warmup_epochs=prune_warmup_epochs,
|
|
929
|
+
eval_interval=eval_interval,
|
|
930
|
+
eval_latent_steps=eval_latent_steps,
|
|
931
|
+
eval_latent_lr=eval_latent_lr,
|
|
932
|
+
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
if return_history:
|
|
936
|
+
return result
|
|
937
|
+
|
|
938
|
+
return result[0], result[1], result[3]
|
|
939
|
+
|
|
940
|
+
def _train_final_model(
|
|
941
|
+
self,
|
|
942
|
+
loader: torch.utils.data.DataLoader,
|
|
943
|
+
best_params: dict,
|
|
944
|
+
initial_latent_vectors: torch.nn.Parameter,
|
|
945
|
+
) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
|
|
946
|
+
"""Trains the final model using the best hyperparameters.
|
|
947
|
+
|
|
948
|
+
This method builds and trains the final NLPCA model using the best hyperparameters obtained from tuning. It initializes the model weights, trains the model on the entire training set, and saves the trained model to disk. It returns the final training loss, trained model, training history, and optimized latent vectors.
|
|
949
|
+
|
|
950
|
+
Args:
|
|
951
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
952
|
+
best_params (dict): Best hyperparameters for the model.
|
|
953
|
+
initial_latent_vectors (torch.nn.Parameter): Initial latent vectors for samples.
|
|
954
|
+
|
|
955
|
+
Returns:
|
|
956
|
+
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: Final training loss, trained model, training history, and optimized latent vectors.
|
|
957
|
+
Raises:
|
|
958
|
+
RuntimeError: If model training fails.
|
|
959
|
+
"""
|
|
960
|
+
self.logger.info(f"Training the final model...")
|
|
961
|
+
|
|
962
|
+
model = self.build_model(self.Model, best_params)
|
|
963
|
+
model.n_features = best_params["n_features"]
|
|
964
|
+
model.apply(self.initialize_weights)
|
|
965
|
+
|
|
966
|
+
loss, trained_model, history, latent_vectors = self._train_and_validate_model(
|
|
967
|
+
model=model,
|
|
968
|
+
loader=loader,
|
|
969
|
+
lr=self.learning_rate,
|
|
970
|
+
l1_penalty=self.l1_penalty,
|
|
971
|
+
return_history=True,
|
|
972
|
+
latent_vectors=initial_latent_vectors,
|
|
973
|
+
lr_input_factor=self.lr_input_factor,
|
|
974
|
+
class_weights=self.class_weights_,
|
|
975
|
+
X_val=self.X_test_,
|
|
976
|
+
params=best_params,
|
|
977
|
+
prune_metric=self.tune_metric,
|
|
978
|
+
prune_warmup_epochs=5,
|
|
979
|
+
eval_interval=1,
|
|
980
|
+
eval_latent_steps=50,
|
|
981
|
+
eval_latent_lr=self.learning_rate * self.lr_input_factor,
|
|
982
|
+
eval_latent_weight_decay=0.0,
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
if trained_model is None:
|
|
986
|
+
msg = "Final model training failed."
|
|
987
|
+
self.logger.error(msg)
|
|
988
|
+
raise RuntimeError(msg)
|
|
989
|
+
|
|
990
|
+
fn = self.models_dir / "final_model.pt"
|
|
991
|
+
torch.save(trained_model.state_dict(), fn)
|
|
992
|
+
|
|
993
|
+
return (loss, trained_model, {"Train": history}, latent_vectors)
|
|
994
|
+
|
|
995
|
+
def _execute_training_loop(
|
|
996
|
+
self,
|
|
997
|
+
loader,
|
|
998
|
+
optimizer,
|
|
999
|
+
latent_optimizer,
|
|
1000
|
+
scheduler,
|
|
1001
|
+
model,
|
|
1002
|
+
l1_penalty,
|
|
1003
|
+
return_history,
|
|
1004
|
+
latent_vectors,
|
|
1005
|
+
class_weights,
|
|
1006
|
+
*,
|
|
1007
|
+
trial: optuna.Trial | None = None,
|
|
1008
|
+
X_val: np.ndarray | None = None,
|
|
1009
|
+
params: dict | None = None,
|
|
1010
|
+
prune_metric: str | None = None,
|
|
1011
|
+
prune_warmup_epochs: int = 3,
|
|
1012
|
+
eval_interval: int = 1,
|
|
1013
|
+
eval_latent_steps: int = 50,
|
|
1014
|
+
eval_latent_lr: float = 1e-2,
|
|
1015
|
+
eval_latent_weight_decay: float = 0.0,
|
|
1016
|
+
) -> Tuple[float, torch.nn.Module, list, torch.nn.Parameter]:
|
|
1017
|
+
"""Executes the training loop with optional Optuna pruning.
|
|
1018
|
+
|
|
1019
|
+
This method runs the training loop for the NLPCA model, performing multiple epochs of training. It supports optional integration with Optuna for hyperparameter tuning and pruning based on validation performance. The method tracks training history, applies early stopping, and returns the best model and training metrics.
|
|
1020
|
+
|
|
1021
|
+
Args:
|
|
1022
|
+
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
1023
|
+
optimizer (torch.optim.Optimizer): Optimizer for model parameters.
|
|
1024
|
+
latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
|
|
1025
|
+
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
|
1026
|
+
model (torch.nn.Module): The NLPCA model.
|
|
1027
|
+
l1_penalty (float): L1 regularization penalty.
|
|
1028
|
+
return_history (bool): Whether to return training history.
|
|
1029
|
+
latent_vectors (torch.nn.Parameter): Latent vectors for samples.
|
|
1030
|
+
class_weights (torch.Tensor): Class weights for
|
|
1031
|
+
handling class imbalance.
|
|
1032
|
+
trial (optuna.Trial | None): Optuna trial for hyperparameter tuning.
|
|
1033
|
+
X_val (np.ndarray | None): Validation data for pruning.
|
|
1034
|
+
params (dict | None): Model parameters.
|
|
1035
|
+
prune_metric (str | None): Metric to monitor for pruning.
|
|
1036
|
+
prune_warmup_epochs (int): Epochs to wait before pruning.
|
|
1037
|
+
eval_interval (int): Epoch interval for evaluation.
|
|
1038
|
+
eval_latent_steps (int): Steps to refine latents during eval.
|
|
1039
|
+
eval_latent_lr (float): Learning rate for latent refinement during eval.
|
|
1040
|
+
eval_latent_weight_decay (float): Weight decay for latent refinement during eval.
|
|
1041
|
+
|
|
1042
|
+
Returns:
|
|
1043
|
+
Tuple[float, torch.nn.Module, list, torch.nn.Parameter]: Best loss, best model, training history, and optimized latent vectors.
|
|
1044
|
+
|
|
1045
|
+
Raises:
|
|
1046
|
+
optuna.exceptions.TrialPruned: If the trial is pruned based on validation performance.
|
|
1047
|
+
"""
|
|
1048
|
+
best_model = None
|
|
1049
|
+
train_history = []
|
|
1050
|
+
early_stopping = EarlyStopping(
|
|
1051
|
+
patience=self.early_stop_gen,
|
|
1052
|
+
min_epochs=self.min_epochs,
|
|
1053
|
+
verbose=self.verbose,
|
|
1054
|
+
prefix=self.prefix,
|
|
1055
|
+
debug=self.debug,
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
# compute the epoch budget used by the loop
|
|
1059
|
+
max_epochs = (
|
|
1060
|
+
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
1061
|
+
)
|
|
1062
|
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
|
1063
|
+
|
|
1064
|
+
# just above the for-epoch loop
|
|
1065
|
+
_latent_cache: dict = {}
|
|
1066
|
+
_latent_cache_key = f"{self.prefix}_{self.model_name}_val_latents"
|
|
1067
|
+
|
|
1068
|
+
for epoch in range(max_epochs):
|
|
1069
|
+
train_loss, latent_vectors = self._train_step(
|
|
1070
|
+
loader,
|
|
1071
|
+
optimizer,
|
|
1072
|
+
latent_optimizer,
|
|
1073
|
+
model,
|
|
1074
|
+
l1_penalty,
|
|
1075
|
+
latent_vectors,
|
|
1076
|
+
class_weights,
|
|
1077
|
+
)
|
|
1078
|
+
scheduler.step()
|
|
1079
|
+
|
|
1080
|
+
if np.isnan(train_loss) or np.isinf(train_loss):
|
|
1081
|
+
raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
|
|
1082
|
+
|
|
1083
|
+
if return_history:
|
|
1084
|
+
train_history.append(train_loss)
|
|
1085
|
+
|
|
1086
|
+
if (
|
|
1087
|
+
trial is not None
|
|
1088
|
+
and X_val is not None
|
|
1089
|
+
and ((epoch + 1) % eval_interval == 0)
|
|
1090
|
+
):
|
|
1091
|
+
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
1092
|
+
|
|
1093
|
+
do_infer = (eval_latent_steps or 0) > 0
|
|
1094
|
+
metric_val = self._eval_for_pruning(
|
|
1095
|
+
model=model,
|
|
1096
|
+
X_val=X_val,
|
|
1097
|
+
params=params or getattr(self, "best_params_", {}),
|
|
1098
|
+
metric=metric_key,
|
|
1099
|
+
objective_mode=True,
|
|
1100
|
+
do_latent_infer=do_infer,
|
|
1101
|
+
latent_steps=eval_latent_steps,
|
|
1102
|
+
latent_lr=eval_latent_lr,
|
|
1103
|
+
latent_weight_decay=eval_latent_weight_decay,
|
|
1104
|
+
latent_seed=(self.seed if self.seed is not None else None),
|
|
1105
|
+
_latent_cache=_latent_cache,
|
|
1106
|
+
_latent_cache_key=_latent_cache_key,
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
trial.report(metric_val, step=epoch + 1)
|
|
1110
|
+
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
1111
|
+
raise optuna.exceptions.TrialPruned(
|
|
1112
|
+
f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.3f}"
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
early_stopping(train_loss, model)
|
|
1116
|
+
if early_stopping.early_stop:
|
|
1117
|
+
self.logger.info(f"Early stopping at epoch {epoch + 1}.")
|
|
1118
|
+
break
|
|
1119
|
+
|
|
1120
|
+
best_loss = early_stopping.best_score
|
|
1121
|
+
best_model = model # reuse instance
|
|
1122
|
+
best_model.load_state_dict(early_stopping.best_model.state_dict())
|
|
1123
|
+
return best_loss, best_model, train_history, latent_vectors
|
|
1124
|
+
|
|
1125
|
+
def _optimize_latents_for_inference(
|
|
1126
|
+
self,
|
|
1127
|
+
X_new: np.ndarray,
|
|
1128
|
+
model: torch.nn.Module,
|
|
1129
|
+
params: dict,
|
|
1130
|
+
inference_epochs: int = 200,
|
|
1131
|
+
) -> torch.Tensor:
|
|
1132
|
+
"""Optimizes latent vectors for new, unseen data.
|
|
1133
|
+
|
|
1134
|
+
This method optimizes latent vectors for new data samples that were not part of the training set. It initializes latent vectors and performs gradient-based optimization to minimize the reconstruction loss using the trained NLPCA model. The optimized latent vectors are returned for further predictions.
|
|
1135
|
+
|
|
1136
|
+
Args:
|
|
1137
|
+
X_new (np.ndarray): New data in 0/1/2 encoding with -1 for missing values.
|
|
1138
|
+
model (torch.nn.Module): Trained NLPCA model.
|
|
1139
|
+
params (dict): Model parameters.
|
|
1140
|
+
inference_epochs (int): Number of epochs to optimize latent vectors.
|
|
1141
|
+
|
|
1142
|
+
Returns:
|
|
1143
|
+
torch.Tensor: Optimized latent vectors for the new data.
|
|
1144
|
+
"""
|
|
1145
|
+
if self.tune and self.tune_fast:
|
|
1146
|
+
inference_epochs = min(
|
|
1147
|
+
inference_epochs, getattr(self, "tune_infer_epochs", 20)
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
model.eval()
|
|
1151
|
+
|
|
1152
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
1153
|
+
|
|
1154
|
+
new_latent_vectors = self._create_latent_space(
|
|
1155
|
+
params, len(X_new), X_new, self.latent_init
|
|
1156
|
+
)
|
|
1157
|
+
latent_optimizer = torch.optim.Adam(
|
|
1158
|
+
[new_latent_vectors], lr=self.learning_rate * self.lr_input_factor
|
|
1159
|
+
)
|
|
1160
|
+
y_target = torch.from_numpy(X_new).long().to(self.device)
|
|
1161
|
+
|
|
1162
|
+
for _ in range(inference_epochs):
|
|
1163
|
+
latent_optimizer.zero_grad()
|
|
1164
|
+
logits = model.phase23_decoder(new_latent_vectors).view(
|
|
1165
|
+
len(X_new), nF, self.num_classes_
|
|
1166
|
+
)
|
|
1167
|
+
loss = F.cross_entropy(
|
|
1168
|
+
logits.view(-1, self.num_classes_), y_target.view(-1), ignore_index=-1
|
|
1169
|
+
)
|
|
1170
|
+
if torch.isnan(loss):
|
|
1171
|
+
self.logger.warning("Inference loss is NaN; stopping.")
|
|
1172
|
+
break
|
|
1173
|
+
loss.backward()
|
|
1174
|
+
latent_optimizer.step()
|
|
1175
|
+
|
|
1176
|
+
return new_latent_vectors.detach()
|
|
1177
|
+
|
|
1178
|
+
def _latent_infer_for_eval(
|
|
1179
|
+
self,
|
|
1180
|
+
model: torch.nn.Module,
|
|
1181
|
+
X_val: np.ndarray,
|
|
1182
|
+
*,
|
|
1183
|
+
steps: int,
|
|
1184
|
+
lr: float,
|
|
1185
|
+
weight_decay: float,
|
|
1186
|
+
seed: int,
|
|
1187
|
+
cache: dict | None,
|
|
1188
|
+
cache_key: str | None,
|
|
1189
|
+
) -> None:
|
|
1190
|
+
"""Freeze weights; refine validation latents only (no leakage).
|
|
1191
|
+
|
|
1192
|
+
This method refines latent vectors for validation data by optimizing them while keeping the model weights frozen. It uses gradient-based optimization to minimize the reconstruction loss on the validation set. The optimized latent vectors can be cached for future use.
|
|
1193
|
+
|
|
1194
|
+
Args:
|
|
1195
|
+
model (torch.nn.Module): Trained NLPCA model.
|
|
1196
|
+
X_val (np.ndarray): Validation data in 0/1/2 encoding with - 1 for missing.
|
|
1197
|
+
steps (int): Number of optimization steps for latent refinement.
|
|
1198
|
+
lr (float): Learning rate for latent optimization.
|
|
1199
|
+
weight_decay (float): Weight decay for latent optimization.
|
|
1200
|
+
seed (int): Random seed for reproducibility.
|
|
1201
|
+
cache (dict | None): Cache for storing optimized latents.
|
|
1202
|
+
cache_key (str | None): Key for storing/retrieving latents in/from cache
|
|
1203
|
+
|
|
1204
|
+
Returns:
|
|
1205
|
+
None. Updates cache in place if provided.
|
|
1206
|
+
"""
|
|
1207
|
+
if seed is None:
|
|
1208
|
+
seed = np.random.randint(0, 999999)
|
|
1209
|
+
|
|
1210
|
+
torch.manual_seed(seed)
|
|
1211
|
+
np.random.seed(seed)
|
|
1212
|
+
|
|
1213
|
+
model.eval()
|
|
1214
|
+
|
|
1215
|
+
nF = getattr(model, "n_features", self.num_features_)
|
|
1216
|
+
|
|
1217
|
+
for p in model.parameters():
|
|
1218
|
+
p.requires_grad_(False)
|
|
1219
|
+
|
|
1220
|
+
X_val = X_val.astype(np.int64, copy=False)
|
|
1221
|
+
X_val[X_val < 0] = -1
|
|
1222
|
+
y_target = torch.from_numpy(X_val).long().to(self.device)
|
|
1223
|
+
|
|
1224
|
+
# Get latent_dim from the *model actually being evaluated*
|
|
1225
|
+
latent_dim_model = self._first_linear_in_features(model)
|
|
1226
|
+
|
|
1227
|
+
# Make a cache key that is specific to this latent size (and feature schema)
|
|
1228
|
+
cache_key = (
|
|
1229
|
+
f"{self.prefix}_nlpca_val_latents_"
|
|
1230
|
+
f"z{latent_dim_model}_L{self.num_features_}_K{self.num_classes_}"
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
# Warm-start from cache if available *and* shape-compatible
|
|
1234
|
+
if cache is not None and cache_key in cache:
|
|
1235
|
+
val_latents = cache[cache_key].detach().clone().requires_grad_(True)
|
|
1236
|
+
else:
|
|
1237
|
+
val_latents = self._create_latent_space(
|
|
1238
|
+
{"latent_dim": latent_dim_model}, # use model's latent size
|
|
1239
|
+
n_samples=X_val.shape[0],
|
|
1240
|
+
X=X_val,
|
|
1241
|
+
latent_init=self.latent_init,
|
|
1242
|
+
).requires_grad_(True)
|
|
1243
|
+
|
|
1244
|
+
opt = torch.optim.AdamW([val_latents], lr=lr, weight_decay=weight_decay)
|
|
1245
|
+
|
|
1246
|
+
for _ in range(max(int(steps), 0)):
|
|
1247
|
+
opt.zero_grad(set_to_none=True)
|
|
1248
|
+
logits = model.phase23_decoder(val_latents).view(
|
|
1249
|
+
X_val.shape[0], nF, self.num_classes_
|
|
1250
|
+
)
|
|
1251
|
+
loss = F.cross_entropy(
|
|
1252
|
+
logits.view(-1, self.num_classes_),
|
|
1253
|
+
y_target.view(-1),
|
|
1254
|
+
ignore_index=-1,
|
|
1255
|
+
reduction="mean",
|
|
1256
|
+
)
|
|
1257
|
+
loss.backward()
|
|
1258
|
+
opt.step()
|
|
1259
|
+
|
|
1260
|
+
if cache is not None:
|
|
1261
|
+
cache[cache_key] = val_latents.detach().clone()
|
|
1262
|
+
|
|
1263
|
+
for p in model.parameters():
|
|
1264
|
+
p.requires_grad_(True)
|