pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,10 @@ import copy
2
2
  import gc
3
3
  import json
4
4
  import logging
5
+ from collections import Counter
6
+ from datetime import datetime
5
7
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
7
9
 
8
10
  import matplotlib.pyplot as plt
9
11
  import numpy as np
@@ -13,11 +15,18 @@ import plotly.graph_objects as go
13
15
  import torch
14
16
  import torch.nn.functional as F
15
17
  from matplotlib.figure import Figure
16
- from sklearn.metrics import classification_report
18
+ from sklearn.metrics import (
19
+ average_precision_score,
20
+ classification_report,
21
+ jaccard_score,
22
+ matthews_corrcoef,
23
+ )
17
24
  from sklearn.model_selection import train_test_split
18
25
  from snpio import SNPioMultiQC
19
26
  from snpio.utils.logging import LoggerManager
27
+ from snpio.utils.misc import validate_input_type
20
28
 
29
+ from pgsui.data_processing.transformers import SimMissingTransformer
21
30
  from pgsui.impute.unsupervised.nn_scorers import Scorer
22
31
  from pgsui.utils.classification_viz import ClassificationReportVisualizer
23
32
  from pgsui.utils.logging_utils import configure_logger
@@ -27,16 +36,24 @@ from pgsui.utils.pretty_metrics import PrettyMetrics
27
36
  if TYPE_CHECKING:
28
37
  from snpio.read_input.genotype_data import GenotypeData
29
38
 
30
- from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
31
- from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
32
- from pgsui.impute.unsupervised.models.ubp_model import UBPModel
33
- from pgsui.impute.unsupervised.models.vae_model import VAEModel
39
+
40
+ class _MaskedNumpyDataset(torch.utils.data.Dataset):
41
+ def __init__(self, X: np.ndarray, y: np.ndarray, mask: np.ndarray):
42
+ self.X = X
43
+ self.y = y
44
+ self.mask = mask.astype(np.bool_, copy=False)
45
+
46
+ def __len__(self) -> int:
47
+ return self.X.shape[0]
48
+
49
+ def __getitem__(self, idx: int):
50
+ return self.X[idx], self.y[idx], self.mask[idx]
34
51
 
35
52
 
36
53
  class BaseNNImputer:
37
- """An abstract base class for neural network-based imputers.
54
+ """Abstract base class for neural network-based imputers.
38
55
 
39
- This class provides a shared framework and common functionality for all neural network imputers. It is not meant to be instantiated directly. Instead, child classes should inherit from it and implement the abstract methods. Provided functionality: Directory setup and logging initialization; A hyperparameter tuning pipeline using Optuna; Utility methods for building models (`build_model`), initializing weights (`initialize_weights`), and checking for fitted attributes (`ensure_attribute`); Helper methods for calculating class weights for imbalanced data; Setup for standardized plotting and model scoring classes.
56
+ This class provides shared infrastructure for NN imputers (e.g., directory/logging setup, Optuna tuning, model construction helpers, class-weight utilities, standardized plotting/scoring, and IUPAC decoding). It is not intended to be instantiated directly; subclasses must implement the abstract methods.
40
57
  """
41
58
 
42
59
  def __init__(
@@ -54,6 +71,8 @@ class BaseNNImputer:
54
71
  This constructor sets up the device (CPU, GPU, or MPS), creates the necessary output directories for models and results, and a logger. It also initializes a genotype encoder for handling genotype data.
55
72
 
56
73
  Args:
74
+ model_name (str): The model class name used in output paths and logs.
75
+ genotype_data (GenotypeData): Backing genotype data object.
57
76
  prefix (str): A prefix used to name the output directory (e.g., 'pgsui_output').
58
77
  device (Literal["gpu", "cpu", "mps"]): The device to use for PyTorch operations. If 'gpu' or 'mps' is chosen, it will fall back to 'cpu' if the required hardware is not available. Defaults to "cpu".
59
78
  verbose (bool): If True, enables detailed logging output. Defaults to False.
@@ -62,6 +81,11 @@ class BaseNNImputer:
62
81
  self.model_name = model_name
63
82
  self.genotype_data = genotype_data
64
83
 
84
+ if not hasattr(self, "tree_parser"):
85
+ self.tree_parser = None
86
+ if not hasattr(self, "sim_kwargs"):
87
+ self.sim_kwargs = {}
88
+
65
89
  self.prefix = prefix
66
90
  self.verbose = verbose
67
91
  self.debug = debug
@@ -89,15 +113,15 @@ class BaseNNImputer:
89
113
  self.logger = configure_logger(
90
114
  logman.get_logger(), verbose=self.verbose, debug=self.debug
91
115
  )
92
- self._float_genotype_cache: np.ndarray | None = None
93
- self._sim_mask_cache: dict[tuple, np.ndarray] = {}
116
+
117
+ self.logger.info(f"Using PyTorch device: {self.device.type}.")
94
118
 
95
119
  # To be initialized by child classes or fit method
96
120
  self.tune_save_db: bool = False
97
121
  self.tune_resume: bool = False
98
122
  self.n_trials: int = 100
99
123
  self.model_params: Dict[str, Any] = {}
100
- self.tune_metric: str = "val_f1_macro"
124
+ self.tune_metric: str = "f1"
101
125
  self.learning_rate: float = 1e-3
102
126
  self.plotter_: "Plotting"
103
127
  self.num_features_: int = 0
@@ -110,21 +134,16 @@ class BaseNNImputer:
110
134
  self.show_plots: bool = False
111
135
  self.scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
112
136
  self.pgenc: Any = None
113
- self.is_haploid: bool = False
137
+ self.is_haploid_: bool = False
114
138
  self.ploidy: int = 2
115
139
  self.beta: float = 0.9999
116
- self.max_ratio: float = 5.0
117
- self.sim_strategy: str = "mcar"
118
- self.sim_prop: float = 0.1
119
- self.seed: int | None = 42
140
+ self.max_ratio: Optional[float] = None
141
+ self.sim_strategy: str = "random"
142
+ self.sim_prop: float = 0.2
143
+ self.seed: Optional[int] = None
120
144
  self.rng: np.random.Generator = np.random.default_rng(self.seed)
121
145
  self.ground_truth_: np.ndarray
122
- self.tune_fast: bool = False
123
- self.tune_max_samples: int = 1000
124
- self.tune_max_loci: int = 500
125
146
  self.validation_split: float = 0.2
126
- self.tune_batch_size: int = 64
127
- self.tune_proxy_metric_batch: int = 512
128
147
  self.batch_size: int = 64
129
148
  self.best_params_: Dict[str, Any] = {}
130
149
 
@@ -133,9 +152,11 @@ class BaseNNImputer:
133
152
  self.plots_dir: Path
134
153
  self.metrics_dir: Path
135
154
  self.parameters_dir: Path
136
- self.study_db: Path
155
+ self.study_db: Optional[Path] = None
156
+ self.X_model_input_: Optional[np.ndarray] = None
157
+ self.class_weights_: Optional[torch.Tensor] = None
137
158
 
138
- def tune_hyperparameters(self) -> None:
159
+ def tune_hyperparameters(self) -> Dict[str, Any]:
139
160
  """Tunes model hyperparameters using an Optuna study.
140
161
 
141
162
  This method orchestrates the hyperparameter search process. It creates an Optuna study that aims to maximize the metric defined in `self.tune_metric`. The search is driven by the `_objective` method, which must be implemented by the child class. After the search, the best parameters are logged, saved to a JSON file, and visualizations of the study are generated.
@@ -153,15 +174,20 @@ class BaseNNImputer:
153
174
  study_db = None
154
175
  load_if_exists = False
155
176
  if self.tune_save_db:
156
- study_db = self.optimize_dir / "study_database" / "optuna_study.db"
177
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
178
+ study_db = (
179
+ self.optimize_dir / "study_database" / f"optuna_study_{timestamp}.db"
180
+ )
157
181
  study_db.parent.mkdir(parents=True, exist_ok=True)
158
182
 
159
183
  if self.tune_resume and study_db.exists():
160
184
  load_if_exists = True
161
185
 
162
186
  if not self.tune_resume and study_db.exists():
163
- study_db.unlink()
187
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
188
+ study_db = study_db.with_name(f"optuna_study_{timestamp}.db")
164
189
 
190
+ self.study_db = study_db
165
191
  study_name = f"{self.prefix} {self.model_name} Model Optimization"
166
192
  storage = f"sqlite:///{study_db}" if self.tune_save_db else None
167
193
 
@@ -170,7 +196,17 @@ class BaseNNImputer:
170
196
  study_name=study_name,
171
197
  storage=storage,
172
198
  load_if_exists=load_if_exists,
173
- pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
199
+ pruner=optuna.pruners.MedianPruner(
200
+ # Guard against small `n_trials` values (e.g., 1)
201
+ # that can otherwise produce 0 startup/warmup/min trials.
202
+ n_startup_trials=max(
203
+ 1, min(int(self.n_trials * 0.1), 10, int(self.n_trials))
204
+ ),
205
+ n_warmup_steps=150,
206
+ n_min_trials=max(
207
+ 1, min(int(0.5 * self.n_trials), 25, int(self.n_trials))
208
+ ),
209
+ ),
174
210
  )
175
211
 
176
212
  if not hasattr(self, "_objective"):
@@ -185,6 +221,13 @@ class BaseNNImputer:
185
221
 
186
222
  show_progress_bar = not self.verbose and not self.debug and self.n_jobs == 1
187
223
 
224
+ # Set the best parameters.
225
+ # NOTE: _set_best_params() must be implemented in the child class.
226
+ if not hasattr(self, "_set_best_params"):
227
+ msg = "Method `_set_best_params()` must be implemented in the child class."
228
+ self.logger.error(msg)
229
+ raise NotImplementedError(msg)
230
+
188
231
  study.optimize(
189
232
  lambda trial: self._objective(trial),
190
233
  n_trials=self.n_trials,
@@ -193,15 +236,13 @@ class BaseNNImputer:
193
236
  show_progress_bar=show_progress_bar,
194
237
  )
195
238
 
196
- best_metric = study.best_value
197
- best_params = study.best_params
198
-
199
- # Set the best parameters.
200
- # NOTE: `_set_best_params()` must be implemented in the child class.
201
- if not hasattr(self, "_set_best_params"):
202
- msg = "Method `_set_best_params()` must be implemented in the child class."
239
+ try:
240
+ best_metric = study.best_value
241
+ best_params = study.best_params
242
+ except Exception:
243
+ msg = "Tuning failed: No successful trials completed."
203
244
  self.logger.error(msg)
204
- raise NotImplementedError(msg)
245
+ raise RuntimeError(msg)
205
246
 
206
247
  self.best_params_ = self._set_best_params(best_params)
207
248
  self.model_params.update(self.best_params_)
@@ -210,45 +251,34 @@ class BaseNNImputer:
210
251
  best_params_tmp = copy.deepcopy(best_params)
211
252
  best_params_tmp["learning_rate"] = self.learning_rate
212
253
 
213
- title = f"{self.model_name} Optimized Parameters"
214
- pm = PrettyMetrics(best_params_tmp, precision=6, title=title)
215
- pm.render()
254
+ tn = f"{self.tune_metric} Value"
216
255
 
217
- # Save best parameters to a JSON file.
218
- self._save_best_params(best_params)
256
+ if self.show_plots:
257
+ self.plotter_.plot_tuning(
258
+ study, self.model_name, self.optimize_dir / "plots", target_name=tn
259
+ )
219
260
 
220
- tn = f"{self.tune_metric} Value"
221
- self.plotter_.plot_tuning(
222
- study, self.model_name, self.optimize_dir / "plots", target_name=tn
223
- )
261
+ return best_params_tmp
224
262
 
225
263
  @staticmethod
226
264
  def initialize_weights(module: torch.nn.Module) -> None:
227
- """Initializes model weights using the Kaiming Uniform distribution.
228
-
229
- This static method is intended to be applied to a PyTorch model to initialize the weights of its linear and convolutional layers. This initialization scheme is particularly effective for networks that use ReLU-family activation functions, as it helps maintain stable activation variances during training.
265
+ """Initializes model weights using Xavier/Glorot Uniform distribution.
230
266
 
231
- Args:
232
- module (torch.nn.Module): The PyTorch module (e.g., a layer) to initialize.
267
+ Switching from Kaiming to Xavier is safer for deep VAEs to prevent
268
+ exploding gradients or dead neurons in the early epochs.
233
269
  """
234
270
  if isinstance(
235
271
  module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.ConvTranspose1d)
236
272
  ):
237
- # Use Kaiming Uniform initialization for Linear and Conv layers
238
- torch.nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
273
+ # Xavier is generally more stable for VAEs than Kaiming
274
+ torch.nn.init.xavier_uniform_(module.weight)
239
275
  if module.bias is not None:
240
276
  torch.nn.init.zeros_(module.bias)
241
277
 
242
278
  def build_model(
243
279
  self,
244
- Model: (
245
- torch.nn.Module
246
- | type["AutoencoderModel"]
247
- | type["NLPCAModel"]
248
- | type["UBPModel"]
249
- | type["VAEModel"]
250
- ),
251
- model_params: Dict[str, int | float | str | bool],
280
+ Model: Type[torch.nn.Module],
281
+ model_params: Dict[str, Any],
252
282
  ) -> torch.nn.Module:
253
283
  """Builds and initializes a neural network model instance.
254
284
 
@@ -277,7 +307,6 @@ class BaseNNImputer:
277
307
  self.logger.error(msg)
278
308
  raise AttributeError(msg)
279
309
 
280
- # Start with a base set of fixed (non-tuned) parameters.
281
310
  all_params = {
282
311
  "n_features": self.num_features_,
283
312
  "prefix": self.prefix,
@@ -287,7 +316,7 @@ class BaseNNImputer:
287
316
  "device": self.device,
288
317
  }
289
318
 
290
- # Update with the variable hyperparameters from the provided dictionary
319
+ # Update with the variable hyperparameters
291
320
  all_params.update(model_params)
292
321
 
293
322
  return Model(**all_params).to(self.device)
@@ -369,110 +398,12 @@ class BaseNNImputer:
369
398
  X (np.ndarray | pd.DataFrame | list | None): The input data with missing values.
370
399
 
371
400
  Returns:
372
- np.ndarray: The data with missing values imputed.
401
+ np.ndarray: IUPAC strings with missing values imputed.
373
402
  """
374
403
  msg = "Method ``transform()`` must be implemented in the child class."
375
404
  self.logger.error(msg)
376
405
  raise NotImplementedError(msg)
377
406
 
378
- def _class_balanced_weights_from_mask(
379
- self,
380
- y: np.ndarray,
381
- train_mask: np.ndarray,
382
- num_classes: int,
383
- beta: float = 0.9999,
384
- max_ratio: float = 5.0,
385
- mode: Literal["allele", "genotype10"] = "allele",
386
- ) -> torch.Tensor:
387
- """Class-balanced weights (Cui et al. 2019) with overflow-safe effective number.
388
-
389
- mode="allele": y is 1D alleles in {0..3}, train_mask same shape. mode="genotype10": y is (nS,nF,2) alleles; train_mask is (nS,nF) loci where both alleles known.
390
-
391
- Args:
392
- y (np.ndarray): Ground truth labels.
393
- train_mask (np.ndarray): Boolean mask of training examples (same shape as y or y without last dim for genotype10).
394
- num_classes (int): Number of classes.
395
- beta (float): Hyperparameter for effective number calculation. Clamped to (0,1). Default is 0.9999.
396
- max_ratio (float): Maximum allowed ratio between largest and smallest non-zero weight. Default is 5.0.
397
- mode (Literal["allele", "genotype10"]): Whether y contains allele labels or 10-class genotypes. Default is "allele".
398
-
399
- Returns:
400
- torch.Tensor: Class weights of shape (num_classes,). Mean weight is 1.0, zero-weight classes remain zero.
401
- """
402
- if mode == "allele":
403
- valid = (y >= 0) & train_mask
404
- cls, cnt = np.unique(y[valid].astype(np.int64), return_counts=True)
405
- counts = np.zeros(num_classes, dtype=np.float64)
406
- counts[cls] = cnt
407
-
408
- elif mode == "genotype10":
409
- if y.ndim != 3 or y.shape[-1] != 2:
410
- msg = "For genotype10, y must be (nS,nF,2)."
411
- self.logger.error(msg)
412
- raise ValueError(msg)
413
-
414
- if train_mask.shape != y.shape[:2]:
415
- msg = "train_mask must be (nS,nF) for genotype10."
416
- self.logger.error(msg)
417
- raise ValueError(msg)
418
-
419
- # only loci where both alleles known and in training
420
- m = train_mask & np.all(y >= 0, axis=-1)
421
- if not np.any(m):
422
- counts = np.zeros(num_classes, dtype=np.float64)
423
-
424
- else:
425
- a1 = y[:, :, 0][m].astype(int)
426
- a2 = y[:, :, 1][m].astype(int)
427
- lo, hi = np.minimum(a1, a2), np.maximum(a1, a2)
428
- # map to 10-class index
429
- map10 = self.pgenc.map10
430
- idx10 = map10[lo, hi]
431
- idx10 = idx10[(idx10 >= 0) & (idx10 < num_classes)]
432
- counts = np.bincount(idx10, minlength=num_classes).astype(np.float64)
433
- else:
434
- msg = f"Unknown mode supplied to _class_balanced_weights_from_mask: {mode}"
435
- self.logger.error(msg)
436
- raise ValueError(msg)
437
-
438
- # ---- Effective number ----
439
- beta = float(beta)
440
-
441
- # clamp beta ∈ (0,1)
442
- if not np.isfinite(beta):
443
- beta = 0.9999
444
-
445
- beta = min(max(beta, 1e-8), 1.0 - 1e-8)
446
-
447
- logb = np.log(beta) # < 0
448
- t = counts * logb # ≤ 0
449
-
450
- # 1 - beta^n = 1 - exp(n*log(beta)) = -(exp(n*log(beta)) - 1)
451
- # use expm1 for accuracy near 0; for very negative t, eff≈1.0
452
- eff = np.where(t > -50.0, -np.expm1(t), 1.0)
453
-
454
- # class-balanced weights
455
- w = (1.0 - beta) / (eff + 1e-12)
456
-
457
- # Give unseen classes the largest non-zero weight (keeps it learnable)
458
- if np.any(counts == 0) and np.any(counts > 0):
459
- w[counts == 0] = w[counts > 0].max()
460
-
461
- # normalize by mean of non-zero
462
- nz = w > 0
463
- w[nz] /= w[nz].mean() + 1e-12
464
-
465
- # cap spread consistently with a single 'cap'
466
- cap = float(max_ratio) if max_ratio is not None else 10.0
467
- cap = max(cap, 5.0) # ensure we allow some differentiation
468
- if np.any(nz):
469
- spread = w[nz].max() / max(w[nz].min(), 1e-12)
470
- if spread > cap:
471
- scale = cap / spread
472
- w[nz] = 1.0 + (w[nz] - 1.0) * scale
473
-
474
- return torch.tensor(w.astype(np.float32), device=self.device)
475
-
476
407
  def _select_device(self, device: Literal["gpu", "cpu", "mps"]) -> torch.device:
477
408
  """Selects the appropriate PyTorch device based on user preference and availability.
478
409
 
@@ -484,36 +415,37 @@ class BaseNNImputer:
484
415
  Returns:
485
416
  torch.device: The selected PyTorch device.
486
417
  """
487
- dvc: str = device
488
- dvc = dvc.lower().strip()
418
+ dvc = device.lower().strip()
489
419
  if dvc == "cpu":
490
- self.logger.info("Using PyTorch device: CPU.")
491
420
  return torch.device("cpu")
492
421
  if dvc == "mps":
493
422
  if torch.backends.mps.is_available():
494
- self.logger.info("Using PyTorch device: mps.")
495
423
  return torch.device("mps")
496
- self.logger.warning("MPS unavailable; falling back to CPU.")
497
424
  return torch.device("cpu")
498
- # gpu
499
425
  if torch.cuda.is_available():
500
- self.logger.info("Using PyTorch device: cuda.")
501
426
  return torch.device("cuda")
502
- self.logger.warning("CUDA unavailable; falling back to CPU.")
503
427
  return torch.device("cpu")
504
428
 
505
- def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
429
+ def _create_model_directories(
430
+ self, prefix: str, outdirs: List[str], *, outdir: Path | str | None = None
431
+ ) -> None:
506
432
  """Creates the directory structure for storing model outputs.
507
433
 
508
- This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
434
+ This method sets up a standardized folder hierarchy for saving models,
435
+ plots, metrics, and optimization results, organized under a main directory
436
+ named after the provided prefix. The current implementation always uses
437
+ ``<cwd>/<prefix>_output`` regardless of ``outdir``.
509
438
 
510
439
  Args:
511
440
  prefix (str): The prefix for the main output directory.
512
441
  outdirs (List[str]): A list of subdirectory names to create within the main directory.
442
+ outdir (Path | str | None): Requested base output directory (currently ignored).
513
443
 
514
444
  Raises:
515
445
  Exception: If any of the directories cannot be created.
516
446
  """
447
+ base_root = Path(outdir) if outdir is not None else Path.cwd()
448
+ formatted_output_dir = base_root / f"{prefix}_output"
517
449
  formatted_output_dir = Path(f"{prefix}_output")
518
450
  base_dir = formatted_output_dir / "Unsupervised"
519
451
 
@@ -527,27 +459,16 @@ class BaseNNImputer:
527
459
  self.logger.error(msg)
528
460
  raise Exception(msg)
529
461
 
530
- def _clear_resources(
531
- self,
532
- model: torch.nn.Module,
533
- train_loader: torch.utils.data.DataLoader,
534
- latent_vectors: torch.nn.Parameter | None = None,
535
- ) -> None:
462
+ def _clear_resources(self, model: torch.nn.Module) -> None:
536
463
  """Releases GPU and CPU memory after an Optuna trial.
537
464
 
538
465
  This is a crucial step during hyperparameter tuning to prevent memory leaks between trials, ensuring that each trial runs in a clean environment.
539
466
 
540
467
  Args:
541
468
  model (torch.nn.Module): The model from the completed trial.
542
- train_loader (torch.utils.data.DataLoader): The data loader from the trial.
543
- latent_vectors (torch.nn.Parameter | None): The latent vectors from the trial.
544
469
  """
545
470
  try:
546
- del model, train_loader
547
-
548
- if latent_vectors is not None:
549
- del latent_vectors
550
-
471
+ del model
551
472
  except NameError:
552
473
  pass
553
474
 
@@ -568,7 +489,7 @@ class BaseNNImputer:
568
489
  y_pred: np.ndarray,
569
490
  metrics: Dict[str, float],
570
491
  msg: str,
571
- ):
492
+ ) -> None:
572
493
  """Generate and save evaluation visualizations.
573
494
 
574
495
  3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
@@ -586,20 +507,109 @@ class BaseNNImputer:
586
507
  prefix = "zygosity" if len(labels) == 3 else "iupac"
587
508
  n_labels = len(labels)
588
509
 
589
- self.plotter_.plot_metrics(
590
- y_true=y_true,
591
- y_pred_proba=y_pred_proba,
592
- metrics=metrics,
593
- label_names=labels,
594
- prefix=f"geno{n_labels}_{prefix}",
510
+ if self.show_plots:
511
+ self.plotter_.plot_metrics(
512
+ y_true=y_true,
513
+ y_pred_proba=y_pred_proba,
514
+ metrics=metrics,
515
+ label_names=labels,
516
+ prefix=f"geno{n_labels}_{prefix}",
517
+ )
518
+ self.plotter_.plot_confusion_matrix(
519
+ y_true_1d=y_true,
520
+ y_pred_1d=y_pred,
521
+ label_names=labels,
522
+ prefix=f"geno{n_labels}_{prefix}",
523
+ )
524
+
525
+ def _additional_metrics(
526
+ self,
527
+ y_true: np.ndarray,
528
+ y_pred: np.ndarray,
529
+ labels: list[int],
530
+ report_names: list[str],
531
+ report: dict,
532
+ ) -> dict[str, dict[str, float] | float]:
533
+ """Compute additional metrics and augment the report dictionary.
534
+
535
+ Args:
536
+ y_true (np.ndarray): True genotypes.
537
+ y_pred (np.ndarray): Predicted genotypes.
538
+ labels (list[int]): List of label indices.
539
+ report_names (list[str]): List of report names corresponding to labels.
540
+ report (dict): Classification report dictionary to augment.
541
+
542
+ Returns:
543
+ dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
544
+ """
545
+ # Create an identity matrix and use the targets array as indices
546
+ y_score = np.eye(len(report_names))[y_pred]
547
+
548
+ # Per-class metrics
549
+ ap_pc = average_precision_score(y_true, y_score, average=None)
550
+ jaccard_pc = jaccard_score(
551
+ y_true, y_pred, average=None, labels=labels, zero_division=0
595
552
  )
596
- self.plotter_.plot_confusion_matrix(
597
- y_true_1d=y_true,
598
- y_pred_1d=y_pred,
599
- label_names=labels,
600
- prefix=f"geno{n_labels}_{prefix}",
553
+
554
+ # Macro/weighted metrics
555
+ ap_macro = average_precision_score(y_true, y_score, average="macro")
556
+ ap_weighted = average_precision_score(y_true, y_score, average="weighted")
557
+ jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
558
+ jaccard_weighted = jaccard_score(
559
+ y_true, y_pred, average="weighted", zero_division=0
601
560
  )
602
561
 
562
+ # Matthews correlation coefficient (MCC)
563
+ mcc = matthews_corrcoef(y_true, y_pred)
564
+
565
+ if not isinstance(ap_pc, np.ndarray):
566
+ msg = "average_precision_score or f1_score did not return np.ndarray as expected."
567
+ self.logger.error(msg)
568
+ raise TypeError(msg)
569
+
570
+ if not isinstance(jaccard_pc, np.ndarray):
571
+ msg = "jaccard_score did not return np.ndarray as expected."
572
+ self.logger.error(msg)
573
+ raise TypeError(msg)
574
+
575
+ # Add per-class metrics
576
+ report_full = {}
577
+ dd_subset = {
578
+ k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
579
+ }
580
+ for i, class_name in enumerate(report_names):
581
+ class_report: dict[str, float] = {}
582
+ if class_name in dd_subset:
583
+ class_report = dd_subset[class_name]
584
+
585
+ if isinstance(class_report, float) or not class_report:
586
+ continue
587
+
588
+ report_full[class_name] = dict(class_report)
589
+ report_full[class_name]["average-precision"] = float(ap_pc[i])
590
+ report_full[class_name]["jaccard"] = float(jaccard_pc[i])
591
+
592
+ macro_avg = report.get("macro avg")
593
+ if isinstance(macro_avg, dict):
594
+ report_full["macro avg"] = dict(macro_avg)
595
+ report_full["macro avg"]["average-precision"] = float(ap_macro)
596
+ report_full["macro avg"]["jaccard"] = float(jaccard_macro)
597
+
598
+ weighted_avg = report.get("weighted avg")
599
+ if isinstance(weighted_avg, dict):
600
+ report_full["weighted avg"] = dict(weighted_avg)
601
+ report_full["weighted avg"]["average-precision"] = float(ap_weighted)
602
+ report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
603
+
604
+ # Add scalar summary metrics
605
+ report_full["mcc"] = float(mcc)
606
+ accuracy_val = report.get("accuracy")
607
+
608
+ if isinstance(accuracy_val, (int, float)):
609
+ report_full["accuracy"] = float(accuracy_val)
610
+
611
+ return report_full
612
+
603
613
  def _make_class_reports(
604
614
  self,
605
615
  y_true: np.ndarray,
@@ -617,27 +627,28 @@ class BaseNNImputer:
617
627
  y_pred (np.ndarray): Predicted labels (1D array).
618
628
  metrics (Dict[str, float]): Computed metrics.
619
629
  y_pred_proba (np.ndarray | None): Predicted probabilities (2D array). Defaults to None.
620
- labels (List[str]): Class label names
621
- (default: ["REF", "HET", "ALT"] for 3-class).
630
+ labels (List[str]): Class label names (default: ["REF", "HET", "ALT"] for 3-class).
622
631
  """
623
- report_name = "zygosity" if len(labels) == 3 else "iupac"
632
+ report_name = "zygosity" if len(labels) <= 3 else "iupac"
624
633
  middle = "IUPAC" if report_name == "iupac" else "Zygosity"
625
634
 
626
- msg = f"{middle} Report (on {y_true.size} total genotypes)"
635
+ msg = f"{middle} Report (on {y_pred.size} total genotypes)"
627
636
  self.logger.info(msg)
628
637
 
629
638
  if y_pred_proba is not None:
630
- self.plotter_.plot_metrics(
631
- y_true,
632
- y_pred_proba,
633
- metrics,
634
- label_names=labels,
635
- prefix=report_name,
636
- )
639
+ if self.show_plots:
640
+ self.plotter_.plot_metrics(
641
+ y_true,
642
+ y_pred_proba,
643
+ metrics,
644
+ label_names=labels,
645
+ prefix=report_name,
646
+ )
637
647
 
638
- self.plotter_.plot_confusion_matrix(
639
- y_true, y_pred, label_names=labels, prefix=report_name
640
- )
648
+ if self.show_plots:
649
+ self.plotter_.plot_confusion_matrix(
650
+ y_true, y_pred, label_names=labels, prefix=report_name
651
+ )
641
652
 
642
653
  report: str | dict = classification_report(
643
654
  y_true,
@@ -650,62 +661,63 @@ class BaseNNImputer:
650
661
 
651
662
  if not isinstance(report, dict):
652
663
  msg = "Expected classification_report to return a dict."
653
- self.logger.error(msg)
664
+ self.logger.error(msg, exc_info=True)
654
665
  raise ValueError(msg)
655
666
 
656
- report_subset = {}
657
- for k, v in report.items():
658
- tmp = {}
659
- if isinstance(v, dict) and "support" in v:
660
- for k2, v2 in v.items():
661
- if k2 != "support":
662
- tmp[k2] = v2
663
- if tmp:
664
- report_subset[k] = tmp
665
-
666
- if report_subset:
667
+ if self.show_plots:
668
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
669
+ try:
670
+ plots = viz.plot_all(
671
+ report, # type: ignore
672
+ title_prefix=f"{self.model_name} {middle} Report",
673
+ show=self.show_plots,
674
+ heatmap_classes_only=True,
675
+ )
676
+ finally:
677
+ viz._reset_mpl_style()
678
+
679
+ for name, fig in plots.items():
680
+ fout = (
681
+ self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
682
+ )
683
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
684
+ fig.savefig(fout, dpi=300, facecolor="#111122")
685
+ plt.close(fig)
686
+ elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
687
+ fout_html = fout.with_suffix(".html")
688
+ fig.write_html(file=fout_html)
689
+
690
+ SNPioMultiQC.queue_html(
691
+ fout_html,
692
+ panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
693
+ section=f"PG-SUI: {self.model_name} Model Imputation",
694
+ title=f"{self.model_name} {middle} Radar Plot",
695
+ index_label=name,
696
+ description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
697
+ )
698
+
699
+ if not self.is_haploid_:
700
+ msg = f"Ploidy: {self.ploidy}. Evaluating per genotype (REF, HET, ALT)."
701
+ self.logger.info(msg)
702
+
703
+ report_full = self._additional_metrics(
704
+ y_true,
705
+ y_pred,
706
+ labels=list(range(len(labels))),
707
+ report_names=labels,
708
+ report=report,
709
+ )
710
+
711
+ if self.verbose or self.debug:
667
712
  pm = PrettyMetrics(
668
- report_subset,
669
- precision=3,
713
+ report_full,
714
+ precision=2,
670
715
  title=f"{self.model_name} {middle} Report",
671
716
  )
672
717
  pm.render()
673
718
 
674
719
  with open(self.metrics_dir / f"{report_name}_report.json", "w") as f:
675
- json.dump(report, f, indent=4)
676
-
677
- viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
678
-
679
- plots = viz.plot_all(
680
- report, # type: ignore
681
- title_prefix=f"{self.model_name} {middle} Report",
682
- show=getattr(self, "show_plots", False),
683
- heatmap_classes_only=True,
684
- )
685
-
686
- for name, fig in plots.items():
687
- fout = self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
688
- if hasattr(fig, "savefig") and isinstance(fig, Figure):
689
- fig.savefig(fout, dpi=300, facecolor="#111122")
690
- plt.close(fig)
691
- elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
692
- fout_html = fout.with_suffix(".html")
693
- fig.write_html(file=fout_html)
694
-
695
- SNPioMultiQC.queue_html(
696
- fout_html,
697
- panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
698
- section=f"PG-SUI: {self.model_name} Model Imputation",
699
- title=f"{self.model_name} {middle} Radar Plot",
700
- index_label=name,
701
- description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
702
- )
703
-
704
- if not self.is_haploid:
705
- msg = f"Ploidy: {self.ploidy}. Evaluating per allele."
706
- self.logger.info(msg)
707
-
708
- viz._reset_mpl_style()
720
+ json.dump(report_full, f, indent=4)
709
721
 
710
722
  def _compute_hidden_layer_sizes(
711
723
  self,
@@ -713,6 +725,7 @@ class BaseNNImputer:
713
725
  n_outputs: int,
714
726
  n_samples: int,
715
727
  n_hidden: int,
728
+ latent_dim: int,
716
729
  *,
717
730
  alpha: float = 4.0,
718
731
  schedule: str = "pyramid",
@@ -724,182 +737,439 @@ class BaseNNImputer:
724
737
  ) -> list[int]:
725
738
  """Compute hidden layer sizes given problem scale and a layer count.
726
739
 
727
- This method computes a list of hidden layer sizes based on the number of input features, output classes, training samples, and desired hidden layers. The sizes are determined using a specified schedule (pyramid, constant, or linear) and are constrained by minimum and maximum sizes, as well as rounding to multiples of a specified value.
740
+ Notes:
741
+ - Returns sizes for *hidden layers only* (length = n_hidden).
742
+ - Does NOT include the input layer (n_inputs) or the latent layer (latent_dim).
743
+ - Enforces a latent-aware minimum: one discrete level above latent_dim, where a level is `multiple_of`.
744
+ - Enforces *strictly decreasing* hidden sizes (no repeats). This may require bumping `base` upward.
728
745
 
729
746
  Args:
730
- n_inputs (int): Number of input features.
731
- n_outputs (int): Number of output classes.
732
- n_samples (int): Number of training samples.
733
- n_hidden (int): Number of hidden layers.
734
- alpha (float): Scaling factor for base layer size. Default is 4.0.
735
- schedule (Literal["pyramid", "constant", "linear"]): Size schedule. Default is "pyramid".
736
- min_size (int): Minimum layer size. Default is 16.
737
- max_size (int | None): Maximum layer size. Default is None (no limit).
738
- multiple_of (int): Round layer sizes to be multiples of this. Default is 8.
739
- decay (float | None): Decay factor for "pyramid" schedule. If None, it is computed automatically. Default is None.
740
- cap_by_inputs (bool): If True, cap layer sizes to n_inputs. Default is True.
747
+ n_inputs: Number of input features (e.g., flattened one-hot: num_features * num_classes).
748
+ n_outputs: Number of output classes (often equals num_classes).
749
+ n_samples: Number of training samples.
750
+ n_hidden: Number of hidden layers (excluding input and latent layers).
751
+ latent_dim: Latent dimensionality (not returned, used only to set a floor).
752
+ alpha: Scaling factor for base layer size.
753
+ schedule: Size schedule ("pyramid" or "linear").
754
+ min_size: Minimum layer size floor before latent-aware adjustment.
755
+ max_size: Maximum layer size cap. If None, a heuristic cap is used.
756
+ multiple_of: Hidden sizes are multiples of this value.
757
+ decay: Pyramid decay factor. If None, computed to land near the target.
758
+ cap_by_inputs: If True, cap layer sizes to n_inputs.
741
759
 
742
760
  Returns:
743
- list[int]: List of hidden layer sizes.
761
+ list[int]: Hidden layer sizes (len = n_hidden).
744
762
 
745
763
  Raises:
746
- ValueError: If n_hidden < 0 or if alpha * (n_inputs + n_outputs) <= 0 or if schedule is unknown.
747
- TypeError: If any argument is not of the expected type.
748
-
749
- Notes:
750
- - If n_hidden is 0, returns an empty list.
751
- - The base layer size is computed as ceil(n_samples / (alpha * (n_inputs + n_outputs))).
752
- - The sizes are adjusted according to the specified schedule and constraints.
764
+ ValueError: On invalid arguments or conflicting constraints.
753
765
  """
766
+ # ----------------------------
767
+ # Basic validation
768
+ # ----------------------------
754
769
  if n_hidden < 0:
755
770
  msg = f"n_hidden must be >= 0, got {n_hidden}."
756
771
  self.logger.error(msg)
757
772
  raise ValueError(msg)
758
773
 
759
- if schedule not in {"pyramid", "constant", "linear"}:
760
- msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
774
+ if n_hidden == 0:
775
+ return []
776
+
777
+ if n_inputs <= 0:
778
+ msg = f"n_inputs must be > 0, got {n_inputs}."
761
779
  self.logger.error(msg)
762
780
  raise ValueError(msg)
763
781
 
764
- if n_hidden == 0:
765
- return []
782
+ if n_outputs <= 0:
783
+ msg = f"n_outputs must be > 0, got {n_outputs}."
784
+ self.logger.error(msg)
785
+ raise ValueError(msg)
766
786
 
767
- denom = float(alpha) * float(n_inputs + n_outputs)
787
+ if n_samples <= 0:
788
+ msg = f"n_samples must be > 0, got {n_samples}."
789
+ self.logger.error(msg)
790
+ raise ValueError(msg)
768
791
 
769
- if denom <= 0:
770
- msg = f"alpha * (n_inputs + n_outputs) must be > 0, got {denom}."
792
+ if latent_dim <= 0:
793
+ msg = f"latent_dim must be > 0, got {latent_dim}."
771
794
  self.logger.error(msg)
772
795
  raise ValueError(msg)
773
796
 
774
- base = int(np.ceil(float(n_samples) / denom))
797
+ if multiple_of <= 0:
798
+ msg = f"multiple_of must be > 0, got {multiple_of}."
799
+ self.logger.error(msg)
800
+ raise ValueError(msg)
801
+
802
+ if alpha <= 0:
803
+ msg = f"alpha must be > 0, got {alpha}."
804
+ self.logger.error(msg)
805
+ raise ValueError(msg)
806
+
807
+ schedule = str(schedule).lower().strip()
808
+ if schedule not in {"pyramid", "linear"}:
809
+ msg = f"Invalid schedule '{schedule}'. Must be 'pyramid' or 'linear'."
810
+ self.logger.error(msg)
811
+ raise ValueError(msg)
812
+
813
+ # ----------------------------
814
+ # Latent-aware minimum floor
815
+ # ----------------------------
816
+ # Smallest multiple_of strictly greater than latent_dim
817
+ min_hidden_floor = int(np.ceil((latent_dim + 1) / multiple_of) * multiple_of)
818
+ effective_min = max(int(min_size), min_hidden_floor)
819
+
820
+ if cap_by_inputs and n_inputs < effective_min:
821
+ msg = (
822
+ "Cannot satisfy latent-aware minimum hidden size with cap_by_inputs=True. "
823
+ f"Required hidden size >= {effective_min} (one level above latent_dim={latent_dim}), "
824
+ f"but n_inputs={n_inputs}. Set cap_by_inputs=False or reduce latent_dim/multiple_of."
825
+ )
826
+ self.logger.error(msg)
827
+ raise ValueError(msg)
828
+
829
+ # ----------------------------
830
+ # Infer num_features (if using flattened one-hot: n_inputs = num_features * num_classes)
831
+ # ----------------------------
832
+ if n_inputs % n_outputs == 0:
833
+ num_features = n_inputs // n_outputs
834
+ else:
835
+ num_features = n_inputs
836
+ self.logger.warning(
837
+ "n_inputs is not divisible by n_outputs; falling back to num_features=n_inputs "
838
+ f"(n_inputs={n_inputs}, n_outputs={n_outputs}). If using one-hot flattening, "
839
+ "pass n_outputs=num_classes so num_features can be inferred correctly."
840
+ )
775
841
 
842
+ # ----------------------------
843
+ # Base size heuristic (feature-matrix aware; avoids collapse for huge n_inputs)
844
+ # ----------------------------
845
+ obs_scale = (float(n_samples) * float(num_features)) / float(
846
+ num_features + n_outputs
847
+ )
848
+ base = int(np.ceil(float(alpha) * np.sqrt(obs_scale)))
849
+
850
+ # ----------------------------
851
+ # Determine max_size
852
+ # ----------------------------
776
853
  if max_size is None:
777
- max_size = max(n_inputs, base)
854
+ max_size = max(int(n_inputs), int(base), int(effective_min))
855
+
856
+ if cap_by_inputs:
857
+ max_size = min(int(max_size), int(n_inputs))
858
+ else:
859
+ max_size = int(max_size)
860
+
861
+ if max_size < effective_min:
862
+ msg = (
863
+ f"max_size ({max_size}) must be >= effective_min ({effective_min}), where effective_min "
864
+ f"is max(min_size={min_size}, one-level-above latent_dim={latent_dim})."
865
+ )
866
+ self.logger.error(msg)
867
+ raise ValueError(msg)
868
+
869
+ # Round base up to a multiple and clip to bounds
870
+ base = int(np.clip(base, effective_min, max_size))
871
+ base = int(np.ceil(base / multiple_of) * multiple_of)
872
+ base = int(np.clip(base, effective_min, max_size))
873
+
874
+ # ----------------------------
875
+ # Enforce "no repeats" feasibility in discrete levels
876
+ # Need n_hidden distinct multiples between base and effective_min:
877
+ # base >= effective_min + (n_hidden - 1) * multiple_of
878
+ # ----------------------------
879
+ required_min_base = effective_min + (n_hidden - 1) * multiple_of
880
+
881
+ if required_min_base > max_size:
882
+ msg = (
883
+ "Cannot build strictly-decreasing (no-repeat) hidden sizes under current constraints. "
884
+ f"Need base >= {required_min_base} to fit n_hidden={n_hidden} distinct layers "
885
+ f"with multiple_of={multiple_of} down to effective_min={effective_min}, "
886
+ f"but max_size={max_size}. Reduce n_hidden, reduce multiple_of, lower latent_dim/min_size, "
887
+ "or increase max_size / set cap_by_inputs=False."
888
+ )
889
+ self.logger.error(msg)
890
+ raise ValueError(msg)
778
891
 
779
- base = int(np.clip(base, min_size, max_size))
892
+ if base < required_min_base:
893
+ # Bump base upward so a strict staircase is possible
894
+ base = required_min_base
895
+ base = int(np.ceil(base / multiple_of) * multiple_of)
896
+ base = int(np.clip(base, effective_min, max_size))
780
897
 
781
- if schedule == "constant":
782
- sizes = np.full(shape=(n_hidden,), fill_value=base, dtype=float)
898
+ # Work in "levels" of multiple_of for guaranteed uniqueness
899
+ start_level = base // multiple_of
900
+ end_level = effective_min // multiple_of
901
+
902
+ # Sanity: distinct levels available
903
+ if (start_level - end_level) < (n_hidden - 1):
904
+ # This should not happen due to required_min_base logic, but keep a hard guard.
905
+ msg = (
906
+ "Internal constraint failure: insufficient discrete levels to enforce no repeats. "
907
+ f"start_level={start_level}, end_level={end_level}, n_hidden={n_hidden}."
908
+ )
909
+ self.logger.error(msg)
910
+ raise ValueError(msg)
911
+
912
+ # ----------------------------
913
+ # Build schedule in level space (integers), then convert to sizes
914
+ # ----------------------------
915
+ if n_hidden == 1:
916
+ levels = np.array([start_level], dtype=int)
783
917
 
784
918
  elif schedule == "linear":
785
- target = max(min_size, min(base, base // 4))
786
- sizes = (
787
- np.array([base], dtype=float)
788
- if n_hidden == 1
789
- else np.linspace(base, target, num=n_hidden, dtype=float)
919
+ # Linear interpolation in level space, then strictify
920
+ levels = np.round(np.linspace(start_level, end_level, num=n_hidden)).astype(
921
+ int
790
922
  )
791
923
 
924
+ # Enforce bounds then strict decrease
925
+ levels = np.clip(levels, end_level, start_level)
926
+
927
+ for i in range(1, n_hidden):
928
+ if levels[i] >= levels[i - 1]:
929
+ levels[i] = levels[i - 1] - 1
930
+
931
+ if levels[-1] < end_level:
932
+ msg = (
933
+ "Failed to enforce strictly-decreasing linear schedule without violating the floor. "
934
+ f"(levels[-1]={levels[-1]} < end_level={end_level}). "
935
+ "Reduce n_hidden or multiple_of, or increase max_size."
936
+ )
937
+ self.logger.error(msg)
938
+ raise ValueError(msg)
939
+
940
+ # Force exact floor at the end (still strict because we have enough room by construction)
941
+ levels[-1] = end_level
942
+ for i in range(n_hidden - 2, -1, -1):
943
+ if levels[i] <= levels[i + 1]:
944
+ levels[i] = levels[i + 1] + 1
945
+
946
+ if levels[0] > start_level:
947
+ # If this happens, we would need an even larger base; handle by raising base once.
948
+ needed_base = int(levels[0] * multiple_of)
949
+ if needed_base > max_size:
950
+ msg = (
951
+ "Cannot enforce strictly-decreasing linear schedule after floor anchoring; "
952
+ f"would require base={needed_base} > max_size={max_size}."
953
+ )
954
+ self.logger.error(msg)
955
+ raise ValueError(msg)
956
+ # Rebuild with bumped base
957
+ start_level = needed_base // multiple_of
958
+ levels = np.arange(start_level, start_level - n_hidden, -1, dtype=int)
959
+ levels[-1] = end_level # keep floor
960
+ # Ensure strict with backward adjust
961
+ for i in range(n_hidden - 2, -1, -1):
962
+ if levels[i] <= levels[i + 1]:
963
+ levels[i] = levels[i + 1] + 1
964
+
792
965
  elif schedule == "pyramid":
793
- if n_hidden == 1:
794
- sizes = np.array([base], dtype=float)
966
+ # Geometric decay in level space (more aggressive early taper than linear)
967
+ if decay is not None:
968
+ dcy = float(decay)
795
969
  else:
970
+ # Choose decay to land exactly at end_level (in float space)
971
+ dcy = (float(end_level) / float(start_level)) ** (
972
+ 1.0 / float(n_hidden - 1)
973
+ )
974
+
975
+ # Keep it in a sensible range
976
+ dcy = float(np.clip(dcy, 0.05, 0.99))
977
+
978
+ exponents = np.arange(n_hidden, dtype=float)
979
+ levels_float = float(start_level) * (dcy**exponents)
980
+
981
+ levels = np.round(levels_float).astype(int)
982
+ levels = np.clip(levels, end_level, start_level)
983
+
984
+ # Anchor the last layer at the floor, then strictify backward
985
+ levels[-1] = end_level
986
+ for i in range(n_hidden - 2, -1, -1):
987
+ if levels[i] <= levels[i + 1]:
988
+ levels[i] = levels[i + 1] + 1
989
+
990
+ # If we overshot the start, bump base (once) if possible, then rebuild
991
+ if levels[0] > start_level:
992
+ needed_base = int(levels[0] * multiple_of)
993
+ if needed_base > max_size:
994
+ msg = (
995
+ "Cannot enforce strictly-decreasing pyramid schedule; "
996
+ f"would require base={needed_base} > max_size={max_size}."
997
+ )
998
+ self.logger.error(msg)
999
+ raise ValueError(msg)
1000
+
1001
+ start_level = needed_base // multiple_of
1002
+ # Recompute with new start_level and same decay (or recompute decay if decay is None)
796
1003
  if decay is None:
797
- target = max(min_size, base // 4)
798
- if base <= 0 or target <= 0:
799
- dcy = 1.0
800
- else:
801
- dcy = (target / float(base)) ** (1.0 / (n_hidden - 1))
802
- dcy = float(np.clip(dcy, 0.25, 0.99))
803
- exponents = np.arange(n_hidden, dtype=float)
804
- sizes = base * (dcy**exponents)
1004
+ dcy = (float(end_level) / float(start_level)) ** (
1005
+ 1.0 / float(n_hidden - 1)
1006
+ )
1007
+ dcy = float(np.clip(dcy, 0.05, 0.99))
1008
+
1009
+ levels_float = float(start_level) * (dcy**exponents)
1010
+ levels = np.round(levels_float).astype(int)
1011
+ levels = np.clip(levels, end_level, start_level)
1012
+ levels[-1] = end_level
1013
+ for i in range(n_hidden - 2, -1, -1):
1014
+ if levels[i] <= levels[i + 1]:
1015
+ levels[i] = levels[i + 1] + 1
805
1016
 
806
1017
  else:
807
- msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
1018
+ msg = f"Unknown schedule '{schedule}'. Use 'pyramid' or 'linear' (constant disallowed with no repeats)."
808
1019
  self.logger.error(msg)
809
1020
  raise ValueError(msg)
810
1021
 
811
- sizes = np.clip(sizes, min_size, max_size)
1022
+ # Convert levels -> sizes
1023
+ sizes = (levels * multiple_of).astype(int)
812
1024
 
813
- if cap_by_inputs:
814
- sizes = np.minimum(sizes, float(n_inputs))
1025
+ # Final clip (should be redundant, but safe)
1026
+ sizes = np.clip(sizes, effective_min, max_size).astype(int)
815
1027
 
816
- sizes = (np.ceil(sizes / multiple_of) * multiple_of).astype(int)
817
- sizes = np.minimum.accumulate(sizes)
818
- return np.clip(sizes, min_size, max_size).astype(int).tolist()
1028
+ # Final strict no-repeat assertion
1029
+ if np.any(np.diff(sizes) >= 0):
1030
+ msg = (
1031
+ "Internal error: produced non-decreasing or repeated hidden sizes after strict enforcement. "
1032
+ f"sizes={sizes.tolist()}"
1033
+ )
1034
+ self.logger.error(msg)
1035
+ raise ValueError(msg)
819
1036
 
820
- def _class_weights_from_zygosity(self, X: np.ndarray) -> torch.Tensor:
821
- """Class-balanced weights for 0/1/2 (handles haploid collapse if needed).
1037
+ return sizes.tolist()
822
1038
 
823
- This method computes class-balanced weights for the genotype classes (0/1/2) based on the provided genotype matrix. It handles cases where the data is haploid by collapsing the ALT class to 1, effectively treating the problem as binary classification (REF vs ALT). The weights are calculated using a class-balanced weighting scheme that considers the frequency of each class in the training data, with parameters for beta and maximum ratio to control the weighting behavior. The resulting weights are returned as a PyTorch tensor on the current device.
1039
+ def _class_weights_from_zygosity(
1040
+ self,
1041
+ X: np.ndarray,
1042
+ train_mask: Optional[np.ndarray] = None,
1043
+ *,
1044
+ inverse: bool = False,
1045
+ normalize: bool = False,
1046
+ power: float = 1.0,
1047
+ max_ratio: float | None = None,
1048
+ eps: float = 1e-12,
1049
+ ) -> torch.Tensor:
1050
+ """Compute class weights for zygosity labels.
824
1051
 
825
- Args:
826
- X (np.ndarray): 0/1/2 with -1 for missing.
1052
+ If inverse=False (default):
1053
+ w_c = N / (K * n_c) ("balanced")
1054
+
1055
+ If inverse=True:
1056
+ w_c = N / n_c (same ratios, scaled by K)
1057
+
1058
+ If power != 1.0:
1059
+ w_c <- w_c ** power (amplifies or softens imbalance handling)
1060
+
1061
+ If normalize=True:
1062
+ rescales nonzero weights so mean(nonzero_weights) == 1.
827
1063
 
828
1064
  Returns:
829
- torch.Tensor: Weights on current device.
1065
+ torch.Tensor: Class weights of shape (num_classes,) on self.device.
830
1066
  """
831
- y = X[X != -1].ravel().astype(np.int64)
832
- if y.size == 0:
833
- return torch.ones(
834
- self.num_classes_, dtype=torch.float32, device=self.device
835
- )
1067
+ y = np.asarray(X).ravel().astype(np.int8)
836
1068
 
837
- return self._class_balanced_weights_from_mask(
838
- y=y,
839
- train_mask=np.ones_like(y, dtype=bool),
840
- num_classes=self.num_classes_,
841
- beta=self.beta,
842
- max_ratio=self.max_ratio,
843
- mode="allele", # 1D int vector
844
- ).to(self.device)
1069
+ m = y >= 0
1070
+ if train_mask is not None:
1071
+ tm = np.asarray(train_mask, dtype=bool).ravel()
1072
+ if tm.shape != y.shape:
1073
+ msg = "train_mask must have the same shape as X."
1074
+ self.logger.error(msg)
1075
+ raise ValueError(msg)
1076
+ m &= tm
1077
+
1078
+ is_hap = bool(getattr(self, "is_haploid_", False))
1079
+ num_classes = 2 if is_hap else int(self.num_classes_)
1080
+
1081
+ if not np.any(m):
1082
+ return torch.ones(num_classes, dtype=torch.long, device=self.device)
1083
+
1084
+ if is_hap:
1085
+ y = y.copy()
1086
+ y[(y == 2) & m] = 1
1087
+
1088
+ y_m = y[m]
1089
+ if y_m.size:
1090
+ ymin = int(y_m.min())
1091
+ ymax = int(y_m.max())
1092
+ if ymin < 0 or ymax >= num_classes:
1093
+ msg = (
1094
+ f"Found out-of-range labels under mask: min={ymin}, max={ymax}, "
1095
+ f"expected in [0, {num_classes - 1}]."
1096
+ )
1097
+ self.logger.error(msg)
1098
+ raise ValueError(msg)
845
1099
 
846
- @staticmethod
847
- def _normalize_class_weights(
848
- weights: torch.Tensor | None,
849
- ) -> torch.Tensor | None:
850
- """Normalize class weights once to keep loss scale stable.
1100
+ counts = np.bincount(y_m, minlength=num_classes).astype(np.float32)
1101
+ N = float(counts.sum())
1102
+ K = float(num_classes)
851
1103
 
852
- Args:
853
- weights (torch.Tensor | None): Class weights to normalize.
1104
+ w = np.zeros(num_classes, dtype=np.float32)
1105
+ nz = counts > 0
854
1106
 
855
- Returns:
856
- torch.Tensor | None: Normalized class weights or None if input is None.
857
- """
858
- if weights is None:
859
- return None
860
- return weights / weights.mean().clamp_min(1e-8)
1107
+ if np.any(nz):
1108
+ if inverse:
1109
+ w[nz] = N / (counts[nz] + eps)
1110
+ else:
1111
+ w[nz] = N / (K * (counts[nz] + eps))
861
1112
 
862
- def _get_float_genotypes(self, *, copy: bool = True) -> np.ndarray:
863
- """Float32 0/1/2 matrix with NaNs for missing, cached per dataset.
1113
+ # Amplify / soften class contrast
1114
+ if power <= 0.0:
1115
+ msg = "power must be > 0."
1116
+ self.logger.error(msg)
1117
+ raise ValueError(msg)
1118
+ if power != 1.0:
1119
+ w[nz] = np.power(w[nz], power)
864
1120
 
865
- Args:
866
- copy (bool): If True, return a copy of the cached array. Default is True.
1121
+ if np.any(~nz):
1122
+ self.logger.warning(
1123
+ "Some classes have zero count under the provided mask: "
1124
+ f"{np.where(~nz)[0].tolist()}. Setting their weights to 0."
1125
+ )
867
1126
 
868
- Returns:
869
- np.ndarray: Float32 genotype matrix with NaNs for missing values.
870
- """
871
- cache = self._float_genotype_cache
872
- current = self.pgenc.genotypes_012
873
- if cache is None or cache.shape != current.shape or cache.dtype != np.float32:
874
- arr = np.asarray(current, dtype=np.float32)
875
- arr = np.where(arr < 0, np.nan, arr)
876
- self._float_genotype_cache = arr
877
- cache = arr
878
- return cache.copy() if copy else cache
879
-
880
- def _sim_mask_cache_key(self) -> tuple | None:
881
- """Key for caching simulated-missing masks."""
882
- if not getattr(self, "simulate_missing", False):
883
- return None
884
- shape = tuple(self.pgenc.genotypes_012.shape)
885
- return (
886
- id(self.genotype_data),
887
- self.sim_strategy,
888
- round(float(self.sim_prop), 6),
889
- self.seed,
890
- shape,
1127
+ # Cap ratio among observed classes
1128
+ if max_ratio is not None and np.any(nz):
1129
+ cap = float(max_ratio)
1130
+ if cap <= 1.0:
1131
+ msg = "max_ratio must be > 1.0 or None."
1132
+ self.logger.error(msg)
1133
+ raise ValueError(msg)
1134
+
1135
+ wmin = max(float(w[nz].min()), eps)
1136
+ wmax = wmin * cap
1137
+ w[nz] = np.clip(w[nz], wmin, wmax)
1138
+
1139
+ # Optional normalization: mean(nonzero) -> 1.0
1140
+ if normalize and np.any(nz):
1141
+ mean_nz = float(w[nz].mean())
1142
+ if mean_nz > 0.0:
1143
+ w[nz] /= mean_nz
1144
+ else:
1145
+ self.logger.warning(
1146
+ "normalize=True requested, but mean of nonzero weights is not positive; skipping normalization."
1147
+ )
1148
+
1149
+ self.logger.debug(f"Class counts: {counts.astype(np.int8)}")
1150
+ self.logger.debug(
1151
+ f"Class weights (inverse={inverse}, power={power}, normalize={normalize}): {w}"
891
1152
  )
892
1153
 
893
- def _one_hot_encode_012(self, X: np.ndarray | torch.Tensor) -> torch.Tensor:
894
- """One-hot 0/1/2; -1 rows are all-zeros (B, L, K).
1154
+ return torch.as_tensor(w, dtype=torch.long, device=self.device)
895
1155
 
896
- This method performs one-hot encoding of the input genotype data (0, 1, 2) while handling missing values represented by -1. The output is a tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
1156
+ def _one_hot_encode_012(
1157
+ self, X: np.ndarray | torch.Tensor, num_classes: int | None
1158
+ ) -> torch.Tensor:
1159
+ """One-hot encode genotype calls. Missing inputs (<0) result in a vector of -1s.
897
1160
 
898
1161
  Args:
899
- X (np.ndarray | torch.Tensor): The input data to be one-hot encoded, either as a NumPy array or a PyTorch tensor.
1162
+ X (np.ndarray | torch.Tensor): Input genotype calls as integers (0,1, 2, etc.).
1163
+ num_classes (int | None): Number of classes (K). If None, uses self.num_classes_.
900
1164
 
901
1165
  Returns:
902
- torch.Tensor: A one-hot encoded tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
1166
+ torch.Tensor: One-hot encoded tensor of shape (B, L, K) with dtype
1167
+ ``torch.long``. Valid calls are 0/1, missing calls are all -1.
1168
+
1169
+ Notes:
1170
+ - Valid classes must be integers in [0, K-1]
1171
+ - Missing is any value < 0; these positions become [-1, -1, ..., -1]
1172
+ - If K==2 and values are in {0,2} (no 1s), map 2->1.
903
1173
  """
904
1174
  Xt = (
905
1175
  torch.from_numpy(X).to(self.device)
@@ -907,212 +1177,680 @@ class BaseNNImputer:
907
1177
  else X.to(self.device)
908
1178
  )
909
1179
 
910
- # B=batch, L=features, K=classes
1180
+ # Make sure we have integer class labels
1181
+ if Xt.dtype.is_floating_point:
1182
+ # Convert NaN -> -1 and cast to long
1183
+ Xt = torch.nan_to_num(Xt, nan=-1.0).long()
1184
+ else:
1185
+ Xt = Xt.long()
1186
+
911
1187
  B, L = Xt.shape
912
- K = self.num_classes_
913
- X_ohe = torch.zeros(B, L, K, dtype=torch.float32, device=self.device)
914
- valid = Xt != -1
915
- idx = Xt[valid].long()
1188
+ K = int(num_classes) if num_classes is not None else int(self.num_classes_)
1189
+
1190
+ # Missing is anything < 0 (covers -1, -9, etc.)
1191
+ valid = Xt >= 0
1192
+
1193
+ # If binary mode but data is {0,2}
1194
+ # (haploid-like or "ref vs non-ref"), map 2->1
1195
+ if K == 2:
1196
+ has_het = torch.any(valid & (Xt == 1))
1197
+ has_alt2 = torch.any(valid & (Xt == 2))
1198
+ if has_alt2 and not has_het:
1199
+ Xt = Xt.clone()
1200
+ Xt[valid & (Xt == 2)] = 1
1201
+
1202
+ # Now enforce the one-hot precondition
1203
+ if torch.any(valid & (Xt >= K)):
1204
+ bad_vals = torch.unique(Xt[valid & (Xt >= K)]).detach().cpu().tolist()
1205
+ all_vals = torch.unique(Xt[valid]).detach().cpu().tolist()
1206
+ msg = f"_one_hot_encode_012 received class values outside [0, {K-1}]. num_classes={K}, offending_values={bad_vals}, observed_values={all_vals}. Upstream encoding mismatch (e.g., passing 0/1/2 with num_classes=2)."
1207
+ self.logger.error(msg)
1208
+ raise ValueError(msg)
1209
+
1210
+ # CHANGE: Initialize with -1.0 to ensure missing values are represented as [-1, -1, ... -1]
1211
+ X_ohe = torch.full((B, L, K), -1.0, dtype=torch.long, device=self.device)
1212
+
1213
+ idx = Xt[valid]
916
1214
 
917
1215
  if idx.numel() > 0:
918
- X_ohe[valid] = F.one_hot(idx, num_classes=K).float()
1216
+ # Overwrite valid positions (which were -1) with the correct one-hot vectors
1217
+ X_ohe[valid] = F.one_hot(idx, num_classes=K).long()
919
1218
 
920
1219
  return X_ohe
921
1220
 
922
- def _eval_for_pruning(
923
- self,
924
- *,
925
- model: torch.nn.Module,
926
- X_val: np.ndarray,
927
- params: dict,
928
- metric: str,
929
- objective_mode: bool = True,
930
- do_latent_infer: bool = False,
931
- latent_steps: int = 50,
932
- latent_lr: float = 1e-2,
933
- latent_weight_decay: float = 0.0,
934
- latent_seed: int = 123,
935
- _latent_cache: dict | None = None,
936
- _latent_cache_key: str | None = None,
937
- eval_mask_override: np.ndarray | None = None,
938
- ) -> float:
939
- """Compute a scalar metric (to MAXIMIZE) on a fixed validation set.
940
-
941
- This method evaluates the model on a validation dataset and computes a specified metric, which is used for pruning decisions during hyperparameter tuning. It supports optional latent inference to optimize latent representations before evaluation. The method handles potential issues with non-finite metric values by returning negative infinity, making it easier to prune poorly performing trials.
1221
+ def decode_012(
1222
+ self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
1223
+ ) -> np.ndarray:
1224
+ """Decode 012-encodings to IUPAC chars with metadata repair.
1225
+
1226
+ This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
1227
+ 1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
1228
+ 2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
942
1229
 
943
1230
  Args:
944
- model (torch.nn.Module): The model to evaluate.
945
- X_val (np.ndarray): Validation data.
946
- params (dict): Model parameters.
947
- metric (str): Metric name to return.
948
- objective_mode (bool): If True, use objective-mode evaluation. Default is True.
949
- do_latent_infer (bool): If True, perform latent inference before evaluation. Default
950
- latent_steps (int): Number of steps for latent inference. Default is 50.
951
- latent_lr (float): Learning rate for latent inference. Default is 1e-2
952
- latent_weight_decay (float): Weight decay for latent inference. Default is 0.0.
953
- latent_seed (int): Random seed for latent inference. Default is 123.
954
- _latent_cache (dict | None): Optional cache for storing/retrieving optimized latents
955
- _latent_cache_key (str | None): Key for storing/retrieving in _latent_cache.
956
- eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
1231
+ X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
1232
+ is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
957
1233
 
958
1234
  Returns:
959
- float: The computed metric value to maximize. Returns -inf on failure.
1235
+ np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
1236
+
1237
+ Notes:
1238
+ - The method normalizes input values to handle various formats, including strings, lists, and arrays.
1239
+ - It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
1240
+ - Missing or invalid codes are represented as 'N' if they can't be resolved.
1241
+ - The method includes repair logic to infer missing metadata from the source SNP data when necessary.
1242
+
1243
+ Raises:
1244
+ ValueError: If input is not a DataFrame.
960
1245
  """
961
- optimized_val_latents = None
962
-
963
- # Optional latent inference path for models that need it.
964
- if do_latent_infer and hasattr(self, "_latent_infer_for_eval"):
965
- optimized_val_latents = self._latent_infer_for_eval( # type: ignore
966
- model=model,
967
- X_val=X_val,
968
- steps=latent_steps,
969
- lr=latent_lr,
970
- weight_decay=latent_weight_decay,
971
- seed=latent_seed,
972
- cache=_latent_cache,
973
- cache_key=_latent_cache_key,
1246
+ df = validate_input_type(X, return_type="df")
1247
+
1248
+ if not isinstance(df, pd.DataFrame):
1249
+ msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
1250
+ self.logger.error(msg)
1251
+ raise ValueError(msg)
1252
+
1253
+ # IUPAC Definitions
1254
+ iupac_to_bases: dict[str, set[str]] = {
1255
+ "A": {"A"},
1256
+ "C": {"C"},
1257
+ "G": {"G"},
1258
+ "T": {"T"},
1259
+ "R": {"A", "G"},
1260
+ "Y": {"C", "T"},
1261
+ "S": {"G", "C"},
1262
+ "W": {"A", "T"},
1263
+ "K": {"G", "T"},
1264
+ "M": {"A", "C"},
1265
+ "B": {"C", "G", "T"},
1266
+ "D": {"A", "G", "T"},
1267
+ "H": {"A", "C", "T"},
1268
+ "V": {"A", "C", "G"},
1269
+ "N": set(),
1270
+ }
1271
+ bases_to_iupac = {
1272
+ frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
1273
+ }
1274
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
1275
+
1276
+ def _normalize_iupac(value: object) -> str | None:
1277
+ """Normalize an input into a single IUPAC code token or None."""
1278
+ if value is None:
1279
+ return None
1280
+
1281
+ # Bytes -> str (make type narrowing explicit)
1282
+ if isinstance(value, (bytes, np.bytes_)):
1283
+ value = bytes(value).decode("utf-8", errors="ignore")
1284
+
1285
+ # Handle list/tuple/array/Series: take first valid
1286
+ if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
1287
+ # Convert Series to numpy array for consistent behavior
1288
+ if isinstance(value, pd.Series):
1289
+ arr = value.to_numpy()
1290
+ else:
1291
+ arr = value
1292
+
1293
+ # Scalar numpy array fast path
1294
+ if isinstance(arr, np.ndarray) and arr.ndim == 0:
1295
+ return _normalize_iupac(arr.item())
1296
+
1297
+ # Empty sequence/array
1298
+ if len(arr) == 0:
1299
+ return None
1300
+
1301
+ # First valid element wins
1302
+ for item in arr:
1303
+ code = _normalize_iupac(item)
1304
+ if code is not None:
1305
+ return code
1306
+ return None
1307
+
1308
+ s = str(value).upper().strip()
1309
+ if not s or s in missing_codes:
1310
+ return None
1311
+
1312
+ if "," in s:
1313
+ for tok in (t.strip() for t in s.split(",")):
1314
+ if tok and tok not in missing_codes and tok in iupac_to_bases:
1315
+ return tok
1316
+ return None
1317
+
1318
+ return s if s in iupac_to_bases else None
1319
+
1320
+ codes_df = df.apply(pd.to_numeric, errors="coerce")
1321
+ codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
1322
+ n_rows, n_cols = codes.shape
1323
+
1324
+ if is_nuc:
1325
+ iupac_list = np.array(
1326
+ ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
974
1327
  )
975
- # Retrieve the optimized latents from the cache
976
- if _latent_cache is not None and _latent_cache_key in _latent_cache:
977
- optimized_val_latents = _latent_cache[_latent_cache_key]
978
-
979
- if getattr(self, "_tune_eval_slice", None) is not None:
980
- X_val = X_val[self._tune_eval_slice]
981
- if eval_mask_override is not None:
982
- eval_mask_override = eval_mask_override[self._tune_eval_slice]
983
-
984
- # Child's evaluator now accepts the pre-computed latents
985
- metrics = self._evaluate_model( # type: ignore
986
- X_val=X_val,
987
- model=model,
988
- params=params,
989
- objective_mode=objective_mode,
990
- latent_vectors_val=optimized_val_latents,
991
- eval_mask_override=eval_mask_override,
992
- )
1328
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
1329
+ mask = (codes >= 0) & (codes <= 9)
1330
+ out[mask] = iupac_list[codes[mask]]
1331
+ return out
1332
+
1333
+ # Metadata fetch
1334
+ ref_alleles = getattr(self.genotype_data, "ref", [])
1335
+ alt_alleles = getattr(self.genotype_data, "alt", [])
1336
+
1337
+ if len(ref_alleles) != n_cols:
1338
+ ref_alleles = getattr(self, "_ref", [None] * n_cols)
1339
+ if len(alt_alleles) != n_cols:
1340
+ alt_alleles = getattr(self, "_alt", [None] * n_cols)
1341
+
1342
+ # Ensure list length matches
1343
+ if len(ref_alleles) != n_cols:
1344
+ ref_alleles = [None] * n_cols
1345
+ if len(alt_alleles) != n_cols:
1346
+ alt_alleles = [None] * n_cols
1347
+
1348
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
1349
+ source_snp_data = None
1350
+
1351
+ for j in range(n_cols):
1352
+ ref = _normalize_iupac(ref_alleles[j])
1353
+ alt = _normalize_iupac(alt_alleles[j])
1354
+
1355
+ # --- REPAIR LOGIC ---
1356
+ # If metadata is missing, scan the source column.
1357
+ if ref is None or alt is None:
1358
+ if source_snp_data is None and self.genotype_data.snp_data is not None:
1359
+ try:
1360
+ source_snp_data = np.asarray(self.genotype_data.snp_data)
1361
+ except Exception:
1362
+ pass # if lazy loading fails
1363
+
1364
+ if source_snp_data is not None:
1365
+ try:
1366
+ col_data = source_snp_data[:, j]
1367
+ uniques = set()
1368
+ # Optimization: check up to 200 non-empty values
1369
+ count = 0
1370
+ for val in col_data:
1371
+ norm = _normalize_iupac(val)
1372
+ if norm:
1373
+ uniques.add(norm)
1374
+ count += 1
1375
+ if len(uniques) >= 2 or count > 200:
1376
+ break
1377
+
1378
+ sorted_u = sorted(list(uniques))
1379
+ if len(sorted_u) >= 1 and ref is None:
1380
+ ref = sorted_u[0]
1381
+ if len(sorted_u) >= 2 and alt is None:
1382
+ alt = sorted_u[1]
1383
+ except Exception:
1384
+ pass
1385
+
1386
+ # --- DEFAULTS FOR MISSING ---
1387
+ # If still missing, we cannot decode.
1388
+ if ref is None and alt is None:
1389
+ ref = "N"
1390
+ alt = "N"
1391
+ elif ref is None:
1392
+ ref = alt
1393
+ elif alt is None:
1394
+ alt = ref # Monomorphic site: ALT becomes REF
1395
+
1396
+ # --- COMPUTE HET CODE ---
1397
+ if ref == alt:
1398
+ het_code = ref
1399
+ else:
1400
+ ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
1401
+ alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
1402
+ union_set = frozenset(ref_set | alt_set)
1403
+ het_code = bases_to_iupac.get(union_set, "N")
993
1404
 
994
- # Prefer the requested metric; fall back to self.tune_metric if needed.
995
- val = metrics.get(metric, metrics.get(getattr(self, "tune_metric", ""), None))
1405
+ # --- ASSIGNMENT WITH SAFETY FALLBACKS ---
1406
+ col_codes = codes[:, j]
996
1407
 
997
- if val is None or not np.isfinite(val):
998
- return -np.inf # make pruning decisions easy/robust on bad reads
1408
+ # Case 0: REF
1409
+ if ref != "N":
1410
+ out[col_codes == 0, j] = ref
999
1411
 
1000
- return float(val)
1412
+ # Case 1: HET
1413
+ if het_code != "N":
1414
+ out[col_codes == 1, j] = het_code
1415
+ else:
1416
+ # If HET code is invalid (e.g. ref='A', alt='N'),
1417
+ # fallback to REF
1418
+ # Fix for an issue where a HET prediction at a monomorphic site
1419
+ # produced 'N'
1420
+ if ref != "N":
1421
+ out[col_codes == 1, j] = ref
1422
+
1423
+ # Case 2: ALT
1424
+ if alt != "N":
1425
+ out[col_codes == 2, j] = alt
1426
+ else:
1427
+ # If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
1428
+ # Fix for an issue where an ALT prediction on a monomorphic site
1429
+ # produced 'N'
1430
+ if ref != "N":
1431
+ out[col_codes == 2, j] = ref
1432
+
1433
+ return out
1434
+
1435
+ def _save_best_params(
1436
+ self, best_params: Dict[str, Any], objective_mode: bool = False
1437
+ ) -> None:
1438
+ """Save the best hyperparameters to a JSON file.
1439
+
1440
+ This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
1441
+
1442
+ Args:
1443
+ best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
1444
+ """
1445
+ if not hasattr(self, "parameters_dir"):
1446
+ msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
1447
+ self.logger.error(msg)
1448
+ raise AttributeError(msg)
1449
+
1450
+ if objective_mode:
1451
+ fout = self.optimize_dir / "parameters" / "best_tuned_parameters.json"
1452
+ else:
1453
+ fout = self.parameters_dir / "best_parameters.json"
1454
+
1455
+ fout.parent.mkdir(parents=True, exist_ok=True)
1456
+
1457
+ with open(fout, "w") as f:
1458
+ json.dump(best_params, f, indent=4)
1001
1459
 
1002
- def _first_linear_in_features(self, model: torch.nn.Module) -> int:
1003
- """Return in_features of the model's first Linear layer.
1460
+ def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
1461
+ """An abstract method for setting best parameters."""
1462
+ raise NotImplementedError
1004
1463
 
1005
- This method iterates through the modules of the provided PyTorch model to find the first instance of a Linear layer. It then retrieves and returns the `in_features` attribute of that layer, which indicates the number of input features expected by the layer.
1464
+ def sim_missing_transform(
1465
+ self, X: np.ndarray
1466
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
1467
+ """Simulate missing data according to the specified strategy.
1006
1468
 
1007
1469
  Args:
1008
- model (torch.nn.Module): The model to inspect.
1470
+ X (np.ndarray): Genotype matrix to simulate missing data on.
1009
1471
 
1010
1472
  Returns:
1011
- int: The in_features of the first Linear layer.
1473
+ X_for_model (np.ndarray): Genotype matrix with simulated missing data.
1474
+ sim_mask (np.ndarray): Boolean mask of simulated missing entries.
1475
+ orig_mask (np.ndarray): Boolean mask of original missing entries.
1012
1476
  """
1013
- for m in model.modules():
1014
- if isinstance(m, torch.nn.Linear):
1015
- return int(m.in_features)
1016
- raise RuntimeError("No Linear layers found in model.")
1477
+ if (
1478
+ not hasattr(self, "sim_prop")
1479
+ or self.sim_prop <= 0.0
1480
+ or self.sim_prop >= 1.0
1481
+ ):
1482
+ msg = "sim_prop must be set and between 0.0 and 1.0."
1483
+ self.logger.error(msg)
1484
+ raise AttributeError(msg)
1017
1485
 
1018
- def _assert_model_latent_compat(
1019
- self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter
1020
- ) -> None:
1021
- """Raise if model's first Linear doesn't match latent_vectors width.
1486
+ if not hasattr(self, "tree_parser") and "nonrandom" in self.sim_strategy:
1487
+ msg = "tree_parser must be set for 'nonrandom' or 'nonrandom_weighted' sim_strategy."
1488
+ self.logger.error(msg)
1489
+ raise AttributeError(msg)
1022
1490
 
1023
- This method checks that the dimensionality of the provided latent vectors matches the expected input feature size of the model's first linear layer. If there is a mismatch, it raises a ValueError with a descriptive message.
1491
+ # --- Simulate missing data ---
1492
+ X_for_sim = X.astype(np.float32, copy=True)
1493
+ tr = SimMissingTransformer(
1494
+ genotype_data=self.genotype_data,
1495
+ tree_parser=self.tree_parser,
1496
+ prop_missing=self.sim_prop,
1497
+ strategy=self.sim_strategy,
1498
+ missing_val=-1,
1499
+ mask_missing=True,
1500
+ verbose=self.verbose,
1501
+ seed=self.seed,
1502
+ **self.sim_kwargs,
1503
+ )
1504
+ tr.fit(X_for_sim.copy())
1505
+ X_for_model = tr.transform(X_for_sim.copy())
1506
+ sim_mask = tr.sim_missing_mask_.astype(bool)
1507
+ orig_mask = tr.original_missing_mask_.astype(bool)
1508
+
1509
+ return X_for_model, sim_mask, orig_mask
1510
+
1511
+ def _train_val_test_split(
1512
+ self, X: np.ndarray
1513
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
1514
+ """Split data into train, validation, and test sets.
1024
1515
 
1025
1516
  Args:
1026
- model (torch.nn.Module): The model to check.
1027
- latent_vectors (torch.nn.Parameter): The latent vectors to check.
1517
+ X (np.ndarray): Genotype matrix to split.
1518
+
1519
+ Returns:
1520
+ tuple[np.ndarray, np.ndarray, np.ndarray]: Indices for train, validation, and test sets.
1028
1521
 
1029
1522
  Raises:
1030
- ValueError: If the latent dimension does not match the model's expected input features.
1523
+ ValueError: If there are not enough samples for splitting.
1524
+ AssertionError: If validation_split is not in (0.0, 1.0).
1031
1525
  """
1032
- zdim = int(latent_vectors.shape[1])
1033
- first_in = self._first_linear_in_features(model)
1034
- if first_in != zdim:
1035
- raise ValueError(
1036
- f"Latent mismatch: zdim={zdim}, model first Linear expects in_features={first_in}"
1037
- )
1526
+ n_samples = X.shape[0]
1038
1527
 
1039
- def _prepare_tuning_artifacts(self) -> None:
1040
- """Prepare data and artifacts needed for hyperparameter tuning.
1528
+ if n_samples < 3:
1529
+ msg = f"Not enough samples ({n_samples}) for train/val/test split."
1530
+ self.logger.error(msg)
1531
+ raise ValueError(msg)
1041
1532
 
1042
- This method sets up the necessary data splits, data loaders, and class weights required for hyperparameter tuning. It creates training and validation sets from the ground truth data, initializes data loaders with a specified batch size, and computes class-balanced weights based on the training data. The method also handles optional subsampling of the dataset for faster tuning and prepares slices for evaluation if needed.
1533
+ assert (
1534
+ self.validation_split > 0.0 and self.validation_split < 1.0
1535
+ ), f"validation_split must be in (0.0, 1.0), but got {self.validation_split}."
1043
1536
 
1044
- Raises:
1045
- AttributeError: If the ground truth data (`ground_truth_`) is not set.
1537
+ # Train/Val split
1538
+ indices = np.arange(n_samples)
1539
+ train_idx, val_test_idx = train_test_split(
1540
+ indices,
1541
+ test_size=self.validation_split,
1542
+ random_state=self.seed,
1543
+ )
1544
+
1545
+ if not val_test_idx.size >= 4:
1546
+ msg = f"Not enough samples ({val_test_idx.size}) for validation/test split."
1547
+ self.logger.error(msg)
1548
+ raise ValueError(msg)
1549
+
1550
+ # Split val and test equally
1551
+ val_idx, test_idx = train_test_split(
1552
+ val_test_idx, test_size=0.5, random_state=self.seed
1553
+ )
1554
+
1555
+ return train_idx, val_idx, test_idx
1556
+
1557
+ def _get_data_loaders(
1558
+ self,
1559
+ X: np.ndarray,
1560
+ y: np.ndarray,
1561
+ mask: np.ndarray,
1562
+ batch_size: int,
1563
+ *,
1564
+ shuffle: bool = True,
1565
+ ) -> torch.utils.data.DataLoader:
1566
+ """Create DataLoader for training and validation.
1567
+
1568
+ Args:
1569
+ X (np.ndarray): 0/1/2-encoded input matrix.
1570
+ y (np.ndarray): 0/1/2-encoded matrix with -1 for missing.
1571
+ mask (np.ndarray): Boolean mask of entries to score in the loss.
1572
+ batch_size (int): Batch size.
1573
+ shuffle (bool): Whether to shuffle batches. Defaults to True.
1574
+
1575
+ Returns:
1576
+ The DataLoader.
1046
1577
  """
1047
- if getattr(self, "_tune_ready", False):
1048
- return
1578
+ dataset = _MaskedNumpyDataset(X, y, mask)
1049
1579
 
1050
- X = self.ground_truth_
1051
- n_samp, n_loci = X.shape
1052
- rng = self.rng
1580
+ return torch.utils.data.DataLoader(
1581
+ dataset,
1582
+ batch_size=batch_size,
1583
+ shuffle=shuffle,
1584
+ pin_memory=(str(self.device).startswith("cuda")),
1585
+ )
1053
1586
 
1054
- if self.tune_fast:
1055
- s = min(n_samp, self.tune_max_samples)
1056
- l = n_loci if self.tune_max_loci == 0 else min(n_loci, self.tune_max_loci)
1587
+ def _update_anneal_schedule(
1588
+ self,
1589
+ final: float,
1590
+ warm: int,
1591
+ ramp: int,
1592
+ epoch: int,
1593
+ *,
1594
+ init_val: float = 0.0,
1595
+ ) -> torch.Tensor:
1596
+ """Update annealed hyperparameter value based on epoch.
1057
1597
 
1058
- samp_idx = np.sort(rng.choice(n_samp, size=s, replace=False))
1059
- loci_idx = np.sort(rng.choice(n_loci, size=l, replace=False))
1060
- X_small = X[samp_idx][:, loci_idx]
1598
+ Args:
1599
+ final (float): Final value after annealing.
1600
+ warm (int): Number of warm-up epochs.
1601
+ ramp (int): Number of ramp-up epochs.
1602
+ epoch (int): Current epoch number.
1603
+ init_val (float): Initial value before annealing starts.
1604
+
1605
+ Returns:
1606
+ torch.Tensor: Current value of the hyperparameter.
1607
+ """
1608
+ if epoch < warm:
1609
+ val = torch.tensor(init_val)
1610
+ elif epoch < warm + ramp:
1611
+ val = torch.tensor(final * ((epoch - warm) / ramp))
1061
1612
  else:
1062
- X_small = X
1613
+ val = torch.tensor(final)
1063
1614
 
1064
- idx = np.arange(X_small.shape[0])
1065
- tr, te = train_test_split(
1066
- idx, test_size=self.validation_split, random_state=self.seed
1067
- )
1068
- self._tune_train_idx = tr
1069
- self._tune_test_idx = te
1070
- self._tune_X_train = X_small[tr]
1071
- self._tune_X_test = X_small[te]
1615
+ return val.to(self.device)
1072
1616
 
1073
- self._tune_class_weights = self._normalize_class_weights(
1074
- self._class_weights_from_zygosity(self._tune_X_train)
1075
- )
1617
+ def _anneal_config(
1618
+ self,
1619
+ params: Optional[dict],
1620
+ key: str,
1621
+ default: float,
1622
+ max_epochs: int,
1623
+ *,
1624
+ warm_alt: int = 50,
1625
+ ramp_alt: int = 100,
1626
+ ) -> Tuple[float, int, int]:
1627
+ """Configure annealing schedule for a hyperparameter.
1076
1628
 
1077
- # Temporarily bump batch size only for tuning loader
1078
- orig_bs = self.batch_size
1079
- self.batch_size = self.tune_batch_size
1080
- self._tune_loader = self._get_data_loaders(self._tune_X_train) # type: ignore
1081
- self.batch_size = orig_bs
1629
+ Args:
1630
+ params (Optional[dict]): Dictionary of parameters to extract from.
1631
+ key (str): Key to look for in params.
1632
+ default (float): Default final value if not specified in params.
1633
+ max_epochs (int): Total number of training epochs.
1634
+ warm_alt (int): Alternative warm-up period if 10% of epochs is too long
1635
+ ramp_alt (int): Alternative ramp-up period if 20% of epochs is too long
1082
1636
 
1083
- self._tune_num_features = self._tune_X_train.shape[1]
1084
- self._tune_val_latents_source = None
1085
- self._tune_train_latents_source = None
1637
+ Returns:
1638
+ Tuple[float, int, int]: Final value, warm-up epochs, ramp-up epochs.
1639
+ """
1640
+ val = None
1641
+ if params is not None and params:
1642
+ if not hasattr(self, key):
1643
+ msg = f"Attribute '{key}' not found for anneal_config."
1644
+ self.logger.error(msg)
1645
+ raise AttributeError(msg)
1086
1646
 
1087
- # Optional: for huge val sets, thin them for proxy metric
1088
- if (
1089
- self.tune_proxy_metric_batch
1090
- and self._tune_X_test.shape[0] > self.tune_proxy_metric_batch
1091
- ):
1092
- self._tune_eval_slice = np.arange(self.tune_proxy_metric_batch)
1647
+ val = params.get(key, getattr(self, key))
1648
+
1649
+ if val is not None and isinstance(val, (float, int)):
1650
+ final = float(val)
1093
1651
  else:
1094
- self._tune_eval_slice = None
1652
+ final = default
1095
1653
 
1096
- self._tune_ready = True
1654
+ warm, ramp = min(int(0.1 * max_epochs), warm_alt), min(
1655
+ int(0.2 * max_epochs), ramp_alt
1656
+ )
1657
+ return final, warm, ramp
1097
1658
 
1098
- def _save_best_params(self, best_params: Dict[str, Any]) -> None:
1099
- """Save the best hyperparameters to a JSON file.
1659
+ def _repair_ref_alt_from_iupac(self, loci: np.ndarray) -> None:
1660
+ """Repair REF/ALT for specific loci using observed IUPAC genotypes.
1100
1661
 
1101
- This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
1662
+ Args:
1663
+ loci (np.ndarray): Array of locus indices to repair.
1664
+
1665
+ Notes:
1666
+ - Modifies self.genotype_data.ref and self.genotype_data.alt in place.
1667
+ """
1668
+ iupac_to_bases = {
1669
+ "A": {"A"},
1670
+ "C": {"C"},
1671
+ "G": {"G"},
1672
+ "T": {"T"},
1673
+ "R": {"A", "G"},
1674
+ "Y": {"C", "T"},
1675
+ "S": {"G", "C"},
1676
+ "W": {"A", "T"},
1677
+ "K": {"G", "T"},
1678
+ "M": {"A", "C"},
1679
+ "B": {"C", "G", "T"},
1680
+ "D": {"A", "G", "T"},
1681
+ "H": {"A", "C", "T"},
1682
+ "V": {"A", "C", "G"},
1683
+ }
1684
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
1685
+
1686
+ def norm(v: object) -> str | None:
1687
+ if v is None:
1688
+ return None
1689
+ s = str(v).upper().strip()
1690
+ if not s or s in missing_codes:
1691
+ return None
1692
+ return s if s in iupac_to_bases else None
1693
+
1694
+ snp = np.asarray(self.genotype_data.snp_data, dtype=object) # (N,L) IUPAC-ish
1695
+ refs = list(getattr(self.genotype_data, "ref", [None] * snp.shape[1]))
1696
+ alts = list(getattr(self.genotype_data, "alt", [None] * snp.shape[1]))
1697
+
1698
+ for j in loci:
1699
+ cnt = Counter()
1700
+ col = snp[:, int(j)]
1701
+ for g in col:
1702
+ code = norm(g)
1703
+ if code is None:
1704
+ continue
1705
+ for b in iupac_to_bases[code]:
1706
+ cnt[b] += 1
1707
+
1708
+ if not cnt:
1709
+ continue
1710
+
1711
+ common = [b for b, _ in cnt.most_common()]
1712
+ ref = common[0]
1713
+ alt = common[1] if len(common) > 1 else None
1714
+
1715
+ refs[int(j)] = ref
1716
+ alts[int(j)] = alt if alt is not None else "."
1717
+
1718
+ self.genotype_data.ref = np.asarray(refs, dtype=object)
1719
+
1720
+ if not isinstance(alts, np.ndarray):
1721
+ alts = np.array(alts, dtype=object).tolist()
1722
+
1723
+ self.genotype_data.alt = alts
1724
+
1725
+ def _aligned_ref_alt(self, L: int) -> tuple[list[object], list[object]]:
1726
+ """Return REF/ALT aligned to the genotype matrix columns.
1102
1727
 
1103
1728
  Args:
1104
- best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
1729
+ L (int): Number of loci (columns in genotype matrix).
1730
+
1731
+ Returns:
1732
+ tuple[list[object], list[object]]: Aligned REF and ALT lists.
1105
1733
  """
1106
- if not hasattr(self, "parameters_dir"):
1107
- msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
1734
+ refs = getattr(self.genotype_data, "ref", None)
1735
+ alts = getattr(self.genotype_data, "alt", None)
1736
+
1737
+ if refs is None or alts is None:
1738
+ msg = "genotype_data.ref/alt are required but missing."
1108
1739
  self.logger.error(msg)
1109
- raise AttributeError(msg)
1740
+ raise ValueError(msg)
1110
1741
 
1111
- fout = self.parameters_dir / "best_parameters.json"
1742
+ refs_arr = np.asarray(refs, dtype=object)
1743
+ alts_arr = np.asarray(alts, dtype=object)
1112
1744
 
1113
- with open(fout, "w") as f:
1114
- json.dump(best_params, f, indent=4)
1745
+ if refs_arr.shape[0] != L or alts_arr.shape[0] != L:
1746
+ msg = f"REF/ALT length mismatch vs matrix columns: L={L}, len(ref)={refs_arr.shape[0]}, len(alt)={alts_arr.shape[0]}. You are using REF/ALT metadata that is not aligned to pgenc.genotypes_012 columns. Fix by subsetting/refiltering ref/alt with the same locus mask used for the genotype matrix."
1747
+ self.logger.error(msg)
1748
+ raise ValueError(msg)
1115
1749
 
1116
- def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
1117
- """An abstract method for setting best parameters."""
1118
- raise NotImplementedError
1750
+ # Unwrap singleton ALT arrays like array(['T'], dtype=object)
1751
+ def unwrap(x: object) -> object:
1752
+ if isinstance(x, np.ndarray):
1753
+ if x.size == 0:
1754
+ return None
1755
+ if x.size == 1:
1756
+ return x.item()
1757
+ return x
1758
+
1759
+ refs_list = [unwrap(x) for x in refs_arr.tolist()]
1760
+ alts_list = [unwrap(x) for x in alts_arr.tolist()]
1761
+ return refs_list, alts_list
1762
+
1763
+ def _build_valid_class_mask(self) -> torch.Tensor:
1764
+ L = self.num_features_
1765
+ K = self.num_classes_
1766
+ mask = np.ones((L, K), dtype=bool)
1767
+
1768
+ # --- IUPAC helpers (single-character only) ---
1769
+ iupac_to_bases: dict[str, set[str]] = {
1770
+ "A": {"A"},
1771
+ "C": {"C"},
1772
+ "G": {"G"},
1773
+ "T": {"T"},
1774
+ "R": {"A", "G"},
1775
+ "Y": {"C", "T"},
1776
+ "S": {"G", "C"},
1777
+ "W": {"A", "T"},
1778
+ "K": {"G", "T"},
1779
+ "M": {"A", "C"},
1780
+ "B": {"C", "G", "T"},
1781
+ "D": {"A", "G", "T"},
1782
+ "H": {"A", "C", "T"},
1783
+ "V": {"A", "C", "G"},
1784
+ }
1785
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
1786
+
1787
+ # get aligned ref/alt (should be exactly length L)
1788
+ refs, alts = self._aligned_ref_alt(L)
1789
+
1790
+ def _normalize_iupac(value: object) -> str | None:
1791
+ """Return a single-letter IUPAC code or None if missing/invalid."""
1792
+ if value is None:
1793
+ return None
1794
+ if isinstance(value, (bytes, np.bytes_)):
1795
+ value = value.decode("utf-8", errors="ignore")
1796
+
1797
+ # allow list/tuple/array containers (take first valid)
1798
+ if isinstance(value, (list, tuple, np.ndarray, pd.Series)):
1799
+ for item in value:
1800
+ code = _normalize_iupac(item)
1801
+ if code is not None:
1802
+ return code
1803
+ return None
1804
+
1805
+ s = str(value).upper().strip()
1806
+ if not s or s in missing_codes:
1807
+ return None
1808
+
1809
+ # handle comma-separated values
1810
+ if "," in s:
1811
+ for tok in (t.strip() for t in s.split(",")):
1812
+ if tok and tok not in missing_codes and tok in iupac_to_bases:
1813
+ return tok
1814
+ return None
1815
+
1816
+ return s if s in iupac_to_bases else None
1817
+
1818
+ # 1) metadata restriction
1819
+ for j in range(L):
1820
+ ref = _normalize_iupac(refs[j])
1821
+ alt = _normalize_iupac(alts[j])
1822
+
1823
+ if alt is None or (ref is not None and alt == ref):
1824
+ mask[j, :] = False
1825
+ mask[j, 0] = True
1826
+
1827
+ # 2) data-driven override
1828
+ y_train = getattr(self, "y_train_", None)
1829
+ if y_train is not None:
1830
+ y = np.asarray(y_train)
1831
+ if y.ndim == 2 and y.shape[1] == L:
1832
+ if K == 2:
1833
+ y = y.copy()
1834
+ y[y == 2] = 1
1835
+ valid = y >= 0
1836
+ if valid.any():
1837
+ observed = np.zeros((L, K), dtype=bool)
1838
+ for c in range(K):
1839
+ observed[:, c] = np.any(valid & (y == c), axis=0)
1840
+
1841
+ conflict = observed & (~mask)
1842
+ if conflict.any():
1843
+ loci = np.where(conflict.any(axis=1))[0]
1844
+ self.valid_class_mask_conflict_loci_ = loci
1845
+ self.logger.warning(
1846
+ f"valid_class_mask_ metadata forbids observed classes at {loci.size} loci. "
1847
+ "Expanding mask to include observed classes."
1848
+ )
1849
+ mask |= observed
1850
+
1851
+ bad = np.where(~mask.any(axis=1))[0]
1852
+ if bad.size:
1853
+ mask[bad, :] = False
1854
+ mask[bad, 0] = True
1855
+
1856
+ return torch.as_tensor(mask, dtype=torch.bool, device=self.device)