pheval 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pheval might be problematic. Click here for more details.

@@ -1,12 +1,12 @@
1
- import csv
2
1
  from dataclasses import dataclass, field
3
- from pathlib import Path
4
2
  from statistics import mean
5
3
  from typing import List
6
4
 
7
5
  import numpy as np
6
+ from duckdb import DuckDBPyConnection
8
7
  from sklearn.metrics import ndcg_score
9
8
 
9
+ from pheval.analyse.benchmark_db_manager import BenchmarkDBManager
10
10
  from pheval.analyse.binary_classification_stats import BinaryClassificationStats
11
11
 
12
12
 
@@ -36,29 +36,76 @@ class RankStats:
36
36
  relevant_result_ranks: List[List[int]] = field(default_factory=list)
37
37
  mrr: float = None
38
38
 
39
- def add_rank(self, rank: int) -> None:
39
+ def add_ranks(self, benchmark_name: str, table_name: str, column_name: str) -> None:
40
40
  """
41
- Add rank for matched result.
41
+ Add ranks to RankStats instance from table.
42
+ Args:
43
+ table_name (str): Name of the table to add ranks from.
44
+ column_name (str): Name of the column to add ranks from.:
45
+ """
46
+ conn = BenchmarkDBManager(benchmark_name).conn
47
+ self.top = self._execute_count_query(conn, table_name, column_name, " = 1")
48
+ self.top3 = self._execute_count_query(conn, table_name, column_name, " BETWEEN 1 AND 3")
49
+ self.top5 = self._execute_count_query(conn, table_name, column_name, " BETWEEN 1 AND 5")
50
+ self.top10 = self._execute_count_query(conn, table_name, column_name, " BETWEEN 1 AND 10")
51
+ self.found = self._execute_count_query(conn, table_name, column_name, " > 0")
52
+ self.total = self._execute_count_query(conn, table_name, column_name, " >= 0")
53
+ self.reciprocal_ranks = self._fetch_reciprocal_ranks(conn, table_name, column_name)
54
+ self.relevant_result_ranks = self._fetch_relevant_ranks(conn, table_name, column_name)
55
+ conn.close()
56
+
57
+ @staticmethod
58
+ def _execute_count_query(
59
+ conn: DuckDBPyConnection, table_name: str, column_name: str, condition: str
60
+ ) -> int:
61
+ """
62
+ Execute count query on table.
63
+ Args:
64
+ conn (DuckDBPyConnection): Connection to the database.
65
+ table_name (str): Name of the table to execute count query on.
66
+ column_name (str): Name of the column to execute count query on.
67
+ condition (str): Condition to execute count query.
68
+ Returns:
69
+ int: Count query result.
70
+ """
71
+ query = f'SELECT COUNT(*) FROM {table_name} WHERE "{column_name}" {condition}'
72
+ return conn.execute(query).fetchone()[0]
42
73
 
74
+ @staticmethod
75
+ def _fetch_reciprocal_ranks(
76
+ conn: DuckDBPyConnection, table_name: str, column_name: str
77
+ ) -> List[float]:
78
+ """
79
+ Fetch reciprocal ranks from table.
43
80
  Args:
44
- rank (int): The rank value to be added.
45
-
46
- Notes:
47
- This method updates the internal attributes of the RankStats object based on the provided rank value.
48
- It calculates various statistics such as the count of top ranks (1, 3, 5, and 10),
49
- the total number of ranks found,and the reciprocal rank.
50
- This function modifies the object's state by updating the internal attributes.
51
- """
52
- self.reciprocal_ranks.append(1 / rank)
53
- self.found += 1
54
- if rank == 1:
55
- self.top += 1
56
- if rank != "" and rank <= 3:
57
- self.top3 += 1
58
- if rank != "" and rank <= 5:
59
- self.top5 += 1
60
- if rank != "" and rank <= 10:
61
- self.top10 += 1
81
+ conn (DuckDBPyConnection): Connection to the database.
82
+ table_name (str): Name of the table to fetch reciprocal ranks from.
83
+ column_name (str): Name of the column to fetch reciprocal ranks from.
84
+
85
+ Returns:
86
+ List[float]: List of reciprocal ranks.
87
+ """
88
+ query = f'SELECT "{column_name}" FROM {table_name}'
89
+ return [1 / rank[0] if rank[0] > 0 else 0 for rank in conn.execute(query).fetchall()]
90
+
91
+ @staticmethod
92
+ def _fetch_relevant_ranks(
93
+ conn: DuckDBPyConnection, table_name: str, column_name: str
94
+ ) -> List[List[int]]:
95
+ """
96
+ Fetch relevant ranks from table.
97
+ Args:
98
+ conn (DuckDBPyConnection): Connection to the database.
99
+ table_name (str): Name of the table to fetch relevant ranks from.
100
+ column_name (str): Name of the column to fetch relevant ranks from.
101
+
102
+ Returns:
103
+ List[List[int]]: List of relevant ranks.
104
+ """
105
+ query = (
106
+ f'SELECT LIST("{column_name}") as values_list FROM {table_name} GROUP BY phenopacket'
107
+ )
108
+ return [rank[0] for rank in conn.execute(query).fetchall()]
62
109
 
63
110
  def percentage_rank(self, value: int) -> float:
64
111
  """
@@ -280,135 +327,121 @@ class RankStats:
280
327
  class RankStatsWriter:
281
328
  """Class for writing the rank stats to a file."""
282
329
 
283
- def __init__(self, file: Path):
330
+ def __init__(self, benchmark_name: str, table_name: str):
284
331
  """
285
332
  Initialise the RankStatsWriter class
286
333
  Args:
287
- file (Path): Path to the file where rank stats will be written
288
- """
289
- self.file = open(file, "w")
290
- self.writer = csv.writer(self.file, delimiter="\t")
291
- self.writer.writerow(
292
- [
293
- "results_directory_path",
294
- "top",
295
- "top3",
296
- "top5",
297
- "top10",
298
- "found",
299
- "total",
300
- "mean_reciprocal_rank",
301
- "percentage_top",
302
- "percentage_top3",
303
- "percentage_top5",
304
- "percentage_top10",
305
- "percentage_found",
306
- "precision@1",
307
- "precision@3",
308
- "precision@5",
309
- "precision@10",
310
- "MAP@1",
311
- "MAP@3",
312
- "MAP@5",
313
- "MAP@10",
314
- "f_beta_score@1",
315
- "f_beta_score@3",
316
- "f_beta_score@5",
317
- "f_beta_score@10",
318
- "NDCG@3",
319
- "NDCG@5",
320
- "NDCG@10",
321
- "true_positives",
322
- "false_positives",
323
- "true_negatives",
324
- "false_negatives",
325
- "sensitivity",
326
- "specificity",
327
- "precision",
328
- "negative_predictive_value",
329
- "false_positive_rate",
330
- "false_discovery_rate",
331
- "false_negative_rate",
332
- "accuracy",
333
- "f1_score",
334
- "matthews_correlation_coefficient",
335
- ]
334
+ table_name (str): Name of table to add statistics.
335
+ """
336
+
337
+ self.table_name = table_name
338
+ self.benchmark_name = benchmark_name
339
+ conn = BenchmarkDBManager(benchmark_name).conn
340
+ conn.execute(
341
+ f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ('
342
+ f"results_directory_path VARCHAR,"
343
+ f"top INT,"
344
+ f"top3 INT,"
345
+ f"top5 INT,"
346
+ f"top10 INT,"
347
+ f'"found" INT,'
348
+ f"total INT,"
349
+ f"mean_reciprocal_rank FLOAT,"
350
+ f"percentage_top FLOAT,"
351
+ f"percentage_top3 FLOAT,"
352
+ f"percentage_top5 FLOAT,"
353
+ f"percentage_top10 FLOAT,"
354
+ f"percentage_found FLOAT,"
355
+ f'"precision@1" FLOAT,'
356
+ f'"precision@3" FLOAT,'
357
+ f'"precision@5" FLOAT,'
358
+ f'"precision@10" FLOAT,'
359
+ f'"MAP@1" FLOAT,'
360
+ f'"MAP@3" FLOAT,'
361
+ f'"MAP@5" FLOAT,'
362
+ f'"MAP@10" FLOAT,'
363
+ f'"f_beta_score@1" FLOAT,'
364
+ f'"f_beta_score@3"FLOAT,'
365
+ f'"f_beta_score@5" FLOAT,'
366
+ f'"f_beta_score@10" FLOAT,'
367
+ f'"NDCG@3" FLOAT,'
368
+ f'"NDCG@5" FLOAT,'
369
+ f'"NDCG@10" FLOAT,'
370
+ f"true_positives INT,"
371
+ f"false_positives INT,"
372
+ f"true_negatives INT,"
373
+ f"false_negatives INT,"
374
+ f"sensitivity FLOAT,"
375
+ f"specificity FLOAT,"
376
+ f'"precision" FLOAT,'
377
+ f"negative_predictive_value FLOAT,"
378
+ f"false_positive_rate FLOAT,"
379
+ f"false_discovery_rate FLOAT,"
380
+ f"false_negative_rate FLOAT,"
381
+ f"accuracy FLOAT,"
382
+ f"f1_score FLOAT,"
383
+ f"matthews_correlation_coefficient FLOAT, )"
336
384
  )
385
+ conn.close()
337
386
 
338
- def write_row(
387
+ def add_statistics_entry(
339
388
  self,
340
- directory: Path,
389
+ run_identifier: str,
341
390
  rank_stats: RankStats,
342
391
  binary_classification: BinaryClassificationStats,
343
- ) -> None:
392
+ ):
344
393
  """
345
- Write summary rank statistics row for a run to the file.
346
-
394
+ Add statistics row to table for a run.
347
395
  Args:
348
- directory (Path): Path to the results directory corresponding to the run
349
- rank_stats (RankStats): RankStats instance containing rank statistics corresponding to the run
350
-
351
- Raises:
352
- IOError: If there is an error writing to the file.
353
- """
354
- try:
355
- self.writer.writerow(
356
- [
357
- directory,
358
- rank_stats.top,
359
- rank_stats.top3,
360
- rank_stats.top5,
361
- rank_stats.top10,
362
- rank_stats.found,
363
- rank_stats.total,
364
- rank_stats.mean_reciprocal_rank(),
365
- rank_stats.percentage_top(),
366
- rank_stats.percentage_top3(),
367
- rank_stats.percentage_top5(),
368
- rank_stats.percentage_top10(),
369
- rank_stats.percentage_found(),
370
- rank_stats.precision_at_k(1),
371
- rank_stats.precision_at_k(3),
372
- rank_stats.precision_at_k(5),
373
- rank_stats.precision_at_k(10),
374
- rank_stats.mean_average_precision_at_k(1),
375
- rank_stats.mean_average_precision_at_k(3),
376
- rank_stats.mean_average_precision_at_k(5),
377
- rank_stats.mean_average_precision_at_k(10),
378
- rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 1),
379
- rank_stats.f_beta_score_at_k(rank_stats.percentage_top3(), 3),
380
- rank_stats.f_beta_score_at_k(rank_stats.percentage_top5(), 5),
381
- rank_stats.f_beta_score_at_k(rank_stats.percentage_top10(), 10),
382
- rank_stats.mean_normalised_discounted_cumulative_gain(3),
383
- rank_stats.mean_normalised_discounted_cumulative_gain(5),
384
- rank_stats.mean_normalised_discounted_cumulative_gain(10),
385
- binary_classification.true_positives,
386
- binary_classification.false_positives,
387
- binary_classification.true_negatives,
388
- binary_classification.false_negatives,
389
- binary_classification.sensitivity(),
390
- binary_classification.specificity(),
391
- binary_classification.precision(),
392
- binary_classification.negative_predictive_value(),
393
- binary_classification.false_positive_rate(),
394
- binary_classification.false_discovery_rate(),
395
- binary_classification.false_negative_rate(),
396
- binary_classification.accuracy(),
397
- binary_classification.f1_score(),
398
- binary_classification.matthews_correlation_coefficient(),
399
- ]
400
- )
401
- except IOError:
402
- print("Error writing ", self.file)
403
-
404
- def close(self) -> None:
405
- """
406
- Close the file used for writing rank statistics.
407
-
408
- Raises:
409
- IOError: If there's an error while closing the file.
410
- """
411
- try:
412
- self.file.close()
413
- except IOError:
414
- print("Error closing ", self.file)
396
+ run_identifier (str): The run identifier.
397
+ rank_stats (RankStats): RankStats object for the run.
398
+ binary_classification (BinaryClassificationStats): BinaryClassificationStats object for the run.
399
+ """
400
+ conn = BenchmarkDBManager(self.benchmark_name).conn
401
+ conn.execute(
402
+ f' INSERT INTO "{self.table_name}" VALUES ( '
403
+ f"'{run_identifier}',"
404
+ f"{rank_stats.top},"
405
+ f"{rank_stats.top3},"
406
+ f"{rank_stats.top5},"
407
+ f"{rank_stats.top10},"
408
+ f"{rank_stats.found},"
409
+ f"{rank_stats.total},"
410
+ f"{rank_stats.mean_reciprocal_rank()},"
411
+ f"{rank_stats.percentage_top()},"
412
+ f"{rank_stats.percentage_top3()},"
413
+ f"{rank_stats.percentage_top5()},"
414
+ f"{rank_stats.percentage_top10()},"
415
+ f"{rank_stats.percentage_found()},"
416
+ f"{rank_stats.precision_at_k(1)},"
417
+ f"{rank_stats.precision_at_k(3)},"
418
+ f"{rank_stats.precision_at_k(5)},"
419
+ f"{rank_stats.precision_at_k(10)},"
420
+ f"{rank_stats.mean_average_precision_at_k(1)},"
421
+ f"{rank_stats.mean_average_precision_at_k(3)},"
422
+ f"{rank_stats.mean_average_precision_at_k(5)},"
423
+ f"{rank_stats.mean_average_precision_at_k(10)},"
424
+ f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 1)},"
425
+ f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 3)},"
426
+ f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 5)},"
427
+ f"{rank_stats.f_beta_score_at_k(rank_stats.percentage_top(), 10)},"
428
+ f"{rank_stats.mean_normalised_discounted_cumulative_gain(3)},"
429
+ f"{rank_stats.mean_normalised_discounted_cumulative_gain(5)},"
430
+ f"{rank_stats.mean_normalised_discounted_cumulative_gain(10)},"
431
+ f"{binary_classification.true_positives},"
432
+ f"{binary_classification.false_positives},"
433
+ f"{binary_classification.true_negatives},"
434
+ f"{binary_classification.false_negatives},"
435
+ f"{binary_classification.sensitivity()},"
436
+ f"{binary_classification.specificity()},"
437
+ f"{binary_classification.precision()},"
438
+ f"{binary_classification.negative_predictive_value()},"
439
+ f"{binary_classification.false_positive_rate()},"
440
+ f"{binary_classification.false_discovery_rate()},"
441
+ f"{binary_classification.false_negative_rate()},"
442
+ f"{binary_classification.accuracy()},"
443
+ f"{binary_classification.f1_score()},"
444
+ f"{binary_classification.matthews_correlation_coefficient()})"
445
+ )
446
+
447
+ conn.close()
@@ -1,44 +1,125 @@
1
- from dataclasses import dataclass
2
1
  from pathlib import Path
3
- from typing import List
2
+ from typing import List, Optional
4
3
 
5
- import pandas as pd
4
+ import yaml
5
+ from pydantic import BaseModel, root_validator
6
6
 
7
7
 
8
- @dataclass
9
- class TrackInputOutputDirectories:
8
+ class RunConfig(BaseModel):
10
9
  """
11
- Track the input phenopacket test data for a corresponding pheval output directory.
10
+ Store configurations for a run.
12
11
 
13
12
  Attributes:
14
- phenopacket_dir (Path): The directory containing input phenopackets.
15
- results_dir (Path): The directory containing output results from pheval.
13
+ run_identifier (str): The run identifier.
14
+ phenopacket_dir (str): The path to the phenopacket directory used for generating the results.
15
+ results_dir (str): The path to the results directory.
16
+ gene_analysis (bool): Whether or not to benchmark gene analysis results.
17
+ variant_analysis (bool): Whether or not to benchmark variant analysis results.
18
+ disease_analysis (bool): Whether or not to benchmark disease analysis results.
19
+ threshold (Optional[float]): The threshold to consider for benchmarking.
20
+ score_order (Optional[str]): The order of scores to consider for benchmarking, either ascending or descending.
16
21
  """
17
22
 
23
+ run_identifier: str
18
24
  phenopacket_dir: Path
19
25
  results_dir: Path
26
+ gene_analysis: bool
27
+ variant_analysis: bool
28
+ disease_analysis: bool
29
+ threshold: Optional[float]
30
+ score_order: Optional[str]
20
31
 
32
+ @root_validator(pre=True)
33
+ def handle_blank_fields(cls, values: dict) -> dict: # noqa: N805
34
+ """
35
+ Root validator to handle fields that may be explicitly set to None.
21
36
 
22
- def parse_run_data_text_file(run_data_path: Path) -> List[TrackInputOutputDirectories]:
37
+ This method checks if 'threshold' and 'score_order' are None and assigns default values if so.
38
+
39
+ Args:
40
+ values (dict): The input values provided to the model.
41
+
42
+ Returns:
43
+ dict: The updated values with defaults applied where necessary.
44
+ """
45
+ if values.get("threshold") is None:
46
+ values["threshold"] = 0
47
+ print("setting default threshold")
48
+ if values.get("score_order") is None:
49
+ values["score_order"] = "descending"
50
+ return values
51
+
52
+
53
+ class SinglePlotCustomisation(BaseModel):
23
54
  """
24
- Parse run data .txt file returning a list of input phenopacket and corresponding output directories.
55
+ Store customisations for plots.
56
+
57
+ Attributes:
58
+ plot_type (str): The plot type.
59
+ rank_plot_title (str): The title for the rank summary plot.
60
+ roc_curve_title (str): The title for the roc curve plot.
61
+ precision_recall_title (str): The title for the precision-recall plot.
62
+ """
63
+
64
+ plot_type: Optional[str] = "bar_cumulative"
65
+ rank_plot_title: Optional[str]
66
+ roc_curve_title: Optional[str]
67
+ precision_recall_title: Optional[str]
68
+
69
+ @root_validator(pre=True)
70
+ def handle_blank_fields(cls, values: dict) -> dict: # noqa: N805
71
+ """
72
+ Root validator to handle fields that may be explicitly set to None.
73
+
74
+ This method checks if 'plot_type' is None and assigns default value if so.
75
+
76
+ Args:
77
+ values (dict): The input values provided to the model.
78
+
79
+ Returns:
80
+ dict: The updated values with defaults applied where necessary.
81
+ """
82
+ if values.get("plot_type") is None:
83
+ values["plot_type"] = "bar_cumulative"
84
+ return values
85
+
86
+
87
+ class PlotCustomisation(BaseModel):
88
+ """
89
+ Store customisations for all plots.
90
+ Attributes:
91
+ gene_plots (SinglePlotCustomisation): Customisation for all gene benchmarking plots.
92
+ disease_plots (SinglePlotCustomisation): Customisation for all disease benchmarking plots.
93
+ variant_plots (SinglePlotCustomisation): Customisation for all variant benchmarking plots.
94
+ """
95
+
96
+ gene_plots: SinglePlotCustomisation
97
+ disease_plots: SinglePlotCustomisation
98
+ variant_plots: SinglePlotCustomisation
99
+
100
+
101
+ class Config(BaseModel):
102
+ """
103
+ Store configurations for a runs.
104
+ Attributes:
105
+ runs (List[RunConfig]): The list of run configurations.
106
+ """
107
+
108
+ benchmark_name: str
109
+ runs: List[RunConfig]
110
+ plot_customisation: PlotCustomisation
25
111
 
26
- Args:
27
- run_data_path (Path): The path to the run data .txt file.
28
112
 
113
+ def parse_run_config(run_config: Path) -> Config:
114
+ """
115
+ Parse a run configuration yaml file.
116
+ Args:
117
+ run_config (Path): The path to the run data yaml configuration.
29
118
  Returns:
30
- List[TrackInputOutputDirectories]: A list of TrackInputOutputDirectories objects, containing
31
- input test data directories and their corresponding output directories.
32
-
33
- Notes:
34
- The run data .txt file should be formatted with tab-separated values. Each row should contain
35
- two columns: the first column representing the input test data phenopacket directory, and
36
- the second column representing the corresponding run output directory.
37
- """
38
- run_data = pd.read_csv(run_data_path, delimiter="\t", header=None)
39
- run_data_list = []
40
- for _index, row in run_data.iterrows():
41
- run_data_list.append(
42
- TrackInputOutputDirectories(phenopacket_dir=Path(row[0]), results_dir=Path(row[1]))
43
- )
44
- return run_data_list
119
+ Config: The parsed run configurations.
120
+ """
121
+ with open(run_config, "r") as f:
122
+ config_data = yaml.safe_load(f)
123
+ f.close()
124
+ config = Config(**config_data)
125
+ return config