valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. valor_lite/cache/__init__.py +11 -0
  2. valor_lite/cache/compute.py +211 -0
  3. valor_lite/cache/ephemeral.py +302 -0
  4. valor_lite/cache/persistent.py +536 -0
  5. valor_lite/classification/__init__.py +5 -10
  6. valor_lite/classification/annotation.py +4 -0
  7. valor_lite/classification/computation.py +233 -251
  8. valor_lite/classification/evaluator.py +882 -0
  9. valor_lite/classification/loader.py +97 -0
  10. valor_lite/classification/metric.py +141 -4
  11. valor_lite/classification/shared.py +184 -0
  12. valor_lite/classification/utilities.py +221 -118
  13. valor_lite/exceptions.py +5 -0
  14. valor_lite/object_detection/__init__.py +5 -4
  15. valor_lite/object_detection/annotation.py +13 -1
  16. valor_lite/object_detection/computation.py +368 -299
  17. valor_lite/object_detection/evaluator.py +804 -0
  18. valor_lite/object_detection/loader.py +292 -0
  19. valor_lite/object_detection/metric.py +152 -3
  20. valor_lite/object_detection/shared.py +206 -0
  21. valor_lite/object_detection/utilities.py +182 -100
  22. valor_lite/semantic_segmentation/__init__.py +5 -4
  23. valor_lite/semantic_segmentation/annotation.py +7 -0
  24. valor_lite/semantic_segmentation/computation.py +20 -110
  25. valor_lite/semantic_segmentation/evaluator.py +414 -0
  26. valor_lite/semantic_segmentation/loader.py +205 -0
  27. valor_lite/semantic_segmentation/shared.py +149 -0
  28. valor_lite/semantic_segmentation/utilities.py +6 -23
  29. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
  30. valor_lite-0.37.5.dist-info/RECORD +49 -0
  31. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
  32. valor_lite/classification/manager.py +0 -545
  33. valor_lite/object_detection/manager.py +0 -864
  34. valor_lite/profiling.py +0 -374
  35. valor_lite/semantic_segmentation/benchmark.py +0 -237
  36. valor_lite/semantic_segmentation/manager.py +0 -446
  37. valor_lite-0.36.6.dist-info/RECORD +0 -41
  38. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
@@ -1,42 +1,24 @@
1
1
  from collections import defaultdict
2
2
 
3
3
  import numpy as np
4
+ import pyarrow as pa
4
5
  from numpy.typing import NDArray
5
6
 
6
- from valor_lite.object_detection.computation import PairClassification
7
7
  from valor_lite.object_detection.metric import Metric, MetricType
8
8
 
9
9
 
10
10
  def unpack_precision_recall_into_metric_lists(
11
- results: tuple[
12
- tuple[
13
- NDArray[np.float64],
14
- NDArray[np.float64],
15
- ],
16
- tuple[
17
- NDArray[np.float64],
18
- NDArray[np.float64],
19
- ],
20
- NDArray[np.float64],
21
- NDArray[np.float64],
22
- ],
11
+ counts: NDArray[np.uint64],
12
+ precision_recall_f1: NDArray[np.float64],
13
+ average_precision: NDArray[np.float64],
14
+ mean_average_precision: NDArray[np.float64],
15
+ average_recall: NDArray[np.float64],
16
+ mean_average_recall: NDArray[np.float64],
17
+ pr_curve: NDArray[np.float64],
23
18
  iou_thresholds: list[float],
24
19
  score_thresholds: list[float],
25
- index_to_label: list[str],
20
+ index_to_label: dict[int, str],
26
21
  ):
27
- (
28
- (
29
- average_precision,
30
- mean_average_precision,
31
- ),
32
- (
33
- average_recall,
34
- mean_average_recall,
35
- ),
36
- precision_recall,
37
- pr_curves,
38
- ) = results
39
-
40
22
  metrics = defaultdict(list)
41
23
 
42
24
  metrics[MetricType.AP] = [
@@ -46,7 +28,7 @@ def unpack_precision_recall_into_metric_lists(
46
28
  label=label,
47
29
  )
48
30
  for iou_idx, iou_threshold in enumerate(iou_thresholds)
49
- for label_idx, label in enumerate(index_to_label)
31
+ for label_idx, label in index_to_label.items()
50
32
  ]
51
33
 
52
34
  metrics[MetricType.mAP] = [
@@ -64,7 +46,7 @@ def unpack_precision_recall_into_metric_lists(
64
46
  iou_thresholds=iou_thresholds,
65
47
  label=label,
66
48
  )
67
- for label_idx, label in enumerate(index_to_label)
49
+ for label_idx, label in index_to_label.items()
68
50
  ]
69
51
 
70
52
  # TODO - (c.zaloom) will be removed in the future
@@ -83,7 +65,7 @@ def unpack_precision_recall_into_metric_lists(
83
65
  label=label,
84
66
  )
85
67
  for score_idx, score_threshold in enumerate(score_thresholds)
86
- for label_idx, label in enumerate(index_to_label)
68
+ for label_idx, label in index_to_label.items()
87
69
  ]
88
70
 
89
71
  metrics[MetricType.mAR] = [
@@ -103,7 +85,7 @@ def unpack_precision_recall_into_metric_lists(
103
85
  iou_thresholds=iou_thresholds,
104
86
  label=label,
105
87
  )
106
- for label_idx, label in enumerate(index_to_label)
88
+ for label_idx, label in index_to_label.items()
107
89
  ]
108
90
 
109
91
  # TODO - (c.zaloom) will be removed in the future
@@ -117,20 +99,20 @@ def unpack_precision_recall_into_metric_lists(
117
99
 
118
100
  metrics[MetricType.PrecisionRecallCurve] = [
119
101
  Metric.precision_recall_curve(
120
- precisions=pr_curves[iou_idx, label_idx, :, 0].tolist(), # type: ignore[reportArgumentType]
121
- scores=pr_curves[iou_idx, label_idx, :, 1].tolist(), # type: ignore[reportArgumentType]
102
+ precisions=pr_curve[iou_idx, label_idx, :, 0].tolist(),
103
+ scores=pr_curve[iou_idx, label_idx, :, 1].tolist(),
122
104
  iou_threshold=iou_threshold,
123
105
  label=label,
124
106
  )
125
107
  for iou_idx, iou_threshold in enumerate(iou_thresholds)
126
- for label_idx, label in enumerate(index_to_label)
108
+ for label_idx, label in index_to_label.items()
127
109
  ]
128
110
 
129
- for label_idx, label in enumerate(index_to_label):
111
+ for label_idx, label in index_to_label.items():
130
112
  for score_idx, score_threshold in enumerate(score_thresholds):
131
113
  for iou_idx, iou_threshold in enumerate(iou_thresholds):
132
114
 
133
- row = precision_recall[iou_idx, score_idx, label_idx, :]
115
+ row = counts[iou_idx, score_idx, :, label_idx]
134
116
  kwargs = {
135
117
  "label": label,
136
118
  "iou_threshold": iou_threshold,
@@ -145,21 +127,22 @@ def unpack_precision_recall_into_metric_lists(
145
127
  )
146
128
  )
147
129
 
130
+ row = precision_recall_f1[iou_idx, score_idx, :, label_idx]
148
131
  metrics[MetricType.Precision].append(
149
132
  Metric.precision(
150
- value=float(row[3]),
133
+ value=float(row[0]),
151
134
  **kwargs,
152
135
  )
153
136
  )
154
137
  metrics[MetricType.Recall].append(
155
138
  Metric.recall(
156
- value=float(row[4]),
139
+ value=float(row[1]),
157
140
  **kwargs,
158
141
  )
159
142
  )
160
143
  metrics[MetricType.F1].append(
161
144
  Metric.f1_score(
162
- value=float(row[5]),
145
+ value=float(row[2]),
163
146
  **kwargs,
164
147
  )
165
148
  )
@@ -167,40 +150,153 @@ def unpack_precision_recall_into_metric_lists(
167
150
  return metrics
168
151
 
169
152
 
170
- def _create_empty_confusion_matrix(index_to_labels: list[str]):
171
- unmatched_ground_truths = dict()
153
+ def unpack_confusion_matrix(
154
+ confusion_matrices: NDArray[np.uint64],
155
+ unmatched_groundtruths: NDArray[np.uint64],
156
+ unmatched_predictions: NDArray[np.uint64],
157
+ index_to_label: dict[int, str],
158
+ iou_thresholds: list[float],
159
+ score_thresholds: list[float],
160
+ ) -> list[Metric]:
161
+ metrics = []
162
+ for iou_idx, iou_thresh in enumerate(iou_thresholds):
163
+ for score_idx, score_thresh in enumerate(score_thresholds):
164
+ cm_dict = {}
165
+ ugt_dict = {}
166
+ upd_dict = {}
167
+ for idx, label in index_to_label.items():
168
+ ugt_dict[label] = int(
169
+ unmatched_groundtruths[iou_idx, score_idx, idx]
170
+ )
171
+ upd_dict[label] = int(
172
+ unmatched_predictions[iou_idx, score_idx, idx]
173
+ )
174
+ for pidx, plabel in index_to_label.items():
175
+ if label not in cm_dict:
176
+ cm_dict[label] = {}
177
+ cm_dict[label][plabel] = int(
178
+ confusion_matrices[iou_idx, score_idx, idx, pidx]
179
+ )
180
+ metrics.append(
181
+ Metric.confusion_matrix(
182
+ confusion_matrix=cm_dict,
183
+ unmatched_ground_truths=ugt_dict,
184
+ unmatched_predictions=upd_dict,
185
+ iou_threshold=iou_thresh,
186
+ score_threshold=score_thresh,
187
+ )
188
+ )
189
+ return metrics
190
+
191
+
192
+ def create_mapping(
193
+ tbl: pa.Table,
194
+ pairs: NDArray[np.float64],
195
+ index: int,
196
+ id_col: str,
197
+ uid_col: str,
198
+ ) -> dict[int, str]:
199
+ col = pairs[:, index].astype(np.int64)
200
+ values, indices = np.unique(col, return_index=True)
201
+ indices = indices[values >= 0]
202
+ return {
203
+ tbl[id_col][idx].as_py(): tbl[uid_col][idx].as_py() for idx in indices
204
+ }
205
+
206
+
207
+ def unpack_examples(
208
+ detailed_pairs: NDArray[np.float64],
209
+ mask_tp: NDArray[np.bool_],
210
+ mask_fn: NDArray[np.bool_],
211
+ mask_fp: NDArray[np.bool_],
212
+ iou_thresholds: list[float],
213
+ score_thresholds: list[float],
214
+ index_to_datum_id: dict[int, str],
215
+ index_to_groundtruth_id: dict[int, str],
216
+ index_to_prediction_id: dict[int, str],
217
+ ) -> list[Metric]:
218
+ metrics = []
219
+ ids = detailed_pairs[:, :5].astype(np.int64)
220
+ unique_datums = np.unique(detailed_pairs[:, 0].astype(np.int64))
221
+ for datum_index in unique_datums:
222
+ mask_datum = detailed_pairs[:, 0] == datum_index
223
+ mask_datum_tp = mask_tp & mask_datum
224
+ mask_datum_fp = mask_fp & mask_datum
225
+ mask_datum_fn = mask_fn & mask_datum
226
+
227
+ datum_id = index_to_datum_id[datum_index]
228
+ for iou_idx, iou_thresh in enumerate(iou_thresholds):
229
+ for score_idx, score_thresh in enumerate(score_thresholds):
230
+
231
+ unique_tp = np.unique(
232
+ ids[np.ix_(mask_datum_tp[iou_idx, score_idx], (0, 1, 2, 3, 4))], axis=0 # type: ignore - numpy ix_ typing
233
+ )
234
+ unique_fp = np.unique(
235
+ ids[np.ix_(mask_datum_fp[iou_idx, score_idx], (0, 2, 4))], axis=0 # type: ignore - numpy ix_ typing
236
+ )
237
+ unique_fn = np.unique(
238
+ ids[np.ix_(mask_datum_fn[iou_idx, score_idx], (0, 1, 3))], axis=0 # type: ignore - numpy ix_ typing
239
+ )
240
+
241
+ tp = [
242
+ (
243
+ index_to_groundtruth_id[row[1]],
244
+ index_to_prediction_id[row[2]],
245
+ )
246
+ for row in unique_tp
247
+ ]
248
+ fp = [index_to_prediction_id[row[1]] for row in unique_fp]
249
+ fn = [index_to_groundtruth_id[row[1]] for row in unique_fn]
250
+ metrics.append(
251
+ Metric.examples(
252
+ datum_id=datum_id,
253
+ true_positives=tp,
254
+ false_negatives=fn,
255
+ false_positives=fp,
256
+ iou_threshold=iou_thresh,
257
+ score_threshold=score_thresh,
258
+ )
259
+ )
260
+ return metrics
261
+
262
+
263
+ def create_empty_confusion_matrix_with_examples(
264
+ iou_threhsold: float,
265
+ score_threshold: float,
266
+ index_to_label: dict[int, str],
267
+ ) -> Metric:
268
+ unmatched_groundtruths = dict()
172
269
  unmatched_predictions = dict()
173
270
  confusion_matrix = dict()
174
- for label in index_to_labels:
175
- unmatched_ground_truths[label] = {"count": 0, "examples": []}
271
+ for label in index_to_label.values():
272
+ unmatched_groundtruths[label] = {"count": 0, "examples": []}
176
273
  unmatched_predictions[label] = {"count": 0, "examples": []}
177
274
  confusion_matrix[label] = {}
178
- for plabel in index_to_labels:
275
+ for plabel in index_to_label.values():
179
276
  confusion_matrix[label][plabel] = {"count": 0, "examples": []}
180
- return (
181
- confusion_matrix,
182
- unmatched_predictions,
183
- unmatched_ground_truths,
277
+
278
+ return Metric.confusion_matrix_with_examples(
279
+ confusion_matrix=confusion_matrix,
280
+ unmatched_ground_truths=unmatched_groundtruths,
281
+ unmatched_predictions=unmatched_predictions,
282
+ iou_threshold=iou_threhsold,
283
+ score_threshold=score_threshold,
184
284
  )
185
285
 
186
286
 
187
- def _unpack_confusion_matrix(
287
+ def _unpack_confusion_matrix_with_examples(
288
+ metric: Metric,
188
289
  ids: NDArray[np.int32],
189
290
  mask_matched: NDArray[np.bool_],
190
291
  mask_fp_unmatched: NDArray[np.bool_],
191
292
  mask_fn_unmatched: NDArray[np.bool_],
192
- index_to_datum_id: list[str],
193
- index_to_groundtruth_id: list[str],
194
- index_to_prediction_id: list[str],
195
- index_to_label: list[str],
196
- iou_threhsold: float,
197
- score_threshold: float,
293
+ index_to_datum_id: dict[int, str],
294
+ index_to_groundtruth_id: dict[int, str],
295
+ index_to_prediction_id: dict[int, str],
296
+ index_to_label: dict[int, str],
198
297
  ):
199
- (
200
- confusion_matrix,
201
- unmatched_predictions,
202
- unmatched_ground_truths,
203
- ) = _create_empty_confusion_matrix(index_to_label)
298
+ if not isinstance(metric.value, dict):
299
+ raise TypeError("expected metric to contain a dictionary value")
204
300
 
205
301
  unique_matches = np.unique(
206
302
  ids[np.ix_(mask_matched, (0, 1, 2, 3, 4))], axis=0 # type: ignore - numpy ix_ typing
@@ -220,8 +316,8 @@ def _unpack_confusion_matrix(
220
316
  for idx in range(n_max):
221
317
  if idx < n_unmatched_groundtruths:
222
318
  label = index_to_label[unique_unmatched_groundtruths[idx, 2]]
223
- unmatched_ground_truths[label]["count"] += 1
224
- unmatched_ground_truths[label]["examples"].append(
319
+ metric.value["unmatched_ground_truths"][label]["count"] += 1
320
+ metric.value["unmatched_ground_truths"][label]["examples"].append(
225
321
  {
226
322
  "datum_id": index_to_datum_id[
227
323
  unique_unmatched_groundtruths[idx, 0]
@@ -232,9 +328,10 @@ def _unpack_confusion_matrix(
232
328
  }
233
329
  )
234
330
  if idx < n_unmatched_predictions:
235
- label = index_to_label[unique_unmatched_predictions[idx, 2]]
236
- unmatched_predictions[label]["count"] += 1
237
- unmatched_predictions[label]["examples"].append(
331
+ label_id = unique_unmatched_predictions[idx, 2]
332
+ label = index_to_label[label_id]
333
+ metric.value["unmatched_predictions"][label]["count"] += 1
334
+ metric.value["unmatched_predictions"][label]["examples"].append(
238
335
  {
239
336
  "datum_id": index_to_datum_id[
240
337
  unique_unmatched_predictions[idx, 0]
@@ -247,8 +344,10 @@ def _unpack_confusion_matrix(
247
344
  if idx < n_matched:
248
345
  glabel = index_to_label[unique_matches[idx, 3]]
249
346
  plabel = index_to_label[unique_matches[idx, 4]]
250
- confusion_matrix[glabel][plabel]["count"] += 1
251
- confusion_matrix[glabel][plabel]["examples"].append(
347
+ metric.value["confusion_matrix"][glabel][plabel]["count"] += 1
348
+ metric.value["confusion_matrix"][glabel][plabel][
349
+ "examples"
350
+ ].append(
252
351
  {
253
352
  "datum_id": index_to_datum_id[unique_matches[idx, 0]],
254
353
  "ground_truth_id": index_to_groundtruth_id[
@@ -260,43 +359,29 @@ def _unpack_confusion_matrix(
260
359
  }
261
360
  )
262
361
 
263
- return Metric.confusion_matrix(
264
- confusion_matrix=confusion_matrix,
265
- unmatched_ground_truths=unmatched_ground_truths,
266
- unmatched_predictions=unmatched_predictions,
267
- iou_threshold=iou_threhsold,
268
- score_threshold=score_threshold,
269
- )
362
+ return metric
270
363
 
271
364
 
272
- def unpack_confusion_matrix_into_metric_list(
273
- results: NDArray[np.uint8],
365
+ def unpack_confusion_matrix_with_examples(
366
+ metrics: dict[int, dict[int, Metric]],
274
367
  detailed_pairs: NDArray[np.float64],
275
- iou_thresholds: list[float],
276
- score_thresholds: list[float],
277
- index_to_datum_id: list[str],
278
- index_to_groundtruth_id: list[str],
279
- index_to_prediction_id: list[str],
280
- index_to_label: list[str],
368
+ mask_tp: NDArray[np.bool_],
369
+ mask_fp_fn_misclf: NDArray[np.bool_],
370
+ mask_fp_unmatched: NDArray[np.bool_],
371
+ mask_fn_unmatched: NDArray[np.bool_],
372
+ index_to_datum_id: dict[int, str],
373
+ index_to_groundtruth_id: dict[int, str],
374
+ index_to_prediction_id: dict[int, str],
375
+ index_to_label: dict[int, str],
281
376
  ) -> list[Metric]:
282
377
 
283
378
  ids = detailed_pairs[:, :5].astype(np.int32)
284
379
 
285
- mask_matched = (
286
- np.bitwise_and(
287
- results, PairClassification.TP | PairClassification.FP_FN_MISCLF
288
- )
289
- > 0
290
- )
291
- mask_fp_unmatched = (
292
- np.bitwise_and(results, PairClassification.FP_UNMATCHED) > 0
293
- )
294
- mask_fn_unmatched = (
295
- np.bitwise_and(results, PairClassification.FN_UNMATCHED) > 0
296
- )
380
+ mask_matched = mask_tp | mask_fp_fn_misclf
297
381
 
298
382
  return [
299
- _unpack_confusion_matrix(
383
+ _unpack_confusion_matrix_with_examples(
384
+ metric=metric,
300
385
  ids=ids,
301
386
  mask_matched=mask_matched[iou_idx, score_idx],
302
387
  mask_fp_unmatched=mask_fp_unmatched[iou_idx, score_idx],
@@ -305,10 +390,7 @@ def unpack_confusion_matrix_into_metric_list(
305
390
  index_to_groundtruth_id=index_to_groundtruth_id,
306
391
  index_to_prediction_id=index_to_prediction_id,
307
392
  index_to_label=index_to_label,
308
- iou_threhsold=iou_threshold,
309
- score_threshold=score_threshold,
310
393
  )
311
- for iou_idx, iou_threshold in enumerate(iou_thresholds)
312
- for score_idx, score_threshold in enumerate(score_thresholds)
313
- if (results[iou_idx, score_idx] != -1).any()
394
+ for iou_idx, inner in metrics.items()
395
+ for score_idx, metric in inner.items()
314
396
  ]
@@ -1,14 +1,15 @@
1
1
  from .annotation import Bitmask, Segmentation
2
- from .manager import DataLoader, Evaluator, Filter, Metadata
2
+ from .evaluator import Builder, Evaluator, EvaluatorInfo
3
+ from .loader import Loader
3
4
  from .metric import Metric, MetricType
4
5
 
5
6
  __all__ = [
6
- "DataLoader",
7
+ "Builder",
8
+ "Loader",
7
9
  "Evaluator",
8
10
  "Segmentation",
9
11
  "Bitmask",
10
12
  "Metric",
11
13
  "MetricType",
12
- "Filter",
13
- "Metadata",
14
+ "EvaluatorInfo",
14
15
  ]
@@ -1,5 +1,6 @@
1
1
  import warnings
2
2
  from dataclasses import dataclass, field
3
+ from typing import Any
3
4
 
4
5
  import numpy as np
5
6
  from numpy.typing import NDArray
@@ -16,6 +17,8 @@ class Bitmask:
16
17
  A NumPy array of boolean values representing the mask.
17
18
  label : str
18
19
  The semantic label associated with the mask.
20
+ metadata : dict[str, Any], optional
21
+ A dictionary containing any metadata to be used within filtering operations.
19
22
 
20
23
  Examples
21
24
  --------
@@ -26,6 +29,7 @@ class Bitmask:
26
29
 
27
30
  mask: NDArray[np.bool_]
28
31
  label: str
32
+ metadata: dict[str, Any] | None = None
29
33
 
30
34
  def __post_init__(self):
31
35
  if self.mask.dtype != np.bool_:
@@ -51,6 +55,8 @@ class Segmentation:
51
55
  The shape of the segmentation masks. This is set automatically after initialization.
52
56
  size : int, optional
53
57
  The total number of pixels in the masks. This is set automatically after initialization.
58
+ metadata : dict[str, Any], optional
59
+ A dictionary containing any metadata to be used within filtering operations.
54
60
 
55
61
  Examples
56
62
  --------
@@ -71,6 +77,7 @@ class Segmentation:
71
77
  predictions: list[Bitmask]
72
78
  shape: tuple[int, ...]
73
79
  size: int = field(default=0)
80
+ metadata: dict[str, Any] | None = None
74
81
 
75
82
  def __post_init__(self):
76
83
 
@@ -2,93 +2,13 @@ import numpy as np
2
2
  from numpy.typing import NDArray
3
3
 
4
4
 
5
- def compute_label_metadata(
6
- confusion_matrices: NDArray[np.int64],
7
- n_labels: int,
8
- ) -> NDArray[np.int64]:
9
- """
10
- Computes label metadata returning a count of annotations per label.
11
-
12
- Parameters
13
- ----------
14
- confusion_matrices : NDArray[np.int64]
15
- Confusion matrices per datum with shape (n_datums, n_labels + 1, n_labels + 1).
16
- n_labels : int
17
- The total number of unique labels.
18
-
19
- Returns
20
- -------
21
- NDArray[np.int64]
22
- The label metadata array with shape (n_labels, 2).
23
- Index 0 - Ground truth label count
24
- Index 1 - Prediction label count
25
- """
26
- label_metadata = np.zeros((n_labels, 2), dtype=np.int64)
27
- label_metadata[:, 0] = confusion_matrices[:, 1:, :].sum(axis=(0, 2))
28
- label_metadata[:, 1] = confusion_matrices[:, :, 1:].sum(axis=(0, 1))
29
- return label_metadata
30
-
31
-
32
- def filter_cache(
33
- confusion_matrices: NDArray[np.int64],
34
- datum_mask: NDArray[np.bool_],
35
- label_mask: NDArray[np.bool_],
36
- number_of_labels: int,
37
- ) -> tuple[NDArray[np.int64], NDArray[np.int64]]:
38
- """
39
- Performs the filter operation over the internal cache.
40
-
41
- Parameters
42
- ----------
43
- confusion_matrices : NDArray[int64]
44
- The internal evaluator cache.
45
- datum_mask : NDArray[bool]
46
- A mask that filters out datums.
47
- datum_mask : NDArray[bool]
48
- A mask that filters out labels.
49
-
50
- Returns
51
- -------
52
- NDArray[int64]
53
- Filtered confusion matrices.
54
- NDArray[int64]
55
- Filtered label metadata.
56
- """
57
- if label_mask.any():
58
- # add filtered labels to background
59
- null_predictions = confusion_matrices[:, label_mask, :].sum(
60
- axis=(1, 2)
61
- )
62
- null_groundtruths = confusion_matrices[:, :, label_mask].sum(
63
- axis=(1, 2)
64
- )
65
- null_intersection = (
66
- confusion_matrices[:, label_mask, label_mask]
67
- .reshape(confusion_matrices.shape[0], -1)
68
- .sum(axis=1)
69
- )
70
- confusion_matrices[:, 0, 0] += (
71
- null_groundtruths + null_predictions - null_intersection
72
- )
73
- confusion_matrices[:, label_mask, :] = 0
74
- confusion_matrices[:, :, label_mask] = 0
75
-
76
- confusion_matrices = confusion_matrices[datum_mask]
77
-
78
- label_metadata = compute_label_metadata(
79
- confusion_matrices=confusion_matrices,
80
- n_labels=number_of_labels,
81
- )
82
- return confusion_matrices, label_metadata
83
-
84
-
85
- def compute_intermediate_confusion_matrices(
5
+ def compute_intermediates(
86
6
  groundtruths: NDArray[np.bool_],
87
7
  predictions: NDArray[np.bool_],
88
8
  groundtruth_labels: NDArray[np.int64],
89
9
  prediction_labels: NDArray[np.int64],
90
10
  n_labels: int,
91
- ) -> NDArray[np.int64]:
11
+ ) -> NDArray[np.uint64]:
92
12
  """
93
13
  Computes an intermediate confusion matrix containing label counts.
94
14
 
@@ -99,15 +19,15 @@ def compute_intermediate_confusion_matrices(
99
19
  predictions : NDArray[np.bool_]
100
20
  A 2-D array containing flattened bitmasks for each label.
101
21
  groundtruth_labels : NDArray[np.int64]
102
- A 1-D array containing label indices.
103
- groundtruth_labels : NDArray[np.int64]
104
- A 1-D array containing label indices.
22
+ A 1-D array containing ground truth label indices.
23
+ prediction_labels : NDArray[np.int64]
24
+ A 1-D array containing prediction label indices.
105
25
  n_labels : int
106
26
  The number of unique labels.
107
27
 
108
28
  Returns
109
29
  -------
110
- NDArray[np.int64]
30
+ NDArray[np.uint64]
111
31
  A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
112
32
  """
113
33
 
@@ -125,7 +45,7 @@ def compute_intermediate_confusion_matrices(
125
45
  intersected_groundtruth_counts = intersection_counts.sum(axis=1)
126
46
  intersected_prediction_counts = intersection_counts.sum(axis=0)
127
47
 
128
- confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.int64)
48
+ confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.uint64)
129
49
  confusion_matrix[0, 0] = background_counts
130
50
  confusion_matrix[
131
51
  np.ix_(groundtruth_labels + 1, prediction_labels + 1)
@@ -136,14 +56,11 @@ def compute_intermediate_confusion_matrices(
136
56
  confusion_matrix[groundtruth_labels + 1, 0] = (
137
57
  groundtruth_counts - intersected_groundtruth_counts
138
58
  )
139
-
140
59
  return confusion_matrix
141
60
 
142
61
 
143
62
  def compute_metrics(
144
- confusion_matrices: NDArray[np.int64],
145
- label_metadata: NDArray[np.int64],
146
- n_pixels: int,
63
+ confusion_matrix: NDArray[np.uint64],
147
64
  ) -> tuple[
148
65
  NDArray[np.float64],
149
66
  NDArray[np.float64],
@@ -156,16 +73,10 @@ def compute_metrics(
156
73
  """
157
74
  Computes semantic segmentation metrics.
158
75
 
159
- Takes data with shape (3, N).
160
-
161
76
  Parameters
162
77
  ----------
163
- confusion_matrices : NDArray[np.int64]
164
- A 3-D array containing confusion matrices for each datum with shape (n_datums, n_labels + 1, n_labels + 1).
165
- label_metadata : NDArray[np.int64]
166
- A 2-D array containing label metadata with shape (n_labels, 2).
167
- Index 0: Ground Truth Label Count
168
- Index 1: Prediction Label Count
78
+ counts : NDArray[np.uint64]
79
+ A 2-D confusion matrix with shape (n_labels + 1, n_labels + 1).
169
80
 
170
81
  Returns
171
82
  -------
@@ -184,14 +95,13 @@ def compute_metrics(
184
95
  NDArray[np.float64]
185
96
  Unmatched ground truth ratios.
186
97
  """
187
- n_labels = label_metadata.shape[0]
188
- gt_counts = label_metadata[:, 0]
189
- pd_counts = label_metadata[:, 1]
190
-
191
- counts = confusion_matrices.sum(axis=0)
98
+ n_labels = confusion_matrix.shape[0] - 1
99
+ n_pixels = confusion_matrix.sum()
100
+ gt_counts = confusion_matrix[1:, :].sum(axis=1)
101
+ pd_counts = confusion_matrix[:, 1:].sum(axis=0)
192
102
 
193
103
  # compute iou, unmatched_ground_truth and unmatched predictions
194
- intersection_ = counts[1:, 1:]
104
+ intersection_ = confusion_matrix[1:, 1:]
195
105
  union_ = (
196
106
  gt_counts[:, np.newaxis] + pd_counts[np.newaxis, :] - intersection_
197
107
  )
@@ -206,7 +116,7 @@ def compute_metrics(
206
116
 
207
117
  unmatched_prediction_ratio = np.zeros((n_labels), dtype=np.float64)
208
118
  np.divide(
209
- counts[0, 1:],
119
+ confusion_matrix[0, 1:],
210
120
  pd_counts,
211
121
  where=pd_counts > 1e-9,
212
122
  out=unmatched_prediction_ratio,
@@ -214,14 +124,14 @@ def compute_metrics(
214
124
 
215
125
  unmatched_ground_truth_ratio = np.zeros((n_labels), dtype=np.float64)
216
126
  np.divide(
217
- counts[1:, 0],
127
+ confusion_matrix[1:, 0],
218
128
  gt_counts,
219
129
  where=gt_counts > 1e-9,
220
130
  out=unmatched_ground_truth_ratio,
221
131
  )
222
132
 
223
133
  # compute precision, recall, f1
224
- tp_counts = counts.diagonal()[1:]
134
+ tp_counts = confusion_matrix.diagonal()[1:]
225
135
 
226
136
  precision = np.zeros(n_labels, dtype=np.float64)
227
137
  np.divide(tp_counts, pd_counts, where=pd_counts > 1e-9, out=precision)
@@ -238,8 +148,8 @@ def compute_metrics(
238
148
  )
239
149
 
240
150
  # compute accuracy
241
- tp_count = counts[1:, 1:].diagonal().sum()
242
- background_count = counts[0, 0]
151
+ tp_count = confusion_matrix[1:, 1:].diagonal().sum()
152
+ background_count = confusion_matrix[0, 0]
243
153
  accuracy = (
244
154
  (tp_count + background_count) / n_pixels if n_pixels > 0 else 0.0
245
155
  )