pheval 0.3.9__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pheval might be problematic. Click here for more details.

@@ -14,11 +14,8 @@ from pheval.analyse.benchmark_generator import (
14
14
  VariantBenchmarkRunOutputGenerator,
15
15
  )
16
16
  from pheval.analyse.benchmarking_data import BenchmarkRunResults
17
- from pheval.analyse.parse_benchmark_summary import (
18
- parse_benchmark_result_summary,
19
- read_benchmark_tsv_result_summary,
20
- )
21
- from pheval.constants import PHEVAL_RESULTS_DIRECTORY_SUFFIX
17
+ from pheval.analyse.parse_benchmark_summary import parse_benchmark_db
18
+ from pheval.analyse.run_data_parser import parse_run_config
22
19
 
23
20
 
24
21
  def trim_corpus_results_directory_suffix(corpus_results_directory: Path) -> Path:
@@ -31,7 +28,7 @@ def trim_corpus_results_directory_suffix(corpus_results_directory: Path) -> Path
31
28
  Returns:
32
29
  Path: The Path object with the suffix removed from the directory name.
33
30
  """
34
- return Path(str(corpus_results_directory).replace(PHEVAL_RESULTS_DIRECTORY_SUFFIX, ""))
31
+ return Path(str(corpus_results_directory).replace("_results", ""))
35
32
 
36
33
 
37
34
  class PlotGenerator:
@@ -50,9 +47,7 @@ class PlotGenerator:
50
47
  "#1b9e77",
51
48
  ]
52
49
 
53
- def __init__(
54
- self,
55
- ):
50
+ def __init__(self, benchmark_name: str):
56
51
  """
57
52
  Initialise the PlotGenerator class.
58
53
  Note:
@@ -61,6 +56,7 @@ class PlotGenerator:
61
56
  Matplotlib settings are configured to remove the right and top axes spines
62
57
  for generated plots.
63
58
  """
59
+ self.benchmark_name = benchmark_name
64
60
  self.stats, self.mrr = [], []
65
61
  matplotlib.rcParams["axes.spines.right"] = False
66
62
  matplotlib.rcParams["axes.spines.top"] = False
@@ -145,7 +141,6 @@ class PlotGenerator:
145
141
  self,
146
142
  benchmarking_results: List[BenchmarkRunResults],
147
143
  benchmark_generator: BenchmarkRunOutputGenerator,
148
- title: str = None,
149
144
  ) -> None:
150
145
  """
151
146
  Generate a stacked bar plot and Mean Reciprocal Rank (MRR) bar plot.
@@ -153,12 +148,12 @@ class PlotGenerator:
153
148
  Args:
154
149
  benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
155
150
  benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
156
- title (str, optional): Title for the generated plot. Defaults to None.
157
151
  """
158
152
  for benchmark_result in benchmarking_results:
159
153
  self._generate_stacked_bar_plot_data(benchmark_result)
160
154
  self._generate_stats_mrr_bar_plot_data(benchmark_result)
161
155
  stats_df = pd.DataFrame(self.stats)
156
+ plt.clf()
162
157
  stats_df.set_index("Run").plot(
163
158
  kind="bar",
164
159
  stacked=True,
@@ -166,15 +161,15 @@ class PlotGenerator:
166
161
  ylabel=benchmark_generator.y_label,
167
162
  edgecolor="white",
168
163
  ).legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
169
- if title is None:
164
+ if benchmark_generator.plot_customisation.rank_plot_title is None:
165
+ plt.title(f"{benchmark_generator.prioritisation_type_string.capitalize()} Rank Stats")
166
+ else:
170
167
  plt.title(
171
- f"{benchmark_generator.prioritisation_type_file_prefix.capitalize()} Rank Stats"
168
+ benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
172
169
  )
173
- else:
174
- plt.title(title, loc="center", fontsize=15)
175
170
  plt.ylim(0, 100)
176
171
  plt.savefig(
177
- f"{benchmark_generator.prioritisation_type_file_prefix}_rank_stats.svg",
172
+ f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
178
173
  format="svg",
179
174
  bbox_inches="tight",
180
175
  )
@@ -183,16 +178,16 @@ class PlotGenerator:
183
178
  mrr_df.set_index("Run").plot(
184
179
  kind="bar",
185
180
  color=self.palette_hex_codes,
186
- ylabel=f"{benchmark_generator.prioritisation_type_file_prefix.capitalize()} mean reciprocal rank",
181
+ ylabel=f"{benchmark_generator.prioritisation_type_string.capitalize()} mean reciprocal rank",
187
182
  legend=False,
188
183
  edgecolor="white",
189
184
  )
190
185
  plt.title(
191
- f"{benchmark_generator.prioritisation_type_file_prefix.capitalize()} results - mean reciprocal rank"
186
+ f"{benchmark_generator.prioritisation_type_string.capitalize()} results - mean reciprocal rank"
192
187
  )
193
188
  plt.ylim(0, 1)
194
189
  plt.savefig(
195
- f"{benchmark_generator.prioritisation_type_file_prefix}_mrr.svg",
190
+ f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_mrr.svg",
196
191
  format="svg",
197
192
  bbox_inches="tight",
198
193
  )
@@ -254,7 +249,6 @@ class PlotGenerator:
254
249
  self,
255
250
  benchmarking_results: List[BenchmarkRunResults],
256
251
  benchmark_generator: BenchmarkRunOutputGenerator,
257
- title: str = None,
258
252
  ) -> None:
259
253
  """
260
254
  Generate a cumulative bar plot.
@@ -262,11 +256,11 @@ class PlotGenerator:
262
256
  Args:
263
257
  benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
264
258
  benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
265
- title (str, optional): Title for the generated plot. Defaults to None.
266
259
  """
267
260
  for benchmark_result in benchmarking_results:
268
261
  self._generate_cumulative_bar_plot_data(benchmark_result)
269
262
  stats_df = pd.DataFrame(self.stats)
263
+ plt.clf()
270
264
  sns.catplot(
271
265
  data=stats_df,
272
266
  kind="bar",
@@ -278,15 +272,17 @@ class PlotGenerator:
278
272
  legend=False,
279
273
  ).set(xlabel="Rank", ylabel=benchmark_generator.y_label)
280
274
  plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3, title="Run")
281
- if title is None:
275
+ if benchmark_generator.plot_customisation.rank_plot_title is None:
282
276
  plt.title(
283
- f"{benchmark_generator.prioritisation_type_file_prefix.capitalize()} Cumulative Rank Stats"
277
+ f"{benchmark_generator.prioritisation_type_string.capitalize()} Cumulative Rank Stats"
284
278
  )
285
279
  else:
286
- plt.title(title, loc="center", fontsize=15)
280
+ plt.title(
281
+ benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
282
+ )
287
283
  plt.ylim(0, 1)
288
284
  plt.savefig(
289
- f"{benchmark_generator.prioritisation_type_file_prefix}_rank_stats.svg",
285
+ f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
290
286
  format="svg",
291
287
  bbox_inches="tight",
292
288
  )
@@ -370,6 +366,7 @@ class PlotGenerator:
370
366
  benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
371
367
  benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
372
368
  """
369
+ plt.clf()
373
370
  for i, benchmark_result in enumerate(benchmarking_results):
374
371
  fpr, tpr, thresh = roc_curve(
375
372
  benchmark_result.binary_classification_stats.labels,
@@ -388,10 +385,13 @@ class PlotGenerator:
388
385
  plt.plot(linestyle="--", color="gray")
389
386
  plt.xlabel("False Positive Rate")
390
387
  plt.ylabel("True Positive Rate")
391
- plt.title("Receiver Operating Characteristic (ROC) Curve")
388
+ if benchmark_generator.plot_customisation.roc_curve_title is None:
389
+ plt.title("Receiver Operating Characteristic (ROC) Curve")
390
+ else:
391
+ plt.title(benchmark_generator.plot_customisation.roc_curve_title)
392
392
  plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
393
393
  plt.savefig(
394
- f"{benchmark_generator.prioritisation_type_file_prefix}_roc_curve.svg",
394
+ f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_roc_curve.svg",
395
395
  format="svg",
396
396
  bbox_inches="tight",
397
397
  )
@@ -408,6 +408,7 @@ class PlotGenerator:
408
408
  benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
409
409
  benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
410
410
  """
411
+ plt.clf()
411
412
  plt.figure()
412
413
  for i, benchmark_result in enumerate(benchmarking_results):
413
414
  precision, recall, thresh = precision_recall_curve(
@@ -426,10 +427,13 @@ class PlotGenerator:
426
427
  plt.plot(linestyle="--", color="gray")
427
428
  plt.xlabel("Recall")
428
429
  plt.ylabel("Precision")
429
- plt.title("Precision-Recall Curve")
430
+ if benchmark_generator.plot_customisation.precision_recall_title is None:
431
+ plt.title("Precision-Recall Curve")
432
+ else:
433
+ plt.title(benchmark_generator.plot_customisation.precision_recall_title)
430
434
  plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
431
435
  plt.savefig(
432
- f"{benchmark_generator.prioritisation_type_file_prefix}_precision_recall_curve.svg",
436
+ f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_pr_curve.svg",
433
437
  format="svg",
434
438
  bbox_inches="tight",
435
439
  )
@@ -438,7 +442,6 @@ class PlotGenerator:
438
442
  self,
439
443
  benchmarking_results: List[BenchmarkRunResults],
440
444
  benchmark_generator: BenchmarkRunOutputGenerator,
441
- title: str = None,
442
445
  ) -> None:
443
446
  """
444
447
  Generate a non-cumulative bar plot.
@@ -446,8 +449,8 @@ class PlotGenerator:
446
449
  Args:
447
450
  benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
448
451
  benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
449
- title (str, optional): Title for the generated plot. Defaults to None.
450
452
  """
453
+ plt.clf()
451
454
  for benchmark_result in benchmarking_results:
452
455
  self._generate_non_cumulative_bar_plot_data(benchmark_result)
453
456
 
@@ -463,26 +466,27 @@ class PlotGenerator:
463
466
  legend=False,
464
467
  ).set(xlabel="Rank", ylabel=benchmark_generator.y_label)
465
468
  plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15), ncol=3, title="Run")
466
- if title is None:
469
+ if benchmark_generator.plot_customisation.rank_plot_title is None:
467
470
  plt.title(
468
- f"{benchmark_generator.prioritisation_type_file_prefix.capitalize()} Non-Cumulative Rank Stats"
471
+ f"{benchmark_generator.prioritisation_type_string.capitalize()} Non-Cumulative Rank Stats"
469
472
  )
470
473
  else:
471
- plt.title(title, loc="center", fontsize=15)
474
+ plt.title(
475
+ benchmark_generator.plot_customisation.rank_plot_title, loc="center", fontsize=15
476
+ )
472
477
  plt.ylim(0, 1)
473
478
  plt.savefig(
474
- f"{benchmark_generator.prioritisation_type_file_prefix}_rank_stats.svg",
479
+ f"{self.benchmark_name}_{benchmark_generator.prioritisation_type_string}_rank_stats.svg",
475
480
  format="svg",
476
481
  bbox_inches="tight",
477
482
  )
478
483
 
479
484
 
480
485
  def generate_plots(
486
+ benchmark_name: str,
481
487
  benchmarking_results: List[BenchmarkRunResults],
482
488
  benchmark_generator: BenchmarkRunOutputGenerator,
483
- plot_type: str,
484
- title: str = None,
485
- generate_from_tsv: bool = False,
489
+ generate_from_db: bool = False,
486
490
  ) -> None:
487
491
  """
488
492
  Generate summary statistics bar plots for prioritisation.
@@ -492,56 +496,54 @@ def generate_plots(
492
496
  Args:
493
497
  benchmarking_results (list[BenchmarkRunResults]): List of benchmarking results for multiple runs.
494
498
  benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
495
- plot_type (str): Type of plot to be generated ("bar_stacked", "bar_cumulative", "bar_non_cumulative").
496
- title (str, optional): Title for the generated plot. Defaults to None.
497
- generate_from_tsv (bool): Specify whether to generate plots from the TSV file. Defaults to False.
499
+ generate_from_db (bool): Specify whether to generate plots from the db file. Defaults to False.
498
500
  """
499
- plot_generator = PlotGenerator()
500
- if not generate_from_tsv:
501
+ plot_generator = PlotGenerator(benchmark_name)
502
+ if not generate_from_db:
501
503
  plot_generator.generate_roc_curve(benchmarking_results, benchmark_generator)
502
504
  plot_generator.generate_precision_recall(benchmarking_results, benchmark_generator)
503
- if plot_type == "bar_stacked":
504
- plot_generator.generate_stacked_bar_plot(benchmarking_results, benchmark_generator, title)
505
- elif plot_type == "bar_cumulative":
506
- plot_generator.generate_cumulative_bar(benchmarking_results, benchmark_generator, title)
507
- elif plot_type == "bar_non_cumulative":
508
- plot_generator.generate_non_cumulative_bar(benchmarking_results, benchmark_generator, title)
509
-
510
-
511
- def generate_plots_from_benchmark_summary_tsv(
512
- benchmark_summary_tsv: Path,
513
- gene_analysis: bool,
514
- variant_analysis: bool,
515
- disease_analysis: bool,
516
- plot_type: str,
517
- title: str,
505
+ if benchmark_generator.plot_customisation.plot_type == "bar_stacked":
506
+ plot_generator.generate_stacked_bar_plot(benchmarking_results, benchmark_generator)
507
+ elif benchmark_generator.plot_customisation.plot_type == "bar_cumulative":
508
+ plot_generator.generate_cumulative_bar(benchmarking_results, benchmark_generator)
509
+ elif benchmark_generator.plot_customisation.plot_type == "bar_non_cumulative":
510
+ plot_generator.generate_non_cumulative_bar(benchmarking_results, benchmark_generator)
511
+
512
+
513
+ def generate_plots_from_benchmark_summary_db(
514
+ benchmark_db: Path,
515
+ run_data: Path,
518
516
  ):
519
517
  """
520
518
  Generate bar plot from summary benchmark results.
521
519
 
522
- Reads a summary of benchmark results from a TSV file and generates a bar plot
520
+ Reads a summary of benchmark results from a benchmark db and generates a bar plot
523
521
  based on the analysis type and plot type.
524
522
 
525
523
  Args:
526
- benchmark_summary_tsv (Path): Path to the summary TSV file containing benchmark results.
527
- gene_analysis (bool): Flag indicating whether to analyse gene prioritisation.
528
- variant_analysis (bool): Flag indicating whether to analyse variant prioritisation.
529
- disease_analysis (bool): Flag indicating whether to analyse disease prioritisation.
530
- plot_type (str): Type of plot to be generated ("bar_stacked", "bar_cumulative", "bar_non_cumulative").
531
- title (str): Title for the generated plot.
532
- Raises:
533
- ValueError: If an unsupported plot type is specified.
524
+ benchmark_db (Path): Path to the summary TSV file containing benchmark results.
525
+ run_data (Path): Path to YAML benchmarking configuration file.
534
526
  """
535
- benchmark_stats_summary = read_benchmark_tsv_result_summary(benchmark_summary_tsv)
536
- benchmarking_results = parse_benchmark_result_summary(benchmark_stats_summary)
537
- if gene_analysis:
538
- benchmark_generator = GeneBenchmarkRunOutputGenerator()
539
- elif variant_analysis:
540
- benchmark_generator = VariantBenchmarkRunOutputGenerator()
541
- elif disease_analysis:
542
- benchmark_generator = DiseaseBenchmarkRunOutputGenerator()
543
- else:
544
- raise ValueError(
545
- "Specify one analysis type (gene_analysis, variant_analysis, or disease_analysis)"
527
+ benchmark_stats_summary = parse_benchmark_db(benchmark_db)
528
+ config = parse_run_config(run_data)
529
+ if benchmark_stats_summary.gene_results:
530
+ generate_plots(
531
+ config.benchmark_name,
532
+ benchmark_stats_summary.gene_results,
533
+ GeneBenchmarkRunOutputGenerator(config.plot_customisation.gene_plots),
534
+ True,
535
+ )
536
+ if benchmark_stats_summary.variant_results:
537
+ generate_plots(
538
+ config.benchmark_name,
539
+ benchmark_stats_summary.variant_results,
540
+ VariantBenchmarkRunOutputGenerator(config.plot_customisation.variant_plots),
541
+ True,
542
+ )
543
+ elif benchmark_stats_summary.disease_results:
544
+ generate_plots(
545
+ config.benchmark_name,
546
+ benchmark_stats_summary.disease_results,
547
+ DiseaseBenchmarkRunOutputGenerator(config.plot_customisation.disease_plots),
548
+ True,
546
549
  )
547
- generate_plots(benchmarking_results, benchmark_generator, plot_type, title, True)
@@ -1,143 +1,68 @@
1
1
  import itertools
2
- from collections import defaultdict
3
- from copy import deepcopy
4
2
  from typing import List
5
3
 
6
- import numpy as np
7
- import pandas as pd
8
-
4
+ from pheval.analyse.benchmark_db_manager import BenchmarkDBManager
9
5
  from pheval.analyse.benchmark_generator import BenchmarkRunOutputGenerator
10
6
  from pheval.analyse.benchmarking_data import BenchmarkRunResults
11
7
  from pheval.analyse.generate_plots import generate_plots
12
- from pheval.constants import RANK_COMPARISON_FILE_SUFFIX
13
-
14
-
15
- class RankComparisonGenerator:
16
- """Class for writing the run comparison of rank assignment for prioritisation."""
17
-
18
- def __init__(self, run_comparison: defaultdict):
19
- """
20
- Initialise the RankComparisonGenerator class.
21
-
22
- Args:
23
- run_comparison (defaultdict): A nested dictionary containing the run comparison data.
24
- """
25
- self.run_comparison = run_comparison
26
8
 
27
- def _generate_dataframe(self) -> pd.DataFrame:
28
- """
29
- Generate a Pandas DataFrame based on the run comparison data.
30
9
 
31
- Returns:
32
- pd.DataFrame: DataFrame containing the run comparison data.
33
- """
34
- return pd.DataFrame.from_dict(self.run_comparison, orient="index")
35
-
36
- def _calculate_rank_difference(self) -> pd.DataFrame:
37
- """
38
- Calculate the rank decrease for runs, taking the first directory as a baseline.
39
-
40
- Returns:
41
- pd.DataFrame: DataFrame containing the calculated rank differences.
42
- """
43
- comparison_df = self._generate_dataframe()
44
- comparison_df["rank_change"] = comparison_df.iloc[:, 2] - comparison_df.iloc[:, 3]
45
- comparison_df["rank_change"] = np.where(
46
- (comparison_df.iloc[:, 2] == 0) & (comparison_df.iloc[:, 3] != 0),
47
- "GAINED",
48
- np.where(
49
- (comparison_df.iloc[:, 3] == 0) & (comparison_df.iloc[:, 2] != 0),
50
- "LOST",
51
- comparison_df["rank_change"],
52
- ),
53
- )
54
- comparison_df["rank_change"] = comparison_df["rank_change"].apply(
55
- lambda x: int(x) if str(x).lstrip("-").isdigit() else x
56
- )
57
- return comparison_df
58
-
59
- def generate_output(self, prefix: str, suffix: str) -> None:
60
- """
61
- Generate output file from the run comparison data.
62
-
63
- Args:
64
- prefix (str): Prefix for the output file name.
65
- suffix (str): Suffix for the output file name.
66
- """
67
- self._generate_dataframe().to_csv(prefix + suffix, sep="\t")
68
-
69
- def generate_comparison_output(self, prefix: str, suffix: str) -> None:
70
- """
71
- Generate output file with calculated rank differences.
72
-
73
- Args:
74
- prefix (str): Prefix for the output file name.
75
- suffix (str): Suffix for the output file name.
76
- """
77
- self._calculate_rank_difference().to_csv(prefix + suffix, sep="\t")
78
-
79
-
80
- def generate_benchmark_output(
81
- benchmarking_results: BenchmarkRunResults,
82
- plot_type: str,
83
- benchmark_generator: BenchmarkRunOutputGenerator,
84
- ) -> None:
10
+ def get_new_table_name(run_identifier_1: str, run_identifier_2: str, output_prefix: str) -> str:
85
11
  """
86
- Generate prioritisation outputs for a single benchmarking run.
87
-
12
+ Get the new table name for rank comparison tables.
88
13
  Args:
89
- benchmarking_results (BenchmarkRunResults): Results of a benchmarking run.
90
- plot_type (str): Type of plot to generate.
91
- benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
14
+ run_identifier_1: The first run identifier.
15
+ run_identifier_2: The second run identifier.
16
+ output_prefix: The output prefix of the table
17
+ Returns:
18
+ The new table name.
92
19
  """
93
- rank_comparison_data = benchmarking_results.ranks
94
- results_dir_name = benchmarking_results.results_dir.name
95
- RankComparisonGenerator(rank_comparison_data).generate_output(
96
- f"{results_dir_name}",
97
- f"-{benchmark_generator.prioritisation_type_file_prefix}{RANK_COMPARISON_FILE_SUFFIX}",
98
- )
99
- generate_plots(
100
- [benchmarking_results],
101
- benchmark_generator,
102
- plot_type,
103
- )
20
+ return f"{run_identifier_1}_vs_" f"{run_identifier_2}_" f"{output_prefix}_rank_comparison"
104
21
 
105
22
 
106
- def merge_results(result1: dict, result2: dict) -> defaultdict:
23
+ def create_comparison_table(
24
+ comparison_table_name: str,
25
+ connector: BenchmarkDBManager,
26
+ drop_columns: List[str],
27
+ run_identifier_1: str,
28
+ run_identifier_2: str,
29
+ table_name: str,
30
+ ) -> None:
107
31
  """
108
- Merge two nested dictionaries containing results on commonalities.
109
-
110
- This function merges two dictionaries, `result1` and `result2`, containing nested structures.
111
- It traverses the dictionaries recursively and merges their contents based on common keys.
112
- If a key is present in both dictionaries and points to another dictionary, the function
113
- will further merge their nested contents. If a key exists in `result2` but not in `result1`,
114
- it will be added to `result1`.
115
-
32
+ Create rank comparison tables.
116
33
  Args:
117
- result1 (dict): The first dictionary to be merged.
118
- result2 (dict): The second dictionary to be merged.
119
-
120
- Returns:
121
- defaultdict: The merged dictionary containing the combined contents of `result1` and `result2`.
34
+ comparison_table_name (str): Name of the comparison table to create.
35
+ connector (BenchmarkDBManager): DBConnector instance.
36
+ drop_columns (List[str]): List of columns to drop.
37
+ run_identifier_1 (str): The first run identifier.
38
+ run_identifier_2 (str): The second run identifier.
39
+ table_name (str): Name of the table to extract ranks from
122
40
  """
123
- for key, val in result1.items():
124
- if type(val) == dict:
125
- if key in result2 and type(result2[key] == dict):
126
- merge_results(result1[key], result2[key])
127
- else:
128
- if key in result2:
129
- result1[key] = result2[key]
41
+ connector.drop_table(comparison_table_name)
42
+ excluded_columns = tuple(drop_columns + ["identifier"]) if drop_columns else ("identifier",)
43
+ connector.conn.execute(
44
+ f'CREATE TABLE "{comparison_table_name}" AS SELECT * '
45
+ f"EXCLUDE {excluded_columns} FROM {table_name}"
46
+ )
130
47
 
131
- for key, val in result2.items():
132
- if key not in result1:
133
- result1[key] = val
134
- return result1
48
+ connector.conn.execute(
49
+ f"""ALTER TABLE "{comparison_table_name}" ADD COLUMN rank_change VARCHAR;"""
50
+ )
51
+ connector.conn.execute(
52
+ f'UPDATE "{comparison_table_name}" SET rank_change = CASE WHEN "{run_identifier_1}" = 0 '
53
+ f'AND "{run_identifier_2}" != 0 '
54
+ f"THEN 'GAINED' WHEN \"{run_identifier_1}\" != 0 AND \"{run_identifier_2}\" = 0 THEN 'LOST' ELSE "
55
+ f'CAST ("{run_identifier_1}" - "{run_identifier_2}" AS VARCHAR) END;'
56
+ )
57
+ connector.conn.commit()
135
58
 
136
59
 
137
60
  def generate_benchmark_comparison_output(
61
+ benchmark_name: str,
138
62
  benchmarking_results: List[BenchmarkRunResults],
139
- plot_type: str,
63
+ run_identifiers: List[str],
140
64
  benchmark_generator: BenchmarkRunOutputGenerator,
65
+ table_name: str,
141
66
  ) -> None:
142
67
  """
143
68
  Generate prioritisation outputs for benchmarking multiple runs.
@@ -147,29 +72,34 @@ def generate_benchmark_comparison_output(
147
72
  comparison outputs using `RankComparisonGenerator` for each pair.
148
73
 
149
74
  Args:
75
+ benchmark_name (str): Name of the benchmark.
150
76
  benchmarking_results (List[BenchmarkRunResults]): A list containing BenchmarkRunResults instances
151
77
  representing the benchmarking results of multiple runs.
152
- plot_type (str): The type of plot to be generated.
78
+ run_identifiers (List[str]): A list of run identifiers.
153
79
  benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
80
+ table_name (str): The name of the table where ranks are stored.
154
81
  """
155
- output_prefix = benchmark_generator.prioritisation_type_file_prefix
156
- for pair in itertools.combinations(benchmarking_results, 2):
157
- result1 = pair[0]
158
- result2 = pair[1]
159
- merged_results = merge_results(
160
- deepcopy(result1.ranks),
161
- deepcopy(result2.ranks),
82
+ output_prefix = benchmark_generator.prioritisation_type_string
83
+ connector = BenchmarkDBManager(benchmark_name)
84
+ for pair in itertools.combinations(
85
+ [str(result.benchmark_name) for result in benchmarking_results], 2
86
+ ):
87
+ run_identifier_1 = pair[0]
88
+ run_identifier_2 = pair[1]
89
+ drop_columns = [run for run in run_identifiers if run not in pair]
90
+ comparison_table_name = get_new_table_name(
91
+ run_identifier_1, run_identifier_2, output_prefix
162
92
  )
163
- RankComparisonGenerator(merged_results).generate_comparison_output(
164
- f"{result1.results_dir.parents[0].name}_"
165
- f"{result1.results_dir.name}"
166
- f"_vs_{result2.results_dir.parents[0].name}_"
167
- f"{result2.results_dir.name}",
168
- f"-{output_prefix}{RANK_COMPARISON_FILE_SUFFIX}",
93
+ create_comparison_table(
94
+ comparison_table_name,
95
+ connector,
96
+ drop_columns,
97
+ run_identifier_1,
98
+ run_identifier_2,
99
+ table_name,
169
100
  )
170
-
171
101
  generate_plots(
102
+ benchmark_name,
172
103
  benchmarking_results,
173
104
  benchmark_generator,
174
- plot_type,
175
105
  )