pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/RECORD +0 -75
  83. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,1118 @@
1
+ import copy
2
+ import gc
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import optuna
11
+ import pandas as pd
12
+ import plotly.graph_objects as go
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from matplotlib.figure import Figure
16
+ from sklearn.metrics import classification_report
17
+ from sklearn.model_selection import train_test_split
18
+ from snpio import SNPioMultiQC
19
+ from snpio.utils.logging import LoggerManager
20
+
21
+ from pgsui.impute.unsupervised.nn_scorers import Scorer
22
+ from pgsui.utils.classification_viz import ClassificationReportVisualizer
23
+ from pgsui.utils.logging_utils import configure_logger
24
+ from pgsui.utils.plotting import Plotting
25
+ from pgsui.utils.pretty_metrics import PrettyMetrics
26
+
27
+ if TYPE_CHECKING:
28
+ from snpio.read_input.genotype_data import GenotypeData
29
+
30
+ from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
31
+ from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
32
+ from pgsui.impute.unsupervised.models.ubp_model import UBPModel
33
+ from pgsui.impute.unsupervised.models.vae_model import VAEModel
34
+
35
+
36
+ class BaseNNImputer:
37
+ """An abstract base class for neural network-based imputers.
38
+
39
+ This class provides a shared framework and common functionality for all neural network imputers. It is not meant to be instantiated directly. Instead, child classes should inherit from it and implement the abstract methods. Provided functionality: Directory setup and logging initialization; A hyperparameter tuning pipeline using Optuna; Utility methods for building models (`build_model`), initializing weights (`initialize_weights`), and checking for fitted attributes (`ensure_attribute`); Helper methods for calculating class weights for imbalanced data; Setup for standardized plotting and model scoring classes.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ model_name: str,
45
+ genotype_data: "GenotypeData",
46
+ prefix: str,
47
+ *,
48
+ device: Literal["gpu", "cpu", "mps"] = "cpu",
49
+ verbose: bool = False,
50
+ debug: bool = False,
51
+ ):
52
+ """Initializes the base class for neural network imputers.
53
+
54
+ This constructor sets up the device (CPU, GPU, or MPS), creates the necessary output directories for models and results, and a logger. It also initializes a genotype encoder for handling genotype data.
55
+
56
+ Args:
57
+ prefix (str): A prefix used to name the output directory (e.g., 'pgsui_output').
58
+ device (Literal["gpu", "cpu", "mps"]): The device to use for PyTorch operations. If 'gpu' or 'mps' is chosen, it will fall back to 'cpu' if the required hardware is not available. Defaults to "cpu".
59
+ verbose (bool): If True, enables detailed logging output. Defaults to False.
60
+ debug (bool): If True, enables debug mode. Defaults to False.
61
+ """
62
+ self.model_name = model_name
63
+ self.genotype_data = genotype_data
64
+
65
+ self.prefix = prefix
66
+ self.verbose = verbose
67
+ self.debug = debug
68
+
69
+ # Quiet Matplotlib/fontTools INFO logging when saving PDF/SVG
70
+ for name in (
71
+ "fontTools",
72
+ "fontTools.subset",
73
+ "fontTools.ttLib",
74
+ "matplotlib.font_manager",
75
+ ):
76
+ lg = logging.getLogger(name)
77
+ lg.setLevel(logging.WARNING)
78
+ lg.propagate = False
79
+
80
+ self.device = self._select_device(device)
81
+
82
+ # Prepare directory structure
83
+ outdirs = ["models", "plots", "metrics", "optimize", "parameters"]
84
+ self._create_model_directories(prefix, outdirs)
85
+
86
+ # Initialize loggers
87
+ kwargs = {"prefix": prefix, "verbose": verbose, "debug": debug}
88
+ logman = LoggerManager(__name__, **kwargs)
89
+ self.logger = configure_logger(
90
+ logman.get_logger(), verbose=self.verbose, debug=self.debug
91
+ )
92
+ self._float_genotype_cache: np.ndarray | None = None
93
+ self._sim_mask_cache: dict[tuple, np.ndarray] = {}
94
+
95
+ # To be initialized by child classes or fit method
96
+ self.tune_save_db: bool = False
97
+ self.tune_resume: bool = False
98
+ self.n_trials: int = 100
99
+ self.model_params: Dict[str, Any] = {}
100
+ self.tune_metric: str = "val_f1_macro"
101
+ self.learning_rate: float = 1e-3
102
+ self.plotter_: "Plotting"
103
+ self.num_features_: int = 0
104
+ self.num_classes_: int = 3
105
+ self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = "pdf"
106
+ self.plot_fontsize: int = 10
107
+ self.plot_dpi: int = 300
108
+ self.title_fontsize: int = 12
109
+ self.despine: bool = True
110
+ self.show_plots: bool = False
111
+ self.scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
112
+ self.pgenc: Any = None
113
+ self.is_haploid: bool = False
114
+ self.ploidy: int = 2
115
+ self.beta: float = 0.9999
116
+ self.max_ratio: float = 5.0
117
+ self.sim_strategy: str = "mcar"
118
+ self.sim_prop: float = 0.1
119
+ self.seed: int | None = 42
120
+ self.rng: np.random.Generator = np.random.default_rng(self.seed)
121
+ self.ground_truth_: np.ndarray
122
+ self.tune_fast: bool = False
123
+ self.tune_max_samples: int = 1000
124
+ self.tune_max_loci: int = 500
125
+ self.validation_split: float = 0.2
126
+ self.tune_batch_size: int = 64
127
+ self.tune_proxy_metric_batch: int = 512
128
+ self.batch_size: int = 64
129
+ self.best_params_: Dict[str, Any] = {}
130
+
131
+ self.optimize_dir: Path
132
+ self.models_dir: Path
133
+ self.plots_dir: Path
134
+ self.metrics_dir: Path
135
+ self.parameters_dir: Path
136
+ self.study_db: Path
137
+
138
+ def tune_hyperparameters(self) -> None:
139
+ """Tunes model hyperparameters using an Optuna study.
140
+
141
+ This method orchestrates the hyperparameter search process. It creates an Optuna study that aims to maximize the metric defined in `self.tune_metric`. The search is driven by the `_objective` method, which must be implemented by the child class. After the search, the best parameters are logged, saved to a JSON file, and visualizations of the study are generated.
142
+
143
+ Raises:
144
+ NotImplementedError: If the `_objective` or `_set_best_params` methods are not implemented in the inheriting child class.
145
+ """
146
+ self.logger.info("Tuning hyperparameters. This might take a while...")
147
+
148
+ if self.verbose or self.debug:
149
+ optuna.logging.set_verbosity(optuna.logging.INFO)
150
+ else:
151
+ optuna.logging.set_verbosity(optuna.logging.WARNING)
152
+
153
+ study_db = None
154
+ load_if_exists = False
155
+ if self.tune_save_db:
156
+ study_db = self.optimize_dir / "study_database" / "optuna_study.db"
157
+ study_db.parent.mkdir(parents=True, exist_ok=True)
158
+
159
+ if self.tune_resume and study_db.exists():
160
+ load_if_exists = True
161
+
162
+ if not self.tune_resume and study_db.exists():
163
+ study_db.unlink()
164
+
165
+ study_name = f"{self.prefix} {self.model_name} Model Optimization"
166
+ storage = f"sqlite:///{study_db}" if self.tune_save_db else None
167
+
168
+ study = optuna.create_study(
169
+ direction="maximize",
170
+ study_name=study_name,
171
+ storage=storage,
172
+ load_if_exists=load_if_exists,
173
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
174
+ )
175
+
176
+ if not hasattr(self, "_objective"):
177
+ msg = "`_objective()` must be implemented in the child class."
178
+ self.logger.error(msg)
179
+ raise NotImplementedError(msg)
180
+
181
+ self.n_jobs = getattr(self, "n_jobs", 1)
182
+ if self.n_jobs < -1 or self.n_jobs == 0:
183
+ self.logger.warning(f"Invalid n_jobs={self.n_jobs}. Setting n_jobs=1.")
184
+ self.n_jobs = 1
185
+
186
+ show_progress_bar = not self.verbose and not self.debug and self.n_jobs == 1
187
+
188
+ study.optimize(
189
+ lambda trial: self._objective(trial),
190
+ n_trials=self.n_trials,
191
+ n_jobs=self.n_jobs,
192
+ gc_after_trial=True,
193
+ show_progress_bar=show_progress_bar,
194
+ )
195
+
196
+ best_metric = study.best_value
197
+ best_params = study.best_params
198
+
199
+ # Set the best parameters.
200
+ # NOTE: `_set_best_params()` must be implemented in the child class.
201
+ if not hasattr(self, "_set_best_params"):
202
+ msg = "Method `_set_best_params()` must be implemented in the child class."
203
+ self.logger.error(msg)
204
+ raise NotImplementedError(msg)
205
+
206
+ self.best_params_ = self._set_best_params(best_params)
207
+ self.model_params.update(self.best_params_)
208
+ self.logger.info(f"Best {self.tune_metric} metric: {best_metric}")
209
+ self.logger.info("Best parameters:")
210
+ best_params_tmp = copy.deepcopy(best_params)
211
+ best_params_tmp["learning_rate"] = self.learning_rate
212
+
213
+ title = f"{self.model_name} Optimized Parameters"
214
+ pm = PrettyMetrics(best_params_tmp, precision=6, title=title)
215
+ pm.render()
216
+
217
+ # Save best parameters to a JSON file.
218
+ self._save_best_params(best_params)
219
+
220
+ tn = f"{self.tune_metric} Value"
221
+ self.plotter_.plot_tuning(
222
+ study, self.model_name, self.optimize_dir / "plots", target_name=tn
223
+ )
224
+
225
+ @staticmethod
226
+ def initialize_weights(module: torch.nn.Module) -> None:
227
+ """Initializes model weights using the Kaiming Uniform distribution.
228
+
229
+ This static method is intended to be applied to a PyTorch model to initialize the weights of its linear and convolutional layers. This initialization scheme is particularly effective for networks that use ReLU-family activation functions, as it helps maintain stable activation variances during training.
230
+
231
+ Args:
232
+ module (torch.nn.Module): The PyTorch module (e.g., a layer) to initialize.
233
+ """
234
+ if isinstance(
235
+ module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.ConvTranspose1d)
236
+ ):
237
+ # Use Kaiming Uniform initialization for Linear and Conv layers
238
+ torch.nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
239
+ if module.bias is not None:
240
+ torch.nn.init.zeros_(module.bias)
241
+
242
+ def build_model(
243
+ self,
244
+ Model: (
245
+ torch.nn.Module
246
+ | type["AutoencoderModel"]
247
+ | type["NLPCAModel"]
248
+ | type["UBPModel"]
249
+ | type["VAEModel"]
250
+ ),
251
+ model_params: Dict[str, int | float | str | bool],
252
+ ) -> torch.nn.Module:
253
+ """Builds and initializes a neural network model instance.
254
+
255
+ This method instantiates a model by combining fixed, data-dependent parameters (like `n_features`) with variable hyperparameters (like `latent_dim`). The resulting model is then moved to the appropriate compute device.
256
+
257
+ Args:
258
+ Model (torch.nn.Module): The model class to be instantiated.
259
+ model_params (Dict[str, Any]): A dictionary of variable model hyperparameters, typically sampled during a hyperparameter search.
260
+
261
+ Returns:
262
+ torch.nn.Module: The constructed model instance, ready for training.
263
+
264
+ Raises:
265
+ TypeError: If `model_params` is not a dictionary.
266
+ AttributeError: If a required data-dependent attribute like `num_features_` has not been set, typically by calling `fit` first.
267
+ """
268
+ if not isinstance(model_params, dict):
269
+ msg = f"'model_params' must be a dictionary, but got {type(model_params)}."
270
+ self.logger.error(msg)
271
+ raise TypeError(msg)
272
+
273
+ if not hasattr(self, "num_features_"):
274
+ msg = (
275
+ "Attribute 'num_features_' is not set. Call fit() before build_model()."
276
+ )
277
+ self.logger.error(msg)
278
+ raise AttributeError(msg)
279
+
280
+ # Start with a base set of fixed (non-tuned) parameters.
281
+ all_params = {
282
+ "n_features": self.num_features_,
283
+ "prefix": self.prefix,
284
+ "num_classes": self.num_classes_,
285
+ "verbose": self.verbose,
286
+ "debug": self.debug,
287
+ "device": self.device,
288
+ }
289
+
290
+ # Update with the variable hyperparameters from the provided dictionary
291
+ all_params.update(model_params)
292
+
293
+ return Model(**all_params).to(self.device)
294
+
295
+ def initialize_plotting_and_scorers(self) -> Tuple[Plotting, Scorer]:
296
+ """Initializes and returns the plotting and scoring utility classes.
297
+
298
+ This method should be called within a `fit` method to set up the standardized utilities for generating plots and calculating performance metrics.
299
+
300
+ Returns:
301
+ Tuple[Plotting, Scorer]: A tuple containing the initialized Plotting and Scorer objects.
302
+ """
303
+ fmt = self.plot_format
304
+
305
+ # Initialize plotter.
306
+ plotter = Plotting(
307
+ model_name=self.model_name,
308
+ prefix=self.prefix,
309
+ plot_format=fmt,
310
+ plot_fontsize=self.plot_fontsize,
311
+ plot_dpi=self.plot_dpi,
312
+ title_fontsize=self.title_fontsize,
313
+ despine=self.despine,
314
+ show_plots=self.show_plots,
315
+ verbose=self.verbose,
316
+ debug=self.debug,
317
+ multiqc=True,
318
+ multiqc_section=f"PG-SUI: {self.model_name} Model Imputation",
319
+ )
320
+
321
+ # Metrics
322
+ scorers = Scorer(
323
+ prefix=self.prefix,
324
+ average=self.scoring_averaging,
325
+ verbose=self.verbose,
326
+ debug=self.debug,
327
+ )
328
+
329
+ return plotter, scorers
330
+
331
+ def _objective(self, trial: optuna.Trial) -> float:
332
+ """Defines the objective function for Optuna hyperparameter tuning.
333
+
334
+ This abstract method must be implemented by the child class. It should define a single hyperparameter tuning trial, which typically involves building, training, and evaluating a model with a set of sampled hyperparameters.
335
+
336
+ Args:
337
+ trial (optuna.Trial): The Optuna trial object, used to sample hyperparameters.
338
+
339
+ Returns:
340
+ float: The value of the metric to be optimized (e.g., validation accuracy, F1-score).
341
+ """
342
+ msg = "Method `_objective()` must be implemented in the child class."
343
+ self.logger.error(msg)
344
+ raise NotImplementedError(msg)
345
+
346
+ def fit(self, X: np.ndarray | pd.DataFrame | list | None = None) -> "BaseNNImputer":
347
+ """Fits the imputer model to the data.
348
+
349
+ This abstract method must be implemented by the child class. It should contain the logic for training the neural network model on the provided input data `X`.
350
+
351
+ Args:
352
+ X (np.ndarray | pd.DataFrame | list | None): The input data, which may contain missing values.
353
+
354
+ Returns:
355
+ BaseNNImputer: The fitted imputer instance.
356
+ """
357
+ msg = "Method ``fit()`` must be implemented in the child class."
358
+ self.logger.error(msg)
359
+ raise NotImplementedError(msg)
360
+
361
+ def transform(
362
+ self, X: np.ndarray | pd.DataFrame | list | None = None
363
+ ) -> np.ndarray:
364
+ """Imputes missing values in the data using the trained model.
365
+
366
+ This abstract method must be implemented by the child class. It should use the fitted model to fill in missing values in the provided data `X`.
367
+
368
+ Args:
369
+ X (np.ndarray | pd.DataFrame | list | None): The input data with missing values.
370
+
371
+ Returns:
372
+ np.ndarray: The data with missing values imputed.
373
+ """
374
+ msg = "Method ``transform()`` must be implemented in the child class."
375
+ self.logger.error(msg)
376
+ raise NotImplementedError(msg)
377
+
378
+ def _class_balanced_weights_from_mask(
379
+ self,
380
+ y: np.ndarray,
381
+ train_mask: np.ndarray,
382
+ num_classes: int,
383
+ beta: float = 0.9999,
384
+ max_ratio: float = 5.0,
385
+ mode: Literal["allele", "genotype10"] = "allele",
386
+ ) -> torch.Tensor:
387
+ """Class-balanced weights (Cui et al. 2019) with overflow-safe effective number.
388
+
389
+ mode="allele": y is 1D alleles in {0..3}, train_mask same shape. mode="genotype10": y is (nS,nF,2) alleles; train_mask is (nS,nF) loci where both alleles known.
390
+
391
+ Args:
392
+ y (np.ndarray): Ground truth labels.
393
+ train_mask (np.ndarray): Boolean mask of training examples (same shape as y or y without last dim for genotype10).
394
+ num_classes (int): Number of classes.
395
+ beta (float): Hyperparameter for effective number calculation. Clamped to (0,1). Default is 0.9999.
396
+ max_ratio (float): Maximum allowed ratio between largest and smallest non-zero weight. Default is 5.0.
397
+ mode (Literal["allele", "genotype10"]): Whether y contains allele labels or 10-class genotypes. Default is "allele".
398
+
399
+ Returns:
400
+ torch.Tensor: Class weights of shape (num_classes,). Mean weight is 1.0, zero-weight classes remain zero.
401
+ """
402
+ if mode == "allele":
403
+ valid = (y >= 0) & train_mask
404
+ cls, cnt = np.unique(y[valid].astype(np.int64), return_counts=True)
405
+ counts = np.zeros(num_classes, dtype=np.float64)
406
+ counts[cls] = cnt
407
+
408
+ elif mode == "genotype10":
409
+ if y.ndim != 3 or y.shape[-1] != 2:
410
+ msg = "For genotype10, y must be (nS,nF,2)."
411
+ self.logger.error(msg)
412
+ raise ValueError(msg)
413
+
414
+ if train_mask.shape != y.shape[:2]:
415
+ msg = "train_mask must be (nS,nF) for genotype10."
416
+ self.logger.error(msg)
417
+ raise ValueError(msg)
418
+
419
+ # only loci where both alleles known and in training
420
+ m = train_mask & np.all(y >= 0, axis=-1)
421
+ if not np.any(m):
422
+ counts = np.zeros(num_classes, dtype=np.float64)
423
+
424
+ else:
425
+ a1 = y[:, :, 0][m].astype(int)
426
+ a2 = y[:, :, 1][m].astype(int)
427
+ lo, hi = np.minimum(a1, a2), np.maximum(a1, a2)
428
+ # map to 10-class index
429
+ map10 = self.pgenc.map10
430
+ idx10 = map10[lo, hi]
431
+ idx10 = idx10[(idx10 >= 0) & (idx10 < num_classes)]
432
+ counts = np.bincount(idx10, minlength=num_classes).astype(np.float64)
433
+ else:
434
+ msg = f"Unknown mode supplied to _class_balanced_weights_from_mask: {mode}"
435
+ self.logger.error(msg)
436
+ raise ValueError(msg)
437
+
438
+ # ---- Effective number ----
439
+ beta = float(beta)
440
+
441
+ # clamp beta ∈ (0,1)
442
+ if not np.isfinite(beta):
443
+ beta = 0.9999
444
+
445
+ beta = min(max(beta, 1e-8), 1.0 - 1e-8)
446
+
447
+ logb = np.log(beta) # < 0
448
+ t = counts * logb # ≤ 0
449
+
450
+ # 1 - beta^n = 1 - exp(n*log(beta)) = -(exp(n*log(beta)) - 1)
451
+ # use expm1 for accuracy near 0; for very negative t, eff≈1.0
452
+ eff = np.where(t > -50.0, -np.expm1(t), 1.0)
453
+
454
+ # class-balanced weights
455
+ w = (1.0 - beta) / (eff + 1e-12)
456
+
457
+ # Give unseen classes the largest non-zero weight (keeps it learnable)
458
+ if np.any(counts == 0) and np.any(counts > 0):
459
+ w[counts == 0] = w[counts > 0].max()
460
+
461
+ # normalize by mean of non-zero
462
+ nz = w > 0
463
+ w[nz] /= w[nz].mean() + 1e-12
464
+
465
+ # cap spread consistently with a single 'cap'
466
+ cap = float(max_ratio) if max_ratio is not None else 10.0
467
+ cap = max(cap, 5.0) # ensure we allow some differentiation
468
+ if np.any(nz):
469
+ spread = w[nz].max() / max(w[nz].min(), 1e-12)
470
+ if spread > cap:
471
+ scale = cap / spread
472
+ w[nz] = 1.0 + (w[nz] - 1.0) * scale
473
+
474
+ return torch.tensor(w.astype(np.float32), device=self.device)
475
+
476
+ def _select_device(self, device: Literal["gpu", "cpu", "mps"]) -> torch.device:
477
+ """Selects the appropriate PyTorch device based on user preference and availability.
478
+
479
+ This method checks the user's device preference ('gpu', 'cpu', or 'mps') and verifies if the requested hardware is available. If the preferred device is not available, it falls back to CPU and logs a warning.
480
+
481
+ Args:
482
+ device (Literal["gpu", "cpu", "mps"]): The preferred device type for PyTorch operations.
483
+
484
+ Returns:
485
+ torch.device: The selected PyTorch device.
486
+ """
487
+ dvc: str = device
488
+ dvc = dvc.lower().strip()
489
+ if dvc == "cpu":
490
+ self.logger.info("Using PyTorch device: CPU.")
491
+ return torch.device("cpu")
492
+ if dvc == "mps":
493
+ if torch.backends.mps.is_available():
494
+ self.logger.info("Using PyTorch device: mps.")
495
+ return torch.device("mps")
496
+ self.logger.warning("MPS unavailable; falling back to CPU.")
497
+ return torch.device("cpu")
498
+ # gpu
499
+ if torch.cuda.is_available():
500
+ self.logger.info("Using PyTorch device: cuda.")
501
+ return torch.device("cuda")
502
+ self.logger.warning("CUDA unavailable; falling back to CPU.")
503
+ return torch.device("cpu")
504
+
505
+ def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
506
+ """Creates the directory structure for storing model outputs.
507
+
508
+ This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
509
+
510
+ Args:
511
+ prefix (str): The prefix for the main output directory.
512
+ outdirs (List[str]): A list of subdirectory names to create within the main directory.
513
+
514
+ Raises:
515
+ Exception: If any of the directories cannot be created.
516
+ """
517
+ formatted_output_dir = Path(f"{prefix}_output")
518
+ base_dir = formatted_output_dir / "Unsupervised"
519
+
520
+ for d in outdirs:
521
+ subdir = base_dir / d / self.model_name
522
+ setattr(self, f"{d}_dir", subdir)
523
+ try:
524
+ getattr(self, f"{d}_dir").mkdir(parents=True, exist_ok=True)
525
+ except Exception as e:
526
+ msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
527
+ self.logger.error(msg)
528
+ raise Exception(msg)
529
+
530
+ def _clear_resources(
531
+ self,
532
+ model: torch.nn.Module,
533
+ train_loader: torch.utils.data.DataLoader,
534
+ latent_vectors: torch.nn.Parameter | None = None,
535
+ ) -> None:
536
+ """Releases GPU and CPU memory after an Optuna trial.
537
+
538
+ This is a crucial step during hyperparameter tuning to prevent memory leaks between trials, ensuring that each trial runs in a clean environment.
539
+
540
+ Args:
541
+ model (torch.nn.Module): The model from the completed trial.
542
+ train_loader (torch.utils.data.DataLoader): The data loader from the trial.
543
+ latent_vectors (torch.nn.Parameter | None): The latent vectors from the trial.
544
+ """
545
+ try:
546
+ del model, train_loader
547
+
548
+ if latent_vectors is not None:
549
+ del latent_vectors
550
+
551
+ except NameError:
552
+ pass
553
+
554
+ gc.collect()
555
+ if torch.cuda.is_available():
556
+ torch.cuda.empty_cache()
557
+ elif hasattr(torch, "mps") and torch.backends.mps.is_available():
558
+ try:
559
+ torch.mps.empty_cache()
560
+ except Exception:
561
+ pass
562
+
563
+ def _make_eval_visualizations(
564
+ self,
565
+ labels: List[str],
566
+ y_pred_proba: np.ndarray,
567
+ y_true: np.ndarray,
568
+ y_pred: np.ndarray,
569
+ metrics: Dict[str, float],
570
+ msg: str,
571
+ ):
572
+ """Generate and save evaluation visualizations.
573
+
574
+ 3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
575
+
576
+ Args:
577
+ labels (List[str]): Class label names.
578
+ y_pred_proba (np.ndarray): Predicted probabilities (2D array).
579
+ y_true (np.ndarray): True labels (1D array).
580
+ y_pred (np.ndarray): Predicted labels (1D array).
581
+ metrics (Dict[str, float]): Computed metrics.
582
+ msg (str): Message to log before generating plots.
583
+ """
584
+ self.logger.info(msg)
585
+
586
+ prefix = "zygosity" if len(labels) == 3 else "iupac"
587
+ n_labels = len(labels)
588
+
589
+ self.plotter_.plot_metrics(
590
+ y_true=y_true,
591
+ y_pred_proba=y_pred_proba,
592
+ metrics=metrics,
593
+ label_names=labels,
594
+ prefix=f"geno{n_labels}_{prefix}",
595
+ )
596
+ self.plotter_.plot_confusion_matrix(
597
+ y_true_1d=y_true,
598
+ y_pred_1d=y_pred,
599
+ label_names=labels,
600
+ prefix=f"geno{n_labels}_{prefix}",
601
+ )
602
+
603
+ def _make_class_reports(
604
+ self,
605
+ y_true: np.ndarray,
606
+ y_pred: np.ndarray,
607
+ metrics: Dict[str, float],
608
+ y_pred_proba: np.ndarray | None = None,
609
+ labels: List[str] = ["REF", "HET", "ALT"],
610
+ ) -> None:
611
+ """Generate and save detailed classification reports and visualizations.
612
+
613
+ 3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
614
+
615
+ Args:
616
+ y_true (np.ndarray): True labels (1D array).
617
+ y_pred (np.ndarray): Predicted labels (1D array).
618
+ metrics (Dict[str, float]): Computed metrics.
619
+ y_pred_proba (np.ndarray | None): Predicted probabilities (2D array). Defaults to None.
620
+ labels (List[str]): Class label names
621
+ (default: ["REF", "HET", "ALT"] for 3-class).
622
+ """
623
+ report_name = "zygosity" if len(labels) == 3 else "iupac"
624
+ middle = "IUPAC" if report_name == "iupac" else "Zygosity"
625
+
626
+ msg = f"{middle} Report (on {y_true.size} total genotypes)"
627
+ self.logger.info(msg)
628
+
629
+ if y_pred_proba is not None:
630
+ self.plotter_.plot_metrics(
631
+ y_true,
632
+ y_pred_proba,
633
+ metrics,
634
+ label_names=labels,
635
+ prefix=report_name,
636
+ )
637
+
638
+ self.plotter_.plot_confusion_matrix(
639
+ y_true, y_pred, label_names=labels, prefix=report_name
640
+ )
641
+
642
+ report: str | dict = classification_report(
643
+ y_true,
644
+ y_pred,
645
+ labels=list(range(len(labels))),
646
+ target_names=labels,
647
+ zero_division=0,
648
+ output_dict=True,
649
+ )
650
+
651
+ if not isinstance(report, dict):
652
+ msg = "Expected classification_report to return a dict."
653
+ self.logger.error(msg)
654
+ raise ValueError(msg)
655
+
656
+ report_subset = {}
657
+ for k, v in report.items():
658
+ tmp = {}
659
+ if isinstance(v, dict) and "support" in v:
660
+ for k2, v2 in v.items():
661
+ if k2 != "support":
662
+ tmp[k2] = v2
663
+ if tmp:
664
+ report_subset[k] = tmp
665
+
666
+ if report_subset:
667
+ pm = PrettyMetrics(
668
+ report_subset,
669
+ precision=3,
670
+ title=f"{self.model_name} {middle} Report",
671
+ )
672
+ pm.render()
673
+
674
+ with open(self.metrics_dir / f"{report_name}_report.json", "w") as f:
675
+ json.dump(report, f, indent=4)
676
+
677
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
678
+
679
+ plots = viz.plot_all(
680
+ report, # type: ignore
681
+ title_prefix=f"{self.model_name} {middle} Report",
682
+ show=getattr(self, "show_plots", False),
683
+ heatmap_classes_only=True,
684
+ )
685
+
686
+ for name, fig in plots.items():
687
+ fout = self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
688
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
689
+ fig.savefig(fout, dpi=300, facecolor="#111122")
690
+ plt.close(fig)
691
+ elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
692
+ fout_html = fout.with_suffix(".html")
693
+ fig.write_html(file=fout_html)
694
+
695
+ SNPioMultiQC.queue_html(
696
+ fout_html,
697
+ panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
698
+ section=f"PG-SUI: {self.model_name} Model Imputation",
699
+ title=f"{self.model_name} {middle} Radar Plot",
700
+ index_label=name,
701
+ description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
702
+ )
703
+
704
+ if not self.is_haploid:
705
+ msg = f"Ploidy: {self.ploidy}. Evaluating per allele."
706
+ self.logger.info(msg)
707
+
708
+ viz._reset_mpl_style()
709
+
710
+ def _compute_hidden_layer_sizes(
711
+ self,
712
+ n_inputs: int,
713
+ n_outputs: int,
714
+ n_samples: int,
715
+ n_hidden: int,
716
+ *,
717
+ alpha: float = 4.0,
718
+ schedule: str = "pyramid",
719
+ min_size: int = 16,
720
+ max_size: int | None = None,
721
+ multiple_of: int = 8,
722
+ decay: float | None = None,
723
+ cap_by_inputs: bool = True,
724
+ ) -> list[int]:
725
+ """Compute hidden layer sizes given problem scale and a layer count.
726
+
727
+ This method computes a list of hidden layer sizes based on the number of input features, output classes, training samples, and desired hidden layers. The sizes are determined using a specified schedule (pyramid, constant, or linear) and are constrained by minimum and maximum sizes, as well as rounding to multiples of a specified value.
728
+
729
+ Args:
730
+ n_inputs (int): Number of input features.
731
+ n_outputs (int): Number of output classes.
732
+ n_samples (int): Number of training samples.
733
+ n_hidden (int): Number of hidden layers.
734
+ alpha (float): Scaling factor for base layer size. Default is 4.0.
735
+ schedule (Literal["pyramid", "constant", "linear"]): Size schedule. Default is "pyramid".
736
+ min_size (int): Minimum layer size. Default is 16.
737
+ max_size (int | None): Maximum layer size. Default is None (no limit).
738
+ multiple_of (int): Round layer sizes to be multiples of this. Default is 8.
739
+ decay (float | None): Decay factor for "pyramid" schedule. If None, it is computed automatically. Default is None.
740
+ cap_by_inputs (bool): If True, cap layer sizes to n_inputs. Default is True.
741
+
742
+ Returns:
743
+ list[int]: List of hidden layer sizes.
744
+
745
+ Raises:
746
+ ValueError: If n_hidden < 0 or if alpha * (n_inputs + n_outputs) <= 0 or if schedule is unknown.
747
+ TypeError: If any argument is not of the expected type.
748
+
749
+ Notes:
750
+ - If n_hidden is 0, returns an empty list.
751
+ - The base layer size is computed as ceil(n_samples / (alpha * (n_inputs + n_outputs))).
752
+ - The sizes are adjusted according to the specified schedule and constraints.
753
+ """
754
+ if n_hidden < 0:
755
+ msg = f"n_hidden must be >= 0, got {n_hidden}."
756
+ self.logger.error(msg)
757
+ raise ValueError(msg)
758
+
759
+ if schedule not in {"pyramid", "constant", "linear"}:
760
+ msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
761
+ self.logger.error(msg)
762
+ raise ValueError(msg)
763
+
764
+ if n_hidden == 0:
765
+ return []
766
+
767
+ denom = float(alpha) * float(n_inputs + n_outputs)
768
+
769
+ if denom <= 0:
770
+ msg = f"alpha * (n_inputs + n_outputs) must be > 0, got {denom}."
771
+ self.logger.error(msg)
772
+ raise ValueError(msg)
773
+
774
+ base = int(np.ceil(float(n_samples) / denom))
775
+
776
+ if max_size is None:
777
+ max_size = max(n_inputs, base)
778
+
779
+ base = int(np.clip(base, min_size, max_size))
780
+
781
+ if schedule == "constant":
782
+ sizes = np.full(shape=(n_hidden,), fill_value=base, dtype=float)
783
+
784
+ elif schedule == "linear":
785
+ target = max(min_size, min(base, base // 4))
786
+ sizes = (
787
+ np.array([base], dtype=float)
788
+ if n_hidden == 1
789
+ else np.linspace(base, target, num=n_hidden, dtype=float)
790
+ )
791
+
792
+ elif schedule == "pyramid":
793
+ if n_hidden == 1:
794
+ sizes = np.array([base], dtype=float)
795
+ else:
796
+ if decay is None:
797
+ target = max(min_size, base // 4)
798
+ if base <= 0 or target <= 0:
799
+ dcy = 1.0
800
+ else:
801
+ dcy = (target / float(base)) ** (1.0 / (n_hidden - 1))
802
+ dcy = float(np.clip(dcy, 0.25, 0.99))
803
+ exponents = np.arange(n_hidden, dtype=float)
804
+ sizes = base * (dcy**exponents)
805
+
806
+ else:
807
+ msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
808
+ self.logger.error(msg)
809
+ raise ValueError(msg)
810
+
811
+ sizes = np.clip(sizes, min_size, max_size)
812
+
813
+ if cap_by_inputs:
814
+ sizes = np.minimum(sizes, float(n_inputs))
815
+
816
+ sizes = (np.ceil(sizes / multiple_of) * multiple_of).astype(int)
817
+ sizes = np.minimum.accumulate(sizes)
818
+ return np.clip(sizes, min_size, max_size).astype(int).tolist()
819
+
820
+ def _class_weights_from_zygosity(self, X: np.ndarray) -> torch.Tensor:
821
+ """Class-balanced weights for 0/1/2 (handles haploid collapse if needed).
822
+
823
+ This method computes class-balanced weights for the genotype classes (0/1/2) based on the provided genotype matrix. It handles cases where the data is haploid by collapsing the ALT class to 1, effectively treating the problem as binary classification (REF vs ALT). The weights are calculated using a class-balanced weighting scheme that considers the frequency of each class in the training data, with parameters for beta and maximum ratio to control the weighting behavior. The resulting weights are returned as a PyTorch tensor on the current device.
824
+
825
+ Args:
826
+ X (np.ndarray): 0/1/2 with -1 for missing.
827
+
828
+ Returns:
829
+ torch.Tensor: Weights on current device.
830
+ """
831
+ y = X[X != -1].ravel().astype(np.int64)
832
+ if y.size == 0:
833
+ return torch.ones(
834
+ self.num_classes_, dtype=torch.float32, device=self.device
835
+ )
836
+
837
+ return self._class_balanced_weights_from_mask(
838
+ y=y,
839
+ train_mask=np.ones_like(y, dtype=bool),
840
+ num_classes=self.num_classes_,
841
+ beta=self.beta,
842
+ max_ratio=self.max_ratio,
843
+ mode="allele", # 1D int vector
844
+ ).to(self.device)
845
+
846
+ @staticmethod
847
+ def _normalize_class_weights(
848
+ weights: torch.Tensor | None,
849
+ ) -> torch.Tensor | None:
850
+ """Normalize class weights once to keep loss scale stable.
851
+
852
+ Args:
853
+ weights (torch.Tensor | None): Class weights to normalize.
854
+
855
+ Returns:
856
+ torch.Tensor | None: Normalized class weights or None if input is None.
857
+ """
858
+ if weights is None:
859
+ return None
860
+ return weights / weights.mean().clamp_min(1e-8)
861
+
862
+ def _get_float_genotypes(self, *, copy: bool = True) -> np.ndarray:
863
+ """Float32 0/1/2 matrix with NaNs for missing, cached per dataset.
864
+
865
+ Args:
866
+ copy (bool): If True, return a copy of the cached array. Default is True.
867
+
868
+ Returns:
869
+ np.ndarray: Float32 genotype matrix with NaNs for missing values.
870
+ """
871
+ cache = self._float_genotype_cache
872
+ current = self.pgenc.genotypes_012
873
+ if cache is None or cache.shape != current.shape or cache.dtype != np.float32:
874
+ arr = np.asarray(current, dtype=np.float32)
875
+ arr = np.where(arr < 0, np.nan, arr)
876
+ self._float_genotype_cache = arr
877
+ cache = arr
878
+ return cache.copy() if copy else cache
879
+
880
+ def _sim_mask_cache_key(self) -> tuple | None:
881
+ """Key for caching simulated-missing masks."""
882
+ if not getattr(self, "simulate_missing", False):
883
+ return None
884
+ shape = tuple(self.pgenc.genotypes_012.shape)
885
+ return (
886
+ id(self.genotype_data),
887
+ self.sim_strategy,
888
+ round(float(self.sim_prop), 6),
889
+ self.seed,
890
+ shape,
891
+ )
892
+
893
+ def _one_hot_encode_012(self, X: np.ndarray | torch.Tensor) -> torch.Tensor:
894
+ """One-hot 0/1/2; -1 rows are all-zeros (B, L, K).
895
+
896
+ This method performs one-hot encoding of the input genotype data (0, 1, 2) while handling missing values represented by -1. The output is a tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
897
+
898
+ Args:
899
+ X (np.ndarray | torch.Tensor): The input data to be one-hot encoded, either as a NumPy array or a PyTorch tensor.
900
+
901
+ Returns:
902
+ torch.Tensor: A one-hot encoded tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
903
+ """
904
+ Xt = (
905
+ torch.from_numpy(X).to(self.device)
906
+ if isinstance(X, np.ndarray)
907
+ else X.to(self.device)
908
+ )
909
+
910
+ # B=batch, L=features, K=classes
911
+ B, L = Xt.shape
912
+ K = self.num_classes_
913
+ X_ohe = torch.zeros(B, L, K, dtype=torch.float32, device=self.device)
914
+ valid = Xt != -1
915
+ idx = Xt[valid].long()
916
+
917
+ if idx.numel() > 0:
918
+ X_ohe[valid] = F.one_hot(idx, num_classes=K).float()
919
+
920
+ return X_ohe
921
+
922
+ def _eval_for_pruning(
923
+ self,
924
+ *,
925
+ model: torch.nn.Module,
926
+ X_val: np.ndarray,
927
+ params: dict,
928
+ metric: str,
929
+ objective_mode: bool = True,
930
+ do_latent_infer: bool = False,
931
+ latent_steps: int = 50,
932
+ latent_lr: float = 1e-2,
933
+ latent_weight_decay: float = 0.0,
934
+ latent_seed: int = 123,
935
+ _latent_cache: dict | None = None,
936
+ _latent_cache_key: str | None = None,
937
+ eval_mask_override: np.ndarray | None = None,
938
+ ) -> float:
939
+ """Compute a scalar metric (to MAXIMIZE) on a fixed validation set.
940
+
941
+ This method evaluates the model on a validation dataset and computes a specified metric, which is used for pruning decisions during hyperparameter tuning. It supports optional latent inference to optimize latent representations before evaluation. The method handles potential issues with non-finite metric values by returning negative infinity, making it easier to prune poorly performing trials.
942
+
943
+ Args:
944
+ model (torch.nn.Module): The model to evaluate.
945
+ X_val (np.ndarray): Validation data.
946
+ params (dict): Model parameters.
947
+ metric (str): Metric name to return.
948
+ objective_mode (bool): If True, use objective-mode evaluation. Default is True.
949
+ do_latent_infer (bool): If True, perform latent inference before evaluation. Default
950
+ latent_steps (int): Number of steps for latent inference. Default is 50.
951
+ latent_lr (float): Learning rate for latent inference. Default is 1e-2
952
+ latent_weight_decay (float): Weight decay for latent inference. Default is 0.0.
953
+ latent_seed (int): Random seed for latent inference. Default is 123.
954
+ _latent_cache (dict | None): Optional cache for storing/retrieving optimized latents
955
+ _latent_cache_key (str | None): Key for storing/retrieving in _latent_cache.
956
+ eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
957
+
958
+ Returns:
959
+ float: The computed metric value to maximize. Returns -inf on failure.
960
+ """
961
+ optimized_val_latents = None
962
+
963
+ # Optional latent inference path for models that need it.
964
+ if do_latent_infer and hasattr(self, "_latent_infer_for_eval"):
965
+ optimized_val_latents = self._latent_infer_for_eval( # type: ignore
966
+ model=model,
967
+ X_val=X_val,
968
+ steps=latent_steps,
969
+ lr=latent_lr,
970
+ weight_decay=latent_weight_decay,
971
+ seed=latent_seed,
972
+ cache=_latent_cache,
973
+ cache_key=_latent_cache_key,
974
+ )
975
+ # Retrieve the optimized latents from the cache
976
+ if _latent_cache is not None and _latent_cache_key in _latent_cache:
977
+ optimized_val_latents = _latent_cache[_latent_cache_key]
978
+
979
+ if getattr(self, "_tune_eval_slice", None) is not None:
980
+ X_val = X_val[self._tune_eval_slice]
981
+ if eval_mask_override is not None:
982
+ eval_mask_override = eval_mask_override[self._tune_eval_slice]
983
+
984
+ # Child's evaluator now accepts the pre-computed latents
985
+ metrics = self._evaluate_model( # type: ignore
986
+ X_val=X_val,
987
+ model=model,
988
+ params=params,
989
+ objective_mode=objective_mode,
990
+ latent_vectors_val=optimized_val_latents,
991
+ eval_mask_override=eval_mask_override,
992
+ )
993
+
994
+ # Prefer the requested metric; fall back to self.tune_metric if needed.
995
+ val = metrics.get(metric, metrics.get(getattr(self, "tune_metric", ""), None))
996
+
997
+ if val is None or not np.isfinite(val):
998
+ return -np.inf # make pruning decisions easy/robust on bad reads
999
+
1000
+ return float(val)
1001
+
1002
+ def _first_linear_in_features(self, model: torch.nn.Module) -> int:
1003
+ """Return in_features of the model's first Linear layer.
1004
+
1005
+ This method iterates through the modules of the provided PyTorch model to find the first instance of a Linear layer. It then retrieves and returns the `in_features` attribute of that layer, which indicates the number of input features expected by the layer.
1006
+
1007
+ Args:
1008
+ model (torch.nn.Module): The model to inspect.
1009
+
1010
+ Returns:
1011
+ int: The in_features of the first Linear layer.
1012
+ """
1013
+ for m in model.modules():
1014
+ if isinstance(m, torch.nn.Linear):
1015
+ return int(m.in_features)
1016
+ raise RuntimeError("No Linear layers found in model.")
1017
+
1018
+ def _assert_model_latent_compat(
1019
+ self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter
1020
+ ) -> None:
1021
+ """Raise if model's first Linear doesn't match latent_vectors width.
1022
+
1023
+ This method checks that the dimensionality of the provided latent vectors matches the expected input feature size of the model's first linear layer. If there is a mismatch, it raises a ValueError with a descriptive message.
1024
+
1025
+ Args:
1026
+ model (torch.nn.Module): The model to check.
1027
+ latent_vectors (torch.nn.Parameter): The latent vectors to check.
1028
+
1029
+ Raises:
1030
+ ValueError: If the latent dimension does not match the model's expected input features.
1031
+ """
1032
+ zdim = int(latent_vectors.shape[1])
1033
+ first_in = self._first_linear_in_features(model)
1034
+ if first_in != zdim:
1035
+ raise ValueError(
1036
+ f"Latent mismatch: zdim={zdim}, model first Linear expects in_features={first_in}"
1037
+ )
1038
+
1039
+ def _prepare_tuning_artifacts(self) -> None:
1040
+ """Prepare data and artifacts needed for hyperparameter tuning.
1041
+
1042
+ This method sets up the necessary data splits, data loaders, and class weights required for hyperparameter tuning. It creates training and validation sets from the ground truth data, initializes data loaders with a specified batch size, and computes class-balanced weights based on the training data. The method also handles optional subsampling of the dataset for faster tuning and prepares slices for evaluation if needed.
1043
+
1044
+ Raises:
1045
+ AttributeError: If the ground truth data (`ground_truth_`) is not set.
1046
+ """
1047
+ if getattr(self, "_tune_ready", False):
1048
+ return
1049
+
1050
+ X = self.ground_truth_
1051
+ n_samp, n_loci = X.shape
1052
+ rng = self.rng
1053
+
1054
+ if self.tune_fast:
1055
+ s = min(n_samp, self.tune_max_samples)
1056
+ l = n_loci if self.tune_max_loci == 0 else min(n_loci, self.tune_max_loci)
1057
+
1058
+ samp_idx = np.sort(rng.choice(n_samp, size=s, replace=False))
1059
+ loci_idx = np.sort(rng.choice(n_loci, size=l, replace=False))
1060
+ X_small = X[samp_idx][:, loci_idx]
1061
+ else:
1062
+ X_small = X
1063
+
1064
+ idx = np.arange(X_small.shape[0])
1065
+ tr, te = train_test_split(
1066
+ idx, test_size=self.validation_split, random_state=self.seed
1067
+ )
1068
+ self._tune_train_idx = tr
1069
+ self._tune_test_idx = te
1070
+ self._tune_X_train = X_small[tr]
1071
+ self._tune_X_test = X_small[te]
1072
+
1073
+ self._tune_class_weights = self._normalize_class_weights(
1074
+ self._class_weights_from_zygosity(self._tune_X_train)
1075
+ )
1076
+
1077
+ # Temporarily bump batch size only for tuning loader
1078
+ orig_bs = self.batch_size
1079
+ self.batch_size = self.tune_batch_size
1080
+ self._tune_loader = self._get_data_loaders(self._tune_X_train) # type: ignore
1081
+ self.batch_size = orig_bs
1082
+
1083
+ self._tune_num_features = self._tune_X_train.shape[1]
1084
+ self._tune_val_latents_source = None
1085
+ self._tune_train_latents_source = None
1086
+
1087
+ # Optional: for huge val sets, thin them for proxy metric
1088
+ if (
1089
+ self.tune_proxy_metric_batch
1090
+ and self._tune_X_test.shape[0] > self.tune_proxy_metric_batch
1091
+ ):
1092
+ self._tune_eval_slice = np.arange(self.tune_proxy_metric_batch)
1093
+ else:
1094
+ self._tune_eval_slice = None
1095
+
1096
+ self._tune_ready = True
1097
+
1098
+ def _save_best_params(self, best_params: Dict[str, Any]) -> None:
1099
+ """Save the best hyperparameters to a JSON file.
1100
+
1101
+ This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
1102
+
1103
+ Args:
1104
+ best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
1105
+ """
1106
+ if not hasattr(self, "parameters_dir"):
1107
+ msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
1108
+ self.logger.error(msg)
1109
+ raise AttributeError(msg)
1110
+
1111
+ fout = self.parameters_dir / "best_parameters.json"
1112
+
1113
+ with open(fout, "w") as f:
1114
+ json.dump(best_params, f, indent=4)
1115
+
1116
+ def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
1117
+ """An abstract method for setting best parameters."""
1118
+ raise NotImplementedError