pheval 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pheval might be problematic. Click here for more details.
- pheval/analyse/benchmark.py +156 -0
- pheval/analyse/benchmark_db_manager.py +16 -134
- pheval/analyse/benchmark_output_type.py +43 -0
- pheval/analyse/binary_classification_curves.py +132 -0
- pheval/analyse/binary_classification_stats.py +164 -307
- pheval/analyse/generate_plots.py +210 -395
- pheval/analyse/generate_rank_comparisons.py +44 -0
- pheval/analyse/rank_stats.py +190 -382
- pheval/analyse/run_data_parser.py +21 -39
- pheval/cli.py +28 -25
- pheval/cli_pheval_utils.py +7 -8
- pheval/post_processing/phenopacket_truth_set.py +235 -0
- pheval/post_processing/post_processing.py +183 -303
- pheval/post_processing/validate_result_format.py +92 -0
- pheval/prepare/update_phenopacket.py +11 -9
- pheval/utils/logger.py +35 -0
- pheval/utils/phenopacket_utils.py +85 -91
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/METADATA +4 -4
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/RECORD +22 -26
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/WHEEL +1 -1
- pheval/analyse/analysis.py +0 -104
- pheval/analyse/assess_prioritisation_base.py +0 -108
- pheval/analyse/benchmark_generator.py +0 -126
- pheval/analyse/benchmarking_data.py +0 -25
- pheval/analyse/disease_prioritisation_analysis.py +0 -152
- pheval/analyse/gene_prioritisation_analysis.py +0 -147
- pheval/analyse/generate_summary_outputs.py +0 -105
- pheval/analyse/parse_benchmark_summary.py +0 -81
- pheval/analyse/parse_corpus.py +0 -219
- pheval/analyse/prioritisation_result_types.py +0 -52
- pheval/analyse/variant_prioritisation_analysis.py +0 -159
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/LICENSE +0 -0
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/entry_points.txt +0 -0
|
@@ -2,7 +2,9 @@ from pathlib import Path
|
|
|
2
2
|
from typing import List, Optional
|
|
3
3
|
|
|
4
4
|
import yaml
|
|
5
|
-
from pydantic import BaseModel,
|
|
5
|
+
from pydantic import BaseModel, field_validator
|
|
6
|
+
|
|
7
|
+
from pheval.utils.logger import get_logger
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class RunConfig(BaseModel):
|
|
@@ -12,10 +14,10 @@ class RunConfig(BaseModel):
|
|
|
12
14
|
Attributes:
|
|
13
15
|
run_identifier (str): The run identifier.
|
|
14
16
|
phenopacket_dir (str): The path to the phenopacket directory used for generating the results.
|
|
15
|
-
results_dir (str): The path to the
|
|
16
|
-
gene_analysis (bool): Whether
|
|
17
|
-
variant_analysis (bool): Whether
|
|
18
|
-
disease_analysis (bool): Whether
|
|
17
|
+
results_dir (str): The path to the result directory.
|
|
18
|
+
gene_analysis (bool): Whether to benchmark gene analysis results.
|
|
19
|
+
variant_analysis (bool): Whether to benchmark variant analysis results.
|
|
20
|
+
disease_analysis (bool): Whether to benchmark disease analysis results.
|
|
19
21
|
threshold (Optional[float]): The threshold to consider for benchmarking.
|
|
20
22
|
score_order (Optional[str]): The order of scores to consider for benchmarking, either ascending or descending.
|
|
21
23
|
"""
|
|
@@ -29,25 +31,15 @@ class RunConfig(BaseModel):
|
|
|
29
31
|
threshold: Optional[float]
|
|
30
32
|
score_order: Optional[str]
|
|
31
33
|
|
|
32
|
-
@
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
This method checks if 'threshold' and 'score_order' are None and assigns default values if so.
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
values (dict): The input values provided to the model.
|
|
34
|
+
@field_validator("threshold", mode="before")
|
|
35
|
+
@classmethod
|
|
36
|
+
def set_threshold(cls, threshold):
|
|
37
|
+
return threshold or None
|
|
41
38
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
values["threshold"] = 0
|
|
47
|
-
print("setting default threshold")
|
|
48
|
-
if values.get("score_order") is None:
|
|
49
|
-
values["score_order"] = "descending"
|
|
50
|
-
return values
|
|
39
|
+
@field_validator("score_order", mode="before")
|
|
40
|
+
@classmethod
|
|
41
|
+
def set_score_order(cls, score_order):
|
|
42
|
+
return score_order or "descending"
|
|
51
43
|
|
|
52
44
|
|
|
53
45
|
class SinglePlotCustomisation(BaseModel):
|
|
@@ -66,22 +58,10 @@ class SinglePlotCustomisation(BaseModel):
|
|
|
66
58
|
roc_curve_title: Optional[str]
|
|
67
59
|
precision_recall_title: Optional[str]
|
|
68
60
|
|
|
69
|
-
@
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
This method checks if 'plot_type' is None and assigns default value if so.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
values (dict): The input values provided to the model.
|
|
78
|
-
|
|
79
|
-
Returns:
|
|
80
|
-
dict: The updated values with defaults applied where necessary.
|
|
81
|
-
"""
|
|
82
|
-
if values.get("plot_type") is None:
|
|
83
|
-
values["plot_type"] = "bar_cumulative"
|
|
84
|
-
return values
|
|
61
|
+
@field_validator("plot_type", mode="before")
|
|
62
|
+
@classmethod
|
|
63
|
+
def set_plot_type(cls, plot_type):
|
|
64
|
+
return plot_type or "bar_cumulative"
|
|
85
65
|
|
|
86
66
|
|
|
87
67
|
class PlotCustomisation(BaseModel):
|
|
@@ -118,6 +98,8 @@ def parse_run_config(run_config: Path) -> Config:
|
|
|
118
98
|
Returns:
|
|
119
99
|
Config: The parsed run configurations.
|
|
120
100
|
"""
|
|
101
|
+
logger = get_logger()
|
|
102
|
+
logger.info(f"Loading benchmark configuration from {run_config}")
|
|
121
103
|
with open(run_config, "r") as f:
|
|
122
104
|
config_data = yaml.safe_load(f)
|
|
123
105
|
f.close()
|
pheval/cli.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
|
-
"""PhEval CLI Module
|
|
1
|
+
"""PhEval CLI Module"""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
|
|
5
5
|
import click
|
|
6
6
|
|
|
7
|
+
from pheval.utils.logger import get_logger, initialise_context
|
|
8
|
+
|
|
7
9
|
from .cli_pheval import run
|
|
8
10
|
from .cli_pheval_utils import (
|
|
11
|
+
benchmark,
|
|
9
12
|
create_spiked_vcfs_command,
|
|
10
|
-
|
|
11
|
-
generate_stats_plot,
|
|
13
|
+
generate_plots,
|
|
12
14
|
prepare_corpus_command,
|
|
13
15
|
scramble_phenopackets_command,
|
|
14
16
|
semsim_scramble_command,
|
|
@@ -16,50 +18,51 @@ from .cli_pheval_utils import (
|
|
|
16
18
|
update_phenopackets_command,
|
|
17
19
|
)
|
|
18
20
|
|
|
19
|
-
|
|
21
|
+
logger = get_logger()
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
@click.group()
|
|
23
25
|
@click.option("-v", "--verbose", count=True)
|
|
24
|
-
@click.option("-q", "--quiet")
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
quiet (bool, optional): Queit Flag.
|
|
31
|
-
"""
|
|
26
|
+
@click.option("-q", "--quiet", is_flag=True)
|
|
27
|
+
@click.pass_context
|
|
28
|
+
def main(ctx, verbose=1, quiet=False):
|
|
29
|
+
"""Main CLI method for PhEval."""
|
|
30
|
+
initialise_context(ctx)
|
|
31
|
+
|
|
32
32
|
if verbose >= 2:
|
|
33
|
-
|
|
33
|
+
logger.setLevel(logging.DEBUG)
|
|
34
34
|
elif verbose == 1:
|
|
35
|
-
|
|
35
|
+
logger.setLevel(logging.INFO)
|
|
36
36
|
else:
|
|
37
|
-
|
|
37
|
+
logger.setLevel(logging.WARNING)
|
|
38
38
|
if quiet:
|
|
39
|
-
|
|
39
|
+
logger.setLevel(logging.ERROR)
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
@
|
|
43
|
-
|
|
42
|
+
@main.group()
|
|
43
|
+
@click.pass_context
|
|
44
|
+
def pheval(ctx):
|
|
44
45
|
"""pheval"""
|
|
46
|
+
initialise_context(ctx)
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
@click.group()
|
|
51
|
-
def pheval_utils():
|
|
49
|
+
@main.group()
|
|
50
|
+
@click.pass_context
|
|
51
|
+
def pheval_utils(ctx):
|
|
52
52
|
"""pheval_utils"""
|
|
53
|
+
initialise_context(ctx)
|
|
53
54
|
|
|
54
55
|
|
|
56
|
+
pheval.add_command(run)
|
|
57
|
+
|
|
55
58
|
pheval_utils.add_command(semsim_scramble_command)
|
|
56
59
|
pheval_utils.add_command(scramble_phenopackets_command)
|
|
57
60
|
pheval_utils.add_command(update_phenopackets_command)
|
|
58
61
|
pheval_utils.add_command(create_spiked_vcfs_command)
|
|
59
|
-
pheval_utils.add_command(
|
|
62
|
+
pheval_utils.add_command(benchmark)
|
|
60
63
|
pheval_utils.add_command(semsim_to_exomiserdb_command)
|
|
61
|
-
pheval_utils.add_command(generate_stats_plot)
|
|
62
64
|
pheval_utils.add_command(prepare_corpus_command)
|
|
65
|
+
pheval_utils.add_command(generate_plots)
|
|
63
66
|
|
|
64
67
|
if __name__ == "__main__":
|
|
65
68
|
main()
|
pheval/cli_pheval_utils.py
CHANGED
|
@@ -5,9 +5,8 @@ from typing import List
|
|
|
5
5
|
|
|
6
6
|
import click
|
|
7
7
|
|
|
8
|
-
from pheval.analyse.
|
|
9
|
-
from pheval.analyse.generate_plots import
|
|
10
|
-
from pheval.analyse.run_data_parser import parse_run_config
|
|
8
|
+
from pheval.analyse.benchmark import benchmark_runs
|
|
9
|
+
from pheval.analyse.generate_plots import generate_plots_from_db
|
|
11
10
|
from pheval.prepare.create_noisy_phenopackets import scramble_phenopackets
|
|
12
11
|
from pheval.prepare.create_spiked_vcf import spike_vcfs
|
|
13
12
|
from pheval.prepare.custom_exceptions import InputError, MutuallyExclusiveOptionError
|
|
@@ -353,12 +352,12 @@ def create_spiked_vcfs_command(
|
|
|
353
352
|
help="Path to yaml configuration file for benchmarking.",
|
|
354
353
|
type=Path,
|
|
355
354
|
)
|
|
356
|
-
def
|
|
355
|
+
def benchmark(
|
|
357
356
|
run_yaml: Path,
|
|
358
357
|
):
|
|
359
358
|
"""Benchmark the gene/variant/disease prioritisation performance for runs."""
|
|
360
|
-
|
|
361
|
-
|
|
359
|
+
benchmark_runs(
|
|
360
|
+
run_yaml,
|
|
362
361
|
)
|
|
363
362
|
|
|
364
363
|
|
|
@@ -426,12 +425,12 @@ def semsim_to_exomiserdb_command(
|
|
|
426
425
|
help="Path to yaml configuration file for benchmarking.",
|
|
427
426
|
type=Path,
|
|
428
427
|
)
|
|
429
|
-
def
|
|
428
|
+
def generate_plots(
|
|
430
429
|
benchmark_db: Path,
|
|
431
430
|
run_data: Path,
|
|
432
431
|
):
|
|
433
432
|
"""Generate bar plot from benchmark db."""
|
|
434
|
-
|
|
433
|
+
generate_plots_from_db(benchmark_db, run_data)
|
|
435
434
|
|
|
436
435
|
|
|
437
436
|
@click.command("prepare-corpus")
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import polars as pl
|
|
5
|
+
|
|
6
|
+
from pheval.utils.phenopacket_utils import (
|
|
7
|
+
GenomicVariant,
|
|
8
|
+
PhenopacketUtil,
|
|
9
|
+
ProbandCausativeGene,
|
|
10
|
+
ProbandDisease,
|
|
11
|
+
phenopacket_reader,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PhenopacketTruthSet:
|
|
16
|
+
"""Class for finding the causative gene/disease/variant from a phenopacket"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, phenopacket_dir: Path):
|
|
19
|
+
self.phenopacket_dir = phenopacket_dir
|
|
20
|
+
|
|
21
|
+
def _get_phenopacket_path(self, phenopacket_name: str) -> Path:
|
|
22
|
+
"""
|
|
23
|
+
Get the phenopacket path for a given phenopacket name.
|
|
24
|
+
Args:
|
|
25
|
+
phenopacket_name (str): Name of the phenopacket.
|
|
26
|
+
Returns:
|
|
27
|
+
Path: Path to the phenopacket path.
|
|
28
|
+
"""
|
|
29
|
+
phenopacket_path = self.phenopacket_dir.joinpath(f"{phenopacket_name}.json")
|
|
30
|
+
if not phenopacket_path.exists():
|
|
31
|
+
raise FileNotFoundError(phenopacket_name + " not found in corpus!")
|
|
32
|
+
return phenopacket_path
|
|
33
|
+
|
|
34
|
+
def _get_phenopacket_util(self, phenopacket_name: str) -> PhenopacketUtil:
|
|
35
|
+
"""
|
|
36
|
+
Get the phenopacket util for a given phenopacket name.
|
|
37
|
+
Args:
|
|
38
|
+
phenopacket_name (str): Name of the phenopacket.
|
|
39
|
+
Returns:
|
|
40
|
+
PhenopacketUtil: PhenopacketUtil object.
|
|
41
|
+
"""
|
|
42
|
+
phenopacket_path = self._get_phenopacket_path(phenopacket_name)
|
|
43
|
+
phenopacket = phenopacket_reader(phenopacket_path)
|
|
44
|
+
return PhenopacketUtil(phenopacket)
|
|
45
|
+
|
|
46
|
+
def _get_causative_genes(self, phenopacket_name: str) -> List[ProbandCausativeGene]:
|
|
47
|
+
"""
|
|
48
|
+
Get the causative genes for a given phenopacket.
|
|
49
|
+
Args:
|
|
50
|
+
phenopacket_name (str): Name of the phenopacket.
|
|
51
|
+
Returns:
|
|
52
|
+
List[ProbandCausativeGene]: List of ProbandCausativeGene.
|
|
53
|
+
"""
|
|
54
|
+
phenopacket_util = self._get_phenopacket_util(phenopacket_name)
|
|
55
|
+
return phenopacket_util.diagnosed_genes()
|
|
56
|
+
|
|
57
|
+
def _get_causative_variants(self, phenopacket_name: str) -> List[GenomicVariant]:
|
|
58
|
+
"""
|
|
59
|
+
Get the causative variants for a given phenopacket.
|
|
60
|
+
Args:
|
|
61
|
+
phenopacket_name (str): Name of the phenopacket.
|
|
62
|
+
Returns:
|
|
63
|
+
List[GenomicVariant]: List of GenomicVariant.
|
|
64
|
+
"""
|
|
65
|
+
phenopacket_util = self._get_phenopacket_util(phenopacket_name)
|
|
66
|
+
return phenopacket_util.diagnosed_variants()
|
|
67
|
+
|
|
68
|
+
def _get_causative_diseases(self, phenopacket_name: str) -> List[ProbandDisease]:
|
|
69
|
+
"""
|
|
70
|
+
Get the diseases for a given phenopacket.
|
|
71
|
+
Args:
|
|
72
|
+
phenopacket_name (str): Name of the phenopacket.
|
|
73
|
+
Returns:
|
|
74
|
+
List[ProbandDisease]: List of ProbandDisease
|
|
75
|
+
"""
|
|
76
|
+
phenopacket_util = self._get_phenopacket_util(phenopacket_name)
|
|
77
|
+
return phenopacket_util.diagnoses()
|
|
78
|
+
|
|
79
|
+
def classified_gene(self, result_name: str) -> pl.DataFrame:
|
|
80
|
+
"""
|
|
81
|
+
Classify gene results for a given phenopacket.
|
|
82
|
+
Args:
|
|
83
|
+
result_name (str): Name of the result file.
|
|
84
|
+
Returns:
|
|
85
|
+
pl.DataFrame: Classified ranked gene results.
|
|
86
|
+
"""
|
|
87
|
+
causative_genes = self._get_causative_genes(result_name)
|
|
88
|
+
gene_symbols = [causative_gene.gene_symbol for causative_gene in causative_genes]
|
|
89
|
+
gene_identifiers = [causative_gene.gene_identifier for causative_gene in causative_genes]
|
|
90
|
+
return pl.DataFrame(
|
|
91
|
+
{
|
|
92
|
+
"gene_symbol": [g for g in gene_symbols],
|
|
93
|
+
"gene_identifier": [g for g in gene_identifiers],
|
|
94
|
+
}
|
|
95
|
+
).with_columns(
|
|
96
|
+
[
|
|
97
|
+
pl.lit(0).cast(pl.Float64).alias("score"),
|
|
98
|
+
pl.lit(0).cast(pl.Int64).alias("rank"),
|
|
99
|
+
pl.lit(True).alias("true_positive"),
|
|
100
|
+
]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def merge_gene_results(ranked_results: pl.DataFrame, output_file: Path) -> pl.DataFrame:
|
|
105
|
+
"""
|
|
106
|
+
Merge ranked gene results with the classified genes.
|
|
107
|
+
Args:
|
|
108
|
+
ranked_results (pl.DataFrame): Ranked gene results.
|
|
109
|
+
output_file (Path): Path to the output file.
|
|
110
|
+
Returns:
|
|
111
|
+
pl.DataFrame: Merged ranked gene results.
|
|
112
|
+
"""
|
|
113
|
+
classified_results = pl.read_parquet(output_file)
|
|
114
|
+
return (
|
|
115
|
+
ranked_results.with_columns(
|
|
116
|
+
(
|
|
117
|
+
pl.col("gene_symbol").is_in(classified_results["gene_symbol"])
|
|
118
|
+
| pl.col("gene_identifier").is_in(classified_results["gene_identifier"])
|
|
119
|
+
).alias("true_positive")
|
|
120
|
+
)
|
|
121
|
+
.with_columns(pl.col("rank").cast(pl.Int64))
|
|
122
|
+
.select(classified_results.columns)
|
|
123
|
+
.vstack(
|
|
124
|
+
classified_results.filter(
|
|
125
|
+
~pl.col("gene_symbol").is_in(ranked_results["gene_symbol"])
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def classified_variant(self, result_name: str) -> pl.DataFrame:
|
|
131
|
+
"""
|
|
132
|
+
Classified variant results for a given phenopacket.
|
|
133
|
+
Args:
|
|
134
|
+
result_name (str): Name of the result file.
|
|
135
|
+
Returns:
|
|
136
|
+
pl.DataFrame: Classified ranked variant results.
|
|
137
|
+
"""
|
|
138
|
+
variants = self._get_causative_variants(result_name)
|
|
139
|
+
return pl.DataFrame(
|
|
140
|
+
{
|
|
141
|
+
"chrom": [v.chrom for v in variants],
|
|
142
|
+
"pos": [v.pos for v in variants],
|
|
143
|
+
"ref": [v.ref for v in variants],
|
|
144
|
+
"alt": [v.alt for v in variants],
|
|
145
|
+
}
|
|
146
|
+
).with_columns(
|
|
147
|
+
[
|
|
148
|
+
pl.concat_str(["chrom", "pos", "ref", "alt"], separator="-").alias("variant_id"),
|
|
149
|
+
pl.lit(0.0).cast(pl.Float64).alias("score"),
|
|
150
|
+
pl.lit(0).cast(pl.Int64).alias("rank"),
|
|
151
|
+
pl.lit(True).alias("true_positive"),
|
|
152
|
+
]
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
@staticmethod
|
|
156
|
+
def merge_variant_results(ranked_results: pl.DataFrame, output_file: Path) -> pl.DataFrame:
|
|
157
|
+
"""
|
|
158
|
+
Merge ranked variant results with the classified variants.
|
|
159
|
+
Args:
|
|
160
|
+
ranked_results (pl.DataFrame): Ranked variant results.
|
|
161
|
+
output_file (Path): Path to the output file.
|
|
162
|
+
Returns:
|
|
163
|
+
pl.DataFrame: Merged ranked variant results.
|
|
164
|
+
"""
|
|
165
|
+
classified_results = pl.read_parquet(output_file)
|
|
166
|
+
return (
|
|
167
|
+
ranked_results.with_columns(
|
|
168
|
+
[
|
|
169
|
+
pl.struct(["chrom", "pos", "ref", "alt"])
|
|
170
|
+
.is_in(
|
|
171
|
+
classified_results.select(
|
|
172
|
+
pl.struct(["chrom", "pos", "ref", "alt"])
|
|
173
|
+
).to_series()
|
|
174
|
+
)
|
|
175
|
+
.alias("true_positive")
|
|
176
|
+
]
|
|
177
|
+
)
|
|
178
|
+
.with_columns(pl.col("rank").cast(pl.Int64))
|
|
179
|
+
.select(classified_results.columns)
|
|
180
|
+
.vstack(
|
|
181
|
+
classified_results.filter(
|
|
182
|
+
~pl.struct(["chrom", "pos", "ref", "alt"]).is_in(
|
|
183
|
+
ranked_results.select(pl.struct(["chrom", "pos", "ref", "alt"])).to_series()
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def classified_disease(self, result_name: str) -> pl.DataFrame:
|
|
190
|
+
"""
|
|
191
|
+
Classify disease results for a given phenopacket.
|
|
192
|
+
Args:
|
|
193
|
+
result_name (str): Name of the result file.
|
|
194
|
+
Returns:
|
|
195
|
+
pl.DataFrame: Classified ranked disease results.
|
|
196
|
+
"""
|
|
197
|
+
diseases = self._get_causative_diseases(result_name)
|
|
198
|
+
disease_identifiers = list(set(disease.disease_identifier for disease in diseases))
|
|
199
|
+
return pl.DataFrame(
|
|
200
|
+
{
|
|
201
|
+
"disease_identifier": [d for d in disease_identifiers],
|
|
202
|
+
}
|
|
203
|
+
).with_columns(
|
|
204
|
+
[
|
|
205
|
+
pl.lit(0).cast(pl.Float64).alias("score"),
|
|
206
|
+
pl.lit(0).cast(pl.Int64).alias("rank"),
|
|
207
|
+
pl.lit(True).alias("true_positive"),
|
|
208
|
+
]
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def merge_disease_results(ranked_results: pl.DataFrame, output_file: Path) -> pl.DataFrame:
|
|
213
|
+
"""
|
|
214
|
+
Merge ranked disease results with the classified diseases.
|
|
215
|
+
Args:
|
|
216
|
+
ranked_results (pl.DataFrame): Ranked disease results.
|
|
217
|
+
output_file (Path): Path to the output file.
|
|
218
|
+
Returns:
|
|
219
|
+
pl.DataFrame: Merged ranked disease results.
|
|
220
|
+
"""
|
|
221
|
+
classified_results = pl.read_parquet(output_file)
|
|
222
|
+
return (
|
|
223
|
+
ranked_results.with_columns(
|
|
224
|
+
(
|
|
225
|
+
pl.col("disease_identifier").is_in(classified_results["disease_identifier"])
|
|
226
|
+
).alias("true_positive")
|
|
227
|
+
)
|
|
228
|
+
.with_columns(pl.col("rank").cast(pl.Int64))
|
|
229
|
+
.select(classified_results.columns)
|
|
230
|
+
.vstack(
|
|
231
|
+
classified_results.filter(
|
|
232
|
+
~pl.col("disease_identifier").is_in(ranked_results["disease_identifier"])
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
)
|