pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
pgsui/utils/scorers.py CHANGED
@@ -1,750 +1,297 @@
1
- import sys
1
+ from typing import Dict, Literal
2
2
 
3
3
  import numpy as np
4
-
5
4
  from sklearn.metrics import (
6
- roc_curve,
7
- auc,
8
5
  accuracy_score,
9
- hamming_loss,
10
- make_scorer,
11
- precision_recall_curve,
12
6
  average_precision_score,
13
- multilabel_confusion_matrix,
14
7
  f1_score,
8
+ precision_score,
9
+ recall_score,
10
+ roc_auc_score,
15
11
  )
16
-
17
12
  from sklearn.preprocessing import label_binarize
13
+ from snpio.utils.logging import LoggerManager
14
+ from torch import Tensor
18
15
 
19
- try:
20
- from ..impute.unsupervised.neural_network_methods import (
21
- NeuralNetworkMethods,
22
- )
23
- except (ModuleNotFoundError, ValueError, ImportError):
24
- from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
25
-
26
-
27
- class Scorers:
28
- @staticmethod
29
- def compute_roc_auc_micro_macro(
30
- y_true, y_pred, num_classes=3, binarize_pred=True
31
- ):
32
- """Compute ROC curve with AUC scores.
33
-
34
- ROC (Receiver Operating Characteristic) curves and AUC (area under curve) scores are computed per-class and for micro and macro averages.
35
-
36
- Args:
37
- y_true (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_true should be integer-encoded.
38
-
39
- y_pred (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_pred should be probabilities.
40
-
41
- num_classes (int, optional): How many classes to use. Defaults to 3.
42
-
43
- binarize_pred (bool, optional): Whether to binarize y_pred. If False, y_pred should be probabilities of each class. Defaults to True.
44
-
45
- Returns:
46
- Dict[str, Any]: Dictionary with true and false positive rates along probability threshold curve per class, micro and macro tpr and fpr curves averaged across classes, and AUC scores per-class and for micro and macro averages.
47
- """
48
- cats = range(num_classes)
49
-
50
- # Get only classes that appear in y_true.
51
- classes = [i for i in cats if i in y_true]
52
-
53
- # Binarize the output for use with ROC-AUC.
54
- y_true_bin = label_binarize(y_true, classes=cats)
55
-
56
- if binarize_pred:
57
- y_pred_bin = label_binarize(y_pred, classes=cats)
58
- else:
59
- y_pred_bin = y_pred
60
-
61
- for i in range(y_true_bin.shape[1]):
62
- if i not in classes:
63
- y_true_bin = np.delete(y_true_bin, i, axis=-1)
64
- y_pred_bin = np.delete(y_pred_bin, i, axis=-1)
65
-
66
- n_classes = len(classes)
67
-
68
- # Compute ROC curve and ROC area for each class.
69
- fpr = dict()
70
- tpr = dict()
71
- roc_auc = dict()
72
- for i, c in enumerate(classes):
73
- fpr[c], tpr[c], _ = roc_curve(y_true_bin[:, i], y_pred_bin[:, i])
74
- roc_auc[c] = auc(fpr[c], tpr[c])
75
-
76
- # Compute micro-average ROC curve and ROC area.
77
- fpr["micro"], tpr["micro"], _ = roc_curve(
78
- y_true_bin.ravel(), y_pred_bin.ravel()
79
- )
80
-
81
- roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
82
-
83
- # Aggregate all false positive rates
84
- all_fpr = np.unique(np.concatenate([fpr[i] for i in classes]))
16
+ from pgsui.utils.logging_utils import configure_logger
17
+ from pgsui.utils.misc import validate_input_type
85
18
 
86
- # Then interpolate all ROC curves at these points.
87
- mean_tpr = np.zeros_like(all_fpr)
88
- for i in classes:
89
- mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
90
19
 
91
- # Finally, average it and compute AUC.
92
- mean_tpr /= n_classes
20
+ class Scorer:
21
+ """Class for evaluating the performance of a model using various metrics.
93
22
 
94
- fpr["macro"] = all_fpr
95
- tpr["macro"] = mean_tpr
23
+ This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
24
+ """
96
25
 
97
- roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
26
+ def __init__(
27
+ self,
28
+ prefix: str,
29
+ average: Literal["micro", "macro", "weighted"] = "macro",
30
+ verbose: bool = False,
31
+ debug: bool = False,
32
+ ) -> None:
33
+ """Initialize a Scorer object.
98
34
 
99
- roc_auc["fpr_macro"] = fpr["macro"]
100
- roc_auc["tpr_macro"] = tpr["macro"]
101
- roc_auc["fpr_micro"] = fpr["micro"]
102
- roc_auc["tpr_micro"] = tpr["micro"]
103
-
104
- for i in classes:
105
- roc_auc[f"fpr_{i}"] = fpr[i]
106
- roc_auc[f"tpr_{i}"] = tpr[i]
107
-
108
- return roc_auc
109
-
110
- @staticmethod
111
- def compute_pr(y_true, y_pred, use_int_encodings=False, num_classes=3):
112
- """Compute precision-recall curve with Average Precision scores.
113
-
114
- PR and AP scores are computed per-class and for micro and macro averages.
35
+ This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
115
36
 
116
37
  Args:
117
- y_true (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,).
118
-
119
- y_pred (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_pred should be integer-encoded.
38
+ prefix (str): Prefix for logging messages.
39
+ average (Literal["micro", "macro", "weighted"]): Average method for metrics. Must be one of 'micro', 'macro', or 'weighted'.
40
+ verbose (bool): Verbosity level for logging messages. Default is False.
41
+ debug (bool): Debug mode for logging messages. Default is False.
120
42
 
121
- use_int_encodings (bool, optional): Whether the imputer model is a neural network model. Defaults to False.
122
-
123
- num_classes (int, optional): How many classes to use. Defaults to 3.
124
-
125
- Returns:
126
- Dict[str, Any]: Dictionary with precision and recall curves per class and micro and macro averaged across classes, plus AP scores per-class and for micro and macro averages.
43
+ Raises:
44
+ ValueError: If the average parameter is invalid. Must be one of 'micro', 'macro', or 'weighted'.
127
45
  """
128
-
129
- cats = range(num_classes)
130
-
131
- is_multiclass = True if num_classes != 4 else False
132
-
133
- # Get only classes that appear in y_true.
134
- classes = [i for i in cats if i in y_true]
135
-
136
- # Binarize the output for use with ROC-AUC.
137
- y_true_bin = label_binarize(y_true, classes=cats)
138
- y_pred_proba_bin = y_pred
139
-
140
- if is_multiclass:
141
- for i in range(y_true_bin.shape[1]):
142
- if i not in classes:
143
- y_true_bin = np.delete(y_true_bin, i, axis=-1)
144
- y_pred_proba_bin = np.delete(y_pred_proba_bin, i, axis=-1)
145
-
146
- nn = NeuralNetworkMethods()
147
- if use_int_encodings:
148
- y_pred_012 = nn.decode_masked(y_true_bin, y_pred_proba_bin)
149
- thresh = 0.5
150
- else:
151
- y_pred_012 = nn.decode_masked(
152
- y_true_bin,
153
- y_pred_proba_bin,
154
- is_multiclass=is_multiclass,
155
- return_int=False,
156
- return_multilab=True,
157
- )
158
-
159
- encode_func = (
160
- nn.encode_multiclass if is_multiclass else nn.encode_multilab
161
- )
162
-
163
- y_true = encode_func(y_true, num_classes=num_classes)
164
-
165
- # Make confusion matrix to get true negatives and true positives.
166
- mcm = multilabel_confusion_matrix(y_true, y_pred_012)
167
-
168
- tn = np.sum(mcm[:, 0, 0])
169
- tn /= num_classes
170
-
171
- tp = np.sum(mcm[:, 1, 1])
172
- tp /= num_classes
173
-
174
- baseline = tp / (tn + tp)
175
-
176
- precision = dict()
177
- recall = dict()
178
- average_precision = dict()
179
-
180
- for i, c in enumerate(classes):
181
- precision[c], recall[c], _ = precision_recall_curve(
182
- y_true_bin[:, i], y_pred_proba_bin[:, i]
183
- )
184
- average_precision[c] = average_precision_score(
185
- y_true_bin[:, i], y_pred_proba_bin[:, i]
186
- )
187
-
188
- # A "micro-average": quantifying score on all classes jointly.
189
- precision["micro"], recall["micro"], _ = precision_recall_curve(
190
- y_true_bin.ravel(), y_pred_proba_bin.ravel()
191
- )
192
-
193
- average_precision["micro"] = average_precision_score(
194
- y_true_bin, y_pred_proba_bin, average="micro"
46
+ logman = LoggerManager(
47
+ name=__name__, prefix=prefix, debug=debug, verbose=verbose >= 1
195
48
  )
196
-
197
- average_precision["macro"] = average_precision_score(
198
- y_true_bin, y_pred_proba_bin, average="macro"
49
+ self.logger = configure_logger(
50
+ logman.get_logger(), verbose=verbose >= 1, debug=debug
199
51
  )
200
52
 
201
- if use_int_encodings:
202
- y_pred_012 = (
203
- nn.decode_masked(
204
- y_true_bin,
205
- y_pred_proba_bin,
206
- threshold=thresh,
207
- return_multilab=True,
208
- predict_still_missing=False,
209
- ),
210
- )
211
-
212
- f1 = f1_score(y_true_bin, y_pred_012, average="macro")
213
-
214
- # Aggregate all recalls
215
- all_recall = np.unique(np.concatenate([recall[i] for i in classes]))
216
-
217
- # Then interpolate all PR curves at these points.
218
- mean_precision = np.zeros_like(all_recall)
219
- for i in classes:
220
- mean_precision += np.interp(all_recall, precision[i], recall[i])
221
-
222
- # Finally, average it and compute AUC.
223
- mean_precision /= num_classes
224
-
225
- recall["macro"] = all_recall
226
- precision["macro"] = mean_precision
227
-
228
- results = dict()
229
-
230
- results["micro"] = average_precision["micro"]
231
- results["macro"] = average_precision["macro"]
232
- results["f1_score"] = f1
233
- results["recall_macro"] = all_recall
234
- results["precision_macro"] = mean_precision
235
- results["recall_micro"] = recall["micro"]
236
- results["precision_micro"] = precision["micro"]
237
-
238
- for i in classes:
239
- results[f"recall_{i}"] = recall[i]
240
- results[f"precision_{i}"] = precision[i]
241
- results[i] = average_precision[i]
242
- results["baseline"] = baseline
53
+ if average not in {"micro", "macro", "weighted"}:
54
+ msg = f"Invalid average parameter: {average}. Must be one of 'micro', 'macro', or 'weighted'."
55
+ self.logger.error(msg)
56
+ raise ValueError(msg)
243
57
 
244
- return results
58
+ self.average: Literal["micro", "macro", "weighted"] = average
245
59
 
246
- @staticmethod
247
- def check_if_tuple(y_pred):
248
- """Checks if y_pred is a tuple and if so, returns the first element of the tuple."""
249
- if isinstance(y_pred, tuple):
250
- y_pred = y_pred[0]
251
- return y_pred
60
+ def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
61
+ """Calculate the accuracy of the model.
252
62
 
253
- @staticmethod
254
- def accuracy_scorer(y_true, y_pred, **kwargs):
255
- """Get accuracy score for grid search.
256
-
257
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
63
+ This method calculates the accuracy of the model by comparing the ground truth labels with the predicted labels.
258
64
 
259
65
  Args:
260
- y_true (numpy.ndarray): 012-encoded true target values.
261
-
262
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities. They must first be decoded to use with accuracy_score.
263
-
264
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
66
+ y_true (np.ndarray): Ground truth labels.
67
+ y_pred (np.ndarray): Predicted labels.
265
68
 
266
69
  Returns:
267
- float: Metric score by comparing y_true and y_pred.
70
+ float: Accuracy score.
268
71
  """
269
- # Get missing mask if provided.
270
- # Otherwise default is all missing values (array all True).
271
-
272
- missing_mask = kwargs.get("missing_mask")
273
- testing = kwargs.get("testing", False)
274
- nn_model = kwargs.get("nn_model", True)
275
-
276
- y_pred = Scorers.check_if_tuple(y_pred)
277
-
278
- if nn_model:
279
- nn = NeuralNetworkMethods()
280
-
281
- y_true_masked = y_true[missing_mask]
282
- y_pred_masked = y_pred[missing_mask]
283
-
284
- if nn_model:
285
- y_pred_masked_decoded = nn.decode_masked(
286
- y_true_masked, y_pred_masked, predict_still_missing=False
287
- )
288
- else:
289
- y_pred_masked_decoded = y_pred_masked
72
+ return float(accuracy_score(y_true, y_pred))
290
73
 
291
- acc = accuracy_score(y_true_masked, y_pred_masked_decoded)
74
+ def f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
75
+ """Calculate the F1 score of the model.
292
76
 
293
- if testing:
294
- np.set_printoptions(threshold=np.inf)
295
- print(y_true_masked)
296
- print(y_pred_masked_decoded)
297
-
298
- return acc
299
-
300
- @staticmethod
301
- def hamming_scorer(y_true, y_pred, **kwargs):
302
- """Get Hamming score for grid search.
303
-
304
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
77
+ This method calculates the F1 score of the model by comparing the ground truth labels with the predicted labels.
305
78
 
306
79
  Args:
307
- y_true (numpy.ndarray): 012-encoded true target values.
308
-
309
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities. They must first be decoded to use with hamming_scorer.
310
-
311
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
80
+ y_true (np.ndarray): Ground truth labels.
81
+ y_pred (np.ndarray): Predicted labels.
312
82
 
313
83
  Returns:
314
- float: Metric score by comparing y_true and y_pred.
84
+ float: F1 score.
315
85
  """
316
- # Get missing mask if provided.
317
- # Otherwise default is all missing values (array all True).
318
-
319
- missing_mask = kwargs.get("missing_mask")
320
- testing = kwargs.get("testing", False)
321
- nn_model = kwargs.get("nn_model", True)
322
- num_classes = kwargs.get("num_classes", 3)
323
-
324
- y_pred = Scorers.check_if_tuple(y_pred)
325
-
326
- if nn_model:
327
- nn = NeuralNetworkMethods()
328
-
329
- y_true_masked = y_true[missing_mask]
330
- y_pred_masked = y_pred[missing_mask]
331
-
332
- if nn_model:
333
- y_pred_masked_decoded = nn.decode_masked(
334
- y_true_masked,
335
- y_pred_masked,
336
- predict_still_missing=False,
337
- )
338
- else:
339
- y_pred_masked_decoded = y_pred_masked
340
-
341
- ham = hamming_loss(y_true_masked, y_pred_masked_decoded)
342
-
343
- if testing:
344
- np.set_printoptions(threshold=np.inf)
345
- print(y_true_masked)
346
- print(y_pred_masked_decoded)
347
-
348
- return ham
86
+ avg: str = self.average
87
+ return float(f1_score(y_true, y_pred, average=avg, zero_division=0))
349
88
 
350
- @staticmethod
351
- def auc_macro(y_true, y_pred, **kwargs):
352
- """Get AUC score with macro averaging for grid search.
89
+ def precision(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
90
+ """Calculate the precision of the model.
353
91
 
354
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
92
+ This method calculates the precision of the model by comparing the ground truth labels with the predicted labels.
355
93
 
356
94
  Args:
357
- y_true (numpy.ndarray): 012-encoded true target values.
358
-
359
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
360
-
361
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
95
+ y_true (np.ndarray): Ground truth labels.
96
+ y_pred (np.ndarray): Predicted labels.
362
97
 
363
98
  Returns:
364
- float: Metric score by comparing y_true and y_pred.
99
+ float: Precision score.
365
100
  """
366
- # Get missing mask if provided.
367
- # Otherwise default is all missing values (array all True).
368
- missing_mask = kwargs.get("missing_mask")
369
- num_classes = kwargs.get("num_classes", 3)
370
- nn_model = kwargs.get("nn_model", True)
371
- testing = kwargs.get("testing", False)
372
-
373
- is_multiclass = True if num_classes != 4 else False
374
-
375
- y_pred = Scorers.check_if_tuple(y_pred)
376
-
377
- if nn_model:
378
- nn = NeuralNetworkMethods()
101
+ avg: str = self.average
102
+ return float(precision_score(y_true, y_pred, average=avg, zero_division=0))
379
103
 
380
- y_true_masked = y_true[missing_mask]
381
- y_pred_masked = y_pred[missing_mask]
104
+ def recall(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
105
+ """Calculate the recall of the model.
382
106
 
383
- if nn_model:
384
- y_pred_masked_decoded = nn.decode_masked(
385
- y_true_masked, y_pred_masked, is_multiclass=is_multiclass
386
- )
387
- else:
388
- y_pred_masked_decoded = y_pred_masked
389
-
390
- roc_auc = Scorers.compute_roc_auc_micro_macro(
391
- y_true_masked,
392
- y_pred_masked,
393
- num_classes=num_classes,
394
- binarize_pred=False,
395
- )
396
-
397
- return roc_auc["macro"]
398
-
399
- @staticmethod
400
- def auc_micro(y_true, y_pred, **kwargs):
401
- """Get AUC score with micro averaging for grid search.
402
-
403
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
107
+ This method calculates the recall of the model by comparing the ground truth labels with the predicted labels.
404
108
 
405
109
  Args:
406
- y_true (numpy.ndarray): 012-encoded true target values.
407
-
408
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
409
-
410
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
110
+ y_true (np.ndarray): Ground truth labels.
111
+ y_pred (np.ndarray): Predicted labels.
411
112
 
412
113
  Returns:
413
- float: Metric score by comparing y_true and y_pred.
114
+ float: Recall score.
414
115
  """
415
- # Get missing mask if provided.
416
- # Otherwise default is all missing values (array all True).
417
- missing_mask = kwargs.get("missing_mask")
418
- nn_model = kwargs.get("nn_model", True)
419
- num_classes = kwargs.get("num_classes", 3)
420
-
421
- is_multiclass = True if num_classes != 4 else False
116
+ avg: str = self.average
117
+ return float(recall_score(y_true, y_pred, average=avg, zero_division=0))
422
118
 
423
- y_pred = Scorers.check_if_tuple(y_pred)
119
+ def roc_auc(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
120
+ """Multiclass ROC-AUC with label targets.
424
121
 
425
- if nn_model:
426
- nn = NeuralNetworkMethods()
427
-
428
- y_true_masked = y_true[missing_mask]
429
- y_pred_masked = y_pred[missing_mask]
430
-
431
- if nn_model:
432
- y_pred_masked_decoded = nn.decode_masked(
433
- y_true_masked, y_pred_masked, is_multiclass=is_multiclass
434
- )
435
- else:
436
- y_pred_masked_decoded = y_pred_masked
437
-
438
- roc_auc = Scorers.compute_roc_auc_micro_macro(
439
- y_true_masked,
440
- y_pred_masked,
441
- num_classes=num_classes,
442
- binarize_pred=False,
443
- )
444
-
445
- return roc_auc["micro"]
446
-
447
- @staticmethod
448
- def pr_macro(y_true, y_pred, **kwargs):
449
- """Get Precision-Recall score with macro averaging for grid search.
450
-
451
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
122
+ This method calculates the ROC-AUC score for multiclass classification problems. It handles both 1D integer labels and 2D one-hot/indicator matrices for the ground truth labels.
452
123
 
453
124
  Args:
454
- y_true (numpy.ndarray): 012-encoded true target values.
455
-
456
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
457
-
458
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
459
-
460
- Returns:
461
- float: Metric score by comparing y_true and y_pred.
125
+ y_true: 1D integer labels (shape: [n]).
126
+ If a one-hot/indicator matrix is supplied, we convert to labels.
127
+ y_pred_proba: 2D probabilities (shape: [n, n_classes]).
462
128
  """
463
-
464
- # Get missing mask if provided.
465
- # Otherwise default is all missing values (array all True).
466
- missing_mask = kwargs.get("missing_mask")
467
- num_classes = kwargs.get("num_classes", 3)
468
- testing = kwargs.get("testing", False)
469
-
470
- y_pred = Scorers.check_if_tuple(y_pred)
471
-
472
- y_true_masked = y_true[missing_mask]
473
- y_pred_masked = y_pred[missing_mask]
474
-
475
- pr_ap = Scorers.compute_pr(
476
- y_true_masked, y_pred_masked, num_classes=num_classes
129
+ y_true = np.asarray(y_true)
130
+ y_pred_proba = np.asarray(y_pred_proba)
131
+
132
+ if y_pred_proba.ndim == 3:
133
+ y_pred_proba = y_pred_proba.reshape(-1, y_pred_proba.shape[-1])
134
+
135
+ # If user passed indicator/one-hot, convert to labels.
136
+ if y_true.ndim == 2 and y_true.shape[1] == y_pred_proba.shape[1]:
137
+ y_true = y_true.argmax(axis=1)
138
+
139
+ # Guard: need >1 class present for AUC
140
+ if np.unique(y_true).size < 2:
141
+ return 0.5
142
+
143
+ return float(
144
+ roc_auc_score(
145
+ y_true,
146
+ y_pred_proba,
147
+ multi_class="ovr",
148
+ average=self.average,
149
+ )
477
150
  )
478
151
 
479
- return pr_ap["macro"]
480
-
481
- @staticmethod
482
- def pr_samples(y_true, y_pred, **kwargs):
483
- """Get Precision-Recall score with samples averaging for grid search.
484
-
485
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
152
+ def evaluate(
153
+ self,
154
+ y_true: np.ndarray | Tensor | list,
155
+ y_pred: np.ndarray | Tensor | list,
156
+ y_true_ohe: np.ndarray | Tensor | list,
157
+ y_pred_proba: np.ndarray | Tensor | list,
158
+ objective_mode: bool = False,
159
+ tune_metric: Literal[
160
+ "pr_macro",
161
+ "roc_auc",
162
+ "average_precision",
163
+ "accuracy",
164
+ "f1",
165
+ "precision",
166
+ "recall",
167
+ ] = "pr_macro",
168
+ ) -> Dict[str, float] | None:
169
+ """Evaluate the model using various metrics.
170
+
171
+ This method evaluates the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The method can be used to evaluate the performance of a model on a dataset with ground truth labels. The method can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
486
172
 
487
173
  Args:
488
- y_true (numpy.ndarray): 012-encoded true target values.
489
-
490
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
491
-
492
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
174
+ y_true (np.ndarray | torch.Tensor): Ground truth labels.
175
+ y_pred (np.ndarray | torch.Tensor): Predicted labels.
176
+ y_true_ohe (np.ndarray | torch.Tensor): One-hot encoded ground truth labels.
177
+ y_pred_proba (np.ndarray | torch.Tensor): Predicted probabilities.
178
+ objective_mode (bool): Whether to use objective mode for evaluation. Default is False.
179
+ tune_metric (Literal["pr_macro", "roc_auc", "average_precision", "accuracy", "f1", "precision", "recall"]): Metric to use for tuning. Ignored if `objective_mode` is False. Default is 'pr_macro'.
493
180
 
494
181
  Returns:
495
- float: Metric score by comparing y_true and y_pred.
496
- """
497
- # Get missing mask if provided.
498
- # Otherwise default is all missing values (array all True).
499
- missing_mask = kwargs.get("missing_mask")
500
- num_classes = kwargs.get("num_classes", 3)
501
- testing = kwargs.get("testing", False)
502
-
503
- y_pred = Scorers.check_if_tuple(y_pred)
504
-
505
- nn = NeuralNetworkMethods()
506
-
507
- y_true_masked = y_true[missing_mask]
508
- y_pred_masked = y_pred[missing_mask]
509
-
510
- pr_ap = Scorers.compute_pr(
511
- y_true_masked, y_pred_masked, num_classes=num_classes
512
- )
513
-
514
- return pr_ap["samples"]
515
-
516
- @staticmethod
517
- def f1_samples(y_true, y_pred, **kwargs):
518
- """Get F1 score with samples averaging for grid search.
182
+ Dict[str, float]: Dictionary of evaluation metrics. Keys are 'accuracy', 'f1', 'precision', 'recall', 'roc_auc', 'average_precision', and 'pr_macro'.
519
183
 
520
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
521
-
522
- Args:
523
- y_true (numpy.ndarray): 012-encoded true target values.
524
-
525
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
526
-
527
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
528
-
529
- Returns:
530
- float: Metric score by comparing y_true and y_pred.
184
+ Raises:
185
+ ValueError: If the input data is invalid.
186
+ ValueError: If an invalid tune_metric is provided.
531
187
  """
532
- # Get missing mask if provided.
533
- # Otherwise default is all missing values (array all True).
534
- missing_mask = kwargs.get("missing_mask")
535
- num_classes = kwargs.get("num_classes", 3)
536
-
537
- y_pred = Scorers.check_if_tuple(y_pred)
538
-
539
- y_true_masked = y_true[missing_mask]
540
- y_pred_masked = y_pred[missing_mask]
541
-
542
- pr_ap = Scorers.compute_pr(
543
- y_true_masked, y_pred_masked, num_classes=num_classes
188
+ y_true = np.asarray(validate_input_type(y_true, return_type="array"))
189
+ y_pred = np.asarray(validate_input_type(y_pred, return_type="array"))
190
+ y_true_ohe = np.asarray(validate_input_type(y_true_ohe, return_type="array"))
191
+ y_pred_proba = np.asarray(
192
+ validate_input_type(y_pred_proba, return_type="array")
544
193
  )
545
194
 
546
- return pr_ap["f1_score"]
195
+ if not y_true.ndim < 3:
196
+ msg = "y_true must have 1 or 2 dimensions."
197
+ self.logger.error(msg)
198
+ raise ValueError(msg)
199
+
200
+ if not y_pred.ndim < 3:
201
+ msg = "y_pred must have 1 or 2 dimensions."
202
+ self.logger.error(msg)
203
+ raise ValueError(msg)
204
+
205
+ if not y_true_ohe.ndim == 2:
206
+ msg = "y_true_ohe must have 2 dimensions."
207
+ self.logger.error(msg)
208
+ raise ValueError(msg)
209
+
210
+ if y_pred_proba.ndim != 2:
211
+ y_pred_proba = y_pred_proba.reshape(-1, y_true_ohe.shape[-1])
212
+ self.logger.debug(f"Reshaped y_pred_proba to {y_pred_proba.shape}")
213
+
214
+ if objective_mode:
215
+ if tune_metric == "pr_macro":
216
+ metrics = {"pr_macro": self.pr_macro(y_true_ohe, y_pred_proba)}
217
+ elif tune_metric == "roc_auc":
218
+ metrics = {"roc_auc": self.roc_auc(y_true, y_pred_proba)}
219
+ elif tune_metric == "average_precision":
220
+ metrics = {
221
+ "average_precision": self.average_precision(y_true, y_pred_proba)
222
+ }
223
+ elif tune_metric == "accuracy":
224
+ metrics = {"accuracy": self.accuracy(y_true, y_pred)}
225
+ elif tune_metric == "f1":
226
+ metrics = {"f1": self.f1(y_true, y_pred)}
227
+ elif tune_metric == "precision":
228
+ metrics = {"precision": self.precision(y_true, y_pred)}
229
+ elif tune_metric == "recall":
230
+ metrics = {"recall": self.recall(y_true, y_pred)}
231
+ else:
232
+ msg = f"Invalid tune_metric provided: '{tune_metric}'."
233
+ self.logger.error(msg)
234
+ raise ValueError(msg)
235
+ else:
236
+ metrics = {
237
+ "accuracy": self.accuracy(y_true, y_pred),
238
+ "f1": self.f1(y_true, y_pred),
239
+ "precision": self.precision(y_true, y_pred),
240
+ "recall": self.recall(y_true, y_pred),
241
+ "roc_auc": self.roc_auc(y_true, y_pred_proba),
242
+ "average_precision": self.average_precision(y_true, y_pred_proba),
243
+ "pr_macro": self.pr_macro(y_true_ohe, y_pred_proba),
244
+ }
547
245
 
548
- @staticmethod
549
- def pr_micro(y_true, y_pred, **kwargs):
550
- """Get Precision-Recall score with micro averaging for grid search.
246
+ return {k: float(v) for k, v in metrics.items()}
551
247
 
552
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
248
+ def average_precision(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
249
+ """Average precision with safe multiclass handling.
553
250
 
554
- Args:
555
- y_true (numpy.ndarray): 012-encoded true target values.
251
+ If y_true is 1D of class indices, it is binarized against the number of columns in y_pred_proba. If y_true is already one-hot or indicator, it is used as-is.
556
252
 
557
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities.
558
-
559
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
253
+ Args:
254
+ y_true (np.ndarray): Ground truth labels (1D class indices or 2D one-hot/indicator).
255
+ y_pred_proba (np.ndarray): Predicted probabilities (2D array).
560
256
 
561
257
  Returns:
562
- float: Metric score by comparing y_true and y_pred.
258
+ float: Average precision score.
563
259
  """
564
- # Get missing mask if provided.
565
- # Otherwise default is all missing values (array all True).
566
- missing_mask = kwargs.get("missing_mask")
567
- num_classes = kwargs.get("num_classes", 3)
568
- testing = kwargs.get("testing", False)
260
+ y_true_arr = np.asarray(y_true)
261
+ y_proba_arr = np.asarray(y_pred_proba)
569
262
 
570
- y_pred = Scorers.check_if_tuple(y_pred)
263
+ if y_proba_arr.ndim == 3:
264
+ y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
571
265
 
572
- nn = NeuralNetworkMethods()
573
-
574
- y_true_masked = y_true[missing_mask]
575
- y_pred_masked = y_pred[missing_mask]
576
-
577
- pr_ap = Scorers.compute_pr(
578
- y_true_masked, y_pred_masked, num_classes=num_classes
579
- )
266
+ # If y_true already matches proba columns (one-hot / indicator)
267
+ if y_true_arr.ndim == 2 and y_true_arr.shape[1] == y_proba_arr.shape[1]:
268
+ y_bin = y_true_arr
269
+ else:
270
+ # Interpret y_true as class indices
271
+ n_classes = y_proba_arr.shape[1]
272
+ y_bin = label_binarize(y_true_arr.ravel(), classes=np.arange(n_classes))
580
273
 
581
- return pr_ap["micro"]
274
+ return float(average_precision_score(y_bin, y_proba_arr, average=self.average))
582
275
 
583
- @classmethod
584
- def make_multimetric_scorer(
585
- cls, metrics, missing_mask, num_classes=3, testing=False
586
- ):
587
- """Get all scoring metrics and make an sklearn scorer.
276
+ def pr_macro(self, y_true_ohe: np.ndarray, y_pred_proba: np.ndarray) -> float:
277
+ """Macro-averaged average precision (precision-recall AUC) across classes.
588
278
 
589
279
  Args:
590
- metrics (str or List[str]): Metrics to use with grid search. If string, it will be converted to a list of one element.
591
-
592
- missing_mask (numpy.ndarray): Missing mask to use to demarcate values to use with scoring.
280
+ y_true_ohe (np.ndarray): One-hot encoded ground truth labels (2D array).
281
+ y_pred_proba (np.ndarray): Predicted probabilities (2D array).
593
282
 
594
- num_classes (int, optional): How many classes to use. Defaults to 3.
595
-
596
- testing (bool, optional): True if in test mode, wherein it prints y_true and y_pred_decoded as 1D lists for comparison. Otherwise False. Defaults to False.
597
283
  Returns:
598
- Dict[str, Callable]: Dictionary with callable scoring functions to use with grid search as the values.
599
-
600
- Raises:
601
- ValueError: Invalid scoring metric provided.
284
+ float: Macro-averaged average precision score.
602
285
  """
603
- if isinstance(metrics, str):
604
- metrics = [metrics]
605
-
606
- scorers = dict()
607
- for item in metrics:
608
- if item.lower() == "accuracy":
609
- scorers["accuracy"] = make_scorer(
610
- cls.accuracy_scorer,
611
- missing_mask=missing_mask,
612
- num_classes=num_classes,
613
- testing=testing,
614
- )
615
- elif item.lower() == "hamming":
616
- scorers["hamming"] = make_scorer(
617
- cls.hamming_scorer,
618
- missing_mask=missing_mask,
619
- num_classes=num_classes,
620
- testing=testing,
621
- greater_is_better=False,
622
- )
623
- elif item.lower() == "auc_macro":
624
- scorers["auc_macro"] = make_scorer(
625
- cls.auc_macro,
626
- missing_mask=missing_mask,
627
- num_classes=num_classes,
628
- testing=testing,
629
- )
630
- elif item.lower() == "auc_micro":
631
- scorers["auc_micro"] = make_scorer(
632
- cls.auc_micro,
633
- missing_mask=missing_mask,
634
- num_classes=num_classes,
635
- testing=testing,
636
- )
637
- elif item.lower() == "precision_recall_macro":
638
- scorers["precision_recall_macro"] = make_scorer(
639
- cls.pr_macro,
640
- missing_mask=missing_mask,
641
- num_classes=num_classes,
642
- testing=testing,
643
- )
644
- elif item.lower() == "precision_recall_micro":
645
- scorers["precision_recall_micro"] = make_scorer(
646
- cls.pr_micro,
647
- missing_mask=missing_mask,
648
- num_classes=num_classes,
649
- testing=testing,
650
- )
651
- elif item.lower() == "precision_recall_samples":
652
- scorers["precision_recall_samples"] = make_scorer(
653
- cls.pr_samples,
654
- missing_mask=missing_mask,
655
- num_classes=num_classes,
656
- testing=testing,
657
- )
658
- elif item.lower() == "f1_score":
659
- scorers["f1_score"] = make_scorer(
660
- cls.f1_samples,
661
- missing_mask=missing_mask,
662
- num_classes=num_classes,
663
- testing=testing,
664
- )
665
- else:
666
- raise ValueError(f"Invalid scoring_metric provided: {item}")
667
- return scorers
668
-
669
- @staticmethod
670
- def scorer(y_true, y_pred, **kwargs):
671
- # Get missing mask if provided.
672
- # Otherwise default is all missing values (array all True).
673
- missing_mask = kwargs.get("missing_mask")
674
- nn_model = kwargs.get("nn_model", True)
675
- num_classes = kwargs.get("num_classes", 3)
676
- testing = kwargs.get("testing", False)
677
-
678
- is_multiclass = True if num_classes != 4 else False
679
-
680
- if nn_model:
681
- nn = NeuralNetworkMethods()
682
-
683
- # VAE has tuple output.
684
- if isinstance(y_pred, tuple):
685
- y_pred = y_pred[0]
686
-
687
- y_true_masked = y_true[missing_mask]
688
- y_pred_masked = y_pred[missing_mask]
689
-
690
- roc_auc = Scorers.compute_roc_auc_micro_macro(
691
- y_true_masked,
692
- y_pred_masked,
693
- num_classes=num_classes,
694
- binarize_pred=False,
695
- )
286
+ y_true_arr = np.asarray(y_true_ohe)
287
+ y_proba_arr = np.asarray(y_pred_proba)
696
288
 
697
- pr_ap = Scorers.compute_pr(
698
- y_true_masked,
699
- y_pred_masked,
700
- num_classes=num_classes,
701
- )
289
+ if y_proba_arr.ndim == 3:
290
+ y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
702
291
 
703
- acc = accuracy_score(
704
- y_true_masked,
705
- nn.decode_masked(
706
- y_true_masked,
707
- y_pred_masked,
708
- is_multiclass=is_multiclass,
709
- return_int=True,
710
- ),
711
- )
712
- ham = hamming_loss(
713
- y_true_masked,
714
- nn.decode_masked(
715
- y_true_masked,
716
- y_pred_masked,
717
- is_multiclass=is_multiclass,
718
- return_int=True,
719
- ),
720
- )
721
-
722
- if testing:
723
- y_pred_masked_decoded = nn.decode_masked(
724
- y_true_masked,
725
- y_pred_masked,
726
- is_multiclass=is_multiclass,
727
- return_int=True,
728
- )
292
+ # Ensure 2D indicator truth
293
+ if y_true_arr.ndim == 1:
294
+ n_classes = y_proba_arr.shape[1]
295
+ y_true_arr = label_binarize(y_true_arr, classes=np.arange(n_classes))
729
296
 
730
- bin_mapping = [np.array2string(x) for x in y_pred_masked]
731
-
732
- with open("genotype_dist.csv", "w") as fout:
733
- fout.write(
734
- "site,prob_vector,imputed_genotype,expected_genotype\n"
735
- )
736
- for i, (yt, yp, ypd) in enumerate(
737
- zip(y_true_masked, bin_mapping, y_pred_masked_decoded)
738
- ):
739
- fout.write(f"{i},{yp},{ypd},{yt}\n")
740
- # np.set_printoptions(threshold=np.inf)
741
- # print(y_true_masked)
742
- # print(y_pred_masked_decoded)
743
-
744
- metrics = dict()
745
- metrics["accuracy"] = acc
746
- metrics["roc_auc"] = roc_auc
747
- metrics["precision_recall"] = pr_ap
748
- metrics["hamming"] = ham
749
-
750
- return metrics
297
+ return float(average_precision_score(y_true_arr, y_proba_arr, average="macro"))