pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,1554 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import copy
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
|
5
|
-
|
|
6
|
-
import matplotlib.pyplot as plt
|
|
7
|
-
import numpy as np
|
|
8
|
-
import optuna
|
|
9
|
-
import torch
|
|
10
|
-
import torch.nn.functional as F
|
|
11
|
-
from sklearn.decomposition import PCA
|
|
12
|
-
from sklearn.exceptions import NotFittedError
|
|
13
|
-
from sklearn.model_selection import train_test_split
|
|
14
|
-
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
15
|
-
from snpio.utils.logging import LoggerManager
|
|
16
|
-
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
17
|
-
|
|
18
|
-
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
19
|
-
from pgsui.data_processing.containers import NLPCAConfig
|
|
20
|
-
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
21
|
-
from pgsui.impute.unsupervised.base import BaseNNImputer
|
|
22
|
-
from pgsui.impute.unsupervised.callbacks import EarlyStopping
|
|
23
|
-
from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
|
|
24
|
-
from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
|
|
25
|
-
from pgsui.utils.logging_utils import configure_logger
|
|
26
|
-
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
27
|
-
|
|
28
|
-
if TYPE_CHECKING:
|
|
29
|
-
from snpio import TreeParser
|
|
30
|
-
from snpio.read_input.genotype_data import GenotypeData
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def ensure_nlpca_config(config: NLPCAConfig | dict | str | None) -> NLPCAConfig:
|
|
34
|
-
"""Return a concrete NLPCAConfig from dataclass, dict, YAML path, or None.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
NLPCAConfig: Concrete configuration instance.
|
|
41
|
-
"""
|
|
42
|
-
if config is None:
|
|
43
|
-
return NLPCAConfig()
|
|
44
|
-
if isinstance(config, NLPCAConfig):
|
|
45
|
-
return config
|
|
46
|
-
if isinstance(config, str):
|
|
47
|
-
# YAML path — top-level `preset` key is supported
|
|
48
|
-
return load_yaml_to_dataclass(config, NLPCAConfig)
|
|
49
|
-
if isinstance(config, dict):
|
|
50
|
-
# Flatten dict into dot-keys then overlay onto a fresh instance
|
|
51
|
-
base = NLPCAConfig()
|
|
52
|
-
|
|
53
|
-
def _flatten(prefix: str, d: dict, out: dict) -> dict:
|
|
54
|
-
for k, v in d.items():
|
|
55
|
-
kk = f"{prefix}.{k}" if prefix else k
|
|
56
|
-
if isinstance(v, dict):
|
|
57
|
-
_flatten(kk, v, out)
|
|
58
|
-
else:
|
|
59
|
-
out[kk] = v
|
|
60
|
-
return out
|
|
61
|
-
|
|
62
|
-
# Lift any present preset first
|
|
63
|
-
preset_name = config.pop("preset", None)
|
|
64
|
-
if "io" in config and isinstance(config["io"], dict):
|
|
65
|
-
preset_name = preset_name or config["io"].pop("preset", None)
|
|
66
|
-
|
|
67
|
-
if preset_name:
|
|
68
|
-
base = NLPCAConfig.from_preset(preset_name)
|
|
69
|
-
|
|
70
|
-
flat = _flatten("", config, {})
|
|
71
|
-
return apply_dot_overrides(base, flat)
|
|
72
|
-
|
|
73
|
-
raise TypeError("config must be an NLPCAConfig, dict, YAML path, or None.")
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class ImputeNLPCA(BaseNNImputer):
|
|
77
|
-
"""Imputes missing genotypes using a Non-linear Principal Component Analysis (NLPCA) model.
|
|
78
|
-
|
|
79
|
-
This class implements an imputer based on Non-linear Principal Component Analysis (NLPCA) using a neural network architecture. It is designed to handle genotype data encoded in 0/1/2 format, where 0 represents the reference allele, 1 represents the heterozygous genotype, and 2 represents the alternate allele. Missing genotypes should be represented as -9 or -1.
|
|
80
|
-
|
|
81
|
-
The NLPCA model consists of an encoder-decoder architecture that learns a low-dimensional latent representation of the genotype data. The model is trained using a focal loss function to address class imbalance, and it can incorporate L1 regularization to promote sparsity in the learned representations.
|
|
82
|
-
|
|
83
|
-
Notes:
|
|
84
|
-
- Supports both haploid and diploid genotype data.
|
|
85
|
-
- Configurable model architecture with options for latent dimension, dropout rate, number of hidden layers, and activation functions.
|
|
86
|
-
- Hyperparameter tuning using Optuna for optimal model performance.
|
|
87
|
-
- Evaluation metrics including accuracy, F1-score, precision, recall, and ROC-AUC.
|
|
88
|
-
- Visualization of training history and genotype distributions.
|
|
89
|
-
- Flexible configuration via dataclass, dictionary, or YAML file.
|
|
90
|
-
|
|
91
|
-
Example:
|
|
92
|
-
>>> from snpio import VCFReader
|
|
93
|
-
>>> from pgsui import ImputeNLPCA
|
|
94
|
-
>>> gdata = VCFReader("genotypes.vcf.gz")
|
|
95
|
-
>>> imputer = ImputeNLPCA(gdata, config="nlpca_config.yaml")
|
|
96
|
-
>>> imputer.fit()
|
|
97
|
-
>>> imputed_genotypes = imputer.transform()
|
|
98
|
-
>>> print(imputed_genotypes)
|
|
99
|
-
[['A' 'G' 'C' ...],
|
|
100
|
-
['G' 'G' 'C' ...],
|
|
101
|
-
...
|
|
102
|
-
['T' 'C' 'A' ...],
|
|
103
|
-
['C' 'C' 'C' ...]]
|
|
104
|
-
"""
|
|
105
|
-
|
|
106
|
-
def __init__(
|
|
107
|
-
self,
|
|
108
|
-
genotype_data: "GenotypeData",
|
|
109
|
-
*,
|
|
110
|
-
tree_parser: Optional["TreeParser"] = None,
|
|
111
|
-
config: NLPCAConfig | dict | str | None = None,
|
|
112
|
-
overrides: dict | None = None,
|
|
113
|
-
simulate_missing: bool = False,
|
|
114
|
-
sim_strategy: Literal[
|
|
115
|
-
"random",
|
|
116
|
-
"random_weighted",
|
|
117
|
-
"random_weighted_inv",
|
|
118
|
-
"nonrandom",
|
|
119
|
-
"nonrandom_weighted",
|
|
120
|
-
] = "random",
|
|
121
|
-
sim_prop: float = 0.10,
|
|
122
|
-
sim_kwargs: dict | None = None,
|
|
123
|
-
):
|
|
124
|
-
"""Initializes the ImputeNLPCA imputer with genotype data and configuration.
|
|
125
|
-
|
|
126
|
-
This constructor sets up the ImputeNLPCA imputer by accepting genotype data and a configuration that can be provided in various formats. It initializes logging, device settings, and model parameters based on the provided configuration.
|
|
127
|
-
|
|
128
|
-
Args:
|
|
129
|
-
genotype_data (GenotypeData): Backing genotype data.
|
|
130
|
-
tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for population-specific modes.
|
|
131
|
-
config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
|
|
132
|
-
overrides (dict | None): Dot-key overrides (e.g. {'model.latent_dim': 4}).
|
|
133
|
-
simulate_missing (bool): Whether to simulate missing data during training.
|
|
134
|
-
sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data.
|
|
135
|
-
sim_prop (float): Proportion of data to simulate as missing.
|
|
136
|
-
sim_kwargs (dict | None): Additional keyword arguments for missing data simulation.
|
|
137
|
-
"""
|
|
138
|
-
self.model_name = "ImputeNLPCA"
|
|
139
|
-
self.genotype_data = genotype_data
|
|
140
|
-
self.tree_parser = tree_parser
|
|
141
|
-
|
|
142
|
-
# Normalize config first, then apply overrides (highest precedence)
|
|
143
|
-
cfg = ensure_nlpca_config(config)
|
|
144
|
-
if overrides:
|
|
145
|
-
cfg = apply_dot_overrides(cfg, overrides)
|
|
146
|
-
|
|
147
|
-
self.cfg = cfg
|
|
148
|
-
|
|
149
|
-
logman = LoggerManager(
|
|
150
|
-
__name__,
|
|
151
|
-
prefix=self.cfg.io.prefix,
|
|
152
|
-
debug=self.cfg.io.debug,
|
|
153
|
-
verbose=self.cfg.io.verbose,
|
|
154
|
-
)
|
|
155
|
-
self.logger = configure_logger(
|
|
156
|
-
logman.get_logger(),
|
|
157
|
-
verbose=self.cfg.io.verbose,
|
|
158
|
-
debug=self.cfg.io.debug,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
# Initialize BaseNNImputer with device/dirs/logging from config
|
|
162
|
-
super().__init__(
|
|
163
|
-
model_name=self.model_name,
|
|
164
|
-
genotype_data=self.genotype_data,
|
|
165
|
-
prefix=self.cfg.io.prefix,
|
|
166
|
-
device=self.cfg.train.device,
|
|
167
|
-
verbose=self.cfg.io.verbose,
|
|
168
|
-
debug=self.cfg.io.debug,
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
self.Model = NLPCAModel
|
|
172
|
-
self.pgenc = GenotypeEncoder(genotype_data)
|
|
173
|
-
self.seed = self.cfg.io.seed
|
|
174
|
-
self.n_jobs = self.cfg.io.n_jobs
|
|
175
|
-
self.prefix = self.cfg.io.prefix
|
|
176
|
-
self.scoring_averaging = self.cfg.io.scoring_averaging
|
|
177
|
-
self.verbose = self.cfg.io.verbose
|
|
178
|
-
self.debug = self.cfg.io.debug
|
|
179
|
-
|
|
180
|
-
self.rng = np.random.default_rng(self.seed)
|
|
181
|
-
|
|
182
|
-
# Model/train hyperparams
|
|
183
|
-
self.latent_dim = self.cfg.model.latent_dim
|
|
184
|
-
self.dropout_rate = self.cfg.model.dropout_rate
|
|
185
|
-
self.num_hidden_layers = self.cfg.model.num_hidden_layers
|
|
186
|
-
self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
|
|
187
|
-
self.layer_schedule = self.cfg.model.layer_schedule
|
|
188
|
-
self.latent_init: Literal["random", "pca"] = self.cfg.model.latent_init
|
|
189
|
-
self.activation = self.cfg.model.hidden_activation
|
|
190
|
-
self.gamma = self.cfg.model.gamma
|
|
191
|
-
|
|
192
|
-
self.batch_size = self.cfg.train.batch_size
|
|
193
|
-
self.learning_rate: float = self.cfg.train.learning_rate
|
|
194
|
-
self.lr_input_factor = self.cfg.train.lr_input_factor
|
|
195
|
-
self.l1_penalty = self.cfg.train.l1_penalty
|
|
196
|
-
self.early_stop_gen = self.cfg.train.early_stop_gen
|
|
197
|
-
self.min_epochs = self.cfg.train.min_epochs
|
|
198
|
-
self.epochs = self.cfg.train.max_epochs
|
|
199
|
-
self.validation_split = self.cfg.train.validation_split
|
|
200
|
-
self.beta = self.cfg.train.weights_beta
|
|
201
|
-
self.max_ratio = self.cfg.train.weights_max_ratio
|
|
202
|
-
|
|
203
|
-
# Tuning
|
|
204
|
-
self.tune = self.cfg.tune.enabled
|
|
205
|
-
self.tune_fast = self.cfg.tune.fast
|
|
206
|
-
self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
|
|
207
|
-
self.tune_batch_size = self.cfg.tune.batch_size
|
|
208
|
-
self.tune_epochs = self.cfg.tune.epochs
|
|
209
|
-
self.tune_eval_interval = self.cfg.tune.eval_interval
|
|
210
|
-
self.tune_metric: Literal[
|
|
211
|
-
"pr_macro",
|
|
212
|
-
"f1",
|
|
213
|
-
"accuracy",
|
|
214
|
-
"average_precision",
|
|
215
|
-
"precision",
|
|
216
|
-
"recall",
|
|
217
|
-
"roc_auc",
|
|
218
|
-
] = self.cfg.tune.metric
|
|
219
|
-
self.n_trials = self.cfg.tune.n_trials
|
|
220
|
-
self.tune_save_db = self.cfg.tune.save_db
|
|
221
|
-
self.tune_resume = self.cfg.tune.resume
|
|
222
|
-
self.tune_max_samples = self.cfg.tune.max_samples
|
|
223
|
-
self.tune_max_loci = self.cfg.tune.max_loci
|
|
224
|
-
self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
|
|
225
|
-
self.tune_patience = self.cfg.tune.patience
|
|
226
|
-
|
|
227
|
-
# Eval
|
|
228
|
-
self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
|
|
229
|
-
self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
|
|
230
|
-
self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
|
|
231
|
-
|
|
232
|
-
# Plotting (NOTE: PlotConfig has 'show', not 'show_plots')
|
|
233
|
-
self.plot_format = self.cfg.plot.fmt
|
|
234
|
-
self.plot_dpi = self.cfg.plot.dpi
|
|
235
|
-
self.plot_fontsize = self.cfg.plot.fontsize
|
|
236
|
-
self.title_fontsize = self.cfg.plot.fontsize
|
|
237
|
-
self.despine = self.cfg.plot.despine
|
|
238
|
-
self.show_plots = self.cfg.plot.show
|
|
239
|
-
|
|
240
|
-
# Core model config
|
|
241
|
-
self.is_haploid = False
|
|
242
|
-
self.num_classes_ = 3
|
|
243
|
-
self.model_params: Dict[str, Any] = {}
|
|
244
|
-
|
|
245
|
-
self.simulate_missing = simulate_missing
|
|
246
|
-
self.sim_strategy = sim_strategy
|
|
247
|
-
self.sim_prop = float(sim_prop)
|
|
248
|
-
self.sim_kwargs = sim_kwargs or {}
|
|
249
|
-
|
|
250
|
-
if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
|
|
251
|
-
msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
|
|
252
|
-
self.logger.error(msg)
|
|
253
|
-
raise ValueError(msg)
|
|
254
|
-
|
|
255
|
-
def fit(self) -> "ImputeNLPCA":
|
|
256
|
-
"""Fits the NLPCA model to the 0/1/2 encoded genotype data.
|
|
257
|
-
|
|
258
|
-
This method prepares the data, splits it into training and validation sets, initializes the model, and trains it. If hyperparameter tuning is enabled, it will perform tuning before final training. After training, it evaluates the model on a test set and generates relevant plots.
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
ImputeNLPCA: The fitted imputer instance.
|
|
262
|
-
"""
|
|
263
|
-
self.logger.info(f"Fitting {self.model_name} model...")
|
|
264
|
-
|
|
265
|
-
# --- BASE MATRIX AND GROUND TRUTH ---
|
|
266
|
-
X012 = self.pgenc.genotypes_012.astype(np.float32)
|
|
267
|
-
X012[X012 < 0] = np.nan # NaN = original missing
|
|
268
|
-
|
|
269
|
-
# Keep an immutable ground-truth copy in 0/1/2 with -1 for original
|
|
270
|
-
# missing
|
|
271
|
-
GT_full = X012.copy()
|
|
272
|
-
GT_full[np.isnan(GT_full)] = -1
|
|
273
|
-
self.ground_truth_ = GT_full.astype(np.int64)
|
|
274
|
-
|
|
275
|
-
# --- OPTIONAL SIMULATED MISSING VIA SimMissingTransformer ---
|
|
276
|
-
self.sim_mask_global_ = None
|
|
277
|
-
if self.simulate_missing:
|
|
278
|
-
tr = SimMissingTransformer(
|
|
279
|
-
genotype_data=self.genotype_data,
|
|
280
|
-
tree_parser=self.tree_parser,
|
|
281
|
-
prop_missing=self.sim_prop,
|
|
282
|
-
strategy=self.sim_strategy,
|
|
283
|
-
missing_val=-9,
|
|
284
|
-
mask_missing=True,
|
|
285
|
-
verbose=self.verbose,
|
|
286
|
-
tol=None,
|
|
287
|
-
max_tries=None,
|
|
288
|
-
)
|
|
289
|
-
# NOTE: pass NaN-coded missing; transformer handles NaNs correctly
|
|
290
|
-
tr.fit(X012.copy())
|
|
291
|
-
|
|
292
|
-
# Store boolean mask of simulated positions only (excludes original-missing)
|
|
293
|
-
self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
|
|
294
|
-
|
|
295
|
-
# Apply simulation to the model’s input copy: encode as -1 for loss
|
|
296
|
-
X_for_model = self.ground_truth_.copy()
|
|
297
|
-
X_for_model[self.sim_mask_global_] = -1
|
|
298
|
-
else:
|
|
299
|
-
X_for_model = self.ground_truth_.copy()
|
|
300
|
-
|
|
301
|
-
# --- Determine Ploidy and Number of Classes ---
|
|
302
|
-
self.is_haploid = bool(
|
|
303
|
-
np.all(
|
|
304
|
-
np.isin(
|
|
305
|
-
self.genotype_data.snp_data,
|
|
306
|
-
["A", "C", "G", "T", "N", "-", ".", "?"],
|
|
307
|
-
)
|
|
308
|
-
)
|
|
309
|
-
)
|
|
310
|
-
|
|
311
|
-
self.ploidy = 1 if self.is_haploid else 2
|
|
312
|
-
|
|
313
|
-
if self.is_haploid:
|
|
314
|
-
self.num_classes_ = 2
|
|
315
|
-
|
|
316
|
-
# Remap labels from {0, 2} to {0, 1}
|
|
317
|
-
self.ground_truth_[self.ground_truth_ == 2] = 1
|
|
318
|
-
X_for_model[X_for_model == 2] = 1 # <- add this line
|
|
319
|
-
self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
|
|
320
|
-
else:
|
|
321
|
-
self.num_classes_ = 3
|
|
322
|
-
|
|
323
|
-
self.logger.info(
|
|
324
|
-
"Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
n_samples, self.num_features_ = X_for_model.shape
|
|
328
|
-
|
|
329
|
-
self.model_params = {
|
|
330
|
-
"n_features": self.num_features_,
|
|
331
|
-
"latent_dim": self.latent_dim,
|
|
332
|
-
"dropout_rate": self.dropout_rate,
|
|
333
|
-
"activation": self.activation,
|
|
334
|
-
"gamma": self.gamma,
|
|
335
|
-
"num_classes": self.num_classes_,
|
|
336
|
-
}
|
|
337
|
-
|
|
338
|
-
# --- Train/Test Split ---
|
|
339
|
-
indices = np.arange(n_samples)
|
|
340
|
-
train_idx, test_idx = train_test_split(
|
|
341
|
-
indices, test_size=self.validation_split, random_state=self.seed
|
|
342
|
-
)
|
|
343
|
-
self.train_idx_, self.test_idx_ = train_idx, test_idx
|
|
344
|
-
# Subset matrices for training/eval
|
|
345
|
-
self.X_train_ = X_for_model[train_idx]
|
|
346
|
-
self.X_test_ = X_for_model[test_idx]
|
|
347
|
-
self.GT_train_full_ = self.ground_truth_[train_idx] # pre-mask truth
|
|
348
|
-
self.GT_test_full_ = self.ground_truth_[test_idx]
|
|
349
|
-
|
|
350
|
-
# Slice the simulation mask by split if present
|
|
351
|
-
if self.sim_mask_global_ is not None:
|
|
352
|
-
self.sim_mask_train_ = self.sim_mask_global_[train_idx]
|
|
353
|
-
self.sim_mask_test_ = self.sim_mask_global_[test_idx]
|
|
354
|
-
else:
|
|
355
|
-
self.sim_mask_train_ = None
|
|
356
|
-
self.sim_mask_test_ = None
|
|
357
|
-
|
|
358
|
-
# Tuning, model setup, training (unchanged except DataLoader input)
|
|
359
|
-
self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
|
|
360
|
-
|
|
361
|
-
if self.tune:
|
|
362
|
-
self.tune_hyperparameters()
|
|
363
|
-
self.best_params_ = getattr(self, "best_params_", self.model_params.copy())
|
|
364
|
-
else:
|
|
365
|
-
self.best_params_ = self._set_best_params_default()
|
|
366
|
-
|
|
367
|
-
# Class weights from 0/1/2 training data
|
|
368
|
-
self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
|
|
369
|
-
|
|
370
|
-
if not self.latent_init in {"random", "pca"} and isinstance(
|
|
371
|
-
self.latent_init, str
|
|
372
|
-
):
|
|
373
|
-
msg = (
|
|
374
|
-
f"Invalid latent_init '{self.latent_init}'; must be 'random' or 'pca'."
|
|
375
|
-
)
|
|
376
|
-
self.logger.error(msg)
|
|
377
|
-
raise ValueError(msg)
|
|
378
|
-
|
|
379
|
-
li: Literal["random", "pca"] = self.latent_init
|
|
380
|
-
|
|
381
|
-
# Latent vectors for training set
|
|
382
|
-
self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
|
|
383
|
-
train_latent_vectors = self._create_latent_space(
|
|
384
|
-
self.best_params_, len(self.X_train_), self.X_train_, li
|
|
385
|
-
)
|
|
386
|
-
train_loader = self._get_data_loaders(self.X_train_)
|
|
387
|
-
|
|
388
|
-
# Train the final model
|
|
389
|
-
(self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
|
|
390
|
-
self._train_final_model(
|
|
391
|
-
train_loader, self.best_params_, train_latent_vectors
|
|
392
|
-
)
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
self.is_fit_ = True
|
|
396
|
-
self.plotter_.plot_history(self.history_)
|
|
397
|
-
|
|
398
|
-
if self.sim_mask_test_ is not None:
|
|
399
|
-
# Evaluate exactly on simulated-missing sites
|
|
400
|
-
self.logger.info("Evaluating on simulated-missing positions only.")
|
|
401
|
-
self._evaluate_model(
|
|
402
|
-
self.X_test_,
|
|
403
|
-
self.model_,
|
|
404
|
-
self.best_params_,
|
|
405
|
-
eval_mask_override=self.sim_mask_test_,
|
|
406
|
-
)
|
|
407
|
-
else:
|
|
408
|
-
self._evaluate_model(self.X_test_, self.model_, self.best_params_)
|
|
409
|
-
|
|
410
|
-
self._save_best_params(self.best_params_)
|
|
411
|
-
|
|
412
|
-
return self
|
|
413
|
-
|
|
414
|
-
def transform(self) -> np.ndarray:
|
|
415
|
-
"""Imputes missing genotypes using the trained model.
|
|
416
|
-
|
|
417
|
-
This method uses the trained NLPCA model to impute missing genotypes in the entire dataset. It optimizes latent vectors for all samples, predicts missing values, and fills them in. The imputed genotypes are returned in IUPAC string format.
|
|
418
|
-
|
|
419
|
-
Returns:
|
|
420
|
-
np.ndarray: Imputed genotypes in IUPAC string format.
|
|
421
|
-
|
|
422
|
-
Raises:
|
|
423
|
-
NotFittedError: If the model has not been fitted.
|
|
424
|
-
"""
|
|
425
|
-
if not getattr(self, "is_fit_", False):
|
|
426
|
-
raise NotFittedError("Model is not fitted. Call fit() before transform().")
|
|
427
|
-
|
|
428
|
-
self.logger.info(f"Imputing entire dataset with {self.model_name}...")
|
|
429
|
-
X_to_impute = self.ground_truth_.copy()
|
|
430
|
-
|
|
431
|
-
# Optimize latents for the full dataset
|
|
432
|
-
optimized_latents = self._optimize_latents_for_inference(
|
|
433
|
-
X_to_impute, self.model_, self.best_params_
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
# Predict missing values
|
|
437
|
-
pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
|
|
438
|
-
|
|
439
|
-
# Fill in missing values
|
|
440
|
-
missing_mask = X_to_impute == -1
|
|
441
|
-
imputed_array = X_to_impute.copy()
|
|
442
|
-
imputed_array[missing_mask] = pred_labels[missing_mask]
|
|
443
|
-
|
|
444
|
-
# Decode back to IUPAC strings
|
|
445
|
-
imputed_genotypes = self.pgenc.decode_012(imputed_array)
|
|
446
|
-
if self.show_plots:
|
|
447
|
-
original_genotypes = self.pgenc.decode_012(X_to_impute)
|
|
448
|
-
plt.rcParams.update(self.plotter_.param_dict) # Ensure consistent style
|
|
449
|
-
self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
|
|
450
|
-
self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
|
|
451
|
-
|
|
452
|
-
return imputed_genotypes
|
|
453
|
-
|
|
454
|
-
def _train_step(
|
|
455
|
-
self,
|
|
456
|
-
loader: torch.utils.data.DataLoader,
|
|
457
|
-
optimizer: torch.optim.Optimizer,
|
|
458
|
-
latent_optimizer: torch.optim.Optimizer,
|
|
459
|
-
model: torch.nn.Module,
|
|
460
|
-
l1_penalty: float,
|
|
461
|
-
latent_vectors: torch.nn.Parameter,
|
|
462
|
-
class_weights: torch.Tensor,
|
|
463
|
-
) -> Tuple[float, torch.nn.Parameter]:
|
|
464
|
-
"""One epoch with stable focal CE, latent+weight updates, and NaN guards.
|
|
465
|
-
|
|
466
|
-
Args:
|
|
467
|
-
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
468
|
-
optimizer (torch.optim.Optimizer): Optimizer for model parameters.
|
|
469
|
-
latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
|
|
470
|
-
model (torch.nn.Module): NLPCA model.
|
|
471
|
-
l1_penalty (float): L1 regularization penalty.
|
|
472
|
-
latent_vectors (torch.nn.Parameter): Latent vectors for samples.
|
|
473
|
-
class_weights (torch.Tensor): Class weights for focal loss.
|
|
474
|
-
|
|
475
|
-
Returns:
|
|
476
|
-
Tuple[float, torch.nn.Parameter]: Average loss and updated latent vectors.
|
|
477
|
-
|
|
478
|
-
Notes:
|
|
479
|
-
- Implements focal cross-entropy loss with class weights.
|
|
480
|
-
- Applies L1 regularization on model weights.
|
|
481
|
-
- Includes guards against NaN/infinite values in logits, loss, and gradients.
|
|
482
|
-
"""
|
|
483
|
-
model.train()
|
|
484
|
-
running = 0.0
|
|
485
|
-
used = 0
|
|
486
|
-
|
|
487
|
-
# Ensure latent vectors are trainable
|
|
488
|
-
if not isinstance(latent_vectors, torch.nn.Parameter):
|
|
489
|
-
latent_vectors = torch.nn.Parameter(latent_vectors, requires_grad=True)
|
|
490
|
-
|
|
491
|
-
# Bound gamma to a sane range
|
|
492
|
-
gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
|
|
493
|
-
gamma = max(0.0, min(gamma, 10.0))
|
|
494
|
-
|
|
495
|
-
# Normalize class weights to mean≈1 to keep loss scale stable
|
|
496
|
-
if class_weights is not None:
|
|
497
|
-
cw = class_weights.to(self.device)
|
|
498
|
-
cw = cw / cw.mean().clamp_min(1e-8)
|
|
499
|
-
else:
|
|
500
|
-
cw = None
|
|
501
|
-
|
|
502
|
-
nF = getattr(model, "n_features", self.num_features_)
|
|
503
|
-
|
|
504
|
-
criterion = SafeFocalCELoss(gamma=gamma, weight=cw, ignore_index=-1)
|
|
505
|
-
|
|
506
|
-
for batch_indices, y_batch in loader:
|
|
507
|
-
optimizer.zero_grad(set_to_none=True)
|
|
508
|
-
latent_optimizer.zero_grad(set_to_none=True)
|
|
509
|
-
|
|
510
|
-
# Targets
|
|
511
|
-
y_batch = y_batch.to(self.device, non_blocking=True).long()
|
|
512
|
-
|
|
513
|
-
decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
|
|
514
|
-
|
|
515
|
-
if not isinstance(decoder, torch.nn.Module):
|
|
516
|
-
msg = "Model decoder is not a valid torch.nn.Module."
|
|
517
|
-
self.logger.error(msg)
|
|
518
|
-
raise TypeError(msg)
|
|
519
|
-
|
|
520
|
-
# Forward
|
|
521
|
-
z = latent_vectors[batch_indices].to(self.device)
|
|
522
|
-
logits = decoder(z).view(len(batch_indices), nF, self.num_classes_)
|
|
523
|
-
|
|
524
|
-
# Guard upstream explosions
|
|
525
|
-
if not torch.isfinite(logits).all():
|
|
526
|
-
# Skip batch if model already produced non-finite values
|
|
527
|
-
continue
|
|
528
|
-
|
|
529
|
-
logits_flat = logits.view(-1, self.num_classes_)
|
|
530
|
-
targets_flat = y_batch.view(-1)
|
|
531
|
-
|
|
532
|
-
loss = criterion(logits_flat, targets_flat)
|
|
533
|
-
|
|
534
|
-
# L1 on model weights only (exclude latents)
|
|
535
|
-
if l1_penalty > 0:
|
|
536
|
-
l1 = torch.stack(
|
|
537
|
-
[p.abs().sum() for p in model.parameters() if p.requires_grad]
|
|
538
|
-
).sum()
|
|
539
|
-
loss = loss + l1_penalty * l1
|
|
540
|
-
|
|
541
|
-
if not torch.isfinite(loss):
|
|
542
|
-
# Skip pathological batch
|
|
543
|
-
continue
|
|
544
|
-
|
|
545
|
-
loss.backward()
|
|
546
|
-
|
|
547
|
-
# Clip both parameter sets to keep grads finite
|
|
548
|
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
549
|
-
torch.nn.utils.clip_grad_norm_([latent_vectors], max_norm=1.0)
|
|
550
|
-
|
|
551
|
-
# If any grad is non-finite, skip updates
|
|
552
|
-
bad_grad = False
|
|
553
|
-
for p in model.parameters():
|
|
554
|
-
if p.grad is not None and not torch.isfinite(p.grad).all():
|
|
555
|
-
bad_grad = True
|
|
556
|
-
break
|
|
557
|
-
if (
|
|
558
|
-
not bad_grad
|
|
559
|
-
and latent_vectors.grad is not None
|
|
560
|
-
and not torch.isfinite(latent_vectors.grad).all()
|
|
561
|
-
):
|
|
562
|
-
bad_grad = True
|
|
563
|
-
if bad_grad:
|
|
564
|
-
optimizer.zero_grad(set_to_none=True)
|
|
565
|
-
latent_optimizer.zero_grad(set_to_none=True)
|
|
566
|
-
continue
|
|
567
|
-
|
|
568
|
-
optimizer.step()
|
|
569
|
-
latent_optimizer.step()
|
|
570
|
-
|
|
571
|
-
running += float(loss.detach().item())
|
|
572
|
-
used += 1
|
|
573
|
-
|
|
574
|
-
if used == 0:
|
|
575
|
-
# Signal upstream that no safe batches were used
|
|
576
|
-
return float("inf"), latent_vectors
|
|
577
|
-
|
|
578
|
-
return running / used, latent_vectors
|
|
579
|
-
|
|
580
|
-
def _predict(
|
|
581
|
-
self, model: torch.nn.Module, latent_vectors: torch.Tensor | None = None
|
|
582
|
-
) -> Tuple[np.ndarray, np.ndarray]:
|
|
583
|
-
"""Generates 0/1/2 predictions from latent vectors.
|
|
584
|
-
|
|
585
|
-
This method uses the trained NLPCA model to generate predictions from the latent vectors by passing them through the decoder. It returns both the predicted labels and their associated probabilities.
|
|
586
|
-
|
|
587
|
-
Args:
|
|
588
|
-
model (torch.nn.Module): Trained NLPCA model.
|
|
589
|
-
latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
|
|
590
|
-
|
|
591
|
-
Returns:
|
|
592
|
-
Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
|
|
593
|
-
"""
|
|
594
|
-
if model is None or latent_vectors is None:
|
|
595
|
-
raise NotFittedError("Model or latent vectors not available.")
|
|
596
|
-
|
|
597
|
-
model.eval()
|
|
598
|
-
|
|
599
|
-
nF = getattr(model, "n_features", self.num_features_)
|
|
600
|
-
|
|
601
|
-
if not isinstance(model.phase23_decoder, torch.nn.Module):
|
|
602
|
-
msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
|
|
603
|
-
self.logger.error(msg)
|
|
604
|
-
raise TypeError(msg)
|
|
605
|
-
|
|
606
|
-
with torch.no_grad():
|
|
607
|
-
logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
|
|
608
|
-
len(latent_vectors), nF, self.num_classes_
|
|
609
|
-
)
|
|
610
|
-
probas = torch.softmax(logits, dim=-1)
|
|
611
|
-
labels = torch.argmax(probas, dim=-1)
|
|
612
|
-
|
|
613
|
-
return labels.cpu().numpy(), probas.cpu().numpy()
|
|
614
|
-
|
|
615
|
-
def _evaluate_model(
|
|
616
|
-
self,
|
|
617
|
-
X_val: np.ndarray,
|
|
618
|
-
model: torch.nn.Module,
|
|
619
|
-
params: dict,
|
|
620
|
-
objective_mode: bool = False,
|
|
621
|
-
latent_vectors_val: torch.Tensor | None = None,
|
|
622
|
-
*,
|
|
623
|
-
eval_mask_override: np.ndarray | None = None,
|
|
624
|
-
) -> Dict[str, float]:
|
|
625
|
-
"""Evaluates the model on a validation set.
|
|
626
|
-
|
|
627
|
-
This method evaluates the trained NLPCA model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
|
|
628
|
-
|
|
629
|
-
Args:
|
|
630
|
-
X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
|
|
631
|
-
model (torch.nn.Module): Trained NLPCA model.
|
|
632
|
-
params (dict): Model parameters.
|
|
633
|
-
objective_mode (bool): If True, suppresses logging and reports only the metric.
|
|
634
|
-
latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
|
|
635
|
-
eval_mask_override (np.ndarray | None): Boolean mask to specify which entries to evaluate.
|
|
636
|
-
|
|
637
|
-
Returns:
|
|
638
|
-
Dict[str, float]: Dictionary of evaluation metrics.
|
|
639
|
-
"""
|
|
640
|
-
if latent_vectors_val is not None:
|
|
641
|
-
test_latent_vectors = latent_vectors_val
|
|
642
|
-
else:
|
|
643
|
-
test_latent_vectors = self._optimize_latents_for_inference(
|
|
644
|
-
X_val, model, params
|
|
645
|
-
)
|
|
646
|
-
|
|
647
|
-
pred_labels, pred_probas = self._predict(
|
|
648
|
-
model=model, latent_vectors=test_latent_vectors
|
|
649
|
-
)
|
|
650
|
-
|
|
651
|
-
if eval_mask_override is not None:
|
|
652
|
-
# Validate row counts to allow feature subsetting during tuning
|
|
653
|
-
if eval_mask_override.shape[0] != X_val.shape[0]:
|
|
654
|
-
msg = (
|
|
655
|
-
f"eval_mask_override rows {eval_mask_override.shape[0]} "
|
|
656
|
-
f"does not match X_val rows {X_val.shape[0]}"
|
|
657
|
-
)
|
|
658
|
-
self.logger.error(msg)
|
|
659
|
-
raise ValueError(msg)
|
|
660
|
-
|
|
661
|
-
# Slice mask columns if override is wider than current X_val (tune_fast)
|
|
662
|
-
if eval_mask_override.shape[1] > X_val.shape[1]:
|
|
663
|
-
eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
|
|
664
|
-
else:
|
|
665
|
-
eval_mask = eval_mask_override.astype(bool)
|
|
666
|
-
else:
|
|
667
|
-
# Default: score only observed entries
|
|
668
|
-
eval_mask = X_val != -1
|
|
669
|
-
|
|
670
|
-
# y_true should be drawn from the pre-mask ground truth
|
|
671
|
-
# Map X_val back to the correct full ground truth slice
|
|
672
|
-
# FIX: Check shape[0] (n_samples) only.
|
|
673
|
-
if X_val.shape[0] == self.X_test_.shape[0]:
|
|
674
|
-
GT_ref = self.GT_test_full_
|
|
675
|
-
elif X_val.shape[0] == self.X_train_.shape[0]:
|
|
676
|
-
GT_ref = self.GT_train_full_
|
|
677
|
-
else:
|
|
678
|
-
GT_ref = self.ground_truth_
|
|
679
|
-
|
|
680
|
-
# FIX: Slice Ground Truth columns if it is wider than X_val (tune_fast)
|
|
681
|
-
if GT_ref.shape[1] > X_val.shape[1]:
|
|
682
|
-
GT_ref = GT_ref[:, : X_val.shape[1]]
|
|
683
|
-
|
|
684
|
-
# Fallback safeguard
|
|
685
|
-
if GT_ref.shape != X_val.shape:
|
|
686
|
-
GT_ref = X_val
|
|
687
|
-
|
|
688
|
-
y_true_flat = GT_ref[eval_mask]
|
|
689
|
-
pred_labels_flat = pred_labels[eval_mask]
|
|
690
|
-
pred_probas_flat = pred_probas[eval_mask]
|
|
691
|
-
|
|
692
|
-
if y_true_flat.size == 0:
|
|
693
|
-
return {self.tune_metric: 0.0}
|
|
694
|
-
|
|
695
|
-
# For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
|
|
696
|
-
labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
|
|
697
|
-
target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
|
|
698
|
-
|
|
699
|
-
y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
|
|
700
|
-
|
|
701
|
-
metrics = self.scorers_.evaluate(
|
|
702
|
-
y_true_flat,
|
|
703
|
-
pred_labels_flat,
|
|
704
|
-
y_true_ohe,
|
|
705
|
-
pred_probas_flat,
|
|
706
|
-
objective_mode,
|
|
707
|
-
self.tune_metric,
|
|
708
|
-
)
|
|
709
|
-
|
|
710
|
-
if not objective_mode:
|
|
711
|
-
pm = PrettyMetrics(
|
|
712
|
-
metrics, precision=3, title=f"{self.model_name} Validation Metrics"
|
|
713
|
-
)
|
|
714
|
-
pm.render() # prints a command-line table
|
|
715
|
-
|
|
716
|
-
self._make_class_reports(
|
|
717
|
-
y_true=y_true_flat,
|
|
718
|
-
y_pred_proba=pred_probas_flat,
|
|
719
|
-
y_pred=pred_labels_flat,
|
|
720
|
-
metrics=metrics,
|
|
721
|
-
labels=target_names,
|
|
722
|
-
)
|
|
723
|
-
|
|
724
|
-
# FIX: Use X_val dimensions for reshaping, not self.num_features_
|
|
725
|
-
y_true_dec = self.pgenc.decode_012(
|
|
726
|
-
GT_ref.reshape(X_val.shape[0], X_val.shape[1])
|
|
727
|
-
)
|
|
728
|
-
|
|
729
|
-
X_pred = X_val.copy()
|
|
730
|
-
X_pred[eval_mask] = pred_labels_flat
|
|
731
|
-
|
|
732
|
-
y_pred_dec = self.pgenc.decode_012(
|
|
733
|
-
X_pred.reshape(X_val.shape[0], X_val.shape[1])
|
|
734
|
-
)
|
|
735
|
-
|
|
736
|
-
encodings_dict = {
|
|
737
|
-
"A": 0,
|
|
738
|
-
"C": 1,
|
|
739
|
-
"G": 2,
|
|
740
|
-
"T": 3,
|
|
741
|
-
"W": 4,
|
|
742
|
-
"R": 5,
|
|
743
|
-
"M": 6,
|
|
744
|
-
"K": 7,
|
|
745
|
-
"Y": 8,
|
|
746
|
-
"S": 9,
|
|
747
|
-
"N": -1,
|
|
748
|
-
}
|
|
749
|
-
|
|
750
|
-
y_true_int = self.pgenc.convert_int_iupac(
|
|
751
|
-
y_true_dec, encodings_dict=encodings_dict
|
|
752
|
-
)
|
|
753
|
-
y_pred_int = self.pgenc.convert_int_iupac(
|
|
754
|
-
y_pred_dec, encodings_dict=encodings_dict
|
|
755
|
-
)
|
|
756
|
-
|
|
757
|
-
# For IUPAC report
|
|
758
|
-
valid_true = y_true_int[eval_mask]
|
|
759
|
-
valid_true = valid_true[valid_true >= 0] # drop -1 (N)
|
|
760
|
-
iupac_label_set = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
|
|
761
|
-
|
|
762
|
-
# For numeric report
|
|
763
|
-
if (
|
|
764
|
-
np.intersect1d(np.unique(y_true_flat), labels_for_scoring).size == 0
|
|
765
|
-
or valid_true.size == 0
|
|
766
|
-
):
|
|
767
|
-
if not objective_mode:
|
|
768
|
-
self.logger.warning(
|
|
769
|
-
"Skipped numeric confusion matrix: no y_true labels present."
|
|
770
|
-
)
|
|
771
|
-
else:
|
|
772
|
-
self._make_class_reports(
|
|
773
|
-
y_true=valid_true,
|
|
774
|
-
y_pred=y_pred_int[eval_mask][y_true_int[eval_mask] >= 0],
|
|
775
|
-
metrics=metrics,
|
|
776
|
-
y_pred_proba=None,
|
|
777
|
-
labels=iupac_label_set,
|
|
778
|
-
)
|
|
779
|
-
|
|
780
|
-
return metrics
|
|
781
|
-
|
|
782
|
-
def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
|
|
783
|
-
"""Creates a PyTorch DataLoader for the 0/1/2 encoded data.
|
|
784
|
-
|
|
785
|
-
This method constructs a DataLoader from the provided genotype data, which is expected to be in 0/1/2 encoding with -1 for missing values. The DataLoader is used for batching and shuffling the data during model training. It converts the numpy array to a PyTorch tensor and creates a TensorDataset. The DataLoader is configured with the specified batch size and shuffling enabled.
|
|
786
|
-
|
|
787
|
-
Args:
|
|
788
|
-
y (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
|
|
789
|
-
|
|
790
|
-
Returns:
|
|
791
|
-
torch.utils.data.DataLoader: DataLoader for the dataset.
|
|
792
|
-
"""
|
|
793
|
-
y_tensor = torch.from_numpy(y).long().to(self.device)
|
|
794
|
-
dataset = torch.utils.data.TensorDataset(
|
|
795
|
-
torch.arange(len(y), device=self.device), y_tensor.to(self.device)
|
|
796
|
-
)
|
|
797
|
-
return torch.utils.data.DataLoader(
|
|
798
|
-
dataset, batch_size=self.batch_size, shuffle=True
|
|
799
|
-
)
|
|
800
|
-
|
|
801
|
-
def _create_latent_space(
|
|
802
|
-
self,
|
|
803
|
-
params: dict,
|
|
804
|
-
n_samples: int,
|
|
805
|
-
X: np.ndarray,
|
|
806
|
-
latent_init: Literal["random", "pca"],
|
|
807
|
-
) -> torch.nn.Parameter:
|
|
808
|
-
"""Initializes the latent space for the NLPCA model.
|
|
809
|
-
|
|
810
|
-
This method initializes the latent space for the NLPCA model based on the specified initialization method. It supports two methods: 'random' initialization using Xavier uniform distribution, and 'pca' initialization which uses PCA to derive initial latent vectors from the data. The latent vectors are returned as a PyTorch Parameter, allowing them to be optimized during training.
|
|
811
|
-
|
|
812
|
-
Args:
|
|
813
|
-
params (dict): Model parameters including 'latent_dim'.
|
|
814
|
-
n_samples (int): Number of samples in the dataset.
|
|
815
|
-
X (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
|
|
816
|
-
latent_init (str): Method to initialize latent space ('random' or 'pca').
|
|
817
|
-
|
|
818
|
-
Returns:
|
|
819
|
-
torch.nn.Parameter: Initialized latent vectors as a PyTorch Parameter.
|
|
820
|
-
"""
|
|
821
|
-
latent_dim = int(params["latent_dim"])
|
|
822
|
-
|
|
823
|
-
if latent_init == "pca":
|
|
824
|
-
X_pca = X.astype(np.float32, copy=True)
|
|
825
|
-
# mark missing
|
|
826
|
-
X_pca[X_pca < 0] = np.nan
|
|
827
|
-
|
|
828
|
-
# ---- SAFE column means without warnings ----
|
|
829
|
-
valid_counts = np.sum(~np.isnan(X_pca), axis=0)
|
|
830
|
-
col_sums = np.nansum(X_pca, axis=0)
|
|
831
|
-
col_means = np.divide(
|
|
832
|
-
col_sums,
|
|
833
|
-
valid_counts,
|
|
834
|
-
out=np.zeros_like(col_sums, dtype=np.float32),
|
|
835
|
-
where=valid_counts > 0,
|
|
836
|
-
)
|
|
837
|
-
|
|
838
|
-
# impute NaNs with per-column means
|
|
839
|
-
# (all-NaN cols -> 0.0 by the divide above)
|
|
840
|
-
nan_r, nan_c = np.where(np.isnan(X_pca))
|
|
841
|
-
if nan_r.size:
|
|
842
|
-
X_pca[nan_r, nan_c] = col_means[nan_c]
|
|
843
|
-
|
|
844
|
-
# center columns
|
|
845
|
-
X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
|
|
846
|
-
|
|
847
|
-
# guard: degenerate / all-zero after centering ->
|
|
848
|
-
# fall back to random
|
|
849
|
-
if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
|
|
850
|
-
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
851
|
-
torch.nn.init.xavier_uniform_(latents)
|
|
852
|
-
return torch.nn.Parameter(latents, requires_grad=True)
|
|
853
|
-
|
|
854
|
-
# rank-aware component count, at least 1
|
|
855
|
-
try:
|
|
856
|
-
est_rank = np.linalg.matrix_rank(X_pca)
|
|
857
|
-
except Exception:
|
|
858
|
-
est_rank = min(n_samples, X_pca.shape[1])
|
|
859
|
-
|
|
860
|
-
n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
|
|
861
|
-
|
|
862
|
-
# use deterministic SVD to avoid power-iteration warnings
|
|
863
|
-
pca = PCA(
|
|
864
|
-
n_components=n_components, svd_solver="full", random_state=self.seed
|
|
865
|
-
)
|
|
866
|
-
initial = pca.fit_transform(X_pca) # (n_samples, n_components)
|
|
867
|
-
|
|
868
|
-
# pad if latent_dim > n_components
|
|
869
|
-
if n_components < latent_dim:
|
|
870
|
-
pad = self.rng.standard_normal(
|
|
871
|
-
size=(n_samples, latent_dim - n_components)
|
|
872
|
-
)
|
|
873
|
-
initial = np.hstack([initial, pad])
|
|
874
|
-
|
|
875
|
-
# standardize latent dims
|
|
876
|
-
initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
|
|
877
|
-
|
|
878
|
-
latents = torch.from_numpy(initial).float().to(self.device)
|
|
879
|
-
return torch.nn.Parameter(latents, requires_grad=True)
|
|
880
|
-
|
|
881
|
-
# --- Random init path (unchanged) ---
|
|
882
|
-
latents = torch.empty(n_samples, latent_dim, device=self.device)
|
|
883
|
-
torch.nn.init.xavier_uniform_(latents)
|
|
884
|
-
return torch.nn.Parameter(latents, requires_grad=True)
|
|
885
|
-
|
|
886
|
-
def _objective(self, trial: optuna.Trial) -> float:
|
|
887
|
-
"""Objective function for hyperparameter tuning with Optuna.
|
|
888
|
-
|
|
889
|
-
This method defines the objective function used by Optuna for hyperparameter tuning of the NLPCA model. It samples a set of hyperparameters, prepares the training and validation data, initializes the model and latent vectors, and trains the model. After training, it evaluates the model on a validation set and returns the value of the specified tuning metric.
|
|
890
|
-
|
|
891
|
-
Args:
|
|
892
|
-
trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
|
|
893
|
-
|
|
894
|
-
Returns:
|
|
895
|
-
float: The value of the tuning metric to be minimized or maximized.
|
|
896
|
-
"""
|
|
897
|
-
self._prepare_tuning_artifacts()
|
|
898
|
-
trial_params = self._sample_hyperparameters(trial)
|
|
899
|
-
model_params = trial_params["model_params"]
|
|
900
|
-
|
|
901
|
-
nfeat = self._tune_num_features
|
|
902
|
-
if self.tune and self.tune_fast:
|
|
903
|
-
model_params["n_features"] = nfeat
|
|
904
|
-
|
|
905
|
-
lr = trial_params["lr"]
|
|
906
|
-
l1_penalty = trial_params["l1_penalty"]
|
|
907
|
-
lr_input_fac = trial_params["lr_input_factor"]
|
|
908
|
-
|
|
909
|
-
X_train_trial = self._tune_X_train
|
|
910
|
-
X_test_trial = self._tune_X_test
|
|
911
|
-
class_weights = self._tune_class_weights
|
|
912
|
-
train_loader = self._tune_loader
|
|
913
|
-
|
|
914
|
-
train_latents = self._create_latent_space(
|
|
915
|
-
model_params, len(X_train_trial), X_train_trial, trial_params["latent_init"]
|
|
916
|
-
)
|
|
917
|
-
|
|
918
|
-
model = self.build_model(self.Model, model_params)
|
|
919
|
-
model.n_features = model_params["n_features"]
|
|
920
|
-
model.apply(self.initialize_weights)
|
|
921
|
-
|
|
922
|
-
_, model, __ = self._train_and_validate_model(
|
|
923
|
-
model=model,
|
|
924
|
-
loader=train_loader,
|
|
925
|
-
lr=lr,
|
|
926
|
-
l1_penalty=l1_penalty,
|
|
927
|
-
trial=trial,
|
|
928
|
-
latent_vectors=train_latents,
|
|
929
|
-
lr_input_factor=lr_input_fac,
|
|
930
|
-
class_weights=class_weights,
|
|
931
|
-
X_val=X_test_trial,
|
|
932
|
-
params=model_params,
|
|
933
|
-
prune_metric=self.tune_metric,
|
|
934
|
-
prune_warmup_epochs=5,
|
|
935
|
-
eval_interval=self.tune_eval_interval,
|
|
936
|
-
eval_latent_steps=self.eval_latent_steps,
|
|
937
|
-
eval_latent_lr=self.eval_latent_lr,
|
|
938
|
-
eval_latent_weight_decay=self.eval_latent_weight_decay,
|
|
939
|
-
)
|
|
940
|
-
|
|
941
|
-
# --- simulate-only eval mask for tuning ---
|
|
942
|
-
eval_mask = None
|
|
943
|
-
if (
|
|
944
|
-
self.simulate_missing
|
|
945
|
-
and getattr(self, "sim_mask_global_", None) is not None
|
|
946
|
-
):
|
|
947
|
-
if hasattr(self, "_tune_test_idx") and self.sim_mask_global_ is not None:
|
|
948
|
-
eval_mask = self.sim_mask_global_[self._tune_test_idx]
|
|
949
|
-
elif getattr(self, "sim_mask_test_", None) is not None:
|
|
950
|
-
eval_mask = self.sim_mask_test_
|
|
951
|
-
|
|
952
|
-
metrics = self._evaluate_model(
|
|
953
|
-
X_test_trial,
|
|
954
|
-
model,
|
|
955
|
-
model_params,
|
|
956
|
-
objective_mode=True,
|
|
957
|
-
eval_mask_override=eval_mask,
|
|
958
|
-
)
|
|
959
|
-
|
|
960
|
-
self._clear_resources(model, train_loader, latent_vectors=train_latents)
|
|
961
|
-
return metrics[self.tune_metric]
|
|
962
|
-
|
|
963
|
-
def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
|
|
964
|
-
"""Samples hyperparameters for the simplified NLPCA model.
|
|
965
|
-
|
|
966
|
-
This method defines the hyperparameter search space for the NLPCA model and samples a set of hyperparameters using the provided Optuna trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
|
|
967
|
-
|
|
968
|
-
Args:
|
|
969
|
-
trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
|
|
970
|
-
|
|
971
|
-
Returns:
|
|
972
|
-
Dict[str, int | float | str | list]: A dictionary of sampled hyperparameters.
|
|
973
|
-
"""
|
|
974
|
-
params = {
|
|
975
|
-
"latent_dim": trial.suggest_int("latent_dim", 2, 32),
|
|
976
|
-
"lr": trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
|
|
977
|
-
"dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.05),
|
|
978
|
-
"num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 16),
|
|
979
|
-
"activation": trial.suggest_categorical(
|
|
980
|
-
"activation", ["relu", "elu", "selu", "leaky_relu"]
|
|
981
|
-
),
|
|
982
|
-
"gamma": trial.suggest_float("gamma", 0.1, 5.0, step=0.1),
|
|
983
|
-
"lr_input_factor": trial.suggest_float(
|
|
984
|
-
"lr_input_factor", 0.1, 10.0, log=True
|
|
985
|
-
),
|
|
986
|
-
"l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
|
|
987
|
-
"layer_scaling_factor": trial.suggest_float(
|
|
988
|
-
"layer_scaling_factor", 2.0, 10.0
|
|
989
|
-
),
|
|
990
|
-
"layer_schedule": trial.suggest_categorical(
|
|
991
|
-
"layer_schedule", ["pyramid", "constant", "linear"]
|
|
992
|
-
),
|
|
993
|
-
"latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
|
|
994
|
-
}
|
|
995
|
-
|
|
996
|
-
use_n_features = (
|
|
997
|
-
self._tune_num_features
|
|
998
|
-
if (self.tune and self.tune_fast and hasattr(self, "_tune_num_features"))
|
|
999
|
-
else self.num_features_
|
|
1000
|
-
)
|
|
1001
|
-
use_n_samples = (
|
|
1002
|
-
len(self._tune_train_idx)
|
|
1003
|
-
if (self.tune and self.tune_fast and hasattr(self, "_tune_train_idx"))
|
|
1004
|
-
else len(self.train_idx_)
|
|
1005
|
-
)
|
|
1006
|
-
|
|
1007
|
-
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1008
|
-
n_inputs=params["latent_dim"],
|
|
1009
|
-
n_outputs=use_n_features * self.num_classes_,
|
|
1010
|
-
n_samples=use_n_samples,
|
|
1011
|
-
n_hidden=params["num_hidden_layers"],
|
|
1012
|
-
alpha=params["layer_scaling_factor"],
|
|
1013
|
-
schedule=params["layer_schedule"],
|
|
1014
|
-
)
|
|
1015
|
-
|
|
1016
|
-
params["model_params"] = {
|
|
1017
|
-
"n_features": use_n_features,
|
|
1018
|
-
"num_classes": self.num_classes_,
|
|
1019
|
-
"latent_dim": params["latent_dim"],
|
|
1020
|
-
"dropout_rate": params["dropout_rate"],
|
|
1021
|
-
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1022
|
-
"activation": params["activation"],
|
|
1023
|
-
"gamma": params["gamma"],
|
|
1024
|
-
}
|
|
1025
|
-
|
|
1026
|
-
return params
|
|
1027
|
-
|
|
1028
|
-
def _set_best_params(self, best_params: dict) -> dict:
|
|
1029
|
-
"""Sets the best hyperparameters found during tuning.
|
|
1030
|
-
|
|
1031
|
-
This method updates the model's attributes with the best hyperparameters obtained from tuning. It also computes the hidden layer sizes based on these parameters and prepares the final model parameters dictionary.
|
|
1032
|
-
|
|
1033
|
-
Args:
|
|
1034
|
-
best_params (dict): Best hyperparameters from tuning.
|
|
1035
|
-
|
|
1036
|
-
Returns:
|
|
1037
|
-
dict: Model parameters configured with the best hyperparameters.
|
|
1038
|
-
"""
|
|
1039
|
-
self.latent_dim = best_params["latent_dim"]
|
|
1040
|
-
self.dropout_rate = best_params["dropout_rate"]
|
|
1041
|
-
self.learning_rate = best_params["learning_rate"]
|
|
1042
|
-
self.gamma = best_params["gamma"]
|
|
1043
|
-
self.lr_input_factor = best_params["lr_input_factor"]
|
|
1044
|
-
self.l1_penalty = best_params["l1_penalty"]
|
|
1045
|
-
self.activation = best_params["activation"]
|
|
1046
|
-
|
|
1047
|
-
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1048
|
-
n_inputs=self.latent_dim,
|
|
1049
|
-
n_outputs=self.num_features_ * self.num_classes_,
|
|
1050
|
-
n_samples=len(self.train_idx_),
|
|
1051
|
-
n_hidden=best_params["num_hidden_layers"],
|
|
1052
|
-
alpha=best_params["layer_scaling_factor"],
|
|
1053
|
-
schedule=best_params["layer_schedule"],
|
|
1054
|
-
)
|
|
1055
|
-
|
|
1056
|
-
return {
|
|
1057
|
-
"n_features": self.num_features_,
|
|
1058
|
-
"latent_dim": self.latent_dim,
|
|
1059
|
-
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1060
|
-
"dropout_rate": self.dropout_rate,
|
|
1061
|
-
"activation": self.activation,
|
|
1062
|
-
"gamma": self.gamma,
|
|
1063
|
-
"num_classes": self.num_classes_,
|
|
1064
|
-
}
|
|
1065
|
-
|
|
1066
|
-
def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
|
|
1067
|
-
"""Default (no-tuning) model_params aligned with current attributes.
|
|
1068
|
-
|
|
1069
|
-
This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
|
|
1070
|
-
|
|
1071
|
-
Returns:
|
|
1072
|
-
Dict[str, int | float | str | list]: model_params payload.
|
|
1073
|
-
"""
|
|
1074
|
-
hidden_layer_sizes = self._compute_hidden_layer_sizes(
|
|
1075
|
-
n_inputs=self.latent_dim,
|
|
1076
|
-
n_outputs=self.num_features_ * self.num_classes_,
|
|
1077
|
-
n_samples=len(self.ground_truth_),
|
|
1078
|
-
n_hidden=self.num_hidden_layers,
|
|
1079
|
-
alpha=self.layer_scaling_factor,
|
|
1080
|
-
schedule=self.layer_schedule,
|
|
1081
|
-
)
|
|
1082
|
-
|
|
1083
|
-
return {
|
|
1084
|
-
"n_features": self.num_features_,
|
|
1085
|
-
"latent_dim": self.latent_dim,
|
|
1086
|
-
"hidden_layer_sizes": hidden_layer_sizes,
|
|
1087
|
-
"dropout_rate": self.dropout_rate,
|
|
1088
|
-
"activation": self.activation,
|
|
1089
|
-
"gamma": self.gamma,
|
|
1090
|
-
"num_classes": self.num_classes_,
|
|
1091
|
-
}
|
|
1092
|
-
|
|
1093
|
-
def _train_and_validate_model(
|
|
1094
|
-
self,
|
|
1095
|
-
model: torch.nn.Module,
|
|
1096
|
-
loader: torch.utils.data.DataLoader,
|
|
1097
|
-
lr: float,
|
|
1098
|
-
l1_penalty: float,
|
|
1099
|
-
trial: optuna.Trial | None = None,
|
|
1100
|
-
return_history: bool = False,
|
|
1101
|
-
latent_vectors: torch.nn.Parameter | None = None,
|
|
1102
|
-
lr_input_factor: float = 1.0,
|
|
1103
|
-
class_weights: torch.Tensor | None = None,
|
|
1104
|
-
*,
|
|
1105
|
-
X_val: np.ndarray | None = None,
|
|
1106
|
-
params: dict | None = None,
|
|
1107
|
-
prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
|
|
1108
|
-
prune_warmup_epochs: int = 3,
|
|
1109
|
-
eval_interval: int = 1,
|
|
1110
|
-
eval_latent_steps: int = 50,
|
|
1111
|
-
eval_latent_lr: float = 1e-2,
|
|
1112
|
-
eval_latent_weight_decay: float = 0.0,
|
|
1113
|
-
) -> Tuple:
|
|
1114
|
-
"""Trains and validates the NLPCA model.
|
|
1115
|
-
|
|
1116
|
-
This method trains the provided NLPCA model using the specified training data and hyperparameters. It supports optional integration with Optuna for hyperparameter tuning and pruning based on validation performance. The method initializes optimizers for both the model parameters and latent vectors, sets up a learning rate scheduler, and executes the training loop. It can return the training history if requested.
|
|
1117
|
-
|
|
1118
|
-
Args:
|
|
1119
|
-
model (torch.nn.Module): The NLPCA model to be trained.
|
|
1120
|
-
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
1121
|
-
lr (float): Learning rate for the model optimizer.
|
|
1122
|
-
l1_penalty (float): L1 regularization penalty.
|
|
1123
|
-
trial (optuna.Trial | None): Optuna trial for hyperparameter tuning.
|
|
1124
|
-
return_history (bool): Whether to return training history.
|
|
1125
|
-
latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
|
|
1126
|
-
lr_input_factor (float): Learning rate factor for latent vectors.
|
|
1127
|
-
class_weights (torch.Tensor | None): Class weights for handling class imbalance.
|
|
1128
|
-
X_val (np.ndarray | None): Validation data for pruning.
|
|
1129
|
-
params (dict | None): Model parameters.
|
|
1130
|
-
prune_metric (str | None): Metric for pruning decisions.
|
|
1131
|
-
prune_warmup_epochs (int): Number of epochs before pruning starts.
|
|
1132
|
-
eval_interval (int): Interval (in epochs) for evaluation during training.
|
|
1133
|
-
eval_latent_steps (int): Steps for latent optimization during evaluation.
|
|
1134
|
-
eval_latent_lr (float): Learning rate for latent optimization during evaluation.
|
|
1135
|
-
eval_latent_weight_decay (float): Weight decay for latent optimization during evaluation.
|
|
1136
|
-
|
|
1137
|
-
Returns:
|
|
1138
|
-
Tuple[float, torch.nn.Module, Dict[str, float], torch.nn.Parameter] | Tuple[float, torch.nn.Module, torch.nn.Parameter]: Training loss, trained model, training history (if requested), and optimized latent vectors.
|
|
1139
|
-
|
|
1140
|
-
Raises:
|
|
1141
|
-
TypeError: If latent_vectors or class_weights are not provided.
|
|
1142
|
-
"""
|
|
1143
|
-
|
|
1144
|
-
if latent_vectors is None or class_weights is None:
|
|
1145
|
-
msg = "latent_vectors and class_weights must be provided."
|
|
1146
|
-
self.logger.error(msg)
|
|
1147
|
-
raise TypeError("Must provide latent_vectors and class_weights.")
|
|
1148
|
-
|
|
1149
|
-
latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
|
|
1150
|
-
|
|
1151
|
-
decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
|
|
1152
|
-
|
|
1153
|
-
if not isinstance(decoder, torch.nn.Module):
|
|
1154
|
-
msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
|
|
1155
|
-
self.logger.error(msg)
|
|
1156
|
-
raise TypeError(msg)
|
|
1157
|
-
|
|
1158
|
-
optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
|
|
1159
|
-
scheduler = CosineAnnealingLR(optimizer, T_max=self.epochs)
|
|
1160
|
-
|
|
1161
|
-
result = self._execute_training_loop(
|
|
1162
|
-
loader=loader,
|
|
1163
|
-
optimizer=optimizer,
|
|
1164
|
-
latent_optimizer=latent_optimizer,
|
|
1165
|
-
scheduler=scheduler,
|
|
1166
|
-
model=model,
|
|
1167
|
-
l1_penalty=l1_penalty,
|
|
1168
|
-
return_history=return_history,
|
|
1169
|
-
latent_vectors=latent_vectors,
|
|
1170
|
-
class_weights=class_weights,
|
|
1171
|
-
trial=trial,
|
|
1172
|
-
X_val=X_val,
|
|
1173
|
-
params=params,
|
|
1174
|
-
prune_metric=prune_metric,
|
|
1175
|
-
prune_warmup_epochs=prune_warmup_epochs,
|
|
1176
|
-
eval_interval=eval_interval,
|
|
1177
|
-
eval_latent_steps=eval_latent_steps,
|
|
1178
|
-
eval_latent_lr=eval_latent_lr,
|
|
1179
|
-
eval_latent_weight_decay=eval_latent_weight_decay,
|
|
1180
|
-
)
|
|
1181
|
-
|
|
1182
|
-
if return_history:
|
|
1183
|
-
return result
|
|
1184
|
-
|
|
1185
|
-
return result[0], result[1], result[3]
|
|
1186
|
-
|
|
1187
|
-
def _train_final_model(
|
|
1188
|
-
self,
|
|
1189
|
-
loader: torch.utils.data.DataLoader,
|
|
1190
|
-
best_params: dict,
|
|
1191
|
-
initial_latent_vectors: torch.nn.Parameter,
|
|
1192
|
-
) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
|
|
1193
|
-
"""Trains the final model using the best hyperparameters.
|
|
1194
|
-
|
|
1195
|
-
This method builds and trains the final NLPCA model using the best hyperparameters obtained from tuning. It initializes the model weights, trains the model on the entire training set, and saves the trained model to disk. It returns the final training loss, trained model, training history, and optimized latent vectors.
|
|
1196
|
-
|
|
1197
|
-
Args:
|
|
1198
|
-
loader (torch.utils.data.DataLoader): DataLoader for training data.
|
|
1199
|
-
best_params (dict): Best hyperparameters for the model.
|
|
1200
|
-
initial_latent_vectors (torch.nn.Parameter): Initial latent vectors for samples.
|
|
1201
|
-
|
|
1202
|
-
Returns:
|
|
1203
|
-
Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: Final training loss, trained model, training history, and optimized latent vectors.
|
|
1204
|
-
Raises:
|
|
1205
|
-
RuntimeError: If model training fails.
|
|
1206
|
-
"""
|
|
1207
|
-
self.logger.info(f"Training the final {self.model_name} model...")
|
|
1208
|
-
|
|
1209
|
-
model = self.build_model(self.Model, best_params)
|
|
1210
|
-
model.n_features = best_params["n_features"]
|
|
1211
|
-
model.apply(self.initialize_weights)
|
|
1212
|
-
|
|
1213
|
-
loss, trained_model, history, latent_vectors = self._train_and_validate_model(
|
|
1214
|
-
model=model,
|
|
1215
|
-
loader=loader,
|
|
1216
|
-
lr=self.learning_rate,
|
|
1217
|
-
l1_penalty=self.l1_penalty,
|
|
1218
|
-
return_history=True,
|
|
1219
|
-
latent_vectors=initial_latent_vectors,
|
|
1220
|
-
lr_input_factor=self.lr_input_factor,
|
|
1221
|
-
class_weights=self.class_weights_,
|
|
1222
|
-
X_val=self.X_test_,
|
|
1223
|
-
params=best_params,
|
|
1224
|
-
prune_metric=self.tune_metric,
|
|
1225
|
-
prune_warmup_epochs=5,
|
|
1226
|
-
eval_interval=1,
|
|
1227
|
-
eval_latent_steps=self.eval_latent_steps,
|
|
1228
|
-
eval_latent_lr=self.eval_latent_lr,
|
|
1229
|
-
eval_latent_weight_decay=self.eval_latent_weight_decay,
|
|
1230
|
-
)
|
|
1231
|
-
|
|
1232
|
-
if trained_model is None:
|
|
1233
|
-
msg = "Final model training failed."
|
|
1234
|
-
self.logger.error(msg)
|
|
1235
|
-
raise RuntimeError(msg)
|
|
1236
|
-
|
|
1237
|
-
fn = self.models_dir / "final_model.pt"
|
|
1238
|
-
torch.save(trained_model.state_dict(), fn)
|
|
1239
|
-
|
|
1240
|
-
return (loss, trained_model, {"Train": history}, latent_vectors)
|
|
1241
|
-
|
|
1242
|
-
def _execute_training_loop(
|
|
1243
|
-
self,
|
|
1244
|
-
loader,
|
|
1245
|
-
optimizer,
|
|
1246
|
-
latent_optimizer,
|
|
1247
|
-
scheduler, # do not overwrite; honor caller's scheduler
|
|
1248
|
-
model,
|
|
1249
|
-
l1_penalty,
|
|
1250
|
-
return_history,
|
|
1251
|
-
latent_vectors,
|
|
1252
|
-
class_weights,
|
|
1253
|
-
*,
|
|
1254
|
-
trial: optuna.Trial | None = None,
|
|
1255
|
-
X_val: np.ndarray | None = None,
|
|
1256
|
-
params: dict | None = None,
|
|
1257
|
-
prune_metric: str | None = None,
|
|
1258
|
-
prune_warmup_epochs: int = 3,
|
|
1259
|
-
eval_interval: int = 1,
|
|
1260
|
-
eval_latent_steps: int = 50,
|
|
1261
|
-
eval_latent_lr: float = 1e-2,
|
|
1262
|
-
eval_latent_weight_decay: float = 0.0,
|
|
1263
|
-
) -> Tuple[float, torch.nn.Module, list, torch.nn.Parameter]:
|
|
1264
|
-
"""Train NLPCA with warmup, pruning, and early stopping."""
|
|
1265
|
-
best_model = None
|
|
1266
|
-
history: list[float] = []
|
|
1267
|
-
|
|
1268
|
-
early_stopping = EarlyStopping(
|
|
1269
|
-
patience=self.early_stop_gen,
|
|
1270
|
-
min_epochs=self.min_epochs,
|
|
1271
|
-
verbose=self.verbose,
|
|
1272
|
-
prefix=self.prefix,
|
|
1273
|
-
debug=self.debug,
|
|
1274
|
-
)
|
|
1275
|
-
|
|
1276
|
-
# Epoch budget
|
|
1277
|
-
max_epochs = (
|
|
1278
|
-
self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
|
|
1279
|
-
)
|
|
1280
|
-
|
|
1281
|
-
# Optional LR warmup for both optimizers
|
|
1282
|
-
warmup_epochs = getattr(self, "lr_warmup_epochs", 5)
|
|
1283
|
-
model_lr0 = optimizer.param_groups[0]["lr"]
|
|
1284
|
-
latent_lr0 = latent_optimizer.param_groups[0]["lr"]
|
|
1285
|
-
model_lr_min = model_lr0 * 0.1
|
|
1286
|
-
latent_lr_min = latent_lr0 * 0.1
|
|
1287
|
-
|
|
1288
|
-
_latent_cache: dict = {}
|
|
1289
|
-
_latent_cache_key = f"{self.prefix}_{self.model_name}_val_latents"
|
|
1290
|
-
|
|
1291
|
-
for epoch in range(max_epochs):
|
|
1292
|
-
# Linear warmup LRs for first few epochs
|
|
1293
|
-
if epoch < warmup_epochs:
|
|
1294
|
-
scale = float(epoch + 1) / warmup_epochs
|
|
1295
|
-
for g in optimizer.param_groups:
|
|
1296
|
-
g["lr"] = model_lr_min + (model_lr0 - model_lr_min) * scale
|
|
1297
|
-
for g in latent_optimizer.param_groups:
|
|
1298
|
-
g["lr"] = latent_lr_min + (latent_lr0 - latent_lr_min) * scale
|
|
1299
|
-
|
|
1300
|
-
train_loss, latent_vectors = self._train_step(
|
|
1301
|
-
loader=loader,
|
|
1302
|
-
optimizer=optimizer,
|
|
1303
|
-
latent_optimizer=latent_optimizer,
|
|
1304
|
-
model=model,
|
|
1305
|
-
l1_penalty=l1_penalty,
|
|
1306
|
-
latent_vectors=latent_vectors,
|
|
1307
|
-
class_weights=class_weights,
|
|
1308
|
-
)
|
|
1309
|
-
|
|
1310
|
-
if not np.isfinite(train_loss):
|
|
1311
|
-
if trial:
|
|
1312
|
-
raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
|
|
1313
|
-
# Reduce both LRs and continue
|
|
1314
|
-
for g in optimizer.param_groups:
|
|
1315
|
-
g["lr"] *= 0.5
|
|
1316
|
-
for g in latent_optimizer.param_groups:
|
|
1317
|
-
g["lr"] *= 0.5
|
|
1318
|
-
continue
|
|
1319
|
-
|
|
1320
|
-
if scheduler is not None:
|
|
1321
|
-
scheduler.step()
|
|
1322
|
-
|
|
1323
|
-
if return_history:
|
|
1324
|
-
history.append(train_loss)
|
|
1325
|
-
|
|
1326
|
-
# Optuna prune on validation metric
|
|
1327
|
-
if (
|
|
1328
|
-
trial is not None
|
|
1329
|
-
and X_val is not None
|
|
1330
|
-
and ((epoch + 1) % eval_interval == 0)
|
|
1331
|
-
):
|
|
1332
|
-
seed = int(
|
|
1333
|
-
self.rng.integers(0, 1_000_000) if self.seed is None else self.seed
|
|
1334
|
-
)
|
|
1335
|
-
metric_key = prune_metric or getattr(self, "tune_metric", "f1")
|
|
1336
|
-
do_infer = int(eval_latent_steps) > 0
|
|
1337
|
-
metric_val = self._eval_for_pruning(
|
|
1338
|
-
model=model,
|
|
1339
|
-
X_val=X_val,
|
|
1340
|
-
params=params or getattr(self, "best_params_", {}),
|
|
1341
|
-
metric=metric_key,
|
|
1342
|
-
objective_mode=True,
|
|
1343
|
-
do_latent_infer=do_infer,
|
|
1344
|
-
latent_steps=eval_latent_steps,
|
|
1345
|
-
latent_lr=eval_latent_lr,
|
|
1346
|
-
latent_weight_decay=eval_latent_weight_decay,
|
|
1347
|
-
latent_seed=seed,
|
|
1348
|
-
_latent_cache=_latent_cache,
|
|
1349
|
-
_latent_cache_key=_latent_cache_key,
|
|
1350
|
-
eval_mask_override=(
|
|
1351
|
-
self.sim_mask_test_
|
|
1352
|
-
if (
|
|
1353
|
-
self.simulate_missing
|
|
1354
|
-
and getattr(self, "sim_mask_test_", None) is not None
|
|
1355
|
-
and X_val.shape == self.X_test_.shape
|
|
1356
|
-
)
|
|
1357
|
-
else (
|
|
1358
|
-
self.sim_mask_global_[self._tune_test_idx]
|
|
1359
|
-
if (
|
|
1360
|
-
self.simulate_missing
|
|
1361
|
-
and self.sim_mask_global_ is not None
|
|
1362
|
-
and hasattr(self, "_tune_test_idx")
|
|
1363
|
-
and X_val.shape[0] == len(self._tune_test_idx)
|
|
1364
|
-
)
|
|
1365
|
-
else None
|
|
1366
|
-
)
|
|
1367
|
-
),
|
|
1368
|
-
)
|
|
1369
|
-
|
|
1370
|
-
trial.report(metric_val, step=epoch + 1)
|
|
1371
|
-
|
|
1372
|
-
if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
|
|
1373
|
-
raise optuna.exceptions.TrialPruned(
|
|
1374
|
-
f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.4f}"
|
|
1375
|
-
)
|
|
1376
|
-
|
|
1377
|
-
early_stopping(train_loss, model)
|
|
1378
|
-
if early_stopping.early_stop:
|
|
1379
|
-
self.logger.info(f"Early stopping at epoch {epoch + 1}.")
|
|
1380
|
-
break
|
|
1381
|
-
|
|
1382
|
-
best_loss = early_stopping.best_score
|
|
1383
|
-
best_model = copy.deepcopy(early_stopping.best_model)
|
|
1384
|
-
if best_model is None:
|
|
1385
|
-
best_model = copy.deepcopy(model)
|
|
1386
|
-
return best_loss, best_model, history, latent_vectors
|
|
1387
|
-
|
|
1388
|
-
def _optimize_latents_for_inference(
|
|
1389
|
-
self,
|
|
1390
|
-
X_new: np.ndarray,
|
|
1391
|
-
model: torch.nn.Module,
|
|
1392
|
-
params: dict,
|
|
1393
|
-
inference_epochs: int = 200,
|
|
1394
|
-
) -> torch.Tensor:
|
|
1395
|
-
"""Refine latents for new data with guards.
|
|
1396
|
-
|
|
1397
|
-
This method optimizes latent vectors for new data samples by refining them through gradient-based optimization. It initializes the latent space and iteratively updates the latent vectors to minimize the reconstruction loss using cross-entropy. The method includes safeguards to handle non-finite values during optimization.
|
|
1398
|
-
|
|
1399
|
-
Args:
|
|
1400
|
-
X_new (np.ndarray): New data in 0/1/2 encoding with -
|
|
1401
|
-
model (torch.nn.Module): Trained NLPCA model.
|
|
1402
|
-
params (dict): Model parameters.
|
|
1403
|
-
inference_epochs (int): Number of optimization epochs.
|
|
1404
|
-
|
|
1405
|
-
Returns:
|
|
1406
|
-
torch.Tensor: Optimized latent vectors for the new data.
|
|
1407
|
-
|
|
1408
|
-
"""
|
|
1409
|
-
if self.tune and self.tune_fast:
|
|
1410
|
-
inference_epochs = min(
|
|
1411
|
-
inference_epochs, getattr(self, "tune_infer_epochs", 20)
|
|
1412
|
-
)
|
|
1413
|
-
|
|
1414
|
-
model.eval()
|
|
1415
|
-
nF = getattr(model, "n_features", self.num_features_)
|
|
1416
|
-
|
|
1417
|
-
z = self._create_latent_space(
|
|
1418
|
-
params, len(X_new), X_new, self.latent_init
|
|
1419
|
-
).requires_grad_(True)
|
|
1420
|
-
opt = torch.optim.AdamW(
|
|
1421
|
-
[z], lr=self.learning_rate * self.lr_input_factor, eps=1e-7
|
|
1422
|
-
)
|
|
1423
|
-
|
|
1424
|
-
X_new = X_new.astype(np.int64, copy=False)
|
|
1425
|
-
X_new[X_new < 0] = -1
|
|
1426
|
-
y = torch.from_numpy(X_new).long().to(self.device)
|
|
1427
|
-
|
|
1428
|
-
for _ in range(inference_epochs):
|
|
1429
|
-
opt.zero_grad(set_to_none=True)
|
|
1430
|
-
|
|
1431
|
-
decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
|
|
1432
|
-
|
|
1433
|
-
if not isinstance(decoder, torch.nn.Module):
|
|
1434
|
-
msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
|
|
1435
|
-
self.logger.error(msg)
|
|
1436
|
-
raise TypeError(msg)
|
|
1437
|
-
|
|
1438
|
-
logits = decoder(z).view(len(X_new), nF, self.num_classes_)
|
|
1439
|
-
|
|
1440
|
-
if not torch.isfinite(logits).all():
|
|
1441
|
-
break
|
|
1442
|
-
|
|
1443
|
-
loss = F.cross_entropy(
|
|
1444
|
-
logits.view(-1, self.num_classes_),
|
|
1445
|
-
y.view(-1),
|
|
1446
|
-
ignore_index=-1,
|
|
1447
|
-
reduction="mean",
|
|
1448
|
-
)
|
|
1449
|
-
if not torch.isfinite(loss):
|
|
1450
|
-
break
|
|
1451
|
-
|
|
1452
|
-
loss.backward()
|
|
1453
|
-
torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
|
|
1454
|
-
if z.grad is None or not torch.isfinite(z.grad).all():
|
|
1455
|
-
break
|
|
1456
|
-
opt.step()
|
|
1457
|
-
|
|
1458
|
-
return z.detach()
|
|
1459
|
-
|
|
1460
|
-
def _latent_infer_for_eval(
|
|
1461
|
-
self,
|
|
1462
|
-
model: torch.nn.Module,
|
|
1463
|
-
X_val: np.ndarray,
|
|
1464
|
-
*,
|
|
1465
|
-
steps: int,
|
|
1466
|
-
lr: float,
|
|
1467
|
-
weight_decay: float,
|
|
1468
|
-
seed: int,
|
|
1469
|
-
cache: dict | None,
|
|
1470
|
-
cache_key: str | None,
|
|
1471
|
-
) -> None:
|
|
1472
|
-
"""Freeze weights; refine validation latents only (no leakage).
|
|
1473
|
-
|
|
1474
|
-
This method refines latent vectors for validation data by optimizing them while keeping the model weights frozen. It initializes the latent space, optionally using cached latent vectors, and iteratively updates the latent vectors to minimize the reconstruction loss using cross-entropy. The method includes safeguards to handle non-finite values during optimization and can store the optimized latent vectors in a cache.
|
|
1475
|
-
|
|
1476
|
-
Args:
|
|
1477
|
-
model (torch.nn.Module): Trained NLPCA model.
|
|
1478
|
-
X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
|
|
1479
|
-
steps (int): Number of optimization steps.
|
|
1480
|
-
lr (float): Learning rate for latent optimization.
|
|
1481
|
-
weight_decay (float): Weight decay for latent optimization.
|
|
1482
|
-
seed (int): Random seed for reproducibility.
|
|
1483
|
-
cache (dict | None): Cache for storing optimized latent vectors.
|
|
1484
|
-
cache_key (str | None): Key for storing/retrieving from cache.
|
|
1485
|
-
"""
|
|
1486
|
-
if seed is None:
|
|
1487
|
-
seed = np.random.randint(0, 999_999)
|
|
1488
|
-
torch.manual_seed(seed)
|
|
1489
|
-
np.random.seed(seed)
|
|
1490
|
-
|
|
1491
|
-
model.eval()
|
|
1492
|
-
nF = getattr(model, "n_features", self.num_features_)
|
|
1493
|
-
|
|
1494
|
-
for p in model.parameters():
|
|
1495
|
-
p.requires_grad_(False)
|
|
1496
|
-
|
|
1497
|
-
X_val = X_val.astype(np.int64, copy=False)
|
|
1498
|
-
X_val[X_val < 0] = -1
|
|
1499
|
-
y = torch.from_numpy(X_val).long().to(self.device)
|
|
1500
|
-
|
|
1501
|
-
latent_dim = self._first_linear_in_features(model)
|
|
1502
|
-
cache_key = f"{self.prefix}_nlpca_val_latents_z{latent_dim}_L{self.num_features_}_K{self.num_classes_}"
|
|
1503
|
-
|
|
1504
|
-
if cache is not None and cache_key in cache:
|
|
1505
|
-
z = cache[cache_key].detach().clone().requires_grad_(True)
|
|
1506
|
-
else:
|
|
1507
|
-
z = self._create_latent_space(
|
|
1508
|
-
{"latent_dim": latent_dim},
|
|
1509
|
-
n_samples=X_val.shape[0],
|
|
1510
|
-
X=X_val,
|
|
1511
|
-
latent_init=self.latent_init,
|
|
1512
|
-
).requires_grad_(True)
|
|
1513
|
-
|
|
1514
|
-
opt = torch.optim.AdamW([z], lr=lr, weight_decay=weight_decay, eps=1e-7)
|
|
1515
|
-
|
|
1516
|
-
for _ in range(max(int(steps), 0)):
|
|
1517
|
-
opt.zero_grad(set_to_none=True)
|
|
1518
|
-
|
|
1519
|
-
decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
|
|
1520
|
-
|
|
1521
|
-
if not isinstance(decoder, torch.nn.Module):
|
|
1522
|
-
msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
|
|
1523
|
-
self.logger.error(msg)
|
|
1524
|
-
raise TypeError(msg)
|
|
1525
|
-
|
|
1526
|
-
logits = decoder(z).view(X_val.shape[0], nF, self.num_classes_)
|
|
1527
|
-
|
|
1528
|
-
if not torch.isfinite(logits).all():
|
|
1529
|
-
break
|
|
1530
|
-
|
|
1531
|
-
loss = F.cross_entropy(
|
|
1532
|
-
logits.view(-1, self.num_classes_),
|
|
1533
|
-
y.view(-1),
|
|
1534
|
-
ignore_index=-1,
|
|
1535
|
-
reduction="mean",
|
|
1536
|
-
)
|
|
1537
|
-
|
|
1538
|
-
if not torch.isfinite(loss):
|
|
1539
|
-
break
|
|
1540
|
-
|
|
1541
|
-
loss.backward()
|
|
1542
|
-
|
|
1543
|
-
torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
|
|
1544
|
-
|
|
1545
|
-
if z.grad is None or not torch.isfinite(z.grad).all():
|
|
1546
|
-
break
|
|
1547
|
-
|
|
1548
|
-
opt.step()
|
|
1549
|
-
|
|
1550
|
-
if cache is not None:
|
|
1551
|
-
cache[cache_key] = z.detach().clone()
|
|
1552
|
-
|
|
1553
|
-
for p in model.parameters():
|
|
1554
|
-
p.requires_grad_(True)
|