pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
pgsui/utils/plotting.py
CHANGED
|
@@ -1,920 +1,610 @@
|
|
|
1
|
-
import
|
|
2
|
-
import sys
|
|
3
|
-
from itertools import cycle
|
|
4
|
-
from pathlib import Path
|
|
1
|
+
import logging
|
|
5
2
|
import warnings
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Literal, Optional, Sequence
|
|
6
5
|
|
|
7
|
-
|
|
6
|
+
import matplotlib as mpl
|
|
8
7
|
|
|
8
|
+
# Use Agg backend for headless plotting
|
|
9
|
+
mpl.use("Agg")
|
|
9
10
|
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
10
12
|
import numpy as np
|
|
13
|
+
import optuna
|
|
11
14
|
import pandas as pd
|
|
12
|
-
import matplotlib.pyplot as plt
|
|
13
15
|
import seaborn as sns
|
|
14
|
-
import
|
|
15
|
-
|
|
16
|
-
from sklearn.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
16
|
+
import torch
|
|
17
|
+
from optuna.exceptions import ExperimentalWarning
|
|
18
|
+
from sklearn.metrics import (
|
|
19
|
+
ConfusionMatrixDisplay,
|
|
20
|
+
auc,
|
|
21
|
+
average_precision_score,
|
|
22
|
+
precision_recall_curve,
|
|
23
|
+
roc_curve,
|
|
24
|
+
)
|
|
25
|
+
from sklearn.preprocessing import label_binarize
|
|
26
|
+
from snpio.utils.logging import LoggerManager
|
|
27
|
+
|
|
28
|
+
from pgsui.utils import misc
|
|
29
|
+
|
|
30
|
+
# Quiet Matplotlib/fontTools INFO logging when saving PDF/SVG
|
|
31
|
+
for name in (
|
|
32
|
+
"fontTools",
|
|
33
|
+
"fontTools.subset",
|
|
34
|
+
"fontTools.ttLib",
|
|
35
|
+
"matplotlib.font_manager",
|
|
36
|
+
):
|
|
37
|
+
lg = logging.getLogger(name)
|
|
38
|
+
lg.setLevel(logging.WARNING)
|
|
39
|
+
lg.propagate = False
|
|
25
40
|
|
|
26
41
|
|
|
27
42
|
class Plotting:
|
|
28
|
-
"""
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
43
|
+
"""Class for plotting imputer scoring and results.
|
|
44
|
+
|
|
45
|
+
This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> from pgsui import Plotting
|
|
49
|
+
>>> plotter = Plotting(model_name="ImputeVAE")
|
|
50
|
+
>>> plotter.plot_metrics(metrics, num_classes)
|
|
51
|
+
>>> plotter.plot_history(history)
|
|
52
|
+
>>> plotter.plot_certainty_heatmap(y_certainty)
|
|
53
|
+
>>> plotter.plot_confusion_matrix(y_true_1d, y_pred_1d)
|
|
54
|
+
>>> plotter.plot_gt_distribution(df)
|
|
55
|
+
>>> plotter.plot_label_clusters(z_mean, z_log_var)
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
model_name (str): Name of the model.
|
|
59
|
+
prefix (str): Prefix for the output directory.
|
|
60
|
+
plot_format (Literal["pdf", "png", "jpeg", "jpg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg').
|
|
61
|
+
plot_fontsize (int): Font size for the plots.
|
|
62
|
+
plot_dpi (int): Dots per inch for the plots.
|
|
63
|
+
title_fontsize (int): Font size for the plot titles.
|
|
64
|
+
show_plots (bool): Whether to display the plots.
|
|
65
|
+
output_dir (Path): Directory where plots will be saved.
|
|
66
|
+
logger (logging.Logger): Logger instance for logging messages.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
model_name: str,
|
|
72
|
+
*,
|
|
73
|
+
prefix: str = "pgsui",
|
|
74
|
+
plot_format: Literal["pdf", "png", "jpeg", "jpg"] = "pdf",
|
|
75
|
+
plot_fontsize: int = 18,
|
|
76
|
+
plot_dpi: int = 300,
|
|
77
|
+
title_fontsize: int = 20,
|
|
78
|
+
despine: bool = True,
|
|
79
|
+
show_plots: bool = False,
|
|
80
|
+
verbose: int = 0,
|
|
81
|
+
debug: bool = False,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Initialize the Plotting object.
|
|
84
|
+
|
|
85
|
+
This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
|
|
35
86
|
|
|
36
87
|
Args:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
88
|
+
model_name (str): Name of the model.
|
|
89
|
+
prefix (str, optional): Prefix for the output directory. Defaults to 'pgsui'.
|
|
90
|
+
plot_format (Literal["pdf", "png", "jpeg", "jpg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg'). Defaults to 'pdf'.
|
|
91
|
+
plot_fontsize (int): Font size for the plots. Defaults to 18.
|
|
92
|
+
plot_dpi (int): Dots per inch for the plots. Defaults to 300.
|
|
93
|
+
title_fontsize (int): Font size for the plot titles. Defaults to 20.
|
|
94
|
+
despine (bool): Whether to remove the top and right spines from the plots. Defaults to True.
|
|
95
|
+
show_plots (bool): Whether to display the plots. Defaults to False.
|
|
96
|
+
verbose (int): Verbosity level for logging. Defaults to 0.
|
|
97
|
+
debug (bool): Whether to enable debug mode. Defaults to False.
|
|
42
98
|
"""
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
means_test = [col for col in results if col.startswith("mean_test_")]
|
|
46
|
-
filter_col = [col for col in results if col.startswith("param_")]
|
|
47
|
-
params_df = results[filter_col].astype(str)
|
|
48
|
-
for i, col in enumerate(means_test):
|
|
49
|
-
params_df[col] = results[means_test[i]]
|
|
50
|
-
|
|
51
|
-
# Get number of needed subplot rows.
|
|
52
|
-
tot = len(filter_col)
|
|
53
|
-
cols = 4
|
|
54
|
-
rows = int(np.ceil(tot / cols))
|
|
55
|
-
|
|
56
|
-
fig = plt.figure(1, figsize=(20, 10))
|
|
57
|
-
fig.tight_layout(pad=3.0)
|
|
58
|
-
|
|
59
|
-
# Set font properties.
|
|
60
|
-
font = {"size": 12}
|
|
61
|
-
plt.rc("font", **font)
|
|
62
|
-
|
|
63
|
-
for i, p in enumerate(filter_col, start=1):
|
|
64
|
-
ax = fig.add_subplot(rows, cols, i)
|
|
65
|
-
|
|
66
|
-
# Plot each metric.
|
|
67
|
-
for col in means_test:
|
|
68
|
-
# Get maximum score for each parameter setting.
|
|
69
|
-
df_plot = params_df.groupby(p)[col].agg("max")
|
|
70
|
-
|
|
71
|
-
# Convert to float if not supposed to be string.
|
|
72
|
-
try:
|
|
73
|
-
df_plot.index = df_plot.index.astype(float)
|
|
74
|
-
except TypeError:
|
|
75
|
-
pass
|
|
76
|
-
|
|
77
|
-
# Sort by index (numerically if possible).
|
|
78
|
-
df_plot = df_plot.sort_index()
|
|
79
|
-
|
|
80
|
-
# Remove prefix from score name.
|
|
81
|
-
col_new_name = col[len("mean_test_") :]
|
|
82
|
-
|
|
83
|
-
ax.plot(
|
|
84
|
-
df_plot.index.astype(str),
|
|
85
|
-
df_plot.values,
|
|
86
|
-
"-o",
|
|
87
|
-
label=col_new_name,
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
ax.legend(loc="best")
|
|
91
|
-
|
|
92
|
-
param_new_name = p[len("param_") :]
|
|
93
|
-
ax.set_xlabel(param_new_name.lower())
|
|
94
|
-
ax.set_ylabel("Max Score")
|
|
95
|
-
ax.set_ylim([0, 1])
|
|
96
|
-
|
|
97
|
-
fig.savefig(
|
|
98
|
-
os.path.join(
|
|
99
|
-
f"{prefix}_output",
|
|
100
|
-
"plots",
|
|
101
|
-
"Unsupervised",
|
|
102
|
-
nn_method,
|
|
103
|
-
"gridsearch_metrics.pdf",
|
|
104
|
-
),
|
|
105
|
-
bbox_inches="tight",
|
|
106
|
-
facecolor="white",
|
|
99
|
+
logman = LoggerManager(
|
|
100
|
+
name=__name__, prefix=prefix, verbose=verbose, debug=debug
|
|
107
101
|
)
|
|
102
|
+
self.logger = logman.get_logger()
|
|
103
|
+
|
|
104
|
+
self.model_name = model_name
|
|
105
|
+
self.prefix = prefix
|
|
106
|
+
self.plot_format = plot_format
|
|
107
|
+
self.plot_fontsize = plot_fontsize
|
|
108
|
+
self.plot_dpi = plot_dpi
|
|
109
|
+
self.title_fontsize = title_fontsize
|
|
110
|
+
self.show_plots = show_plots
|
|
111
|
+
|
|
112
|
+
if self.plot_format.startswith("."):
|
|
113
|
+
self.plot_format = self.plot_format.lstrip(".")
|
|
114
|
+
|
|
115
|
+
self.param_dict = {
|
|
116
|
+
"axes.labelsize": self.plot_fontsize,
|
|
117
|
+
"axes.titlesize": self.title_fontsize,
|
|
118
|
+
"axes.spines.top": despine,
|
|
119
|
+
"axes.spines.right": despine,
|
|
120
|
+
"xtick.labelsize": self.plot_fontsize,
|
|
121
|
+
"ytick.labelsize": self.plot_fontsize,
|
|
122
|
+
"legend.fontsize": self.plot_fontsize,
|
|
123
|
+
"legend.facecolor": "white",
|
|
124
|
+
"figure.titlesize": self.title_fontsize,
|
|
125
|
+
"figure.dpi": self.plot_dpi,
|
|
126
|
+
"figure.facecolor": "white",
|
|
127
|
+
"axes.linewidth": 2.0,
|
|
128
|
+
"lines.linewidth": 2.0,
|
|
129
|
+
"font.size": self.plot_fontsize,
|
|
130
|
+
"savefig.bbox": "tight",
|
|
131
|
+
"savefig.facecolor": "white",
|
|
132
|
+
"savefig.dpi": self.plot_dpi,
|
|
133
|
+
}
|
|
108
134
|
|
|
109
|
-
|
|
110
|
-
def plot_metrics(metrics, num_classes, prefix, nn_method):
|
|
111
|
-
"""Plot AUC-ROC and Precision-Recall performance metrics for neural network classifier.
|
|
135
|
+
mpl.rcParams.update(self.param_dict)
|
|
112
136
|
|
|
113
|
-
|
|
137
|
+
unsuper = {"ImputeVAE", "ImputeNLPCA", "ImputeAutoencoder", "ImputeUBP"}
|
|
138
|
+
det = {
|
|
139
|
+
"ImputeRefAllele",
|
|
140
|
+
"ImputeMostFrequent",
|
|
141
|
+
"ImputeMostFrequentPerPop",
|
|
142
|
+
"ImputePhylo",
|
|
143
|
+
}
|
|
144
|
+
sup = {"ImputeRandomForest", "ImputeHistGradientBoosting"}
|
|
145
|
+
|
|
146
|
+
if model_name in unsuper:
|
|
147
|
+
plot_dir = "Unsupervised"
|
|
148
|
+
elif model_name in det:
|
|
149
|
+
plot_dir = "Deterministic"
|
|
150
|
+
elif model_name in sup:
|
|
151
|
+
plot_dir = "Supervised"
|
|
152
|
+
else:
|
|
153
|
+
msg = f"model_name '{model_name}' not recognized."
|
|
154
|
+
self.logger.error(msg)
|
|
155
|
+
raise ValueError(msg)
|
|
114
156
|
|
|
115
|
-
|
|
116
|
-
|
|
157
|
+
self.output_dir = Path(f"{self.prefix}_output", plot_dir)
|
|
158
|
+
self.output_dir = self.output_dir / "plots" / model_name
|
|
159
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
117
160
|
|
|
118
|
-
|
|
161
|
+
def plot_tuning(
|
|
162
|
+
self,
|
|
163
|
+
study: optuna.study.Study,
|
|
164
|
+
model_name: str,
|
|
165
|
+
target_name: str = "Objective Value",
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Plot the optimization history of a study.
|
|
119
168
|
|
|
120
|
-
|
|
169
|
+
This method plots the optimization history of a study. The plot is saved to disk as a ``<plot_format>`` file.
|
|
121
170
|
|
|
122
|
-
|
|
171
|
+
Args:
|
|
172
|
+
study (optuna.study.Study): Optuna study object.
|
|
173
|
+
model_name (str): Name of the model.
|
|
174
|
+
target_name (str, optional): Name of the target value. Defaults to 'Objective Value'.
|
|
123
175
|
"""
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
fn = os.path.join(
|
|
129
|
-
f"{prefix}_output",
|
|
130
|
-
"plots",
|
|
131
|
-
"Unsupervised",
|
|
132
|
-
nn_method,
|
|
133
|
-
f"auc_pr_curves.pdf",
|
|
134
|
-
)
|
|
135
|
-
fig = plt.figure(figsize=(20, 10))
|
|
176
|
+
with warnings.catch_warnings():
|
|
177
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
178
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
179
|
+
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
|
136
180
|
|
|
137
|
-
|
|
138
|
-
|
|
181
|
+
od = self.output_dir / "optimize"
|
|
182
|
+
target_name = target_name.title()
|
|
139
183
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
)
|
|
143
|
-
axs = fig.subplots(nrows=1, ncols=2)
|
|
144
|
-
plt.subplots_adjust(hspace=0.5)
|
|
145
|
-
|
|
146
|
-
# Line weight
|
|
147
|
-
lw = 2
|
|
148
|
-
|
|
149
|
-
roc_auc = metrics["roc_auc"]
|
|
150
|
-
pr_ap = metrics["precision_recall"]
|
|
151
|
-
|
|
152
|
-
metric_list = [roc_auc, pr_ap]
|
|
153
|
-
|
|
154
|
-
for metric, ax in zip(metric_list, axs):
|
|
155
|
-
if "fpr_micro" in metric:
|
|
156
|
-
prefix1 = "fpr"
|
|
157
|
-
prefix2 = "tpr"
|
|
158
|
-
lab1 = "ROC"
|
|
159
|
-
lab2 = "AUC"
|
|
160
|
-
xlab = "False Positive Rate"
|
|
161
|
-
ylab = "True Positive Rate"
|
|
162
|
-
title = "Receiver Operating Characteristic (ROC)"
|
|
163
|
-
baseline = [0, 1]
|
|
164
|
-
|
|
165
|
-
elif "recall_micro" in metric:
|
|
166
|
-
prefix1 = "recall"
|
|
167
|
-
prefix2 = "precision"
|
|
168
|
-
lab1 = "Precision-Recall"
|
|
169
|
-
lab2 = "AP"
|
|
170
|
-
xlab = "Recall"
|
|
171
|
-
ylab = "Precision"
|
|
172
|
-
title = "Precision-Recall"
|
|
173
|
-
baseline = [metric["baseline"], metric["baseline"]]
|
|
174
|
-
|
|
175
|
-
# Plot iso-f1 curves.
|
|
176
|
-
f_scores = np.linspace(0.2, 0.8, num=4)
|
|
177
|
-
for i, f_score in enumerate(f_scores):
|
|
178
|
-
x = np.linspace(0.01, 1)
|
|
179
|
-
y = f_score * x / (2 * x - f_score)
|
|
180
|
-
ax.plot(
|
|
181
|
-
x[y >= 0],
|
|
182
|
-
y[y >= 0],
|
|
183
|
-
color="gray",
|
|
184
|
-
alpha=0.2,
|
|
185
|
-
linewidth=lw,
|
|
186
|
-
label="Iso-F1 Curves" if i == 0 else "",
|
|
187
|
-
)
|
|
188
|
-
ax.annotate(f"F1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
|
|
189
|
-
|
|
190
|
-
# Plot ROC curves.
|
|
191
|
-
ax.plot(
|
|
192
|
-
metric[f"{prefix1}_micro"],
|
|
193
|
-
metric[f"{prefix2}_micro"],
|
|
194
|
-
label=f"Micro-averaged {lab1} Curve ({lab2} = {metric['micro']:.2f})",
|
|
195
|
-
color="deeppink",
|
|
196
|
-
linestyle=":",
|
|
197
|
-
linewidth=4,
|
|
184
|
+
ax = optuna.visualization.matplotlib.plot_optimization_history(
|
|
185
|
+
study, target_name=target_name
|
|
198
186
|
)
|
|
199
|
-
|
|
200
|
-
ax.
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
187
|
+
ax.set_title(f"{model_name} Optimization History")
|
|
188
|
+
ax.set_xlabel("Trial")
|
|
189
|
+
ax.set_ylabel(target_name)
|
|
190
|
+
ax.legend(
|
|
191
|
+
loc="best",
|
|
192
|
+
shadow=True,
|
|
193
|
+
fancybox=True,
|
|
194
|
+
fontsize=mpl.rcParamsDefault["legend.fontsize"],
|
|
207
195
|
)
|
|
208
196
|
|
|
209
|
-
|
|
210
|
-
for i, color in zip(range(num_classes), colors):
|
|
211
|
-
if f"{prefix1}_{i}" in metric:
|
|
212
|
-
ax.plot(
|
|
213
|
-
metric[f"{prefix1}_{i}"],
|
|
214
|
-
metric[f"{prefix2}_{i}"],
|
|
215
|
-
color=color,
|
|
216
|
-
lw=lw,
|
|
217
|
-
label=f"{lab1} Curve of class {i} ({lab2} = {metric[i]:.2f})",
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
if "fpr_micro" in metric:
|
|
221
|
-
# Make center baseline
|
|
222
|
-
ax.plot(
|
|
223
|
-
baseline,
|
|
224
|
-
baseline,
|
|
225
|
-
"k--",
|
|
226
|
-
linewidth=lw,
|
|
227
|
-
label="No Classification Skill",
|
|
228
|
-
)
|
|
229
|
-
else:
|
|
230
|
-
ax.plot(
|
|
231
|
-
[0, 1],
|
|
232
|
-
baseline,
|
|
233
|
-
"k--",
|
|
234
|
-
linewidth=lw,
|
|
235
|
-
label="No Classification Skill",
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
ax.set_xlim(0.0, 1.0)
|
|
239
|
-
ax.set_ylim(0.0, 1.05)
|
|
240
|
-
ax.set_xlabel(f"{xlab}")
|
|
241
|
-
ax.set_ylabel(f"{ylab}")
|
|
242
|
-
ax.set_title(f"{title}")
|
|
243
|
-
ax.legend(loc="best")
|
|
244
|
-
|
|
245
|
-
fig.savefig(fn, bbox_inches="tight", facecolor="white")
|
|
246
|
-
plt.close()
|
|
247
|
-
plt.clf()
|
|
248
|
-
plt.cla()
|
|
249
|
-
|
|
250
|
-
@staticmethod
|
|
251
|
-
def plot_search_space(
|
|
252
|
-
estimator,
|
|
253
|
-
height=2,
|
|
254
|
-
s=25,
|
|
255
|
-
features=None,
|
|
256
|
-
):
|
|
257
|
-
"""Make density and contour plots for showing search space during grid search.
|
|
258
|
-
|
|
259
|
-
Modified from sklearn-genetic-opt function to implement exception handling.
|
|
197
|
+
fn = od / f"optuna_optimization_history.{self.plot_format}"
|
|
260
198
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
height (float, optional): Height of each facet. Defaults to 2.
|
|
265
|
-
|
|
266
|
-
s (float, optional): Size of the markers in scatter plot. Defaults to 5.
|
|
267
|
-
|
|
268
|
-
features (list, optional): Subset of features to plot, if ``None`` it plots all the features by default. Defaults to None.
|
|
269
|
-
|
|
270
|
-
Returns:
|
|
271
|
-
g (seaborn.PairGrid): Pair plot of the used hyperparameters during the search.
|
|
272
|
-
"""
|
|
273
|
-
sns.set_style("white")
|
|
274
|
-
|
|
275
|
-
df = logbook_to_pandas(estimator.logbook)
|
|
276
|
-
if features:
|
|
277
|
-
_stats = df[features]
|
|
278
|
-
else:
|
|
279
|
-
variables = [*estimator.space.parameters, "score"]
|
|
280
|
-
_stats = df[variables]
|
|
281
|
-
|
|
282
|
-
g = sns.PairGrid(_stats, diag_sharey=False, height=height)
|
|
199
|
+
if not fn.parent.exists():
|
|
200
|
+
fn.parent.mkdir(parents=True, exist_ok=True)
|
|
283
201
|
|
|
284
|
-
|
|
202
|
+
plt.savefig(fn)
|
|
203
|
+
plt.close()
|
|
285
204
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
sns.kdeplot,
|
|
289
|
-
shade=True,
|
|
290
|
-
cmap=sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True),
|
|
205
|
+
ax = optuna.visualization.matplotlib.plot_edf(
|
|
206
|
+
study, target_name=target_name
|
|
291
207
|
)
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
sns.kdeplot,
|
|
301
|
-
shade=True,
|
|
302
|
-
palette="crest",
|
|
303
|
-
alpha=0.2,
|
|
304
|
-
color="red",
|
|
208
|
+
ax.set_title(f"{model_name} Empirical Distribution Function (EDF)")
|
|
209
|
+
ax.set_xlabel(target_name)
|
|
210
|
+
ax.set_ylabel(f"{model_name} Cumulative Probability")
|
|
211
|
+
ax.legend(
|
|
212
|
+
loc="best",
|
|
213
|
+
shadow=True,
|
|
214
|
+
fancybox=True,
|
|
215
|
+
fontsize=mpl.rcParamsDefault["legend.fontsize"],
|
|
305
216
|
)
|
|
306
|
-
except np.linalg.LinAlgError as err:
|
|
307
|
-
if "singular matrix" in str(err).lower():
|
|
308
|
-
g = g.map_diag(sns.histplot, color="red", alpha=1.0, kde=False)
|
|
309
|
-
|
|
310
|
-
return g
|
|
311
|
-
|
|
312
|
-
@staticmethod
|
|
313
|
-
def visualize_missingness(
|
|
314
|
-
genotype_data,
|
|
315
|
-
df,
|
|
316
|
-
zoom=True,
|
|
317
|
-
prefix="imputer",
|
|
318
|
-
horizontal_space=0.6,
|
|
319
|
-
vertical_space=0.6,
|
|
320
|
-
bar_color="gray",
|
|
321
|
-
heatmap_palette="magma",
|
|
322
|
-
plot_format="pdf",
|
|
323
|
-
dpi=300,
|
|
324
|
-
):
|
|
325
|
-
"""Make multiple plots to visualize missing data.
|
|
326
|
-
|
|
327
|
-
Args:
|
|
328
|
-
genotype_data (GenotypeData): Initialized GentoypeData object.
|
|
329
217
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
zoom (bool, optional): If True, zooms in to the missing proportion range on some of the plots. If False, the plot range is fixed at [0, 1]. Defaults to True.
|
|
218
|
+
plt.savefig(fn.with_stem("optuna_edf_plot"))
|
|
219
|
+
plt.close()
|
|
333
220
|
|
|
334
|
-
|
|
221
|
+
ax = optuna.visualization.matplotlib.plot_param_importances(
|
|
222
|
+
study, target_name=target_name
|
|
223
|
+
)
|
|
224
|
+
ax.set_xlabel("Parameter Importance")
|
|
225
|
+
ax.set_ylabel("Parameter")
|
|
226
|
+
ax.legend(loc="best", shadow=True, fancybox=True)
|
|
335
227
|
|
|
336
|
-
|
|
228
|
+
plt.savefig(fn.with_stem("optuna_param_importances_plot"))
|
|
229
|
+
plt.close()
|
|
337
230
|
|
|
338
|
-
|
|
231
|
+
ax = optuna.visualization.matplotlib.plot_timeline(study)
|
|
232
|
+
ax.set_title(f"{model_name} Timeline Plot")
|
|
233
|
+
ax.set_xlabel("Datetime")
|
|
234
|
+
ax.set_ylabel("Trial")
|
|
235
|
+
plt.savefig(fn.with_stem("optuna_timeline_plot"))
|
|
236
|
+
plt.close()
|
|
339
237
|
|
|
340
|
-
|
|
238
|
+
# Reset the style from Optuna's plotting.
|
|
239
|
+
sns.set_style("white", rc=self.param_dict)
|
|
240
|
+
mpl.rcParams.update(self.param_dict)
|
|
341
241
|
|
|
342
|
-
|
|
242
|
+
def plot_metrics(
|
|
243
|
+
self,
|
|
244
|
+
y_true: np.ndarray,
|
|
245
|
+
y_pred_proba: np.ndarray,
|
|
246
|
+
metrics: Dict[str, float],
|
|
247
|
+
label_names: Optional[Sequence[str]] = None,
|
|
248
|
+
prefix: str = "",
|
|
249
|
+
) -> None:
|
|
250
|
+
"""Plot multi-class ROC-AUC and Precision-Recall curves.
|
|
343
251
|
|
|
344
|
-
|
|
252
|
+
This method plots the multi-class ROC-AUC and Precision-Recall curves. The plot is saved to disk as a ``<plot_format>`` file.
|
|
345
253
|
|
|
346
|
-
|
|
254
|
+
Args:
|
|
255
|
+
y_true (np.ndarray): 1D array of true integer labels in [0, n_classes-1].
|
|
256
|
+
y_pred_proba (np.ndarray): (n_samples, n_classes) array of predicted probabilities.
|
|
257
|
+
metrics (Dict[str, float]): Dict of summary metrics to annotate the figure.
|
|
258
|
+
label_names (Optional[Sequence[str]]): Optional sequence of class names (length must equal n_classes).
|
|
259
|
+
If provided, legends will use these names instead of 'Class i'.
|
|
260
|
+
prefix (str): Optional prefix for the output filename.
|
|
347
261
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
pandas.DataFrame: Per-individual missing data proportions.
|
|
351
|
-
pandas.DataFrame: Per-population + per-locus missing data proportions.
|
|
352
|
-
pandas.DataFrame: Per-population missing data proportions.
|
|
353
|
-
pandas.DataFrame: Per-individual and per-population missing data proportions.
|
|
262
|
+
Raises:
|
|
263
|
+
ValueError: If model_name is not recognized (legacy guard).
|
|
354
264
|
"""
|
|
265
|
+
num_classes = y_pred_proba.shape[1]
|
|
355
266
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
267
|
+
# Validate/normalize label names
|
|
268
|
+
if label_names is not None and len(label_names) != num_classes:
|
|
269
|
+
self.logger.warning(
|
|
270
|
+
f"plot_metrics: len(label_names)={len(label_names)} "
|
|
271
|
+
f"!= n_classes={num_classes}. Ignoring label_names."
|
|
272
|
+
)
|
|
273
|
+
label_names = None
|
|
274
|
+
if label_names is None:
|
|
275
|
+
label_names = [f"Class {i}" for i in range(num_classes)]
|
|
276
|
+
|
|
277
|
+
# Binarize y_true for one-vs-rest curves
|
|
278
|
+
y_true_bin = label_binarize(y_true, classes=np.arange(num_classes))
|
|
279
|
+
|
|
280
|
+
# Containers
|
|
281
|
+
fpr, tpr, roc_auc = {}, {}, {}
|
|
282
|
+
precision, recall, average_precision = {}, {}, {}
|
|
283
|
+
|
|
284
|
+
# Per-class ROC & PR
|
|
285
|
+
for i in range(num_classes):
|
|
286
|
+
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
|
|
287
|
+
roc_auc[i] = auc(fpr[i], tpr[i])
|
|
288
|
+
precision[i], recall[i], _ = precision_recall_curve(
|
|
289
|
+
y_true_bin[:, i], y_pred_proba[:, i]
|
|
290
|
+
)
|
|
291
|
+
average_precision[i] = average_precision_score(
|
|
292
|
+
y_true_bin[:, i], y_pred_proba[:, i]
|
|
293
|
+
)
|
|
382
294
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
295
|
+
# Micro-averages
|
|
296
|
+
fpr["micro"], tpr["micro"], _ = roc_curve(
|
|
297
|
+
y_true_bin.ravel(), y_pred_proba.ravel()
|
|
386
298
|
)
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
which="both",
|
|
394
|
-
left=False,
|
|
395
|
-
right=False,
|
|
396
|
-
labelleft=False,
|
|
299
|
+
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
|
300
|
+
precision["micro"], recall["micro"], _ = precision_recall_curve(
|
|
301
|
+
y_true_bin.ravel(), y_pred_proba.ravel()
|
|
302
|
+
)
|
|
303
|
+
average_precision["micro"] = average_precision_score(
|
|
304
|
+
y_true_bin, y_pred_proba, average="micro"
|
|
397
305
|
)
|
|
398
306
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
poploc,
|
|
419
|
-
vmin=0.0,
|
|
420
|
-
vmax=vmax,
|
|
421
|
-
cmap=sns.color_palette(heatmap_palette, as_cmap=True),
|
|
422
|
-
yticklabels=False,
|
|
423
|
-
cbar_kws={"label": "Missing Prop."},
|
|
424
|
-
ax=ax,
|
|
425
|
-
)
|
|
426
|
-
ax.set_xlabel("Population")
|
|
427
|
-
ax.set_ylabel("Locus")
|
|
428
|
-
|
|
429
|
-
id_vars.append("Population")
|
|
430
|
-
|
|
431
|
-
melt_df = indpop.isna()
|
|
432
|
-
melt_df["SampleID"] = genotype_data.samples
|
|
433
|
-
indpop["SampleID"] = genotype_data.samples
|
|
434
|
-
|
|
435
|
-
if poptotal is not None:
|
|
436
|
-
melt_df["Population"] = genotype_data.pops
|
|
437
|
-
indpop["Population"] = genotype_data.pops
|
|
438
|
-
|
|
439
|
-
melt_df = melt_df.melt(value_name="Missing", id_vars=id_vars)
|
|
440
|
-
melt_df.sort_values(by=id_vars[::-1], inplace=True)
|
|
441
|
-
melt_df["Missing"].replace(False, "Present", inplace=True)
|
|
442
|
-
melt_df["Missing"].replace(True, "Missing", inplace=True)
|
|
307
|
+
# Macro-average ROC
|
|
308
|
+
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
|
|
309
|
+
mean_tpr = np.zeros_like(all_fpr)
|
|
310
|
+
for i in range(num_classes):
|
|
311
|
+
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
|
|
312
|
+
mean_tpr /= num_classes
|
|
313
|
+
fpr["macro"], tpr["macro"] = all_fpr, mean_tpr
|
|
314
|
+
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
|
|
315
|
+
|
|
316
|
+
# Macro-average PR
|
|
317
|
+
all_recall = np.unique(np.concatenate([recall[i] for i in range(num_classes)]))
|
|
318
|
+
mean_precision = np.zeros_like(all_recall)
|
|
319
|
+
for i in range(num_classes):
|
|
320
|
+
# recall[i] increases, but precision[i] is given over decreasing thresholds
|
|
321
|
+
mean_precision += np.interp(all_recall, recall[i][::-1], precision[i][::-1])
|
|
322
|
+
mean_precision /= num_classes
|
|
323
|
+
average_precision["macro"] = average_precision_score(
|
|
324
|
+
y_true_bin, y_pred_proba, average="macro"
|
|
325
|
+
)
|
|
443
326
|
|
|
444
|
-
|
|
327
|
+
# Plot
|
|
328
|
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
|
445
329
|
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
330
|
+
# ROC
|
|
331
|
+
axes[0].plot(
|
|
332
|
+
fpr["micro"],
|
|
333
|
+
tpr["micro"],
|
|
334
|
+
label=f"Micro-average ROC (AUC = {roc_auc['micro']:.2f})",
|
|
335
|
+
linestyle=":",
|
|
336
|
+
linewidth=4,
|
|
453
337
|
)
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
338
|
+
axes[0].plot(
|
|
339
|
+
fpr["macro"],
|
|
340
|
+
tpr["macro"],
|
|
341
|
+
label=f"Macro-average ROC (AUC = {roc_auc['macro']:.2f})",
|
|
342
|
+
linestyle="--",
|
|
343
|
+
linewidth=4,
|
|
460
344
|
)
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
ax = axes[1, 2]
|
|
465
|
-
|
|
466
|
-
ax.set_title("Per-Population")
|
|
467
|
-
g = sns.histplot(
|
|
468
|
-
data=melt_df,
|
|
469
|
-
y="Population",
|
|
470
|
-
hue="Missing",
|
|
471
|
-
multiple="fill",
|
|
472
|
-
ax=ax,
|
|
345
|
+
for i in range(num_classes):
|
|
346
|
+
axes[0].plot(
|
|
347
|
+
fpr[i], tpr[i], label=f"{label_names[i]} ROC (AUC = {roc_auc[i]:.2f})"
|
|
473
348
|
)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
349
|
+
axes[0].plot([0, 1], [0, 1], linestyle="--", color="black", label="Random")
|
|
350
|
+
axes[0].set_xlabel("False Positive Rate")
|
|
351
|
+
axes[0].set_ylabel("True Positive Rate")
|
|
352
|
+
axes[0].set_title("Multi-class ROC-AUC Curve")
|
|
353
|
+
axes[0].legend(
|
|
354
|
+
loc="upper center",
|
|
355
|
+
bbox_to_anchor=(0.5, -0.15),
|
|
356
|
+
fancybox=True,
|
|
357
|
+
shadow=True,
|
|
358
|
+
ncol=2,
|
|
482
359
|
)
|
|
483
|
-
plt.cla()
|
|
484
|
-
plt.clf()
|
|
485
|
-
plt.close()
|
|
486
|
-
|
|
487
|
-
return loc, ind, poploc, poptotal, indpop
|
|
488
|
-
|
|
489
|
-
@staticmethod
|
|
490
|
-
def run_and_plot_pca(
|
|
491
|
-
original_genotype_data,
|
|
492
|
-
imputer_object,
|
|
493
|
-
prefix="imputer",
|
|
494
|
-
n_components=3,
|
|
495
|
-
center=True,
|
|
496
|
-
scale=False,
|
|
497
|
-
n_axes=2,
|
|
498
|
-
point_size=15,
|
|
499
|
-
font_size=15,
|
|
500
|
-
plot_format="pdf",
|
|
501
|
-
bottom_margin=0,
|
|
502
|
-
top_margin=0,
|
|
503
|
-
left_margin=0,
|
|
504
|
-
right_margin=0,
|
|
505
|
-
width=1088,
|
|
506
|
-
height=700,
|
|
507
|
-
):
|
|
508
|
-
"""Runs PCA and makes scatterplot with colors showing missingness.
|
|
509
|
-
|
|
510
|
-
Genotypes are plotted as separate shapes per population and colored according to missingness per individual.
|
|
511
|
-
|
|
512
|
-
This function is run at the end of each imputation method, but can be run independently to change plot and PCA parameters such as ``n_axes=3`` or ``scale=True``.
|
|
513
|
-
|
|
514
|
-
The imputed and original GenotypeData objects need to be passed to the function as positional arguments.
|
|
515
|
-
|
|
516
|
-
PCA (principal component analysis) scatterplot can have either two or three axes, set with the n_axes parameter.
|
|
517
|
-
|
|
518
|
-
The plot is saved as both an interactive HTML file and as a static image. Each population is represented by point shapes. The interactive plot has associated metadata when hovering over the points.
|
|
519
|
-
|
|
520
|
-
Files are saved to a reports directory as <prefix>_output/imputed_pca.<plot_format|html>. Supported image formats include: "pdf", "svg", "png", and "jpeg" (or "jpg").
|
|
521
|
-
|
|
522
|
-
Args:
|
|
523
|
-
original_genotype_data (GenotypeData): Original GenotypeData object that was input into the imputer.
|
|
524
|
-
|
|
525
|
-
imputer_object (Any imputer instance): Imputer object created when imputing. Can be any of the imputers, such as: ``ImputePhylo()``, ``ImputeUBP()``, and ``ImputeRandomForest()``.
|
|
526
|
-
|
|
527
|
-
original_012 (pandas.DataFrame, numpy.ndarray, or List[List[int]], optional): Original 012-encoded genotypes (before imputing). Missing values are encoded as -9. This object can be obtained as ``df = GenotypeData.genotypes012_df``.
|
|
528
|
-
|
|
529
|
-
prefix (str, optional): Prefix for report directory. Plots will be save to a directory called <prefix>_output/imputed_pca<html|plot_format>. Report directory will be created if it does not already exist. Defaults to "imputer".
|
|
530
|
-
|
|
531
|
-
n_components (int, optional): Number of principal components to include in the PCA. Defaults to 3.
|
|
532
|
-
|
|
533
|
-
center (bool, optional): If True, centers the genotypes to the mean before doing the PCA. If False, no centering is done. Defaults to True.
|
|
534
|
-
|
|
535
|
-
scale (bool, optional): If True, scales the genotypes to unit variance before doing the PCA. If False, no scaling is done. Defaults to False.
|
|
536
|
-
|
|
537
|
-
n_axes (int, optional): Number of principal component axes to plot. Must be set to either 2 or 3. If set to 3, a 3-dimensional plot will be made. Defaults to 2.
|
|
538
360
|
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
left_margin (int, optional): Adjust left margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
|
|
548
|
-
|
|
549
|
-
right_margin (int, optional): Adjust right margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
|
|
550
|
-
|
|
551
|
-
width (int, optional): Width of plot space. If your plot is cut off at the edges, even after adjusting the margins, increase the width and height. Try to keep the aspect ratio similar. Defaults to 1088.
|
|
552
|
-
|
|
553
|
-
height (int, optional): Height of plot space. If your plot is cut off at the edges, even after adjusting the margins, increase the width and height. Try to keep the aspect ratio similar. Defaults to 700.
|
|
554
|
-
|
|
555
|
-
Returns:
|
|
556
|
-
numpy.ndarray: PCA data as a numpy array with shape (n_samples, n_components).
|
|
557
|
-
|
|
558
|
-
sklearn.decomposision.PCA: Scikit-learn PCA object from sklearn.decomposision.PCA. Any of the sklearn.decomposition.PCA attributes can be accessed from this object. See sklearn documentation.
|
|
559
|
-
|
|
560
|
-
Examples:
|
|
561
|
-
>>>data = GenotypeData(
|
|
562
|
-
>>> filename="snps.str",
|
|
563
|
-
>>> filetype="structure2row",
|
|
564
|
-
>>> popmapfile="popmap.txt",
|
|
565
|
-
>>>)
|
|
566
|
-
>>>
|
|
567
|
-
>>>ubp = ImputeUBP(genotype_data=data)
|
|
568
|
-
>>>
|
|
569
|
-
>>>components, pca = run_and_plot_pca(
|
|
570
|
-
>>> data,
|
|
571
|
-
>>> ubp,
|
|
572
|
-
>>> scale=True,
|
|
573
|
-
>>> center=True,
|
|
574
|
-
>>> plot_format="png"
|
|
575
|
-
>>>)
|
|
576
|
-
>>>
|
|
577
|
-
>>>explvar = pca.explained_variance_ratio\_
|
|
578
|
-
"""
|
|
579
|
-
report_path = os.path.join(f"{prefix}_output", "plots")
|
|
580
|
-
Path(report_path).mkdir(parents=True, exist_ok=True)
|
|
581
|
-
|
|
582
|
-
if n_axes > 3:
|
|
583
|
-
raise ValueError(
|
|
584
|
-
">3 axes is not supported; n_axes must be either 2 or 3."
|
|
585
|
-
)
|
|
586
|
-
if n_axes < 2:
|
|
587
|
-
raise ValueError(
|
|
588
|
-
"<2 axes is not supported; n_axes must be either 2 or 3."
|
|
589
|
-
)
|
|
590
|
-
|
|
591
|
-
imputer = imputer_object.imputed
|
|
592
|
-
|
|
593
|
-
df = misc.validate_input_type(
|
|
594
|
-
imputer.genotypes012_df, return_type="df"
|
|
361
|
+
# PR
|
|
362
|
+
axes[1].plot(
|
|
363
|
+
recall["micro"],
|
|
364
|
+
precision["micro"],
|
|
365
|
+
label=f"Micro-average PR (AP = {average_precision['micro']:.2f})",
|
|
366
|
+
linestyle=":",
|
|
367
|
+
linewidth=4,
|
|
595
368
|
)
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
369
|
+
axes[1].plot(
|
|
370
|
+
all_recall,
|
|
371
|
+
mean_precision,
|
|
372
|
+
label=f"Macro-average PR (AP = {average_precision['macro']:.2f})",
|
|
373
|
+
linestyle="--",
|
|
374
|
+
linewidth=4,
|
|
600
375
|
)
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
376
|
+
for i in range(num_classes):
|
|
377
|
+
axes[1].plot(
|
|
378
|
+
recall[i],
|
|
379
|
+
precision[i],
|
|
380
|
+
label=f"{label_names[i]} PR (AP = {average_precision[i]:.2f})",
|
|
381
|
+
)
|
|
382
|
+
axes[1].plot([0, 1], [1, 0], linestyle="--", color="black", label="Random")
|
|
383
|
+
axes[1].set_xlabel("Recall")
|
|
384
|
+
axes[1].set_ylabel("Precision")
|
|
385
|
+
axes[1].set_title("Multi-class Precision-Recall Curve")
|
|
386
|
+
axes[1].legend(
|
|
387
|
+
loc="upper center",
|
|
388
|
+
bbox_to_anchor=(0.5, -0.15),
|
|
389
|
+
fancybox=True,
|
|
390
|
+
shadow=True,
|
|
391
|
+
ncol=2,
|
|
617
392
|
)
|
|
618
393
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
df_pca["Size"] = point_size
|
|
622
|
-
|
|
623
|
-
_, ind, _, _, _ = imputer.calc_missing(original_df, use_pops=False)
|
|
624
|
-
df_pca["missPerc"] = ind
|
|
625
|
-
|
|
626
|
-
my_scale = [("rgb(19, 43, 67)"), ("rgb(86,177,247)")] # ggplot default
|
|
627
|
-
|
|
628
|
-
z = "Axis3" if n_axes == 3 else None
|
|
629
|
-
labs = {
|
|
630
|
-
"Axis1": f"PC1 ({round(model.explained_variance_ratio_[0] * 100, 2)}%)",
|
|
631
|
-
"Axis2": f"PC2 ({round(model.explained_variance_ratio_[1] * 100, 2)}%)",
|
|
632
|
-
"missPerc": "Missing Prop.",
|
|
633
|
-
"Population": "Population",
|
|
634
|
-
}
|
|
394
|
+
# Title & save
|
|
395
|
+
fig.suptitle("\n".join([f"{k}: {v:.2f}" for k, v in metrics.items()]), y=1.35)
|
|
635
396
|
|
|
636
|
-
if
|
|
637
|
-
|
|
638
|
-
"Axis3"
|
|
639
|
-
] = f"PC3 ({round(model.explained_variance_ratio_[2] * 100, 2)}%)"
|
|
640
|
-
fig = px.scatter_3d(
|
|
641
|
-
df_pca,
|
|
642
|
-
x="Axis1",
|
|
643
|
-
y="Axis2",
|
|
644
|
-
z="Axis3",
|
|
645
|
-
color="missPerc",
|
|
646
|
-
symbol="Population",
|
|
647
|
-
color_continuous_scale=my_scale,
|
|
648
|
-
custom_data=["Axis3", "SampleID", "Population", "missPerc"],
|
|
649
|
-
size="Size",
|
|
650
|
-
size_max=point_size,
|
|
651
|
-
labels=labs,
|
|
652
|
-
)
|
|
653
|
-
else:
|
|
654
|
-
fig = px.scatter(
|
|
655
|
-
df_pca,
|
|
656
|
-
x="Axis1",
|
|
657
|
-
y="Axis2",
|
|
658
|
-
color="missPerc",
|
|
659
|
-
symbol="Population",
|
|
660
|
-
color_continuous_scale=my_scale,
|
|
661
|
-
custom_data=["Axis3", "SampleID", "Population", "missPerc"],
|
|
662
|
-
size="Size",
|
|
663
|
-
size_max=point_size,
|
|
664
|
-
labels=labs,
|
|
665
|
-
)
|
|
666
|
-
fig.update_traces(
|
|
667
|
-
hovertemplate="<br>".join(
|
|
668
|
-
[
|
|
669
|
-
"Axis 1: %{x}",
|
|
670
|
-
"Axis 2: %{y}",
|
|
671
|
-
"Axis 3: %{customdata[0]}",
|
|
672
|
-
"Sample ID: %{customdata[1]}",
|
|
673
|
-
"Population: %{customdata[2]}",
|
|
674
|
-
"Missing Prop.: %{customdata[3]}",
|
|
675
|
-
]
|
|
676
|
-
),
|
|
677
|
-
)
|
|
678
|
-
fig.update_layout(
|
|
679
|
-
showlegend=True,
|
|
680
|
-
margin=dict(
|
|
681
|
-
b=bottom_margin,
|
|
682
|
-
t=top_margin,
|
|
683
|
-
l=left_margin,
|
|
684
|
-
r=right_margin,
|
|
685
|
-
),
|
|
686
|
-
width=width,
|
|
687
|
-
height=height,
|
|
688
|
-
legend_orientation="h",
|
|
689
|
-
legend_title="Population",
|
|
690
|
-
legend_title_font=dict(size=font_size),
|
|
691
|
-
legend_title_side="top",
|
|
692
|
-
font=dict(size=font_size),
|
|
693
|
-
)
|
|
694
|
-
fig.write_html(os.path.join(report_path, "imputed_pca.html"))
|
|
695
|
-
fig.write_image(
|
|
696
|
-
os.path.join(report_path, f"imputed_pca.{plot_format}"),
|
|
697
|
-
)
|
|
397
|
+
if prefix != "":
|
|
398
|
+
prefix = f"{prefix}_"
|
|
698
399
|
|
|
699
|
-
|
|
400
|
+
out_name = f"{self.model_name}_{prefix}roc_pr_curves.{self.plot_format}"
|
|
401
|
+
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
|
|
402
|
+
if self.show_plots:
|
|
403
|
+
plt.show()
|
|
404
|
+
plt.close(fig)
|
|
700
405
|
|
|
701
|
-
|
|
702
|
-
def plot_history(lod, nn_method, prefix="imputer"):
|
|
406
|
+
def plot_history(self, history: Dict[str, List[float]]) -> None:
|
|
703
407
|
"""Plot model history traces. Will be saved to file.
|
|
704
408
|
|
|
409
|
+
This method plots the deep learning model history traces. The plot is saved to disk as a ``<plot_format>`` file.
|
|
410
|
+
|
|
705
411
|
Args:
|
|
706
|
-
|
|
707
|
-
nn_method (str): Neural network method to plot. Possible options include: 'NLPCA', 'UBP', or 'VAE'. NLPCA and VAE get plotted the same, but UBP does it differently due to its three phases.
|
|
708
|
-
prefix (str, optional): Prefix to use for output directory. Defaults to 'imputer'.
|
|
412
|
+
history (Dict[str, List[float]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
|
|
709
413
|
|
|
710
414
|
Raises:
|
|
711
|
-
ValueError: nn_method must be either '
|
|
415
|
+
ValueError: nn_method must be either 'ImputeNLPCA', 'ImputeUBP', 'ImputeAutoencoder', 'ImputeVAE'.
|
|
712
416
|
"""
|
|
713
|
-
if
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
)
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
fig,
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
fig.suptitle(title)
|
|
729
|
-
fig.tight_layout(h_pad=3.0, w_pad=3.0)
|
|
730
|
-
history = lod[0]
|
|
731
|
-
|
|
732
|
-
acctrain = "binary_accuracy"
|
|
733
|
-
accval = "val_binary_accuracy"
|
|
734
|
-
lossval = "val_loss"
|
|
417
|
+
if self.model_name not in {
|
|
418
|
+
"ImputeNLPCA",
|
|
419
|
+
"ImputeVAE",
|
|
420
|
+
"ImputeAutoencoder",
|
|
421
|
+
"ImputeUBP",
|
|
422
|
+
}:
|
|
423
|
+
msg = "nn_method must be either 'ImputeNLPCA', 'ImputeVAE', 'ImputeAutoencoder', 'ImputeUBP'."
|
|
424
|
+
self.logger.error(msg)
|
|
425
|
+
raise ValueError(msg)
|
|
426
|
+
|
|
427
|
+
if self.model_name != "ImputeUBP":
|
|
428
|
+
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
|
|
429
|
+
df = pd.DataFrame(history)
|
|
430
|
+
df = df.iloc[1:]
|
|
735
431
|
|
|
736
432
|
# Plot train accuracy
|
|
737
|
-
|
|
738
|
-
ax1.set_title("Model Accuracy")
|
|
739
|
-
ax1.set_ylabel("Accuracy")
|
|
740
|
-
ax1.set_xlabel("Epoch")
|
|
741
|
-
ax1.set_ylim(bottom=0.0, top=1.0)
|
|
742
|
-
ax1.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
|
743
|
-
|
|
744
|
-
labels = ["Train"]
|
|
745
|
-
|
|
746
|
-
# Plot validation accuracy
|
|
747
|
-
ax1.plot(history[accval])
|
|
748
|
-
labels.append("Validation")
|
|
749
|
-
|
|
750
|
-
ax1.legend(labels, loc="best")
|
|
433
|
+
ax.plot(df["Train"], c="blue", lw=3)
|
|
751
434
|
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
435
|
+
ax.set_title(f"{self.model_name} Loss per Epoch")
|
|
436
|
+
ax.set_ylabel("Loss")
|
|
437
|
+
ax.set_xlabel("Epoch")
|
|
438
|
+
ax.legend(["Train"], loc="best", shadow=True, fancybox=True)
|
|
755
439
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
ax2.set_xlabel("Epoch")
|
|
759
|
-
ax2.legend(labels, loc="best")
|
|
440
|
+
else:
|
|
441
|
+
fig, ax = plt.subplots(3, 1, figsize=(12, 8))
|
|
760
442
|
|
|
761
|
-
|
|
443
|
+
for i, phase in enumerate(range(1, 4)):
|
|
444
|
+
train = pd.Series(history["Train"][f"Phase {phase}"])
|
|
445
|
+
train = train.iloc[1:] # ignore first epoch
|
|
762
446
|
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
fig.tight_layout(h_pad=2.0, w_pad=2.0)
|
|
770
|
-
fn = os.path.join(
|
|
771
|
-
f"{prefix}_output",
|
|
772
|
-
"plots",
|
|
773
|
-
"Unsupervised",
|
|
774
|
-
nn_method,
|
|
775
|
-
"histplot.pdf",
|
|
776
|
-
)
|
|
447
|
+
# Plot train accuracy
|
|
448
|
+
ax[i].plot(train, c="blue", lw=3)
|
|
449
|
+
ax[i].set_title(f"{self.model_name}: Phase {phase} Loss per Epoch")
|
|
450
|
+
ax[i].set_ylabel("Loss")
|
|
451
|
+
ax[i].set_xlabel("Epoch")
|
|
452
|
+
ax[i].legend([f"Phase {phase}"], loc="best", shadow=True, fancybox=True)
|
|
777
453
|
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
title = f"Phase {i}"
|
|
454
|
+
fn = f"{self.model_name.lower()}_history_plot.{self.plot_format}"
|
|
455
|
+
fn = self.output_dir / fn
|
|
456
|
+
fig.savefig(fn)
|
|
782
457
|
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
ax.set_title(f"{title} Accuracy")
|
|
787
|
-
ax.set_ylabel("Accuracy")
|
|
788
|
-
ax.set_xlabel("Epoch")
|
|
789
|
-
ax.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
|
458
|
+
if self.show_plots:
|
|
459
|
+
plt.show()
|
|
460
|
+
plt.close(fig)
|
|
790
461
|
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
462
|
+
def plot_confusion_matrix(
|
|
463
|
+
self,
|
|
464
|
+
y_true_1d: np.ndarray,
|
|
465
|
+
y_pred_1d: np.ndarray,
|
|
466
|
+
label_names: Sequence[str] | None = None,
|
|
467
|
+
prefix: str = "",
|
|
468
|
+
) -> None:
|
|
469
|
+
"""Plot a confusion matrix with optional class labels.
|
|
794
470
|
|
|
795
|
-
|
|
796
|
-
plt.subplot(3, 2, idx + 1)
|
|
797
|
-
ax = plt.gca()
|
|
798
|
-
ax.plot(history["loss"])
|
|
799
|
-
ax.set_title(f"{title} Loss")
|
|
800
|
-
ax.set_ylabel("Loss (MSE)")
|
|
801
|
-
ax.set_xlabel("Epoch")
|
|
471
|
+
This method plots a confusion matrix using true and predicted labels. The plot is saved to disk as a ``<plot_format>`` file.
|
|
802
472
|
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
473
|
+
Args:
|
|
474
|
+
y_true_1d (np.ndarray): 1D array of true integer labels in [0, n_classes-1].
|
|
475
|
+
y_pred_1d (np.ndarray): 1D array of predicted integer labels in [0, n_classes-1].
|
|
476
|
+
label_names (Sequence[str] | None): Optional sequence of class names (length must equal n_classes). If provided, both the internal label order and displayed tick labels will respect this order (assumed to be 0..n-1).
|
|
477
|
+
prefix (str): Optional prefix for the output filename.
|
|
806
478
|
|
|
807
|
-
|
|
479
|
+
Notes:
|
|
480
|
+
- If `label_names` is None, the display labels default to the numeric class indices inferred from `y_true_1d ∪ y_pred_1d`.
|
|
481
|
+
"""
|
|
482
|
+
y_true_1d = misc.validate_input_type(y_true_1d, return_type="array")
|
|
483
|
+
y_pred_1d = misc.validate_input_type(y_pred_1d, return_type="array")
|
|
808
484
|
|
|
809
|
-
|
|
485
|
+
if y_true_1d.ndim > 1:
|
|
486
|
+
y_true_1d = y_true_1d.flatten()
|
|
810
487
|
|
|
811
|
-
|
|
812
|
-
|
|
488
|
+
if y_pred_1d.ndim > 1:
|
|
489
|
+
y_pred_1d = y_pred_1d.flatten()
|
|
813
490
|
|
|
491
|
+
# Determine class count/order
|
|
492
|
+
if label_names is not None:
|
|
493
|
+
n_classes = len(label_names)
|
|
494
|
+
labels = np.arange(n_classes) # our y_* are ints 0..n-1
|
|
495
|
+
display_labels = list(map(str, label_names))
|
|
814
496
|
else:
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
@staticmethod
|
|
820
|
-
def plot_certainty_heatmap(
|
|
821
|
-
y_certainty, sample_ids=None, nn_method="VAE", prefix="imputer"
|
|
822
|
-
):
|
|
823
|
-
fig = plt.figure()
|
|
824
|
-
hm = sns.heatmap(
|
|
825
|
-
data=y_certainty,
|
|
826
|
-
cmap="viridis",
|
|
827
|
-
vmin=0.0,
|
|
828
|
-
vmax=1.0,
|
|
829
|
-
cbar_kws={"label": "Prob."},
|
|
830
|
-
)
|
|
831
|
-
hm.set_xlabel("Site")
|
|
832
|
-
hm.set_ylabel("Sample")
|
|
833
|
-
hm.set_title("Probabilities of Uncertain Sites")
|
|
834
|
-
fig.tight_layout()
|
|
835
|
-
fig.savefig(
|
|
836
|
-
os.path.join(
|
|
837
|
-
f"{prefix}_output",
|
|
838
|
-
"plots",
|
|
839
|
-
"Unsupervised",
|
|
840
|
-
nn_method,
|
|
841
|
-
"uncertainty_plot.png",
|
|
842
|
-
),
|
|
843
|
-
bbox_inches="tight",
|
|
844
|
-
facecolor="white",
|
|
845
|
-
)
|
|
497
|
+
# Infer labels from data to keep matrix tight
|
|
498
|
+
labels = np.unique(np.concatenate([y_true_1d, y_pred_1d]))
|
|
499
|
+
display_labels = labels # sklearn will convert to strings
|
|
846
500
|
|
|
847
|
-
@staticmethod
|
|
848
|
-
def plot_confusion_matrix(
|
|
849
|
-
y_true_1d, y_pred_1d, nn_method, prefix="imputer"
|
|
850
|
-
):
|
|
851
501
|
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
|
|
852
|
-
ConfusionMatrixDisplay.from_predictions(
|
|
853
|
-
y_true=y_true_1d, y_pred=y_pred_1d, ax=ax
|
|
854
|
-
)
|
|
855
502
|
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
503
|
+
ConfusionMatrixDisplay.from_predictions(
|
|
504
|
+
y_true=y_true_1d,
|
|
505
|
+
y_pred=y_pred_1d,
|
|
506
|
+
labels=labels,
|
|
507
|
+
display_labels=display_labels,
|
|
508
|
+
ax=ax,
|
|
509
|
+
cmap="viridis",
|
|
510
|
+
colorbar=True,
|
|
862
511
|
)
|
|
863
512
|
|
|
864
|
-
if
|
|
865
|
-
|
|
513
|
+
if prefix != "":
|
|
514
|
+
prefix = f"{prefix}_"
|
|
866
515
|
|
|
867
|
-
|
|
516
|
+
out_name = (
|
|
517
|
+
f"{self.model_name.lower()}_{prefix}confusion_matrix.{self.plot_format}"
|
|
518
|
+
)
|
|
519
|
+
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
|
|
520
|
+
if self.show_plots:
|
|
521
|
+
plt.show()
|
|
522
|
+
plt.close(fig)
|
|
868
523
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
cnts = pd.DataFrame(cnts).reset_index()
|
|
876
|
-
cnts.sort_values(by="Genotype", inplace=True)
|
|
877
|
-
cnts["Genotype"] = cnts["Genotype"].astype(str)
|
|
524
|
+
def plot_gt_distribution(
|
|
525
|
+
self,
|
|
526
|
+
X: np.ndarray | pd.DataFrame | list | torch.Tensor,
|
|
527
|
+
is_imputed: bool = False,
|
|
528
|
+
) -> None:
|
|
529
|
+
"""Plot genotype distribution (IUPAC or integer-encoded).
|
|
878
530
|
|
|
879
|
-
|
|
880
|
-
g = sns.barplot(x="Genotype", y="Count", data=cnts, ax=ax)
|
|
881
|
-
g.set_xlabel("Integer-encoded Genotype")
|
|
882
|
-
g.set_ylabel("Count")
|
|
883
|
-
g.set_title("Genotype Counts")
|
|
884
|
-
for p in g.patches:
|
|
885
|
-
g.annotate(
|
|
886
|
-
f"{p.get_height():.1f}",
|
|
887
|
-
(p.get_x() + 0.25, p.get_height() + 0.01),
|
|
888
|
-
xytext=(0, 1),
|
|
889
|
-
textcoords="offset points",
|
|
890
|
-
va="bottom",
|
|
891
|
-
)
|
|
531
|
+
This plots counts for all genotypes present in X. It supports IUPAC single-letter genotypes and integer encodings. Missing markers '-', '.', '?' are normalized to 'N'. Bars are annotated with counts and percentages.
|
|
892
532
|
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
533
|
+
Args:
|
|
534
|
+
X (np.ndarray | pd.DataFrame | list | torch.Tensor): Array-like genotype matrix. Rows=loci, cols=samples (any orientation is OK). Elements are IUPAC one-letter genotypes (e.g., 'A','C','G','T','N','R',...) or integers (e.g., 0/1/2[/3]).
|
|
535
|
+
is_imputed (bool): Whether these genotypes are imputed. Affects the title only. Defaults to False.
|
|
536
|
+
"""
|
|
537
|
+
# Flatten X to a 1D Series
|
|
538
|
+
if isinstance(X, pd.DataFrame):
|
|
539
|
+
arr = X.values
|
|
540
|
+
elif torch.is_tensor(X):
|
|
541
|
+
arr = X.detach().cpu().numpy()
|
|
542
|
+
else:
|
|
543
|
+
arr = np.asarray(X)
|
|
544
|
+
|
|
545
|
+
s = pd.Series(arr.ravel())
|
|
546
|
+
|
|
547
|
+
# Detect string vs numeric encodings and normalize
|
|
548
|
+
if s.dtype.kind in ("O", "U", "S"): # string-like → IUPAC path
|
|
549
|
+
s = s.astype(str).str.upper().replace({"-": "N", ".": "N", "?": "N"})
|
|
550
|
+
x_label = "Genotype (IUPAC)"
|
|
551
|
+
|
|
552
|
+
# Define canonical order: N, A/C/T/G, then IUPAC ambiguity codes.
|
|
553
|
+
canonical = ["A", "C", "T", "G"]
|
|
554
|
+
iupac_ambiguity = sorted(["M", "R", "W", "S", "Y", "K", "V", "H", "D", "B"])
|
|
555
|
+
base_order = ["N"] + canonical + iupac_ambiguity
|
|
556
|
+
else: # numeric path (e.g., 0/1/2/[3], -1 for missing)
|
|
557
|
+
# Map common missing sentinels to 'N', keep others as strings for
|
|
558
|
+
# labeling
|
|
559
|
+
s = s.astype(float) # allow NaN comparisons
|
|
560
|
+
s = s.where(~np.isin(s, [-1, np.nan]), other=np.nan)
|
|
561
|
+
s = s.fillna("N").astype(int, errors="ignore").astype(str)
|
|
562
|
+
|
|
563
|
+
x_label = "Genotype (Integer-encoded)"
|
|
564
|
+
|
|
565
|
+
# Support both ternary and quaternary encodings; keep a stable order
|
|
566
|
+
base_order = ["N", "0", "1", "2", "3"]
|
|
567
|
+
|
|
568
|
+
# Include any unexpected symbols at the end (sorted) so nothing is
|
|
569
|
+
# dropped
|
|
570
|
+
extras = sorted(set(s.unique()) - set(base_order))
|
|
571
|
+
full_order = base_order + [e for e in extras if e not in base_order]
|
|
572
|
+
|
|
573
|
+
# Count and reindex to show zero-count categories
|
|
574
|
+
counts = s.value_counts().reindex(full_order, fill_value=0)
|
|
575
|
+
df = counts.rename_axis("Genotype").reset_index(name="Count")
|
|
576
|
+
df["Percent"] = df["Count"] / df["Count"].sum() * 100
|
|
577
|
+
|
|
578
|
+
title = "Imputed Genotype Counts" if is_imputed else "Genotype Counts"
|
|
579
|
+
|
|
580
|
+
# --- Plot ---
|
|
581
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
582
|
+
sns.despine(fig=fig)
|
|
583
|
+
|
|
584
|
+
ax = sns.barplot(
|
|
585
|
+
data=df,
|
|
586
|
+
x="Genotype",
|
|
587
|
+
y="Percent",
|
|
588
|
+
hue="Genotype",
|
|
589
|
+
order=full_order,
|
|
590
|
+
errorbar=None,
|
|
591
|
+
ax=ax,
|
|
592
|
+
palette="Set1",
|
|
593
|
+
legend=False,
|
|
594
|
+
fill=True,
|
|
897
595
|
)
|
|
898
|
-
plt.close()
|
|
899
596
|
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
sns.scatterplot(x=z_mean[:, 0], y=z_mean[:, 1], ax=ax)
|
|
906
|
-
ax.set_xlabel("Latent Dimension 1")
|
|
907
|
-
ax.set_ylabel("Latent Dimension 2")
|
|
597
|
+
ax.set_xlabel(x_label)
|
|
598
|
+
ax.set_ylabel("Percent")
|
|
599
|
+
ax.set_title(title)
|
|
600
|
+
ax.set_ylim([0, 50])
|
|
908
601
|
|
|
909
|
-
|
|
910
|
-
f"{prefix}_output",
|
|
911
|
-
"plots",
|
|
912
|
-
"Unsupervised",
|
|
913
|
-
"VAE",
|
|
914
|
-
"label_clusters.png",
|
|
915
|
-
)
|
|
602
|
+
fig.tight_layout()
|
|
916
603
|
|
|
917
|
-
if
|
|
918
|
-
|
|
604
|
+
suffix = "imputed" if is_imputed else "original"
|
|
605
|
+
fn = self.output_dir / f"gt_distributions_{suffix}.{self.plot_format}"
|
|
606
|
+
fig.savefig(fn, dpi=300)
|
|
919
607
|
|
|
920
|
-
|
|
608
|
+
if self.show_plots:
|
|
609
|
+
plt.show()
|
|
610
|
+
plt.close(fig)
|