pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
pgsui/utils/scorers.py
CHANGED
|
@@ -1,508 +1,282 @@
|
|
|
1
|
-
import
|
|
1
|
+
from typing import Dict, Literal
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
|
|
5
4
|
from sklearn.metrics import (
|
|
6
|
-
roc_curve,
|
|
7
|
-
auc,
|
|
8
5
|
accuracy_score,
|
|
9
|
-
hamming_loss,
|
|
10
|
-
make_scorer,
|
|
11
|
-
precision_recall_curve,
|
|
12
6
|
average_precision_score,
|
|
13
|
-
multilabel_confusion_matrix,
|
|
14
7
|
f1_score,
|
|
8
|
+
precision_score,
|
|
9
|
+
recall_score,
|
|
15
10
|
roc_auc_score,
|
|
16
11
|
)
|
|
17
|
-
|
|
18
12
|
from sklearn.preprocessing import label_binarize
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from ..impute.unsupervised.neural_network_methods import (
|
|
22
|
-
NeuralNetworkMethods,
|
|
23
|
-
)
|
|
24
|
-
except (ModuleNotFoundError, ValueError, ImportError):
|
|
25
|
-
from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
|
|
13
|
+
from snpio.utils.logging import LoggerManager
|
|
14
|
+
from torch import Tensor
|
|
26
15
|
|
|
27
16
|
|
|
28
|
-
class
|
|
29
|
-
|
|
30
|
-
def compute_roc_auc_micro_macro(
|
|
31
|
-
y_true, y_pred, num_classes=3, binarize_pred=True
|
|
32
|
-
):
|
|
33
|
-
"""Compute ROC curve with AUC scores.
|
|
17
|
+
class Scorer:
|
|
18
|
+
"""Class for evaluating the performance of a model using various metrics.
|
|
34
19
|
|
|
35
|
-
|
|
20
|
+
This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
|
|
21
|
+
"""
|
|
36
22
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
prefix: str,
|
|
26
|
+
average: Literal["micro", "macro", "weighted"] = "macro",
|
|
27
|
+
verbose: bool = False,
|
|
28
|
+
debug: bool = False,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initialize a Scorer object.
|
|
41
31
|
|
|
42
|
-
|
|
32
|
+
This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
|
|
43
33
|
|
|
44
|
-
|
|
34
|
+
Args:
|
|
35
|
+
prefix (str): Prefix for logging messages.
|
|
36
|
+
average (Literal["micro", "macro", "weighted"]): Average method for metrics. Must be one of 'micro', 'macro', or 'weighted'.
|
|
37
|
+
verbose (bool): Verbosity level for logging messages. Default is False.
|
|
38
|
+
debug (bool): Debug mode for logging messages. Default is False.
|
|
45
39
|
|
|
46
|
-
|
|
47
|
-
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If the average parameter is invalid. Must be one of 'micro', 'macro', or 'weighted'.
|
|
48
42
|
"""
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
# Get only classes that appear in y_true.
|
|
52
|
-
classes = [i for i in cats if i in y_true]
|
|
53
|
-
|
|
54
|
-
# Binarize the output for use with ROC-AUC.
|
|
55
|
-
y_true_bin = label_binarize(y_true, classes=cats)
|
|
56
|
-
|
|
57
|
-
if binarize_pred:
|
|
58
|
-
y_pred_bin = label_binarize(y_pred, classes=cats)
|
|
59
|
-
else:
|
|
60
|
-
y_pred_bin = y_pred
|
|
61
|
-
|
|
62
|
-
for i in range(y_true_bin.shape[1]):
|
|
63
|
-
if i not in classes:
|
|
64
|
-
y_true_bin = np.delete(y_true_bin, i, axis=-1)
|
|
65
|
-
y_pred_bin = np.delete(y_pred_bin, i, axis=-1)
|
|
66
|
-
|
|
67
|
-
n_classes = len(classes)
|
|
68
|
-
|
|
69
|
-
# Compute ROC curve and ROC area for each class.
|
|
70
|
-
fpr = dict()
|
|
71
|
-
tpr = dict()
|
|
72
|
-
roc_auc = dict()
|
|
73
|
-
for i, c in enumerate(classes):
|
|
74
|
-
fpr[c], tpr[c], _ = roc_curve(y_true_bin[:, i], y_pred_bin[:, i])
|
|
75
|
-
roc_auc[c] = auc(fpr[c], tpr[c])
|
|
76
|
-
|
|
77
|
-
# Compute micro-average ROC curve and ROC area.
|
|
78
|
-
fpr["micro"], tpr["micro"], _ = roc_curve(
|
|
79
|
-
y_true_bin.ravel(), y_pred_bin.ravel()
|
|
43
|
+
logman = LoggerManager(
|
|
44
|
+
name=__name__, prefix=prefix, debug=debug, verbose=verbose >= 1
|
|
80
45
|
)
|
|
46
|
+
self.logger = logman.get_logger()
|
|
81
47
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
48
|
+
if average not in {"micro", "macro", "weighted"}:
|
|
49
|
+
msg = f"Invalid average parameter: {average}. Must be one of 'micro', 'macro', or 'weighted'."
|
|
50
|
+
self.logger.error(msg)
|
|
51
|
+
raise ValueError(msg)
|
|
86
52
|
|
|
87
|
-
|
|
88
|
-
mean_tpr = np.zeros_like(all_fpr)
|
|
89
|
-
for i in classes:
|
|
90
|
-
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
|
|
53
|
+
self.average = average
|
|
91
54
|
|
|
92
|
-
|
|
93
|
-
|
|
55
|
+
def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
56
|
+
"""Calculate the accuracy of the model.
|
|
94
57
|
|
|
95
|
-
|
|
96
|
-
tpr["macro"] = mean_tpr
|
|
58
|
+
This method calculates the accuracy of the model by comparing the ground truth labels with the predicted labels.
|
|
97
59
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
roc_auc["tpr_macro"] = tpr["macro"]
|
|
102
|
-
roc_auc["fpr_micro"] = fpr["micro"]
|
|
103
|
-
roc_auc["tpr_micro"] = tpr["micro"]
|
|
104
|
-
|
|
105
|
-
for i in classes:
|
|
106
|
-
roc_auc[f"fpr_{i}"] = fpr[i]
|
|
107
|
-
roc_auc[f"tpr_{i}"] = tpr[i]
|
|
60
|
+
Args:
|
|
61
|
+
y_true (np.ndarray): Ground truth labels.
|
|
62
|
+
y_pred (np.ndarray): Predicted labels.
|
|
108
63
|
|
|
109
|
-
|
|
64
|
+
Returns:
|
|
65
|
+
float: Accuracy score.
|
|
66
|
+
"""
|
|
67
|
+
return accuracy_score(y_true, y_pred)
|
|
110
68
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
"""Compute precision-recall curve with Average Precision scores.
|
|
69
|
+
def f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
70
|
+
"""Calculate the F1 score of the model.
|
|
114
71
|
|
|
115
|
-
|
|
72
|
+
This method calculates the F1 score of the model by comparing the ground truth labels with the predicted labels.
|
|
116
73
|
|
|
117
74
|
Args:
|
|
118
|
-
y_true (
|
|
119
|
-
|
|
120
|
-
y_pred (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_pred should be integer-encoded.
|
|
75
|
+
y_true (np.ndarray): Ground truth labels.
|
|
76
|
+
y_pred (np.ndarray): Predicted labels.
|
|
121
77
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
num_classes (int, optional): How many classes to use. Defaults to 3.
|
|
125
|
-
|
|
126
|
-
Returns:
|
|
127
|
-
Dict[str, Any]: Dictionary with precision and recall curves per class and micro and macro averaged across classes, plus AP scores per-class and for micro and macro averages.
|
|
78
|
+
Returns:
|
|
79
|
+
float: F1 score.
|
|
128
80
|
"""
|
|
129
|
-
|
|
81
|
+
return f1_score(y_true, y_pred, average=self.average, zero_division=0.0)
|
|
130
82
|
|
|
131
|
-
|
|
83
|
+
def precision(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
84
|
+
"""Calculate the precision of the model.
|
|
132
85
|
|
|
133
|
-
|
|
134
|
-
classes = [i for i in cats if i in y_true]
|
|
86
|
+
This method calculates the precision of the model by comparing the ground truth labels with the predicted labels.
|
|
135
87
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
if is_multiclass:
|
|
141
|
-
for i in range(y_true_bin.shape[1]):
|
|
142
|
-
if i not in classes:
|
|
143
|
-
y_true_bin = np.delete(y_true_bin, i, axis=-1)
|
|
144
|
-
y_pred_proba_bin = np.delete(y_pred_proba_bin, i, axis=-1)
|
|
145
|
-
|
|
146
|
-
nn = NeuralNetworkMethods()
|
|
147
|
-
if len(y_true.shape) == 1 or y_true.shape[1] != num_classes:
|
|
148
|
-
y_true = label_binarize(y_true, classes=cats)
|
|
88
|
+
Args:
|
|
89
|
+
y_true (np.ndarray): Ground truth labels.
|
|
90
|
+
y_pred (np.ndarray): Predicted labels.
|
|
149
91
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
y_pred_proba_bin,
|
|
155
|
-
return_multilab=True, # Ensure multilabel format is returned
|
|
156
|
-
)
|
|
157
|
-
else:
|
|
158
|
-
y_pred_012 = nn.decode_masked(
|
|
159
|
-
y_true,
|
|
160
|
-
y_pred_proba_bin,
|
|
161
|
-
is_multiclass=False,
|
|
162
|
-
return_int=False,
|
|
163
|
-
return_multilab=True,
|
|
164
|
-
)
|
|
92
|
+
Returns:
|
|
93
|
+
float: Precision score.
|
|
94
|
+
"""
|
|
95
|
+
return precision_score(y_true, y_pred, average=self.average, zero_division=0.0)
|
|
165
96
|
|
|
166
|
-
|
|
167
|
-
|
|
97
|
+
def recall(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
|
98
|
+
"""Calculate the recall of the model.
|
|
168
99
|
|
|
169
|
-
|
|
170
|
-
tn /= num_classes
|
|
100
|
+
This method calculates the recall of the model by comparing the ground truth labels with the predicted labels.
|
|
171
101
|
|
|
172
|
-
|
|
173
|
-
|
|
102
|
+
Args:
|
|
103
|
+
y_true (np.ndarray): Ground truth labels.
|
|
104
|
+
y_pred (np.ndarray): Predicted labels.
|
|
174
105
|
|
|
175
|
-
|
|
106
|
+
Returns:
|
|
107
|
+
float: Recall score.
|
|
108
|
+
"""
|
|
109
|
+
return recall_score(y_true, y_pred, average=self.average, zero_division=0.0)
|
|
176
110
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
average_precision = dict()
|
|
111
|
+
def roc_auc(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
112
|
+
"""Multiclass ROC-AUC with label targets.
|
|
180
113
|
|
|
181
|
-
for
|
|
182
|
-
precision[c], recall[c], _ = precision_recall_curve(
|
|
183
|
-
y_true_bin[:, i], y_pred_proba_bin[:, i]
|
|
184
|
-
)
|
|
185
|
-
average_precision[c] = average_precision_score(
|
|
186
|
-
y_true_bin[:, i], y_pred_proba_bin[:, i]
|
|
187
|
-
)
|
|
114
|
+
This method calculates the ROC-AUC score for multiclass classification problems. It handles both 1D integer labels and 2D one-hot/indicator matrices for the ground truth labels.
|
|
188
115
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
116
|
+
Args:
|
|
117
|
+
y_true: 1D integer labels (shape: [n]).
|
|
118
|
+
If a one-hot/indicator matrix is supplied, we convert to labels.
|
|
119
|
+
y_pred_proba: 2D probabilities (shape: [n, n_classes]).
|
|
120
|
+
"""
|
|
121
|
+
y_true = np.asarray(y_true)
|
|
122
|
+
y_pred_proba = np.asarray(y_pred_proba)
|
|
193
123
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
)
|
|
124
|
+
if y_pred_proba.ndim == 3:
|
|
125
|
+
y_pred_proba = y_pred_proba.reshape(-1, y_pred_proba.shape[-1])
|
|
197
126
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
127
|
+
# If user passed indicator/one-hot, convert to labels.
|
|
128
|
+
if y_true.ndim == 2 and y_true.shape[1] == y_pred_proba.shape[1]:
|
|
129
|
+
y_true = y_true.argmax(axis=1)
|
|
201
130
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
131
|
+
# Guard: need >1 class present for AUC
|
|
132
|
+
if np.unique(y_true).size < 2:
|
|
133
|
+
return 0.5
|
|
205
134
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
predict_still_missing=False,
|
|
213
|
-
),
|
|
135
|
+
return float(
|
|
136
|
+
roc_auc_score(
|
|
137
|
+
y_true,
|
|
138
|
+
y_pred_proba,
|
|
139
|
+
multi_class="ovr",
|
|
140
|
+
average=self.average,
|
|
214
141
|
)
|
|
142
|
+
)
|
|
215
143
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
results["macro"] = average_precision["macro"]
|
|
237
|
-
results["f1_score"] = f1
|
|
238
|
-
results["f1_score_weighted"] = f1_weighted
|
|
239
|
-
results["recall_macro"] = all_recall
|
|
240
|
-
results["precision_macro"] = mean_precision
|
|
241
|
-
results["recall_micro"] = recall["micro"]
|
|
242
|
-
results["precision_micro"] = precision["micro"]
|
|
243
|
-
|
|
244
|
-
for i in classes:
|
|
245
|
-
results[f"recall_{i}"] = recall[i]
|
|
246
|
-
results[f"precision_{i}"] = precision[i]
|
|
247
|
-
results[i] = average_precision[i]
|
|
248
|
-
results["baseline"] = baseline
|
|
249
|
-
|
|
250
|
-
return results
|
|
251
|
-
|
|
252
|
-
@staticmethod
|
|
253
|
-
def check_if_tuple(y_pred):
|
|
254
|
-
"""Checks if y_pred is a tuple and if so, returns the first element of the tuple."""
|
|
255
|
-
if isinstance(y_pred, tuple):
|
|
256
|
-
y_pred = y_pred[0]
|
|
257
|
-
return y_pred
|
|
258
|
-
|
|
259
|
-
@staticmethod
|
|
260
|
-
def accuracy_scorer(y_true, y_pred, **kwargs):
|
|
261
|
-
"""Get accuracy score for grid search.
|
|
262
|
-
|
|
263
|
-
If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
|
|
144
|
+
def evaluate(
|
|
145
|
+
self,
|
|
146
|
+
y_true: np.ndarray | Tensor | list,
|
|
147
|
+
y_pred: np.ndarray | Tensor | list,
|
|
148
|
+
y_true_ohe: np.ndarray | Tensor | list,
|
|
149
|
+
y_pred_proba: np.ndarray | Tensor | list,
|
|
150
|
+
objective_mode: bool = False,
|
|
151
|
+
tune_metric: Literal[
|
|
152
|
+
"pr_macro",
|
|
153
|
+
"roc_auc",
|
|
154
|
+
"average_precision",
|
|
155
|
+
"accuracy",
|
|
156
|
+
"f1",
|
|
157
|
+
"precision",
|
|
158
|
+
"recall",
|
|
159
|
+
] = "pr_macro",
|
|
160
|
+
) -> Dict[str, float]:
|
|
161
|
+
"""Evaluate the model using various metrics.
|
|
162
|
+
|
|
163
|
+
This method evaluates the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The method can be used to evaluate the performance of a model on a dataset with ground truth labels. The method can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
|
|
264
164
|
|
|
265
165
|
Args:
|
|
266
|
-
y_true (
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
166
|
+
y_true (np.ndarray | torch.Tensor): Ground truth labels.
|
|
167
|
+
y_pred (np.ndarray | torch.Tensor): Predicted labels.
|
|
168
|
+
y_true_ohe (np.ndarray | torch.Tensor): One-hot encoded ground truth labels.
|
|
169
|
+
y_pred_proba (np.ndarray | torch.Tensor): Predicted probabilities.
|
|
170
|
+
objective_mode (bool): Whether to use objective mode for evaluation. Default is False.
|
|
171
|
+
tune_metric (Literal["pr_macro", "roc_auc", "average_precision", "accuracy", "f1", "precision", "recall"]): Metric to use for tuning. Ignored if `objective_mode` is False. Default is 'pr_macro'.
|
|
271
172
|
|
|
272
173
|
Returns:
|
|
273
|
-
float:
|
|
274
|
-
"""
|
|
275
|
-
missing_mask = kwargs.get("missing_mask")
|
|
276
|
-
|
|
277
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
278
|
-
|
|
279
|
-
y_true_masked = y_true[missing_mask]
|
|
280
|
-
y_pred_masked = y_pred[missing_mask]
|
|
174
|
+
Dict[str, float]: Dictionary of evaluation metrics. Keys are 'accuracy', 'f1', 'precision', 'recall', 'roc_auc', 'average_precision', and 'pr_macro'.
|
|
281
175
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
176
|
+
Raises:
|
|
177
|
+
ValueError: If the input data is invalid.
|
|
178
|
+
ValueError: If an invalid tune_metric is provided.
|
|
179
|
+
"""
|
|
180
|
+
if not y_true.ndim < 3:
|
|
181
|
+
msg = "y_true must have 1 or 2 dimensions."
|
|
182
|
+
self.logger.error(msg)
|
|
183
|
+
raise ValueError(msg)
|
|
184
|
+
|
|
185
|
+
if not y_pred.ndim < 3:
|
|
186
|
+
msg = "y_pred must have 1 or 2 dimensions."
|
|
187
|
+
self.logger.error(msg)
|
|
188
|
+
raise ValueError(msg)
|
|
189
|
+
|
|
190
|
+
if not y_true_ohe.ndim == 2:
|
|
191
|
+
msg = "y_true_ohe must have 2 dimensions."
|
|
192
|
+
self.logger.error(msg)
|
|
193
|
+
raise ValueError(msg)
|
|
194
|
+
|
|
195
|
+
if y_pred_proba.ndim != 2:
|
|
196
|
+
y_pred_proba = y_pred_proba.reshape(-1, y_true_ohe.shape[-1])
|
|
197
|
+
self.logger.debug(f"Reshaped y_pred_proba to {y_pred_proba.shape}")
|
|
198
|
+
|
|
199
|
+
if objective_mode:
|
|
200
|
+
if tune_metric == "pr_macro":
|
|
201
|
+
metrics = {"pr_macro": self.pr_macro(y_true_ohe, y_pred_proba)}
|
|
202
|
+
elif tune_metric == "roc_auc":
|
|
203
|
+
metrics = {"roc_auc": self.roc_auc(y_true, y_pred_proba)}
|
|
204
|
+
elif tune_metric == "average_precision":
|
|
205
|
+
metrics = {
|
|
206
|
+
"average_precision": self.average_precision(y_true, y_pred_proba)
|
|
207
|
+
}
|
|
208
|
+
elif tune_metric == "accuracy":
|
|
209
|
+
metrics = {"accuracy": self.accuracy(y_true, y_pred)}
|
|
210
|
+
elif tune_metric == "f1":
|
|
211
|
+
metrics = {"f1": self.f1(y_true, y_pred)}
|
|
212
|
+
elif tune_metric == "precision":
|
|
213
|
+
metrics = {"precision": self.precision(y_true, y_pred)}
|
|
214
|
+
elif tune_metric == "recall":
|
|
215
|
+
metrics = {"recall": self.recall(y_true, y_pred)}
|
|
216
|
+
else:
|
|
217
|
+
msg = f"Invalid tune_metric provided: '{tune_metric}'."
|
|
218
|
+
self.logger.error(msg)
|
|
219
|
+
raise ValueError(msg)
|
|
220
|
+
else:
|
|
221
|
+
metrics = {
|
|
222
|
+
"accuracy": self.accuracy(y_true, y_pred),
|
|
223
|
+
"f1": self.f1(y_true, y_pred),
|
|
224
|
+
"precision": self.precision(y_true, y_pred),
|
|
225
|
+
"recall": self.recall(y_true, y_pred),
|
|
226
|
+
"roc_auc": self.roc_auc(y_true, y_pred_proba),
|
|
227
|
+
"average_precision": self.average_precision(y_true, y_pred_proba),
|
|
228
|
+
"pr_macro": self.pr_macro(y_true_ohe, y_pred_proba),
|
|
229
|
+
}
|
|
286
230
|
|
|
287
|
-
return
|
|
231
|
+
return {k: float(v) for k, v in metrics.items()}
|
|
288
232
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
"""Get Hamming score for grid search.
|
|
233
|
+
def average_precision(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
234
|
+
"""Average precision with safe multiclass handling.
|
|
292
235
|
|
|
293
|
-
If
|
|
236
|
+
If y_true is 1D of class indices, it is binarized against the number of columns in y_pred_proba. If y_true is already one-hot or indicator, it is used as-is.
|
|
294
237
|
|
|
295
238
|
Args:
|
|
296
|
-
y_true (
|
|
297
|
-
|
|
298
|
-
y_pred (tensorflow.EagerTensor): Predictions from model as probabilities. They must first be decoded to use with hamming_scorer.
|
|
299
|
-
|
|
300
|
-
kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
|
|
239
|
+
y_true (np.ndarray): Ground truth labels (1D class indices or 2D one-hot/indicator).
|
|
240
|
+
y_pred_proba (np.ndarray): Predicted probabilities (2D array).
|
|
301
241
|
|
|
302
242
|
Returns:
|
|
303
|
-
float:
|
|
243
|
+
float: Average precision score.
|
|
304
244
|
"""
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
308
|
-
|
|
309
|
-
y_true_masked = y_true[missing_mask]
|
|
310
|
-
y_pred_masked = y_pred[missing_mask]
|
|
311
|
-
|
|
312
|
-
nn = NeuralNetworkMethods()
|
|
313
|
-
y_pred_masked_decoded = nn.decode_masked(
|
|
314
|
-
y_true_masked,
|
|
315
|
-
y_pred_masked,
|
|
316
|
-
predict_still_missing=False,
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
return hamming_loss(y_true_masked, y_pred_masked_decoded)
|
|
320
|
-
|
|
321
|
-
@staticmethod
|
|
322
|
-
def compute_metric(
|
|
323
|
-
y_true, y_pred, metric_type, scoring_function, **kwargs
|
|
324
|
-
):
|
|
325
|
-
y_true = np.array(y_true)
|
|
326
|
-
y_pred = np.array(y_pred)
|
|
245
|
+
y_true_arr = np.asarray(y_true)
|
|
246
|
+
y_proba_arr = np.asarray(y_pred_proba)
|
|
327
247
|
|
|
328
|
-
|
|
329
|
-
|
|
248
|
+
if y_proba_arr.ndim == 3:
|
|
249
|
+
y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
|
|
330
250
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
return roc_auc_score(
|
|
335
|
-
y_true_bin, y_pred, multi_class="ovr", average=metric_type
|
|
336
|
-
)
|
|
337
|
-
elif scoring_function == "f1":
|
|
338
|
-
is_multiclass = num_classes != 4
|
|
339
|
-
y_pred_proba_bin = y_pred
|
|
340
|
-
nn = NeuralNetworkMethods()
|
|
341
|
-
y_pred_bin = nn.decode_masked(
|
|
342
|
-
y_true_bin,
|
|
343
|
-
y_pred_proba_bin,
|
|
344
|
-
is_multiclass=is_multiclass,
|
|
345
|
-
return_int=False,
|
|
346
|
-
return_multilab=True,
|
|
347
|
-
)
|
|
348
|
-
return f1_score(
|
|
349
|
-
y_true_bin, y_pred_bin, average=metric_type, zero_division=0
|
|
350
|
-
)
|
|
351
|
-
elif scoring_function == "average_precision":
|
|
352
|
-
return average_precision_score(
|
|
353
|
-
y_true_bin, y_pred, average=metric_type
|
|
354
|
-
)
|
|
251
|
+
# If y_true already matches proba columns (one-hot / indicator)
|
|
252
|
+
if y_true_arr.ndim == 2 and y_true_arr.shape[1] == y_proba_arr.shape[1]:
|
|
253
|
+
y_bin = y_true_arr
|
|
355
254
|
else:
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
@staticmethod
|
|
361
|
-
def compute_score(y_true, y_pred, metric_type, scoring_function, **kwargs):
|
|
362
|
-
missing_mask = kwargs.get("missing_mask")
|
|
363
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
364
|
-
|
|
365
|
-
y_true_masked = y_true[missing_mask]
|
|
366
|
-
y_pred_masked = y_pred[missing_mask]
|
|
367
|
-
|
|
368
|
-
return Scorers.compute_metric(
|
|
369
|
-
y_true_masked,
|
|
370
|
-
y_pred_masked,
|
|
371
|
-
metric_type=metric_type,
|
|
372
|
-
scoring_function=scoring_function,
|
|
373
|
-
**kwargs,
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
@classmethod
|
|
377
|
-
def make_multimetric_scorer(
|
|
378
|
-
cls, metrics, missing_mask, num_classes=4, testing=False
|
|
379
|
-
):
|
|
380
|
-
if isinstance(metrics, str):
|
|
381
|
-
metrics = [metrics]
|
|
382
|
-
|
|
383
|
-
metric_map = {
|
|
384
|
-
"roc_auc_macro": ("macro", "roc_auc"),
|
|
385
|
-
"roc_auc_micro": ("micro", "roc_auc"),
|
|
386
|
-
"roc_auc_weighted": ("weighted", "roc_auc"),
|
|
387
|
-
"f1_micro": ("micro", "f1"),
|
|
388
|
-
"f1_macro": ("macro", "f1"),
|
|
389
|
-
"f1_weighted": ("weighted", "f1"),
|
|
390
|
-
"average_precision_macro": ("macro", "average_precision"),
|
|
391
|
-
"average_precision_micro": ("micro", "average_precision"),
|
|
392
|
-
"average_precision_weighted": ("weighted", "average_precision"),
|
|
393
|
-
}
|
|
394
|
-
|
|
395
|
-
default_params = {
|
|
396
|
-
"missing_mask": missing_mask,
|
|
397
|
-
"num_classes": num_classes,
|
|
398
|
-
"testing": testing,
|
|
399
|
-
}
|
|
400
|
-
|
|
401
|
-
scorers = dict()
|
|
402
|
-
for item in metrics:
|
|
403
|
-
item = item.lower()
|
|
404
|
-
|
|
405
|
-
if item in metric_map:
|
|
406
|
-
metric_type, scoring_function = metric_map[item]
|
|
407
|
-
params = default_params.copy()
|
|
408
|
-
scorers[item] = make_scorer(
|
|
409
|
-
Scorers.compute_score,
|
|
410
|
-
metric_type=metric_type,
|
|
411
|
-
scoring_function=scoring_function,
|
|
412
|
-
**params,
|
|
413
|
-
)
|
|
414
|
-
elif item == "accuracy":
|
|
415
|
-
scorers[item] = make_scorer(
|
|
416
|
-
cls.accuracy_scorer, **default_params
|
|
417
|
-
)
|
|
418
|
-
elif item == "hamming":
|
|
419
|
-
scorers[item] = make_scorer(
|
|
420
|
-
cls.hamming_scorer, **default_params
|
|
421
|
-
)
|
|
422
|
-
else:
|
|
423
|
-
raise ValueError(f"Unsupported metric: {item}")
|
|
255
|
+
# Interpret y_true as class indices
|
|
256
|
+
n_classes = y_proba_arr.shape[1]
|
|
257
|
+
y_bin = label_binarize(y_true_arr.ravel(), classes=np.arange(n_classes))
|
|
424
258
|
|
|
425
|
-
return
|
|
259
|
+
return float(average_precision_score(y_bin, y_proba_arr, average=self.average))
|
|
426
260
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
# Get missing mask if provided.
|
|
430
|
-
# Otherwise default is all missing values (array all True).
|
|
431
|
-
missing_mask = kwargs.get("missing_mask")
|
|
432
|
-
nn_model = kwargs.get("nn_model", True)
|
|
433
|
-
num_classes = kwargs.get("num_classes", 3)
|
|
434
|
-
testing = kwargs.get("testing", False)
|
|
261
|
+
def pr_macro(self, y_true_ohe: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
262
|
+
"""Macro-averaged average precision (precision-recall AUC) across classes.
|
|
435
263
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
nn = NeuralNetworkMethods()
|
|
440
|
-
|
|
441
|
-
# VAE has tuple output.
|
|
442
|
-
if isinstance(y_pred, tuple):
|
|
443
|
-
y_pred = y_pred[0]
|
|
444
|
-
|
|
445
|
-
y_true_masked = y_true[missing_mask]
|
|
446
|
-
y_pred_masked = y_pred[missing_mask]
|
|
447
|
-
|
|
448
|
-
roc_auc = Scorers.compute_roc_auc_micro_macro(
|
|
449
|
-
y_true_masked,
|
|
450
|
-
y_pred_masked,
|
|
451
|
-
num_classes=num_classes,
|
|
452
|
-
binarize_pred=False,
|
|
453
|
-
)
|
|
264
|
+
Args:
|
|
265
|
+
y_true_ohe (np.ndarray): One-hot encoded ground truth labels (2D array).
|
|
266
|
+
y_pred_proba (np.ndarray): Predicted probabilities (2D array).
|
|
454
267
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
)
|
|
268
|
+
Returns:
|
|
269
|
+
float: Macro-averaged average precision score.
|
|
270
|
+
"""
|
|
271
|
+
y_true_arr = np.asarray(y_true_ohe)
|
|
272
|
+
y_proba_arr = np.asarray(y_pred_proba)
|
|
460
273
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
nn.decode_masked(
|
|
464
|
-
y_true_masked,
|
|
465
|
-
y_pred_masked,
|
|
466
|
-
is_multiclass=is_multiclass,
|
|
467
|
-
return_int=True,
|
|
468
|
-
),
|
|
469
|
-
)
|
|
470
|
-
ham = hamming_loss(
|
|
471
|
-
y_true_masked,
|
|
472
|
-
nn.decode_masked(
|
|
473
|
-
y_true_masked,
|
|
474
|
-
y_pred_masked,
|
|
475
|
-
is_multiclass=is_multiclass,
|
|
476
|
-
return_int=True,
|
|
477
|
-
),
|
|
478
|
-
)
|
|
274
|
+
if y_proba_arr.ndim == 3:
|
|
275
|
+
y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
|
|
479
276
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
is_multiclass=is_multiclass,
|
|
485
|
-
return_int=True,
|
|
486
|
-
)
|
|
277
|
+
# Ensure 2D indicator truth
|
|
278
|
+
if y_true_arr.ndim == 1:
|
|
279
|
+
n_classes = y_proba_arr.shape[1]
|
|
280
|
+
y_true_arr = label_binarize(y_true_arr, classes=np.arange(n_classes))
|
|
487
281
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
with open("genotype_dist.csv", "w") as fout:
|
|
491
|
-
fout.write(
|
|
492
|
-
"site,prob_vector,imputed_genotype,expected_genotype\n"
|
|
493
|
-
)
|
|
494
|
-
for i, (yt, yp, ypd) in enumerate(
|
|
495
|
-
zip(y_true_masked, bin_mapping, y_pred_masked_decoded)
|
|
496
|
-
):
|
|
497
|
-
fout.write(f"{i},{yp},{ypd},{yt}\n")
|
|
498
|
-
# np.set_printoptions(threshold=np.inf)
|
|
499
|
-
# print(y_true_masked)
|
|
500
|
-
# print(y_pred_masked_decoded)
|
|
501
|
-
|
|
502
|
-
metrics = dict()
|
|
503
|
-
metrics["accuracy"] = acc
|
|
504
|
-
metrics["roc_auc"] = roc_auc
|
|
505
|
-
metrics["precision_recall"] = pr_ap
|
|
506
|
-
metrics["hamming"] = ham
|
|
507
|
-
|
|
508
|
-
return metrics
|
|
282
|
+
return float(average_precision_score(y_true_arr, y_proba_arr, average="macro"))
|