pheval 0.4.6__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pheval might be problematic. Click here for more details.
- {pheval-0.4.6 → pheval-0.5.0}/PKG-INFO +4 -4
- {pheval-0.4.6 → pheval-0.5.0}/pyproject.toml +4 -4
- pheval-0.5.0/src/pheval/analyse/benchmark.py +156 -0
- pheval-0.5.0/src/pheval/analyse/benchmark_db_manager.py +23 -0
- pheval-0.5.0/src/pheval/analyse/benchmark_output_type.py +43 -0
- pheval-0.5.0/src/pheval/analyse/binary_classification_curves.py +132 -0
- pheval-0.5.0/src/pheval/analyse/binary_classification_stats.py +186 -0
- pheval-0.5.0/src/pheval/analyse/generate_plots.py +379 -0
- pheval-0.5.0/src/pheval/analyse/generate_rank_comparisons.py +44 -0
- pheval-0.5.0/src/pheval/analyse/rank_stats.py +255 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/analyse/run_data_parser.py +21 -39
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/cli.py +28 -25
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/cli_pheval_utils.py +7 -8
- pheval-0.5.0/src/pheval/post_processing/phenopacket_truth_set.py +235 -0
- pheval-0.5.0/src/pheval/post_processing/post_processing.py +266 -0
- pheval-0.5.0/src/pheval/post_processing/validate_result_format.py +92 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/prepare/update_phenopacket.py +11 -9
- pheval-0.5.0/src/pheval/utils/logger.py +35 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/phenopacket_utils.py +85 -91
- pheval-0.4.6/src/pheval/analyse/analysis.py +0 -104
- pheval-0.4.6/src/pheval/analyse/assess_prioritisation_base.py +0 -108
- pheval-0.4.6/src/pheval/analyse/benchmark_db_manager.py +0 -141
- pheval-0.4.6/src/pheval/analyse/benchmark_generator.py +0 -126
- pheval-0.4.6/src/pheval/analyse/benchmarking_data.py +0 -25
- pheval-0.4.6/src/pheval/analyse/binary_classification_stats.py +0 -329
- pheval-0.4.6/src/pheval/analyse/disease_prioritisation_analysis.py +0 -152
- pheval-0.4.6/src/pheval/analyse/gene_prioritisation_analysis.py +0 -147
- pheval-0.4.6/src/pheval/analyse/generate_plots.py +0 -564
- pheval-0.4.6/src/pheval/analyse/generate_summary_outputs.py +0 -105
- pheval-0.4.6/src/pheval/analyse/parse_benchmark_summary.py +0 -81
- pheval-0.4.6/src/pheval/analyse/parse_corpus.py +0 -219
- pheval-0.4.6/src/pheval/analyse/prioritisation_result_types.py +0 -52
- pheval-0.4.6/src/pheval/analyse/rank_stats.py +0 -447
- pheval-0.4.6/src/pheval/analyse/variant_prioritisation_analysis.py +0 -159
- pheval-0.4.6/src/pheval/post_processing/post_processing.py +0 -386
- {pheval-0.4.6 → pheval-0.5.0}/LICENSE +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/README.md +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/analyse/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/cli_pheval.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/config_parser.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/implementations/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/infra/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/infra/exomiserdb.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/post_processing/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/prepare/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/prepare/create_noisy_phenopackets.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/prepare/create_spiked_vcf.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/prepare/custom_exceptions.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/prepare/prepare_corpus.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/CADA_results.txt +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/DeepPVP_results.txt +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/OVA_results.txt +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/Phen2Gene_results.json +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/Phenolyzer_results.txt +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/lirical_results.tsv +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/svanna_results.tsv +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/resources/hgnc_complete_set.txt +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/run_metadata.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/runners/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/runners/runner.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/__init__.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/docs_gen.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/docs_gen.sh +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/exomiser.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/file_utils.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/semsim_utils.py +0 -0
- {pheval-0.4.6 → pheval-0.5.0}/src/pheval/utils/utils.py +0 -0
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: pheval
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Yasemin Bridges
|
|
6
6
|
Author-email: y.bridges@qmul.ac.uk
|
|
7
|
-
Requires-Python: >=3.
|
|
7
|
+
Requires-Python: >=3.10,<4.0.0
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
10
9
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
@@ -22,8 +21,9 @@ Requires-Dist: oaklib (>=0.5.6)
|
|
|
22
21
|
Requires-Dist: pandas (>=1.5.1)
|
|
23
22
|
Requires-Dist: phenopackets (>=2.0.2,<3.0.0)
|
|
24
23
|
Requires-Dist: plotly (>=5.13.0,<6.0.0)
|
|
25
|
-
Requires-Dist: polars (>=
|
|
24
|
+
Requires-Dist: polars (>=1.23,<2.0)
|
|
26
25
|
Requires-Dist: pyaml (>=21.10.1,<22.0.0)
|
|
26
|
+
Requires-Dist: pyarrow (>=19.0.1,<20.0.0)
|
|
27
27
|
Requires-Dist: pyserde (>=0.9.8,<0.10.0)
|
|
28
28
|
Requires-Dist: scikit-learn (>=1.4.0,<2.0.0)
|
|
29
29
|
Requires-Dist: seaborn (>=0.12.2,<0.13.0)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "pheval"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.5.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Yasemin Bridges <y.bridges@qmul.ac.uk>",
|
|
6
6
|
"Julius Jacobsen <j.jacobsen@qmul.ac.uk>",
|
|
@@ -10,7 +10,7 @@ readme = "README.md"
|
|
|
10
10
|
packages = [{include = "pheval", from = "src"}]
|
|
11
11
|
|
|
12
12
|
[tool.poetry.dependencies]
|
|
13
|
-
python = ">=3.
|
|
13
|
+
python = ">=3.10,<4.0.0"
|
|
14
14
|
jaydebeapi = ">=1.2.3"
|
|
15
15
|
tqdm = ">=4.64.1"
|
|
16
16
|
pandas = ">=1.5.1"
|
|
@@ -25,14 +25,14 @@ plotly = "^5.13.0"
|
|
|
25
25
|
seaborn = "^0.12.2"
|
|
26
26
|
matplotlib = "^3.7.0"
|
|
27
27
|
pyserde = "^0.9.8"
|
|
28
|
-
polars = "^
|
|
28
|
+
polars = "^1.23"
|
|
29
29
|
scikit-learn = "^1.4.0"
|
|
30
30
|
duckdb = "^1.0.0"
|
|
31
|
+
pyarrow = "^19.0.1"
|
|
31
32
|
|
|
32
33
|
[tool.poetry.dev-dependencies]
|
|
33
34
|
pytest = "^7.2.0"
|
|
34
35
|
coverage = "^6.5.0"
|
|
35
|
-
pheval-template = "^0.1.2"
|
|
36
36
|
pytest-workflow = "^2.0.1"
|
|
37
37
|
|
|
38
38
|
[tool.poetry.scripts]
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import duckdb
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
from pheval.analyse.benchmark_db_manager import write_table
|
|
9
|
+
from pheval.analyse.benchmark_output_type import BenchmarkOutputType, BenchmarkOutputTypeEnum
|
|
10
|
+
from pheval.analyse.binary_classification_curves import compute_curves
|
|
11
|
+
from pheval.analyse.binary_classification_stats import compute_confusion_matrix
|
|
12
|
+
from pheval.analyse.generate_plots import generate_plots
|
|
13
|
+
from pheval.analyse.generate_rank_comparisons import calculate_rank_changes
|
|
14
|
+
from pheval.analyse.rank_stats import compute_rank_stats
|
|
15
|
+
from pheval.analyse.run_data_parser import Config, RunConfig, parse_run_config
|
|
16
|
+
from pheval.utils.logger import get_logger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def scan_directory(run: RunConfig, benchmark_type: BenchmarkOutputType) -> pl.LazyFrame:
|
|
20
|
+
"""
|
|
21
|
+
Scan a results directory containing pheval parquet standardised results and return a LazyFrame object.
|
|
22
|
+
Args:
|
|
23
|
+
run (RunConfig): RunConfig object.
|
|
24
|
+
benchmark_type (BenchmarkOutputTypeEnum): Benchmark output type.
|
|
25
|
+
Returns:
|
|
26
|
+
pl.LazyFrame: LazyFrame object containing all the results in the directory..
|
|
27
|
+
"""
|
|
28
|
+
logger = get_logger()
|
|
29
|
+
logger.info(f"Analysing results in {run.results_dir.joinpath(benchmark_type.result_directory)}")
|
|
30
|
+
return (
|
|
31
|
+
pl.scan_parquet(
|
|
32
|
+
run.results_dir.joinpath(benchmark_type.result_directory),
|
|
33
|
+
include_file_paths="file_path",
|
|
34
|
+
).with_columns(
|
|
35
|
+
pl.col("rank").cast(pl.Int64),
|
|
36
|
+
pl.col("file_path").str.extract(r"([^/\\]+)$").alias("result_file"),
|
|
37
|
+
pl.col("true_positive").fill_null(False),
|
|
38
|
+
)
|
|
39
|
+
).filter(
|
|
40
|
+
(
|
|
41
|
+
pl.col("score") >= run.threshold
|
|
42
|
+
if run.score_order.lower() == "descending"
|
|
43
|
+
else pl.col("score") <= run.threshold
|
|
44
|
+
)
|
|
45
|
+
if run.threshold is not None
|
|
46
|
+
else True
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def process_stats(
|
|
51
|
+
runs: List[RunConfig], benchmark_type: BenchmarkOutputType
|
|
52
|
+
) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
|
53
|
+
"""
|
|
54
|
+
Processes stats outputs for specified runs to compare.
|
|
55
|
+
Args:
|
|
56
|
+
runs (List[RunConfig]): List of runs to benchmark.
|
|
57
|
+
benchmark_type (BenchmarkOutputTypeEnum): Benchmark output type.
|
|
58
|
+
Returns:
|
|
59
|
+
Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]: The stats for all runs.
|
|
60
|
+
"""
|
|
61
|
+
stats, curve_results, true_positive_cases = [], [], []
|
|
62
|
+
for run in runs:
|
|
63
|
+
result_scan = scan_directory(run, benchmark_type)
|
|
64
|
+
stats.append(
|
|
65
|
+
compute_rank_stats(run.run_identifier, result_scan).join(
|
|
66
|
+
compute_confusion_matrix(run.run_identifier, result_scan), on="run_identifier"
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
curve_results.append(compute_curves(run.run_identifier, result_scan))
|
|
70
|
+
true_positive_cases.append(
|
|
71
|
+
result_scan.filter(pl.col("true_positive")).select(
|
|
72
|
+
["result_file", *benchmark_type.columns, pl.col("rank").alias(run.run_identifier)]
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
return (
|
|
76
|
+
pl.concat(stats, how="vertical").collect(),
|
|
77
|
+
pl.concat(curve_results, how="vertical").collect(),
|
|
78
|
+
pl.concat(true_positive_cases, how="align_inner").collect(),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def benchmark(config: Config, benchmark_type: BenchmarkOutputType) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Benchmark results for specified runs for a specified prioritisation type for comparison.
|
|
85
|
+
Args:
|
|
86
|
+
config (Config): Configuration for benchmarking.
|
|
87
|
+
benchmark_type (BenchmarkOutputType): Benchmark output type.
|
|
88
|
+
"""
|
|
89
|
+
conn = duckdb.connect(f"{config.benchmark_name}.duckdb")
|
|
90
|
+
stats, curve_results, true_positive_cases = process_stats(config.runs, benchmark_type)
|
|
91
|
+
write_table(
|
|
92
|
+
conn, stats, f"{config.benchmark_name}_{benchmark_type.prioritisation_type_string}_summary"
|
|
93
|
+
)
|
|
94
|
+
write_table(
|
|
95
|
+
conn,
|
|
96
|
+
curve_results,
|
|
97
|
+
f"{config.benchmark_name}_{benchmark_type.prioritisation_type_string}_binary_classification_curves",
|
|
98
|
+
)
|
|
99
|
+
calculate_rank_changes(
|
|
100
|
+
conn, [run.run_identifier for run in config.runs], true_positive_cases, benchmark_type
|
|
101
|
+
)
|
|
102
|
+
generate_plots(
|
|
103
|
+
config.benchmark_name, stats, curve_results, benchmark_type, config.plot_customisation
|
|
104
|
+
)
|
|
105
|
+
conn.close()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def benchmark_runs(benchmark_config_file: Path) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Benchmark results for specified runs for comparison.
|
|
111
|
+
Args:
|
|
112
|
+
benchmark_config_file (Path): Path to benchmark config file.
|
|
113
|
+
"""
|
|
114
|
+
logger = get_logger()
|
|
115
|
+
start_time = time.perf_counter()
|
|
116
|
+
logger.info("Initiated benchmarking process.")
|
|
117
|
+
config = parse_run_config(benchmark_config_file)
|
|
118
|
+
gene_analysis_runs = [run for run in config.runs if run.gene_analysis]
|
|
119
|
+
variant_analysis_runs = [run for run in config.runs if run.variant_analysis]
|
|
120
|
+
disease_analysis_runs = [run for run in config.runs if run.disease_analysis]
|
|
121
|
+
if gene_analysis_runs:
|
|
122
|
+
logger.info("Initiating benchmarking for gene results.")
|
|
123
|
+
benchmark(
|
|
124
|
+
Config(
|
|
125
|
+
benchmark_name=config.benchmark_name,
|
|
126
|
+
runs=gene_analysis_runs,
|
|
127
|
+
plot_customisation=config.plot_customisation,
|
|
128
|
+
),
|
|
129
|
+
BenchmarkOutputTypeEnum.GENE.value,
|
|
130
|
+
)
|
|
131
|
+
logger.info("Finished benchmarking for gene results.")
|
|
132
|
+
if variant_analysis_runs:
|
|
133
|
+
logger.info("Initiating benchmarking for variant results")
|
|
134
|
+
benchmark(
|
|
135
|
+
Config(
|
|
136
|
+
benchmark_name=config.benchmark_name,
|
|
137
|
+
runs=variant_analysis_runs,
|
|
138
|
+
plot_customisation=config.plot_customisation,
|
|
139
|
+
),
|
|
140
|
+
BenchmarkOutputTypeEnum.VARIANT.value,
|
|
141
|
+
)
|
|
142
|
+
logger.info("Finished benchmarking for variant results.")
|
|
143
|
+
if disease_analysis_runs:
|
|
144
|
+
logger.info("Initiating benchmarking for disease results")
|
|
145
|
+
benchmark(
|
|
146
|
+
Config(
|
|
147
|
+
benchmark_name=config.benchmark_name,
|
|
148
|
+
runs=disease_analysis_runs,
|
|
149
|
+
plot_customisation=config.plot_customisation,
|
|
150
|
+
),
|
|
151
|
+
BenchmarkOutputTypeEnum.DISEASE.value,
|
|
152
|
+
)
|
|
153
|
+
logger.info("Finished benchmarking for disease results.")
|
|
154
|
+
logger.info(
|
|
155
|
+
f"Finished benchmarking! Total time: {time.perf_counter() - start_time:.2f} seconds."
|
|
156
|
+
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
from duckdb import DuckDBPyConnection
|
|
3
|
+
|
|
4
|
+
from pheval.utils.logger import get_logger
|
|
5
|
+
|
|
6
|
+
logger = get_logger()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def load_table_lazy(table_name: str, conn: DuckDBPyConnection) -> pl.LazyFrame:
|
|
10
|
+
logger.info(f"Loading table {table_name}")
|
|
11
|
+
return pl.from_arrow(conn.execute(f"SELECT * FROM {table_name}").fetch_arrow_table()).lazy()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def write_table(conn: DuckDBPyConnection, df: pl.DataFrame, table_name: str) -> None:
|
|
15
|
+
"""
|
|
16
|
+
Write table to DuckDB database.
|
|
17
|
+
Args:
|
|
18
|
+
conn (DuckDBPyConnection): DuckDB connection.
|
|
19
|
+
df (pl.DataFrame): DuckDB dataframe.
|
|
20
|
+
table_name (str): Table name.
|
|
21
|
+
"""
|
|
22
|
+
logger.info(f"Storing results in {table_name}.")
|
|
23
|
+
conn.execute(f"""CREATE TABLE "{table_name}" AS SELECT * FROM df""")
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import List, NamedTuple
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BenchmarkOutputType(NamedTuple):
|
|
6
|
+
"""
|
|
7
|
+
Represents the structure of benchmark output types.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
prioritisation_type_string (str): The type of prioritisation being performed.
|
|
11
|
+
y_label (str): The label for the y-axis in performance evaluation plots.
|
|
12
|
+
columns (List[str]): The list of column names relevant to the benchmark output.
|
|
13
|
+
result_directory (str): The directory where benchmark results are stored.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
prioritisation_type_string: str
|
|
17
|
+
y_label: str
|
|
18
|
+
columns: List[str]
|
|
19
|
+
result_directory: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BenchmarkOutputTypeEnum(Enum):
|
|
23
|
+
"""
|
|
24
|
+
Enumeration of benchmark output types, representing different entities.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
GENE (BenchmarkOutputType): Benchmark output type for gene prioritisation.
|
|
28
|
+
VARIANT (BenchmarkOutputType): Benchmark output type for variant prioritisation.
|
|
29
|
+
DISEASE (BenchmarkOutputType): Benchmark output type for disease prioritisation.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
GENE = BenchmarkOutputType(
|
|
33
|
+
"gene",
|
|
34
|
+
"Disease-causing genes (%)",
|
|
35
|
+
["gene_identifier", "gene_symbol"],
|
|
36
|
+
"pheval_gene_results",
|
|
37
|
+
)
|
|
38
|
+
VARIANT = BenchmarkOutputType(
|
|
39
|
+
"variant", "Disease-causing variants (%)", ["variant_id"], "pheval_variant_results"
|
|
40
|
+
)
|
|
41
|
+
DISEASE = BenchmarkOutputType(
|
|
42
|
+
"disease", "Known diseases (%)", ["disease_identifier"], "pheval_disease_results"
|
|
43
|
+
)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import polars as pl
|
|
5
|
+
from sklearn.metrics import precision_recall_curve, roc_curve
|
|
6
|
+
|
|
7
|
+
from pheval.utils.logger import get_logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BinaryClassificationCurves:
|
|
11
|
+
"""Class for computing and storing ROC & Precision-Recall curves in Polars."""
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def _compute_finite_bounds(result_scan: pl.LazyFrame) -> Tuple[float, float]:
|
|
15
|
+
"""
|
|
16
|
+
Compute min and max finite values in the 'score' column to handle NaN and Inf values.
|
|
17
|
+
Args:
|
|
18
|
+
result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Tuple[float, float]: The (max_finite, min_finite) values for normalising scores.
|
|
22
|
+
"""
|
|
23
|
+
return (
|
|
24
|
+
result_scan.select(
|
|
25
|
+
[
|
|
26
|
+
pl.col("score").filter(pl.col("score").is_finite()).max().alias("max_finite"),
|
|
27
|
+
pl.col("score").filter(pl.col("score").is_finite()).min().alias("min_finite"),
|
|
28
|
+
]
|
|
29
|
+
)
|
|
30
|
+
.collect()
|
|
31
|
+
.row(0)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _clean_and_extract_data(
|
|
36
|
+
result_scan: pl.LazyFrame, max_finite: float, min_finite: float
|
|
37
|
+
) -> pl.LazyFrame:
|
|
38
|
+
"""
|
|
39
|
+
Normalise the 'score' column (handling NaNs and Inf values) and extract 'true_positive' labels.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
|
|
43
|
+
max_finite (float): The maximum finite score value.
|
|
44
|
+
min_finite (float): The minimum finite score value.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
pl.LazyFrame: A LazyFrame with cleaned 'score' and binary 'true_positive' columns.
|
|
48
|
+
"""
|
|
49
|
+
return result_scan.with_columns(
|
|
50
|
+
[
|
|
51
|
+
pl.when(pl.col("score").is_nan())
|
|
52
|
+
.then(0.0)
|
|
53
|
+
.when(pl.col("score").is_infinite() & (pl.col("score") > 0))
|
|
54
|
+
.then(max_finite)
|
|
55
|
+
.when(pl.col("score").is_infinite() & (pl.col("score") < 0))
|
|
56
|
+
.then(min_finite)
|
|
57
|
+
.otherwise(pl.col("score"))
|
|
58
|
+
.alias("score"),
|
|
59
|
+
pl.when(pl.col("true_positive").is_null())
|
|
60
|
+
.then(0)
|
|
61
|
+
.otherwise(pl.col("true_positive").cast(pl.Int8))
|
|
62
|
+
.alias("true_positive"),
|
|
63
|
+
]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _compute_roc_pr_curves(
|
|
68
|
+
run_identifier: str, labels: np.ndarray, scores: np.ndarray
|
|
69
|
+
) -> pl.LazyFrame:
|
|
70
|
+
"""
|
|
71
|
+
Compute ROC and Precision-Recall curves.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
labels (np.ndarray): Binary ground truth labels (0 or 1).
|
|
75
|
+
scores (np.ndarray): Prediction scores.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
pl.LazyFrame: A LazyFrame containing the computed FPR, TPR, Precision, Recall, and Thresholds.
|
|
79
|
+
"""
|
|
80
|
+
fpr, tpr, roc_thresholds = roc_curve(labels, scores, pos_label=1)
|
|
81
|
+
precision, recall, pr_thresholds = precision_recall_curve(labels, scores, pos_label=1)
|
|
82
|
+
|
|
83
|
+
return pl.LazyFrame(
|
|
84
|
+
{
|
|
85
|
+
"run_identifier": [run_identifier],
|
|
86
|
+
"fpr": [fpr.tolist()],
|
|
87
|
+
"tpr": [tpr.tolist()],
|
|
88
|
+
"threshold_roc": [roc_thresholds.tolist()],
|
|
89
|
+
"precision": [precision.tolist()],
|
|
90
|
+
"recall": [recall.tolist()],
|
|
91
|
+
"threshold_pr": [pr_thresholds.tolist()],
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def process(cls, result_scan: pl.LazyFrame, run_identifier: str) -> pl.LazyFrame:
|
|
97
|
+
"""
|
|
98
|
+
Process scores, extract true labels, compute ROC and Precision-Recall curves,
|
|
99
|
+
and store results in a Polars LazyFrame with NumPy arrays.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
|
|
103
|
+
run_identifier (str): Identifier for this run.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
pl.LazyFrame: A LazyFrame containing ROC & PR curve data with NumPy arrays.
|
|
107
|
+
"""
|
|
108
|
+
max_finite, min_finite = cls._compute_finite_bounds(result_scan)
|
|
109
|
+
cleaned_data = (
|
|
110
|
+
cls._clean_and_extract_data(result_scan, max_finite, min_finite)
|
|
111
|
+
.select(["true_positive", "score"])
|
|
112
|
+
.collect()
|
|
113
|
+
)
|
|
114
|
+
return cls._compute_roc_pr_curves(
|
|
115
|
+
run_identifier,
|
|
116
|
+
cleaned_data["true_positive"].to_numpy().flatten(),
|
|
117
|
+
cleaned_data["score"].to_numpy().flatten(),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def compute_curves(run_identifier: str, result_scan: pl.LazyFrame) -> pl.LazyFrame:
|
|
122
|
+
"""
|
|
123
|
+
Compute ROC and Precision-Recall curves.
|
|
124
|
+
Args:
|
|
125
|
+
result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
|
|
126
|
+
run_identifier (str): Identifier for this run.
|
|
127
|
+
Returns:
|
|
128
|
+
pl.LazyFrame: LazyFrame containing the ROC & Precision-Recall curve data with NumPy arrays.
|
|
129
|
+
"""
|
|
130
|
+
logger = get_logger()
|
|
131
|
+
logger.info("Calculating ROC and Precision-Recall metrics")
|
|
132
|
+
return BinaryClassificationCurves.process(result_scan, run_identifier)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from multiprocessing.util import get_logger
|
|
3
|
+
|
|
4
|
+
import polars as pl
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class ConfusionMatrix:
|
|
9
|
+
"""
|
|
10
|
+
Define logical conditions for computing a confusion matrix using Polars expressions.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
TRUE_POSITIVES (pl.Expr): Condition identifying true positive cases,
|
|
14
|
+
where `rank == 1` and `true_positive` is `True`.
|
|
15
|
+
FALSE_POSITIVES (pl.Expr): Condition identifying false positive cases,
|
|
16
|
+
where `rank == 1` and `true_positive` is `False`.
|
|
17
|
+
TRUE_NEGATIVES (pl.Expr): Condition identifying true negative cases,
|
|
18
|
+
where `rank != 1` and `true_positive` is `False`.
|
|
19
|
+
FALSE_NEGATIVES (pl.Expr): Condition identifying false negative cases,
|
|
20
|
+
where `rank != 1` and `true_positive` is `True`.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
TRUE_POSITIVES = (pl.col("rank") == 1) & (pl.col("true_positive"))
|
|
24
|
+
FALSE_POSITIVES = (pl.col("rank") == 1) & (~pl.col("true_positive"))
|
|
25
|
+
TRUE_NEGATIVES = (pl.col("rank") != 1) & (~pl.col("true_positive"))
|
|
26
|
+
FALSE_NEGATIVES = (pl.col("rank") != 1) & (pl.col("true_positive"))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class BinaryClassificationStats:
|
|
31
|
+
"""Binary classification statistic expressions."""
|
|
32
|
+
|
|
33
|
+
SENSITIVITY = (
|
|
34
|
+
pl.when((pl.col("true_positives") + pl.col("false_negatives")) != 0)
|
|
35
|
+
.then(pl.col("true_positives") / (pl.col("true_positives") + pl.col("false_negatives")))
|
|
36
|
+
.otherwise(0.0)
|
|
37
|
+
.alias("sensitivity")
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
SPECIFICITY = (
|
|
41
|
+
pl.when((pl.col("true_negatives") + pl.col("false_positives")) != 0)
|
|
42
|
+
.then(pl.col("true_negatives") / (pl.col("true_negatives") + pl.col("false_positives")))
|
|
43
|
+
.otherwise(0.0)
|
|
44
|
+
.alias("specificity")
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
PRECISION = (
|
|
48
|
+
pl.when((pl.col("true_positives") + pl.col("false_positives")) != 0)
|
|
49
|
+
.then(pl.col("true_positives") / (pl.col("true_positives") + pl.col("false_positives")))
|
|
50
|
+
.otherwise(0.0)
|
|
51
|
+
.alias("precision")
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
NEGATIVE_PREDICTIVE_VALUE = (
|
|
55
|
+
pl.when((pl.col("true_negatives") + pl.col("false_negatives")) != 0)
|
|
56
|
+
.then(pl.col("true_negatives") / (pl.col("true_negatives") + pl.col("false_negatives")))
|
|
57
|
+
.otherwise(0.0)
|
|
58
|
+
.alias("negative_predictive_value")
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
FALSE_POSITIVE_RATE = (
|
|
62
|
+
pl.when((pl.col("false_positives") + pl.col("true_negatives")) != 0)
|
|
63
|
+
.then(pl.col("false_positives") / (pl.col("false_positives") + pl.col("true_negatives")))
|
|
64
|
+
.otherwise(0.0)
|
|
65
|
+
.alias("false_positive_rate")
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
FALSE_DISCOVERY_RATE = (
|
|
69
|
+
pl.when((pl.col("false_positives") + pl.col("true_positives")) != 0)
|
|
70
|
+
.then(pl.col("false_positives") / (pl.col("false_positives") + pl.col("true_positives")))
|
|
71
|
+
.otherwise(0.0)
|
|
72
|
+
.alias("false_discovery_rate")
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
FALSE_NEGATIVE_RATE = (
|
|
76
|
+
pl.when((pl.col("false_negatives") + pl.col("true_positives")) != 0)
|
|
77
|
+
.then(pl.col("false_negatives") / (pl.col("false_negatives") + pl.col("true_positives")))
|
|
78
|
+
.otherwise(0.0)
|
|
79
|
+
.alias("false_negative_rate")
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
ACCURACY = (
|
|
83
|
+
pl.when(
|
|
84
|
+
(
|
|
85
|
+
pl.col("true_positives")
|
|
86
|
+
+ pl.col("false_positives")
|
|
87
|
+
+ pl.col("true_negatives")
|
|
88
|
+
+ pl.col("false_negatives")
|
|
89
|
+
)
|
|
90
|
+
!= 0
|
|
91
|
+
)
|
|
92
|
+
.then(
|
|
93
|
+
(pl.col("true_positives") + pl.col("true_negatives"))
|
|
94
|
+
/ (
|
|
95
|
+
pl.col("true_positives")
|
|
96
|
+
+ pl.col("false_positives")
|
|
97
|
+
+ pl.col("true_negatives")
|
|
98
|
+
+ pl.col("false_negatives")
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
.otherwise(0.0)
|
|
102
|
+
.alias("accuracy")
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
F1_SCORE = (
|
|
106
|
+
pl.when(
|
|
107
|
+
2 * (pl.col("true_positives") + pl.col("false_positives") + pl.col("false_negatives"))
|
|
108
|
+
!= 0
|
|
109
|
+
)
|
|
110
|
+
.then(
|
|
111
|
+
2
|
|
112
|
+
* pl.col("true_positives")
|
|
113
|
+
/ (2 * pl.col("true_positives") + pl.col("false_positives") + pl.col("false_negatives"))
|
|
114
|
+
)
|
|
115
|
+
.otherwise(0.0)
|
|
116
|
+
.alias("f1_score")
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
MATTHEWS_CORRELATION_COEFFICIENT = (
|
|
120
|
+
pl.when(
|
|
121
|
+
(
|
|
122
|
+
(pl.col("true_positives") + pl.col("false_positives"))
|
|
123
|
+
* (pl.col("true_positives") + pl.col("false_negatives"))
|
|
124
|
+
* (pl.col("true_negatives") + pl.col("false_positives"))
|
|
125
|
+
* (pl.col("true_negatives") + pl.col("false_negatives"))
|
|
126
|
+
)
|
|
127
|
+
> 0
|
|
128
|
+
)
|
|
129
|
+
.then(
|
|
130
|
+
(
|
|
131
|
+
(pl.col("true_positives") * pl.col("true_negatives"))
|
|
132
|
+
- (pl.col("false_positives") * pl.col("false_negatives"))
|
|
133
|
+
)
|
|
134
|
+
/ (
|
|
135
|
+
(pl.col("true_positives") + pl.col("false_positives"))
|
|
136
|
+
* (pl.col("true_positives") + pl.col("false_negatives"))
|
|
137
|
+
* (pl.col("true_negatives") + pl.col("false_positives"))
|
|
138
|
+
* (pl.col("true_negatives") + pl.col("false_negatives"))
|
|
139
|
+
).sqrt()
|
|
140
|
+
)
|
|
141
|
+
.otherwise(0.0)
|
|
142
|
+
.alias("matthews_correlation_coefficient")
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def compute_confusion_matrix(run_identifier: str, result_scan: pl.LazyFrame) -> pl.LazyFrame:
|
|
147
|
+
"""
|
|
148
|
+
Computes binary classification statistics.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
run_identifier (str): The identifier for the run.
|
|
152
|
+
result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
pl.LazyFrame: The LazyFrame containing the binary classification statistics.
|
|
156
|
+
"""
|
|
157
|
+
logger = get_logger()
|
|
158
|
+
logger.info(f"Computing binary classification statistics for {run_identifier}")
|
|
159
|
+
confusion_matrix = result_scan.select(
|
|
160
|
+
[
|
|
161
|
+
pl.lit(run_identifier).alias("run_identifier"),
|
|
162
|
+
ConfusionMatrix.TRUE_POSITIVES.sum().alias("true_positives").cast(pl.Int64),
|
|
163
|
+
ConfusionMatrix.FALSE_POSITIVES.sum().alias("false_positives").cast(pl.Int64),
|
|
164
|
+
ConfusionMatrix.TRUE_NEGATIVES.sum().alias("true_negatives").cast(pl.Int64),
|
|
165
|
+
ConfusionMatrix.FALSE_NEGATIVES.sum().alias("false_negatives").cast(pl.Int64),
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
return confusion_matrix.select(
|
|
169
|
+
[
|
|
170
|
+
pl.col("run_identifier"),
|
|
171
|
+
pl.col("true_positives"),
|
|
172
|
+
pl.col("false_positives"),
|
|
173
|
+
pl.col("true_negatives"),
|
|
174
|
+
pl.col("false_negatives"),
|
|
175
|
+
BinaryClassificationStats.SENSITIVITY,
|
|
176
|
+
BinaryClassificationStats.SPECIFICITY,
|
|
177
|
+
BinaryClassificationStats.PRECISION,
|
|
178
|
+
BinaryClassificationStats.NEGATIVE_PREDICTIVE_VALUE,
|
|
179
|
+
BinaryClassificationStats.FALSE_POSITIVE_RATE,
|
|
180
|
+
BinaryClassificationStats.FALSE_DISCOVERY_RATE,
|
|
181
|
+
BinaryClassificationStats.FALSE_NEGATIVE_RATE,
|
|
182
|
+
BinaryClassificationStats.ACCURACY,
|
|
183
|
+
BinaryClassificationStats.F1_SCORE,
|
|
184
|
+
BinaryClassificationStats.MATTHEWS_CORRELATION_COEFFICIENT,
|
|
185
|
+
]
|
|
186
|
+
)
|