valor-lite 0.34.3__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

@@ -1,18 +1,23 @@
1
- from collections import defaultdict
2
- from dataclasses import dataclass
1
+ import warnings
3
2
 
4
3
  import numpy as np
5
4
  from numpy.typing import NDArray
6
5
  from tqdm import tqdm
7
6
 
8
- from valor_lite.object_detection.annotation import Detection
7
+ from valor_lite.object_detection.annotation import (
8
+ Bitmask,
9
+ BoundingBox,
10
+ Detection,
11
+ Polygon,
12
+ )
9
13
  from valor_lite.object_detection.computation import (
10
14
  compute_bbox_iou,
11
15
  compute_bitmask_iou,
12
16
  compute_confusion_matrix,
17
+ compute_label_metadata,
13
18
  compute_polygon_iou,
14
19
  compute_precion_recall,
15
- compute_ranked_pairs,
20
+ rank_pairs,
16
21
  )
17
22
  from valor_lite.object_detection.metric import Metric, MetricType
18
23
  from valor_lite.object_detection.utilities import (
@@ -41,13 +46,6 @@ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_m
41
46
  """
42
47
 
43
48
 
44
- @dataclass
45
- class Filter:
46
- ranked_indices: NDArray[np.intp]
47
- detailed_indices: NDArray[np.intp]
48
- label_metadata: NDArray[np.int32]
49
-
50
-
51
49
  class Evaluator:
52
50
  """
53
51
  Object Detection Evaluator
@@ -55,37 +53,94 @@ class Evaluator:
55
53
 
56
54
  def __init__(self):
57
55
 
58
- # metadata
59
- self.n_datums = 0
60
- self.n_groundtruths = 0
61
- self.n_predictions = 0
62
- self.n_labels = 0
56
+ # external reference
57
+ self.datum_id_to_index: dict[str, int] = {}
58
+ self.groundtruth_id_to_index: dict[str, int] = {}
59
+ self.prediction_id_to_index: dict[str, int] = {}
60
+ self.label_to_index: dict[str, int] = {}
61
+
62
+ self.index_to_datum_id: list[str] = []
63
+ self.index_to_groundtruth_id: list[str] = []
64
+ self.index_to_prediction_id: list[str] = []
65
+ self.index_to_label: list[str] = []
66
+
67
+ # temporary cache
68
+ self._temp_cache: list[NDArray[np.float64]] | None = []
69
+
70
+ # cache
71
+ self._detailed_pairs = np.array([[]], dtype=np.float64)
72
+ self._ranked_pairs = np.array([[]], dtype=np.float64)
73
+ self._label_metadata: NDArray[np.int32] = np.array([[]])
74
+
75
+ # filter cache
76
+ self._filtered_detailed_pairs: NDArray[np.float64] | None = None
77
+ self._filtered_ranked_pairs: NDArray[np.float64] | None = None
78
+ self._filtered_label_metadata: NDArray[np.int32] | None = None
79
+
80
+ @property
81
+ def is_filtered(self) -> bool:
82
+ return self._filtered_detailed_pairs is not None
83
+
84
+ @property
85
+ def label_metadata(self) -> NDArray[np.int32]:
86
+ return (
87
+ self._filtered_label_metadata
88
+ if self._filtered_label_metadata is not None
89
+ else self._label_metadata
90
+ )
63
91
 
64
- # datum reference
65
- self.uid_to_index: dict[str, int] = dict()
66
- self.index_to_uid: dict[int, str] = dict()
92
+ @property
93
+ def detailed_pairs(self) -> NDArray[np.float64]:
94
+ return (
95
+ self._filtered_detailed_pairs
96
+ if self._filtered_detailed_pairs is not None
97
+ else self._detailed_pairs
98
+ )
67
99
 
68
- # annotation reference
69
- self.groundtruth_examples: dict[int, NDArray[np.float16]] = dict()
70
- self.prediction_examples: dict[int, NDArray[np.float16]] = dict()
100
+ @property
101
+ def ranked_pairs(self) -> NDArray[np.float64]:
102
+ return (
103
+ self._filtered_ranked_pairs
104
+ if self._filtered_ranked_pairs is not None
105
+ else self._ranked_pairs
106
+ )
71
107
 
72
- # label reference
73
- self.label_to_index: dict[str, int] = dict()
74
- self.index_to_label: dict[int, str] = dict()
108
+ @property
109
+ def n_labels(self) -> int:
110
+ """Returns the total number of unique labels."""
111
+ return len(self.index_to_label)
75
112
 
76
- # computation caches
77
- self._detailed_pairs: NDArray[np.float64] = np.array([])
78
- self._ranked_pairs: NDArray[np.float64] = np.array([])
79
- self._label_metadata: NDArray[np.int32] = np.array([])
80
- self._label_metadata_per_datum: NDArray[np.int32] = np.array([])
113
+ @property
114
+ def n_datums(self) -> int:
115
+ """Returns the number of datums."""
116
+ return np.unique(self.detailed_pairs[:, 0]).size
117
+
118
+ @property
119
+ def n_groundtruths(self) -> int:
120
+ """Returns the number of ground truth annotations."""
121
+ mask_valid_gts = self.detailed_pairs[:, 1] >= 0
122
+ unique_ids = np.unique(
123
+ self.detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
124
+ )
125
+ return int(unique_ids.shape[0])
126
+
127
+ @property
128
+ def n_predictions(self) -> int:
129
+ """Returns the number of prediction annotations."""
130
+ mask_valid_pds = self.detailed_pairs[:, 2] >= 0
131
+ unique_ids = np.unique(
132
+ self.detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
133
+ )
134
+ return int(unique_ids.shape[0])
81
135
 
82
136
  @property
83
137
  def ignored_prediction_labels(self) -> list[str]:
84
138
  """
85
139
  Prediction labels that are not present in the ground truth set.
86
140
  """
87
- glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
88
- plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
141
+ label_metadata = self.label_metadata
142
+ glabels = set(np.where(label_metadata[:, 0] > 0)[0])
143
+ plabels = set(np.where(label_metadata[:, 1] > 0)[0])
89
144
  return [
90
145
  self.index_to_label[label_id] for label_id in (plabels - glabels)
91
146
  ]
@@ -95,8 +150,9 @@ class Evaluator:
95
150
  """
96
151
  Ground truth labels that are not present in the prediction set.
97
152
  """
98
- glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
99
- plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
153
+ label_metadata = self.label_metadata
154
+ glabels = set(np.where(label_metadata[:, 0] > 0)[0])
155
+ plabels = set(np.where(label_metadata[:, 1] > 0)[0])
100
156
  return [
101
157
  self.index_to_label[label_id] for label_id in (glabels - plabels)
102
158
  ]
@@ -115,89 +171,10 @@ class Evaluator:
115
171
  "missing_prediction_labels": self.missing_prediction_labels,
116
172
  }
117
173
 
118
- def create_filter(
119
- self,
120
- datum_uids: list[str] | NDArray[np.int32] | None = None,
121
- labels: list[str] | NDArray[np.int32] | None = None,
122
- ) -> Filter:
123
- """
124
- Creates a filter that can be passed to an evaluation.
125
-
126
- Parameters
127
- ----------
128
- datum_uids : list[str] | NDArray[np.int32], optional
129
- An optional list of string uids or a numpy array of uid indices.
130
- labels : list[str] | NDArray[np.int32], optional
131
- An optional list of labels or a numpy array of label indices.
132
-
133
- Returns
134
- -------
135
- Filter
136
- A filter object that can be passed to the `evaluate` method.
137
- """
138
-
139
- n_datums = self._label_metadata_per_datum.shape[1]
140
- n_labels = self._label_metadata_per_datum.shape[2]
141
-
142
- mask_ranked = np.ones((self._ranked_pairs.shape[0], 1), dtype=np.bool_)
143
- mask_detailed = np.ones(
144
- (self._detailed_pairs.shape[0], 1), dtype=np.bool_
145
- )
146
- mask_datums = np.ones(n_datums, dtype=np.bool_)
147
- mask_labels = np.ones(n_labels, dtype=np.bool_)
148
-
149
- if datum_uids is not None:
150
- if isinstance(datum_uids, list):
151
- datum_uids = np.array(
152
- [self.uid_to_index[uid] for uid in datum_uids],
153
- dtype=np.int32,
154
- )
155
- mask_ranked[
156
- ~np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
157
- ] = False
158
- mask_detailed[
159
- ~np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
160
- ] = False
161
- mask_datums[~np.isin(np.arange(n_datums), datum_uids)] = False
162
-
163
- if labels is not None:
164
- if isinstance(labels, list):
165
- labels = np.array(
166
- [self.label_to_index[label] for label in labels]
167
- )
168
- mask_ranked[
169
- ~np.isin(self._ranked_pairs[:, 4].astype(int), labels)
170
- ] = False
171
- mask_detailed[
172
- ~np.isin(self._detailed_pairs[:, 4].astype(int), labels)
173
- ] = False
174
- mask_labels[~np.isin(np.arange(n_labels), labels)] = False
175
-
176
- mask_label_metadata = (
177
- mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
178
- )
179
- label_metadata_per_datum = self._label_metadata_per_datum.copy()
180
- label_metadata_per_datum[:, ~mask_label_metadata] = 0
181
-
182
- label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
183
- label_metadata = np.transpose(
184
- np.sum(
185
- label_metadata_per_datum,
186
- axis=1,
187
- )
188
- )
189
-
190
- return Filter(
191
- ranked_indices=np.where(mask_ranked)[0],
192
- detailed_indices=np.where(mask_detailed)[0],
193
- label_metadata=label_metadata,
194
- )
195
-
196
174
  def compute_precision_recall(
197
175
  self,
198
176
  iou_thresholds: list[float],
199
177
  score_thresholds: list[float],
200
- filter_: Filter | None = None,
201
178
  ) -> dict[MetricType, list[Metric]]:
202
179
  """
203
180
  Computes all metrics except for ConfusionMatrix
@@ -208,31 +185,25 @@ class Evaluator:
208
185
  A list of IOU thresholds to compute metrics over.
209
186
  score_thresholds : list[float]
210
187
  A list of score thresholds to compute metrics over.
211
- filter_ : Filter, optional
212
- An optional filter object.
213
188
 
214
189
  Returns
215
190
  -------
216
191
  dict[MetricType, list]
217
192
  A dictionary mapping MetricType enumerations to lists of computed metrics.
218
193
  """
219
-
220
- ranked_pairs = self._ranked_pairs
221
- label_metadata = self._label_metadata
222
- if filter_ is not None:
223
- ranked_pairs = ranked_pairs[filter_.ranked_indices]
224
- label_metadata = filter_.label_metadata
225
-
194
+ if not iou_thresholds:
195
+ raise ValueError("At least one IOU threshold must be passed.")
196
+ elif not score_thresholds:
197
+ raise ValueError("At least one score threshold must be passed.")
226
198
  results = compute_precion_recall(
227
- data=ranked_pairs,
228
- label_metadata=label_metadata,
199
+ ranked_pairs=self.ranked_pairs,
200
+ label_metadata=self.label_metadata,
229
201
  iou_thresholds=np.array(iou_thresholds),
230
202
  score_thresholds=np.array(score_thresholds),
231
203
  )
232
-
233
204
  return unpack_precision_recall_into_metric_lists(
234
205
  results=results,
235
- label_metadata=label_metadata,
206
+ label_metadata=self.label_metadata,
236
207
  iou_thresholds=iou_thresholds,
237
208
  score_thresholds=score_thresholds,
238
209
  index_to_label=self.index_to_label,
@@ -242,8 +213,6 @@ class Evaluator:
242
213
  self,
243
214
  iou_thresholds: list[float],
244
215
  score_thresholds: list[float],
245
- number_of_examples: int,
246
- filter_: Filter | None = None,
247
216
  ) -> list[Metric]:
248
217
  """
249
218
  Computes confusion matrices at various thresholds.
@@ -254,51 +223,39 @@ class Evaluator:
254
223
  A list of IOU thresholds to compute metrics over.
255
224
  score_thresholds : list[float]
256
225
  A list of score thresholds to compute metrics over.
257
- number_of_examples : int
258
- Maximum number of annotation examples to return in ConfusionMatrix.
259
- filter_ : Filter, optional
260
- An optional filter object.
261
226
 
262
227
  Returns
263
228
  -------
264
229
  list[Metric]
265
230
  List of confusion matrices per threshold pair.
266
231
  """
267
-
268
- detailed_pairs = self._detailed_pairs
269
- label_metadata = self._label_metadata
270
- if filter_ is not None:
271
- detailed_pairs = detailed_pairs[filter_.detailed_indices]
272
- label_metadata = filter_.label_metadata
273
-
274
- if detailed_pairs.size == 0:
275
- return list()
276
-
232
+ if not iou_thresholds:
233
+ raise ValueError("At least one IOU threshold must be passed.")
234
+ elif not score_thresholds:
235
+ raise ValueError("At least one score threshold must be passed.")
236
+ elif self.detailed_pairs.size == 0:
237
+ warnings.warn("attempted to compute over an empty set")
238
+ return []
277
239
  results = compute_confusion_matrix(
278
- data=detailed_pairs,
279
- label_metadata=label_metadata,
240
+ detailed_pairs=self.detailed_pairs,
280
241
  iou_thresholds=np.array(iou_thresholds),
281
242
  score_thresholds=np.array(score_thresholds),
282
- n_examples=number_of_examples,
283
243
  )
284
-
285
244
  return unpack_confusion_matrix_into_metric_list(
286
245
  results=results,
246
+ detailed_pairs=self.detailed_pairs,
287
247
  iou_thresholds=iou_thresholds,
288
248
  score_thresholds=score_thresholds,
289
- number_of_examples=number_of_examples,
290
- index_to_uid=self.index_to_uid,
249
+ index_to_datum_id=self.index_to_datum_id,
250
+ index_to_groundtruth_id=self.index_to_groundtruth_id,
251
+ index_to_prediction_id=self.index_to_prediction_id,
291
252
  index_to_label=self.index_to_label,
292
- groundtruth_examples=self.groundtruth_examples,
293
- prediction_examples=self.prediction_examples,
294
253
  )
295
254
 
296
255
  def evaluate(
297
256
  self,
298
257
  iou_thresholds: list[float] = [0.1, 0.5, 0.75],
299
258
  score_thresholds: list[float] = [0.5],
300
- number_of_examples: int = 0,
301
- filter_: Filter | None = None,
302
259
  ) -> dict[MetricType, list[Metric]]:
303
260
  """
304
261
  Computes all available metrics.
@@ -309,10 +266,6 @@ class Evaluator:
309
266
  A list of IOU thresholds to compute metrics over.
310
267
  score_thresholds : list[float], default=[0.5]
311
268
  A list of score thresholds to compute metrics over.
312
- number_of_examples : int, default=0
313
- Maximum number of annotation examples to return in ConfusionMatrix.
314
- filter_ : Filter, optional
315
- An optional filter object.
316
269
 
317
270
  Returns
318
271
  -------
@@ -322,53 +275,82 @@ class Evaluator:
322
275
  metrics = self.compute_precision_recall(
323
276
  iou_thresholds=iou_thresholds,
324
277
  score_thresholds=score_thresholds,
325
- filter_=filter_,
326
278
  )
327
-
328
279
  metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
329
280
  iou_thresholds=iou_thresholds,
330
281
  score_thresholds=score_thresholds,
331
- number_of_examples=number_of_examples,
332
- filter_=filter_,
333
282
  )
334
-
335
283
  return metrics
336
284
 
285
+ def _add_datum(self, datum_id: str) -> int:
286
+ """
287
+ Helper function for adding a datum to the cache.
337
288
 
338
- def defaultdict_int():
339
- return defaultdict(int)
340
-
289
+ Parameters
290
+ ----------
291
+ datum_id : str
292
+ The datum identifier.
341
293
 
342
- class DataLoader:
343
- """
344
- Object Detection DataLoader
345
- """
294
+ Returns
295
+ -------
296
+ int
297
+ The datum index.
298
+ """
299
+ if datum_id not in self.datum_id_to_index:
300
+ if len(self.datum_id_to_index) != len(self.index_to_datum_id):
301
+ raise RuntimeError("datum cache size mismatch")
302
+ idx = len(self.datum_id_to_index)
303
+ self.datum_id_to_index[datum_id] = idx
304
+ self.index_to_datum_id.append(datum_id)
305
+ return self.datum_id_to_index[datum_id]
306
+
307
+ def _add_groundtruth(self, annotation_id: str) -> int:
308
+ """
309
+ Helper function for adding a ground truth annotation identifier to the cache.
346
310
 
347
- def __init__(self):
348
- self._evaluator = Evaluator()
349
- self.pairs: list[NDArray[np.float64]] = list()
350
- self.groundtruth_count = defaultdict(defaultdict_int)
351
- self.prediction_count = defaultdict(defaultdict_int)
311
+ Parameters
312
+ ----------
313
+ annotation_id : str
314
+ The ground truth annotation identifier.
352
315
 
353
- def _add_datum(self, uid: str) -> int:
316
+ Returns
317
+ -------
318
+ int
319
+ The ground truth annotation index.
354
320
  """
355
- Helper function for adding a datum to the cache.
321
+ if annotation_id not in self.groundtruth_id_to_index:
322
+ if len(self.groundtruth_id_to_index) != len(
323
+ self.index_to_groundtruth_id
324
+ ):
325
+ raise RuntimeError("ground truth cache size mismatch")
326
+ idx = len(self.groundtruth_id_to_index)
327
+ self.groundtruth_id_to_index[annotation_id] = idx
328
+ self.index_to_groundtruth_id.append(annotation_id)
329
+ return self.groundtruth_id_to_index[annotation_id]
330
+
331
+ def _add_prediction(self, annotation_id: str) -> int:
332
+ """
333
+ Helper function for adding a prediction annotation identifier to the cache.
356
334
 
357
335
  Parameters
358
336
  ----------
359
- uid : str
360
- The datum uid.
337
+ annotation_id : str
338
+ The prediction annotation identifier.
361
339
 
362
340
  Returns
363
341
  -------
364
342
  int
365
- The datum index.
343
+ The prediction annotation index.
366
344
  """
367
- if uid not in self._evaluator.uid_to_index:
368
- index = len(self._evaluator.uid_to_index)
369
- self._evaluator.uid_to_index[uid] = index
370
- self._evaluator.index_to_uid[index] = uid
371
- return self._evaluator.uid_to_index[uid]
345
+ if annotation_id not in self.prediction_id_to_index:
346
+ if len(self.prediction_id_to_index) != len(
347
+ self.index_to_prediction_id
348
+ ):
349
+ raise RuntimeError("prediction cache size mismatch")
350
+ idx = len(self.prediction_id_to_index)
351
+ self.prediction_id_to_index[annotation_id] = idx
352
+ self.index_to_prediction_id.append(annotation_id)
353
+ return self.prediction_id_to_index[annotation_id]
372
354
 
373
355
  def _add_label(self, label: str) -> int:
374
356
  """
@@ -384,90 +366,14 @@ class DataLoader:
384
366
  int
385
367
  Label index.
386
368
  """
387
-
388
- label_id = len(self._evaluator.index_to_label)
389
- if label not in self._evaluator.label_to_index:
390
- self._evaluator.label_to_index[label] = label_id
391
- self._evaluator.index_to_label[label_id] = label
392
-
369
+ label_id = len(self.index_to_label)
370
+ if label not in self.label_to_index:
371
+ if len(self.label_to_index) != len(self.index_to_label):
372
+ raise RuntimeError("label cache size mismatch")
373
+ self.label_to_index[label] = label_id
374
+ self.index_to_label.append(label)
393
375
  label_id += 1
394
-
395
- return self._evaluator.label_to_index[label]
396
-
397
- def _cache_pairs(
398
- self,
399
- uid_index: int,
400
- groundtruths: list,
401
- predictions: list,
402
- ious: NDArray[np.float64],
403
- ) -> None:
404
- """
405
- Compute IOUs between groundtruths and preditions before storing as pairs.
406
-
407
- Parameters
408
- ----------
409
- uid_index : int
410
- The index of the detection.
411
- groundtruths : list
412
- A list of groundtruths.
413
- predictions : list
414
- A list of predictions.
415
- ious : NDArray[np.float64]
416
- An array with shape (n_preds, n_gts) containing IOUs.
417
- """
418
-
419
- predictions_with_iou_of_zero = np.where((ious < 1e-9).all(axis=1))[0]
420
- groundtruths_with_iou_of_zero = np.where((ious < 1e-9).all(axis=0))[0]
421
-
422
- pairs = [
423
- np.array(
424
- [
425
- float(uid_index),
426
- float(gidx),
427
- float(pidx),
428
- ious[pidx, gidx],
429
- float(glabel),
430
- float(plabel),
431
- float(score),
432
- ]
433
- )
434
- for pidx, plabel, score in predictions
435
- for gidx, glabel in groundtruths
436
- if ious[pidx, gidx] >= 1e-9
437
- ]
438
- pairs.extend(
439
- [
440
- np.array(
441
- [
442
- float(uid_index),
443
- -1.0,
444
- float(predictions[index][0]),
445
- 0.0,
446
- -1.0,
447
- float(predictions[index][1]),
448
- float(predictions[index][2]),
449
- ]
450
- )
451
- for index in predictions_with_iou_of_zero
452
- ]
453
- )
454
- pairs.extend(
455
- [
456
- np.array(
457
- [
458
- float(uid_index),
459
- float(groundtruths[index][0]),
460
- -1.0,
461
- 0.0,
462
- float(groundtruths[index][1]),
463
- -1.0,
464
- -1.0,
465
- ]
466
- )
467
- for index in groundtruths_with_iou_of_zero
468
- ]
469
- )
470
- self.pairs.append(np.array(pairs))
376
+ return self.label_to_index[label]
471
377
 
472
378
  def _add_data(
473
379
  self,
@@ -491,66 +397,103 @@ class DataLoader:
491
397
  for detection, ious in tqdm(
492
398
  zip(detections, detection_ious), disable=disable_tqdm
493
399
  ):
494
-
495
- # update metadata
496
- self._evaluator.n_datums += 1
497
- self._evaluator.n_groundtruths += len(detection.groundtruths)
498
- self._evaluator.n_predictions += len(detection.predictions)
499
-
500
- # update datum uid index
501
- uid_index = self._add_datum(uid=detection.uid)
502
-
503
- # initialize bounding box examples
504
- self._evaluator.groundtruth_examples[uid_index] = np.zeros(
505
- (len(detection.groundtruths), 4), dtype=np.float16
506
- )
507
- self._evaluator.prediction_examples[uid_index] = np.zeros(
508
- (len(detection.predictions), 4), dtype=np.float16
509
- )
510
-
511
- # cache labels and annotations
512
- groundtruths = list()
513
- predictions = list()
514
-
515
- for gidx, gann in enumerate(detection.groundtruths):
516
- self._evaluator.groundtruth_examples[uid_index][
517
- gidx
518
- ] = gann.extrema
519
- for glabel in gann.labels:
520
- label_idx = self._add_label(glabel)
521
- self.groundtruth_count[label_idx][uid_index] += 1
522
- groundtruths.append(
523
- (
524
- gidx,
525
- label_idx,
400
+ # cache labels and annotation pairs
401
+ pairs = []
402
+ datum_idx = self._add_datum(detection.uid)
403
+ if detection.groundtruths:
404
+ for gidx, gann in enumerate(detection.groundtruths):
405
+ groundtruth_idx = self._add_groundtruth(gann.uid)
406
+ glabel_idx = self._add_label(gann.labels[0])
407
+ if (ious[:, gidx] < 1e-9).all():
408
+ pairs.extend(
409
+ [
410
+ np.array(
411
+ [
412
+ float(datum_idx),
413
+ float(groundtruth_idx),
414
+ -1.0,
415
+ float(glabel_idx),
416
+ -1.0,
417
+ 0.0,
418
+ -1.0,
419
+ ]
420
+ )
421
+ ]
526
422
  )
423
+ for pidx, pann in enumerate(detection.predictions):
424
+ prediction_idx = self._add_prediction(pann.uid)
425
+ if (ious[pidx, :] < 1e-9).all():
426
+ pairs.extend(
427
+ [
428
+ np.array(
429
+ [
430
+ float(datum_idx),
431
+ -1.0,
432
+ float(prediction_idx),
433
+ -1.0,
434
+ float(self._add_label(plabel)),
435
+ 0.0,
436
+ float(pscore),
437
+ ]
438
+ )
439
+ for plabel, pscore in zip(
440
+ pann.labels, pann.scores
441
+ )
442
+ ]
443
+ )
444
+ if ious[pidx, gidx] >= 1e-9:
445
+ pairs.extend(
446
+ [
447
+ np.array(
448
+ [
449
+ float(datum_idx),
450
+ float(groundtruth_idx),
451
+ float(prediction_idx),
452
+ float(self._add_label(glabel)),
453
+ float(self._add_label(plabel)),
454
+ ious[pidx, gidx],
455
+ float(pscore),
456
+ ]
457
+ )
458
+ for glabel in gann.labels
459
+ for plabel, pscore in zip(
460
+ pann.labels, pann.scores
461
+ )
462
+ ]
463
+ )
464
+ elif detection.predictions:
465
+ for pidx, pann in enumerate(detection.predictions):
466
+ prediction_idx = self._add_prediction(pann.uid)
467
+ pairs.extend(
468
+ [
469
+ np.array(
470
+ [
471
+ float(datum_idx),
472
+ -1.0,
473
+ float(prediction_idx),
474
+ -1.0,
475
+ float(self._add_label(plabel)),
476
+ 0.0,
477
+ float(pscore),
478
+ ]
479
+ )
480
+ for plabel, pscore in zip(pann.labels, pann.scores)
481
+ ]
527
482
  )
528
483
 
529
- for pidx, pann in enumerate(detection.predictions):
530
- self._evaluator.prediction_examples[uid_index][
531
- pidx
532
- ] = pann.extrema
533
- for plabel, pscore in zip(pann.labels, pann.scores):
534
- label_idx = self._add_label(plabel)
535
- self.prediction_count[label_idx][uid_index] += 1
536
- predictions.append(
537
- (
538
- pidx,
539
- label_idx,
540
- pscore,
541
- )
484
+ data = np.array(pairs)
485
+ if data.size > 0:
486
+ # reset filtered cache if it exists
487
+ self.clear_filter()
488
+ if self._temp_cache is None:
489
+ raise RuntimeError(
490
+ "cannot add data as evaluator has already been finalized"
542
491
  )
543
-
544
- self._cache_pairs(
545
- uid_index=uid_index,
546
- groundtruths=groundtruths,
547
- predictions=predictions,
548
- ious=ious,
549
- )
492
+ self._temp_cache.append(data)
550
493
 
551
494
  def add_bounding_boxes(
552
495
  self,
553
- detections: list[Detection],
496
+ detections: list[Detection[BoundingBox]],
554
497
  show_progress: bool = False,
555
498
  ):
556
499
  """
@@ -584,7 +527,7 @@ class DataLoader:
584
527
 
585
528
  def add_polygons(
586
529
  self,
587
- detections: list[Detection],
530
+ detections: list[Detection[Polygon]],
588
531
  show_progress: bool = False,
589
532
  ):
590
533
  """
@@ -601,7 +544,7 @@ class DataLoader:
601
544
  compute_polygon_iou(
602
545
  np.array(
603
546
  [
604
- [gt.shape, pd.shape] # type: ignore - using the AttributeError as a validator
547
+ [gt.shape, pd.shape]
605
548
  for pd in detection.predictions
606
549
  for gt in detection.groundtruths
607
550
  ]
@@ -617,7 +560,7 @@ class DataLoader:
617
560
 
618
561
  def add_bitmasks(
619
562
  self,
620
- detections: list[Detection],
563
+ detections: list[Detection[Bitmask]],
621
564
  show_progress: bool = False,
622
565
  ):
623
566
  """
@@ -634,7 +577,7 @@ class DataLoader:
634
577
  compute_bitmask_iou(
635
578
  np.array(
636
579
  [
637
- [gt.mask, pd.mask] # type: ignore - using the AttributeError as a validator
580
+ [gt.mask, pd.mask]
638
581
  for pd in detection.predictions
639
582
  for gt in detection.groundtruths
640
583
  ]
@@ -648,7 +591,7 @@ class DataLoader:
648
591
  show_progress=show_progress,
649
592
  )
650
593
 
651
- def finalize(self) -> Evaluator:
594
+ def finalize(self):
652
595
  """
653
596
  Performs data finalization and some preprocessing steps.
654
597
 
@@ -657,65 +600,206 @@ class DataLoader:
657
600
  Evaluator
658
601
  A ready-to-use evaluator object.
659
602
  """
603
+ if self._temp_cache is None:
604
+ warnings.warn("evaluator is already finalized or in a bad state")
605
+ return self
606
+ elif not self._temp_cache:
607
+ self._detailed_pairs = np.array([], dtype=np.float64)
608
+ self._ranked_pairs = np.array([], dtype=np.float64)
609
+ self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
610
+ warnings.warn("no valid pairs")
611
+ return self
612
+ else:
613
+ self._detailed_pairs = np.concatenate(self._temp_cache, axis=0)
614
+ self._temp_cache = None
615
+
616
+ # order pairs by descending score, iou
617
+ indices = np.lexsort(
618
+ (
619
+ -self._detailed_pairs[:, 5], # iou
620
+ -self._detailed_pairs[:, 6], # score
621
+ )
622
+ )
623
+ self._detailed_pairs = self._detailed_pairs[indices]
624
+ self._label_metadata = compute_label_metadata(
625
+ ids=self._detailed_pairs[:, :5].astype(np.int32),
626
+ n_labels=self.n_labels,
627
+ )
628
+ self._ranked_pairs = rank_pairs(
629
+ detailed_pairs=self.detailed_pairs,
630
+ label_metadata=self._label_metadata,
631
+ )
632
+ return self
660
633
 
661
- self.pairs = [pair for pair in self.pairs if pair.size > 0]
662
- if len(self.pairs) == 0:
663
- raise ValueError("No data available to create evaluator.")
664
-
665
- n_datums = self._evaluator.n_datums
666
- n_labels = len(self._evaluator.index_to_label)
634
+ def apply_filter(
635
+ self,
636
+ datum_ids: list[str] | None = None,
637
+ groundtruth_ids: list[str] | None = None,
638
+ prediction_ids: list[str] | None = None,
639
+ labels: list[str] | None = None,
640
+ ):
641
+ """
642
+ Apply a filter on the evaluator.
667
643
 
668
- self._evaluator.n_labels = n_labels
644
+ Can be reset by calling 'clear_filter'.
669
645
 
670
- self._evaluator._label_metadata_per_datum = np.zeros(
671
- (2, n_datums, n_labels), dtype=np.int32
646
+ Parameters
647
+ ----------
648
+ datum_uids : list[str], optional
649
+ An optional list of string uids representing datums.
650
+ groundtruth_ids : list[str], optional
651
+ An optional list of string uids representing ground truth annotations.
652
+ prediction_ids : list[str], optional
653
+ An optional list of string uids representing prediction annotations.
654
+ labels : list[str], optional
655
+ An optional list of labels.
656
+ """
657
+ self._filtered_detailed_pairs = self._detailed_pairs.copy()
658
+ self._filtered_ranked_pairs = np.array([], dtype=np.float64)
659
+ self._filtered_label_metadata = np.zeros(
660
+ (self.n_labels, 2), dtype=np.int32
672
661
  )
673
- for datum_idx in range(n_datums):
674
- for label_idx in range(n_labels):
675
- gt_count = (
676
- self.groundtruth_count[label_idx].get(datum_idx, 0)
677
- if label_idx in self.groundtruth_count
678
- else 0
662
+
663
+ valid_datum_indices = None
664
+ if datum_ids is not None:
665
+ if not datum_ids:
666
+ self._filtered_detailed_pairs = np.array([], dtype=np.float64)
667
+ warnings.warn("no valid filtered pairs")
668
+ return
669
+ valid_datum_indices = np.array(
670
+ [self.datum_id_to_index[uid] for uid in datum_ids],
671
+ dtype=np.int32,
672
+ )
673
+
674
+ valid_groundtruth_indices = None
675
+ if groundtruth_ids is not None:
676
+ valid_groundtruth_indices = np.array(
677
+ [self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
678
+ dtype=np.int32,
679
+ )
680
+
681
+ valid_prediction_indices = None
682
+ if prediction_ids is not None:
683
+ valid_prediction_indices = np.array(
684
+ [self.prediction_id_to_index[uid] for uid in prediction_ids],
685
+ dtype=np.int32,
686
+ )
687
+
688
+ valid_label_indices = None
689
+ if labels is not None:
690
+ if not labels:
691
+ self._filtered_detailed_pairs = np.array([], dtype=np.float64)
692
+ warnings.warn("no valid filtered pairs")
693
+ return
694
+ valid_label_indices = np.array(
695
+ [self.label_to_index[label] for label in labels] + [-1]
696
+ )
697
+
698
+ # filter datums
699
+ if valid_datum_indices is not None:
700
+ mask_valid_datums = np.isin(
701
+ self._filtered_detailed_pairs[:, 0], valid_datum_indices
702
+ )
703
+ self._filtered_detailed_pairs = self._filtered_detailed_pairs[
704
+ mask_valid_datums
705
+ ]
706
+
707
+ n_rows = self._filtered_detailed_pairs.shape[0]
708
+ mask_invalid_groundtruths = np.zeros(n_rows, dtype=np.bool_)
709
+ mask_invalid_predictions = np.zeros_like(mask_invalid_groundtruths)
710
+
711
+ # filter ground truth annotations
712
+ if valid_groundtruth_indices is not None:
713
+ mask_invalid_groundtruths[
714
+ ~np.isin(
715
+ self._filtered_detailed_pairs[:, 1],
716
+ valid_groundtruth_indices,
679
717
  )
680
- pd_count = (
681
- self.prediction_count[label_idx].get(datum_idx, 0)
682
- if label_idx in self.prediction_count
683
- else 0
718
+ ] = True
719
+
720
+ # filter prediction annotations
721
+ if valid_prediction_indices is not None:
722
+ mask_invalid_predictions[
723
+ ~np.isin(
724
+ self._filtered_detailed_pairs[:, 2],
725
+ valid_prediction_indices,
684
726
  )
685
- self._evaluator._label_metadata_per_datum[
686
- :, datum_idx, label_idx
687
- ] = np.array([gt_count, pd_count])
688
-
689
- self._evaluator._label_metadata = np.array(
690
- [
691
- [
692
- float(
693
- np.sum(
694
- self._evaluator._label_metadata_per_datum[
695
- 0, :, label_idx
696
- ]
697
- )
698
- ),
699
- float(
700
- np.sum(
701
- self._evaluator._label_metadata_per_datum[
702
- 1, :, label_idx
703
- ]
704
- )
705
- ),
706
- ]
707
- for label_idx in range(n_labels)
727
+ ] = True
728
+
729
+ # filter labels
730
+ if valid_label_indices is not None:
731
+ mask_invalid_groundtruths[
732
+ ~np.isin(
733
+ self._filtered_detailed_pairs[:, 3], valid_label_indices
734
+ )
735
+ ] = True
736
+ mask_invalid_predictions[
737
+ ~np.isin(
738
+ self._filtered_detailed_pairs[:, 4], valid_label_indices
739
+ )
740
+ ] = True
741
+
742
+ # filter cache
743
+ if mask_invalid_groundtruths.any():
744
+ invalid_groundtruth_indices = np.where(mask_invalid_groundtruths)[
745
+ 0
708
746
  ]
747
+ self._filtered_detailed_pairs[
748
+ invalid_groundtruth_indices[:, None], (1, 3, 5)
749
+ ] = np.array([[-1, -1, 0]])
750
+
751
+ if mask_invalid_predictions.any():
752
+ invalid_prediction_indices = np.where(mask_invalid_predictions)[0]
753
+ self._filtered_detailed_pairs[
754
+ invalid_prediction_indices[:, None], (2, 4, 5, 6)
755
+ ] = np.array([[-1, -1, 0, -1]])
756
+
757
+ # filter null pairs
758
+ mask_null_pairs = np.all(
759
+ np.isclose(
760
+ self._filtered_detailed_pairs[:, 1:5],
761
+ np.array([-1.0, -1.0, -1.0, -1.0]),
762
+ ),
763
+ axis=1,
709
764
  )
765
+ self._filtered_detailed_pairs = self._filtered_detailed_pairs[
766
+ ~mask_null_pairs
767
+ ]
710
768
 
711
- self._evaluator._detailed_pairs = np.concatenate(
712
- self.pairs,
713
- axis=0,
769
+ if self._filtered_detailed_pairs.size == 0:
770
+ self._ranked_pairs = np.array([], dtype=np.float64)
771
+ self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
772
+ warnings.warn("no valid filtered pairs")
773
+ return
774
+
775
+ # sorts by score, iou with ground truth id as a tie-breaker
776
+ indices = np.lexsort(
777
+ (
778
+ self._filtered_detailed_pairs[:, 1], # ground truth id
779
+ -self._filtered_detailed_pairs[:, 5], # iou
780
+ -self._filtered_detailed_pairs[:, 6], # score
781
+ )
714
782
  )
715
-
716
- self._evaluator._ranked_pairs = compute_ranked_pairs(
717
- self.pairs,
718
- label_metadata=self._evaluator._label_metadata,
783
+ self._filtered_detailed_pairs = self._filtered_detailed_pairs[indices]
784
+ self._filtered_label_metadata = compute_label_metadata(
785
+ ids=self._filtered_detailed_pairs[:, :5].astype(np.int32),
786
+ n_labels=self.n_labels,
719
787
  )
788
+ self._filtered_ranked_pairs = rank_pairs(
789
+ detailed_pairs=self._filtered_detailed_pairs,
790
+ label_metadata=self._filtered_label_metadata,
791
+ )
792
+
793
+ def clear_filter(self):
794
+ """Removes a filter if one exists."""
795
+ self._filtered_detailed_pairs = None
796
+ self._filtered_ranked_pairs = None
797
+ self._filtered_label_metadata = None
798
+
799
+
800
+ class DataLoader(Evaluator):
801
+ """
802
+ Used for backwards compatibility as the Evaluator now handles ingestion.
803
+ """
720
804
 
721
- return self._evaluator
805
+ pass