pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
pgsui/utils/plotting.py CHANGED
@@ -1,949 +1,1116 @@
1
- import os
2
- import sys
3
- from itertools import cycle
1
+ import logging
2
+ import warnings
4
3
  from pathlib import Path
4
+ from typing import Dict, List, Literal, Optional, Sequence, cast
5
5
 
6
+ import matplotlib as mpl
7
+
8
+ # Use Agg backend for headless plotting
9
+ mpl.use("Agg")
10
+
11
+ import matplotlib.pyplot as plt
6
12
  import numpy as np
13
+ import optuna
7
14
  import pandas as pd
8
- import matplotlib.pyplot as plt
9
15
  import seaborn as sns
10
- import plotly.express as px
11
-
12
- from sklearn.decomposition import PCA
13
- from sklearn.preprocessing import StandardScaler
14
- from sklearn_genetic.utils import logbook_to_pandas
15
- from sklearn.metrics import ConfusionMatrixDisplay
16
-
17
- try:
18
- from . import misc
19
- except (ModuleNotFoundError, ValueError, ImportError):
20
- from utils import misc
16
+ import torch
17
+ from optuna.exceptions import ExperimentalWarning
18
+ from sklearn.metrics import (
19
+ ConfusionMatrixDisplay,
20
+ auc,
21
+ average_precision_score,
22
+ confusion_matrix,
23
+ precision_recall_curve,
24
+ roc_curve,
25
+ )
26
+ from sklearn.preprocessing import label_binarize
27
+ from snpio import SNPioMultiQC
28
+ from snpio.utils.logging import LoggerManager
29
+
30
+ from pgsui.utils import misc
31
+ from pgsui.utils.logging_utils import configure_logger
32
+
33
+ # Quiet Matplotlib/fontTools INFO logging when saving PDF/SVG
34
+ for name in (
35
+ "fontTools",
36
+ "fontTools.subset",
37
+ "fontTools.ttLib",
38
+ "matplotlib.font_manager",
39
+ ):
40
+ lg = logging.getLogger(name)
41
+ lg.setLevel(logging.WARNING)
42
+ lg.propagate = False
21
43
 
22
44
 
23
45
  class Plotting:
24
- """Functions for plotting imputer scoring and results."""
25
-
26
- @staticmethod
27
- def plot_grid_search(cv_results, nn_method, prefix):
28
- """Plot cv_results\_ from a grid search for each parameter.
29
-
30
- Saves a figure to disk.
46
+ """Class for plotting imputer scoring and results.
47
+
48
+ This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
49
+
50
+ Example:
51
+ >>> from pgsui import Plotting
52
+ >>> plotter = Plotting(model_name="ImputeVAE", prefix="pgsui_test", plot_format="png")
53
+ >>> plotter.plot_metrics(metrics, num_classes)
54
+ >>> plotter.plot_history(history)
55
+ >>> plotter.plot_confusion_matrix(y_true_1d, y_pred_1d)
56
+ >>> plotter.plot_tuning(study, model_name, optimize_dir, target_name="Objective Value")
57
+ >>> plotter.plot_gt_distribution(df)
58
+
59
+ Attributes:
60
+ model_name (str): Name of the model.
61
+ prefix (str): Prefix for the output directory.
62
+ plot_format (Literal["pdf", "png", "jpeg", "jpg", "svg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg', 'svg').
63
+ plot_fontsize (int): Font size for the plots.
64
+ plot_dpi (int): Dots per inch for the plots.
65
+ title_fontsize (int): Font size for the plot titles.
66
+ show_plots (bool): Whether to display the plots inline or during execution.
67
+ output_dir (Path): Directory where plots will be saved.
68
+ logger (logging.Logger): Logger instance for logging messages.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ model_name: str,
74
+ *,
75
+ prefix: str = "pgsui",
76
+ plot_format: Literal["pdf", "png", "jpeg", "jpg", "svg"] = "pdf",
77
+ plot_fontsize: int = 18,
78
+ plot_dpi: int = 300,
79
+ title_fontsize: int = 20,
80
+ despine: bool = True,
81
+ show_plots: bool = False,
82
+ verbose: int = 0,
83
+ debug: bool = False,
84
+ multiqc: bool = False,
85
+ multiqc_section: Optional[str] = None,
86
+ ) -> None:
87
+ """Initialize the Plotting object.
88
+
89
+ This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
31
90
 
32
91
  Args:
33
- cv_results (np.ndarray): the cv_results\_ attribute from a trained grid search object.
34
-
35
- nn_method (str): Neural network algorithm name.
36
-
37
- prefix (str): Prefix to use for saving the plot to file.
92
+ model_name (str): Name of the model.
93
+ prefix (str, optional): Prefix for the output directory. Defaults to 'pgsui'.
94
+ plot_format (Literal["pdf", "png", "jpeg", "jpg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg'). Defaults to 'pdf'.
95
+ plot_fontsize (int): Font size for the plots. Defaults to 18.
96
+ plot_dpi (int): Dots per inch for the plots. Defaults to 300.
97
+ title_fontsize (int): Font size for the plot titles. Defaults to 20.
98
+ despine (bool): Whether to remove the top and right spines from the plots. Defaults to True.
99
+ show_plots (bool): Whether to display the plots. Defaults to False.
100
+ verbose (int): Verbosity level for logging. Defaults to 0.
101
+ debug (bool): Whether to enable debug mode. Defaults to False.
102
+ multiqc (bool): Whether to queue plots for a MultiQC HTML report. Defaults to False.
103
+ multiqc_section (Optional[str]): Section name to use in MultiQC. Defaults to 'PG-SUI (<model_name>)'.
38
104
  """
39
- ## Results from grid search
40
- results = pd.DataFrame(cv_results)
41
- means_test = [col for col in results if col.startswith("mean_test_")]
42
- filter_col = [col for col in results if col.startswith("param_")]
43
- params_df = results[filter_col].astype(str)
44
- for i, col in enumerate(means_test):
45
- params_df[col] = results[means_test[i]]
46
-
47
- # Get number of needed subplot rows.
48
- tot = len(filter_col)
49
- cols = 4
50
- rows = int(np.ceil(tot / cols))
51
-
52
- fig = plt.figure(1, figsize=(20, 10))
53
- fig.tight_layout(pad=3.0)
54
-
55
- # Set font properties.
56
- font = {"size": 12}
57
- plt.rc("font", **font)
58
-
59
- for i, p in enumerate(filter_col, start=1):
60
- ax = fig.add_subplot(rows, cols, i)
61
-
62
- # Plot each metric.
63
- for col in means_test:
64
- # Get maximum score for each parameter setting.
65
- df_plot = params_df.groupby(p)[col].agg("max")
66
-
67
- # Convert to float if not supposed to be string.
68
- try:
69
- df_plot.index = df_plot.index.astype(float)
70
- except TypeError:
71
- pass
72
-
73
- # Sort by index (numerically if possible).
74
- df_plot = df_plot.sort_index()
75
-
76
- # Remove prefix from score name.
77
- col_new_name = col[len("mean_test_") :]
78
-
79
- ax.plot(
80
- df_plot.index.astype(str),
81
- df_plot.values,
82
- "-o",
83
- label=col_new_name,
84
- )
105
+ logman = LoggerManager(
106
+ name=__name__, prefix=prefix, verbose=bool(verbose), debug=bool(debug)
107
+ )
108
+ self.logger = configure_logger(
109
+ logman.get_logger(), verbose=bool(verbose), debug=bool(debug)
110
+ )
85
111
 
86
- ax.legend(loc="best")
112
+ self.model_name = model_name
113
+ self.prefix = prefix
114
+ self.plot_format = plot_format
115
+ self.plot_fontsize = plot_fontsize
116
+ self.plot_dpi = plot_dpi
117
+ self.title_fontsize = title_fontsize
118
+ self.show_plots = show_plots
87
119
 
88
- param_new_name = p[len("param_") :]
89
- ax.set_xlabel(param_new_name.lower())
90
- ax.set_ylabel("Max Score")
91
- ax.set_ylim([0, 1])
120
+ # MultiQC configuration
121
+ self.use_multiqc: bool = bool(multiqc)
92
122
 
93
- fig.savefig(
94
- os.path.join(
95
- f"{prefix}_output",
96
- "plots",
97
- "Unsupervised",
98
- nn_method,
99
- "gridsearch_metrics.pdf",
100
- ),
101
- bbox_inches="tight",
102
- facecolor="white",
123
+ self.multiqc_section: str = (
124
+ multiqc_section if multiqc_section is not None else f"PG-SUI ({model_name})"
103
125
  )
104
126
 
105
- @staticmethod
106
- def plot_metrics(metrics, num_classes, prefix, nn_method):
107
- """Plot AUC-ROC and Precision-Recall performance metrics for neural network classifier.
127
+ if self.plot_format.startswith("."):
128
+ self.plot_format = self.plot_format.lstrip(".")
129
+
130
+ self.param_dict = {
131
+ "axes.labelsize": self.plot_fontsize,
132
+ "axes.titlesize": self.title_fontsize,
133
+ "axes.spines.top": despine,
134
+ "axes.spines.right": despine,
135
+ "xtick.labelsize": self.plot_fontsize,
136
+ "ytick.labelsize": self.plot_fontsize,
137
+ "legend.fontsize": self.plot_fontsize,
138
+ "legend.facecolor": "white",
139
+ "figure.titlesize": self.title_fontsize,
140
+ "figure.dpi": self.plot_dpi,
141
+ "figure.facecolor": "white",
142
+ "axes.linewidth": 2.0,
143
+ "lines.linewidth": 2.0,
144
+ "font.size": self.plot_fontsize,
145
+ "savefig.bbox": "tight",
146
+ "savefig.facecolor": "white",
147
+ "savefig.dpi": self.plot_dpi,
148
+ }
108
149
 
109
- Saves plot to PDF file on disk.
150
+ mpl.rcParams.update(self.param_dict)
110
151
 
111
- Args:
112
- metrics (Dict[str, Any]): Per-class, micro, and macro-averaged metrics including accuracy, ROC-AUC, and Precision-Recall with Average Precision scores.
152
+ unsuper = {"ImputeVAE", "ImputeNLPCA", "ImputeAutoencoder", "ImputeUBP"}
113
153
 
114
- num_classes (int): Number of classes evaluated.
154
+ det = {
155
+ "ImputeRefAllele",
156
+ "ImputeMostFrequent",
157
+ "ImputeMostFrequentPerPop",
158
+ "ImputePhylo",
159
+ }
160
+
161
+ sup = {"ImputeRandomForest", "ImputeHistGradientBoosting"}
115
162
 
116
- prefix (str): Prefix to use for output plot.
163
+ if model_name in unsuper:
164
+ plot_dir = "Unsupervised"
165
+ elif model_name in det:
166
+ plot_dir = "Deterministic"
167
+ elif model_name in sup:
168
+ plot_dir = "Supervised"
169
+ else:
170
+ msg = f"model_name '{model_name}' not recognized."
171
+ self.logger.error(msg)
172
+ raise ValueError(msg)
173
+
174
+ self.output_dir = Path(f"{self.prefix}_output", plot_dir)
175
+ self.output_dir = self.output_dir / "plots" / model_name
176
+ self.output_dir.mkdir(parents=True, exist_ok=True)
177
+
178
+ # --------------------------------------------------------------------- #
179
+ # Core plotting methods #
180
+ # --------------------------------------------------------------------- #
181
+ def plot_tuning(
182
+ self,
183
+ study: optuna.study.Study,
184
+ model_name: str,
185
+ optimize_dir: Path,
186
+ target_name: str = "Objective Value",
187
+ ) -> None:
188
+ """Plot the optimization history of a study.
189
+
190
+ This method plots the optimization history of a study. The plot is saved to disk as a ``<plot_format>`` file.
117
191
 
118
- nn_method (str): Neural network algorithm being used.
192
+ Args:
193
+ study (optuna.study.Study): Optuna study object.
194
+ model_name (str): Name of the model.
195
+ optimize_dir (Path): Directory to save the optimization plots.
196
+ target_name (str, optional): Name of the target value. Defaults to 'Objective Value'.
119
197
  """
120
- # Set font properties.
121
- font = {"size": 12}
122
- plt.rc("font", **font)
123
-
124
- fn = os.path.join(
125
- f"{prefix}_output",
126
- "plots",
127
- "Unsupervised",
128
- nn_method,
129
- f"auc_pr_curves.pdf",
130
- )
131
- fig = plt.figure(figsize=(20, 10))
198
+ with warnings.catch_warnings():
199
+ warnings.filterwarnings("ignore", category=UserWarning)
200
+ warnings.filterwarnings("ignore", category=FutureWarning)
201
+ warnings.filterwarnings("ignore", category=ExperimentalWarning)
132
202
 
133
- acc = round(metrics["accuracy"] * 100, 2)
134
- ham = round(metrics["hamming"], 2)
203
+ target_name = target_name.title()
135
204
 
136
- fig.suptitle(
137
- f"Performance Metrics\nAccuracy: {acc}\nHamming Loss: {ham}"
138
- )
139
- axs = fig.subplots(nrows=1, ncols=2)
140
- plt.subplots_adjust(hspace=0.5)
141
-
142
- # Line weight
143
- lw = 2
144
-
145
- roc_auc = metrics["roc_auc"]
146
- pr_ap = metrics["precision_recall"]
147
-
148
- metric_list = [roc_auc, pr_ap]
149
-
150
- for metric, ax in zip(metric_list, axs):
151
- if "fpr_micro" in metric:
152
- prefix1 = "fpr"
153
- prefix2 = "tpr"
154
- lab1 = "ROC"
155
- lab2 = "AUC"
156
- xlab = "False Positive Rate"
157
- ylab = "True Positive Rate"
158
- title = "Receiver Operating Characteristic (ROC)"
159
- baseline = [0, 1]
160
-
161
- elif "recall_micro" in metric:
162
- prefix1 = "recall"
163
- prefix2 = "precision"
164
- lab1 = "Precision-Recall"
165
- lab2 = "AP"
166
- xlab = "Recall"
167
- ylab = "Precision"
168
- title = "Precision-Recall"
169
- baseline = [metric["baseline"], metric["baseline"]]
170
-
171
- # Plot iso-f1 curves.
172
- f_scores = np.linspace(0.2, 0.8, num=4)
173
- for i, f_score in enumerate(f_scores):
174
- x = np.linspace(0.01, 1)
175
- y = f_score * x / (2 * x - f_score)
176
- ax.plot(
177
- x[y >= 0],
178
- y[y >= 0],
179
- color="gray",
180
- alpha=0.2,
181
- linewidth=lw,
182
- label="Iso-F1 Curves" if i == 0 else "",
183
- )
184
- ax.annotate(f"F1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
185
-
186
- # Plot ROC curves.
187
- ax.plot(
188
- metric[f"{prefix1}_micro"],
189
- metric[f"{prefix2}_micro"],
190
- label=f"Micro-averaged {lab1} Curve ({lab2} = {metric['micro']:.2f})",
191
- color="deeppink",
192
- linestyle=":",
193
- linewidth=4,
205
+ ax = optuna.visualization.matplotlib.plot_optimization_history(
206
+ study, target_name=target_name
194
207
  )
195
-
196
- ax.plot(
197
- metric[f"{prefix1}_macro"],
198
- metric[f"{prefix2}_macro"],
199
- label=f"Macro-averaged {lab1} Curve ({lab2} = {metric['macro']:.2f})",
200
- color="navy",
201
- linestyle=":",
202
- linewidth=4,
208
+ ax.set_title(f"{model_name} Optimization History")
209
+ ax.set_xlabel("Trial")
210
+ ax.set_ylabel(target_name)
211
+ ax.legend(
212
+ loc="best",
213
+ shadow=True,
214
+ fancybox=True,
215
+ fontsize=mpl.rcParamsDefault["legend.fontsize"],
203
216
  )
204
217
 
205
- colors = cycle(["aqua", "darkorange", "cornflowerblue"])
206
- for i, color in zip(range(num_classes), colors):
207
- if f"{prefix1}_{i}" in metric:
208
- ax.plot(
209
- metric[f"{prefix1}_{i}"],
210
- metric[f"{prefix2}_{i}"],
211
- color=color,
212
- lw=lw,
213
- label=f"{lab1} Curve of class {i} ({lab2} = {metric[i]:.2f})",
214
- )
215
-
216
- if "fpr_micro" in metric:
217
- # Make center baseline
218
- ax.plot(
219
- baseline,
220
- baseline,
221
- "k--",
222
- linewidth=lw,
223
- label="No Classification Skill",
224
- )
225
- else:
226
- ax.plot(
227
- [0, 1],
228
- baseline,
229
- "k--",
230
- linewidth=lw,
231
- label="No Classification Skill",
232
- )
233
-
234
- ax.set_xlim(0.0, 1.0)
235
- ax.set_ylim(0.0, 1.05)
236
- ax.set_xlabel(f"{xlab}")
237
- ax.set_ylabel(f"{ylab}")
238
- ax.set_title(f"{title}")
239
- ax.legend(loc="best")
240
-
241
- fig.savefig(fn, bbox_inches="tight", facecolor="white")
242
- plt.close()
243
- plt.clf()
244
- plt.cla()
245
-
246
- @staticmethod
247
- def plot_search_space(
248
- estimator,
249
- height=2,
250
- s=25,
251
- features=None,
252
- ):
253
- """Make density and contour plots for showing search space during grid search.
254
-
255
- Modified from sklearn-genetic-opt function to implement exception handling.
256
-
257
- Args:
258
- estimator (sklearn estimator object): A fitted estimator from :class:`~sklearn_genetic.GASearchCV`.
259
-
260
- height (float, optional): Height of each facet. Defaults to 2.
261
-
262
- s (float, optional): Size of the markers in scatter plot. Defaults to 5.
263
-
264
- features (list, optional): Subset of features to plot, if ``None`` it plots all the features by default. Defaults to None.
218
+ od = optimize_dir
219
+ fn = od / f"optuna_optimization_history.{self.plot_format}"
265
220
 
266
- Returns:
267
- g (seaborn.PairGrid): Pair plot of the used hyperparameters during the search.
268
- """
269
- sns.set_style("white")
270
-
271
- df = logbook_to_pandas(estimator.logbook)
272
- if features:
273
- _stats = df[features]
274
- else:
275
- variables = [*estimator.space.parameters, "score"]
276
- _stats = df[variables]
221
+ if not fn.parent.exists():
222
+ fn.parent.mkdir(parents=True, exist_ok=True)
277
223
 
278
- g = sns.PairGrid(_stats, diag_sharey=False, height=height)
279
-
280
- g = g.map_upper(sns.scatterplot, s=s, color="r", alpha=0.2)
224
+ plt.savefig(fn)
225
+ plt.close()
281
226
 
282
- try:
283
- g = g.map_lower(
284
- sns.kdeplot,
285
- shade=True,
286
- cmap=sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True),
227
+ ax = optuna.visualization.matplotlib.plot_edf(
228
+ study, target_name=target_name
287
229
  )
288
- except np.linalg.LinAlgError as err:
289
- if "singular matrix" in str(err).lower():
290
- g = g.map_lower(sns.scatterplot, s=s, color="b", alpha=1.0)
291
- else:
292
- raise
293
-
294
- try:
295
- g = g.map_diag(
296
- sns.kdeplot,
297
- shade=True,
298
- palette="crest",
299
- alpha=0.2,
300
- color="red",
230
+ ax.set_title(f"{model_name} Empirical Distribution Function (EDF)")
231
+ ax.set_xlabel(target_name)
232
+ ax.set_ylabel(f"{model_name} Cumulative Probability")
233
+ ax.legend(
234
+ loc="best",
235
+ shadow=True,
236
+ fancybox=True,
237
+ fontsize=mpl.rcParamsDefault["legend.fontsize"],
301
238
  )
302
- except np.linalg.LinAlgError as err:
303
- if "singular matrix" in str(err).lower():
304
- g = g.map_diag(sns.histplot, color="red", alpha=1.0, kde=False)
305
-
306
- return g
307
-
308
- @staticmethod
309
- def visualize_missingness(
310
- genotype_data,
311
- df,
312
- zoom=True,
313
- prefix="imputer",
314
- horizontal_space=0.6,
315
- vertical_space=0.6,
316
- bar_color="gray",
317
- heatmap_palette="magma",
318
- plot_format="pdf",
319
- dpi=300,
320
- ):
321
- """Make multiple plots to visualize missing data.
322
-
323
- Args:
324
- genotype_data (GenotypeData): Initialized GentoypeData object.
325
239
 
326
- df (pandas.DataFrame): DataFrame with snps to visualize.
240
+ plt.savefig(fn.with_stem("optuna_edf_plot"))
241
+ plt.close()
327
242
 
328
- zoom (bool, optional): If True, zooms in to the missing proportion range on some of the plots. If False, the plot range is fixed at [0, 1]. Defaults to True.
243
+ ax = optuna.visualization.matplotlib.plot_param_importances(
244
+ study, target_name=target_name
245
+ )
246
+ ax.set_xlabel("Parameter Importance")
247
+ ax.set_ylabel("Parameter")
248
+ ax.legend(loc="best", shadow=True, fancybox=True)
329
249
 
330
- prefix (str, optional): Prefix for output directory and files. Plots and files will be written to a directory called <prefix>_reports. The report directory will be created if it does not already exist. If prefix is None, then the reports directory will not have a prefix. Defaults to 'imputer'.
250
+ plt.savefig(fn.with_stem("optuna_param_importances_plot"))
251
+ plt.close()
331
252
 
332
- horizontal_space (float, optional): Set width spacing between subplots. If your plot are overlapping horizontally, increase horizontal_space. If your plots are too far apart, decrease it. Defaults to 0.6.
253
+ ax = optuna.visualization.matplotlib.plot_timeline(study)
254
+ ax.set_title(f"{model_name} Timeline Plot")
255
+ ax.set_xlabel("Datetime")
256
+ ax.set_ylabel("Trial")
257
+ plt.savefig(fn.with_stem("optuna_timeline_plot"))
258
+ plt.close()
333
259
 
334
- vertical_space (float, optioanl): Set height spacing between subplots. If your plots are overlapping vertically, increase vertical_space. If your plots are too far apart, decrease it. Defaults to 0.6.
260
+ # Reset the style from Optuna's plotting.
261
+ sns.set_style("white", rc=self.param_dict)
262
+ mpl.rcParams.update(self.param_dict)
335
263
 
336
- bar_color (str, optional): Color of the bars on the non-stacked barplots. Can be any color supported by matplotlib. See matplotlib.pyplot.colors documentation. Defaults to 'gray'.
264
+ # ---- MultiQC: Optuna tuning line graph + best-params table --------
265
+ if self._multiqc_enabled():
266
+ try:
267
+ self._queue_multiqc_tuning(
268
+ study=study, model_name=model_name, target_name=target_name
269
+ )
270
+ except Exception as exc: # pragma: no cover - defensive
271
+ self.logger.warning(f"Failed to queue MultiQC tuning plots: {exc}")
337
272
 
338
- heatmap_palette (str, optional): Palette to use for heatmap plot. Can be any palette supported by seaborn. See seaborn documentation. Defaults to 'magma'.
273
+ def plot_metrics(
274
+ self,
275
+ y_true: np.ndarray,
276
+ y_pred_proba: np.ndarray,
277
+ metrics: Dict[str, float],
278
+ label_names: Optional[Sequence[str]] = None,
279
+ prefix: str = "",
280
+ ) -> None:
281
+ """Plot multi-class ROC-AUC and Precision-Recall curves.
339
282
 
340
- plot_format (str, optional): Format to save plots. Can be any of the following: "pdf", "png", "svg", "ps", "eps". Defaults to "pdf".
283
+ This method plots the multi-class ROC-AUC and Precision-Recall curves. The plot is saved to disk as a ``<plot_format>`` file.
341
284
 
342
- dpi (int): The resolution in dots per inch. Defaults to 300.
285
+ Args:
286
+ y_true (np.ndarray): 1D array of true integer labels in [0, n_classes-1].
287
+ y_pred_proba (np.ndarray): (n_samples, n_classes) array of predicted probabilities.
288
+ metrics (Dict[str, float]): Dict of summary metrics to annotate the figure.
289
+ label_names (Optional[Sequence[str]]): Optional sequence of class names (length must equal n_classes).
290
+ If provided, legends will use these names instead of 'Class i'.
291
+ prefix (str): Optional prefix for the output filename.
343
292
 
344
- Returns:
345
- pandas.DataFrame: Per-locus missing data proportions.
346
- pandas.DataFrame: Per-individual missing data proportions.
347
- pandas.DataFrame: Per-population + per-locus missing data proportions.
348
- pandas.DataFrame: Per-population missing data proportions.
349
- pandas.DataFrame: Per-individual and per-population missing data proportions.
293
+ Raises:
294
+ ValueError: If model_name is not recognized (legacy guard).
350
295
  """
296
+ num_classes = y_pred_proba.shape[1]
351
297
 
352
- loc, ind, poploc, poptotal, indpop = genotype_data.calc_missing(df)
353
-
354
- ncol = 3
355
- nrow = 1 if genotype_data.pops is None else 2
356
-
357
- fig, axes = plt.subplots(nrow, ncol, figsize=(8, 11))
358
- plt.subplots_adjust(wspace=horizontal_space, hspace=vertical_space)
359
- fig.suptitle("Missingness Report")
360
-
361
- ax = axes[0, 0]
362
-
363
- ax.set_title("Per-Individual")
364
- ax.barh(genotype_data.samples, ind, color=bar_color, height=1.0)
365
- if not zoom:
366
- ax.set_xlim([0, 1])
367
- ax.set_ylabel("Sample")
368
- ax.set_xlabel("Missing Prop.")
369
- ax.tick_params(
370
- axis="y",
371
- which="both",
372
- left=False,
373
- right=False,
374
- labelleft=False,
375
- )
376
-
377
- ax = axes[0, 1]
298
+ # Validate/normalize label names
299
+ if label_names is not None and len(label_names) != num_classes:
300
+ self.logger.warning(
301
+ f"plot_metrics: len(label_names)={len(label_names)} "
302
+ f"!= n_classes={num_classes}. Ignoring label_names."
303
+ )
304
+ label_names = None
305
+ if label_names is None:
306
+ label_names = [f"Class {i}" for i in range(num_classes)]
307
+
308
+ # Binarize y_true for one-vs-rest curves
309
+ y_true_bin = np.asarray(label_binarize(y_true, classes=np.arange(num_classes)))
310
+
311
+ # Containers
312
+ fpr, tpr, roc_auc_vals = {}, {}, {}
313
+ precision, recall, average_precision_vals = {}, {}, {}
314
+
315
+ # Per-class ROC & PR
316
+ for i in range(num_classes):
317
+ fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
318
+ roc_auc_vals[i] = auc(fpr[i], tpr[i])
319
+ precision[i], recall[i], _ = precision_recall_curve(
320
+ y_true_bin[:, i], y_pred_proba[:, i]
321
+ )
322
+ average_precision_vals[i] = average_precision_score(
323
+ y_true_bin[:, i], y_pred_proba[:, i]
324
+ )
378
325
 
379
- ax.set_title("Per-Locus")
380
- ax.barh(
381
- range(genotype_data.num_snps), loc, color=bar_color, height=1.0
326
+ # Micro-averages
327
+ fpr["micro"], tpr["micro"], _ = roc_curve(
328
+ y_true_bin.ravel(), y_pred_proba.ravel()
382
329
  )
383
- if not zoom:
384
- ax.set_xlim([0, 1])
385
- ax.set_ylabel("Locus")
386
- ax.set_xlabel("Missing Prop.")
387
- ax.tick_params(
388
- axis="y",
389
- which="both",
390
- left=False,
391
- right=False,
392
- labelleft=False,
330
+ roc_auc_vals["micro"] = auc(fpr["micro"], tpr["micro"])
331
+ precision["micro"], recall["micro"], _ = precision_recall_curve(
332
+ y_true_bin.ravel(), y_pred_proba.ravel()
333
+ )
334
+ average_precision_vals["micro"] = average_precision_score(
335
+ y_true_bin, y_pred_proba, average="micro"
393
336
  )
394
337
 
395
- id_vars = ["SampleID"]
396
- if poptotal is not None:
397
- ax = axes[0, 2]
398
-
399
- ax.set_title("Per-Population Total")
400
- ax.barh(poptotal.index, poptotal, color=bar_color, height=1.0)
401
- if not zoom:
402
- ax.set_xlim([0, 1])
403
- ax.set_xlabel("Missing Prop.")
404
- ax.set_ylabel("Population")
405
-
406
- ax = axes[1, 0]
407
-
408
- ax.set_title("Per-Population + Per-Locus")
409
- npops = len(poploc.columns)
338
+ # Macro-average ROC
339
+ all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
340
+ mean_tpr = np.zeros_like(all_fpr)
341
+ for i in range(num_classes):
342
+ mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
343
+ mean_tpr /= num_classes
344
+ fpr["macro"], tpr["macro"] = all_fpr, mean_tpr
345
+ roc_auc_vals["macro"] = auc(fpr["macro"], tpr["macro"])
346
+
347
+ # Macro-average PR
348
+ all_recall = np.unique(np.concatenate([recall[i] for i in range(num_classes)]))
349
+ mean_precision = np.zeros_like(all_recall)
350
+ for i in range(num_classes):
351
+ # recall[i] increases, but precision[i] is given over decreasing thresholds
352
+ mean_precision += np.interp(all_recall, recall[i][::-1], precision[i][::-1])
353
+ mean_precision /= num_classes
354
+ average_precision_vals["macro"] = average_precision_score(
355
+ y_true_bin, y_pred_proba, average="macro"
356
+ )
410
357
 
411
- vmax = None if zoom else 1.0
358
+ # Plot
359
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
412
360
 
413
- sns.heatmap(
414
- poploc,
415
- vmin=0.0,
416
- vmax=vmax,
417
- cmap=sns.color_palette(heatmap_palette, as_cmap=True),
418
- yticklabels=False,
419
- cbar_kws={"label": "Missing Prop."},
420
- ax=ax,
361
+ # ROC
362
+ axes[0].plot(
363
+ fpr["micro"],
364
+ tpr["micro"],
365
+ label=f"Micro-average ROC (AUC = {roc_auc_vals['micro']:.2f})",
366
+ linestyle=":",
367
+ linewidth=4,
368
+ )
369
+ axes[0].plot(
370
+ fpr["macro"],
371
+ tpr["macro"],
372
+ label=f"Macro-average ROC (AUC = {roc_auc_vals['macro']:.2f})",
373
+ linestyle="--",
374
+ linewidth=4,
375
+ )
376
+ for i in range(num_classes):
377
+ axes[0].plot(
378
+ fpr[i],
379
+ tpr[i],
380
+ label=f"{label_names[i]} ROC (AUC = {roc_auc_vals[i]:.2f})",
421
381
  )
422
- ax.set_xlabel("Population")
423
- ax.set_ylabel("Locus")
424
-
425
- id_vars.append("Population")
426
-
427
- melt_df = indpop.isna()
428
- melt_df["SampleID"] = genotype_data.samples
429
- indpop["SampleID"] = genotype_data.samples
430
-
431
- if poptotal is not None:
432
- melt_df["Population"] = genotype_data.pops
433
- indpop["Population"] = genotype_data.pops
434
-
435
- melt_df = melt_df.melt(value_name="Missing", id_vars=id_vars)
436
- melt_df.sort_values(by=id_vars[::-1], inplace=True)
437
- melt_df["Missing"].replace(False, "Present", inplace=True)
438
- melt_df["Missing"].replace(True, "Missing", inplace=True)
439
-
440
- ax = axes[0, 2] if poptotal is None else axes[1, 1]
382
+ axes[0].plot([0, 1], [0, 1], linestyle="--", color="black", label="Random")
383
+ axes[0].set_xlabel("False Positive Rate")
384
+ axes[0].set_ylabel("True Positive Rate")
385
+ axes[0].set_title("Multi-class ROC-AUC Curve")
386
+ axes[0].legend(
387
+ loc="upper center",
388
+ bbox_to_anchor=(0.5, -0.15),
389
+ fancybox=True,
390
+ shadow=True,
391
+ ncol=2,
392
+ )
441
393
 
442
- ax.set_title("Per-Individual")
443
- g = sns.histplot(
444
- data=melt_df,
445
- y="variable",
446
- hue="Missing",
447
- multiple="fill",
448
- ax=ax,
394
+ # PR
395
+ axes[1].plot(
396
+ recall["micro"],
397
+ precision["micro"],
398
+ label=f"Micro-average PR (AP = {average_precision_vals['micro']:.2f})",
399
+ linestyle=":",
400
+ linewidth=4,
449
401
  )
450
- ax.tick_params(
451
- axis="y",
452
- which="both",
453
- left=False,
454
- right=False,
455
- labelleft=False,
402
+ axes[1].plot(
403
+ all_recall,
404
+ mean_precision,
405
+ label=f"Macro-average PR (AP = {average_precision_vals['macro']:.2f})",
406
+ linestyle="--",
407
+ linewidth=4,
456
408
  )
457
- g.get_legend().set_title(None)
458
-
459
- if poptotal is not None:
460
- ax = axes[1, 2]
461
-
462
- ax.set_title("Per-Population")
463
- g = sns.histplot(
464
- data=melt_df,
465
- y="Population",
466
- hue="Missing",
467
- multiple="fill",
468
- ax=ax,
409
+ for i in range(num_classes):
410
+ axes[1].plot(
411
+ recall[i],
412
+ precision[i],
413
+ label=f"{label_names[i]} PR (AP = {average_precision_vals[i]:.2f})",
469
414
  )
470
- g.get_legend().set_title(None)
471
-
472
- fig.savefig(
473
- os.path.join(
474
- f"{prefix}_output", "plots", f"missingness.{plot_format}"
475
- ),
476
- bbox_inches="tight",
477
- facecolor="white",
415
+ axes[1].plot([0, 1], [1, 0], linestyle="--", color="black", label="Random")
416
+ axes[1].set_xlabel("Recall")
417
+ axes[1].set_ylabel("Precision")
418
+ axes[1].set_title("Multi-class Precision-Recall Curve")
419
+ axes[1].legend(
420
+ loc="upper center",
421
+ bbox_to_anchor=(0.5, -0.15),
422
+ fancybox=True,
423
+ shadow=True,
424
+ ncol=2,
478
425
  )
479
- plt.cla()
480
- plt.clf()
481
- plt.close()
482
426
 
483
- return loc, ind, poploc, poptotal, indpop
427
+ # Title & save
428
+ fig.suptitle("\n".join([f"{k}: {v:.2f}" for k, v in metrics.items()]), y=1.35)
484
429
 
485
- @staticmethod
486
- def run_and_plot_pca(
487
- original_genotype_data,
488
- imputer_object,
489
- prefix="imputer",
490
- n_components=3,
491
- center=True,
492
- scale=False,
493
- n_axes=2,
494
- point_size=15,
495
- font_size=15,
496
- plot_format="pdf",
497
- bottom_margin=0,
498
- top_margin=0,
499
- left_margin=0,
500
- right_margin=0,
501
- width=1088,
502
- height=700,
503
- ):
504
- """Runs PCA and makes scatterplot with colors showing missingness.
505
-
506
- Genotypes are plotted as separate shapes per population and colored according to missingness per individual.
430
+ prefix_for_name = f"{prefix}_" if prefix != "" else ""
431
+ out_name = (
432
+ f"{self.model_name}_{prefix_for_name}roc_pr_curves.{self.plot_format}"
433
+ )
434
+ fig.savefig(self.output_dir / out_name, bbox_inches="tight")
435
+ if self.show_plots:
436
+ plt.show()
437
+ plt.close(fig)
438
+
439
+ # ---- MultiQC: metrics table + per-class AUC/AP heatmap ------------
440
+ if self._multiqc_enabled():
441
+ try:
442
+ self._queue_multiqc_metrics(
443
+ metrics=metrics,
444
+ roc_auc=roc_auc_vals,
445
+ average_precision=average_precision_vals,
446
+ label_names=label_names,
447
+ panel_prefix=prefix,
448
+ )
449
+ except Exception as exc: # pragma: no cover - defensive
450
+ self.logger.warning(f"Failed to queue MultiQC metrics plots: {exc}")
451
+
452
+ try:
453
+ self._queue_multiqc_roc_curves(
454
+ fpr=fpr,
455
+ tpr=tpr,
456
+ label_names=label_names,
457
+ panel_prefix=prefix,
458
+ )
459
+ self._queue_multiqc_pr_curves(
460
+ precision=precision,
461
+ recall=recall,
462
+ label_names=label_names,
463
+ panel_prefix=prefix,
464
+ )
465
+ except Exception as exc: # pragma: no cover - defensive
466
+ self.logger.warning(f"Failed to queue MultiQC ROC/PR curves: {exc}")
507
467
 
508
- This function is run at the end of each imputation method, but can be run independently to change plot and PCA parameters such as ``n_axes=3`` or ``scale=True``.
468
+ def plot_history(
469
+ self,
470
+ history: Dict[str, List[float] | Dict[str, List[float]] | None] | None,
471
+ ) -> None:
472
+ """Plot model history traces. Will be saved to file.
509
473
 
510
- The imputed and original GenotypeData objects need to be passed to the function as positional arguments.
474
+ This method plots the deep learning model history traces. The plot is saved to disk as a ``<plot_format>`` file.
511
475
 
512
- PCA (principal component analysis) scatterplot can have either two or three axes, set with the n_axes parameter.
476
+ Args:
477
+ history (Dict[str, List[float]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
513
478
 
514
- The plot is saved as both an interactive HTML file and as a static image. Each population is represented by point shapes. The interactive plot has associated metadata when hovering over the points.
479
+ Raises:
480
+ ValueError: nn_method must be either 'ImputeNLPCA', 'ImputeUBP', 'ImputeAutoencoder', 'ImputeVAE'.
481
+ """
482
+ if self.model_name not in {
483
+ "ImputeNLPCA",
484
+ "ImputeVAE",
485
+ "ImputeAutoencoder",
486
+ "ImputeUBP",
487
+ }:
488
+ msg = "nn_method must be either 'ImputeNLPCA', 'ImputeVAE', 'ImputeAutoencoder', 'ImputeUBP'."
489
+ self.logger.error(msg)
490
+ raise ValueError(msg)
491
+
492
+ if self.model_name != "ImputeUBP":
493
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
494
+ df = pd.DataFrame(history)
495
+ df = df.iloc[1:]
515
496
 
516
- Files are saved to a reports directory as <prefix>_output/imputed_pca.<plot_format|html>. Supported image formats include: "pdf", "svg", "png", and "jpeg" (or "jpg").
497
+ # Plot train accuracy
498
+ ax.plot(df["Train"], c="blue", lw=3)
517
499
 
518
- Args:
519
- original_genotype_data (GenotypeData): Original GenotypeData object that was input into the imputer.
500
+ ax.set_title(f"{self.model_name} Loss per Epoch")
501
+ ax.set_ylabel("Loss")
502
+ ax.set_xlabel("Epoch")
503
+ ax.legend(["Train"], loc="best", shadow=True, fancybox=True)
520
504
 
521
- imputer_object (Any imputer instance): Imputer object created when imputing. Can be any of the imputers, such as: ``ImputePhylo()``, ``ImputeUBP()``, and ``ImputeRandomForest()``.
505
+ else:
506
+ fig, ax = plt.subplots(3, 1, figsize=(12, 8))
507
+
508
+ # Ensure history is the nested dictionary type for ImputeUBP
509
+ if not (
510
+ isinstance(history, dict)
511
+ and "Train" in history
512
+ and isinstance(history["Train"], dict)
513
+ ):
514
+ msg = "For ImputeUBP, history must be a nested dictionary with phases."
515
+ self.logger.error(msg)
516
+ raise TypeError(msg)
517
+
518
+ for i, phase in enumerate(range(1, 4)):
519
+ train = pd.Series(history["Train"][f"Phase {phase}"])
520
+ train = train.iloc[1:] # ignore first epoch
521
+
522
+ # Plot train accuracy
523
+ ax[i].plot(train, c="blue", lw=3)
524
+ ax[i].set_title(f"{self.model_name}: Phase {phase} Loss per Epoch")
525
+ ax[i].set_ylabel("Loss")
526
+ ax[i].set_xlabel("Epoch")
527
+ ax[i].legend([f"Phase {phase}"], loc="best", shadow=True, fancybox=True)
528
+
529
+ fn = f"{self.model_name.lower()}_history_plot.{self.plot_format}"
530
+ fn = self.output_dir / fn
531
+ fig.savefig(fn)
532
+
533
+ if self.show_plots:
534
+ plt.show()
535
+ plt.close(fig)
536
+
537
+ # ---- MultiQC: training-loss vs epoch linegraphs -------------------
538
+ if self._multiqc_enabled():
539
+ try:
540
+ self._queue_multiqc_history(history=history)
541
+ except Exception as exc: # pragma: no cover
542
+ self.logger.warning(f"Failed to queue MultiQC history plot: {exc}")
522
543
 
523
- original_012 (pandas.DataFrame, numpy.ndarray, or List[List[int]], optional): Original 012-encoded genotypes (before imputing). Missing values are encoded as -9. This object can be obtained as ``df = GenotypeData.genotypes012_df``.
544
+ def plot_confusion_matrix(
545
+ self,
546
+ y_true_1d: np.ndarray | pd.DataFrame | List[str | int] | torch.Tensor,
547
+ y_pred_1d: np.ndarray | pd.DataFrame | List[str | int] | torch.Tensor,
548
+ label_names: Sequence[str] | Dict[str, int] | None = None,
549
+ prefix: str = "",
550
+ ) -> None:
551
+ """Plot a confusion matrix with optional class labels.
524
552
 
525
- prefix (str, optional): Prefix for report directory. Plots will be save to a directory called <prefix>_output/imputed_pca<html|plot_format>. Report directory will be created if it does not already exist. Defaults to "imputer".
553
+ This method plots a confusion matrix using true and predicted labels. The plot is saved to disk as a ``<plot_format>`` file.
526
554
 
527
- n_components (int, optional): Number of principal components to include in the PCA. Defaults to 3.
555
+ Args:
556
+ y_true_1d (np.ndarray | pd.DataFrame | list | torch.Tensor): 1D array of true integer labels in [0, n_classes-1].
557
+ y_pred_1d (np.ndarray | pd.DataFrame | list | torch.Tensor): 1D array of predicted integer labels in [0, n_classes-1].
558
+ label_names (Sequence[str] | None): Optional sequence of class names (length must equal n_classes). If provided, both the internal label order and displayed tick labels will respect this order (assumed to be 0..n-1).
559
+ prefix (str): Optional prefix for the output filename.
528
560
 
529
- center (bool, optional): If True, centers the genotypes to the mean before doing the PCA. If False, no centering is done. Defaults to True.
561
+ Notes:
562
+ - If `label_names` is None, the display labels default to the numeric class indices inferred from `y_true_1d ∪ y_pred_1d`.
563
+ """
564
+ y_true_1d = misc.validate_input_type(y_true_1d, return_type="array")
565
+ y_pred_1d = misc.validate_input_type(y_pred_1d, return_type="array")
566
+
567
+ if not isinstance(y_true_1d, np.ndarray) or y_true_1d.ndim != 1:
568
+ msg = "y_true_1d must be a 1D array-like of true labels."
569
+ self.logger.error(msg)
570
+ raise TypeError(msg)
571
+
572
+ if not isinstance(y_pred_1d, np.ndarray) or y_pred_1d.ndim != 1:
573
+ msg = "y_pred_1d must be a 1D array-like of predicted labels."
574
+ self.logger.error(msg)
575
+ raise TypeError(msg)
576
+
577
+ if y_true_1d.ndim > 1:
578
+ y_true_1d = y_true_1d.flatten()
579
+
580
+ if y_pred_1d.ndim > 1:
581
+ y_pred_1d = y_pred_1d.flatten()
582
+
583
+ # Determine class count/order
584
+ if label_names is not None:
585
+ n_classes = len(label_names)
586
+ labels = np.arange(n_classes) # our y_* are ints 0..n-1
587
+ display_labels = list(map(str, label_names))
588
+ else:
589
+ # Infer labels from data to keep matrix tight
590
+ labels = np.unique(np.concatenate([y_true_1d, y_pred_1d]))
591
+ display_labels = labels # sklearn will convert to strings
530
592
 
531
- scale (bool, optional): If True, scales the genotypes to unit variance before doing the PCA. If False, no scaling is done. Defaults to False.
593
+ fig, ax = plt.subplots(1, 1, figsize=(15, 15))
532
594
 
533
- n_axes (int, optional): Number of principal component axes to plot. Must be set to either 2 or 3. If set to 3, a 3-dimensional plot will be made. Defaults to 2.
595
+ ConfusionMatrixDisplay.from_predictions(
596
+ y_true=y_true_1d,
597
+ y_pred=y_pred_1d,
598
+ labels=labels,
599
+ display_labels=display_labels,
600
+ ax=ax,
601
+ cmap="viridis",
602
+ colorbar=True,
603
+ )
534
604
 
535
- point_size (int, optional): Point size for scatterplot points. Defaults to 15.
605
+ # Build a stable panel id before mutating prefix
606
+ panel_suffix = f"{prefix}_" if prefix else ""
607
+ panel_id = f"{self.model_name.lower()}_{panel_suffix}confusion_matrix"
536
608
 
537
- plot_format (str, optional): Plot file format to use. Supported formats include: "pdf", "svg", "png", and "jpeg" (or "jpg"). An interactive HTML file is also created regardless of this setting. Defaults to "pdf".
609
+ if prefix != "":
610
+ prefix = f"{prefix}_"
538
611
 
539
- bottom_margin (int, optional): Adjust bottom margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
612
+ out_name = (
613
+ f"{self.model_name.lower()}_{prefix}confusion_matrix.{self.plot_format}"
614
+ )
615
+ fig.savefig(self.output_dir / out_name, bbox_inches="tight")
616
+ if self.show_plots:
617
+ plt.show()
618
+ plt.close(fig)
619
+
620
+ # ---- MultiQC: confusion-matrix heatmap ----------------------------
621
+ if self._multiqc_enabled():
622
+ try:
623
+ self._queue_multiqc_confusion(
624
+ y_true=y_true_1d,
625
+ y_pred=y_pred_1d,
626
+ labels=labels,
627
+ display_labels=display_labels,
628
+ panel_id=panel_id,
629
+ )
630
+ except Exception as exc: # pragma: no cover
631
+ self.logger.warning(f"Failed to queue MultiQC confusion matrix: {exc}")
540
632
 
541
- top (int, optional): Adjust top margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
633
+ def plot_gt_distribution(
634
+ self,
635
+ X: np.ndarray | pd.DataFrame | list | torch.Tensor,
636
+ is_imputed: bool = False,
637
+ ) -> None:
638
+ """Plot genotype distribution (IUPAC or integer-encoded).
542
639
 
543
- left_margin (int, optional): Adjust left margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
640
+ This plots counts for all genotypes present in X. It supports IUPAC single-letter genotypes and integer encodings. Missing markers '-', '.', '?' are normalized to 'N'. Bars are annotated with counts and percentages.
544
641
 
545
- right_margin (int, optional): Adjust right margin. If whitespace cuts off some of your plot, lower the corresponding margins. The default corresponds to that of plotly update_layout(). Defaults to 0.
642
+ Args:
643
+ X (np.ndarray | pd.DataFrame | list | torch.Tensor): Array-like genotype matrix. Rows=loci, cols=samples (any orientation is OK). Elements are IUPAC one-letter genotypes (e.g., 'A','C','G','T','N','R',...) or integers (e.g., 0/1/2[/3]).
644
+ is_imputed (bool): Whether these genotypes are imputed. Affects the title only. Defaults to False.
645
+ """
646
+ # Flatten X to a 1D Series
647
+ if isinstance(X, pd.DataFrame):
648
+ arr = X.values
649
+ elif torch.is_tensor(X):
650
+ arr = X.detach().cpu().numpy()
651
+ else:
652
+ arr = np.asarray(X)
653
+
654
+ s = pd.Series(arr.ravel())
655
+
656
+ # Detect string vs numeric encodings and normalize
657
+ if s.dtype.kind in ("O", "U", "S"): # string-like → IUPAC path
658
+ s = s.astype(str).str.upper().replace({"-": "N", ".": "N", "?": "N"})
659
+ x_label = "Genotype (IUPAC)"
660
+
661
+ # Define canonical order: N, A/C/T/G, then IUPAC ambiguity codes.
662
+ canonical = ["A", "C", "T", "G"]
663
+ iupac_ambiguity = sorted(["M", "R", "W", "S", "Y", "K", "V", "H", "D", "B"])
664
+ base_order = ["N"] + canonical + iupac_ambiguity
665
+ else: # numeric path (e.g., 0/1/2/[3], -1 for missing)
666
+ # Map common missing sentinels to 'N', keep others as strings for
667
+ # labeling
668
+ s = s.astype(float) # allow NaN comparisons
669
+ s = s.where(~np.isin(s, [-1, np.nan]), other=np.nan)
670
+ s = s.fillna("N").astype(int, errors="ignore").astype(str)
671
+
672
+ x_label = "Genotype (Integer-encoded)"
673
+
674
+ # Support both ternary and quaternary encodings; keep a stable order
675
+ base_order = ["N", "0", "1", "2", "3"]
676
+
677
+ # Include any unexpected symbols at the end (sorted) so nothing is
678
+ # dropped
679
+ extras = sorted(set(s.unique()) - set(base_order))
680
+ full_order = base_order + [e for e in extras if e not in base_order]
681
+
682
+ # Count and reindex to show zero-count categories
683
+ counts = s.value_counts().reindex(full_order, fill_value=0)
684
+ df = counts.rename_axis("Genotype").reset_index(name="Count")
685
+ df["Percent"] = df["Count"] / df["Count"].sum() * 100
686
+
687
+ title = "Imputed Genotype Counts" if is_imputed else "Genotype Counts"
688
+
689
+ # --- Plot ---
690
+ fig, ax = plt.subplots(figsize=(8, 5))
691
+ sns.despine(fig=fig)
692
+
693
+ ax = sns.barplot(
694
+ data=df,
695
+ x="Genotype",
696
+ y="Percent",
697
+ hue="Genotype",
698
+ order=full_order,
699
+ errorbar=None,
700
+ ax=ax,
701
+ palette="Set1",
702
+ legend=False,
703
+ fill=True,
704
+ )
546
705
 
547
- width (int, optional): Width of plot space. If your plot is cut off at the edges, even after adjusting the margins, increase the width and height. Try to keep the aspect ratio similar. Defaults to 1088.
706
+ ax.set_xlabel(x_label)
707
+ ax.set_ylabel("Percent")
708
+ ax.set_title(title)
709
+ ax.set_ylim((0.0, 50.0))
548
710
 
549
- height (int, optional): Height of plot space. If your plot is cut off at the edges, even after adjusting the margins, increase the width and height. Try to keep the aspect ratio similar. Defaults to 700.
711
+ fig.tight_layout()
550
712
 
551
- Returns:
552
- numpy.ndarray: PCA data as a numpy array with shape (n_samples, n_components).
713
+ suffix = "imputed" if is_imputed else "original"
714
+ fn = self.output_dir / f"gt_distributions_{suffix}.{self.plot_format}"
715
+ fig.savefig(fn, dpi=300)
716
+
717
+ if self.show_plots:
718
+ plt.show()
719
+ plt.close(fig)
720
+
721
+ # ---- MultiQC: genotype-distribution barplot -----------------------
722
+ if self._multiqc_enabled():
723
+ try:
724
+ self._queue_multiqc_gt_distribution(df=df, is_imputed=is_imputed)
725
+ except Exception as exc: # pragma: no cover
726
+ self.logger.warning(
727
+ f"Failed to queue MultiQC genotype distribution: {exc}"
728
+ )
553
729
 
554
- sklearn.decomposision.PCA: Scikit-learn PCA object from sklearn.decomposision.PCA. Any of the sklearn.decomposition.PCA attributes can be accessed from this object. See sklearn documentation.
730
+ # --------------------------------------------------------------------- #
731
+ # MultiQC helper methods #
732
+ # --------------------------------------------------------------------- #
733
+ def _multiqc_enabled(self) -> bool:
734
+ """Return True if MultiQC integration is active."""
735
+ return bool(self.use_multiqc)
736
+
737
+ def _queue_multiqc_tuning(
738
+ self,
739
+ *,
740
+ study: optuna.study.Study,
741
+ model_name: str,
742
+ target_name: str,
743
+ ) -> None:
744
+ """Queue Optuna tuning results for MultiQC.
555
745
 
556
- Examples:
557
- >>>data = GenotypeData(
558
- >>> filename="snps.str",
559
- >>> filetype="structure2row",
560
- >>> popmapfile="popmap.txt",
561
- >>>)
562
- >>>
563
- >>>ubp = ImputeUBP(genotype_data=data)
564
- >>>
565
- >>>components, pca = run_and_plot_pca(
566
- >>> data,
567
- >>> ubp,
568
- >>> scale=True,
569
- >>> center=True,
570
- >>> plot_format="png"
571
- >>>)
572
- >>>
573
- >>>explvar = pca.explained_variance_ratio\_
746
+ Args:
747
+ study (optuna.study.Study): Optuna study object.
748
+ model_name (str): Name of the model.
749
+ target_name (str): Name of the target value.
574
750
  """
575
- report_path = os.path.join(f"{prefix}_output", "plots")
576
- Path(report_path).mkdir(parents=True, exist_ok=True)
751
+ if not self._multiqc_enabled():
752
+ return
577
753
 
578
- if n_axes > 3:
579
- raise ValueError(
580
- ">3 axes is not supported; n_axes must be either 2 or 3."
581
- )
582
- if n_axes < 2:
583
- raise ValueError(
584
- "<2 axes is not supported; n_axes must be either 2 or 3."
754
+ # trial number vs objective value line graph
755
+ try:
756
+ df_trials = study.trials_dataframe(attrs=("number", "value"))
757
+ except Exception as exc: # pragma: no cover
758
+ self.logger.warning(
759
+ f"Could not extract trials_dataframe for MultiQC: {exc}"
585
760
  )
761
+ return
586
762
 
587
- imputer = imputer_object.imputed
763
+ if df_trials.empty or "value" not in df_trials:
764
+ return
588
765
 
589
- df = misc.validate_input_type(
590
- imputer.genotypes012_df, return_type="df"
591
- )
592
-
593
- original_df = misc.validate_input_type(
594
- original_genotype_data.genotypes_012(fmt="pandas"),
595
- return_type="df",
596
- )
597
-
598
- original_df.replace(-9, np.nan, inplace=True)
599
-
600
- if center or scale:
601
- # Center data to mean. Scaling to unit variance is off.
602
- scaler = StandardScaler(with_mean=center, with_std=scale)
603
- pca_df = scaler.fit_transform(df)
604
- else:
605
- pca_df = df.copy()
606
-
607
- # Run PCA.
608
- model = PCA(n_components=n_components)
609
- components = model.fit_transform(pca_df)
766
+ data: Dict[str, Dict[int, int]] = {
767
+ model_name: {
768
+ row["number"]: row["value"]
769
+ for _, row in df_trials.iterrows()
770
+ if row["value"] is not None
771
+ }
772
+ }
610
773
 
611
- df_pca = pd.DataFrame(
612
- components[:, [0, 1, 2]], columns=["Axis1", "Axis2", "Axis3"]
774
+ if not data[model_name]:
775
+ return
776
+
777
+ SNPioMultiQC.queue_linegraph(
778
+ data=data,
779
+ panel_id=f"{self.model_name}_optuna_history",
780
+ section=self.multiqc_section,
781
+ title=f"{self.model_name} Optuna Optimization History",
782
+ index_label="Trial",
783
+ description=f"Optuna optimization history for {self.model_name} "
784
+ f"(target={target_name}).",
613
785
  )
614
786
 
615
- df_pca["SampleID"] = original_genotype_data.samples
616
- df_pca["Population"] = original_genotype_data.pops
617
- df_pca["Size"] = point_size
618
-
619
- _, ind, _, _, _ = imputer.calc_missing(original_df, use_pops=False)
620
- df_pca["missPerc"] = ind
621
-
622
- my_scale = [("rgb(19, 43, 67)"), ("rgb(86,177,247)")] # ggplot default
787
+ # best-params table
788
+ try:
789
+ best_value = study.best_value
790
+ best_params = study.best_params
791
+ except Exception:
792
+ return
793
+
794
+ if best_params:
795
+ series = pd.Series(best_params, name="Best Value")
796
+ series["objective"] = best_value
797
+ SNPioMultiQC.queue_table(
798
+ df=series,
799
+ panel_id=f"{self.model_name}_optuna_best_params",
800
+ section=self.multiqc_section,
801
+ title=f"{self.model_name} Best Optuna Parameters",
802
+ index_label="Parameter",
803
+ description="Best Optuna hyperparameters and objective value.",
804
+ )
623
805
 
624
- z = "Axis3" if n_axes == 3 else None
625
- labs = {
626
- "Axis1": f"PC1 ({round(model.explained_variance_ratio_[0] * 100, 2)}%)",
627
- "Axis2": f"PC2 ({round(model.explained_variance_ratio_[1] * 100, 2)}%)",
628
- "missPerc": "Missing Prop.",
629
- "Population": "Population",
630
- }
806
+ def _queue_multiqc_roc_curves(
807
+ self,
808
+ *,
809
+ fpr: dict,
810
+ tpr: dict,
811
+ label_names: Sequence[str],
812
+ panel_prefix: str,
813
+ ) -> None:
814
+ """Queue ROC and Precision-Recall curves for MultiQC.
631
815
 
632
- if z is not None:
633
- labs[
634
- "Axis3"
635
- ] = f"PC3 ({round(model.explained_variance_ratio_[2] * 100, 2)}%)"
636
- fig = px.scatter_3d(
637
- df_pca,
638
- x="Axis1",
639
- y="Axis2",
640
- z="Axis3",
641
- color="missPerc",
642
- symbol="Population",
643
- color_continuous_scale=my_scale,
644
- custom_data=["Axis3", "SampleID", "Population", "missPerc"],
645
- size="Size",
646
- size_max=point_size,
647
- labels=labs,
648
- )
649
- else:
650
- fig = px.scatter(
651
- df_pca,
652
- x="Axis1",
653
- y="Axis2",
654
- color="missPerc",
655
- symbol="Population",
656
- color_continuous_scale=my_scale,
657
- custom_data=["Axis3", "SampleID", "Population", "missPerc"],
658
- size="Size",
659
- size_max=point_size,
660
- labels=labs,
661
- )
662
- fig.update_traces(
663
- hovertemplate="<br>".join(
664
- [
665
- "Axis 1: %{x}",
666
- "Axis 2: %{y}",
667
- "Axis 3: %{customdata[0]}",
668
- "Sample ID: %{customdata[1]}",
669
- "Population: %{customdata[2]}",
670
- "Missing Prop.: %{customdata[3]}",
671
- ]
816
+ Args:
817
+ fpr (dict): False positive rates for each class.
818
+ tpr (dict): True positive rates for each class.
819
+ label_names (Sequence[str]): Class names.
820
+ panel_prefix (str): Optional prefix for panel IDs.
821
+ """
822
+ if not self._multiqc_enabled():
823
+ return
824
+
825
+ def _curve_to_mapping(
826
+ x_vals: Sequence[float], y_vals: Sequence[float]
827
+ ) -> Dict[float, float]:
828
+ """Return {x: y} mapping expected by MultiQC linegraphs."""
829
+ return {float(x): float(y) for x, y in zip(x_vals, y_vals)}
830
+
831
+ data: Dict[str, Dict[float, float]] = {}
832
+
833
+ # Only report the first three classes (MultiQC plot readability) plus micro/macro averages
834
+ class_keys = sorted(k for k in fpr.keys() if isinstance(k, int))
835
+ for idx in class_keys[:3]:
836
+ label = label_names[idx] if idx < len(label_names) else f"Class {idx}"
837
+ data[label] = _curve_to_mapping(fpr[idx], tpr[idx])
838
+
839
+ for agg in ("micro", "macro"):
840
+ if agg in fpr and agg in tpr:
841
+ pretty_name = f"{agg.title()} Average"
842
+ data[pretty_name] = _curve_to_mapping(fpr[agg], tpr[agg])
843
+
844
+ if not data:
845
+ return
846
+
847
+ # ROC curves
848
+ curve_data = cast(Dict[str, Dict[int, int]], data)
849
+
850
+ SNPioMultiQC.queue_linegraph(
851
+ data=curve_data,
852
+ panel_id=(
853
+ f"{self.model_name}_{panel_prefix}_roc_curves"
854
+ if panel_prefix
855
+ else f"{self.model_name}_roc_curves"
672
856
  ),
857
+ section=self.multiqc_section,
858
+ title=f"{self.model_name} ROC Curves",
859
+ index_label="False Positive Rate",
860
+ description="Multi-class ROC curves for PG-SUI predictions.",
673
861
  )
674
- fig.update_layout(
675
- showlegend=True,
676
- margin=dict(
677
- b=bottom_margin,
678
- t=top_margin,
679
- l=left_margin,
680
- r=right_margin,
862
+
863
+ def _queue_multiqc_pr_curves(
864
+ self,
865
+ *,
866
+ precision: dict,
867
+ recall: dict,
868
+ label_names: Sequence[str],
869
+ panel_prefix: str,
870
+ ) -> None:
871
+ """Queue Precision-Recall curves for MultiQC."""
872
+ if not self._multiqc_enabled():
873
+ return
874
+
875
+ def _curve_to_mapping(
876
+ x_vals: Sequence[float], y_vals: Sequence[float]
877
+ ) -> Dict[float, float]:
878
+ """Return {recall: precision} mapping expected by MultiQC linegraphs."""
879
+ return {float(x): float(y) for x, y in zip(x_vals, y_vals)}
880
+
881
+ data: Dict[str, Dict[float, float]] = {}
882
+
883
+ # Only report the first three classes (MultiQC plot readability) plus micro/macro averages
884
+ class_keys = sorted(k for k in recall.keys() if isinstance(k, int))
885
+ for idx in class_keys[:3]:
886
+ if idx not in precision or idx not in recall:
887
+ continue
888
+ label = label_names[idx] if idx < len(label_names) else f"Class {idx}"
889
+ data[label] = _curve_to_mapping(recall[idx], precision[idx])
890
+
891
+ for agg in ("micro", "macro"):
892
+ if agg in precision and agg in recall:
893
+ pretty_name = f"{agg.title()} Average"
894
+ data[pretty_name] = _curve_to_mapping(recall[agg], precision[agg])
895
+
896
+ if not data:
897
+ return
898
+
899
+ curve_data = cast(Dict[str, Dict[int, int]], data)
900
+
901
+ SNPioMultiQC.queue_linegraph(
902
+ data=curve_data,
903
+ panel_id=(
904
+ f"{self.model_name}_{panel_prefix}_pr_curves"
905
+ if panel_prefix
906
+ else f"{self.model_name}_pr_curves"
681
907
  ),
682
- width=width,
683
- height=height,
684
- legend_orientation="h",
685
- legend_title="Population",
686
- legend_title_font=dict(size=font_size),
687
- legend_title_side="top",
688
- font=dict(size=font_size),
689
- )
690
- fig.write_html(os.path.join(report_path, "imputed_pca.html"))
691
- fig.write_image(
692
- os.path.join(report_path, f"imputed_pca.{plot_format}"),
908
+ section=self.multiqc_section,
909
+ title=f"{self.model_name} Precision-Recall Curves",
910
+ index_label="Recall",
911
+ description="Multi-class Precision-Recall curves for PG-SUI predictions.",
693
912
  )
694
913
 
695
- return components, model
696
-
697
- @staticmethod
698
- def plot_history(lod, nn_method, prefix="imputer"):
699
- """Plot model history traces. Will be saved to file.
914
+ def _queue_multiqc_metrics(
915
+ self,
916
+ *,
917
+ metrics: Dict[str, float],
918
+ roc_auc: Dict[object, float],
919
+ average_precision: Dict[object, float],
920
+ label_names: Sequence[str],
921
+ panel_prefix: str,
922
+ ) -> None:
923
+ """Queue summary metrics and per-class AUC/AP for MultiQC.
700
924
 
701
925
  Args:
702
- lod (List[tf.keras.callbacks.History]): List of history objects.
703
- nn_method (str): Neural network method to plot. Possible options include: 'NLPCA', 'UBP', or 'VAE'. NLPCA and VAE get plotted the same, but UBP does it differently due to its three phases.
704
- prefix (str, optional): Prefix to use for output directory. Defaults to 'imputer'.
705
-
706
- Raises:
707
- ValueError: nn_method must be either 'NLPCA', 'UBP', or 'VAE'.
926
+ metrics (Dict[str, float]): Summary metrics (accuracy, F1, etc.).
927
+ roc_auc (Dict[object, float]): Per-class and aggregate ROC-AUC values.
928
+ average_precision (Dict[object, float]): Per-class and aggregate average precision values.
929
+ label_names (Sequence[str]): Class names.
930
+ panel_prefix (str): Optional prefix for panel IDs.
708
931
  """
709
- if nn_method == "NLPCA" or nn_method == "VAE" or nn_method == "SAE":
710
- title = nn_method
711
- fn = os.path.join(
712
- f"{prefix}_output",
713
- "plots",
714
- "Unsupervised",
715
- nn_method,
716
- "histplot.pdf",
932
+ if not self._multiqc_enabled():
933
+ return
934
+
935
+ # Summary metrics table (accuracy, F1, etc.)
936
+ if metrics:
937
+ series = pd.Series(metrics, name="Value")
938
+ SNPioMultiQC.queue_table(
939
+ df=series,
940
+ panel_id=f"{self.model_name}_summary_metrics",
941
+ section=self.multiqc_section,
942
+ title=f"{self.model_name} Summary Metrics",
943
+ index_label="Metric",
944
+ description="Global evaluation metrics produced by PG-SUI.",
717
945
  )
718
946
 
719
- if nn_method == "VAE":
720
- fig, axes = plt.subplots(2, 2)
721
- ax1 = axes[0, 0]
722
- ax2 = axes[0, 1]
723
- # ax3 = axes[1, 0]
724
- # ax4 = axes[1, 1]
725
- else:
726
- fig, (ax1, ax2) = plt.subplots(1, 2)
727
- fig.suptitle(title)
728
- fig.tight_layout(h_pad=3.0, w_pad=3.0)
729
- history = lod[0]
730
-
731
- acctrain = (
732
- "categorical_accuracy" if nn_method == "NLPCA" else "accuracy"
947
+ # Per-class ROC-AUC and AP heatmap
948
+ rows: List[Dict[str, float | str]] = []
949
+
950
+ # integer keys are classes; others are 'micro', 'macro'
951
+ class_keys = [k for k in roc_auc.keys() if isinstance(k, int)]
952
+ class_keys_sorted = sorted(class_keys)
953
+
954
+ for i in class_keys_sorted:
955
+ class_name = label_names[i] if i < len(label_names) else f"Class {i}"
956
+ rows.append(
957
+ {
958
+ "Class": str(class_name),
959
+ "ROC_AUC": float(roc_auc.get(i, np.nan)),
960
+ "AveragePrecision": float(average_precision.get(i, np.nan)),
961
+ }
733
962
  )
734
963
 
735
- # if nn_method == "VAE":
736
- # accval = "val_accuracy"
737
- # # recon_loss = "reconstruction_loss"
738
- # # kl_loss = "kl_loss"
739
- # # val_recon_loss = "val_reconstruction_loss"
740
- # # val_kl_loss = "val_kl_loss"
741
- # lossval = "val_loss"
964
+ for agg in ("micro", "macro"):
965
+ if agg in roc_auc:
966
+ rows.append(
967
+ {
968
+ "Class": agg,
969
+ "ROC_AUC": float(roc_auc.get(agg, np.nan)),
970
+ "AveragePrecision": float(average_precision.get(agg, np.nan)),
971
+ }
972
+ )
742
973
 
743
- if nn_method == "SAE":
744
- accval = "val_accuracy"
745
- lossval = "val_loss"
974
+ if not rows:
975
+ return
976
+
977
+ df = pd.DataFrame(rows).set_index("Class")
978
+ suffix = f"{panel_prefix}_" if panel_prefix else ""
979
+ panel_id = f"{self.model_name}_{suffix}roc_pr_summary"
980
+
981
+ SNPioMultiQC.queue_heatmap(
982
+ df=df,
983
+ panel_id=panel_id,
984
+ section=self.multiqc_section,
985
+ title=f"{self.model_name} ROC-AUC and Average Precision",
986
+ index_label="Class",
987
+ description=(
988
+ "Per-class ROC-AUC and average precision for PG-SUI predictions (including micro/macro averages where available)."
989
+ ),
990
+ )
746
991
 
747
- # Plot train accuracy
748
- ax1.plot(history[acctrain])
749
- ax1.set_title("Model Accuracy")
750
- ax1.set_ylabel("Accuracy")
751
- ax1.set_xlabel("Epoch")
752
- ax1.set_ylim(bottom=0.0, top=1.0)
753
- ax1.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
754
-
755
- labels = ["Train"]
756
- if nn_method == "SAE":
757
- # Plot validation accuracy
758
- ax1.plot(history[accval])
759
- labels.append("Validation")
760
-
761
- ax1.legend(labels, loc="best")
762
-
763
- # Plot model loss
764
- # if nn_method == "VAE":
765
- # # Reconstruction loss only.
766
- # ax2.plot(history["loss"])
767
- # ax2.plot(history[val_recon_loss])
768
-
769
- # # KL Loss
770
- # ax3.plot(history[kl_loss])
771
- # ax3.plot(history[val_kl_loss])
772
- # ax3.set_title("KL Divergence Loss")
773
- # ax3.set_ylabel("Loss")
774
- # ax3.set_xlabel("Epoch")
775
- # ax3.legend(labels, loc="best")
776
-
777
- # Total Loss (Reconstruction Loss + KL Loss)
778
- # ax4.plot(history["loss"])
779
- # ax4.plot(history[lossval])
780
- # ax4.set_title("Total Loss (Recon. + KL)")
781
- # ax4.set_ylabel("Loss")
782
- # ax4.set_xlabel("Epoch")
783
- # ax4.legend(labels, loc="best")
784
-
785
- # else:
786
- ax2.plot(history["loss"])
787
-
788
- if nn_method == "SAE":
789
- ax2.plot(history[lossval])
790
-
791
- ax2.set_title("Total Loss")
792
- ax2.set_ylabel("Loss")
793
- ax2.set_xlabel("Epoch")
794
- ax2.legend(labels, loc="best")
795
-
796
- fig.savefig(fn, bbox_inches="tight", facecolor="white")
992
+ def _queue_multiqc_history(
993
+ self,
994
+ *,
995
+ history: Dict[str, List[float] | Dict[str, List[float]] | None] | None,
996
+ ) -> None:
997
+ """Queue training history (loss vs epoch) for MultiQC.
797
998
 
798
- plt.close()
799
- plt.clf()
800
-
801
- elif nn_method == "UBP":
802
- fig = plt.figure(figsize=(12, 16))
803
- fig.suptitle(nn_method)
804
- fig.tight_layout(h_pad=2.0, w_pad=2.0)
805
- fn = os.path.join(
806
- f"{prefix}_output",
807
- "plots",
808
- "Unsupervised",
809
- nn_method,
810
- "histplot.pdf",
811
- )
999
+ Args:
1000
+ history (Dict[str, List[float]] | None): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
1001
+ """
1002
+ if not self._multiqc_enabled() or history is None:
1003
+ return
812
1004
 
813
- idx = 1
814
- for i, history in enumerate(lod, start=1):
815
- plt.subplot(3, 2, idx)
816
- title = f"Phase {i}"
817
-
818
- # Plot model accuracy
819
- ax = plt.gca()
820
- ax.plot(history["categorical_accuracy"])
821
- ax.set_title(f"{title} Accuracy")
822
- ax.set_ylabel("Accuracy")
823
- ax.set_xlabel("Epoch")
824
- ax.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
825
- ax.legend(["Training"], loc="best")
826
-
827
- # Plot model loss
828
- plt.subplot(3, 2, idx + 1)
829
- ax = plt.gca()
830
- ax.plot(history["loss"])
831
- ax.set_title(f"{title} Loss")
832
- ax.set_ylabel("Loss (MSE)")
833
- ax.set_xlabel("Epoch")
834
- ax.legend(["Train"], loc="best")
835
-
836
- idx += 2
837
-
838
- plt.savefig(fn, bbox_inches="tight", facecolor="white")
1005
+ data: Dict[str, Dict[int, int]] = {}
839
1006
 
840
- plt.close()
841
- plt.clf()
1007
+ if self.model_name != "ImputeUBP":
1008
+ if not isinstance(history, dict) or "Train" not in history:
1009
+ return
842
1010
 
843
- else:
844
- raise ValueError(
845
- f"nn_method must be either 'NLPCA', 'UBP', or 'VAE', but got {nn_method}"
846
- )
1011
+ train_vals = pd.Series(history["Train"]).iloc[1:]
847
1012
 
848
- @staticmethod
849
- def plot_certainty_heatmap(
850
- y_certainty, sample_ids=None, nn_method="VAE", prefix="imputer"
851
- ):
852
- fig = plt.figure()
853
- hm = sns.heatmap(
854
- data=y_certainty,
855
- cmap="viridis",
856
- vmin=0.0,
857
- vmax=1.0,
858
- cbar_kws={"label": "Prob."},
859
- )
860
- hm.set_xlabel("Site")
861
- hm.set_ylabel("Sample")
862
- hm.set_title("Probabilities of Uncertain Sites")
863
- fig.tight_layout()
864
- fig.savefig(
865
- os.path.join(
866
- f"{prefix}_output",
867
- "plots",
868
- "Unsupervised",
869
- nn_method,
870
- "uncertainty_plot.png",
871
- ),
872
- bbox_inches="tight",
873
- facecolor="white",
1013
+ data["Train"] = {
1014
+ epoch: val for epoch, val in enumerate(train_vals.values, start=1)
1015
+ }
1016
+ else:
1017
+ if not (
1018
+ isinstance(history, dict)
1019
+ and "Train" in history
1020
+ and isinstance(history["Train"], dict)
1021
+ ):
1022
+ return
1023
+ for phase in range(1, 4):
1024
+ key = f"Phase {phase}"
1025
+ if key not in history["Train"]:
1026
+ continue
1027
+ series = pd.Series(history["Train"][key]).iloc[1:]
1028
+ data[key] = {
1029
+ epoch: val for epoch, val in enumerate(series.values, start=1)
1030
+ }
1031
+
1032
+ if not data:
1033
+ return
1034
+
1035
+ SNPioMultiQC.queue_linegraph(
1036
+ data=data,
1037
+ panel_id=f"{self.model_name}_training_history",
1038
+ section=self.multiqc_section,
1039
+ title=f"{self.model_name} Training Loss per Epoch",
1040
+ index_label="Epoch",
1041
+ description="Training loss trajectory by epoch as recorded by PG-SUI.",
874
1042
  )
875
1043
 
876
- @staticmethod
877
- def plot_confusion_matrix(
878
- y_true_1d, y_pred_1d, nn_method, prefix="imputer"
879
- ):
880
- fig, ax = plt.subplots(1, 1, figsize=(15, 15))
881
- ConfusionMatrixDisplay.from_predictions(
882
- y_true=y_true_1d, y_pred=y_pred_1d, ax=ax
883
- )
1044
+ def _queue_multiqc_confusion(
1045
+ self,
1046
+ *,
1047
+ y_true: np.ndarray,
1048
+ y_pred: np.ndarray,
1049
+ labels: np.ndarray,
1050
+ display_labels: List[str] | np.ndarray,
1051
+ panel_id: str,
1052
+ ) -> None:
1053
+ """Queue confusion-matrix heatmap for MultiQC.
884
1054
 
885
- outfile = os.path.join(
886
- f"{prefix}_output",
887
- "plots",
888
- "Unsupervised",
889
- nn_method,
890
- f"confusion_matrix_{nn_method}.png",
1055
+ Args:
1056
+ y_true (np.ndarray): 1D array of true integer labels.
1057
+ y_pred (np.ndarray): 1D array of predicted integer labels.
1058
+ labels (np.ndarray): Array of label indices to index the confusion matrix.
1059
+ display_labels (List[str] | np.ndarray): Labels to display on axes.
1060
+ panel_id (str): Panel ID for MultiQC.
1061
+ """
1062
+ if not self._multiqc_enabled():
1063
+ return
1064
+
1065
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
1066
+ df_cm = pd.DataFrame(cm, index=display_labels, columns=display_labels)
1067
+
1068
+ SNPioMultiQC.queue_heatmap(
1069
+ df=df_cm,
1070
+ panel_id=panel_id,
1071
+ section=self.multiqc_section,
1072
+ title=f"{self.model_name} Confusion Matrix",
1073
+ index_label="True Label",
1074
+ description=(
1075
+ "Confusion matrix for PG-SUI predictions. Rows correspond to true "
1076
+ "labels; columns correspond to predicted labels."
1077
+ ),
891
1078
  )
892
1079
 
893
- if os.path.isfile(outfile):
894
- os.remove(outfile)
895
-
896
- fig.savefig(outfile, facecolor="white")
897
-
898
- @staticmethod
899
- def plot_gt_distribution(df, plot_path):
900
- df = misc.validate_input_type(df, return_type="df")
901
- df_melt = pd.melt(df, value_name="Count")
902
- cnts = df_melt["Count"].value_counts()
903
- cnts.index.names = ["Genotype"]
904
- cnts = pd.DataFrame(cnts).reset_index()
905
- cnts.sort_values(by="Genotype", inplace=True)
906
- cnts["Genotype"] = cnts["Genotype"].astype(str)
907
-
908
- fig, ax = plt.subplots(1, 1, figsize=(15, 15))
909
- g = sns.barplot(x="Genotype", y="Count", data=cnts, ax=ax)
910
- g.set_xlabel("Integer-encoded Genotype")
911
- g.set_ylabel("Count")
912
- g.set_title("Genotype Counts")
913
- for p in g.patches:
914
- g.annotate(
915
- f"{p.get_height():.1f}",
916
- (p.get_x() + 0.25, p.get_height() + 0.01),
917
- xytext=(0, 1),
918
- textcoords="offset points",
919
- va="bottom",
920
- )
1080
+ def _queue_multiqc_gt_distribution(
1081
+ self,
1082
+ *,
1083
+ df: pd.DataFrame,
1084
+ is_imputed: bool,
1085
+ ) -> None:
1086
+ """Queue genotype-distribution barplot for MultiQC.
921
1087
 
922
- fig.savefig(
923
- os.path.join(plot_path, "genotype_distributions.png"),
924
- bbox_inches="tight",
925
- facecolor="white",
1088
+ Args:
1089
+ df (pd.DataFrame): DataFrame with 'Genotype' and 'Percent' columns
1090
+ is_imputed (bool): Whether these genotypes are imputed.
1091
+ """
1092
+ if not self._multiqc_enabled():
1093
+ return
1094
+
1095
+ if "Genotype" not in df.columns or "Percent" not in df.columns:
1096
+ return
1097
+
1098
+ series = df.set_index("Genotype")["Percent"]
1099
+ suffix = "imputed" if is_imputed else "original"
1100
+ title = (
1101
+ f"{self.model_name} Imputed Genotype Distribution"
1102
+ if is_imputed
1103
+ else f"{self.model_name} Genotype Distribution"
926
1104
  )
927
- plt.close()
928
-
929
- @staticmethod
930
- def plot_label_clusters(z_mean, labels, prefix="imputer"):
931
- """Display a 2D plot of the classes in the latent space."""
932
- fig, ax = plt.subplots(1, 1, figsize=(15, 15))
933
-
934
- sns.scatterplot(x=z_mean[:, 0], y=z_mean[:, 1], ax=ax)
935
- ax.set_xlabel("Latent Dimension 1")
936
- ax.set_ylabel("Latent Dimension 2")
937
1105
 
938
- outfile = os.path.join(
939
- f"{prefix}_output",
940
- "plots",
941
- "Unsupervised",
942
- "VAE",
943
- "label_clusters.png",
1106
+ SNPioMultiQC.queue_barplot(
1107
+ df=series,
1108
+ panel_id=f"{self.model_name}_gt_distribution_{suffix}",
1109
+ section=self.multiqc_section,
1110
+ title=title,
1111
+ index_label="Genotype",
1112
+ value_label="Percent",
1113
+ description=(
1114
+ "Genotype frequency distribution (percent of calls per genotype) computed by PG-SUI."
1115
+ ),
944
1116
  )
945
-
946
- if os.path.isfile(outfile):
947
- os.remove(outfile)
948
-
949
- fig.savefig(outfile, facecolor="white", bbox_inches="tight")