pheval 0.4.7__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pheval might be problematic. Click here for more details.

Files changed (68) hide show
  1. {pheval-0.4.7 → pheval-0.5.0}/PKG-INFO +4 -4
  2. {pheval-0.4.7 → pheval-0.5.0}/pyproject.toml +4 -4
  3. pheval-0.5.0/src/pheval/analyse/benchmark.py +156 -0
  4. pheval-0.5.0/src/pheval/analyse/benchmark_db_manager.py +23 -0
  5. pheval-0.5.0/src/pheval/analyse/benchmark_output_type.py +43 -0
  6. pheval-0.5.0/src/pheval/analyse/binary_classification_curves.py +132 -0
  7. pheval-0.5.0/src/pheval/analyse/binary_classification_stats.py +186 -0
  8. pheval-0.5.0/src/pheval/analyse/generate_plots.py +379 -0
  9. pheval-0.5.0/src/pheval/analyse/generate_rank_comparisons.py +44 -0
  10. pheval-0.5.0/src/pheval/analyse/rank_stats.py +255 -0
  11. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/analyse/run_data_parser.py +21 -39
  12. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/cli.py +27 -24
  13. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/cli_pheval_utils.py +7 -8
  14. pheval-0.5.0/src/pheval/post_processing/phenopacket_truth_set.py +235 -0
  15. pheval-0.5.0/src/pheval/post_processing/post_processing.py +266 -0
  16. pheval-0.5.0/src/pheval/post_processing/validate_result_format.py +92 -0
  17. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/prepare/update_phenopacket.py +11 -9
  18. pheval-0.5.0/src/pheval/utils/logger.py +35 -0
  19. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/phenopacket_utils.py +85 -91
  20. pheval-0.4.7/src/pheval/analyse/analysis.py +0 -104
  21. pheval-0.4.7/src/pheval/analyse/assess_prioritisation_base.py +0 -108
  22. pheval-0.4.7/src/pheval/analyse/benchmark_db_manager.py +0 -141
  23. pheval-0.4.7/src/pheval/analyse/benchmark_generator.py +0 -126
  24. pheval-0.4.7/src/pheval/analyse/benchmarking_data.py +0 -25
  25. pheval-0.4.7/src/pheval/analyse/binary_classification_stats.py +0 -329
  26. pheval-0.4.7/src/pheval/analyse/disease_prioritisation_analysis.py +0 -152
  27. pheval-0.4.7/src/pheval/analyse/gene_prioritisation_analysis.py +0 -147
  28. pheval-0.4.7/src/pheval/analyse/generate_plots.py +0 -564
  29. pheval-0.4.7/src/pheval/analyse/generate_summary_outputs.py +0 -105
  30. pheval-0.4.7/src/pheval/analyse/parse_benchmark_summary.py +0 -81
  31. pheval-0.4.7/src/pheval/analyse/parse_corpus.py +0 -219
  32. pheval-0.4.7/src/pheval/analyse/prioritisation_result_types.py +0 -52
  33. pheval-0.4.7/src/pheval/analyse/rank_stats.py +0 -447
  34. pheval-0.4.7/src/pheval/analyse/variant_prioritisation_analysis.py +0 -159
  35. pheval-0.4.7/src/pheval/post_processing/post_processing.py +0 -418
  36. {pheval-0.4.7 → pheval-0.5.0}/LICENSE +0 -0
  37. {pheval-0.4.7 → pheval-0.5.0}/README.md +0 -0
  38. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/__init__.py +0 -0
  39. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/analyse/__init__.py +0 -0
  40. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/cli_pheval.py +0 -0
  41. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/config_parser.py +0 -0
  42. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/implementations/__init__.py +0 -0
  43. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/infra/__init__.py +0 -0
  44. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/infra/exomiserdb.py +0 -0
  45. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/post_processing/__init__.py +0 -0
  46. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/prepare/__init__.py +0 -0
  47. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/prepare/create_noisy_phenopackets.py +0 -0
  48. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/prepare/create_spiked_vcf.py +0 -0
  49. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/prepare/custom_exceptions.py +0 -0
  50. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/prepare/prepare_corpus.py +0 -0
  51. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/CADA_results.txt +0 -0
  52. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/DeepPVP_results.txt +0 -0
  53. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/OVA_results.txt +0 -0
  54. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/Phen2Gene_results.json +0 -0
  55. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/Phenolyzer_results.txt +0 -0
  56. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/lirical_results.tsv +0 -0
  57. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/alternate_ouputs/svanna_results.tsv +0 -0
  58. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/resources/hgnc_complete_set.txt +0 -0
  59. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/run_metadata.py +0 -0
  60. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/runners/__init__.py +0 -0
  61. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/runners/runner.py +0 -0
  62. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/__init__.py +0 -0
  63. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/docs_gen.py +0 -0
  64. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/docs_gen.sh +0 -0
  65. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/exomiser.py +0 -0
  66. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/file_utils.py +0 -0
  67. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/semsim_utils.py +0 -0
  68. {pheval-0.4.7 → pheval-0.5.0}/src/pheval/utils/utils.py +0 -0
@@ -1,12 +1,11 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: pheval
3
- Version: 0.4.7
3
+ Version: 0.5.0
4
4
  Summary:
5
5
  Author: Yasemin Bridges
6
6
  Author-email: y.bridges@qmul.ac.uk
7
- Requires-Python: >=3.9,<4.0.0
7
+ Requires-Python: >=3.10,<4.0.0
8
8
  Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.9
10
9
  Classifier: Programming Language :: Python :: 3.10
11
10
  Classifier: Programming Language :: Python :: 3.11
12
11
  Classifier: Programming Language :: Python :: 3.12
@@ -22,8 +21,9 @@ Requires-Dist: oaklib (>=0.5.6)
22
21
  Requires-Dist: pandas (>=1.5.1)
23
22
  Requires-Dist: phenopackets (>=2.0.2,<3.0.0)
24
23
  Requires-Dist: plotly (>=5.13.0,<6.0.0)
25
- Requires-Dist: polars (>=0.19.15,<0.20.0)
24
+ Requires-Dist: polars (>=1.23,<2.0)
26
25
  Requires-Dist: pyaml (>=21.10.1,<22.0.0)
26
+ Requires-Dist: pyarrow (>=19.0.1,<20.0.0)
27
27
  Requires-Dist: pyserde (>=0.9.8,<0.10.0)
28
28
  Requires-Dist: scikit-learn (>=1.4.0,<2.0.0)
29
29
  Requires-Dist: seaborn (>=0.12.2,<0.13.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "pheval"
3
- version = "0.4.7"
3
+ version = "0.5.0"
4
4
  description = ""
5
5
  authors = ["Yasemin Bridges <y.bridges@qmul.ac.uk>",
6
6
  "Julius Jacobsen <j.jacobsen@qmul.ac.uk>",
@@ -10,7 +10,7 @@ readme = "README.md"
10
10
  packages = [{include = "pheval", from = "src"}]
11
11
 
12
12
  [tool.poetry.dependencies]
13
- python = ">=3.9,<4.0.0"
13
+ python = ">=3.10,<4.0.0"
14
14
  jaydebeapi = ">=1.2.3"
15
15
  tqdm = ">=4.64.1"
16
16
  pandas = ">=1.5.1"
@@ -25,14 +25,14 @@ plotly = "^5.13.0"
25
25
  seaborn = "^0.12.2"
26
26
  matplotlib = "^3.7.0"
27
27
  pyserde = "^0.9.8"
28
- polars = "^0.19.15"
28
+ polars = "^1.23"
29
29
  scikit-learn = "^1.4.0"
30
30
  duckdb = "^1.0.0"
31
+ pyarrow = "^19.0.1"
31
32
 
32
33
  [tool.poetry.dev-dependencies]
33
34
  pytest = "^7.2.0"
34
35
  coverage = "^6.5.0"
35
- pheval-template = "^0.1.2"
36
36
  pytest-workflow = "^2.0.1"
37
37
 
38
38
  [tool.poetry.scripts]
@@ -0,0 +1,156 @@
1
+ import time
2
+ from pathlib import Path
3
+ from typing import List, Tuple
4
+
5
+ import duckdb
6
+ import polars as pl
7
+
8
+ from pheval.analyse.benchmark_db_manager import write_table
9
+ from pheval.analyse.benchmark_output_type import BenchmarkOutputType, BenchmarkOutputTypeEnum
10
+ from pheval.analyse.binary_classification_curves import compute_curves
11
+ from pheval.analyse.binary_classification_stats import compute_confusion_matrix
12
+ from pheval.analyse.generate_plots import generate_plots
13
+ from pheval.analyse.generate_rank_comparisons import calculate_rank_changes
14
+ from pheval.analyse.rank_stats import compute_rank_stats
15
+ from pheval.analyse.run_data_parser import Config, RunConfig, parse_run_config
16
+ from pheval.utils.logger import get_logger
17
+
18
+
19
+ def scan_directory(run: RunConfig, benchmark_type: BenchmarkOutputType) -> pl.LazyFrame:
20
+ """
21
+ Scan a results directory containing pheval parquet standardised results and return a LazyFrame object.
22
+ Args:
23
+ run (RunConfig): RunConfig object.
24
+ benchmark_type (BenchmarkOutputTypeEnum): Benchmark output type.
25
+ Returns:
26
+ pl.LazyFrame: LazyFrame object containing all the results in the directory..
27
+ """
28
+ logger = get_logger()
29
+ logger.info(f"Analysing results in {run.results_dir.joinpath(benchmark_type.result_directory)}")
30
+ return (
31
+ pl.scan_parquet(
32
+ run.results_dir.joinpath(benchmark_type.result_directory),
33
+ include_file_paths="file_path",
34
+ ).with_columns(
35
+ pl.col("rank").cast(pl.Int64),
36
+ pl.col("file_path").str.extract(r"([^/\\]+)$").alias("result_file"),
37
+ pl.col("true_positive").fill_null(False),
38
+ )
39
+ ).filter(
40
+ (
41
+ pl.col("score") >= run.threshold
42
+ if run.score_order.lower() == "descending"
43
+ else pl.col("score") <= run.threshold
44
+ )
45
+ if run.threshold is not None
46
+ else True
47
+ )
48
+
49
+
50
+ def process_stats(
51
+ runs: List[RunConfig], benchmark_type: BenchmarkOutputType
52
+ ) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
53
+ """
54
+ Processes stats outputs for specified runs to compare.
55
+ Args:
56
+ runs (List[RunConfig]): List of runs to benchmark.
57
+ benchmark_type (BenchmarkOutputTypeEnum): Benchmark output type.
58
+ Returns:
59
+ Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame]: The stats for all runs.
60
+ """
61
+ stats, curve_results, true_positive_cases = [], [], []
62
+ for run in runs:
63
+ result_scan = scan_directory(run, benchmark_type)
64
+ stats.append(
65
+ compute_rank_stats(run.run_identifier, result_scan).join(
66
+ compute_confusion_matrix(run.run_identifier, result_scan), on="run_identifier"
67
+ )
68
+ )
69
+ curve_results.append(compute_curves(run.run_identifier, result_scan))
70
+ true_positive_cases.append(
71
+ result_scan.filter(pl.col("true_positive")).select(
72
+ ["result_file", *benchmark_type.columns, pl.col("rank").alias(run.run_identifier)]
73
+ )
74
+ )
75
+ return (
76
+ pl.concat(stats, how="vertical").collect(),
77
+ pl.concat(curve_results, how="vertical").collect(),
78
+ pl.concat(true_positive_cases, how="align_inner").collect(),
79
+ )
80
+
81
+
82
+ def benchmark(config: Config, benchmark_type: BenchmarkOutputType) -> None:
83
+ """
84
+ Benchmark results for specified runs for a specified prioritisation type for comparison.
85
+ Args:
86
+ config (Config): Configuration for benchmarking.
87
+ benchmark_type (BenchmarkOutputType): Benchmark output type.
88
+ """
89
+ conn = duckdb.connect(f"{config.benchmark_name}.duckdb")
90
+ stats, curve_results, true_positive_cases = process_stats(config.runs, benchmark_type)
91
+ write_table(
92
+ conn, stats, f"{config.benchmark_name}_{benchmark_type.prioritisation_type_string}_summary"
93
+ )
94
+ write_table(
95
+ conn,
96
+ curve_results,
97
+ f"{config.benchmark_name}_{benchmark_type.prioritisation_type_string}_binary_classification_curves",
98
+ )
99
+ calculate_rank_changes(
100
+ conn, [run.run_identifier for run in config.runs], true_positive_cases, benchmark_type
101
+ )
102
+ generate_plots(
103
+ config.benchmark_name, stats, curve_results, benchmark_type, config.plot_customisation
104
+ )
105
+ conn.close()
106
+
107
+
108
+ def benchmark_runs(benchmark_config_file: Path) -> None:
109
+ """
110
+ Benchmark results for specified runs for comparison.
111
+ Args:
112
+ benchmark_config_file (Path): Path to benchmark config file.
113
+ """
114
+ logger = get_logger()
115
+ start_time = time.perf_counter()
116
+ logger.info("Initiated benchmarking process.")
117
+ config = parse_run_config(benchmark_config_file)
118
+ gene_analysis_runs = [run for run in config.runs if run.gene_analysis]
119
+ variant_analysis_runs = [run for run in config.runs if run.variant_analysis]
120
+ disease_analysis_runs = [run for run in config.runs if run.disease_analysis]
121
+ if gene_analysis_runs:
122
+ logger.info("Initiating benchmarking for gene results.")
123
+ benchmark(
124
+ Config(
125
+ benchmark_name=config.benchmark_name,
126
+ runs=gene_analysis_runs,
127
+ plot_customisation=config.plot_customisation,
128
+ ),
129
+ BenchmarkOutputTypeEnum.GENE.value,
130
+ )
131
+ logger.info("Finished benchmarking for gene results.")
132
+ if variant_analysis_runs:
133
+ logger.info("Initiating benchmarking for variant results")
134
+ benchmark(
135
+ Config(
136
+ benchmark_name=config.benchmark_name,
137
+ runs=variant_analysis_runs,
138
+ plot_customisation=config.plot_customisation,
139
+ ),
140
+ BenchmarkOutputTypeEnum.VARIANT.value,
141
+ )
142
+ logger.info("Finished benchmarking for variant results.")
143
+ if disease_analysis_runs:
144
+ logger.info("Initiating benchmarking for disease results")
145
+ benchmark(
146
+ Config(
147
+ benchmark_name=config.benchmark_name,
148
+ runs=disease_analysis_runs,
149
+ plot_customisation=config.plot_customisation,
150
+ ),
151
+ BenchmarkOutputTypeEnum.DISEASE.value,
152
+ )
153
+ logger.info("Finished benchmarking for disease results.")
154
+ logger.info(
155
+ f"Finished benchmarking! Total time: {time.perf_counter() - start_time:.2f} seconds."
156
+ )
@@ -0,0 +1,23 @@
1
+ import polars as pl
2
+ from duckdb import DuckDBPyConnection
3
+
4
+ from pheval.utils.logger import get_logger
5
+
6
+ logger = get_logger()
7
+
8
+
9
+ def load_table_lazy(table_name: str, conn: DuckDBPyConnection) -> pl.LazyFrame:
10
+ logger.info(f"Loading table {table_name}")
11
+ return pl.from_arrow(conn.execute(f"SELECT * FROM {table_name}").fetch_arrow_table()).lazy()
12
+
13
+
14
+ def write_table(conn: DuckDBPyConnection, df: pl.DataFrame, table_name: str) -> None:
15
+ """
16
+ Write table to DuckDB database.
17
+ Args:
18
+ conn (DuckDBPyConnection): DuckDB connection.
19
+ df (pl.DataFrame): DuckDB dataframe.
20
+ table_name (str): Table name.
21
+ """
22
+ logger.info(f"Storing results in {table_name}.")
23
+ conn.execute(f"""CREATE TABLE "{table_name}" AS SELECT * FROM df""")
@@ -0,0 +1,43 @@
1
+ from enum import Enum
2
+ from typing import List, NamedTuple
3
+
4
+
5
+ class BenchmarkOutputType(NamedTuple):
6
+ """
7
+ Represents the structure of benchmark output types.
8
+
9
+ Attributes:
10
+ prioritisation_type_string (str): The type of prioritisation being performed.
11
+ y_label (str): The label for the y-axis in performance evaluation plots.
12
+ columns (List[str]): The list of column names relevant to the benchmark output.
13
+ result_directory (str): The directory where benchmark results are stored.
14
+ """
15
+
16
+ prioritisation_type_string: str
17
+ y_label: str
18
+ columns: List[str]
19
+ result_directory: str
20
+
21
+
22
+ class BenchmarkOutputTypeEnum(Enum):
23
+ """
24
+ Enumeration of benchmark output types, representing different entities.
25
+
26
+ Attributes:
27
+ GENE (BenchmarkOutputType): Benchmark output type for gene prioritisation.
28
+ VARIANT (BenchmarkOutputType): Benchmark output type for variant prioritisation.
29
+ DISEASE (BenchmarkOutputType): Benchmark output type for disease prioritisation.
30
+ """
31
+
32
+ GENE = BenchmarkOutputType(
33
+ "gene",
34
+ "Disease-causing genes (%)",
35
+ ["gene_identifier", "gene_symbol"],
36
+ "pheval_gene_results",
37
+ )
38
+ VARIANT = BenchmarkOutputType(
39
+ "variant", "Disease-causing variants (%)", ["variant_id"], "pheval_variant_results"
40
+ )
41
+ DISEASE = BenchmarkOutputType(
42
+ "disease", "Known diseases (%)", ["disease_identifier"], "pheval_disease_results"
43
+ )
@@ -0,0 +1,132 @@
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import polars as pl
5
+ from sklearn.metrics import precision_recall_curve, roc_curve
6
+
7
+ from pheval.utils.logger import get_logger
8
+
9
+
10
+ class BinaryClassificationCurves:
11
+ """Class for computing and storing ROC & Precision-Recall curves in Polars."""
12
+
13
+ @staticmethod
14
+ def _compute_finite_bounds(result_scan: pl.LazyFrame) -> Tuple[float, float]:
15
+ """
16
+ Compute min and max finite values in the 'score' column to handle NaN and Inf values.
17
+ Args:
18
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
19
+
20
+ Returns:
21
+ Tuple[float, float]: The (max_finite, min_finite) values for normalising scores.
22
+ """
23
+ return (
24
+ result_scan.select(
25
+ [
26
+ pl.col("score").filter(pl.col("score").is_finite()).max().alias("max_finite"),
27
+ pl.col("score").filter(pl.col("score").is_finite()).min().alias("min_finite"),
28
+ ]
29
+ )
30
+ .collect()
31
+ .row(0)
32
+ )
33
+
34
+ @staticmethod
35
+ def _clean_and_extract_data(
36
+ result_scan: pl.LazyFrame, max_finite: float, min_finite: float
37
+ ) -> pl.LazyFrame:
38
+ """
39
+ Normalise the 'score' column (handling NaNs and Inf values) and extract 'true_positive' labels.
40
+
41
+ Args:
42
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
43
+ max_finite (float): The maximum finite score value.
44
+ min_finite (float): The minimum finite score value.
45
+
46
+ Returns:
47
+ pl.LazyFrame: A LazyFrame with cleaned 'score' and binary 'true_positive' columns.
48
+ """
49
+ return result_scan.with_columns(
50
+ [
51
+ pl.when(pl.col("score").is_nan())
52
+ .then(0.0)
53
+ .when(pl.col("score").is_infinite() & (pl.col("score") > 0))
54
+ .then(max_finite)
55
+ .when(pl.col("score").is_infinite() & (pl.col("score") < 0))
56
+ .then(min_finite)
57
+ .otherwise(pl.col("score"))
58
+ .alias("score"),
59
+ pl.when(pl.col("true_positive").is_null())
60
+ .then(0)
61
+ .otherwise(pl.col("true_positive").cast(pl.Int8))
62
+ .alias("true_positive"),
63
+ ]
64
+ )
65
+
66
+ @staticmethod
67
+ def _compute_roc_pr_curves(
68
+ run_identifier: str, labels: np.ndarray, scores: np.ndarray
69
+ ) -> pl.LazyFrame:
70
+ """
71
+ Compute ROC and Precision-Recall curves.
72
+
73
+ Args:
74
+ labels (np.ndarray): Binary ground truth labels (0 or 1).
75
+ scores (np.ndarray): Prediction scores.
76
+
77
+ Returns:
78
+ pl.LazyFrame: A LazyFrame containing the computed FPR, TPR, Precision, Recall, and Thresholds.
79
+ """
80
+ fpr, tpr, roc_thresholds = roc_curve(labels, scores, pos_label=1)
81
+ precision, recall, pr_thresholds = precision_recall_curve(labels, scores, pos_label=1)
82
+
83
+ return pl.LazyFrame(
84
+ {
85
+ "run_identifier": [run_identifier],
86
+ "fpr": [fpr.tolist()],
87
+ "tpr": [tpr.tolist()],
88
+ "threshold_roc": [roc_thresholds.tolist()],
89
+ "precision": [precision.tolist()],
90
+ "recall": [recall.tolist()],
91
+ "threshold_pr": [pr_thresholds.tolist()],
92
+ }
93
+ )
94
+
95
+ @classmethod
96
+ def process(cls, result_scan: pl.LazyFrame, run_identifier: str) -> pl.LazyFrame:
97
+ """
98
+ Process scores, extract true labels, compute ROC and Precision-Recall curves,
99
+ and store results in a Polars LazyFrame with NumPy arrays.
100
+
101
+ Args:
102
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
103
+ run_identifier (str): Identifier for this run.
104
+
105
+ Returns:
106
+ pl.LazyFrame: A LazyFrame containing ROC & PR curve data with NumPy arrays.
107
+ """
108
+ max_finite, min_finite = cls._compute_finite_bounds(result_scan)
109
+ cleaned_data = (
110
+ cls._clean_and_extract_data(result_scan, max_finite, min_finite)
111
+ .select(["true_positive", "score"])
112
+ .collect()
113
+ )
114
+ return cls._compute_roc_pr_curves(
115
+ run_identifier,
116
+ cleaned_data["true_positive"].to_numpy().flatten(),
117
+ cleaned_data["score"].to_numpy().flatten(),
118
+ )
119
+
120
+
121
+ def compute_curves(run_identifier: str, result_scan: pl.LazyFrame) -> pl.LazyFrame:
122
+ """
123
+ Compute ROC and Precision-Recall curves.
124
+ Args:
125
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
126
+ run_identifier (str): Identifier for this run.
127
+ Returns:
128
+ pl.LazyFrame: LazyFrame containing the ROC & Precision-Recall curve data with NumPy arrays.
129
+ """
130
+ logger = get_logger()
131
+ logger.info("Calculating ROC and Precision-Recall metrics")
132
+ return BinaryClassificationCurves.process(result_scan, run_identifier)
@@ -0,0 +1,186 @@
1
+ from dataclasses import dataclass
2
+ from multiprocessing.util import get_logger
3
+
4
+ import polars as pl
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class ConfusionMatrix:
9
+ """
10
+ Define logical conditions for computing a confusion matrix using Polars expressions.
11
+
12
+ Attributes:
13
+ TRUE_POSITIVES (pl.Expr): Condition identifying true positive cases,
14
+ where `rank == 1` and `true_positive` is `True`.
15
+ FALSE_POSITIVES (pl.Expr): Condition identifying false positive cases,
16
+ where `rank == 1` and `true_positive` is `False`.
17
+ TRUE_NEGATIVES (pl.Expr): Condition identifying true negative cases,
18
+ where `rank != 1` and `true_positive` is `False`.
19
+ FALSE_NEGATIVES (pl.Expr): Condition identifying false negative cases,
20
+ where `rank != 1` and `true_positive` is `True`.
21
+ """
22
+
23
+ TRUE_POSITIVES = (pl.col("rank") == 1) & (pl.col("true_positive"))
24
+ FALSE_POSITIVES = (pl.col("rank") == 1) & (~pl.col("true_positive"))
25
+ TRUE_NEGATIVES = (pl.col("rank") != 1) & (~pl.col("true_positive"))
26
+ FALSE_NEGATIVES = (pl.col("rank") != 1) & (pl.col("true_positive"))
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class BinaryClassificationStats:
31
+ """Binary classification statistic expressions."""
32
+
33
+ SENSITIVITY = (
34
+ pl.when((pl.col("true_positives") + pl.col("false_negatives")) != 0)
35
+ .then(pl.col("true_positives") / (pl.col("true_positives") + pl.col("false_negatives")))
36
+ .otherwise(0.0)
37
+ .alias("sensitivity")
38
+ )
39
+
40
+ SPECIFICITY = (
41
+ pl.when((pl.col("true_negatives") + pl.col("false_positives")) != 0)
42
+ .then(pl.col("true_negatives") / (pl.col("true_negatives") + pl.col("false_positives")))
43
+ .otherwise(0.0)
44
+ .alias("specificity")
45
+ )
46
+
47
+ PRECISION = (
48
+ pl.when((pl.col("true_positives") + pl.col("false_positives")) != 0)
49
+ .then(pl.col("true_positives") / (pl.col("true_positives") + pl.col("false_positives")))
50
+ .otherwise(0.0)
51
+ .alias("precision")
52
+ )
53
+
54
+ NEGATIVE_PREDICTIVE_VALUE = (
55
+ pl.when((pl.col("true_negatives") + pl.col("false_negatives")) != 0)
56
+ .then(pl.col("true_negatives") / (pl.col("true_negatives") + pl.col("false_negatives")))
57
+ .otherwise(0.0)
58
+ .alias("negative_predictive_value")
59
+ )
60
+
61
+ FALSE_POSITIVE_RATE = (
62
+ pl.when((pl.col("false_positives") + pl.col("true_negatives")) != 0)
63
+ .then(pl.col("false_positives") / (pl.col("false_positives") + pl.col("true_negatives")))
64
+ .otherwise(0.0)
65
+ .alias("false_positive_rate")
66
+ )
67
+
68
+ FALSE_DISCOVERY_RATE = (
69
+ pl.when((pl.col("false_positives") + pl.col("true_positives")) != 0)
70
+ .then(pl.col("false_positives") / (pl.col("false_positives") + pl.col("true_positives")))
71
+ .otherwise(0.0)
72
+ .alias("false_discovery_rate")
73
+ )
74
+
75
+ FALSE_NEGATIVE_RATE = (
76
+ pl.when((pl.col("false_negatives") + pl.col("true_positives")) != 0)
77
+ .then(pl.col("false_negatives") / (pl.col("false_negatives") + pl.col("true_positives")))
78
+ .otherwise(0.0)
79
+ .alias("false_negative_rate")
80
+ )
81
+
82
+ ACCURACY = (
83
+ pl.when(
84
+ (
85
+ pl.col("true_positives")
86
+ + pl.col("false_positives")
87
+ + pl.col("true_negatives")
88
+ + pl.col("false_negatives")
89
+ )
90
+ != 0
91
+ )
92
+ .then(
93
+ (pl.col("true_positives") + pl.col("true_negatives"))
94
+ / (
95
+ pl.col("true_positives")
96
+ + pl.col("false_positives")
97
+ + pl.col("true_negatives")
98
+ + pl.col("false_negatives")
99
+ )
100
+ )
101
+ .otherwise(0.0)
102
+ .alias("accuracy")
103
+ )
104
+
105
+ F1_SCORE = (
106
+ pl.when(
107
+ 2 * (pl.col("true_positives") + pl.col("false_positives") + pl.col("false_negatives"))
108
+ != 0
109
+ )
110
+ .then(
111
+ 2
112
+ * pl.col("true_positives")
113
+ / (2 * pl.col("true_positives") + pl.col("false_positives") + pl.col("false_negatives"))
114
+ )
115
+ .otherwise(0.0)
116
+ .alias("f1_score")
117
+ )
118
+
119
+ MATTHEWS_CORRELATION_COEFFICIENT = (
120
+ pl.when(
121
+ (
122
+ (pl.col("true_positives") + pl.col("false_positives"))
123
+ * (pl.col("true_positives") + pl.col("false_negatives"))
124
+ * (pl.col("true_negatives") + pl.col("false_positives"))
125
+ * (pl.col("true_negatives") + pl.col("false_negatives"))
126
+ )
127
+ > 0
128
+ )
129
+ .then(
130
+ (
131
+ (pl.col("true_positives") * pl.col("true_negatives"))
132
+ - (pl.col("false_positives") * pl.col("false_negatives"))
133
+ )
134
+ / (
135
+ (pl.col("true_positives") + pl.col("false_positives"))
136
+ * (pl.col("true_positives") + pl.col("false_negatives"))
137
+ * (pl.col("true_negatives") + pl.col("false_positives"))
138
+ * (pl.col("true_negatives") + pl.col("false_negatives"))
139
+ ).sqrt()
140
+ )
141
+ .otherwise(0.0)
142
+ .alias("matthews_correlation_coefficient")
143
+ )
144
+
145
+
146
+ def compute_confusion_matrix(run_identifier: str, result_scan: pl.LazyFrame) -> pl.LazyFrame:
147
+ """
148
+ Computes binary classification statistics.
149
+
150
+ Args:
151
+ run_identifier (str): The identifier for the run.
152
+ result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
153
+
154
+ Returns:
155
+ pl.LazyFrame: The LazyFrame containing the binary classification statistics.
156
+ """
157
+ logger = get_logger()
158
+ logger.info(f"Computing binary classification statistics for {run_identifier}")
159
+ confusion_matrix = result_scan.select(
160
+ [
161
+ pl.lit(run_identifier).alias("run_identifier"),
162
+ ConfusionMatrix.TRUE_POSITIVES.sum().alias("true_positives").cast(pl.Int64),
163
+ ConfusionMatrix.FALSE_POSITIVES.sum().alias("false_positives").cast(pl.Int64),
164
+ ConfusionMatrix.TRUE_NEGATIVES.sum().alias("true_negatives").cast(pl.Int64),
165
+ ConfusionMatrix.FALSE_NEGATIVES.sum().alias("false_negatives").cast(pl.Int64),
166
+ ]
167
+ )
168
+ return confusion_matrix.select(
169
+ [
170
+ pl.col("run_identifier"),
171
+ pl.col("true_positives"),
172
+ pl.col("false_positives"),
173
+ pl.col("true_negatives"),
174
+ pl.col("false_negatives"),
175
+ BinaryClassificationStats.SENSITIVITY,
176
+ BinaryClassificationStats.SPECIFICITY,
177
+ BinaryClassificationStats.PRECISION,
178
+ BinaryClassificationStats.NEGATIVE_PREDICTIVE_VALUE,
179
+ BinaryClassificationStats.FALSE_POSITIVE_RATE,
180
+ BinaryClassificationStats.FALSE_DISCOVERY_RATE,
181
+ BinaryClassificationStats.FALSE_NEGATIVE_RATE,
182
+ BinaryClassificationStats.ACCURACY,
183
+ BinaryClassificationStats.F1_SCORE,
184
+ BinaryClassificationStats.MATTHEWS_CORRELATION_COEFFICIENT,
185
+ ]
186
+ )