pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
pgsui/utils/scorers.py CHANGED
@@ -1,508 +1,282 @@
1
- import sys
1
+ from typing import Dict, Literal
2
2
 
3
3
  import numpy as np
4
-
5
4
  from sklearn.metrics import (
6
- roc_curve,
7
- auc,
8
5
  accuracy_score,
9
- hamming_loss,
10
- make_scorer,
11
- precision_recall_curve,
12
6
  average_precision_score,
13
- multilabel_confusion_matrix,
14
7
  f1_score,
8
+ precision_score,
9
+ recall_score,
15
10
  roc_auc_score,
16
11
  )
17
-
18
12
  from sklearn.preprocessing import label_binarize
19
-
20
- try:
21
- from ..impute.unsupervised.neural_network_methods import (
22
- NeuralNetworkMethods,
23
- )
24
- except (ModuleNotFoundError, ValueError, ImportError):
25
- from impute.unsupervised.neural_network_methods import NeuralNetworkMethods
13
+ from snpio.utils.logging import LoggerManager
14
+ from torch import Tensor
26
15
 
27
16
 
28
- class Scorers:
29
- @staticmethod
30
- def compute_roc_auc_micro_macro(
31
- y_true, y_pred, num_classes=3, binarize_pred=True
32
- ):
33
- """Compute ROC curve with AUC scores.
17
+ class Scorer:
18
+ """Class for evaluating the performance of a model using various metrics.
34
19
 
35
- ROC (Receiver Operating Characteristic) curves and AUC (area under curve) scores are computed per-class and for micro and macro averages.
20
+ This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
21
+ """
36
22
 
37
- Args:
38
- y_true (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_true should be integer-encoded.
39
-
40
- y_pred (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_pred should be probabilities.
23
+ def __init__(
24
+ self,
25
+ prefix: str,
26
+ average: Literal["micro", "macro", "weighted"] = "macro",
27
+ verbose: bool = False,
28
+ debug: bool = False,
29
+ ) -> None:
30
+ """Initialize a Scorer object.
41
31
 
42
- num_classes (int, optional): How many classes to use. Defaults to 3.
32
+ This class is used to evaluate the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The class can be used to evaluate the performance of a model on a dataset with ground truth labels. The class can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
43
33
 
44
- binarize_pred (bool, optional): Whether to binarize y_pred. If False, y_pred should be probabilities of each class. Defaults to True.
34
+ Args:
35
+ prefix (str): Prefix for logging messages.
36
+ average (Literal["micro", "macro", "weighted"]): Average method for metrics. Must be one of 'micro', 'macro', or 'weighted'.
37
+ verbose (bool): Verbosity level for logging messages. Default is False.
38
+ debug (bool): Debug mode for logging messages. Default is False.
45
39
 
46
- Returns:
47
- Dict[str, Any]: Dictionary with true and false positive rates along probability threshold curve per class, micro and macro tpr and fpr curves averaged across classes, and AUC scores per-class and for micro and macro averages.
40
+ Raises:
41
+ ValueError: If the average parameter is invalid. Must be one of 'micro', 'macro', or 'weighted'.
48
42
  """
49
- cats = range(num_classes)
50
-
51
- # Get only classes that appear in y_true.
52
- classes = [i for i in cats if i in y_true]
53
-
54
- # Binarize the output for use with ROC-AUC.
55
- y_true_bin = label_binarize(y_true, classes=cats)
56
-
57
- if binarize_pred:
58
- y_pred_bin = label_binarize(y_pred, classes=cats)
59
- else:
60
- y_pred_bin = y_pred
61
-
62
- for i in range(y_true_bin.shape[1]):
63
- if i not in classes:
64
- y_true_bin = np.delete(y_true_bin, i, axis=-1)
65
- y_pred_bin = np.delete(y_pred_bin, i, axis=-1)
66
-
67
- n_classes = len(classes)
68
-
69
- # Compute ROC curve and ROC area for each class.
70
- fpr = dict()
71
- tpr = dict()
72
- roc_auc = dict()
73
- for i, c in enumerate(classes):
74
- fpr[c], tpr[c], _ = roc_curve(y_true_bin[:, i], y_pred_bin[:, i])
75
- roc_auc[c] = auc(fpr[c], tpr[c])
76
-
77
- # Compute micro-average ROC curve and ROC area.
78
- fpr["micro"], tpr["micro"], _ = roc_curve(
79
- y_true_bin.ravel(), y_pred_bin.ravel()
43
+ logman = LoggerManager(
44
+ name=__name__, prefix=prefix, debug=debug, verbose=verbose >= 1
80
45
  )
46
+ self.logger = logman.get_logger()
81
47
 
82
- roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
83
-
84
- # Aggregate all false positive rates
85
- all_fpr = np.unique(np.concatenate([fpr[i] for i in classes]))
48
+ if average not in {"micro", "macro", "weighted"}:
49
+ msg = f"Invalid average parameter: {average}. Must be one of 'micro', 'macro', or 'weighted'."
50
+ self.logger.error(msg)
51
+ raise ValueError(msg)
86
52
 
87
- # Then interpolate all ROC curves at these points.
88
- mean_tpr = np.zeros_like(all_fpr)
89
- for i in classes:
90
- mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
53
+ self.average = average
91
54
 
92
- # Finally, average it and compute AUC.
93
- mean_tpr /= n_classes
55
+ def accuracy(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
56
+ """Calculate the accuracy of the model.
94
57
 
95
- fpr["macro"] = all_fpr
96
- tpr["macro"] = mean_tpr
58
+ This method calculates the accuracy of the model by comparing the ground truth labels with the predicted labels.
97
59
 
98
- roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
99
-
100
- roc_auc["fpr_macro"] = fpr["macro"]
101
- roc_auc["tpr_macro"] = tpr["macro"]
102
- roc_auc["fpr_micro"] = fpr["micro"]
103
- roc_auc["tpr_micro"] = tpr["micro"]
104
-
105
- for i in classes:
106
- roc_auc[f"fpr_{i}"] = fpr[i]
107
- roc_auc[f"tpr_{i}"] = tpr[i]
60
+ Args:
61
+ y_true (np.ndarray): Ground truth labels.
62
+ y_pred (np.ndarray): Predicted labels.
108
63
 
109
- return roc_auc
64
+ Returns:
65
+ float: Accuracy score.
66
+ """
67
+ return accuracy_score(y_true, y_pred)
110
68
 
111
- @staticmethod
112
- def compute_pr(y_true, y_pred, use_int_encodings=False, num_classes=4):
113
- """Compute precision-recall curve with Average Precision scores.
69
+ def f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
70
+ """Calculate the F1 score of the model.
114
71
 
115
- PR and AP scores are computed per-class and for micro and macro averages.
72
+ This method calculates the F1 score of the model by comparing the ground truth labels with the predicted labels.
116
73
 
117
74
  Args:
118
- y_true (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,).
119
-
120
- y_pred (numpy.ndarray): Ravelled numpy array of shape (n_samples * n_features,). y_pred should be integer-encoded.
75
+ y_true (np.ndarray): Ground truth labels.
76
+ y_pred (np.ndarray): Predicted labels.
121
77
 
122
- use_int_encodings (bool, optional): Whether the imputer model is a neural network model. Defaults to False.
123
-
124
- num_classes (int, optional): How many classes to use. Defaults to 3.
125
-
126
- Returns:
127
- Dict[str, Any]: Dictionary with precision and recall curves per class and micro and macro averaged across classes, plus AP scores per-class and for micro and macro averages.
78
+ Returns:
79
+ float: F1 score.
128
80
  """
129
- cats = range(num_classes)
81
+ return f1_score(y_true, y_pred, average=self.average, zero_division=0.0)
130
82
 
131
- is_multiclass = True if num_classes != 4 else False
83
+ def precision(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
84
+ """Calculate the precision of the model.
132
85
 
133
- # Get only classes that appear in y_true.
134
- classes = [i for i in cats if i in y_true]
86
+ This method calculates the precision of the model by comparing the ground truth labels with the predicted labels.
135
87
 
136
- # Binarize the output for use with ROC-AUC.
137
- y_true_bin = label_binarize(y_true, classes=cats)
138
- y_pred_proba_bin = y_pred
139
-
140
- if is_multiclass:
141
- for i in range(y_true_bin.shape[1]):
142
- if i not in classes:
143
- y_true_bin = np.delete(y_true_bin, i, axis=-1)
144
- y_pred_proba_bin = np.delete(y_pred_proba_bin, i, axis=-1)
145
-
146
- nn = NeuralNetworkMethods()
147
- if len(y_true.shape) == 1 or y_true.shape[1] != num_classes:
148
- y_true = label_binarize(y_true, classes=cats)
88
+ Args:
89
+ y_true (np.ndarray): Ground truth labels.
90
+ y_pred (np.ndarray): Predicted labels.
149
91
 
150
- # Ensure y_pred_012 is in the multilabel format
151
- if use_int_encodings:
152
- y_pred_012 = nn.decode_masked(
153
- y_true,
154
- y_pred_proba_bin,
155
- return_multilab=True, # Ensure multilabel format is returned
156
- )
157
- else:
158
- y_pred_012 = nn.decode_masked(
159
- y_true,
160
- y_pred_proba_bin,
161
- is_multiclass=False,
162
- return_int=False,
163
- return_multilab=True,
164
- )
92
+ Returns:
93
+ float: Precision score.
94
+ """
95
+ return precision_score(y_true, y_pred, average=self.average, zero_division=0.0)
165
96
 
166
- # Make confusion matrix to get true negatives and true positives.
167
- mcm = multilabel_confusion_matrix(y_true, y_pred_012)
97
+ def recall(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
98
+ """Calculate the recall of the model.
168
99
 
169
- tn = np.sum(mcm[:, 0, 0])
170
- tn /= num_classes
100
+ This method calculates the recall of the model by comparing the ground truth labels with the predicted labels.
171
101
 
172
- tp = np.sum(mcm[:, 1, 1])
173
- tp /= num_classes
102
+ Args:
103
+ y_true (np.ndarray): Ground truth labels.
104
+ y_pred (np.ndarray): Predicted labels.
174
105
 
175
- baseline = tp / (tn + tp)
106
+ Returns:
107
+ float: Recall score.
108
+ """
109
+ return recall_score(y_true, y_pred, average=self.average, zero_division=0.0)
176
110
 
177
- precision = dict()
178
- recall = dict()
179
- average_precision = dict()
111
+ def roc_auc(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
112
+ """Multiclass ROC-AUC with label targets.
180
113
 
181
- for i, c in enumerate(classes):
182
- precision[c], recall[c], _ = precision_recall_curve(
183
- y_true_bin[:, i], y_pred_proba_bin[:, i]
184
- )
185
- average_precision[c] = average_precision_score(
186
- y_true_bin[:, i], y_pred_proba_bin[:, i]
187
- )
114
+ This method calculates the ROC-AUC score for multiclass classification problems. It handles both 1D integer labels and 2D one-hot/indicator matrices for the ground truth labels.
188
115
 
189
- # A "micro-average": quantifying score on all classes jointly.
190
- precision["micro"], recall["micro"], _ = precision_recall_curve(
191
- y_true_bin.ravel(), y_pred_proba_bin.ravel()
192
- )
116
+ Args:
117
+ y_true: 1D integer labels (shape: [n]).
118
+ If a one-hot/indicator matrix is supplied, we convert to labels.
119
+ y_pred_proba: 2D probabilities (shape: [n, n_classes]).
120
+ """
121
+ y_true = np.asarray(y_true)
122
+ y_pred_proba = np.asarray(y_pred_proba)
193
123
 
194
- average_precision["micro"] = average_precision_score(
195
- y_true_bin, y_pred_proba_bin, average="micro"
196
- )
124
+ if y_pred_proba.ndim == 3:
125
+ y_pred_proba = y_pred_proba.reshape(-1, y_pred_proba.shape[-1])
197
126
 
198
- average_precision["macro"] = average_precision_score(
199
- y_true_bin, y_pred_proba_bin, average="macro"
200
- )
127
+ # If user passed indicator/one-hot, convert to labels.
128
+ if y_true.ndim == 2 and y_true.shape[1] == y_pred_proba.shape[1]:
129
+ y_true = y_true.argmax(axis=1)
201
130
 
202
- average_precision["weighted"] = average_precision_score(
203
- y_true_bin, y_pred_proba_bin, average="weighted"
204
- )
131
+ # Guard: need >1 class present for AUC
132
+ if np.unique(y_true).size < 2:
133
+ return 0.5
205
134
 
206
- if use_int_encodings:
207
- y_pred_012 = (
208
- nn.decode_masked(
209
- y_true_bin,
210
- y_pred_proba_bin,
211
- return_multilab=True,
212
- predict_still_missing=False,
213
- ),
135
+ return float(
136
+ roc_auc_score(
137
+ y_true,
138
+ y_pred_proba,
139
+ multi_class="ovr",
140
+ average=self.average,
214
141
  )
142
+ )
215
143
 
216
- f1 = f1_score(y_true_bin, y_pred_012, average="macro")
217
- f1_weighted = f1_score(y_true_bin, y_pred_012, average="weighted")
218
-
219
- # Aggregate all recalls
220
- all_recall = np.unique(np.concatenate([recall[i] for i in classes]))
221
-
222
- # Then interpolate all PR curves at these points.
223
- mean_precision = np.zeros_like(all_recall)
224
- for i in classes:
225
- mean_precision += np.interp(all_recall, precision[i], recall[i])
226
-
227
- # Finally, average it and compute AUC.
228
- mean_precision /= num_classes
229
-
230
- recall["macro"] = all_recall
231
- precision["macro"] = mean_precision
232
-
233
- results = dict()
234
-
235
- results["micro"] = average_precision["micro"]
236
- results["macro"] = average_precision["macro"]
237
- results["f1_score"] = f1
238
- results["f1_score_weighted"] = f1_weighted
239
- results["recall_macro"] = all_recall
240
- results["precision_macro"] = mean_precision
241
- results["recall_micro"] = recall["micro"]
242
- results["precision_micro"] = precision["micro"]
243
-
244
- for i in classes:
245
- results[f"recall_{i}"] = recall[i]
246
- results[f"precision_{i}"] = precision[i]
247
- results[i] = average_precision[i]
248
- results["baseline"] = baseline
249
-
250
- return results
251
-
252
- @staticmethod
253
- def check_if_tuple(y_pred):
254
- """Checks if y_pred is a tuple and if so, returns the first element of the tuple."""
255
- if isinstance(y_pred, tuple):
256
- y_pred = y_pred[0]
257
- return y_pred
258
-
259
- @staticmethod
260
- def accuracy_scorer(y_true, y_pred, **kwargs):
261
- """Get accuracy score for grid search.
262
-
263
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
144
+ def evaluate(
145
+ self,
146
+ y_true: np.ndarray | Tensor | list,
147
+ y_pred: np.ndarray | Tensor | list,
148
+ y_true_ohe: np.ndarray | Tensor | list,
149
+ y_pred_proba: np.ndarray | Tensor | list,
150
+ objective_mode: bool = False,
151
+ tune_metric: Literal[
152
+ "pr_macro",
153
+ "roc_auc",
154
+ "average_precision",
155
+ "accuracy",
156
+ "f1",
157
+ "precision",
158
+ "recall",
159
+ ] = "pr_macro",
160
+ ) -> Dict[str, float]:
161
+ """Evaluate the model using various metrics.
162
+
163
+ This method evaluates the performance of a model using various metrics, such as accuracy, F1 score, precision, recall, average precision, and ROC AUC. The method can be used to evaluate the performance of a model on a dataset with ground truth labels. The method can also be used to evaluate the performance of a model in objective mode for hyperparameter tuning.
264
164
 
265
165
  Args:
266
- y_true (numpy.ndarray): 012-encoded true target values.
267
-
268
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities. They must first be decoded to use with accuracy_score.
269
-
270
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
166
+ y_true (np.ndarray | torch.Tensor): Ground truth labels.
167
+ y_pred (np.ndarray | torch.Tensor): Predicted labels.
168
+ y_true_ohe (np.ndarray | torch.Tensor): One-hot encoded ground truth labels.
169
+ y_pred_proba (np.ndarray | torch.Tensor): Predicted probabilities.
170
+ objective_mode (bool): Whether to use objective mode for evaluation. Default is False.
171
+ tune_metric (Literal["pr_macro", "roc_auc", "average_precision", "accuracy", "f1", "precision", "recall"]): Metric to use for tuning. Ignored if `objective_mode` is False. Default is 'pr_macro'.
271
172
 
272
173
  Returns:
273
- float: Metric score by comparing y_true and y_pred.
274
- """
275
- missing_mask = kwargs.get("missing_mask")
276
-
277
- y_pred = Scorers.check_if_tuple(y_pred)
278
-
279
- y_true_masked = y_true[missing_mask]
280
- y_pred_masked = y_pred[missing_mask]
174
+ Dict[str, float]: Dictionary of evaluation metrics. Keys are 'accuracy', 'f1', 'precision', 'recall', 'roc_auc', 'average_precision', and 'pr_macro'.
281
175
 
282
- nn = NeuralNetworkMethods()
283
- y_pred_masked_decoded = nn.decode_masked(
284
- y_true_masked, y_pred_masked, predict_still_missing=False
285
- )
176
+ Raises:
177
+ ValueError: If the input data is invalid.
178
+ ValueError: If an invalid tune_metric is provided.
179
+ """
180
+ if not y_true.ndim < 3:
181
+ msg = "y_true must have 1 or 2 dimensions."
182
+ self.logger.error(msg)
183
+ raise ValueError(msg)
184
+
185
+ if not y_pred.ndim < 3:
186
+ msg = "y_pred must have 1 or 2 dimensions."
187
+ self.logger.error(msg)
188
+ raise ValueError(msg)
189
+
190
+ if not y_true_ohe.ndim == 2:
191
+ msg = "y_true_ohe must have 2 dimensions."
192
+ self.logger.error(msg)
193
+ raise ValueError(msg)
194
+
195
+ if y_pred_proba.ndim != 2:
196
+ y_pred_proba = y_pred_proba.reshape(-1, y_true_ohe.shape[-1])
197
+ self.logger.debug(f"Reshaped y_pred_proba to {y_pred_proba.shape}")
198
+
199
+ if objective_mode:
200
+ if tune_metric == "pr_macro":
201
+ metrics = {"pr_macro": self.pr_macro(y_true_ohe, y_pred_proba)}
202
+ elif tune_metric == "roc_auc":
203
+ metrics = {"roc_auc": self.roc_auc(y_true, y_pred_proba)}
204
+ elif tune_metric == "average_precision":
205
+ metrics = {
206
+ "average_precision": self.average_precision(y_true, y_pred_proba)
207
+ }
208
+ elif tune_metric == "accuracy":
209
+ metrics = {"accuracy": self.accuracy(y_true, y_pred)}
210
+ elif tune_metric == "f1":
211
+ metrics = {"f1": self.f1(y_true, y_pred)}
212
+ elif tune_metric == "precision":
213
+ metrics = {"precision": self.precision(y_true, y_pred)}
214
+ elif tune_metric == "recall":
215
+ metrics = {"recall": self.recall(y_true, y_pred)}
216
+ else:
217
+ msg = f"Invalid tune_metric provided: '{tune_metric}'."
218
+ self.logger.error(msg)
219
+ raise ValueError(msg)
220
+ else:
221
+ metrics = {
222
+ "accuracy": self.accuracy(y_true, y_pred),
223
+ "f1": self.f1(y_true, y_pred),
224
+ "precision": self.precision(y_true, y_pred),
225
+ "recall": self.recall(y_true, y_pred),
226
+ "roc_auc": self.roc_auc(y_true, y_pred_proba),
227
+ "average_precision": self.average_precision(y_true, y_pred_proba),
228
+ "pr_macro": self.pr_macro(y_true_ohe, y_pred_proba),
229
+ }
286
230
 
287
- return accuracy_score(y_true_masked, y_pred_masked_decoded)
231
+ return {k: float(v) for k, v in metrics.items()}
288
232
 
289
- @staticmethod
290
- def hamming_scorer(y_true, y_pred, **kwargs):
291
- """Get Hamming score for grid search.
233
+ def average_precision(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
234
+ """Average precision with safe multiclass handling.
292
235
 
293
- If provided, only calculates score where missing_mask is True (i.e., data were missing). This is so that users can simulate missing data for known values, and then the predictions for only those known values can be evaluated.
236
+ If y_true is 1D of class indices, it is binarized against the number of columns in y_pred_proba. If y_true is already one-hot or indicator, it is used as-is.
294
237
 
295
238
  Args:
296
- y_true (numpy.ndarray): 012-encoded true target values.
297
-
298
- y_pred (tensorflow.EagerTensor): Predictions from model as probabilities. They must first be decoded to use with hamming_scorer.
299
-
300
- kwargs (Any): Keyword arguments to use with scorer. Supported options include ``missing_mask`` and ``testing``\.
239
+ y_true (np.ndarray): Ground truth labels (1D class indices or 2D one-hot/indicator).
240
+ y_pred_proba (np.ndarray): Predicted probabilities (2D array).
301
241
 
302
242
  Returns:
303
- float: Metric score by comparing y_true and y_pred.
243
+ float: Average precision score.
304
244
  """
305
- missing_mask = kwargs.get("missing_mask")
306
-
307
- y_pred = Scorers.check_if_tuple(y_pred)
308
-
309
- y_true_masked = y_true[missing_mask]
310
- y_pred_masked = y_pred[missing_mask]
311
-
312
- nn = NeuralNetworkMethods()
313
- y_pred_masked_decoded = nn.decode_masked(
314
- y_true_masked,
315
- y_pred_masked,
316
- predict_still_missing=False,
317
- )
318
-
319
- return hamming_loss(y_true_masked, y_pred_masked_decoded)
320
-
321
- @staticmethod
322
- def compute_metric(
323
- y_true, y_pred, metric_type, scoring_function, **kwargs
324
- ):
325
- y_true = np.array(y_true)
326
- y_pred = np.array(y_pred)
245
+ y_true_arr = np.asarray(y_true)
246
+ y_proba_arr = np.asarray(y_pred_proba)
327
247
 
328
- num_classes = kwargs.get("num_classes", 4)
329
- cats = range(num_classes)
248
+ if y_proba_arr.ndim == 3:
249
+ y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
330
250
 
331
- y_true_bin = label_binarize(y_true, classes=cats)
332
-
333
- if scoring_function == "roc_auc":
334
- return roc_auc_score(
335
- y_true_bin, y_pred, multi_class="ovr", average=metric_type
336
- )
337
- elif scoring_function == "f1":
338
- is_multiclass = num_classes != 4
339
- y_pred_proba_bin = y_pred
340
- nn = NeuralNetworkMethods()
341
- y_pred_bin = nn.decode_masked(
342
- y_true_bin,
343
- y_pred_proba_bin,
344
- is_multiclass=is_multiclass,
345
- return_int=False,
346
- return_multilab=True,
347
- )
348
- return f1_score(
349
- y_true_bin, y_pred_bin, average=metric_type, zero_division=0
350
- )
351
- elif scoring_function == "average_precision":
352
- return average_precision_score(
353
- y_true_bin, y_pred, average=metric_type
354
- )
251
+ # If y_true already matches proba columns (one-hot / indicator)
252
+ if y_true_arr.ndim == 2 and y_true_arr.shape[1] == y_proba_arr.shape[1]:
253
+ y_bin = y_true_arr
355
254
  else:
356
- raise ValueError(
357
- f"Unsupported scoring function: {scoring_function}"
358
- )
359
-
360
- @staticmethod
361
- def compute_score(y_true, y_pred, metric_type, scoring_function, **kwargs):
362
- missing_mask = kwargs.get("missing_mask")
363
- y_pred = Scorers.check_if_tuple(y_pred)
364
-
365
- y_true_masked = y_true[missing_mask]
366
- y_pred_masked = y_pred[missing_mask]
367
-
368
- return Scorers.compute_metric(
369
- y_true_masked,
370
- y_pred_masked,
371
- metric_type=metric_type,
372
- scoring_function=scoring_function,
373
- **kwargs,
374
- )
375
-
376
- @classmethod
377
- def make_multimetric_scorer(
378
- cls, metrics, missing_mask, num_classes=4, testing=False
379
- ):
380
- if isinstance(metrics, str):
381
- metrics = [metrics]
382
-
383
- metric_map = {
384
- "roc_auc_macro": ("macro", "roc_auc"),
385
- "roc_auc_micro": ("micro", "roc_auc"),
386
- "roc_auc_weighted": ("weighted", "roc_auc"),
387
- "f1_micro": ("micro", "f1"),
388
- "f1_macro": ("macro", "f1"),
389
- "f1_weighted": ("weighted", "f1"),
390
- "average_precision_macro": ("macro", "average_precision"),
391
- "average_precision_micro": ("micro", "average_precision"),
392
- "average_precision_weighted": ("weighted", "average_precision"),
393
- }
394
-
395
- default_params = {
396
- "missing_mask": missing_mask,
397
- "num_classes": num_classes,
398
- "testing": testing,
399
- }
400
-
401
- scorers = dict()
402
- for item in metrics:
403
- item = item.lower()
404
-
405
- if item in metric_map:
406
- metric_type, scoring_function = metric_map[item]
407
- params = default_params.copy()
408
- scorers[item] = make_scorer(
409
- Scorers.compute_score,
410
- metric_type=metric_type,
411
- scoring_function=scoring_function,
412
- **params,
413
- )
414
- elif item == "accuracy":
415
- scorers[item] = make_scorer(
416
- cls.accuracy_scorer, **default_params
417
- )
418
- elif item == "hamming":
419
- scorers[item] = make_scorer(
420
- cls.hamming_scorer, **default_params
421
- )
422
- else:
423
- raise ValueError(f"Unsupported metric: {item}")
255
+ # Interpret y_true as class indices
256
+ n_classes = y_proba_arr.shape[1]
257
+ y_bin = label_binarize(y_true_arr.ravel(), classes=np.arange(n_classes))
424
258
 
425
- return scorers
259
+ return float(average_precision_score(y_bin, y_proba_arr, average=self.average))
426
260
 
427
- @staticmethod
428
- def scorer(y_true, y_pred, **kwargs):
429
- # Get missing mask if provided.
430
- # Otherwise default is all missing values (array all True).
431
- missing_mask = kwargs.get("missing_mask")
432
- nn_model = kwargs.get("nn_model", True)
433
- num_classes = kwargs.get("num_classes", 3)
434
- testing = kwargs.get("testing", False)
261
+ def pr_macro(self, y_true_ohe: np.ndarray, y_pred_proba: np.ndarray) -> float:
262
+ """Macro-averaged average precision (precision-recall AUC) across classes.
435
263
 
436
- is_multiclass = True if num_classes != 4 else False
437
-
438
- if nn_model:
439
- nn = NeuralNetworkMethods()
440
-
441
- # VAE has tuple output.
442
- if isinstance(y_pred, tuple):
443
- y_pred = y_pred[0]
444
-
445
- y_true_masked = y_true[missing_mask]
446
- y_pred_masked = y_pred[missing_mask]
447
-
448
- roc_auc = Scorers.compute_roc_auc_micro_macro(
449
- y_true_masked,
450
- y_pred_masked,
451
- num_classes=num_classes,
452
- binarize_pred=False,
453
- )
264
+ Args:
265
+ y_true_ohe (np.ndarray): One-hot encoded ground truth labels (2D array).
266
+ y_pred_proba (np.ndarray): Predicted probabilities (2D array).
454
267
 
455
- pr_ap = Scorers.compute_pr(
456
- y_true_masked,
457
- y_pred_masked,
458
- num_classes=num_classes,
459
- )
268
+ Returns:
269
+ float: Macro-averaged average precision score.
270
+ """
271
+ y_true_arr = np.asarray(y_true_ohe)
272
+ y_proba_arr = np.asarray(y_pred_proba)
460
273
 
461
- acc = accuracy_score(
462
- y_true_masked,
463
- nn.decode_masked(
464
- y_true_masked,
465
- y_pred_masked,
466
- is_multiclass=is_multiclass,
467
- return_int=True,
468
- ),
469
- )
470
- ham = hamming_loss(
471
- y_true_masked,
472
- nn.decode_masked(
473
- y_true_masked,
474
- y_pred_masked,
475
- is_multiclass=is_multiclass,
476
- return_int=True,
477
- ),
478
- )
274
+ if y_proba_arr.ndim == 3:
275
+ y_proba_arr = y_proba_arr.reshape(-1, y_proba_arr.shape[-1])
479
276
 
480
- if testing:
481
- y_pred_masked_decoded = nn.decode_masked(
482
- y_true_masked,
483
- y_pred_masked,
484
- is_multiclass=is_multiclass,
485
- return_int=True,
486
- )
277
+ # Ensure 2D indicator truth
278
+ if y_true_arr.ndim == 1:
279
+ n_classes = y_proba_arr.shape[1]
280
+ y_true_arr = label_binarize(y_true_arr, classes=np.arange(n_classes))
487
281
 
488
- bin_mapping = [np.array2string(x) for x in y_pred_masked]
489
-
490
- with open("genotype_dist.csv", "w") as fout:
491
- fout.write(
492
- "site,prob_vector,imputed_genotype,expected_genotype\n"
493
- )
494
- for i, (yt, yp, ypd) in enumerate(
495
- zip(y_true_masked, bin_mapping, y_pred_masked_decoded)
496
- ):
497
- fout.write(f"{i},{yp},{ypd},{yt}\n")
498
- # np.set_printoptions(threshold=np.inf)
499
- # print(y_true_masked)
500
- # print(y_pred_masked_decoded)
501
-
502
- metrics = dict()
503
- metrics["accuracy"] = acc
504
- metrics["roc_auc"] = roc_auc
505
- metrics["precision_recall"] = pr_ap
506
- metrics["hamming"] = ham
507
-
508
- return metrics
282
+ return float(average_precision_score(y_true_arr, y_proba_arr, average="macro"))