pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,10 @@ import copy
|
|
|
2
2
|
import gc
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from datetime import datetime
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
|
|
7
9
|
|
|
8
10
|
import matplotlib.pyplot as plt
|
|
9
11
|
import numpy as np
|
|
@@ -13,11 +15,18 @@ import plotly.graph_objects as go
|
|
|
13
15
|
import torch
|
|
14
16
|
import torch.nn.functional as F
|
|
15
17
|
from matplotlib.figure import Figure
|
|
16
|
-
from sklearn.metrics import
|
|
18
|
+
from sklearn.metrics import (
|
|
19
|
+
average_precision_score,
|
|
20
|
+
classification_report,
|
|
21
|
+
jaccard_score,
|
|
22
|
+
matthews_corrcoef,
|
|
23
|
+
)
|
|
17
24
|
from sklearn.model_selection import train_test_split
|
|
18
25
|
from snpio import SNPioMultiQC
|
|
19
26
|
from snpio.utils.logging import LoggerManager
|
|
27
|
+
from snpio.utils.misc import validate_input_type
|
|
20
28
|
|
|
29
|
+
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
21
30
|
from pgsui.impute.unsupervised.nn_scorers import Scorer
|
|
22
31
|
from pgsui.utils.classification_viz import ClassificationReportVisualizer
|
|
23
32
|
from pgsui.utils.logging_utils import configure_logger
|
|
@@ -27,16 +36,24 @@ from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
|
27
36
|
if TYPE_CHECKING:
|
|
28
37
|
from snpio.read_input.genotype_data import GenotypeData
|
|
29
38
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
39
|
+
|
|
40
|
+
class _MaskedNumpyDataset(torch.utils.data.Dataset):
|
|
41
|
+
def __init__(self, X: np.ndarray, y: np.ndarray, mask: np.ndarray):
|
|
42
|
+
self.X = X
|
|
43
|
+
self.y = y
|
|
44
|
+
self.mask = mask.astype(np.bool_, copy=False)
|
|
45
|
+
|
|
46
|
+
def __len__(self) -> int:
|
|
47
|
+
return self.X.shape[0]
|
|
48
|
+
|
|
49
|
+
def __getitem__(self, idx: int):
|
|
50
|
+
return self.X[idx], self.y[idx], self.mask[idx]
|
|
34
51
|
|
|
35
52
|
|
|
36
53
|
class BaseNNImputer:
|
|
37
|
-
"""
|
|
54
|
+
"""Abstract base class for neural network-based imputers.
|
|
38
55
|
|
|
39
|
-
This class provides
|
|
56
|
+
This class provides shared infrastructure for NN imputers (e.g., directory/logging setup, Optuna tuning, model construction helpers, class-weight utilities, standardized plotting/scoring, and IUPAC decoding). It is not intended to be instantiated directly; subclasses must implement the abstract methods.
|
|
40
57
|
"""
|
|
41
58
|
|
|
42
59
|
def __init__(
|
|
@@ -54,6 +71,8 @@ class BaseNNImputer:
|
|
|
54
71
|
This constructor sets up the device (CPU, GPU, or MPS), creates the necessary output directories for models and results, and a logger. It also initializes a genotype encoder for handling genotype data.
|
|
55
72
|
|
|
56
73
|
Args:
|
|
74
|
+
model_name (str): The model class name used in output paths and logs.
|
|
75
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
57
76
|
prefix (str): A prefix used to name the output directory (e.g., 'pgsui_output').
|
|
58
77
|
device (Literal["gpu", "cpu", "mps"]): The device to use for PyTorch operations. If 'gpu' or 'mps' is chosen, it will fall back to 'cpu' if the required hardware is not available. Defaults to "cpu".
|
|
59
78
|
verbose (bool): If True, enables detailed logging output. Defaults to False.
|
|
@@ -62,6 +81,11 @@ class BaseNNImputer:
|
|
|
62
81
|
self.model_name = model_name
|
|
63
82
|
self.genotype_data = genotype_data
|
|
64
83
|
|
|
84
|
+
if not hasattr(self, "tree_parser"):
|
|
85
|
+
self.tree_parser = None
|
|
86
|
+
if not hasattr(self, "sim_kwargs"):
|
|
87
|
+
self.sim_kwargs = {}
|
|
88
|
+
|
|
65
89
|
self.prefix = prefix
|
|
66
90
|
self.verbose = verbose
|
|
67
91
|
self.debug = debug
|
|
@@ -89,15 +113,15 @@ class BaseNNImputer:
|
|
|
89
113
|
self.logger = configure_logger(
|
|
90
114
|
logman.get_logger(), verbose=self.verbose, debug=self.debug
|
|
91
115
|
)
|
|
92
|
-
|
|
93
|
-
self.
|
|
116
|
+
|
|
117
|
+
self.logger.info(f"Using PyTorch device: {self.device.type}.")
|
|
94
118
|
|
|
95
119
|
# To be initialized by child classes or fit method
|
|
96
120
|
self.tune_save_db: bool = False
|
|
97
121
|
self.tune_resume: bool = False
|
|
98
122
|
self.n_trials: int = 100
|
|
99
123
|
self.model_params: Dict[str, Any] = {}
|
|
100
|
-
self.tune_metric: str = "
|
|
124
|
+
self.tune_metric: str = "f1"
|
|
101
125
|
self.learning_rate: float = 1e-3
|
|
102
126
|
self.plotter_: "Plotting"
|
|
103
127
|
self.num_features_: int = 0
|
|
@@ -110,21 +134,16 @@ class BaseNNImputer:
|
|
|
110
134
|
self.show_plots: bool = False
|
|
111
135
|
self.scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
|
|
112
136
|
self.pgenc: Any = None
|
|
113
|
-
self.
|
|
137
|
+
self.is_haploid_: bool = False
|
|
114
138
|
self.ploidy: int = 2
|
|
115
139
|
self.beta: float = 0.9999
|
|
116
|
-
self.max_ratio: float =
|
|
117
|
-
self.sim_strategy: str = "
|
|
118
|
-
self.sim_prop: float = 0.
|
|
119
|
-
self.seed: int
|
|
140
|
+
self.max_ratio: Optional[float] = None
|
|
141
|
+
self.sim_strategy: str = "random"
|
|
142
|
+
self.sim_prop: float = 0.2
|
|
143
|
+
self.seed: Optional[int] = None
|
|
120
144
|
self.rng: np.random.Generator = np.random.default_rng(self.seed)
|
|
121
145
|
self.ground_truth_: np.ndarray
|
|
122
|
-
self.tune_fast: bool = False
|
|
123
|
-
self.tune_max_samples: int = 1000
|
|
124
|
-
self.tune_max_loci: int = 500
|
|
125
146
|
self.validation_split: float = 0.2
|
|
126
|
-
self.tune_batch_size: int = 64
|
|
127
|
-
self.tune_proxy_metric_batch: int = 512
|
|
128
147
|
self.batch_size: int = 64
|
|
129
148
|
self.best_params_: Dict[str, Any] = {}
|
|
130
149
|
|
|
@@ -133,9 +152,11 @@ class BaseNNImputer:
|
|
|
133
152
|
self.plots_dir: Path
|
|
134
153
|
self.metrics_dir: Path
|
|
135
154
|
self.parameters_dir: Path
|
|
136
|
-
self.study_db: Path
|
|
155
|
+
self.study_db: Optional[Path] = None
|
|
156
|
+
self.X_model_input_: Optional[np.ndarray] = None
|
|
157
|
+
self.class_weights_: Optional[torch.Tensor] = None
|
|
137
158
|
|
|
138
|
-
def tune_hyperparameters(self) ->
|
|
159
|
+
def tune_hyperparameters(self) -> Dict[str, Any]:
|
|
139
160
|
"""Tunes model hyperparameters using an Optuna study.
|
|
140
161
|
|
|
141
162
|
This method orchestrates the hyperparameter search process. It creates an Optuna study that aims to maximize the metric defined in `self.tune_metric`. The search is driven by the `_objective` method, which must be implemented by the child class. After the search, the best parameters are logged, saved to a JSON file, and visualizations of the study are generated.
|
|
@@ -153,15 +174,20 @@ class BaseNNImputer:
|
|
|
153
174
|
study_db = None
|
|
154
175
|
load_if_exists = False
|
|
155
176
|
if self.tune_save_db:
|
|
156
|
-
|
|
177
|
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
178
|
+
study_db = (
|
|
179
|
+
self.optimize_dir / "study_database" / f"optuna_study_{timestamp}.db"
|
|
180
|
+
)
|
|
157
181
|
study_db.parent.mkdir(parents=True, exist_ok=True)
|
|
158
182
|
|
|
159
183
|
if self.tune_resume and study_db.exists():
|
|
160
184
|
load_if_exists = True
|
|
161
185
|
|
|
162
186
|
if not self.tune_resume and study_db.exists():
|
|
163
|
-
|
|
187
|
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
188
|
+
study_db = study_db.with_name(f"optuna_study_{timestamp}.db")
|
|
164
189
|
|
|
190
|
+
self.study_db = study_db
|
|
165
191
|
study_name = f"{self.prefix} {self.model_name} Model Optimization"
|
|
166
192
|
storage = f"sqlite:///{study_db}" if self.tune_save_db else None
|
|
167
193
|
|
|
@@ -170,7 +196,17 @@ class BaseNNImputer:
|
|
|
170
196
|
study_name=study_name,
|
|
171
197
|
storage=storage,
|
|
172
198
|
load_if_exists=load_if_exists,
|
|
173
|
-
pruner=optuna.pruners.MedianPruner(
|
|
199
|
+
pruner=optuna.pruners.MedianPruner(
|
|
200
|
+
# Guard against small `n_trials` values (e.g., 1)
|
|
201
|
+
# that can otherwise produce 0 startup/warmup/min trials.
|
|
202
|
+
n_startup_trials=max(
|
|
203
|
+
1, min(int(self.n_trials * 0.1), 10, int(self.n_trials))
|
|
204
|
+
),
|
|
205
|
+
n_warmup_steps=150,
|
|
206
|
+
n_min_trials=max(
|
|
207
|
+
1, min(int(0.5 * self.n_trials), 25, int(self.n_trials))
|
|
208
|
+
),
|
|
209
|
+
),
|
|
174
210
|
)
|
|
175
211
|
|
|
176
212
|
if not hasattr(self, "_objective"):
|
|
@@ -185,6 +221,13 @@ class BaseNNImputer:
|
|
|
185
221
|
|
|
186
222
|
show_progress_bar = not self.verbose and not self.debug and self.n_jobs == 1
|
|
187
223
|
|
|
224
|
+
# Set the best parameters.
|
|
225
|
+
# NOTE: _set_best_params() must be implemented in the child class.
|
|
226
|
+
if not hasattr(self, "_set_best_params"):
|
|
227
|
+
msg = "Method `_set_best_params()` must be implemented in the child class."
|
|
228
|
+
self.logger.error(msg)
|
|
229
|
+
raise NotImplementedError(msg)
|
|
230
|
+
|
|
188
231
|
study.optimize(
|
|
189
232
|
lambda trial: self._objective(trial),
|
|
190
233
|
n_trials=self.n_trials,
|
|
@@ -193,15 +236,13 @@ class BaseNNImputer:
|
|
|
193
236
|
show_progress_bar=show_progress_bar,
|
|
194
237
|
)
|
|
195
238
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
if not hasattr(self, "_set_best_params"):
|
|
202
|
-
msg = "Method `_set_best_params()` must be implemented in the child class."
|
|
239
|
+
try:
|
|
240
|
+
best_metric = study.best_value
|
|
241
|
+
best_params = study.best_params
|
|
242
|
+
except Exception:
|
|
243
|
+
msg = "Tuning failed: No successful trials completed."
|
|
203
244
|
self.logger.error(msg)
|
|
204
|
-
raise
|
|
245
|
+
raise RuntimeError(msg)
|
|
205
246
|
|
|
206
247
|
self.best_params_ = self._set_best_params(best_params)
|
|
207
248
|
self.model_params.update(self.best_params_)
|
|
@@ -210,45 +251,34 @@ class BaseNNImputer:
|
|
|
210
251
|
best_params_tmp = copy.deepcopy(best_params)
|
|
211
252
|
best_params_tmp["learning_rate"] = self.learning_rate
|
|
212
253
|
|
|
213
|
-
|
|
214
|
-
pm = PrettyMetrics(best_params_tmp, precision=6, title=title)
|
|
215
|
-
pm.render()
|
|
254
|
+
tn = f"{self.tune_metric} Value"
|
|
216
255
|
|
|
217
|
-
|
|
218
|
-
|
|
256
|
+
if self.show_plots:
|
|
257
|
+
self.plotter_.plot_tuning(
|
|
258
|
+
study, self.model_name, self.optimize_dir / "plots", target_name=tn
|
|
259
|
+
)
|
|
219
260
|
|
|
220
|
-
|
|
221
|
-
self.plotter_.plot_tuning(
|
|
222
|
-
study, self.model_name, self.optimize_dir / "plots", target_name=tn
|
|
223
|
-
)
|
|
261
|
+
return best_params_tmp
|
|
224
262
|
|
|
225
263
|
@staticmethod
|
|
226
264
|
def initialize_weights(module: torch.nn.Module) -> None:
|
|
227
|
-
"""Initializes model weights using
|
|
228
|
-
|
|
229
|
-
This static method is intended to be applied to a PyTorch model to initialize the weights of its linear and convolutional layers. This initialization scheme is particularly effective for networks that use ReLU-family activation functions, as it helps maintain stable activation variances during training.
|
|
265
|
+
"""Initializes model weights using Xavier/Glorot Uniform distribution.
|
|
230
266
|
|
|
231
|
-
|
|
232
|
-
|
|
267
|
+
Switching from Kaiming to Xavier is safer for deep VAEs to prevent
|
|
268
|
+
exploding gradients or dead neurons in the early epochs.
|
|
233
269
|
"""
|
|
234
270
|
if isinstance(
|
|
235
271
|
module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.ConvTranspose1d)
|
|
236
272
|
):
|
|
237
|
-
#
|
|
238
|
-
torch.nn.init.
|
|
273
|
+
# Xavier is generally more stable for VAEs than Kaiming
|
|
274
|
+
torch.nn.init.xavier_uniform_(module.weight)
|
|
239
275
|
if module.bias is not None:
|
|
240
276
|
torch.nn.init.zeros_(module.bias)
|
|
241
277
|
|
|
242
278
|
def build_model(
|
|
243
279
|
self,
|
|
244
|
-
Model:
|
|
245
|
-
|
|
246
|
-
| type["AutoencoderModel"]
|
|
247
|
-
| type["NLPCAModel"]
|
|
248
|
-
| type["UBPModel"]
|
|
249
|
-
| type["VAEModel"]
|
|
250
|
-
),
|
|
251
|
-
model_params: Dict[str, int | float | str | bool],
|
|
280
|
+
Model: Type[torch.nn.Module],
|
|
281
|
+
model_params: Dict[str, Any],
|
|
252
282
|
) -> torch.nn.Module:
|
|
253
283
|
"""Builds and initializes a neural network model instance.
|
|
254
284
|
|
|
@@ -277,7 +307,6 @@ class BaseNNImputer:
|
|
|
277
307
|
self.logger.error(msg)
|
|
278
308
|
raise AttributeError(msg)
|
|
279
309
|
|
|
280
|
-
# Start with a base set of fixed (non-tuned) parameters.
|
|
281
310
|
all_params = {
|
|
282
311
|
"n_features": self.num_features_,
|
|
283
312
|
"prefix": self.prefix,
|
|
@@ -287,7 +316,7 @@ class BaseNNImputer:
|
|
|
287
316
|
"device": self.device,
|
|
288
317
|
}
|
|
289
318
|
|
|
290
|
-
# Update with the variable hyperparameters
|
|
319
|
+
# Update with the variable hyperparameters
|
|
291
320
|
all_params.update(model_params)
|
|
292
321
|
|
|
293
322
|
return Model(**all_params).to(self.device)
|
|
@@ -369,110 +398,12 @@ class BaseNNImputer:
|
|
|
369
398
|
X (np.ndarray | pd.DataFrame | list | None): The input data with missing values.
|
|
370
399
|
|
|
371
400
|
Returns:
|
|
372
|
-
np.ndarray:
|
|
401
|
+
np.ndarray: IUPAC strings with missing values imputed.
|
|
373
402
|
"""
|
|
374
403
|
msg = "Method ``transform()`` must be implemented in the child class."
|
|
375
404
|
self.logger.error(msg)
|
|
376
405
|
raise NotImplementedError(msg)
|
|
377
406
|
|
|
378
|
-
def _class_balanced_weights_from_mask(
|
|
379
|
-
self,
|
|
380
|
-
y: np.ndarray,
|
|
381
|
-
train_mask: np.ndarray,
|
|
382
|
-
num_classes: int,
|
|
383
|
-
beta: float = 0.9999,
|
|
384
|
-
max_ratio: float = 5.0,
|
|
385
|
-
mode: Literal["allele", "genotype10"] = "allele",
|
|
386
|
-
) -> torch.Tensor:
|
|
387
|
-
"""Class-balanced weights (Cui et al. 2019) with overflow-safe effective number.
|
|
388
|
-
|
|
389
|
-
mode="allele": y is 1D alleles in {0..3}, train_mask same shape. mode="genotype10": y is (nS,nF,2) alleles; train_mask is (nS,nF) loci where both alleles known.
|
|
390
|
-
|
|
391
|
-
Args:
|
|
392
|
-
y (np.ndarray): Ground truth labels.
|
|
393
|
-
train_mask (np.ndarray): Boolean mask of training examples (same shape as y or y without last dim for genotype10).
|
|
394
|
-
num_classes (int): Number of classes.
|
|
395
|
-
beta (float): Hyperparameter for effective number calculation. Clamped to (0,1). Default is 0.9999.
|
|
396
|
-
max_ratio (float): Maximum allowed ratio between largest and smallest non-zero weight. Default is 5.0.
|
|
397
|
-
mode (Literal["allele", "genotype10"]): Whether y contains allele labels or 10-class genotypes. Default is "allele".
|
|
398
|
-
|
|
399
|
-
Returns:
|
|
400
|
-
torch.Tensor: Class weights of shape (num_classes,). Mean weight is 1.0, zero-weight classes remain zero.
|
|
401
|
-
"""
|
|
402
|
-
if mode == "allele":
|
|
403
|
-
valid = (y >= 0) & train_mask
|
|
404
|
-
cls, cnt = np.unique(y[valid].astype(np.int64), return_counts=True)
|
|
405
|
-
counts = np.zeros(num_classes, dtype=np.float64)
|
|
406
|
-
counts[cls] = cnt
|
|
407
|
-
|
|
408
|
-
elif mode == "genotype10":
|
|
409
|
-
if y.ndim != 3 or y.shape[-1] != 2:
|
|
410
|
-
msg = "For genotype10, y must be (nS,nF,2)."
|
|
411
|
-
self.logger.error(msg)
|
|
412
|
-
raise ValueError(msg)
|
|
413
|
-
|
|
414
|
-
if train_mask.shape != y.shape[:2]:
|
|
415
|
-
msg = "train_mask must be (nS,nF) for genotype10."
|
|
416
|
-
self.logger.error(msg)
|
|
417
|
-
raise ValueError(msg)
|
|
418
|
-
|
|
419
|
-
# only loci where both alleles known and in training
|
|
420
|
-
m = train_mask & np.all(y >= 0, axis=-1)
|
|
421
|
-
if not np.any(m):
|
|
422
|
-
counts = np.zeros(num_classes, dtype=np.float64)
|
|
423
|
-
|
|
424
|
-
else:
|
|
425
|
-
a1 = y[:, :, 0][m].astype(int)
|
|
426
|
-
a2 = y[:, :, 1][m].astype(int)
|
|
427
|
-
lo, hi = np.minimum(a1, a2), np.maximum(a1, a2)
|
|
428
|
-
# map to 10-class index
|
|
429
|
-
map10 = self.pgenc.map10
|
|
430
|
-
idx10 = map10[lo, hi]
|
|
431
|
-
idx10 = idx10[(idx10 >= 0) & (idx10 < num_classes)]
|
|
432
|
-
counts = np.bincount(idx10, minlength=num_classes).astype(np.float64)
|
|
433
|
-
else:
|
|
434
|
-
msg = f"Unknown mode supplied to _class_balanced_weights_from_mask: {mode}"
|
|
435
|
-
self.logger.error(msg)
|
|
436
|
-
raise ValueError(msg)
|
|
437
|
-
|
|
438
|
-
# ---- Effective number ----
|
|
439
|
-
beta = float(beta)
|
|
440
|
-
|
|
441
|
-
# clamp beta ∈ (0,1)
|
|
442
|
-
if not np.isfinite(beta):
|
|
443
|
-
beta = 0.9999
|
|
444
|
-
|
|
445
|
-
beta = min(max(beta, 1e-8), 1.0 - 1e-8)
|
|
446
|
-
|
|
447
|
-
logb = np.log(beta) # < 0
|
|
448
|
-
t = counts * logb # ≤ 0
|
|
449
|
-
|
|
450
|
-
# 1 - beta^n = 1 - exp(n*log(beta)) = -(exp(n*log(beta)) - 1)
|
|
451
|
-
# use expm1 for accuracy near 0; for very negative t, eff≈1.0
|
|
452
|
-
eff = np.where(t > -50.0, -np.expm1(t), 1.0)
|
|
453
|
-
|
|
454
|
-
# class-balanced weights
|
|
455
|
-
w = (1.0 - beta) / (eff + 1e-12)
|
|
456
|
-
|
|
457
|
-
# Give unseen classes the largest non-zero weight (keeps it learnable)
|
|
458
|
-
if np.any(counts == 0) and np.any(counts > 0):
|
|
459
|
-
w[counts == 0] = w[counts > 0].max()
|
|
460
|
-
|
|
461
|
-
# normalize by mean of non-zero
|
|
462
|
-
nz = w > 0
|
|
463
|
-
w[nz] /= w[nz].mean() + 1e-12
|
|
464
|
-
|
|
465
|
-
# cap spread consistently with a single 'cap'
|
|
466
|
-
cap = float(max_ratio) if max_ratio is not None else 10.0
|
|
467
|
-
cap = max(cap, 5.0) # ensure we allow some differentiation
|
|
468
|
-
if np.any(nz):
|
|
469
|
-
spread = w[nz].max() / max(w[nz].min(), 1e-12)
|
|
470
|
-
if spread > cap:
|
|
471
|
-
scale = cap / spread
|
|
472
|
-
w[nz] = 1.0 + (w[nz] - 1.0) * scale
|
|
473
|
-
|
|
474
|
-
return torch.tensor(w.astype(np.float32), device=self.device)
|
|
475
|
-
|
|
476
407
|
def _select_device(self, device: Literal["gpu", "cpu", "mps"]) -> torch.device:
|
|
477
408
|
"""Selects the appropriate PyTorch device based on user preference and availability.
|
|
478
409
|
|
|
@@ -484,36 +415,37 @@ class BaseNNImputer:
|
|
|
484
415
|
Returns:
|
|
485
416
|
torch.device: The selected PyTorch device.
|
|
486
417
|
"""
|
|
487
|
-
dvc
|
|
488
|
-
dvc = dvc.lower().strip()
|
|
418
|
+
dvc = device.lower().strip()
|
|
489
419
|
if dvc == "cpu":
|
|
490
|
-
self.logger.info("Using PyTorch device: CPU.")
|
|
491
420
|
return torch.device("cpu")
|
|
492
421
|
if dvc == "mps":
|
|
493
422
|
if torch.backends.mps.is_available():
|
|
494
|
-
self.logger.info("Using PyTorch device: mps.")
|
|
495
423
|
return torch.device("mps")
|
|
496
|
-
self.logger.warning("MPS unavailable; falling back to CPU.")
|
|
497
424
|
return torch.device("cpu")
|
|
498
|
-
# gpu
|
|
499
425
|
if torch.cuda.is_available():
|
|
500
|
-
self.logger.info("Using PyTorch device: cuda.")
|
|
501
426
|
return torch.device("cuda")
|
|
502
|
-
self.logger.warning("CUDA unavailable; falling back to CPU.")
|
|
503
427
|
return torch.device("cpu")
|
|
504
428
|
|
|
505
|
-
def _create_model_directories(
|
|
429
|
+
def _create_model_directories(
|
|
430
|
+
self, prefix: str, outdirs: List[str], *, outdir: Path | str | None = None
|
|
431
|
+
) -> None:
|
|
506
432
|
"""Creates the directory structure for storing model outputs.
|
|
507
433
|
|
|
508
|
-
This method sets up a standardized folder hierarchy for saving models,
|
|
434
|
+
This method sets up a standardized folder hierarchy for saving models,
|
|
435
|
+
plots, metrics, and optimization results, organized under a main directory
|
|
436
|
+
named after the provided prefix. The current implementation always uses
|
|
437
|
+
``<cwd>/<prefix>_output`` regardless of ``outdir``.
|
|
509
438
|
|
|
510
439
|
Args:
|
|
511
440
|
prefix (str): The prefix for the main output directory.
|
|
512
441
|
outdirs (List[str]): A list of subdirectory names to create within the main directory.
|
|
442
|
+
outdir (Path | str | None): Requested base output directory (currently ignored).
|
|
513
443
|
|
|
514
444
|
Raises:
|
|
515
445
|
Exception: If any of the directories cannot be created.
|
|
516
446
|
"""
|
|
447
|
+
base_root = Path(outdir) if outdir is not None else Path.cwd()
|
|
448
|
+
formatted_output_dir = base_root / f"{prefix}_output"
|
|
517
449
|
formatted_output_dir = Path(f"{prefix}_output")
|
|
518
450
|
base_dir = formatted_output_dir / "Unsupervised"
|
|
519
451
|
|
|
@@ -527,27 +459,16 @@ class BaseNNImputer:
|
|
|
527
459
|
self.logger.error(msg)
|
|
528
460
|
raise Exception(msg)
|
|
529
461
|
|
|
530
|
-
def _clear_resources(
|
|
531
|
-
self,
|
|
532
|
-
model: torch.nn.Module,
|
|
533
|
-
train_loader: torch.utils.data.DataLoader,
|
|
534
|
-
latent_vectors: torch.nn.Parameter | None = None,
|
|
535
|
-
) -> None:
|
|
462
|
+
def _clear_resources(self, model: torch.nn.Module) -> None:
|
|
536
463
|
"""Releases GPU and CPU memory after an Optuna trial.
|
|
537
464
|
|
|
538
465
|
This is a crucial step during hyperparameter tuning to prevent memory leaks between trials, ensuring that each trial runs in a clean environment.
|
|
539
466
|
|
|
540
467
|
Args:
|
|
541
468
|
model (torch.nn.Module): The model from the completed trial.
|
|
542
|
-
train_loader (torch.utils.data.DataLoader): The data loader from the trial.
|
|
543
|
-
latent_vectors (torch.nn.Parameter | None): The latent vectors from the trial.
|
|
544
469
|
"""
|
|
545
470
|
try:
|
|
546
|
-
del model
|
|
547
|
-
|
|
548
|
-
if latent_vectors is not None:
|
|
549
|
-
del latent_vectors
|
|
550
|
-
|
|
471
|
+
del model
|
|
551
472
|
except NameError:
|
|
552
473
|
pass
|
|
553
474
|
|
|
@@ -568,7 +489,7 @@ class BaseNNImputer:
|
|
|
568
489
|
y_pred: np.ndarray,
|
|
569
490
|
metrics: Dict[str, float],
|
|
570
491
|
msg: str,
|
|
571
|
-
):
|
|
492
|
+
) -> None:
|
|
572
493
|
"""Generate and save evaluation visualizations.
|
|
573
494
|
|
|
574
495
|
3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
|
|
@@ -586,20 +507,109 @@ class BaseNNImputer:
|
|
|
586
507
|
prefix = "zygosity" if len(labels) == 3 else "iupac"
|
|
587
508
|
n_labels = len(labels)
|
|
588
509
|
|
|
589
|
-
self.
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
510
|
+
if self.show_plots:
|
|
511
|
+
self.plotter_.plot_metrics(
|
|
512
|
+
y_true=y_true,
|
|
513
|
+
y_pred_proba=y_pred_proba,
|
|
514
|
+
metrics=metrics,
|
|
515
|
+
label_names=labels,
|
|
516
|
+
prefix=f"geno{n_labels}_{prefix}",
|
|
517
|
+
)
|
|
518
|
+
self.plotter_.plot_confusion_matrix(
|
|
519
|
+
y_true_1d=y_true,
|
|
520
|
+
y_pred_1d=y_pred,
|
|
521
|
+
label_names=labels,
|
|
522
|
+
prefix=f"geno{n_labels}_{prefix}",
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
def _additional_metrics(
|
|
526
|
+
self,
|
|
527
|
+
y_true: np.ndarray,
|
|
528
|
+
y_pred: np.ndarray,
|
|
529
|
+
labels: list[int],
|
|
530
|
+
report_names: list[str],
|
|
531
|
+
report: dict,
|
|
532
|
+
) -> dict[str, dict[str, float] | float]:
|
|
533
|
+
"""Compute additional metrics and augment the report dictionary.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
y_true (np.ndarray): True genotypes.
|
|
537
|
+
y_pred (np.ndarray): Predicted genotypes.
|
|
538
|
+
labels (list[int]): List of label indices.
|
|
539
|
+
report_names (list[str]): List of report names corresponding to labels.
|
|
540
|
+
report (dict): Classification report dictionary to augment.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
|
|
544
|
+
"""
|
|
545
|
+
# Create an identity matrix and use the targets array as indices
|
|
546
|
+
y_score = np.eye(len(report_names))[y_pred]
|
|
547
|
+
|
|
548
|
+
# Per-class metrics
|
|
549
|
+
ap_pc = average_precision_score(y_true, y_score, average=None)
|
|
550
|
+
jaccard_pc = jaccard_score(
|
|
551
|
+
y_true, y_pred, average=None, labels=labels, zero_division=0
|
|
595
552
|
)
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
553
|
+
|
|
554
|
+
# Macro/weighted metrics
|
|
555
|
+
ap_macro = average_precision_score(y_true, y_score, average="macro")
|
|
556
|
+
ap_weighted = average_precision_score(y_true, y_score, average="weighted")
|
|
557
|
+
jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
|
|
558
|
+
jaccard_weighted = jaccard_score(
|
|
559
|
+
y_true, y_pred, average="weighted", zero_division=0
|
|
601
560
|
)
|
|
602
561
|
|
|
562
|
+
# Matthews correlation coefficient (MCC)
|
|
563
|
+
mcc = matthews_corrcoef(y_true, y_pred)
|
|
564
|
+
|
|
565
|
+
if not isinstance(ap_pc, np.ndarray):
|
|
566
|
+
msg = "average_precision_score or f1_score did not return np.ndarray as expected."
|
|
567
|
+
self.logger.error(msg)
|
|
568
|
+
raise TypeError(msg)
|
|
569
|
+
|
|
570
|
+
if not isinstance(jaccard_pc, np.ndarray):
|
|
571
|
+
msg = "jaccard_score did not return np.ndarray as expected."
|
|
572
|
+
self.logger.error(msg)
|
|
573
|
+
raise TypeError(msg)
|
|
574
|
+
|
|
575
|
+
# Add per-class metrics
|
|
576
|
+
report_full = {}
|
|
577
|
+
dd_subset = {
|
|
578
|
+
k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
|
|
579
|
+
}
|
|
580
|
+
for i, class_name in enumerate(report_names):
|
|
581
|
+
class_report: dict[str, float] = {}
|
|
582
|
+
if class_name in dd_subset:
|
|
583
|
+
class_report = dd_subset[class_name]
|
|
584
|
+
|
|
585
|
+
if isinstance(class_report, float) or not class_report:
|
|
586
|
+
continue
|
|
587
|
+
|
|
588
|
+
report_full[class_name] = dict(class_report)
|
|
589
|
+
report_full[class_name]["average-precision"] = float(ap_pc[i])
|
|
590
|
+
report_full[class_name]["jaccard"] = float(jaccard_pc[i])
|
|
591
|
+
|
|
592
|
+
macro_avg = report.get("macro avg")
|
|
593
|
+
if isinstance(macro_avg, dict):
|
|
594
|
+
report_full["macro avg"] = dict(macro_avg)
|
|
595
|
+
report_full["macro avg"]["average-precision"] = float(ap_macro)
|
|
596
|
+
report_full["macro avg"]["jaccard"] = float(jaccard_macro)
|
|
597
|
+
|
|
598
|
+
weighted_avg = report.get("weighted avg")
|
|
599
|
+
if isinstance(weighted_avg, dict):
|
|
600
|
+
report_full["weighted avg"] = dict(weighted_avg)
|
|
601
|
+
report_full["weighted avg"]["average-precision"] = float(ap_weighted)
|
|
602
|
+
report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
|
|
603
|
+
|
|
604
|
+
# Add scalar summary metrics
|
|
605
|
+
report_full["mcc"] = float(mcc)
|
|
606
|
+
accuracy_val = report.get("accuracy")
|
|
607
|
+
|
|
608
|
+
if isinstance(accuracy_val, (int, float)):
|
|
609
|
+
report_full["accuracy"] = float(accuracy_val)
|
|
610
|
+
|
|
611
|
+
return report_full
|
|
612
|
+
|
|
603
613
|
def _make_class_reports(
|
|
604
614
|
self,
|
|
605
615
|
y_true: np.ndarray,
|
|
@@ -617,27 +627,28 @@ class BaseNNImputer:
|
|
|
617
627
|
y_pred (np.ndarray): Predicted labels (1D array).
|
|
618
628
|
metrics (Dict[str, float]): Computed metrics.
|
|
619
629
|
y_pred_proba (np.ndarray | None): Predicted probabilities (2D array). Defaults to None.
|
|
620
|
-
labels (List[str]): Class label names
|
|
621
|
-
(default: ["REF", "HET", "ALT"] for 3-class).
|
|
630
|
+
labels (List[str]): Class label names (default: ["REF", "HET", "ALT"] for 3-class).
|
|
622
631
|
"""
|
|
623
|
-
report_name = "zygosity" if len(labels)
|
|
632
|
+
report_name = "zygosity" if len(labels) <= 3 else "iupac"
|
|
624
633
|
middle = "IUPAC" if report_name == "iupac" else "Zygosity"
|
|
625
634
|
|
|
626
|
-
msg = f"{middle} Report (on {
|
|
635
|
+
msg = f"{middle} Report (on {y_pred.size} total genotypes)"
|
|
627
636
|
self.logger.info(msg)
|
|
628
637
|
|
|
629
638
|
if y_pred_proba is not None:
|
|
630
|
-
self.
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
639
|
+
if self.show_plots:
|
|
640
|
+
self.plotter_.plot_metrics(
|
|
641
|
+
y_true,
|
|
642
|
+
y_pred_proba,
|
|
643
|
+
metrics,
|
|
644
|
+
label_names=labels,
|
|
645
|
+
prefix=report_name,
|
|
646
|
+
)
|
|
637
647
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
648
|
+
if self.show_plots:
|
|
649
|
+
self.plotter_.plot_confusion_matrix(
|
|
650
|
+
y_true, y_pred, label_names=labels, prefix=report_name
|
|
651
|
+
)
|
|
641
652
|
|
|
642
653
|
report: str | dict = classification_report(
|
|
643
654
|
y_true,
|
|
@@ -650,62 +661,63 @@ class BaseNNImputer:
|
|
|
650
661
|
|
|
651
662
|
if not isinstance(report, dict):
|
|
652
663
|
msg = "Expected classification_report to return a dict."
|
|
653
|
-
self.logger.error(msg)
|
|
664
|
+
self.logger.error(msg, exc_info=True)
|
|
654
665
|
raise ValueError(msg)
|
|
655
666
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
+
if self.show_plots:
|
|
668
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
669
|
+
try:
|
|
670
|
+
plots = viz.plot_all(
|
|
671
|
+
report, # type: ignore
|
|
672
|
+
title_prefix=f"{self.model_name} {middle} Report",
|
|
673
|
+
show=self.show_plots,
|
|
674
|
+
heatmap_classes_only=True,
|
|
675
|
+
)
|
|
676
|
+
finally:
|
|
677
|
+
viz._reset_mpl_style()
|
|
678
|
+
|
|
679
|
+
for name, fig in plots.items():
|
|
680
|
+
fout = (
|
|
681
|
+
self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
|
|
682
|
+
)
|
|
683
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
684
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
685
|
+
plt.close(fig)
|
|
686
|
+
elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
|
|
687
|
+
fout_html = fout.with_suffix(".html")
|
|
688
|
+
fig.write_html(file=fout_html)
|
|
689
|
+
|
|
690
|
+
SNPioMultiQC.queue_html(
|
|
691
|
+
fout_html,
|
|
692
|
+
panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
|
|
693
|
+
section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
694
|
+
title=f"{self.model_name} {middle} Radar Plot",
|
|
695
|
+
index_label=name,
|
|
696
|
+
description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
if not self.is_haploid_:
|
|
700
|
+
msg = f"Ploidy: {self.ploidy}. Evaluating per genotype (REF, HET, ALT)."
|
|
701
|
+
self.logger.info(msg)
|
|
702
|
+
|
|
703
|
+
report_full = self._additional_metrics(
|
|
704
|
+
y_true,
|
|
705
|
+
y_pred,
|
|
706
|
+
labels=list(range(len(labels))),
|
|
707
|
+
report_names=labels,
|
|
708
|
+
report=report,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
if self.verbose or self.debug:
|
|
667
712
|
pm = PrettyMetrics(
|
|
668
|
-
|
|
669
|
-
precision=
|
|
713
|
+
report_full,
|
|
714
|
+
precision=2,
|
|
670
715
|
title=f"{self.model_name} {middle} Report",
|
|
671
716
|
)
|
|
672
717
|
pm.render()
|
|
673
718
|
|
|
674
719
|
with open(self.metrics_dir / f"{report_name}_report.json", "w") as f:
|
|
675
|
-
json.dump(
|
|
676
|
-
|
|
677
|
-
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
678
|
-
|
|
679
|
-
plots = viz.plot_all(
|
|
680
|
-
report, # type: ignore
|
|
681
|
-
title_prefix=f"{self.model_name} {middle} Report",
|
|
682
|
-
show=getattr(self, "show_plots", False),
|
|
683
|
-
heatmap_classes_only=True,
|
|
684
|
-
)
|
|
685
|
-
|
|
686
|
-
for name, fig in plots.items():
|
|
687
|
-
fout = self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
|
|
688
|
-
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
689
|
-
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
690
|
-
plt.close(fig)
|
|
691
|
-
elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
|
|
692
|
-
fout_html = fout.with_suffix(".html")
|
|
693
|
-
fig.write_html(file=fout_html)
|
|
694
|
-
|
|
695
|
-
SNPioMultiQC.queue_html(
|
|
696
|
-
fout_html,
|
|
697
|
-
panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
|
|
698
|
-
section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
699
|
-
title=f"{self.model_name} {middle} Radar Plot",
|
|
700
|
-
index_label=name,
|
|
701
|
-
description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
|
|
702
|
-
)
|
|
703
|
-
|
|
704
|
-
if not self.is_haploid:
|
|
705
|
-
msg = f"Ploidy: {self.ploidy}. Evaluating per allele."
|
|
706
|
-
self.logger.info(msg)
|
|
707
|
-
|
|
708
|
-
viz._reset_mpl_style()
|
|
720
|
+
json.dump(report_full, f, indent=4)
|
|
709
721
|
|
|
710
722
|
def _compute_hidden_layer_sizes(
|
|
711
723
|
self,
|
|
@@ -713,6 +725,7 @@ class BaseNNImputer:
|
|
|
713
725
|
n_outputs: int,
|
|
714
726
|
n_samples: int,
|
|
715
727
|
n_hidden: int,
|
|
728
|
+
latent_dim: int,
|
|
716
729
|
*,
|
|
717
730
|
alpha: float = 4.0,
|
|
718
731
|
schedule: str = "pyramid",
|
|
@@ -724,182 +737,439 @@ class BaseNNImputer:
|
|
|
724
737
|
) -> list[int]:
|
|
725
738
|
"""Compute hidden layer sizes given problem scale and a layer count.
|
|
726
739
|
|
|
727
|
-
|
|
740
|
+
Notes:
|
|
741
|
+
- Returns sizes for *hidden layers only* (length = n_hidden).
|
|
742
|
+
- Does NOT include the input layer (n_inputs) or the latent layer (latent_dim).
|
|
743
|
+
- Enforces a latent-aware minimum: one discrete level above latent_dim, where a level is `multiple_of`.
|
|
744
|
+
- Enforces *strictly decreasing* hidden sizes (no repeats). This may require bumping `base` upward.
|
|
728
745
|
|
|
729
746
|
Args:
|
|
730
|
-
n_inputs
|
|
731
|
-
n_outputs
|
|
732
|
-
n_samples
|
|
733
|
-
n_hidden
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
747
|
+
n_inputs: Number of input features (e.g., flattened one-hot: num_features * num_classes).
|
|
748
|
+
n_outputs: Number of output classes (often equals num_classes).
|
|
749
|
+
n_samples: Number of training samples.
|
|
750
|
+
n_hidden: Number of hidden layers (excluding input and latent layers).
|
|
751
|
+
latent_dim: Latent dimensionality (not returned, used only to set a floor).
|
|
752
|
+
alpha: Scaling factor for base layer size.
|
|
753
|
+
schedule: Size schedule ("pyramid" or "linear").
|
|
754
|
+
min_size: Minimum layer size floor before latent-aware adjustment.
|
|
755
|
+
max_size: Maximum layer size cap. If None, a heuristic cap is used.
|
|
756
|
+
multiple_of: Hidden sizes are multiples of this value.
|
|
757
|
+
decay: Pyramid decay factor. If None, computed to land near the target.
|
|
758
|
+
cap_by_inputs: If True, cap layer sizes to n_inputs.
|
|
741
759
|
|
|
742
760
|
Returns:
|
|
743
|
-
list[int]:
|
|
761
|
+
list[int]: Hidden layer sizes (len = n_hidden).
|
|
744
762
|
|
|
745
763
|
Raises:
|
|
746
|
-
ValueError:
|
|
747
|
-
TypeError: If any argument is not of the expected type.
|
|
748
|
-
|
|
749
|
-
Notes:
|
|
750
|
-
- If n_hidden is 0, returns an empty list.
|
|
751
|
-
- The base layer size is computed as ceil(n_samples / (alpha * (n_inputs + n_outputs))).
|
|
752
|
-
- The sizes are adjusted according to the specified schedule and constraints.
|
|
764
|
+
ValueError: On invalid arguments or conflicting constraints.
|
|
753
765
|
"""
|
|
766
|
+
# ----------------------------
|
|
767
|
+
# Basic validation
|
|
768
|
+
# ----------------------------
|
|
754
769
|
if n_hidden < 0:
|
|
755
770
|
msg = f"n_hidden must be >= 0, got {n_hidden}."
|
|
756
771
|
self.logger.error(msg)
|
|
757
772
|
raise ValueError(msg)
|
|
758
773
|
|
|
759
|
-
if
|
|
760
|
-
|
|
774
|
+
if n_hidden == 0:
|
|
775
|
+
return []
|
|
776
|
+
|
|
777
|
+
if n_inputs <= 0:
|
|
778
|
+
msg = f"n_inputs must be > 0, got {n_inputs}."
|
|
761
779
|
self.logger.error(msg)
|
|
762
780
|
raise ValueError(msg)
|
|
763
781
|
|
|
764
|
-
if
|
|
765
|
-
|
|
782
|
+
if n_outputs <= 0:
|
|
783
|
+
msg = f"n_outputs must be > 0, got {n_outputs}."
|
|
784
|
+
self.logger.error(msg)
|
|
785
|
+
raise ValueError(msg)
|
|
766
786
|
|
|
767
|
-
|
|
787
|
+
if n_samples <= 0:
|
|
788
|
+
msg = f"n_samples must be > 0, got {n_samples}."
|
|
789
|
+
self.logger.error(msg)
|
|
790
|
+
raise ValueError(msg)
|
|
768
791
|
|
|
769
|
-
if
|
|
770
|
-
msg = f"
|
|
792
|
+
if latent_dim <= 0:
|
|
793
|
+
msg = f"latent_dim must be > 0, got {latent_dim}."
|
|
771
794
|
self.logger.error(msg)
|
|
772
795
|
raise ValueError(msg)
|
|
773
796
|
|
|
774
|
-
|
|
797
|
+
if multiple_of <= 0:
|
|
798
|
+
msg = f"multiple_of must be > 0, got {multiple_of}."
|
|
799
|
+
self.logger.error(msg)
|
|
800
|
+
raise ValueError(msg)
|
|
801
|
+
|
|
802
|
+
if alpha <= 0:
|
|
803
|
+
msg = f"alpha must be > 0, got {alpha}."
|
|
804
|
+
self.logger.error(msg)
|
|
805
|
+
raise ValueError(msg)
|
|
806
|
+
|
|
807
|
+
schedule = str(schedule).lower().strip()
|
|
808
|
+
if schedule not in {"pyramid", "linear"}:
|
|
809
|
+
msg = f"Invalid schedule '{schedule}'. Must be 'pyramid' or 'linear'."
|
|
810
|
+
self.logger.error(msg)
|
|
811
|
+
raise ValueError(msg)
|
|
812
|
+
|
|
813
|
+
# ----------------------------
|
|
814
|
+
# Latent-aware minimum floor
|
|
815
|
+
# ----------------------------
|
|
816
|
+
# Smallest multiple_of strictly greater than latent_dim
|
|
817
|
+
min_hidden_floor = int(np.ceil((latent_dim + 1) / multiple_of) * multiple_of)
|
|
818
|
+
effective_min = max(int(min_size), min_hidden_floor)
|
|
819
|
+
|
|
820
|
+
if cap_by_inputs and n_inputs < effective_min:
|
|
821
|
+
msg = (
|
|
822
|
+
"Cannot satisfy latent-aware minimum hidden size with cap_by_inputs=True. "
|
|
823
|
+
f"Required hidden size >= {effective_min} (one level above latent_dim={latent_dim}), "
|
|
824
|
+
f"but n_inputs={n_inputs}. Set cap_by_inputs=False or reduce latent_dim/multiple_of."
|
|
825
|
+
)
|
|
826
|
+
self.logger.error(msg)
|
|
827
|
+
raise ValueError(msg)
|
|
828
|
+
|
|
829
|
+
# ----------------------------
|
|
830
|
+
# Infer num_features (if using flattened one-hot: n_inputs = num_features * num_classes)
|
|
831
|
+
# ----------------------------
|
|
832
|
+
if n_inputs % n_outputs == 0:
|
|
833
|
+
num_features = n_inputs // n_outputs
|
|
834
|
+
else:
|
|
835
|
+
num_features = n_inputs
|
|
836
|
+
self.logger.warning(
|
|
837
|
+
"n_inputs is not divisible by n_outputs; falling back to num_features=n_inputs "
|
|
838
|
+
f"(n_inputs={n_inputs}, n_outputs={n_outputs}). If using one-hot flattening, "
|
|
839
|
+
"pass n_outputs=num_classes so num_features can be inferred correctly."
|
|
840
|
+
)
|
|
775
841
|
|
|
842
|
+
# ----------------------------
|
|
843
|
+
# Base size heuristic (feature-matrix aware; avoids collapse for huge n_inputs)
|
|
844
|
+
# ----------------------------
|
|
845
|
+
obs_scale = (float(n_samples) * float(num_features)) / float(
|
|
846
|
+
num_features + n_outputs
|
|
847
|
+
)
|
|
848
|
+
base = int(np.ceil(float(alpha) * np.sqrt(obs_scale)))
|
|
849
|
+
|
|
850
|
+
# ----------------------------
|
|
851
|
+
# Determine max_size
|
|
852
|
+
# ----------------------------
|
|
776
853
|
if max_size is None:
|
|
777
|
-
max_size = max(n_inputs, base)
|
|
854
|
+
max_size = max(int(n_inputs), int(base), int(effective_min))
|
|
855
|
+
|
|
856
|
+
if cap_by_inputs:
|
|
857
|
+
max_size = min(int(max_size), int(n_inputs))
|
|
858
|
+
else:
|
|
859
|
+
max_size = int(max_size)
|
|
860
|
+
|
|
861
|
+
if max_size < effective_min:
|
|
862
|
+
msg = (
|
|
863
|
+
f"max_size ({max_size}) must be >= effective_min ({effective_min}), where effective_min "
|
|
864
|
+
f"is max(min_size={min_size}, one-level-above latent_dim={latent_dim})."
|
|
865
|
+
)
|
|
866
|
+
self.logger.error(msg)
|
|
867
|
+
raise ValueError(msg)
|
|
868
|
+
|
|
869
|
+
# Round base up to a multiple and clip to bounds
|
|
870
|
+
base = int(np.clip(base, effective_min, max_size))
|
|
871
|
+
base = int(np.ceil(base / multiple_of) * multiple_of)
|
|
872
|
+
base = int(np.clip(base, effective_min, max_size))
|
|
873
|
+
|
|
874
|
+
# ----------------------------
|
|
875
|
+
# Enforce "no repeats" feasibility in discrete levels
|
|
876
|
+
# Need n_hidden distinct multiples between base and effective_min:
|
|
877
|
+
# base >= effective_min + (n_hidden - 1) * multiple_of
|
|
878
|
+
# ----------------------------
|
|
879
|
+
required_min_base = effective_min + (n_hidden - 1) * multiple_of
|
|
880
|
+
|
|
881
|
+
if required_min_base > max_size:
|
|
882
|
+
msg = (
|
|
883
|
+
"Cannot build strictly-decreasing (no-repeat) hidden sizes under current constraints. "
|
|
884
|
+
f"Need base >= {required_min_base} to fit n_hidden={n_hidden} distinct layers "
|
|
885
|
+
f"with multiple_of={multiple_of} down to effective_min={effective_min}, "
|
|
886
|
+
f"but max_size={max_size}. Reduce n_hidden, reduce multiple_of, lower latent_dim/min_size, "
|
|
887
|
+
"or increase max_size / set cap_by_inputs=False."
|
|
888
|
+
)
|
|
889
|
+
self.logger.error(msg)
|
|
890
|
+
raise ValueError(msg)
|
|
778
891
|
|
|
779
|
-
|
|
892
|
+
if base < required_min_base:
|
|
893
|
+
# Bump base upward so a strict staircase is possible
|
|
894
|
+
base = required_min_base
|
|
895
|
+
base = int(np.ceil(base / multiple_of) * multiple_of)
|
|
896
|
+
base = int(np.clip(base, effective_min, max_size))
|
|
780
897
|
|
|
781
|
-
|
|
782
|
-
|
|
898
|
+
# Work in "levels" of multiple_of for guaranteed uniqueness
|
|
899
|
+
start_level = base // multiple_of
|
|
900
|
+
end_level = effective_min // multiple_of
|
|
901
|
+
|
|
902
|
+
# Sanity: distinct levels available
|
|
903
|
+
if (start_level - end_level) < (n_hidden - 1):
|
|
904
|
+
# This should not happen due to required_min_base logic, but keep a hard guard.
|
|
905
|
+
msg = (
|
|
906
|
+
"Internal constraint failure: insufficient discrete levels to enforce no repeats. "
|
|
907
|
+
f"start_level={start_level}, end_level={end_level}, n_hidden={n_hidden}."
|
|
908
|
+
)
|
|
909
|
+
self.logger.error(msg)
|
|
910
|
+
raise ValueError(msg)
|
|
911
|
+
|
|
912
|
+
# ----------------------------
|
|
913
|
+
# Build schedule in level space (integers), then convert to sizes
|
|
914
|
+
# ----------------------------
|
|
915
|
+
if n_hidden == 1:
|
|
916
|
+
levels = np.array([start_level], dtype=int)
|
|
783
917
|
|
|
784
918
|
elif schedule == "linear":
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
if n_hidden == 1
|
|
789
|
-
else np.linspace(base, target, num=n_hidden, dtype=float)
|
|
919
|
+
# Linear interpolation in level space, then strictify
|
|
920
|
+
levels = np.round(np.linspace(start_level, end_level, num=n_hidden)).astype(
|
|
921
|
+
int
|
|
790
922
|
)
|
|
791
923
|
|
|
924
|
+
# Enforce bounds then strict decrease
|
|
925
|
+
levels = np.clip(levels, end_level, start_level)
|
|
926
|
+
|
|
927
|
+
for i in range(1, n_hidden):
|
|
928
|
+
if levels[i] >= levels[i - 1]:
|
|
929
|
+
levels[i] = levels[i - 1] - 1
|
|
930
|
+
|
|
931
|
+
if levels[-1] < end_level:
|
|
932
|
+
msg = (
|
|
933
|
+
"Failed to enforce strictly-decreasing linear schedule without violating the floor. "
|
|
934
|
+
f"(levels[-1]={levels[-1]} < end_level={end_level}). "
|
|
935
|
+
"Reduce n_hidden or multiple_of, or increase max_size."
|
|
936
|
+
)
|
|
937
|
+
self.logger.error(msg)
|
|
938
|
+
raise ValueError(msg)
|
|
939
|
+
|
|
940
|
+
# Force exact floor at the end (still strict because we have enough room by construction)
|
|
941
|
+
levels[-1] = end_level
|
|
942
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
943
|
+
if levels[i] <= levels[i + 1]:
|
|
944
|
+
levels[i] = levels[i + 1] + 1
|
|
945
|
+
|
|
946
|
+
if levels[0] > start_level:
|
|
947
|
+
# If this happens, we would need an even larger base; handle by raising base once.
|
|
948
|
+
needed_base = int(levels[0] * multiple_of)
|
|
949
|
+
if needed_base > max_size:
|
|
950
|
+
msg = (
|
|
951
|
+
"Cannot enforce strictly-decreasing linear schedule after floor anchoring; "
|
|
952
|
+
f"would require base={needed_base} > max_size={max_size}."
|
|
953
|
+
)
|
|
954
|
+
self.logger.error(msg)
|
|
955
|
+
raise ValueError(msg)
|
|
956
|
+
# Rebuild with bumped base
|
|
957
|
+
start_level = needed_base // multiple_of
|
|
958
|
+
levels = np.arange(start_level, start_level - n_hidden, -1, dtype=int)
|
|
959
|
+
levels[-1] = end_level # keep floor
|
|
960
|
+
# Ensure strict with backward adjust
|
|
961
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
962
|
+
if levels[i] <= levels[i + 1]:
|
|
963
|
+
levels[i] = levels[i + 1] + 1
|
|
964
|
+
|
|
792
965
|
elif schedule == "pyramid":
|
|
793
|
-
|
|
794
|
-
|
|
966
|
+
# Geometric decay in level space (more aggressive early taper than linear)
|
|
967
|
+
if decay is not None:
|
|
968
|
+
dcy = float(decay)
|
|
795
969
|
else:
|
|
970
|
+
# Choose decay to land exactly at end_level (in float space)
|
|
971
|
+
dcy = (float(end_level) / float(start_level)) ** (
|
|
972
|
+
1.0 / float(n_hidden - 1)
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
# Keep it in a sensible range
|
|
976
|
+
dcy = float(np.clip(dcy, 0.05, 0.99))
|
|
977
|
+
|
|
978
|
+
exponents = np.arange(n_hidden, dtype=float)
|
|
979
|
+
levels_float = float(start_level) * (dcy**exponents)
|
|
980
|
+
|
|
981
|
+
levels = np.round(levels_float).astype(int)
|
|
982
|
+
levels = np.clip(levels, end_level, start_level)
|
|
983
|
+
|
|
984
|
+
# Anchor the last layer at the floor, then strictify backward
|
|
985
|
+
levels[-1] = end_level
|
|
986
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
987
|
+
if levels[i] <= levels[i + 1]:
|
|
988
|
+
levels[i] = levels[i + 1] + 1
|
|
989
|
+
|
|
990
|
+
# If we overshot the start, bump base (once) if possible, then rebuild
|
|
991
|
+
if levels[0] > start_level:
|
|
992
|
+
needed_base = int(levels[0] * multiple_of)
|
|
993
|
+
if needed_base > max_size:
|
|
994
|
+
msg = (
|
|
995
|
+
"Cannot enforce strictly-decreasing pyramid schedule; "
|
|
996
|
+
f"would require base={needed_base} > max_size={max_size}."
|
|
997
|
+
)
|
|
998
|
+
self.logger.error(msg)
|
|
999
|
+
raise ValueError(msg)
|
|
1000
|
+
|
|
1001
|
+
start_level = needed_base // multiple_of
|
|
1002
|
+
# Recompute with new start_level and same decay (or recompute decay if decay is None)
|
|
796
1003
|
if decay is None:
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
1004
|
+
dcy = (float(end_level) / float(start_level)) ** (
|
|
1005
|
+
1.0 / float(n_hidden - 1)
|
|
1006
|
+
)
|
|
1007
|
+
dcy = float(np.clip(dcy, 0.05, 0.99))
|
|
1008
|
+
|
|
1009
|
+
levels_float = float(start_level) * (dcy**exponents)
|
|
1010
|
+
levels = np.round(levels_float).astype(int)
|
|
1011
|
+
levels = np.clip(levels, end_level, start_level)
|
|
1012
|
+
levels[-1] = end_level
|
|
1013
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
1014
|
+
if levels[i] <= levels[i + 1]:
|
|
1015
|
+
levels[i] = levels[i + 1] + 1
|
|
805
1016
|
|
|
806
1017
|
else:
|
|
807
|
-
msg = f"Unknown schedule '{schedule}'. Use 'pyramid'
|
|
1018
|
+
msg = f"Unknown schedule '{schedule}'. Use 'pyramid' or 'linear' (constant disallowed with no repeats)."
|
|
808
1019
|
self.logger.error(msg)
|
|
809
1020
|
raise ValueError(msg)
|
|
810
1021
|
|
|
811
|
-
|
|
1022
|
+
# Convert levels -> sizes
|
|
1023
|
+
sizes = (levels * multiple_of).astype(int)
|
|
812
1024
|
|
|
813
|
-
|
|
814
|
-
|
|
1025
|
+
# Final clip (should be redundant, but safe)
|
|
1026
|
+
sizes = np.clip(sizes, effective_min, max_size).astype(int)
|
|
815
1027
|
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
1028
|
+
# Final strict no-repeat assertion
|
|
1029
|
+
if np.any(np.diff(sizes) >= 0):
|
|
1030
|
+
msg = (
|
|
1031
|
+
"Internal error: produced non-decreasing or repeated hidden sizes after strict enforcement. "
|
|
1032
|
+
f"sizes={sizes.tolist()}"
|
|
1033
|
+
)
|
|
1034
|
+
self.logger.error(msg)
|
|
1035
|
+
raise ValueError(msg)
|
|
819
1036
|
|
|
820
|
-
|
|
821
|
-
"""Class-balanced weights for 0/1/2 (handles haploid collapse if needed).
|
|
1037
|
+
return sizes.tolist()
|
|
822
1038
|
|
|
823
|
-
|
|
1039
|
+
def _class_weights_from_zygosity(
|
|
1040
|
+
self,
|
|
1041
|
+
X: np.ndarray,
|
|
1042
|
+
train_mask: Optional[np.ndarray] = None,
|
|
1043
|
+
*,
|
|
1044
|
+
inverse: bool = False,
|
|
1045
|
+
normalize: bool = False,
|
|
1046
|
+
power: float = 1.0,
|
|
1047
|
+
max_ratio: float | None = None,
|
|
1048
|
+
eps: float = 1e-12,
|
|
1049
|
+
) -> torch.Tensor:
|
|
1050
|
+
"""Compute class weights for zygosity labels.
|
|
824
1051
|
|
|
825
|
-
|
|
826
|
-
|
|
1052
|
+
If inverse=False (default):
|
|
1053
|
+
w_c = N / (K * n_c) ("balanced")
|
|
1054
|
+
|
|
1055
|
+
If inverse=True:
|
|
1056
|
+
w_c = N / n_c (same ratios, scaled by K)
|
|
1057
|
+
|
|
1058
|
+
If power != 1.0:
|
|
1059
|
+
w_c <- w_c ** power (amplifies or softens imbalance handling)
|
|
1060
|
+
|
|
1061
|
+
If normalize=True:
|
|
1062
|
+
rescales nonzero weights so mean(nonzero_weights) == 1.
|
|
827
1063
|
|
|
828
1064
|
Returns:
|
|
829
|
-
torch.Tensor:
|
|
1065
|
+
torch.Tensor: Class weights of shape (num_classes,) on self.device.
|
|
830
1066
|
"""
|
|
831
|
-
y = X
|
|
832
|
-
if y.size == 0:
|
|
833
|
-
return torch.ones(
|
|
834
|
-
self.num_classes_, dtype=torch.float32, device=self.device
|
|
835
|
-
)
|
|
1067
|
+
y = np.asarray(X).ravel().astype(np.int8)
|
|
836
1068
|
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
1069
|
+
m = y >= 0
|
|
1070
|
+
if train_mask is not None:
|
|
1071
|
+
tm = np.asarray(train_mask, dtype=bool).ravel()
|
|
1072
|
+
if tm.shape != y.shape:
|
|
1073
|
+
msg = "train_mask must have the same shape as X."
|
|
1074
|
+
self.logger.error(msg)
|
|
1075
|
+
raise ValueError(msg)
|
|
1076
|
+
m &= tm
|
|
1077
|
+
|
|
1078
|
+
is_hap = bool(getattr(self, "is_haploid_", False))
|
|
1079
|
+
num_classes = 2 if is_hap else int(self.num_classes_)
|
|
1080
|
+
|
|
1081
|
+
if not np.any(m):
|
|
1082
|
+
return torch.ones(num_classes, dtype=torch.long, device=self.device)
|
|
1083
|
+
|
|
1084
|
+
if is_hap:
|
|
1085
|
+
y = y.copy()
|
|
1086
|
+
y[(y == 2) & m] = 1
|
|
1087
|
+
|
|
1088
|
+
y_m = y[m]
|
|
1089
|
+
if y_m.size:
|
|
1090
|
+
ymin = int(y_m.min())
|
|
1091
|
+
ymax = int(y_m.max())
|
|
1092
|
+
if ymin < 0 or ymax >= num_classes:
|
|
1093
|
+
msg = (
|
|
1094
|
+
f"Found out-of-range labels under mask: min={ymin}, max={ymax}, "
|
|
1095
|
+
f"expected in [0, {num_classes - 1}]."
|
|
1096
|
+
)
|
|
1097
|
+
self.logger.error(msg)
|
|
1098
|
+
raise ValueError(msg)
|
|
845
1099
|
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
) -> torch.Tensor | None:
|
|
850
|
-
"""Normalize class weights once to keep loss scale stable.
|
|
1100
|
+
counts = np.bincount(y_m, minlength=num_classes).astype(np.float32)
|
|
1101
|
+
N = float(counts.sum())
|
|
1102
|
+
K = float(num_classes)
|
|
851
1103
|
|
|
852
|
-
|
|
853
|
-
|
|
1104
|
+
w = np.zeros(num_classes, dtype=np.float32)
|
|
1105
|
+
nz = counts > 0
|
|
854
1106
|
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
return weights / weights.mean().clamp_min(1e-8)
|
|
1107
|
+
if np.any(nz):
|
|
1108
|
+
if inverse:
|
|
1109
|
+
w[nz] = N / (counts[nz] + eps)
|
|
1110
|
+
else:
|
|
1111
|
+
w[nz] = N / (K * (counts[nz] + eps))
|
|
861
1112
|
|
|
862
|
-
|
|
863
|
-
|
|
1113
|
+
# Amplify / soften class contrast
|
|
1114
|
+
if power <= 0.0:
|
|
1115
|
+
msg = "power must be > 0."
|
|
1116
|
+
self.logger.error(msg)
|
|
1117
|
+
raise ValueError(msg)
|
|
1118
|
+
if power != 1.0:
|
|
1119
|
+
w[nz] = np.power(w[nz], power)
|
|
864
1120
|
|
|
865
|
-
|
|
866
|
-
|
|
1121
|
+
if np.any(~nz):
|
|
1122
|
+
self.logger.warning(
|
|
1123
|
+
"Some classes have zero count under the provided mask: "
|
|
1124
|
+
f"{np.where(~nz)[0].tolist()}. Setting their weights to 0."
|
|
1125
|
+
)
|
|
867
1126
|
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
1127
|
+
# Cap ratio among observed classes
|
|
1128
|
+
if max_ratio is not None and np.any(nz):
|
|
1129
|
+
cap = float(max_ratio)
|
|
1130
|
+
if cap <= 1.0:
|
|
1131
|
+
msg = "max_ratio must be > 1.0 or None."
|
|
1132
|
+
self.logger.error(msg)
|
|
1133
|
+
raise ValueError(msg)
|
|
1134
|
+
|
|
1135
|
+
wmin = max(float(w[nz].min()), eps)
|
|
1136
|
+
wmax = wmin * cap
|
|
1137
|
+
w[nz] = np.clip(w[nz], wmin, wmax)
|
|
1138
|
+
|
|
1139
|
+
# Optional normalization: mean(nonzero) -> 1.0
|
|
1140
|
+
if normalize and np.any(nz):
|
|
1141
|
+
mean_nz = float(w[nz].mean())
|
|
1142
|
+
if mean_nz > 0.0:
|
|
1143
|
+
w[nz] /= mean_nz
|
|
1144
|
+
else:
|
|
1145
|
+
self.logger.warning(
|
|
1146
|
+
"normalize=True requested, but mean of nonzero weights is not positive; skipping normalization."
|
|
1147
|
+
)
|
|
1148
|
+
|
|
1149
|
+
self.logger.debug(f"Class counts: {counts.astype(np.int8)}")
|
|
1150
|
+
self.logger.debug(
|
|
1151
|
+
f"Class weights (inverse={inverse}, power={power}, normalize={normalize}): {w}"
|
|
891
1152
|
)
|
|
892
1153
|
|
|
893
|
-
|
|
894
|
-
"""One-hot 0/1/2; -1 rows are all-zeros (B, L, K).
|
|
1154
|
+
return torch.as_tensor(w, dtype=torch.long, device=self.device)
|
|
895
1155
|
|
|
896
|
-
|
|
1156
|
+
def _one_hot_encode_012(
|
|
1157
|
+
self, X: np.ndarray | torch.Tensor, num_classes: int | None
|
|
1158
|
+
) -> torch.Tensor:
|
|
1159
|
+
"""One-hot encode genotype calls. Missing inputs (<0) result in a vector of -1s.
|
|
897
1160
|
|
|
898
1161
|
Args:
|
|
899
|
-
X (np.ndarray | torch.Tensor):
|
|
1162
|
+
X (np.ndarray | torch.Tensor): Input genotype calls as integers (0,1, 2, etc.).
|
|
1163
|
+
num_classes (int | None): Number of classes (K). If None, uses self.num_classes_.
|
|
900
1164
|
|
|
901
1165
|
Returns:
|
|
902
|
-
torch.Tensor:
|
|
1166
|
+
torch.Tensor: One-hot encoded tensor of shape (B, L, K) with dtype
|
|
1167
|
+
``torch.long``. Valid calls are 0/1, missing calls are all -1.
|
|
1168
|
+
|
|
1169
|
+
Notes:
|
|
1170
|
+
- Valid classes must be integers in [0, K-1]
|
|
1171
|
+
- Missing is any value < 0; these positions become [-1, -1, ..., -1]
|
|
1172
|
+
- If K==2 and values are in {0,2} (no 1s), map 2->1.
|
|
903
1173
|
"""
|
|
904
1174
|
Xt = (
|
|
905
1175
|
torch.from_numpy(X).to(self.device)
|
|
@@ -907,212 +1177,680 @@ class BaseNNImputer:
|
|
|
907
1177
|
else X.to(self.device)
|
|
908
1178
|
)
|
|
909
1179
|
|
|
910
|
-
#
|
|
1180
|
+
# Make sure we have integer class labels
|
|
1181
|
+
if Xt.dtype.is_floating_point:
|
|
1182
|
+
# Convert NaN -> -1 and cast to long
|
|
1183
|
+
Xt = torch.nan_to_num(Xt, nan=-1.0).long()
|
|
1184
|
+
else:
|
|
1185
|
+
Xt = Xt.long()
|
|
1186
|
+
|
|
911
1187
|
B, L = Xt.shape
|
|
912
|
-
K = self.num_classes_
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
1188
|
+
K = int(num_classes) if num_classes is not None else int(self.num_classes_)
|
|
1189
|
+
|
|
1190
|
+
# Missing is anything < 0 (covers -1, -9, etc.)
|
|
1191
|
+
valid = Xt >= 0
|
|
1192
|
+
|
|
1193
|
+
# If binary mode but data is {0,2}
|
|
1194
|
+
# (haploid-like or "ref vs non-ref"), map 2->1
|
|
1195
|
+
if K == 2:
|
|
1196
|
+
has_het = torch.any(valid & (Xt == 1))
|
|
1197
|
+
has_alt2 = torch.any(valid & (Xt == 2))
|
|
1198
|
+
if has_alt2 and not has_het:
|
|
1199
|
+
Xt = Xt.clone()
|
|
1200
|
+
Xt[valid & (Xt == 2)] = 1
|
|
1201
|
+
|
|
1202
|
+
# Now enforce the one-hot precondition
|
|
1203
|
+
if torch.any(valid & (Xt >= K)):
|
|
1204
|
+
bad_vals = torch.unique(Xt[valid & (Xt >= K)]).detach().cpu().tolist()
|
|
1205
|
+
all_vals = torch.unique(Xt[valid]).detach().cpu().tolist()
|
|
1206
|
+
msg = f"_one_hot_encode_012 received class values outside [0, {K-1}]. num_classes={K}, offending_values={bad_vals}, observed_values={all_vals}. Upstream encoding mismatch (e.g., passing 0/1/2 with num_classes=2)."
|
|
1207
|
+
self.logger.error(msg)
|
|
1208
|
+
raise ValueError(msg)
|
|
1209
|
+
|
|
1210
|
+
# CHANGE: Initialize with -1.0 to ensure missing values are represented as [-1, -1, ... -1]
|
|
1211
|
+
X_ohe = torch.full((B, L, K), -1.0, dtype=torch.long, device=self.device)
|
|
1212
|
+
|
|
1213
|
+
idx = Xt[valid]
|
|
916
1214
|
|
|
917
1215
|
if idx.numel() > 0:
|
|
918
|
-
|
|
1216
|
+
# Overwrite valid positions (which were -1) with the correct one-hot vectors
|
|
1217
|
+
X_ohe[valid] = F.one_hot(idx, num_classes=K).long()
|
|
919
1218
|
|
|
920
1219
|
return X_ohe
|
|
921
1220
|
|
|
922
|
-
def
|
|
923
|
-
self,
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
do_latent_infer: bool = False,
|
|
931
|
-
latent_steps: int = 50,
|
|
932
|
-
latent_lr: float = 1e-2,
|
|
933
|
-
latent_weight_decay: float = 0.0,
|
|
934
|
-
latent_seed: int = 123,
|
|
935
|
-
_latent_cache: dict | None = None,
|
|
936
|
-
_latent_cache_key: str | None = None,
|
|
937
|
-
eval_mask_override: np.ndarray | None = None,
|
|
938
|
-
) -> float:
|
|
939
|
-
"""Compute a scalar metric (to MAXIMIZE) on a fixed validation set.
|
|
940
|
-
|
|
941
|
-
This method evaluates the model on a validation dataset and computes a specified metric, which is used for pruning decisions during hyperparameter tuning. It supports optional latent inference to optimize latent representations before evaluation. The method handles potential issues with non-finite metric values by returning negative infinity, making it easier to prune poorly performing trials.
|
|
1221
|
+
def decode_012(
|
|
1222
|
+
self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
|
|
1223
|
+
) -> np.ndarray:
|
|
1224
|
+
"""Decode 012-encodings to IUPAC chars with metadata repair.
|
|
1225
|
+
|
|
1226
|
+
This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
|
|
1227
|
+
1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
|
|
1228
|
+
2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
|
|
942
1229
|
|
|
943
1230
|
Args:
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
params (dict): Model parameters.
|
|
947
|
-
metric (str): Metric name to return.
|
|
948
|
-
objective_mode (bool): If True, use objective-mode evaluation. Default is True.
|
|
949
|
-
do_latent_infer (bool): If True, perform latent inference before evaluation. Default
|
|
950
|
-
latent_steps (int): Number of steps for latent inference. Default is 50.
|
|
951
|
-
latent_lr (float): Learning rate for latent inference. Default is 1e-2
|
|
952
|
-
latent_weight_decay (float): Weight decay for latent inference. Default is 0.0.
|
|
953
|
-
latent_seed (int): Random seed for latent inference. Default is 123.
|
|
954
|
-
_latent_cache (dict | None): Optional cache for storing/retrieving optimized latents
|
|
955
|
-
_latent_cache_key (str | None): Key for storing/retrieving in _latent_cache.
|
|
956
|
-
eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
|
|
1231
|
+
X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
|
|
1232
|
+
is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
|
|
957
1233
|
|
|
958
1234
|
Returns:
|
|
959
|
-
|
|
1235
|
+
np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
|
|
1236
|
+
|
|
1237
|
+
Notes:
|
|
1238
|
+
- The method normalizes input values to handle various formats, including strings, lists, and arrays.
|
|
1239
|
+
- It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
|
|
1240
|
+
- Missing or invalid codes are represented as 'N' if they can't be resolved.
|
|
1241
|
+
- The method includes repair logic to infer missing metadata from the source SNP data when necessary.
|
|
1242
|
+
|
|
1243
|
+
Raises:
|
|
1244
|
+
ValueError: If input is not a DataFrame.
|
|
960
1245
|
"""
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
1246
|
+
df = validate_input_type(X, return_type="df")
|
|
1247
|
+
|
|
1248
|
+
if not isinstance(df, pd.DataFrame):
|
|
1249
|
+
msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
|
|
1250
|
+
self.logger.error(msg)
|
|
1251
|
+
raise ValueError(msg)
|
|
1252
|
+
|
|
1253
|
+
# IUPAC Definitions
|
|
1254
|
+
iupac_to_bases: dict[str, set[str]] = {
|
|
1255
|
+
"A": {"A"},
|
|
1256
|
+
"C": {"C"},
|
|
1257
|
+
"G": {"G"},
|
|
1258
|
+
"T": {"T"},
|
|
1259
|
+
"R": {"A", "G"},
|
|
1260
|
+
"Y": {"C", "T"},
|
|
1261
|
+
"S": {"G", "C"},
|
|
1262
|
+
"W": {"A", "T"},
|
|
1263
|
+
"K": {"G", "T"},
|
|
1264
|
+
"M": {"A", "C"},
|
|
1265
|
+
"B": {"C", "G", "T"},
|
|
1266
|
+
"D": {"A", "G", "T"},
|
|
1267
|
+
"H": {"A", "C", "T"},
|
|
1268
|
+
"V": {"A", "C", "G"},
|
|
1269
|
+
"N": set(),
|
|
1270
|
+
}
|
|
1271
|
+
bases_to_iupac = {
|
|
1272
|
+
frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
|
|
1273
|
+
}
|
|
1274
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
|
|
1275
|
+
|
|
1276
|
+
def _normalize_iupac(value: object) -> str | None:
|
|
1277
|
+
"""Normalize an input into a single IUPAC code token or None."""
|
|
1278
|
+
if value is None:
|
|
1279
|
+
return None
|
|
1280
|
+
|
|
1281
|
+
# Bytes -> str (make type narrowing explicit)
|
|
1282
|
+
if isinstance(value, (bytes, np.bytes_)):
|
|
1283
|
+
value = bytes(value).decode("utf-8", errors="ignore")
|
|
1284
|
+
|
|
1285
|
+
# Handle list/tuple/array/Series: take first valid
|
|
1286
|
+
if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
|
|
1287
|
+
# Convert Series to numpy array for consistent behavior
|
|
1288
|
+
if isinstance(value, pd.Series):
|
|
1289
|
+
arr = value.to_numpy()
|
|
1290
|
+
else:
|
|
1291
|
+
arr = value
|
|
1292
|
+
|
|
1293
|
+
# Scalar numpy array fast path
|
|
1294
|
+
if isinstance(arr, np.ndarray) and arr.ndim == 0:
|
|
1295
|
+
return _normalize_iupac(arr.item())
|
|
1296
|
+
|
|
1297
|
+
# Empty sequence/array
|
|
1298
|
+
if len(arr) == 0:
|
|
1299
|
+
return None
|
|
1300
|
+
|
|
1301
|
+
# First valid element wins
|
|
1302
|
+
for item in arr:
|
|
1303
|
+
code = _normalize_iupac(item)
|
|
1304
|
+
if code is not None:
|
|
1305
|
+
return code
|
|
1306
|
+
return None
|
|
1307
|
+
|
|
1308
|
+
s = str(value).upper().strip()
|
|
1309
|
+
if not s or s in missing_codes:
|
|
1310
|
+
return None
|
|
1311
|
+
|
|
1312
|
+
if "," in s:
|
|
1313
|
+
for tok in (t.strip() for t in s.split(",")):
|
|
1314
|
+
if tok and tok not in missing_codes and tok in iupac_to_bases:
|
|
1315
|
+
return tok
|
|
1316
|
+
return None
|
|
1317
|
+
|
|
1318
|
+
return s if s in iupac_to_bases else None
|
|
1319
|
+
|
|
1320
|
+
codes_df = df.apply(pd.to_numeric, errors="coerce")
|
|
1321
|
+
codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
|
|
1322
|
+
n_rows, n_cols = codes.shape
|
|
1323
|
+
|
|
1324
|
+
if is_nuc:
|
|
1325
|
+
iupac_list = np.array(
|
|
1326
|
+
["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
|
|
974
1327
|
)
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
)
|
|
1328
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
1329
|
+
mask = (codes >= 0) & (codes <= 9)
|
|
1330
|
+
out[mask] = iupac_list[codes[mask]]
|
|
1331
|
+
return out
|
|
1332
|
+
|
|
1333
|
+
# Metadata fetch
|
|
1334
|
+
ref_alleles = getattr(self.genotype_data, "ref", [])
|
|
1335
|
+
alt_alleles = getattr(self.genotype_data, "alt", [])
|
|
1336
|
+
|
|
1337
|
+
if len(ref_alleles) != n_cols:
|
|
1338
|
+
ref_alleles = getattr(self, "_ref", [None] * n_cols)
|
|
1339
|
+
if len(alt_alleles) != n_cols:
|
|
1340
|
+
alt_alleles = getattr(self, "_alt", [None] * n_cols)
|
|
1341
|
+
|
|
1342
|
+
# Ensure list length matches
|
|
1343
|
+
if len(ref_alleles) != n_cols:
|
|
1344
|
+
ref_alleles = [None] * n_cols
|
|
1345
|
+
if len(alt_alleles) != n_cols:
|
|
1346
|
+
alt_alleles = [None] * n_cols
|
|
1347
|
+
|
|
1348
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
1349
|
+
source_snp_data = None
|
|
1350
|
+
|
|
1351
|
+
for j in range(n_cols):
|
|
1352
|
+
ref = _normalize_iupac(ref_alleles[j])
|
|
1353
|
+
alt = _normalize_iupac(alt_alleles[j])
|
|
1354
|
+
|
|
1355
|
+
# --- REPAIR LOGIC ---
|
|
1356
|
+
# If metadata is missing, scan the source column.
|
|
1357
|
+
if ref is None or alt is None:
|
|
1358
|
+
if source_snp_data is None and self.genotype_data.snp_data is not None:
|
|
1359
|
+
try:
|
|
1360
|
+
source_snp_data = np.asarray(self.genotype_data.snp_data)
|
|
1361
|
+
except Exception:
|
|
1362
|
+
pass # if lazy loading fails
|
|
1363
|
+
|
|
1364
|
+
if source_snp_data is not None:
|
|
1365
|
+
try:
|
|
1366
|
+
col_data = source_snp_data[:, j]
|
|
1367
|
+
uniques = set()
|
|
1368
|
+
# Optimization: check up to 200 non-empty values
|
|
1369
|
+
count = 0
|
|
1370
|
+
for val in col_data:
|
|
1371
|
+
norm = _normalize_iupac(val)
|
|
1372
|
+
if norm:
|
|
1373
|
+
uniques.add(norm)
|
|
1374
|
+
count += 1
|
|
1375
|
+
if len(uniques) >= 2 or count > 200:
|
|
1376
|
+
break
|
|
1377
|
+
|
|
1378
|
+
sorted_u = sorted(list(uniques))
|
|
1379
|
+
if len(sorted_u) >= 1 and ref is None:
|
|
1380
|
+
ref = sorted_u[0]
|
|
1381
|
+
if len(sorted_u) >= 2 and alt is None:
|
|
1382
|
+
alt = sorted_u[1]
|
|
1383
|
+
except Exception:
|
|
1384
|
+
pass
|
|
1385
|
+
|
|
1386
|
+
# --- DEFAULTS FOR MISSING ---
|
|
1387
|
+
# If still missing, we cannot decode.
|
|
1388
|
+
if ref is None and alt is None:
|
|
1389
|
+
ref = "N"
|
|
1390
|
+
alt = "N"
|
|
1391
|
+
elif ref is None:
|
|
1392
|
+
ref = alt
|
|
1393
|
+
elif alt is None:
|
|
1394
|
+
alt = ref # Monomorphic site: ALT becomes REF
|
|
1395
|
+
|
|
1396
|
+
# --- COMPUTE HET CODE ---
|
|
1397
|
+
if ref == alt:
|
|
1398
|
+
het_code = ref
|
|
1399
|
+
else:
|
|
1400
|
+
ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
|
|
1401
|
+
alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
|
|
1402
|
+
union_set = frozenset(ref_set | alt_set)
|
|
1403
|
+
het_code = bases_to_iupac.get(union_set, "N")
|
|
993
1404
|
|
|
994
|
-
|
|
995
|
-
|
|
1405
|
+
# --- ASSIGNMENT WITH SAFETY FALLBACKS ---
|
|
1406
|
+
col_codes = codes[:, j]
|
|
996
1407
|
|
|
997
|
-
|
|
998
|
-
|
|
1408
|
+
# Case 0: REF
|
|
1409
|
+
if ref != "N":
|
|
1410
|
+
out[col_codes == 0, j] = ref
|
|
999
1411
|
|
|
1000
|
-
|
|
1412
|
+
# Case 1: HET
|
|
1413
|
+
if het_code != "N":
|
|
1414
|
+
out[col_codes == 1, j] = het_code
|
|
1415
|
+
else:
|
|
1416
|
+
# If HET code is invalid (e.g. ref='A', alt='N'),
|
|
1417
|
+
# fallback to REF
|
|
1418
|
+
# Fix for an issue where a HET prediction at a monomorphic site
|
|
1419
|
+
# produced 'N'
|
|
1420
|
+
if ref != "N":
|
|
1421
|
+
out[col_codes == 1, j] = ref
|
|
1422
|
+
|
|
1423
|
+
# Case 2: ALT
|
|
1424
|
+
if alt != "N":
|
|
1425
|
+
out[col_codes == 2, j] = alt
|
|
1426
|
+
else:
|
|
1427
|
+
# If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
|
|
1428
|
+
# Fix for an issue where an ALT prediction on a monomorphic site
|
|
1429
|
+
# produced 'N'
|
|
1430
|
+
if ref != "N":
|
|
1431
|
+
out[col_codes == 2, j] = ref
|
|
1432
|
+
|
|
1433
|
+
return out
|
|
1434
|
+
|
|
1435
|
+
def _save_best_params(
|
|
1436
|
+
self, best_params: Dict[str, Any], objective_mode: bool = False
|
|
1437
|
+
) -> None:
|
|
1438
|
+
"""Save the best hyperparameters to a JSON file.
|
|
1439
|
+
|
|
1440
|
+
This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
|
|
1444
|
+
"""
|
|
1445
|
+
if not hasattr(self, "parameters_dir"):
|
|
1446
|
+
msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
|
|
1447
|
+
self.logger.error(msg)
|
|
1448
|
+
raise AttributeError(msg)
|
|
1449
|
+
|
|
1450
|
+
if objective_mode:
|
|
1451
|
+
fout = self.optimize_dir / "parameters" / "best_tuned_parameters.json"
|
|
1452
|
+
else:
|
|
1453
|
+
fout = self.parameters_dir / "best_parameters.json"
|
|
1454
|
+
|
|
1455
|
+
fout.parent.mkdir(parents=True, exist_ok=True)
|
|
1456
|
+
|
|
1457
|
+
with open(fout, "w") as f:
|
|
1458
|
+
json.dump(best_params, f, indent=4)
|
|
1001
1459
|
|
|
1002
|
-
def
|
|
1003
|
-
"""
|
|
1460
|
+
def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
1461
|
+
"""An abstract method for setting best parameters."""
|
|
1462
|
+
raise NotImplementedError
|
|
1004
1463
|
|
|
1005
|
-
|
|
1464
|
+
def sim_missing_transform(
|
|
1465
|
+
self, X: np.ndarray
|
|
1466
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
1467
|
+
"""Simulate missing data according to the specified strategy.
|
|
1006
1468
|
|
|
1007
1469
|
Args:
|
|
1008
|
-
|
|
1470
|
+
X (np.ndarray): Genotype matrix to simulate missing data on.
|
|
1009
1471
|
|
|
1010
1472
|
Returns:
|
|
1011
|
-
|
|
1473
|
+
X_for_model (np.ndarray): Genotype matrix with simulated missing data.
|
|
1474
|
+
sim_mask (np.ndarray): Boolean mask of simulated missing entries.
|
|
1475
|
+
orig_mask (np.ndarray): Boolean mask of original missing entries.
|
|
1012
1476
|
"""
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1477
|
+
if (
|
|
1478
|
+
not hasattr(self, "sim_prop")
|
|
1479
|
+
or self.sim_prop <= 0.0
|
|
1480
|
+
or self.sim_prop >= 1.0
|
|
1481
|
+
):
|
|
1482
|
+
msg = "sim_prop must be set and between 0.0 and 1.0."
|
|
1483
|
+
self.logger.error(msg)
|
|
1484
|
+
raise AttributeError(msg)
|
|
1017
1485
|
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1486
|
+
if not hasattr(self, "tree_parser") and "nonrandom" in self.sim_strategy:
|
|
1487
|
+
msg = "tree_parser must be set for 'nonrandom' or 'nonrandom_weighted' sim_strategy."
|
|
1488
|
+
self.logger.error(msg)
|
|
1489
|
+
raise AttributeError(msg)
|
|
1022
1490
|
|
|
1023
|
-
|
|
1491
|
+
# --- Simulate missing data ---
|
|
1492
|
+
X_for_sim = X.astype(np.float32, copy=True)
|
|
1493
|
+
tr = SimMissingTransformer(
|
|
1494
|
+
genotype_data=self.genotype_data,
|
|
1495
|
+
tree_parser=self.tree_parser,
|
|
1496
|
+
prop_missing=self.sim_prop,
|
|
1497
|
+
strategy=self.sim_strategy,
|
|
1498
|
+
missing_val=-1,
|
|
1499
|
+
mask_missing=True,
|
|
1500
|
+
verbose=self.verbose,
|
|
1501
|
+
seed=self.seed,
|
|
1502
|
+
**self.sim_kwargs,
|
|
1503
|
+
)
|
|
1504
|
+
tr.fit(X_for_sim.copy())
|
|
1505
|
+
X_for_model = tr.transform(X_for_sim.copy())
|
|
1506
|
+
sim_mask = tr.sim_missing_mask_.astype(bool)
|
|
1507
|
+
orig_mask = tr.original_missing_mask_.astype(bool)
|
|
1508
|
+
|
|
1509
|
+
return X_for_model, sim_mask, orig_mask
|
|
1510
|
+
|
|
1511
|
+
def _train_val_test_split(
|
|
1512
|
+
self, X: np.ndarray
|
|
1513
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
1514
|
+
"""Split data into train, validation, and test sets.
|
|
1024
1515
|
|
|
1025
1516
|
Args:
|
|
1026
|
-
|
|
1027
|
-
|
|
1517
|
+
X (np.ndarray): Genotype matrix to split.
|
|
1518
|
+
|
|
1519
|
+
Returns:
|
|
1520
|
+
tuple[np.ndarray, np.ndarray, np.ndarray]: Indices for train, validation, and test sets.
|
|
1028
1521
|
|
|
1029
1522
|
Raises:
|
|
1030
|
-
ValueError: If
|
|
1523
|
+
ValueError: If there are not enough samples for splitting.
|
|
1524
|
+
AssertionError: If validation_split is not in (0.0, 1.0).
|
|
1031
1525
|
"""
|
|
1032
|
-
|
|
1033
|
-
first_in = self._first_linear_in_features(model)
|
|
1034
|
-
if first_in != zdim:
|
|
1035
|
-
raise ValueError(
|
|
1036
|
-
f"Latent mismatch: zdim={zdim}, model first Linear expects in_features={first_in}"
|
|
1037
|
-
)
|
|
1526
|
+
n_samples = X.shape[0]
|
|
1038
1527
|
|
|
1039
|
-
|
|
1040
|
-
|
|
1528
|
+
if n_samples < 3:
|
|
1529
|
+
msg = f"Not enough samples ({n_samples}) for train/val/test split."
|
|
1530
|
+
self.logger.error(msg)
|
|
1531
|
+
raise ValueError(msg)
|
|
1041
1532
|
|
|
1042
|
-
|
|
1533
|
+
assert (
|
|
1534
|
+
self.validation_split > 0.0 and self.validation_split < 1.0
|
|
1535
|
+
), f"validation_split must be in (0.0, 1.0), but got {self.validation_split}."
|
|
1043
1536
|
|
|
1044
|
-
|
|
1045
|
-
|
|
1537
|
+
# Train/Val split
|
|
1538
|
+
indices = np.arange(n_samples)
|
|
1539
|
+
train_idx, val_test_idx = train_test_split(
|
|
1540
|
+
indices,
|
|
1541
|
+
test_size=self.validation_split,
|
|
1542
|
+
random_state=self.seed,
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
if not val_test_idx.size >= 4:
|
|
1546
|
+
msg = f"Not enough samples ({val_test_idx.size}) for validation/test split."
|
|
1547
|
+
self.logger.error(msg)
|
|
1548
|
+
raise ValueError(msg)
|
|
1549
|
+
|
|
1550
|
+
# Split val and test equally
|
|
1551
|
+
val_idx, test_idx = train_test_split(
|
|
1552
|
+
val_test_idx, test_size=0.5, random_state=self.seed
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
return train_idx, val_idx, test_idx
|
|
1556
|
+
|
|
1557
|
+
def _get_data_loaders(
|
|
1558
|
+
self,
|
|
1559
|
+
X: np.ndarray,
|
|
1560
|
+
y: np.ndarray,
|
|
1561
|
+
mask: np.ndarray,
|
|
1562
|
+
batch_size: int,
|
|
1563
|
+
*,
|
|
1564
|
+
shuffle: bool = True,
|
|
1565
|
+
) -> torch.utils.data.DataLoader:
|
|
1566
|
+
"""Create DataLoader for training and validation.
|
|
1567
|
+
|
|
1568
|
+
Args:
|
|
1569
|
+
X (np.ndarray): 0/1/2-encoded input matrix.
|
|
1570
|
+
y (np.ndarray): 0/1/2-encoded matrix with -1 for missing.
|
|
1571
|
+
mask (np.ndarray): Boolean mask of entries to score in the loss.
|
|
1572
|
+
batch_size (int): Batch size.
|
|
1573
|
+
shuffle (bool): Whether to shuffle batches. Defaults to True.
|
|
1574
|
+
|
|
1575
|
+
Returns:
|
|
1576
|
+
The DataLoader.
|
|
1046
1577
|
"""
|
|
1047
|
-
|
|
1048
|
-
return
|
|
1578
|
+
dataset = _MaskedNumpyDataset(X, y, mask)
|
|
1049
1579
|
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1580
|
+
return torch.utils.data.DataLoader(
|
|
1581
|
+
dataset,
|
|
1582
|
+
batch_size=batch_size,
|
|
1583
|
+
shuffle=shuffle,
|
|
1584
|
+
pin_memory=(str(self.device).startswith("cuda")),
|
|
1585
|
+
)
|
|
1053
1586
|
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1587
|
+
def _update_anneal_schedule(
|
|
1588
|
+
self,
|
|
1589
|
+
final: float,
|
|
1590
|
+
warm: int,
|
|
1591
|
+
ramp: int,
|
|
1592
|
+
epoch: int,
|
|
1593
|
+
*,
|
|
1594
|
+
init_val: float = 0.0,
|
|
1595
|
+
) -> torch.Tensor:
|
|
1596
|
+
"""Update annealed hyperparameter value based on epoch.
|
|
1057
1597
|
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1598
|
+
Args:
|
|
1599
|
+
final (float): Final value after annealing.
|
|
1600
|
+
warm (int): Number of warm-up epochs.
|
|
1601
|
+
ramp (int): Number of ramp-up epochs.
|
|
1602
|
+
epoch (int): Current epoch number.
|
|
1603
|
+
init_val (float): Initial value before annealing starts.
|
|
1604
|
+
|
|
1605
|
+
Returns:
|
|
1606
|
+
torch.Tensor: Current value of the hyperparameter.
|
|
1607
|
+
"""
|
|
1608
|
+
if epoch < warm:
|
|
1609
|
+
val = torch.tensor(init_val)
|
|
1610
|
+
elif epoch < warm + ramp:
|
|
1611
|
+
val = torch.tensor(final * ((epoch - warm) / ramp))
|
|
1061
1612
|
else:
|
|
1062
|
-
|
|
1613
|
+
val = torch.tensor(final)
|
|
1063
1614
|
|
|
1064
|
-
|
|
1065
|
-
tr, te = train_test_split(
|
|
1066
|
-
idx, test_size=self.validation_split, random_state=self.seed
|
|
1067
|
-
)
|
|
1068
|
-
self._tune_train_idx = tr
|
|
1069
|
-
self._tune_test_idx = te
|
|
1070
|
-
self._tune_X_train = X_small[tr]
|
|
1071
|
-
self._tune_X_test = X_small[te]
|
|
1615
|
+
return val.to(self.device)
|
|
1072
1616
|
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1617
|
+
def _anneal_config(
|
|
1618
|
+
self,
|
|
1619
|
+
params: Optional[dict],
|
|
1620
|
+
key: str,
|
|
1621
|
+
default: float,
|
|
1622
|
+
max_epochs: int,
|
|
1623
|
+
*,
|
|
1624
|
+
warm_alt: int = 50,
|
|
1625
|
+
ramp_alt: int = 100,
|
|
1626
|
+
) -> Tuple[float, int, int]:
|
|
1627
|
+
"""Configure annealing schedule for a hyperparameter.
|
|
1076
1628
|
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1629
|
+
Args:
|
|
1630
|
+
params (Optional[dict]): Dictionary of parameters to extract from.
|
|
1631
|
+
key (str): Key to look for in params.
|
|
1632
|
+
default (float): Default final value if not specified in params.
|
|
1633
|
+
max_epochs (int): Total number of training epochs.
|
|
1634
|
+
warm_alt (int): Alternative warm-up period if 10% of epochs is too long
|
|
1635
|
+
ramp_alt (int): Alternative ramp-up period if 20% of epochs is too long
|
|
1082
1636
|
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1637
|
+
Returns:
|
|
1638
|
+
Tuple[float, int, int]: Final value, warm-up epochs, ramp-up epochs.
|
|
1639
|
+
"""
|
|
1640
|
+
val = None
|
|
1641
|
+
if params is not None and params:
|
|
1642
|
+
if not hasattr(self, key):
|
|
1643
|
+
msg = f"Attribute '{key}' not found for anneal_config."
|
|
1644
|
+
self.logger.error(msg)
|
|
1645
|
+
raise AttributeError(msg)
|
|
1086
1646
|
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
):
|
|
1092
|
-
self._tune_eval_slice = np.arange(self.tune_proxy_metric_batch)
|
|
1647
|
+
val = params.get(key, getattr(self, key))
|
|
1648
|
+
|
|
1649
|
+
if val is not None and isinstance(val, (float, int)):
|
|
1650
|
+
final = float(val)
|
|
1093
1651
|
else:
|
|
1094
|
-
|
|
1652
|
+
final = default
|
|
1095
1653
|
|
|
1096
|
-
|
|
1654
|
+
warm, ramp = min(int(0.1 * max_epochs), warm_alt), min(
|
|
1655
|
+
int(0.2 * max_epochs), ramp_alt
|
|
1656
|
+
)
|
|
1657
|
+
return final, warm, ramp
|
|
1097
1658
|
|
|
1098
|
-
def
|
|
1099
|
-
"""
|
|
1659
|
+
def _repair_ref_alt_from_iupac(self, loci: np.ndarray) -> None:
|
|
1660
|
+
"""Repair REF/ALT for specific loci using observed IUPAC genotypes.
|
|
1100
1661
|
|
|
1101
|
-
|
|
1662
|
+
Args:
|
|
1663
|
+
loci (np.ndarray): Array of locus indices to repair.
|
|
1664
|
+
|
|
1665
|
+
Notes:
|
|
1666
|
+
- Modifies self.genotype_data.ref and self.genotype_data.alt in place.
|
|
1667
|
+
"""
|
|
1668
|
+
iupac_to_bases = {
|
|
1669
|
+
"A": {"A"},
|
|
1670
|
+
"C": {"C"},
|
|
1671
|
+
"G": {"G"},
|
|
1672
|
+
"T": {"T"},
|
|
1673
|
+
"R": {"A", "G"},
|
|
1674
|
+
"Y": {"C", "T"},
|
|
1675
|
+
"S": {"G", "C"},
|
|
1676
|
+
"W": {"A", "T"},
|
|
1677
|
+
"K": {"G", "T"},
|
|
1678
|
+
"M": {"A", "C"},
|
|
1679
|
+
"B": {"C", "G", "T"},
|
|
1680
|
+
"D": {"A", "G", "T"},
|
|
1681
|
+
"H": {"A", "C", "T"},
|
|
1682
|
+
"V": {"A", "C", "G"},
|
|
1683
|
+
}
|
|
1684
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
|
|
1685
|
+
|
|
1686
|
+
def norm(v: object) -> str | None:
|
|
1687
|
+
if v is None:
|
|
1688
|
+
return None
|
|
1689
|
+
s = str(v).upper().strip()
|
|
1690
|
+
if not s or s in missing_codes:
|
|
1691
|
+
return None
|
|
1692
|
+
return s if s in iupac_to_bases else None
|
|
1693
|
+
|
|
1694
|
+
snp = np.asarray(self.genotype_data.snp_data, dtype=object) # (N,L) IUPAC-ish
|
|
1695
|
+
refs = list(getattr(self.genotype_data, "ref", [None] * snp.shape[1]))
|
|
1696
|
+
alts = list(getattr(self.genotype_data, "alt", [None] * snp.shape[1]))
|
|
1697
|
+
|
|
1698
|
+
for j in loci:
|
|
1699
|
+
cnt = Counter()
|
|
1700
|
+
col = snp[:, int(j)]
|
|
1701
|
+
for g in col:
|
|
1702
|
+
code = norm(g)
|
|
1703
|
+
if code is None:
|
|
1704
|
+
continue
|
|
1705
|
+
for b in iupac_to_bases[code]:
|
|
1706
|
+
cnt[b] += 1
|
|
1707
|
+
|
|
1708
|
+
if not cnt:
|
|
1709
|
+
continue
|
|
1710
|
+
|
|
1711
|
+
common = [b for b, _ in cnt.most_common()]
|
|
1712
|
+
ref = common[0]
|
|
1713
|
+
alt = common[1] if len(common) > 1 else None
|
|
1714
|
+
|
|
1715
|
+
refs[int(j)] = ref
|
|
1716
|
+
alts[int(j)] = alt if alt is not None else "."
|
|
1717
|
+
|
|
1718
|
+
self.genotype_data.ref = np.asarray(refs, dtype=object)
|
|
1719
|
+
|
|
1720
|
+
if not isinstance(alts, np.ndarray):
|
|
1721
|
+
alts = np.array(alts, dtype=object).tolist()
|
|
1722
|
+
|
|
1723
|
+
self.genotype_data.alt = alts
|
|
1724
|
+
|
|
1725
|
+
def _aligned_ref_alt(self, L: int) -> tuple[list[object], list[object]]:
|
|
1726
|
+
"""Return REF/ALT aligned to the genotype matrix columns.
|
|
1102
1727
|
|
|
1103
1728
|
Args:
|
|
1104
|
-
|
|
1729
|
+
L (int): Number of loci (columns in genotype matrix).
|
|
1730
|
+
|
|
1731
|
+
Returns:
|
|
1732
|
+
tuple[list[object], list[object]]: Aligned REF and ALT lists.
|
|
1105
1733
|
"""
|
|
1106
|
-
|
|
1107
|
-
|
|
1734
|
+
refs = getattr(self.genotype_data, "ref", None)
|
|
1735
|
+
alts = getattr(self.genotype_data, "alt", None)
|
|
1736
|
+
|
|
1737
|
+
if refs is None or alts is None:
|
|
1738
|
+
msg = "genotype_data.ref/alt are required but missing."
|
|
1108
1739
|
self.logger.error(msg)
|
|
1109
|
-
raise
|
|
1740
|
+
raise ValueError(msg)
|
|
1110
1741
|
|
|
1111
|
-
|
|
1742
|
+
refs_arr = np.asarray(refs, dtype=object)
|
|
1743
|
+
alts_arr = np.asarray(alts, dtype=object)
|
|
1112
1744
|
|
|
1113
|
-
|
|
1114
|
-
|
|
1745
|
+
if refs_arr.shape[0] != L or alts_arr.shape[0] != L:
|
|
1746
|
+
msg = f"REF/ALT length mismatch vs matrix columns: L={L}, len(ref)={refs_arr.shape[0]}, len(alt)={alts_arr.shape[0]}. You are using REF/ALT metadata that is not aligned to pgenc.genotypes_012 columns. Fix by subsetting/refiltering ref/alt with the same locus mask used for the genotype matrix."
|
|
1747
|
+
self.logger.error(msg)
|
|
1748
|
+
raise ValueError(msg)
|
|
1115
1749
|
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1750
|
+
# Unwrap singleton ALT arrays like array(['T'], dtype=object)
|
|
1751
|
+
def unwrap(x: object) -> object:
|
|
1752
|
+
if isinstance(x, np.ndarray):
|
|
1753
|
+
if x.size == 0:
|
|
1754
|
+
return None
|
|
1755
|
+
if x.size == 1:
|
|
1756
|
+
return x.item()
|
|
1757
|
+
return x
|
|
1758
|
+
|
|
1759
|
+
refs_list = [unwrap(x) for x in refs_arr.tolist()]
|
|
1760
|
+
alts_list = [unwrap(x) for x in alts_arr.tolist()]
|
|
1761
|
+
return refs_list, alts_list
|
|
1762
|
+
|
|
1763
|
+
def _build_valid_class_mask(self) -> torch.Tensor:
|
|
1764
|
+
L = self.num_features_
|
|
1765
|
+
K = self.num_classes_
|
|
1766
|
+
mask = np.ones((L, K), dtype=bool)
|
|
1767
|
+
|
|
1768
|
+
# --- IUPAC helpers (single-character only) ---
|
|
1769
|
+
iupac_to_bases: dict[str, set[str]] = {
|
|
1770
|
+
"A": {"A"},
|
|
1771
|
+
"C": {"C"},
|
|
1772
|
+
"G": {"G"},
|
|
1773
|
+
"T": {"T"},
|
|
1774
|
+
"R": {"A", "G"},
|
|
1775
|
+
"Y": {"C", "T"},
|
|
1776
|
+
"S": {"G", "C"},
|
|
1777
|
+
"W": {"A", "T"},
|
|
1778
|
+
"K": {"G", "T"},
|
|
1779
|
+
"M": {"A", "C"},
|
|
1780
|
+
"B": {"C", "G", "T"},
|
|
1781
|
+
"D": {"A", "G", "T"},
|
|
1782
|
+
"H": {"A", "C", "T"},
|
|
1783
|
+
"V": {"A", "C", "G"},
|
|
1784
|
+
}
|
|
1785
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
|
|
1786
|
+
|
|
1787
|
+
# get aligned ref/alt (should be exactly length L)
|
|
1788
|
+
refs, alts = self._aligned_ref_alt(L)
|
|
1789
|
+
|
|
1790
|
+
def _normalize_iupac(value: object) -> str | None:
|
|
1791
|
+
"""Return a single-letter IUPAC code or None if missing/invalid."""
|
|
1792
|
+
if value is None:
|
|
1793
|
+
return None
|
|
1794
|
+
if isinstance(value, (bytes, np.bytes_)):
|
|
1795
|
+
value = value.decode("utf-8", errors="ignore")
|
|
1796
|
+
|
|
1797
|
+
# allow list/tuple/array containers (take first valid)
|
|
1798
|
+
if isinstance(value, (list, tuple, np.ndarray, pd.Series)):
|
|
1799
|
+
for item in value:
|
|
1800
|
+
code = _normalize_iupac(item)
|
|
1801
|
+
if code is not None:
|
|
1802
|
+
return code
|
|
1803
|
+
return None
|
|
1804
|
+
|
|
1805
|
+
s = str(value).upper().strip()
|
|
1806
|
+
if not s or s in missing_codes:
|
|
1807
|
+
return None
|
|
1808
|
+
|
|
1809
|
+
# handle comma-separated values
|
|
1810
|
+
if "," in s:
|
|
1811
|
+
for tok in (t.strip() for t in s.split(",")):
|
|
1812
|
+
if tok and tok not in missing_codes and tok in iupac_to_bases:
|
|
1813
|
+
return tok
|
|
1814
|
+
return None
|
|
1815
|
+
|
|
1816
|
+
return s if s in iupac_to_bases else None
|
|
1817
|
+
|
|
1818
|
+
# 1) metadata restriction
|
|
1819
|
+
for j in range(L):
|
|
1820
|
+
ref = _normalize_iupac(refs[j])
|
|
1821
|
+
alt = _normalize_iupac(alts[j])
|
|
1822
|
+
|
|
1823
|
+
if alt is None or (ref is not None and alt == ref):
|
|
1824
|
+
mask[j, :] = False
|
|
1825
|
+
mask[j, 0] = True
|
|
1826
|
+
|
|
1827
|
+
# 2) data-driven override
|
|
1828
|
+
y_train = getattr(self, "y_train_", None)
|
|
1829
|
+
if y_train is not None:
|
|
1830
|
+
y = np.asarray(y_train)
|
|
1831
|
+
if y.ndim == 2 and y.shape[1] == L:
|
|
1832
|
+
if K == 2:
|
|
1833
|
+
y = y.copy()
|
|
1834
|
+
y[y == 2] = 1
|
|
1835
|
+
valid = y >= 0
|
|
1836
|
+
if valid.any():
|
|
1837
|
+
observed = np.zeros((L, K), dtype=bool)
|
|
1838
|
+
for c in range(K):
|
|
1839
|
+
observed[:, c] = np.any(valid & (y == c), axis=0)
|
|
1840
|
+
|
|
1841
|
+
conflict = observed & (~mask)
|
|
1842
|
+
if conflict.any():
|
|
1843
|
+
loci = np.where(conflict.any(axis=1))[0]
|
|
1844
|
+
self.valid_class_mask_conflict_loci_ = loci
|
|
1845
|
+
self.logger.warning(
|
|
1846
|
+
f"valid_class_mask_ metadata forbids observed classes at {loci.size} loci. "
|
|
1847
|
+
"Expanding mask to include observed classes."
|
|
1848
|
+
)
|
|
1849
|
+
mask |= observed
|
|
1850
|
+
|
|
1851
|
+
bad = np.where(~mask.any(axis=1))[0]
|
|
1852
|
+
if bad.size:
|
|
1853
|
+
mask[bad, :] = False
|
|
1854
|
+
mask[bad, 0] = True
|
|
1855
|
+
|
|
1856
|
+
return torch.as_tensor(mask, dtype=torch.bool, device=self.device)
|