pheval 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pheval might be problematic. Click here for more details.

Files changed (33) hide show
  1. pheval/analyse/benchmark.py +156 -0
  2. pheval/analyse/benchmark_db_manager.py +16 -134
  3. pheval/analyse/benchmark_output_type.py +43 -0
  4. pheval/analyse/binary_classification_curves.py +132 -0
  5. pheval/analyse/binary_classification_stats.py +164 -307
  6. pheval/analyse/generate_plots.py +210 -395
  7. pheval/analyse/generate_rank_comparisons.py +44 -0
  8. pheval/analyse/rank_stats.py +190 -382
  9. pheval/analyse/run_data_parser.py +21 -39
  10. pheval/cli.py +27 -24
  11. pheval/cli_pheval_utils.py +7 -8
  12. pheval/post_processing/phenopacket_truth_set.py +250 -0
  13. pheval/post_processing/post_processing.py +179 -345
  14. pheval/post_processing/validate_result_format.py +91 -0
  15. pheval/prepare/update_phenopacket.py +11 -9
  16. pheval/utils/logger.py +35 -0
  17. pheval/utils/phenopacket_utils.py +85 -91
  18. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/METADATA +4 -4
  19. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/RECORD +22 -26
  20. pheval/analyse/analysis.py +0 -104
  21. pheval/analyse/assess_prioritisation_base.py +0 -108
  22. pheval/analyse/benchmark_generator.py +0 -126
  23. pheval/analyse/benchmarking_data.py +0 -25
  24. pheval/analyse/disease_prioritisation_analysis.py +0 -152
  25. pheval/analyse/gene_prioritisation_analysis.py +0 -147
  26. pheval/analyse/generate_summary_outputs.py +0 -105
  27. pheval/analyse/parse_benchmark_summary.py +0 -81
  28. pheval/analyse/parse_corpus.py +0 -219
  29. pheval/analyse/prioritisation_result_types.py +0 -52
  30. pheval/analyse/variant_prioritisation_analysis.py +0 -159
  31. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/LICENSE +0 -0
  32. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/WHEEL +0 -0
  33. {pheval-0.4.7.dist-info → pheval-0.5.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,156 @@
1
+ import time
2
+ from pathlib import Path
3
+ from typing import List, Tuple
4
+
5
+ import duckdb
6
+ import polars as pl
7
+
8
+ from pheval.analyse.benchmark_db_manager import write_table
9
+ from pheval.analyse.benchmark_output_type import BenchmarkOutputType, BenchmarkOutputTypeEnum
10
+ from pheval.analyse.binary_classification_curves import compute_curves
11
+ from pheval.analyse.binary_classification_stats import compute_confusion_matrix
12
+ from pheval.analyse.generate_plots import generate_plots
13
+ from pheval.analyse.generate_rank_comparisons import calculate_rank_changes
14
+ from pheval.analyse.rank_stats import compute_rank_stats
15
+ from pheval.analyse.run_data_parser import Config, RunConfig, parse_run_config
16
+ from pheval.utils.logger import get_logger
17
+
18
+
19
+ def scan_directory(run: RunConfig, benchmark_type: BenchmarkOutputType) -> pl.LazyFrame:
20
+ """
21
+ Scan a results directory containing pheval parquet standardised results and return a LazyFrame object.
22
+ Args:
23
+ run (RunConfig): RunConfig object.
24
+ benchmark_type (BenchmarkOutputTypeEnum): Benchmark output type.
25
+ Returns:
26
+ pl.LazyFrame: LazyFrame object containing all the results in the directory..
27
+ """
28
+ logger = get_logger()
29
+ logger.info(f"Analysing results in {run.results_dir.joinpath(benchmark_type.result_directory)}")
30
+ return (
31
+ pl.scan_parquet(
32
+ run.results_dir.joinpath(benchmark_type.result_directory),
33
+ include_file_paths="file_path",
34
+ ).with_columns(
35
+ pl.col("rank").cast(pl.Int64),
36
+ pl.col("file_path").str.extract(r"([^/\\]+)$").alias("result_file"),
37
+ pl.col("true_positive").fill_null(False),
38
+ )
39
+ ).filter(
40
+ (
41
+ pl.col("score") >= run.threshold
42
+ if run.score_order.lower() == "descending"
43
+ else pl.col("score") <= run.threshold
44
+ )
45
+ if run.threshold is not None
46
+ else True
47
+ )
48
+
49
+
50
+ def process_stats(
51
+ runs: List[RunConfig], benchmark_type: BenchmarkOutputType
52
+ ) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
53
+ """
54
+ Processes stats outputs for specified runs to compare.
55
+ Args:
56
+ runs (List[RunConfig]): List of runs to benchmark.
57
+ benchmark_type (BenchmarkOutputTypeEnum): Benchmark output type.
58
+ Returns:
59
+ Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]: The stats for all runs.
60
+ """
61
+ stats, curve_results, true_positive_cases = [], [], []
62
+ for run in runs:
63
+ result_scan = scan_directory(run, benchmark_type)
64
+ stats.append(
65
+ compute_rank_stats(run.run_identifier, result_scan).join(
66
+ compute_confusion_matrix(run.run_identifier, result_scan), on="run_identifier"
67
+ )
68
+ )
69
+ curve_results.append(compute_curves(run.run_identifier, result_scan))
70
+ true_positive_cases.append(
71
+ result_scan.filter(pl.col("true_positive")).select(
72
+ ["result_file", *benchmark_type.columns, pl.col("rank").alias(run.run_identifier)]
73
+ )
74
+ )
75
+ return (
76
+ pl.concat(stats, how="vertical").collect(),
77
+ pl.concat(curve_results, how="vertical").collect(),
78
+ pl.concat(true_positive_cases, how="align_inner").collect(),
79
+ )
80
+
81
+
82
+ def benchmark(config: Config, benchmark_type: BenchmarkOutputType) -> None:
83
+ """
84
+ Benchmark results for specified runs for a specified prioritisation type for comparison.
85
+ Args:
86
+ config (Config): Configuration for benchmarking.
87
+ benchmark_type (BenchmarkOutputType): Benchmark output type.
88
+ """
89
+ conn = duckdb.connect(f"{config.benchmark_name}.duckdb")
90
+ stats, curve_results, true_positive_cases = process_stats(config.runs, benchmark_type)
91
+ write_table(
92
+ conn, stats, f"{config.benchmark_name}_{benchmark_type.prioritisation_type_string}_summary"
93
+ )
94
+ write_table(
95
+ conn,
96
+ curve_results,
97
+ f"{config.benchmark_name}_{benchmark_type.prioritisation_type_string}_binary_classification_curves",
98
+ )
99
+ calculate_rank_changes(
100
+ conn, [run.run_identifier for run in config.runs], true_positive_cases, benchmark_type
101
+ )
102
+ generate_plots(
103
+ config.benchmark_name, stats, curve_results, benchmark_type, config.plot_customisation
104
+ )
105
+ conn.close()
106
+
107
+
108
+ def benchmark_runs(benchmark_config_file: Path) -> None:
109
+ """
110
+ Benchmark results for specified runs for comparison.
111
+ Args:
112
+ benchmark_config_file (Path): Path to benchmark config file.
113
+ """
114
+ logger = get_logger()
115
+ start_time = time.perf_counter()
116
+ logger.info("Initiated benchmarking process.")
117
+ config = parse_run_config(benchmark_config_file)
118
+ gene_analysis_runs = [run for run in config.runs if run.gene_analysis]
119
+ variant_analysis_runs = [run for run in config.runs if run.variant_analysis]
120
+ disease_analysis_runs = [run for run in config.runs if run.disease_analysis]
121
+ if gene_analysis_runs:
122
+ logger.info("Initiating benchmarking for gene results.")
123
+ benchmark(
124
+ Config(
125
+ benchmark_name=config.benchmark_name,
126
+ runs=gene_analysis_runs,
127
+ plot_customisation=config.plot_customisation,
128
+ ),
129
+ BenchmarkOutputTypeEnum.GENE.value,
130
+ )
131
+ logger.info("Finished benchmarking for gene results.")
132
+ if variant_analysis_runs:
133
+ logger.info("Initiating benchmarking for variant results")
134
+ benchmark(
135
+ Config(
136
+ benchmark_name=config.benchmark_name,
137
+ runs=variant_analysis_runs,
138
+ plot_customisation=config.plot_customisation,
139
+ ),
140
+ BenchmarkOutputTypeEnum.VARIANT.value,
141
+ )
142
+ logger.info("Finished benchmarking for variant results.")
143
+ if disease_analysis_runs:
144
+ logger.info("Initiating benchmarking for disease results")
145
+ benchmark(
146
+ Config(
147
+ benchmark_name=config.benchmark_name,
148
+ runs=disease_analysis_runs,
149
+ plot_customisation=config.plot_customisation,
150
+ ),
151
+ BenchmarkOutputTypeEnum.DISEASE.value,
152
+ )
153
+ logger.info("Finished benchmarking for disease results.")
154
+ logger.info(
155
+ f"Finished benchmarking! Total time: {time.perf_counter() - start_time:.2f} seconds."
156
+ )
@@ -1,141 +1,23 @@
1
- import ast
2
- import re
3
- from typing import List, Type, Union
4
-
5
- import duckdb
1
+ import polars as pl
6
2
  from duckdb import DuckDBPyConnection
7
3
 
8
- from pheval.post_processing.post_processing import (
9
- RankedPhEvalDiseaseResult,
10
- RankedPhEvalGeneResult,
11
- RankedPhEvalVariantResult,
12
- )
13
-
14
-
15
- class BenchmarkDBManager:
16
- """
17
- Class to connect to database.
18
- """
19
-
20
- def __init__(self, benchmark_name: str):
21
- """Initialise the BenchmarkDBManager class."""
22
- self.conn = self.get_connection(
23
- f"{benchmark_name}" if str(benchmark_name).endswith(".db") else f"{benchmark_name}.db"
24
- )
25
-
26
- def initialise(self):
27
- """Initialise the duckdb connection."""
28
- self.add_contains_function()
4
+ from pheval.utils.logger import get_logger
29
5
 
30
- @staticmethod
31
- def get_connection(db_name: str) -> DuckDBPyConnection:
32
- """
33
- Get a connection to the database.
34
- Returns:
35
- DuckDBPyConnection: Connection to the database.
36
- """
37
- conn = duckdb.connect(db_name)
38
- return conn
6
+ logger = get_logger()
39
7
 
40
- def add_column_integer_default(self, table_name: str, column: str, default: int = 0) -> None:
41
- """
42
- Add a column to an existing table with an integer default value.
43
- Args:
44
- table_name (str): Name of the table.
45
- column (str): Name of the column to add.
46
- default (int): Default integer value to add.
47
- """
48
- try:
49
- self.conn.execute(
50
- f'ALTER TABLE "{table_name}" ADD COLUMN "{column}" INTEGER DEFAULT {default}'
51
- )
52
- self.conn.execute(f'UPDATE "{table_name}" SET "{column}" = ?', (default,))
53
- self.conn.commit()
54
- except duckdb.CatalogException:
55
- pass
56
8
 
57
- def drop_table(self, table_name: str) -> None:
58
- """
59
- Drop a table from the database.
60
- Args:
61
- table_name: Name of the table to drop from the database
62
- """
63
- self.conn.execute(f"""DROP TABLE IF EXISTS "{table_name}";""")
9
+ def load_table_lazy(table_name: str, conn: DuckDBPyConnection) -> pl.LazyFrame:
10
+ logger.info(f"Loading table {table_name}")
11
+ return pl.from_arrow(conn.execute(f"SELECT * FROM {table_name}").fetch_arrow_table()).lazy()
64
12
 
65
- @staticmethod
66
- def contains_entity_function(entity: str, known_causative_entity: str) -> bool:
67
- """
68
- Determines if a known causative entity is present within an entity or list of entities.
69
- Args:
70
- entity (str): The entity to be checked. It can be a single entity or a string representation of a list.
71
- known_causative_entity (str): The entity to search for within the `entity`.
72
13
 
73
- Returns:
74
- bool: `True` if `known_causative_entity` is found in `entity` (or its list representation),
75
- `False` otherwise.
76
- """
77
- list_pattern = re.compile(r"^\[\s*(?:[^\[\],\s]+(?:\s*,\s*[^\[\],\s]+)*)?\s*]$")
78
- entity = entity.replace("nan", "None").replace("NaN", "None")
79
- if list_pattern.match(str(entity)):
80
- list_representation = ast.literal_eval(entity)
81
- if isinstance(list_representation, list):
82
- return known_causative_entity in list_representation
83
- return known_causative_entity == entity
84
-
85
- def add_contains_function(self) -> None:
86
- """
87
- Adds a custom `contains_entity_function` to the DuckDB connection if it does not already exist.
88
- """
89
- result = self.conn.execute(
90
- "SELECT * FROM duckdb_functions() WHERE function_name = ?", ["contains_entity_function"]
91
- ).fetchall()
92
- if not result:
93
- self.conn.create_function("contains_entity_function", self.contains_entity_function)
94
-
95
- def parse_table_into_dataclass(
96
- self,
97
- table_name: str,
98
- dataclass: Union[
99
- Type[RankedPhEvalGeneResult],
100
- Type[RankedPhEvalVariantResult],
101
- Type[RankedPhEvalDiseaseResult],
102
- ],
103
- ) -> Union[
104
- List[RankedPhEvalGeneResult],
105
- List[RankedPhEvalVariantResult],
106
- List[RankedPhEvalDiseaseResult],
107
- ]:
108
- """
109
- Parses a DuckDB table into a list of dataclass instances.
110
- Args:
111
- table_name (str): The name of the DuckDB table to be parsed.
112
- dataclass (Union[Type[RankedPhEvalGeneResult], Type[RankedPhEvalVariantResult],
113
- Type[RankedPhEvalDiseaseResult]]):
114
- The dataclass type to which each row in the table should be mapped.
115
-
116
- Returns:
117
- List[dataclass]: A list of instances of the provided dataclass, each representing a row from the table.
118
- """
119
- result = (
120
- self.conn.execute(f"SELECT * FROM '{table_name}'").fetchdf().to_dict(orient="records")
121
- )
122
- return [dataclass(**row) for row in result]
123
-
124
- def check_table_exists(self, table_name: str) -> bool:
125
- """
126
- Check if a table exists in the connected DuckDB database.
127
- Args:
128
- table_name (str): The name of the table to check for existence.
129
- Returns:
130
- bool: Returns `True` if the table exists in the database, `False` otherwise.
131
- """
132
- result = self.conn.execute(
133
- f"SELECT * FROM information_schema.tables WHERE table_name = '{table_name}'"
134
- ).fetchall()
135
- if result:
136
- return True
137
- return False
138
-
139
- def close(self):
140
- """Close the connection to the database."""
141
- self.conn.close()
14
+ def write_table(conn: DuckDBPyConnection, df: pl.DataFrame, table_name: str) -> None:
15
+ """
16
+ Write table to DuckDB database.
17
+ Args:
18
+ conn (DuckDBPyConnection): DuckDB connection.
19
+ df (pl.DataFrame): DuckDB dataframe.
20
+ table_name (str): Table name.
21
+ """
22
+ logger.info(f"Storing results in {table_name}.")
23
+ conn.execute(f"""CREATE TABLE "{table_name}" AS SELECT * FROM df""")
@@ -0,0 +1,43 @@
1
+ from enum import Enum
2
+ from typing import List, NamedTuple
3
+
4
+
5
+ class BenchmarkOutputType(NamedTuple):
6
+ """
7
+ Represents the structure of benchmark output types.
8
+
9
+ Attributes:
10
+ prioritisation_type_string (str): The type of prioritisation being performed.
11
+ y_label (str): The label for the y-axis in performance evaluation plots.
12
+ columns (List[str]): The list of column names relevant to the benchmark output.
13
+ result_directory (str): The directory where benchmark results are stored.
14
+ """
15
+
16
+ prioritisation_type_string: str
17
+ y_label: str
18
+ columns: List[str]
19
+ result_directory: str
20
+
21
+
22
+ class BenchmarkOutputTypeEnum(Enum):
23
+ """
24
+ Enumeration of benchmark output types, representing different entities.
25
+
26
+ Attributes:
27
+ GENE (BenchmarkOutputType): Benchmark output type for gene prioritisation.
28
+ VARIANT (BenchmarkOutputType): Benchmark output type for variant prioritisation.
29
+ DISEASE (BenchmarkOutputType): Benchmark output type for disease prioritisation.
30
+ """
31
+
32
+ GENE = BenchmarkOutputType(
33
+ "gene",
34
+ "Disease-causing genes (%)",
35
+ ["gene_identifier", "gene_symbol"],
36
+ "pheval_gene_results",
37
+ )
38
+ VARIANT = BenchmarkOutputType(
39
+ "variant", "Disease-causing variants (%)", ["variant_id"], "pheval_variant_results"
40
+ )
41
+ DISEASE = BenchmarkOutputType(
42
+ "disease", "Known diseases (%)", ["disease_identifier"], "pheval_disease_results"
43
+ )
@@ -0,0 +1,132 @@
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import polars as pl
5
+ from sklearn.metrics import precision_recall_curve, roc_curve
6
+
7
+ from pheval.utils.logger import get_logger
8
+
9
+
10
+ class BinaryClassificationCurves:
11
+ """Class for computing and storing ROC & Precision-Recall curves in Polars."""
12
+
13
+ @staticmethod
14
+ def _compute_finite_bounds(result_scan: pl.LazyFrame) -> Tuple[float, float]:
15
+ """
16
+ Compute min and max finite values in the 'score' column to handle NaN and Inf values.
17
+ Args:
18
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
19
+
20
+ Returns:
21
+ Tuple[float, float]: The (max_finite, min_finite) values for normalising scores.
22
+ """
23
+ return (
24
+ result_scan.select(
25
+ [
26
+ pl.col("score").filter(pl.col("score").is_finite()).max().alias("max_finite"),
27
+ pl.col("score").filter(pl.col("score").is_finite()).min().alias("min_finite"),
28
+ ]
29
+ )
30
+ .collect()
31
+ .row(0)
32
+ )
33
+
34
+ @staticmethod
35
+ def _clean_and_extract_data(
36
+ result_scan: pl.LazyFrame, max_finite: float, min_finite: float
37
+ ) -> pl.LazyFrame:
38
+ """
39
+ Normalise the 'score' column (handling NaNs and Inf values) and extract 'true_positive' labels.
40
+
41
+ Args:
42
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
43
+ max_finite (float): The maximum finite score value.
44
+ min_finite (float): The minimum finite score value.
45
+
46
+ Returns:
47
+ pl.LazyFrame: A LazyFrame with cleaned 'score' and binary 'true_positive' columns.
48
+ """
49
+ return result_scan.with_columns(
50
+ [
51
+ pl.when(pl.col("score").is_nan())
52
+ .then(0.0)
53
+ .when(pl.col("score").is_infinite() & (pl.col("score") > 0))
54
+ .then(max_finite)
55
+ .when(pl.col("score").is_infinite() & (pl.col("score") < 0))
56
+ .then(min_finite)
57
+ .otherwise(pl.col("score"))
58
+ .alias("score"),
59
+ pl.when(pl.col("true_positive").is_null())
60
+ .then(0)
61
+ .otherwise(pl.col("true_positive").cast(pl.Int8))
62
+ .alias("true_positive"),
63
+ ]
64
+ )
65
+
66
+ @staticmethod
67
+ def _compute_roc_pr_curves(
68
+ run_identifier: str, labels: np.ndarray, scores: np.ndarray
69
+ ) -> pl.LazyFrame:
70
+ """
71
+ Compute ROC and Precision-Recall curves.
72
+
73
+ Args:
74
+ labels (np.ndarray): Binary ground truth labels (0 or 1).
75
+ scores (np.ndarray): Prediction scores.
76
+
77
+ Returns:
78
+ pl.LazyFrame: A LazyFrame containing the computed FPR, TPR, Precision, Recall, and Thresholds.
79
+ """
80
+ fpr, tpr, roc_thresholds = roc_curve(labels, scores, pos_label=1)
81
+ precision, recall, pr_thresholds = precision_recall_curve(labels, scores, pos_label=1)
82
+
83
+ return pl.LazyFrame(
84
+ {
85
+ "run_identifier": [run_identifier],
86
+ "fpr": [fpr.tolist()],
87
+ "tpr": [tpr.tolist()],
88
+ "threshold_roc": [roc_thresholds.tolist()],
89
+ "precision": [precision.tolist()],
90
+ "recall": [recall.tolist()],
91
+ "threshold_pr": [pr_thresholds.tolist()],
92
+ }
93
+ )
94
+
95
+ @classmethod
96
+ def process(cls, result_scan: pl.LazyFrame, run_identifier: str) -> pl.LazyFrame:
97
+ """
98
+ Process scores, extract true labels, compute ROC and Precision-Recall curves,
99
+ and store results in a Polars LazyFrame with NumPy arrays.
100
+
101
+ Args:
102
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
103
+ run_identifier (str): Identifier for this run.
104
+
105
+ Returns:
106
+ pl.LazyFrame: A LazyFrame containing ROC & PR curve data with NumPy arrays.
107
+ """
108
+ max_finite, min_finite = cls._compute_finite_bounds(result_scan)
109
+ cleaned_data = (
110
+ cls._clean_and_extract_data(result_scan, max_finite, min_finite)
111
+ .select(["true_positive", "score"])
112
+ .collect()
113
+ )
114
+ return cls._compute_roc_pr_curves(
115
+ run_identifier,
116
+ cleaned_data["true_positive"].to_numpy().flatten(),
117
+ cleaned_data["score"].to_numpy().flatten(),
118
+ )
119
+
120
+
121
+ def compute_curves(run_identifier: str, result_scan: pl.LazyFrame) -> pl.LazyFrame:
122
+ """
123
+ Compute ROC and Precision-Recall curves.
124
+ Args:
125
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
126
+ run_identifier (str): Identifier for this run.
127
+ Returns:
128
+ pl.LazyFrame: LazyFrame containing the ROC & Precision-Recall curve data with NumPy arrays.
129
+ """
130
+ logger = get_logger()
131
+ logger.info("Calculating ROC and Precision-Recall metrics")
132
+ return BinaryClassificationCurves.process(result_scan, run_identifier)