pheval 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pheval might be problematic. Click here for more details.

Files changed (33) hide show
  1. pheval/analyse/benchmark.py +156 -0
  2. pheval/analyse/benchmark_db_manager.py +16 -134
  3. pheval/analyse/benchmark_output_type.py +43 -0
  4. pheval/analyse/binary_classification_curves.py +132 -0
  5. pheval/analyse/binary_classification_stats.py +164 -307
  6. pheval/analyse/generate_plots.py +210 -395
  7. pheval/analyse/generate_rank_comparisons.py +44 -0
  8. pheval/analyse/rank_stats.py +190 -382
  9. pheval/analyse/run_data_parser.py +21 -39
  10. pheval/cli.py +28 -25
  11. pheval/cli_pheval_utils.py +7 -8
  12. pheval/post_processing/phenopacket_truth_set.py +235 -0
  13. pheval/post_processing/post_processing.py +183 -303
  14. pheval/post_processing/validate_result_format.py +92 -0
  15. pheval/prepare/update_phenopacket.py +11 -9
  16. pheval/utils/logger.py +35 -0
  17. pheval/utils/phenopacket_utils.py +85 -91
  18. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/METADATA +4 -4
  19. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/RECORD +22 -26
  20. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/WHEEL +1 -1
  21. pheval/analyse/analysis.py +0 -104
  22. pheval/analyse/assess_prioritisation_base.py +0 -108
  23. pheval/analyse/benchmark_generator.py +0 -126
  24. pheval/analyse/benchmarking_data.py +0 -25
  25. pheval/analyse/disease_prioritisation_analysis.py +0 -152
  26. pheval/analyse/gene_prioritisation_analysis.py +0 -147
  27. pheval/analyse/generate_summary_outputs.py +0 -105
  28. pheval/analyse/parse_benchmark_summary.py +0 -81
  29. pheval/analyse/parse_corpus.py +0 -219
  30. pheval/analyse/prioritisation_result_types.py +0 -52
  31. pheval/analyse/variant_prioritisation_analysis.py +0 -159
  32. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/LICENSE +0 -0
  33. {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -2,7 +2,9 @@ from pathlib import Path
2
2
  from typing import List, Optional
3
3
 
4
4
  import yaml
5
- from pydantic import BaseModel, root_validator
5
+ from pydantic import BaseModel, field_validator
6
+
7
+ from pheval.utils.logger import get_logger
6
8
 
7
9
 
8
10
  class RunConfig(BaseModel):
@@ -12,10 +14,10 @@ class RunConfig(BaseModel):
12
14
  Attributes:
13
15
  run_identifier (str): The run identifier.
14
16
  phenopacket_dir (str): The path to the phenopacket directory used for generating the results.
15
- results_dir (str): The path to the results directory.
16
- gene_analysis (bool): Whether or not to benchmark gene analysis results.
17
- variant_analysis (bool): Whether or not to benchmark variant analysis results.
18
- disease_analysis (bool): Whether or not to benchmark disease analysis results.
17
+ results_dir (str): The path to the result directory.
18
+ gene_analysis (bool): Whether to benchmark gene analysis results.
19
+ variant_analysis (bool): Whether to benchmark variant analysis results.
20
+ disease_analysis (bool): Whether to benchmark disease analysis results.
19
21
  threshold (Optional[float]): The threshold to consider for benchmarking.
20
22
  score_order (Optional[str]): The order of scores to consider for benchmarking, either ascending or descending.
21
23
  """
@@ -29,25 +31,15 @@ class RunConfig(BaseModel):
29
31
  threshold: Optional[float]
30
32
  score_order: Optional[str]
31
33
 
32
- @root_validator(pre=True)
33
- def handle_blank_fields(cls, values: dict) -> dict: # noqa: N805
34
- """
35
- Root validator to handle fields that may be explicitly set to None.
36
-
37
- This method checks if 'threshold' and 'score_order' are None and assigns default values if so.
38
-
39
- Args:
40
- values (dict): The input values provided to the model.
34
+ @field_validator("threshold", mode="before")
35
+ @classmethod
36
+ def set_threshold(cls, threshold):
37
+ return threshold or None
41
38
 
42
- Returns:
43
- dict: The updated values with defaults applied where necessary.
44
- """
45
- if values.get("threshold") is None:
46
- values["threshold"] = 0
47
- print("setting default threshold")
48
- if values.get("score_order") is None:
49
- values["score_order"] = "descending"
50
- return values
39
+ @field_validator("score_order", mode="before")
40
+ @classmethod
41
+ def set_score_order(cls, score_order):
42
+ return score_order or "descending"
51
43
 
52
44
 
53
45
  class SinglePlotCustomisation(BaseModel):
@@ -66,22 +58,10 @@ class SinglePlotCustomisation(BaseModel):
66
58
  roc_curve_title: Optional[str]
67
59
  precision_recall_title: Optional[str]
68
60
 
69
- @root_validator(pre=True)
70
- def handle_blank_fields(cls, values: dict) -> dict: # noqa: N805
71
- """
72
- Root validator to handle fields that may be explicitly set to None.
73
-
74
- This method checks if 'plot_type' is None and assigns default value if so.
75
-
76
- Args:
77
- values (dict): The input values provided to the model.
78
-
79
- Returns:
80
- dict: The updated values with defaults applied where necessary.
81
- """
82
- if values.get("plot_type") is None:
83
- values["plot_type"] = "bar_cumulative"
84
- return values
61
+ @field_validator("plot_type", mode="before")
62
+ @classmethod
63
+ def set_plot_type(cls, plot_type):
64
+ return plot_type or "bar_cumulative"
85
65
 
86
66
 
87
67
  class PlotCustomisation(BaseModel):
@@ -118,6 +98,8 @@ def parse_run_config(run_config: Path) -> Config:
118
98
  Returns:
119
99
  Config: The parsed run configurations.
120
100
  """
101
+ logger = get_logger()
102
+ logger.info(f"Loading benchmark configuration from {run_config}")
121
103
  with open(run_config, "r") as f:
122
104
  config_data = yaml.safe_load(f)
123
105
  f.close()
pheval/cli.py CHANGED
@@ -1,14 +1,16 @@
1
- """PhEval CLI Module """
1
+ """PhEval CLI Module"""
2
2
 
3
3
  import logging
4
4
 
5
5
  import click
6
6
 
7
+ from pheval.utils.logger import get_logger, initialise_context
8
+
7
9
  from .cli_pheval import run
8
10
  from .cli_pheval_utils import (
11
+ benchmark,
9
12
  create_spiked_vcfs_command,
10
- generate_benchmark_stats,
11
- generate_stats_plot,
13
+ generate_plots,
12
14
  prepare_corpus_command,
13
15
  scramble_phenopackets_command,
14
16
  semsim_scramble_command,
@@ -16,50 +18,51 @@ from .cli_pheval_utils import (
16
18
  update_phenopackets_command,
17
19
  )
18
20
 
19
- info_log = logging.getLogger("info")
21
+ logger = get_logger()
20
22
 
21
23
 
22
24
  @click.group()
23
25
  @click.option("-v", "--verbose", count=True)
24
- @click.option("-q", "--quiet")
25
- def main(verbose=1, quiet=False) -> None:
26
- """main CLI method for PhEval
27
-
28
- Args:
29
- verbose (int, optional): Verbose flag.
30
- quiet (bool, optional): Queit Flag.
31
- """
26
+ @click.option("-q", "--quiet", is_flag=True)
27
+ @click.pass_context
28
+ def main(ctx, verbose=1, quiet=False):
29
+ """Main CLI method for PhEval."""
30
+ initialise_context(ctx)
31
+
32
32
  if verbose >= 2:
33
- info_log.setLevel(level=logging.DEBUG)
33
+ logger.setLevel(logging.DEBUG)
34
34
  elif verbose == 1:
35
- info_log.setLevel(level=logging.INFO)
35
+ logger.setLevel(logging.INFO)
36
36
  else:
37
- info_log.setLevel(level=logging.WARNING)
37
+ logger.setLevel(logging.WARNING)
38
38
  if quiet:
39
- info_log.setLevel(level=logging.ERROR)
39
+ logger.setLevel(logging.ERROR)
40
40
 
41
41
 
42
- @click.group()
43
- def pheval():
42
+ @main.group()
43
+ @click.pass_context
44
+ def pheval(ctx):
44
45
  """pheval"""
46
+ initialise_context(ctx)
45
47
 
46
48
 
47
- pheval.add_command(run)
48
-
49
-
50
- @click.group()
51
- def pheval_utils():
49
+ @main.group()
50
+ @click.pass_context
51
+ def pheval_utils(ctx):
52
52
  """pheval_utils"""
53
+ initialise_context(ctx)
53
54
 
54
55
 
56
+ pheval.add_command(run)
57
+
55
58
  pheval_utils.add_command(semsim_scramble_command)
56
59
  pheval_utils.add_command(scramble_phenopackets_command)
57
60
  pheval_utils.add_command(update_phenopackets_command)
58
61
  pheval_utils.add_command(create_spiked_vcfs_command)
59
- pheval_utils.add_command(generate_benchmark_stats)
62
+ pheval_utils.add_command(benchmark)
60
63
  pheval_utils.add_command(semsim_to_exomiserdb_command)
61
- pheval_utils.add_command(generate_stats_plot)
62
64
  pheval_utils.add_command(prepare_corpus_command)
65
+ pheval_utils.add_command(generate_plots)
63
66
 
64
67
  if __name__ == "__main__":
65
68
  main()
@@ -5,9 +5,8 @@ from typing import List
5
5
 
6
6
  import click
7
7
 
8
- from pheval.analyse.analysis import benchmark_run_comparisons
9
- from pheval.analyse.generate_plots import generate_plots_from_benchmark_summary_db
10
- from pheval.analyse.run_data_parser import parse_run_config
8
+ from pheval.analyse.benchmark import benchmark_runs
9
+ from pheval.analyse.generate_plots import generate_plots_from_db
11
10
  from pheval.prepare.create_noisy_phenopackets import scramble_phenopackets
12
11
  from pheval.prepare.create_spiked_vcf import spike_vcfs
13
12
  from pheval.prepare.custom_exceptions import InputError, MutuallyExclusiveOptionError
@@ -353,12 +352,12 @@ def create_spiked_vcfs_command(
353
352
  help="Path to yaml configuration file for benchmarking.",
354
353
  type=Path,
355
354
  )
356
- def generate_benchmark_stats(
355
+ def benchmark(
357
356
  run_yaml: Path,
358
357
  ):
359
358
  """Benchmark the gene/variant/disease prioritisation performance for runs."""
360
- benchmark_run_comparisons(
361
- parse_run_config(run_yaml),
359
+ benchmark_runs(
360
+ run_yaml,
362
361
  )
363
362
 
364
363
 
@@ -426,12 +425,12 @@ def semsim_to_exomiserdb_command(
426
425
  help="Path to yaml configuration file for benchmarking.",
427
426
  type=Path,
428
427
  )
429
- def generate_stats_plot(
428
+ def generate_plots(
430
429
  benchmark_db: Path,
431
430
  run_data: Path,
432
431
  ):
433
432
  """Generate bar plot from benchmark db."""
434
- generate_plots_from_benchmark_summary_db(benchmark_db, run_data)
433
+ generate_plots_from_db(benchmark_db, run_data)
435
434
 
436
435
 
437
436
  @click.command("prepare-corpus")
@@ -0,0 +1,235 @@
1
+ from pathlib import Path
2
+ from typing import List
3
+
4
+ import polars as pl
5
+
6
+ from pheval.utils.phenopacket_utils import (
7
+ GenomicVariant,
8
+ PhenopacketUtil,
9
+ ProbandCausativeGene,
10
+ ProbandDisease,
11
+ phenopacket_reader,
12
+ )
13
+
14
+
15
+ class PhenopacketTruthSet:
16
+ """Class for finding the causative gene/disease/variant from a phenopacket"""
17
+
18
+ def __init__(self, phenopacket_dir: Path):
19
+ self.phenopacket_dir = phenopacket_dir
20
+
21
+ def _get_phenopacket_path(self, phenopacket_name: str) -> Path:
22
+ """
23
+ Get the phenopacket path for a given phenopacket name.
24
+ Args:
25
+ phenopacket_name (str): Name of the phenopacket.
26
+ Returns:
27
+ Path: Path to the phenopacket path.
28
+ """
29
+ phenopacket_path = self.phenopacket_dir.joinpath(f"{phenopacket_name}.json")
30
+ if not phenopacket_path.exists():
31
+ raise FileNotFoundError(phenopacket_name + " not found in corpus!")
32
+ return phenopacket_path
33
+
34
+ def _get_phenopacket_util(self, phenopacket_name: str) -> PhenopacketUtil:
35
+ """
36
+ Get the phenopacket util for a given phenopacket name.
37
+ Args:
38
+ phenopacket_name (str): Name of the phenopacket.
39
+ Returns:
40
+ PhenopacketUtil: PhenopacketUtil object.
41
+ """
42
+ phenopacket_path = self._get_phenopacket_path(phenopacket_name)
43
+ phenopacket = phenopacket_reader(phenopacket_path)
44
+ return PhenopacketUtil(phenopacket)
45
+
46
+ def _get_causative_genes(self, phenopacket_name: str) -> List[ProbandCausativeGene]:
47
+ """
48
+ Get the causative genes for a given phenopacket.
49
+ Args:
50
+ phenopacket_name (str): Name of the phenopacket.
51
+ Returns:
52
+ List[ProbandCausativeGene]: List of ProbandCausativeGene.
53
+ """
54
+ phenopacket_util = self._get_phenopacket_util(phenopacket_name)
55
+ return phenopacket_util.diagnosed_genes()
56
+
57
+ def _get_causative_variants(self, phenopacket_name: str) -> List[GenomicVariant]:
58
+ """
59
+ Get the causative variants for a given phenopacket.
60
+ Args:
61
+ phenopacket_name (str): Name of the phenopacket.
62
+ Returns:
63
+ List[GenomicVariant]: List of GenomicVariant.
64
+ """
65
+ phenopacket_util = self._get_phenopacket_util(phenopacket_name)
66
+ return phenopacket_util.diagnosed_variants()
67
+
68
+ def _get_causative_diseases(self, phenopacket_name: str) -> List[ProbandDisease]:
69
+ """
70
+ Get the diseases for a given phenopacket.
71
+ Args:
72
+ phenopacket_name (str): Name of the phenopacket.
73
+ Returns:
74
+ List[ProbandDisease]: List of ProbandDisease
75
+ """
76
+ phenopacket_util = self._get_phenopacket_util(phenopacket_name)
77
+ return phenopacket_util.diagnoses()
78
+
79
+ def classified_gene(self, result_name: str) -> pl.DataFrame:
80
+ """
81
+ Classify gene results for a given phenopacket.
82
+ Args:
83
+ result_name (str): Name of the result file.
84
+ Returns:
85
+ pl.DataFrame: Classified ranked gene results.
86
+ """
87
+ causative_genes = self._get_causative_genes(result_name)
88
+ gene_symbols = [causative_gene.gene_symbol for causative_gene in causative_genes]
89
+ gene_identifiers = [causative_gene.gene_identifier for causative_gene in causative_genes]
90
+ return pl.DataFrame(
91
+ {
92
+ "gene_symbol": [g for g in gene_symbols],
93
+ "gene_identifier": [g for g in gene_identifiers],
94
+ }
95
+ ).with_columns(
96
+ [
97
+ pl.lit(0).cast(pl.Float64).alias("score"),
98
+ pl.lit(0).cast(pl.Int64).alias("rank"),
99
+ pl.lit(True).alias("true_positive"),
100
+ ]
101
+ )
102
+
103
+ @staticmethod
104
+ def merge_gene_results(ranked_results: pl.DataFrame, output_file: Path) -> pl.DataFrame:
105
+ """
106
+ Merge ranked gene results with the classified genes.
107
+ Args:
108
+ ranked_results (pl.DataFrame): Ranked gene results.
109
+ output_file (Path): Path to the output file.
110
+ Returns:
111
+ pl.DataFrame: Merged ranked gene results.
112
+ """
113
+ classified_results = pl.read_parquet(output_file)
114
+ return (
115
+ ranked_results.with_columns(
116
+ (
117
+ pl.col("gene_symbol").is_in(classified_results["gene_symbol"])
118
+ | pl.col("gene_identifier").is_in(classified_results["gene_identifier"])
119
+ ).alias("true_positive")
120
+ )
121
+ .with_columns(pl.col("rank").cast(pl.Int64))
122
+ .select(classified_results.columns)
123
+ .vstack(
124
+ classified_results.filter(
125
+ ~pl.col("gene_symbol").is_in(ranked_results["gene_symbol"])
126
+ )
127
+ )
128
+ )
129
+
130
+ def classified_variant(self, result_name: str) -> pl.DataFrame:
131
+ """
132
+ Classified variant results for a given phenopacket.
133
+ Args:
134
+ result_name (str): Name of the result file.
135
+ Returns:
136
+ pl.DataFrame: Classified ranked variant results.
137
+ """
138
+ variants = self._get_causative_variants(result_name)
139
+ return pl.DataFrame(
140
+ {
141
+ "chrom": [v.chrom for v in variants],
142
+ "pos": [v.pos for v in variants],
143
+ "ref": [v.ref for v in variants],
144
+ "alt": [v.alt for v in variants],
145
+ }
146
+ ).with_columns(
147
+ [
148
+ pl.concat_str(["chrom", "pos", "ref", "alt"], separator="-").alias("variant_id"),
149
+ pl.lit(0.0).cast(pl.Float64).alias("score"),
150
+ pl.lit(0).cast(pl.Int64).alias("rank"),
151
+ pl.lit(True).alias("true_positive"),
152
+ ]
153
+ )
154
+
155
+ @staticmethod
156
+ def merge_variant_results(ranked_results: pl.DataFrame, output_file: Path) -> pl.DataFrame:
157
+ """
158
+ Merge ranked variant results with the classified variants.
159
+ Args:
160
+ ranked_results (pl.DataFrame): Ranked variant results.
161
+ output_file (Path): Path to the output file.
162
+ Returns:
163
+ pl.DataFrame: Merged ranked variant results.
164
+ """
165
+ classified_results = pl.read_parquet(output_file)
166
+ return (
167
+ ranked_results.with_columns(
168
+ [
169
+ pl.struct(["chrom", "pos", "ref", "alt"])
170
+ .is_in(
171
+ classified_results.select(
172
+ pl.struct(["chrom", "pos", "ref", "alt"])
173
+ ).to_series()
174
+ )
175
+ .alias("true_positive")
176
+ ]
177
+ )
178
+ .with_columns(pl.col("rank").cast(pl.Int64))
179
+ .select(classified_results.columns)
180
+ .vstack(
181
+ classified_results.filter(
182
+ ~pl.struct(["chrom", "pos", "ref", "alt"]).is_in(
183
+ ranked_results.select(pl.struct(["chrom", "pos", "ref", "alt"])).to_series()
184
+ )
185
+ )
186
+ )
187
+ )
188
+
189
+ def classified_disease(self, result_name: str) -> pl.DataFrame:
190
+ """
191
+ Classify disease results for a given phenopacket.
192
+ Args:
193
+ result_name (str): Name of the result file.
194
+ Returns:
195
+ pl.DataFrame: Classified ranked disease results.
196
+ """
197
+ diseases = self._get_causative_diseases(result_name)
198
+ disease_identifiers = list(set(disease.disease_identifier for disease in diseases))
199
+ return pl.DataFrame(
200
+ {
201
+ "disease_identifier": [d for d in disease_identifiers],
202
+ }
203
+ ).with_columns(
204
+ [
205
+ pl.lit(0).cast(pl.Float64).alias("score"),
206
+ pl.lit(0).cast(pl.Int64).alias("rank"),
207
+ pl.lit(True).alias("true_positive"),
208
+ ]
209
+ )
210
+
211
+ @staticmethod
212
+ def merge_disease_results(ranked_results: pl.DataFrame, output_file: Path) -> pl.DataFrame:
213
+ """
214
+ Merge ranked disease results with the classified diseases.
215
+ Args:
216
+ ranked_results (pl.DataFrame): Ranked disease results.
217
+ output_file (Path): Path to the output file.
218
+ Returns:
219
+ pl.DataFrame: Merged ranked disease results.
220
+ """
221
+ classified_results = pl.read_parquet(output_file)
222
+ return (
223
+ ranked_results.with_columns(
224
+ (
225
+ pl.col("disease_identifier").is_in(classified_results["disease_identifier"])
226
+ ).alias("true_positive")
227
+ )
228
+ .with_columns(pl.col("rank").cast(pl.Int64))
229
+ .select(classified_results.columns)
230
+ .vstack(
231
+ classified_results.filter(
232
+ ~pl.col("disease_identifier").is_in(ranked_results["disease_identifier"])
233
+ )
234
+ )
235
+ )