pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,10 @@ import copy
2
2
  import gc
3
3
  import json
4
4
  import logging
5
+ from collections import Counter
6
+ from datetime import datetime
5
7
  from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
7
9
 
8
10
  import matplotlib.pyplot as plt
9
11
  import numpy as np
@@ -13,11 +15,18 @@ import plotly.graph_objects as go
13
15
  import torch
14
16
  import torch.nn.functional as F
15
17
  from matplotlib.figure import Figure
16
- from sklearn.metrics import classification_report
18
+ from sklearn.metrics import (
19
+ average_precision_score,
20
+ classification_report,
21
+ jaccard_score,
22
+ matthews_corrcoef,
23
+ )
17
24
  from sklearn.model_selection import train_test_split
18
25
  from snpio import SNPioMultiQC
19
26
  from snpio.utils.logging import LoggerManager
27
+ from snpio.utils.misc import validate_input_type
20
28
 
29
+ from pgsui.data_processing.transformers import SimMissingTransformer
21
30
  from pgsui.impute.unsupervised.nn_scorers import Scorer
22
31
  from pgsui.utils.classification_viz import ClassificationReportVisualizer
23
32
  from pgsui.utils.logging_utils import configure_logger
@@ -27,16 +36,24 @@ from pgsui.utils.pretty_metrics import PrettyMetrics
27
36
  if TYPE_CHECKING:
28
37
  from snpio.read_input.genotype_data import GenotypeData
29
38
 
30
- from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
31
- from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
32
- from pgsui.impute.unsupervised.models.ubp_model import UBPModel
33
- from pgsui.impute.unsupervised.models.vae_model import VAEModel
39
+
40
+ class _MaskedNumpyDataset(torch.utils.data.Dataset):
41
+ def __init__(self, X: np.ndarray, y: np.ndarray, mask: np.ndarray):
42
+ self.X = X
43
+ self.y = y
44
+ self.mask = mask.astype(np.bool_, copy=False)
45
+
46
+ def __len__(self) -> int:
47
+ return self.X.shape[0]
48
+
49
+ def __getitem__(self, idx: int):
50
+ return self.X[idx], self.y[idx], self.mask[idx]
34
51
 
35
52
 
36
53
  class BaseNNImputer:
37
- """An abstract base class for neural network-based imputers.
54
+ """Abstract base class for neural network-based imputers.
38
55
 
39
- This class provides a shared framework and common functionality for all neural network imputers. It is not meant to be instantiated directly. Instead, child classes should inherit from it and implement the abstract methods. Provided functionality: Directory setup and logging initialization; A hyperparameter tuning pipeline using Optuna; Utility methods for building models (`build_model`), initializing weights (`initialize_weights`), and checking for fitted attributes (`ensure_attribute`); Helper methods for calculating class weights for imbalanced data; Setup for standardized plotting and model scoring classes.
56
+ This class provides shared infrastructure for NN imputers (e.g., directory/logging setup, Optuna tuning, model construction helpers, class-weight utilities, standardized plotting/scoring, and IUPAC decoding). It is not intended to be instantiated directly; subclasses must implement the abstract methods.
40
57
  """
41
58
 
42
59
  def __init__(
@@ -54,6 +71,8 @@ class BaseNNImputer:
54
71
  This constructor sets up the device (CPU, GPU, or MPS), creates the necessary output directories for models and results, and a logger. It also initializes a genotype encoder for handling genotype data.
55
72
 
56
73
  Args:
74
+ model_name (str): The model class name used in output paths and logs.
75
+ genotype_data (GenotypeData): Backing genotype data object.
57
76
  prefix (str): A prefix used to name the output directory (e.g., 'pgsui_output').
58
77
  device (Literal["gpu", "cpu", "mps"]): The device to use for PyTorch operations. If 'gpu' or 'mps' is chosen, it will fall back to 'cpu' if the required hardware is not available. Defaults to "cpu".
59
78
  verbose (bool): If True, enables detailed logging output. Defaults to False.
@@ -62,6 +81,11 @@ class BaseNNImputer:
62
81
  self.model_name = model_name
63
82
  self.genotype_data = genotype_data
64
83
 
84
+ if not hasattr(self, "tree_parser"):
85
+ self.tree_parser = None
86
+ if not hasattr(self, "sim_kwargs"):
87
+ self.sim_kwargs = {}
88
+
65
89
  self.prefix = prefix
66
90
  self.verbose = verbose
67
91
  self.debug = debug
@@ -89,15 +113,15 @@ class BaseNNImputer:
89
113
  self.logger = configure_logger(
90
114
  logman.get_logger(), verbose=self.verbose, debug=self.debug
91
115
  )
92
- self._float_genotype_cache: np.ndarray | None = None
93
- self._sim_mask_cache: dict[tuple, np.ndarray] = {}
116
+
117
+ self.logger.info(f"Using PyTorch device: {self.device.type}.")
94
118
 
95
119
  # To be initialized by child classes or fit method
96
120
  self.tune_save_db: bool = False
97
121
  self.tune_resume: bool = False
98
122
  self.n_trials: int = 100
99
123
  self.model_params: Dict[str, Any] = {}
100
- self.tune_metric: str = "val_f1_macro"
124
+ self.tune_metric: str = "f1"
101
125
  self.learning_rate: float = 1e-3
102
126
  self.plotter_: "Plotting"
103
127
  self.num_features_: int = 0
@@ -110,21 +134,16 @@ class BaseNNImputer:
110
134
  self.show_plots: bool = False
111
135
  self.scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
112
136
  self.pgenc: Any = None
113
- self.is_haploid: bool = False
137
+ self.is_haploid_: bool = False
114
138
  self.ploidy: int = 2
115
139
  self.beta: float = 0.9999
116
- self.max_ratio: float = 5.0
117
- self.sim_strategy: str = "mcar"
118
- self.sim_prop: float = 0.1
119
- self.seed: int | None = 42
140
+ self.max_ratio: Optional[float] = None
141
+ self.sim_strategy: str = "random"
142
+ self.sim_prop: float = 0.2
143
+ self.seed: Optional[int] = None
120
144
  self.rng: np.random.Generator = np.random.default_rng(self.seed)
121
145
  self.ground_truth_: np.ndarray
122
- self.tune_fast: bool = False
123
- self.tune_max_samples: int = 1000
124
- self.tune_max_loci: int = 500
125
146
  self.validation_split: float = 0.2
126
- self.tune_batch_size: int = 64
127
- self.tune_proxy_metric_batch: int = 512
128
147
  self.batch_size: int = 64
129
148
  self.best_params_: Dict[str, Any] = {}
130
149
 
@@ -133,9 +152,11 @@ class BaseNNImputer:
133
152
  self.plots_dir: Path
134
153
  self.metrics_dir: Path
135
154
  self.parameters_dir: Path
136
- self.study_db: Path
155
+ self.study_db: Optional[Path] = None
156
+ self.X_model_input_: Optional[np.ndarray] = None
157
+ self.class_weights_: Optional[torch.Tensor] = None
137
158
 
138
- def tune_hyperparameters(self) -> None:
159
+ def tune_hyperparameters(self) -> Dict[str, Any]:
139
160
  """Tunes model hyperparameters using an Optuna study.
140
161
 
141
162
  This method orchestrates the hyperparameter search process. It creates an Optuna study that aims to maximize the metric defined in `self.tune_metric`. The search is driven by the `_objective` method, which must be implemented by the child class. After the search, the best parameters are logged, saved to a JSON file, and visualizations of the study are generated.
@@ -153,15 +174,20 @@ class BaseNNImputer:
153
174
  study_db = None
154
175
  load_if_exists = False
155
176
  if self.tune_save_db:
156
- study_db = self.optimize_dir / "study_database" / "optuna_study.db"
177
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
178
+ study_db = (
179
+ self.optimize_dir / "study_database" / f"optuna_study_{timestamp}.db"
180
+ )
157
181
  study_db.parent.mkdir(parents=True, exist_ok=True)
158
182
 
159
183
  if self.tune_resume and study_db.exists():
160
184
  load_if_exists = True
161
185
 
162
186
  if not self.tune_resume and study_db.exists():
163
- study_db.unlink()
187
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
188
+ study_db = study_db.with_name(f"optuna_study_{timestamp}.db")
164
189
 
190
+ self.study_db = study_db
165
191
  study_name = f"{self.prefix} {self.model_name} Model Optimization"
166
192
  storage = f"sqlite:///{study_db}" if self.tune_save_db else None
167
193
 
@@ -170,7 +196,17 @@ class BaseNNImputer:
170
196
  study_name=study_name,
171
197
  storage=storage,
172
198
  load_if_exists=load_if_exists,
173
- pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
199
+ pruner=optuna.pruners.MedianPruner(
200
+ # Guard against small `n_trials` values (e.g., 1)
201
+ # that can otherwise produce 0 startup/warmup/min trials.
202
+ n_startup_trials=max(
203
+ 1, min(int(self.n_trials * 0.1), 10, int(self.n_trials))
204
+ ),
205
+ n_warmup_steps=150,
206
+ n_min_trials=max(
207
+ 1, min(int(0.5 * self.n_trials), 25, int(self.n_trials))
208
+ ),
209
+ ),
174
210
  )
175
211
 
176
212
  if not hasattr(self, "_objective"):
@@ -185,6 +221,13 @@ class BaseNNImputer:
185
221
 
186
222
  show_progress_bar = not self.verbose and not self.debug and self.n_jobs == 1
187
223
 
224
+ # Set the best parameters.
225
+ # NOTE: _set_best_params() must be implemented in the child class.
226
+ if not hasattr(self, "_set_best_params"):
227
+ msg = "Method `_set_best_params()` must be implemented in the child class."
228
+ self.logger.error(msg)
229
+ raise NotImplementedError(msg)
230
+
188
231
  study.optimize(
189
232
  lambda trial: self._objective(trial),
190
233
  n_trials=self.n_trials,
@@ -193,15 +236,13 @@ class BaseNNImputer:
193
236
  show_progress_bar=show_progress_bar,
194
237
  )
195
238
 
196
- best_metric = study.best_value
197
- best_params = study.best_params
198
-
199
- # Set the best parameters.
200
- # NOTE: `_set_best_params()` must be implemented in the child class.
201
- if not hasattr(self, "_set_best_params"):
202
- msg = "Method `_set_best_params()` must be implemented in the child class."
239
+ try:
240
+ best_metric = study.best_value
241
+ best_params = study.best_params
242
+ except Exception:
243
+ msg = "Tuning failed: No successful trials completed."
203
244
  self.logger.error(msg)
204
- raise NotImplementedError(msg)
245
+ raise RuntimeError(msg)
205
246
 
206
247
  self.best_params_ = self._set_best_params(best_params)
207
248
  self.model_params.update(self.best_params_)
@@ -210,45 +251,34 @@ class BaseNNImputer:
210
251
  best_params_tmp = copy.deepcopy(best_params)
211
252
  best_params_tmp["learning_rate"] = self.learning_rate
212
253
 
213
- title = f"{self.model_name} Optimized Parameters"
214
- pm = PrettyMetrics(best_params_tmp, precision=6, title=title)
215
- pm.render()
254
+ tn = f"{self.tune_metric} Value"
216
255
 
217
- # Save best parameters to a JSON file.
218
- self._save_best_params(best_params)
256
+ if self.show_plots:
257
+ self.plotter_.plot_tuning(
258
+ study, self.model_name, self.optimize_dir / "plots", target_name=tn
259
+ )
219
260
 
220
- tn = f"{self.tune_metric} Value"
221
- self.plotter_.plot_tuning(
222
- study, self.model_name, self.optimize_dir / "plots", target_name=tn
223
- )
261
+ return best_params_tmp
224
262
 
225
263
  @staticmethod
226
264
  def initialize_weights(module: torch.nn.Module) -> None:
227
- """Initializes model weights using the Kaiming Uniform distribution.
228
-
229
- This static method is intended to be applied to a PyTorch model to initialize the weights of its linear and convolutional layers. This initialization scheme is particularly effective for networks that use ReLU-family activation functions, as it helps maintain stable activation variances during training.
265
+ """Initializes model weights using Xavier/Glorot Uniform distribution.
230
266
 
231
- Args:
232
- module (torch.nn.Module): The PyTorch module (e.g., a layer) to initialize.
267
+ Switching from Kaiming to Xavier is safer for deep VAEs to prevent
268
+ exploding gradients or dead neurons in the early epochs.
233
269
  """
234
270
  if isinstance(
235
271
  module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.ConvTranspose1d)
236
272
  ):
237
- # Use Kaiming Uniform initialization for Linear and Conv layers
238
- torch.nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
273
+ # Xavier is generally more stable for VAEs than Kaiming
274
+ torch.nn.init.xavier_uniform_(module.weight)
239
275
  if module.bias is not None:
240
276
  torch.nn.init.zeros_(module.bias)
241
277
 
242
278
  def build_model(
243
279
  self,
244
- Model: (
245
- torch.nn.Module
246
- | type["AutoencoderModel"]
247
- | type["NLPCAModel"]
248
- | type["UBPModel"]
249
- | type["VAEModel"]
250
- ),
251
- model_params: Dict[str, int | float | str | bool],
280
+ Model: Type[torch.nn.Module],
281
+ model_params: Dict[str, Any],
252
282
  ) -> torch.nn.Module:
253
283
  """Builds and initializes a neural network model instance.
254
284
 
@@ -277,20 +307,16 @@ class BaseNNImputer:
277
307
  self.logger.error(msg)
278
308
  raise AttributeError(msg)
279
309
 
280
- # Start with a base set of fixed (non-tuned) parameters.
281
- base_num_classes = getattr(self, "output_classes_", None)
282
- if base_num_classes is None:
283
- base_num_classes = self.num_classes_
284
310
  all_params = {
285
311
  "n_features": self.num_features_,
286
312
  "prefix": self.prefix,
287
- "num_classes": base_num_classes,
313
+ "num_classes": self.num_classes_,
288
314
  "verbose": self.verbose,
289
315
  "debug": self.debug,
290
316
  "device": self.device,
291
317
  }
292
318
 
293
- # Update with the variable hyperparameters from the provided dictionary
319
+ # Update with the variable hyperparameters
294
320
  all_params.update(model_params)
295
321
 
296
322
  return Model(**all_params).to(self.device)
@@ -372,110 +398,12 @@ class BaseNNImputer:
372
398
  X (np.ndarray | pd.DataFrame | list | None): The input data with missing values.
373
399
 
374
400
  Returns:
375
- np.ndarray: The data with missing values imputed.
401
+ np.ndarray: IUPAC strings with missing values imputed.
376
402
  """
377
403
  msg = "Method ``transform()`` must be implemented in the child class."
378
404
  self.logger.error(msg)
379
405
  raise NotImplementedError(msg)
380
406
 
381
- def _class_balanced_weights_from_mask(
382
- self,
383
- y: np.ndarray,
384
- train_mask: np.ndarray,
385
- num_classes: int,
386
- beta: float = 0.9999,
387
- max_ratio: float = 5.0,
388
- mode: Literal["allele", "genotype10"] = "allele",
389
- ) -> torch.Tensor:
390
- """Class-balanced weights (Cui et al. 2019) with overflow-safe effective number.
391
-
392
- mode="allele": y is 1D alleles in {0..3}, train_mask same shape. mode="genotype10": y is (nS,nF,2) alleles; train_mask is (nS,nF) loci where both alleles known.
393
-
394
- Args:
395
- y (np.ndarray): Ground truth labels.
396
- train_mask (np.ndarray): Boolean mask of training examples (same shape as y or y without last dim for genotype10).
397
- num_classes (int): Number of classes.
398
- beta (float): Hyperparameter for effective number calculation. Clamped to (0,1). Default is 0.9999.
399
- max_ratio (float): Maximum allowed ratio between largest and smallest non-zero weight. Default is 5.0.
400
- mode (Literal["allele", "genotype10"]): Whether y contains allele labels or 10-class genotypes. Default is "allele".
401
-
402
- Returns:
403
- torch.Tensor: Class weights of shape (num_classes,). Mean weight is 1.0, zero-weight classes remain zero.
404
- """
405
- if mode == "allele":
406
- valid = (y >= 0) & train_mask
407
- cls, cnt = np.unique(y[valid].astype(np.int64), return_counts=True)
408
- counts = np.zeros(num_classes, dtype=np.float64)
409
- counts[cls] = cnt
410
-
411
- elif mode == "genotype10":
412
- if y.ndim != 3 or y.shape[-1] != 2:
413
- msg = "For genotype10, y must be (nS,nF,2)."
414
- self.logger.error(msg)
415
- raise ValueError(msg)
416
-
417
- if train_mask.shape != y.shape[:2]:
418
- msg = "train_mask must be (nS,nF) for genotype10."
419
- self.logger.error(msg)
420
- raise ValueError(msg)
421
-
422
- # only loci where both alleles known and in training
423
- m = train_mask & np.all(y >= 0, axis=-1)
424
- if not np.any(m):
425
- counts = np.zeros(num_classes, dtype=np.float64)
426
-
427
- else:
428
- a1 = y[:, :, 0][m].astype(int)
429
- a2 = y[:, :, 1][m].astype(int)
430
- lo, hi = np.minimum(a1, a2), np.maximum(a1, a2)
431
- # map to 10-class index
432
- map10 = self.pgenc.map10
433
- idx10 = map10[lo, hi]
434
- idx10 = idx10[(idx10 >= 0) & (idx10 < num_classes)]
435
- counts = np.bincount(idx10, minlength=num_classes).astype(np.float64)
436
- else:
437
- msg = f"Unknown mode supplied to _class_balanced_weights_from_mask: {mode}"
438
- self.logger.error(msg)
439
- raise ValueError(msg)
440
-
441
- # ---- Effective number ----
442
- beta = float(beta)
443
-
444
- # clamp beta ∈ (0,1)
445
- if not np.isfinite(beta):
446
- beta = 0.9999
447
-
448
- beta = min(max(beta, 1e-8), 1.0 - 1e-8)
449
-
450
- logb = np.log(beta) # < 0
451
- t = counts * logb # ≤ 0
452
-
453
- # 1 - beta^n = 1 - exp(n*log(beta)) = -(exp(n*log(beta)) - 1)
454
- # use expm1 for accuracy near 0; for very negative t, eff≈1.0
455
- eff = np.where(t > -50.0, -np.expm1(t), 1.0)
456
-
457
- # class-balanced weights
458
- w = (1.0 - beta) / (eff + 1e-12)
459
-
460
- # Give unseen classes the largest non-zero weight (keeps it learnable)
461
- if np.any(counts == 0) and np.any(counts > 0):
462
- w[counts == 0] = w[counts > 0].max()
463
-
464
- # normalize by mean of non-zero
465
- nz = w > 0
466
- w[nz] /= w[nz].mean() + 1e-12
467
-
468
- # cap spread consistently with a single 'cap'
469
- cap = float(max_ratio) if max_ratio is not None else 10.0
470
- cap = max(cap, 5.0) # ensure we allow some differentiation
471
- if np.any(nz):
472
- spread = w[nz].max() / max(w[nz].min(), 1e-12)
473
- if spread > cap:
474
- scale = cap / spread
475
- w[nz] = 1.0 + (w[nz] - 1.0) * scale
476
-
477
- return torch.tensor(w.astype(np.float32), device=self.device)
478
-
479
407
  def _select_device(self, device: Literal["gpu", "cpu", "mps"]) -> torch.device:
480
408
  """Selects the appropriate PyTorch device based on user preference and availability.
481
409
 
@@ -487,36 +415,37 @@ class BaseNNImputer:
487
415
  Returns:
488
416
  torch.device: The selected PyTorch device.
489
417
  """
490
- dvc: str = device
491
- dvc = dvc.lower().strip()
418
+ dvc = device.lower().strip()
492
419
  if dvc == "cpu":
493
- self.logger.info("Using PyTorch device: CPU.")
494
420
  return torch.device("cpu")
495
421
  if dvc == "mps":
496
422
  if torch.backends.mps.is_available():
497
- self.logger.info("Using PyTorch device: mps.")
498
423
  return torch.device("mps")
499
- self.logger.warning("MPS unavailable; falling back to CPU.")
500
424
  return torch.device("cpu")
501
- # gpu
502
425
  if torch.cuda.is_available():
503
- self.logger.info("Using PyTorch device: cuda.")
504
426
  return torch.device("cuda")
505
- self.logger.warning("CUDA unavailable; falling back to CPU.")
506
427
  return torch.device("cpu")
507
428
 
508
- def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
429
+ def _create_model_directories(
430
+ self, prefix: str, outdirs: List[str], *, outdir: Path | str | None = None
431
+ ) -> None:
509
432
  """Creates the directory structure for storing model outputs.
510
433
 
511
- This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
434
+ This method sets up a standardized folder hierarchy for saving models,
435
+ plots, metrics, and optimization results, organized under a main directory
436
+ named after the provided prefix. The current implementation always uses
437
+ ``<cwd>/<prefix>_output`` regardless of ``outdir``.
512
438
 
513
439
  Args:
514
440
  prefix (str): The prefix for the main output directory.
515
441
  outdirs (List[str]): A list of subdirectory names to create within the main directory.
442
+ outdir (Path | str | None): Requested base output directory (currently ignored).
516
443
 
517
444
  Raises:
518
445
  Exception: If any of the directories cannot be created.
519
446
  """
447
+ base_root = Path(outdir) if outdir is not None else Path.cwd()
448
+ formatted_output_dir = base_root / f"{prefix}_output"
520
449
  formatted_output_dir = Path(f"{prefix}_output")
521
450
  base_dir = formatted_output_dir / "Unsupervised"
522
451
 
@@ -530,27 +459,16 @@ class BaseNNImputer:
530
459
  self.logger.error(msg)
531
460
  raise Exception(msg)
532
461
 
533
- def _clear_resources(
534
- self,
535
- model: torch.nn.Module,
536
- train_loader: torch.utils.data.DataLoader,
537
- latent_vectors: torch.nn.Parameter | None = None,
538
- ) -> None:
462
+ def _clear_resources(self, model: torch.nn.Module) -> None:
539
463
  """Releases GPU and CPU memory after an Optuna trial.
540
464
 
541
465
  This is a crucial step during hyperparameter tuning to prevent memory leaks between trials, ensuring that each trial runs in a clean environment.
542
466
 
543
467
  Args:
544
468
  model (torch.nn.Module): The model from the completed trial.
545
- train_loader (torch.utils.data.DataLoader): The data loader from the trial.
546
- latent_vectors (torch.nn.Parameter | None): The latent vectors from the trial.
547
469
  """
548
470
  try:
549
- del model, train_loader
550
-
551
- if latent_vectors is not None:
552
- del latent_vectors
553
-
471
+ del model
554
472
  except NameError:
555
473
  pass
556
474
 
@@ -571,7 +489,7 @@ class BaseNNImputer:
571
489
  y_pred: np.ndarray,
572
490
  metrics: Dict[str, float],
573
491
  msg: str,
574
- ):
492
+ ) -> None:
575
493
  """Generate and save evaluation visualizations.
576
494
 
577
495
  3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
@@ -589,20 +507,109 @@ class BaseNNImputer:
589
507
  prefix = "zygosity" if len(labels) == 3 else "iupac"
590
508
  n_labels = len(labels)
591
509
 
592
- self.plotter_.plot_metrics(
593
- y_true=y_true,
594
- y_pred_proba=y_pred_proba,
595
- metrics=metrics,
596
- label_names=labels,
597
- prefix=f"geno{n_labels}_{prefix}",
510
+ if self.show_plots:
511
+ self.plotter_.plot_metrics(
512
+ y_true=y_true,
513
+ y_pred_proba=y_pred_proba,
514
+ metrics=metrics,
515
+ label_names=labels,
516
+ prefix=f"geno{n_labels}_{prefix}",
517
+ )
518
+ self.plotter_.plot_confusion_matrix(
519
+ y_true_1d=y_true,
520
+ y_pred_1d=y_pred,
521
+ label_names=labels,
522
+ prefix=f"geno{n_labels}_{prefix}",
523
+ )
524
+
525
+ def _additional_metrics(
526
+ self,
527
+ y_true: np.ndarray,
528
+ y_pred: np.ndarray,
529
+ labels: list[int],
530
+ report_names: list[str],
531
+ report: dict,
532
+ ) -> dict[str, dict[str, float] | float]:
533
+ """Compute additional metrics and augment the report dictionary.
534
+
535
+ Args:
536
+ y_true (np.ndarray): True genotypes.
537
+ y_pred (np.ndarray): Predicted genotypes.
538
+ labels (list[int]): List of label indices.
539
+ report_names (list[str]): List of report names corresponding to labels.
540
+ report (dict): Classification report dictionary to augment.
541
+
542
+ Returns:
543
+ dict[str, dict[str, float] | float]: Augmented report dictionary with additional metrics.
544
+ """
545
+ # Create an identity matrix and use the targets array as indices
546
+ y_score = np.eye(len(report_names))[y_pred]
547
+
548
+ # Per-class metrics
549
+ ap_pc = average_precision_score(y_true, y_score, average=None)
550
+ jaccard_pc = jaccard_score(
551
+ y_true, y_pred, average=None, labels=labels, zero_division=0
598
552
  )
599
- self.plotter_.plot_confusion_matrix(
600
- y_true_1d=y_true,
601
- y_pred_1d=y_pred,
602
- label_names=labels,
603
- prefix=f"geno{n_labels}_{prefix}",
553
+
554
+ # Macro/weighted metrics
555
+ ap_macro = average_precision_score(y_true, y_score, average="macro")
556
+ ap_weighted = average_precision_score(y_true, y_score, average="weighted")
557
+ jaccard_macro = jaccard_score(y_true, y_pred, average="macro", zero_division=0)
558
+ jaccard_weighted = jaccard_score(
559
+ y_true, y_pred, average="weighted", zero_division=0
604
560
  )
605
561
 
562
+ # Matthews correlation coefficient (MCC)
563
+ mcc = matthews_corrcoef(y_true, y_pred)
564
+
565
+ if not isinstance(ap_pc, np.ndarray):
566
+ msg = "average_precision_score or f1_score did not return np.ndarray as expected."
567
+ self.logger.error(msg)
568
+ raise TypeError(msg)
569
+
570
+ if not isinstance(jaccard_pc, np.ndarray):
571
+ msg = "jaccard_score did not return np.ndarray as expected."
572
+ self.logger.error(msg)
573
+ raise TypeError(msg)
574
+
575
+ # Add per-class metrics
576
+ report_full = {}
577
+ dd_subset = {
578
+ k: v for k, v in report.items() if k in report_names and isinstance(v, dict)
579
+ }
580
+ for i, class_name in enumerate(report_names):
581
+ class_report: dict[str, float] = {}
582
+ if class_name in dd_subset:
583
+ class_report = dd_subset[class_name]
584
+
585
+ if isinstance(class_report, float) or not class_report:
586
+ continue
587
+
588
+ report_full[class_name] = dict(class_report)
589
+ report_full[class_name]["average-precision"] = float(ap_pc[i])
590
+ report_full[class_name]["jaccard"] = float(jaccard_pc[i])
591
+
592
+ macro_avg = report.get("macro avg")
593
+ if isinstance(macro_avg, dict):
594
+ report_full["macro avg"] = dict(macro_avg)
595
+ report_full["macro avg"]["average-precision"] = float(ap_macro)
596
+ report_full["macro avg"]["jaccard"] = float(jaccard_macro)
597
+
598
+ weighted_avg = report.get("weighted avg")
599
+ if isinstance(weighted_avg, dict):
600
+ report_full["weighted avg"] = dict(weighted_avg)
601
+ report_full["weighted avg"]["average-precision"] = float(ap_weighted)
602
+ report_full["weighted avg"]["jaccard"] = float(jaccard_weighted)
603
+
604
+ # Add scalar summary metrics
605
+ report_full["mcc"] = float(mcc)
606
+ accuracy_val = report.get("accuracy")
607
+
608
+ if isinstance(accuracy_val, (int, float)):
609
+ report_full["accuracy"] = float(accuracy_val)
610
+
611
+ return report_full
612
+
606
613
  def _make_class_reports(
607
614
  self,
608
615
  y_true: np.ndarray,
@@ -620,27 +627,28 @@ class BaseNNImputer:
620
627
  y_pred (np.ndarray): Predicted labels (1D array).
621
628
  metrics (Dict[str, float]): Computed metrics.
622
629
  y_pred_proba (np.ndarray | None): Predicted probabilities (2D array). Defaults to None.
623
- labels (List[str]): Class label names
624
- (default: ["REF", "HET", "ALT"] for 3-class).
630
+ labels (List[str]): Class label names (default: ["REF", "HET", "ALT"] for 3-class).
625
631
  """
626
- report_name = "zygosity" if len(labels) == 3 else "iupac"
632
+ report_name = "zygosity" if len(labels) <= 3 else "iupac"
627
633
  middle = "IUPAC" if report_name == "iupac" else "Zygosity"
628
634
 
629
- msg = f"{middle} Report (on {y_true.size} total genotypes)"
635
+ msg = f"{middle} Report (on {y_pred.size} total genotypes)"
630
636
  self.logger.info(msg)
631
637
 
632
638
  if y_pred_proba is not None:
633
- self.plotter_.plot_metrics(
634
- y_true,
635
- y_pred_proba,
636
- metrics,
637
- label_names=labels,
638
- prefix=report_name,
639
- )
639
+ if self.show_plots:
640
+ self.plotter_.plot_metrics(
641
+ y_true,
642
+ y_pred_proba,
643
+ metrics,
644
+ label_names=labels,
645
+ prefix=report_name,
646
+ )
640
647
 
641
- self.plotter_.plot_confusion_matrix(
642
- y_true, y_pred, label_names=labels, prefix=report_name
643
- )
648
+ if self.show_plots:
649
+ self.plotter_.plot_confusion_matrix(
650
+ y_true, y_pred, label_names=labels, prefix=report_name
651
+ )
644
652
 
645
653
  report: str | dict = classification_report(
646
654
  y_true,
@@ -653,62 +661,63 @@ class BaseNNImputer:
653
661
 
654
662
  if not isinstance(report, dict):
655
663
  msg = "Expected classification_report to return a dict."
656
- self.logger.error(msg)
664
+ self.logger.error(msg, exc_info=True)
657
665
  raise ValueError(msg)
658
666
 
659
- report_subset = {}
660
- for k, v in report.items():
661
- tmp = {}
662
- if isinstance(v, dict) and "support" in v:
663
- for k2, v2 in v.items():
664
- if k2 != "support":
665
- tmp[k2] = v2
666
- if tmp:
667
- report_subset[k] = tmp
668
-
669
- if report_subset:
667
+ if self.show_plots:
668
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
669
+ try:
670
+ plots = viz.plot_all(
671
+ report, # type: ignore
672
+ title_prefix=f"{self.model_name} {middle} Report",
673
+ show=self.show_plots,
674
+ heatmap_classes_only=True,
675
+ )
676
+ finally:
677
+ viz._reset_mpl_style()
678
+
679
+ for name, fig in plots.items():
680
+ fout = (
681
+ self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
682
+ )
683
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
684
+ fig.savefig(fout, dpi=300, facecolor="#111122")
685
+ plt.close(fig)
686
+ elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
687
+ fout_html = fout.with_suffix(".html")
688
+ fig.write_html(file=fout_html)
689
+
690
+ SNPioMultiQC.queue_html(
691
+ fout_html,
692
+ panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
693
+ section=f"PG-SUI: {self.model_name} Model Imputation",
694
+ title=f"{self.model_name} {middle} Radar Plot",
695
+ index_label=name,
696
+ description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
697
+ )
698
+
699
+ if not self.is_haploid_:
700
+ msg = f"Ploidy: {self.ploidy}. Evaluating per genotype (REF, HET, ALT)."
701
+ self.logger.info(msg)
702
+
703
+ report_full = self._additional_metrics(
704
+ y_true,
705
+ y_pred,
706
+ labels=list(range(len(labels))),
707
+ report_names=labels,
708
+ report=report,
709
+ )
710
+
711
+ if self.verbose or self.debug:
670
712
  pm = PrettyMetrics(
671
- report_subset,
672
- precision=3,
713
+ report_full,
714
+ precision=2,
673
715
  title=f"{self.model_name} {middle} Report",
674
716
  )
675
717
  pm.render()
676
718
 
677
719
  with open(self.metrics_dir / f"{report_name}_report.json", "w") as f:
678
- json.dump(report, f, indent=4)
679
-
680
- viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
681
-
682
- plots = viz.plot_all(
683
- report, # type: ignore
684
- title_prefix=f"{self.model_name} {middle} Report",
685
- show=getattr(self, "show_plots", False),
686
- heatmap_classes_only=True,
687
- )
688
-
689
- for name, fig in plots.items():
690
- fout = self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
691
- if hasattr(fig, "savefig") and isinstance(fig, Figure):
692
- fig.savefig(fout, dpi=300, facecolor="#111122")
693
- plt.close(fig)
694
- elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
695
- fout_html = fout.with_suffix(".html")
696
- fig.write_html(file=fout_html)
697
-
698
- SNPioMultiQC.queue_html(
699
- fout_html,
700
- panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
701
- section=f"PG-SUI: {self.model_name} Model Imputation",
702
- title=f"{self.model_name} {middle} Radar Plot",
703
- index_label=name,
704
- description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
705
- )
706
-
707
- if not self.is_haploid:
708
- msg = f"Ploidy: {self.ploidy}. Evaluating per allele."
709
- self.logger.info(msg)
710
-
711
- viz._reset_mpl_style()
720
+ json.dump(report_full, f, indent=4)
712
721
 
713
722
  def _compute_hidden_layer_sizes(
714
723
  self,
@@ -716,6 +725,7 @@ class BaseNNImputer:
716
725
  n_outputs: int,
717
726
  n_samples: int,
718
727
  n_hidden: int,
728
+ latent_dim: int,
719
729
  *,
720
730
  alpha: float = 4.0,
721
731
  schedule: str = "pyramid",
@@ -727,182 +737,439 @@ class BaseNNImputer:
727
737
  ) -> list[int]:
728
738
  """Compute hidden layer sizes given problem scale and a layer count.
729
739
 
730
- This method computes a list of hidden layer sizes based on the number of input features, output classes, training samples, and desired hidden layers. The sizes are determined using a specified schedule (pyramid, constant, or linear) and are constrained by minimum and maximum sizes, as well as rounding to multiples of a specified value.
740
+ Notes:
741
+ - Returns sizes for *hidden layers only* (length = n_hidden).
742
+ - Does NOT include the input layer (n_inputs) or the latent layer (latent_dim).
743
+ - Enforces a latent-aware minimum: one discrete level above latent_dim, where a level is `multiple_of`.
744
+ - Enforces *strictly decreasing* hidden sizes (no repeats). This may require bumping `base` upward.
731
745
 
732
746
  Args:
733
- n_inputs (int): Number of input features.
734
- n_outputs (int): Number of output classes.
735
- n_samples (int): Number of training samples.
736
- n_hidden (int): Number of hidden layers.
737
- alpha (float): Scaling factor for base layer size. Default is 4.0.
738
- schedule (Literal["pyramid", "constant", "linear"]): Size schedule. Default is "pyramid".
739
- min_size (int): Minimum layer size. Default is 16.
740
- max_size (int | None): Maximum layer size. Default is None (no limit).
741
- multiple_of (int): Round layer sizes to be multiples of this. Default is 8.
742
- decay (float | None): Decay factor for "pyramid" schedule. If None, it is computed automatically. Default is None.
743
- cap_by_inputs (bool): If True, cap layer sizes to n_inputs. Default is True.
747
+ n_inputs: Number of input features (e.g., flattened one-hot: num_features * num_classes).
748
+ n_outputs: Number of output classes (often equals num_classes).
749
+ n_samples: Number of training samples.
750
+ n_hidden: Number of hidden layers (excluding input and latent layers).
751
+ latent_dim: Latent dimensionality (not returned, used only to set a floor).
752
+ alpha: Scaling factor for base layer size.
753
+ schedule: Size schedule ("pyramid" or "linear").
754
+ min_size: Minimum layer size floor before latent-aware adjustment.
755
+ max_size: Maximum layer size cap. If None, a heuristic cap is used.
756
+ multiple_of: Hidden sizes are multiples of this value.
757
+ decay: Pyramid decay factor. If None, computed to land near the target.
758
+ cap_by_inputs: If True, cap layer sizes to n_inputs.
744
759
 
745
760
  Returns:
746
- list[int]: List of hidden layer sizes.
761
+ list[int]: Hidden layer sizes (len = n_hidden).
747
762
 
748
763
  Raises:
749
- ValueError: If n_hidden < 0 or if alpha * (n_inputs + n_outputs) <= 0 or if schedule is unknown.
750
- TypeError: If any argument is not of the expected type.
751
-
752
- Notes:
753
- - If n_hidden is 0, returns an empty list.
754
- - The base layer size is computed as ceil(n_samples / (alpha * (n_inputs + n_outputs))).
755
- - The sizes are adjusted according to the specified schedule and constraints.
764
+ ValueError: On invalid arguments or conflicting constraints.
756
765
  """
766
+ # ----------------------------
767
+ # Basic validation
768
+ # ----------------------------
757
769
  if n_hidden < 0:
758
770
  msg = f"n_hidden must be >= 0, got {n_hidden}."
759
771
  self.logger.error(msg)
760
772
  raise ValueError(msg)
761
773
 
762
- if schedule not in {"pyramid", "constant", "linear"}:
763
- msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
774
+ if n_hidden == 0:
775
+ return []
776
+
777
+ if n_inputs <= 0:
778
+ msg = f"n_inputs must be > 0, got {n_inputs}."
764
779
  self.logger.error(msg)
765
780
  raise ValueError(msg)
766
781
 
767
- if n_hidden == 0:
768
- return []
782
+ if n_outputs <= 0:
783
+ msg = f"n_outputs must be > 0, got {n_outputs}."
784
+ self.logger.error(msg)
785
+ raise ValueError(msg)
769
786
 
770
- denom = float(alpha) * float(n_inputs + n_outputs)
787
+ if n_samples <= 0:
788
+ msg = f"n_samples must be > 0, got {n_samples}."
789
+ self.logger.error(msg)
790
+ raise ValueError(msg)
771
791
 
772
- if denom <= 0:
773
- msg = f"alpha * (n_inputs + n_outputs) must be > 0, got {denom}."
792
+ if latent_dim <= 0:
793
+ msg = f"latent_dim must be > 0, got {latent_dim}."
774
794
  self.logger.error(msg)
775
795
  raise ValueError(msg)
776
796
 
777
- base = int(np.ceil(float(n_samples) / denom))
797
+ if multiple_of <= 0:
798
+ msg = f"multiple_of must be > 0, got {multiple_of}."
799
+ self.logger.error(msg)
800
+ raise ValueError(msg)
801
+
802
+ if alpha <= 0:
803
+ msg = f"alpha must be > 0, got {alpha}."
804
+ self.logger.error(msg)
805
+ raise ValueError(msg)
806
+
807
+ schedule = str(schedule).lower().strip()
808
+ if schedule not in {"pyramid", "linear"}:
809
+ msg = f"Invalid schedule '{schedule}'. Must be 'pyramid' or 'linear'."
810
+ self.logger.error(msg)
811
+ raise ValueError(msg)
812
+
813
+ # ----------------------------
814
+ # Latent-aware minimum floor
815
+ # ----------------------------
816
+ # Smallest multiple_of strictly greater than latent_dim
817
+ min_hidden_floor = int(np.ceil((latent_dim + 1) / multiple_of) * multiple_of)
818
+ effective_min = max(int(min_size), min_hidden_floor)
819
+
820
+ if cap_by_inputs and n_inputs < effective_min:
821
+ msg = (
822
+ "Cannot satisfy latent-aware minimum hidden size with cap_by_inputs=True. "
823
+ f"Required hidden size >= {effective_min} (one level above latent_dim={latent_dim}), "
824
+ f"but n_inputs={n_inputs}. Set cap_by_inputs=False or reduce latent_dim/multiple_of."
825
+ )
826
+ self.logger.error(msg)
827
+ raise ValueError(msg)
828
+
829
+ # ----------------------------
830
+ # Infer num_features (if using flattened one-hot: n_inputs = num_features * num_classes)
831
+ # ----------------------------
832
+ if n_inputs % n_outputs == 0:
833
+ num_features = n_inputs // n_outputs
834
+ else:
835
+ num_features = n_inputs
836
+ self.logger.warning(
837
+ "n_inputs is not divisible by n_outputs; falling back to num_features=n_inputs "
838
+ f"(n_inputs={n_inputs}, n_outputs={n_outputs}). If using one-hot flattening, "
839
+ "pass n_outputs=num_classes so num_features can be inferred correctly."
840
+ )
778
841
 
842
+ # ----------------------------
843
+ # Base size heuristic (feature-matrix aware; avoids collapse for huge n_inputs)
844
+ # ----------------------------
845
+ obs_scale = (float(n_samples) * float(num_features)) / float(
846
+ num_features + n_outputs
847
+ )
848
+ base = int(np.ceil(float(alpha) * np.sqrt(obs_scale)))
849
+
850
+ # ----------------------------
851
+ # Determine max_size
852
+ # ----------------------------
779
853
  if max_size is None:
780
- max_size = max(n_inputs, base)
854
+ max_size = max(int(n_inputs), int(base), int(effective_min))
855
+
856
+ if cap_by_inputs:
857
+ max_size = min(int(max_size), int(n_inputs))
858
+ else:
859
+ max_size = int(max_size)
860
+
861
+ if max_size < effective_min:
862
+ msg = (
863
+ f"max_size ({max_size}) must be >= effective_min ({effective_min}), where effective_min "
864
+ f"is max(min_size={min_size}, one-level-above latent_dim={latent_dim})."
865
+ )
866
+ self.logger.error(msg)
867
+ raise ValueError(msg)
868
+
869
+ # Round base up to a multiple and clip to bounds
870
+ base = int(np.clip(base, effective_min, max_size))
871
+ base = int(np.ceil(base / multiple_of) * multiple_of)
872
+ base = int(np.clip(base, effective_min, max_size))
873
+
874
+ # ----------------------------
875
+ # Enforce "no repeats" feasibility in discrete levels
876
+ # Need n_hidden distinct multiples between base and effective_min:
877
+ # base >= effective_min + (n_hidden - 1) * multiple_of
878
+ # ----------------------------
879
+ required_min_base = effective_min + (n_hidden - 1) * multiple_of
880
+
881
+ if required_min_base > max_size:
882
+ msg = (
883
+ "Cannot build strictly-decreasing (no-repeat) hidden sizes under current constraints. "
884
+ f"Need base >= {required_min_base} to fit n_hidden={n_hidden} distinct layers "
885
+ f"with multiple_of={multiple_of} down to effective_min={effective_min}, "
886
+ f"but max_size={max_size}. Reduce n_hidden, reduce multiple_of, lower latent_dim/min_size, "
887
+ "or increase max_size / set cap_by_inputs=False."
888
+ )
889
+ self.logger.error(msg)
890
+ raise ValueError(msg)
781
891
 
782
- base = int(np.clip(base, min_size, max_size))
892
+ if base < required_min_base:
893
+ # Bump base upward so a strict staircase is possible
894
+ base = required_min_base
895
+ base = int(np.ceil(base / multiple_of) * multiple_of)
896
+ base = int(np.clip(base, effective_min, max_size))
783
897
 
784
- if schedule == "constant":
785
- sizes = np.full(shape=(n_hidden,), fill_value=base, dtype=float)
898
+ # Work in "levels" of multiple_of for guaranteed uniqueness
899
+ start_level = base // multiple_of
900
+ end_level = effective_min // multiple_of
901
+
902
+ # Sanity: distinct levels available
903
+ if (start_level - end_level) < (n_hidden - 1):
904
+ # This should not happen due to required_min_base logic, but keep a hard guard.
905
+ msg = (
906
+ "Internal constraint failure: insufficient discrete levels to enforce no repeats. "
907
+ f"start_level={start_level}, end_level={end_level}, n_hidden={n_hidden}."
908
+ )
909
+ self.logger.error(msg)
910
+ raise ValueError(msg)
911
+
912
+ # ----------------------------
913
+ # Build schedule in level space (integers), then convert to sizes
914
+ # ----------------------------
915
+ if n_hidden == 1:
916
+ levels = np.array([start_level], dtype=int)
786
917
 
787
918
  elif schedule == "linear":
788
- target = max(min_size, min(base, base // 4))
789
- sizes = (
790
- np.array([base], dtype=float)
791
- if n_hidden == 1
792
- else np.linspace(base, target, num=n_hidden, dtype=float)
919
+ # Linear interpolation in level space, then strictify
920
+ levels = np.round(np.linspace(start_level, end_level, num=n_hidden)).astype(
921
+ int
793
922
  )
794
923
 
924
+ # Enforce bounds then strict decrease
925
+ levels = np.clip(levels, end_level, start_level)
926
+
927
+ for i in range(1, n_hidden):
928
+ if levels[i] >= levels[i - 1]:
929
+ levels[i] = levels[i - 1] - 1
930
+
931
+ if levels[-1] < end_level:
932
+ msg = (
933
+ "Failed to enforce strictly-decreasing linear schedule without violating the floor. "
934
+ f"(levels[-1]={levels[-1]} < end_level={end_level}). "
935
+ "Reduce n_hidden or multiple_of, or increase max_size."
936
+ )
937
+ self.logger.error(msg)
938
+ raise ValueError(msg)
939
+
940
+ # Force exact floor at the end (still strict because we have enough room by construction)
941
+ levels[-1] = end_level
942
+ for i in range(n_hidden - 2, -1, -1):
943
+ if levels[i] <= levels[i + 1]:
944
+ levels[i] = levels[i + 1] + 1
945
+
946
+ if levels[0] > start_level:
947
+ # If this happens, we would need an even larger base; handle by raising base once.
948
+ needed_base = int(levels[0] * multiple_of)
949
+ if needed_base > max_size:
950
+ msg = (
951
+ "Cannot enforce strictly-decreasing linear schedule after floor anchoring; "
952
+ f"would require base={needed_base} > max_size={max_size}."
953
+ )
954
+ self.logger.error(msg)
955
+ raise ValueError(msg)
956
+ # Rebuild with bumped base
957
+ start_level = needed_base // multiple_of
958
+ levels = np.arange(start_level, start_level - n_hidden, -1, dtype=int)
959
+ levels[-1] = end_level # keep floor
960
+ # Ensure strict with backward adjust
961
+ for i in range(n_hidden - 2, -1, -1):
962
+ if levels[i] <= levels[i + 1]:
963
+ levels[i] = levels[i + 1] + 1
964
+
795
965
  elif schedule == "pyramid":
796
- if n_hidden == 1:
797
- sizes = np.array([base], dtype=float)
966
+ # Geometric decay in level space (more aggressive early taper than linear)
967
+ if decay is not None:
968
+ dcy = float(decay)
798
969
  else:
970
+ # Choose decay to land exactly at end_level (in float space)
971
+ dcy = (float(end_level) / float(start_level)) ** (
972
+ 1.0 / float(n_hidden - 1)
973
+ )
974
+
975
+ # Keep it in a sensible range
976
+ dcy = float(np.clip(dcy, 0.05, 0.99))
977
+
978
+ exponents = np.arange(n_hidden, dtype=float)
979
+ levels_float = float(start_level) * (dcy**exponents)
980
+
981
+ levels = np.round(levels_float).astype(int)
982
+ levels = np.clip(levels, end_level, start_level)
983
+
984
+ # Anchor the last layer at the floor, then strictify backward
985
+ levels[-1] = end_level
986
+ for i in range(n_hidden - 2, -1, -1):
987
+ if levels[i] <= levels[i + 1]:
988
+ levels[i] = levels[i + 1] + 1
989
+
990
+ # If we overshot the start, bump base (once) if possible, then rebuild
991
+ if levels[0] > start_level:
992
+ needed_base = int(levels[0] * multiple_of)
993
+ if needed_base > max_size:
994
+ msg = (
995
+ "Cannot enforce strictly-decreasing pyramid schedule; "
996
+ f"would require base={needed_base} > max_size={max_size}."
997
+ )
998
+ self.logger.error(msg)
999
+ raise ValueError(msg)
1000
+
1001
+ start_level = needed_base // multiple_of
1002
+ # Recompute with new start_level and same decay (or recompute decay if decay is None)
799
1003
  if decay is None:
800
- target = max(min_size, base // 4)
801
- if base <= 0 or target <= 0:
802
- dcy = 1.0
803
- else:
804
- dcy = (target / float(base)) ** (1.0 / (n_hidden - 1))
805
- dcy = float(np.clip(dcy, 0.25, 0.99))
806
- exponents = np.arange(n_hidden, dtype=float)
807
- sizes = base * (dcy**exponents)
1004
+ dcy = (float(end_level) / float(start_level)) ** (
1005
+ 1.0 / float(n_hidden - 1)
1006
+ )
1007
+ dcy = float(np.clip(dcy, 0.05, 0.99))
1008
+
1009
+ levels_float = float(start_level) * (dcy**exponents)
1010
+ levels = np.round(levels_float).astype(int)
1011
+ levels = np.clip(levels, end_level, start_level)
1012
+ levels[-1] = end_level
1013
+ for i in range(n_hidden - 2, -1, -1):
1014
+ if levels[i] <= levels[i + 1]:
1015
+ levels[i] = levels[i + 1] + 1
808
1016
 
809
1017
  else:
810
- msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
1018
+ msg = f"Unknown schedule '{schedule}'. Use 'pyramid' or 'linear' (constant disallowed with no repeats)."
811
1019
  self.logger.error(msg)
812
1020
  raise ValueError(msg)
813
1021
 
814
- sizes = np.clip(sizes, min_size, max_size)
1022
+ # Convert levels -> sizes
1023
+ sizes = (levels * multiple_of).astype(int)
815
1024
 
816
- if cap_by_inputs:
817
- sizes = np.minimum(sizes, float(n_inputs))
1025
+ # Final clip (should be redundant, but safe)
1026
+ sizes = np.clip(sizes, effective_min, max_size).astype(int)
818
1027
 
819
- sizes = (np.ceil(sizes / multiple_of) * multiple_of).astype(int)
820
- sizes = np.minimum.accumulate(sizes)
821
- return np.clip(sizes, min_size, max_size).astype(int).tolist()
1028
+ # Final strict no-repeat assertion
1029
+ if np.any(np.diff(sizes) >= 0):
1030
+ msg = (
1031
+ "Internal error: produced non-decreasing or repeated hidden sizes after strict enforcement. "
1032
+ f"sizes={sizes.tolist()}"
1033
+ )
1034
+ self.logger.error(msg)
1035
+ raise ValueError(msg)
822
1036
 
823
- def _class_weights_from_zygosity(self, X: np.ndarray) -> torch.Tensor:
824
- """Class-balanced weights for 0/1/2 (handles haploid collapse if needed).
1037
+ return sizes.tolist()
825
1038
 
826
- This method computes class-balanced weights for the genotype classes (0/1/2) based on the provided genotype matrix. It handles cases where the data is haploid by collapsing the ALT class to 1, effectively treating the problem as binary classification (REF vs ALT). The weights are calculated using a class-balanced weighting scheme that considers the frequency of each class in the training data, with parameters for beta and maximum ratio to control the weighting behavior. The resulting weights are returned as a PyTorch tensor on the current device.
1039
+ def _class_weights_from_zygosity(
1040
+ self,
1041
+ X: np.ndarray,
1042
+ train_mask: Optional[np.ndarray] = None,
1043
+ *,
1044
+ inverse: bool = False,
1045
+ normalize: bool = False,
1046
+ power: float = 1.0,
1047
+ max_ratio: float | None = None,
1048
+ eps: float = 1e-12,
1049
+ ) -> torch.Tensor:
1050
+ """Compute class weights for zygosity labels.
827
1051
 
828
- Args:
829
- X (np.ndarray): 0/1/2 with -1 for missing.
1052
+ If inverse=False (default):
1053
+ w_c = N / (K * n_c) ("balanced")
1054
+
1055
+ If inverse=True:
1056
+ w_c = N / n_c (same ratios, scaled by K)
1057
+
1058
+ If power != 1.0:
1059
+ w_c <- w_c ** power (amplifies or softens imbalance handling)
1060
+
1061
+ If normalize=True:
1062
+ rescales nonzero weights so mean(nonzero_weights) == 1.
830
1063
 
831
1064
  Returns:
832
- torch.Tensor: Weights on current device.
1065
+ torch.Tensor: Class weights of shape (num_classes,) on self.device.
833
1066
  """
834
- y = X[X != -1].ravel().astype(np.int64)
835
- if y.size == 0:
836
- return torch.ones(
837
- self.num_classes_, dtype=torch.float32, device=self.device
838
- )
1067
+ y = np.asarray(X).ravel().astype(np.int8)
839
1068
 
840
- return self._class_balanced_weights_from_mask(
841
- y=y,
842
- train_mask=np.ones_like(y, dtype=bool),
843
- num_classes=self.num_classes_,
844
- beta=self.beta,
845
- max_ratio=self.max_ratio,
846
- mode="allele", # 1D int vector
847
- ).to(self.device)
1069
+ m = y >= 0
1070
+ if train_mask is not None:
1071
+ tm = np.asarray(train_mask, dtype=bool).ravel()
1072
+ if tm.shape != y.shape:
1073
+ msg = "train_mask must have the same shape as X."
1074
+ self.logger.error(msg)
1075
+ raise ValueError(msg)
1076
+ m &= tm
1077
+
1078
+ is_hap = bool(getattr(self, "is_haploid_", False))
1079
+ num_classes = 2 if is_hap else int(self.num_classes_)
1080
+
1081
+ if not np.any(m):
1082
+ return torch.ones(num_classes, dtype=torch.long, device=self.device)
1083
+
1084
+ if is_hap:
1085
+ y = y.copy()
1086
+ y[(y == 2) & m] = 1
1087
+
1088
+ y_m = y[m]
1089
+ if y_m.size:
1090
+ ymin = int(y_m.min())
1091
+ ymax = int(y_m.max())
1092
+ if ymin < 0 or ymax >= num_classes:
1093
+ msg = (
1094
+ f"Found out-of-range labels under mask: min={ymin}, max={ymax}, "
1095
+ f"expected in [0, {num_classes - 1}]."
1096
+ )
1097
+ self.logger.error(msg)
1098
+ raise ValueError(msg)
848
1099
 
849
- @staticmethod
850
- def _normalize_class_weights(
851
- weights: torch.Tensor | None,
852
- ) -> torch.Tensor | None:
853
- """Normalize class weights once to keep loss scale stable.
1100
+ counts = np.bincount(y_m, minlength=num_classes).astype(np.float32)
1101
+ N = float(counts.sum())
1102
+ K = float(num_classes)
854
1103
 
855
- Args:
856
- weights (torch.Tensor | None): Class weights to normalize.
1104
+ w = np.zeros(num_classes, dtype=np.float32)
1105
+ nz = counts > 0
857
1106
 
858
- Returns:
859
- torch.Tensor | None: Normalized class weights or None if input is None.
860
- """
861
- if weights is None:
862
- return None
863
- return weights / weights.mean().clamp_min(1e-8)
1107
+ if np.any(nz):
1108
+ if inverse:
1109
+ w[nz] = N / (counts[nz] + eps)
1110
+ else:
1111
+ w[nz] = N / (K * (counts[nz] + eps))
864
1112
 
865
- def _get_float_genotypes(self, *, copy: bool = True) -> np.ndarray:
866
- """Float32 0/1/2 matrix with NaNs for missing, cached per dataset.
1113
+ # Amplify / soften class contrast
1114
+ if power <= 0.0:
1115
+ msg = "power must be > 0."
1116
+ self.logger.error(msg)
1117
+ raise ValueError(msg)
1118
+ if power != 1.0:
1119
+ w[nz] = np.power(w[nz], power)
867
1120
 
868
- Args:
869
- copy (bool): If True, return a copy of the cached array. Default is True.
1121
+ if np.any(~nz):
1122
+ self.logger.warning(
1123
+ "Some classes have zero count under the provided mask: "
1124
+ f"{np.where(~nz)[0].tolist()}. Setting their weights to 0."
1125
+ )
870
1126
 
871
- Returns:
872
- np.ndarray: Float32 genotype matrix with NaNs for missing values.
873
- """
874
- cache = self._float_genotype_cache
875
- current = self.pgenc.genotypes_012
876
- if cache is None or cache.shape != current.shape or cache.dtype != np.float32:
877
- arr = np.asarray(current, dtype=np.float32)
878
- arr = np.where(arr < 0, np.nan, arr)
879
- self._float_genotype_cache = arr
880
- cache = arr
881
- return cache.copy() if copy else cache
882
-
883
- def _sim_mask_cache_key(self) -> tuple | None:
884
- """Key for caching simulated-missing masks."""
885
- if not getattr(self, "simulate_missing", False):
886
- return None
887
- shape = tuple(self.pgenc.genotypes_012.shape)
888
- return (
889
- id(self.genotype_data),
890
- self.sim_strategy,
891
- round(float(self.sim_prop), 6),
892
- self.seed,
893
- shape,
1127
+ # Cap ratio among observed classes
1128
+ if max_ratio is not None and np.any(nz):
1129
+ cap = float(max_ratio)
1130
+ if cap <= 1.0:
1131
+ msg = "max_ratio must be > 1.0 or None."
1132
+ self.logger.error(msg)
1133
+ raise ValueError(msg)
1134
+
1135
+ wmin = max(float(w[nz].min()), eps)
1136
+ wmax = wmin * cap
1137
+ w[nz] = np.clip(w[nz], wmin, wmax)
1138
+
1139
+ # Optional normalization: mean(nonzero) -> 1.0
1140
+ if normalize and np.any(nz):
1141
+ mean_nz = float(w[nz].mean())
1142
+ if mean_nz > 0.0:
1143
+ w[nz] /= mean_nz
1144
+ else:
1145
+ self.logger.warning(
1146
+ "normalize=True requested, but mean of nonzero weights is not positive; skipping normalization."
1147
+ )
1148
+
1149
+ self.logger.debug(f"Class counts: {counts.astype(np.int8)}")
1150
+ self.logger.debug(
1151
+ f"Class weights (inverse={inverse}, power={power}, normalize={normalize}): {w}"
894
1152
  )
895
1153
 
896
- def _one_hot_encode_012(self, X: np.ndarray | torch.Tensor) -> torch.Tensor:
897
- """One-hot 0/1/2; -1 rows are all-zeros (B, L, K).
1154
+ return torch.as_tensor(w, dtype=torch.long, device=self.device)
898
1155
 
899
- This method performs one-hot encoding of the input genotype data (0, 1, 2) while handling missing values represented by -1. The output is a tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
1156
+ def _one_hot_encode_012(
1157
+ self, X: np.ndarray | torch.Tensor, num_classes: int | None
1158
+ ) -> torch.Tensor:
1159
+ """One-hot encode genotype calls. Missing inputs (<0) result in a vector of -1s.
900
1160
 
901
1161
  Args:
902
- X (np.ndarray | torch.Tensor): The input data to be one-hot encoded, either as a NumPy array or a PyTorch tensor.
1162
+ X (np.ndarray | torch.Tensor): Input genotype calls as integers (0,1, 2, etc.).
1163
+ num_classes (int | None): Number of classes (K). If None, uses self.num_classes_.
903
1164
 
904
1165
  Returns:
905
- torch.Tensor: A one-hot encoded tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
1166
+ torch.Tensor: One-hot encoded tensor of shape (B, L, K) with dtype
1167
+ ``torch.long``. Valid calls are 0/1, missing calls are all -1.
1168
+
1169
+ Notes:
1170
+ - Valid classes must be integers in [0, K-1]
1171
+ - Missing is any value < 0; these positions become [-1, -1, ..., -1]
1172
+ - If K==2 and values are in {0,2} (no 1s), map 2->1.
906
1173
  """
907
1174
  Xt = (
908
1175
  torch.from_numpy(X).to(self.device)
@@ -910,212 +1177,680 @@ class BaseNNImputer:
910
1177
  else X.to(self.device)
911
1178
  )
912
1179
 
913
- # B=batch, L=features, K=classes
1180
+ # Make sure we have integer class labels
1181
+ if Xt.dtype.is_floating_point:
1182
+ # Convert NaN -> -1 and cast to long
1183
+ Xt = torch.nan_to_num(Xt, nan=-1.0).long()
1184
+ else:
1185
+ Xt = Xt.long()
1186
+
914
1187
  B, L = Xt.shape
915
- K = self.num_classes_
916
- X_ohe = torch.zeros(B, L, K, dtype=torch.float32, device=self.device)
917
- valid = Xt != -1
918
- idx = Xt[valid].long()
1188
+ K = int(num_classes) if num_classes is not None else int(self.num_classes_)
1189
+
1190
+ # Missing is anything < 0 (covers -1, -9, etc.)
1191
+ valid = Xt >= 0
1192
+
1193
+ # If binary mode but data is {0,2}
1194
+ # (haploid-like or "ref vs non-ref"), map 2->1
1195
+ if K == 2:
1196
+ has_het = torch.any(valid & (Xt == 1))
1197
+ has_alt2 = torch.any(valid & (Xt == 2))
1198
+ if has_alt2 and not has_het:
1199
+ Xt = Xt.clone()
1200
+ Xt[valid & (Xt == 2)] = 1
1201
+
1202
+ # Now enforce the one-hot precondition
1203
+ if torch.any(valid & (Xt >= K)):
1204
+ bad_vals = torch.unique(Xt[valid & (Xt >= K)]).detach().cpu().tolist()
1205
+ all_vals = torch.unique(Xt[valid]).detach().cpu().tolist()
1206
+ msg = f"_one_hot_encode_012 received class values outside [0, {K-1}]. num_classes={K}, offending_values={bad_vals}, observed_values={all_vals}. Upstream encoding mismatch (e.g., passing 0/1/2 with num_classes=2)."
1207
+ self.logger.error(msg)
1208
+ raise ValueError(msg)
1209
+
1210
+ # CHANGE: Initialize with -1.0 to ensure missing values are represented as [-1, -1, ... -1]
1211
+ X_ohe = torch.full((B, L, K), -1.0, dtype=torch.long, device=self.device)
1212
+
1213
+ idx = Xt[valid]
919
1214
 
920
1215
  if idx.numel() > 0:
921
- X_ohe[valid] = F.one_hot(idx, num_classes=K).float()
1216
+ # Overwrite valid positions (which were -1) with the correct one-hot vectors
1217
+ X_ohe[valid] = F.one_hot(idx, num_classes=K).long()
922
1218
 
923
1219
  return X_ohe
924
1220
 
925
- def _eval_for_pruning(
926
- self,
927
- *,
928
- model: torch.nn.Module,
929
- X_val: np.ndarray,
930
- params: dict,
931
- metric: str,
932
- objective_mode: bool = True,
933
- do_latent_infer: bool = False,
934
- latent_steps: int = 50,
935
- latent_lr: float = 1e-2,
936
- latent_weight_decay: float = 0.0,
937
- latent_seed: int = 123,
938
- _latent_cache: dict | None = None,
939
- _latent_cache_key: str | None = None,
940
- eval_mask_override: np.ndarray | None = None,
941
- ) -> float:
942
- """Compute a scalar metric (to MAXIMIZE) on a fixed validation set.
943
-
944
- This method evaluates the model on a validation dataset and computes a specified metric, which is used for pruning decisions during hyperparameter tuning. It supports optional latent inference to optimize latent representations before evaluation. The method handles potential issues with non-finite metric values by returning negative infinity, making it easier to prune poorly performing trials.
1221
+ def decode_012(
1222
+ self, X: np.ndarray | pd.DataFrame | list[list[int]], is_nuc: bool = False
1223
+ ) -> np.ndarray:
1224
+ """Decode 012-encodings to IUPAC chars with metadata repair.
1225
+
1226
+ This method converts genotype calls encoded as integers (0, 1, 2, etc.) into their corresponding IUPAC nucleotide codes. It supports two modes of decoding:
1227
+ 1. Nucleotide mode (`is_nuc=True`): Decodes integer codes (0-9) directly to IUPAC nucleotide codes.
1228
+ 2. Metadata mode (`is_nuc=False`): Uses reference and alternate allele metadata to determine the appropriate IUPAC codes. If metadata is missing or inconsistent, the method attempts to repair the decoding by scanning the source SNP data for valid IUPAC codes.
945
1229
 
946
1230
  Args:
947
- model (torch.nn.Module): The model to evaluate.
948
- X_val (np.ndarray): Validation data.
949
- params (dict): Model parameters.
950
- metric (str): Metric name to return.
951
- objective_mode (bool): If True, use objective-mode evaluation. Default is True.
952
- do_latent_infer (bool): If True, perform latent inference before evaluation. Default
953
- latent_steps (int): Number of steps for latent inference. Default is 50.
954
- latent_lr (float): Learning rate for latent inference. Default is 1e-2
955
- latent_weight_decay (float): Weight decay for latent inference. Default is 0.0.
956
- latent_seed (int): Random seed for latent inference. Default is 123.
957
- _latent_cache (dict | None): Optional cache for storing/retrieving optimized latents
958
- _latent_cache_key (str | None): Key for storing/retrieving in _latent_cache.
959
- eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
1231
+ X (np.ndarray | pd.DataFrame | list[list[int]]): Input genotype calls as integers. Can be a NumPy array, Pandas DataFrame, or nested list.
1232
+ is_nuc (bool): If True, decode 0-9 nucleotide codes; else use ref/alt metadata. Defaults to False.
960
1233
 
961
1234
  Returns:
962
- float: The computed metric value to maximize. Returns -inf on failure.
1235
+ np.ndarray: IUPAC strings as a 2D array of shape (n_samples, n_snps).
1236
+
1237
+ Notes:
1238
+ - The method normalizes input values to handle various formats, including strings, lists, and arrays.
1239
+ - It uses a predefined mapping of IUPAC codes to nucleotide bases and vice versa.
1240
+ - Missing or invalid codes are represented as 'N' if they can't be resolved.
1241
+ - The method includes repair logic to infer missing metadata from the source SNP data when necessary.
1242
+
1243
+ Raises:
1244
+ ValueError: If input is not a DataFrame.
963
1245
  """
964
- optimized_val_latents = None
965
-
966
- # Optional latent inference path for models that need it.
967
- if do_latent_infer and hasattr(self, "_latent_infer_for_eval"):
968
- optimized_val_latents = self._latent_infer_for_eval( # type: ignore
969
- model=model,
970
- X_val=X_val,
971
- steps=latent_steps,
972
- lr=latent_lr,
973
- weight_decay=latent_weight_decay,
974
- seed=latent_seed,
975
- cache=_latent_cache,
976
- cache_key=_latent_cache_key,
1246
+ df = validate_input_type(X, return_type="df")
1247
+
1248
+ if not isinstance(df, pd.DataFrame):
1249
+ msg = f"Expected a pandas.DataFrame in 'decode_012', but got: {type(df)}."
1250
+ self.logger.error(msg)
1251
+ raise ValueError(msg)
1252
+
1253
+ # IUPAC Definitions
1254
+ iupac_to_bases: dict[str, set[str]] = {
1255
+ "A": {"A"},
1256
+ "C": {"C"},
1257
+ "G": {"G"},
1258
+ "T": {"T"},
1259
+ "R": {"A", "G"},
1260
+ "Y": {"C", "T"},
1261
+ "S": {"G", "C"},
1262
+ "W": {"A", "T"},
1263
+ "K": {"G", "T"},
1264
+ "M": {"A", "C"},
1265
+ "B": {"C", "G", "T"},
1266
+ "D": {"A", "G", "T"},
1267
+ "H": {"A", "C", "T"},
1268
+ "V": {"A", "C", "G"},
1269
+ "N": set(),
1270
+ }
1271
+ bases_to_iupac = {
1272
+ frozenset(v): k for k, v in iupac_to_bases.items() if k != "N"
1273
+ }
1274
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|.", "NAN", "nan"}
1275
+
1276
+ def _normalize_iupac(value: object) -> str | None:
1277
+ """Normalize an input into a single IUPAC code token or None."""
1278
+ if value is None:
1279
+ return None
1280
+
1281
+ # Bytes -> str (make type narrowing explicit)
1282
+ if isinstance(value, (bytes, np.bytes_)):
1283
+ value = bytes(value).decode("utf-8", errors="ignore")
1284
+
1285
+ # Handle list/tuple/array/Series: take first valid
1286
+ if isinstance(value, (list, tuple, pd.Series, np.ndarray)):
1287
+ # Convert Series to numpy array for consistent behavior
1288
+ if isinstance(value, pd.Series):
1289
+ arr = value.to_numpy()
1290
+ else:
1291
+ arr = value
1292
+
1293
+ # Scalar numpy array fast path
1294
+ if isinstance(arr, np.ndarray) and arr.ndim == 0:
1295
+ return _normalize_iupac(arr.item())
1296
+
1297
+ # Empty sequence/array
1298
+ if len(arr) == 0:
1299
+ return None
1300
+
1301
+ # First valid element wins
1302
+ for item in arr:
1303
+ code = _normalize_iupac(item)
1304
+ if code is not None:
1305
+ return code
1306
+ return None
1307
+
1308
+ s = str(value).upper().strip()
1309
+ if not s or s in missing_codes:
1310
+ return None
1311
+
1312
+ if "," in s:
1313
+ for tok in (t.strip() for t in s.split(",")):
1314
+ if tok and tok not in missing_codes and tok in iupac_to_bases:
1315
+ return tok
1316
+ return None
1317
+
1318
+ return s if s in iupac_to_bases else None
1319
+
1320
+ codes_df = df.apply(pd.to_numeric, errors="coerce")
1321
+ codes = codes_df.fillna(-1).astype(np.int8).to_numpy()
1322
+ n_rows, n_cols = codes.shape
1323
+
1324
+ if is_nuc:
1325
+ iupac_list = np.array(
1326
+ ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"], dtype="<U1"
977
1327
  )
978
- # Retrieve the optimized latents from the cache
979
- if _latent_cache is not None and _latent_cache_key in _latent_cache:
980
- optimized_val_latents = _latent_cache[_latent_cache_key]
981
-
982
- if getattr(self, "_tune_eval_slice", None) is not None:
983
- X_val = X_val[self._tune_eval_slice]
984
- if eval_mask_override is not None:
985
- eval_mask_override = eval_mask_override[self._tune_eval_slice]
986
-
987
- # Child's evaluator now accepts the pre-computed latents
988
- metrics = self._evaluate_model( # type: ignore
989
- X_val=X_val,
990
- model=model,
991
- params=params,
992
- objective_mode=objective_mode,
993
- latent_vectors_val=optimized_val_latents,
994
- eval_mask_override=eval_mask_override,
995
- )
1328
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
1329
+ mask = (codes >= 0) & (codes <= 9)
1330
+ out[mask] = iupac_list[codes[mask]]
1331
+ return out
1332
+
1333
+ # Metadata fetch
1334
+ ref_alleles = getattr(self.genotype_data, "ref", [])
1335
+ alt_alleles = getattr(self.genotype_data, "alt", [])
1336
+
1337
+ if len(ref_alleles) != n_cols:
1338
+ ref_alleles = getattr(self, "_ref", [None] * n_cols)
1339
+ if len(alt_alleles) != n_cols:
1340
+ alt_alleles = getattr(self, "_alt", [None] * n_cols)
1341
+
1342
+ # Ensure list length matches
1343
+ if len(ref_alleles) != n_cols:
1344
+ ref_alleles = [None] * n_cols
1345
+ if len(alt_alleles) != n_cols:
1346
+ alt_alleles = [None] * n_cols
1347
+
1348
+ out = np.full((n_rows, n_cols), "N", dtype="<U1")
1349
+ source_snp_data = None
1350
+
1351
+ for j in range(n_cols):
1352
+ ref = _normalize_iupac(ref_alleles[j])
1353
+ alt = _normalize_iupac(alt_alleles[j])
1354
+
1355
+ # --- REPAIR LOGIC ---
1356
+ # If metadata is missing, scan the source column.
1357
+ if ref is None or alt is None:
1358
+ if source_snp_data is None and self.genotype_data.snp_data is not None:
1359
+ try:
1360
+ source_snp_data = np.asarray(self.genotype_data.snp_data)
1361
+ except Exception:
1362
+ pass # if lazy loading fails
1363
+
1364
+ if source_snp_data is not None:
1365
+ try:
1366
+ col_data = source_snp_data[:, j]
1367
+ uniques = set()
1368
+ # Optimization: check up to 200 non-empty values
1369
+ count = 0
1370
+ for val in col_data:
1371
+ norm = _normalize_iupac(val)
1372
+ if norm:
1373
+ uniques.add(norm)
1374
+ count += 1
1375
+ if len(uniques) >= 2 or count > 200:
1376
+ break
1377
+
1378
+ sorted_u = sorted(list(uniques))
1379
+ if len(sorted_u) >= 1 and ref is None:
1380
+ ref = sorted_u[0]
1381
+ if len(sorted_u) >= 2 and alt is None:
1382
+ alt = sorted_u[1]
1383
+ except Exception:
1384
+ pass
1385
+
1386
+ # --- DEFAULTS FOR MISSING ---
1387
+ # If still missing, we cannot decode.
1388
+ if ref is None and alt is None:
1389
+ ref = "N"
1390
+ alt = "N"
1391
+ elif ref is None:
1392
+ ref = alt
1393
+ elif alt is None:
1394
+ alt = ref # Monomorphic site: ALT becomes REF
1395
+
1396
+ # --- COMPUTE HET CODE ---
1397
+ if ref == alt:
1398
+ het_code = ref
1399
+ else:
1400
+ ref_set = iupac_to_bases.get(ref, set()) if ref is not None else set()
1401
+ alt_set = iupac_to_bases.get(alt, set()) if alt is not None else set()
1402
+ union_set = frozenset(ref_set | alt_set)
1403
+ het_code = bases_to_iupac.get(union_set, "N")
996
1404
 
997
- # Prefer the requested metric; fall back to self.tune_metric if needed.
998
- val = metrics.get(metric, metrics.get(getattr(self, "tune_metric", ""), None))
1405
+ # --- ASSIGNMENT WITH SAFETY FALLBACKS ---
1406
+ col_codes = codes[:, j]
999
1407
 
1000
- if val is None or not np.isfinite(val):
1001
- return -np.inf # make pruning decisions easy/robust on bad reads
1408
+ # Case 0: REF
1409
+ if ref != "N":
1410
+ out[col_codes == 0, j] = ref
1002
1411
 
1003
- return float(val)
1412
+ # Case 1: HET
1413
+ if het_code != "N":
1414
+ out[col_codes == 1, j] = het_code
1415
+ else:
1416
+ # If HET code is invalid (e.g. ref='A', alt='N'),
1417
+ # fallback to REF
1418
+ # Fix for an issue where a HET prediction at a monomorphic site
1419
+ # produced 'N'
1420
+ if ref != "N":
1421
+ out[col_codes == 1, j] = ref
1422
+
1423
+ # Case 2: ALT
1424
+ if alt != "N":
1425
+ out[col_codes == 2, j] = alt
1426
+ else:
1427
+ # If ALT is invalid (e.g. ref='A', alt='N'), fallback to REF
1428
+ # Fix for an issue where an ALT prediction on a monomorphic site
1429
+ # produced 'N'
1430
+ if ref != "N":
1431
+ out[col_codes == 2, j] = ref
1432
+
1433
+ return out
1434
+
1435
+ def _save_best_params(
1436
+ self, best_params: Dict[str, Any], objective_mode: bool = False
1437
+ ) -> None:
1438
+ """Save the best hyperparameters to a JSON file.
1439
+
1440
+ This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
1441
+
1442
+ Args:
1443
+ best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
1444
+ """
1445
+ if not hasattr(self, "parameters_dir"):
1446
+ msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
1447
+ self.logger.error(msg)
1448
+ raise AttributeError(msg)
1449
+
1450
+ if objective_mode:
1451
+ fout = self.optimize_dir / "parameters" / "best_tuned_parameters.json"
1452
+ else:
1453
+ fout = self.parameters_dir / "best_parameters.json"
1454
+
1455
+ fout.parent.mkdir(parents=True, exist_ok=True)
1456
+
1457
+ with open(fout, "w") as f:
1458
+ json.dump(best_params, f, indent=4)
1004
1459
 
1005
- def _first_linear_in_features(self, model: torch.nn.Module) -> int:
1006
- """Return in_features of the model's first Linear layer.
1460
+ def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
1461
+ """An abstract method for setting best parameters."""
1462
+ raise NotImplementedError
1007
1463
 
1008
- This method iterates through the modules of the provided PyTorch model to find the first instance of a Linear layer. It then retrieves and returns the `in_features` attribute of that layer, which indicates the number of input features expected by the layer.
1464
+ def sim_missing_transform(
1465
+ self, X: np.ndarray
1466
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
1467
+ """Simulate missing data according to the specified strategy.
1009
1468
 
1010
1469
  Args:
1011
- model (torch.nn.Module): The model to inspect.
1470
+ X (np.ndarray): Genotype matrix to simulate missing data on.
1012
1471
 
1013
1472
  Returns:
1014
- int: The in_features of the first Linear layer.
1473
+ X_for_model (np.ndarray): Genotype matrix with simulated missing data.
1474
+ sim_mask (np.ndarray): Boolean mask of simulated missing entries.
1475
+ orig_mask (np.ndarray): Boolean mask of original missing entries.
1015
1476
  """
1016
- for m in model.modules():
1017
- if isinstance(m, torch.nn.Linear):
1018
- return int(m.in_features)
1019
- raise RuntimeError("No Linear layers found in model.")
1477
+ if (
1478
+ not hasattr(self, "sim_prop")
1479
+ or self.sim_prop <= 0.0
1480
+ or self.sim_prop >= 1.0
1481
+ ):
1482
+ msg = "sim_prop must be set and between 0.0 and 1.0."
1483
+ self.logger.error(msg)
1484
+ raise AttributeError(msg)
1020
1485
 
1021
- def _assert_model_latent_compat(
1022
- self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter
1023
- ) -> None:
1024
- """Raise if model's first Linear doesn't match latent_vectors width.
1486
+ if not hasattr(self, "tree_parser") and "nonrandom" in self.sim_strategy:
1487
+ msg = "tree_parser must be set for 'nonrandom' or 'nonrandom_weighted' sim_strategy."
1488
+ self.logger.error(msg)
1489
+ raise AttributeError(msg)
1025
1490
 
1026
- This method checks that the dimensionality of the provided latent vectors matches the expected input feature size of the model's first linear layer. If there is a mismatch, it raises a ValueError with a descriptive message.
1491
+ # --- Simulate missing data ---
1492
+ X_for_sim = X.astype(np.float32, copy=True)
1493
+ tr = SimMissingTransformer(
1494
+ genotype_data=self.genotype_data,
1495
+ tree_parser=self.tree_parser,
1496
+ prop_missing=self.sim_prop,
1497
+ strategy=self.sim_strategy,
1498
+ missing_val=-1,
1499
+ mask_missing=True,
1500
+ verbose=self.verbose,
1501
+ seed=self.seed,
1502
+ **self.sim_kwargs,
1503
+ )
1504
+ tr.fit(X_for_sim.copy())
1505
+ X_for_model = tr.transform(X_for_sim.copy())
1506
+ sim_mask = tr.sim_missing_mask_.astype(bool)
1507
+ orig_mask = tr.original_missing_mask_.astype(bool)
1508
+
1509
+ return X_for_model, sim_mask, orig_mask
1510
+
1511
+ def _train_val_test_split(
1512
+ self, X: np.ndarray
1513
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
1514
+ """Split data into train, validation, and test sets.
1027
1515
 
1028
1516
  Args:
1029
- model (torch.nn.Module): The model to check.
1030
- latent_vectors (torch.nn.Parameter): The latent vectors to check.
1517
+ X (np.ndarray): Genotype matrix to split.
1518
+
1519
+ Returns:
1520
+ tuple[np.ndarray, np.ndarray, np.ndarray]: Indices for train, validation, and test sets.
1031
1521
 
1032
1522
  Raises:
1033
- ValueError: If the latent dimension does not match the model's expected input features.
1523
+ ValueError: If there are not enough samples for splitting.
1524
+ AssertionError: If validation_split is not in (0.0, 1.0).
1034
1525
  """
1035
- zdim = int(latent_vectors.shape[1])
1036
- first_in = self._first_linear_in_features(model)
1037
- if first_in != zdim:
1038
- raise ValueError(
1039
- f"Latent mismatch: zdim={zdim}, model first Linear expects in_features={first_in}"
1040
- )
1526
+ n_samples = X.shape[0]
1041
1527
 
1042
- def _prepare_tuning_artifacts(self) -> None:
1043
- """Prepare data and artifacts needed for hyperparameter tuning.
1528
+ if n_samples < 3:
1529
+ msg = f"Not enough samples ({n_samples}) for train/val/test split."
1530
+ self.logger.error(msg)
1531
+ raise ValueError(msg)
1044
1532
 
1045
- This method sets up the necessary data splits, data loaders, and class weights required for hyperparameter tuning. It creates training and validation sets from the ground truth data, initializes data loaders with a specified batch size, and computes class-balanced weights based on the training data. The method also handles optional subsampling of the dataset for faster tuning and prepares slices for evaluation if needed.
1533
+ assert (
1534
+ self.validation_split > 0.0 and self.validation_split < 1.0
1535
+ ), f"validation_split must be in (0.0, 1.0), but got {self.validation_split}."
1046
1536
 
1047
- Raises:
1048
- AttributeError: If the ground truth data (`ground_truth_`) is not set.
1537
+ # Train/Val split
1538
+ indices = np.arange(n_samples)
1539
+ train_idx, val_test_idx = train_test_split(
1540
+ indices,
1541
+ test_size=self.validation_split,
1542
+ random_state=self.seed,
1543
+ )
1544
+
1545
+ if not val_test_idx.size >= 4:
1546
+ msg = f"Not enough samples ({val_test_idx.size}) for validation/test split."
1547
+ self.logger.error(msg)
1548
+ raise ValueError(msg)
1549
+
1550
+ # Split val and test equally
1551
+ val_idx, test_idx = train_test_split(
1552
+ val_test_idx, test_size=0.5, random_state=self.seed
1553
+ )
1554
+
1555
+ return train_idx, val_idx, test_idx
1556
+
1557
+ def _get_data_loaders(
1558
+ self,
1559
+ X: np.ndarray,
1560
+ y: np.ndarray,
1561
+ mask: np.ndarray,
1562
+ batch_size: int,
1563
+ *,
1564
+ shuffle: bool = True,
1565
+ ) -> torch.utils.data.DataLoader:
1566
+ """Create DataLoader for training and validation.
1567
+
1568
+ Args:
1569
+ X (np.ndarray): 0/1/2-encoded input matrix.
1570
+ y (np.ndarray): 0/1/2-encoded matrix with -1 for missing.
1571
+ mask (np.ndarray): Boolean mask of entries to score in the loss.
1572
+ batch_size (int): Batch size.
1573
+ shuffle (bool): Whether to shuffle batches. Defaults to True.
1574
+
1575
+ Returns:
1576
+ The DataLoader.
1049
1577
  """
1050
- if getattr(self, "_tune_ready", False):
1051
- return
1578
+ dataset = _MaskedNumpyDataset(X, y, mask)
1052
1579
 
1053
- X = self.ground_truth_
1054
- n_samp, n_loci = X.shape
1055
- rng = self.rng
1580
+ return torch.utils.data.DataLoader(
1581
+ dataset,
1582
+ batch_size=batch_size,
1583
+ shuffle=shuffle,
1584
+ pin_memory=(str(self.device).startswith("cuda")),
1585
+ )
1056
1586
 
1057
- if self.tune_fast:
1058
- s = min(n_samp, self.tune_max_samples)
1059
- l = n_loci if self.tune_max_loci == 0 else min(n_loci, self.tune_max_loci)
1587
+ def _update_anneal_schedule(
1588
+ self,
1589
+ final: float,
1590
+ warm: int,
1591
+ ramp: int,
1592
+ epoch: int,
1593
+ *,
1594
+ init_val: float = 0.0,
1595
+ ) -> torch.Tensor:
1596
+ """Update annealed hyperparameter value based on epoch.
1060
1597
 
1061
- samp_idx = np.sort(rng.choice(n_samp, size=s, replace=False))
1062
- loci_idx = np.sort(rng.choice(n_loci, size=l, replace=False))
1063
- X_small = X[samp_idx][:, loci_idx]
1598
+ Args:
1599
+ final (float): Final value after annealing.
1600
+ warm (int): Number of warm-up epochs.
1601
+ ramp (int): Number of ramp-up epochs.
1602
+ epoch (int): Current epoch number.
1603
+ init_val (float): Initial value before annealing starts.
1604
+
1605
+ Returns:
1606
+ torch.Tensor: Current value of the hyperparameter.
1607
+ """
1608
+ if epoch < warm:
1609
+ val = torch.tensor(init_val)
1610
+ elif epoch < warm + ramp:
1611
+ val = torch.tensor(final * ((epoch - warm) / ramp))
1064
1612
  else:
1065
- X_small = X
1613
+ val = torch.tensor(final)
1066
1614
 
1067
- idx = np.arange(X_small.shape[0])
1068
- tr, te = train_test_split(
1069
- idx, test_size=self.validation_split, random_state=self.seed
1070
- )
1071
- self._tune_train_idx = tr
1072
- self._tune_test_idx = te
1073
- self._tune_X_train = X_small[tr]
1074
- self._tune_X_test = X_small[te]
1615
+ return val.to(self.device)
1075
1616
 
1076
- self._tune_class_weights = self._normalize_class_weights(
1077
- self._class_weights_from_zygosity(self._tune_X_train)
1078
- )
1617
+ def _anneal_config(
1618
+ self,
1619
+ params: Optional[dict],
1620
+ key: str,
1621
+ default: float,
1622
+ max_epochs: int,
1623
+ *,
1624
+ warm_alt: int = 50,
1625
+ ramp_alt: int = 100,
1626
+ ) -> Tuple[float, int, int]:
1627
+ """Configure annealing schedule for a hyperparameter.
1079
1628
 
1080
- # Temporarily bump batch size only for tuning loader
1081
- orig_bs = self.batch_size
1082
- self.batch_size = self.tune_batch_size
1083
- self._tune_loader = self._get_data_loaders(self._tune_X_train) # type: ignore
1084
- self.batch_size = orig_bs
1629
+ Args:
1630
+ params (Optional[dict]): Dictionary of parameters to extract from.
1631
+ key (str): Key to look for in params.
1632
+ default (float): Default final value if not specified in params.
1633
+ max_epochs (int): Total number of training epochs.
1634
+ warm_alt (int): Alternative warm-up period if 10% of epochs is too long
1635
+ ramp_alt (int): Alternative ramp-up period if 20% of epochs is too long
1085
1636
 
1086
- self._tune_num_features = self._tune_X_train.shape[1]
1087
- self._tune_val_latents_source = None
1088
- self._tune_train_latents_source = None
1637
+ Returns:
1638
+ Tuple[float, int, int]: Final value, warm-up epochs, ramp-up epochs.
1639
+ """
1640
+ val = None
1641
+ if params is not None and params:
1642
+ if not hasattr(self, key):
1643
+ msg = f"Attribute '{key}' not found for anneal_config."
1644
+ self.logger.error(msg)
1645
+ raise AttributeError(msg)
1089
1646
 
1090
- # Optional: for huge val sets, thin them for proxy metric
1091
- if (
1092
- self.tune_proxy_metric_batch
1093
- and self._tune_X_test.shape[0] > self.tune_proxy_metric_batch
1094
- ):
1095
- self._tune_eval_slice = np.arange(self.tune_proxy_metric_batch)
1647
+ val = params.get(key, getattr(self, key))
1648
+
1649
+ if val is not None and isinstance(val, (float, int)):
1650
+ final = float(val)
1096
1651
  else:
1097
- self._tune_eval_slice = None
1652
+ final = default
1098
1653
 
1099
- self._tune_ready = True
1654
+ warm, ramp = min(int(0.1 * max_epochs), warm_alt), min(
1655
+ int(0.2 * max_epochs), ramp_alt
1656
+ )
1657
+ return final, warm, ramp
1100
1658
 
1101
- def _save_best_params(self, best_params: Dict[str, Any]) -> None:
1102
- """Save the best hyperparameters to a JSON file.
1659
+ def _repair_ref_alt_from_iupac(self, loci: np.ndarray) -> None:
1660
+ """Repair REF/ALT for specific loci using observed IUPAC genotypes.
1103
1661
 
1104
- This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
1662
+ Args:
1663
+ loci (np.ndarray): Array of locus indices to repair.
1664
+
1665
+ Notes:
1666
+ - Modifies self.genotype_data.ref and self.genotype_data.alt in place.
1667
+ """
1668
+ iupac_to_bases = {
1669
+ "A": {"A"},
1670
+ "C": {"C"},
1671
+ "G": {"G"},
1672
+ "T": {"T"},
1673
+ "R": {"A", "G"},
1674
+ "Y": {"C", "T"},
1675
+ "S": {"G", "C"},
1676
+ "W": {"A", "T"},
1677
+ "K": {"G", "T"},
1678
+ "M": {"A", "C"},
1679
+ "B": {"C", "G", "T"},
1680
+ "D": {"A", "G", "T"},
1681
+ "H": {"A", "C", "T"},
1682
+ "V": {"A", "C", "G"},
1683
+ }
1684
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
1685
+
1686
+ def norm(v: object) -> str | None:
1687
+ if v is None:
1688
+ return None
1689
+ s = str(v).upper().strip()
1690
+ if not s or s in missing_codes:
1691
+ return None
1692
+ return s if s in iupac_to_bases else None
1693
+
1694
+ snp = np.asarray(self.genotype_data.snp_data, dtype=object) # (N,L) IUPAC-ish
1695
+ refs = list(getattr(self.genotype_data, "ref", [None] * snp.shape[1]))
1696
+ alts = list(getattr(self.genotype_data, "alt", [None] * snp.shape[1]))
1697
+
1698
+ for j in loci:
1699
+ cnt = Counter()
1700
+ col = snp[:, int(j)]
1701
+ for g in col:
1702
+ code = norm(g)
1703
+ if code is None:
1704
+ continue
1705
+ for b in iupac_to_bases[code]:
1706
+ cnt[b] += 1
1707
+
1708
+ if not cnt:
1709
+ continue
1710
+
1711
+ common = [b for b, _ in cnt.most_common()]
1712
+ ref = common[0]
1713
+ alt = common[1] if len(common) > 1 else None
1714
+
1715
+ refs[int(j)] = ref
1716
+ alts[int(j)] = alt if alt is not None else "."
1717
+
1718
+ self.genotype_data.ref = np.asarray(refs, dtype=object)
1719
+
1720
+ if not isinstance(alts, np.ndarray):
1721
+ alts = np.array(alts, dtype=object).tolist()
1722
+
1723
+ self.genotype_data.alt = alts
1724
+
1725
+ def _aligned_ref_alt(self, L: int) -> tuple[list[object], list[object]]:
1726
+ """Return REF/ALT aligned to the genotype matrix columns.
1105
1727
 
1106
1728
  Args:
1107
- best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
1729
+ L (int): Number of loci (columns in genotype matrix).
1730
+
1731
+ Returns:
1732
+ tuple[list[object], list[object]]: Aligned REF and ALT lists.
1108
1733
  """
1109
- if not hasattr(self, "parameters_dir"):
1110
- msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
1734
+ refs = getattr(self.genotype_data, "ref", None)
1735
+ alts = getattr(self.genotype_data, "alt", None)
1736
+
1737
+ if refs is None or alts is None:
1738
+ msg = "genotype_data.ref/alt are required but missing."
1111
1739
  self.logger.error(msg)
1112
- raise AttributeError(msg)
1740
+ raise ValueError(msg)
1113
1741
 
1114
- fout = self.parameters_dir / "best_parameters.json"
1742
+ refs_arr = np.asarray(refs, dtype=object)
1743
+ alts_arr = np.asarray(alts, dtype=object)
1115
1744
 
1116
- with open(fout, "w") as f:
1117
- json.dump(best_params, f, indent=4)
1745
+ if refs_arr.shape[0] != L or alts_arr.shape[0] != L:
1746
+ msg = f"REF/ALT length mismatch vs matrix columns: L={L}, len(ref)={refs_arr.shape[0]}, len(alt)={alts_arr.shape[0]}. You are using REF/ALT metadata that is not aligned to pgenc.genotypes_012 columns. Fix by subsetting/refiltering ref/alt with the same locus mask used for the genotype matrix."
1747
+ self.logger.error(msg)
1748
+ raise ValueError(msg)
1118
1749
 
1119
- def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
1120
- """An abstract method for setting best parameters."""
1121
- raise NotImplementedError
1750
+ # Unwrap singleton ALT arrays like array(['T'], dtype=object)
1751
+ def unwrap(x: object) -> object:
1752
+ if isinstance(x, np.ndarray):
1753
+ if x.size == 0:
1754
+ return None
1755
+ if x.size == 1:
1756
+ return x.item()
1757
+ return x
1758
+
1759
+ refs_list = [unwrap(x) for x in refs_arr.tolist()]
1760
+ alts_list = [unwrap(x) for x in alts_arr.tolist()]
1761
+ return refs_list, alts_list
1762
+
1763
+ def _build_valid_class_mask(self) -> torch.Tensor:
1764
+ L = self.num_features_
1765
+ K = self.num_classes_
1766
+ mask = np.ones((L, K), dtype=bool)
1767
+
1768
+ # --- IUPAC helpers (single-character only) ---
1769
+ iupac_to_bases: dict[str, set[str]] = {
1770
+ "A": {"A"},
1771
+ "C": {"C"},
1772
+ "G": {"G"},
1773
+ "T": {"T"},
1774
+ "R": {"A", "G"},
1775
+ "Y": {"C", "T"},
1776
+ "S": {"G", "C"},
1777
+ "W": {"A", "T"},
1778
+ "K": {"G", "T"},
1779
+ "M": {"A", "C"},
1780
+ "B": {"C", "G", "T"},
1781
+ "D": {"A", "G", "T"},
1782
+ "H": {"A", "C", "T"},
1783
+ "V": {"A", "C", "G"},
1784
+ }
1785
+ missing_codes = {"", ".", "N", "NONE", "-", "?", "./.", ".|."}
1786
+
1787
+ # get aligned ref/alt (should be exactly length L)
1788
+ refs, alts = self._aligned_ref_alt(L)
1789
+
1790
+ def _normalize_iupac(value: object) -> str | None:
1791
+ """Return a single-letter IUPAC code or None if missing/invalid."""
1792
+ if value is None:
1793
+ return None
1794
+ if isinstance(value, (bytes, np.bytes_)):
1795
+ value = value.decode("utf-8", errors="ignore")
1796
+
1797
+ # allow list/tuple/array containers (take first valid)
1798
+ if isinstance(value, (list, tuple, np.ndarray, pd.Series)):
1799
+ for item in value:
1800
+ code = _normalize_iupac(item)
1801
+ if code is not None:
1802
+ return code
1803
+ return None
1804
+
1805
+ s = str(value).upper().strip()
1806
+ if not s or s in missing_codes:
1807
+ return None
1808
+
1809
+ # handle comma-separated values
1810
+ if "," in s:
1811
+ for tok in (t.strip() for t in s.split(",")):
1812
+ if tok and tok not in missing_codes and tok in iupac_to_bases:
1813
+ return tok
1814
+ return None
1815
+
1816
+ return s if s in iupac_to_bases else None
1817
+
1818
+ # 1) metadata restriction
1819
+ for j in range(L):
1820
+ ref = _normalize_iupac(refs[j])
1821
+ alt = _normalize_iupac(alts[j])
1822
+
1823
+ if alt is None or (ref is not None and alt == ref):
1824
+ mask[j, :] = False
1825
+ mask[j, 0] = True
1826
+
1827
+ # 2) data-driven override
1828
+ y_train = getattr(self, "y_train_", None)
1829
+ if y_train is not None:
1830
+ y = np.asarray(y_train)
1831
+ if y.ndim == 2 and y.shape[1] == L:
1832
+ if K == 2:
1833
+ y = y.copy()
1834
+ y[y == 2] = 1
1835
+ valid = y >= 0
1836
+ if valid.any():
1837
+ observed = np.zeros((L, K), dtype=bool)
1838
+ for c in range(K):
1839
+ observed[:, c] = np.any(valid & (y == c), axis=0)
1840
+
1841
+ conflict = observed & (~mask)
1842
+ if conflict.any():
1843
+ loci = np.where(conflict.any(axis=1))[0]
1844
+ self.valid_class_mask_conflict_loci_ = loci
1845
+ self.logger.warning(
1846
+ f"valid_class_mask_ metadata forbids observed classes at {loci.size} loci. "
1847
+ "Expanding mask to include observed classes."
1848
+ )
1849
+ mask |= observed
1850
+
1851
+ bad = np.where(~mask.any(axis=1))[0]
1852
+ if bad.size:
1853
+ mask[bad, :] = False
1854
+ mask[bad, 0] = True
1855
+
1856
+ return torch.as_tensor(mask, dtype=torch.bool, device=self.device)