pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.0.dist-info/RECORD +0 -75
- pg_sui-0.2.0.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
pgsui/utils/plotting.py
CHANGED
|
@@ -1,949 +1,1116 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
from itertools import cycle
|
|
1
|
+
import logging
|
|
2
|
+
import warnings
|
|
4
3
|
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Literal, Optional, Sequence, cast
|
|
5
5
|
|
|
6
|
+
import matplotlib as mpl
|
|
7
|
+
|
|
8
|
+
# Use Agg backend for headless plotting
|
|
9
|
+
mpl.use("Agg")
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
6
12
|
import numpy as np
|
|
13
|
+
import optuna
|
|
7
14
|
import pandas as pd
|
|
8
|
-
import matplotlib.pyplot as plt
|
|
9
15
|
import seaborn as sns
|
|
10
|
-
import
|
|
11
|
-
|
|
12
|
-
from sklearn.
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
import torch
|
|
17
|
+
from optuna.exceptions import ExperimentalWarning
|
|
18
|
+
from sklearn.metrics import (
|
|
19
|
+
ConfusionMatrixDisplay,
|
|
20
|
+
auc,
|
|
21
|
+
average_precision_score,
|
|
22
|
+
confusion_matrix,
|
|
23
|
+
precision_recall_curve,
|
|
24
|
+
roc_curve,
|
|
25
|
+
)
|
|
26
|
+
from sklearn.preprocessing import label_binarize
|
|
27
|
+
from snpio import SNPioMultiQC
|
|
28
|
+
from snpio.utils.logging import LoggerManager
|
|
29
|
+
|
|
30
|
+
from pgsui.utils import misc
|
|
31
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
32
|
+
|
|
33
|
+
# Quiet Matplotlib/fontTools INFO logging when saving PDF/SVG
|
|
34
|
+
for name in (
|
|
35
|
+
"fontTools",
|
|
36
|
+
"fontTools.subset",
|
|
37
|
+
"fontTools.ttLib",
|
|
38
|
+
"matplotlib.font_manager",
|
|
39
|
+
):
|
|
40
|
+
lg = logging.getLogger(name)
|
|
41
|
+
lg.setLevel(logging.WARNING)
|
|
42
|
+
lg.propagate = False
|
|
21
43
|
|
|
22
44
|
|
|
23
45
|
class Plotting:
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
46
|
+
"""Class for plotting imputer scoring and results.
|
|
47
|
+
|
|
48
|
+
This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> from pgsui import Plotting
|
|
52
|
+
>>> plotter = Plotting(model_name="ImputeVAE", prefix="pgsui_test", plot_format="png")
|
|
53
|
+
>>> plotter.plot_metrics(metrics, num_classes)
|
|
54
|
+
>>> plotter.plot_history(history)
|
|
55
|
+
>>> plotter.plot_confusion_matrix(y_true_1d, y_pred_1d)
|
|
56
|
+
>>> plotter.plot_tuning(study, model_name, optimize_dir, target_name="Objective Value")
|
|
57
|
+
>>> plotter.plot_gt_distribution(df)
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
model_name (str): Name of the model.
|
|
61
|
+
prefix (str): Prefix for the output directory.
|
|
62
|
+
plot_format (Literal["pdf", "png", "jpeg", "jpg", "svg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg', 'svg').
|
|
63
|
+
plot_fontsize (int): Font size for the plots.
|
|
64
|
+
plot_dpi (int): Dots per inch for the plots.
|
|
65
|
+
title_fontsize (int): Font size for the plot titles.
|
|
66
|
+
show_plots (bool): Whether to display the plots inline or during execution.
|
|
67
|
+
output_dir (Path): Directory where plots will be saved.
|
|
68
|
+
logger (logging.Logger): Logger instance for logging messages.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
model_name: str,
|
|
74
|
+
*,
|
|
75
|
+
prefix: str = "pgsui",
|
|
76
|
+
plot_format: Literal["pdf", "png", "jpeg", "jpg", "svg"] = "pdf",
|
|
77
|
+
plot_fontsize: int = 18,
|
|
78
|
+
plot_dpi: int = 300,
|
|
79
|
+
title_fontsize: int = 20,
|
|
80
|
+
despine: bool = True,
|
|
81
|
+
show_plots: bool = False,
|
|
82
|
+
verbose: int = 0,
|
|
83
|
+
debug: bool = False,
|
|
84
|
+
multiqc: bool = False,
|
|
85
|
+
multiqc_section: Optional[str] = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Initialize the Plotting object.
|
|
88
|
+
|
|
89
|
+
This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
|
|
31
90
|
|
|
32
91
|
Args:
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
92
|
+
model_name (str): Name of the model.
|
|
93
|
+
prefix (str, optional): Prefix for the output directory. Defaults to 'pgsui'.
|
|
94
|
+
plot_format (Literal["pdf", "png", "jpeg", "jpg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg'). Defaults to 'pdf'.
|
|
95
|
+
plot_fontsize (int): Font size for the plots. Defaults to 18.
|
|
96
|
+
plot_dpi (int): Dots per inch for the plots. Defaults to 300.
|
|
97
|
+
title_fontsize (int): Font size for the plot titles. Defaults to 20.
|
|
98
|
+
despine (bool): Whether to remove the top and right spines from the plots. Defaults to True.
|
|
99
|
+
show_plots (bool): Whether to display the plots. Defaults to False.
|
|
100
|
+
verbose (int): Verbosity level for logging. Defaults to 0.
|
|
101
|
+
debug (bool): Whether to enable debug mode. Defaults to False.
|
|
102
|
+
multiqc (bool): Whether to queue plots for a MultiQC HTML report. Defaults to False.
|
|
103
|
+
multiqc_section (Optional[str]): Section name to use in MultiQC. Defaults to 'PG-SUI (<model_name>)'.
|
|
38
104
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
params_df[col] = results[means_test[i]]
|
|
46
|
-
|
|
47
|
-
# Get number of needed subplot rows.
|
|
48
|
-
tot = len(filter_col)
|
|
49
|
-
cols = 4
|
|
50
|
-
rows = int(np.ceil(tot / cols))
|
|
51
|
-
|
|
52
|
-
fig = plt.figure(1, figsize=(20, 10))
|
|
53
|
-
fig.tight_layout(pad=3.0)
|
|
54
|
-
|
|
55
|
-
# Set font properties.
|
|
56
|
-
font = {"size": 12}
|
|
57
|
-
plt.rc("font", **font)
|
|
58
|
-
|
|
59
|
-
for i, p in enumerate(filter_col, start=1):
|
|
60
|
-
ax = fig.add_subplot(rows, cols, i)
|
|
61
|
-
|
|
62
|
-
# Plot each metric.
|
|
63
|
-
for col in means_test:
|
|
64
|
-
# Get maximum score for each parameter setting.
|
|
65
|
-
df_plot = params_df.groupby(p)[col].agg("max")
|
|
66
|
-
|
|
67
|
-
# Convert to float if not supposed to be string.
|
|
68
|
-
try:
|
|
69
|
-
df_plot.index = df_plot.index.astype(float)
|
|
70
|
-
except TypeError:
|
|
71
|
-
pass
|
|
72
|
-
|
|
73
|
-
# Sort by index (numerically if possible).
|
|
74
|
-
df_plot = df_plot.sort_index()
|
|
75
|
-
|
|
76
|
-
# Remove prefix from score name.
|
|
77
|
-
col_new_name = col[len("mean_test_") :]
|
|
78
|
-
|
|
79
|
-
ax.plot(
|
|
80
|
-
df_plot.index.astype(str),
|
|
81
|
-
df_plot.values,
|
|
82
|
-
"-o",
|
|
83
|
-
label=col_new_name,
|
|
84
|
-
)
|
|
105
|
+
logman = LoggerManager(
|
|
106
|
+
name=__name__, prefix=prefix, verbose=bool(verbose), debug=bool(debug)
|
|
107
|
+
)
|
|
108
|
+
self.logger = configure_logger(
|
|
109
|
+
logman.get_logger(), verbose=bool(verbose), debug=bool(debug)
|
|
110
|
+
)
|
|
85
111
|
|
|
86
|
-
|
|
112
|
+
self.model_name = model_name
|
|
113
|
+
self.prefix = prefix
|
|
114
|
+
self.plot_format = plot_format
|
|
115
|
+
self.plot_fontsize = plot_fontsize
|
|
116
|
+
self.plot_dpi = plot_dpi
|
|
117
|
+
self.title_fontsize = title_fontsize
|
|
118
|
+
self.show_plots = show_plots
|
|
87
119
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
ax.set_ylabel("Max Score")
|
|
91
|
-
ax.set_ylim([0, 1])
|
|
120
|
+
# MultiQC configuration
|
|
121
|
+
self.use_multiqc: bool = bool(multiqc)
|
|
92
122
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
f"{prefix}_output",
|
|
96
|
-
"plots",
|
|
97
|
-
"Unsupervised",
|
|
98
|
-
nn_method,
|
|
99
|
-
"gridsearch_metrics.pdf",
|
|
100
|
-
),
|
|
101
|
-
bbox_inches="tight",
|
|
102
|
-
facecolor="white",
|
|
123
|
+
self.multiqc_section: str = (
|
|
124
|
+
multiqc_section if multiqc_section is not None else f"PG-SUI ({model_name})"
|
|
103
125
|
)
|
|
104
126
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
127
|
+
if self.plot_format.startswith("."):
|
|
128
|
+
self.plot_format = self.plot_format.lstrip(".")
|
|
129
|
+
|
|
130
|
+
self.param_dict = {
|
|
131
|
+
"axes.labelsize": self.plot_fontsize,
|
|
132
|
+
"axes.titlesize": self.title_fontsize,
|
|
133
|
+
"axes.spines.top": despine,
|
|
134
|
+
"axes.spines.right": despine,
|
|
135
|
+
"xtick.labelsize": self.plot_fontsize,
|
|
136
|
+
"ytick.labelsize": self.plot_fontsize,
|
|
137
|
+
"legend.fontsize": self.plot_fontsize,
|
|
138
|
+
"legend.facecolor": "white",
|
|
139
|
+
"figure.titlesize": self.title_fontsize,
|
|
140
|
+
"figure.dpi": self.plot_dpi,
|
|
141
|
+
"figure.facecolor": "white",
|
|
142
|
+
"axes.linewidth": 2.0,
|
|
143
|
+
"lines.linewidth": 2.0,
|
|
144
|
+
"font.size": self.plot_fontsize,
|
|
145
|
+
"savefig.bbox": "tight",
|
|
146
|
+
"savefig.facecolor": "white",
|
|
147
|
+
"savefig.dpi": self.plot_dpi,
|
|
148
|
+
}
|
|
108
149
|
|
|
109
|
-
|
|
150
|
+
mpl.rcParams.update(self.param_dict)
|
|
110
151
|
|
|
111
|
-
|
|
112
|
-
metrics (Dict[str, Any]): Per-class, micro, and macro-averaged metrics including accuracy, ROC-AUC, and Precision-Recall with Average Precision scores.
|
|
152
|
+
unsuper = {"ImputeVAE", "ImputeNLPCA", "ImputeAutoencoder", "ImputeUBP"}
|
|
113
153
|
|
|
114
|
-
|
|
154
|
+
det = {
|
|
155
|
+
"ImputeRefAllele",
|
|
156
|
+
"ImputeMostFrequent",
|
|
157
|
+
"ImputeMostFrequentPerPop",
|
|
158
|
+
"ImputePhylo",
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
sup = {"ImputeRandomForest", "ImputeHistGradientBoosting"}
|
|
115
162
|
|
|
116
|
-
|
|
163
|
+
if model_name in unsuper:
|
|
164
|
+
plot_dir = "Unsupervised"
|
|
165
|
+
elif model_name in det:
|
|
166
|
+
plot_dir = "Deterministic"
|
|
167
|
+
elif model_name in sup:
|
|
168
|
+
plot_dir = "Supervised"
|
|
169
|
+
else:
|
|
170
|
+
msg = f"model_name '{model_name}' not recognized."
|
|
171
|
+
self.logger.error(msg)
|
|
172
|
+
raise ValueError(msg)
|
|
173
|
+
|
|
174
|
+
self.output_dir = Path(f"{self.prefix}_output", plot_dir)
|
|
175
|
+
self.output_dir = self.output_dir / "plots" / model_name
|
|
176
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
177
|
+
|
|
178
|
+
# --------------------------------------------------------------------- #
|
|
179
|
+
# Core plotting methods #
|
|
180
|
+
# --------------------------------------------------------------------- #
|
|
181
|
+
def plot_tuning(
|
|
182
|
+
self,
|
|
183
|
+
study: optuna.study.Study,
|
|
184
|
+
model_name: str,
|
|
185
|
+
optimize_dir: Path,
|
|
186
|
+
target_name: str = "Objective Value",
|
|
187
|
+
) -> None:
|
|
188
|
+
"""Plot the optimization history of a study.
|
|
189
|
+
|
|
190
|
+
This method plots the optimization history of a study. The plot is saved to disk as a ``<plot_format>`` file.
|
|
117
191
|
|
|
118
|
-
|
|
192
|
+
Args:
|
|
193
|
+
study (optuna.study.Study): Optuna study object.
|
|
194
|
+
model_name (str): Name of the model.
|
|
195
|
+
optimize_dir (Path): Directory to save the optimization plots.
|
|
196
|
+
target_name (str, optional): Name of the target value. Defaults to 'Objective Value'.
|
|
119
197
|
"""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
fn = os.path.join(
|
|
125
|
-
f"{prefix}_output",
|
|
126
|
-
"plots",
|
|
127
|
-
"Unsupervised",
|
|
128
|
-
nn_method,
|
|
129
|
-
f"auc_pr_curves.pdf",
|
|
130
|
-
)
|
|
131
|
-
fig = plt.figure(figsize=(20, 10))
|
|
198
|
+
with warnings.catch_warnings():
|
|
199
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
200
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
201
|
+
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
|
132
202
|
|
|
133
|
-
|
|
134
|
-
ham = round(metrics["hamming"], 2)
|
|
203
|
+
target_name = target_name.title()
|
|
135
204
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
)
|
|
139
|
-
axs = fig.subplots(nrows=1, ncols=2)
|
|
140
|
-
plt.subplots_adjust(hspace=0.5)
|
|
141
|
-
|
|
142
|
-
# Line weight
|
|
143
|
-
lw = 2
|
|
144
|
-
|
|
145
|
-
roc_auc = metrics["roc_auc"]
|
|
146
|
-
pr_ap = metrics["precision_recall"]
|
|
147
|
-
|
|
148
|
-
metric_list = [roc_auc, pr_ap]
|
|
149
|
-
|
|
150
|
-
for metric, ax in zip(metric_list, axs):
|
|
151
|
-
if "fpr_micro" in metric:
|
|
152
|
-
prefix1 = "fpr"
|
|
153
|
-
prefix2 = "tpr"
|
|
154
|
-
lab1 = "ROC"
|
|
155
|
-
lab2 = "AUC"
|
|
156
|
-
xlab = "False Positive Rate"
|
|
157
|
-
ylab = "True Positive Rate"
|
|
158
|
-
title = "Receiver Operating Characteristic (ROC)"
|
|
159
|
-
baseline = [0, 1]
|
|
160
|
-
|
|
161
|
-
elif "recall_micro" in metric:
|
|
162
|
-
prefix1 = "recall"
|
|
163
|
-
prefix2 = "precision"
|
|
164
|
-
lab1 = "Precision-Recall"
|
|
165
|
-
lab2 = "AP"
|
|
166
|
-
xlab = "Recall"
|
|
167
|
-
ylab = "Precision"
|
|
168
|
-
title = "Precision-Recall"
|
|
169
|
-
baseline = [metric["baseline"], metric["baseline"]]
|
|
170
|
-
|
|
171
|
-
# Plot iso-f1 curves.
|
|
172
|
-
f_scores = np.linspace(0.2, 0.8, num=4)
|
|
173
|
-
for i, f_score in enumerate(f_scores):
|
|
174
|
-
x = np.linspace(0.01, 1)
|
|
175
|
-
y = f_score * x / (2 * x - f_score)
|
|
176
|
-
ax.plot(
|
|
177
|
-
x[y >= 0],
|
|
178
|
-
y[y >= 0],
|
|
179
|
-
color="gray",
|
|
180
|
-
alpha=0.2,
|
|
181
|
-
linewidth=lw,
|
|
182
|
-
label="Iso-F1 Curves" if i == 0 else "",
|
|
183
|
-
)
|
|
184
|
-
ax.annotate(f"F1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
|
|
185
|
-
|
|
186
|
-
# Plot ROC curves.
|
|
187
|
-
ax.plot(
|
|
188
|
-
metric[f"{prefix1}_micro"],
|
|
189
|
-
metric[f"{prefix2}_micro"],
|
|
190
|
-
label=f"Micro-averaged {lab1} Curve ({lab2} = {metric['micro']:.2f})",
|
|
191
|
-
color="deeppink",
|
|
192
|
-
linestyle=":",
|
|
193
|
-
linewidth=4,
|
|
205
|
+
ax = optuna.visualization.matplotlib.plot_optimization_history(
|
|
206
|
+
study, target_name=target_name
|
|
194
207
|
)
|
|
195
|
-
|
|
196
|
-
ax.
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
208
|
+
ax.set_title(f"{model_name} Optimization History")
|
|
209
|
+
ax.set_xlabel("Trial")
|
|
210
|
+
ax.set_ylabel(target_name)
|
|
211
|
+
ax.legend(
|
|
212
|
+
loc="best",
|
|
213
|
+
shadow=True,
|
|
214
|
+
fancybox=True,
|
|
215
|
+
fontsize=mpl.rcParamsDefault["legend.fontsize"],
|
|
203
216
|
)
|
|
204
217
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
if f"{prefix1}_{i}" in metric:
|
|
208
|
-
ax.plot(
|
|
209
|
-
metric[f"{prefix1}_{i}"],
|
|
210
|
-
metric[f"{prefix2}_{i}"],
|
|
211
|
-
color=color,
|
|
212
|
-
lw=lw,
|
|
213
|
-
label=f"{lab1} Curve of class {i} ({lab2} = {metric[i]:.2f})",
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
if "fpr_micro" in metric:
|
|
217
|
-
# Make center baseline
|
|
218
|
-
ax.plot(
|
|
219
|
-
baseline,
|
|
220
|
-
baseline,
|
|
221
|
-
"k--",
|
|
222
|
-
linewidth=lw,
|
|
223
|
-
label="No Classification Skill",
|
|
224
|
-
)
|
|
225
|
-
else:
|
|
226
|
-
ax.plot(
|
|
227
|
-
[0, 1],
|
|
228
|
-
baseline,
|
|
229
|
-
"k--",
|
|
230
|
-
linewidth=lw,
|
|
231
|
-
label="No Classification Skill",
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
ax.set_xlim(0.0, 1.0)
|
|
235
|
-
ax.set_ylim(0.0, 1.05)
|
|
236
|
-
ax.set_xlabel(f"{xlab}")
|
|
237
|
-
ax.set_ylabel(f"{ylab}")
|
|
238
|
-
ax.set_title(f"{title}")
|
|
239
|
-
ax.legend(loc="best")
|
|
240
|
-
|
|
241
|
-
fig.savefig(fn, bbox_inches="tight", facecolor="white")
|
|
242
|
-
plt.close()
|
|
243
|
-
plt.clf()
|
|
244
|
-
plt.cla()
|
|
245
|
-
|
|
246
|
-
@staticmethod
|
|
247
|
-
def plot_search_space(
|
|
248
|
-
estimator,
|
|
249
|
-
height=2,
|
|
250
|
-
s=25,
|
|
251
|
-
features=None,
|
|
252
|
-
):
|
|
253
|
-
"""Make density and contour plots for showing search space during grid search.
|
|
254
|
-
|
|
255
|
-
Modified from sklearn-genetic-opt function to implement exception handling.
|
|
256
|
-
|
|
257
|
-
Args:
|
|
258
|
-
estimator (sklearn estimator object): A fitted estimator from :class:`~sklearn_genetic.GASearchCV`.
|
|
259
|
-
|
|
260
|
-
height (float, optional): Height of each facet. Defaults to 2.
|
|
261
|
-
|
|
262
|
-
s (float, optional): Size of the markers in scatter plot. Defaults to 5.
|
|
263
|
-
|
|
264
|
-
features (list, optional): Subset of features to plot, if ``None`` it plots all the features by default. Defaults to None.
|
|
218
|
+
od = optimize_dir
|
|
219
|
+
fn = od / f"optuna_optimization_history.{self.plot_format}"
|
|
265
220
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
"""
|
|
269
|
-
sns.set_style("white")
|
|
270
|
-
|
|
271
|
-
df = logbook_to_pandas(estimator.logbook)
|
|
272
|
-
if features:
|
|
273
|
-
_stats = df[features]
|
|
274
|
-
else:
|
|
275
|
-
variables = [*estimator.space.parameters, "score"]
|
|
276
|
-
_stats = df[variables]
|
|
221
|
+
if not fn.parent.exists():
|
|
222
|
+
fn.parent.mkdir(parents=True, exist_ok=True)
|
|
277
223
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
g = g.map_upper(sns.scatterplot, s=s, color="r", alpha=0.2)
|
|
224
|
+
plt.savefig(fn)
|
|
225
|
+
plt.close()
|
|
281
226
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
sns.kdeplot,
|
|
285
|
-
shade=True,
|
|
286
|
-
cmap=sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True),
|
|
227
|
+
ax = optuna.visualization.matplotlib.plot_edf(
|
|
228
|
+
study, target_name=target_name
|
|
287
229
|
)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
sns.kdeplot,
|
|
297
|
-
shade=True,
|
|
298
|
-
palette="crest",
|
|
299
|
-
alpha=0.2,
|
|
300
|
-
color="red",
|
|
230
|
+
ax.set_title(f"{model_name} Empirical Distribution Function (EDF)")
|
|
231
|
+
ax.set_xlabel(target_name)
|
|
232
|
+
ax.set_ylabel(f"{model_name} Cumulative Probability")
|
|
233
|
+
ax.legend(
|
|
234
|
+
loc="best",
|
|
235
|
+
shadow=True,
|
|
236
|
+
fancybox=True,
|
|
237
|
+
fontsize=mpl.rcParamsDefault["legend.fontsize"],
|
|
301
238
|
)
|
|
302
|
-
except np.linalg.LinAlgError as err:
|
|
303
|
-
if "singular matrix" in str(err).lower():
|
|
304
|
-
g = g.map_diag(sns.histplot, color="red", alpha=1.0, kde=False)
|
|
305
|
-
|
|
306
|
-
return g
|
|
307
|
-
|
|
308
|
-
@staticmethod
|
|
309
|
-
def visualize_missingness(
|
|
310
|
-
genotype_data,
|
|
311
|
-
df,
|
|
312
|
-
zoom=True,
|
|
313
|
-
prefix="imputer",
|
|
314
|
-
horizontal_space=0.6,
|
|
315
|
-
vertical_space=0.6,
|
|
316
|
-
bar_color="gray",
|
|
317
|
-
heatmap_palette="magma",
|
|
318
|
-
plot_format="pdf",
|
|
319
|
-
dpi=300,
|
|
320
|
-
):
|
|
321
|
-
"""Make multiple plots to visualize missing data.
|
|
322
|
-
|
|
323
|
-
Args:
|
|
324
|
-
genotype_data (GenotypeData): Initialized GentoypeData object.
|
|
325
239
|
|
|
326
|
-
|
|
240
|
+
plt.savefig(fn.with_stem("optuna_edf_plot"))
|
|
241
|
+
plt.close()
|
|
327
242
|
|
|
328
|
-
|
|
243
|
+
ax = optuna.visualization.matplotlib.plot_param_importances(
|
|
244
|
+
study, target_name=target_name
|
|
245
|
+
)
|
|
246
|
+
ax.set_xlabel("Parameter Importance")
|
|
247
|
+
ax.set_ylabel("Parameter")
|
|
248
|
+
ax.legend(loc="best", shadow=True, fancybox=True)
|
|
329
249
|
|
|
330
|
-
|
|
250
|
+
plt.savefig(fn.with_stem("optuna_param_importances_plot"))
|
|
251
|
+
plt.close()
|
|
331
252
|
|
|
332
|
-
|
|
253
|
+
ax = optuna.visualization.matplotlib.plot_timeline(study)
|
|
254
|
+
ax.set_title(f"{model_name} Timeline Plot")
|
|
255
|
+
ax.set_xlabel("Datetime")
|
|
256
|
+
ax.set_ylabel("Trial")
|
|
257
|
+
plt.savefig(fn.with_stem("optuna_timeline_plot"))
|
|
258
|
+
plt.close()
|
|
333
259
|
|
|
334
|
-
|
|
260
|
+
# Reset the style from Optuna's plotting.
|
|
261
|
+
sns.set_style("white", rc=self.param_dict)
|
|
262
|
+
mpl.rcParams.update(self.param_dict)
|
|
335
263
|
|
|
336
|
-
|
|
264
|
+
# ---- MultiQC: Optuna tuning line graph + best-params table --------
|
|
265
|
+
if self._multiqc_enabled():
|
|
266
|
+
try:
|
|
267
|
+
self._queue_multiqc_tuning(
|
|
268
|
+
study=study, model_name=model_name, target_name=target_name
|
|
269
|
+
)
|
|
270
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
271
|
+
self.logger.warning(f"Failed to queue MultiQC tuning plots: {exc}")
|
|
337
272
|
|
|
338
|
-
|
|
273
|
+
def plot_metrics(
|
|
274
|
+
self,
|
|
275
|
+
y_true: np.ndarray,
|
|
276
|
+
y_pred_proba: np.ndarray,
|
|
277
|
+
metrics: Dict[str, float],
|
|
278
|
+
label_names: Optional[Sequence[str]] = None,
|
|
279
|
+
prefix: str = "",
|
|
280
|
+
) -> None:
|
|
281
|
+
"""Plot multi-class ROC-AUC and Precision-Recall curves.
|
|
339
282
|
|
|
340
|
-
|
|
283
|
+
This method plots the multi-class ROC-AUC and Precision-Recall curves. The plot is saved to disk as a ``<plot_format>`` file.
|
|
341
284
|
|
|
342
|
-
|
|
285
|
+
Args:
|
|
286
|
+
y_true (np.ndarray): 1D array of true integer labels in [0, n_classes-1].
|
|
287
|
+
y_pred_proba (np.ndarray): (n_samples, n_classes) array of predicted probabilities.
|
|
288
|
+
metrics (Dict[str, float]): Dict of summary metrics to annotate the figure.
|
|
289
|
+
label_names (Optional[Sequence[str]]): Optional sequence of class names (length must equal n_classes).
|
|
290
|
+
If provided, legends will use these names instead of 'Class i'.
|
|
291
|
+
prefix (str): Optional prefix for the output filename.
|
|
343
292
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
pandas.DataFrame: Per-individual missing data proportions.
|
|
347
|
-
pandas.DataFrame: Per-population + per-locus missing data proportions.
|
|
348
|
-
pandas.DataFrame: Per-population missing data proportions.
|
|
349
|
-
pandas.DataFrame: Per-individual and per-population missing data proportions.
|
|
293
|
+
Raises:
|
|
294
|
+
ValueError: If model_name is not recognized (legacy guard).
|
|
350
295
|
"""
|
|
296
|
+
num_classes = y_pred_proba.shape[1]
|
|
351
297
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
298
|
+
# Validate/normalize label names
|
|
299
|
+
if label_names is not None and len(label_names) != num_classes:
|
|
300
|
+
self.logger.warning(
|
|
301
|
+
f"plot_metrics: len(label_names)={len(label_names)} "
|
|
302
|
+
f"!= n_classes={num_classes}. Ignoring label_names."
|
|
303
|
+
)
|
|
304
|
+
label_names = None
|
|
305
|
+
if label_names is None:
|
|
306
|
+
label_names = [f"Class {i}" for i in range(num_classes)]
|
|
307
|
+
|
|
308
|
+
# Binarize y_true for one-vs-rest curves
|
|
309
|
+
y_true_bin = np.asarray(label_binarize(y_true, classes=np.arange(num_classes)))
|
|
310
|
+
|
|
311
|
+
# Containers
|
|
312
|
+
fpr, tpr, roc_auc_vals = {}, {}, {}
|
|
313
|
+
precision, recall, average_precision_vals = {}, {}, {}
|
|
314
|
+
|
|
315
|
+
# Per-class ROC & PR
|
|
316
|
+
for i in range(num_classes):
|
|
317
|
+
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
|
|
318
|
+
roc_auc_vals[i] = auc(fpr[i], tpr[i])
|
|
319
|
+
precision[i], recall[i], _ = precision_recall_curve(
|
|
320
|
+
y_true_bin[:, i], y_pred_proba[:, i]
|
|
321
|
+
)
|
|
322
|
+
average_precision_vals[i] = average_precision_score(
|
|
323
|
+
y_true_bin[:, i], y_pred_proba[:, i]
|
|
324
|
+
)
|
|
378
325
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
326
|
+
# Micro-averages
|
|
327
|
+
fpr["micro"], tpr["micro"], _ = roc_curve(
|
|
328
|
+
y_true_bin.ravel(), y_pred_proba.ravel()
|
|
382
329
|
)
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
which="both",
|
|
390
|
-
left=False,
|
|
391
|
-
right=False,
|
|
392
|
-
labelleft=False,
|
|
330
|
+
roc_auc_vals["micro"] = auc(fpr["micro"], tpr["micro"])
|
|
331
|
+
precision["micro"], recall["micro"], _ = precision_recall_curve(
|
|
332
|
+
y_true_bin.ravel(), y_pred_proba.ravel()
|
|
333
|
+
)
|
|
334
|
+
average_precision_vals["micro"] = average_precision_score(
|
|
335
|
+
y_true_bin, y_pred_proba, average="micro"
|
|
393
336
|
)
|
|
394
337
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
338
|
+
# Macro-average ROC
|
|
339
|
+
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
|
|
340
|
+
mean_tpr = np.zeros_like(all_fpr)
|
|
341
|
+
for i in range(num_classes):
|
|
342
|
+
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
|
|
343
|
+
mean_tpr /= num_classes
|
|
344
|
+
fpr["macro"], tpr["macro"] = all_fpr, mean_tpr
|
|
345
|
+
roc_auc_vals["macro"] = auc(fpr["macro"], tpr["macro"])
|
|
346
|
+
|
|
347
|
+
# Macro-average PR
|
|
348
|
+
all_recall = np.unique(np.concatenate([recall[i] for i in range(num_classes)]))
|
|
349
|
+
mean_precision = np.zeros_like(all_recall)
|
|
350
|
+
for i in range(num_classes):
|
|
351
|
+
# recall[i] increases, but precision[i] is given over decreasing thresholds
|
|
352
|
+
mean_precision += np.interp(all_recall, recall[i][::-1], precision[i][::-1])
|
|
353
|
+
mean_precision /= num_classes
|
|
354
|
+
average_precision_vals["macro"] = average_precision_score(
|
|
355
|
+
y_true_bin, y_pred_proba, average="macro"
|
|
356
|
+
)
|
|
410
357
|
|
|
411
|
-
|
|
358
|
+
# Plot
|
|
359
|
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
|
412
360
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
361
|
+
# ROC
|
|
362
|
+
axes[0].plot(
|
|
363
|
+
fpr["micro"],
|
|
364
|
+
tpr["micro"],
|
|
365
|
+
label=f"Micro-average ROC (AUC = {roc_auc_vals['micro']:.2f})",
|
|
366
|
+
linestyle=":",
|
|
367
|
+
linewidth=4,
|
|
368
|
+
)
|
|
369
|
+
axes[0].plot(
|
|
370
|
+
fpr["macro"],
|
|
371
|
+
tpr["macro"],
|
|
372
|
+
label=f"Macro-average ROC (AUC = {roc_auc_vals['macro']:.2f})",
|
|
373
|
+
linestyle="--",
|
|
374
|
+
linewidth=4,
|
|
375
|
+
)
|
|
376
|
+
for i in range(num_classes):
|
|
377
|
+
axes[0].plot(
|
|
378
|
+
fpr[i],
|
|
379
|
+
tpr[i],
|
|
380
|
+
label=f"{label_names[i]} ROC (AUC = {roc_auc_vals[i]:.2f})",
|
|
421
381
|
)
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
indpop["Population"] = genotype_data.pops
|
|
434
|
-
|
|
435
|
-
melt_df = melt_df.melt(value_name="Missing", id_vars=id_vars)
|
|
436
|
-
melt_df.sort_values(by=id_vars[::-1], inplace=True)
|
|
437
|
-
melt_df["Missing"].replace(False, "Present", inplace=True)
|
|
438
|
-
melt_df["Missing"].replace(True, "Missing", inplace=True)
|
|
439
|
-
|
|
440
|
-
ax = axes[0, 2] if poptotal is None else axes[1, 1]
|
|
382
|
+
axes[0].plot([0, 1], [0, 1], linestyle="--", color="black", label="Random")
|
|
383
|
+
axes[0].set_xlabel("False Positive Rate")
|
|
384
|
+
axes[0].set_ylabel("True Positive Rate")
|
|
385
|
+
axes[0].set_title("Multi-class ROC-AUC Curve")
|
|
386
|
+
axes[0].legend(
|
|
387
|
+
loc="upper center",
|
|
388
|
+
bbox_to_anchor=(0.5, -0.15),
|
|
389
|
+
fancybox=True,
|
|
390
|
+
shadow=True,
|
|
391
|
+
ncol=2,
|
|
392
|
+
)
|
|
441
393
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
394
|
+
# PR
|
|
395
|
+
axes[1].plot(
|
|
396
|
+
recall["micro"],
|
|
397
|
+
precision["micro"],
|
|
398
|
+
label=f"Micro-average PR (AP = {average_precision_vals['micro']:.2f})",
|
|
399
|
+
linestyle=":",
|
|
400
|
+
linewidth=4,
|
|
449
401
|
)
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
402
|
+
axes[1].plot(
|
|
403
|
+
all_recall,
|
|
404
|
+
mean_precision,
|
|
405
|
+
label=f"Macro-average PR (AP = {average_precision_vals['macro']:.2f})",
|
|
406
|
+
linestyle="--",
|
|
407
|
+
linewidth=4,
|
|
456
408
|
)
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
ax.set_title("Per-Population")
|
|
463
|
-
g = sns.histplot(
|
|
464
|
-
data=melt_df,
|
|
465
|
-
y="Population",
|
|
466
|
-
hue="Missing",
|
|
467
|
-
multiple="fill",
|
|
468
|
-
ax=ax,
|
|
409
|
+
for i in range(num_classes):
|
|
410
|
+
axes[1].plot(
|
|
411
|
+
recall[i],
|
|
412
|
+
precision[i],
|
|
413
|
+
label=f"{label_names[i]} PR (AP = {average_precision_vals[i]:.2f})",
|
|
469
414
|
)
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
415
|
+
axes[1].plot([0, 1], [1, 0], linestyle="--", color="black", label="Random")
|
|
416
|
+
axes[1].set_xlabel("Recall")
|
|
417
|
+
axes[1].set_ylabel("Precision")
|
|
418
|
+
axes[1].set_title("Multi-class Precision-Recall Curve")
|
|
419
|
+
axes[1].legend(
|
|
420
|
+
loc="upper center",
|
|
421
|
+
bbox_to_anchor=(0.5, -0.15),
|
|
422
|
+
fancybox=True,
|
|
423
|
+
shadow=True,
|
|
424
|
+
ncol=2,
|
|
478
425
|
)
|
|
479
|
-
plt.cla()
|
|
480
|
-
plt.clf()
|
|
481
|
-
plt.close()
|
|
482
426
|
|
|
483
|
-
|
|
427
|
+
# Title & save
|
|
428
|
+
fig.suptitle("\n".join([f"{k}: {v:.2f}" for k, v in metrics.items()]), y=1.35)
|
|
484
429
|
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
430
|
+
prefix_for_name = f"{prefix}_" if prefix != "" else ""
|
|
431
|
+
out_name = (
|
|
432
|
+
f"{self.model_name}_{prefix_for_name}roc_pr_curves.{self.plot_format}"
|
|
433
|
+
)
|
|
434
|
+
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
|
|
435
|
+
if self.show_plots:
|
|
436
|
+
plt.show()
|
|
437
|
+
plt.close(fig)
|
|
438
|
+
|
|
439
|
+
# ---- MultiQC: metrics table + per-class AUC/AP heatmap ------------
|
|
440
|
+
if self._multiqc_enabled():
|
|
441
|
+
try:
|
|
442
|
+
self._queue_multiqc_metrics(
|
|
443
|
+
metrics=metrics,
|
|
444
|
+
roc_auc=roc_auc_vals,
|
|
445
|
+
average_precision=average_precision_vals,
|
|
446
|
+
label_names=label_names,
|
|
447
|
+
panel_prefix=prefix,
|
|
448
|
+
)
|
|
449
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
450
|
+
self.logger.warning(f"Failed to queue MultiQC metrics plots: {exc}")
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
self._queue_multiqc_roc_curves(
|
|
454
|
+
fpr=fpr,
|
|
455
|
+
tpr=tpr,
|
|
456
|
+
label_names=label_names,
|
|
457
|
+
panel_prefix=prefix,
|
|
458
|
+
)
|
|
459
|
+
self._queue_multiqc_pr_curves(
|
|
460
|
+
precision=precision,
|
|
461
|
+
recall=recall,
|
|
462
|
+
label_names=label_names,
|
|
463
|
+
panel_prefix=prefix,
|
|
464
|
+
)
|
|
465
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
466
|
+
self.logger.warning(f"Failed to queue MultiQC ROC/PR curves: {exc}")
|
|
507
467
|
|
|
508
|
-
|
|
468
|
+
def plot_history(
|
|
469
|
+
self,
|
|
470
|
+
history: Dict[str, List[float] | Dict[str, List[float]] | None] | None,
|
|
471
|
+
) -> None:
|
|
472
|
+
"""Plot model history traces. Will be saved to file.
|
|
509
473
|
|
|
510
|
-
|
|
474
|
+
This method plots the deep learning model history traces. The plot is saved to disk as a ``<plot_format>`` file.
|
|
511
475
|
|
|
512
|
-
|
|
476
|
+
Args:
|
|
477
|
+
history (Dict[str, List[float]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
|
|
513
478
|
|
|
514
|
-
|
|
479
|
+
Raises:
|
|
480
|
+
ValueError: nn_method must be either 'ImputeNLPCA', 'ImputeUBP', 'ImputeAutoencoder', 'ImputeVAE'.
|
|
481
|
+
"""
|
|
482
|
+
if self.model_name not in {
|
|
483
|
+
"ImputeNLPCA",
|
|
484
|
+
"ImputeVAE",
|
|
485
|
+
"ImputeAutoencoder",
|
|
486
|
+
"ImputeUBP",
|
|
487
|
+
}:
|
|
488
|
+
msg = "nn_method must be either 'ImputeNLPCA', 'ImputeVAE', 'ImputeAutoencoder', 'ImputeUBP'."
|
|
489
|
+
self.logger.error(msg)
|
|
490
|
+
raise ValueError(msg)
|
|
491
|
+
|
|
492
|
+
if self.model_name != "ImputeUBP":
|
|
493
|
+
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
|
|
494
|
+
df = pd.DataFrame(history)
|
|
495
|
+
df = df.iloc[1:]
|
|
515
496
|
|
|
516
|
-
|
|
497
|
+
# Plot train accuracy
|
|
498
|
+
ax.plot(df["Train"], c="blue", lw=3)
|
|
517
499
|
|
|
518
|
-
|
|
519
|
-
|
|
500
|
+
ax.set_title(f"{self.model_name} Loss per Epoch")
|
|
501
|
+
ax.set_ylabel("Loss")
|
|
502
|
+
ax.set_xlabel("Epoch")
|
|
503
|
+
ax.legend(["Train"], loc="best", shadow=True, fancybox=True)
|
|
520
504
|
|
|
521
|
-
|
|
505
|
+
else:
|
|
506
|
+
fig, ax = plt.subplots(3, 1, figsize=(12, 8))
|
|
507
|
+
|
|
508
|
+
# Ensure history is the nested dictionary type for ImputeUBP
|
|
509
|
+
if not (
|
|
510
|
+
isinstance(history, dict)
|
|
511
|
+
and "Train" in history
|
|
512
|
+
and isinstance(history["Train"], dict)
|
|
513
|
+
):
|
|
514
|
+
msg = "For ImputeUBP, history must be a nested dictionary with phases."
|
|
515
|
+
self.logger.error(msg)
|
|
516
|
+
raise TypeError(msg)
|
|
517
|
+
|
|
518
|
+
for i, phase in enumerate(range(1, 4)):
|
|
519
|
+
train = pd.Series(history["Train"][f"Phase {phase}"])
|
|
520
|
+
train = train.iloc[1:] # ignore first epoch
|
|
521
|
+
|
|
522
|
+
# Plot train accuracy
|
|
523
|
+
ax[i].plot(train, c="blue", lw=3)
|
|
524
|
+
ax[i].set_title(f"{self.model_name}: Phase {phase} Loss per Epoch")
|
|
525
|
+
ax[i].set_ylabel("Loss")
|
|
526
|
+
ax[i].set_xlabel("Epoch")
|
|
527
|
+
ax[i].legend([f"Phase {phase}"], loc="best", shadow=True, fancybox=True)
|
|
528
|
+
|
|
529
|
+
fn = f"{self.model_name.lower()}_history_plot.{self.plot_format}"
|
|
530
|
+
fn = self.output_dir / fn
|
|
531
|
+
fig.savefig(fn)
|
|
532
|
+
|
|
533
|
+
if self.show_plots:
|
|
534
|
+
plt.show()
|
|
535
|
+
plt.close(fig)
|
|
536
|
+
|
|
537
|
+
# ---- MultiQC: training-loss vs epoch linegraphs -------------------
|
|
538
|
+
if self._multiqc_enabled():
|
|
539
|
+
try:
|
|
540
|
+
self._queue_multiqc_history(history=history)
|
|
541
|
+
except Exception as exc: # pragma: no cover
|
|
542
|
+
self.logger.warning(f"Failed to queue MultiQC history plot: {exc}")
|
|
522
543
|
|
|
523
|
-
|
|
544
|
+
def plot_confusion_matrix(
|
|
545
|
+
self,
|
|
546
|
+
y_true_1d: np.ndarray | pd.DataFrame | List[str | int] | torch.Tensor,
|
|
547
|
+
y_pred_1d: np.ndarray | pd.DataFrame | List[str | int] | torch.Tensor,
|
|
548
|
+
label_names: Sequence[str] | Dict[str, int] | None = None,
|
|
549
|
+
prefix: str = "",
|
|
550
|
+
) -> None:
|
|
551
|
+
"""Plot a confusion matrix with optional class labels.
|
|
524
552
|
|
|
525
|
-
|
|
553
|
+
This method plots a confusion matrix using true and predicted labels. The plot is saved to disk as a ``<plot_format>`` file.
|
|
526
554
|
|
|
527
|
-
|
|
555
|
+
Args:
|
|
556
|
+
y_true_1d (np.ndarray | pd.DataFrame | list | torch.Tensor): 1D array of true integer labels in [0, n_classes-1].
|
|
557
|
+
y_pred_1d (np.ndarray | pd.DataFrame | list | torch.Tensor): 1D array of predicted integer labels in [0, n_classes-1].
|
|
558
|
+
label_names (Sequence[str] | None): Optional sequence of class names (length must equal n_classes). If provided, both the internal label order and displayed tick labels will respect this order (assumed to be 0..n-1).
|
|
559
|
+
prefix (str): Optional prefix for the output filename.
|
|
528
560
|
|
|
529
|
-
|
|
561
|
+
Notes:
|
|
562
|
+
- If `label_names` is None, the display labels default to the numeric class indices inferred from `y_true_1d ∪ y_pred_1d`.
|
|
563
|
+
"""
|
|
564
|
+
y_true_1d = misc.validate_input_type(y_true_1d, return_type="array")
|
|
565
|
+
y_pred_1d = misc.validate_input_type(y_pred_1d, return_type="array")
|
|
566
|
+
|
|
567
|
+
if not isinstance(y_true_1d, np.ndarray) or y_true_1d.ndim != 1:
|
|
568
|
+
msg = "y_true_1d must be a 1D array-like of true labels."
|
|
569
|
+
self.logger.error(msg)
|
|
570
|
+
raise TypeError(msg)
|
|
571
|
+
|
|
572
|
+
if not isinstance(y_pred_1d, np.ndarray) or y_pred_1d.ndim != 1:
|
|
573
|
+
msg = "y_pred_1d must be a 1D array-like of predicted labels."
|
|
574
|
+
self.logger.error(msg)
|
|
575
|
+
raise TypeError(msg)
|
|
576
|
+
|
|
577
|
+
if y_true_1d.ndim > 1:
|
|
578
|
+
y_true_1d = y_true_1d.flatten()
|
|
579
|
+
|
|
580
|
+
if y_pred_1d.ndim > 1:
|
|
581
|
+
y_pred_1d = y_pred_1d.flatten()
|
|
582
|
+
|
|
583
|
+
# Determine class count/order
|
|
584
|
+
if label_names is not None:
|
|
585
|
+
n_classes = len(label_names)
|
|
586
|
+
labels = np.arange(n_classes) # our y_* are ints 0..n-1
|
|
587
|
+
display_labels = list(map(str, label_names))
|
|
588
|
+
else:
|
|
589
|
+
# Infer labels from data to keep matrix tight
|
|
590
|
+
labels = np.unique(np.concatenate([y_true_1d, y_pred_1d]))
|
|
591
|
+
display_labels = labels # sklearn will convert to strings
|
|
530
592
|
|
|
531
|
-
|
|
593
|
+
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
|
|
532
594
|
|
|
533
|
-
|
|
595
|
+
ConfusionMatrixDisplay.from_predictions(
|
|
596
|
+
y_true=y_true_1d,
|
|
597
|
+
y_pred=y_pred_1d,
|
|
598
|
+
labels=labels,
|
|
599
|
+
display_labels=display_labels,
|
|
600
|
+
ax=ax,
|
|
601
|
+
cmap="viridis",
|
|
602
|
+
colorbar=True,
|
|
603
|
+
)
|
|
534
604
|
|
|
535
|
-
|
|
605
|
+
# Build a stable panel id before mutating prefix
|
|
606
|
+
panel_suffix = f"{prefix}_" if prefix else ""
|
|
607
|
+
panel_id = f"{self.model_name.lower()}_{panel_suffix}confusion_matrix"
|
|
536
608
|
|
|
537
|
-
|
|
609
|
+
if prefix != "":
|
|
610
|
+
prefix = f"{prefix}_"
|
|
538
611
|
|
|
539
|
-
|
|
612
|
+
out_name = (
|
|
613
|
+
f"{self.model_name.lower()}_{prefix}confusion_matrix.{self.plot_format}"
|
|
614
|
+
)
|
|
615
|
+
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
|
|
616
|
+
if self.show_plots:
|
|
617
|
+
plt.show()
|
|
618
|
+
plt.close(fig)
|
|
619
|
+
|
|
620
|
+
# ---- MultiQC: confusion-matrix heatmap ----------------------------
|
|
621
|
+
if self._multiqc_enabled():
|
|
622
|
+
try:
|
|
623
|
+
self._queue_multiqc_confusion(
|
|
624
|
+
y_true=y_true_1d,
|
|
625
|
+
y_pred=y_pred_1d,
|
|
626
|
+
labels=labels,
|
|
627
|
+
display_labels=display_labels,
|
|
628
|
+
panel_id=panel_id,
|
|
629
|
+
)
|
|
630
|
+
except Exception as exc: # pragma: no cover
|
|
631
|
+
self.logger.warning(f"Failed to queue MultiQC confusion matrix: {exc}")
|
|
540
632
|
|
|
541
|
-
|
|
633
|
+
def plot_gt_distribution(
|
|
634
|
+
self,
|
|
635
|
+
X: np.ndarray | pd.DataFrame | list | torch.Tensor,
|
|
636
|
+
is_imputed: bool = False,
|
|
637
|
+
) -> None:
|
|
638
|
+
"""Plot genotype distribution (IUPAC or integer-encoded).
|
|
542
639
|
|
|
543
|
-
|
|
640
|
+
This plots counts for all genotypes present in X. It supports IUPAC single-letter genotypes and integer encodings. Missing markers '-', '.', '?' are normalized to 'N'. Bars are annotated with counts and percentages.
|
|
544
641
|
|
|
545
|
-
|
|
642
|
+
Args:
|
|
643
|
+
X (np.ndarray | pd.DataFrame | list | torch.Tensor): Array-like genotype matrix. Rows=loci, cols=samples (any orientation is OK). Elements are IUPAC one-letter genotypes (e.g., 'A','C','G','T','N','R',...) or integers (e.g., 0/1/2[/3]).
|
|
644
|
+
is_imputed (bool): Whether these genotypes are imputed. Affects the title only. Defaults to False.
|
|
645
|
+
"""
|
|
646
|
+
# Flatten X to a 1D Series
|
|
647
|
+
if isinstance(X, pd.DataFrame):
|
|
648
|
+
arr = X.values
|
|
649
|
+
elif torch.is_tensor(X):
|
|
650
|
+
arr = X.detach().cpu().numpy()
|
|
651
|
+
else:
|
|
652
|
+
arr = np.asarray(X)
|
|
653
|
+
|
|
654
|
+
s = pd.Series(arr.ravel())
|
|
655
|
+
|
|
656
|
+
# Detect string vs numeric encodings and normalize
|
|
657
|
+
if s.dtype.kind in ("O", "U", "S"): # string-like → IUPAC path
|
|
658
|
+
s = s.astype(str).str.upper().replace({"-": "N", ".": "N", "?": "N"})
|
|
659
|
+
x_label = "Genotype (IUPAC)"
|
|
660
|
+
|
|
661
|
+
# Define canonical order: N, A/C/T/G, then IUPAC ambiguity codes.
|
|
662
|
+
canonical = ["A", "C", "T", "G"]
|
|
663
|
+
iupac_ambiguity = sorted(["M", "R", "W", "S", "Y", "K", "V", "H", "D", "B"])
|
|
664
|
+
base_order = ["N"] + canonical + iupac_ambiguity
|
|
665
|
+
else: # numeric path (e.g., 0/1/2/[3], -1 for missing)
|
|
666
|
+
# Map common missing sentinels to 'N', keep others as strings for
|
|
667
|
+
# labeling
|
|
668
|
+
s = s.astype(float) # allow NaN comparisons
|
|
669
|
+
s = s.where(~np.isin(s, [-1, np.nan]), other=np.nan)
|
|
670
|
+
s = s.fillna("N").astype(int, errors="ignore").astype(str)
|
|
671
|
+
|
|
672
|
+
x_label = "Genotype (Integer-encoded)"
|
|
673
|
+
|
|
674
|
+
# Support both ternary and quaternary encodings; keep a stable order
|
|
675
|
+
base_order = ["N", "0", "1", "2", "3"]
|
|
676
|
+
|
|
677
|
+
# Include any unexpected symbols at the end (sorted) so nothing is
|
|
678
|
+
# dropped
|
|
679
|
+
extras = sorted(set(s.unique()) - set(base_order))
|
|
680
|
+
full_order = base_order + [e for e in extras if e not in base_order]
|
|
681
|
+
|
|
682
|
+
# Count and reindex to show zero-count categories
|
|
683
|
+
counts = s.value_counts().reindex(full_order, fill_value=0)
|
|
684
|
+
df = counts.rename_axis("Genotype").reset_index(name="Count")
|
|
685
|
+
df["Percent"] = df["Count"] / df["Count"].sum() * 100
|
|
686
|
+
|
|
687
|
+
title = "Imputed Genotype Counts" if is_imputed else "Genotype Counts"
|
|
688
|
+
|
|
689
|
+
# --- Plot ---
|
|
690
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
691
|
+
sns.despine(fig=fig)
|
|
692
|
+
|
|
693
|
+
ax = sns.barplot(
|
|
694
|
+
data=df,
|
|
695
|
+
x="Genotype",
|
|
696
|
+
y="Percent",
|
|
697
|
+
hue="Genotype",
|
|
698
|
+
order=full_order,
|
|
699
|
+
errorbar=None,
|
|
700
|
+
ax=ax,
|
|
701
|
+
palette="Set1",
|
|
702
|
+
legend=False,
|
|
703
|
+
fill=True,
|
|
704
|
+
)
|
|
546
705
|
|
|
547
|
-
|
|
706
|
+
ax.set_xlabel(x_label)
|
|
707
|
+
ax.set_ylabel("Percent")
|
|
708
|
+
ax.set_title(title)
|
|
709
|
+
ax.set_ylim((0.0, 50.0))
|
|
548
710
|
|
|
549
|
-
|
|
711
|
+
fig.tight_layout()
|
|
550
712
|
|
|
551
|
-
|
|
552
|
-
|
|
713
|
+
suffix = "imputed" if is_imputed else "original"
|
|
714
|
+
fn = self.output_dir / f"gt_distributions_{suffix}.{self.plot_format}"
|
|
715
|
+
fig.savefig(fn, dpi=300)
|
|
716
|
+
|
|
717
|
+
if self.show_plots:
|
|
718
|
+
plt.show()
|
|
719
|
+
plt.close(fig)
|
|
720
|
+
|
|
721
|
+
# ---- MultiQC: genotype-distribution barplot -----------------------
|
|
722
|
+
if self._multiqc_enabled():
|
|
723
|
+
try:
|
|
724
|
+
self._queue_multiqc_gt_distribution(df=df, is_imputed=is_imputed)
|
|
725
|
+
except Exception as exc: # pragma: no cover
|
|
726
|
+
self.logger.warning(
|
|
727
|
+
f"Failed to queue MultiQC genotype distribution: {exc}"
|
|
728
|
+
)
|
|
553
729
|
|
|
554
|
-
|
|
730
|
+
# --------------------------------------------------------------------- #
|
|
731
|
+
# MultiQC helper methods #
|
|
732
|
+
# --------------------------------------------------------------------- #
|
|
733
|
+
def _multiqc_enabled(self) -> bool:
|
|
734
|
+
"""Return True if MultiQC integration is active."""
|
|
735
|
+
return bool(self.use_multiqc)
|
|
736
|
+
|
|
737
|
+
def _queue_multiqc_tuning(
|
|
738
|
+
self,
|
|
739
|
+
*,
|
|
740
|
+
study: optuna.study.Study,
|
|
741
|
+
model_name: str,
|
|
742
|
+
target_name: str,
|
|
743
|
+
) -> None:
|
|
744
|
+
"""Queue Optuna tuning results for MultiQC.
|
|
555
745
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
>>> popmapfile="popmap.txt",
|
|
561
|
-
>>>)
|
|
562
|
-
>>>
|
|
563
|
-
>>>ubp = ImputeUBP(genotype_data=data)
|
|
564
|
-
>>>
|
|
565
|
-
>>>components, pca = run_and_plot_pca(
|
|
566
|
-
>>> data,
|
|
567
|
-
>>> ubp,
|
|
568
|
-
>>> scale=True,
|
|
569
|
-
>>> center=True,
|
|
570
|
-
>>> plot_format="png"
|
|
571
|
-
>>>)
|
|
572
|
-
>>>
|
|
573
|
-
>>>explvar = pca.explained_variance_ratio\_
|
|
746
|
+
Args:
|
|
747
|
+
study (optuna.study.Study): Optuna study object.
|
|
748
|
+
model_name (str): Name of the model.
|
|
749
|
+
target_name (str): Name of the target value.
|
|
574
750
|
"""
|
|
575
|
-
|
|
576
|
-
|
|
751
|
+
if not self._multiqc_enabled():
|
|
752
|
+
return
|
|
577
753
|
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
"<2 axes is not supported; n_axes must be either 2 or 3."
|
|
754
|
+
# trial number vs objective value line graph
|
|
755
|
+
try:
|
|
756
|
+
df_trials = study.trials_dataframe(attrs=("number", "value"))
|
|
757
|
+
except Exception as exc: # pragma: no cover
|
|
758
|
+
self.logger.warning(
|
|
759
|
+
f"Could not extract trials_dataframe for MultiQC: {exc}"
|
|
585
760
|
)
|
|
761
|
+
return
|
|
586
762
|
|
|
587
|
-
|
|
763
|
+
if df_trials.empty or "value" not in df_trials:
|
|
764
|
+
return
|
|
588
765
|
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
original_df.replace(-9, np.nan, inplace=True)
|
|
599
|
-
|
|
600
|
-
if center or scale:
|
|
601
|
-
# Center data to mean. Scaling to unit variance is off.
|
|
602
|
-
scaler = StandardScaler(with_mean=center, with_std=scale)
|
|
603
|
-
pca_df = scaler.fit_transform(df)
|
|
604
|
-
else:
|
|
605
|
-
pca_df = df.copy()
|
|
606
|
-
|
|
607
|
-
# Run PCA.
|
|
608
|
-
model = PCA(n_components=n_components)
|
|
609
|
-
components = model.fit_transform(pca_df)
|
|
766
|
+
data: Dict[str, Dict[int, int]] = {
|
|
767
|
+
model_name: {
|
|
768
|
+
row["number"]: row["value"]
|
|
769
|
+
for _, row in df_trials.iterrows()
|
|
770
|
+
if row["value"] is not None
|
|
771
|
+
}
|
|
772
|
+
}
|
|
610
773
|
|
|
611
|
-
|
|
612
|
-
|
|
774
|
+
if not data[model_name]:
|
|
775
|
+
return
|
|
776
|
+
|
|
777
|
+
SNPioMultiQC.queue_linegraph(
|
|
778
|
+
data=data,
|
|
779
|
+
panel_id=f"{self.model_name}_optuna_history",
|
|
780
|
+
section=self.multiqc_section,
|
|
781
|
+
title=f"{self.model_name} Optuna Optimization History",
|
|
782
|
+
index_label="Trial",
|
|
783
|
+
description=f"Optuna optimization history for {self.model_name} "
|
|
784
|
+
f"(target={target_name}).",
|
|
613
785
|
)
|
|
614
786
|
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
787
|
+
# best-params table
|
|
788
|
+
try:
|
|
789
|
+
best_value = study.best_value
|
|
790
|
+
best_params = study.best_params
|
|
791
|
+
except Exception:
|
|
792
|
+
return
|
|
793
|
+
|
|
794
|
+
if best_params:
|
|
795
|
+
series = pd.Series(best_params, name="Best Value")
|
|
796
|
+
series["objective"] = best_value
|
|
797
|
+
SNPioMultiQC.queue_table(
|
|
798
|
+
df=series,
|
|
799
|
+
panel_id=f"{self.model_name}_optuna_best_params",
|
|
800
|
+
section=self.multiqc_section,
|
|
801
|
+
title=f"{self.model_name} Best Optuna Parameters",
|
|
802
|
+
index_label="Parameter",
|
|
803
|
+
description="Best Optuna hyperparameters and objective value.",
|
|
804
|
+
)
|
|
623
805
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
806
|
+
def _queue_multiqc_roc_curves(
|
|
807
|
+
self,
|
|
808
|
+
*,
|
|
809
|
+
fpr: dict,
|
|
810
|
+
tpr: dict,
|
|
811
|
+
label_names: Sequence[str],
|
|
812
|
+
panel_prefix: str,
|
|
813
|
+
) -> None:
|
|
814
|
+
"""Queue ROC and Precision-Recall curves for MultiQC.
|
|
631
815
|
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
816
|
+
Args:
|
|
817
|
+
fpr (dict): False positive rates for each class.
|
|
818
|
+
tpr (dict): True positive rates for each class.
|
|
819
|
+
label_names (Sequence[str]): Class names.
|
|
820
|
+
panel_prefix (str): Optional prefix for panel IDs.
|
|
821
|
+
"""
|
|
822
|
+
if not self._multiqc_enabled():
|
|
823
|
+
return
|
|
824
|
+
|
|
825
|
+
def _curve_to_mapping(
|
|
826
|
+
x_vals: Sequence[float], y_vals: Sequence[float]
|
|
827
|
+
) -> Dict[float, float]:
|
|
828
|
+
"""Return {x: y} mapping expected by MultiQC linegraphs."""
|
|
829
|
+
return {float(x): float(y) for x, y in zip(x_vals, y_vals)}
|
|
830
|
+
|
|
831
|
+
data: Dict[str, Dict[float, float]] = {}
|
|
832
|
+
|
|
833
|
+
# Only report the first three classes (MultiQC plot readability) plus micro/macro averages
|
|
834
|
+
class_keys = sorted(k for k in fpr.keys() if isinstance(k, int))
|
|
835
|
+
for idx in class_keys[:3]:
|
|
836
|
+
label = label_names[idx] if idx < len(label_names) else f"Class {idx}"
|
|
837
|
+
data[label] = _curve_to_mapping(fpr[idx], tpr[idx])
|
|
838
|
+
|
|
839
|
+
for agg in ("micro", "macro"):
|
|
840
|
+
if agg in fpr and agg in tpr:
|
|
841
|
+
pretty_name = f"{agg.title()} Average"
|
|
842
|
+
data[pretty_name] = _curve_to_mapping(fpr[agg], tpr[agg])
|
|
843
|
+
|
|
844
|
+
if not data:
|
|
845
|
+
return
|
|
846
|
+
|
|
847
|
+
# ROC curves
|
|
848
|
+
curve_data = cast(Dict[str, Dict[int, int]], data)
|
|
849
|
+
|
|
850
|
+
SNPioMultiQC.queue_linegraph(
|
|
851
|
+
data=curve_data,
|
|
852
|
+
panel_id=(
|
|
853
|
+
f"{self.model_name}_{panel_prefix}_roc_curves"
|
|
854
|
+
if panel_prefix
|
|
855
|
+
else f"{self.model_name}_roc_curves"
|
|
672
856
|
),
|
|
857
|
+
section=self.multiqc_section,
|
|
858
|
+
title=f"{self.model_name} ROC Curves",
|
|
859
|
+
index_label="False Positive Rate",
|
|
860
|
+
description="Multi-class ROC curves for PG-SUI predictions.",
|
|
673
861
|
)
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
862
|
+
|
|
863
|
+
def _queue_multiqc_pr_curves(
|
|
864
|
+
self,
|
|
865
|
+
*,
|
|
866
|
+
precision: dict,
|
|
867
|
+
recall: dict,
|
|
868
|
+
label_names: Sequence[str],
|
|
869
|
+
panel_prefix: str,
|
|
870
|
+
) -> None:
|
|
871
|
+
"""Queue Precision-Recall curves for MultiQC."""
|
|
872
|
+
if not self._multiqc_enabled():
|
|
873
|
+
return
|
|
874
|
+
|
|
875
|
+
def _curve_to_mapping(
|
|
876
|
+
x_vals: Sequence[float], y_vals: Sequence[float]
|
|
877
|
+
) -> Dict[float, float]:
|
|
878
|
+
"""Return {recall: precision} mapping expected by MultiQC linegraphs."""
|
|
879
|
+
return {float(x): float(y) for x, y in zip(x_vals, y_vals)}
|
|
880
|
+
|
|
881
|
+
data: Dict[str, Dict[float, float]] = {}
|
|
882
|
+
|
|
883
|
+
# Only report the first three classes (MultiQC plot readability) plus micro/macro averages
|
|
884
|
+
class_keys = sorted(k for k in recall.keys() if isinstance(k, int))
|
|
885
|
+
for idx in class_keys[:3]:
|
|
886
|
+
if idx not in precision or idx not in recall:
|
|
887
|
+
continue
|
|
888
|
+
label = label_names[idx] if idx < len(label_names) else f"Class {idx}"
|
|
889
|
+
data[label] = _curve_to_mapping(recall[idx], precision[idx])
|
|
890
|
+
|
|
891
|
+
for agg in ("micro", "macro"):
|
|
892
|
+
if agg in precision and agg in recall:
|
|
893
|
+
pretty_name = f"{agg.title()} Average"
|
|
894
|
+
data[pretty_name] = _curve_to_mapping(recall[agg], precision[agg])
|
|
895
|
+
|
|
896
|
+
if not data:
|
|
897
|
+
return
|
|
898
|
+
|
|
899
|
+
curve_data = cast(Dict[str, Dict[int, int]], data)
|
|
900
|
+
|
|
901
|
+
SNPioMultiQC.queue_linegraph(
|
|
902
|
+
data=curve_data,
|
|
903
|
+
panel_id=(
|
|
904
|
+
f"{self.model_name}_{panel_prefix}_pr_curves"
|
|
905
|
+
if panel_prefix
|
|
906
|
+
else f"{self.model_name}_pr_curves"
|
|
681
907
|
),
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
legend_title_font=dict(size=font_size),
|
|
687
|
-
legend_title_side="top",
|
|
688
|
-
font=dict(size=font_size),
|
|
689
|
-
)
|
|
690
|
-
fig.write_html(os.path.join(report_path, "imputed_pca.html"))
|
|
691
|
-
fig.write_image(
|
|
692
|
-
os.path.join(report_path, f"imputed_pca.{plot_format}"),
|
|
908
|
+
section=self.multiqc_section,
|
|
909
|
+
title=f"{self.model_name} Precision-Recall Curves",
|
|
910
|
+
index_label="Recall",
|
|
911
|
+
description="Multi-class Precision-Recall curves for PG-SUI predictions.",
|
|
693
912
|
)
|
|
694
913
|
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
914
|
+
def _queue_multiqc_metrics(
|
|
915
|
+
self,
|
|
916
|
+
*,
|
|
917
|
+
metrics: Dict[str, float],
|
|
918
|
+
roc_auc: Dict[object, float],
|
|
919
|
+
average_precision: Dict[object, float],
|
|
920
|
+
label_names: Sequence[str],
|
|
921
|
+
panel_prefix: str,
|
|
922
|
+
) -> None:
|
|
923
|
+
"""Queue summary metrics and per-class AUC/AP for MultiQC.
|
|
700
924
|
|
|
701
925
|
Args:
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
ValueError: nn_method must be either 'NLPCA', 'UBP', or 'VAE'.
|
|
926
|
+
metrics (Dict[str, float]): Summary metrics (accuracy, F1, etc.).
|
|
927
|
+
roc_auc (Dict[object, float]): Per-class and aggregate ROC-AUC values.
|
|
928
|
+
average_precision (Dict[object, float]): Per-class and aggregate average precision values.
|
|
929
|
+
label_names (Sequence[str]): Class names.
|
|
930
|
+
panel_prefix (str): Optional prefix for panel IDs.
|
|
708
931
|
"""
|
|
709
|
-
if
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
932
|
+
if not self._multiqc_enabled():
|
|
933
|
+
return
|
|
934
|
+
|
|
935
|
+
# Summary metrics table (accuracy, F1, etc.)
|
|
936
|
+
if metrics:
|
|
937
|
+
series = pd.Series(metrics, name="Value")
|
|
938
|
+
SNPioMultiQC.queue_table(
|
|
939
|
+
df=series,
|
|
940
|
+
panel_id=f"{self.model_name}_summary_metrics",
|
|
941
|
+
section=self.multiqc_section,
|
|
942
|
+
title=f"{self.model_name} Summary Metrics",
|
|
943
|
+
index_label="Metric",
|
|
944
|
+
description="Global evaluation metrics produced by PG-SUI.",
|
|
717
945
|
)
|
|
718
946
|
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
947
|
+
# Per-class ROC-AUC and AP heatmap
|
|
948
|
+
rows: List[Dict[str, float | str]] = []
|
|
949
|
+
|
|
950
|
+
# integer keys are classes; others are 'micro', 'macro'
|
|
951
|
+
class_keys = [k for k in roc_auc.keys() if isinstance(k, int)]
|
|
952
|
+
class_keys_sorted = sorted(class_keys)
|
|
953
|
+
|
|
954
|
+
for i in class_keys_sorted:
|
|
955
|
+
class_name = label_names[i] if i < len(label_names) else f"Class {i}"
|
|
956
|
+
rows.append(
|
|
957
|
+
{
|
|
958
|
+
"Class": str(class_name),
|
|
959
|
+
"ROC_AUC": float(roc_auc.get(i, np.nan)),
|
|
960
|
+
"AveragePrecision": float(average_precision.get(i, np.nan)),
|
|
961
|
+
}
|
|
733
962
|
)
|
|
734
963
|
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
964
|
+
for agg in ("micro", "macro"):
|
|
965
|
+
if agg in roc_auc:
|
|
966
|
+
rows.append(
|
|
967
|
+
{
|
|
968
|
+
"Class": agg,
|
|
969
|
+
"ROC_AUC": float(roc_auc.get(agg, np.nan)),
|
|
970
|
+
"AveragePrecision": float(average_precision.get(agg, np.nan)),
|
|
971
|
+
}
|
|
972
|
+
)
|
|
742
973
|
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
974
|
+
if not rows:
|
|
975
|
+
return
|
|
976
|
+
|
|
977
|
+
df = pd.DataFrame(rows).set_index("Class")
|
|
978
|
+
suffix = f"{panel_prefix}_" if panel_prefix else ""
|
|
979
|
+
panel_id = f"{self.model_name}_{suffix}roc_pr_summary"
|
|
980
|
+
|
|
981
|
+
SNPioMultiQC.queue_heatmap(
|
|
982
|
+
df=df,
|
|
983
|
+
panel_id=panel_id,
|
|
984
|
+
section=self.multiqc_section,
|
|
985
|
+
title=f"{self.model_name} ROC-AUC and Average Precision",
|
|
986
|
+
index_label="Class",
|
|
987
|
+
description=(
|
|
988
|
+
"Per-class ROC-AUC and average precision for PG-SUI predictions (including micro/macro averages where available)."
|
|
989
|
+
),
|
|
990
|
+
)
|
|
746
991
|
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
ax1.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
|
754
|
-
|
|
755
|
-
labels = ["Train"]
|
|
756
|
-
if nn_method == "SAE":
|
|
757
|
-
# Plot validation accuracy
|
|
758
|
-
ax1.plot(history[accval])
|
|
759
|
-
labels.append("Validation")
|
|
760
|
-
|
|
761
|
-
ax1.legend(labels, loc="best")
|
|
762
|
-
|
|
763
|
-
# Plot model loss
|
|
764
|
-
# if nn_method == "VAE":
|
|
765
|
-
# # Reconstruction loss only.
|
|
766
|
-
# ax2.plot(history["loss"])
|
|
767
|
-
# ax2.plot(history[val_recon_loss])
|
|
768
|
-
|
|
769
|
-
# # KL Loss
|
|
770
|
-
# ax3.plot(history[kl_loss])
|
|
771
|
-
# ax3.plot(history[val_kl_loss])
|
|
772
|
-
# ax3.set_title("KL Divergence Loss")
|
|
773
|
-
# ax3.set_ylabel("Loss")
|
|
774
|
-
# ax3.set_xlabel("Epoch")
|
|
775
|
-
# ax3.legend(labels, loc="best")
|
|
776
|
-
|
|
777
|
-
# Total Loss (Reconstruction Loss + KL Loss)
|
|
778
|
-
# ax4.plot(history["loss"])
|
|
779
|
-
# ax4.plot(history[lossval])
|
|
780
|
-
# ax4.set_title("Total Loss (Recon. + KL)")
|
|
781
|
-
# ax4.set_ylabel("Loss")
|
|
782
|
-
# ax4.set_xlabel("Epoch")
|
|
783
|
-
# ax4.legend(labels, loc="best")
|
|
784
|
-
|
|
785
|
-
# else:
|
|
786
|
-
ax2.plot(history["loss"])
|
|
787
|
-
|
|
788
|
-
if nn_method == "SAE":
|
|
789
|
-
ax2.plot(history[lossval])
|
|
790
|
-
|
|
791
|
-
ax2.set_title("Total Loss")
|
|
792
|
-
ax2.set_ylabel("Loss")
|
|
793
|
-
ax2.set_xlabel("Epoch")
|
|
794
|
-
ax2.legend(labels, loc="best")
|
|
795
|
-
|
|
796
|
-
fig.savefig(fn, bbox_inches="tight", facecolor="white")
|
|
992
|
+
def _queue_multiqc_history(
|
|
993
|
+
self,
|
|
994
|
+
*,
|
|
995
|
+
history: Dict[str, List[float] | Dict[str, List[float]] | None] | None,
|
|
996
|
+
) -> None:
|
|
997
|
+
"""Queue training history (loss vs epoch) for MultiQC.
|
|
797
998
|
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
fig.suptitle(nn_method)
|
|
804
|
-
fig.tight_layout(h_pad=2.0, w_pad=2.0)
|
|
805
|
-
fn = os.path.join(
|
|
806
|
-
f"{prefix}_output",
|
|
807
|
-
"plots",
|
|
808
|
-
"Unsupervised",
|
|
809
|
-
nn_method,
|
|
810
|
-
"histplot.pdf",
|
|
811
|
-
)
|
|
999
|
+
Args:
|
|
1000
|
+
history (Dict[str, List[float]] | None): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
|
|
1001
|
+
"""
|
|
1002
|
+
if not self._multiqc_enabled() or history is None:
|
|
1003
|
+
return
|
|
812
1004
|
|
|
813
|
-
|
|
814
|
-
for i, history in enumerate(lod, start=1):
|
|
815
|
-
plt.subplot(3, 2, idx)
|
|
816
|
-
title = f"Phase {i}"
|
|
817
|
-
|
|
818
|
-
# Plot model accuracy
|
|
819
|
-
ax = plt.gca()
|
|
820
|
-
ax.plot(history["categorical_accuracy"])
|
|
821
|
-
ax.set_title(f"{title} Accuracy")
|
|
822
|
-
ax.set_ylabel("Accuracy")
|
|
823
|
-
ax.set_xlabel("Epoch")
|
|
824
|
-
ax.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
|
825
|
-
ax.legend(["Training"], loc="best")
|
|
826
|
-
|
|
827
|
-
# Plot model loss
|
|
828
|
-
plt.subplot(3, 2, idx + 1)
|
|
829
|
-
ax = plt.gca()
|
|
830
|
-
ax.plot(history["loss"])
|
|
831
|
-
ax.set_title(f"{title} Loss")
|
|
832
|
-
ax.set_ylabel("Loss (MSE)")
|
|
833
|
-
ax.set_xlabel("Epoch")
|
|
834
|
-
ax.legend(["Train"], loc="best")
|
|
835
|
-
|
|
836
|
-
idx += 2
|
|
837
|
-
|
|
838
|
-
plt.savefig(fn, bbox_inches="tight", facecolor="white")
|
|
1005
|
+
data: Dict[str, Dict[int, int]] = {}
|
|
839
1006
|
|
|
840
|
-
|
|
841
|
-
|
|
1007
|
+
if self.model_name != "ImputeUBP":
|
|
1008
|
+
if not isinstance(history, dict) or "Train" not in history:
|
|
1009
|
+
return
|
|
842
1010
|
|
|
843
|
-
|
|
844
|
-
raise ValueError(
|
|
845
|
-
f"nn_method must be either 'NLPCA', 'UBP', or 'VAE', but got {nn_method}"
|
|
846
|
-
)
|
|
1011
|
+
train_vals = pd.Series(history["Train"]).iloc[1:]
|
|
847
1012
|
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
1013
|
+
data["Train"] = {
|
|
1014
|
+
epoch: val for epoch, val in enumerate(train_vals.values, start=1)
|
|
1015
|
+
}
|
|
1016
|
+
else:
|
|
1017
|
+
if not (
|
|
1018
|
+
isinstance(history, dict)
|
|
1019
|
+
and "Train" in history
|
|
1020
|
+
and isinstance(history["Train"], dict)
|
|
1021
|
+
):
|
|
1022
|
+
return
|
|
1023
|
+
for phase in range(1, 4):
|
|
1024
|
+
key = f"Phase {phase}"
|
|
1025
|
+
if key not in history["Train"]:
|
|
1026
|
+
continue
|
|
1027
|
+
series = pd.Series(history["Train"][key]).iloc[1:]
|
|
1028
|
+
data[key] = {
|
|
1029
|
+
epoch: val for epoch, val in enumerate(series.values, start=1)
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
if not data:
|
|
1033
|
+
return
|
|
1034
|
+
|
|
1035
|
+
SNPioMultiQC.queue_linegraph(
|
|
1036
|
+
data=data,
|
|
1037
|
+
panel_id=f"{self.model_name}_training_history",
|
|
1038
|
+
section=self.multiqc_section,
|
|
1039
|
+
title=f"{self.model_name} Training Loss per Epoch",
|
|
1040
|
+
index_label="Epoch",
|
|
1041
|
+
description="Training loss trajectory by epoch as recorded by PG-SUI.",
|
|
874
1042
|
)
|
|
875
1043
|
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
1044
|
+
def _queue_multiqc_confusion(
|
|
1045
|
+
self,
|
|
1046
|
+
*,
|
|
1047
|
+
y_true: np.ndarray,
|
|
1048
|
+
y_pred: np.ndarray,
|
|
1049
|
+
labels: np.ndarray,
|
|
1050
|
+
display_labels: List[str] | np.ndarray,
|
|
1051
|
+
panel_id: str,
|
|
1052
|
+
) -> None:
|
|
1053
|
+
"""Queue confusion-matrix heatmap for MultiQC.
|
|
884
1054
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
1055
|
+
Args:
|
|
1056
|
+
y_true (np.ndarray): 1D array of true integer labels.
|
|
1057
|
+
y_pred (np.ndarray): 1D array of predicted integer labels.
|
|
1058
|
+
labels (np.ndarray): Array of label indices to index the confusion matrix.
|
|
1059
|
+
display_labels (List[str] | np.ndarray): Labels to display on axes.
|
|
1060
|
+
panel_id (str): Panel ID for MultiQC.
|
|
1061
|
+
"""
|
|
1062
|
+
if not self._multiqc_enabled():
|
|
1063
|
+
return
|
|
1064
|
+
|
|
1065
|
+
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
|
1066
|
+
df_cm = pd.DataFrame(cm, index=display_labels, columns=display_labels)
|
|
1067
|
+
|
|
1068
|
+
SNPioMultiQC.queue_heatmap(
|
|
1069
|
+
df=df_cm,
|
|
1070
|
+
panel_id=panel_id,
|
|
1071
|
+
section=self.multiqc_section,
|
|
1072
|
+
title=f"{self.model_name} Confusion Matrix",
|
|
1073
|
+
index_label="True Label",
|
|
1074
|
+
description=(
|
|
1075
|
+
"Confusion matrix for PG-SUI predictions. Rows correspond to true "
|
|
1076
|
+
"labels; columns correspond to predicted labels."
|
|
1077
|
+
),
|
|
891
1078
|
)
|
|
892
1079
|
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
df = misc.validate_input_type(df, return_type="df")
|
|
901
|
-
df_melt = pd.melt(df, value_name="Count")
|
|
902
|
-
cnts = df_melt["Count"].value_counts()
|
|
903
|
-
cnts.index.names = ["Genotype"]
|
|
904
|
-
cnts = pd.DataFrame(cnts).reset_index()
|
|
905
|
-
cnts.sort_values(by="Genotype", inplace=True)
|
|
906
|
-
cnts["Genotype"] = cnts["Genotype"].astype(str)
|
|
907
|
-
|
|
908
|
-
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
|
|
909
|
-
g = sns.barplot(x="Genotype", y="Count", data=cnts, ax=ax)
|
|
910
|
-
g.set_xlabel("Integer-encoded Genotype")
|
|
911
|
-
g.set_ylabel("Count")
|
|
912
|
-
g.set_title("Genotype Counts")
|
|
913
|
-
for p in g.patches:
|
|
914
|
-
g.annotate(
|
|
915
|
-
f"{p.get_height():.1f}",
|
|
916
|
-
(p.get_x() + 0.25, p.get_height() + 0.01),
|
|
917
|
-
xytext=(0, 1),
|
|
918
|
-
textcoords="offset points",
|
|
919
|
-
va="bottom",
|
|
920
|
-
)
|
|
1080
|
+
def _queue_multiqc_gt_distribution(
|
|
1081
|
+
self,
|
|
1082
|
+
*,
|
|
1083
|
+
df: pd.DataFrame,
|
|
1084
|
+
is_imputed: bool,
|
|
1085
|
+
) -> None:
|
|
1086
|
+
"""Queue genotype-distribution barplot for MultiQC.
|
|
921
1087
|
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
1088
|
+
Args:
|
|
1089
|
+
df (pd.DataFrame): DataFrame with 'Genotype' and 'Percent' columns
|
|
1090
|
+
is_imputed (bool): Whether these genotypes are imputed.
|
|
1091
|
+
"""
|
|
1092
|
+
if not self._multiqc_enabled():
|
|
1093
|
+
return
|
|
1094
|
+
|
|
1095
|
+
if "Genotype" not in df.columns or "Percent" not in df.columns:
|
|
1096
|
+
return
|
|
1097
|
+
|
|
1098
|
+
series = df.set_index("Genotype")["Percent"]
|
|
1099
|
+
suffix = "imputed" if is_imputed else "original"
|
|
1100
|
+
title = (
|
|
1101
|
+
f"{self.model_name} Imputed Genotype Distribution"
|
|
1102
|
+
if is_imputed
|
|
1103
|
+
else f"{self.model_name} Genotype Distribution"
|
|
926
1104
|
)
|
|
927
|
-
plt.close()
|
|
928
|
-
|
|
929
|
-
@staticmethod
|
|
930
|
-
def plot_label_clusters(z_mean, labels, prefix="imputer"):
|
|
931
|
-
"""Display a 2D plot of the classes in the latent space."""
|
|
932
|
-
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
|
|
933
|
-
|
|
934
|
-
sns.scatterplot(x=z_mean[:, 0], y=z_mean[:, 1], ax=ax)
|
|
935
|
-
ax.set_xlabel("Latent Dimension 1")
|
|
936
|
-
ax.set_ylabel("Latent Dimension 2")
|
|
937
1105
|
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
"
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
"
|
|
1106
|
+
SNPioMultiQC.queue_barplot(
|
|
1107
|
+
df=series,
|
|
1108
|
+
panel_id=f"{self.model_name}_gt_distribution_{suffix}",
|
|
1109
|
+
section=self.multiqc_section,
|
|
1110
|
+
title=title,
|
|
1111
|
+
index_label="Genotype",
|
|
1112
|
+
value_label="Percent",
|
|
1113
|
+
description=(
|
|
1114
|
+
"Genotype frequency distribution (percent of calls per genotype) computed by PG-SUI."
|
|
1115
|
+
),
|
|
944
1116
|
)
|
|
945
|
-
|
|
946
|
-
if os.path.isfile(outfile):
|
|
947
|
-
os.remove(outfile)
|
|
948
|
-
|
|
949
|
-
fig.savefig(outfile, facecolor="white", bbox_inches="tight")
|