pheval 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pheval might be problematic. Click here for more details.
- pheval/analyse/benchmark.py +156 -0
- pheval/analyse/benchmark_db_manager.py +16 -134
- pheval/analyse/benchmark_output_type.py +43 -0
- pheval/analyse/binary_classification_curves.py +132 -0
- pheval/analyse/binary_classification_stats.py +164 -307
- pheval/analyse/generate_plots.py +210 -395
- pheval/analyse/generate_rank_comparisons.py +44 -0
- pheval/analyse/rank_stats.py +190 -382
- pheval/analyse/run_data_parser.py +21 -39
- pheval/cli.py +27 -24
- pheval/cli_pheval_utils.py +7 -8
- pheval/post_processing/phenopacket_truth_set.py +250 -0
- pheval/post_processing/post_processing.py +179 -345
- pheval/post_processing/validate_result_format.py +91 -0
- pheval/prepare/update_phenopacket.py +11 -9
- pheval/utils/logger.py +35 -0
- pheval/utils/phenopacket_utils.py +85 -91
- {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/METADATA +4 -4
- {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/RECORD +22 -26
- pheval/analyse/analysis.py +0 -104
- pheval/analyse/assess_prioritisation_base.py +0 -108
- pheval/analyse/benchmark_generator.py +0 -126
- pheval/analyse/benchmarking_data.py +0 -25
- pheval/analyse/disease_prioritisation_analysis.py +0 -152
- pheval/analyse/gene_prioritisation_analysis.py +0 -147
- pheval/analyse/generate_summary_outputs.py +0 -105
- pheval/analyse/parse_benchmark_summary.py +0 -81
- pheval/analyse/parse_corpus.py +0 -219
- pheval/analyse/prioritisation_result_types.py +0 -52
- pheval/analyse/variant_prioritisation_analysis.py +0 -159
- {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/LICENSE +0 -0
- {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/WHEEL +0 -0
- {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/entry_points.txt +0 -0
pheval/analyse/generate_plots.py
CHANGED
|
@@ -1,35 +1,32 @@
|
|
|
1
|
+
from enum import Enum
|
|
1
2
|
from pathlib import Path
|
|
2
|
-
from typing import List
|
|
3
3
|
|
|
4
|
+
import duckdb
|
|
4
5
|
import matplotlib
|
|
5
|
-
import
|
|
6
|
-
import pandas as pd
|
|
6
|
+
import polars as pl
|
|
7
7
|
import seaborn as sns
|
|
8
8
|
from matplotlib import pyplot as plt
|
|
9
|
-
from sklearn.metrics import auc
|
|
9
|
+
from sklearn.metrics import auc
|
|
10
10
|
|
|
11
|
-
from pheval.analyse.
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
VariantBenchmarkRunOutputGenerator,
|
|
11
|
+
from pheval.analyse.benchmark_db_manager import load_table_lazy
|
|
12
|
+
from pheval.analyse.benchmark_output_type import (
|
|
13
|
+
BenchmarkOutputType,
|
|
14
|
+
BenchmarkOutputTypeEnum,
|
|
16
15
|
)
|
|
17
|
-
from pheval.analyse.
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
from pheval.analyse.run_data_parser import (
|
|
17
|
+
PlotCustomisation,
|
|
18
|
+
SinglePlotCustomisation,
|
|
19
|
+
parse_run_config,
|
|
20
|
+
)
|
|
21
|
+
from pheval.utils.logger import get_logger
|
|
21
22
|
|
|
22
|
-
|
|
23
|
-
"""
|
|
24
|
-
Trim the suffix from the corpus results directory name.
|
|
23
|
+
logger = get_logger()
|
|
25
24
|
|
|
26
|
-
Args:
|
|
27
|
-
corpus_results_directory (Path): The directory path containing corpus results.
|
|
28
25
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
""
|
|
32
|
-
|
|
26
|
+
class PlotTypes(Enum):
|
|
27
|
+
BAR_STACKED = "bar_stacked"
|
|
28
|
+
BAR_CUMULATIVE = "bar_cumulative"
|
|
29
|
+
BAR_NON_CUMULATIVE = "bar_non_cumulative"
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
class PlotGenerator:
|
|
@@ -52,216 +49,132 @@ class PlotGenerator:
|
|
|
52
49
|
"""
|
|
53
50
|
Initialise the PlotGenerator class.
|
|
54
51
|
Note:
|
|
55
|
-
`self.stats` will be used to store statistics data.
|
|
56
|
-
`self.mrr` will store Mean Reciprocal Rank (MRR) values.
|
|
57
52
|
Matplotlib settings are configured to remove the right and top axes spines
|
|
58
53
|
for generated plots.
|
|
59
54
|
"""
|
|
60
55
|
self.benchmark_name = benchmark_name
|
|
61
|
-
self.stats, self.mrr = [], []
|
|
62
56
|
matplotlib.rcParams["axes.spines.right"] = False
|
|
63
57
|
matplotlib.rcParams["axes.spines.top"] = False
|
|
64
58
|
|
|
65
59
|
@staticmethod
|
|
66
|
-
def
|
|
60
|
+
def _generate_stacked_data(benchmarking_stats_df: pl.DataFrame) -> pl.DataFrame:
|
|
67
61
|
"""
|
|
68
|
-
|
|
69
|
-
|
|
62
|
+
Generate stacked data.
|
|
70
63
|
Args:
|
|
71
|
-
|
|
72
|
-
|
|
64
|
+
benchmarking_stats_df (pl.DataFrame): benchmarking stats dataframe.
|
|
73
65
|
Returns:
|
|
74
|
-
|
|
66
|
+
pl.DataFrame: Data formatted for plotting stacked data.
|
|
75
67
|
"""
|
|
76
|
-
return
|
|
68
|
+
return benchmarking_stats_df.with_columns(
|
|
69
|
+
[
|
|
70
|
+
pl.col("run_identifier").alias("Run"),
|
|
71
|
+
pl.col("percentage@1").alias("Top"),
|
|
72
|
+
(pl.col("percentage@3") - pl.col("percentage@1")).alias("2-3"),
|
|
73
|
+
(pl.col("percentage@5") - pl.col("percentage@3")).alias("4-5"),
|
|
74
|
+
(pl.col("percentage@10") - pl.col("percentage@5")).alias("6-10"),
|
|
75
|
+
(pl.col("percentage_found") - pl.col("percentage@10")).alias(">10"),
|
|
76
|
+
(100 - pl.col("percentage_found")).alias("Missed"),
|
|
77
|
+
]
|
|
78
|
+
).select(["Run", "Top", "2-3", "4-5", "6-10", ">10", "Missed"])
|
|
77
79
|
|
|
78
|
-
|
|
80
|
+
@staticmethod
|
|
81
|
+
def _extract_mrr_data(benchmarking_results_df: pl.DataFrame) -> pl.DataFrame:
|
|
79
82
|
"""
|
|
80
|
-
|
|
83
|
+
Generate data in the correct format for dataframe creation for MRR (Mean Reciprocal Rank) bar plot.
|
|
81
84
|
|
|
82
85
|
Args:
|
|
83
|
-
|
|
84
|
-
|
|
86
|
+
benchmarking_results_df (pl.DataFrame): benchmarking stats dataframe.
|
|
85
87
|
Returns:
|
|
86
|
-
|
|
87
|
-
"""
|
|
88
|
-
return (
|
|
89
|
-
benchmark_result.benchmark_name
|
|
90
|
-
if benchmark_result.results_dir is None
|
|
91
|
-
else self._create_run_identifier(benchmark_result.results_dir)
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
def _generate_stacked_bar_plot_data(self, benchmark_result: BenchmarkRunResults) -> None:
|
|
95
|
-
"""
|
|
96
|
-
Generate data in the correct format for dataframe creation for a stacked bar plot,
|
|
97
|
-
appending to the self.stats attribute of the class.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
|
|
88
|
+
pl.DataFrame: Data formatted for plotting MRR bar plot.
|
|
101
89
|
"""
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
{
|
|
105
|
-
"Run": self.return_benchmark_name(benchmark_result),
|
|
106
|
-
"Top": benchmark_result.rank_stats.percentage_top(),
|
|
107
|
-
"2-3": rank_stats.percentage_difference(
|
|
108
|
-
rank_stats.percentage_top3(), rank_stats.percentage_top()
|
|
109
|
-
),
|
|
110
|
-
"4-5": rank_stats.percentage_difference(
|
|
111
|
-
rank_stats.percentage_top5(), rank_stats.percentage_top3()
|
|
112
|
-
),
|
|
113
|
-
"6-10": rank_stats.percentage_difference(
|
|
114
|
-
rank_stats.percentage_top10(), rank_stats.percentage_top5()
|
|
115
|
-
),
|
|
116
|
-
">10": rank_stats.percentage_difference(
|
|
117
|
-
rank_stats.percentage_found(), rank_stats.percentage_top10()
|
|
118
|
-
),
|
|
119
|
-
"Missed": rank_stats.percentage_difference(100, rank_stats.percentage_found()),
|
|
120
|
-
}
|
|
90
|
+
return benchmarking_results_df.select(["run_identifier", "mrr"]).rename(
|
|
91
|
+
{"run_identifier": "Run", "mrr": "Percentage"}
|
|
121
92
|
)
|
|
122
93
|
|
|
123
|
-
def
|
|
94
|
+
def _save_fig(
|
|
95
|
+
self, benchmark_output_type: BenchmarkOutputType, y_lower_limit: int, y_upper_limit: int
|
|
96
|
+
) -> None:
|
|
124
97
|
"""
|
|
125
|
-
|
|
126
|
-
appending to the self.mrr attribute of the class.
|
|
127
|
-
|
|
98
|
+
Save the generated figure.
|
|
128
99
|
Args:
|
|
129
|
-
|
|
100
|
+
benchmark_output_type (BenchmarkOutputType): Benchmark output type.
|
|
101
|
+
y_lower_limit (int): Lower limit for the y-axis.
|
|
102
|
+
y_upper_limit (int): Upper limit for the y-axis.
|
|
130
103
|
"""
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
"Run": self.return_benchmark_name(benchmark_result),
|
|
137
|
-
}
|
|
138
|
-
]
|
|
104
|
+
plt.ylim(y_lower_limit, y_upper_limit)
|
|
105
|
+
plt.savefig(
|
|
106
|
+
f"{self.benchmark_name}_{benchmark_output_type.prioritisation_type_string}_rank_stats.svg",
|
|
107
|
+
format="svg",
|
|
108
|
+
bbox_inches="tight",
|
|
139
109
|
)
|
|
140
110
|
|
|
141
111
|
def generate_stacked_bar_plot(
|
|
142
112
|
self,
|
|
143
|
-
|
|
144
|
-
|
|
113
|
+
benchmarking_results_df: pl.DataFrame,
|
|
114
|
+
benchmark_output_type: BenchmarkOutputType,
|
|
115
|
+
plot_customisation: SinglePlotCustomisation,
|
|
145
116
|
) -> None:
|
|
146
117
|
"""
|
|
147
118
|
Generate a stacked bar plot and Mean Reciprocal Rank (MRR) bar plot.
|
|
148
|
-
|
|
149
119
|
Args:
|
|
150
|
-
|
|
151
|
-
|
|
120
|
+
benchmarking_results_df (pl.DataFrame): benchmarking stats dataframe.
|
|
121
|
+
benchmark_output_type (BenchmarkOutputType): Benchmark output type.
|
|
122
|
+
plot_customisation (SinglePlotCustomisation): Plotting customisation.
|
|
152
123
|
"""
|
|
153
|
-
for benchmark_result in benchmarking_results:
|
|
154
|
-
self._generate_stacked_bar_plot_data(benchmark_result)
|
|
155
|
-
self._generate_stats_mrr_bar_plot_data(benchmark_result)
|
|
156
|
-
stats_df = pd.DataFrame(self.stats)
|
|
157
124
|
plt.clf()
|
|
158
|
-
stats_df.
|
|
125
|
+
stats_df = self._generate_stacked_data(benchmarking_results_df)
|
|
126
|
+
stats_df.to_pandas().set_index("Run").plot(
|
|
159
127
|
kind="bar",
|
|
160
128
|
stacked=True,
|
|
161
129
|
color=self.palette_hex_codes,
|
|
162
|
-
ylabel=
|
|
130
|
+
ylabel=benchmark_output_type.y_label,
|
|
163
131
|
edgecolor="white",
|
|
164
132
|
).legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
|
|
170
|
-
)
|
|
171
|
-
plt.ylim(0, 100)
|
|
172
|
-
plt.savefig(
|
|
173
|
-
f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
|
|
174
|
-
format="svg",
|
|
175
|
-
bbox_inches="tight",
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
mrr_df = pd.DataFrame(self.mrr)
|
|
179
|
-
mrr_df.set_index("Run").plot(
|
|
133
|
+
plt.title(plot_customisation.rank_plot_title, loc="center", fontsize=15)
|
|
134
|
+
self._save_fig(benchmark_output_type, 0, 100)
|
|
135
|
+
mrr_df = self._extract_mrr_data(benchmarking_results_df)
|
|
136
|
+
mrr_df.to_pandas().set_index("Run").plot(
|
|
180
137
|
kind="bar",
|
|
181
138
|
color=self.palette_hex_codes,
|
|
182
|
-
ylabel=f"{
|
|
139
|
+
ylabel=f"{benchmark_output_type.prioritisation_type_string.capitalize()} mean reciprocal rank",
|
|
183
140
|
legend=False,
|
|
184
141
|
edgecolor="white",
|
|
185
142
|
)
|
|
186
143
|
plt.title(
|
|
187
|
-
f"{
|
|
188
|
-
)
|
|
189
|
-
plt.ylim(0, 1)
|
|
190
|
-
plt.savefig(
|
|
191
|
-
f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_mrr.svg",
|
|
192
|
-
format="svg",
|
|
193
|
-
bbox_inches="tight",
|
|
144
|
+
f"{benchmark_output_type.prioritisation_type_string.capitalize()} results - mean reciprocal rank"
|
|
194
145
|
)
|
|
146
|
+
self._save_fig(benchmark_output_type, 0, 1)
|
|
195
147
|
|
|
196
|
-
|
|
148
|
+
@staticmethod
|
|
149
|
+
def _generate_cumulative_bar_plot_data(benchmarking_results_df: pl.DataFrame) -> pl.DataFrame:
|
|
197
150
|
"""
|
|
198
151
|
Generate data in the correct format for dataframe creation for a cumulative bar plot,
|
|
199
152
|
appending to the self.stats attribute of the class.
|
|
200
|
-
|
|
201
|
-
Args:
|
|
202
|
-
benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
|
|
203
153
|
"""
|
|
204
|
-
|
|
205
|
-
run_identifier = self.return_benchmark_name(benchmark_result)
|
|
206
|
-
self.stats.extend(
|
|
154
|
+
return benchmarking_results_df.select(
|
|
207
155
|
[
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
"Percentage": rank_stats.percentage_top3() / 100,
|
|
216
|
-
"Run": run_identifier,
|
|
217
|
-
},
|
|
218
|
-
{
|
|
219
|
-
"Rank": "Top5",
|
|
220
|
-
"Percentage": rank_stats.percentage_top5() / 100,
|
|
221
|
-
"Run": run_identifier,
|
|
222
|
-
},
|
|
223
|
-
{
|
|
224
|
-
"Rank": "Top10",
|
|
225
|
-
"Percentage": rank_stats.percentage_top10() / 100,
|
|
226
|
-
"Run": run_identifier,
|
|
227
|
-
},
|
|
228
|
-
{
|
|
229
|
-
"Rank": "Found",
|
|
230
|
-
"Percentage": rank_stats.percentage_found() / 100,
|
|
231
|
-
"Run": run_identifier,
|
|
232
|
-
},
|
|
233
|
-
{
|
|
234
|
-
"Rank": "Missed",
|
|
235
|
-
"Percentage": rank_stats.percentage_difference(
|
|
236
|
-
100, rank_stats.percentage_found()
|
|
237
|
-
)
|
|
238
|
-
/ 100,
|
|
239
|
-
"Run": run_identifier,
|
|
240
|
-
},
|
|
241
|
-
{
|
|
242
|
-
"Rank": "MRR",
|
|
243
|
-
"Percentage": rank_stats.return_mean_reciprocal_rank(),
|
|
244
|
-
"Run": run_identifier,
|
|
245
|
-
},
|
|
156
|
+
pl.col("run_identifier").alias("Run"),
|
|
157
|
+
pl.col("percentage@1").alias("Top") / 100,
|
|
158
|
+
pl.col("percentage@3").alias("Top3") / 100,
|
|
159
|
+
pl.col("percentage@5").alias("Top5") / 100,
|
|
160
|
+
pl.col("percentage@10").alias("Top10") / 100,
|
|
161
|
+
pl.col("percentage_found").alias("Found") / 100,
|
|
162
|
+
pl.col("mrr").alias("MRR"),
|
|
246
163
|
]
|
|
247
164
|
)
|
|
248
165
|
|
|
249
|
-
def
|
|
166
|
+
def _plot_bar_plot(
|
|
250
167
|
self,
|
|
251
|
-
|
|
252
|
-
|
|
168
|
+
benchmark_output_type: BenchmarkOutputType,
|
|
169
|
+
stats_df: pl.DataFrame,
|
|
170
|
+
plot_customisation: SinglePlotCustomisation,
|
|
253
171
|
) -> None:
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
"""
|
|
261
|
-
for benchmark_result in benchmarking_results:
|
|
262
|
-
self._generate_cumulative_bar_plot_data(benchmark_result)
|
|
263
|
-
stats_df = pd.DataFrame(self.stats)
|
|
264
|
-
plt.clf()
|
|
172
|
+
stats_df = stats_df.to_pandas().melt(
|
|
173
|
+
id_vars=["Run"],
|
|
174
|
+
value_vars=["Top", "Top3", "Top5", "Top10", "Found", "MRR"],
|
|
175
|
+
var_name="Rank",
|
|
176
|
+
value_name="Percentage",
|
|
177
|
+
)
|
|
265
178
|
sns.catplot(
|
|
266
179
|
data=stats_df,
|
|
267
180
|
kind="bar",
|
|
@@ -271,132 +184,77 @@ class PlotGenerator:
|
|
|
271
184
|
palette=self.palette_hex_codes,
|
|
272
185
|
edgecolor="white",
|
|
273
186
|
legend=False,
|
|
274
|
-
).set(xlabel="Rank", ylabel=
|
|
187
|
+
).set(xlabel="Rank", ylabel=benchmark_output_type.y_label)
|
|
275
188
|
plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3, title="Run")
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
f"{benchmark_generator.prioritisation_type_string.capitalize()} Cumulative Rank Stats"
|
|
279
|
-
)
|
|
280
|
-
else:
|
|
281
|
-
plt.title(
|
|
282
|
-
benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
|
|
283
|
-
)
|
|
284
|
-
plt.ylim(0, 1)
|
|
285
|
-
plt.savefig(
|
|
286
|
-
f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
|
|
287
|
-
format="svg",
|
|
288
|
-
bbox_inches="tight",
|
|
289
|
-
)
|
|
189
|
+
plt.title(plot_customisation.rank_plot_title, loc="center", fontsize=15)
|
|
190
|
+
self._save_fig(benchmark_output_type, 0, 1)
|
|
290
191
|
|
|
291
192
|
def _generate_non_cumulative_bar_plot_data(
|
|
292
|
-
self,
|
|
293
|
-
) ->
|
|
193
|
+
self, benchmarking_results_df: pl.DataFrame
|
|
194
|
+
) -> pl.DataFrame:
|
|
294
195
|
"""
|
|
295
196
|
Generate data in the correct format for dataframe creation for a non-cumulative bar plot,
|
|
296
197
|
appending to the self.stats attribute of the class.
|
|
297
|
-
|
|
298
|
-
Args:
|
|
299
|
-
benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
|
|
300
198
|
"""
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
{
|
|
306
|
-
"Rank": "Top",
|
|
307
|
-
"Percentage": rank_stats.percentage_top() / 100,
|
|
308
|
-
"Run": run_identifier,
|
|
309
|
-
},
|
|
310
|
-
{
|
|
311
|
-
"Rank": "2-3",
|
|
312
|
-
"Percentage": rank_stats.percentage_difference(
|
|
313
|
-
rank_stats.percentage_top3(), rank_stats.percentage_top()
|
|
314
|
-
)
|
|
315
|
-
/ 100,
|
|
316
|
-
"Run": run_identifier,
|
|
317
|
-
},
|
|
318
|
-
{
|
|
319
|
-
"Rank": "4-5",
|
|
320
|
-
"Percentage": rank_stats.percentage_difference(
|
|
321
|
-
rank_stats.percentage_top5(), rank_stats.percentage_top3()
|
|
322
|
-
)
|
|
323
|
-
/ 100,
|
|
324
|
-
"Run": run_identifier,
|
|
325
|
-
},
|
|
326
|
-
{
|
|
327
|
-
"Rank": "6-10",
|
|
328
|
-
"Percentage": rank_stats.percentage_difference(
|
|
329
|
-
rank_stats.percentage_top10(), rank_stats.percentage_top5()
|
|
330
|
-
)
|
|
331
|
-
/ 100,
|
|
332
|
-
"Run": run_identifier,
|
|
333
|
-
},
|
|
334
|
-
{
|
|
335
|
-
"Rank": ">10",
|
|
336
|
-
"Percentage": rank_stats.percentage_difference(
|
|
337
|
-
rank_stats.percentage_found(), rank_stats.percentage_top10()
|
|
338
|
-
)
|
|
339
|
-
/ 100,
|
|
340
|
-
"Run": run_identifier,
|
|
341
|
-
},
|
|
342
|
-
{
|
|
343
|
-
"Rank": "Missed",
|
|
344
|
-
"Percentage": rank_stats.percentage_difference(
|
|
345
|
-
100, rank_stats.percentage_found()
|
|
346
|
-
)
|
|
347
|
-
/ 100,
|
|
348
|
-
"Run": run_identifier,
|
|
349
|
-
},
|
|
350
|
-
{
|
|
351
|
-
"Rank": "MRR",
|
|
352
|
-
"Percentage": rank_stats.return_mean_reciprocal_rank(),
|
|
353
|
-
"Run": run_identifier,
|
|
354
|
-
},
|
|
355
|
-
]
|
|
199
|
+
return self._generate_stacked_data(benchmarking_results_df).hstack(
|
|
200
|
+
self._extract_mrr_data(benchmarking_results_df).select(
|
|
201
|
+
pl.col("Percentage").alias("MRR")
|
|
202
|
+
)
|
|
356
203
|
)
|
|
357
204
|
|
|
205
|
+
def generate_cumulative_bar(
|
|
206
|
+
self,
|
|
207
|
+
benchmarking_results_df: pl.DataFrame,
|
|
208
|
+
benchmark_generator: BenchmarkOutputType,
|
|
209
|
+
plot_customisation: SinglePlotCustomisation,
|
|
210
|
+
) -> None:
|
|
211
|
+
"""
|
|
212
|
+
Generate a cumulative bar plot.
|
|
213
|
+
"""
|
|
214
|
+
plt.clf()
|
|
215
|
+
stats_df = self._generate_cumulative_bar_plot_data(benchmarking_results_df)
|
|
216
|
+
self._plot_bar_plot(benchmark_generator, stats_df, plot_customisation)
|
|
217
|
+
|
|
218
|
+
def generate_non_cumulative_bar(
|
|
219
|
+
self,
|
|
220
|
+
benchmarking_results_df: pl.DataFrame,
|
|
221
|
+
benchmark_generator: BenchmarkOutputType,
|
|
222
|
+
plot_customisation: SinglePlotCustomisation,
|
|
223
|
+
) -> None:
|
|
224
|
+
"""
|
|
225
|
+
Generate a non-cumulative bar plot.
|
|
226
|
+
"""
|
|
227
|
+
plt.clf()
|
|
228
|
+
stats_df = self._generate_non_cumulative_bar_plot_data(benchmarking_results_df)
|
|
229
|
+
self._plot_bar_plot(benchmark_generator, stats_df, plot_customisation)
|
|
230
|
+
|
|
358
231
|
def generate_roc_curve(
|
|
359
232
|
self,
|
|
360
|
-
|
|
361
|
-
benchmark_generator:
|
|
233
|
+
curves: pl.DataFrame,
|
|
234
|
+
benchmark_generator: BenchmarkOutputType,
|
|
235
|
+
plot_customisation: SinglePlotCustomisation,
|
|
362
236
|
):
|
|
363
237
|
"""
|
|
364
238
|
Generate and plot Receiver Operating Characteristic (ROC) curves for binary classification benchmark results.
|
|
365
239
|
|
|
366
240
|
Args:
|
|
367
|
-
benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
|
|
368
|
-
benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
|
|
369
241
|
"""
|
|
370
242
|
plt.clf()
|
|
371
|
-
for i,
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
nan=0.0,
|
|
376
|
-
posinf=max(y_score[np.isfinite(y_score)]),
|
|
377
|
-
neginf=min(y_score[np.isfinite(y_score)]),
|
|
378
|
-
)
|
|
379
|
-
fpr, tpr, thresh = roc_curve(
|
|
380
|
-
benchmark_result.binary_classification_stats.labels,
|
|
381
|
-
y_score,
|
|
382
|
-
pos_label=1,
|
|
383
|
-
)
|
|
243
|
+
for i, row in enumerate(curves.iter_rows(named=True)):
|
|
244
|
+
run_identifier = row["run_identifier"]
|
|
245
|
+
fpr = row["fpr"]
|
|
246
|
+
tpr = row["tpr"]
|
|
384
247
|
roc_auc = auc(fpr, tpr)
|
|
385
|
-
|
|
386
248
|
plt.plot(
|
|
387
249
|
fpr,
|
|
388
250
|
tpr,
|
|
389
|
-
label=f"{
|
|
251
|
+
label=f"{run_identifier} ROC Curve (AUC = {roc_auc:.2f})",
|
|
390
252
|
color=self.palette_hex_codes[i],
|
|
391
253
|
)
|
|
392
|
-
|
|
393
|
-
plt.plot(linestyle="--", color="gray")
|
|
254
|
+
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
394
255
|
plt.xlabel("False Positive Rate")
|
|
395
256
|
plt.ylabel("True Positive Rate")
|
|
396
|
-
|
|
397
|
-
plt.title("Receiver Operating Characteristic (ROC) Curve")
|
|
398
|
-
else:
|
|
399
|
-
plt.title(benchmark_generator.plot_customisation.roc_curve_title)
|
|
257
|
+
plt.title(plot_customisation.roc_curve_title)
|
|
400
258
|
plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
|
|
401
259
|
plt.savefig(
|
|
402
260
|
f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_roc_curve.svg",
|
|
@@ -406,46 +264,30 @@ class PlotGenerator:
|
|
|
406
264
|
|
|
407
265
|
def generate_precision_recall(
|
|
408
266
|
self,
|
|
409
|
-
|
|
410
|
-
benchmark_generator:
|
|
267
|
+
curves: pl.DataFrame,
|
|
268
|
+
benchmark_generator: BenchmarkOutputType,
|
|
269
|
+
plot_customisation: SinglePlotCustomisation,
|
|
411
270
|
):
|
|
412
271
|
"""
|
|
413
272
|
Generate and plot Precision-Recall curves for binary classification benchmark results.
|
|
414
|
-
|
|
415
|
-
Args:
|
|
416
|
-
benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
|
|
417
|
-
benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
|
|
418
273
|
"""
|
|
419
274
|
plt.clf()
|
|
420
275
|
plt.figure()
|
|
421
|
-
for i,
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
posinf=max(y_score[np.isfinite(y_score)]),
|
|
427
|
-
neginf=min(y_score[np.isfinite(y_score)]),
|
|
428
|
-
)
|
|
429
|
-
precision, recall, thresh = precision_recall_curve(
|
|
430
|
-
benchmark_result.binary_classification_stats.labels,
|
|
431
|
-
y_score,
|
|
432
|
-
)
|
|
433
|
-
precision_recall_auc = auc(recall, precision)
|
|
276
|
+
for i, row in enumerate(curves.iter_rows(named=True)):
|
|
277
|
+
run_identifier = row["run_identifier"]
|
|
278
|
+
precision = row["precision"]
|
|
279
|
+
recall = row["recall"]
|
|
280
|
+
pr_auc = auc(recall[::-1], precision[::-1])
|
|
434
281
|
plt.plot(
|
|
435
282
|
recall,
|
|
436
283
|
precision,
|
|
437
|
-
label=f"{
|
|
438
|
-
f"(AUC = {precision_recall_auc:.2f})",
|
|
284
|
+
label=f"{run_identifier} Precision-Recall Curve (AUC = {pr_auc:.2f})",
|
|
439
285
|
color=self.palette_hex_codes[i],
|
|
440
286
|
)
|
|
441
|
-
|
|
442
|
-
plt.plot(linestyle="--", color="gray")
|
|
287
|
+
plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
|
443
288
|
plt.xlabel("Recall")
|
|
444
289
|
plt.ylabel("Precision")
|
|
445
|
-
|
|
446
|
-
plt.title("Precision-Recall Curve")
|
|
447
|
-
else:
|
|
448
|
-
plt.title(benchmark_generator.plot_customisation.precision_recall_title)
|
|
290
|
+
plt.title(plot_customisation.precision_recall_title)
|
|
449
291
|
plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
|
|
450
292
|
plt.savefig(
|
|
451
293
|
f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_pr_curve.svg",
|
|
@@ -453,112 +295,85 @@ class PlotGenerator:
|
|
|
453
295
|
bbox_inches="tight",
|
|
454
296
|
)
|
|
455
297
|
|
|
456
|
-
def generate_non_cumulative_bar(
|
|
457
|
-
self,
|
|
458
|
-
benchmarking_results: List[BenchmarkRunResults],
|
|
459
|
-
benchmark_generator: BenchmarkRunOutputGenerator,
|
|
460
|
-
) -> None:
|
|
461
|
-
"""
|
|
462
|
-
Generate a non-cumulative bar plot.
|
|
463
|
-
|
|
464
|
-
Args:
|
|
465
|
-
benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
|
|
466
|
-
benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
|
|
467
|
-
"""
|
|
468
|
-
plt.clf()
|
|
469
|
-
for benchmark_result in benchmarking_results:
|
|
470
|
-
self._generate_non_cumulative_bar_plot_data(benchmark_result)
|
|
471
|
-
|
|
472
|
-
stats_df = pd.DataFrame(self.stats)
|
|
473
|
-
sns.catplot(
|
|
474
|
-
data=stats_df,
|
|
475
|
-
kind="bar",
|
|
476
|
-
x="Rank",
|
|
477
|
-
y="Percentage",
|
|
478
|
-
hue="Run",
|
|
479
|
-
palette=self.palette_hex_codes,
|
|
480
|
-
edgecolor="white",
|
|
481
|
-
legend=False,
|
|
482
|
-
).set(xlabel="Rank", ylabel=benchmark_generator.y_label)
|
|
483
|
-
plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3, title="Run")
|
|
484
|
-
if benchmark_generator.plot_customisation.rank_plot_title is None:
|
|
485
|
-
plt.title(
|
|
486
|
-
f"{benchmark_generator.prioritisation_type_string.capitalize()} Non-Cumulative Rank Stats"
|
|
487
|
-
)
|
|
488
|
-
else:
|
|
489
|
-
plt.title(
|
|
490
|
-
benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
|
|
491
|
-
)
|
|
492
|
-
plt.ylim(0, 1)
|
|
493
|
-
plt.savefig(
|
|
494
|
-
f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
|
|
495
|
-
format="svg",
|
|
496
|
-
bbox_inches="tight",
|
|
497
|
-
)
|
|
498
|
-
|
|
499
298
|
|
|
500
299
|
def generate_plots(
|
|
501
300
|
benchmark_name: str,
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
301
|
+
benchmarking_results_df: pl.DataFrame,
|
|
302
|
+
curves: pl.DataFrame,
|
|
303
|
+
benchmark_output_type: BenchmarkOutputType,
|
|
304
|
+
plot_customisation: PlotCustomisation,
|
|
505
305
|
) -> None:
|
|
506
306
|
"""
|
|
507
307
|
Generate summary statistics bar plots for prioritisation.
|
|
508
308
|
|
|
509
309
|
This method generates summary statistics bar plots based on the provided benchmarking results and plot type.
|
|
510
|
-
|
|
511
|
-
Args:
|
|
512
|
-
benchmarking_results (list[BenchmarkRunResults]): List of benchmarking results for multiple runs.
|
|
513
|
-
benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
|
|
514
|
-
generate_from_db (bool): Specify whether to generate plots from the db file. Defaults to False.
|
|
515
310
|
"""
|
|
516
311
|
plot_generator = PlotGenerator(benchmark_name)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
312
|
+
plot_customisation_type = getattr(
|
|
313
|
+
plot_customisation, f"{benchmark_output_type.prioritisation_type_string}_plots"
|
|
314
|
+
)
|
|
315
|
+
logger.info("Generating ROC curve visualisations.")
|
|
316
|
+
plot_generator.generate_roc_curve(curves, benchmark_output_type, plot_customisation_type)
|
|
317
|
+
logger.info("Generating Precision-Recall curves visualisations.")
|
|
318
|
+
plot_generator.generate_precision_recall(curves, benchmark_output_type, plot_customisation_type)
|
|
319
|
+
plot_type = PlotTypes(plot_customisation_type.plot_type)
|
|
320
|
+
match plot_type:
|
|
321
|
+
case PlotTypes.BAR_STACKED:
|
|
322
|
+
logger.info("Generating stacked bar plot.")
|
|
323
|
+
plot_generator.generate_stacked_bar_plot(
|
|
324
|
+
benchmarking_results_df, benchmark_output_type, plot_customisation_type
|
|
325
|
+
)
|
|
326
|
+
case PlotTypes.BAR_CUMULATIVE:
|
|
327
|
+
logger.info("Generating cumulative bar plot.")
|
|
328
|
+
plot_generator.generate_cumulative_bar(
|
|
329
|
+
benchmarking_results_df, benchmark_output_type, plot_customisation_type
|
|
330
|
+
)
|
|
331
|
+
case PlotTypes.BAR_NON_CUMULATIVE:
|
|
332
|
+
logger.info("Generating non cumulative bar plot.")
|
|
333
|
+
plot_generator.generate_non_cumulative_bar(
|
|
334
|
+
benchmarking_results_df, benchmark_output_type, plot_customisation_type
|
|
335
|
+
)
|
|
526
336
|
|
|
527
337
|
|
|
528
|
-
def
|
|
529
|
-
benchmark_db: Path,
|
|
530
|
-
run_data: Path,
|
|
531
|
-
):
|
|
338
|
+
def generate_plots_from_db(db_path: Path, config: Path) -> None:
|
|
532
339
|
"""
|
|
533
|
-
Generate
|
|
534
|
-
|
|
535
|
-
Reads a summary of benchmark results from a benchmark db and generates a bar plot
|
|
536
|
-
based on the analysis type and plot type.
|
|
537
|
-
|
|
340
|
+
Generate plots from database file.
|
|
538
341
|
Args:
|
|
539
|
-
|
|
540
|
-
|
|
342
|
+
db_path (Path): Path to the database file.
|
|
343
|
+
config (Path): Path to the benchmarking config file.
|
|
541
344
|
"""
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
345
|
+
logger.info(f"Generating plots from {db_path}")
|
|
346
|
+
conn = duckdb.connect(db_path)
|
|
347
|
+
logger.info(f"Parsing configurations from {config}")
|
|
348
|
+
benchmark_config_file = parse_run_config(config)
|
|
349
|
+
tables = {
|
|
350
|
+
row[0]
|
|
351
|
+
for row in conn.execute(
|
|
352
|
+
"""SELECT table_name FROM duckdb_tables WHERE table_name """
|
|
353
|
+
"""LIKE '%_summary%' OR table_name LIKE '%_binary_classification_curves'"""
|
|
354
|
+
).fetchall()
|
|
355
|
+
}
|
|
356
|
+
for benchmark_output_type in BenchmarkOutputTypeEnum:
|
|
357
|
+
summary_table = (
|
|
358
|
+
f"{benchmark_config_file.benchmark_name}_"
|
|
359
|
+
f"{benchmark_output_type.value.prioritisation_type_string}_summary"
|
|
550
360
|
)
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
benchmark_stats_summary.variant_results,
|
|
555
|
-
VariantBenchmarkRunOutputGenerator(config.plot_customisation.variant_plots),
|
|
556
|
-
True,
|
|
557
|
-
)
|
|
558
|
-
elif benchmark_stats_summary.disease_results:
|
|
559
|
-
generate_plots(
|
|
560
|
-
config.benchmark_name,
|
|
561
|
-
benchmark_stats_summary.disease_results,
|
|
562
|
-
DiseaseBenchmarkRunOutputGenerator(config.plot_customisation.disease_plots),
|
|
563
|
-
True,
|
|
361
|
+
curve_table = (
|
|
362
|
+
f"{benchmark_config_file.benchmark_name}_"
|
|
363
|
+
f"{benchmark_output_type.value.prioritisation_type_string}_binary_classification_curves"
|
|
564
364
|
)
|
|
365
|
+
if summary_table in tables and curve_table in tables:
|
|
366
|
+
logger.info(
|
|
367
|
+
f"Generating plots for {benchmark_output_type.value.prioritisation_type_string} prioritisation."
|
|
368
|
+
)
|
|
369
|
+
benchmarking_results_df = load_table_lazy(summary_table, conn).collect()
|
|
370
|
+
curves_df = load_table_lazy(curve_table, conn).collect()
|
|
371
|
+
generate_plots(
|
|
372
|
+
benchmark_name=benchmark_config_file.benchmark_name,
|
|
373
|
+
benchmarking_results_df=benchmarking_results_df,
|
|
374
|
+
curves=curves_df,
|
|
375
|
+
benchmark_output_type=benchmark_output_type.value,
|
|
376
|
+
plot_customisation=benchmark_config_file.plot_customisation,
|
|
377
|
+
)
|
|
378
|
+
logger.info("Finished generating plots.")
|
|
379
|
+
conn.close()
|