pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.0.dist-info/RECORD +0 -75
- pg_sui-0.2.0.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib.figure import Figure
|
|
8
|
+
from plotly.graph_objs._figure import Figure as PlotlyFigure
|
|
9
|
+
from sklearn.exceptions import NotFittedError
|
|
10
|
+
from sklearn.experimental import enable_iterative_imputer # noqa
|
|
11
|
+
from sklearn.metrics import (
|
|
12
|
+
accuracy_score,
|
|
13
|
+
classification_report,
|
|
14
|
+
f1_score,
|
|
15
|
+
precision_score,
|
|
16
|
+
recall_score,
|
|
17
|
+
)
|
|
18
|
+
from snpio.utils.logging import LoggerManager
|
|
19
|
+
|
|
20
|
+
from pgsui.utils.classification_viz import ClassificationReportVisualizer
|
|
21
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseImputer:
|
|
25
|
+
"""A base class for supervised, iterative imputer models.
|
|
26
|
+
|
|
27
|
+
This class provides a common framework and shared functionality for imputers that use scikit-learn's `IterativeImputer`. It is not intended for direct instantiation. Child classes should inherit from this class and provide a specific estimator model (e.g., RandomForest, GradientBoosting).
|
|
28
|
+
|
|
29
|
+
Notes:
|
|
30
|
+
- A hyperparameter tuning workflow using Optuna.
|
|
31
|
+
- Standardized data splitting, model training, and evaluation methods.
|
|
32
|
+
- Utilities for creating output directories and handling model state.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, verbose: bool = False, debug: bool = False) -> None:
|
|
36
|
+
"""Initializes the BaseImputer class.
|
|
37
|
+
|
|
38
|
+
This class sets up logging and verbosity/debug settings. It also contains methods that all supervised imputers will share.
|
|
39
|
+
|
|
40
|
+
Note:
|
|
41
|
+
Inheriting child classes must define `self.prefix` before calling `super().__init__()`, as it is required for logger initialization.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
verbose (bool): If True, enables detailed logging output. Defaults to False.
|
|
45
|
+
debug (bool): If True, enables debug mode. Defaults to False.
|
|
46
|
+
"""
|
|
47
|
+
self.verbose = verbose
|
|
48
|
+
self.debug = debug
|
|
49
|
+
|
|
50
|
+
self.prefix: str # Must be set by child class after super().__init__()
|
|
51
|
+
self.metrics_dir: Path # Must be set by child class after super().__init__()
|
|
52
|
+
self.plots_dir: Path # Must be set by child class after super().__init__()
|
|
53
|
+
self.parameters_dir: Path # Must be set by child class after super().__init__()
|
|
54
|
+
self.model_name: str # Must be set by child class after super().__init__()
|
|
55
|
+
self.plotter_: Any # Must be set by child class after super().__init__()
|
|
56
|
+
self.plot_format: str # Must be set by child class after super().__init__()
|
|
57
|
+
self.is_haploid_: bool # Must be set by child class after super().__init__()
|
|
58
|
+
self.is_fit_: bool # Must be set by child class after super().__init__()
|
|
59
|
+
|
|
60
|
+
logman = LoggerManager(
|
|
61
|
+
__name__, prefix=self.prefix, verbose=self.verbose, debug=self.debug
|
|
62
|
+
)
|
|
63
|
+
self.logger = configure_logger(
|
|
64
|
+
logman.get_logger(), verbose=self.verbose, debug=self.debug
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
|
|
68
|
+
"""Creates the output directory structure for the imputer.
|
|
69
|
+
|
|
70
|
+
This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized by the model's name.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
prefix (str): The prefix for the main output directory.
|
|
74
|
+
outdirs (List[str]): A list of subdirectories to create (e.g., 'models', 'plots').
|
|
75
|
+
"""
|
|
76
|
+
base_dir = Path(f"{prefix}_output") / "Supervised"
|
|
77
|
+
for d in outdirs:
|
|
78
|
+
subdir = base_dir / d / self.model_name
|
|
79
|
+
setattr(self, f"{d}_dir", subdir)
|
|
80
|
+
subdir.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
|
|
82
|
+
def _make_class_reports(
|
|
83
|
+
self,
|
|
84
|
+
y_true: np.ndarray,
|
|
85
|
+
y_pred: np.ndarray,
|
|
86
|
+
metrics: Dict[str, float],
|
|
87
|
+
y_pred_proba: np.ndarray | None = None,
|
|
88
|
+
labels: List[str] = ["REF", "HET", "ALT"],
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Generate and save detailed classification reports and visualizations.
|
|
91
|
+
|
|
92
|
+
3-class (zygosity) or 10-class (IUPAC) depending on `labels` length.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
y_true (np.ndarray): True labels (1D array).
|
|
96
|
+
y_pred (np.ndarray): Predicted labels (1D array).
|
|
97
|
+
metrics (Dict[str, float]): Computed metrics.
|
|
98
|
+
y_pred_proba (np.ndarray | None): Predicted probabilities (2D array). Defaults to None.
|
|
99
|
+
labels (List[str], optional): Class label names
|
|
100
|
+
(default: ["REF", "HET", "ALT"] for 3-class).
|
|
101
|
+
"""
|
|
102
|
+
report_name = "zygosity" if len(labels) == 3 else "iupac"
|
|
103
|
+
middle = "IUPAC" if report_name == "iupac" else "Zygosity"
|
|
104
|
+
|
|
105
|
+
msg = f"{middle} Report (on {y_true.size} total genotypes)"
|
|
106
|
+
self.logger.info(msg)
|
|
107
|
+
|
|
108
|
+
if y_pred_proba is not None:
|
|
109
|
+
self.plotter_.plot_metrics(
|
|
110
|
+
y_true,
|
|
111
|
+
y_pred_proba,
|
|
112
|
+
metrics,
|
|
113
|
+
label_names=labels,
|
|
114
|
+
prefix=report_name,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self.plotter_.plot_confusion_matrix(
|
|
118
|
+
y_true, y_pred, label_names=labels, prefix=report_name
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
report = classification_report(
|
|
122
|
+
y_true,
|
|
123
|
+
y_pred,
|
|
124
|
+
labels=list(range(len(labels))),
|
|
125
|
+
target_names=labels,
|
|
126
|
+
zero_division=0,
|
|
127
|
+
output_dict=True,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
with open(self.metrics_dir / f"{report_name}_report.json", "w") as f:
|
|
131
|
+
json.dump(report, f, indent=4)
|
|
132
|
+
|
|
133
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
134
|
+
|
|
135
|
+
plots = viz.plot_all(
|
|
136
|
+
report, # type: ignore
|
|
137
|
+
title_prefix=f"{self.model_name} {middle} Report",
|
|
138
|
+
show=getattr(self, "show_plots", False),
|
|
139
|
+
heatmap_classes_only=True,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
for name, fig in plots.items():
|
|
143
|
+
fout = self.plots_dir / f"{report_name}_report_{name}.{self.plot_format}"
|
|
144
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
145
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
146
|
+
plt.close(fig)
|
|
147
|
+
elif hasattr(fig, "write_html") and isinstance(fig, PlotlyFigure):
|
|
148
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
149
|
+
|
|
150
|
+
viz._reset_mpl_style()
|
|
151
|
+
|
|
152
|
+
def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
|
|
153
|
+
"""0/1/2 zygosity report & confusion matrix.
|
|
154
|
+
|
|
155
|
+
This method generates a classification report and confusion matrix for genotypes encoded as 0, 1, or 2. If the data is haploid, it treats genotypes 1 and 2 as equivalent (presence of the alternate allele).
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
y_true (np.ndarray): True genotypes (0/1/2) for masked
|
|
159
|
+
y_pred (np.ndarray): Predicted genotypes (0/1/2) for masked
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
NotFittedError: If fit() and transform() have not been called.
|
|
163
|
+
"""
|
|
164
|
+
labels = [0, 1, 2]
|
|
165
|
+
# Haploid parity: fold ALT (2) into ALT/Present (1)
|
|
166
|
+
if self.is_haploid_:
|
|
167
|
+
y_true[y_true == 2] = 1
|
|
168
|
+
y_pred[y_pred == 2] = 1
|
|
169
|
+
labels = [0, 1]
|
|
170
|
+
|
|
171
|
+
metrics = {
|
|
172
|
+
"n_masked_test": int(y_true.size),
|
|
173
|
+
"accuracy": accuracy_score(y_true, y_pred),
|
|
174
|
+
"f1": f1_score(
|
|
175
|
+
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
176
|
+
),
|
|
177
|
+
"precision": precision_score(
|
|
178
|
+
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
179
|
+
),
|
|
180
|
+
"recall": recall_score(
|
|
181
|
+
y_true, y_pred, average="macro", labels=labels, zero_division=0
|
|
182
|
+
),
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
metrics.update({f"zygosity_{k}": v for k, v in metrics.items()})
|
|
186
|
+
|
|
187
|
+
report_names = ["REF", "HET"] if self.is_haploid_ else ["REF", "HET", "ALT"]
|
|
188
|
+
|
|
189
|
+
self.logger.info(
|
|
190
|
+
f"\n{classification_report(y_true, y_pred, labels=labels, target_names=report_names, zero_division=0)}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
report = classification_report(
|
|
194
|
+
y_true,
|
|
195
|
+
y_pred,
|
|
196
|
+
labels=labels,
|
|
197
|
+
target_names=report_names,
|
|
198
|
+
zero_division=0,
|
|
199
|
+
output_dict=True,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
203
|
+
|
|
204
|
+
plots = viz.plot_all(
|
|
205
|
+
report, # type: ignore
|
|
206
|
+
title_prefix=f"{self.model_name} Zygosity Report",
|
|
207
|
+
show=getattr(self, "show_plots", False),
|
|
208
|
+
heatmap_classes_only=True,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
for name, fig in plots.items():
|
|
212
|
+
fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
|
|
213
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
214
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
215
|
+
plt.close(fig)
|
|
216
|
+
elif hasattr(fig, "write_html") and isinstance(fig, PlotlyFigure):
|
|
217
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
218
|
+
|
|
219
|
+
viz._reset_mpl_style()
|
|
220
|
+
|
|
221
|
+
# Save JSON
|
|
222
|
+
self._save_report(report, suffix="zygosity") # type: ignore
|
|
223
|
+
|
|
224
|
+
# Confusion matrix
|
|
225
|
+
self.plotter_.plot_confusion_matrix(
|
|
226
|
+
y_true, y_pred, label_names=report_names, prefix="zygosity"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def _evaluate_iupac10_and_plot(
|
|
230
|
+
self, y_true: np.ndarray, y_pred: np.ndarray
|
|
231
|
+
) -> None:
|
|
232
|
+
"""10-class IUPAC report & confusion matrix.
|
|
233
|
+
|
|
234
|
+
This method generates a classification report and confusion matrix for genotypes encoded using the 10 IUPAC codes (0-9). The IUPAC codes represent various nucleotide combinations, including ambiguous bases.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
y_true (np.ndarray): True genotypes (0-9) for masked
|
|
238
|
+
y_pred (np.ndarray): Predicted genotypes (0-9) for masked
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
NotFittedError: If fit() and transform() have not been called.
|
|
242
|
+
"""
|
|
243
|
+
labels_idx = list(range(10))
|
|
244
|
+
labels_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
|
|
245
|
+
|
|
246
|
+
metrics = {
|
|
247
|
+
"accuracy": accuracy_score(y_true, y_pred),
|
|
248
|
+
"f1": f1_score(
|
|
249
|
+
y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
|
|
250
|
+
),
|
|
251
|
+
"precision": precision_score(
|
|
252
|
+
y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
|
|
253
|
+
),
|
|
254
|
+
"recall": recall_score(
|
|
255
|
+
y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
|
|
256
|
+
),
|
|
257
|
+
}
|
|
258
|
+
metrics.update({f"iupac_{k}": v for k, v in metrics.items()})
|
|
259
|
+
|
|
260
|
+
self.logger.info(
|
|
261
|
+
f"\n{classification_report(y_true, y_pred, labels=labels_idx, target_names=labels_names, zero_division=0)}"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
report = classification_report(
|
|
265
|
+
y_true,
|
|
266
|
+
y_pred,
|
|
267
|
+
labels=labels_idx,
|
|
268
|
+
target_names=labels_names,
|
|
269
|
+
zero_division=0,
|
|
270
|
+
output_dict=True,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
|
|
274
|
+
|
|
275
|
+
plots = viz.plot_all(
|
|
276
|
+
report, # type: ignore
|
|
277
|
+
title_prefix=f"{self.model_name} IUPAC Report",
|
|
278
|
+
show=getattr(self, "show_plots", False),
|
|
279
|
+
heatmap_classes_only=True,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Reset the style from Optuna's plotting.
|
|
283
|
+
plt.rcParams.update(self.plotter_.param_dict)
|
|
284
|
+
|
|
285
|
+
for name, fig in plots.items():
|
|
286
|
+
fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
|
|
287
|
+
if hasattr(fig, "savefig") and isinstance(fig, Figure):
|
|
288
|
+
fig.savefig(fout, dpi=300, facecolor="#111122")
|
|
289
|
+
plt.close(fig)
|
|
290
|
+
elif hasattr(fig, "write_html") and isinstance(fig, PlotlyFigure):
|
|
291
|
+
fig.write_html(file=fout.with_suffix(".html"))
|
|
292
|
+
|
|
293
|
+
# Reset the style
|
|
294
|
+
viz._reset_mpl_style()
|
|
295
|
+
|
|
296
|
+
# Save JSON
|
|
297
|
+
self._save_report(report, suffix="iupac") # type: ignore
|
|
298
|
+
|
|
299
|
+
# Confusion matrix
|
|
300
|
+
self.plotter_.plot_confusion_matrix(
|
|
301
|
+
y_true, y_pred, label_names=labels_names, prefix="iupac"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def _save_report(self, report_dict: Dict[str, float], suffix: str) -> None:
|
|
305
|
+
"""Save classification report dictionary as a JSON file.
|
|
306
|
+
|
|
307
|
+
This method saves the provided classification report dictionary to a JSON file in the metrics directory, appending the specified suffix to the filename.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
report_dict (Dict[str, float]): The classification report dictionary to save.
|
|
311
|
+
suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
NotFittedError: If fit() and transform() have not been called.
|
|
315
|
+
"""
|
|
316
|
+
if not self.is_fit_:
|
|
317
|
+
msg = "No report to save. Ensure fit() has been called."
|
|
318
|
+
raise NotFittedError(msg)
|
|
319
|
+
|
|
320
|
+
out_fp = self.metrics_dir / f"classification_report_{suffix}.json"
|
|
321
|
+
|
|
322
|
+
with open(out_fp, "w") as f:
|
|
323
|
+
json.dump(report_dict, f, indent=4)
|
|
324
|
+
|
|
325
|
+
self.logger.info(f"{self.model_name} {suffix} report saved to {out_fp}.")
|
|
326
|
+
|
|
327
|
+
def _save_best_params(self, best_params: Dict[str, Any]) -> None:
|
|
328
|
+
"""Save the best hyperparameters to a JSON file.
|
|
329
|
+
|
|
330
|
+
This method saves the best hyperparameters found during hyperparameter tuning to a JSON file in the optimization directory. The filename includes the model name for easy identification.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
best_params (Dict[str, Any]): A dictionary of the best hyperparameters to save.
|
|
334
|
+
"""
|
|
335
|
+
if not hasattr(self, "parameters_dir"):
|
|
336
|
+
msg = "Attribute 'parameters_dir' not found. Ensure _create_model_directories() has been called."
|
|
337
|
+
self.logger.error(msg)
|
|
338
|
+
raise AttributeError(msg)
|
|
339
|
+
|
|
340
|
+
fout = self.parameters_dir / "best_parameters.json"
|
|
341
|
+
|
|
342
|
+
with open(fout, "w") as f:
|
|
343
|
+
json.dump(best_params, f, indent=4)
|
|
File without changes
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
# Standard library
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal
|
|
5
|
+
|
|
6
|
+
# Third-party
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sklearn.ensemble import HistGradientBoostingClassifier
|
|
9
|
+
from sklearn.exceptions import NotFittedError
|
|
10
|
+
from sklearn.experimental import enable_iterative_imputer # noqa
|
|
11
|
+
from sklearn.impute import IterativeImputer
|
|
12
|
+
from sklearn.model_selection import train_test_split
|
|
13
|
+
|
|
14
|
+
# Project
|
|
15
|
+
from snpio.analysis.genotype_encoder import GenotypeEncoder
|
|
16
|
+
from snpio.utils.logging import LoggerManager
|
|
17
|
+
|
|
18
|
+
from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
|
|
19
|
+
from pgsui.data_processing.containers import (
|
|
20
|
+
HGBConfig,
|
|
21
|
+
_HGBParams,
|
|
22
|
+
_ImputerParams,
|
|
23
|
+
_SimParams,
|
|
24
|
+
)
|
|
25
|
+
from pgsui.data_processing.transformers import SimGenotypeDataTransformer
|
|
26
|
+
from pgsui.impute.supervised.base import BaseImputer
|
|
27
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
28
|
+
from pgsui.utils.plotting import Plotting
|
|
29
|
+
from pgsui.utils.scorers import Scorer
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def ensure_hgb_config(config: HGBConfig | Dict | str | None) -> HGBConfig:
|
|
36
|
+
"""Resolve HGB configuration from dataclass, mapping, or YAML path."""
|
|
37
|
+
|
|
38
|
+
if config is None:
|
|
39
|
+
return HGBConfig()
|
|
40
|
+
if isinstance(config, HGBConfig):
|
|
41
|
+
return config
|
|
42
|
+
if isinstance(config, str):
|
|
43
|
+
return load_yaml_to_dataclass(config, HGBConfig)
|
|
44
|
+
if isinstance(config, dict):
|
|
45
|
+
payload = dict(config)
|
|
46
|
+
preset = payload.pop("preset", None)
|
|
47
|
+
base = HGBConfig.from_preset(preset) if preset else HGBConfig()
|
|
48
|
+
|
|
49
|
+
def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None:
|
|
50
|
+
for key, value in data.items():
|
|
51
|
+
dotted = f"{prefix}.{key}" if prefix else key
|
|
52
|
+
if isinstance(value, dict):
|
|
53
|
+
_flatten(dotted, value, out)
|
|
54
|
+
else:
|
|
55
|
+
out[dotted] = value
|
|
56
|
+
|
|
57
|
+
flat: Dict[str, Any] = {}
|
|
58
|
+
_flatten("", payload, flat)
|
|
59
|
+
return apply_dot_overrides(base, flat)
|
|
60
|
+
|
|
61
|
+
raise TypeError("config must be an HGBConfig, dict, YAML path, or None.")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ImputeHistGradientBoosting(BaseImputer):
|
|
65
|
+
"""Supervised HGB imputer driven by :class:`HGBConfig`."""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
genotype_data: "GenotypeData",
|
|
70
|
+
*,
|
|
71
|
+
config: HGBConfig | Dict | str | None = None,
|
|
72
|
+
overrides: Dict | None = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
self.model_name = "ImputeHistGradientBoosting"
|
|
75
|
+
self.Model = HistGradientBoostingClassifier
|
|
76
|
+
|
|
77
|
+
cfg = ensure_hgb_config(config)
|
|
78
|
+
if overrides:
|
|
79
|
+
cfg = cfg.apply_overrides(overrides)
|
|
80
|
+
self.cfg = cfg
|
|
81
|
+
|
|
82
|
+
self.genotype_data = genotype_data
|
|
83
|
+
self.pgenc = GenotypeEncoder(genotype_data)
|
|
84
|
+
|
|
85
|
+
self.prefix = cfg.io.prefix
|
|
86
|
+
self.seed = cfg.io.seed
|
|
87
|
+
self.n_jobs = cfg.io.n_jobs
|
|
88
|
+
self.verbose = cfg.io.verbose
|
|
89
|
+
self.debug = cfg.io.debug
|
|
90
|
+
|
|
91
|
+
super().__init__(verbose=self.verbose, debug=self.debug)
|
|
92
|
+
|
|
93
|
+
logman = LoggerManager(
|
|
94
|
+
__name__, prefix=self.prefix, verbose=self.verbose, debug=self.debug
|
|
95
|
+
)
|
|
96
|
+
self.logger = configure_logger(
|
|
97
|
+
logman.get_logger(), verbose=self.verbose, debug=self.debug
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self._create_model_directories(
|
|
101
|
+
self.prefix, ["models", "plots", "metrics", "optimize", "parameters"]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.plot_format: Literal["png", "pdf", "svg", "jpg", "jpeg"] = cfg.plot.fmt
|
|
105
|
+
|
|
106
|
+
self.plot_fontsize = cfg.plot.fontsize
|
|
107
|
+
self.title_fontsize = cfg.plot.fontsize
|
|
108
|
+
self.plot_dpi = cfg.plot.dpi
|
|
109
|
+
self.despine = cfg.plot.despine
|
|
110
|
+
self.show_plots = cfg.plot.show
|
|
111
|
+
|
|
112
|
+
self.validation_split = cfg.train.validation_split
|
|
113
|
+
|
|
114
|
+
if cfg.model.max_features is None:
|
|
115
|
+
max_feat = None
|
|
116
|
+
else:
|
|
117
|
+
max_feat = cfg.model.max_features
|
|
118
|
+
|
|
119
|
+
class_weight: Literal["balanced", "balanced_subsample", None] = getattr(
|
|
120
|
+
cfg.model, "class_weight", "balanced"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if class_weight not in {"balanced", "balanced_subsample", None}:
|
|
124
|
+
msg = (
|
|
125
|
+
f"Invalid class_weight '{class_weight}'; "
|
|
126
|
+
"must be one of: 'balanced', 'balanced_subsample', or None."
|
|
127
|
+
)
|
|
128
|
+
self.logger.error(msg)
|
|
129
|
+
raise ValueError(msg)
|
|
130
|
+
|
|
131
|
+
self.params = _HGBParams(
|
|
132
|
+
max_iter=cfg.model.n_estimators,
|
|
133
|
+
learning_rate=cfg.model.learning_rate,
|
|
134
|
+
max_depth=cfg.model.max_depth,
|
|
135
|
+
min_samples_leaf=cfg.model.min_samples_leaf,
|
|
136
|
+
max_features=max_feat,
|
|
137
|
+
n_iter_no_change=cfg.model.n_iter_no_change,
|
|
138
|
+
tol=cfg.model.tol,
|
|
139
|
+
class_weight=class_weight,
|
|
140
|
+
random_state=self.seed,
|
|
141
|
+
verbose=self.debug,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
self.imputer_params = _ImputerParams(
|
|
145
|
+
n_nearest_features=cfg.imputer.n_nearest_features,
|
|
146
|
+
max_iter=cfg.imputer.max_iter,
|
|
147
|
+
random_state=self.seed,
|
|
148
|
+
verbose=self.verbose,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
self.sim_params = _SimParams(
|
|
152
|
+
prop_missing=cfg.sim.prop_missing,
|
|
153
|
+
strategy=cfg.sim.strategy,
|
|
154
|
+
missing_val=cfg.sim.missing_val,
|
|
155
|
+
het_boost=cfg.sim.het_boost,
|
|
156
|
+
seed=self.seed,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
self.max_iter = cfg.imputer.max_iter
|
|
160
|
+
self.n_nearest_features = cfg.imputer.n_nearest_features
|
|
161
|
+
|
|
162
|
+
# Will be set in fit()
|
|
163
|
+
self.is_haploid_: bool | None = None
|
|
164
|
+
self.num_classes_: int | None = None
|
|
165
|
+
self.num_features_: int | None = None
|
|
166
|
+
self.models_: List[HistGradientBoostingClassifier | None] | None = None
|
|
167
|
+
self.is_fit_: bool = False
|
|
168
|
+
|
|
169
|
+
def fit(self) -> "BaseImputer":
|
|
170
|
+
"""Fit the imputer using self.genotype_data with no arguments.
|
|
171
|
+
|
|
172
|
+
This method prepares the imputer by splitting the data into training and testing sets, and masking all originally observed genotype entries in the test set to facilitate unbiased evaluation. It does not perform any actual imputation since the RefAllele imputer is deterministic.
|
|
173
|
+
|
|
174
|
+
Steps:
|
|
175
|
+
1) Encode to 0/1/2 with -9/-1 as missing.
|
|
176
|
+
2) Split samples into train/test.
|
|
177
|
+
3) Train IterativeImputer on train (convert missing -> NaN).
|
|
178
|
+
4) Evaluate on test **non-missing positions** (reconstruction metrics) and call your original plotting stack via _make_class_reports().
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
BaseImputer: self.
|
|
182
|
+
"""
|
|
183
|
+
# Prepare utilities & metadata
|
|
184
|
+
self.scorers_ = Scorer(
|
|
185
|
+
prefix=self.prefix, average="macro", verbose=self.verbose, debug=self.debug
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if self.plot_format not in {"png", "pdf", "svg", "jpg", "jpeg"}:
|
|
189
|
+
msg = (
|
|
190
|
+
f"Invalid plot format '{self.plot_format}'; "
|
|
191
|
+
"must be one of: png, pdf, svg, jpg, jpeg."
|
|
192
|
+
)
|
|
193
|
+
self.logger.error(msg)
|
|
194
|
+
raise ValueError(msg)
|
|
195
|
+
|
|
196
|
+
self.plotter_ = Plotting(
|
|
197
|
+
self.model_name,
|
|
198
|
+
prefix=self.prefix,
|
|
199
|
+
plot_format=self.plot_format,
|
|
200
|
+
plot_dpi=self.plot_dpi,
|
|
201
|
+
plot_fontsize=self.plot_fontsize,
|
|
202
|
+
title_fontsize=self.title_fontsize,
|
|
203
|
+
despine=self.despine,
|
|
204
|
+
show_plots=self.show_plots,
|
|
205
|
+
verbose=self.verbose,
|
|
206
|
+
debug=self.debug,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
X_int = self.pgenc.genotypes_012
|
|
210
|
+
self.X012_ = X_int.astype(float)
|
|
211
|
+
self.X012_[self.X012_ < 0] = np.nan # Ensure missing are NaN
|
|
212
|
+
self.is_haploid_ = np.count_nonzero(self.X012_ == 1) == 0
|
|
213
|
+
self.num_classes_ = 2 if self.is_haploid_ else 3
|
|
214
|
+
self.n_samples_, self.n_features_ = X_int.shape
|
|
215
|
+
|
|
216
|
+
# Split
|
|
217
|
+
X_train, X_test = train_test_split(
|
|
218
|
+
self.X012_,
|
|
219
|
+
test_size=self.validation_split,
|
|
220
|
+
random_state=self.seed,
|
|
221
|
+
shuffle=True,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Simulate missing values on test set.
|
|
225
|
+
sim_transformer = SimGenotypeDataTransformer(**self.sim_params.to_dict())
|
|
226
|
+
|
|
227
|
+
X_test = np.nan_to_num(X_test, nan=-1) # ensure missing are -1
|
|
228
|
+
sim_transformer.fit(X_test)
|
|
229
|
+
X_test_sim, missing_masks = sim_transformer.transform(X_test)
|
|
230
|
+
sim_mask = missing_masks["simulated"]
|
|
231
|
+
X_test_sim[X_test_sim < 0] = np.nan # ensure missing are NaN
|
|
232
|
+
|
|
233
|
+
self.model_params_ = self.params.to_dict()
|
|
234
|
+
self.model_params_["random_state"] = self.seed
|
|
235
|
+
|
|
236
|
+
# Train IterativeImputer
|
|
237
|
+
est = self.Model(**self.model_params_)
|
|
238
|
+
|
|
239
|
+
self.imputer_ = IterativeImputer(estimator=est, **self.imputer_params.to_dict())
|
|
240
|
+
|
|
241
|
+
self.imputer_.fit(X_train)
|
|
242
|
+
self.is_fit_ = True
|
|
243
|
+
|
|
244
|
+
X_test_imputed = self.imputer_.transform(X_test_sim)
|
|
245
|
+
|
|
246
|
+
# Predict on simulated test set
|
|
247
|
+
y_true_flat = X_test[sim_mask].copy()
|
|
248
|
+
y_pred_flat = X_test_imputed[sim_mask].copy()
|
|
249
|
+
|
|
250
|
+
# Round and clip predictions to valid {0,1,2} or {0,1} if haploid.
|
|
251
|
+
if self.is_haploid_:
|
|
252
|
+
y_pred_flat = np.clip(np.rint(y_pred_flat), 0, 1).astype(int, copy=False)
|
|
253
|
+
y_true_flat = np.clip(np.rint(y_true_flat), 0, 1).astype(int, copy=False)
|
|
254
|
+
else:
|
|
255
|
+
y_pred_flat = np.clip(np.rint(y_pred_flat), 0, 2).astype(int, copy=False)
|
|
256
|
+
y_true_flat = np.clip(np.rint(y_true_flat), 0, 2).astype(int, copy=False)
|
|
257
|
+
|
|
258
|
+
# Evaluate (012 / zygosity)
|
|
259
|
+
self._evaluate_012_and_plot(y_true_flat.copy(), y_pred_flat.copy())
|
|
260
|
+
|
|
261
|
+
# Evaluate (IUPAC)
|
|
262
|
+
encodings_dict = {
|
|
263
|
+
"A": 0,
|
|
264
|
+
"C": 1,
|
|
265
|
+
"G": 2,
|
|
266
|
+
"T": 3,
|
|
267
|
+
"W": 4,
|
|
268
|
+
"R": 5,
|
|
269
|
+
"M": 6,
|
|
270
|
+
"K": 7,
|
|
271
|
+
"Y": 8,
|
|
272
|
+
"S": 9,
|
|
273
|
+
"N": -1,
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
y_true_iupac_tmp = self.pgenc.decode_012(y_true_flat)
|
|
277
|
+
y_pred_iupac_tmp = self.pgenc.decode_012(y_pred_flat)
|
|
278
|
+
y_true_iupac = self.pgenc.convert_int_iupac(
|
|
279
|
+
y_true_iupac_tmp, encodings_dict=encodings_dict
|
|
280
|
+
)
|
|
281
|
+
y_pred_iupac = self.pgenc.convert_int_iupac(
|
|
282
|
+
y_pred_iupac_tmp, encodings_dict=encodings_dict
|
|
283
|
+
)
|
|
284
|
+
self._evaluate_iupac10_and_plot(y_true_iupac, y_pred_iupac)
|
|
285
|
+
|
|
286
|
+
self.best_params_ = self.model_params_
|
|
287
|
+
self.best_params_.update(self.imputer_params.to_dict())
|
|
288
|
+
self.best_params_.update(self.sim_params.to_dict())
|
|
289
|
+
self._save_best_params(self.best_params_)
|
|
290
|
+
|
|
291
|
+
return self
|
|
292
|
+
|
|
293
|
+
def transform(self) -> np.ndarray:
|
|
294
|
+
"""Impute all samples and return imputed genotypes.
|
|
295
|
+
|
|
296
|
+
This method applies the trained imputer to the entire dataset, filling in missing genotype values. It ensures that any remaining missing values after imputation are set to -9, and decodes the imputed 0/1/2 genotypes back to their original format.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
np.ndarray: (n_samples, n_loci) integers with no -9/-1/NaN.
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
NotFittedError: If fit() has not been called prior to transform().
|
|
303
|
+
"""
|
|
304
|
+
if not self.is_fit_:
|
|
305
|
+
msg = "Imputer has not been fit; call fit() before transform()."
|
|
306
|
+
self.logger.error(msg)
|
|
307
|
+
raise NotFittedError(msg)
|
|
308
|
+
|
|
309
|
+
X = self.X012_.copy()
|
|
310
|
+
X_imp = self.imputer_.transform(X)
|
|
311
|
+
|
|
312
|
+
if np.any(X_imp < 0) or np.isnan(X_imp).any():
|
|
313
|
+
self.logger.warning("Some imputed values are still missing; setting to -9.")
|
|
314
|
+
X_imp[X_imp < 0] = -9
|
|
315
|
+
X_imp[np.isnan(X_imp)] = -9
|
|
316
|
+
|
|
317
|
+
return self.pgenc.decode_012(X_imp)
|