pheval 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pheval might be problematic. Click here for more details.

Files changed (33) hide show
  1. pheval/analyse/benchmark.py +156 -0
  2. pheval/analyse/benchmark_db_manager.py +16 -134
  3. pheval/analyse/benchmark_output_type.py +43 -0
  4. pheval/analyse/binary_classification_curves.py +132 -0
  5. pheval/analyse/binary_classification_stats.py +164 -307
  6. pheval/analyse/generate_plots.py +210 -395
  7. pheval/analyse/generate_rank_comparisons.py +44 -0
  8. pheval/analyse/rank_stats.py +190 -382
  9. pheval/analyse/run_data_parser.py +21 -39
  10. pheval/cli.py +28 -25
  11. pheval/cli_pheval_utils.py +7 -8
  12. pheval/post_processing/phenopacket_truth_set.py +235 -0
  13. pheval/post_processing/post_processing.py +183 -303
  14. pheval/post_processing/validate_result_format.py +92 -0
  15. pheval/prepare/update_phenopacket.py +11 -9
  16. pheval/utils/logger.py +35 -0
  17. pheval/utils/phenopacket_utils.py +85 -91
  18. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/METADATA +4 -4
  19. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/RECORD +22 -26
  20. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/WHEEL +1 -1
  21. pheval/analyse/analysis.py +0 -104
  22. pheval/analyse/assess_prioritisation_base.py +0 -108
  23. pheval/analyse/benchmark_generator.py +0 -126
  24. pheval/analyse/benchmarking_data.py +0 -25
  25. pheval/analyse/disease_prioritisation_analysis.py +0 -152
  26. pheval/analyse/gene_prioritisation_analysis.py +0 -147
  27. pheval/analyse/generate_summary_outputs.py +0 -105
  28. pheval/analyse/parse_benchmark_summary.py +0 -81
  29. pheval/analyse/parse_corpus.py +0 -219
  30. pheval/analyse/prioritisation_result_types.py +0 -52
  31. pheval/analyse/variant_prioritisation_analysis.py +0 -159
  32. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/LICENSE +0 -0
  33. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -1,447 +1,255 @@
1
- from dataclasses import dataclass, field
2
- from statistics import mean
1
+ from dataclasses import dataclass
3
2
  from typing import List
4
3
 
5
4
  import numpy as np
6
- from duckdb import DuckDBPyConnection
5
+ import polars as pl
7
6
  from sklearn.metrics import ndcg_score
8
7
 
9
- from pheval.analyse.benchmark_db_manager import BenchmarkDBManager
10
- from pheval.analyse.binary_classification_stats import BinaryClassificationStats
8
+ from pheval.utils.logger import get_logger
11
9
 
12
10
 
13
- @dataclass
14
- class RankStats:
15
- """Store statistics related to ranking.
16
-
17
- Attributes:
18
- top (int): Count of top-ranked matches.
19
- top3 (int): Count of matches within the top 3 ranks.
20
- top5 (int): Count of matches within the top 5 ranks.
21
- top10 (int): Count of matches within the top 10 ranks.
22
- found (int): Count of found matches.
23
- total (int): Total count of matches.
24
- reciprocal_ranks (List[float]): List of reciprocal ranks.
25
- relevant_ranks List[List[int]]: Nested list of ranks for the known entities for all cases in a run.
26
- mrr (float): Mean Reciprocal Rank (MRR). Defaults to None.
11
+ @dataclass(frozen=True)
12
+ class Ranks:
13
+ """
14
+ Class for calculating ranking statistics.
27
15
  """
28
16
 
29
- top: int = 0
30
- top3: int = 0
31
- top5: int = 0
32
- top10: int = 0
33
- found: int = 0
34
- total: int = 0
35
- reciprocal_ranks: List = field(default_factory=list)
36
- relevant_result_ranks: List[List[int]] = field(default_factory=list)
37
- mrr: float = None
38
-
39
- def add_ranks(self, benchmark_name: str, table_name: str, column_name: str) -> None:
40
- """
41
- Add ranks to RankStats instance from table.
42
- Args:
43
- table_name (str): Name of the table to add ranks from.
44
- column_name (str): Name of the column to add ranks from.:
45
- """
46
- conn = BenchmarkDBManager(benchmark_name).conn
47
- self.top = self._execute_count_query(conn, table_name, column_name, " = 1")
48
- self.top3 = self._execute_count_query(conn, table_name, column_name, " BETWEEN 1 AND 3")
49
- self.top5 = self._execute_count_query(conn, table_name, column_name, " BETWEEN 1 AND 5")
50
- self.top10 = self._execute_count_query(conn, table_name, column_name, " BETWEEN 1 AND 10")
51
- self.found = self._execute_count_query(conn, table_name, column_name, " > 0")
52
- self.total = self._execute_count_query(conn, table_name, column_name, " >= 0")
53
- self.reciprocal_ranks = self._fetch_reciprocal_ranks(conn, table_name, column_name)
54
- self.relevant_result_ranks = self._fetch_relevant_ranks(conn, table_name, column_name)
55
- conn.close()
56
-
57
- @staticmethod
58
- def _execute_count_query(
59
- conn: DuckDBPyConnection, table_name: str, column_name: str, condition: str
60
- ) -> int:
61
- """
62
- Execute count query on table.
63
- Args:
64
- conn (DuckDBPyConnection): Connection to the database.
65
- table_name (str): Name of the table to execute count query on.
66
- column_name (str): Name of the column to execute count query on.
67
- condition (str): Condition to execute count query.
68
- Returns:
69
- int: Count query result.
70
- """
71
- query = f'SELECT COUNT(*) FROM "{table_name}" WHERE "{column_name}" {condition}'
72
- return conn.execute(query).fetchone()[0]
73
-
74
- @staticmethod
75
- def _fetch_reciprocal_ranks(
76
- conn: DuckDBPyConnection, table_name: str, column_name: str
77
- ) -> List[float]:
78
- """
79
- Fetch reciprocal ranks from table.
80
- Args:
81
- conn (DuckDBPyConnection): Connection to the database.
82
- table_name (str): Name of the table to fetch reciprocal ranks from.
83
- column_name (str): Name of the column to fetch reciprocal ranks from.
84
-
85
- Returns:
86
- List[float]: List of reciprocal ranks.
87
- """
88
- query = f'SELECT "{column_name}" FROM "{table_name}"'
89
- return [1 / rank[0] if rank[0] > 0 else 0 for rank in conn.execute(query).fetchall()]
17
+ TOP_1 = pl.col("rank").eq(1).sum().alias("top1")
18
+ TOP_3 = pl.col("rank").is_between(1, 3, closed="both").sum().alias("top3")
19
+ TOP_5 = pl.col("rank").is_between(1, 5, closed="both").sum().alias("top5")
20
+ TOP_10 = pl.col("rank").is_between(1, 10, closed="both").sum().alias("top10")
21
+ FOUND = pl.col("rank").gt(0).sum().alias("found")
22
+ TOTAL = pl.len().alias("total")
23
+ NUMBER_OF_SAMPLES = pl.col("file_path").n_unique().alias("number_of_samples")
24
+ MRR = ((1 / pl.col("rank").filter(pl.col("rank") > 0)).sum() / pl.len()).alias("mrr")
90
25
 
91
- @staticmethod
92
- def _fetch_relevant_ranks(
93
- conn: DuckDBPyConnection, table_name: str, column_name: str
94
- ) -> List[List[int]]:
26
+ @classmethod
27
+ def _filter_results(cls, df: pl.LazyFrame, k: int) -> pl.LazyFrame:
95
28
  """
96
- Fetch relevant ranks from table.
29
+ Filter for ranks within k.
97
30
  Args:
98
- conn (DuckDBPyConnection): Connection to the database.
99
- table_name (str): Name of the table to fetch relevant ranks from.
100
- column_name (str): Name of the column to fetch relevant ranks from.
31
+ df (pl.LazyFrame): The dataframe to filter.
32
+ k (int): The number upper rank limit.
101
33
 
102
34
  Returns:
103
- List[List[int]]: List of relevant ranks.
35
+ pl.LazyFrame: The filtered dataframe.
104
36
  """
105
- query = (
106
- f'SELECT LIST("{column_name}") as values_list FROM "{table_name}" GROUP BY phenopacket'
37
+ df = df.filter(pl.col("rank").is_between(1, k, closed="both"))
38
+ return df.group_by("file_path").agg(
39
+ pl.col("rank").sort().alias("ranks"),
107
40
  )
108
- return [rank[0] for rank in conn.execute(query).fetchall()]
109
41
 
110
- def percentage_rank(self, value: int) -> float:
42
+ @classmethod
43
+ def percentage_at_k(cls, k: int) -> pl.Expr:
111
44
  """
112
- Calculate the percentage rank.
113
-
45
+ Compute percentage at k dynamically.
114
46
  Args:
115
- value (int): The value for which the percentage rank needs to be calculated.
116
-
117
- Returns:
118
- float: The calculated percentage rank based on the provided value and the total count.
119
- """
120
- return 100 * value / self.total
121
-
122
- def percentage_top(self) -> float:
123
- """
124
- Calculate the percentage of top matches.
125
-
126
- Returns:
127
- float: The percentage of top matches compared to the total count.
128
- """
129
- return self.percentage_rank(self.top)
130
-
131
- def percentage_top3(self) -> float:
132
- """
133
- Calculate the percentage of matches within the top 3.
134
-
135
- Returns:
136
- float: The percentage of matches within the top 3 compared to the total count.
137
- """
138
- return self.percentage_rank(self.top3)
139
-
140
- def percentage_top5(self) -> float:
141
- """
142
- Calculate the percentage of matches within the top 5.
143
-
144
- Returns:
145
- float: The percentage of matches within the top 5 compared to the total count.
146
- """
147
- return self.percentage_rank(self.top5)
148
-
149
- def percentage_top10(self) -> float:
150
- """
151
- Calculate the percentage of matches within the top 10.
152
-
47
+ k (int): The upper rank limit.
153
48
  Returns:
154
- float: The percentage of matches within the top 10 compared to the total count.
49
+ pl.Expr: The expression for calculating percentage at k.
155
50
  """
156
- return self.percentage_rank(self.top10)
51
+ return (100 * pl.col(f"top{k}") / pl.col("total")).alias(f"percentage@{k}")
157
52
 
158
- def percentage_found(self) -> float:
53
+ @classmethod
54
+ def percentage_found(cls) -> pl.Expr:
159
55
  """
160
- Calculate the percentage of matches found.
161
-
56
+ Compute the percentage of found items.
162
57
  Returns:
163
- float: The percentage of matches found compared to the total count.
58
+ pl.Expr: The expression for calculating percentage of found items.
164
59
  """
165
- return self.percentage_rank(self.found)
60
+ return (100 * pl.col("found") / pl.col("total")).alias("percentage_found")
166
61
 
167
- @staticmethod
168
- def percentage_difference(percentage_value_1: float, percentage_value_2: float) -> float:
62
+ @classmethod
63
+ def precision_at_k(cls, k: int) -> pl.Expr:
169
64
  """
170
- Calculate the percentage difference between two percentage values.
171
-
65
+ Compute precision at k dynamically.
172
66
  Args:
173
- percentage_value_1 (float): The first percentage value.
174
- percentage_value_2 (float): The second percentage value.
175
-
67
+ k (int): The upper rank limit.
176
68
  Returns:
177
- float: The difference between the two percentage values.
69
+ pl.Expr: The expression for calculating precision at k.
178
70
  """
179
- return percentage_value_1 - percentage_value_2
71
+ return (pl.col(f"top{k}") / (pl.col("number_of_samples") * k)).alias(f"precision@{k}")
180
72
 
181
- def mean_reciprocal_rank(self) -> float:
73
+ @classmethod
74
+ def f_beta_score_at_k(cls, k: int) -> pl.Expr:
182
75
  """
183
- Calculate the Mean Reciprocal Rank (MRR) for the stored ranks.
184
-
185
- The Mean Reciprocal Rank is computed as the mean of the reciprocal ranks
186
- for the found cases.
187
-
188
- If the total number of cases differs from the number of found cases,
189
- this method extends the reciprocal ranks list with zeroes for missing cases.
190
-
191
- Returns:
192
- float: The calculated Mean Reciprocal Rank.
193
- """
194
- if len(self.reciprocal_ranks) != self.total:
195
- missing_cases = self.total - self.found
196
- self.reciprocal_ranks.extend([0] * missing_cases)
197
- return mean(self.reciprocal_ranks)
198
- return mean(self.reciprocal_ranks)
199
-
200
- def return_mean_reciprocal_rank(self) -> float:
201
- """
202
- Retrieve or calculate the Mean Reciprocal Rank (MRR).
203
-
204
- If a pre-calculated MRR value exists (stored in the 'mrr' attribute), this method returns that value.
205
- Otherwise, it computes the Mean Reciprocal Rank using the 'mean_reciprocal_rank' method.
206
-
207
- Returns:
208
- float: The Mean Reciprocal Rank value.
209
- """
210
- if self.mrr is not None:
211
- return self.mrr
212
- else:
213
- return self.mean_reciprocal_rank()
214
-
215
- def precision_at_k(self, k: int) -> float:
216
- """
217
- Calculate the precision at k.
218
- Precision at k is the ratio of relevant items in the top-k predictions to the total number of predictions.
219
- It measures the accuracy of the top-k predictions made by a model.
220
-
76
+ Compute f_beta_score at k.
221
77
  Args:
222
- k (int): The number of top predictions to consider.
223
-
78
+ k (int): The upper rank limit.
224
79
  Returns:
225
- float: The precision at k, ranging from 0.0 to 1.0.
226
- A higher precision indicates a better performance in identifying relevant items in the top-k predictions.
80
+ pl.Expr: The expression for calculating f_beta_score at k.
227
81
  """
228
- k_attr = getattr(self, f"top{k}") if k > 1 else self.top
229
- return k_attr / (self.total * k)
82
+ precision_expr = pl.col(f"top{k}") / (pl.col("number_of_samples") * k)
83
+ recall_expr = pl.col(f"top{k}") / pl.col("total")
84
+ return (
85
+ ((2 * precision_expr * recall_expr) / (precision_expr + recall_expr))
86
+ .fill_nan(0)
87
+ .alias(f"f_beta@{k}")
88
+ )
230
89
 
231
- @staticmethod
232
- def _average_precision_at_k(
233
- number_of_relevant_entities_at_k: int, precision_at_k: float
234
- ) -> float:
90
+ @classmethod
91
+ def _average_precision_at_k(cls, df: pl.LazyFrame, k: int) -> pl.LazyFrame:
235
92
  """
236
- Calculate the Average Precision at k.
93
+ Compute Average Precision at K (AP@K) for each query.
237
94
 
238
- Average Precision at k (AP@k) is a metric used to evaluate the precision of a ranked retrieval system.
239
- It measures the precision at each relevant position up to k and takes the average.
95
+ AP@K = (1 / min(k, R)) * sum(P(i) * rel(i)) for i k
240
96
 
241
97
  Args:
242
- number_of_relevant_entities_at_k (int): The count of relevant entities in the top-k predictions.
243
- precision_at_k (float): The precision at k - the sum of the precision values at each relevant position.
244
-
245
- Returns:
246
- float: The Average Precision at k, ranging from 0.0 to 1.0.
247
- A higher value indicates better precision in the top-k predictions.
248
- """
249
- return (
250
- (1 / number_of_relevant_entities_at_k) * precision_at_k
251
- if number_of_relevant_entities_at_k > 0
252
- else 0.0
98
+ df (pl.LazyFrame): The dataframe calculate AP@K for each query.
99
+ k (int): The upper rank limit.
100
+ Returns:
101
+ pl.LazyFrame: The dataframe with AP@K for each query.
102
+ """
103
+ filtered_df = cls._filter_results(df, k)
104
+ df_grouped = filtered_df.with_columns(
105
+ pl.struct("ranks")
106
+ .map_elements(
107
+ lambda row: cls._compute_ap_k(np.array(row["ranks"])), return_dtype=pl.Float64
108
+ )
109
+ .alias(f"ap@{k}")
253
110
  )
111
+ return df_grouped.select(["file_path", f"ap@{k}"])
254
112
 
255
- def mean_average_precision_at_k(self, k: int) -> float:
113
+ @staticmethod
114
+ def _compute_ap_k(ranks: np.array) -> np.floating:
256
115
  """
257
- Calculate the Mean Average Precision at k.
258
-
259
- Mean Average Precision at k (MAP@k) is a performance metric for ranked data.
260
- It calculates the average precision at k for each result rank and then takes the mean across all queries.
261
-
116
+ Helper function to compute AP@K for a single query.
262
117
  Args:
263
- k (int): The number of top predictions to consider for precision calculation.
264
-
118
+ ranks (np.array): The ranks to compute AP@K.
265
119
  Returns:
266
- float: The Mean Average Precision at k, ranging from 0.0 to 1.0.
267
- A higher value indicates better performance in ranking relevant entities higher in the predictions.
120
+ float: The AP@K.
268
121
  """
269
- cumulative_average_precision_scores = 0
270
- for result_ranks in self.relevant_result_ranks:
271
- precision_at_k, number_of_relevant_entities_at_k = 0, 0
272
- for rank in result_ranks:
273
- if 0 < rank <= k:
274
- number_of_relevant_entities_at_k += 1
275
- precision_at_k += number_of_relevant_entities_at_k / rank
276
- cumulative_average_precision_scores += self._average_precision_at_k(
277
- number_of_relevant_entities_at_k, precision_at_k
278
- )
279
- return (1 / self.total) * cumulative_average_precision_scores
122
+ num_relevant = np.arange(1, len(ranks) + 1)
123
+ precision_at_k = num_relevant / ranks
124
+ return np.mean(precision_at_k)
280
125
 
281
- def f_beta_score_at_k(self, percentage_at_k: float, k: int) -> float:
126
+ @classmethod
127
+ def mean_average_precision_at_k(cls, df: pl.LazyFrame, k: int) -> pl.LazyFrame:
282
128
  """
283
- Calculate the F-beta score at k.
284
-
285
- The F-beta score is a metric that combines precision and recall,
286
- with beta controlling the emphasis on precision.
287
- The Beta value is set to the value of 1 to allow for equal weighting for both precision and recall.
288
- This method computes the F-beta score at a specific percentage threshold within the top-k predictions.
289
-
129
+ Compute Mean Average Precision at K (MAP@K) by averaging AP@K scores.
290
130
  Args:
291
- percentage_at_k (float): The percentage of true positive predictions within the top-k.
292
- k (int): The number of top predictions to consider.
293
-
131
+ df (pl.LazyFrame): The dataframe calculate MAP@K for each query.
132
+ k (int): The upper rank limit.
294
133
  Returns:
295
- float: The F-beta score at k, ranging from 0.0 to 1.0.
296
- A higher score indicates better trade-off between precision and recall.
134
+ pl.LazyFrame: The dataframe with MAP@K for each query.
297
135
  """
298
- precision = self.precision_at_k(k)
299
- recall_at_k = percentage_at_k / 100
136
+ ap_at_k_df = cls._average_precision_at_k(df, k)
300
137
  return (
301
- (2 * precision * recall_at_k) / (precision + recall_at_k)
302
- if (precision + recall_at_k) > 0
303
- else 0
138
+ ap_at_k_df.select(
139
+ pl.col(f"ap@{k}").sum() / df.select(Ranks.NUMBER_OF_SAMPLES).collect()
140
+ )
141
+ .fill_null(0.0)
142
+ .collect()
143
+ .item()
304
144
  )
305
145
 
306
- def mean_normalised_discounted_cumulative_gain(self, k: int) -> float:
146
+ @classmethod
147
+ def _calculate_ndcg_at_k(cls, ranks: List[int], k: int) -> float:
307
148
  """
308
- Calculate the mean Normalised Discounted Cumulative Gain (NDCG) for a given rank cutoff.
309
-
310
- NDCG measures the effectiveness of a ranking by considering both the relevance and the order of items.
311
-
149
+ Compute NDCG@K for a single query.
312
150
  Args:
313
- k (int): The rank cutoff for calculating NDCG.
314
-
151
+ ranks (List[int]): The ranks to compute NDCG@K.
152
+ k (int): The upper rank limit.
315
153
  Returns:
316
- float: The mean NDCG score across all query results.
317
- """
318
- ndcg_scores = []
319
- for result_ranks in self.relevant_result_ranks:
320
- result_ranks = [rank for rank in result_ranks if rank <= k]
321
- result_ranks = [3 if i in result_ranks else 0 for i in range(k)]
322
- ideal_ranking = sorted(result_ranks, reverse=True)
323
- ndcg_scores.append(ndcg_score(np.asarray([ideal_ranking]), np.asarray([result_ranks])))
324
- return np.mean(ndcg_scores)
325
-
326
-
327
- class RankStatsWriter:
328
- """Class for writing the rank stats to a file."""
329
-
330
- def __init__(self, benchmark_name: str, table_name: str):
154
+ float: The NDCG@K.
331
155
  """
332
- Initialise the RankStatsWriter class
333
- Args:
334
- table_name (str): Name of table to add statistics.
335
- """
336
-
337
- self.table_name = table_name
338
- self.benchmark_name = benchmark_name
339
- conn = BenchmarkDBManager(benchmark_name).conn
340
- conn.execute(
341
- f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ('
342
- f"results_directory_path VARCHAR,"
343
- f"top INT,"
344
- f"top3 INT,"
345
- f"top5 INT,"
346
- f"top10 INT,"
347
- f'"found" INT,'
348
- f"total INT,"
349
- f"mean_reciprocal_rank FLOAT,"
350
- f"percentage_top FLOAT,"
351
- f"percentage_top3 FLOAT,"
352
- f"percentage_top5 FLOAT,"
353
- f"percentage_top10 FLOAT,"
354
- f"percentage_found FLOAT,"
355
- f'"precision@1" FLOAT,'
356
- f'"precision@3" FLOAT,'
357
- f'"precision@5" FLOAT,'
358
- f'"precision@10" FLOAT,'
359
- f'"MAP@1" FLOAT,'
360
- f'"MAP@3" FLOAT,'
361
- f'"MAP@5" FLOAT,'
362
- f'"MAP@10" FLOAT,'
363
- f'"f_beta_score@1" FLOAT,'
364
- f'"f_beta_score@3"FLOAT,'
365
- f'"f_beta_score@5" FLOAT,'
366
- f'"f_beta_score@10" FLOAT,'
367
- f'"NDCG@3" FLOAT,'
368
- f'"NDCG@5" FLOAT,'
369
- f'"NDCG@10" FLOAT,'
370
- f"true_positives INT,"
371
- f"false_positives INT,"
372
- f"true_negatives INT,"
373
- f"false_negatives INT,"
374
- f"sensitivity FLOAT,"
375
- f"specificity FLOAT,"
376
- f'"precision" FLOAT,'
377
- f"negative_predictive_value FLOAT,"
378
- f"false_positive_rate FLOAT,"
379
- f"false_discovery_rate FLOAT,"
380
- f"false_negative_rate FLOAT,"
381
- f"accuracy FLOAT,"
382
- f"f1_score FLOAT,"
383
- f"matthews_correlation_coefficient FLOAT, )"
156
+ result_ranks = np.zeros(k, dtype=int)
157
+ indices = np.array(ranks) - 1
158
+ valid_indices = indices[(indices >= 0) & (indices < k)]
159
+ result_ranks[valid_indices] = 3
160
+ ideal_ranking = np.sort(result_ranks)[::-1]
161
+ return (
162
+ ndcg_score(result_ranks.reshape(1, -1), ideal_ranking.reshape(1, -1))
163
+ if np.sum(result_ranks) > 0
164
+ else 0.0
384
165
  )
385
- conn.close()
386
166
 
387
- def add_statistics_entry(
388
- self,
389
- run_identifier: str,
390
- rank_stats: RankStats,
391
- binary_classification: BinaryClassificationStats,
392
- ):
167
+ @classmethod
168
+ def mean_normalised_discounted_cumulative_gain(cls, df: pl.LazyFrame, k: int) -> pl.Float64:
393
169
  """
394
- Add statistics row to table for a run.
170
+ Compute mean normalised discounted cumulative gain.
395
171
  Args:
396
- run_identifier (str): The run identifier.
397
- rank_stats (RankStats): RankStats object for the run.
398
- binary_classification (BinaryClassificationStats): BinaryClassificationStats object for the run.
172
+ df (pl.LazyFrame): The dataframe to calculate mean normalised cumulative gain.
173
+ k (int): The upper rank limit.
174
+ Returns:
175
+ pl.LazyFrame: The dataframe with mean normalised cumulative gain.
399
176
  """
400
- conn = BenchmarkDBManager(self.benchmark_name).conn
401
- conn.execute(
402
- f' INSERT INTO "{self.table_name}" VALUES ( '
403
- f"'{run_identifier}',"
404
- f"{rank_stats.top},"
405
- f"{rank_stats.top3},"
406
- f"{rank_stats.top5},"
407
- f"{rank_stats.top10},"
408
- f"{rank_stats.found},"
409
- f"{rank_stats.total},"
410
- f"{rank_stats.mean_reciprocal_rank()},"
411
- f"{rank_stats.percentage_top()},"
412
- f"{rank_stats.percentage_top3()},"
413
- f"{rank_stats.percentage_top5()},"
414
- f"{rank_stats.percentage_top10()},"
415
- f"{rank_stats.percentage_found()},"
416
- f"{rank_stats.precision_at_k(1)},"
417
- f"{rank_stats.precision_at_k(3)},"
418
- f"{rank_stats.precision_at_k(5)},"
419
- f"{rank_stats.precision_at_k(10)},"
420
- f"{rank_stats.mean_average_precision_at_k(1)},"
421
- f"{rank_stats.mean_average_precision_at_k(3)},"
422
- f"{rank_stats.mean_average_precision_at_k(5)},"
423
- f"{rank_stats.mean_average_precision_at_k(10)},"
424
- f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 1)},"
425
- f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 3)},"
426
- f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 5)},"
427
- f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 10)},"
428
- f"{rank_stats.mean_normalised_discounted_cumulative_gain(3)},"
429
- f"{rank_stats.mean_normalised_discounted_cumulative_gain(5)},"
430
- f"{rank_stats.mean_normalised_discounted_cumulative_gain(10)},"
431
- f"{binary_classification.true_positives},"
432
- f"{binary_classification.false_positives},"
433
- f"{binary_classification.true_negatives},"
434
- f"{binary_classification.false_negatives},"
435
- f"{binary_classification.sensitivity()},"
436
- f"{binary_classification.specificity()},"
437
- f"{binary_classification.precision()},"
438
- f"{binary_classification.negative_predictive_value()},"
439
- f"{binary_classification.false_positive_rate()},"
440
- f"{binary_classification.false_discovery_rate()},"
441
- f"{binary_classification.false_negative_rate()},"
442
- f"{binary_classification.accuracy()},"
443
- f"{binary_classification.f1_score()},"
444
- f"{binary_classification.matthews_correlation_coefficient()})"
177
+ filtered_df = cls._filter_results(df, k)
178
+ return (
179
+ filtered_df.with_columns(
180
+ pl.struct("ranks")
181
+ .map_elements(
182
+ lambda row: cls._calculate_ndcg_at_k(row["ranks"], k), return_dtype=pl.Float64
183
+ )
184
+ .alias(f"NDCG@{k}")
185
+ )
186
+ .select(pl.col(f"NDCG@{k}").sum() / df.select(Ranks.NUMBER_OF_SAMPLES).collect())
187
+ .fill_null(0.0)
188
+ .collect()
189
+ .item()
445
190
  )
446
191
 
447
- conn.close()
192
+
193
+ def compute_rank_stats(run_identifier: str, result_scan: pl.LazyFrame) -> pl.LazyFrame:
194
+ """
195
+ Computes ranking statistics for a given benchmarking run.
196
+ Args:
197
+ run_identifier (str): The identifier of the benchmarking run.
198
+ result_scan (pl.LazyFrame): The scan of the directory to compute ranking statistics for.
199
+ """
200
+ logger = get_logger()
201
+ logger.info(f"Generating ranking statistics for {run_identifier}...")
202
+ true_positive_scan = result_scan.filter(pl.col("true_positive"))
203
+ rankings = true_positive_scan.select(
204
+ [
205
+ pl.lit(run_identifier).alias("run_identifier"),
206
+ Ranks.TOP_1.alias("top1"),
207
+ Ranks.TOP_3.alias("top3"),
208
+ Ranks.TOP_5.alias("top5"),
209
+ Ranks.TOP_10.alias("top10"),
210
+ Ranks.FOUND.alias("found"),
211
+ Ranks.TOTAL.alias("total"),
212
+ Ranks.NUMBER_OF_SAMPLES.alias("number_of_samples"),
213
+ Ranks.MRR.alias("mrr"),
214
+ ]
215
+ )
216
+
217
+ return rankings.select(
218
+ [
219
+ pl.col("run_identifier"),
220
+ pl.col("top1"),
221
+ pl.col("top3"),
222
+ pl.col("top5"),
223
+ pl.col("top10"),
224
+ pl.col("found"),
225
+ pl.col("total"),
226
+ pl.col("number_of_samples"),
227
+ pl.col("mrr"),
228
+ Ranks.percentage_at_k(1),
229
+ Ranks.percentage_at_k(3),
230
+ Ranks.percentage_at_k(5),
231
+ Ranks.percentage_at_k(10),
232
+ Ranks.percentage_found(),
233
+ Ranks.precision_at_k(1),
234
+ Ranks.precision_at_k(3),
235
+ Ranks.precision_at_k(5),
236
+ Ranks.precision_at_k(10),
237
+ Ranks.f_beta_score_at_k(1),
238
+ Ranks.f_beta_score_at_k(3),
239
+ Ranks.f_beta_score_at_k(5),
240
+ Ranks.f_beta_score_at_k(10),
241
+ pl.lit(Ranks.mean_average_precision_at_k(true_positive_scan, 1)).alias("MAP@1"),
242
+ pl.lit(Ranks.mean_average_precision_at_k(true_positive_scan, 3)).alias("MAP@3"),
243
+ pl.lit(Ranks.mean_average_precision_at_k(true_positive_scan, 5)).alias("MAP@5"),
244
+ pl.lit(Ranks.mean_average_precision_at_k(true_positive_scan, 10)).alias("MAP@10"),
245
+ pl.lit(Ranks.mean_normalised_discounted_cumulative_gain(true_positive_scan, 3)).alias(
246
+ "NDCG@3"
247
+ ),
248
+ pl.lit(Ranks.mean_normalised_discounted_cumulative_gain(true_positive_scan, 5)).alias(
249
+ "NDCG@5"
250
+ ),
251
+ pl.lit(Ranks.mean_normalised_discounted_cumulative_gain(true_positive_scan, 10)).alias(
252
+ "NDCG@10"
253
+ ),
254
+ ]
255
+ )