pheval 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pheval might be problematic. Click here for more details.

Files changed (33) hide show
  1. pheval/analyse/benchmark.py +156 -0
  2. pheval/analyse/benchmark_db_manager.py +16 -134
  3. pheval/analyse/benchmark_output_type.py +43 -0
  4. pheval/analyse/binary_classification_curves.py +132 -0
  5. pheval/analyse/binary_classification_stats.py +164 -307
  6. pheval/analyse/generate_plots.py +210 -395
  7. pheval/analyse/generate_rank_comparisons.py +44 -0
  8. pheval/analyse/rank_stats.py +190 -382
  9. pheval/analyse/run_data_parser.py +21 -39
  10. pheval/cli.py +27 -24
  11. pheval/cli_pheval_utils.py +7 -8
  12. pheval/post_processing/phenopacket_truth_set.py +250 -0
  13. pheval/post_processing/post_processing.py +179 -345
  14. pheval/post_processing/validate_result_format.py +91 -0
  15. pheval/prepare/update_phenopacket.py +11 -9
  16. pheval/utils/logger.py +35 -0
  17. pheval/utils/phenopacket_utils.py +85 -91
  18. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/METADATA +4 -4
  19. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/RECORD +22 -26
  20. pheval/analyse/analysis.py +0 -104
  21. pheval/analyse/assess_prioritisation_base.py +0 -108
  22. pheval/analyse/benchmark_generator.py +0 -126
  23. pheval/analyse/benchmarking_data.py +0 -25
  24. pheval/analyse/disease_prioritisation_analysis.py +0 -152
  25. pheval/analyse/gene_prioritisation_analysis.py +0 -147
  26. pheval/analyse/generate_summary_outputs.py +0 -105
  27. pheval/analyse/parse_benchmark_summary.py +0 -81
  28. pheval/analyse/parse_corpus.py +0 -219
  29. pheval/analyse/prioritisation_result_types.py +0 -52
  30. pheval/analyse/variant_prioritisation_analysis.py +0 -159
  31. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/LICENSE +0 -0
  32. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/WHEEL +0 -0
  33. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/entry_points.txt +0 -0
@@ -1,35 +1,32 @@
1
+ from enum import Enum
1
2
  from pathlib import Path
2
- from typing import List
3
3
 
4
+ import duckdb
4
5
  import matplotlib
5
- import numpy as np
6
- import pandas as pd
6
+ import polars as pl
7
7
  import seaborn as sns
8
8
  from matplotlib import pyplot as plt
9
- from sklearn.metrics import auc, precision_recall_curve, roc_curve
9
+ from sklearn.metrics import auc
10
10
 
11
- from pheval.analyse.benchmark_generator import (
12
- BenchmarkRunOutputGenerator,
13
- DiseaseBenchmarkRunOutputGenerator,
14
- GeneBenchmarkRunOutputGenerator,
15
- VariantBenchmarkRunOutputGenerator,
11
+ from pheval.analyse.benchmark_db_manager import load_table_lazy
12
+ from pheval.analyse.benchmark_output_type import (
13
+ BenchmarkOutputType,
14
+ BenchmarkOutputTypeEnum,
16
15
  )
17
- from pheval.analyse.benchmarking_data import BenchmarkRunResults
18
- from pheval.analyse.parse_benchmark_summary import parse_benchmark_db
19
- from pheval.analyse.run_data_parser import parse_run_config
20
-
16
+ from pheval.analyse.run_data_parser import (
17
+ PlotCustomisation,
18
+ SinglePlotCustomisation,
19
+ parse_run_config,
20
+ )
21
+ from pheval.utils.logger import get_logger
21
22
 
22
- def trim_corpus_results_directory_suffix(corpus_results_directory: Path) -> Path:
23
- """
24
- Trim the suffix from the corpus results directory name.
23
+ logger = get_logger()
25
24
 
26
- Args:
27
- corpus_results_directory (Path): The directory path containing corpus results.
28
25
 
29
- Returns:
30
- Path: The Path object with the suffix removed from the directory name.
31
- """
32
- return Path(str(corpus_results_directory).replace("_results", ""))
26
+ class PlotTypes(Enum):
27
+ BAR_STACKED = "bar_stacked"
28
+ BAR_CUMULATIVE = "bar_cumulative"
29
+ BAR_NON_CUMULATIVE = "bar_non_cumulative"
33
30
 
34
31
 
35
32
  class PlotGenerator:
@@ -52,216 +49,132 @@ class PlotGenerator:
52
49
  """
53
50
  Initialise the PlotGenerator class.
54
51
  Note:
55
- `self.stats` will be used to store statistics data.
56
- `self.mrr` will store Mean Reciprocal Rank (MRR) values.
57
52
  Matplotlib settings are configured to remove the right and top axes spines
58
53
  for generated plots.
59
54
  """
60
55
  self.benchmark_name = benchmark_name
61
- self.stats, self.mrr = [], []
62
56
  matplotlib.rcParams["axes.spines.right"] = False
63
57
  matplotlib.rcParams["axes.spines.top"] = False
64
58
 
65
59
  @staticmethod
66
- def _create_run_identifier(results_dir: Path) -> str:
60
+ def _generate_stacked_data(benchmarking_stats_df: pl.DataFrame) -> pl.DataFrame:
67
61
  """
68
- Create a run identifier from a path.
69
-
62
+ Generate stacked data.
70
63
  Args:
71
- results_dir (Path): The directory path for results.
72
-
64
+ benchmarking_stats_df (pl.DataFrame): benchmarking stats dataframe.
73
65
  Returns:
74
- str: A string representing the run identifier created from the given path.
66
+ pl.DataFrame: Data formatted for plotting stacked data.
75
67
  """
76
- return f"{Path(results_dir).parents[0].name}_{trim_corpus_results_directory_suffix(Path(results_dir).name)}"
68
+ return benchmarking_stats_df.with_columns(
69
+ [
70
+ pl.col("run_identifier").alias("Run"),
71
+ pl.col("percentage@1").alias("Top"),
72
+ (pl.col("percentage@3") - pl.col("percentage@1")).alias("2-3"),
73
+ (pl.col("percentage@5") - pl.col("percentage@3")).alias("4-5"),
74
+ (pl.col("percentage@10") - pl.col("percentage@5")).alias("6-10"),
75
+ (pl.col("percentage_found") - pl.col("percentage@10")).alias(">10"),
76
+ (100 - pl.col("percentage_found")).alias("Missed"),
77
+ ]
78
+ ).select(["Run", "Top", "2-3", "4-5", "6-10", ">10", "Missed"])
77
79
 
78
- def return_benchmark_name(self, benchmark_result: BenchmarkRunResults) -> str:
80
+ @staticmethod
81
+ def _extract_mrr_data(benchmarking_results_df: pl.DataFrame) -> pl.DataFrame:
79
82
  """
80
- Return the benchmark name for a run.
83
+ Generate data in the correct format for dataframe creation for MRR (Mean Reciprocal Rank) bar plot.
81
84
 
82
85
  Args:
83
- benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
84
-
86
+ benchmarking_results_df (pl.DataFrame): benchmarking stats dataframe.
85
87
  Returns:
86
- str: The benchmark name obtained from the given BenchmarkRunResults instance.
87
- """
88
- return (
89
- benchmark_result.benchmark_name
90
- if benchmark_result.results_dir is None
91
- else self._create_run_identifier(benchmark_result.results_dir)
92
- )
93
-
94
- def _generate_stacked_bar_plot_data(self, benchmark_result: BenchmarkRunResults) -> None:
95
- """
96
- Generate data in the correct format for dataframe creation for a stacked bar plot,
97
- appending to the self.stats attribute of the class.
98
-
99
- Args:
100
- benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
88
+ pl.DataFrame: Data formatted for plotting MRR bar plot.
101
89
  """
102
- rank_stats = benchmark_result.rank_stats
103
- self.stats.append(
104
- {
105
- "Run": self.return_benchmark_name(benchmark_result),
106
- "Top": benchmark_result.rank_stats.percentage_top(),
107
- "2-3": rank_stats.percentage_difference(
108
- rank_stats.percentage_top3(), rank_stats.percentage_top()
109
- ),
110
- "4-5": rank_stats.percentage_difference(
111
- rank_stats.percentage_top5(), rank_stats.percentage_top3()
112
- ),
113
- "6-10": rank_stats.percentage_difference(
114
- rank_stats.percentage_top10(), rank_stats.percentage_top5()
115
- ),
116
- ">10": rank_stats.percentage_difference(
117
- rank_stats.percentage_found(), rank_stats.percentage_top10()
118
- ),
119
- "Missed": rank_stats.percentage_difference(100, rank_stats.percentage_found()),
120
- }
90
+ return benchmarking_results_df.select(["run_identifier", "mrr"]).rename(
91
+ {"run_identifier": "Run", "mrr": "Percentage"}
121
92
  )
122
93
 
123
- def _generate_stats_mrr_bar_plot_data(self, benchmark_result: BenchmarkRunResults) -> None:
94
+ def _save_fig(
95
+ self, benchmark_output_type: BenchmarkOutputType, y_lower_limit: int, y_upper_limit: int
96
+ ) -> None:
124
97
  """
125
- Generate data in the correct format for dataframe creation for MRR (Mean Reciprocal Rank) bar plot,
126
- appending to the self.mrr attribute of the class.
127
-
98
+ Save the generated figure.
128
99
  Args:
129
- benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
100
+ benchmark_output_type (BenchmarkOutputType): Benchmark output type.
101
+ y_lower_limit (int): Lower limit for the y-axis.
102
+ y_upper_limit (int): Upper limit for the y-axis.
130
103
  """
131
- self.mrr.extend(
132
- [
133
- {
134
- "Rank": "MRR",
135
- "Percentage": benchmark_result.rank_stats.return_mean_reciprocal_rank(),
136
- "Run": self.return_benchmark_name(benchmark_result),
137
- }
138
- ]
104
+ plt.ylim(y_lower_limit, y_upper_limit)
105
+ plt.savefig(
106
+ f"{self.benchmark_name}_{benchmark_output_type.prioritisation_type_string}_rank_stats.svg",
107
+ format="svg",
108
+ bbox_inches="tight",
139
109
  )
140
110
 
141
111
  def generate_stacked_bar_plot(
142
112
  self,
143
- benchmarking_results: List[BenchmarkRunResults],
144
- benchmark_generator: BenchmarkRunOutputGenerator,
113
+ benchmarking_results_df: pl.DataFrame,
114
+ benchmark_output_type: BenchmarkOutputType,
115
+ plot_customisation: SinglePlotCustomisation,
145
116
  ) -> None:
146
117
  """
147
118
  Generate a stacked bar plot and Mean Reciprocal Rank (MRR) bar plot.
148
-
149
119
  Args:
150
- benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
151
- benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
120
+ benchmarking_results_df (pl.DataFrame): benchmarking stats dataframe.
121
+ benchmark_output_type (BenchmarkOutputType): Benchmark output type.
122
+ plot_customisation (SinglePlotCustomisation): Plotting customisation.
152
123
  """
153
- for benchmark_result in benchmarking_results:
154
- self._generate_stacked_bar_plot_data(benchmark_result)
155
- self._generate_stats_mrr_bar_plot_data(benchmark_result)
156
- stats_df = pd.DataFrame(self.stats)
157
124
  plt.clf()
158
- stats_df.set_index("Run").plot(
125
+ stats_df = self._generate_stacked_data(benchmarking_results_df)
126
+ stats_df.to_pandas().set_index("Run").plot(
159
127
  kind="bar",
160
128
  stacked=True,
161
129
  color=self.palette_hex_codes,
162
- ylabel=benchmark_generator.y_label,
130
+ ylabel=benchmark_output_type.y_label,
163
131
  edgecolor="white",
164
132
  ).legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
165
- if benchmark_generator.plot_customisation.rank_plot_title is None:
166
- plt.title(f"{benchmark_generator.prioritisation_type_string.capitalize()} Rank Stats")
167
- else:
168
- plt.title(
169
- benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
170
- )
171
- plt.ylim(0, 100)
172
- plt.savefig(
173
- f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
174
- format="svg",
175
- bbox_inches="tight",
176
- )
177
-
178
- mrr_df = pd.DataFrame(self.mrr)
179
- mrr_df.set_index("Run").plot(
133
+ plt.title(plot_customisation.rank_plot_title, loc="center", fontsize=15)
134
+ self._save_fig(benchmark_output_type, 0, 100)
135
+ mrr_df = self._extract_mrr_data(benchmarking_results_df)
136
+ mrr_df.to_pandas().set_index("Run").plot(
180
137
  kind="bar",
181
138
  color=self.palette_hex_codes,
182
- ylabel=f"{benchmark_generator.prioritisation_type_string.capitalize()} mean reciprocal rank",
139
+ ylabel=f"{benchmark_output_type.prioritisation_type_string.capitalize()} mean reciprocal rank",
183
140
  legend=False,
184
141
  edgecolor="white",
185
142
  )
186
143
  plt.title(
187
- f"{benchmark_generator.prioritisation_type_string.capitalize()} results - mean reciprocal rank"
188
- )
189
- plt.ylim(0, 1)
190
- plt.savefig(
191
- f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_mrr.svg",
192
- format="svg",
193
- bbox_inches="tight",
144
+ f"{benchmark_output_type.prioritisation_type_string.capitalize()} results - mean reciprocal rank"
194
145
  )
146
+ self._save_fig(benchmark_output_type, 0, 1)
195
147
 
196
- def _generate_cumulative_bar_plot_data(self, benchmark_result: BenchmarkRunResults):
148
+ @staticmethod
149
+ def _generate_cumulative_bar_plot_data(benchmarking_results_df: pl.DataFrame) -> pl.DataFrame:
197
150
  """
198
151
  Generate data in the correct format for dataframe creation for a cumulative bar plot,
199
152
  appending to the self.stats attribute of the class.
200
-
201
- Args:
202
- benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
203
153
  """
204
- rank_stats = benchmark_result.rank_stats
205
- run_identifier = self.return_benchmark_name(benchmark_result)
206
- self.stats.extend(
154
+ return benchmarking_results_df.select(
207
155
  [
208
- {
209
- "Rank": "Top",
210
- "Percentage": rank_stats.percentage_top() / 100,
211
- "Run": run_identifier,
212
- },
213
- {
214
- "Rank": "Top3",
215
- "Percentage": rank_stats.percentage_top3() / 100,
216
- "Run": run_identifier,
217
- },
218
- {
219
- "Rank": "Top5",
220
- "Percentage": rank_stats.percentage_top5() / 100,
221
- "Run": run_identifier,
222
- },
223
- {
224
- "Rank": "Top10",
225
- "Percentage": rank_stats.percentage_top10() / 100,
226
- "Run": run_identifier,
227
- },
228
- {
229
- "Rank": "Found",
230
- "Percentage": rank_stats.percentage_found() / 100,
231
- "Run": run_identifier,
232
- },
233
- {
234
- "Rank": "Missed",
235
- "Percentage": rank_stats.percentage_difference(
236
- 100, rank_stats.percentage_found()
237
- )
238
- / 100,
239
- "Run": run_identifier,
240
- },
241
- {
242
- "Rank": "MRR",
243
- "Percentage": rank_stats.return_mean_reciprocal_rank(),
244
- "Run": run_identifier,
245
- },
156
+ pl.col("run_identifier").alias("Run"),
157
+ pl.col("percentage@1").alias("Top") / 100,
158
+ pl.col("percentage@3").alias("Top3") / 100,
159
+ pl.col("percentage@5").alias("Top5") / 100,
160
+ pl.col("percentage@10").alias("Top10") / 100,
161
+ pl.col("percentage_found").alias("Found") / 100,
162
+ pl.col("mrr").alias("MRR"),
246
163
  ]
247
164
  )
248
165
 
249
- def generate_cumulative_bar(
166
+ def _plot_bar_plot(
250
167
  self,
251
- benchmarking_results: List[BenchmarkRunResults],
252
- benchmark_generator: BenchmarkRunOutputGenerator,
168
+ benchmark_output_type: BenchmarkOutputType,
169
+ stats_df: pl.DataFrame,
170
+ plot_customisation: SinglePlotCustomisation,
253
171
  ) -> None:
254
- """
255
- Generate a cumulative bar plot.
256
-
257
- Args:
258
- benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
259
- benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
260
- """
261
- for benchmark_result in benchmarking_results:
262
- self._generate_cumulative_bar_plot_data(benchmark_result)
263
- stats_df = pd.DataFrame(self.stats)
264
- plt.clf()
172
+ stats_df = stats_df.to_pandas().melt(
173
+ id_vars=["Run"],
174
+ value_vars=["Top", "Top3", "Top5", "Top10", "Found", "MRR"],
175
+ var_name="Rank",
176
+ value_name="Percentage",
177
+ )
265
178
  sns.catplot(
266
179
  data=stats_df,
267
180
  kind="bar",
@@ -271,132 +184,77 @@ class PlotGenerator:
271
184
  palette=self.palette_hex_codes,
272
185
  edgecolor="white",
273
186
  legend=False,
274
- ).set(xlabel="Rank", ylabel=benchmark_generator.y_label)
187
+ ).set(xlabel="Rank", ylabel=benchmark_output_type.y_label)
275
188
  plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3, title="Run")
276
- if benchmark_generator.plot_customisation.rank_plot_title is None:
277
- plt.title(
278
- f"{benchmark_generator.prioritisation_type_string.capitalize()} Cumulative Rank Stats"
279
- )
280
- else:
281
- plt.title(
282
- benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
283
- )
284
- plt.ylim(0, 1)
285
- plt.savefig(
286
- f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
287
- format="svg",
288
- bbox_inches="tight",
289
- )
189
+ plt.title(plot_customisation.rank_plot_title, loc="center", fontsize=15)
190
+ self._save_fig(benchmark_output_type, 0, 1)
290
191
 
291
192
  def _generate_non_cumulative_bar_plot_data(
292
- self, benchmark_result: BenchmarkRunResults
293
- ) -> [dict]:
193
+ self, benchmarking_results_df: pl.DataFrame
194
+ ) -> pl.DataFrame:
294
195
  """
295
196
  Generate data in the correct format for dataframe creation for a non-cumulative bar plot,
296
197
  appending to the self.stats attribute of the class.
297
-
298
- Args:
299
- benchmark_result (BenchmarkRunResults): The benchmarking results for a run.
300
198
  """
301
- rank_stats = benchmark_result.rank_stats
302
- run_identifier = self.return_benchmark_name(benchmark_result)
303
- self.stats.extend(
304
- [
305
- {
306
- "Rank": "Top",
307
- "Percentage": rank_stats.percentage_top() / 100,
308
- "Run": run_identifier,
309
- },
310
- {
311
- "Rank": "2-3",
312
- "Percentage": rank_stats.percentage_difference(
313
- rank_stats.percentage_top3(), rank_stats.percentage_top()
314
- )
315
- / 100,
316
- "Run": run_identifier,
317
- },
318
- {
319
- "Rank": "4-5",
320
- "Percentage": rank_stats.percentage_difference(
321
- rank_stats.percentage_top5(), rank_stats.percentage_top3()
322
- )
323
- / 100,
324
- "Run": run_identifier,
325
- },
326
- {
327
- "Rank": "6-10",
328
- "Percentage": rank_stats.percentage_difference(
329
- rank_stats.percentage_top10(), rank_stats.percentage_top5()
330
- )
331
- / 100,
332
- "Run": run_identifier,
333
- },
334
- {
335
- "Rank": ">10",
336
- "Percentage": rank_stats.percentage_difference(
337
- rank_stats.percentage_found(), rank_stats.percentage_top10()
338
- )
339
- / 100,
340
- "Run": run_identifier,
341
- },
342
- {
343
- "Rank": "Missed",
344
- "Percentage": rank_stats.percentage_difference(
345
- 100, rank_stats.percentage_found()
346
- )
347
- / 100,
348
- "Run": run_identifier,
349
- },
350
- {
351
- "Rank": "MRR",
352
- "Percentage": rank_stats.return_mean_reciprocal_rank(),
353
- "Run": run_identifier,
354
- },
355
- ]
199
+ return self._generate_stacked_data(benchmarking_results_df).hstack(
200
+ self._extract_mrr_data(benchmarking_results_df).select(
201
+ pl.col("Percentage").alias("MRR")
202
+ )
356
203
  )
357
204
 
205
+ def generate_cumulative_bar(
206
+ self,
207
+ benchmarking_results_df: pl.DataFrame,
208
+ benchmark_generator: BenchmarkOutputType,
209
+ plot_customisation: SinglePlotCustomisation,
210
+ ) -> None:
211
+ """
212
+ Generate a cumulative bar plot.
213
+ """
214
+ plt.clf()
215
+ stats_df = self._generate_cumulative_bar_plot_data(benchmarking_results_df)
216
+ self._plot_bar_plot(benchmark_generator, stats_df, plot_customisation)
217
+
218
+ def generate_non_cumulative_bar(
219
+ self,
220
+ benchmarking_results_df: pl.DataFrame,
221
+ benchmark_generator: BenchmarkOutputType,
222
+ plot_customisation: SinglePlotCustomisation,
223
+ ) -> None:
224
+ """
225
+ Generate a non-cumulative bar plot.
226
+ """
227
+ plt.clf()
228
+ stats_df = self._generate_non_cumulative_bar_plot_data(benchmarking_results_df)
229
+ self._plot_bar_plot(benchmark_generator, stats_df, plot_customisation)
230
+
358
231
  def generate_roc_curve(
359
232
  self,
360
- benchmarking_results: List[BenchmarkRunResults],
361
- benchmark_generator: BenchmarkRunOutputGenerator,
233
+ curves: pl.DataFrame,
234
+ benchmark_generator: BenchmarkOutputType,
235
+ plot_customisation: SinglePlotCustomisation,
362
236
  ):
363
237
  """
364
238
  Generate and plot Receiver Operating Characteristic (ROC) curves for binary classification benchmark results.
365
239
 
366
240
  Args:
367
- benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
368
- benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
369
241
  """
370
242
  plt.clf()
371
- for i, benchmark_result in enumerate(benchmarking_results):
372
- y_score = np.array(benchmark_result.binary_classification_stats.scores)
373
- y_score = np.nan_to_num(
374
- y_score,
375
- nan=0.0,
376
- posinf=max(y_score[np.isfinite(y_score)]),
377
- neginf=min(y_score[np.isfinite(y_score)]),
378
- )
379
- fpr, tpr, thresh = roc_curve(
380
- benchmark_result.binary_classification_stats.labels,
381
- y_score,
382
- pos_label=1,
383
- )
243
+ for i, row in enumerate(curves.iter_rows(named=True)):
244
+ run_identifier = row["run_identifier"]
245
+ fpr = row["fpr"]
246
+ tpr = row["tpr"]
384
247
  roc_auc = auc(fpr, tpr)
385
-
386
248
  plt.plot(
387
249
  fpr,
388
250
  tpr,
389
- label=f"{self.return_benchmark_name(benchmark_result)} ROC Curve (AUC = {roc_auc:.2f})",
251
+ label=f"{run_identifier} ROC Curve (AUC = {roc_auc:.2f})",
390
252
  color=self.palette_hex_codes[i],
391
253
  )
392
-
393
- plt.plot(linestyle="--", color="gray")
254
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
394
255
  plt.xlabel("False Positive Rate")
395
256
  plt.ylabel("True Positive Rate")
396
- if benchmark_generator.plot_customisation.roc_curve_title is None:
397
- plt.title("Receiver Operating Characteristic (ROC) Curve")
398
- else:
399
- plt.title(benchmark_generator.plot_customisation.roc_curve_title)
257
+ plt.title(plot_customisation.roc_curve_title)
400
258
  plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
401
259
  plt.savefig(
402
260
  f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_roc_curve.svg",
@@ -406,46 +264,30 @@ class PlotGenerator:
406
264
 
407
265
  def generate_precision_recall(
408
266
  self,
409
- benchmarking_results: List[BenchmarkRunResults],
410
- benchmark_generator: BenchmarkRunOutputGenerator,
267
+ curves: pl.DataFrame,
268
+ benchmark_generator: BenchmarkOutputType,
269
+ plot_customisation: SinglePlotCustomisation,
411
270
  ):
412
271
  """
413
272
  Generate and plot Precision-Recall curves for binary classification benchmark results.
414
-
415
- Args:
416
- benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
417
- benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
418
273
  """
419
274
  plt.clf()
420
275
  plt.figure()
421
- for i, benchmark_result in enumerate(benchmarking_results):
422
- y_score = np.array(benchmark_result.binary_classification_stats.scores)
423
- y_score = np.nan_to_num(
424
- y_score,
425
- nan=0.0,
426
- posinf=max(y_score[np.isfinite(y_score)]),
427
- neginf=min(y_score[np.isfinite(y_score)]),
428
- )
429
- precision, recall, thresh = precision_recall_curve(
430
- benchmark_result.binary_classification_stats.labels,
431
- y_score,
432
- )
433
- precision_recall_auc = auc(recall, precision)
276
+ for i, row in enumerate(curves.iter_rows(named=True)):
277
+ run_identifier = row["run_identifier"]
278
+ precision = row["precision"]
279
+ recall = row["recall"]
280
+ pr_auc = auc(recall[::-1], precision[::-1])
434
281
  plt.plot(
435
282
  recall,
436
283
  precision,
437
- label=f"{self.return_benchmark_name(benchmark_result)} Precision-Recall Curve "
438
- f"(AUC = {precision_recall_auc:.2f})",
284
+ label=f"{run_identifier} Precision-Recall Curve (AUC = {pr_auc:.2f})",
439
285
  color=self.palette_hex_codes[i],
440
286
  )
441
-
442
- plt.plot(linestyle="--", color="gray")
287
+ plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
443
288
  plt.xlabel("Recall")
444
289
  plt.ylabel("Precision")
445
- if benchmark_generator.plot_customisation.precision_recall_title is None:
446
- plt.title("Precision-Recall Curve")
447
- else:
448
- plt.title(benchmark_generator.plot_customisation.precision_recall_title)
290
+ plt.title(plot_customisation.precision_recall_title)
449
291
  plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
450
292
  plt.savefig(
451
293
  f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_pr_curve.svg",
@@ -453,112 +295,85 @@ class PlotGenerator:
453
295
  bbox_inches="tight",
454
296
  )
455
297
 
456
- def generate_non_cumulative_bar(
457
- self,
458
- benchmarking_results: List[BenchmarkRunResults],
459
- benchmark_generator: BenchmarkRunOutputGenerator,
460
- ) -> None:
461
- """
462
- Generate a non-cumulative bar plot.
463
-
464
- Args:
465
- benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
466
- benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
467
- """
468
- plt.clf()
469
- for benchmark_result in benchmarking_results:
470
- self._generate_non_cumulative_bar_plot_data(benchmark_result)
471
-
472
- stats_df = pd.DataFrame(self.stats)
473
- sns.catplot(
474
- data=stats_df,
475
- kind="bar",
476
- x="Rank",
477
- y="Percentage",
478
- hue="Run",
479
- palette=self.palette_hex_codes,
480
- edgecolor="white",
481
- legend=False,
482
- ).set(xlabel="Rank", ylabel=benchmark_generator.y_label)
483
- plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3, title="Run")
484
- if benchmark_generator.plot_customisation.rank_plot_title is None:
485
- plt.title(
486
- f"{benchmark_generator.prioritisation_type_string.capitalize()} Non-Cumulative Rank Stats"
487
- )
488
- else:
489
- plt.title(
490
- benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
491
- )
492
- plt.ylim(0, 1)
493
- plt.savefig(
494
- f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
495
- format="svg",
496
- bbox_inches="tight",
497
- )
498
-
499
298
 
500
299
  def generate_plots(
501
300
  benchmark_name: str,
502
- benchmarking_results: List[BenchmarkRunResults],
503
- benchmark_generator: BenchmarkRunOutputGenerator,
504
- generate_from_db: bool = False,
301
+ benchmarking_results_df: pl.DataFrame,
302
+ curves: pl.DataFrame,
303
+ benchmark_output_type: BenchmarkOutputType,
304
+ plot_customisation: PlotCustomisation,
505
305
  ) -> None:
506
306
  """
507
307
  Generate summary statistics bar plots for prioritisation.
508
308
 
509
309
  This method generates summary statistics bar plots based on the provided benchmarking results and plot type.
510
-
511
- Args:
512
- benchmarking_results (list[BenchmarkRunResults]): List of benchmarking results for multiple runs.
513
- benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
514
- generate_from_db (bool): Specify whether to generate plots from the db file. Defaults to False.
515
310
  """
516
311
  plot_generator = PlotGenerator(benchmark_name)
517
- if not generate_from_db:
518
- plot_generator.generate_roc_curve(benchmarking_results, benchmark_generator)
519
- plot_generator.generate_precision_recall(benchmarking_results, benchmark_generator)
520
- if benchmark_generator.plot_customisation.plot_type == "bar_stacked":
521
- plot_generator.generate_stacked_bar_plot(benchmarking_results, benchmark_generator)
522
- elif benchmark_generator.plot_customisation.plot_type == "bar_cumulative":
523
- plot_generator.generate_cumulative_bar(benchmarking_results, benchmark_generator)
524
- elif benchmark_generator.plot_customisation.plot_type == "bar_non_cumulative":
525
- plot_generator.generate_non_cumulative_bar(benchmarking_results, benchmark_generator)
312
+ plot_customisation_type = getattr(
313
+ plot_customisation, f"{benchmark_output_type.prioritisation_type_string}_plots"
314
+ )
315
+ logger.info("Generating ROC curve visualisations.")
316
+ plot_generator.generate_roc_curve(curves, benchmark_output_type, plot_customisation_type)
317
+ logger.info("Generating Precision-Recall curves visualisations.")
318
+ plot_generator.generate_precision_recall(curves, benchmark_output_type, plot_customisation_type)
319
+ plot_type = PlotTypes(plot_customisation_type.plot_type)
320
+ match plot_type:
321
+ case PlotTypes.BAR_STACKED:
322
+ logger.info("Generating stacked bar plot.")
323
+ plot_generator.generate_stacked_bar_plot(
324
+ benchmarking_results_df, benchmark_output_type, plot_customisation_type
325
+ )
326
+ case PlotTypes.BAR_CUMULATIVE:
327
+ logger.info("Generating cumulative bar plot.")
328
+ plot_generator.generate_cumulative_bar(
329
+ benchmarking_results_df, benchmark_output_type, plot_customisation_type
330
+ )
331
+ case PlotTypes.BAR_NON_CUMULATIVE:
332
+ logger.info("Generating non cumulative bar plot.")
333
+ plot_generator.generate_non_cumulative_bar(
334
+ benchmarking_results_df, benchmark_output_type, plot_customisation_type
335
+ )
526
336
 
527
337
 
528
- def generate_plots_from_benchmark_summary_db(
529
- benchmark_db: Path,
530
- run_data: Path,
531
- ):
338
+ def generate_plots_from_db(db_path: Path, config: Path) -> None:
532
339
  """
533
- Generate bar plot from summary benchmark results.
534
-
535
- Reads a summary of benchmark results from a benchmark db and generates a bar plot
536
- based on the analysis type and plot type.
537
-
340
+ Generate plots from database file.
538
341
  Args:
539
- benchmark_db (Path): Path to the summary TSV file containing benchmark results.
540
- run_data (Path): Path to YAML benchmarking configuration file.
342
+ db_path (Path): Path to the database file.
343
+ config (Path): Path to the benchmarking config file.
541
344
  """
542
- benchmark_stats_summary = parse_benchmark_db(benchmark_db)
543
- config = parse_run_config(run_data)
544
- if benchmark_stats_summary.gene_results:
545
- generate_plots(
546
- config.benchmark_name,
547
- benchmark_stats_summary.gene_results,
548
- GeneBenchmarkRunOutputGenerator(config.plot_customisation.gene_plots),
549
- True,
345
+ logger.info(f"Generating plots from {db_path}")
346
+ conn = duckdb.connect(db_path)
347
+ logger.info(f"Parsing configurations from {config}")
348
+ benchmark_config_file = parse_run_config(config)
349
+ tables = {
350
+ row[0]
351
+ for row in conn.execute(
352
+ """SELECT table_name FROM duckdb_tables WHERE table_name """
353
+ """LIKE '%_summary%' OR table_name LIKE '%_binary_classification_curves'"""
354
+ ).fetchall()
355
+ }
356
+ for benchmark_output_type in BenchmarkOutputTypeEnum:
357
+ summary_table = (
358
+ f"{benchmark_config_file.benchmark_name}_"
359
+ f"{benchmark_output_type.value.prioritisation_type_string}_summary"
550
360
  )
551
- if benchmark_stats_summary.variant_results:
552
- generate_plots(
553
- config.benchmark_name,
554
- benchmark_stats_summary.variant_results,
555
- VariantBenchmarkRunOutputGenerator(config.plot_customisation.variant_plots),
556
- True,
557
- )
558
- elif benchmark_stats_summary.disease_results:
559
- generate_plots(
560
- config.benchmark_name,
561
- benchmark_stats_summary.disease_results,
562
- DiseaseBenchmarkRunOutputGenerator(config.plot_customisation.disease_plots),
563
- True,
361
+ curve_table = (
362
+ f"{benchmark_config_file.benchmark_name}_"
363
+ f"{benchmark_output_type.value.prioritisation_type_string}_binary_classification_curves"
564
364
  )
365
+ if summary_table in tables and curve_table in tables:
366
+ logger.info(
367
+ f"Generating plots for {benchmark_output_type.value.prioritisation_type_string} prioritisation."
368
+ )
369
+ benchmarking_results_df = load_table_lazy(summary_table, conn).collect()
370
+ curves_df = load_table_lazy(curve_table, conn).collect()
371
+ generate_plots(
372
+ benchmark_name=benchmark_config_file.benchmark_name,
373
+ benchmarking_results_df=benchmarking_results_df,
374
+ curves=curves_df,
375
+ benchmark_output_type=benchmark_output_type.value,
376
+ plot_customisation=benchmark_config_file.plot_customisation,
377
+ )
378
+ logger.info("Finished generating plots.")
379
+ conn.close()