pheval 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pheval might be problematic. Click here for more details.
- pheval/analyse/benchmark.py +156 -0
- pheval/analyse/benchmark_db_manager.py +16 -134
- pheval/analyse/benchmark_output_type.py +43 -0
- pheval/analyse/binary_classification_curves.py +132 -0
- pheval/analyse/binary_classification_stats.py +164 -307
- pheval/analyse/generate_plots.py +210 -395
- pheval/analyse/generate_rank_comparisons.py +44 -0
- pheval/analyse/rank_stats.py +190 -382
- pheval/analyse/run_data_parser.py +21 -39
- pheval/cli.py +28 -25
- pheval/cli_pheval_utils.py +7 -8
- pheval/post_processing/phenopacket_truth_set.py +235 -0
- pheval/post_processing/post_processing.py +183 -303
- pheval/post_processing/validate_result_format.py +92 -0
- pheval/prepare/update_phenopacket.py +11 -9
- pheval/utils/logger.py +35 -0
- pheval/utils/phenopacket_utils.py +85 -91
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/METADATA +4 -4
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/RECORD +22 -26
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/WHEEL +1 -1
- pheval/analyse/analysis.py +0 -104
- pheval/analyse/assess_prioritisation_base.py +0 -108
- pheval/analyse/benchmark_generator.py +0 -126
- pheval/analyse/benchmarking_data.py +0 -25
- pheval/analyse/disease_prioritisation_analysis.py +0 -152
- pheval/analyse/gene_prioritisation_analysis.py +0 -147
- pheval/analyse/generate_summary_outputs.py +0 -105
- pheval/analyse/parse_benchmark_summary.py +0 -81
- pheval/analyse/parse_corpus.py +0 -219
- pheval/analyse/prioritisation_result_types.py +0 -52
- pheval/analyse/variant_prioritisation_analysis.py +0 -159
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/LICENSE +0 -0
- {pheval-0.4.6.dist-info → pheval-0.5.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,329 +1,186 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
from
|
|
3
|
-
from typing import List, Union
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from multiprocessing.util import get_logger
|
|
4
3
|
|
|
5
|
-
|
|
6
|
-
RankedPhEvalDiseaseResult,
|
|
7
|
-
RankedPhEvalGeneResult,
|
|
8
|
-
RankedPhEvalVariantResult,
|
|
9
|
-
)
|
|
4
|
+
import polars as pl
|
|
10
5
|
|
|
11
6
|
|
|
12
|
-
@dataclass
|
|
13
|
-
class
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class ConfusionMatrix:
|
|
14
9
|
"""
|
|
15
|
-
|
|
10
|
+
Define logical conditions for computing a confusion matrix using Polars expressions.
|
|
16
11
|
|
|
17
12
|
Attributes:
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
13
|
+
TRUE_POSITIVES (pl.Expr): Condition identifying true positive cases,
|
|
14
|
+
where `rank == 1` and `true_positive` is `True`.
|
|
15
|
+
FALSE_POSITIVES (pl.Expr): Condition identifying false positive cases,
|
|
16
|
+
where `rank == 1` and `true_positive` is `False`.
|
|
17
|
+
TRUE_NEGATIVES (pl.Expr): Condition identifying true negative cases,
|
|
18
|
+
where `rank != 1` and `true_positive` is `False`.
|
|
19
|
+
FALSE_NEGATIVES (pl.Expr): Condition identifying false negative cases,
|
|
20
|
+
where `rank != 1` and `true_positive` is `True`.
|
|
26
21
|
"""
|
|
27
22
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
labels: List = field(default_factory=list)
|
|
33
|
-
scores: List = field(default_factory=list)
|
|
34
|
-
|
|
35
|
-
@staticmethod
|
|
36
|
-
def remove_relevant_ranks(
|
|
37
|
-
pheval_results: Union[
|
|
38
|
-
List[RankedPhEvalGeneResult],
|
|
39
|
-
List[RankedPhEvalVariantResult],
|
|
40
|
-
List[RankedPhEvalDiseaseResult],
|
|
41
|
-
],
|
|
42
|
-
relevant_ranks: List[int],
|
|
43
|
-
) -> List[int]:
|
|
44
|
-
"""
|
|
45
|
-
Remove the relevant entity ranks from all result ranks
|
|
46
|
-
Args:
|
|
47
|
-
pheval_results:
|
|
48
|
-
(Union[List[RankedPhEvalGeneResult], List[RankedPhEvalVariantResult], List[RankedPhEvalDiseaseResult]]):
|
|
49
|
-
The list of all pheval results.
|
|
50
|
-
relevant_ranks (List[int]): A list of the ranks associated with the known entities.
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
List[int]: A list of the ranks with the relevant entity ranks removed.
|
|
54
|
-
|
|
55
|
-
"""
|
|
56
|
-
all_result_ranks = [pheval_result.rank for pheval_result in pheval_results]
|
|
57
|
-
for rank in relevant_ranks:
|
|
58
|
-
if rank in all_result_ranks:
|
|
59
|
-
all_result_ranks.remove(rank)
|
|
60
|
-
continue
|
|
61
|
-
return all_result_ranks
|
|
62
|
-
|
|
63
|
-
def add_classification_for_known_entities(self, relevant_ranks: List[int]) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Update binary classification metrics for known entities based on their ranking.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
relevant_ranks (List[int]): A list of the ranks associated with the known entities.
|
|
69
|
-
"""
|
|
70
|
-
for rank in relevant_ranks:
|
|
71
|
-
if rank == 1:
|
|
72
|
-
self.true_positives += 1
|
|
73
|
-
elif rank != 1:
|
|
74
|
-
self.false_negatives += 1
|
|
75
|
-
|
|
76
|
-
def add_classification_for_other_entities(self, ranks: List[int]) -> None:
|
|
77
|
-
"""
|
|
78
|
-
Update binary classification metrics for other entities based on their ranking.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
ranks (List[int]): A list of the ranks for all other entities.
|
|
82
|
-
"""
|
|
83
|
-
for rank in ranks:
|
|
84
|
-
if rank == 1:
|
|
85
|
-
self.false_positives += 1
|
|
86
|
-
elif rank != 1:
|
|
87
|
-
self.true_negatives += 1
|
|
88
|
-
|
|
89
|
-
def add_labels_and_scores(
|
|
90
|
-
self,
|
|
91
|
-
pheval_results: Union[
|
|
92
|
-
List[RankedPhEvalGeneResult],
|
|
93
|
-
List[RankedPhEvalVariantResult],
|
|
94
|
-
List[RankedPhEvalDiseaseResult],
|
|
95
|
-
],
|
|
96
|
-
relevant_ranks: List[int],
|
|
97
|
-
):
|
|
98
|
-
"""
|
|
99
|
-
Adds scores and labels from the PhEval results.
|
|
100
|
-
|
|
101
|
-
Args:
|
|
102
|
-
pheval_results (Union[List[RankedPhEvalGeneResult], List[RankedPhEvalVariantResult],
|
|
103
|
-
List[RankedPhEvalDiseaseResult]]):
|
|
104
|
-
List of all PhEval results
|
|
105
|
-
relevant_ranks (List[int]): A list of the ranks associated with the known entities.
|
|
106
|
-
"""
|
|
107
|
-
relevant_ranks_copy = relevant_ranks.copy()
|
|
108
|
-
for result in pheval_results:
|
|
109
|
-
self.scores.append(result.score)
|
|
110
|
-
label = 1 if result.rank in relevant_ranks_copy else 0
|
|
111
|
-
self.labels.append(label)
|
|
112
|
-
relevant_ranks_copy.remove(result.rank) if label == 1 else None
|
|
113
|
-
|
|
114
|
-
def add_classification(
|
|
115
|
-
self,
|
|
116
|
-
pheval_results: Union[
|
|
117
|
-
List[RankedPhEvalGeneResult],
|
|
118
|
-
List[RankedPhEvalVariantResult],
|
|
119
|
-
List[RankedPhEvalDiseaseResult],
|
|
120
|
-
],
|
|
121
|
-
relevant_ranks: List[int],
|
|
122
|
-
) -> None:
|
|
123
|
-
"""
|
|
124
|
-
Update binary classification metrics for known and unknown entities based on their ranks.
|
|
125
|
-
Args:
|
|
126
|
-
pheval_results:
|
|
127
|
-
(Union[List[RankedPhEvalGeneResult], List[RankedPhEvalVariantResult], List[RankedPhEvalDiseaseResult]]):
|
|
128
|
-
The list of all pheval results.
|
|
129
|
-
relevant_ranks (List[int]): A list of the ranks associated with the known entities.
|
|
130
|
-
"""
|
|
131
|
-
self.add_classification_for_known_entities(relevant_ranks)
|
|
132
|
-
self.add_classification_for_other_entities(
|
|
133
|
-
self.remove_relevant_ranks(pheval_results, relevant_ranks)
|
|
134
|
-
)
|
|
135
|
-
self.add_labels_and_scores(pheval_results, relevant_ranks)
|
|
136
|
-
|
|
137
|
-
def sensitivity(self) -> float:
|
|
138
|
-
"""
|
|
139
|
-
Calculate sensitivity.
|
|
140
|
-
|
|
141
|
-
Sensitivity measures the proportion of actual positive instances correctly identified by the model.
|
|
142
|
-
|
|
143
|
-
Returns:
|
|
144
|
-
float: The sensitivity of the model, calculated as true positives divided by the sum of true positives
|
|
145
|
-
and false negatives. Returns 0 if both true positives and false negatives are zero.
|
|
146
|
-
"""
|
|
147
|
-
return (
|
|
148
|
-
self.true_positives / (self.true_positives + self.false_negatives)
|
|
149
|
-
if (self.true_positives + self.false_negatives) > 0
|
|
150
|
-
else 0.0
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
def specificity(self) -> float:
|
|
154
|
-
"""
|
|
155
|
-
Calculate specificity.
|
|
156
|
-
|
|
157
|
-
Specificity measures the proportion of actual negative instances correctly identified by the model.
|
|
158
|
-
|
|
159
|
-
Returns:
|
|
160
|
-
float: The specificity of the model, calculated as true negatives divided by the sum of true negatives
|
|
161
|
-
and false positives. Returns 0.0 if both true negatives and false positives are zero.
|
|
162
|
-
"""
|
|
163
|
-
return (
|
|
164
|
-
self.true_negatives / (self.true_negatives + self.false_positives)
|
|
165
|
-
if (self.true_negatives + self.false_positives) > 0
|
|
166
|
-
else 0.0
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
def precision(self) -> float:
|
|
170
|
-
"""
|
|
171
|
-
Calculate precision.
|
|
23
|
+
TRUE_POSITIVES = (pl.col("rank") == 1) & (pl.col("true_positive"))
|
|
24
|
+
FALSE_POSITIVES = (pl.col("rank") == 1) & (~pl.col("true_positive"))
|
|
25
|
+
TRUE_NEGATIVES = (pl.col("rank") != 1) & (~pl.col("true_positive"))
|
|
26
|
+
FALSE_NEGATIVES = (pl.col("rank") != 1) & (pl.col("true_positive"))
|
|
172
27
|
|
|
173
|
-
Precision measures the proportion of correctly predicted positive instances out of all instances
|
|
174
|
-
predicted as positive.
|
|
175
|
-
|
|
176
|
-
Returns:
|
|
177
|
-
float: The precision of the model, calculated as true positives divided by the sum of true positives
|
|
178
|
-
and false positives. Returns 0.0 if both true positives and false positives are zero.
|
|
179
|
-
"""
|
|
180
|
-
return (
|
|
181
|
-
self.true_positives / (self.true_positives + self.false_positives)
|
|
182
|
-
if (self.true_positives + self.false_positives) > 0
|
|
183
|
-
else 0.0
|
|
184
|
-
)
|
|
185
28
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class BinaryClassificationStats:
|
|
31
|
+
"""Binary classification statistic expressions."""
|
|
32
|
+
|
|
33
|
+
SENSITIVITY = (
|
|
34
|
+
pl.when((pl.col("true_positives") + pl.col("false_negatives")) != 0)
|
|
35
|
+
.then(pl.col("true_positives") / (pl.col("true_positives") + pl.col("false_negatives")))
|
|
36
|
+
.otherwise(0.0)
|
|
37
|
+
.alias("sensitivity")
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
SPECIFICITY = (
|
|
41
|
+
pl.when((pl.col("true_negatives") + pl.col("false_positives")) != 0)
|
|
42
|
+
.then(pl.col("true_negatives") / (pl.col("true_negatives") + pl.col("false_positives")))
|
|
43
|
+
.otherwise(0.0)
|
|
44
|
+
.alias("specificity")
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
PRECISION = (
|
|
48
|
+
pl.when((pl.col("true_positives") + pl.col("false_positives")) != 0)
|
|
49
|
+
.then(pl.col("true_positives") / (pl.col("true_positives") + pl.col("false_positives")))
|
|
50
|
+
.otherwise(0.0)
|
|
51
|
+
.alias("precision")
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
NEGATIVE_PREDICTIVE_VALUE = (
|
|
55
|
+
pl.when((pl.col("true_negatives") + pl.col("false_negatives")) != 0)
|
|
56
|
+
.then(pl.col("true_negatives") / (pl.col("true_negatives") + pl.col("false_negatives")))
|
|
57
|
+
.otherwise(0.0)
|
|
58
|
+
.alias("negative_predictive_value")
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
FALSE_POSITIVE_RATE = (
|
|
62
|
+
pl.when((pl.col("false_positives") + pl.col("true_negatives")) != 0)
|
|
63
|
+
.then(pl.col("false_positives") / (pl.col("false_positives") + pl.col("true_negatives")))
|
|
64
|
+
.otherwise(0.0)
|
|
65
|
+
.alias("false_positive_rate")
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
FALSE_DISCOVERY_RATE = (
|
|
69
|
+
pl.when((pl.col("false_positives") + pl.col("true_positives")) != 0)
|
|
70
|
+
.then(pl.col("false_positives") / (pl.col("false_positives") + pl.col("true_positives")))
|
|
71
|
+
.otherwise(0.0)
|
|
72
|
+
.alias("false_discovery_rate")
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
FALSE_NEGATIVE_RATE = (
|
|
76
|
+
pl.when((pl.col("false_negatives") + pl.col("true_positives")) != 0)
|
|
77
|
+
.then(pl.col("false_negatives") / (pl.col("false_negatives") + pl.col("true_positives")))
|
|
78
|
+
.otherwise(0.0)
|
|
79
|
+
.alias("false_negative_rate")
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
ACCURACY = (
|
|
83
|
+
pl.when(
|
|
84
|
+
(
|
|
85
|
+
pl.col("true_positives")
|
|
86
|
+
+ pl.col("false_positives")
|
|
87
|
+
+ pl.col("true_negatives")
|
|
88
|
+
+ pl.col("false_negatives")
|
|
89
|
+
)
|
|
90
|
+
!= 0
|
|
200
91
|
)
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
float: The False Positive Rate of the model, calculated as false positives divided by the sum of
|
|
210
|
-
false positives and true negatives. Returns 0.0 if both false positives and true negatives are zero.
|
|
211
|
-
"""
|
|
212
|
-
return (
|
|
213
|
-
self.false_positives / (self.false_positives + self.true_negatives)
|
|
214
|
-
if (self.false_positives + self.true_negatives) > 0
|
|
215
|
-
else 0.0
|
|
92
|
+
.then(
|
|
93
|
+
(pl.col("true_positives") + pl.col("true_negatives"))
|
|
94
|
+
/ (
|
|
95
|
+
pl.col("true_positives")
|
|
96
|
+
+ pl.col("false_positives")
|
|
97
|
+
+ pl.col("true_negatives")
|
|
98
|
+
+ pl.col("false_negatives")
|
|
99
|
+
)
|
|
216
100
|
)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
float: The False Discovery Rate of the model, calculated as false positives divided by the sum of
|
|
226
|
-
false positives and true positives. Returns 0.0 if both false positives and true positives are zero.
|
|
227
|
-
"""
|
|
228
|
-
return (
|
|
229
|
-
self.false_positives / (self.false_positives + self.true_positives)
|
|
230
|
-
if (self.false_positives + self.true_positives) > 0
|
|
231
|
-
else 0.0
|
|
101
|
+
.otherwise(0.0)
|
|
102
|
+
.alias("accuracy")
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
F1_SCORE = (
|
|
106
|
+
pl.when(
|
|
107
|
+
2 * (pl.col("true_positives") + pl.col("false_positives") + pl.col("false_negatives"))
|
|
108
|
+
!= 0
|
|
232
109
|
)
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
FNR measures the proportion of instances that are actually positive but predicted as negative.
|
|
239
|
-
|
|
240
|
-
Returns:
|
|
241
|
-
float: The False Negative Rate of the model, calculated as false negatives divided by the sum of
|
|
242
|
-
false negatives and true positives. Returns 0.0 if both false negatives and true positives are zero.
|
|
243
|
-
"""
|
|
244
|
-
return (
|
|
245
|
-
self.false_negatives / (self.false_negatives + self.true_positives)
|
|
246
|
-
if (self.false_negatives + self.true_positives) > 0
|
|
247
|
-
else 0.0
|
|
110
|
+
.then(
|
|
111
|
+
2
|
|
112
|
+
* pl.col("true_positives")
|
|
113
|
+
/ (2 * pl.col("true_positives") + pl.col("false_positives") + pl.col("false_negatives"))
|
|
248
114
|
)
|
|
115
|
+
.otherwise(0.0)
|
|
116
|
+
.alias("f1_score")
|
|
117
|
+
)
|
|
249
118
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
float: The Accuracy of the model, calculated as the sum of true positives and true negatives divided by
|
|
258
|
-
the sum of true positives, false positives, true negatives, and false negatives.
|
|
259
|
-
Returns 0.0 if the total sum of counts is zero.
|
|
260
|
-
"""
|
|
261
|
-
return (
|
|
262
|
-
(self.true_positives + self.true_negatives)
|
|
263
|
-
/ (
|
|
264
|
-
self.true_positives
|
|
265
|
-
+ self.false_positives
|
|
266
|
-
+ self.true_negatives
|
|
267
|
-
+ self.false_negatives
|
|
268
|
-
)
|
|
269
|
-
if (
|
|
270
|
-
self.true_positives
|
|
271
|
-
+ self.false_negatives
|
|
272
|
-
+ self.true_negatives
|
|
273
|
-
+ self.false_negatives
|
|
119
|
+
MATTHEWS_CORRELATION_COEFFICIENT = (
|
|
120
|
+
pl.when(
|
|
121
|
+
(
|
|
122
|
+
(pl.col("true_positives") + pl.col("false_positives"))
|
|
123
|
+
* (pl.col("true_positives") + pl.col("false_negatives"))
|
|
124
|
+
* (pl.col("true_negatives") + pl.col("false_positives"))
|
|
125
|
+
* (pl.col("true_negatives") + pl.col("false_negatives"))
|
|
274
126
|
)
|
|
275
127
|
> 0
|
|
276
|
-
else 0.0
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
def f1_score(self) -> float:
|
|
280
|
-
"""
|
|
281
|
-
Calculate F1 Score.
|
|
282
|
-
|
|
283
|
-
F1 Score is the harmonic mean of precision and recall, providing a balance between false positives
|
|
284
|
-
and false negatives.
|
|
285
|
-
|
|
286
|
-
Returns:
|
|
287
|
-
float: The F1 Score of the model, calculated as 2 * TP / (2 * TP + FP + FN).
|
|
288
|
-
Returns 0.0 if the denominator is zero.
|
|
289
|
-
"""
|
|
290
|
-
return (
|
|
291
|
-
(2 * self.true_positives)
|
|
292
|
-
/ ((2 * self.true_positives) + self.false_positives + self.false_negatives)
|
|
293
|
-
if (self.true_positives + self.false_positives + self.false_negatives) > 0
|
|
294
|
-
else 0.0
|
|
295
128
|
)
|
|
296
|
-
|
|
297
|
-
def matthews_correlation_coefficient(self) -> float:
|
|
298
|
-
"""
|
|
299
|
-
Calculate Matthews Correlation Coefficient (MCC).
|
|
300
|
-
|
|
301
|
-
MCC is a measure of the quality of binary classifications, accounting for imbalances in the data.
|
|
302
|
-
|
|
303
|
-
Returns:
|
|
304
|
-
float: The Matthews Correlation Coefficient of the model, calculated as
|
|
305
|
-
((TP * TN) - (FP * FN)) / sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)).
|
|
306
|
-
Returns 0.0 if the denominator is zero.
|
|
307
|
-
"""
|
|
308
|
-
return (
|
|
129
|
+
.then(
|
|
309
130
|
(
|
|
310
|
-
(
|
|
311
|
-
- (
|
|
131
|
+
(pl.col("true_positives") * pl.col("true_negatives"))
|
|
132
|
+
- (pl.col("false_positives") * pl.col("false_negatives"))
|
|
312
133
|
)
|
|
313
134
|
/ (
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
)
|
|
320
|
-
)
|
|
321
|
-
if (
|
|
322
|
-
self.true_positives
|
|
323
|
-
+ self.false_negatives
|
|
324
|
-
+ self.true_negatives
|
|
325
|
-
+ self.false_negatives
|
|
326
|
-
)
|
|
327
|
-
> 0
|
|
328
|
-
else 0.0
|
|
135
|
+
(pl.col("true_positives") + pl.col("false_positives"))
|
|
136
|
+
* (pl.col("true_positives") + pl.col("false_negatives"))
|
|
137
|
+
* (pl.col("true_negatives") + pl.col("false_positives"))
|
|
138
|
+
* (pl.col("true_negatives") + pl.col("false_negatives"))
|
|
139
|
+
).sqrt()
|
|
329
140
|
)
|
|
141
|
+
.otherwise(0.0)
|
|
142
|
+
.alias("matthews_correlation_coefficient")
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def compute_confusion_matrix(run_identifier: str, result_scan: pl.LazyFrame) -> pl.LazyFrame:
|
|
147
|
+
"""
|
|
148
|
+
Computes binary classification statistics.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
run_identifier (str): The identifier for the run.
|
|
152
|
+
result_scan (pl.LazyFrame): The LazyFrame containing the results for the directory.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
pl.LazyFrame: The LazyFrame containing the binary classification statistics.
|
|
156
|
+
"""
|
|
157
|
+
logger = get_logger()
|
|
158
|
+
logger.info(f"Computing binary classification statistics for {run_identifier}")
|
|
159
|
+
confusion_matrix = result_scan.select(
|
|
160
|
+
[
|
|
161
|
+
pl.lit(run_identifier).alias("run_identifier"),
|
|
162
|
+
ConfusionMatrix.TRUE_POSITIVES.sum().alias("true_positives").cast(pl.Int64),
|
|
163
|
+
ConfusionMatrix.FALSE_POSITIVES.sum().alias("false_positives").cast(pl.Int64),
|
|
164
|
+
ConfusionMatrix.TRUE_NEGATIVES.sum().alias("true_negatives").cast(pl.Int64),
|
|
165
|
+
ConfusionMatrix.FALSE_NEGATIVES.sum().alias("false_negatives").cast(pl.Int64),
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
return confusion_matrix.select(
|
|
169
|
+
[
|
|
170
|
+
pl.col("run_identifier"),
|
|
171
|
+
pl.col("true_positives"),
|
|
172
|
+
pl.col("false_positives"),
|
|
173
|
+
pl.col("true_negatives"),
|
|
174
|
+
pl.col("false_negatives"),
|
|
175
|
+
BinaryClassificationStats.SENSITIVITY,
|
|
176
|
+
BinaryClassificationStats.SPECIFICITY,
|
|
177
|
+
BinaryClassificationStats.PRECISION,
|
|
178
|
+
BinaryClassificationStats.NEGATIVE_PREDICTIVE_VALUE,
|
|
179
|
+
BinaryClassificationStats.FALSE_POSITIVE_RATE,
|
|
180
|
+
BinaryClassificationStats.FALSE_DISCOVERY_RATE,
|
|
181
|
+
BinaryClassificationStats.FALSE_NEGATIVE_RATE,
|
|
182
|
+
BinaryClassificationStats.ACCURACY,
|
|
183
|
+
BinaryClassificationStats.F1_SCORE,
|
|
184
|
+
BinaryClassificationStats.MATTHEWS_CORRELATION_COEFFICIENT,
|
|
185
|
+
]
|
|
186
|
+
)
|