pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
pgsui/utils/plotting.py CHANGED
@@ -1,920 +1,610 @@
1
- import os
2
- import sys
3
- from itertools import cycle
4
- from pathlib import Path
1
+ import logging
5
2
  import warnings
3
+ from pathlib import Path
4
+ from typing import Dict, List, Literal, Optional, Sequence
6
5
 
7
- warnings.simplefilter(action="ignore", category=FutureWarning)
6
+ import matplotlib as mpl
8
7
 
8
+ # Use Agg backend for headless plotting
9
+ mpl.use("Agg")
9
10
 
11
+ import matplotlib.pyplot as plt
10
12
  import numpy as np
13
+ import optuna
11
14
  import pandas as pd
12
- import matplotlib.pyplot as plt
13
15
  import seaborn as sns
14
- import plotly.express as px
15
-
16
- from sklearn.decomposition import PCA
17
- from sklearn.preprocessing import StandardScaler
18
- from sklearn_genetic.utils import logbook_to_pandas
19
- from sklearn.metrics import ConfusionMatrixDisplay
20
-
21
- try:
22
- from . import misc
23
- except (ModuleNotFoundError, ValueError, ImportError):
24
- from utils import misc
16
+ import torch
17
+ from optuna.exceptions import ExperimentalWarning
18
+ from sklearn.metrics import (
19
+ ConfusionMatrixDisplay,
20
+ auc,
21
+ average_precision_score,
22
+ precision_recall_curve,
23
+ roc_curve,
24
+ )
25
+ from sklearn.preprocessing import label_binarize
26
+ from snpio.utils.logging import LoggerManager
27
+
28
+ from pgsui.utils import misc
29
+
30
+ # Quiet Matplotlib/fontTools INFO logging when saving PDF/SVG
31
+ for name in (
32
+ "fontTools",
33
+ "fontTools.subset",
34
+ "fontTools.ttLib",
35
+ "matplotlib.font_manager",
36
+ ):
37
+ lg = logging.getLogger(name)
38
+ lg.setLevel(logging.WARNING)
39
+ lg.propagate = False
25
40
 
26
41
 
27
42
  class Plotting:
28
- """Functions for plotting imputer scoring and results."""
29
-
30
- @staticmethod
31
- def plot_grid_search(cv_results, nn_method, prefix):
32
- """Plot cv_results\_ from a grid search for each parameter.
33
-
34
- Saves a figure to disk.
43
+ """Class for plotting imputer scoring and results.
44
+
45
+ This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
46
+
47
+ Example:
48
+ >>> from pgsui import Plotting
49
+ >>> plotter = Plotting(model_name="ImputeVAE")
50
+ >>> plotter.plot_metrics(metrics, num_classes)
51
+ >>> plotter.plot_history(history)
52
+ >>> plotter.plot_certainty_heatmap(y_certainty)
53
+ >>> plotter.plot_confusion_matrix(y_true_1d, y_pred_1d)
54
+ >>> plotter.plot_gt_distribution(df)
55
+ >>> plotter.plot_label_clusters(z_mean, z_log_var)
56
+
57
+ Attributes:
58
+ model_name (str): Name of the model.
59
+ prefix (str): Prefix for the output directory.
60
+ plot_format (Literal["pdf", "png", "jpeg", "jpg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg').
61
+ plot_fontsize (int): Font size for the plots.
62
+ plot_dpi (int): Dots per inch for the plots.
63
+ title_fontsize (int): Font size for the plot titles.
64
+ show_plots (bool): Whether to display the plots.
65
+ output_dir (Path): Directory where plots will be saved.
66
+ logger (logging.Logger): Logger instance for logging messages.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ model_name: str,
72
+ *,
73
+ prefix: str = "pgsui",
74
+ plot_format: Literal["pdf", "png", "jpeg", "jpg"] = "pdf",
75
+ plot_fontsize: int = 18,
76
+ plot_dpi: int = 300,
77
+ title_fontsize: int = 20,
78
+ despine: bool = True,
79
+ show_plots: bool = False,
80
+ verbose: int = 0,
81
+ debug: bool = False,
82
+ ) -> None:
83
+ """Initialize the Plotting object.
84
+
85
+ This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
35
86
 
36
87
  Args:
37
- cv_results (np.ndarray): the cv_results\_ attribute from a trained grid search object.
38
-
39
- nn_method (str): Neural network algorithm name.
40
-
41
- prefix (str): Prefix to use for saving the plot to file.
88
+ model_name (str): Name of the model.
89
+ prefix (str, optional): Prefix for the output directory. Defaults to 'pgsui'.
90
+ plot_format (Literal["pdf", "png", "jpeg", "jpg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg'). Defaults to 'pdf'.
91
+ plot_fontsize (int): Font size for the plots. Defaults to 18.
92
+ plot_dpi (int): Dots per inch for the plots. Defaults to 300.
93
+ title_fontsize (int): Font size for the plot titles. Defaults to 20.
94
+ despine (bool): Whether to remove the top and right spines from the plots. Defaults to True.
95
+ show_plots (bool): Whether to display the plots. Defaults to False.
96
+ verbose (int): Verbosity level for logging. Defaults to 0.
97
+ debug (bool): Whether to enable debug mode. Defaults to False.
42
98
  """
43
- ## Results from grid search
44
- results = pd.DataFrame(cv_results)
45
- means_test = [col for col in results if col.startswith("mean_test_")]
46
- filter_col = [col for col in results if col.startswith("param_")]
47
- params_df = results[filter_col].astype(str)
48
- for i, col in enumerate(means_test):
49
- params_df[col] = results[means_test[i]]
50
-
51
- # Get number of needed subplot rows.
52
- tot = len(filter_col)
53
- cols = 4
54
- rows = int(np.ceil(tot / cols))
55
-
56
- fig = plt.figure(1, figsize=(20, 10))
57
- fig.tight_layout(pad=3.0)
58
-
59
- # Set font properties.
60
- font = {"size": 12}
61
- plt.rc("font", **font)
62
-
63
- for i, p in enumerate(filter_col, start=1):
64
- ax = fig.add_subplot(rows, cols, i)
65
-
66
- # Plot each metric.
67
- for col in means_test:
68
- # Get maximum score for each parameter setting.
69
- df_plot = params_df.groupby(p)[col].agg("max")
70
-
71
- # Convert to float if not supposed to be string.
72
- try:
73
- df_plot.index = df_plot.index.astype(float)
74
- except TypeError:
75
- pass
76
-
77
- # Sort by index (numerically if possible).
78
- df_plot = df_plot.sort_index()
79
-
80
- # Remove prefix from score name.
81
- col_new_name = col[len("mean_test_") :]
82
-
83
- ax.plot(
84
- df_plot.index.astype(str),
85
- df_plot.values,
86
- "-o",
87
- label=col_new_name,
88
- )
89
-
90
- ax.legend(loc="best")
91
-
92
- param_new_name = p[len("param_") :]
93
- ax.set_xlabel(param_new_name.lower())
94
- ax.set_ylabel("Max Score")
95
- ax.set_ylim([0, 1])
96
-
97
- fig.savefig(
98
- os.path.join(
99
- f"{prefix}_output",
100
- "plots",
101
- "Unsupervised",
102
- nn_method,
103
- "gridsearch_metrics.pdf",
104
- ),
105
- bbox_inches="tight",
106
- facecolor="white",
99
+ logman = LoggerManager(
100
+ name=__name__, prefix=prefix, verbose=verbose, debug=debug
107
101
  )
102
+ self.logger = logman.get_logger()
103
+
104
+ self.model_name = model_name
105
+ self.prefix = prefix
106
+ self.plot_format = plot_format
107
+ self.plot_fontsize = plot_fontsize
108
+ self.plot_dpi = plot_dpi
109
+ self.title_fontsize = title_fontsize
110
+ self.show_plots = show_plots
111
+
112
+ if self.plot_format.startswith("."):
113
+ self.plot_format = self.plot_format.lstrip(".")
114
+
115
+ self.param_dict = {
116
+ "axes.labelsize": self.plot_fontsize,
117
+ "axes.titlesize": self.title_fontsize,
118
+ "axes.spines.top": despine,
119
+ "axes.spines.right": despine,
120
+ "xtick.labelsize": self.plot_fontsize,
121
+ "ytick.labelsize": self.plot_fontsize,
122
+ "legend.fontsize": self.plot_fontsize,
123
+ "legend.facecolor": "white",
124
+ "figure.titlesize": self.title_fontsize,
125
+ "figure.dpi": self.plot_dpi,
126
+ "figure.facecolor": "white",
127
+ "axes.linewidth": 2.0,
128
+ "lines.linewidth": 2.0,
129
+ "font.size": self.plot_fontsize,
130
+ "savefig.bbox": "tight",
131
+ "savefig.facecolor": "white",
132
+ "savefig.dpi": self.plot_dpi,
133
+ }
108
134
 
109
- @staticmethod
110
- def plot_metrics(metrics, num_classes, prefix, nn_method):
111
- """Plot AUC-ROC and Precision-Recall performance metrics for neural network classifier.
135
+ mpl.rcParams.update(self.param_dict)
112
136
 
113
- Saves plot to PDF file on disk.
137
+ unsuper = {"ImputeVAE", "ImputeNLPCA", "ImputeAutoencoder", "ImputeUBP"}
138
+ det = {
139
+ "ImputeRefAllele",
140
+ "ImputeMostFrequent",
141
+ "ImputeMostFrequentPerPop",
142
+ "ImputePhylo",
143
+ }
144
+ sup = {"ImputeRandomForest", "ImputeHistGradientBoosting"}
145
+
146
+ if model_name in unsuper:
147
+ plot_dir = "Unsupervised"
148
+ elif model_name in det:
149
+ plot_dir = "Deterministic"
150
+ elif model_name in sup:
151
+ plot_dir = "Supervised"
152
+ else:
153
+ msg = f"model_name '{model_name}' not recognized."
154
+ self.logger.error(msg)
155
+ raise ValueError(msg)
114
156
 
115
- Args:
116
- metrics (Dict[str, Any]): Per-class, micro, and macro-averaged metrics including accuracy, ROC-AUC, and Precision-Recall with Average Precision scores.
157
+ self.output_dir = Path(f"{self.prefix}_output", plot_dir)
158
+ self.output_dir = self.output_dir / "plots" / model_name
159
+ self.output_dir.mkdir(parents=True, exist_ok=True)
117
160
 
118
- num_classes (int): Number of classes evaluated.
161
+ def plot_tuning(
162
+ self,
163
+ study: optuna.study.Study,
164
+ model_name: str,
165
+ target_name: str = "Objective Value",
166
+ ) -> None:
167
+ """Plot the optimization history of a study.
119
168
 
120
- prefix (str): Prefix to use for output plot.
169
+ This method plots the optimization history of a study. The plot is saved to disk as a ``<plot_format>`` file.
121
170
 
122
- nn_method (str): Neural network algorithm being used.
171
+ Args:
172
+ study (optuna.study.Study): Optuna study object.
173
+ model_name (str): Name of the model.
174
+ target_name (str, optional): Name of the target value. Defaults to 'Objective Value'.
123
175
  """
124
- # Set font properties.
125
- font = {"size": 12}
126
- plt.rc("font", **font)
127
-
128
- fn = os.path.join(
129
- f"{prefix}_output",
130
- "plots",
131
- "Unsupervised",
132
- nn_method,
133
- f"auc_pr_curves.pdf",
134
- )
135
- fig = plt.figure(figsize=(20, 10))
176
+ with warnings.catch_warnings():
177
+ warnings.filterwarnings("ignore", category=UserWarning)
178
+ warnings.filterwarnings("ignore", category=FutureWarning)
179
+ warnings.filterwarnings("ignore", category=ExperimentalWarning)
136
180
 
137
- acc = round(metrics["accuracy"] * 100, 2)
138
- ham = round(metrics["hamming"], 2)
181
+ od = self.output_dir / "optimize"
182
+ target_name = target_name.title()
139
183
 
140
- fig.suptitle(
141
- f"Performance Metrics\nAccuracy: {acc}\nHamming Loss: {ham}"
142
- )
143
- axs = fig.subplots(nrows=1, ncols=2)
144
- plt.subplots_adjust(hspace=0.5)
145
-
146
- # Line weight
147
- lw = 2
148
-
149
- roc_auc = metrics["roc_auc"]
150
- pr_ap = metrics["precision_recall"]
151
-
152
- metric_list = [roc_auc, pr_ap]
153
-
154
- for metric, ax in zip(metric_list, axs):
155
- if "fpr_micro" in metric:
156
- prefix1 = "fpr"
157
- prefix2 = "tpr"
158
- lab1 = "ROC"
159
- lab2 = "AUC"
160
- xlab = "False Positive Rate"
161
- ylab = "True Positive Rate"
162
- title = "Receiver Operating Characteristic (ROC)"
163
- baseline = [0, 1]
164
-
165
- elif "recall_micro" in metric:
166
- prefix1 = "recall"
167
- prefix2 = "precision"
168
- lab1 = "Precision-Recall"
169
- lab2 = "AP"
170
- xlab = "Recall"
171
- ylab = "Precision"
172
- title = "Precision-Recall"
173
- baseline = [metric["baseline"], metric["baseline"]]
174
-
175
- # Plot iso-f1 curves.
176
- f_scores = np.linspace(0.2, 0.8, num=4)
177
- for i, f_score in enumerate(f_scores):
178
- x = np.linspace(0.01, 1)
179
- y = f_score * x / (2 * x - f_score)
180
- ax.plot(
181
- x[y >= 0],
182
- y[y >= 0],
183
- color="gray",
184
- alpha=0.2,
185
- linewidth=lw,
186
- label="Iso-F1 Curves" if i == 0 else "",
187
- )
188
- ax.annotate(f"F1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
189
-
190
- # Plot ROC curves.
191
- ax.plot(
192
- metric[f"{prefix1}_micro"],
193
- metric[f"{prefix2}_micro"],
194
- label=f"Micro-averaged {lab1} Curve ({lab2} = {metric['micro']:.2f})",
195
- color="deeppink",
196
- linestyle=":",
197
- linewidth=4,
184
+ ax = optuna.visualization.matplotlib.plot_optimization_history(
185
+ study, target_name=target_name
198
186
  )
199
-
200
- ax.plot(
201
- metric[f"{prefix1}_macro"],
202
- metric[f"{prefix2}_macro"],
203
- label=f"Macro-averaged {lab1} Curve ({lab2} = {metric['macro']:.2f})",
204
- color="navy",
205
- linestyle=":",
206
- linewidth=4,
187
+ ax.set_title(f"{model_name} Optimization History")
188
+ ax.set_xlabel("Trial")
189
+ ax.set_ylabel(target_name)
190
+ ax.legend(
191
+ loc="best",
192
+ shadow=True,
193
+ fancybox=True,
194
+ fontsize=mpl.rcParamsDefault["legend.fontsize"],
207
195
  )
208
196
 
209
- colors = cycle(["aqua", "darkorange", "cornflowerblue"])
210
- for i, color in zip(range(num_classes), colors):
211
- if f"{prefix1}_{i}" in metric:
212
- ax.plot(
213
- metric[f"{prefix1}_{i}"],
214
- metric[f"{prefix2}_{i}"],
215
- color=color,
216
- lw=lw,
217
- label=f"{lab1} Curve of class {i} ({lab2} = {metric[i]:.2f})",
218
- )
219
-
220
- if "fpr_micro" in metric:
221
- # Make center baseline
222
- ax.plot(
223
- baseline,
224
- baseline,
225
- "k--",
226
- linewidth=lw,
227
- label="No Classification Skill",
228
- )
229
- else:
230
- ax.plot(
231
- [0, 1],
232
- baseline,
233
- "k--",
234
- linewidth=lw,
235
- label="No Classification Skill",
236
- )
237
-
238
- ax.set_xlim(0.0, 1.0)
239
- ax.set_ylim(0.0, 1.05)
240
- ax.set_xlabel(f"{xlab}")
241
- ax.set_ylabel(f"{ylab}")
242
- ax.set_title(f"{title}")
243
- ax.legend(loc="best")
244
-
245
- fig.savefig(fn, bbox_inches="tight", facecolor="white")
246
- plt.close()
247
- plt.clf()
248
- plt.cla()
249
-
250
- @staticmethod
251
- def plot_search_space(
252
- estimator,
253
- height=2,
254
- s=25,
255
- features=None,
256
- ):
257
- """Make density and contour plots for showing search space during grid search.
258
-
259
- Modified from sklearn-genetic-opt function to implement exception handling.
197
+ fn = od / f"optuna_optimization_history.{self.plot_format}"
260
198
 
261
- Args:
262
- estimator (sklearn estimator object): A fitted estimator from :class:`~sklearn_genetic.GASearchCV`.
263
-
264
- height (float, optional): Height of each facet. Defaults to 2.
265
-
266
- s (float, optional): Size of the markers in scatter plot. Defaults to 5.
267
-
268
- features (list, optional): Subset of features to plot, if ``None`` it plots all the features by default. Defaults to None.
269
-
270
- Returns:
271
- g (seaborn.PairGrid): Pair plot of the used hyperparameters during the search.
272
- """
273
- sns.set_style("white")
274
-
275
- df = logbook_to_pandas(estimator.logbook)
276
- if features:
277
- _stats = df[features]
278
- else:
279
- variables = [*estimator.space.parameters, "score"]
280
- _stats = df[variables]
281
-
282
- g = sns.PairGrid(_stats, diag_sharey=False, height=height)
199
+ if not fn.parent.exists():
200
+ fn.parent.mkdir(parents=True, exist_ok=True)
283
201
 
284
- g = g.map_upper(sns.scatterplot, s=s, color="r", alpha=0.2)
202
+ plt.savefig(fn)
203
+ plt.close()
285
204
 
286
- try:
287
- g = g.map_lower(
288
- sns.kdeplot,
289
- shade=True,
290
- cmap=sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True),
205
+ ax = optuna.visualization.matplotlib.plot_edf(
206
+ study, target_name=target_name
291
207
  )
292
- except np.linalg.LinAlgError as err:
293
- if "singular matrix" in str(err).lower():
294
- g = g.map_lower(sns.scatterplot, s=s, color="b", alpha=1.0)
295
- else:
296
- raise
297
-
298
- try:
299
- g = g.map_diag(
300
- sns.kdeplot,
301
- shade=True,
302
- palette="crest",
303
- alpha=0.2,
304
- color="red",
208
+ ax.set_title(f"{model_name} Empirical Distribution Function (EDF)")
209
+ ax.set_xlabel(target_name)
210
+ ax.set_ylabel(f"{model_name} Cumulative Probability")
211
+ ax.legend(
212
+ loc="best",
213
+ shadow=True,
214
+ fancybox=True,
215
+ fontsize=mpl.rcParamsDefault["legend.fontsize"],
305
216
  )
306
- except np.linalg.LinAlgError as err:
307
- if "singular matrix" in str(err).lower():
308
- g = g.map_diag(sns.histplot, color="red", alpha=1.0, kde=False)
309
-
310
- return g
311
-
312
- @staticmethod
313
- def visualize_missingness(
314
- genotype_data,
315
- df,
316
- zoom=True,
317
- prefix="imputer",
318
- horizontal_space=0.6,
319
- vertical_space=0.6,
320
- bar_color="gray",
321
- heatmap_palette="magma",
322
- plot_format="pdf",
323
- dpi=300,
324
- ):
325
- """Make multiple plots to visualize missing data.
326
-
327
- Args:
328
- genotype_data (GenotypeData): Initialized GentoypeData object.
329
217
 
330
- df (pandas.DataFrame): DataFrame with snps to visualize.
331
-
332
- zoom (bool, optional): If True, zooms in to the missing proportion range on some of the plots. If False, the plot range is fixed at [0, 1]. Defaults to True.
218
+ plt.savefig(fn.with_stem("optuna_edf_plot"))
219
+ plt.close()
333
220
 
334
- prefix (str, optional): Prefix for output directory and files. Plots and files will be written to a directory called <prefix>_reports. The report directory will be created if it does not already exist. If prefix is None, then the reports directory will not have a prefix. Defaults to 'imputer'.
221
+ ax = optuna.visualization.matplotlib.plot_param_importances(
222
+ study, target_name=target_name
223
+ )
224
+ ax.set_xlabel("Parameter Importance")
225
+ ax.set_ylabel("Parameter")
226
+ ax.legend(loc="best", shadow=True, fancybox=True)
335
227
 
336
- horizontal_space (float, optional): Set width spacing between subplots. If your plot are overlapping horizontally, increase horizontal_space. If your plots are too far apart, decrease it. Defaults to 0.6.
228
+ plt.savefig(fn.with_stem("optuna_param_importances_plot"))
229
+ plt.close()
337
230
 
338
- vertical_space (float, optioanl): Set height spacing between subplots. If your plots are overlapping vertically, increase vertical_space. If your plots are too far apart, decrease it. Defaults to 0.6.
231
+ ax = optuna.visualization.matplotlib.plot_timeline(study)
232
+ ax.set_title(f"{model_name} Timeline Plot")
233
+ ax.set_xlabel("Datetime")
234
+ ax.set_ylabel("Trial")
235
+ plt.savefig(fn.with_stem("optuna_timeline_plot"))
236
+ plt.close()
339
237
 
340
- bar_color (str, optional): Color of the bars on the non-stacked barplots. Can be any color supported by matplotlib. See matplotlib.pyplot.colors documentation. Defaults to 'gray'.
238
+ # Reset the style from Optuna's plotting.
239
+ sns.set_style("white", rc=self.param_dict)
240
+ mpl.rcParams.update(self.param_dict)
341
241
 
342
- heatmap_palette (str, optional): Palette to use for heatmap plot. Can be any palette supported by seaborn. See seaborn documentation. Defaults to 'magma'.
242
+ def plot_metrics(
243
+ self,
244
+ y_true: np.ndarray,
245
+ y_pred_proba: np.ndarray,
246
+ metrics: Dict[str, float],
247
+ label_names: Optional[Sequence[str]] = None,
248
+ prefix: str = "",
249
+ ) -> None:
250
+ """Plot multi-class ROC-AUC and Precision-Recall curves.
343
251
 
344
- plot_format (str, optional): Format to save plots. Can be any of the following: "pdf", "png", "svg", "ps", "eps". Defaults to "pdf".
252
+ This method plots the multi-class ROC-AUC and Precision-Recall curves. The plot is saved to disk as a ``<plot_format>`` file.
345
253
 
346
- dpi (int): The resolution in dots per inch. Defaults to 300.
254
+ Args:
255
+ y_true (np.ndarray): 1D array of true integer labels in [0, n_classes-1].
256
+ y_pred_proba (np.ndarray): (n_samples, n_classes) array of predicted probabilities.
257
+ metrics (Dict[str, float]): Dict of summary metrics to annotate the figure.
258
+ label_names (Optional[Sequence[str]]): Optional sequence of class names (length must equal n_classes).
259
+ If provided, legends will use these names instead of 'Class i'.
260
+ prefix (str): Optional prefix for the output filename.
347
261
 
348
- Returns:
349
- pandas.DataFrame: Per-locus missing data proportions.
350
- pandas.DataFrame: Per-individual missing data proportions.
351
- pandas.DataFrame: Per-population + per-locus missing data proportions.
352
- pandas.DataFrame: Per-population missing data proportions.
353
- pandas.DataFrame: Per-individual and per-population missing data proportions.
262
+ Raises:
263
+ ValueError: If model_name is not recognized (legacy guard).
354
264
  """
265
+ num_classes = y_pred_proba.shape[1]
355
266
 
356
- loc, ind, poploc, poptotal, indpop = genotype_data.calc_missing(df)
357
-
358
- ncol = 3
359
- nrow = 1 if genotype_data.pops is None else 2
360
-
361
- fig, axes = plt.subplots(nrow, ncol, figsize=(8, 11))
362
- plt.subplots_adjust(wspace=horizontal_space, hspace=vertical_space)
363
- fig.suptitle("Missingness Report")
364
-
365
- ax = axes[0, 0]
366
-
367
- ax.set_title("Per-Individual")
368
- ax.barh(genotype_data.samples, ind, color=bar_color, height=1.0)
369
- if not zoom:
370
- ax.set_xlim([0, 1])
371
- ax.set_ylabel("Sample")
372
- ax.set_xlabel("Missing Prop.")
373
- ax.tick_params(
374
- axis="y",
375
- which="both",
376
- left=False,
377
- right=False,
378
- labelleft=False,
379
- )
380
-
381
- ax = axes[0, 1]
267
+ # Validate/normalize label names
268
+ if label_names is not None and len(label_names) != num_classes:
269
+ self.logger.warning(
270
+ f"plot_metrics: len(label_names)={len(label_names)} "
271
+ f"!= n_classes={num_classes}. Ignoring label_names."
272
+ )
273
+ label_names = None
274
+ if label_names is None:
275
+ label_names = [f"Class {i}" for i in range(num_classes)]
276
+
277
+ # Binarize y_true for one-vs-rest curves
278
+ y_true_bin = label_binarize(y_true, classes=np.arange(num_classes))
279
+
280
+ # Containers
281
+ fpr, tpr, roc_auc = {}, {}, {}
282
+ precision, recall, average_precision = {}, {}, {}
283
+
284
+ # Per-class ROC & PR
285
+ for i in range(num_classes):
286
+ fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
287
+ roc_auc[i] = auc(fpr[i], tpr[i])
288
+ precision[i], recall[i], _ = precision_recall_curve(
289
+ y_true_bin[:, i], y_pred_proba[:, i]
290
+ )
291
+ average_precision[i] = average_precision_score(
292
+ y_true_bin[:, i], y_pred_proba[:, i]
293
+ )
382
294
 
383
- ax.set_title("Per-Locus")
384
- ax.barh(
385
- range(genotype_data.num_snps), loc, color=bar_color, height=1.0
295
+ # Micro-averages
296
+ fpr["micro"], tpr["micro"], _ = roc_curve(
297
+ y_true_bin.ravel(), y_pred_proba.ravel()
386
298
  )
387
- if not zoom:
388
- ax.set_xlim([0, 1])
389
- ax.set_ylabel("Locus")
390
- ax.set_xlabel("Missing Prop.")
391
- ax.tick_params(
392
- axis="y",
393
- which="both",
394
- left=False,
395
- right=False,
396
- labelleft=False,
299
+ roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
300
+ precision["micro"], recall["micro"], _ = precision_recall_curve(
301
+ y_true_bin.ravel(), y_pred_proba.ravel()
302
+ )
303
+ average_precision["micro"] = average_precision_score(
304
+ y_true_bin, y_pred_proba, average="micro"
397
305
  )
398
306
 
399
- id_vars = ["SampleID"]
400
- if poptotal is not None:
401
- ax = axes[0, 2]
402
-
403
- ax.set_title("Per-Population Total")
404
- ax.barh(poptotal.index, poptotal, color=bar_color, height=1.0)
405
- if not zoom:
406
- ax.set_xlim([0, 1])
407
- ax.set_xlabel("Missing Prop.")
408
- ax.set_ylabel("Population")
409
-
410
- ax = axes[1, 0]
411
-
412
- ax.set_title("Per-Population + Per-Locus")
413
- npops = len(poploc.columns)
414
-
415
- vmax = None if zoom else 1.0
416
-
417
- sns.heatmap(
418
- poploc,
419
- vmin=0.0,
420
- vmax=vmax,
421
- cmap=sns.color_palette(heatmap_palette, as_cmap=True),
422
- yticklabels=False,
423
- cbar_kws={"label": "Missing Prop."},
424
- ax=ax,
425
- )
426
- ax.set_xlabel("Population")
427
- ax.set_ylabel("Locus")
428
-
429
- id_vars.append("Population")
430
-
431
- melt_df = indpop.isna()
432
- melt_df["SampleID"] = genotype_data.samples
433
- indpop["SampleID"] = genotype_data.samples
434
-
435
- if poptotal is not None:
436
- melt_df["Population"] = genotype_data.pops
437
- indpop["Population"] = genotype_data.pops
438
-
439
- melt_df = melt_df.melt(value_name="Missing", id_vars=id_vars)
440
- melt_df.sort_values(by=id_vars[::-1], inplace=True)
441
- melt_df["Missing"].replace(False, "Present", inplace=True)
442
- melt_df["Missing"].replace(True, "Missing", inplace=True)
307
+ # Macro-average ROC
308
+ all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
309
+ mean_tpr = np.zeros_like(all_fpr)
310
+ for i in range(num_classes):
311
+ mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
312
+ mean_tpr /= num_classes
313
+ fpr["macro"], tpr["macro"] = all_fpr, mean_tpr
314
+ roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
315
+
316
+ # Macro-average PR
317
+ all_recall = np.unique(np.concatenate([recall[i] for i in range(num_classes)]))
318
+ mean_precision = np.zeros_like(all_recall)
319
+ for i in range(num_classes):
320
+ # recall[i] increases, but precision[i] is given over decreasing thresholds
321
+ mean_precision += np.interp(all_recall, recall[i][::-1], precision[i][::-1])
322
+ mean_precision /= num_classes
323
+ average_precision["macro"] = average_precision_score(
324
+ y_true_bin, y_pred_proba, average="macro"
325
+ )
443
326
 
444
- ax = axes[0, 2] if poptotal is None else axes[1, 1]
327
+ # Plot
328
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
445
329
 
446
- ax.set_title("Per-Individual")
447
- g = sns.histplot(
448
- data=melt_df,
449
- y="variable",
450
- hue="Missing",
451
- multiple="fill",
452
- ax=ax,
330
+ # ROC
331
+ axes[0].plot(
332
+ fpr["micro"],
333
+ tpr["micro"],
334
+ label=f"Micro-average ROC (AUC = {roc_auc['micro']:.2f})",
335
+ linestyle=":",
336
+ linewidth=4,
453
337
  )
454
- ax.tick_params(
455
- axis="y",
456
- which="both",
457
- left=False,
458
- right=False,
459
- labelleft=False,
338
+ axes[0].plot(
339
+ fpr["macro"],
340
+ tpr["macro"],
341
+ label=f"Macro-average ROC (AUC = {roc_auc['macro']:.2f})",
342
+ linestyle="--",
343
+ linewidth=4,
460
344
  )
461
- g.get_legend().set_title(None)
462
-
463
- if poptotal is not None:
464
- ax = axes[1, 2]
465
-
466
- ax.set_title("Per-Population")
467
- g = sns.histplot(
468
- data=melt_df,
469
- y="Population",
470
- hue="Missing",
471
- multiple="fill",
472
- ax=ax,
345
+ for i in range(num_classes):
346
+ axes[0].plot(
347
+ fpr[i], tpr[i], label=f"{label_names[i]} ROC (AUC = {roc_auc[i]:.2f})"
473
348
  )
474
- g.get_legend().set_title(None)
475
-
476
- fig.savefig(
477
- os.path.join(
478
- f"{prefix}_output", "plots", f"missingness.{plot_format}"
479
- ),
480
- bbox_inches="tight",
481
- facecolor="white",
349
+ axes[0].plot([0, 1], [0, 1], linestyle="--", color="black", label="Random")
350
+ axes[0].set_xlabel("False Positive Rate")
351
+ axes[0].set_ylabel("True Positive Rate")
352
+ axes[0].set_title("Multi-class ROC-AUC Curve")
353
+ axes[0].legend(
354
+ loc="upper center",
355
+ bbox_to_anchor=(0.5, -0.15),
356
+ fancybox=True,
357
+ shadow=True,
358
+ ncol=2,
482
359
  )
483
- plt.cla()
484
- plt.clf()
485
- plt.close()
486
-
487
- return loc, ind, poploc, poptotal, indpop
488
-
489
- @staticmethod
490
- def run_and_plot_pca(
491
- original_genotype_data,
492
- imputer_object,
493
- prefix="imputer",
494
- n_components=3,
495
- center=True,
496
- scale=False,
497
- n_axes=2,
498
- point_size=15,
499
- font_size=15,
500
- plot_format="pdf",
501
- bottom_margin=0,
502
- top_margin=0,
503
- left_margin=0,
504
- right_margin=0,
505
- width=1088,
506
- height=700,
507
- ):
508
- """Runs PCA and makes scatterplot with colors showing missingness.
509
-
510
- Genotypes are plotted as separate shapes per population and colored according to missingness per individual.
511
-
512
- This function is run at the end of each imputation method, but can be run independently to change plot and PCA parameters such as ``n_axes=3`` or ``scale=True``.
513
-
514
- The imputed and original GenotypeData objects need to be passed to the function as positional arguments.
515
-
516
- PCA (principal component analysis) scatterplot can have either two or three axes, set with the n_axes parameter.
517
-
518
- The plot is saved as both an interactive HTML file and as a static image. Each population is represented by point shapes. The interactive plot has associated metadata when hovering over the points.
519
-
520
- Files are saved to a reports directory as <prefix>_output/imputed_pca.<plot_format|html>. Supported image formats include: "pdf", "svg", "png", and "jpeg" (or "jpg").
521
-
522
- Args:
523
- original_genotype_data (GenotypeData): Original GenotypeData object that was input into the imputer.
524
-
525
- imputer_object (Any imputer instance): Imputer object created when imputing. Can be any of the imputers, such as: ``ImputePhylo()``, ``ImputeUBP()``, and ``ImputeRandomForest()``.
526
-
527
- original_012 (pandas.DataFrame, numpy.ndarray, or List[List[int]], optional): Original 012-encoded genotypes (before imputing). Missing values are encoded as -9. This object can be obtained as ``df = GenotypeData.genotypes012_df``.
528
-
529
- prefix (str, optional): Prefix for report directory. Plots will be save to a directory called <prefix>_output/imputed_pca<html|plot_format>. Report directory will be created if it does not already exist. Defaults to "imputer".
530
-
531
- n_components (int, optional): Number of principal components to include in the PCA. Defaults to 3.
532
-
533
- center (bool, optional): If True, centers the genotypes to the mean before doing the PCA. If False, no centering is done. Defaults to True.
534
-
535
- scale (bool, optional): If True, scales the genotypes to unit variance before doing the PCA. If False, no scaling is done. Defaults to False.
536
-
537
- n_axes (int, optional): Number of principal component axes to plot. Must be set to either 2 or 3. If set to 3, a 3-dimensional plot will be made. Defaults to 2.
538
360
 
539
- point_size (int, optional): Point size for scatterplot points. Defaults to 15.
540
-
541
- plot_format (str, optional): Plot file format to use. Supported formats include: "pdf", "svg", "png", and "jpeg" (or "jpg"). An interactive HTML file is also created regardless of this setting. Defaults to "pdf".
542
-
543
- bottom_margin (int, optional): Adjust bottom margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
544
-
545
- top (int, optional): Adjust top margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
546
-
547
- left_margin (int, optional): Adjust left margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
548
-
549
- right_margin (int, optional): Adjust right margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
550
-
551
- width (int, optional): Width of plot space. If your plot is cut off at the edges, even after adjusting the margins, increase the width and height. Try to keep the aspect ratio similar. Defaults to 1088.
552
-
553
- height (int, optional): Height of plot space. If your plot is cut off at the edges, even after adjusting the margins, increase the width and height. Try to keep the aspect ratio similar. Defaults to 700.
554
-
555
- Returns:
556
- numpy.ndarray: PCA data as a numpy array with shape (n_samples, n_components).
557
-
558
- sklearn.decomposision.PCA: Scikit-learn PCA object from sklearn.decomposision.PCA. Any of the sklearn.decomposition.PCA attributes can be accessed from this object. See sklearn documentation.
559
-
560
- Examples:
561
- >>>data = GenotypeData(
562
- >>> filename="snps.str",
563
- >>> filetype="structure2row",
564
- >>> popmapfile="popmap.txt",
565
- >>>)
566
- >>>
567
- >>>ubp = ImputeUBP(genotype_data=data)
568
- >>>
569
- >>>components, pca = run_and_plot_pca(
570
- >>> data,
571
- >>> ubp,
572
- >>> scale=True,
573
- >>> center=True,
574
- >>> plot_format="png"
575
- >>>)
576
- >>>
577
- >>>explvar = pca.explained_variance_ratio\_
578
- """
579
- report_path = os.path.join(f"{prefix}_output", "plots")
580
- Path(report_path).mkdir(parents=True, exist_ok=True)
581
-
582
- if n_axes > 3:
583
- raise ValueError(
584
- ">3 axes is not supported; n_axes must be either 2 or 3."
585
- )
586
- if n_axes < 2:
587
- raise ValueError(
588
- "<2 axes is not supported; n_axes must be either 2 or 3."
589
- )
590
-
591
- imputer = imputer_object.imputed
592
-
593
- df = misc.validate_input_type(
594
- imputer.genotypes012_df, return_type="df"
361
+ # PR
362
+ axes[1].plot(
363
+ recall["micro"],
364
+ precision["micro"],
365
+ label=f"Micro-average PR (AP = {average_precision['micro']:.2f})",
366
+ linestyle=":",
367
+ linewidth=4,
595
368
  )
596
-
597
- original_df = misc.validate_input_type(
598
- original_genotype_data.genotypes_012(fmt="pandas"),
599
- return_type="df",
369
+ axes[1].plot(
370
+ all_recall,
371
+ mean_precision,
372
+ label=f"Macro-average PR (AP = {average_precision['macro']:.2f})",
373
+ linestyle="--",
374
+ linewidth=4,
600
375
  )
601
-
602
- original_df.replace(-9, np.nan, inplace=True)
603
-
604
- if center or scale:
605
- # Center data to mean. Scaling to unit variance is off.
606
- scaler = StandardScaler(with_mean=center, with_std=scale)
607
- pca_df = scaler.fit_transform(df)
608
- else:
609
- pca_df = df.copy()
610
-
611
- # Run PCA.
612
- model = PCA(n_components=n_components)
613
- components = model.fit_transform(pca_df)
614
-
615
- df_pca = pd.DataFrame(
616
- components[:, [0, 1, 2]], columns=["Axis1", "Axis2", "Axis3"]
376
+ for i in range(num_classes):
377
+ axes[1].plot(
378
+ recall[i],
379
+ precision[i],
380
+ label=f"{label_names[i]} PR (AP = {average_precision[i]:.2f})",
381
+ )
382
+ axes[1].plot([0, 1], [1, 0], linestyle="--", color="black", label="Random")
383
+ axes[1].set_xlabel("Recall")
384
+ axes[1].set_ylabel("Precision")
385
+ axes[1].set_title("Multi-class Precision-Recall Curve")
386
+ axes[1].legend(
387
+ loc="upper center",
388
+ bbox_to_anchor=(0.5, -0.15),
389
+ fancybox=True,
390
+ shadow=True,
391
+ ncol=2,
617
392
  )
618
393
 
619
- df_pca["SampleID"] = original_genotype_data.samples
620
- df_pca["Population"] = original_genotype_data.pops
621
- df_pca["Size"] = point_size
622
-
623
- _, ind, _, _, _ = imputer.calc_missing(original_df, use_pops=False)
624
- df_pca["missPerc"] = ind
625
-
626
- my_scale = [("rgb(19, 43, 67)"), ("rgb(86,177,247)")] # ggplot default
627
-
628
- z = "Axis3" if n_axes == 3 else None
629
- labs = {
630
- "Axis1": f"PC1 ({round(model.explained_variance_ratio_[0] * 100, 2)}%)",
631
- "Axis2": f"PC2 ({round(model.explained_variance_ratio_[1] * 100, 2)}%)",
632
- "missPerc": "Missing Prop.",
633
- "Population": "Population",
634
- }
394
+ # Title & save
395
+ fig.suptitle("\n".join([f"{k}: {v:.2f}" for k, v in metrics.items()]), y=1.35)
635
396
 
636
- if z is not None:
637
- labs[
638
- "Axis3"
639
- ] = f"PC3 ({round(model.explained_variance_ratio_[2] * 100, 2)}%)"
640
- fig = px.scatter_3d(
641
- df_pca,
642
- x="Axis1",
643
- y="Axis2",
644
- z="Axis3",
645
- color="missPerc",
646
- symbol="Population",
647
- color_continuous_scale=my_scale,
648
- custom_data=["Axis3", "SampleID", "Population", "missPerc"],
649
- size="Size",
650
- size_max=point_size,
651
- labels=labs,
652
- )
653
- else:
654
- fig = px.scatter(
655
- df_pca,
656
- x="Axis1",
657
- y="Axis2",
658
- color="missPerc",
659
- symbol="Population",
660
- color_continuous_scale=my_scale,
661
- custom_data=["Axis3", "SampleID", "Population", "missPerc"],
662
- size="Size",
663
- size_max=point_size,
664
- labels=labs,
665
- )
666
- fig.update_traces(
667
- hovertemplate="<br>".join(
668
- [
669
- "Axis 1: %{x}",
670
- "Axis 2: %{y}",
671
- "Axis 3: %{customdata[0]}",
672
- "Sample ID: %{customdata[1]}",
673
- "Population: %{customdata[2]}",
674
- "Missing Prop.: %{customdata[3]}",
675
- ]
676
- ),
677
- )
678
- fig.update_layout(
679
- showlegend=True,
680
- margin=dict(
681
- b=bottom_margin,
682
- t=top_margin,
683
- l=left_margin,
684
- r=right_margin,
685
- ),
686
- width=width,
687
- height=height,
688
- legend_orientation="h",
689
- legend_title="Population",
690
- legend_title_font=dict(size=font_size),
691
- legend_title_side="top",
692
- font=dict(size=font_size),
693
- )
694
- fig.write_html(os.path.join(report_path, "imputed_pca.html"))
695
- fig.write_image(
696
- os.path.join(report_path, f"imputed_pca.{plot_format}"),
697
- )
397
+ if prefix != "":
398
+ prefix = f"{prefix}_"
698
399
 
699
- return components, model
400
+ out_name = f"{self.model_name}_{prefix}roc_pr_curves.{self.plot_format}"
401
+ fig.savefig(self.output_dir / out_name, bbox_inches="tight")
402
+ if self.show_plots:
403
+ plt.show()
404
+ plt.close(fig)
700
405
 
701
- @staticmethod
702
- def plot_history(lod, nn_method, prefix="imputer"):
406
+ def plot_history(self, history: Dict[str, List[float]]) -> None:
703
407
  """Plot model history traces. Will be saved to file.
704
408
 
409
+ This method plots the deep learning model history traces. The plot is saved to disk as a ``<plot_format>`` file.
410
+
705
411
  Args:
706
- lod (List[tf.keras.callbacks.History]): List of history objects.
707
- nn_method (str): Neural network method to plot. Possible options include: 'NLPCA', 'UBP', or 'VAE'. NLPCA and VAE get plotted the same, but UBP does it differently due to its three phases.
708
- prefix (str, optional): Prefix to use for output directory. Defaults to 'imputer'.
412
+ history (Dict[str, List[float]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
709
413
 
710
414
  Raises:
711
- ValueError: nn_method must be either 'NLPCA', 'UBP', or 'VAE'.
415
+ ValueError: nn_method must be either 'ImputeNLPCA', 'ImputeUBP', 'ImputeAutoencoder', 'ImputeVAE'.
712
416
  """
713
- if nn_method == "NLPCA" or nn_method == "VAE" or nn_method == "SAE":
714
- title = nn_method
715
- fn = os.path.join(
716
- f"{prefix}_output",
717
- "plots",
718
- "Unsupervised",
719
- nn_method,
720
- "histplot.pdf",
721
- )
722
-
723
- # if nn_method == "VAE":
724
- fig, axes = plt.subplots(1, 2)
725
- ax1 = axes[0]
726
- ax2 = axes[1]
727
-
728
- fig.suptitle(title)
729
- fig.tight_layout(h_pad=3.0, w_pad=3.0)
730
- history = lod[0]
731
-
732
- acctrain = "binary_accuracy"
733
- accval = "val_binary_accuracy"
734
- lossval = "val_loss"
417
+ if self.model_name not in {
418
+ "ImputeNLPCA",
419
+ "ImputeVAE",
420
+ "ImputeAutoencoder",
421
+ "ImputeUBP",
422
+ }:
423
+ msg = "nn_method must be either 'ImputeNLPCA', 'ImputeVAE', 'ImputeAutoencoder', 'ImputeUBP'."
424
+ self.logger.error(msg)
425
+ raise ValueError(msg)
426
+
427
+ if self.model_name != "ImputeUBP":
428
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
429
+ df = pd.DataFrame(history)
430
+ df = df.iloc[1:]
735
431
 
736
432
  # Plot train accuracy
737
- ax1.plot(history[acctrain])
738
- ax1.set_title("Model Accuracy")
739
- ax1.set_ylabel("Accuracy")
740
- ax1.set_xlabel("Epoch")
741
- ax1.set_ylim(bottom=0.0, top=1.0)
742
- ax1.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
743
-
744
- labels = ["Train"]
745
-
746
- # Plot validation accuracy
747
- ax1.plot(history[accval])
748
- labels.append("Validation")
749
-
750
- ax1.legend(labels, loc="best")
433
+ ax.plot(df["Train"], c="blue", lw=3)
751
434
 
752
- # else:
753
- ax2.plot(history["loss"])
754
- ax2.plot(history[lossval])
435
+ ax.set_title(f"{self.model_name} Loss per Epoch")
436
+ ax.set_ylabel("Loss")
437
+ ax.set_xlabel("Epoch")
438
+ ax.legend(["Train"], loc="best", shadow=True, fancybox=True)
755
439
 
756
- ax2.set_title("Total Loss")
757
- ax2.set_ylabel("Loss")
758
- ax2.set_xlabel("Epoch")
759
- ax2.legend(labels, loc="best")
440
+ else:
441
+ fig, ax = plt.subplots(3, 1, figsize=(12, 8))
760
442
 
761
- fig.savefig(fn, bbox_inches="tight", facecolor="white")
443
+ for i, phase in enumerate(range(1, 4)):
444
+ train = pd.Series(history["Train"][f"Phase {phase}"])
445
+ train = train.iloc[1:] # ignore first epoch
762
446
 
763
- plt.close()
764
- plt.clf()
765
-
766
- elif nn_method == "UBP":
767
- fig = plt.figure(figsize=(12, 16))
768
- fig.suptitle(nn_method)
769
- fig.tight_layout(h_pad=2.0, w_pad=2.0)
770
- fn = os.path.join(
771
- f"{prefix}_output",
772
- "plots",
773
- "Unsupervised",
774
- nn_method,
775
- "histplot.pdf",
776
- )
447
+ # Plot train accuracy
448
+ ax[i].plot(train, c="blue", lw=3)
449
+ ax[i].set_title(f"{self.model_name}: Phase {phase} Loss per Epoch")
450
+ ax[i].set_ylabel("Loss")
451
+ ax[i].set_xlabel("Epoch")
452
+ ax[i].legend([f"Phase {phase}"], loc="best", shadow=True, fancybox=True)
777
453
 
778
- idx = 1
779
- for i, history in enumerate(lod, start=1):
780
- plt.subplot(3, 2, idx)
781
- title = f"Phase {i}"
454
+ fn = f"{self.model_name.lower()}_history_plot.{self.plot_format}"
455
+ fn = self.output_dir / fn
456
+ fig.savefig(fn)
782
457
 
783
- # Plot model accuracy
784
- ax = plt.gca()
785
- ax.plot(history["binary_accuracy"])
786
- ax.set_title(f"{title} Accuracy")
787
- ax.set_ylabel("Accuracy")
788
- ax.set_xlabel("Epoch")
789
- ax.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
458
+ if self.show_plots:
459
+ plt.show()
460
+ plt.close(fig)
790
461
 
791
- # Plot validation accuracy
792
- ax.plot(history["val_binary_accuracy"])
793
- ax.legend(["Train", "Validation"], loc="best")
462
+ def plot_confusion_matrix(
463
+ self,
464
+ y_true_1d: np.ndarray,
465
+ y_pred_1d: np.ndarray,
466
+ label_names: Sequence[str] | None = None,
467
+ prefix: str = "",
468
+ ) -> None:
469
+ """Plot a confusion matrix with optional class labels.
794
470
 
795
- # Plot model loss
796
- plt.subplot(3, 2, idx + 1)
797
- ax = plt.gca()
798
- ax.plot(history["loss"])
799
- ax.set_title(f"{title} Loss")
800
- ax.set_ylabel("Loss (MSE)")
801
- ax.set_xlabel("Epoch")
471
+ This method plots a confusion matrix using true and predicted labels. The plot is saved to disk as a ``<plot_format>`` file.
802
472
 
803
- # Plot validation loss
804
- ax.plot(history["val_loss"])
805
- ax.legend(["Train", "Validation"], loc="best")
473
+ Args:
474
+ y_true_1d (np.ndarray): 1D array of true integer labels in [0, n_classes-1].
475
+ y_pred_1d (np.ndarray): 1D array of predicted integer labels in [0, n_classes-1].
476
+ label_names (Sequence[str] | None): Optional sequence of class names (length must equal n_classes). If provided, both the internal label order and displayed tick labels will respect this order (assumed to be 0..n-1).
477
+ prefix (str): Optional prefix for the output filename.
806
478
 
807
- idx += 2
479
+ Notes:
480
+ - If `label_names` is None, the display labels default to the numeric class indices inferred from `y_true_1d ∪ y_pred_1d`.
481
+ """
482
+ y_true_1d = misc.validate_input_type(y_true_1d, return_type="array")
483
+ y_pred_1d = misc.validate_input_type(y_pred_1d, return_type="array")
808
484
 
809
- plt.savefig(fn, bbox_inches="tight", facecolor="white")
485
+ if y_true_1d.ndim > 1:
486
+ y_true_1d = y_true_1d.flatten()
810
487
 
811
- plt.close()
812
- plt.clf()
488
+ if y_pred_1d.ndim > 1:
489
+ y_pred_1d = y_pred_1d.flatten()
813
490
 
491
+ # Determine class count/order
492
+ if label_names is not None:
493
+ n_classes = len(label_names)
494
+ labels = np.arange(n_classes) # our y_* are ints 0..n-1
495
+ display_labels = list(map(str, label_names))
814
496
  else:
815
- raise ValueError(
816
- f"nn_method must be either 'NLPCA', 'UBP', or 'VAE', but got {nn_method}"
817
- )
818
-
819
- @staticmethod
820
- def plot_certainty_heatmap(
821
- y_certainty, sample_ids=None, nn_method="VAE", prefix="imputer"
822
- ):
823
- fig = plt.figure()
824
- hm = sns.heatmap(
825
- data=y_certainty,
826
- cmap="viridis",
827
- vmin=0.0,
828
- vmax=1.0,
829
- cbar_kws={"label": "Prob."},
830
- )
831
- hm.set_xlabel("Site")
832
- hm.set_ylabel("Sample")
833
- hm.set_title("Probabilities of Uncertain Sites")
834
- fig.tight_layout()
835
- fig.savefig(
836
- os.path.join(
837
- f"{prefix}_output",
838
- "plots",
839
- "Unsupervised",
840
- nn_method,
841
- "uncertainty_plot.png",
842
- ),
843
- bbox_inches="tight",
844
- facecolor="white",
845
- )
497
+ # Infer labels from data to keep matrix tight
498
+ labels = np.unique(np.concatenate([y_true_1d, y_pred_1d]))
499
+ display_labels = labels # sklearn will convert to strings
846
500
 
847
- @staticmethod
848
- def plot_confusion_matrix(
849
- y_true_1d, y_pred_1d, nn_method, prefix="imputer"
850
- ):
851
501
  fig, ax = plt.subplots(1, 1, figsize=(15, 15))
852
- ConfusionMatrixDisplay.from_predictions(
853
- y_true=y_true_1d, y_pred=y_pred_1d, ax=ax
854
- )
855
502
 
856
- outfile = os.path.join(
857
- f"{prefix}_output",
858
- "plots",
859
- "Unsupervised",
860
- nn_method,
861
- f"confusion_matrix_{nn_method}.png",
503
+ ConfusionMatrixDisplay.from_predictions(
504
+ y_true=y_true_1d,
505
+ y_pred=y_pred_1d,
506
+ labels=labels,
507
+ display_labels=display_labels,
508
+ ax=ax,
509
+ cmap="viridis",
510
+ colorbar=True,
862
511
  )
863
512
 
864
- if os.path.isfile(outfile):
865
- os.remove(outfile)
513
+ if prefix != "":
514
+ prefix = f"{prefix}_"
866
515
 
867
- fig.savefig(outfile, facecolor="white")
516
+ out_name = (
517
+ f"{self.model_name.lower()}_{prefix}confusion_matrix.{self.plot_format}"
518
+ )
519
+ fig.savefig(self.output_dir / out_name, bbox_inches="tight")
520
+ if self.show_plots:
521
+ plt.show()
522
+ plt.close(fig)
868
523
 
869
- @staticmethod
870
- def plot_gt_distribution(df, plot_path):
871
- df = misc.validate_input_type(df, return_type="df")
872
- df_melt = pd.melt(df, value_name="Count")
873
- cnts = df_melt["Count"].value_counts()
874
- cnts.index.names = ["Genotype"]
875
- cnts = pd.DataFrame(cnts).reset_index()
876
- cnts.sort_values(by="Genotype", inplace=True)
877
- cnts["Genotype"] = cnts["Genotype"].astype(str)
524
+ def plot_gt_distribution(
525
+ self,
526
+ X: np.ndarray | pd.DataFrame | list | torch.Tensor,
527
+ is_imputed: bool = False,
528
+ ) -> None:
529
+ """Plot genotype distribution (IUPAC or integer-encoded).
878
530
 
879
- fig, ax = plt.subplots(1, 1, figsize=(15, 15))
880
- g = sns.barplot(x="Genotype", y="Count", data=cnts, ax=ax)
881
- g.set_xlabel("Integer-encoded Genotype")
882
- g.set_ylabel("Count")
883
- g.set_title("Genotype Counts")
884
- for p in g.patches:
885
- g.annotate(
886
- f"{p.get_height():.1f}",
887
- (p.get_x() + 0.25, p.get_height() + 0.01),
888
- xytext=(0, 1),
889
- textcoords="offset points",
890
- va="bottom",
891
- )
531
+ This plots counts for all genotypes present in X. It supports IUPAC single-letter genotypes and integer encodings. Missing markers '-', '.', '?' are normalized to 'N'. Bars are annotated with counts and percentages.
892
532
 
893
- fig.savefig(
894
- os.path.join(plot_path, "genotype_distributions.png"),
895
- bbox_inches="tight",
896
- facecolor="white",
533
+ Args:
534
+ X (np.ndarray | pd.DataFrame | list | torch.Tensor): Array-like genotype matrix. Rows=loci, cols=samples (any orientation is OK). Elements are IUPAC one-letter genotypes (e.g., 'A','C','G','T','N','R',...) or integers (e.g., 0/1/2[/3]).
535
+ is_imputed (bool): Whether these genotypes are imputed. Affects the title only. Defaults to False.
536
+ """
537
+ # Flatten X to a 1D Series
538
+ if isinstance(X, pd.DataFrame):
539
+ arr = X.values
540
+ elif torch.is_tensor(X):
541
+ arr = X.detach().cpu().numpy()
542
+ else:
543
+ arr = np.asarray(X)
544
+
545
+ s = pd.Series(arr.ravel())
546
+
547
+ # Detect string vs numeric encodings and normalize
548
+ if s.dtype.kind in ("O", "U", "S"): # string-like → IUPAC path
549
+ s = s.astype(str).str.upper().replace({"-": "N", ".": "N", "?": "N"})
550
+ x_label = "Genotype (IUPAC)"
551
+
552
+ # Define canonical order: N, A/C/T/G, then IUPAC ambiguity codes.
553
+ canonical = ["A", "C", "T", "G"]
554
+ iupac_ambiguity = sorted(["M", "R", "W", "S", "Y", "K", "V", "H", "D", "B"])
555
+ base_order = ["N"] + canonical + iupac_ambiguity
556
+ else: # numeric path (e.g., 0/1/2/[3], -1 for missing)
557
+ # Map common missing sentinels to 'N', keep others as strings for
558
+ # labeling
559
+ s = s.astype(float) # allow NaN comparisons
560
+ s = s.where(~np.isin(s, [-1, np.nan]), other=np.nan)
561
+ s = s.fillna("N").astype(int, errors="ignore").astype(str)
562
+
563
+ x_label = "Genotype (Integer-encoded)"
564
+
565
+ # Support both ternary and quaternary encodings; keep a stable order
566
+ base_order = ["N", "0", "1", "2", "3"]
567
+
568
+ # Include any unexpected symbols at the end (sorted) so nothing is
569
+ # dropped
570
+ extras = sorted(set(s.unique()) - set(base_order))
571
+ full_order = base_order + [e for e in extras if e not in base_order]
572
+
573
+ # Count and reindex to show zero-count categories
574
+ counts = s.value_counts().reindex(full_order, fill_value=0)
575
+ df = counts.rename_axis("Genotype").reset_index(name="Count")
576
+ df["Percent"] = df["Count"] / df["Count"].sum() * 100
577
+
578
+ title = "Imputed Genotype Counts" if is_imputed else "Genotype Counts"
579
+
580
+ # --- Plot ---
581
+ fig, ax = plt.subplots(figsize=(8, 5))
582
+ sns.despine(fig=fig)
583
+
584
+ ax = sns.barplot(
585
+ data=df,
586
+ x="Genotype",
587
+ y="Percent",
588
+ hue="Genotype",
589
+ order=full_order,
590
+ errorbar=None,
591
+ ax=ax,
592
+ palette="Set1",
593
+ legend=False,
594
+ fill=True,
897
595
  )
898
- plt.close()
899
596
 
900
- @staticmethod
901
- def plot_label_clusters(z_mean, labels, prefix="imputer"):
902
- """Display a 2D plot of the classes in the latent space."""
903
- fig, ax = plt.subplots(1, 1, figsize=(15, 15))
904
-
905
- sns.scatterplot(x=z_mean[:, 0], y=z_mean[:, 1], ax=ax)
906
- ax.set_xlabel("Latent Dimension 1")
907
- ax.set_ylabel("Latent Dimension 2")
597
+ ax.set_xlabel(x_label)
598
+ ax.set_ylabel("Percent")
599
+ ax.set_title(title)
600
+ ax.set_ylim([0, 50])
908
601
 
909
- outfile = os.path.join(
910
- f"{prefix}_output",
911
- "plots",
912
- "Unsupervised",
913
- "VAE",
914
- "label_clusters.png",
915
- )
602
+ fig.tight_layout()
916
603
 
917
- if os.path.isfile(outfile):
918
- os.remove(outfile)
604
+ suffix = "imputed" if is_imputed else "original"
605
+ fn = self.output_dir / f"gt_distributions_{suffix}.{self.plot_format}"
606
+ fig.savefig(fn, dpi=300)
919
607
 
920
- fig.savefig(outfile, facecolor="white", bbox_inches="tight")
608
+ if self.show_plots:
609
+ plt.show()
610
+ plt.close(fig)