valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. valor_lite/cache/__init__.py +11 -0
  2. valor_lite/cache/compute.py +211 -0
  3. valor_lite/cache/ephemeral.py +302 -0
  4. valor_lite/cache/persistent.py +536 -0
  5. valor_lite/classification/__init__.py +5 -10
  6. valor_lite/classification/annotation.py +4 -0
  7. valor_lite/classification/computation.py +233 -251
  8. valor_lite/classification/evaluator.py +882 -0
  9. valor_lite/classification/loader.py +97 -0
  10. valor_lite/classification/metric.py +141 -4
  11. valor_lite/classification/shared.py +184 -0
  12. valor_lite/classification/utilities.py +221 -118
  13. valor_lite/exceptions.py +5 -0
  14. valor_lite/object_detection/__init__.py +5 -4
  15. valor_lite/object_detection/annotation.py +13 -1
  16. valor_lite/object_detection/computation.py +368 -299
  17. valor_lite/object_detection/evaluator.py +804 -0
  18. valor_lite/object_detection/loader.py +292 -0
  19. valor_lite/object_detection/metric.py +152 -3
  20. valor_lite/object_detection/shared.py +206 -0
  21. valor_lite/object_detection/utilities.py +182 -100
  22. valor_lite/semantic_segmentation/__init__.py +5 -4
  23. valor_lite/semantic_segmentation/annotation.py +7 -0
  24. valor_lite/semantic_segmentation/computation.py +20 -110
  25. valor_lite/semantic_segmentation/evaluator.py +414 -0
  26. valor_lite/semantic_segmentation/loader.py +205 -0
  27. valor_lite/semantic_segmentation/shared.py +149 -0
  28. valor_lite/semantic_segmentation/utilities.py +6 -23
  29. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
  30. valor_lite-0.37.5.dist-info/RECORD +49 -0
  31. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
  32. valor_lite/classification/manager.py +0 -545
  33. valor_lite/object_detection/manager.py +0 -864
  34. valor_lite/profiling.py +0 -374
  35. valor_lite/semantic_segmentation/benchmark.py +0 -237
  36. valor_lite/semantic_segmentation/manager.py +0 -446
  37. valor_lite-0.36.6.dist-info/RECORD +0 -41
  38. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from enum import IntFlag, auto
1
+ from enum import IntFlag
2
2
 
3
3
  import numpy as np
4
4
  from numpy.typing import NDArray
@@ -6,256 +6,141 @@ from numpy.typing import NDArray
6
6
  import valor_lite.classification.numpy_compatibility as npc
7
7
 
8
8
 
9
- def compute_label_metadata(
10
- ids: NDArray[np.int32],
9
+ def compute_rocauc(
10
+ rocauc: NDArray[np.float64],
11
+ array: NDArray[np.float64],
12
+ gt_count_per_label: NDArray[np.uint64],
13
+ pd_count_per_label: NDArray[np.uint64],
11
14
  n_labels: int,
12
- ) -> NDArray[np.int32]:
15
+ prev: NDArray[np.uint64],
16
+ ) -> tuple[NDArray[np.float64], NDArray[np.uint64]]:
13
17
  """
14
- Computes label metadata returning a count of annotations per label.
18
+ Compute ROCAUC.
15
19
 
16
20
  Parameters
17
21
  ----------
18
- detailed_pairs : NDArray[np.int32]
19
- Detailed annotation pairings with shape (n_pairs, 3).
20
- Index 0 - Datum Index
21
- Index 1 - GroundTruth Label Index
22
- Index 2 - Prediction Label Index
22
+ rocauc : NDArray[np.float64]
23
+ The running ROCAUC.
24
+ array : NDArray[np.float64]
25
+ An sorted array of ROCAUC intermediate values with shape (n_pairs, 3).
26
+ Index 0 - Prediction Label Index
27
+ Index 1 - Cumulative FP
28
+ Index 2 - Cumulative TP
29
+ gt_count_per_label : NDArray[np.uint64]
30
+ The number of ground truth occurences per label.
31
+ pd_count_per_label : NDArray[np.uint64]
32
+ The number of prediction occurences per label.
23
33
  n_labels : int
24
- The total number of unique labels.
34
+ The total number of unqiue labels.
35
+ prev : NDArray[np.uint64]
36
+ The previous cumulative sum for FP's and TP's. Used as intermediate in chunking operations.
25
37
 
26
38
  Returns
27
39
  -------
28
- NDArray[np.int32]
29
- The label metadata array with shape (n_labels, 2).
30
- Index 0 - Ground truth label count
31
- Index 1 - Prediction label count
32
- """
33
- label_metadata = np.zeros((n_labels, 2), dtype=np.int32)
34
- ground_truth_pairs = ids[:, (0, 1)]
35
- ground_truth_pairs = ground_truth_pairs[ground_truth_pairs[:, 1] >= 0]
36
- unique_pairs = np.unique(ground_truth_pairs, axis=0)
37
- label_indices, unique_counts = np.unique(
38
- unique_pairs[:, 1], return_counts=True
39
- )
40
- label_metadata[label_indices.astype(np.int32), 0] = unique_counts
41
-
42
- prediction_pairs = ids[:, (0, 2)]
43
- prediction_pairs = prediction_pairs[prediction_pairs[:, 1] >= 0]
44
- unique_pairs = np.unique(prediction_pairs, axis=0)
45
- label_indices, unique_counts = np.unique(
46
- unique_pairs[:, 1], return_counts=True
47
- )
48
- label_metadata[label_indices.astype(np.int32), 1] = unique_counts
49
-
50
- return label_metadata
51
-
52
-
53
- def filter_cache(
54
- detailed_pairs: NDArray[np.float64],
55
- datum_mask: NDArray[np.bool_],
56
- valid_label_indices: NDArray[np.int32] | None,
57
- n_labels: int,
58
- ) -> tuple[NDArray[np.float64], NDArray[np.int32]]:
59
- # filter by datum
60
- detailed_pairs = detailed_pairs[datum_mask].copy()
61
-
62
- n_rows = detailed_pairs.shape[0]
63
- mask_invalid_groundtruths = np.zeros(n_rows, dtype=np.bool_)
64
- mask_invalid_predictions = np.zeros_like(mask_invalid_groundtruths)
65
-
66
- # filter labels
67
- if valid_label_indices is not None:
68
- mask_invalid_groundtruths[
69
- ~np.isin(detailed_pairs[:, 1], valid_label_indices)
70
- ] = True
71
- mask_invalid_predictions[
72
- ~np.isin(detailed_pairs[:, 2], valid_label_indices)
73
- ] = True
74
-
75
- # filter cache
76
- if mask_invalid_groundtruths.any():
77
- invalid_groundtruth_indices = np.where(mask_invalid_groundtruths)[0]
78
- detailed_pairs[invalid_groundtruth_indices[:, None], 1] = np.array(
79
- [[-1.0]]
80
- )
81
-
82
- if mask_invalid_predictions.any():
83
- invalid_prediction_indices = np.where(mask_invalid_predictions)[0]
84
- detailed_pairs[
85
- invalid_prediction_indices[:, None], (2, 3, 4)
86
- ] = np.array([[-1.0, -1.0, -1.0]])
87
-
88
- # filter null pairs
89
- mask_null_pairs = np.all(
90
- np.isclose(
91
- detailed_pairs[:, 1:5],
92
- np.array([-1.0, -1.0, -1.0, -1.0]),
93
- ),
94
- axis=1,
95
- )
96
- detailed_pairs = detailed_pairs[~mask_null_pairs]
97
-
98
- detailed_pairs = np.unique(detailed_pairs, axis=0)
99
- indices = np.lexsort(
100
- (
101
- detailed_pairs[:, 1], # ground truth
102
- detailed_pairs[:, 2], # prediction
103
- -detailed_pairs[:, 3], # score
104
- )
105
- )
106
- detailed_pairs = detailed_pairs[indices]
107
- label_metadata = compute_label_metadata(
108
- ids=detailed_pairs[:, :3].astype(np.int32),
109
- n_labels=n_labels,
110
- )
111
- return detailed_pairs, label_metadata
112
-
113
-
114
- def _compute_rocauc(
115
- data: NDArray[np.float64],
116
- label_metadata: NDArray[np.int32],
117
- n_datums: int,
118
- n_labels: int,
119
- mask_matching_labels: NDArray[np.bool_],
120
- pd_labels: NDArray[np.int32],
121
- ) -> tuple[NDArray[np.float64], float]:
122
- """
123
- Compute ROCAUC and mean ROCAUC.
40
+ NDArray[np.float64]
41
+ ROCAUC.
42
+ NDArray[np.uint64]
43
+ The previous cumulative sum for FP's. Used as intermediate in chunking operations.
44
+ NDArray[np.uint64]
45
+ The previous cumulative sum for TP's. Used as intermediate in chunking operations.
124
46
  """
125
- positive_count = label_metadata[:, 0]
126
- negative_count = label_metadata[:, 1] - label_metadata[:, 0]
47
+ pd_labels = array[:, 0]
48
+ cumulative_fp = array[:, 1]
49
+ cumulative_tp = array[:, 2]
127
50
 
128
- true_positives = np.zeros((n_labels, n_datums), dtype=np.int32)
129
- false_positives = np.zeros_like(true_positives)
130
- scores = np.zeros_like(true_positives, dtype=np.float64)
51
+ positive_count = gt_count_per_label
52
+ negative_count = pd_count_per_label - gt_count_per_label
131
53
 
132
54
  for label_idx in range(n_labels):
133
- if label_metadata[label_idx, 1] == 0:
134
- continue
135
-
136
55
  mask_pds = pd_labels == label_idx
137
- true_positives[label_idx] = mask_matching_labels[mask_pds]
138
- false_positives[label_idx] = ~mask_matching_labels[mask_pds]
139
- scores[label_idx] = data[mask_pds, 3]
140
-
141
- cumulative_fp = np.cumsum(false_positives, axis=1)
142
- cumulative_tp = np.cumsum(true_positives, axis=1)
56
+ n_masked_pds = mask_pds.sum()
57
+ if pd_count_per_label[label_idx] == 0 or n_masked_pds == 0:
58
+ continue
143
59
 
144
- fpr = np.zeros_like(true_positives, dtype=np.float64)
145
- np.divide(
146
- cumulative_fp,
147
- negative_count[:, np.newaxis],
148
- where=negative_count[:, np.newaxis] > 1e-9,
149
- out=fpr,
150
- )
151
- tpr = np.zeros_like(true_positives, dtype=np.float64)
152
- np.divide(
153
- cumulative_tp,
154
- positive_count[:, np.newaxis],
155
- where=positive_count[:, np.newaxis] > 1e-9,
156
- out=tpr,
157
- )
60
+ fps = cumulative_fp[mask_pds]
61
+ tps = cumulative_tp[mask_pds]
62
+ if prev[label_idx, 0] > 0 or prev[label_idx, 1] > 0:
63
+ fps = np.r_[prev[label_idx, 0], fps]
64
+ tps = np.r_[prev[label_idx, 1], tps]
158
65
 
159
- # sort by -tpr, -score
160
- indices = np.lexsort((-tpr, -scores), axis=1)
161
- fpr = np.take_along_axis(fpr, indices, axis=1)
162
- tpr = np.take_along_axis(tpr, indices, axis=1)
66
+ prev[label_idx, 0] = fps[-1]
67
+ prev[label_idx, 1] = tps[-1]
163
68
 
164
- # running max of tpr
165
- np.maximum.accumulate(tpr, axis=1, out=tpr)
69
+ if fps.size == 1:
70
+ continue
166
71
 
167
- # compute rocauc
168
- rocauc = npc.trapezoid(x=fpr, y=tpr, axis=1)
72
+ fpr = np.zeros_like(fps, dtype=np.float64)
73
+ np.divide(
74
+ fps,
75
+ negative_count[label_idx],
76
+ where=negative_count[label_idx] > 0,
77
+ out=fpr,
78
+ )
79
+ tpr = np.zeros_like(tps, dtype=np.float64)
80
+ np.divide(
81
+ tps,
82
+ positive_count[label_idx],
83
+ where=positive_count[label_idx] > 0,
84
+ out=tpr,
85
+ )
169
86
 
170
- # compute mean rocauc
171
- mean_rocauc = rocauc.mean()
87
+ # compute rocauc
88
+ rocauc[label_idx] += npc.trapezoid(x=fpr, y=tpr, axis=0)
172
89
 
173
- return rocauc, mean_rocauc # type: ignore[reportReturnType]
90
+ return rocauc, prev
174
91
 
175
92
 
176
- def compute_precision_recall_rocauc(
177
- detailed_pairs: NDArray[np.float64],
178
- label_metadata: NDArray[np.int32],
93
+ def compute_counts(
94
+ ids: NDArray[np.int64],
95
+ scores: NDArray[np.float64],
96
+ winners: NDArray[np.bool_],
179
97
  score_thresholds: NDArray[np.float64],
180
98
  hardmax: bool,
181
- n_datums: int,
182
- ) -> tuple[
183
- NDArray[np.int32],
184
- NDArray[np.float64],
185
- NDArray[np.float64],
186
- NDArray[np.float64],
187
- NDArray[np.float64],
188
- NDArray[np.float64],
189
- float,
190
- ]:
99
+ n_labels: int,
100
+ ) -> NDArray[np.uint64]:
191
101
  """
192
- Computes classification metrics.
102
+ Computes counts of TP, FP and FN's per label.
193
103
 
194
104
  Parameters
195
105
  ----------
196
- detailed_pairs : NDArray[np.float64]
197
- A sorted array of classification pairs with shape (n_pairs, 5).
106
+ ids : NDArray[np.int64]
107
+ A sorted array of classification pairs with shape (n_pairs, 3).
198
108
  Index 0 - Datum Index
199
109
  Index 1 - GroundTruth Label Index
200
110
  Index 2 - Prediction Label Index
201
- Index 3 - Score
202
- Index 4 - Hard-Max Score
203
- label_metadata : NDArray[np.int32]
204
- An array containing metadata related to labels with shape (n_labels, 2).
205
- Index 0 - GroundTruth Label Count
206
- Index 1 - Prediction Label Count
111
+ scores : NDArray[np.float64]
112
+ A sorted array of classification scores with shape (n_pairs,).
113
+ winner : NDArray[np.bool_]
114
+ Marks predictions with highest score over a datum.
207
115
  score_thresholds : NDArray[np.float64]
208
116
  A 1-D array contains score thresholds to compute metrics over.
209
117
  hardmax : bool
210
118
  Option to only allow a single positive prediction.
211
- n_datums : int
212
- The number of datums being operated over.
119
+ n_labels : int
120
+ The total number of unqiue labels.
213
121
 
214
122
  Returns
215
123
  -------
216
124
  NDArray[np.int32]
217
125
  TP, FP, FN, TN counts.
218
- NDArray[np.float64]
219
- Precision.
220
- NDArray[np.float64]
221
- Recall.
222
- NDArray[np.float64]
223
- Accuracy
224
- NDArray[np.float64]
225
- F1 Score
226
- NDArray[np.float64]
227
- ROCAUC.
228
- float
229
- mROCAUC.
230
126
  """
231
-
232
- n_labels = label_metadata.shape[0]
233
127
  n_scores = score_thresholds.shape[0]
128
+ counts = np.zeros((n_scores, n_labels, 4), dtype=np.uint64)
129
+ if ids.size == 0:
130
+ return counts
234
131
 
235
- pd_labels = detailed_pairs[:, 2].astype(int)
132
+ gt_labels = ids[:, 1]
133
+ pd_labels = ids[:, 2]
236
134
 
237
- mask_matching_labels = np.isclose(
238
- detailed_pairs[:, 1], detailed_pairs[:, 2]
239
- )
240
- mask_score_nonzero = ~np.isclose(detailed_pairs[:, 3], 0.0)
241
- mask_hardmax = detailed_pairs[:, 4] > 0.5
242
-
243
- # calculate ROCAUC
244
- rocauc, mean_rocauc = _compute_rocauc(
245
- data=detailed_pairs,
246
- label_metadata=label_metadata,
247
- n_datums=n_datums,
248
- n_labels=n_labels,
249
- mask_matching_labels=mask_matching_labels,
250
- pd_labels=pd_labels,
251
- )
135
+ mask_matching_labels = np.isclose(gt_labels, pd_labels)
136
+ mask_score_nonzero = ~np.isclose(scores, 0.0)
137
+ mask_hardmax = winners > 0.5
138
+ mask_valid_gts = gt_labels >= 0
139
+ mask_valid_pds = pd_labels >= 0
252
140
 
253
141
  # calculate metrics at various score thresholds
254
- counts = np.zeros((n_scores, n_labels, 4), dtype=np.int32)
255
142
  for score_idx in range(n_scores):
256
- mask_score_threshold = (
257
- detailed_pairs[:, 3] >= score_thresholds[score_idx]
258
- )
143
+ mask_score_threshold = scores >= score_thresholds[score_idx]
259
144
  mask_score = mask_score_nonzero & mask_score_threshold
260
145
 
261
146
  if hardmax:
@@ -266,8 +151,11 @@ def compute_precision_recall_rocauc(
266
151
  mask_fn = (mask_matching_labels & ~mask_score) | mask_fp
267
152
  mask_tn = ~mask_matching_labels & ~mask_score
268
153
 
269
- fn = np.unique(detailed_pairs[mask_fn][:, [0, 1]].astype(int), axis=0)
270
- tn = np.unique(detailed_pairs[mask_tn][:, [0, 2]].astype(int), axis=0)
154
+ mask_fn &= mask_valid_gts
155
+ mask_fp &= mask_valid_pds
156
+
157
+ fn = np.unique(ids[mask_fn][:, [0, 1]].astype(int), axis=0)
158
+ tn = np.unique(ids[mask_tn][:, [0, 2]].astype(int), axis=0)
271
159
 
272
160
  counts[score_idx, :, 0] = np.bincount(
273
161
  pd_labels[mask_tp], minlength=n_labels
@@ -278,29 +166,45 @@ def compute_precision_recall_rocauc(
278
166
  counts[score_idx, :, 2] = np.bincount(fn[:, 1], minlength=n_labels)
279
167
  counts[score_idx, :, 3] = np.bincount(tn[:, 1], minlength=n_labels)
280
168
 
281
- recall = np.zeros((n_scores, n_labels), dtype=np.float64)
282
- np.divide(
283
- counts[:, :, 0],
284
- (counts[:, :, 0] + counts[:, :, 2]),
285
- where=(counts[:, :, 0] + counts[:, :, 2]) > 1e-9,
286
- out=recall,
287
- )
169
+ return counts
170
+
288
171
 
289
- precision = np.zeros_like(recall)
172
+ def compute_precision(counts: NDArray[np.uint64]) -> NDArray[np.float64]:
173
+ """
174
+ Compute precision metric using result of compute_counts.
175
+ """
176
+ n_scores, n_labels, _ = counts.shape
177
+ precision = np.zeros((n_scores, n_labels), dtype=np.float64)
290
178
  np.divide(
291
179
  counts[:, :, 0],
292
180
  (counts[:, :, 0] + counts[:, :, 1]),
293
- where=(counts[:, :, 0] + counts[:, :, 1]) > 1e-9,
181
+ where=(counts[:, :, 0] + counts[:, :, 1]) > 0,
294
182
  out=precision,
295
183
  )
184
+ return precision
296
185
 
297
- accuracy = np.zeros(n_scores, dtype=np.float64)
186
+
187
+ def compute_recall(counts: NDArray[np.uint64]) -> NDArray[np.float64]:
188
+ """
189
+ Compute recall metric using result of compute_counts.
190
+ """
191
+ n_scores, n_labels, _ = counts.shape
192
+ recall = np.zeros((n_scores, n_labels), dtype=np.float64)
298
193
  np.divide(
299
- counts[:, :, 0].sum(axis=1),
300
- float(n_datums),
301
- out=accuracy,
194
+ counts[:, :, 0],
195
+ (counts[:, :, 0] + counts[:, :, 2]),
196
+ where=(counts[:, :, 0] + counts[:, :, 2]) > 0,
197
+ out=recall,
302
198
  )
199
+ return recall
200
+
303
201
 
202
+ def compute_f1_score(
203
+ precision: NDArray[np.float64], recall: NDArray[np.float64]
204
+ ) -> NDArray[np.float64]:
205
+ """
206
+ Compute f1 metric using result of compute_precision and compute_recall.
207
+ """
304
208
  f1_score = np.zeros_like(recall)
305
209
  np.divide(
306
210
  (2 * precision * recall),
@@ -308,68 +212,86 @@ def compute_precision_recall_rocauc(
308
212
  where=(precision + recall) > 1e-9,
309
213
  out=f1_score,
310
214
  )
215
+ return f1_score
311
216
 
312
- return (
313
- counts,
314
- precision,
315
- recall,
316
- accuracy,
317
- f1_score,
318
- rocauc,
319
- mean_rocauc,
217
+
218
+ def compute_accuracy(
219
+ counts: NDArray[np.uint64], n_datums: int
220
+ ) -> NDArray[np.float64]:
221
+ """
222
+ Compute accuracy metric using result of compute_counts.
223
+ """
224
+ n_scores, _, _ = counts.shape
225
+ accuracy = np.zeros(n_scores, dtype=np.float64)
226
+ if n_datums == 0:
227
+ return accuracy
228
+ np.divide(
229
+ counts[:, :, 0].sum(axis=1),
230
+ n_datums,
231
+ out=accuracy,
320
232
  )
233
+ return accuracy
321
234
 
322
235
 
323
236
  class PairClassification(IntFlag):
324
- TP = auto()
325
- FP_FN_MISCLF = auto()
326
- FN_UNMATCHED = auto()
237
+ TP = 1 << 0
238
+ FP_FN_MISCLF = 1 << 1
239
+ FN_UNMATCHED = 1 << 2
327
240
 
328
241
 
329
- def compute_confusion_matrix(
330
- detailed_pairs: NDArray[np.float64],
242
+ def compute_pair_classifications(
243
+ ids: NDArray[np.int64],
244
+ scores: NDArray[np.float64],
245
+ winners: NDArray[np.bool_],
331
246
  score_thresholds: NDArray[np.float64],
332
247
  hardmax: bool,
333
- ) -> NDArray[np.uint8]:
248
+ ) -> tuple[NDArray[np.bool_], NDArray[np.bool_], NDArray[np.bool_]]:
334
249
  """
335
- Compute detailed confusion matrix.
250
+ Classifiy ID pairs as TP, FP or FN.
336
251
 
337
252
  Parameters
338
253
  ----------
339
- detailed_pairs : NDArray[np.float64]
340
- A 2-D sorted array summarizing the IOU calculations of one or more pairs with shape (n_pairs, 5).
254
+ ids : NDArray[np.int64]
255
+ A sorted array of classification pairs with shape (n_pairs, 3).
341
256
  Index 0 - Datum Index
342
257
  Index 1 - GroundTruth Label Index
343
258
  Index 2 - Prediction Label Index
344
- Index 3 - Score
345
- Index 4 - Hard Max Score
346
- iou_thresholds : NDArray[np.float64]
347
- A 1-D array containing IOU thresholds.
259
+ scores : NDArray[np.float64]
260
+ A sorted array of classification scores with shape (n_pairs,).
261
+ winner : NDArray[np.bool_]
262
+ Marks predictions with highest score over a datum.
348
263
  score_thresholds : NDArray[np.float64]
349
264
  A 1-D array containing score thresholds.
265
+ hardmax : bool
266
+ Option to only allow a single positive prediction.
350
267
 
351
268
  Returns
352
269
  -------
353
- NDArray[uint8]
354
- Row-wise classification of pairs.
270
+ NDArray[bool]
271
+ True-positive mask.
272
+ NDArray[bool]
273
+ Misclassification FP, FN mask.
274
+ NDArray[bool]
275
+ Unmatched FN mask.
355
276
  """
356
- n_pairs = detailed_pairs.shape[0]
277
+ n_pairs = ids.shape[0]
357
278
  n_scores = score_thresholds.shape[0]
358
279
 
280
+ gt_labels = ids[:, 1]
281
+ pd_labels = ids[:, 2]
282
+ groundtruths = ids[:, [0, 1]]
283
+
359
284
  pair_classifications = np.zeros(
360
285
  (n_scores, n_pairs),
361
286
  dtype=np.uint8,
362
287
  )
363
288
 
364
- mask_label_match = np.isclose(detailed_pairs[:, 1], detailed_pairs[:, 2])
365
- mask_score = detailed_pairs[:, 3] > 1e-9
366
-
367
- groundtruths = detailed_pairs[:, [0, 1]].astype(int)
368
-
289
+ mask_label_match = np.isclose(gt_labels, pd_labels)
290
+ mask_score = scores > 1e-9
369
291
  for score_idx in range(n_scores):
370
- mask_score &= detailed_pairs[:, 3] >= score_thresholds[score_idx]
292
+ mask_score &= scores >= score_thresholds[score_idx]
371
293
  if hardmax:
372
- mask_score &= detailed_pairs[:, 4] > 0.5
294
+ mask_score &= winners
373
295
 
374
296
  mask_true_positives = mask_label_match & mask_score
375
297
  mask_misclassifications = ~mask_label_match & mask_score
@@ -393,4 +315,64 @@ def compute_confusion_matrix(
393
315
  score_idx, mask_unmatched_groundtruths
394
316
  ] |= np.uint8(PairClassification.FN_UNMATCHED)
395
317
 
396
- return pair_classifications
318
+ mask_tp = np.bitwise_and(pair_classifications, PairClassification.TP) > 0
319
+ mask_fp_fn_misclf = (
320
+ np.bitwise_and(pair_classifications, PairClassification.FP_FN_MISCLF)
321
+ > 0
322
+ )
323
+ mask_fn_unmatched = (
324
+ np.bitwise_and(pair_classifications, PairClassification.FN_UNMATCHED)
325
+ > 0
326
+ )
327
+
328
+ return (
329
+ mask_tp,
330
+ mask_fp_fn_misclf,
331
+ mask_fn_unmatched,
332
+ )
333
+
334
+
335
+ def compute_confusion_matrix(
336
+ ids: NDArray[np.int64],
337
+ mask_tp: NDArray[np.bool_],
338
+ mask_fp_fn_misclf: NDArray[np.bool_],
339
+ mask_fn_unmatched: NDArray[np.bool_],
340
+ score_thresholds: NDArray[np.float64],
341
+ n_labels: int,
342
+ ):
343
+ """
344
+ Compute confusion matrix using output of compute_pair_classifications.
345
+ """
346
+ n_scores = score_thresholds.size
347
+
348
+ # initialize arrays
349
+ confusion_matrices = np.zeros(
350
+ (n_scores, n_labels, n_labels), dtype=np.uint64
351
+ )
352
+ unmatched_groundtruths = np.zeros((n_scores, n_labels), dtype=np.uint64)
353
+
354
+ mask_matched = mask_tp | mask_fp_fn_misclf
355
+ for score_idx in range(n_scores):
356
+ # matched annotations
357
+ unique_pairs = np.unique(
358
+ ids[np.ix_(mask_matched[score_idx], (0, 1, 2))], # type: ignore - numpy ix_ typing
359
+ axis=0,
360
+ )
361
+ unique_labels, unique_label_counts = np.unique(
362
+ unique_pairs[:, (1, 2)], axis=0, return_counts=True
363
+ )
364
+ confusion_matrices[
365
+ score_idx, unique_labels[:, 0], unique_labels[:, 1]
366
+ ] = unique_label_counts
367
+
368
+ # unmatched groundtruths
369
+ unique_pairs = np.unique(
370
+ ids[np.ix_(mask_fn_unmatched[score_idx], (0, 1))], # type: ignore - numpy ix_ typing
371
+ axis=0,
372
+ )
373
+ unique_labels, unique_label_counts = np.unique(
374
+ unique_pairs[:, 1], return_counts=True
375
+ )
376
+ unmatched_groundtruths[score_idx, unique_labels] = unique_label_counts
377
+
378
+ return confusion_matrices, unmatched_groundtruths