valor-lite 0.34.3__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

@@ -1,18 +1,25 @@
1
- from collections import defaultdict
2
- from dataclasses import dataclass
1
+ import warnings
2
+ from dataclasses import asdict, dataclass
3
3
 
4
4
  import numpy as np
5
5
  from numpy.typing import NDArray
6
6
  from tqdm import tqdm
7
7
 
8
- from valor_lite.object_detection.annotation import Detection
8
+ from valor_lite.object_detection.annotation import (
9
+ Bitmask,
10
+ BoundingBox,
11
+ Detection,
12
+ Polygon,
13
+ )
9
14
  from valor_lite.object_detection.computation import (
10
15
  compute_bbox_iou,
11
16
  compute_bitmask_iou,
12
17
  compute_confusion_matrix,
18
+ compute_label_metadata,
13
19
  compute_polygon_iou,
14
20
  compute_precion_recall,
15
- compute_ranked_pairs,
21
+ filter_cache,
22
+ rank_pairs,
16
23
  )
17
24
  from valor_lite.object_detection.metric import Metric, MetricType
18
25
  from valor_lite.object_detection.utilities import (
@@ -41,11 +48,51 @@ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_m
41
48
  """
42
49
 
43
50
 
51
+ @dataclass
52
+ class Metadata:
53
+ number_of_datums: int = 0
54
+ number_of_ground_truths: int = 0
55
+ number_of_predictions: int = 0
56
+ number_of_labels: int = 0
57
+
58
+ @classmethod
59
+ def create(
60
+ cls,
61
+ detailed_pairs: NDArray[np.float64],
62
+ number_of_datums: int,
63
+ number_of_labels: int,
64
+ ):
65
+ # count number of ground truths
66
+ mask_valid_gts = detailed_pairs[:, 1] >= 0
67
+ unique_ids = np.unique(
68
+ detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
69
+ )
70
+ number_of_ground_truths = int(unique_ids.shape[0])
71
+
72
+ # count number of predictions
73
+ mask_valid_pds = detailed_pairs[:, 2] >= 0
74
+ unique_ids = np.unique(
75
+ detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
76
+ )
77
+ number_of_predictions = int(unique_ids.shape[0])
78
+
79
+ return cls(
80
+ number_of_datums=number_of_datums,
81
+ number_of_ground_truths=number_of_ground_truths,
82
+ number_of_predictions=number_of_predictions,
83
+ number_of_labels=number_of_labels,
84
+ )
85
+
86
+ def to_dict(self) -> dict[str, int | bool]:
87
+ return asdict(self)
88
+
89
+
44
90
  @dataclass
45
91
  class Filter:
46
- ranked_indices: NDArray[np.intp]
47
- detailed_indices: NDArray[np.intp]
48
- label_metadata: NDArray[np.int32]
92
+ mask_datums: NDArray[np.bool_]
93
+ mask_groundtruths: NDArray[np.bool_]
94
+ mask_predictions: NDArray[np.bool_]
95
+ metadata: Metadata
49
96
 
50
97
 
51
98
  class Evaluator:
@@ -55,29 +102,25 @@ class Evaluator:
55
102
 
56
103
  def __init__(self):
57
104
 
58
- # metadata
59
- self.n_datums = 0
60
- self.n_groundtruths = 0
61
- self.n_predictions = 0
62
- self.n_labels = 0
105
+ # external reference
106
+ self.datum_id_to_index: dict[str, int] = {}
107
+ self.groundtruth_id_to_index: dict[str, int] = {}
108
+ self.prediction_id_to_index: dict[str, int] = {}
109
+ self.label_to_index: dict[str, int] = {}
63
110
 
64
- # datum reference
65
- self.uid_to_index: dict[str, int] = dict()
66
- self.index_to_uid: dict[int, str] = dict()
111
+ self.index_to_datum_id: list[str] = []
112
+ self.index_to_groundtruth_id: list[str] = []
113
+ self.index_to_prediction_id: list[str] = []
114
+ self.index_to_label: list[str] = []
67
115
 
68
- # annotation reference
69
- self.groundtruth_examples: dict[int, NDArray[np.float16]] = dict()
70
- self.prediction_examples: dict[int, NDArray[np.float16]] = dict()
116
+ # temporary cache
117
+ self._temp_cache: list[NDArray[np.float64]] | None = []
71
118
 
72
- # label reference
73
- self.label_to_index: dict[str, int] = dict()
74
- self.index_to_label: dict[int, str] = dict()
75
-
76
- # computation caches
77
- self._detailed_pairs: NDArray[np.float64] = np.array([])
78
- self._ranked_pairs: NDArray[np.float64] = np.array([])
79
- self._label_metadata: NDArray[np.int32] = np.array([])
80
- self._label_metadata_per_datum: NDArray[np.int32] = np.array([])
119
+ # internal cache
120
+ self._detailed_pairs = np.array([[]], dtype=np.float64)
121
+ self._ranked_pairs = np.array([[]], dtype=np.float64)
122
+ self._label_metadata: NDArray[np.int32] = np.array([[]])
123
+ self._metadata = Metadata()
81
124
 
82
125
  @property
83
126
  def ignored_prediction_labels(self) -> list[str]:
@@ -102,273 +145,81 @@ class Evaluator:
102
145
  ]
103
146
 
104
147
  @property
105
- def metadata(self) -> dict:
148
+ def metadata(self) -> Metadata:
106
149
  """
107
150
  Evaluation metadata.
108
151
  """
109
- return {
110
- "n_datums": self.n_datums,
111
- "n_groundtruths": self.n_groundtruths,
112
- "n_predictions": self.n_predictions,
113
- "n_labels": self.n_labels,
114
- "ignored_prediction_labels": self.ignored_prediction_labels,
115
- "missing_prediction_labels": self.missing_prediction_labels,
116
- }
117
-
118
- def create_filter(
119
- self,
120
- datum_uids: list[str] | NDArray[np.int32] | None = None,
121
- labels: list[str] | NDArray[np.int32] | None = None,
122
- ) -> Filter:
123
- """
124
- Creates a filter that can be passed to an evaluation.
152
+ return self._metadata
125
153
 
126
- Parameters
127
- ----------
128
- datum_uids : list[str] | NDArray[np.int32], optional
129
- An optional list of string uids or a numpy array of uid indices.
130
- labels : list[str] | NDArray[np.int32], optional
131
- An optional list of labels or a numpy array of label indices.
132
-
133
- Returns
134
- -------
135
- Filter
136
- A filter object that can be passed to the `evaluate` method.
154
+ def _add_datum(self, datum_id: str) -> int:
137
155
  """
138
-
139
- n_datums = self._label_metadata_per_datum.shape[1]
140
- n_labels = self._label_metadata_per_datum.shape[2]
141
-
142
- mask_ranked = np.ones((self._ranked_pairs.shape[0], 1), dtype=np.bool_)
143
- mask_detailed = np.ones(
144
- (self._detailed_pairs.shape[0], 1), dtype=np.bool_
145
- )
146
- mask_datums = np.ones(n_datums, dtype=np.bool_)
147
- mask_labels = np.ones(n_labels, dtype=np.bool_)
148
-
149
- if datum_uids is not None:
150
- if isinstance(datum_uids, list):
151
- datum_uids = np.array(
152
- [self.uid_to_index[uid] for uid in datum_uids],
153
- dtype=np.int32,
154
- )
155
- mask_ranked[
156
- ~np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
157
- ] = False
158
- mask_detailed[
159
- ~np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
160
- ] = False
161
- mask_datums[~np.isin(np.arange(n_datums), datum_uids)] = False
162
-
163
- if labels is not None:
164
- if isinstance(labels, list):
165
- labels = np.array(
166
- [self.label_to_index[label] for label in labels]
167
- )
168
- mask_ranked[
169
- ~np.isin(self._ranked_pairs[:, 4].astype(int), labels)
170
- ] = False
171
- mask_detailed[
172
- ~np.isin(self._detailed_pairs[:, 4].astype(int), labels)
173
- ] = False
174
- mask_labels[~np.isin(np.arange(n_labels), labels)] = False
175
-
176
- mask_label_metadata = (
177
- mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
178
- )
179
- label_metadata_per_datum = self._label_metadata_per_datum.copy()
180
- label_metadata_per_datum[:, ~mask_label_metadata] = 0
181
-
182
- label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
183
- label_metadata = np.transpose(
184
- np.sum(
185
- label_metadata_per_datum,
186
- axis=1,
187
- )
188
- )
189
-
190
- return Filter(
191
- ranked_indices=np.where(mask_ranked)[0],
192
- detailed_indices=np.where(mask_detailed)[0],
193
- label_metadata=label_metadata,
194
- )
195
-
196
- def compute_precision_recall(
197
- self,
198
- iou_thresholds: list[float],
199
- score_thresholds: list[float],
200
- filter_: Filter | None = None,
201
- ) -> dict[MetricType, list[Metric]]:
202
- """
203
- Computes all metrics except for ConfusionMatrix
156
+ Helper function for adding a datum to the cache.
204
157
 
205
158
  Parameters
206
159
  ----------
207
- iou_thresholds : list[float]
208
- A list of IOU thresholds to compute metrics over.
209
- score_thresholds : list[float]
210
- A list of score thresholds to compute metrics over.
211
- filter_ : Filter, optional
212
- An optional filter object.
160
+ datum_id : str
161
+ The datum identifier.
213
162
 
214
163
  Returns
215
164
  -------
216
- dict[MetricType, list]
217
- A dictionary mapping MetricType enumerations to lists of computed metrics.
218
- """
219
-
220
- ranked_pairs = self._ranked_pairs
221
- label_metadata = self._label_metadata
222
- if filter_ is not None:
223
- ranked_pairs = ranked_pairs[filter_.ranked_indices]
224
- label_metadata = filter_.label_metadata
225
-
226
- results = compute_precion_recall(
227
- data=ranked_pairs,
228
- label_metadata=label_metadata,
229
- iou_thresholds=np.array(iou_thresholds),
230
- score_thresholds=np.array(score_thresholds),
231
- )
232
-
233
- return unpack_precision_recall_into_metric_lists(
234
- results=results,
235
- label_metadata=label_metadata,
236
- iou_thresholds=iou_thresholds,
237
- score_thresholds=score_thresholds,
238
- index_to_label=self.index_to_label,
239
- )
240
-
241
- def compute_confusion_matrix(
242
- self,
243
- iou_thresholds: list[float],
244
- score_thresholds: list[float],
245
- number_of_examples: int,
246
- filter_: Filter | None = None,
247
- ) -> list[Metric]:
248
- """
249
- Computes confusion matrices at various thresholds.
250
-
251
- Parameters
252
- ----------
253
- iou_thresholds : list[float]
254
- A list of IOU thresholds to compute metrics over.
255
- score_thresholds : list[float]
256
- A list of score thresholds to compute metrics over.
257
- number_of_examples : int
258
- Maximum number of annotation examples to return in ConfusionMatrix.
259
- filter_ : Filter, optional
260
- An optional filter object.
261
-
262
- Returns
263
- -------
264
- list[Metric]
265
- List of confusion matrices per threshold pair.
165
+ int
166
+ The datum index.
266
167
  """
267
-
268
- detailed_pairs = self._detailed_pairs
269
- label_metadata = self._label_metadata
270
- if filter_ is not None:
271
- detailed_pairs = detailed_pairs[filter_.detailed_indices]
272
- label_metadata = filter_.label_metadata
273
-
274
- if detailed_pairs.size == 0:
275
- return list()
276
-
277
- results = compute_confusion_matrix(
278
- data=detailed_pairs,
279
- label_metadata=label_metadata,
280
- iou_thresholds=np.array(iou_thresholds),
281
- score_thresholds=np.array(score_thresholds),
282
- n_examples=number_of_examples,
283
- )
284
-
285
- return unpack_confusion_matrix_into_metric_list(
286
- results=results,
287
- iou_thresholds=iou_thresholds,
288
- score_thresholds=score_thresholds,
289
- number_of_examples=number_of_examples,
290
- index_to_uid=self.index_to_uid,
291
- index_to_label=self.index_to_label,
292
- groundtruth_examples=self.groundtruth_examples,
293
- prediction_examples=self.prediction_examples,
294
- )
295
-
296
- def evaluate(
297
- self,
298
- iou_thresholds: list[float] = [0.1, 0.5, 0.75],
299
- score_thresholds: list[float] = [0.5],
300
- number_of_examples: int = 0,
301
- filter_: Filter | None = None,
302
- ) -> dict[MetricType, list[Metric]]:
168
+ if datum_id not in self.datum_id_to_index:
169
+ if len(self.datum_id_to_index) != len(self.index_to_datum_id):
170
+ raise RuntimeError("datum cache size mismatch")
171
+ idx = len(self.datum_id_to_index)
172
+ self.datum_id_to_index[datum_id] = idx
173
+ self.index_to_datum_id.append(datum_id)
174
+ return self.datum_id_to_index[datum_id]
175
+
176
+ def _add_groundtruth(self, annotation_id: str) -> int:
303
177
  """
304
- Computes all available metrics.
178
+ Helper function for adding a ground truth annotation identifier to the cache.
305
179
 
306
180
  Parameters
307
181
  ----------
308
- iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
309
- A list of IOU thresholds to compute metrics over.
310
- score_thresholds : list[float], default=[0.5]
311
- A list of score thresholds to compute metrics over.
312
- number_of_examples : int, default=0
313
- Maximum number of annotation examples to return in ConfusionMatrix.
314
- filter_ : Filter, optional
315
- An optional filter object.
182
+ annotation_id : str
183
+ The ground truth annotation identifier.
316
184
 
317
185
  Returns
318
186
  -------
319
- dict[MetricType, list[Metric]]
320
- Lists of metrics organized by metric type.
187
+ int
188
+ The ground truth annotation index.
321
189
  """
322
- metrics = self.compute_precision_recall(
323
- iou_thresholds=iou_thresholds,
324
- score_thresholds=score_thresholds,
325
- filter_=filter_,
326
- )
327
-
328
- metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
329
- iou_thresholds=iou_thresholds,
330
- score_thresholds=score_thresholds,
331
- number_of_examples=number_of_examples,
332
- filter_=filter_,
333
- )
334
-
335
- return metrics
336
-
337
-
338
- def defaultdict_int():
339
- return defaultdict(int)
340
-
341
-
342
- class DataLoader:
343
- """
344
- Object Detection DataLoader
345
- """
346
-
347
- def __init__(self):
348
- self._evaluator = Evaluator()
349
- self.pairs: list[NDArray[np.float64]] = list()
350
- self.groundtruth_count = defaultdict(defaultdict_int)
351
- self.prediction_count = defaultdict(defaultdict_int)
352
-
353
- def _add_datum(self, uid: str) -> int:
190
+ if annotation_id not in self.groundtruth_id_to_index:
191
+ if len(self.groundtruth_id_to_index) != len(
192
+ self.index_to_groundtruth_id
193
+ ):
194
+ raise RuntimeError("ground truth cache size mismatch")
195
+ idx = len(self.groundtruth_id_to_index)
196
+ self.groundtruth_id_to_index[annotation_id] = idx
197
+ self.index_to_groundtruth_id.append(annotation_id)
198
+ return self.groundtruth_id_to_index[annotation_id]
199
+
200
+ def _add_prediction(self, annotation_id: str) -> int:
354
201
  """
355
- Helper function for adding a datum to the cache.
202
+ Helper function for adding a prediction annotation identifier to the cache.
356
203
 
357
204
  Parameters
358
205
  ----------
359
- uid : str
360
- The datum uid.
206
+ annotation_id : str
207
+ The prediction annotation identifier.
361
208
 
362
209
  Returns
363
210
  -------
364
211
  int
365
- The datum index.
212
+ The prediction annotation index.
366
213
  """
367
- if uid not in self._evaluator.uid_to_index:
368
- index = len(self._evaluator.uid_to_index)
369
- self._evaluator.uid_to_index[uid] = index
370
- self._evaluator.index_to_uid[index] = uid
371
- return self._evaluator.uid_to_index[uid]
214
+ if annotation_id not in self.prediction_id_to_index:
215
+ if len(self.prediction_id_to_index) != len(
216
+ self.index_to_prediction_id
217
+ ):
218
+ raise RuntimeError("prediction cache size mismatch")
219
+ idx = len(self.prediction_id_to_index)
220
+ self.prediction_id_to_index[annotation_id] = idx
221
+ self.index_to_prediction_id.append(annotation_id)
222
+ return self.prediction_id_to_index[annotation_id]
372
223
 
373
224
  def _add_label(self, label: str) -> int:
374
225
  """
@@ -384,90 +235,14 @@ class DataLoader:
384
235
  int
385
236
  Label index.
386
237
  """
387
-
388
- label_id = len(self._evaluator.index_to_label)
389
- if label not in self._evaluator.label_to_index:
390
- self._evaluator.label_to_index[label] = label_id
391
- self._evaluator.index_to_label[label_id] = label
392
-
238
+ label_id = len(self.index_to_label)
239
+ if label not in self.label_to_index:
240
+ if len(self.label_to_index) != len(self.index_to_label):
241
+ raise RuntimeError("label cache size mismatch")
242
+ self.label_to_index[label] = label_id
243
+ self.index_to_label.append(label)
393
244
  label_id += 1
394
-
395
- return self._evaluator.label_to_index[label]
396
-
397
- def _cache_pairs(
398
- self,
399
- uid_index: int,
400
- groundtruths: list,
401
- predictions: list,
402
- ious: NDArray[np.float64],
403
- ) -> None:
404
- """
405
- Compute IOUs between groundtruths and preditions before storing as pairs.
406
-
407
- Parameters
408
- ----------
409
- uid_index : int
410
- The index of the detection.
411
- groundtruths : list
412
- A list of groundtruths.
413
- predictions : list
414
- A list of predictions.
415
- ious : NDArray[np.float64]
416
- An array with shape (n_preds, n_gts) containing IOUs.
417
- """
418
-
419
- predictions_with_iou_of_zero = np.where((ious < 1e-9).all(axis=1))[0]
420
- groundtruths_with_iou_of_zero = np.where((ious < 1e-9).all(axis=0))[0]
421
-
422
- pairs = [
423
- np.array(
424
- [
425
- float(uid_index),
426
- float(gidx),
427
- float(pidx),
428
- ious[pidx, gidx],
429
- float(glabel),
430
- float(plabel),
431
- float(score),
432
- ]
433
- )
434
- for pidx, plabel, score in predictions
435
- for gidx, glabel in groundtruths
436
- if ious[pidx, gidx] >= 1e-9
437
- ]
438
- pairs.extend(
439
- [
440
- np.array(
441
- [
442
- float(uid_index),
443
- -1.0,
444
- float(predictions[index][0]),
445
- 0.0,
446
- -1.0,
447
- float(predictions[index][1]),
448
- float(predictions[index][2]),
449
- ]
450
- )
451
- for index in predictions_with_iou_of_zero
452
- ]
453
- )
454
- pairs.extend(
455
- [
456
- np.array(
457
- [
458
- float(uid_index),
459
- float(groundtruths[index][0]),
460
- -1.0,
461
- 0.0,
462
- float(groundtruths[index][1]),
463
- -1.0,
464
- -1.0,
465
- ]
466
- )
467
- for index in groundtruths_with_iou_of_zero
468
- ]
469
- )
470
- self.pairs.append(np.array(pairs))
245
+ return self.label_to_index[label]
471
246
 
472
247
  def _add_data(
473
248
  self,
@@ -491,66 +266,102 @@ class DataLoader:
491
266
  for detection, ious in tqdm(
492
267
  zip(detections, detection_ious), disable=disable_tqdm
493
268
  ):
494
-
495
- # update metadata
496
- self._evaluator.n_datums += 1
497
- self._evaluator.n_groundtruths += len(detection.groundtruths)
498
- self._evaluator.n_predictions += len(detection.predictions)
499
-
500
- # update datum uid index
501
- uid_index = self._add_datum(uid=detection.uid)
502
-
503
- # initialize bounding box examples
504
- self._evaluator.groundtruth_examples[uid_index] = np.zeros(
505
- (len(detection.groundtruths), 4), dtype=np.float16
506
- )
507
- self._evaluator.prediction_examples[uid_index] = np.zeros(
508
- (len(detection.predictions), 4), dtype=np.float16
509
- )
510
-
511
- # cache labels and annotations
512
- groundtruths = list()
513
- predictions = list()
514
-
515
- for gidx, gann in enumerate(detection.groundtruths):
516
- self._evaluator.groundtruth_examples[uid_index][
517
- gidx
518
- ] = gann.extrema
519
- for glabel in gann.labels:
520
- label_idx = self._add_label(glabel)
521
- self.groundtruth_count[label_idx][uid_index] += 1
522
- groundtruths.append(
523
- (
524
- gidx,
525
- label_idx,
269
+ # cache labels and annotation pairs
270
+ pairs = []
271
+ datum_idx = self._add_datum(detection.uid)
272
+ if detection.groundtruths:
273
+ for gidx, gann in enumerate(detection.groundtruths):
274
+ groundtruth_idx = self._add_groundtruth(gann.uid)
275
+ glabel_idx = self._add_label(gann.labels[0])
276
+ if (ious[:, gidx] < 1e-9).all():
277
+ pairs.extend(
278
+ [
279
+ np.array(
280
+ [
281
+ float(datum_idx),
282
+ float(groundtruth_idx),
283
+ -1.0,
284
+ float(glabel_idx),
285
+ -1.0,
286
+ 0.0,
287
+ -1.0,
288
+ ]
289
+ )
290
+ ]
526
291
  )
292
+ for pidx, pann in enumerate(detection.predictions):
293
+ prediction_idx = self._add_prediction(pann.uid)
294
+ if (ious[pidx, :] < 1e-9).all():
295
+ pairs.extend(
296
+ [
297
+ np.array(
298
+ [
299
+ float(datum_idx),
300
+ -1.0,
301
+ float(prediction_idx),
302
+ -1.0,
303
+ float(self._add_label(plabel)),
304
+ 0.0,
305
+ float(pscore),
306
+ ]
307
+ )
308
+ for plabel, pscore in zip(
309
+ pann.labels, pann.scores
310
+ )
311
+ ]
312
+ )
313
+ if ious[pidx, gidx] >= 1e-9:
314
+ pairs.extend(
315
+ [
316
+ np.array(
317
+ [
318
+ float(datum_idx),
319
+ float(groundtruth_idx),
320
+ float(prediction_idx),
321
+ float(self._add_label(glabel)),
322
+ float(self._add_label(plabel)),
323
+ ious[pidx, gidx],
324
+ float(pscore),
325
+ ]
326
+ )
327
+ for glabel in gann.labels
328
+ for plabel, pscore in zip(
329
+ pann.labels, pann.scores
330
+ )
331
+ ]
332
+ )
333
+ elif detection.predictions:
334
+ for pidx, pann in enumerate(detection.predictions):
335
+ prediction_idx = self._add_prediction(pann.uid)
336
+ pairs.extend(
337
+ [
338
+ np.array(
339
+ [
340
+ float(datum_idx),
341
+ -1.0,
342
+ float(prediction_idx),
343
+ -1.0,
344
+ float(self._add_label(plabel)),
345
+ 0.0,
346
+ float(pscore),
347
+ ]
348
+ )
349
+ for plabel, pscore in zip(pann.labels, pann.scores)
350
+ ]
527
351
  )
528
352
 
529
- for pidx, pann in enumerate(detection.predictions):
530
- self._evaluator.prediction_examples[uid_index][
531
- pidx
532
- ] = pann.extrema
533
- for plabel, pscore in zip(pann.labels, pann.scores):
534
- label_idx = self._add_label(plabel)
535
- self.prediction_count[label_idx][uid_index] += 1
536
- predictions.append(
537
- (
538
- pidx,
539
- label_idx,
540
- pscore,
541
- )
353
+ data = np.array(pairs)
354
+ if data.size > 0:
355
+ # reset filtered cache if it exists
356
+ if self._temp_cache is None:
357
+ raise RuntimeError(
358
+ "cannot add data as evaluator has already been finalized"
542
359
  )
543
-
544
- self._cache_pairs(
545
- uid_index=uid_index,
546
- groundtruths=groundtruths,
547
- predictions=predictions,
548
- ious=ious,
549
- )
360
+ self._temp_cache.append(data)
550
361
 
551
362
  def add_bounding_boxes(
552
363
  self,
553
- detections: list[Detection],
364
+ detections: list[Detection[BoundingBox]],
554
365
  show_progress: bool = False,
555
366
  ):
556
367
  """
@@ -584,7 +395,7 @@ class DataLoader:
584
395
 
585
396
  def add_polygons(
586
397
  self,
587
- detections: list[Detection],
398
+ detections: list[Detection[Polygon]],
588
399
  show_progress: bool = False,
589
400
  ):
590
401
  """
@@ -601,7 +412,7 @@ class DataLoader:
601
412
  compute_polygon_iou(
602
413
  np.array(
603
414
  [
604
- [gt.shape, pd.shape] # type: ignore - using the AttributeError as a validator
415
+ [gt.shape, pd.shape]
605
416
  for pd in detection.predictions
606
417
  for gt in detection.groundtruths
607
418
  ]
@@ -617,7 +428,7 @@ class DataLoader:
617
428
 
618
429
  def add_bitmasks(
619
430
  self,
620
- detections: list[Detection],
431
+ detections: list[Detection[Bitmask]],
621
432
  show_progress: bool = False,
622
433
  ):
623
434
  """
@@ -634,7 +445,7 @@ class DataLoader:
634
445
  compute_bitmask_iou(
635
446
  np.array(
636
447
  [
637
- [gt.mask, pd.mask] # type: ignore - using the AttributeError as a validator
448
+ [gt.mask, pd.mask]
638
449
  for pd in detection.predictions
639
450
  for gt in detection.groundtruths
640
451
  ]
@@ -648,7 +459,7 @@ class DataLoader:
648
459
  show_progress=show_progress,
649
460
  )
650
461
 
651
- def finalize(self) -> Evaluator:
462
+ def finalize(self):
652
463
  """
653
464
  Performs data finalization and some preprocessing steps.
654
465
 
@@ -657,65 +468,347 @@ class DataLoader:
657
468
  Evaluator
658
469
  A ready-to-use evaluator object.
659
470
  """
471
+ n_labels = len(self.index_to_label)
472
+ n_datums = len(self.index_to_datum_id)
473
+ if self._temp_cache is None:
474
+ warnings.warn("evaluator is already finalized or in a bad state")
475
+ return self
476
+ elif not self._temp_cache:
477
+ self._detailed_pairs = np.array([], dtype=np.float64)
478
+ self._ranked_pairs = np.array([], dtype=np.float64)
479
+ self._label_metadata = np.zeros((n_labels, 2), dtype=np.int32)
480
+ self._metadata = Metadata()
481
+ warnings.warn("no valid pairs")
482
+ return self
483
+ else:
484
+ self._detailed_pairs = np.concatenate(self._temp_cache, axis=0)
485
+ self._temp_cache = None
486
+
487
+ # order pairs by descending score, iou
488
+ indices = np.lexsort(
489
+ (
490
+ -self._detailed_pairs[:, 5], # iou
491
+ -self._detailed_pairs[:, 6], # score
492
+ )
493
+ )
494
+ self._detailed_pairs = self._detailed_pairs[indices]
495
+ self._label_metadata = compute_label_metadata(
496
+ ids=self._detailed_pairs[:, :5].astype(np.int32),
497
+ n_labels=n_labels,
498
+ )
499
+ self._ranked_pairs = rank_pairs(
500
+ detailed_pairs=self._detailed_pairs,
501
+ label_metadata=self._label_metadata,
502
+ )
503
+ self._metadata = Metadata.create(
504
+ detailed_pairs=self._detailed_pairs,
505
+ number_of_datums=n_datums,
506
+ number_of_labels=n_labels,
507
+ )
508
+ return self
660
509
 
661
- self.pairs = [pair for pair in self.pairs if pair.size > 0]
662
- if len(self.pairs) == 0:
663
- raise ValueError("No data available to create evaluator.")
510
+ def create_filter(
511
+ self,
512
+ datum_ids: list[str] | None = None,
513
+ groundtruth_ids: list[str] | None = None,
514
+ prediction_ids: list[str] | None = None,
515
+ labels: list[str] | None = None,
516
+ ) -> Filter:
517
+ """
518
+ Creates a filter object.
664
519
 
665
- n_datums = self._evaluator.n_datums
666
- n_labels = len(self._evaluator.index_to_label)
520
+ Parameters
521
+ ----------
522
+ datum_uids : list[str], optional
523
+ An optional list of string uids representing datums to keep.
524
+ groundtruth_ids : list[str], optional
525
+ An optional list of string uids representing ground truth annotations to keep.
526
+ prediction_ids : list[str], optional
527
+ An optional list of string uids representing prediction annotations to keep.
528
+ labels : list[str], optional
529
+ An optional list of labels to keep.
530
+ """
531
+ mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
532
+
533
+ # filter datums
534
+ if datum_ids is not None:
535
+ if not datum_ids:
536
+ warnings.warn("creating a filter that removes all datums")
537
+ return Filter(
538
+ mask_datums=np.zeros_like(mask_datums),
539
+ mask_groundtruths=np.array([], dtype=np.bool_),
540
+ mask_predictions=np.array([], dtype=np.bool_),
541
+ metadata=Metadata(),
542
+ )
543
+ valid_datum_indices = np.array(
544
+ [self.datum_id_to_index[uid] for uid in datum_ids],
545
+ dtype=np.int32,
546
+ )
547
+ mask_datums = np.isin(
548
+ self._detailed_pairs[:, 0], valid_datum_indices
549
+ )
667
550
 
668
- self._evaluator.n_labels = n_labels
551
+ filtered_detailed_pairs = self._detailed_pairs[mask_datums]
552
+ n_pairs = self._detailed_pairs[mask_datums].shape[0]
553
+ mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
554
+ mask_predictions = np.zeros_like(mask_groundtruths)
669
555
 
670
- self._evaluator._label_metadata_per_datum = np.zeros(
671
- (2, n_datums, n_labels), dtype=np.int32
672
- )
673
- for datum_idx in range(n_datums):
674
- for label_idx in range(n_labels):
675
- gt_count = (
676
- self.groundtruth_count[label_idx].get(datum_idx, 0)
677
- if label_idx in self.groundtruth_count
678
- else 0
556
+ # filter by ground truth annotation ids
557
+ if groundtruth_ids is not None:
558
+ if not groundtruth_ids:
559
+ warnings.warn(
560
+ "creating a filter that removes all ground truths"
679
561
  )
680
- pd_count = (
681
- self.prediction_count[label_idx].get(datum_idx, 0)
682
- if label_idx in self.prediction_count
683
- else 0
562
+ valid_groundtruth_indices = np.array(
563
+ [self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
564
+ dtype=np.int32,
565
+ )
566
+ mask_groundtruths[
567
+ ~np.isin(
568
+ filtered_detailed_pairs[:, 1],
569
+ valid_groundtruth_indices,
684
570
  )
685
- self._evaluator._label_metadata_per_datum[
686
- :, datum_idx, label_idx
687
- ] = np.array([gt_count, pd_count])
688
-
689
- self._evaluator._label_metadata = np.array(
690
- [
691
- [
692
- float(
693
- np.sum(
694
- self._evaluator._label_metadata_per_datum[
695
- 0, :, label_idx
696
- ]
697
- )
698
- ),
699
- float(
700
- np.sum(
701
- self._evaluator._label_metadata_per_datum[
702
- 1, :, label_idx
703
- ]
704
- )
705
- ),
706
- ]
707
- for label_idx in range(n_labels)
708
- ]
571
+ ] = True
572
+
573
+ # filter by prediction annotation ids
574
+ if prediction_ids is not None:
575
+ if not prediction_ids:
576
+ warnings.warn("creating a filter that removes all predictions")
577
+ valid_prediction_indices = np.array(
578
+ [self.prediction_id_to_index[uid] for uid in prediction_ids],
579
+ dtype=np.int32,
580
+ )
581
+ mask_predictions[
582
+ ~np.isin(
583
+ filtered_detailed_pairs[:, 2],
584
+ valid_prediction_indices,
585
+ )
586
+ ] = True
587
+
588
+ # filter by labels
589
+ if labels is not None:
590
+ if not labels:
591
+ warnings.warn("creating a filter that removes all labels")
592
+ return Filter(
593
+ mask_datums=mask_datums,
594
+ mask_groundtruths=np.ones_like(mask_datums),
595
+ mask_predictions=np.ones_like(mask_datums),
596
+ metadata=Metadata(),
597
+ )
598
+ valid_label_indices = np.array(
599
+ [self.label_to_index[label] for label in labels] + [-1]
600
+ )
601
+ mask_groundtruths[
602
+ ~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
603
+ ] = True
604
+ mask_predictions[
605
+ ~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
606
+ ] = True
607
+
608
+ filtered_detailed_pairs, _, _ = filter_cache(
609
+ self._detailed_pairs,
610
+ mask_datums=mask_datums,
611
+ mask_ground_truths=mask_groundtruths,
612
+ mask_predictions=mask_predictions,
613
+ n_labels=len(self.index_to_label),
709
614
  )
710
615
 
711
- self._evaluator._detailed_pairs = np.concatenate(
712
- self.pairs,
713
- axis=0,
616
+ number_of_datums = (
617
+ len(datum_ids)
618
+ if datum_ids
619
+ else np.unique(filtered_detailed_pairs[:, 0]).size
714
620
  )
715
621
 
716
- self._evaluator._ranked_pairs = compute_ranked_pairs(
717
- self.pairs,
718
- label_metadata=self._evaluator._label_metadata,
622
+ return Filter(
623
+ mask_datums=mask_datums,
624
+ mask_groundtruths=mask_groundtruths,
625
+ mask_predictions=mask_predictions,
626
+ metadata=Metadata.create(
627
+ detailed_pairs=filtered_detailed_pairs,
628
+ number_of_datums=number_of_datums,
629
+ number_of_labels=len(self.index_to_label),
630
+ ),
719
631
  )
720
632
 
721
- return self._evaluator
633
+ def filter(
634
+ self, filter_: Filter
635
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
636
+ """
637
+ Performs filtering over the internal cache.
638
+
639
+ Parameters
640
+ ----------
641
+ filter_ : Filter
642
+ The filter parameterization.
643
+
644
+ Returns
645
+ -------
646
+ NDArray[float64]
647
+ Filtered detailed pairs.
648
+ NDArray[float64]
649
+ Filtered ranked pairs.
650
+ NDArray[int32]
651
+ Label metadata.
652
+ """
653
+ if not filter_.mask_datums.any():
654
+ warnings.warn("filter removed all datums")
655
+ return (
656
+ np.array([], dtype=np.float64),
657
+ np.array([], dtype=np.float64),
658
+ np.zeros((self.metadata.number_of_labels, 2), dtype=np.int32),
659
+ )
660
+ if filter_.mask_groundtruths.all():
661
+ warnings.warn("filter removed all ground truths")
662
+ if filter_.mask_predictions.all():
663
+ warnings.warn("filter removed all predictions")
664
+ return filter_cache(
665
+ detailed_pairs=self._detailed_pairs,
666
+ mask_datums=filter_.mask_datums,
667
+ mask_ground_truths=filter_.mask_groundtruths,
668
+ mask_predictions=filter_.mask_predictions,
669
+ n_labels=len(self.index_to_label),
670
+ )
671
+
672
+ def compute_precision_recall(
673
+ self,
674
+ iou_thresholds: list[float],
675
+ score_thresholds: list[float],
676
+ filter_: Filter | None = None,
677
+ ) -> dict[MetricType, list[Metric]]:
678
+ """
679
+ Computes all metrics except for ConfusionMatrix
680
+
681
+ Parameters
682
+ ----------
683
+ iou_thresholds : list[float]
684
+ A list of IOU thresholds to compute metrics over.
685
+ score_thresholds : list[float]
686
+ A list of score thresholds to compute metrics over.
687
+ filter_ : Filter, optional
688
+ A collection of filter parameters and masks.
689
+
690
+ Returns
691
+ -------
692
+ dict[MetricType, list]
693
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
694
+ """
695
+ if not iou_thresholds:
696
+ raise ValueError("At least one IOU threshold must be passed.")
697
+ elif not score_thresholds:
698
+ raise ValueError("At least one score threshold must be passed.")
699
+
700
+ if filter_ is not None:
701
+ _, ranked_pairs, label_metadata = self.filter(filter_=filter_)
702
+ else:
703
+ ranked_pairs = self._ranked_pairs
704
+ label_metadata = self._label_metadata
705
+
706
+ results = compute_precion_recall(
707
+ ranked_pairs=ranked_pairs,
708
+ label_metadata=label_metadata,
709
+ iou_thresholds=np.array(iou_thresholds),
710
+ score_thresholds=np.array(score_thresholds),
711
+ )
712
+ return unpack_precision_recall_into_metric_lists(
713
+ results=results,
714
+ label_metadata=label_metadata,
715
+ iou_thresholds=iou_thresholds,
716
+ score_thresholds=score_thresholds,
717
+ index_to_label=self.index_to_label,
718
+ )
719
+
720
+ def compute_confusion_matrix(
721
+ self,
722
+ iou_thresholds: list[float],
723
+ score_thresholds: list[float],
724
+ filter_: Filter | None = None,
725
+ ) -> list[Metric]:
726
+ """
727
+ Computes confusion matrices at various thresholds.
728
+
729
+ Parameters
730
+ ----------
731
+ iou_thresholds : list[float]
732
+ A list of IOU thresholds to compute metrics over.
733
+ score_thresholds : list[float]
734
+ A list of score thresholds to compute metrics over.
735
+ filter_ : Filter, optional
736
+ A collection of filter parameters and masks.
737
+
738
+ Returns
739
+ -------
740
+ list[Metric]
741
+ List of confusion matrices per threshold pair.
742
+ """
743
+ if not iou_thresholds:
744
+ raise ValueError("At least one IOU threshold must be passed.")
745
+ elif not score_thresholds:
746
+ raise ValueError("At least one score threshold must be passed.")
747
+
748
+ if filter_ is not None:
749
+ detailed_pairs, _, _ = self.filter(filter_=filter_)
750
+ else:
751
+ detailed_pairs = self._detailed_pairs
752
+
753
+ if detailed_pairs.size == 0:
754
+ warnings.warn("attempted to compute over an empty set")
755
+ return []
756
+
757
+ results = compute_confusion_matrix(
758
+ detailed_pairs=detailed_pairs,
759
+ iou_thresholds=np.array(iou_thresholds),
760
+ score_thresholds=np.array(score_thresholds),
761
+ )
762
+ return unpack_confusion_matrix_into_metric_list(
763
+ results=results,
764
+ detailed_pairs=detailed_pairs,
765
+ iou_thresholds=iou_thresholds,
766
+ score_thresholds=score_thresholds,
767
+ index_to_datum_id=self.index_to_datum_id,
768
+ index_to_groundtruth_id=self.index_to_groundtruth_id,
769
+ index_to_prediction_id=self.index_to_prediction_id,
770
+ index_to_label=self.index_to_label,
771
+ )
772
+
773
+ def evaluate(
774
+ self,
775
+ iou_thresholds: list[float] = [0.1, 0.5, 0.75],
776
+ score_thresholds: list[float] = [0.5],
777
+ filter_: Filter | None = None,
778
+ ) -> dict[MetricType, list[Metric]]:
779
+ """
780
+ Computes all available metrics.
781
+
782
+ Parameters
783
+ ----------
784
+ iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
785
+ A list of IOU thresholds to compute metrics over.
786
+ score_thresholds : list[float], default=[0.5]
787
+ A list of score thresholds to compute metrics over.
788
+ filter_ : Filter, optional
789
+ A collection of filter parameters and masks.
790
+
791
+ Returns
792
+ -------
793
+ dict[MetricType, list[Metric]]
794
+ Lists of metrics organized by metric type.
795
+ """
796
+ metrics = self.compute_precision_recall(
797
+ iou_thresholds=iou_thresholds,
798
+ score_thresholds=score_thresholds,
799
+ filter_=filter_,
800
+ )
801
+ metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
802
+ iou_thresholds=iou_thresholds,
803
+ score_thresholds=score_thresholds,
804
+ filter_=filter_,
805
+ )
806
+ return metrics
807
+
808
+
809
+ class DataLoader(Evaluator):
810
+ """
811
+ Used for backwards compatibility as the Evaluator now handles ingestion.
812
+ """
813
+
814
+ pass