pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.0.dist-info/RECORD +0 -75
- pg_sui-0.2.0.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
pgsui/utils/scorers.py
CHANGED
|
@@ -1,750 +1,297 @@
|
|
|
1
|
-
import
|
|
1
|
+
from typing import Dict, Literal
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
|
|
5
4
|
from sklearn.metrics import (
|
|
6
|
-
roc_curve,
|
|
7
|
-
auc,
|
|
8
5
|
accuracy_score,
|
|
9
|
-
hamming_loss,
|
|
10
|
-
make_scorer,
|
|
11
|
-
precision_recall_curve,
|
|
12
6
|
average_precision_score,
|
|
13
|
-
multilabel_confusion_matrix,
|
|
14
7
|
f1_score,
|
|
8
|
+
precision_score,
|
|
9
|
+
recall_score,
|
|
10
|
+
roc_auc_score,
|
|
15
11
|
)
|
|
16
|
-
|
|
17
12
|
from sklearn.preprocessing import label_binarize
|
|
13
|
+
from snpio.utils.logging import LoggerManager
|
|
14
|
+
from torch import Tensor
|
|
18
15
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
NeuralNetworkMethods,
|
|
22
|
-
)
|
|
23
|
-
except (ModuleNotFoundError, ValueError, ImportError):
|
|
24
|
-
from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class Scorers:
|
|
28
|
-
@staticmethod
|
|
29
|
-
def compute_roc_auc_micro_macro(
|
|
30
|
-
y_true, y_pred, num_classes=3, binarize_pred=True
|
|
31
|
-
):
|
|
32
|
-
"""Compute ROC curve with AUC scores.
|
|
33
|
-
|
|
34
|
-
ROC (Receiver Operating Characteristic) curves and AUC (area under curve) scores are computed per-class and for micro and macro averages.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
y_true (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_true should be integer-encoded.
|
|
38
|
-
|
|
39
|
-
y_pred (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_pred should be probabilities.
|
|
40
|
-
|
|
41
|
-
num_classes (int, optional): How many classes to use. Defaults to 3.
|
|
42
|
-
|
|
43
|
-
binarize_pred (bool, optional): Whether to binarize y_pred. If False, y_pred should be probabilities of each class. Defaults to True.
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
Dict[str, Any]: Dictionary with true and false positive rates along probability threshold curve per class, micro and macro tpr and fpr curves averaged across classes, and AUC scores per-class and for micro and macro averages.
|
|
47
|
-
"""
|
|
48
|
-
cats = range(num_classes)
|
|
49
|
-
|
|
50
|
-
# Get only classes that appear in y_true.
|
|
51
|
-
classes = [i for i in cats if i in y_true]
|
|
52
|
-
|
|
53
|
-
# Binarize the output for use with ROC-AUC.
|
|
54
|
-
y_true_bin = label_binarize(y_true, classes=cats)
|
|
55
|
-
|
|
56
|
-
if binarize_pred:
|
|
57
|
-
y_pred_bin = label_binarize(y_pred, classes=cats)
|
|
58
|
-
else:
|
|
59
|
-
y_pred_bin = y_pred
|
|
60
|
-
|
|
61
|
-
for i in range(y_true_bin.shape[1]):
|
|
62
|
-
if i not in classes:
|
|
63
|
-
y_true_bin = np.delete(y_true_bin, i, axis=-1)
|
|
64
|
-
y_pred_bin = np.delete(y_pred_bin, i, axis=-1)
|
|
65
|
-
|
|
66
|
-
n_classes = len(classes)
|
|
67
|
-
|
|
68
|
-
# Compute ROC curve and ROC area for each class.
|
|
69
|
-
fpr = dict()
|
|
70
|
-
tpr = dict()
|
|
71
|
-
roc_auc = dict()
|
|
72
|
-
for i, c in enumerate(classes):
|
|
73
|
-
fpr[c], tpr[c], _ = roc_curve(y_true_bin[:, i], y_pred_bin[:, i])
|
|
74
|
-
roc_auc[c] = auc(fpr[c], tpr[c])
|
|
75
|
-
|
|
76
|
-
# Compute micro-average ROC curve and ROC area.
|
|
77
|
-
fpr["micro"], tpr["micro"], _ = roc_curve(
|
|
78
|
-
y_true_bin.ravel(), y_pred_bin.ravel()
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
|
82
|
-
|
|
83
|
-
# Aggregate all false positive rates
|
|
84
|
-
all_fpr = np.unique(np.concatenate([fpr[i] for i in classes]))
|
|
16
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
17
|
+
from pgsui.utils.misc import validate_input_type
|
|
85
18
|
|
|
86
|
-
# Then interpolate all ROC curves at these points.
|
|
87
|
-
mean_tpr = np.zeros_like(all_fpr)
|
|
88
|
-
for i in classes:
|
|
89
|
-
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
|
|
90
19
|
|
|
91
|
-
|
|
92
|
-
|
|
20
|
+
class Scorer:
|
|
21
|
+
"""Class for evaluating the performance of a model using various metrics.
|
|
93
22
|
|
|
94
|
-
|
|
95
|
-
|
|
23
|
+
This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
|
|
24
|
+
"""
|
|
96
25
|
|
|
97
|
-
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
prefix: str,
|
|
29
|
+
average: Literal["micro", "macro", "weighted"] = "macro",
|
|
30
|
+
verbose: bool = False,
|
|
31
|
+
debug: bool = False,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initialize a Scorer object.
|
|
98
34
|
|
|
99
|
-
|
|
100
|
-
roc_auc["tpr_macro"] = tpr["macro"]
|
|
101
|
-
roc_auc["fpr_micro"] = fpr["micro"]
|
|
102
|
-
roc_auc["tpr_micro"] = tpr["micro"]
|
|
103
|
-
|
|
104
|
-
for i in classes:
|
|
105
|
-
roc_auc[f"fpr_{i}"] = fpr[i]
|
|
106
|
-
roc_auc[f"tpr_{i}"] = tpr[i]
|
|
107
|
-
|
|
108
|
-
return roc_auc
|
|
109
|
-
|
|
110
|
-
@staticmethod
|
|
111
|
-
def compute_pr(y_true, y_pred, use_int_encodings=False, num_classes=3):
|
|
112
|
-
"""Compute precision-recall curve with Average Precision scores.
|
|
113
|
-
|
|
114
|
-
PR and AP scores are computed per-class and for micro and macro averages.
|
|
35
|
+
This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
|
|
115
36
|
|
|
116
37
|
Args:
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
38
|
+
prefix (str): Prefix for logging messages.
|
|
39
|
+
average (Literal["micro", "macro", "weighted"]): Average method for metrics. Must be one of 'micro', 'macro', or 'weighted'.
|
|
40
|
+
verbose (bool): Verbosity level for logging messages. Default is False.
|
|
41
|
+
debug (bool): Debug mode for logging messages. Default is False.
|
|
120
42
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
num_classes (int, optional): How many classes to use. Defaults to 3.
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
Dict[str, Any]: Dictionary with precision and recall curves per class and micro and macro averaged across classes, plus AP scores per-class and for micro and macro averages.
|
|
43
|
+
Raises:
|
|
44
|
+
ValueError: If the average parameter is invalid. Must be one of 'micro', 'macro', or 'weighted'.
|
|
127
45
|
"""
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
is_multiclass = True if num_classes != 4 else False
|
|
132
|
-
|
|
133
|
-
# Get only classes that appear in y_true.
|
|
134
|
-
classes = [i for i in cats if i in y_true]
|
|
135
|
-
|
|
136
|
-
# Binarize the output for use with ROC-AUC.
|
|
137
|
-
y_true_bin = label_binarize(y_true, classes=cats)
|
|
138
|
-
y_pred_proba_bin = y_pred
|
|
139
|
-
|
|
140
|
-
if is_multiclass:
|
|
141
|
-
for i in range(y_true_bin.shape[1]):
|
|
142
|
-
if i not in classes:
|
|
143
|
-
y_true_bin = np.delete(y_true_bin, i, axis=-1)
|
|
144
|
-
y_pred_proba_bin = np.delete(y_pred_proba_bin, i, axis=-1)
|
|
145
|
-
|
|
146
|
-
nn = NeuralNetworkMethods()
|
|
147
|
-
if use_int_encodings:
|
|
148
|
-
y_pred_012 = nn.decode_masked(y_true_bin, y_pred_proba_bin)
|
|
149
|
-
thresh = 0.5
|
|
150
|
-
else:
|
|
151
|
-
y_pred_012 = nn.decode_masked(
|
|
152
|
-
y_true_bin,
|
|
153
|
-
y_pred_proba_bin,
|
|
154
|
-
is_multiclass=is_multiclass,
|
|
155
|
-
return_int=False,
|
|
156
|
-
return_multilab=True,
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
encode_func = (
|
|
160
|
-
nn.encode_multiclass if is_multiclass else nn.encode_multilab
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
y_true = encode_func(y_true, num_classes=num_classes)
|
|
164
|
-
|
|
165
|
-
# Make confusion matrix to get true negatives and true positives.
|
|
166
|
-
mcm = multilabel_confusion_matrix(y_true, y_pred_012)
|
|
167
|
-
|
|
168
|
-
tn = np.sum(mcm[:, 0, 0])
|
|
169
|
-
tn /= num_classes
|
|
170
|
-
|
|
171
|
-
tp = np.sum(mcm[:, 1, 1])
|
|
172
|
-
tp /= num_classes
|
|
173
|
-
|
|
174
|
-
baseline = tp / (tn + tp)
|
|
175
|
-
|
|
176
|
-
precision = dict()
|
|
177
|
-
recall = dict()
|
|
178
|
-
average_precision = dict()
|
|
179
|
-
|
|
180
|
-
for i, c in enumerate(classes):
|
|
181
|
-
precision[c], recall[c], _ = precision_recall_curve(
|
|
182
|
-
y_true_bin[:, i], y_pred_proba_bin[:, i]
|
|
183
|
-
)
|
|
184
|
-
average_precision[c] = average_precision_score(
|
|
185
|
-
y_true_bin[:, i], y_pred_proba_bin[:, i]
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
# A "micro-average": quantifying score on all classes jointly.
|
|
189
|
-
precision["micro"], recall["micro"], _ = precision_recall_curve(
|
|
190
|
-
y_true_bin.ravel(), y_pred_proba_bin.ravel()
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
average_precision["micro"] = average_precision_score(
|
|
194
|
-
y_true_bin, y_pred_proba_bin, average="micro"
|
|
46
|
+
logman = LoggerManager(
|
|
47
|
+
name=__name__, prefix=prefix, debug=debug, verbose=verbose >= 1
|
|
195
48
|
)
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
y_true_bin, y_pred_proba_bin, average="macro"
|
|
49
|
+
self.logger = configure_logger(
|
|
50
|
+
logman.get_logger(), verbose=verbose >= 1, debug=debug
|
|
199
51
|
)
|
|
200
52
|
|
|
201
|
-
if
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
y_pred_proba_bin,
|
|
206
|
-
threshold=thresh,
|
|
207
|
-
return_multilab=True,
|
|
208
|
-
predict_still_missing=False,
|
|
209
|
-
),
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
f1 = f1_score(y_true_bin, y_pred_012, average="macro")
|
|
213
|
-
|
|
214
|
-
# Aggregate all recalls
|
|
215
|
-
all_recall = np.unique(np.concatenate([recall[i] for i in classes]))
|
|
216
|
-
|
|
217
|
-
# Then interpolate all PR curves at these points.
|
|
218
|
-
mean_precision = np.zeros_like(all_recall)
|
|
219
|
-
for i in classes:
|
|
220
|
-
mean_precision += np.interp(all_recall, precision[i], recall[i])
|
|
221
|
-
|
|
222
|
-
# Finally, average it and compute AUC.
|
|
223
|
-
mean_precision /= num_classes
|
|
224
|
-
|
|
225
|
-
recall["macro"] = all_recall
|
|
226
|
-
precision["macro"] = mean_precision
|
|
227
|
-
|
|
228
|
-
results = dict()
|
|
229
|
-
|
|
230
|
-
results["micro"] = average_precision["micro"]
|
|
231
|
-
results["macro"] = average_precision["macro"]
|
|
232
|
-
results["f1_score"] = f1
|
|
233
|
-
results["recall_macro"] = all_recall
|
|
234
|
-
results["precision_macro"] = mean_precision
|
|
235
|
-
results["recall_micro"] = recall["micro"]
|
|
236
|
-
results["precision_micro"] = precision["micro"]
|
|
237
|
-
|
|
238
|
-
for i in classes:
|
|
239
|
-
results[f"recall_{i}"] = recall[i]
|
|
240
|
-
results[f"precision_{i}"] = precision[i]
|
|
241
|
-
results[i] = average_precision[i]
|
|
242
|
-
results["baseline"] = baseline
|
|
53
|
+
if average not in {"micro", "macro", "weighted"}:
|
|
54
|
+
msg = f"Invalid average parameter: {average}. Must be one of 'micro', 'macro', or 'weighted'."
|
|
55
|
+
self.logger.error(msg)
|
|
56
|
+
raise ValueError(msg)
|
|
243
57
|
|
|
244
|
-
|
|
58
|
+
self.average: Literal["micro", "macro", "weighted"] = average
|
|
245
59
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
"""Checks if y_pred is a tuple and if so, returns the first element of the tuple."""
|
|
249
|
-
if isinstance(y_pred, tuple):
|
|
250
|
-
y_pred = y_pred[0]
|
|
251
|
-
return y_pred
|
|
60
|
+
def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
61
|
+
"""Calculate the accuracy of the model.
|
|
252
62
|
|
|
253
|
-
|
|
254
|
-
def accuracy_scorer(y_true, y_pred, **kwargs):
|
|
255
|
-
"""Get accuracy score for grid search.
|
|
256
|
-
|
|
257
|
-
If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
|
|
63
|
+
This method calculates the accuracy of the model by comparing the ground truth labels with the predicted labels.
|
|
258
64
|
|
|
259
65
|
Args:
|
|
260
|
-
y_true (
|
|
261
|
-
|
|
262
|
-
y_pred (tensorflow.EagerTensor): Predictions from model as probabilities. They must first be decoded to use with accuracy_score.
|
|
263
|
-
|
|
264
|
-
kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
|
|
66
|
+
y_true (np.ndarray): Ground truth labels.
|
|
67
|
+
y_pred (np.ndarray): Predicted labels.
|
|
265
68
|
|
|
266
69
|
Returns:
|
|
267
|
-
float:
|
|
70
|
+
float: Accuracy score.
|
|
268
71
|
"""
|
|
269
|
-
|
|
270
|
-
# Otherwise default is all missing values (array all True).
|
|
271
|
-
|
|
272
|
-
missing_mask = kwargs.get("missing_mask")
|
|
273
|
-
testing = kwargs.get("testing", False)
|
|
274
|
-
nn_model = kwargs.get("nn_model", True)
|
|
275
|
-
|
|
276
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
277
|
-
|
|
278
|
-
if nn_model:
|
|
279
|
-
nn = NeuralNetworkMethods()
|
|
280
|
-
|
|
281
|
-
y_true_masked = y_true[missing_mask]
|
|
282
|
-
y_pred_masked = y_pred[missing_mask]
|
|
283
|
-
|
|
284
|
-
if nn_model:
|
|
285
|
-
y_pred_masked_decoded = nn.decode_masked(
|
|
286
|
-
y_true_masked, y_pred_masked, predict_still_missing=False
|
|
287
|
-
)
|
|
288
|
-
else:
|
|
289
|
-
y_pred_masked_decoded = y_pred_masked
|
|
72
|
+
return float(accuracy_score(y_true, y_pred))
|
|
290
73
|
|
|
291
|
-
|
|
74
|
+
def f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
75
|
+
"""Calculate the F1 score of the model.
|
|
292
76
|
|
|
293
|
-
|
|
294
|
-
np.set_printoptions(threshold=np.inf)
|
|
295
|
-
print(y_true_masked)
|
|
296
|
-
print(y_pred_masked_decoded)
|
|
297
|
-
|
|
298
|
-
return acc
|
|
299
|
-
|
|
300
|
-
@staticmethod
|
|
301
|
-
def hamming_scorer(y_true, y_pred, **kwargs):
|
|
302
|
-
"""Get Hamming score for grid search.
|
|
303
|
-
|
|
304
|
-
If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
|
|
77
|
+
This method calculates the F1 score of the model by comparing the ground truth labels with the predicted labels.
|
|
305
78
|
|
|
306
79
|
Args:
|
|
307
|
-
y_true (
|
|
308
|
-
|
|
309
|
-
y_pred (tensorflow.EagerTensor): Predictions from model as probabilities. They must first be decoded to use with hamming_scorer.
|
|
310
|
-
|
|
311
|
-
kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
|
|
80
|
+
y_true (np.ndarray): Ground truth labels.
|
|
81
|
+
y_pred (np.ndarray): Predicted labels.
|
|
312
82
|
|
|
313
83
|
Returns:
|
|
314
|
-
float:
|
|
84
|
+
float: F1 score.
|
|
315
85
|
"""
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
missing_mask = kwargs.get("missing_mask")
|
|
320
|
-
testing = kwargs.get("testing", False)
|
|
321
|
-
nn_model = kwargs.get("nn_model", True)
|
|
322
|
-
num_classes = kwargs.get("num_classes", 3)
|
|
323
|
-
|
|
324
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
325
|
-
|
|
326
|
-
if nn_model:
|
|
327
|
-
nn = NeuralNetworkMethods()
|
|
328
|
-
|
|
329
|
-
y_true_masked = y_true[missing_mask]
|
|
330
|
-
y_pred_masked = y_pred[missing_mask]
|
|
331
|
-
|
|
332
|
-
if nn_model:
|
|
333
|
-
y_pred_masked_decoded = nn.decode_masked(
|
|
334
|
-
y_true_masked,
|
|
335
|
-
y_pred_masked,
|
|
336
|
-
predict_still_missing=False,
|
|
337
|
-
)
|
|
338
|
-
else:
|
|
339
|
-
y_pred_masked_decoded = y_pred_masked
|
|
340
|
-
|
|
341
|
-
ham = hamming_loss(y_true_masked, y_pred_masked_decoded)
|
|
342
|
-
|
|
343
|
-
if testing:
|
|
344
|
-
np.set_printoptions(threshold=np.inf)
|
|
345
|
-
print(y_true_masked)
|
|
346
|
-
print(y_pred_masked_decoded)
|
|
347
|
-
|
|
348
|
-
return ham
|
|
86
|
+
avg: str = self.average
|
|
87
|
+
return float(f1_score(y_true, y_pred, average=avg, zero_division=0))
|
|
349
88
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
"""Get AUC score with macro averaging for grid search.
|
|
89
|
+
def precision(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
90
|
+
"""Calculate the precision of the model.
|
|
353
91
|
|
|
354
|
-
|
|
92
|
+
This method calculates the precision of the model by comparing the ground truth labels with the predicted labels.
|
|
355
93
|
|
|
356
94
|
Args:
|
|
357
|
-
y_true (
|
|
358
|
-
|
|
359
|
-
y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
|
|
360
|
-
|
|
361
|
-
kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
|
|
95
|
+
y_true (np.ndarray): Ground truth labels.
|
|
96
|
+
y_pred (np.ndarray): Predicted labels.
|
|
362
97
|
|
|
363
98
|
Returns:
|
|
364
|
-
float:
|
|
99
|
+
float: Precision score.
|
|
365
100
|
"""
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
missing_mask = kwargs.get("missing_mask")
|
|
369
|
-
num_classes = kwargs.get("num_classes", 3)
|
|
370
|
-
nn_model = kwargs.get("nn_model", True)
|
|
371
|
-
testing = kwargs.get("testing", False)
|
|
372
|
-
|
|
373
|
-
is_multiclass = True if num_classes != 4 else False
|
|
374
|
-
|
|
375
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
376
|
-
|
|
377
|
-
if nn_model:
|
|
378
|
-
nn = NeuralNetworkMethods()
|
|
101
|
+
avg: str = self.average
|
|
102
|
+
return float(precision_score(y_true, y_pred, average=avg, zero_division=0))
|
|
379
103
|
|
|
380
|
-
|
|
381
|
-
|
|
104
|
+
def recall(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
105
|
+
"""Calculate the recall of the model.
|
|
382
106
|
|
|
383
|
-
|
|
384
|
-
y_pred_masked_decoded = nn.decode_masked(
|
|
385
|
-
y_true_masked, y_pred_masked, is_multiclass=is_multiclass
|
|
386
|
-
)
|
|
387
|
-
else:
|
|
388
|
-
y_pred_masked_decoded = y_pred_masked
|
|
389
|
-
|
|
390
|
-
roc_auc = Scorers.compute_roc_auc_micro_macro(
|
|
391
|
-
y_true_masked,
|
|
392
|
-
y_pred_masked,
|
|
393
|
-
num_classes=num_classes,
|
|
394
|
-
binarize_pred=False,
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
return roc_auc["macro"]
|
|
398
|
-
|
|
399
|
-
@staticmethod
|
|
400
|
-
def auc_micro(y_true, y_pred, **kwargs):
|
|
401
|
-
"""Get AUC score with micro averaging for grid search.
|
|
402
|
-
|
|
403
|
-
If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
|
|
107
|
+
This method calculates the recall of the model by comparing the ground truth labels with the predicted labels.
|
|
404
108
|
|
|
405
109
|
Args:
|
|
406
|
-
y_true (
|
|
407
|
-
|
|
408
|
-
y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
|
|
409
|
-
|
|
410
|
-
kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
|
|
110
|
+
y_true (np.ndarray): Ground truth labels.
|
|
111
|
+
y_pred (np.ndarray): Predicted labels.
|
|
411
112
|
|
|
412
113
|
Returns:
|
|
413
|
-
float:
|
|
114
|
+
float: Recall score.
|
|
414
115
|
"""
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
missing_mask = kwargs.get("missing_mask")
|
|
418
|
-
nn_model = kwargs.get("nn_model", True)
|
|
419
|
-
num_classes = kwargs.get("num_classes", 3)
|
|
420
|
-
|
|
421
|
-
is_multiclass = True if num_classes != 4 else False
|
|
116
|
+
avg: str = self.average
|
|
117
|
+
return float(recall_score(y_true, y_pred, average=avg, zero_division=0))
|
|
422
118
|
|
|
423
|
-
|
|
119
|
+
def roc_auc(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
120
|
+
"""Multiclass ROC-AUC with label targets.
|
|
424
121
|
|
|
425
|
-
|
|
426
|
-
nn = NeuralNetworkMethods()
|
|
427
|
-
|
|
428
|
-
y_true_masked = y_true[missing_mask]
|
|
429
|
-
y_pred_masked = y_pred[missing_mask]
|
|
430
|
-
|
|
431
|
-
if nn_model:
|
|
432
|
-
y_pred_masked_decoded = nn.decode_masked(
|
|
433
|
-
y_true_masked, y_pred_masked, is_multiclass=is_multiclass
|
|
434
|
-
)
|
|
435
|
-
else:
|
|
436
|
-
y_pred_masked_decoded = y_pred_masked
|
|
437
|
-
|
|
438
|
-
roc_auc = Scorers.compute_roc_auc_micro_macro(
|
|
439
|
-
y_true_masked,
|
|
440
|
-
y_pred_masked,
|
|
441
|
-
num_classes=num_classes,
|
|
442
|
-
binarize_pred=False,
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
return roc_auc["micro"]
|
|
446
|
-
|
|
447
|
-
@staticmethod
|
|
448
|
-
def pr_macro(y_true, y_pred, **kwargs):
|
|
449
|
-
"""Get Precision-Recall score with macro averaging for grid search.
|
|
450
|
-
|
|
451
|
-
If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
|
|
122
|
+
This method calculates the ROC-AUC score for multiclass classification problems. It handles both 1D integer labels and 2D one-hot/indicator matrices for the ground truth labels.
|
|
452
123
|
|
|
453
124
|
Args:
|
|
454
|
-
y_true
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
|
|
459
|
-
|
|
460
|
-
Returns:
|
|
461
|
-
float: Metric score by comparing y_true and y_pred.
|
|
125
|
+
y_true: 1D integer labels (shape: [n]).
|
|
126
|
+
If a one-hot/indicator matrix is supplied, we convert to labels.
|
|
127
|
+
y_pred_proba: 2D probabilities (shape: [n, n_classes]).
|
|
462
128
|
"""
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
129
|
+
y_true = np.asarray(y_true)
|
|
130
|
+
y_pred_proba = np.asarray(y_pred_proba)
|
|
131
|
+
|
|
132
|
+
if y_pred_proba.ndim == 3:
|
|
133
|
+
y_pred_proba = y_pred_proba.reshape(-1, y_pred_proba.shape[-1])
|
|
134
|
+
|
|
135
|
+
# If user passed indicator/one-hot, convert to labels.
|
|
136
|
+
if y_true.ndim == 2 and y_true.shape[1] == y_pred_proba.shape[1]:
|
|
137
|
+
y_true = y_true.argmax(axis=1)
|
|
138
|
+
|
|
139
|
+
# Guard: need >1 class present for AUC
|
|
140
|
+
if np.unique(y_true).size < 2:
|
|
141
|
+
return 0.5
|
|
142
|
+
|
|
143
|
+
return float(
|
|
144
|
+
roc_auc_score(
|
|
145
|
+
y_true,
|
|
146
|
+
y_pred_proba,
|
|
147
|
+
multi_class="ovr",
|
|
148
|
+
average=self.average,
|
|
149
|
+
)
|
|
477
150
|
)
|
|
478
151
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
152
|
+
def evaluate(
|
|
153
|
+
self,
|
|
154
|
+
y_true: np.ndarray | Tensor | list,
|
|
155
|
+
y_pred: np.ndarray | Tensor | list,
|
|
156
|
+
y_true_ohe: np.ndarray | Tensor | list,
|
|
157
|
+
y_pred_proba: np.ndarray | Tensor | list,
|
|
158
|
+
objective_mode: bool = False,
|
|
159
|
+
tune_metric: Literal[
|
|
160
|
+
"pr_macro",
|
|
161
|
+
"roc_auc",
|
|
162
|
+
"average_precision",
|
|
163
|
+
"accuracy",
|
|
164
|
+
"f1",
|
|
165
|
+
"precision",
|
|
166
|
+
"recall",
|
|
167
|
+
] = "pr_macro",
|
|
168
|
+
) -> Dict[str, float] | None:
|
|
169
|
+
"""Evaluate the model using various metrics.
|
|
170
|
+
|
|
171
|
+
This method evaluates the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The method can be used to evaluate the performance of a model on a dataset with ground truth labels. The method can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
|
|
486
172
|
|
|
487
173
|
Args:
|
|
488
|
-
y_true (
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
174
|
+
y_true (np.ndarray | torch.Tensor): Ground truth labels.
|
|
175
|
+
y_pred (np.ndarray | torch.Tensor): Predicted labels.
|
|
176
|
+
y_true_ohe (np.ndarray | torch.Tensor): One-hot encoded ground truth labels.
|
|
177
|
+
y_pred_proba (np.ndarray | torch.Tensor): Predicted probabilities.
|
|
178
|
+
objective_mode (bool): Whether to use objective mode for evaluation. Default is False.
|
|
179
|
+
tune_metric (Literal["pr_macro", "roc_auc", "average_precision", "accuracy", "f1", "precision", "recall"]): Metric to use for tuning. Ignored if `objective_mode` is False. Default is 'pr_macro'.
|
|
493
180
|
|
|
494
181
|
Returns:
|
|
495
|
-
float:
|
|
496
|
-
"""
|
|
497
|
-
# Get missing mask if provided.
|
|
498
|
-
# Otherwise default is all missing values (array all True).
|
|
499
|
-
missing_mask = kwargs.get("missing_mask")
|
|
500
|
-
num_classes = kwargs.get("num_classes", 3)
|
|
501
|
-
testing = kwargs.get("testing", False)
|
|
502
|
-
|
|
503
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
504
|
-
|
|
505
|
-
nn = NeuralNetworkMethods()
|
|
506
|
-
|
|
507
|
-
y_true_masked = y_true[missing_mask]
|
|
508
|
-
y_pred_masked = y_pred[missing_mask]
|
|
509
|
-
|
|
510
|
-
pr_ap = Scorers.compute_pr(
|
|
511
|
-
y_true_masked, y_pred_masked, num_classes=num_classes
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
return pr_ap["samples"]
|
|
515
|
-
|
|
516
|
-
@staticmethod
|
|
517
|
-
def f1_samples(y_true, y_pred, **kwargs):
|
|
518
|
-
"""Get F1 score with samples averaging for grid search.
|
|
182
|
+
Dict[str, float]: Dictionary of evaluation metrics. Keys are 'accuracy', 'f1', 'precision', 'recall', 'roc_auc', 'average_precision', and 'pr_macro'.
|
|
519
183
|
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
y_true (numpy.ndarray): 012-encoded true target values.
|
|
524
|
-
|
|
525
|
-
y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
|
|
526
|
-
|
|
527
|
-
kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
|
|
528
|
-
|
|
529
|
-
Returns:
|
|
530
|
-
float: Metric score by comparing y_true and y_pred.
|
|
184
|
+
Raises:
|
|
185
|
+
ValueError: If the input data is invalid.
|
|
186
|
+
ValueError: If an invalid tune_metric is provided.
|
|
531
187
|
"""
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
y_pred = Scorers.check_if_tuple(y_pred)
|
|
538
|
-
|
|
539
|
-
y_true_masked = y_true[missing_mask]
|
|
540
|
-
y_pred_masked = y_pred[missing_mask]
|
|
541
|
-
|
|
542
|
-
pr_ap = Scorers.compute_pr(
|
|
543
|
-
y_true_masked, y_pred_masked, num_classes=num_classes
|
|
188
|
+
y_true = np.asarray(validate_input_type(y_true, return_type="array"))
|
|
189
|
+
y_pred = np.asarray(validate_input_type(y_pred, return_type="array"))
|
|
190
|
+
y_true_ohe = np.asarray(validate_input_type(y_true_ohe, return_type="array"))
|
|
191
|
+
y_pred_proba = np.asarray(
|
|
192
|
+
validate_input_type(y_pred_proba, return_type="array")
|
|
544
193
|
)
|
|
545
194
|
|
|
546
|
-
|
|
195
|
+
if not y_true.ndim < 3:
|
|
196
|
+
msg = "y_true must have 1 or 2 dimensions."
|
|
197
|
+
self.logger.error(msg)
|
|
198
|
+
raise ValueError(msg)
|
|
199
|
+
|
|
200
|
+
if not y_pred.ndim < 3:
|
|
201
|
+
msg = "y_pred must have 1 or 2 dimensions."
|
|
202
|
+
self.logger.error(msg)
|
|
203
|
+
raise ValueError(msg)
|
|
204
|
+
|
|
205
|
+
if not y_true_ohe.ndim == 2:
|
|
206
|
+
msg = "y_true_ohe must have 2 dimensions."
|
|
207
|
+
self.logger.error(msg)
|
|
208
|
+
raise ValueError(msg)
|
|
209
|
+
|
|
210
|
+
if y_pred_proba.ndim != 2:
|
|
211
|
+
y_pred_proba = y_pred_proba.reshape(-1, y_true_ohe.shape[-1])
|
|
212
|
+
self.logger.debug(f"Reshaped y_pred_proba to {y_pred_proba.shape}")
|
|
213
|
+
|
|
214
|
+
if objective_mode:
|
|
215
|
+
if tune_metric == "pr_macro":
|
|
216
|
+
metrics = {"pr_macro": self.pr_macro(y_true_ohe, y_pred_proba)}
|
|
217
|
+
elif tune_metric == "roc_auc":
|
|
218
|
+
metrics = {"roc_auc": self.roc_auc(y_true, y_pred_proba)}
|
|
219
|
+
elif tune_metric == "average_precision":
|
|
220
|
+
metrics = {
|
|
221
|
+
"average_precision": self.average_precision(y_true, y_pred_proba)
|
|
222
|
+
}
|
|
223
|
+
elif tune_metric == "accuracy":
|
|
224
|
+
metrics = {"accuracy": self.accuracy(y_true, y_pred)}
|
|
225
|
+
elif tune_metric == "f1":
|
|
226
|
+
metrics = {"f1": self.f1(y_true, y_pred)}
|
|
227
|
+
elif tune_metric == "precision":
|
|
228
|
+
metrics = {"precision": self.precision(y_true, y_pred)}
|
|
229
|
+
elif tune_metric == "recall":
|
|
230
|
+
metrics = {"recall": self.recall(y_true, y_pred)}
|
|
231
|
+
else:
|
|
232
|
+
msg = f"Invalid tune_metric provided: '{tune_metric}'."
|
|
233
|
+
self.logger.error(msg)
|
|
234
|
+
raise ValueError(msg)
|
|
235
|
+
else:
|
|
236
|
+
metrics = {
|
|
237
|
+
"accuracy": self.accuracy(y_true, y_pred),
|
|
238
|
+
"f1": self.f1(y_true, y_pred),
|
|
239
|
+
"precision": self.precision(y_true, y_pred),
|
|
240
|
+
"recall": self.recall(y_true, y_pred),
|
|
241
|
+
"roc_auc": self.roc_auc(y_true, y_pred_proba),
|
|
242
|
+
"average_precision": self.average_precision(y_true, y_pred_proba),
|
|
243
|
+
"pr_macro": self.pr_macro(y_true_ohe, y_pred_proba),
|
|
244
|
+
}
|
|
547
245
|
|
|
548
|
-
|
|
549
|
-
def pr_micro(y_true, y_pred, **kwargs):
|
|
550
|
-
"""Get Precision-Recall score with micro averaging for grid search.
|
|
246
|
+
return {k: float(v) for k, v in metrics.items()}
|
|
551
247
|
|
|
552
|
-
|
|
248
|
+
def average_precision(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
249
|
+
"""Average precision with safe multiclass handling.
|
|
553
250
|
|
|
554
|
-
|
|
555
|
-
y_true (numpy.ndarray): 012-encoded true target values.
|
|
251
|
+
If y_true is 1D of class indices, it is binarized against the number of columns in y_pred_proba. If y_true is already one-hot or indicator, it is used as-is.
|
|
556
252
|
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
253
|
+
Args:
|
|
254
|
+
y_true (np.ndarray): Ground truth labels (1D class indices or 2D one-hot/indicator).
|
|
255
|
+
y_pred_proba (np.ndarray): Predicted probabilities (2D array).
|
|
560
256
|
|
|
561
257
|
Returns:
|
|
562
|
-
float:
|
|
258
|
+
float: Average precision score.
|
|
563
259
|
"""
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
missing_mask = kwargs.get("missing_mask")
|
|
567
|
-
num_classes = kwargs.get("num_classes", 3)
|
|
568
|
-
testing = kwargs.get("testing", False)
|
|
260
|
+
y_true_arr = np.asarray(y_true)
|
|
261
|
+
y_proba_arr = np.asarray(y_pred_proba)
|
|
569
262
|
|
|
570
|
-
|
|
263
|
+
if y_proba_arr.ndim == 3:
|
|
264
|
+
y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
|
|
571
265
|
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
)
|
|
266
|
+
# If y_true already matches proba columns (one-hot / indicator)
|
|
267
|
+
if y_true_arr.ndim == 2 and y_true_arr.shape[1] == y_proba_arr.shape[1]:
|
|
268
|
+
y_bin = y_true_arr
|
|
269
|
+
else:
|
|
270
|
+
# Interpret y_true as class indices
|
|
271
|
+
n_classes = y_proba_arr.shape[1]
|
|
272
|
+
y_bin = label_binarize(y_true_arr.ravel(), classes=np.arange(n_classes))
|
|
580
273
|
|
|
581
|
-
return
|
|
274
|
+
return float(average_precision_score(y_bin, y_proba_arr, average=self.average))
|
|
582
275
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
cls, metrics, missing_mask, num_classes=3, testing=False
|
|
586
|
-
):
|
|
587
|
-
"""Get all scoring metrics and make an sklearn scorer.
|
|
276
|
+
def pr_macro(self, y_true_ohe: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
277
|
+
"""Macro-averaged average precision (precision-recall AUC) across classes.
|
|
588
278
|
|
|
589
279
|
Args:
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
missing_mask (numpy.ndarray): Missing mask to use to demarcate values to use with scoring.
|
|
280
|
+
y_true_ohe (np.ndarray): One-hot encoded ground truth labels (2D array).
|
|
281
|
+
y_pred_proba (np.ndarray): Predicted probabilities (2D array).
|
|
593
282
|
|
|
594
|
-
num_classes (int, optional): How many classes to use. Defaults to 3.
|
|
595
|
-
|
|
596
|
-
testing (bool, optional): True if in test mode, wherein it prints y_true and y_pred_decoded as 1D lists for comparison. Otherwise False. Defaults to False.
|
|
597
283
|
Returns:
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
Raises:
|
|
601
|
-
ValueError: Invalid scoring metric provided.
|
|
284
|
+
float: Macro-averaged average precision score.
|
|
602
285
|
"""
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
scorers = dict()
|
|
607
|
-
for item in metrics:
|
|
608
|
-
if item.lower() == "accuracy":
|
|
609
|
-
scorers["accuracy"] = make_scorer(
|
|
610
|
-
cls.accuracy_scorer,
|
|
611
|
-
missing_mask=missing_mask,
|
|
612
|
-
num_classes=num_classes,
|
|
613
|
-
testing=testing,
|
|
614
|
-
)
|
|
615
|
-
elif item.lower() == "hamming":
|
|
616
|
-
scorers["hamming"] = make_scorer(
|
|
617
|
-
cls.hamming_scorer,
|
|
618
|
-
missing_mask=missing_mask,
|
|
619
|
-
num_classes=num_classes,
|
|
620
|
-
testing=testing,
|
|
621
|
-
greater_is_better=False,
|
|
622
|
-
)
|
|
623
|
-
elif item.lower() == "auc_macro":
|
|
624
|
-
scorers["auc_macro"] = make_scorer(
|
|
625
|
-
cls.auc_macro,
|
|
626
|
-
missing_mask=missing_mask,
|
|
627
|
-
num_classes=num_classes,
|
|
628
|
-
testing=testing,
|
|
629
|
-
)
|
|
630
|
-
elif item.lower() == "auc_micro":
|
|
631
|
-
scorers["auc_micro"] = make_scorer(
|
|
632
|
-
cls.auc_micro,
|
|
633
|
-
missing_mask=missing_mask,
|
|
634
|
-
num_classes=num_classes,
|
|
635
|
-
testing=testing,
|
|
636
|
-
)
|
|
637
|
-
elif item.lower() == "precision_recall_macro":
|
|
638
|
-
scorers["precision_recall_macro"] = make_scorer(
|
|
639
|
-
cls.pr_macro,
|
|
640
|
-
missing_mask=missing_mask,
|
|
641
|
-
num_classes=num_classes,
|
|
642
|
-
testing=testing,
|
|
643
|
-
)
|
|
644
|
-
elif item.lower() == "precision_recall_micro":
|
|
645
|
-
scorers["precision_recall_micro"] = make_scorer(
|
|
646
|
-
cls.pr_micro,
|
|
647
|
-
missing_mask=missing_mask,
|
|
648
|
-
num_classes=num_classes,
|
|
649
|
-
testing=testing,
|
|
650
|
-
)
|
|
651
|
-
elif item.lower() == "precision_recall_samples":
|
|
652
|
-
scorers["precision_recall_samples"] = make_scorer(
|
|
653
|
-
cls.pr_samples,
|
|
654
|
-
missing_mask=missing_mask,
|
|
655
|
-
num_classes=num_classes,
|
|
656
|
-
testing=testing,
|
|
657
|
-
)
|
|
658
|
-
elif item.lower() == "f1_score":
|
|
659
|
-
scorers["f1_score"] = make_scorer(
|
|
660
|
-
cls.f1_samples,
|
|
661
|
-
missing_mask=missing_mask,
|
|
662
|
-
num_classes=num_classes,
|
|
663
|
-
testing=testing,
|
|
664
|
-
)
|
|
665
|
-
else:
|
|
666
|
-
raise ValueError(f"Invalid scoring_metric provided: {item}")
|
|
667
|
-
return scorers
|
|
668
|
-
|
|
669
|
-
@staticmethod
|
|
670
|
-
def scorer(y_true, y_pred, **kwargs):
|
|
671
|
-
# Get missing mask if provided.
|
|
672
|
-
# Otherwise default is all missing values (array all True).
|
|
673
|
-
missing_mask = kwargs.get("missing_mask")
|
|
674
|
-
nn_model = kwargs.get("nn_model", True)
|
|
675
|
-
num_classes = kwargs.get("num_classes", 3)
|
|
676
|
-
testing = kwargs.get("testing", False)
|
|
677
|
-
|
|
678
|
-
is_multiclass = True if num_classes != 4 else False
|
|
679
|
-
|
|
680
|
-
if nn_model:
|
|
681
|
-
nn = NeuralNetworkMethods()
|
|
682
|
-
|
|
683
|
-
# VAE has tuple output.
|
|
684
|
-
if isinstance(y_pred, tuple):
|
|
685
|
-
y_pred = y_pred[0]
|
|
686
|
-
|
|
687
|
-
y_true_masked = y_true[missing_mask]
|
|
688
|
-
y_pred_masked = y_pred[missing_mask]
|
|
689
|
-
|
|
690
|
-
roc_auc = Scorers.compute_roc_auc_micro_macro(
|
|
691
|
-
y_true_masked,
|
|
692
|
-
y_pred_masked,
|
|
693
|
-
num_classes=num_classes,
|
|
694
|
-
binarize_pred=False,
|
|
695
|
-
)
|
|
286
|
+
y_true_arr = np.asarray(y_true_ohe)
|
|
287
|
+
y_proba_arr = np.asarray(y_pred_proba)
|
|
696
288
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
y_pred_masked,
|
|
700
|
-
num_classes=num_classes,
|
|
701
|
-
)
|
|
289
|
+
if y_proba_arr.ndim == 3:
|
|
290
|
+
y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
|
|
702
291
|
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
y_pred_masked,
|
|
708
|
-
is_multiclass=is_multiclass,
|
|
709
|
-
return_int=True,
|
|
710
|
-
),
|
|
711
|
-
)
|
|
712
|
-
ham = hamming_loss(
|
|
713
|
-
y_true_masked,
|
|
714
|
-
nn.decode_masked(
|
|
715
|
-
y_true_masked,
|
|
716
|
-
y_pred_masked,
|
|
717
|
-
is_multiclass=is_multiclass,
|
|
718
|
-
return_int=True,
|
|
719
|
-
),
|
|
720
|
-
)
|
|
721
|
-
|
|
722
|
-
if testing:
|
|
723
|
-
y_pred_masked_decoded = nn.decode_masked(
|
|
724
|
-
y_true_masked,
|
|
725
|
-
y_pred_masked,
|
|
726
|
-
is_multiclass=is_multiclass,
|
|
727
|
-
return_int=True,
|
|
728
|
-
)
|
|
292
|
+
# Ensure 2D indicator truth
|
|
293
|
+
if y_true_arr.ndim == 1:
|
|
294
|
+
n_classes = y_proba_arr.shape[1]
|
|
295
|
+
y_true_arr = label_binarize(y_true_arr, classes=np.arange(n_classes))
|
|
729
296
|
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
with open("genotype_dist.csv", "w") as fout:
|
|
733
|
-
fout.write(
|
|
734
|
-
"site,prob_vector,imputed_genotype,expected_genotype\n"
|
|
735
|
-
)
|
|
736
|
-
for i, (yt, yp, ypd) in enumerate(
|
|
737
|
-
zip(y_true_masked, bin_mapping, y_pred_masked_decoded)
|
|
738
|
-
):
|
|
739
|
-
fout.write(f"{i},{yp},{ypd},{yt}\n")
|
|
740
|
-
# np.set_printoptions(threshold=np.inf)
|
|
741
|
-
# print(y_true_masked)
|
|
742
|
-
# print(y_pred_masked_decoded)
|
|
743
|
-
|
|
744
|
-
metrics = dict()
|
|
745
|
-
metrics["accuracy"] = acc
|
|
746
|
-
metrics["roc_auc"] = roc_auc
|
|
747
|
-
metrics["precision_recall"] = pr_ap
|
|
748
|
-
metrics["hamming"] = ham
|
|
749
|
-
|
|
750
|
-
return metrics
|
|
297
|
+
return float(average_precision_score(y_true_arr, y_proba_arr, average="macro"))
|