pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.6.16a3.dist-info/METADATA +292 -0
- pg_sui-1.6.16a3.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
- pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +922 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1436 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1121 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
- pgsui/impute/unsupervised/imputers/vae.py +1316 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/METADATA +0 -322
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,1121 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gc
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import optuna
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import plotly.graph_objects as go
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
from matplotlib.figure import Figure
|
|
16
|
+
from sklearn.metrics import classification_report
|
|
17
|
+
from sklearn.model_selection import train_test_split
|
|
18
|
+
from snpio import SNPioMultiQC
|
|
19
|
+
from snpio.utils.logging import LoggerManager
|
|
20
|
+
|
|
21
|
+
from pgsui.impute.unsupervised.nn_scorers import Scorer
|
|
22
|
+
from pgsui.utils.classification_viz import ClassificationReportVisualizer
|
|
23
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
24
|
+
from pgsui.utils.plotting import Plotting
|
|
25
|
+
from pgsui.utils.pretty_metrics import PrettyMetrics
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
29
|
+
|
|
30
|
+
from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
|
|
31
|
+
from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
|
|
32
|
+
from pgsui.impute.unsupervised.models.ubp_model import UBPModel
|
|
33
|
+
from pgsui.impute.unsupervised.models.vae_model import VAEModel
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BaseNNImputer:
|
|
37
|
+
"""An abstract base class for neural network-based imputers.
|
|
38
|
+
|
|
39
|
+
This class provides a shared framework and common functionality for all neural network imputers. It is not meant to be instantiated directly. Instead, child classes should inherit from it and implement the abstract methods. Provided functionality: Directory setup and logging initialization; A hyperparameter tuning pipeline using Optuna; Utility methods for building models (`build_model`), initializing weights (`initialize_weights`), and checking for fitted attributes (`ensure_attribute`); Helper methods for calculating class weights for imbalanced data; Setup for standardized plotting and model scoring classes.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
model_name: str,
|
|
45
|
+
genotype_data: "GenotypeData",
|
|
46
|
+
prefix: str,
|
|
47
|
+
*,
|
|
48
|
+
device: Literal["gpu", "cpu", "mps"] = "cpu",
|
|
49
|
+
verbose: bool = False,
|
|
50
|
+
debug: bool = False,
|
|
51
|
+
):
|
|
52
|
+
"""Initializes the base class for neural network imputers.
|
|
53
|
+
|
|
54
|
+
This constructor sets up the device (CPU, GPU, or MPS), creates the necessary output directories for models and results, and a logger. It also initializes a genotype encoder for handling genotype data.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
prefix (str): A prefix used to name the output directory (e.g., 'pgsui_output').
|
|
58
|
+
device (Literal["gpu", "cpu", "mps"]): The device to use for PyTorch operations. If 'gpu' or 'mps' is chosen, it will fall back to 'cpu' if the required hardware is not available. Defaults to "cpu".
|
|
59
|
+
verbose (bool): If True, enables detailed logging output. Defaults to False.
|
|
60
|
+
debug (bool): If True, enables debug mode. Defaults to False.
|
|
61
|
+
"""
|
|
62
|
+
self.model_name = model_name
|
|
63
|
+
self.genotype_data = genotype_data
|
|
64
|
+
|
|
65
|
+
self.prefix = prefix
|
|
66
|
+
self.verbose = verbose
|
|
67
|
+
self.debug = debug
|
|
68
|
+
|
|
69
|
+
# Quiet Matplotlib/fontTools INFO logging when saving PDF/SVG
|
|
70
|
+
for name in (
|
|
71
|
+
"fontTools",
|
|
72
|
+
"fontTools.subset",
|
|
73
|
+
"fontTools.ttLib",
|
|
74
|
+
"matplotlib.font_manager",
|
|
75
|
+
):
|
|
76
|
+
lg = logging.getLogger(name)
|
|
77
|
+
lg.setLevel(logging.WARNING)
|
|
78
|
+
lg.propagate = False
|
|
79
|
+
|
|
80
|
+
self.device = self._select_device(device)
|
|
81
|
+
|
|
82
|
+
# Prepare directory structure
|
|
83
|
+
outdirs = ["models", "plots", "metrics", "optimize", "parameters"]
|
|
84
|
+
self._create_model_directories(prefix, outdirs)
|
|
85
|
+
|
|
86
|
+
# Initialize loggers
|
|
87
|
+
kwargs = {"prefix": prefix, "verbose": verbose, "debug": debug}
|
|
88
|
+
logman = LoggerManager(__name__, **kwargs)
|
|
89
|
+
self.logger = configure_logger(
|
|
90
|
+
logman.get_logger(), verbose=self.verbose, debug=self.debug
|
|
91
|
+
)
|
|
92
|
+
self._float_genotype_cache: np.ndarray | None = None
|
|
93
|
+
self._sim_mask_cache: dict[tuple, np.ndarray] = {}
|
|
94
|
+
|
|
95
|
+
# To be initialized by child classes or fit method
|
|
96
|
+
self.tune_save_db: bool = False
|
|
97
|
+
self.tune_resume: bool = False
|
|
98
|
+
self.n_trials: int = 100
|
|
99
|
+
self.model_params: Dict[str, Any] = {}
|
|
100
|
+
self.tune_metric: str = "val_f1_macro"
|
|
101
|
+
self.learning_rate: float = 1e-3
|
|
102
|
+
self.plotter_: "Plotting"
|
|
103
|
+
self.num_features_: int = 0
|
|
104
|
+
self.num_classes_: int = 3
|
|
105
|
+
self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = "pdf"
|
|
106
|
+
self.plot_fontsize: int = 10
|
|
107
|
+
self.plot_dpi: int = 300
|
|
108
|
+
self.title_fontsize: int = 12
|
|
109
|
+
self.despine: bool = True
|
|
110
|
+
self.show_plots: bool = False
|
|
111
|
+
self.scoring_averaging: Literal["macro", "micro", "weighted"] = "macro"
|
|
112
|
+
self.pgenc: Any = None
|
|
113
|
+
self.is_haploid: bool = False
|
|
114
|
+
self.ploidy: int = 2
|
|
115
|
+
self.beta: float = 0.9999
|
|
116
|
+
self.max_ratio: float = 5.0
|
|
117
|
+
self.sim_strategy: str = "mcar"
|
|
118
|
+
self.sim_prop: float = 0.1
|
|
119
|
+
self.seed: int | None = 42
|
|
120
|
+
self.rng: np.random.Generator = np.random.default_rng(self.seed)
|
|
121
|
+
self.ground_truth_: np.ndarray
|
|
122
|
+
self.tune_fast: bool = False
|
|
123
|
+
self.tune_max_samples: int = 1000
|
|
124
|
+
self.tune_max_loci: int = 500
|
|
125
|
+
self.validation_split: float = 0.2
|
|
126
|
+
self.tune_batch_size: int = 64
|
|
127
|
+
self.tune_proxy_metric_batch: int = 512
|
|
128
|
+
self.batch_size: int = 64
|
|
129
|
+
self.best_params_: Dict[str, Any] = {}
|
|
130
|
+
|
|
131
|
+
self.optimize_dir: Path
|
|
132
|
+
self.models_dir: Path
|
|
133
|
+
self.plots_dir: Path
|
|
134
|
+
self.metrics_dir: Path
|
|
135
|
+
self.parameters_dir: Path
|
|
136
|
+
self.study_db: Path
|
|
137
|
+
|
|
138
|
+
def tune_hyperparameters(self) -> None:
|
|
139
|
+
"""Tunes model hyperparameters using an Optuna study.
|
|
140
|
+
|
|
141
|
+
This method orchestrates the hyperparameter search process. It creates an Optuna study that aims to maximize the metric defined in `self.tune_metric`. The search is driven by the `_objective` method, which must be implemented by the child class. After the search, the best parameters are logged, saved to a JSON file, and visualizations of the study are generated.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
NotImplementedError: If the `_objective` or `_set_best_params` methods are not implemented in the inheriting child class.
|
|
145
|
+
"""
|
|
146
|
+
self.logger.info("Tuning hyperparameters. This might take a while...")
|
|
147
|
+
|
|
148
|
+
if self.verbose or self.debug:
|
|
149
|
+
optuna.logging.set_verbosity(optuna.logging.INFO)
|
|
150
|
+
else:
|
|
151
|
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
|
152
|
+
|
|
153
|
+
study_db = None
|
|
154
|
+
load_if_exists = False
|
|
155
|
+
if self.tune_save_db:
|
|
156
|
+
study_db = self.optimize_dir / "study_database" / "optuna_study.db"
|
|
157
|
+
study_db.parent.mkdir(parents=True, exist_ok=True)
|
|
158
|
+
|
|
159
|
+
if self.tune_resume and study_db.exists():
|
|
160
|
+
load_if_exists = True
|
|
161
|
+
|
|
162
|
+
if not self.tune_resume and study_db.exists():
|
|
163
|
+
study_db.unlink()
|
|
164
|
+
|
|
165
|
+
study_name = f"{self.prefix} {self.model_name} Model Optimization"
|
|
166
|
+
storage = f"sqlite:///{study_db}" if self.tune_save_db else None
|
|
167
|
+
|
|
168
|
+
study = optuna.create_study(
|
|
169
|
+
direction="maximize",
|
|
170
|
+
study_name=study_name,
|
|
171
|
+
storage=storage,
|
|
172
|
+
load_if_exists=load_if_exists,
|
|
173
|
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if not hasattr(self, "_objective"):
|
|
177
|
+
msg = "`_objective()` must be implemented in the child class."
|
|
178
|
+
self.logger.error(msg)
|
|
179
|
+
raise NotImplementedError(msg)
|
|
180
|
+
|
|
181
|
+
self.n_jobs = getattr(self, "n_jobs", 1)
|
|
182
|
+
if self.n_jobs < -1 or self.n_jobs == 0:
|
|
183
|
+
self.logger.warning(f"Invalid n_jobs={self.n_jobs}. Setting n_jobs=1.")
|
|
184
|
+
self.n_jobs = 1
|
|
185
|
+
|
|
186
|
+
show_progress_bar = not self.verbose and not self.debug and self.n_jobs == 1
|
|
187
|
+
|
|
188
|
+
study.optimize(
|
|
189
|
+
lambda trial: self._objective(trial),
|
|
190
|
+
n_trials=self.n_trials,
|
|
191
|
+
n_jobs=self.n_jobs,
|
|
192
|
+
gc_after_trial=True,
|
|
193
|
+
show_progress_bar=show_progress_bar,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
best_metric = study.best_value
|
|
197
|
+
best_params = study.best_params
|
|
198
|
+
|
|
199
|
+
# Set the best parameters.
|
|
200
|
+
# NOTE: `_set_best_params()` must be implemented in the child class.
|
|
201
|
+
if not hasattr(self, "_set_best_params"):
|
|
202
|
+
msg = "Method `_set_best_params()` must be implemented in the child class."
|
|
203
|
+
self.logger.error(msg)
|
|
204
|
+
raise NotImplementedError(msg)
|
|
205
|
+
|
|
206
|
+
self.best_params_ = self._set_best_params(best_params)
|
|
207
|
+
self.model_params.update(self.best_params_)
|
|
208
|
+
self.logger.info(f"Best {self.tune_metric} metric: {best_metric}")
|
|
209
|
+
self.logger.info("Best parameters:")
|
|
210
|
+
best_params_tmp = copy.deepcopy(best_params)
|
|
211
|
+
best_params_tmp["learning_rate"] = self.learning_rate
|
|
212
|
+
|
|
213
|
+
title = f"{self.model_name} Optimized Parameters"
|
|
214
|
+
pm = PrettyMetrics(best_params_tmp, precision=6, title=title)
|
|
215
|
+
pm.render()
|
|
216
|
+
|
|
217
|
+
# Save best parameters to a JSON file.
|
|
218
|
+
self._save_best_params(best_params)
|
|
219
|
+
|
|
220
|
+
tn = f"{self.tune_metric} Value"
|
|
221
|
+
self.plotter_.plot_tuning(
|
|
222
|
+
study, self.model_name, self.optimize_dir / "plots", target_name=tn
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def initialize_weights(module: torch.nn.Module) -> None:
|
|
227
|
+
"""Initializes model weights using the Kaiming Uniform distribution.
|
|
228
|
+
|
|
229
|
+
This static method is intended to be applied to a PyTorch model to initialize the weights of its linear and convolutional layers. This initialization scheme is particularly effective for networks that use ReLU-family activation functions, as it helps maintain stable activation variances during training.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
module (torch.nn.Module): The PyTorch module (e.g., a layer) to initialize.
|
|
233
|
+
"""
|
|
234
|
+
if isinstance(
|
|
235
|
+
module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.ConvTranspose1d)
|
|
236
|
+
):
|
|
237
|
+
# Use Kaiming Uniform initialization for Linear and Conv layers
|
|
238
|
+
torch.nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
|
|
239
|
+
if module.bias is not None:
|
|
240
|
+
torch.nn.init.zeros_(module.bias)
|
|
241
|
+
|
|
242
|
+
def build_model(
|
|
243
|
+
self,
|
|
244
|
+
Model: (
|
|
245
|
+
torch.nn.Module
|
|
246
|
+
| type["AutoencoderModel"]
|
|
247
|
+
| type["NLPCAModel"]
|
|
248
|
+
| type["UBPModel"]
|
|
249
|
+
| type["VAEModel"]
|
|
250
|
+
),
|
|
251
|
+
model_params: Dict[str, int | float | str | bool],
|
|
252
|
+
) -> torch.nn.Module:
|
|
253
|
+
"""Builds and initializes a neural network model instance.
|
|
254
|
+
|
|
255
|
+
This method instantiates a model by combining fixed, data-dependent parameters (like `n_features`) with variable hyperparameters (like `latent_dim`). The resulting model is then moved to the appropriate compute device.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
Model (torch.nn.Module): The model class to be instantiated.
|
|
259
|
+
model_params (Dict[str, Any]): A dictionary of variable model hyperparameters, typically sampled during a hyperparameter search.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
torch.nn.Module: The constructed model instance, ready for training.
|
|
263
|
+
|
|
264
|
+
Raises:
|
|
265
|
+
TypeError: If `model_params` is not a dictionary.
|
|
266
|
+
AttributeError: If a required data-dependent attribute like `num_features_` has not been set, typically by calling `fit` first.
|
|
267
|
+
"""
|
|
268
|
+
if not isinstance(model_params, dict):
|
|
269
|
+
msg = f"'model_params' must be a dictionary, but got {type(model_params)}."
|
|
270
|
+
self.logger.error(msg)
|
|
271
|
+
raise TypeError(msg)
|
|
272
|
+
|
|
273
|
+
if not hasattr(self, "num_features_"):
|
|
274
|
+
msg = (
|
|
275
|
+
"Attribute 'num_features_' is not set. Call fit() before build_model()."
|
|
276
|
+
)
|
|
277
|
+
self.logger.error(msg)
|
|
278
|
+
raise AttributeError(msg)
|
|
279
|
+
|
|
280
|
+
# Start with a base set of fixed (non-tuned) parameters.
|
|
281
|
+
base_num_classes = getattr(self, "output_classes_", None)
|
|
282
|
+
if base_num_classes is None:
|
|
283
|
+
base_num_classes = self.num_classes_
|
|
284
|
+
all_params = {
|
|
285
|
+
"n_features": self.num_features_,
|
|
286
|
+
"prefix": self.prefix,
|
|
287
|
+
"num_classes": base_num_classes,
|
|
288
|
+
"verbose": self.verbose,
|
|
289
|
+
"debug": self.debug,
|
|
290
|
+
"device": self.device,
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
# Update with the variable hyperparameters from the provided dictionary
|
|
294
|
+
all_params.update(model_params)
|
|
295
|
+
|
|
296
|
+
return Model(**all_params).to(self.device)
|
|
297
|
+
|
|
298
|
+
def initialize_plotting_and_scorers(self) -> Tuple[Plotting, Scorer]:
|
|
299
|
+
"""Initializes and returns the plotting and scoring utility classes.
|
|
300
|
+
|
|
301
|
+
This method should be called within a `fit` method to set up the standardized utilities for generating plots and calculating performance metrics.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Tuple[Plotting, Scorer]: A tuple containing the initialized Plotting and Scorer objects.
|
|
305
|
+
"""
|
|
306
|
+
fmt = self.plot_format
|
|
307
|
+
|
|
308
|
+
# Initialize plotter.
|
|
309
|
+
plotter = Plotting(
|
|
310
|
+
model_name=self.model_name,
|
|
311
|
+
prefix=self.prefix,
|
|
312
|
+
plot_format=fmt,
|
|
313
|
+
plot_fontsize=self.plot_fontsize,
|
|
314
|
+
plot_dpi=self.plot_dpi,
|
|
315
|
+
title_fontsize=self.title_fontsize,
|
|
316
|
+
despine=self.despine,
|
|
317
|
+
show_plots=self.show_plots,
|
|
318
|
+
verbose=self.verbose,
|
|
319
|
+
debug=self.debug,
|
|
320
|
+
multiqc=True,
|
|
321
|
+
multiqc_section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Metrics
|
|
325
|
+
scorers = Scorer(
|
|
326
|
+
prefix=self.prefix,
|
|
327
|
+
average=self.scoring_averaging,
|
|
328
|
+
verbose=self.verbose,
|
|
329
|
+
debug=self.debug,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
return plotter, scorers
|
|
333
|
+
|
|
334
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
335
|
+
"""Defines the objective function for Optuna hyperparameter tuning.
|
|
336
|
+
|
|
337
|
+
This abstract method must be implemented by the child class. It should define a single hyperparameter tuning trial, which typically involves building, training, and evaluating a model with a set of sampled hyperparameters.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
trial (optuna.Trial): The Optuna trial object, used to sample hyperparameters.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
float: The value of the metric to be optimized (e.g., validation accuracy, F1-score).
|
|
344
|
+
"""
|
|
345
|
+
msg = "Method `_objective()` must be implemented in the child class."
|
|
346
|
+
self.logger.error(msg)
|
|
347
|
+
raise NotImplementedError(msg)
|
|
348
|
+
|
|
349
|
+
def fit(self, X: np.ndarray | pd.DataFrame | list | None = None) -> "BaseNNImputer":
|
|
350
|
+
"""Fits the imputer model to the data.
|
|
351
|
+
|
|
352
|
+
This abstract method must be implemented by the child class. It should contain the logic for training the neural network model on the provided input data `X`.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
X (np.ndarray | pd.DataFrame | list | None): The input data, which may contain missing values.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
BaseNNImputer: The fitted imputer instance.
|
|
359
|
+
"""
|
|
360
|
+
msg = "Method ``fit()`` must be implemented in the child class."
|
|
361
|
+
self.logger.error(msg)
|
|
362
|
+
raise NotImplementedError(msg)
|
|
363
|
+
|
|
364
|
+
def transform(
|
|
365
|
+
self, X: np.ndarray | pd.DataFrame | list | None = None
|
|
366
|
+
) -> np.ndarray:
|
|
367
|
+
"""Imputes missing values in the data using the trained model.
|
|
368
|
+
|
|
369
|
+
This abstract method must be implemented by the child class. It should use the fitted model to fill in missing values in the provided data `X`.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
X (np.ndarray | pd.DataFrame | list | None): The input data with missing values.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
np.ndarray: The data with missing values imputed.
|
|
376
|
+
"""
|
|
377
|
+
msg = "Method ``transform()`` must be implemented in the child class."
|
|
378
|
+
self.logger.error(msg)
|
|
379
|
+
raise NotImplementedError(msg)
|
|
380
|
+
|
|
381
|
+
def _class_balanced_weights_from_mask(
|
|
382
|
+
self,
|
|
383
|
+
y: np.ndarray,
|
|
384
|
+
train_mask: np.ndarray,
|
|
385
|
+
num_classes: int,
|
|
386
|
+
beta: float = 0.9999,
|
|
387
|
+
max_ratio: float = 5.0,
|
|
388
|
+
mode: Literal["allele", "genotype10"] = "allele",
|
|
389
|
+
) -> torch.Tensor:
|
|
390
|
+
"""Class-balanced weights (Cui et al. 2019) with overflow-safe effective number.
|
|
391
|
+
|
|
392
|
+
mode="allele": y is 1D alleles in {0..3}, train_mask same shape. mode="genotype10": y is (nS,nF,2) alleles; train_mask is (nS,nF) loci where both alleles known.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
y (np.ndarray): Ground truth labels.
|
|
396
|
+
train_mask (np.ndarray): Boolean mask of training examples (same shape as y or y without last dim for genotype10).
|
|
397
|
+
num_classes (int): Number of classes.
|
|
398
|
+
beta (float): Hyperparameter for effective number calculation. Clamped to (0,1). Default is 0.9999.
|
|
399
|
+
max_ratio (float): Maximum allowed ratio between largest and smallest non-zero weight. Default is 5.0.
|
|
400
|
+
mode (Literal["allele", "genotype10"]): Whether y contains allele labels or 10-class genotypes. Default is "allele".
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
torch.Tensor: Class weights of shape (num_classes,). Mean weight is 1.0, zero-weight classes remain zero.
|
|
404
|
+
"""
|
|
405
|
+
if mode == "allele":
|
|
406
|
+
valid = (y >= 0) & train_mask
|
|
407
|
+
cls, cnt = np.unique(y[valid].astype(np.int64), return_counts=True)
|
|
408
|
+
counts = np.zeros(num_classes, dtype=np.float64)
|
|
409
|
+
counts[cls] = cnt
|
|
410
|
+
|
|
411
|
+
elif mode == "genotype10":
|
|
412
|
+
if y.ndim != 3 or y.shape[-1] != 2:
|
|
413
|
+
msg = "For genotype10, y must be (nS,nF,2)."
|
|
414
|
+
self.logger.error(msg)
|
|
415
|
+
raise ValueError(msg)
|
|
416
|
+
|
|
417
|
+
if train_mask.shape != y.shape[:2]:
|
|
418
|
+
msg = "train_mask must be (nS,nF) for genotype10."
|
|
419
|
+
self.logger.error(msg)
|
|
420
|
+
raise ValueError(msg)
|
|
421
|
+
|
|
422
|
+
# only loci where both alleles known and in training
|
|
423
|
+
m = train_mask & np.all(y >= 0, axis=-1)
|
|
424
|
+
if not np.any(m):
|
|
425
|
+
counts = np.zeros(num_classes, dtype=np.float64)
|
|
426
|
+
|
|
427
|
+
else:
|
|
428
|
+
a1 = y[:, :, 0][m].astype(int)
|
|
429
|
+
a2 = y[:, :, 1][m].astype(int)
|
|
430
|
+
lo, hi = np.minimum(a1, a2), np.maximum(a1, a2)
|
|
431
|
+
# map to 10-class index
|
|
432
|
+
map10 = self.pgenc.map10
|
|
433
|
+
idx10 = map10[lo, hi]
|
|
434
|
+
idx10 = idx10[(idx10 >= 0) & (idx10 < num_classes)]
|
|
435
|
+
counts = np.bincount(idx10, minlength=num_classes).astype(np.float64)
|
|
436
|
+
else:
|
|
437
|
+
msg = f"Unknown mode supplied to _class_balanced_weights_from_mask: {mode}"
|
|
438
|
+
self.logger.error(msg)
|
|
439
|
+
raise ValueError(msg)
|
|
440
|
+
|
|
441
|
+
# ---- Effective number ----
|
|
442
|
+
beta = float(beta)
|
|
443
|
+
|
|
444
|
+
# clamp beta ∈ (0,1)
|
|
445
|
+
if not np.isfinite(beta):
|
|
446
|
+
beta = 0.9999
|
|
447
|
+
|
|
448
|
+
beta = min(max(beta, 1e-8), 1.0 - 1e-8)
|
|
449
|
+
|
|
450
|
+
logb = np.log(beta) # < 0
|
|
451
|
+
t = counts * logb # ≤ 0
|
|
452
|
+
|
|
453
|
+
# 1 - beta^n = 1 - exp(n*log(beta)) = -(exp(n*log(beta)) - 1)
|
|
454
|
+
# use expm1 for accuracy near 0; for very negative t, eff≈1.0
|
|
455
|
+
eff = np.where(t > -50.0, -np.expm1(t), 1.0)
|
|
456
|
+
|
|
457
|
+
# class-balanced weights
|
|
458
|
+
w = (1.0 - beta) / (eff + 1e-12)
|
|
459
|
+
|
|
460
|
+
# Give unseen classes the largest non-zero weight (keeps it learnable)
|
|
461
|
+
if np.any(counts == 0) and np.any(counts > 0):
|
|
462
|
+
w[counts == 0] = w[counts > 0].max()
|
|
463
|
+
|
|
464
|
+
# normalize by mean of non-zero
|
|
465
|
+
nz = w > 0
|
|
466
|
+
w[nz] /= w[nz].mean() + 1e-12
|
|
467
|
+
|
|
468
|
+
# cap spread consistently with a single 'cap'
|
|
469
|
+
cap = float(max_ratio) if max_ratio is not None else 10.0
|
|
470
|
+
cap = max(cap, 5.0) # ensure we allow some differentiation
|
|
471
|
+
if np.any(nz):
|
|
472
|
+
spread = w[nz].max() / max(w[nz].min(), 1e-12)
|
|
473
|
+
if spread > cap:
|
|
474
|
+
scale = cap / spread
|
|
475
|
+
w[nz] = 1.0 + (w[nz] - 1.0) * scale
|
|
476
|
+
|
|
477
|
+
return torch.tensor(w.astype(np.float32), device=self.device)
|
|
478
|
+
|
|
479
|
+
def _select_device(self, device: Literal["gpu", "cpu", "mps"]) -> torch.device:
|
|
480
|
+
"""Selects the appropriate PyTorch device based on user preference and availability.
|
|
481
|
+
|
|
482
|
+
This method checks the user's device preference ('gpu', 'cpu', or 'mps') and verifies if the requested hardware is available. If the preferred device is not available, it falls back to CPU and logs a warning.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
device (Literal["gpu", "cpu", "mps"]): The preferred device type for PyTorch operations.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
torch.device: The selected PyTorch device.
|
|
489
|
+
"""
|
|
490
|
+
dvc: str = device
|
|
491
|
+
dvc = dvc.lower().strip()
|
|
492
|
+
if dvc == "cpu":
|
|
493
|
+
self.logger.info("Using PyTorch device: CPU.")
|
|
494
|
+
return torch.device("cpu")
|
|
495
|
+
if dvc == "mps":
|
|
496
|
+
if torch.backends.mps.is_available():
|
|
497
|
+
self.logger.info("Using PyTorch device: mps.")
|
|
498
|
+
return torch.device("mps")
|
|
499
|
+
self.logger.warning("MPS unavailable; falling back to CPU.")
|
|
500
|
+
return torch.device("cpu")
|
|
501
|
+
# gpu
|
|
502
|
+
if torch.cuda.is_available():
|
|
503
|
+
self.logger.info("Using PyTorch device: cuda.")
|
|
504
|
+
return torch.device("cuda")
|
|
505
|
+
self.logger.warning("CUDA unavailable; falling back to CPU.")
|
|
506
|
+
return torch.device("cpu")
|
|
507
|
+
|
|
508
|
+
def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
|
|
509
|
+
"""Creates the directory structure for storing model outputs.
|
|
510
|
+
|
|
511
|
+
This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
prefix (str): The prefix for the main output directory.
|
|
515
|
+
outdirs (List[str]): A list of subdirectory names to create within the main directory.
|
|
516
|
+
|
|
517
|
+
Raises:
|
|
518
|
+
Exception: If any of the directories cannot be created.
|
|
519
|
+
"""
|
|
520
|
+
formatted_output_dir = Path(f"{prefix}_output")
|
|
521
|
+
base_dir = formatted_output_dir / "Unsupervised"
|
|
522
|
+
|
|
523
|
+
for d in outdirs:
|
|
524
|
+
subdir = base_dir / d / self.model_name
|
|
525
|
+
setattr(self, f"{d}_dir", subdir)
|
|
526
|
+
try:
|
|
527
|
+
getattr(self, f"{d}_dir").mkdir(parents=True, exist_ok=True)
|
|
528
|
+
except Exception as e:
|
|
529
|
+
msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
|
|
530
|
+
self.logger.error(msg)
|
|
531
|
+
raise Exception(msg)
|
|
532
|
+
|
|
533
|
+
def _clear_resources(
|
|
534
|
+
self,
|
|
535
|
+
model: torch.nn.Module,
|
|
536
|
+
train_loader: torch.utils.data.DataLoader,
|
|
537
|
+
latent_vectors: torch.nn.Parameter | None = None,
|
|
538
|
+
) -> None:
|
|
539
|
+
"""Releases GPU and CPU memory after an Optuna trial.
|
|
540
|
+
|
|
541
|
+
This is a crucial step during hyperparameter tuning to prevent memory leaks between trials, ensuring that each trial runs in a clean environment.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
model (torch.nn.Module): The model from the completed trial.
|
|
545
|
+
train_loader (torch.utils.data.DataLoader): The data loader from the trial.
|
|
546
|
+
latent_vectors (torch.nn.Parameter | None): The latent vectors from the trial.
|
|
547
|
+
"""
|
|
548
|
+
try:
|
|
549
|
+
del model, train_loader
|
|
550
|
+
|
|
551
|
+
if latent_vectors is not None:
|
|
552
|
+
del latent_vectors
|
|
553
|
+
|
|
554
|
+
except NameError:
|
|
555
|
+
pass
|
|
556
|
+
|
|
557
|
+
gc.collect()
|
|
558
|
+
if torch.cuda.is_available():
|
|
559
|
+
torch.cuda.empty_cache()
|
|
560
|
+
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
|
|
561
|
+
try:
|
|
562
|
+
torch.mps.empty_cache()
|
|
563
|
+
except Exception:
|
|
564
|
+
pass
|
|
565
|
+
|
|
566
|
+
def _make_eval_visualizations(
|
|
567
|
+
self,
|
|
568
|
+
labels: List[str],
|
|
569
|
+
y_pred_proba: np.ndarray,
|
|
570
|
+
y_true: np.ndarray,
|
|
571
|
+
y_pred: np.ndarray,
|
|
572
|
+
metrics: Dict[str, float],
|
|
573
|
+
msg: str,
|
|
574
|
+
):
|
|
575
|
+
"""Generate and save evaluation visualizations.
|
|
576
|
+
|
|
577
|
+
3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
labels (List[str]): Class label names.
|
|
581
|
+
y_pred_proba (np.ndarray): Predicted probabilities (2D array).
|
|
582
|
+
y_true (np.ndarray): True labels (1D array).
|
|
583
|
+
y_pred (np.ndarray): Predicted labels (1D array).
|
|
584
|
+
metrics (Dict[str, float]): Computed metrics.
|
|
585
|
+
msg (str): Message to log before generating plots.
|
|
586
|
+
"""
|
|
587
|
+
self.logger.info(msg)
|
|
588
|
+
|
|
589
|
+
prefix = "zygosity" if len(labels) == 3 else "iupac"
|
|
590
|
+
n_labels = len(labels)
|
|
591
|
+
|
|
592
|
+
self.plotter_.plot_metrics(
|
|
593
|
+
y_true=y_true,
|
|
594
|
+
y_pred_proba=y_pred_proba,
|
|
595
|
+
metrics=metrics,
|
|
596
|
+
label_names=labels,
|
|
597
|
+
prefix=f"geno{n_labels}_{prefix}",
|
|
598
|
+
)
|
|
599
|
+
self.plotter_.plot_confusion_matrix(
|
|
600
|
+
y_true_1d=y_true,
|
|
601
|
+
y_pred_1d=y_pred,
|
|
602
|
+
label_names=labels,
|
|
603
|
+
prefix=f"geno{n_labels}_{prefix}",
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
def _make_class_reports(
|
|
607
|
+
self,
|
|
608
|
+
y_true: np.ndarray,
|
|
609
|
+
y_pred: np.ndarray,
|
|
610
|
+
metrics: Dict[str, float],
|
|
611
|
+
y_pred_proba: np.ndarray | None = None,
|
|
612
|
+
labels: List[str] = ["REF", "HET", "ALT"],
|
|
613
|
+
) -> None:
|
|
614
|
+
"""Generate and save detailed classification reports and visualizations.
|
|
615
|
+
|
|
616
|
+
3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
y_true (np.ndarray): True labels (1D array).
|
|
620
|
+
y_pred (np.ndarray): Predicted labels (1D array).
|
|
621
|
+
metrics (Dict[str, float]): Computed metrics.
|
|
622
|
+
y_pred_proba (np.ndarray | None): Predicted probabilities (2D array). Defaults to None.
|
|
623
|
+
labels (List[str]): Class label names
|
|
624
|
+
(default: ["REF", "HET", "ALT"] for 3-class).
|
|
625
|
+
"""
|
|
626
|
+
report_name = "zygosity" if len(labels) == 3 else "iupac"
|
|
627
|
+
middle = "IUPAC" if report_name == "iupac" else "Zygosity"
|
|
628
|
+
|
|
629
|
+
msg = f"{middle} Report (on {y_true.size} total genotypes)"
|
|
630
|
+
self.logger.info(msg)
|
|
631
|
+
|
|
632
|
+
if y_pred_proba is not None:
|
|
633
|
+
self.plotter_.plot_metrics(
|
|
634
|
+
y_true,
|
|
635
|
+
y_pred_proba,
|
|
636
|
+
metrics,
|
|
637
|
+
label_names=labels,
|
|
638
|
+
prefix=report_name,
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
self.plotter_.plot_confusion_matrix(
|
|
642
|
+
y_true, y_pred, label_names=labels, prefix=report_name
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
report: str | dict = classification_report(
|
|
646
|
+
y_true,
|
|
647
|
+
y_pred,
|
|
648
|
+
labels=list(range(len(labels))),
|
|
649
|
+
target_names=labels,
|
|
650
|
+
zero_division=0,
|
|
651
|
+
output_dict=True,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
if not isinstance(report, dict):
|
|
655
|
+
msg = "Expected classification_report to return a dict."
|
|
656
|
+
self.logger.error(msg)
|
|
657
|
+
raise ValueError(msg)
|
|
658
|
+
|
|
659
|
+
report_subset = {}
|
|
660
|
+
for k, v in report.items():
|
|
661
|
+
tmp = {}
|
|
662
|
+
if isinstance(v, dict) and "support" in v:
|
|
663
|
+
for k2, v2 in v.items():
|
|
664
|
+
if k2 != "support":
|
|
665
|
+
tmp[k2] = v2
|
|
666
|
+
if tmp:
|
|
667
|
+
report_subset[k] = tmp
|
|
668
|
+
|
|
669
|
+
if report_subset:
|
|
670
|
+
pm = PrettyMetrics(
|
|
671
|
+
report_subset,
|
|
672
|
+
precision=3,
|
|
673
|
+
title=f"{self.model_name} {middle} Report",
|
|
674
|
+
)
|
|
675
|
+
pm.render()
|
|
676
|
+
|
|
677
|
+
with open(self.metrics_dir / f"{report_name}_report.json", "w") as f:
|
|
678
|
+
json.dump(report, f, indent=4)
|
|
679
|
+
|
|
680
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
681
|
+
|
|
682
|
+
plots = viz.plot_all(
|
|
683
|
+
report, # type: ignore
|
|
684
|
+
title_prefix=f"{self.model_name} {middle} Report",
|
|
685
|
+
show=getattr(self, "show_plots", False),
|
|
686
|
+
heatmap_classes_only=True,
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
for name, fig in plots.items():
|
|
690
|
+
fout = self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
|
|
691
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
692
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
693
|
+
plt.close(fig)
|
|
694
|
+
elif hasattr(fig, "write_html") and isinstance(fig, go.Figure):
|
|
695
|
+
fout_html = fout.with_suffix(".html")
|
|
696
|
+
fig.write_html(file=fout_html)
|
|
697
|
+
|
|
698
|
+
SNPioMultiQC.queue_html(
|
|
699
|
+
fout_html,
|
|
700
|
+
panel_id=f"pgsui_{self.model_name.lower()}_{report_name}_radar",
|
|
701
|
+
section=f"PG-SUI: {self.model_name} Model Imputation",
|
|
702
|
+
title=f"{self.model_name} {middle} Radar Plot",
|
|
703
|
+
index_label=name,
|
|
704
|
+
description=f"{self.model_name} {middle} {len(labels)}-base Radar Plot. This radar plot visualizes model performance for three metrics per-class: precision, recall, and F1-score. Each axis represents one of these metrics, allowing for a quick visual assessment of the model's strengths and weaknesses. Higher values towards the outer edge indicate better performance.",
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
if not self.is_haploid:
|
|
708
|
+
msg = f"Ploidy: {self.ploidy}. Evaluating per allele."
|
|
709
|
+
self.logger.info(msg)
|
|
710
|
+
|
|
711
|
+
viz._reset_mpl_style()
|
|
712
|
+
|
|
713
|
+
def _compute_hidden_layer_sizes(
|
|
714
|
+
self,
|
|
715
|
+
n_inputs: int,
|
|
716
|
+
n_outputs: int,
|
|
717
|
+
n_samples: int,
|
|
718
|
+
n_hidden: int,
|
|
719
|
+
*,
|
|
720
|
+
alpha: float = 4.0,
|
|
721
|
+
schedule: str = "pyramid",
|
|
722
|
+
min_size: int = 16,
|
|
723
|
+
max_size: int | None = None,
|
|
724
|
+
multiple_of: int = 8,
|
|
725
|
+
decay: float | None = None,
|
|
726
|
+
cap_by_inputs: bool = True,
|
|
727
|
+
) -> list[int]:
|
|
728
|
+
"""Compute hidden layer sizes given problem scale and a layer count.
|
|
729
|
+
|
|
730
|
+
This method computes a list of hidden layer sizes based on the number of input features, output classes, training samples, and desired hidden layers. The sizes are determined using a specified schedule (pyramid, constant, or linear) and are constrained by minimum and maximum sizes, as well as rounding to multiples of a specified value.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
n_inputs (int): Number of input features.
|
|
734
|
+
n_outputs (int): Number of output classes.
|
|
735
|
+
n_samples (int): Number of training samples.
|
|
736
|
+
n_hidden (int): Number of hidden layers.
|
|
737
|
+
alpha (float): Scaling factor for base layer size. Default is 4.0.
|
|
738
|
+
schedule (Literal["pyramid", "constant", "linear"]): Size schedule. Default is "pyramid".
|
|
739
|
+
min_size (int): Minimum layer size. Default is 16.
|
|
740
|
+
max_size (int | None): Maximum layer size. Default is None (no limit).
|
|
741
|
+
multiple_of (int): Round layer sizes to be multiples of this. Default is 8.
|
|
742
|
+
decay (float | None): Decay factor for "pyramid" schedule. If None, it is computed automatically. Default is None.
|
|
743
|
+
cap_by_inputs (bool): If True, cap layer sizes to n_inputs. Default is True.
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
list[int]: List of hidden layer sizes.
|
|
747
|
+
|
|
748
|
+
Raises:
|
|
749
|
+
ValueError: If n_hidden < 0 or if alpha * (n_inputs + n_outputs) <= 0 or if schedule is unknown.
|
|
750
|
+
TypeError: If any argument is not of the expected type.
|
|
751
|
+
|
|
752
|
+
Notes:
|
|
753
|
+
- If n_hidden is 0, returns an empty list.
|
|
754
|
+
- The base layer size is computed as ceil(n_samples / (alpha * (n_inputs + n_outputs))).
|
|
755
|
+
- The sizes are adjusted according to the specified schedule and constraints.
|
|
756
|
+
"""
|
|
757
|
+
if n_hidden < 0:
|
|
758
|
+
msg = f"n_hidden must be >= 0, got {n_hidden}."
|
|
759
|
+
self.logger.error(msg)
|
|
760
|
+
raise ValueError(msg)
|
|
761
|
+
|
|
762
|
+
if schedule not in {"pyramid", "constant", "linear"}:
|
|
763
|
+
msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
|
|
764
|
+
self.logger.error(msg)
|
|
765
|
+
raise ValueError(msg)
|
|
766
|
+
|
|
767
|
+
if n_hidden == 0:
|
|
768
|
+
return []
|
|
769
|
+
|
|
770
|
+
denom = float(alpha) * float(n_inputs + n_outputs)
|
|
771
|
+
|
|
772
|
+
if denom <= 0:
|
|
773
|
+
msg = f"alpha * (n_inputs + n_outputs) must be > 0, got {denom}."
|
|
774
|
+
self.logger.error(msg)
|
|
775
|
+
raise ValueError(msg)
|
|
776
|
+
|
|
777
|
+
base = int(np.ceil(float(n_samples) / denom))
|
|
778
|
+
|
|
779
|
+
if max_size is None:
|
|
780
|
+
max_size = max(n_inputs, base)
|
|
781
|
+
|
|
782
|
+
base = int(np.clip(base, min_size, max_size))
|
|
783
|
+
|
|
784
|
+
if schedule == "constant":
|
|
785
|
+
sizes = np.full(shape=(n_hidden,), fill_value=base, dtype=float)
|
|
786
|
+
|
|
787
|
+
elif schedule == "linear":
|
|
788
|
+
target = max(min_size, min(base, base // 4))
|
|
789
|
+
sizes = (
|
|
790
|
+
np.array([base], dtype=float)
|
|
791
|
+
if n_hidden == 1
|
|
792
|
+
else np.linspace(base, target, num=n_hidden, dtype=float)
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
elif schedule == "pyramid":
|
|
796
|
+
if n_hidden == 1:
|
|
797
|
+
sizes = np.array([base], dtype=float)
|
|
798
|
+
else:
|
|
799
|
+
if decay is None:
|
|
800
|
+
target = max(min_size, base // 4)
|
|
801
|
+
if base <= 0 or target <= 0:
|
|
802
|
+
dcy = 1.0
|
|
803
|
+
else:
|
|
804
|
+
dcy = (target / float(base)) ** (1.0 / (n_hidden - 1))
|
|
805
|
+
dcy = float(np.clip(dcy, 0.25, 0.99))
|
|
806
|
+
exponents = np.arange(n_hidden, dtype=float)
|
|
807
|
+
sizes = base * (dcy**exponents)
|
|
808
|
+
|
|
809
|
+
else:
|
|
810
|
+
msg = f"Unknown schedule '{schedule}'. Use 'pyramid', 'constant', or 'linear'."
|
|
811
|
+
self.logger.error(msg)
|
|
812
|
+
raise ValueError(msg)
|
|
813
|
+
|
|
814
|
+
sizes = np.clip(sizes, min_size, max_size)
|
|
815
|
+
|
|
816
|
+
if cap_by_inputs:
|
|
817
|
+
sizes = np.minimum(sizes, float(n_inputs))
|
|
818
|
+
|
|
819
|
+
sizes = (np.ceil(sizes / multiple_of) * multiple_of).astype(int)
|
|
820
|
+
sizes = np.minimum.accumulate(sizes)
|
|
821
|
+
return np.clip(sizes, min_size, max_size).astype(int).tolist()
|
|
822
|
+
|
|
823
|
+
def _class_weights_from_zygosity(self, X: np.ndarray) -> torch.Tensor:
|
|
824
|
+
"""Class-balanced weights for 0/1/2 (handles haploid collapse if needed).
|
|
825
|
+
|
|
826
|
+
This method computes class-balanced weights for the genotype classes (0/1/2) based on the provided genotype matrix. It handles cases where the data is haploid by collapsing the ALT class to 1, effectively treating the problem as binary classification (REF vs ALT). The weights are calculated using a class-balanced weighting scheme that considers the frequency of each class in the training data, with parameters for beta and maximum ratio to control the weighting behavior. The resulting weights are returned as a PyTorch tensor on the current device.
|
|
827
|
+
|
|
828
|
+
Args:
|
|
829
|
+
X (np.ndarray): 0/1/2 with -1 for missing.
|
|
830
|
+
|
|
831
|
+
Returns:
|
|
832
|
+
torch.Tensor: Weights on current device.
|
|
833
|
+
"""
|
|
834
|
+
y = X[X != -1].ravel().astype(np.int64)
|
|
835
|
+
if y.size == 0:
|
|
836
|
+
return torch.ones(
|
|
837
|
+
self.num_classes_, dtype=torch.float32, device=self.device
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
return self._class_balanced_weights_from_mask(
|
|
841
|
+
y=y,
|
|
842
|
+
train_mask=np.ones_like(y, dtype=bool),
|
|
843
|
+
num_classes=self.num_classes_,
|
|
844
|
+
beta=self.beta,
|
|
845
|
+
max_ratio=self.max_ratio,
|
|
846
|
+
mode="allele", # 1D int vector
|
|
847
|
+
).to(self.device)
|
|
848
|
+
|
|
849
|
+
@staticmethod
|
|
850
|
+
def _normalize_class_weights(
|
|
851
|
+
weights: torch.Tensor | None,
|
|
852
|
+
) -> torch.Tensor | None:
|
|
853
|
+
"""Normalize class weights once to keep loss scale stable.
|
|
854
|
+
|
|
855
|
+
Args:
|
|
856
|
+
weights (torch.Tensor | None): Class weights to normalize.
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
torch.Tensor | None: Normalized class weights or None if input is None.
|
|
860
|
+
"""
|
|
861
|
+
if weights is None:
|
|
862
|
+
return None
|
|
863
|
+
return weights / weights.mean().clamp_min(1e-8)
|
|
864
|
+
|
|
865
|
+
def _get_float_genotypes(self, *, copy: bool = True) -> np.ndarray:
|
|
866
|
+
"""Float32 0/1/2 matrix with NaNs for missing, cached per dataset.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
copy (bool): If True, return a copy of the cached array. Default is True.
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
np.ndarray: Float32 genotype matrix with NaNs for missing values.
|
|
873
|
+
"""
|
|
874
|
+
cache = self._float_genotype_cache
|
|
875
|
+
current = self.pgenc.genotypes_012
|
|
876
|
+
if cache is None or cache.shape != current.shape or cache.dtype != np.float32:
|
|
877
|
+
arr = np.asarray(current, dtype=np.float32)
|
|
878
|
+
arr = np.where(arr < 0, np.nan, arr)
|
|
879
|
+
self._float_genotype_cache = arr
|
|
880
|
+
cache = arr
|
|
881
|
+
return cache.copy() if copy else cache
|
|
882
|
+
|
|
883
|
+
def _sim_mask_cache_key(self) -> tuple | None:
|
|
884
|
+
"""Key for caching simulated-missing masks."""
|
|
885
|
+
if not getattr(self, "simulate_missing", False):
|
|
886
|
+
return None
|
|
887
|
+
shape = tuple(self.pgenc.genotypes_012.shape)
|
|
888
|
+
return (
|
|
889
|
+
id(self.genotype_data),
|
|
890
|
+
self.sim_strategy,
|
|
891
|
+
round(float(self.sim_prop), 6),
|
|
892
|
+
self.seed,
|
|
893
|
+
shape,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
def _one_hot_encode_012(self, X: np.ndarray | torch.Tensor) -> torch.Tensor:
|
|
897
|
+
"""One-hot 0/1/2; -1 rows are all-zeros (B, L, K).
|
|
898
|
+
|
|
899
|
+
This method performs one-hot encoding of the input genotype data (0, 1, 2) while handling missing values represented by -1. The output is a tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
|
|
900
|
+
|
|
901
|
+
Args:
|
|
902
|
+
X (np.ndarray | torch.Tensor): The input data to be one-hot encoded, either as a NumPy array or a PyTorch tensor.
|
|
903
|
+
|
|
904
|
+
Returns:
|
|
905
|
+
torch.Tensor: A one-hot encoded tensor of shape (B, L, K), where B is the batch size, L is the number of features, and K is the number of classes.
|
|
906
|
+
"""
|
|
907
|
+
Xt = (
|
|
908
|
+
torch.from_numpy(X).to(self.device)
|
|
909
|
+
if isinstance(X, np.ndarray)
|
|
910
|
+
else X.to(self.device)
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
# B=batch, L=features, K=classes
|
|
914
|
+
B, L = Xt.shape
|
|
915
|
+
K = self.num_classes_
|
|
916
|
+
X_ohe = torch.zeros(B, L, K, dtype=torch.float32, device=self.device)
|
|
917
|
+
valid = Xt != -1
|
|
918
|
+
idx = Xt[valid].long()
|
|
919
|
+
|
|
920
|
+
if idx.numel() > 0:
|
|
921
|
+
X_ohe[valid] = F.one_hot(idx, num_classes=K).float()
|
|
922
|
+
|
|
923
|
+
return X_ohe
|
|
924
|
+
|
|
925
|
+
def _eval_for_pruning(
|
|
926
|
+
self,
|
|
927
|
+
*,
|
|
928
|
+
model: torch.nn.Module,
|
|
929
|
+
X_val: np.ndarray,
|
|
930
|
+
params: dict,
|
|
931
|
+
metric: str,
|
|
932
|
+
objective_mode: bool = True,
|
|
933
|
+
do_latent_infer: bool = False,
|
|
934
|
+
latent_steps: int = 50,
|
|
935
|
+
latent_lr: float = 1e-2,
|
|
936
|
+
latent_weight_decay: float = 0.0,
|
|
937
|
+
latent_seed: int = 123,
|
|
938
|
+
_latent_cache: dict | None = None,
|
|
939
|
+
_latent_cache_key: str | None = None,
|
|
940
|
+
eval_mask_override: np.ndarray | None = None,
|
|
941
|
+
) -> float:
|
|
942
|
+
"""Compute a scalar metric (to MAXIMIZE) on a fixed validation set.
|
|
943
|
+
|
|
944
|
+
This method evaluates the model on a validation dataset and computes a specified metric, which is used for pruning decisions during hyperparameter tuning. It supports optional latent inference to optimize latent representations before evaluation. The method handles potential issues with non-finite metric values by returning negative infinity, making it easier to prune poorly performing trials.
|
|
945
|
+
|
|
946
|
+
Args:
|
|
947
|
+
model (torch.nn.Module): The model to evaluate.
|
|
948
|
+
X_val (np.ndarray): Validation data.
|
|
949
|
+
params (dict): Model parameters.
|
|
950
|
+
metric (str): Metric name to return.
|
|
951
|
+
objective_mode (bool): If True, use objective-mode evaluation. Default is True.
|
|
952
|
+
do_latent_infer (bool): If True, perform latent inference before evaluation. Default
|
|
953
|
+
latent_steps (int): Number of steps for latent inference. Default is 50.
|
|
954
|
+
latent_lr (float): Learning rate for latent inference. Default is 1e-2
|
|
955
|
+
latent_weight_decay (float): Weight decay for latent inference. Default is 0.0.
|
|
956
|
+
latent_seed (int): Random seed for latent inference. Default is 123.
|
|
957
|
+
_latent_cache (dict | None): Optional cache for storing/retrieving optimized latents
|
|
958
|
+
_latent_cache_key (str | None): Key for storing/retrieving in _latent_cache.
|
|
959
|
+
eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
|
|
960
|
+
|
|
961
|
+
Returns:
|
|
962
|
+
float: The computed metric value to maximize. Returns -inf on failure.
|
|
963
|
+
"""
|
|
964
|
+
optimized_val_latents = None
|
|
965
|
+
|
|
966
|
+
# Optional latent inference path for models that need it.
|
|
967
|
+
if do_latent_infer and hasattr(self, "_latent_infer_for_eval"):
|
|
968
|
+
optimized_val_latents = self._latent_infer_for_eval( # type: ignore
|
|
969
|
+
model=model,
|
|
970
|
+
X_val=X_val,
|
|
971
|
+
steps=latent_steps,
|
|
972
|
+
lr=latent_lr,
|
|
973
|
+
weight_decay=latent_weight_decay,
|
|
974
|
+
seed=latent_seed,
|
|
975
|
+
cache=_latent_cache,
|
|
976
|
+
cache_key=_latent_cache_key,
|
|
977
|
+
)
|
|
978
|
+
# Retrieve the optimized latents from the cache
|
|
979
|
+
if _latent_cache is not None and _latent_cache_key in _latent_cache:
|
|
980
|
+
optimized_val_latents = _latent_cache[_latent_cache_key]
|
|
981
|
+
|
|
982
|
+
if getattr(self, "_tune_eval_slice", None) is not None:
|
|
983
|
+
X_val = X_val[self._tune_eval_slice]
|
|
984
|
+
if eval_mask_override is not None:
|
|
985
|
+
eval_mask_override = eval_mask_override[self._tune_eval_slice]
|
|
986
|
+
|
|
987
|
+
# Child's evaluator now accepts the pre-computed latents
|
|
988
|
+
metrics = self._evaluate_model( # type: ignore
|
|
989
|
+
X_val=X_val,
|
|
990
|
+
model=model,
|
|
991
|
+
params=params,
|
|
992
|
+
objective_mode=objective_mode,
|
|
993
|
+
latent_vectors_val=optimized_val_latents,
|
|
994
|
+
eval_mask_override=eval_mask_override,
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
# Prefer the requested metric; fall back to self.tune_metric if needed.
|
|
998
|
+
val = metrics.get(metric, metrics.get(getattr(self, "tune_metric", ""), None))
|
|
999
|
+
|
|
1000
|
+
if val is None or not np.isfinite(val):
|
|
1001
|
+
return -np.inf # make pruning decisions easy/robust on bad reads
|
|
1002
|
+
|
|
1003
|
+
return float(val)
|
|
1004
|
+
|
|
1005
|
+
def _first_linear_in_features(self, model: torch.nn.Module) -> int:
|
|
1006
|
+
"""Return in_features of the model's first Linear layer.
|
|
1007
|
+
|
|
1008
|
+
This method iterates through the modules of the provided PyTorch model to find the first instance of a Linear layer. It then retrieves and returns the `in_features` attribute of that layer, which indicates the number of input features expected by the layer.
|
|
1009
|
+
|
|
1010
|
+
Args:
|
|
1011
|
+
model (torch.nn.Module): The model to inspect.
|
|
1012
|
+
|
|
1013
|
+
Returns:
|
|
1014
|
+
int: The in_features of the first Linear layer.
|
|
1015
|
+
"""
|
|
1016
|
+
for m in model.modules():
|
|
1017
|
+
if isinstance(m, torch.nn.Linear):
|
|
1018
|
+
return int(m.in_features)
|
|
1019
|
+
raise RuntimeError("No Linear layers found in model.")
|
|
1020
|
+
|
|
1021
|
+
def _assert_model_latent_compat(
|
|
1022
|
+
self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter
|
|
1023
|
+
) -> None:
|
|
1024
|
+
"""Raise if model's first Linear doesn't match latent_vectors width.
|
|
1025
|
+
|
|
1026
|
+
This method checks that the dimensionality of the provided latent vectors matches the expected input feature size of the model's first linear layer. If there is a mismatch, it raises a ValueError with a descriptive message.
|
|
1027
|
+
|
|
1028
|
+
Args:
|
|
1029
|
+
model (torch.nn.Module): The model to check.
|
|
1030
|
+
latent_vectors (torch.nn.Parameter): The latent vectors to check.
|
|
1031
|
+
|
|
1032
|
+
Raises:
|
|
1033
|
+
ValueError: If the latent dimension does not match the model's expected input features.
|
|
1034
|
+
"""
|
|
1035
|
+
zdim = int(latent_vectors.shape[1])
|
|
1036
|
+
first_in = self._first_linear_in_features(model)
|
|
1037
|
+
if first_in != zdim:
|
|
1038
|
+
raise ValueError(
|
|
1039
|
+
f"Latent mismatch: zdim={zdim}, model first Linear expects in_features={first_in}"
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
def _prepare_tuning_artifacts(self) -> None:
|
|
1043
|
+
"""Prepare data and artifacts needed for hyperparameter tuning.
|
|
1044
|
+
|
|
1045
|
+
This method sets up the necessary data splits, data loaders, and class weights required for hyperparameter tuning. It creates training and validation sets from the ground truth data, initializes data loaders with a specified batch size, and computes class-balanced weights based on the training data. The method also handles optional subsampling of the dataset for faster tuning and prepares slices for evaluation if needed.
|
|
1046
|
+
|
|
1047
|
+
Raises:
|
|
1048
|
+
AttributeError: If the ground truth data (`ground_truth_`) is not set.
|
|
1049
|
+
"""
|
|
1050
|
+
if getattr(self, "_tune_ready", False):
|
|
1051
|
+
return
|
|
1052
|
+
|
|
1053
|
+
X = self.ground_truth_
|
|
1054
|
+
n_samp, n_loci = X.shape
|
|
1055
|
+
rng = self.rng
|
|
1056
|
+
|
|
1057
|
+
if self.tune_fast:
|
|
1058
|
+
s = min(n_samp, self.tune_max_samples)
|
|
1059
|
+
l = n_loci if self.tune_max_loci == 0 else min(n_loci, self.tune_max_loci)
|
|
1060
|
+
|
|
1061
|
+
samp_idx = np.sort(rng.choice(n_samp, size=s, replace=False))
|
|
1062
|
+
loci_idx = np.sort(rng.choice(n_loci, size=l, replace=False))
|
|
1063
|
+
X_small = X[samp_idx][:, loci_idx]
|
|
1064
|
+
else:
|
|
1065
|
+
X_small = X
|
|
1066
|
+
|
|
1067
|
+
idx = np.arange(X_small.shape[0])
|
|
1068
|
+
tr, te = train_test_split(
|
|
1069
|
+
idx, test_size=self.validation_split, random_state=self.seed
|
|
1070
|
+
)
|
|
1071
|
+
self._tune_train_idx = tr
|
|
1072
|
+
self._tune_test_idx = te
|
|
1073
|
+
self._tune_X_train = X_small[tr]
|
|
1074
|
+
self._tune_X_test = X_small[te]
|
|
1075
|
+
|
|
1076
|
+
self._tune_class_weights = self._normalize_class_weights(
|
|
1077
|
+
self._class_weights_from_zygosity(self._tune_X_train)
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
# Temporarily bump batch size only for tuning loader
|
|
1081
|
+
orig_bs = self.batch_size
|
|
1082
|
+
self.batch_size = self.tune_batch_size
|
|
1083
|
+
self._tune_loader = self._get_data_loaders(self._tune_X_train) # type: ignore
|
|
1084
|
+
self.batch_size = orig_bs
|
|
1085
|
+
|
|
1086
|
+
self._tune_num_features = self._tune_X_train.shape[1]
|
|
1087
|
+
self._tune_val_latents_source = None
|
|
1088
|
+
self._tune_train_latents_source = None
|
|
1089
|
+
|
|
1090
|
+
# Optional: for huge val sets, thin them for proxy metric
|
|
1091
|
+
if (
|
|
1092
|
+
self.tune_proxy_metric_batch
|
|
1093
|
+
and self._tune_X_test.shape[0] > self.tune_proxy_metric_batch
|
|
1094
|
+
):
|
|
1095
|
+
self._tune_eval_slice = np.arange(self.tune_proxy_metric_batch)
|
|
1096
|
+
else:
|
|
1097
|
+
self._tune_eval_slice = None
|
|
1098
|
+
|
|
1099
|
+
self._tune_ready = True
|
|
1100
|
+
|
|
1101
|
+
def _save_best_params(self, best_params: Dict[str, Any]) -> None:
|
|
1102
|
+
"""Save the best hyperparameters to a JSON file.
|
|
1103
|
+
|
|
1104
|
+
This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
|
|
1105
|
+
|
|
1106
|
+
Args:
|
|
1107
|
+
best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
|
|
1108
|
+
"""
|
|
1109
|
+
if not hasattr(self, "parameters_dir"):
|
|
1110
|
+
msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
|
|
1111
|
+
self.logger.error(msg)
|
|
1112
|
+
raise AttributeError(msg)
|
|
1113
|
+
|
|
1114
|
+
fout = self.parameters_dir / "best_parameters.json"
|
|
1115
|
+
|
|
1116
|
+
with open(fout, "w") as f:
|
|
1117
|
+
json.dump(best_params, f, indent=4)
|
|
1118
|
+
|
|
1119
|
+
def _set_best_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
1120
|
+
"""An abstract method for setting best parameters."""
|
|
1121
|
+
raise NotImplementedError
|