pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,10 @@ import copy
|
|
|
2
2
|
import gc
|
|
3
3
|
import json
|
|
4
4
|
import logging
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from datetime import datetime
|
|
5
7
|
from pathlib import Path
|
|
6
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
|
|
7
9
|
|
|
8
10
|
import matplotlib.pyplot as plt
|
|
9
11
|
import numpy as np
|
|
@@ -13,11 +15,18 @@ import plotly.graph_objects as go
|
|
|
13
15
|
import torch
|
|
14
16
|
import torch.nn.functional as F
|
|
15
17
|
from matplotlib.figure import Figure
|
|
16
|
-
from sklearn.metrics import
|
|
18
|
+
from sklearn.metrics import (
|
|
19
|
+
average_precision_score,
|
|
20
|
+
classification_report,
|
|
21
|
+
jaccard_score,
|
|
22
|
+
matthews_corrcoef,
|
|
23
|
+
)
|
|
17
24
|
from sklearn.model_selection import train_test_split
|
|
18
25
|
from snpio import SNPioMultiQC
|
|
19
26
|
from snpio.utils.logging import LoggerManager
|
|
27
|
+
from snpio.utils.misc import validate_input_type
|
|
20
28
|
|
|
29
|
+
from pgsui.data_processing.transformers import SimMissingTransformer
|
|
21
30
|
from pgsui.impute.unsupervised.nn_scorers import Scorer
|
|
22
31
|
from pgsui.utils.classification_viz import ClassificationReportVisualizer
|
|
23
32
|
from pgsui.utils.logging_utils import configure_logger
|
|
@@ -27,16 +36,24 @@ from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
|
27
36
|
if TYPE_CHECKING:
|
|
28
37
|
from snpio.read_input.genotype_data import GenotypeData
|
|
29
38
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
39
|
+
|
|
40
|
+
class _MaskedNumpyDataset(torch.utils.data.Dataset):
|
|
41
|
+
def __init__(self, X: np.ndarray, y: np.ndarray, mask: np.ndarray):
|
|
42
|
+
self.X = X
|
|
43
|
+
self.y = y
|
|
44
|
+
self.mask = mask.astype(np.bool_, copy=False)
|
|
45
|
+
|
|
46
|
+
def __len__(self) -> int:
|
|
47
|
+
return self.X.shape[0]
|
|
48
|
+
|
|
49
|
+
def __getitem__(self, idx: int):
|
|
50
|
+
return self.X[idx], self.y[idx], self.mask[idx]
|
|
34
51
|
|
|
35
52
|
|
|
36
53
|
class BaseNNImputer:
|
|
37
|
-
"""
|
|
54
|
+
"""Abstract base class for neural network-based imputers.
|
|
38
55
|
|
|
39
|
-
This class provides
|
|
56
|
+
This class provides shared infrastructure for NN imputers (e.g., directory/logging setup, Optuna tuning, model construction helpers, class-weight utilities, standardized plotting/scoring, and IUPAC decoding). It is not intended to be instantiated directly; subclasses must implement the abstract methods.
|
|
40
57
|
"""
|
|
41
58
|
|
|
42
59
|
def __init__(
|
|
@@ -54,6 +71,8 @@ class BaseNNImputer:
|
|
|
54
71
|
This constructor sets up the device (CPU, GPU, or MPS), creates the necessary output directories for models and results, and a logger. It also initializes a genotype encoder for handling genotype data.
|
|
55
72
|
|
|
56
73
|
Args:
|
|
74
|
+
model_name (str): The model class name used in output paths and logs.
|
|
75
|
+
genotype_data (GenotypeData): Backing genotype data object.
|
|
57
76
|
prefix (str): A prefix used to name the output directory (e.g., 'pgsui_output').
|
|
58
77
|
device (Literal["gpu", "cpu", "mps"]): The device to use for PyTorch operations. If 'gpu' or 'mps' is chosen, it will fall back to 'cpu' if the required hardware is not available. Defaults to "cpu".
|
|
59
78
|
verbose (bool): If True, enables detailed logging output. Defaults to False.
|
|
@@ -62,6 +81,11 @@ class BaseNNImputer:
|
|
|
62
81
|
self.model_name = model_name
|
|
63
82
|
self.genotype_data = genotype_data
|
|
64
83
|
|
|
84
|
+
if not hasattr(self, "tree_parser"):
|
|
85
|
+
self.tree_parser = None
|
|
86
|
+
if not hasattr(self, "sim_kwargs"):
|
|
87
|
+
self.sim_kwargs = {}
|
|
88
|
+
|
|
65
89
|
self.prefix = prefix
|
|
66
90
|
self.verbose = verbose
|
|
67
91
|
self.debug = debug
|
|
@@ -89,15 +113,15 @@ class BaseNNImputer:
|
|
|
89
113
|
self.logger = configure_logger(
|
|
90
114
|
logman.get_logger(), verbose=self.verbose, debug=self.debug
|
|
91
115
|
)
|
|
92
|
-
|
|
93
|
-
self.
|
|
116
|
+
|
|
117
|
+
self.logger.info(f"Using PyTorch device: {self.device.type}.")
|
|
94
118
|
|
|
95
119
|
# To be initialized by child classes or fit method
|
|
96
120
|
self.tune_save_db: bool = False
|
|
97
121
|
self.tune_resume: bool = False
|
|
98
122
|
self.n_trials: int = 100
|
|
99
123
|
self.model_params: Dict[str, Any] = {}
|
|
100
|
-
self.tune_metric: str = "
|
|
124
|
+
self.tune_metric: str = "f1"
|
|
101
125
|
self.learning_rate: float = 1e-3
|
|
102
126
|
self.plotter_: "Plotting"
|
|
103
127
|
self.num_features_: int = 0
|
|
@@ -110,21 +134,16 @@ class BaseNNImputer:
|
|
|
110
134
|
self.show_plots: bool = False
|
|
111
135
|
self.scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
|
|
112
136
|
self.pgenc: Any = None
|
|
113
|
-
self.
|
|
137
|
+
self.is_haploid_: bool = False
|
|
114
138
|
self.ploidy: int = 2
|
|
115
139
|
self.beta: float = 0.9999
|
|
116
|
-
self.max_ratio: float =
|
|
117
|
-
self.sim_strategy: str = "
|
|
118
|
-
self.sim_prop: float = 0.
|
|
119
|
-
self.seed: int
|
|
140
|
+
self.max_ratio: Optional[float] = None
|
|
141
|
+
self.sim_strategy: str = "random"
|
|
142
|
+
self.sim_prop: float = 0.2
|
|
143
|
+
self.seed: Optional[int] = None
|
|
120
144
|
self.rng: np.random.Generator = np.random.default_rng(self.seed)
|
|
121
145
|
self.ground_truth_: np.ndarray
|
|
122
|
-
self.tune_fast: bool = False
|
|
123
|
-
self.tune_max_samples: int = 1000
|
|
124
|
-
self.tune_max_loci: int = 500
|
|
125
146
|
self.validation_split: float = 0.2
|
|
126
|
-
self.tune_batch_size: int = 64
|
|
127
|
-
self.tune_proxy_metric_batch: int = 512
|
|
128
147
|
self.batch_size: int = 64
|
|
129
148
|
self.best_params_: Dict[str, Any] = {}
|
|
130
149
|
|
|
@@ -133,9 +152,11 @@ class BaseNNImputer:
|
|
|
133
152
|
self.plots_dir: Path
|
|
134
153
|
self.metrics_dir: Path
|
|
135
154
|
self.parameters_dir: Path
|
|
136
|
-
self.study_db: Path
|
|
155
|
+
self.study_db: Optional[Path] = None
|
|
156
|
+
self.X_model_input_: Optional[np.ndarray] = None
|
|
157
|
+
self.class_weights_: Optional[torch.Tensor] = None
|
|
137
158
|
|
|
138
|
-
def tune_hyperparameters(self) ->
|
|
159
|
+
def tune_hyperparameters(self) -> Dict[str, Any]:
|
|
139
160
|
"""Tunes model hyperparameters using an Optuna study.
|
|
140
161
|
|
|
141
162
|
This method orchestrates the hyperparameter search process. It creates an Optuna study that aims to maximize the metric defined in `self.tune_metric`. The search is driven by the `_objective` method, which must be implemented by the child class. After the search, the best parameters are logged, saved to a JSON file, and visualizations of the study are generated.
|
|
@@ -153,15 +174,20 @@ class BaseNNImputer:
|
|
|
153
174
|
study_db = None
|
|
154
175
|
load_if_exists = False
|
|
155
176
|
if self.tune_save_db:
|
|
156
|
-
|
|
177
|
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
178
|
+
study_db = (
|
|
179
|
+
self.optimize_dir / "study_database" / f"optuna_study_{timestamp}.db"
|
|
180
|
+
)
|
|
157
181
|
study_db.parent.mkdir(parents=True, exist_ok=True)
|
|
158
182
|
|
|
159
183
|
if self.tune_resume and study_db.exists():
|
|
160
184
|
load_if_exists = True
|
|
161
185
|
|
|
162
186
|
if not self.tune_resume and study_db.exists():
|
|
163
|
-
|
|
187
|
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
188
|
+
study_db = study_db.with_name(f"optuna_study_{timestamp}.db")
|
|
164
189
|
|
|
190
|
+
self.study_db = study_db
|
|
165
191
|
study_name = f"{self.prefix} {self.model_name} Model Optimization"
|
|
166
192
|
storage = f"sqlite:///{study_db}" if self.tune_save_db else None
|
|
167
193
|
|
|
@@ -170,7 +196,17 @@ class BaseNNImputer:
|
|
|
170
196
|
study_name=study_name,
|
|
171
197
|
storage=storage,
|
|
172
198
|
load_if_exists=load_if_exists,
|
|
173
|
-
pruner=optuna.pruners.MedianPruner(
|
|
199
|
+
pruner=optuna.pruners.MedianPruner(
|
|
200
|
+
# Guard against small `n_trials` values (e.g., 1)
|
|
201
|
+
# that can otherwise produce 0 startup/warmup/min trials.
|
|
202
|
+
n_startup_trials=max(
|
|
203
|
+
1, min(int(self.n_trials * 0.1), 10, int(self.n_trials))
|
|
204
|
+
),
|
|
205
|
+
n_warmup_steps=150,
|
|
206
|
+
n_min_trials=max(
|
|
207
|
+
1, min(int(0.5 * self.n_trials), 25, int(self.n_trials))
|
|
208
|
+
),
|
|
209
|
+
),
|
|
174
210
|
)
|
|
175
211
|
|
|
176
212
|
if not hasattr(self, "_objective"):
|
|
@@ -185,6 +221,13 @@ class BaseNNImputer:
|
|
|
185
221
|
|
|
186
222
|
show_progress_bar = not self.verbose and not self.debug and self.n_jobs == 1
|
|
187
223
|
|
|
224
|
+
# Set the best parameters.
|
|
225
|
+
# NOTE: _set_best_params() must be implemented in the child class.
|
|
226
|
+
if not hasattr(self, "_set_best_params"):
|
|
227
|
+
msg = "Method `_set_best_params()` must be implemented in the child class."
|
|
228
|
+
self.logger.error(msg)
|
|
229
|
+
raise NotImplementedError(msg)
|
|
230
|
+
|
|
188
231
|
study.optimize(
|
|
189
232
|
lambda trial: self._objective(trial),
|
|
190
233
|
n_trials=self.n_trials,
|
|
@@ -193,15 +236,13 @@ class BaseNNImputer:
|
|
|
193
236
|
show_progress_bar=show_progress_bar,
|
|
194
237
|
)
|
|
195
238
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
if not hasattr(self, "_set_best_params"):
|
|
202
|
-
msg = "Method `_set_best_params()` must be implemented in the child class."
|
|
239
|
+
try:
|
|
240
|
+
best_metric = study.best_value
|
|
241
|
+
best_params = study.best_params
|
|
242
|
+
except Exception:
|
|
243
|
+
msg = "Tuning failed: No successful trials completed."
|
|
203
244
|
self.logger.error(msg)
|
|
204
|
-
raise
|
|
245
|
+
raise RuntimeError(msg)
|
|
205
246
|
|
|
206
247
|
self.best_params_ = self._set_best_params(best_params)
|
|
207
248
|
self.model_params.update(self.best_params_)
|
|
@@ -210,45 +251,34 @@ class BaseNNImputer:
|
|
|
210
251
|
best_params_tmp = copy.deepcopy(best_params)
|
|
211
252
|
best_params_tmp["learning_rate"] = self.learning_rate
|
|
212
253
|
|
|
213
|
-
|
|
214
|
-
pm = PrettyMetrics(best_params_tmp, precision=6, title=title)
|
|
215
|
-
pm.render()
|
|
254
|
+
tn = f"{self.tune_metric} Value"
|
|
216
255
|
|
|
217
|
-
|
|
218
|
-
|
|
256
|
+
if self.show_plots:
|
|
257
|
+
self.plotter_.plot_tuning(
|
|
258
|
+
study, self.model_name, self.optimize_dir / "plots", target_name=tn
|
|
259
|
+
)
|
|
219
260
|
|
|
220
|
-
|
|
221
|
-
self.plotter_.plot_tuning(
|
|
222
|
-
study, self.model_name, self.optimize_dir / "plots", target_name=tn
|
|
223
|
-
)
|
|
261
|
+
return best_params_tmp
|
|
224
262
|
|
|
225
263
|
@staticmethod
|
|
226
264
|
def initialize_weights(module: torch.nn.Module) -> None:
|
|
227
|
-
"""Initializes model weights using
|
|
228
|
-
|
|
229
|
-
This static method is intended to be applied to a PyTorch model to initialize the weights of its linear and convolutional layers. This initialization scheme is particularly effective for networks that use ReLU-family activation functions, as it helps maintain stable activation variances during training.
|
|
265
|
+
"""Initializes model weights using Xavier/Glorot Uniform distribution.
|
|
230
266
|
|
|
231
|
-
|
|
232
|
-
|
|
267
|
+
Switching from Kaiming to Xavier is safer for deep VAEs to prevent
|
|
268
|
+
exploding gradients or dead neurons in the early epochs.
|
|
233
269
|
"""
|
|
234
270
|
if isinstance(
|
|
235
271
|
module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.ConvTranspose1d)
|
|
236
272
|
):
|
|
237
|
-
#
|
|
238
|
-
torch.nn.init.
|
|
273
|
+
# Xavier is generally more stable for VAEs than Kaiming
|
|
274
|
+
torch.nn.init.xavier_uniform_(module.weight)
|
|
239
275
|
if module.bias is not None:
|
|
240
276
|
torch.nn.init.zeros_(module.bias)
|
|
241
277
|
|
|
242
278
|
def build_model(
|
|
243
279
|
self,
|
|
244
|
-
Model:
|
|
245
|
-
|
|
246
|
-
| type["AutoencoderModel"]
|
|
247
|
-
| type["NLPCAModel"]
|
|
248
|
-
| type["UBPModel"]
|
|
249
|
-
| type["VAEModel"]
|
|
250
|
-
),
|
|
251
|
-
model_params: Dict[str, int | float | str | bool],
|
|
280
|
+
Model: Type[torch.nn.Module],
|
|
281
|
+
model_params: Dict[str, Any],
|
|
252
282
|
) -> torch.nn.Module:
|
|
253
283
|
"""Builds and initializes a neural network model instance.
|
|
254
284
|
|
|
@@ -277,20 +307,16 @@ class BaseNNImputer:
|
|
|
277
307
|
self.logger.error(msg)
|
|
278
308
|
raise AttributeError(msg)
|
|
279
309
|
|
|
280
|
-
# Start with a base set of fixed (non-tuned) parameters.
|
|
281
|
-
base_num_classes = getattr(self, "output_classes_", None)
|
|
282
|
-
if base_num_classes is None:
|
|
283
|
-
base_num_classes = self.num_classes_
|
|
284
310
|
all_params = {
|
|
285
311
|
"n_features": self.num_features_,
|
|
286
312
|
"prefix": self.prefix,
|
|
287
|
-
"num_classes":
|
|
313
|
+
"num_classes": self.num_classes_,
|
|
288
314
|
"verbose": self.verbose,
|
|
289
315
|
"debug": self.debug,
|
|
290
316
|
"device": self.device,
|
|
291
317
|
}
|
|
292
318
|
|
|
293
|
-
# Update with the variable hyperparameters
|
|
319
|
+
# Update with the variable hyperparameters
|
|
294
320
|
all_params.update(model_params)
|
|
295
321
|
|
|
296
322
|
return Model(**all_params).to(self.device)
|
|
@@ -372,110 +398,12 @@ class BaseNNImputer:
|
|
|
372
398
|
X (np.ndarray | pd.DataFrame | list | None): The input data with missing values.
|
|
373
399
|
|
|
374
400
|
Returns:
|
|
375
|
-
np.ndarray:
|
|
401
|
+
np.ndarray: IUPAC strings with missing values imputed.
|
|
376
402
|
"""
|
|
377
403
|
msg = "Method ``transform()`` must be implemented in the child class."
|
|
378
404
|
self.logger.error(msg)
|
|
379
405
|
raise NotImplementedError(msg)
|
|
380
406
|
|
|
381
|
-
def _class_balanced_weights_from_mask(
|
|
382
|
-
self,
|
|
383
|
-
y: np.ndarray,
|
|
384
|
-
train_mask: np.ndarray,
|
|
385
|
-
num_classes: int,
|
|
386
|
-
beta: float = 0.9999,
|
|
387
|
-
max_ratio: float = 5.0,
|
|
388
|
-
mode: Literal["allele", "genotype10"] = "allele",
|
|
389
|
-
) -> torch.Tensor:
|
|
390
|
-
"""Class-balanced weights (Cui et al. 2019) with overflow-safe effective number.
|
|
391
|
-
|
|
392
|
-
mode="allele": y is 1D alleles in {0..3}, train_mask same shape. mode="genotype10": y is (nS,nF,2) alleles; train_mask is (nS,nF) loci where both alleles known.
|
|
393
|
-
|
|
394
|
-
Args:
|
|
395
|
-
y (np.ndarray): Ground truth labels.
|
|
396
|
-
train_mask (np.ndarray): Boolean mask of training examples (same shape as y or y without last dim for genotype10).
|
|
397
|
-
num_classes (int): Number of classes.
|
|
398
|
-
beta (float): Hyperparameter for effective number calculation. Clamped to (0,1). Default is 0.9999.
|
|
399
|
-
max_ratio (float): Maximum allowed ratio between largest and smallest non-zero weight. Default is 5.0.
|
|
400
|
-
mode (Literal["allele", "genotype10"]): Whether y contains allele labels or 10-class genotypes. Default is "allele".
|
|
401
|
-
|
|
402
|
-
Returns:
|
|
403
|
-
torch.Tensor: Class weights of shape (num_classes,). Mean weight is 1.0, zero-weight classes remain zero.
|
|
404
|
-
"""
|
|
405
|
-
if mode == "allele":
|
|
406
|
-
valid = (y >= 0) & train_mask
|
|
407
|
-
cls, cnt = np.unique(y[valid].astype(np.int64), return_counts=True)
|
|
408
|
-
counts = np.zeros(num_classes, dtype=np.float64)
|
|
409
|
-
counts[cls] = cnt
|
|
410
|
-
|
|
411
|
-
elif mode == "genotype10":
|
|
412
|
-
if y.ndim != 3 or y.shape[-1] != 2:
|
|
413
|
-
msg = "For genotype10, y must be (nS,nF,2)."
|
|
414
|
-
self.logger.error(msg)
|
|
415
|
-
raise ValueError(msg)
|
|
416
|
-
|
|
417
|
-
if train_mask.shape != y.shape[:2]:
|
|
418
|
-
msg = "train_mask must be (nS,nF) for genotype10."
|
|
419
|
-
self.logger.error(msg)
|
|
420
|
-
raise ValueError(msg)
|
|
421
|
-
|
|
422
|
-
# only loci where both alleles known and in training
|
|
423
|
-
m = train_mask & np.all(y >= 0, axis=-1)
|
|
424
|
-
if not np.any(m):
|
|
425
|
-
counts = np.zeros(num_classes, dtype=np.float64)
|
|
426
|
-
|
|
427
|
-
else:
|
|
428
|
-
a1 = y[:, :, 0][m].astype(int)
|
|
429
|
-
a2 = y[:, :, 1][m].astype(int)
|
|
430
|
-
lo, hi = np.minimum(a1, a2), np.maximum(a1, a2)
|
|
431
|
-
# map to 10-class index
|
|
432
|
-
map10 = self.pgenc.map10
|
|
433
|
-
idx10 = map10[lo, hi]
|
|
434
|
-
idx10 = idx10[(idx10 >= 0) & (idx10 < num_classes)]
|
|
435
|
-
counts = np.bincount(idx10, minlength=num_classes).astype(np.float64)
|
|
436
|
-
else:
|
|
437
|
-
msg = f"Unknown mode supplied to _class_balanced_weights_from_mask: {mode}"
|
|
438
|
-
self.logger.error(msg)
|
|
439
|
-
raise ValueError(msg)
|
|
440
|
-
|
|
441
|
-
# ---- Effective number ----
|
|
442
|
-
beta = float(beta)
|
|
443
|
-
|
|
444
|
-
# clamp beta ∈ (0,1)
|
|
445
|
-
if not np.isfinite(beta):
|
|
446
|
-
beta = 0.9999
|
|
447
|
-
|
|
448
|
-
beta = min(max(beta, 1e-8), 1.0 - 1e-8)
|
|
449
|
-
|
|
450
|
-
logb = np.log(beta) # < 0
|
|
451
|
-
t = counts * logb # ≤ 0
|
|
452
|
-
|
|
453
|
-
# 1 - beta^n = 1 - exp(n*log(beta)) = -(exp(n*log(beta)) - 1)
|
|
454
|
-
# use expm1 for accuracy near 0; for very negative t, eff≈1.0
|
|
455
|
-
eff = np.where(t > -50.0, -np.expm1(t), 1.0)
|
|
456
|
-
|
|
457
|
-
# class-balanced weights
|
|
458
|
-
w = (1.0 - beta) / (eff + 1e-12)
|
|
459
|
-
|
|
460
|
-
# Give unseen classes the largest non-zero weight (keeps it learnable)
|
|
461
|
-
if np.any(counts == 0) and np.any(counts > 0):
|
|
462
|
-
w[counts == 0] = w[counts > 0].max()
|
|
463
|
-
|
|
464
|
-
# normalize by mean of non-zero
|
|
465
|
-
nz = w > 0
|
|
466
|
-
w[nz] /= w[nz].mean() + 1e-12
|
|
467
|
-
|
|
468
|
-
# cap spread consistently with a single 'cap'
|
|
469
|
-
cap = float(max_ratio) if max_ratio is not None else 10.0
|
|
470
|
-
cap = max(cap, 5.0) # ensure we allow some differentiation
|
|
471
|
-
if np.any(nz):
|
|
472
|
-
spread = w[nz].max() / max(w[nz].min(), 1e-12)
|
|
473
|
-
if spread > cap:
|
|
474
|
-
scale = cap / spread
|
|
475
|
-
w[nz] = 1.0 + (w[nz] - 1.0) * scale
|
|
476
|
-
|
|
477
|
-
return torch.tensor(w.astype(np.float32), device=self.device)
|
|
478
|
-
|
|
479
407
|
def _select_device(self, device: Literal["gpu", "cpu", "mps"]) -> torch.device:
|
|
480
408
|
"""Selects the appropriate PyTorch device based on user preference and availability.
|
|
481
409
|
|
|
@@ -487,36 +415,37 @@ class BaseNNImputer:
|
|
|
487
415
|
Returns:
|
|
488
416
|
torch.device: The selected PyTorch device.
|
|
489
417
|
"""
|
|
490
|
-
dvc
|
|
491
|
-
dvc = dvc.lower().strip()
|
|
418
|
+
dvc = device.lower().strip()
|
|
492
419
|
if dvc == "cpu":
|
|
493
|
-
self.logger.info("Using PyTorch device: CPU.")
|
|
494
420
|
return torch.device("cpu")
|
|
495
421
|
if dvc == "mps":
|
|
496
422
|
if torch.backends.mps.is_available():
|
|
497
|
-
self.logger.info("Using PyTorch device: mps.")
|
|
498
423
|
return torch.device("mps")
|
|
499
|
-
self.logger.warning("MPS unavailable; falling back to CPU.")
|
|
500
424
|
return torch.device("cpu")
|
|
501
|
-
# gpu
|
|
502
425
|
if torch.cuda.is_available():
|
|
503
|
-
self.logger.info("Using PyTorch device: cuda.")
|
|
504
426
|
return torch.device("cuda")
|
|
505
|
-
self.logger.warning("CUDA unavailable; falling back to CPU.")
|
|
506
427
|
return torch.device("cpu")
|
|
507
428
|
|
|
508
|
-
def _create_model_directories(
|
|
429
|
+
def _create_model_directories(
|
|
430
|
+
self, prefix: str, outdirs: List[str], *, outdir: Path | str | None = None
|
|
431
|
+
) -> None:
|
|
509
432
|
"""Creates the directory structure for storing model outputs.
|
|
510
433
|
|
|
511
|
-
This method sets up a standardized folder hierarchy for saving models,
|
|
434
|
+
This method sets up a standardized folder hierarchy for saving models,
|
|
435
|
+
plots, metrics, and optimization results, organized under a main directory
|
|
436
|
+
named after the provided prefix. The current implementation always uses
|
|
437
|
+
``<cwd>/<prefix>_output`` regardless of ``outdir``.
|
|
512
438
|
|
|
513
439
|
Args:
|
|
514
440
|
prefix (str): The prefix for the main output directory.
|
|
515
441
|
outdirs (List[str]): A list of subdirectory names to create within the main directory.
|
|
442
|
+
outdir (Path | str | None): Requested base output directory (currently ignored).
|
|
516
443
|
|
|
517
444
|
Raises:
|
|
518
445
|
Exception: If any of the directories cannot be created.
|
|
519
446
|
"""
|
|
447
|
+
base_root = Path(outdir) if outdir is not None else Path.cwd()
|
|
448
|
+
formatted_output_dir = base_root / f"{prefix}_output"
|
|
520
449
|
formatted_output_dir = Path(f"{prefix}_output")
|
|
521
450
|
base_dir = formatted_output_dir / "Unsupervised"
|
|
522
451
|
|
|
@@ -530,27 +459,16 @@ class BaseNNImputer:
|
|
|
530
459
|
self.logger.error(msg)
|
|
531
460
|
raise Exception(msg)
|
|
532
461
|
|
|
533
|
-
def _clear_resources(
|
|
534
|
-
self,
|
|
535
|
-
model: torch.nn.Module,
|
|
536
|
-
train_loader: torch.utils.data.DataLoader,
|
|
537
|
-
latent_vectors: torch.nn.Parameter | None = None,
|
|
538
|
-
) -> None:
|
|
462
|
+
def _clear_resources(self, model: torch.nn.Module) -> None:
|
|
539
463
|
"""Releases GPU and CPU memory after an Optuna trial.
|
|
540
464
|
|
|
541
465
|
This is a crucial step during hyperparameter tuning to prevent memory leaks between trials, ensuring that each trial runs in a clean environment.
|
|
542
466
|
|
|
543
467
|
Args:
|
|
544
468
|
model (torch.nn.Module): The model from the completed trial.
|
|
545
|
-
train_loader (torch.utils.data.DataLoader): The data loader from the trial.
|
|
546
|
-
latent_vectors (torch.nn.Parameter | None): The latent vectors from the trial.
|
|
547
469
|
"""
|
|
548
470
|
try:
|
|
549
|
-
del model
|
|
550
|
-
|
|
551
|
-
if latent_vectors is not None:
|
|
552
|
-
del latent_vectors
|
|
553
|
-
|
|
471
|
+
del model
|
|
554
472
|
except NameError:
|
|
555
473
|
pass
|
|
556
474
|
|
|
@@ -571,7 +489,7 @@ class BaseNNImputer:
|
|
|
571
489
|
y_pred: np.ndarray,
|
|
572
490
|
metrics: Dict[str, float],
|
|
573
491
|
msg: str,
|
|
574
|
-
):
|
|
492
|
+
) -> None:
|
|
575
493
|
"""Generate and save evaluation visualizations.
|
|
576
494
|
|
|
577
495
|
3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
|
|
@@ -589,20 +507,109 @@ class BaseNNImputer:
|
|
|
589
507
|
prefix = "zygosity" if len(labels) == 3 else "iupac"
|
|
590
508
|
n_labels = len(labels)
|
|
591
509
|
|
|
592
|
-
self.
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
510
|
+
if self.show_plots:
|
|
511
|
+
self.plotter_.plot_metrics(
|
|
512
|
+
y_true=y_true,
|
|
513
|
+
y_pred_proba=y_pred_proba,
|
|
514
|
+
metrics=metrics,
|
|
515
|
+
label_names=labels,
|
|
516
|
+
prefix=f"geno{n_labels}_{prefix}",
|
|
517
|
+
)
|
|
518
|
+
self.plotter_.plot_confusion_matrix(
|
|
519
|
+
y_true_1d=y_true,
|
|
520
|
+
y_pred_1d=y_pred,
|
|
521
|
+
label_names=labels,
|
|
522
|
+
prefix=f"geno{n_labels}_{prefix}",
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
def _additional_metrics(
|
|
526
|
+
self,
|
|
527
|
+
y_true: np.ndarray,
|
|
528
|
+
y_pred: np.ndarray,
|
|
529
|
+
labels: list[int],
|
|
530
|
+
report_names: list[str],
|
|
531
|
+
report: dict,
|
|
532
|
+
) -> dict[str, dict[str, float] | float]:
|
|
533
|
+
"""Compute additional metrics and augment the report dictionary.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
y_true (np.ndarray): True genotypes.
|
|
537
|
+
y_pred (np.ndarray): Predicted genotypes.
|
|
538
|
+
labels (list[int]): List of label indices.
|
|
539
|
+
report_names (list[str]): List of report names corresponding to labels.
|
|
540
|
+
report (dict): Classification report dictionary to augment.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
|
|
544
|
+
"""
|
|
545
|
+
# Create an identity matrix and use the targets array as indices
|
|
546
|
+
y_score = np.eye(len(report_names))[y_pred]
|
|
547
|
+
|
|
548
|
+
# Per-class metrics
|
|
549
|
+
ap_pc = average_precision_score(y_true, y_score, average=None)
|
|
550
|
+
jaccard_pc = jaccard_score(
|
|
551
|
+
y_true, y_pred, average=None, labels=labels, zero_division=0
|
|
598
552
|
)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
553
|
+
|
|
554
|
+
# Macro/weighted metrics
|
|
555
|
+
ap_macro = average_precision_score(y_true, y_score, average="macro")
|
|
556
|
+
ap_weighted = average_precision_score(y_true, y_score, average="weighted")
|
|
557
|
+
jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
|
|
558
|
+
jaccard_weighted = jaccard_score(
|
|
559
|
+
y_true, y_pred, average="weighted", zero_division=0
|
|
604
560
|
)
|
|
605
561
|
|
|
562
|
+
# Matthews correlation coefficient (MCC)
|
|
563
|
+
mcc = matthews_corrcoef(y_true, y_pred)
|
|
564
|
+
|
|
565
|
+
if not isinstance(ap_pc, np.ndarray):
|
|
566
|
+
msg = "average_precision_score or f1_score did not return np.ndarray as expected."
|
|
567
|
+
self.logger.error(msg)
|
|
568
|
+
raise TypeError(msg)
|
|
569
|
+
|
|
570
|
+
if not isinstance(jaccard_pc, np.ndarray):
|
|
571
|
+
msg = "jaccard_score did not return np.ndarray as expected."
|
|
572
|
+
self.logger.error(msg)
|
|
573
|
+
raise TypeError(msg)
|
|
574
|
+
|
|
575
|
+
# Add per-class metrics
|
|
576
|
+
report_full = {}
|
|
577
|
+
dd_subset = {
|
|
578
|
+
k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
|
|
579
|
+
}
|
|
580
|
+
for i, class_name in enumerate(report_names):
|
|
581
|
+
class_report: dict[str, float] = {}
|
|
582
|
+
if class_name in dd_subset:
|
|
583
|
+
class_report = dd_subset[class_name]
|
|
584
|
+
|
|
585
|
+
if isinstance(class_report, float) or not class_report:
|
|
586
|
+
continue
|
|
587
|
+
|
|
588
|
+
report_full[class_name] = dict(class_report)
|
|
589
|
+
report_full[class_name]["average-precision"] = float(ap_pc[i])
|
|
590
|
+
report_full[class_name]["jaccard"] = float(jaccard_pc[i])
|
|
591
|
+
|
|
592
|
+
macro_avg = report.get("macro avg")
|
|
593
|
+
if isinstance(macro_avg, dict):
|
|
594
|
+
report_full["macro avg"] = dict(macro_avg)
|
|
595
|
+
report_full["macro avg"]["average-precision"] = float(ap_macro)
|
|
596
|
+
report_full["macro avg"]["jaccard"] = float(jaccard_macro)
|
|
597
|
+
|
|
598
|
+
weighted_avg = report.get("weighted avg")
|
|
599
|
+
if isinstance(weighted_avg, dict):
|
|
600
|
+
report_full["weighted avg"] = dict(weighted_avg)
|
|
601
|
+
report_full["weighted avg"]["average-precision"] = float(ap_weighted)
|
|
602
|
+
report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
|
|
603
|
+
|
|
604
|
+
# Add scalar summary metrics
|
|
605
|
+
report_full["mcc"] = float(mcc)
|
|
606
|
+
accuracy_val = report.get("accuracy")
|
|
607
|
+
|
|
608
|
+
if isinstance(accuracy_val, (int, float)):
|
|
609
|
+
report_full["accuracy"] = float(accuracy_val)
|
|
610
|
+
|
|
611
|
+
return report_full
|
|
612
|
+
|
|
606
613
|
def _make_class_reports(
|
|
607
614
|
self,
|
|
608
615
|
y_true: np.ndarray,
|
|
@@ -620,27 +627,28 @@ class BaseNNImputer:
|
|
|
620
627
|
y_pred (np.ndarray): Predicted labels (1D array).
|
|
621
628
|
metrics (Dict[str, float]): Computed metrics.
|
|
622
629
|
y_pred_proba (np.ndarray | None): Predicted probabilities (2D array). Defaults to None.
|
|
623
|
-
labels (List[str]): Class label names
|
|
624
|
-
(default: ["REF", "HET", "ALT"] for 3-class).
|
|
630
|
+
labels (List[str]): Class label names (default: ["REF", "HET", "ALT"] for 3-class).
|
|
625
631
|
"""
|
|
626
|
-
report_name = "zygosity" if len(labels)
|
|
632
|
+
report_name = "zygosity" if len(labels) <= 3 else "iupac"
|
|
627
633
|
middle = "IUPAC" if report_name == "iupac" else "Zygosity"
|
|
628
634
|
|
|
629
|
-
msg = f"{middle} Report (on {
|
|
635
|
+
msg = f"{middle} Report (on {y_pred.size} total genotypes)"
|
|
630
636
|
self.logger.info(msg)
|
|
631
637
|
|
|
632
638
|
if y_pred_proba is not None:
|
|
633
|
-
self.
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
639
|
+
if self.show_plots:
|
|
640
|
+
self.plotter_.plot_metrics(
|
|
641
|
+
y_true,
|
|
642
|
+
y_pred_proba,
|
|
643
|
+
metrics,
|
|
644
|
+
label_names=labels,
|
|
645
|
+
prefix=report_name,
|
|
646
|
+
)
|
|
640
647
|
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
648
|
+
if self.show_plots:
|
|
649
|
+
self.plotter_.plot_confusion_matrix(
|
|
650
|
+
y_true, y_pred, label_names=labels, prefix=report_name
|
|
651
|
+
)
|
|
644
652
|
|
|
645
653
|
report: str | dict = classification_report(
|
|
646
654
|
y_true,
|
|
@@ -653,62 +661,63 @@ class BaseNNImputer:
|
|
|
653
661
|
|
|
654
662
|
if not isinstance(report, dict):
|
|
655
663
|
msg = "Expected classification_report to return a dict."
|
|
656
|
-
self.logger.error(msg)
|
|
664
|
+
self.logger.error(msg, exc_info=True)
|
|
657
665
|
raise ValueError(msg)
|
|
658
666
|
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
667
|
+
if self.show_plots:
|
|
668
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
669
|
+
try:
|
|
670
|
+
plots = viz.plot_all(
|
|
671
|
+
report, # type: ignore
|
|
672
|
+
title_prefix=f"{self.model_name} {middle} Report",
|
|
673
|
+
show=self.show_plots,
|
|
674
|
+
heatmap_classes_only=True,
|
|
675
|
+
)
|
|
676
|
+
finally:
|
|
677
|
+
viz._reset_mpl_style()
|
|
678
|
+
|
|
679
|
+
for name, fig in plots.items():
|
|
680
|
+
fout = (
|
|
681
|
+
self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
|
|
682
|
+
)
|
|
683
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
684
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
685
|
+
plt.close(fig)
|
|
686
|
+
elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
|
|
687
|
+
fout_html = fout.with_suffix(".html")
|
|
688
|
+
fig.write_html(file=fout_html)
|
|
689
|
+
|
|
690
|
+
SNPioMultiQC.queue_html(
|
|
691
|
+
fout_html,
|
|
692
|
+
panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
|
|
693
|
+
section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
694
|
+
title=f"{self.model_name} {middle} Radar Plot",
|
|
695
|
+
index_label=name,
|
|
696
|
+
description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
if not self.is_haploid_:
|
|
700
|
+
msg = f"Ploidy: {self.ploidy}. Evaluating per genotype (REF, HET, ALT)."
|
|
701
|
+
self.logger.info(msg)
|
|
702
|
+
|
|
703
|
+
report_full = self._additional_metrics(
|
|
704
|
+
y_true,
|
|
705
|
+
y_pred,
|
|
706
|
+
labels=list(range(len(labels))),
|
|
707
|
+
report_names=labels,
|
|
708
|
+
report=report,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
if self.verbose or self.debug:
|
|
670
712
|
pm = PrettyMetrics(
|
|
671
|
-
|
|
672
|
-
precision=
|
|
713
|
+
report_full,
|
|
714
|
+
precision=2,
|
|
673
715
|
title=f"{self.model_name} {middle} Report",
|
|
674
716
|
)
|
|
675
717
|
pm.render()
|
|
676
718
|
|
|
677
719
|
with open(self.metrics_dir / f"{report_name}_report.json", "w") as f:
|
|
678
|
-
json.dump(
|
|
679
|
-
|
|
680
|
-
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
681
|
-
|
|
682
|
-
plots = viz.plot_all(
|
|
683
|
-
report, # type: ignore
|
|
684
|
-
title_prefix=f"{self.model_name} {middle} Report",
|
|
685
|
-
show=getattr(self, "show_plots", False),
|
|
686
|
-
heatmap_classes_only=True,
|
|
687
|
-
)
|
|
688
|
-
|
|
689
|
-
for name, fig in plots.items():
|
|
690
|
-
fout = self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
|
|
691
|
-
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
692
|
-
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
693
|
-
plt.close(fig)
|
|
694
|
-
elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
|
|
695
|
-
fout_html = fout.with_suffix(".html")
|
|
696
|
-
fig.write_html(file=fout_html)
|
|
697
|
-
|
|
698
|
-
SNPioMultiQC.queue_html(
|
|
699
|
-
fout_html,
|
|
700
|
-
panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
|
|
701
|
-
section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
702
|
-
title=f"{self.model_name} {middle} Radar Plot",
|
|
703
|
-
index_label=name,
|
|
704
|
-
description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
|
|
705
|
-
)
|
|
706
|
-
|
|
707
|
-
if not self.is_haploid:
|
|
708
|
-
msg = f"Ploidy: {self.ploidy}. Evaluating per allele."
|
|
709
|
-
self.logger.info(msg)
|
|
710
|
-
|
|
711
|
-
viz._reset_mpl_style()
|
|
720
|
+
json.dump(report_full, f, indent=4)
|
|
712
721
|
|
|
713
722
|
def _compute_hidden_layer_sizes(
|
|
714
723
|
self,
|
|
@@ -716,6 +725,7 @@ class BaseNNImputer:
|
|
|
716
725
|
n_outputs: int,
|
|
717
726
|
n_samples: int,
|
|
718
727
|
n_hidden: int,
|
|
728
|
+
latent_dim: int,
|
|
719
729
|
*,
|
|
720
730
|
alpha: float = 4.0,
|
|
721
731
|
schedule: str = "pyramid",
|
|
@@ -727,182 +737,439 @@ class BaseNNImputer:
|
|
|
727
737
|
) -> list[int]:
|
|
728
738
|
"""Compute hidden layer sizes given problem scale and a layer count.
|
|
729
739
|
|
|
730
|
-
|
|
740
|
+
Notes:
|
|
741
|
+
- Returns sizes for *hidden layers only* (length = n_hidden).
|
|
742
|
+
- Does NOT include the input layer (n_inputs) or the latent layer (latent_dim).
|
|
743
|
+
- Enforces a latent-aware minimum: one discrete level above latent_dim, where a level is `multiple_of`.
|
|
744
|
+
- Enforces *strictly decreasing* hidden sizes (no repeats). This may require bumping `base` upward.
|
|
731
745
|
|
|
732
746
|
Args:
|
|
733
|
-
n_inputs
|
|
734
|
-
n_outputs
|
|
735
|
-
n_samples
|
|
736
|
-
n_hidden
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
747
|
+
n_inputs: Number of input features (e.g., flattened one-hot: num_features * num_classes).
|
|
748
|
+
n_outputs: Number of output classes (often equals num_classes).
|
|
749
|
+
n_samples: Number of training samples.
|
|
750
|
+
n_hidden: Number of hidden layers (excluding input and latent layers).
|
|
751
|
+
latent_dim: Latent dimensionality (not returned, used only to set a floor).
|
|
752
|
+
alpha: Scaling factor for base layer size.
|
|
753
|
+
schedule: Size schedule ("pyramid" or "linear").
|
|
754
|
+
min_size: Minimum layer size floor before latent-aware adjustment.
|
|
755
|
+
max_size: Maximum layer size cap. If None, a heuristic cap is used.
|
|
756
|
+
multiple_of: Hidden sizes are multiples of this value.
|
|
757
|
+
decay: Pyramid decay factor. If None, computed to land near the target.
|
|
758
|
+
cap_by_inputs: If True, cap layer sizes to n_inputs.
|
|
744
759
|
|
|
745
760
|
Returns:
|
|
746
|
-
list[int]:
|
|
761
|
+
list[int]: Hidden layer sizes (len = n_hidden).
|
|
747
762
|
|
|
748
763
|
Raises:
|
|
749
|
-
ValueError:
|
|
750
|
-
TypeError: If any argument is not of the expected type.
|
|
751
|
-
|
|
752
|
-
Notes:
|
|
753
|
-
- If n_hidden is 0, returns an empty list.
|
|
754
|
-
- The base layer size is computed as ceil(n_samples / (alpha * (n_inputs + n_outputs))).
|
|
755
|
-
- The sizes are adjusted according to the specified schedule and constraints.
|
|
764
|
+
ValueError: On invalid arguments or conflicting constraints.
|
|
756
765
|
"""
|
|
766
|
+
# ----------------------------
|
|
767
|
+
# Basic validation
|
|
768
|
+
# ----------------------------
|
|
757
769
|
if n_hidden < 0:
|
|
758
770
|
msg = f"n_hidden must be >= 0, got {n_hidden}."
|
|
759
771
|
self.logger.error(msg)
|
|
760
772
|
raise ValueError(msg)
|
|
761
773
|
|
|
762
|
-
if
|
|
763
|
-
|
|
774
|
+
if n_hidden == 0:
|
|
775
|
+
return []
|
|
776
|
+
|
|
777
|
+
if n_inputs <= 0:
|
|
778
|
+
msg = f"n_inputs must be > 0, got {n_inputs}."
|
|
764
779
|
self.logger.error(msg)
|
|
765
780
|
raise ValueError(msg)
|
|
766
781
|
|
|
767
|
-
if
|
|
768
|
-
|
|
782
|
+
if n_outputs <= 0:
|
|
783
|
+
msg = f"n_outputs must be > 0, got {n_outputs}."
|
|
784
|
+
self.logger.error(msg)
|
|
785
|
+
raise ValueError(msg)
|
|
769
786
|
|
|
770
|
-
|
|
787
|
+
if n_samples <= 0:
|
|
788
|
+
msg = f"n_samples must be > 0, got {n_samples}."
|
|
789
|
+
self.logger.error(msg)
|
|
790
|
+
raise ValueError(msg)
|
|
771
791
|
|
|
772
|
-
if
|
|
773
|
-
msg = f"
|
|
792
|
+
if latent_dim <= 0:
|
|
793
|
+
msg = f"latent_dim must be > 0, got {latent_dim}."
|
|
774
794
|
self.logger.error(msg)
|
|
775
795
|
raise ValueError(msg)
|
|
776
796
|
|
|
777
|
-
|
|
797
|
+
if multiple_of <= 0:
|
|
798
|
+
msg = f"multiple_of must be > 0, got {multiple_of}."
|
|
799
|
+
self.logger.error(msg)
|
|
800
|
+
raise ValueError(msg)
|
|
801
|
+
|
|
802
|
+
if alpha <= 0:
|
|
803
|
+
msg = f"alpha must be > 0, got {alpha}."
|
|
804
|
+
self.logger.error(msg)
|
|
805
|
+
raise ValueError(msg)
|
|
806
|
+
|
|
807
|
+
schedule = str(schedule).lower().strip()
|
|
808
|
+
if schedule not in {"pyramid", "linear"}:
|
|
809
|
+
msg = f"Invalid schedule '{schedule}'. Must be 'pyramid' or 'linear'."
|
|
810
|
+
self.logger.error(msg)
|
|
811
|
+
raise ValueError(msg)
|
|
812
|
+
|
|
813
|
+
# ----------------------------
|
|
814
|
+
# Latent-aware minimum floor
|
|
815
|
+
# ----------------------------
|
|
816
|
+
# Smallest multiple_of strictly greater than latent_dim
|
|
817
|
+
min_hidden_floor = int(np.ceil((latent_dim + 1) / multiple_of) * multiple_of)
|
|
818
|
+
effective_min = max(int(min_size), min_hidden_floor)
|
|
819
|
+
|
|
820
|
+
if cap_by_inputs and n_inputs < effective_min:
|
|
821
|
+
msg = (
|
|
822
|
+
"Cannot satisfy latent-aware minimum hidden size with cap_by_inputs=True. "
|
|
823
|
+
f"Required hidden size >= {effective_min} (one level above latent_dim={latent_dim}), "
|
|
824
|
+
f"but n_inputs={n_inputs}. Set cap_by_inputs=False or reduce latent_dim/multiple_of."
|
|
825
|
+
)
|
|
826
|
+
self.logger.error(msg)
|
|
827
|
+
raise ValueError(msg)
|
|
828
|
+
|
|
829
|
+
# ----------------------------
|
|
830
|
+
# Infer num_features (if using flattened one-hot: n_inputs = num_features * num_classes)
|
|
831
|
+
# ----------------------------
|
|
832
|
+
if n_inputs % n_outputs == 0:
|
|
833
|
+
num_features = n_inputs // n_outputs
|
|
834
|
+
else:
|
|
835
|
+
num_features = n_inputs
|
|
836
|
+
self.logger.warning(
|
|
837
|
+
"n_inputs is not divisible by n_outputs; falling back to num_features=n_inputs "
|
|
838
|
+
f"(n_inputs={n_inputs}, n_outputs={n_outputs}). If using one-hot flattening, "
|
|
839
|
+
"pass n_outputs=num_classes so num_features can be inferred correctly."
|
|
840
|
+
)
|
|
778
841
|
|
|
842
|
+
# ----------------------------
|
|
843
|
+
# Base size heuristic (feature-matrix aware; avoids collapse for huge n_inputs)
|
|
844
|
+
# ----------------------------
|
|
845
|
+
obs_scale = (float(n_samples) * float(num_features)) / float(
|
|
846
|
+
num_features + n_outputs
|
|
847
|
+
)
|
|
848
|
+
base = int(np.ceil(float(alpha) * np.sqrt(obs_scale)))
|
|
849
|
+
|
|
850
|
+
# ----------------------------
|
|
851
|
+
# Determine max_size
|
|
852
|
+
# ----------------------------
|
|
779
853
|
if max_size is None:
|
|
780
|
-
max_size = max(n_inputs, base)
|
|
854
|
+
max_size = max(int(n_inputs), int(base), int(effective_min))
|
|
855
|
+
|
|
856
|
+
if cap_by_inputs:
|
|
857
|
+
max_size = min(int(max_size), int(n_inputs))
|
|
858
|
+
else:
|
|
859
|
+
max_size = int(max_size)
|
|
860
|
+
|
|
861
|
+
if max_size < effective_min:
|
|
862
|
+
msg = (
|
|
863
|
+
f"max_size ({max_size}) must be >= effective_min ({effective_min}), where effective_min "
|
|
864
|
+
f"is max(min_size={min_size}, one-level-above latent_dim={latent_dim})."
|
|
865
|
+
)
|
|
866
|
+
self.logger.error(msg)
|
|
867
|
+
raise ValueError(msg)
|
|
868
|
+
|
|
869
|
+
# Round base up to a multiple and clip to bounds
|
|
870
|
+
base = int(np.clip(base, effective_min, max_size))
|
|
871
|
+
base = int(np.ceil(base / multiple_of) * multiple_of)
|
|
872
|
+
base = int(np.clip(base, effective_min, max_size))
|
|
873
|
+
|
|
874
|
+
# ----------------------------
|
|
875
|
+
# Enforce "no repeats" feasibility in discrete levels
|
|
876
|
+
# Need n_hidden distinct multiples between base and effective_min:
|
|
877
|
+
# base >= effective_min + (n_hidden - 1) * multiple_of
|
|
878
|
+
# ----------------------------
|
|
879
|
+
required_min_base = effective_min + (n_hidden - 1) * multiple_of
|
|
880
|
+
|
|
881
|
+
if required_min_base > max_size:
|
|
882
|
+
msg = (
|
|
883
|
+
"Cannot build strictly-decreasing (no-repeat) hidden sizes under current constraints. "
|
|
884
|
+
f"Need base >= {required_min_base} to fit n_hidden={n_hidden} distinct layers "
|
|
885
|
+
f"with multiple_of={multiple_of} down to effective_min={effective_min}, "
|
|
886
|
+
f"but max_size={max_size}. Reduce n_hidden, reduce multiple_of, lower latent_dim/min_size, "
|
|
887
|
+
"or increase max_size / set cap_by_inputs=False."
|
|
888
|
+
)
|
|
889
|
+
self.logger.error(msg)
|
|
890
|
+
raise ValueError(msg)
|
|
781
891
|
|
|
782
|
-
|
|
892
|
+
if base < required_min_base:
|
|
893
|
+
# Bump base upward so a strict staircase is possible
|
|
894
|
+
base = required_min_base
|
|
895
|
+
base = int(np.ceil(base / multiple_of) * multiple_of)
|
|
896
|
+
base = int(np.clip(base, effective_min, max_size))
|
|
783
897
|
|
|
784
|
-
|
|
785
|
-
|
|
898
|
+
# Work in "levels" of multiple_of for guaranteed uniqueness
|
|
899
|
+
start_level = base // multiple_of
|
|
900
|
+
end_level = effective_min // multiple_of
|
|
901
|
+
|
|
902
|
+
# Sanity: distinct levels available
|
|
903
|
+
if (start_level - end_level) < (n_hidden - 1):
|
|
904
|
+
# This should not happen due to required_min_base logic, but keep a hard guard.
|
|
905
|
+
msg = (
|
|
906
|
+
"Internal constraint failure: insufficient discrete levels to enforce no repeats. "
|
|
907
|
+
f"start_level={start_level}, end_level={end_level}, n_hidden={n_hidden}."
|
|
908
|
+
)
|
|
909
|
+
self.logger.error(msg)
|
|
910
|
+
raise ValueError(msg)
|
|
911
|
+
|
|
912
|
+
# ----------------------------
|
|
913
|
+
# Build schedule in level space (integers), then convert to sizes
|
|
914
|
+
# ----------------------------
|
|
915
|
+
if n_hidden == 1:
|
|
916
|
+
levels = np.array([start_level], dtype=int)
|
|
786
917
|
|
|
787
918
|
elif schedule == "linear":
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
if n_hidden == 1
|
|
792
|
-
else np.linspace(base, target, num=n_hidden, dtype=float)
|
|
919
|
+
# Linear interpolation in level space, then strictify
|
|
920
|
+
levels = np.round(np.linspace(start_level, end_level, num=n_hidden)).astype(
|
|
921
|
+
int
|
|
793
922
|
)
|
|
794
923
|
|
|
924
|
+
# Enforce bounds then strict decrease
|
|
925
|
+
levels = np.clip(levels, end_level, start_level)
|
|
926
|
+
|
|
927
|
+
for i in range(1, n_hidden):
|
|
928
|
+
if levels[i] >= levels[i - 1]:
|
|
929
|
+
levels[i] = levels[i - 1] - 1
|
|
930
|
+
|
|
931
|
+
if levels[-1] < end_level:
|
|
932
|
+
msg = (
|
|
933
|
+
"Failed to enforce strictly-decreasing linear schedule without violating the floor. "
|
|
934
|
+
f"(levels[-1]={levels[-1]} < end_level={end_level}). "
|
|
935
|
+
"Reduce n_hidden or multiple_of, or increase max_size."
|
|
936
|
+
)
|
|
937
|
+
self.logger.error(msg)
|
|
938
|
+
raise ValueError(msg)
|
|
939
|
+
|
|
940
|
+
# Force exact floor at the end (still strict because we have enough room by construction)
|
|
941
|
+
levels[-1] = end_level
|
|
942
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
943
|
+
if levels[i] <= levels[i + 1]:
|
|
944
|
+
levels[i] = levels[i + 1] + 1
|
|
945
|
+
|
|
946
|
+
if levels[0] > start_level:
|
|
947
|
+
# If this happens, we would need an even larger base; handle by raising base once.
|
|
948
|
+
needed_base = int(levels[0] * multiple_of)
|
|
949
|
+
if needed_base > max_size:
|
|
950
|
+
msg = (
|
|
951
|
+
"Cannot enforce strictly-decreasing linear schedule after floor anchoring; "
|
|
952
|
+
f"would require base={needed_base} > max_size={max_size}."
|
|
953
|
+
)
|
|
954
|
+
self.logger.error(msg)
|
|
955
|
+
raise ValueError(msg)
|
|
956
|
+
# Rebuild with bumped base
|
|
957
|
+
start_level = needed_base // multiple_of
|
|
958
|
+
levels = np.arange(start_level, start_level - n_hidden, -1, dtype=int)
|
|
959
|
+
levels[-1] = end_level # keep floor
|
|
960
|
+
# Ensure strict with backward adjust
|
|
961
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
962
|
+
if levels[i] <= levels[i + 1]:
|
|
963
|
+
levels[i] = levels[i + 1] + 1
|
|
964
|
+
|
|
795
965
|
elif schedule == "pyramid":
|
|
796
|
-
|
|
797
|
-
|
|
966
|
+
# Geometric decay in level space (more aggressive early taper than linear)
|
|
967
|
+
if decay is not None:
|
|
968
|
+
dcy = float(decay)
|
|
798
969
|
else:
|
|
970
|
+
# Choose decay to land exactly at end_level (in float space)
|
|
971
|
+
dcy = (float(end_level) / float(start_level)) ** (
|
|
972
|
+
1.0 / float(n_hidden - 1)
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
# Keep it in a sensible range
|
|
976
|
+
dcy = float(np.clip(dcy, 0.05, 0.99))
|
|
977
|
+
|
|
978
|
+
exponents = np.arange(n_hidden, dtype=float)
|
|
979
|
+
levels_float = float(start_level) * (dcy**exponents)
|
|
980
|
+
|
|
981
|
+
levels = np.round(levels_float).astype(int)
|
|
982
|
+
levels = np.clip(levels, end_level, start_level)
|
|
983
|
+
|
|
984
|
+
# Anchor the last layer at the floor, then strictify backward
|
|
985
|
+
levels[-1] = end_level
|
|
986
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
987
|
+
if levels[i] <= levels[i + 1]:
|
|
988
|
+
levels[i] = levels[i + 1] + 1
|
|
989
|
+
|
|
990
|
+
# If we overshot the start, bump base (once) if possible, then rebuild
|
|
991
|
+
if levels[0] > start_level:
|
|
992
|
+
needed_base = int(levels[0] * multiple_of)
|
|
993
|
+
if needed_base > max_size:
|
|
994
|
+
msg = (
|
|
995
|
+
"Cannot enforce strictly-decreasing pyramid schedule; "
|
|
996
|
+
f"would require base={needed_base} > max_size={max_size}."
|
|
997
|
+
)
|
|
998
|
+
self.logger.error(msg)
|
|
999
|
+
raise ValueError(msg)
|
|
1000
|
+
|
|
1001
|
+
start_level = needed_base // multiple_of
|
|
1002
|
+
# Recompute with new start_level and same decay (or recompute decay if decay is None)
|
|
799
1003
|
if decay is None:
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
1004
|
+
dcy = (float(end_level) / float(start_level)) ** (
|
|
1005
|
+
1.0 / float(n_hidden - 1)
|
|
1006
|
+
)
|
|
1007
|
+
dcy = float(np.clip(dcy, 0.05, 0.99))
|
|
1008
|
+
|
|
1009
|
+
levels_float = float(start_level) * (dcy**exponents)
|
|
1010
|
+
levels = np.round(levels_float).astype(int)
|
|
1011
|
+
levels = np.clip(levels, end_level, start_level)
|
|
1012
|
+
levels[-1] = end_level
|
|
1013
|
+
for i in range(n_hidden - 2, -1, -1):
|
|
1014
|
+
if levels[i] <= levels[i + 1]:
|
|
1015
|
+
levels[i] = levels[i + 1] + 1
|
|
808
1016
|
|
|
809
1017
|
else:
|
|
810
|
-
msg = f"Unknown schedule '{schedule}'. Use 'pyramid'
|
|
1018
|
+
msg = f"Unknown schedule '{schedule}'. Use 'pyramid' or 'linear' (constant disallowed with no repeats)."
|
|
811
1019
|
self.logger.error(msg)
|
|
812
1020
|
raise ValueError(msg)
|
|
813
1021
|
|
|
814
|
-
|
|
1022
|
+
# Convert levels -> sizes
|
|
1023
|
+
sizes = (levels * multiple_of).astype(int)
|
|
815
1024
|
|
|
816
|
-
|
|
817
|
-
|
|
1025
|
+
# Final clip (should be redundant, but safe)
|
|
1026
|
+
sizes = np.clip(sizes, effective_min, max_size).astype(int)
|
|
818
1027
|
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
1028
|
+
# Final strict no-repeat assertion
|
|
1029
|
+
if np.any(np.diff(sizes) >= 0):
|
|
1030
|
+
msg = (
|
|
1031
|
+
"Internal error: produced non-decreasing or repeated hidden sizes after strict enforcement. "
|
|
1032
|
+
f"sizes={sizes.tolist()}"
|
|
1033
|
+
)
|
|
1034
|
+
self.logger.error(msg)
|
|
1035
|
+
raise ValueError(msg)
|
|
822
1036
|
|
|
823
|
-
|
|
824
|
-
"""Class-balanced weights for 0/1/2 (handles haploid collapse if needed).
|
|
1037
|
+
return sizes.tolist()
|
|
825
1038
|
|
|
826
|
-
|
|
1039
|
+
def _class_weights_from_zygosity(
|
|
1040
|
+
self,
|
|
1041
|
+
X: np.ndarray,
|
|
1042
|
+
train_mask: Optional[np.ndarray] = None,
|
|
1043
|
+
*,
|
|
1044
|
+
inverse: bool = False,
|
|
1045
|
+
normalize: bool = False,
|
|
1046
|
+
power: float = 1.0,
|
|
1047
|
+
max_ratio: float | None = None,
|
|
1048
|
+
eps: float = 1e-12,
|
|
1049
|
+
) -> torch.Tensor:
|
|
1050
|
+
"""Compute class weights for zygosity labels.
|
|
827
1051
|
|
|
828
|
-
|
|
829
|
-
|
|
1052
|
+
If inverse=False (default):
|
|
1053
|
+
w_c = N / (K * n_c) ("balanced")
|
|
1054
|
+
|
|
1055
|
+
If inverse=True:
|
|
1056
|
+
w_c = N / n_c (same ratios, scaled by K)
|
|
1057
|
+
|
|
1058
|
+
If power != 1.0:
|
|
1059
|
+
w_c <- w_c ** power (amplifies or softens imbalance handling)
|
|
1060
|
+
|
|
1061
|
+
If normalize=True:
|
|
1062
|
+
rescales nonzero weights so mean(nonzero_weights) == 1.
|
|
830
1063
|
|
|
831
1064
|
Returns:
|
|
832
|
-
torch.Tensor:
|
|
1065
|
+
torch.Tensor: Class weights of shape (num_classes,) on self.device.
|
|
833
1066
|
"""
|
|
834
|
-
y = X
|
|
835
|
-
if y.size == 0:
|
|
836
|
-
return torch.ones(
|
|
837
|
-
self.num_classes_, dtype=torch.float32, device=self.device
|
|
838
|
-
)
|
|
1067
|
+
y = np.asarray(X).ravel().astype(np.int8)
|
|
839
1068
|
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
1069
|
+
m = y >= 0
|
|
1070
|
+
if train_mask is not None:
|
|
1071
|
+
tm = np.asarray(train_mask, dtype=bool).ravel()
|
|
1072
|
+
if tm.shape != y.shape:
|
|
1073
|
+
msg = "train_mask must have the same shape as X."
|
|
1074
|
+
self.logger.error(msg)
|
|
1075
|
+
raise ValueError(msg)
|
|
1076
|
+
m &= tm
|
|
1077
|
+
|
|
1078
|
+
is_hap = bool(getattr(self, "is_haploid_", False))
|
|
1079
|
+
num_classes = 2 if is_hap else int(self.num_classes_)
|
|
1080
|
+
|
|
1081
|
+
if not np.any(m):
|
|
1082
|
+
return torch.ones(num_classes, dtype=torch.long, device=self.device)
|
|
1083
|
+
|
|
1084
|
+
if is_hap:
|
|
1085
|
+
y = y.copy()
|
|
1086
|
+
y[(y == 2) & m] = 1
|
|
1087
|
+
|
|
1088
|
+
y_m = y[m]
|
|
1089
|
+
if y_m.size:
|
|
1090
|
+
ymin = int(y_m.min())
|
|
1091
|
+
ymax = int(y_m.max())
|
|
1092
|
+
if ymin < 0 or ymax >= num_classes:
|
|
1093
|
+
msg = (
|
|
1094
|
+
f"Found out-of-range labels under mask: min={ymin}, max={ymax}, "
|
|
1095
|
+
f"expected in [0, {num_classes - 1}]."
|
|
1096
|
+
)
|
|
1097
|
+
self.logger.error(msg)
|
|
1098
|
+
raise ValueError(msg)
|
|
848
1099
|
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
) -> torch.Tensor | None:
|
|
853
|
-
"""Normalize class weights once to keep loss scale stable.
|
|
1100
|
+
counts = np.bincount(y_m, minlength=num_classes).astype(np.float32)
|
|
1101
|
+
N = float(counts.sum())
|
|
1102
|
+
K = float(num_classes)
|
|
854
1103
|
|
|
855
|
-
|
|
856
|
-
|
|
1104
|
+
w = np.zeros(num_classes, dtype=np.float32)
|
|
1105
|
+
nz = counts > 0
|
|
857
1106
|
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
return weights / weights.mean().clamp_min(1e-8)
|
|
1107
|
+
if np.any(nz):
|
|
1108
|
+
if inverse:
|
|
1109
|
+
w[nz] = N / (counts[nz] + eps)
|
|
1110
|
+
else:
|
|
1111
|
+
w[nz] = N / (K * (counts[nz] + eps))
|
|
864
1112
|
|
|
865
|
-
|
|
866
|
-
|
|
1113
|
+
# Amplify / soften class contrast
|
|
1114
|
+
if power <= 0.0:
|
|
1115
|
+
msg = "power must be > 0."
|
|
1116
|
+
self.logger.error(msg)
|
|
1117
|
+
raise ValueError(msg)
|
|
1118
|
+
if power != 1.0:
|
|
1119
|
+
w[nz] = np.power(w[nz], power)
|
|
867
1120
|
|
|
868
|
-
|
|
869
|
-
|
|
1121
|
+
if np.any(~nz):
|
|
1122
|
+
self.logger.warning(
|
|
1123
|
+
"Some classes have zero count under the provided mask: "
|
|
1124
|
+
f"{np.where(~nz)[0].tolist()}. Setting their weights to 0."
|
|
1125
|
+
)
|
|
870
1126
|
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
1127
|
+
# Cap ratio among observed classes
|
|
1128
|
+
if max_ratio is not None and np.any(nz):
|
|
1129
|
+
cap = float(max_ratio)
|
|
1130
|
+
if cap <= 1.0:
|
|
1131
|
+
msg = "max_ratio must be > 1.0 or None."
|
|
1132
|
+
self.logger.error(msg)
|
|
1133
|
+
raise ValueError(msg)
|
|
1134
|
+
|
|
1135
|
+
wmin = max(float(w[nz].min()), eps)
|
|
1136
|
+
wmax = wmin * cap
|
|
1137
|
+
w[nz] = np.clip(w[nz], wmin, wmax)
|
|
1138
|
+
|
|
1139
|
+
# Optional normalization: mean(nonzero) -> 1.0
|
|
1140
|
+
if normalize and np.any(nz):
|
|
1141
|
+
mean_nz = float(w[nz].mean())
|
|
1142
|
+
if mean_nz > 0.0:
|
|
1143
|
+
w[nz] /= mean_nz
|
|
1144
|
+
else:
|
|
1145
|
+
self.logger.warning(
|
|
1146
|
+
"normalize=True requested, but mean of nonzero weights is not positive; skipping normalization."
|
|
1147
|
+
)
|
|
1148
|
+
|
|
1149
|
+
self.logger.debug(f"Class counts: {counts.astype(np.int8)}")
|
|
1150
|
+
self.logger.debug(
|
|
1151
|
+
f"Class weights (inverse={inverse}, power={power}, normalize={normalize}): {w}"
|
|
894
1152
|
)
|
|
895
1153
|
|
|
896
|
-
|
|
897
|
-
"""One-hot 0/1/2; -1 rows are all-zeros (B, L, K).
|
|
1154
|
+
return torch.as_tensor(w, dtype=torch.long, device=self.device)
|
|
898
1155
|
|
|
899
|
-
|
|
1156
|
+
def _one_hot_encode_012(
|
|
1157
|
+
self, X: np.ndarray | torch.Tensor, num_classes: int | None
|
|
1158
|
+
) -> torch.Tensor:
|
|
1159
|
+
"""One-hot encode genotype calls. Missing inputs (<0) result in a vector of -1s.
|
|
900
1160
|
|
|
901
1161
|
Args:
|
|
902
|
-
X (np.ndarray | torch.Tensor):
|
|
1162
|
+
X (np.ndarray | torch.Tensor): Input genotype calls as integers (0,1, 2, etc.).
|
|
1163
|
+
num_classes (int | None): Number of classes (K). If None, uses self.num_classes_.
|
|
903
1164
|
|
|
904
1165
|
Returns:
|
|
905
|
-
torch.Tensor:
|
|
1166
|
+
torch.Tensor: One-hot encoded tensor of shape (B, L, K) with dtype
|
|
1167
|
+
``torch.long``. Valid calls are 0/1, missing calls are all -1.
|
|
1168
|
+
|
|
1169
|
+
Notes:
|
|
1170
|
+
- Valid classes must be integers in [0, K-1]
|
|
1171
|
+
- Missing is any value < 0; these positions become [-1, -1, ..., -1]
|
|
1172
|
+
- If K==2 and values are in {0,2} (no 1s), map 2->1.
|
|
906
1173
|
"""
|
|
907
1174
|
Xt = (
|
|
908
1175
|
torch.from_numpy(X).to(self.device)
|
|
@@ -910,212 +1177,680 @@ class BaseNNImputer:
|
|
|
910
1177
|
else X.to(self.device)
|
|
911
1178
|
)
|
|
912
1179
|
|
|
913
|
-
#
|
|
1180
|
+
# Make sure we have integer class labels
|
|
1181
|
+
if Xt.dtype.is_floating_point:
|
|
1182
|
+
# Convert NaN -> -1 and cast to long
|
|
1183
|
+
Xt = torch.nan_to_num(Xt, nan=-1.0).long()
|
|
1184
|
+
else:
|
|
1185
|
+
Xt = Xt.long()
|
|
1186
|
+
|
|
914
1187
|
B, L = Xt.shape
|
|
915
|
-
K = self.num_classes_
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
1188
|
+
K = int(num_classes) if num_classes is not None else int(self.num_classes_)
|
|
1189
|
+
|
|
1190
|
+
# Missing is anything < 0 (covers -1, -9, etc.)
|
|
1191
|
+
valid = Xt >= 0
|
|
1192
|
+
|
|
1193
|
+
# If binary mode but data is {0,2}
|
|
1194
|
+
# (haploid-like or "ref vs non-ref"), map 2->1
|
|
1195
|
+
if K == 2:
|
|
1196
|
+
has_het = torch.any(valid & (Xt == 1))
|
|
1197
|
+
has_alt2 = torch.any(valid & (Xt == 2))
|
|
1198
|
+
if has_alt2 and not has_het:
|
|
1199
|
+
Xt = Xt.clone()
|
|
1200
|
+
Xt[valid & (Xt == 2)] = 1
|
|
1201
|
+
|
|
1202
|
+
# Now enforce the one-hot precondition
|
|
1203
|
+
if torch.any(valid & (Xt >= K)):
|
|
1204
|
+
bad_vals = torch.unique(Xt[valid & (Xt >= K)]).detach().cpu().tolist()
|
|
1205
|
+
all_vals = torch.unique(Xt[valid]).detach().cpu().tolist()
|
|
1206
|
+
msg = f"_one_hot_encode_012 received class values outside [0, {K-1}]. num_classes={K}, offending_values={bad_vals}, observed_values={all_vals}. Upstream encoding mismatch (e.g., passing 0/1/2 with num_classes=2)."
|
|
1207
|
+
self.logger.error(msg)
|
|
1208
|
+
raise ValueError(msg)
|
|
1209
|
+
|
|
1210
|
+
# CHANGE: Initialize with -1.0 to ensure missing values are represented as [-1, -1, ... -1]
|
|
1211
|
+
X_ohe = torch.full((B, L, K), -1.0, dtype=torch.long, device=self.device)
|
|
1212
|
+
|
|
1213
|
+
idx = Xt[valid]
|
|
919
1214
|
|
|
920
1215
|
if idx.numel() > 0:
|
|
921
|
-
|
|
1216
|
+
# Overwrite valid positions (which were -1) with the correct one-hot vectors
|
|
1217
|
+
X_ohe[valid] = F.one_hot(idx, num_classes=K).long()
|
|
922
1218
|
|
|
923
1219
|
return X_ohe
|
|
924
1220
|
|
|
925
|
-
def
|
|
926
|
-
self,
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
do_latent_infer: bool = False,
|
|
934
|
-
latent_steps: int = 50,
|
|
935
|
-
latent_lr: float = 1e-2,
|
|
936
|
-
latent_weight_decay: float = 0.0,
|
|
937
|
-
latent_seed: int = 123,
|
|
938
|
-
_latent_cache: dict | None = None,
|
|
939
|
-
_latent_cache_key: str | None = None,
|
|
940
|
-
eval_mask_override: np.ndarray | None = None,
|
|
941
|
-
) -> float:
|
|
942
|
-
"""Compute a scalar metric (to MAXIMIZE) on a fixed validation set.
|
|
943
|
-
|
|
944
|
-
This method evaluates the model on a validation dataset and computes a specified metric, which is used for pruning decisions during hyperparameter tuning. It supports optional latent inference to optimize latent representations before evaluation. The method handles potential issues with non-finite metric values by returning negative infinity, making it easier to prune poorly performing trials.
|
|
1221
|
+
def decode_012(
|
|
1222
|
+
self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
|
|
1223
|
+
) -> np.ndarray:
|
|
1224
|
+
"""Decode 012-encodings to IUPAC chars with metadata repair.
|
|
1225
|
+
|
|
1226
|
+
This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
|
|
1227
|
+
1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
|
|
1228
|
+
2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
|
|
945
1229
|
|
|
946
1230
|
Args:
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
params (dict): Model parameters.
|
|
950
|
-
metric (str): Metric name to return.
|
|
951
|
-
objective_mode (bool): If True, use objective-mode evaluation. Default is True.
|
|
952
|
-
do_latent_infer (bool): If True, perform latent inference before evaluation. Default
|
|
953
|
-
latent_steps (int): Number of steps for latent inference. Default is 50.
|
|
954
|
-
latent_lr (float): Learning rate for latent inference. Default is 1e-2
|
|
955
|
-
latent_weight_decay (float): Weight decay for latent inference. Default is 0.0.
|
|
956
|
-
latent_seed (int): Random seed for latent inference. Default is 123.
|
|
957
|
-
_latent_cache (dict | None): Optional cache for storing/retrieving optimized latents
|
|
958
|
-
_latent_cache_key (str | None): Key for storing/retrieving in _latent_cache.
|
|
959
|
-
eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
|
|
1231
|
+
X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
|
|
1232
|
+
is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
|
|
960
1233
|
|
|
961
1234
|
Returns:
|
|
962
|
-
|
|
1235
|
+
np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
|
|
1236
|
+
|
|
1237
|
+
Notes:
|
|
1238
|
+
- The method normalizes input values to handle various formats, including strings, lists, and arrays.
|
|
1239
|
+
- It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
|
|
1240
|
+
- Missing or invalid codes are represented as 'N' if they can't be resolved.
|
|
1241
|
+
- The method includes repair logic to infer missing metadata from the source SNP data when necessary.
|
|
1242
|
+
|
|
1243
|
+
Raises:
|
|
1244
|
+
ValueError: If input is not a DataFrame.
|
|
963
1245
|
"""
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
1246
|
+
df = validate_input_type(X, return_type="df")
|
|
1247
|
+
|
|
1248
|
+
if not isinstance(df, pd.DataFrame):
|
|
1249
|
+
msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
|
|
1250
|
+
self.logger.error(msg)
|
|
1251
|
+
raise ValueError(msg)
|
|
1252
|
+
|
|
1253
|
+
# IUPAC Definitions
|
|
1254
|
+
iupac_to_bases: dict[str, set[str]] = {
|
|
1255
|
+
"A": {"A"},
|
|
1256
|
+
"C": {"C"},
|
|
1257
|
+
"G": {"G"},
|
|
1258
|
+
"T": {"T"},
|
|
1259
|
+
"R": {"A", "G"},
|
|
1260
|
+
"Y": {"C", "T"},
|
|
1261
|
+
"S": {"G", "C"},
|
|
1262
|
+
"W": {"A", "T"},
|
|
1263
|
+
"K": {"G", "T"},
|
|
1264
|
+
"M": {"A", "C"},
|
|
1265
|
+
"B": {"C", "G", "T"},
|
|
1266
|
+
"D": {"A", "G", "T"},
|
|
1267
|
+
"H": {"A", "C", "T"},
|
|
1268
|
+
"V": {"A", "C", "G"},
|
|
1269
|
+
"N": set(),
|
|
1270
|
+
}
|
|
1271
|
+
bases_to_iupac = {
|
|
1272
|
+
frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
|
|
1273
|
+
}
|
|
1274
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
|
|
1275
|
+
|
|
1276
|
+
def _normalize_iupac(value: object) -> str | None:
|
|
1277
|
+
"""Normalize an input into a single IUPAC code token or None."""
|
|
1278
|
+
if value is None:
|
|
1279
|
+
return None
|
|
1280
|
+
|
|
1281
|
+
# Bytes -> str (make type narrowing explicit)
|
|
1282
|
+
if isinstance(value, (bytes, np.bytes_)):
|
|
1283
|
+
value = bytes(value).decode("utf-8", errors="ignore")
|
|
1284
|
+
|
|
1285
|
+
# Handle list/tuple/array/Series: take first valid
|
|
1286
|
+
if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
|
|
1287
|
+
# Convert Series to numpy array for consistent behavior
|
|
1288
|
+
if isinstance(value, pd.Series):
|
|
1289
|
+
arr = value.to_numpy()
|
|
1290
|
+
else:
|
|
1291
|
+
arr = value
|
|
1292
|
+
|
|
1293
|
+
# Scalar numpy array fast path
|
|
1294
|
+
if isinstance(arr, np.ndarray) and arr.ndim == 0:
|
|
1295
|
+
return _normalize_iupac(arr.item())
|
|
1296
|
+
|
|
1297
|
+
# Empty sequence/array
|
|
1298
|
+
if len(arr) == 0:
|
|
1299
|
+
return None
|
|
1300
|
+
|
|
1301
|
+
# First valid element wins
|
|
1302
|
+
for item in arr:
|
|
1303
|
+
code = _normalize_iupac(item)
|
|
1304
|
+
if code is not None:
|
|
1305
|
+
return code
|
|
1306
|
+
return None
|
|
1307
|
+
|
|
1308
|
+
s = str(value).upper().strip()
|
|
1309
|
+
if not s or s in missing_codes:
|
|
1310
|
+
return None
|
|
1311
|
+
|
|
1312
|
+
if "," in s:
|
|
1313
|
+
for tok in (t.strip() for t in s.split(",")):
|
|
1314
|
+
if tok and tok not in missing_codes and tok in iupac_to_bases:
|
|
1315
|
+
return tok
|
|
1316
|
+
return None
|
|
1317
|
+
|
|
1318
|
+
return s if s in iupac_to_bases else None
|
|
1319
|
+
|
|
1320
|
+
codes_df = df.apply(pd.to_numeric, errors="coerce")
|
|
1321
|
+
codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
|
|
1322
|
+
n_rows, n_cols = codes.shape
|
|
1323
|
+
|
|
1324
|
+
if is_nuc:
|
|
1325
|
+
iupac_list = np.array(
|
|
1326
|
+
["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
|
|
977
1327
|
)
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
)
|
|
1328
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
1329
|
+
mask = (codes >= 0) & (codes <= 9)
|
|
1330
|
+
out[mask] = iupac_list[codes[mask]]
|
|
1331
|
+
return out
|
|
1332
|
+
|
|
1333
|
+
# Metadata fetch
|
|
1334
|
+
ref_alleles = getattr(self.genotype_data, "ref", [])
|
|
1335
|
+
alt_alleles = getattr(self.genotype_data, "alt", [])
|
|
1336
|
+
|
|
1337
|
+
if len(ref_alleles) != n_cols:
|
|
1338
|
+
ref_alleles = getattr(self, "_ref", [None] * n_cols)
|
|
1339
|
+
if len(alt_alleles) != n_cols:
|
|
1340
|
+
alt_alleles = getattr(self, "_alt", [None] * n_cols)
|
|
1341
|
+
|
|
1342
|
+
# Ensure list length matches
|
|
1343
|
+
if len(ref_alleles) != n_cols:
|
|
1344
|
+
ref_alleles = [None] * n_cols
|
|
1345
|
+
if len(alt_alleles) != n_cols:
|
|
1346
|
+
alt_alleles = [None] * n_cols
|
|
1347
|
+
|
|
1348
|
+
out = np.full((n_rows, n_cols), "N", dtype="<U1")
|
|
1349
|
+
source_snp_data = None
|
|
1350
|
+
|
|
1351
|
+
for j in range(n_cols):
|
|
1352
|
+
ref = _normalize_iupac(ref_alleles[j])
|
|
1353
|
+
alt = _normalize_iupac(alt_alleles[j])
|
|
1354
|
+
|
|
1355
|
+
# --- REPAIR LOGIC ---
|
|
1356
|
+
# If metadata is missing, scan the source column.
|
|
1357
|
+
if ref is None or alt is None:
|
|
1358
|
+
if source_snp_data is None and self.genotype_data.snp_data is not None:
|
|
1359
|
+
try:
|
|
1360
|
+
source_snp_data = np.asarray(self.genotype_data.snp_data)
|
|
1361
|
+
except Exception:
|
|
1362
|
+
pass # if lazy loading fails
|
|
1363
|
+
|
|
1364
|
+
if source_snp_data is not None:
|
|
1365
|
+
try:
|
|
1366
|
+
col_data = source_snp_data[:, j]
|
|
1367
|
+
uniques = set()
|
|
1368
|
+
# Optimization: check up to 200 non-empty values
|
|
1369
|
+
count = 0
|
|
1370
|
+
for val in col_data:
|
|
1371
|
+
norm = _normalize_iupac(val)
|
|
1372
|
+
if norm:
|
|
1373
|
+
uniques.add(norm)
|
|
1374
|
+
count += 1
|
|
1375
|
+
if len(uniques) >= 2 or count > 200:
|
|
1376
|
+
break
|
|
1377
|
+
|
|
1378
|
+
sorted_u = sorted(list(uniques))
|
|
1379
|
+
if len(sorted_u) >= 1 and ref is None:
|
|
1380
|
+
ref = sorted_u[0]
|
|
1381
|
+
if len(sorted_u) >= 2 and alt is None:
|
|
1382
|
+
alt = sorted_u[1]
|
|
1383
|
+
except Exception:
|
|
1384
|
+
pass
|
|
1385
|
+
|
|
1386
|
+
# --- DEFAULTS FOR MISSING ---
|
|
1387
|
+
# If still missing, we cannot decode.
|
|
1388
|
+
if ref is None and alt is None:
|
|
1389
|
+
ref = "N"
|
|
1390
|
+
alt = "N"
|
|
1391
|
+
elif ref is None:
|
|
1392
|
+
ref = alt
|
|
1393
|
+
elif alt is None:
|
|
1394
|
+
alt = ref # Monomorphic site: ALT becomes REF
|
|
1395
|
+
|
|
1396
|
+
# --- COMPUTE HET CODE ---
|
|
1397
|
+
if ref == alt:
|
|
1398
|
+
het_code = ref
|
|
1399
|
+
else:
|
|
1400
|
+
ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
|
|
1401
|
+
alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
|
|
1402
|
+
union_set = frozenset(ref_set | alt_set)
|
|
1403
|
+
het_code = bases_to_iupac.get(union_set, "N")
|
|
996
1404
|
|
|
997
|
-
|
|
998
|
-
|
|
1405
|
+
# --- ASSIGNMENT WITH SAFETY FALLBACKS ---
|
|
1406
|
+
col_codes = codes[:, j]
|
|
999
1407
|
|
|
1000
|
-
|
|
1001
|
-
|
|
1408
|
+
# Case 0: REF
|
|
1409
|
+
if ref != "N":
|
|
1410
|
+
out[col_codes == 0, j] = ref
|
|
1002
1411
|
|
|
1003
|
-
|
|
1412
|
+
# Case 1: HET
|
|
1413
|
+
if het_code != "N":
|
|
1414
|
+
out[col_codes == 1, j] = het_code
|
|
1415
|
+
else:
|
|
1416
|
+
# If HET code is invalid (e.g. ref='A', alt='N'),
|
|
1417
|
+
# fallback to REF
|
|
1418
|
+
# Fix for an issue where a HET prediction at a monomorphic site
|
|
1419
|
+
# produced 'N'
|
|
1420
|
+
if ref != "N":
|
|
1421
|
+
out[col_codes == 1, j] = ref
|
|
1422
|
+
|
|
1423
|
+
# Case 2: ALT
|
|
1424
|
+
if alt != "N":
|
|
1425
|
+
out[col_codes == 2, j] = alt
|
|
1426
|
+
else:
|
|
1427
|
+
# If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
|
|
1428
|
+
# Fix for an issue where an ALT prediction on a monomorphic site
|
|
1429
|
+
# produced 'N'
|
|
1430
|
+
if ref != "N":
|
|
1431
|
+
out[col_codes == 2, j] = ref
|
|
1432
|
+
|
|
1433
|
+
return out
|
|
1434
|
+
|
|
1435
|
+
def _save_best_params(
|
|
1436
|
+
self, best_params: Dict[str, Any], objective_mode: bool = False
|
|
1437
|
+
) -> None:
|
|
1438
|
+
"""Save the best hyperparameters to a JSON file.
|
|
1439
|
+
|
|
1440
|
+
This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
|
|
1444
|
+
"""
|
|
1445
|
+
if not hasattr(self, "parameters_dir"):
|
|
1446
|
+
msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
|
|
1447
|
+
self.logger.error(msg)
|
|
1448
|
+
raise AttributeError(msg)
|
|
1449
|
+
|
|
1450
|
+
if objective_mode:
|
|
1451
|
+
fout = self.optimize_dir / "parameters" / "best_tuned_parameters.json"
|
|
1452
|
+
else:
|
|
1453
|
+
fout = self.parameters_dir / "best_parameters.json"
|
|
1454
|
+
|
|
1455
|
+
fout.parent.mkdir(parents=True, exist_ok=True)
|
|
1456
|
+
|
|
1457
|
+
with open(fout, "w") as f:
|
|
1458
|
+
json.dump(best_params, f, indent=4)
|
|
1004
1459
|
|
|
1005
|
-
def
|
|
1006
|
-
"""
|
|
1460
|
+
def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
1461
|
+
"""An abstract method for setting best parameters."""
|
|
1462
|
+
raise NotImplementedError
|
|
1007
1463
|
|
|
1008
|
-
|
|
1464
|
+
def sim_missing_transform(
|
|
1465
|
+
self, X: np.ndarray
|
|
1466
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
1467
|
+
"""Simulate missing data according to the specified strategy.
|
|
1009
1468
|
|
|
1010
1469
|
Args:
|
|
1011
|
-
|
|
1470
|
+
X (np.ndarray): Genotype matrix to simulate missing data on.
|
|
1012
1471
|
|
|
1013
1472
|
Returns:
|
|
1014
|
-
|
|
1473
|
+
X_for_model (np.ndarray): Genotype matrix with simulated missing data.
|
|
1474
|
+
sim_mask (np.ndarray): Boolean mask of simulated missing entries.
|
|
1475
|
+
orig_mask (np.ndarray): Boolean mask of original missing entries.
|
|
1015
1476
|
"""
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1477
|
+
if (
|
|
1478
|
+
not hasattr(self, "sim_prop")
|
|
1479
|
+
or self.sim_prop <= 0.0
|
|
1480
|
+
or self.sim_prop >= 1.0
|
|
1481
|
+
):
|
|
1482
|
+
msg = "sim_prop must be set and between 0.0 and 1.0."
|
|
1483
|
+
self.logger.error(msg)
|
|
1484
|
+
raise AttributeError(msg)
|
|
1020
1485
|
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1486
|
+
if not hasattr(self, "tree_parser") and "nonrandom" in self.sim_strategy:
|
|
1487
|
+
msg = "tree_parser must be set for 'nonrandom' or 'nonrandom_weighted' sim_strategy."
|
|
1488
|
+
self.logger.error(msg)
|
|
1489
|
+
raise AttributeError(msg)
|
|
1025
1490
|
|
|
1026
|
-
|
|
1491
|
+
# --- Simulate missing data ---
|
|
1492
|
+
X_for_sim = X.astype(np.float32, copy=True)
|
|
1493
|
+
tr = SimMissingTransformer(
|
|
1494
|
+
genotype_data=self.genotype_data,
|
|
1495
|
+
tree_parser=self.tree_parser,
|
|
1496
|
+
prop_missing=self.sim_prop,
|
|
1497
|
+
strategy=self.sim_strategy,
|
|
1498
|
+
missing_val=-1,
|
|
1499
|
+
mask_missing=True,
|
|
1500
|
+
verbose=self.verbose,
|
|
1501
|
+
seed=self.seed,
|
|
1502
|
+
**self.sim_kwargs,
|
|
1503
|
+
)
|
|
1504
|
+
tr.fit(X_for_sim.copy())
|
|
1505
|
+
X_for_model = tr.transform(X_for_sim.copy())
|
|
1506
|
+
sim_mask = tr.sim_missing_mask_.astype(bool)
|
|
1507
|
+
orig_mask = tr.original_missing_mask_.astype(bool)
|
|
1508
|
+
|
|
1509
|
+
return X_for_model, sim_mask, orig_mask
|
|
1510
|
+
|
|
1511
|
+
def _train_val_test_split(
|
|
1512
|
+
self, X: np.ndarray
|
|
1513
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
1514
|
+
"""Split data into train, validation, and test sets.
|
|
1027
1515
|
|
|
1028
1516
|
Args:
|
|
1029
|
-
|
|
1030
|
-
|
|
1517
|
+
X (np.ndarray): Genotype matrix to split.
|
|
1518
|
+
|
|
1519
|
+
Returns:
|
|
1520
|
+
tuple[np.ndarray, np.ndarray, np.ndarray]: Indices for train, validation, and test sets.
|
|
1031
1521
|
|
|
1032
1522
|
Raises:
|
|
1033
|
-
ValueError: If
|
|
1523
|
+
ValueError: If there are not enough samples for splitting.
|
|
1524
|
+
AssertionError: If validation_split is not in (0.0, 1.0).
|
|
1034
1525
|
"""
|
|
1035
|
-
|
|
1036
|
-
first_in = self._first_linear_in_features(model)
|
|
1037
|
-
if first_in != zdim:
|
|
1038
|
-
raise ValueError(
|
|
1039
|
-
f"Latent mismatch: zdim={zdim}, model first Linear expects in_features={first_in}"
|
|
1040
|
-
)
|
|
1526
|
+
n_samples = X.shape[0]
|
|
1041
1527
|
|
|
1042
|
-
|
|
1043
|
-
|
|
1528
|
+
if n_samples < 3:
|
|
1529
|
+
msg = f"Not enough samples ({n_samples}) for train/val/test split."
|
|
1530
|
+
self.logger.error(msg)
|
|
1531
|
+
raise ValueError(msg)
|
|
1044
1532
|
|
|
1045
|
-
|
|
1533
|
+
assert (
|
|
1534
|
+
self.validation_split > 0.0 and self.validation_split < 1.0
|
|
1535
|
+
), f"validation_split must be in (0.0, 1.0), but got {self.validation_split}."
|
|
1046
1536
|
|
|
1047
|
-
|
|
1048
|
-
|
|
1537
|
+
# Train/Val split
|
|
1538
|
+
indices = np.arange(n_samples)
|
|
1539
|
+
train_idx, val_test_idx = train_test_split(
|
|
1540
|
+
indices,
|
|
1541
|
+
test_size=self.validation_split,
|
|
1542
|
+
random_state=self.seed,
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
if not val_test_idx.size >= 4:
|
|
1546
|
+
msg = f"Not enough samples ({val_test_idx.size}) for validation/test split."
|
|
1547
|
+
self.logger.error(msg)
|
|
1548
|
+
raise ValueError(msg)
|
|
1549
|
+
|
|
1550
|
+
# Split val and test equally
|
|
1551
|
+
val_idx, test_idx = train_test_split(
|
|
1552
|
+
val_test_idx, test_size=0.5, random_state=self.seed
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
return train_idx, val_idx, test_idx
|
|
1556
|
+
|
|
1557
|
+
def _get_data_loaders(
|
|
1558
|
+
self,
|
|
1559
|
+
X: np.ndarray,
|
|
1560
|
+
y: np.ndarray,
|
|
1561
|
+
mask: np.ndarray,
|
|
1562
|
+
batch_size: int,
|
|
1563
|
+
*,
|
|
1564
|
+
shuffle: bool = True,
|
|
1565
|
+
) -> torch.utils.data.DataLoader:
|
|
1566
|
+
"""Create DataLoader for training and validation.
|
|
1567
|
+
|
|
1568
|
+
Args:
|
|
1569
|
+
X (np.ndarray): 0/1/2-encoded input matrix.
|
|
1570
|
+
y (np.ndarray): 0/1/2-encoded matrix with -1 for missing.
|
|
1571
|
+
mask (np.ndarray): Boolean mask of entries to score in the loss.
|
|
1572
|
+
batch_size (int): Batch size.
|
|
1573
|
+
shuffle (bool): Whether to shuffle batches. Defaults to True.
|
|
1574
|
+
|
|
1575
|
+
Returns:
|
|
1576
|
+
The DataLoader.
|
|
1049
1577
|
"""
|
|
1050
|
-
|
|
1051
|
-
return
|
|
1578
|
+
dataset = _MaskedNumpyDataset(X, y, mask)
|
|
1052
1579
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1580
|
+
return torch.utils.data.DataLoader(
|
|
1581
|
+
dataset,
|
|
1582
|
+
batch_size=batch_size,
|
|
1583
|
+
shuffle=shuffle,
|
|
1584
|
+
pin_memory=(str(self.device).startswith("cuda")),
|
|
1585
|
+
)
|
|
1056
1586
|
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1587
|
+
def _update_anneal_schedule(
|
|
1588
|
+
self,
|
|
1589
|
+
final: float,
|
|
1590
|
+
warm: int,
|
|
1591
|
+
ramp: int,
|
|
1592
|
+
epoch: int,
|
|
1593
|
+
*,
|
|
1594
|
+
init_val: float = 0.0,
|
|
1595
|
+
) -> torch.Tensor:
|
|
1596
|
+
"""Update annealed hyperparameter value based on epoch.
|
|
1060
1597
|
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1598
|
+
Args:
|
|
1599
|
+
final (float): Final value after annealing.
|
|
1600
|
+
warm (int): Number of warm-up epochs.
|
|
1601
|
+
ramp (int): Number of ramp-up epochs.
|
|
1602
|
+
epoch (int): Current epoch number.
|
|
1603
|
+
init_val (float): Initial value before annealing starts.
|
|
1604
|
+
|
|
1605
|
+
Returns:
|
|
1606
|
+
torch.Tensor: Current value of the hyperparameter.
|
|
1607
|
+
"""
|
|
1608
|
+
if epoch < warm:
|
|
1609
|
+
val = torch.tensor(init_val)
|
|
1610
|
+
elif epoch < warm + ramp:
|
|
1611
|
+
val = torch.tensor(final * ((epoch - warm) / ramp))
|
|
1064
1612
|
else:
|
|
1065
|
-
|
|
1613
|
+
val = torch.tensor(final)
|
|
1066
1614
|
|
|
1067
|
-
|
|
1068
|
-
tr, te = train_test_split(
|
|
1069
|
-
idx, test_size=self.validation_split, random_state=self.seed
|
|
1070
|
-
)
|
|
1071
|
-
self._tune_train_idx = tr
|
|
1072
|
-
self._tune_test_idx = te
|
|
1073
|
-
self._tune_X_train = X_small[tr]
|
|
1074
|
-
self._tune_X_test = X_small[te]
|
|
1615
|
+
return val.to(self.device)
|
|
1075
1616
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1617
|
+
def _anneal_config(
|
|
1618
|
+
self,
|
|
1619
|
+
params: Optional[dict],
|
|
1620
|
+
key: str,
|
|
1621
|
+
default: float,
|
|
1622
|
+
max_epochs: int,
|
|
1623
|
+
*,
|
|
1624
|
+
warm_alt: int = 50,
|
|
1625
|
+
ramp_alt: int = 100,
|
|
1626
|
+
) -> Tuple[float, int, int]:
|
|
1627
|
+
"""Configure annealing schedule for a hyperparameter.
|
|
1079
1628
|
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1629
|
+
Args:
|
|
1630
|
+
params (Optional[dict]): Dictionary of parameters to extract from.
|
|
1631
|
+
key (str): Key to look for in params.
|
|
1632
|
+
default (float): Default final value if not specified in params.
|
|
1633
|
+
max_epochs (int): Total number of training epochs.
|
|
1634
|
+
warm_alt (int): Alternative warm-up period if 10% of epochs is too long
|
|
1635
|
+
ramp_alt (int): Alternative ramp-up period if 20% of epochs is too long
|
|
1085
1636
|
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1637
|
+
Returns:
|
|
1638
|
+
Tuple[float, int, int]: Final value, warm-up epochs, ramp-up epochs.
|
|
1639
|
+
"""
|
|
1640
|
+
val = None
|
|
1641
|
+
if params is not None and params:
|
|
1642
|
+
if not hasattr(self, key):
|
|
1643
|
+
msg = f"Attribute '{key}' not found for anneal_config."
|
|
1644
|
+
self.logger.error(msg)
|
|
1645
|
+
raise AttributeError(msg)
|
|
1089
1646
|
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
):
|
|
1095
|
-
self._tune_eval_slice = np.arange(self.tune_proxy_metric_batch)
|
|
1647
|
+
val = params.get(key, getattr(self, key))
|
|
1648
|
+
|
|
1649
|
+
if val is not None and isinstance(val, (float, int)):
|
|
1650
|
+
final = float(val)
|
|
1096
1651
|
else:
|
|
1097
|
-
|
|
1652
|
+
final = default
|
|
1098
1653
|
|
|
1099
|
-
|
|
1654
|
+
warm, ramp = min(int(0.1 * max_epochs), warm_alt), min(
|
|
1655
|
+
int(0.2 * max_epochs), ramp_alt
|
|
1656
|
+
)
|
|
1657
|
+
return final, warm, ramp
|
|
1100
1658
|
|
|
1101
|
-
def
|
|
1102
|
-
"""
|
|
1659
|
+
def _repair_ref_alt_from_iupac(self, loci: np.ndarray) -> None:
|
|
1660
|
+
"""Repair REF/ALT for specific loci using observed IUPAC genotypes.
|
|
1103
1661
|
|
|
1104
|
-
|
|
1662
|
+
Args:
|
|
1663
|
+
loci (np.ndarray): Array of locus indices to repair.
|
|
1664
|
+
|
|
1665
|
+
Notes:
|
|
1666
|
+
- Modifies self.genotype_data.ref and self.genotype_data.alt in place.
|
|
1667
|
+
"""
|
|
1668
|
+
iupac_to_bases = {
|
|
1669
|
+
"A": {"A"},
|
|
1670
|
+
"C": {"C"},
|
|
1671
|
+
"G": {"G"},
|
|
1672
|
+
"T": {"T"},
|
|
1673
|
+
"R": {"A", "G"},
|
|
1674
|
+
"Y": {"C", "T"},
|
|
1675
|
+
"S": {"G", "C"},
|
|
1676
|
+
"W": {"A", "T"},
|
|
1677
|
+
"K": {"G", "T"},
|
|
1678
|
+
"M": {"A", "C"},
|
|
1679
|
+
"B": {"C", "G", "T"},
|
|
1680
|
+
"D": {"A", "G", "T"},
|
|
1681
|
+
"H": {"A", "C", "T"},
|
|
1682
|
+
"V": {"A", "C", "G"},
|
|
1683
|
+
}
|
|
1684
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
|
|
1685
|
+
|
|
1686
|
+
def norm(v: object) -> str | None:
|
|
1687
|
+
if v is None:
|
|
1688
|
+
return None
|
|
1689
|
+
s = str(v).upper().strip()
|
|
1690
|
+
if not s or s in missing_codes:
|
|
1691
|
+
return None
|
|
1692
|
+
return s if s in iupac_to_bases else None
|
|
1693
|
+
|
|
1694
|
+
snp = np.asarray(self.genotype_data.snp_data, dtype=object) # (N,L) IUPAC-ish
|
|
1695
|
+
refs = list(getattr(self.genotype_data, "ref", [None] * snp.shape[1]))
|
|
1696
|
+
alts = list(getattr(self.genotype_data, "alt", [None] * snp.shape[1]))
|
|
1697
|
+
|
|
1698
|
+
for j in loci:
|
|
1699
|
+
cnt = Counter()
|
|
1700
|
+
col = snp[:, int(j)]
|
|
1701
|
+
for g in col:
|
|
1702
|
+
code = norm(g)
|
|
1703
|
+
if code is None:
|
|
1704
|
+
continue
|
|
1705
|
+
for b in iupac_to_bases[code]:
|
|
1706
|
+
cnt[b] += 1
|
|
1707
|
+
|
|
1708
|
+
if not cnt:
|
|
1709
|
+
continue
|
|
1710
|
+
|
|
1711
|
+
common = [b for b, _ in cnt.most_common()]
|
|
1712
|
+
ref = common[0]
|
|
1713
|
+
alt = common[1] if len(common) > 1 else None
|
|
1714
|
+
|
|
1715
|
+
refs[int(j)] = ref
|
|
1716
|
+
alts[int(j)] = alt if alt is not None else "."
|
|
1717
|
+
|
|
1718
|
+
self.genotype_data.ref = np.asarray(refs, dtype=object)
|
|
1719
|
+
|
|
1720
|
+
if not isinstance(alts, np.ndarray):
|
|
1721
|
+
alts = np.array(alts, dtype=object).tolist()
|
|
1722
|
+
|
|
1723
|
+
self.genotype_data.alt = alts
|
|
1724
|
+
|
|
1725
|
+
def _aligned_ref_alt(self, L: int) -> tuple[list[object], list[object]]:
|
|
1726
|
+
"""Return REF/ALT aligned to the genotype matrix columns.
|
|
1105
1727
|
|
|
1106
1728
|
Args:
|
|
1107
|
-
|
|
1729
|
+
L (int): Number of loci (columns in genotype matrix).
|
|
1730
|
+
|
|
1731
|
+
Returns:
|
|
1732
|
+
tuple[list[object], list[object]]: Aligned REF and ALT lists.
|
|
1108
1733
|
"""
|
|
1109
|
-
|
|
1110
|
-
|
|
1734
|
+
refs = getattr(self.genotype_data, "ref", None)
|
|
1735
|
+
alts = getattr(self.genotype_data, "alt", None)
|
|
1736
|
+
|
|
1737
|
+
if refs is None or alts is None:
|
|
1738
|
+
msg = "genotype_data.ref/alt are required but missing."
|
|
1111
1739
|
self.logger.error(msg)
|
|
1112
|
-
raise
|
|
1740
|
+
raise ValueError(msg)
|
|
1113
1741
|
|
|
1114
|
-
|
|
1742
|
+
refs_arr = np.asarray(refs, dtype=object)
|
|
1743
|
+
alts_arr = np.asarray(alts, dtype=object)
|
|
1115
1744
|
|
|
1116
|
-
|
|
1117
|
-
|
|
1745
|
+
if refs_arr.shape[0] != L or alts_arr.shape[0] != L:
|
|
1746
|
+
msg = f"REF/ALT length mismatch vs matrix columns: L={L}, len(ref)={refs_arr.shape[0]}, len(alt)={alts_arr.shape[0]}. You are using REF/ALT metadata that is not aligned to pgenc.genotypes_012 columns. Fix by subsetting/refiltering ref/alt with the same locus mask used for the genotype matrix."
|
|
1747
|
+
self.logger.error(msg)
|
|
1748
|
+
raise ValueError(msg)
|
|
1118
1749
|
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1750
|
+
# Unwrap singleton ALT arrays like array(['T'], dtype=object)
|
|
1751
|
+
def unwrap(x: object) -> object:
|
|
1752
|
+
if isinstance(x, np.ndarray):
|
|
1753
|
+
if x.size == 0:
|
|
1754
|
+
return None
|
|
1755
|
+
if x.size == 1:
|
|
1756
|
+
return x.item()
|
|
1757
|
+
return x
|
|
1758
|
+
|
|
1759
|
+
refs_list = [unwrap(x) for x in refs_arr.tolist()]
|
|
1760
|
+
alts_list = [unwrap(x) for x in alts_arr.tolist()]
|
|
1761
|
+
return refs_list, alts_list
|
|
1762
|
+
|
|
1763
|
+
def _build_valid_class_mask(self) -> torch.Tensor:
|
|
1764
|
+
L = self.num_features_
|
|
1765
|
+
K = self.num_classes_
|
|
1766
|
+
mask = np.ones((L, K), dtype=bool)
|
|
1767
|
+
|
|
1768
|
+
# --- IUPAC helpers (single-character only) ---
|
|
1769
|
+
iupac_to_bases: dict[str, set[str]] = {
|
|
1770
|
+
"A": {"A"},
|
|
1771
|
+
"C": {"C"},
|
|
1772
|
+
"G": {"G"},
|
|
1773
|
+
"T": {"T"},
|
|
1774
|
+
"R": {"A", "G"},
|
|
1775
|
+
"Y": {"C", "T"},
|
|
1776
|
+
"S": {"G", "C"},
|
|
1777
|
+
"W": {"A", "T"},
|
|
1778
|
+
"K": {"G", "T"},
|
|
1779
|
+
"M": {"A", "C"},
|
|
1780
|
+
"B": {"C", "G", "T"},
|
|
1781
|
+
"D": {"A", "G", "T"},
|
|
1782
|
+
"H": {"A", "C", "T"},
|
|
1783
|
+
"V": {"A", "C", "G"},
|
|
1784
|
+
}
|
|
1785
|
+
missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
|
|
1786
|
+
|
|
1787
|
+
# get aligned ref/alt (should be exactly length L)
|
|
1788
|
+
refs, alts = self._aligned_ref_alt(L)
|
|
1789
|
+
|
|
1790
|
+
def _normalize_iupac(value: object) -> str | None:
|
|
1791
|
+
"""Return a single-letter IUPAC code or None if missing/invalid."""
|
|
1792
|
+
if value is None:
|
|
1793
|
+
return None
|
|
1794
|
+
if isinstance(value, (bytes, np.bytes_)):
|
|
1795
|
+
value = value.decode("utf-8", errors="ignore")
|
|
1796
|
+
|
|
1797
|
+
# allow list/tuple/array containers (take first valid)
|
|
1798
|
+
if isinstance(value, (list, tuple, np.ndarray, pd.Series)):
|
|
1799
|
+
for item in value:
|
|
1800
|
+
code = _normalize_iupac(item)
|
|
1801
|
+
if code is not None:
|
|
1802
|
+
return code
|
|
1803
|
+
return None
|
|
1804
|
+
|
|
1805
|
+
s = str(value).upper().strip()
|
|
1806
|
+
if not s or s in missing_codes:
|
|
1807
|
+
return None
|
|
1808
|
+
|
|
1809
|
+
# handle comma-separated values
|
|
1810
|
+
if "," in s:
|
|
1811
|
+
for tok in (t.strip() for t in s.split(",")):
|
|
1812
|
+
if tok and tok not in missing_codes and tok in iupac_to_bases:
|
|
1813
|
+
return tok
|
|
1814
|
+
return None
|
|
1815
|
+
|
|
1816
|
+
return s if s in iupac_to_bases else None
|
|
1817
|
+
|
|
1818
|
+
# 1) metadata restriction
|
|
1819
|
+
for j in range(L):
|
|
1820
|
+
ref = _normalize_iupac(refs[j])
|
|
1821
|
+
alt = _normalize_iupac(alts[j])
|
|
1822
|
+
|
|
1823
|
+
if alt is None or (ref is not None and alt == ref):
|
|
1824
|
+
mask[j, :] = False
|
|
1825
|
+
mask[j, 0] = True
|
|
1826
|
+
|
|
1827
|
+
# 2) data-driven override
|
|
1828
|
+
y_train = getattr(self, "y_train_", None)
|
|
1829
|
+
if y_train is not None:
|
|
1830
|
+
y = np.asarray(y_train)
|
|
1831
|
+
if y.ndim == 2 and y.shape[1] == L:
|
|
1832
|
+
if K == 2:
|
|
1833
|
+
y = y.copy()
|
|
1834
|
+
y[y == 2] = 1
|
|
1835
|
+
valid = y >= 0
|
|
1836
|
+
if valid.any():
|
|
1837
|
+
observed = np.zeros((L, K), dtype=bool)
|
|
1838
|
+
for c in range(K):
|
|
1839
|
+
observed[:, c] = np.any(valid & (y == c), axis=0)
|
|
1840
|
+
|
|
1841
|
+
conflict = observed & (~mask)
|
|
1842
|
+
if conflict.any():
|
|
1843
|
+
loci = np.where(conflict.any(axis=1))[0]
|
|
1844
|
+
self.valid_class_mask_conflict_loci_ = loci
|
|
1845
|
+
self.logger.warning(
|
|
1846
|
+
f"valid_class_mask_ metadata forbids observed classes at {loci.size} loci. "
|
|
1847
|
+
"Expanding mask to include observed classes."
|
|
1848
|
+
)
|
|
1849
|
+
mask |= observed
|
|
1850
|
+
|
|
1851
|
+
bad = np.where(~mask.any(axis=1))[0]
|
|
1852
|
+
if bad.size:
|
|
1853
|
+
mask[bad, :] = False
|
|
1854
|
+
mask[bad, 0] = True
|
|
1855
|
+
|
|
1856
|
+
return torch.as_tensor(mask, dtype=torch.bool, device=self.device)
|