valor-lite 0.35.0__py3-none-any.whl → 0.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,15 @@
1
1
  import warnings
2
+ from dataclasses import asdict, dataclass
2
3
 
3
4
  import numpy as np
4
5
  from numpy.typing import NDArray
5
6
  from tqdm import tqdm
6
7
 
8
+ from valor_lite.exceptions import (
9
+ EmptyEvaluatorException,
10
+ EmptyFilterException,
11
+ InternalCacheException,
12
+ )
7
13
  from valor_lite.object_detection.annotation import (
8
14
  Bitmask,
9
15
  BoundingBox,
@@ -17,6 +23,7 @@ from valor_lite.object_detection.computation import (
17
23
  compute_label_metadata,
18
24
  compute_polygon_iou,
19
25
  compute_precion_recall,
26
+ filter_cache,
20
27
  rank_pairs,
21
28
  )
22
29
  from valor_lite.object_detection.metric import Metric, MetricType
@@ -46,6 +53,68 @@ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_m
46
53
  """
47
54
 
48
55
 
56
+ @dataclass
57
+ class Metadata:
58
+ number_of_datums: int = 0
59
+ number_of_ground_truths: int = 0
60
+ number_of_predictions: int = 0
61
+ number_of_labels: int = 0
62
+
63
+ @classmethod
64
+ def create(
65
+ cls,
66
+ detailed_pairs: NDArray[np.float64],
67
+ number_of_datums: int,
68
+ number_of_labels: int,
69
+ ):
70
+ # count number of ground truths
71
+ mask_valid_gts = detailed_pairs[:, 1] >= 0
72
+ unique_ids = np.unique(
73
+ detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
74
+ )
75
+ number_of_ground_truths = int(unique_ids.shape[0])
76
+
77
+ # count number of predictions
78
+ mask_valid_pds = detailed_pairs[:, 2] >= 0
79
+ unique_ids = np.unique(
80
+ detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
81
+ )
82
+ number_of_predictions = int(unique_ids.shape[0])
83
+
84
+ return cls(
85
+ number_of_datums=number_of_datums,
86
+ number_of_ground_truths=number_of_ground_truths,
87
+ number_of_predictions=number_of_predictions,
88
+ number_of_labels=number_of_labels,
89
+ )
90
+
91
+ def to_dict(self) -> dict[str, int | bool]:
92
+ return asdict(self)
93
+
94
+
95
+ @dataclass
96
+ class Filter:
97
+ mask_datums: NDArray[np.bool_]
98
+ mask_groundtruths: NDArray[np.bool_]
99
+ mask_predictions: NDArray[np.bool_]
100
+ metadata: Metadata
101
+
102
+ def __post_init__(self):
103
+ # validate datums mask
104
+ if not self.mask_datums.any():
105
+ raise EmptyFilterException("filter removes all datums")
106
+
107
+ # validate annotation masks
108
+ no_gts = self.mask_groundtruths.all()
109
+ no_pds = self.mask_predictions.all()
110
+ if no_gts and no_pds:
111
+ raise EmptyFilterException("filter removes all annotations")
112
+ elif no_gts:
113
+ warnings.warn("filter removes all ground truths")
114
+ elif no_pds:
115
+ warnings.warn("filter removes all predictions")
116
+
117
+
49
118
  class Evaluator:
50
119
  """
51
120
  Object Detection Evaluator
@@ -67,80 +136,19 @@ class Evaluator:
67
136
  # temporary cache
68
137
  self._temp_cache: list[NDArray[np.float64]] | None = []
69
138
 
70
- # cache
139
+ # internal cache
71
140
  self._detailed_pairs = np.array([[]], dtype=np.float64)
72
141
  self._ranked_pairs = np.array([[]], dtype=np.float64)
73
142
  self._label_metadata: NDArray[np.int32] = np.array([[]])
74
-
75
- # filter cache
76
- self._filtered_detailed_pairs: NDArray[np.float64] | None = None
77
- self._filtered_ranked_pairs: NDArray[np.float64] | None = None
78
- self._filtered_label_metadata: NDArray[np.int32] | None = None
79
-
80
- @property
81
- def is_filtered(self) -> bool:
82
- return self._filtered_detailed_pairs is not None
83
-
84
- @property
85
- def label_metadata(self) -> NDArray[np.int32]:
86
- return (
87
- self._filtered_label_metadata
88
- if self._filtered_label_metadata is not None
89
- else self._label_metadata
90
- )
91
-
92
- @property
93
- def detailed_pairs(self) -> NDArray[np.float64]:
94
- return (
95
- self._filtered_detailed_pairs
96
- if self._filtered_detailed_pairs is not None
97
- else self._detailed_pairs
98
- )
99
-
100
- @property
101
- def ranked_pairs(self) -> NDArray[np.float64]:
102
- return (
103
- self._filtered_ranked_pairs
104
- if self._filtered_ranked_pairs is not None
105
- else self._ranked_pairs
106
- )
107
-
108
- @property
109
- def n_labels(self) -> int:
110
- """Returns the total number of unique labels."""
111
- return len(self.index_to_label)
112
-
113
- @property
114
- def n_datums(self) -> int:
115
- """Returns the number of datums."""
116
- return np.unique(self.detailed_pairs[:, 0]).size
117
-
118
- @property
119
- def n_groundtruths(self) -> int:
120
- """Returns the number of ground truth annotations."""
121
- mask_valid_gts = self.detailed_pairs[:, 1] >= 0
122
- unique_ids = np.unique(
123
- self.detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
124
- )
125
- return int(unique_ids.shape[0])
126
-
127
- @property
128
- def n_predictions(self) -> int:
129
- """Returns the number of prediction annotations."""
130
- mask_valid_pds = self.detailed_pairs[:, 2] >= 0
131
- unique_ids = np.unique(
132
- self.detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
133
- )
134
- return int(unique_ids.shape[0])
143
+ self._metadata = Metadata()
135
144
 
136
145
  @property
137
146
  def ignored_prediction_labels(self) -> list[str]:
138
147
  """
139
148
  Prediction labels that are not present in the ground truth set.
140
149
  """
141
- label_metadata = self.label_metadata
142
- glabels = set(np.where(label_metadata[:, 0] > 0)[0])
143
- plabels = set(np.where(label_metadata[:, 1] > 0)[0])
150
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
151
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
144
152
  return [
145
153
  self.index_to_label[label_id] for label_id in (plabels - glabels)
146
154
  ]
@@ -150,31 +158,157 @@ class Evaluator:
150
158
  """
151
159
  Ground truth labels that are not present in the prediction set.
152
160
  """
153
- label_metadata = self.label_metadata
154
- glabels = set(np.where(label_metadata[:, 0] > 0)[0])
155
- plabels = set(np.where(label_metadata[:, 1] > 0)[0])
161
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
162
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
156
163
  return [
157
164
  self.index_to_label[label_id] for label_id in (glabels - plabels)
158
165
  ]
159
166
 
160
167
  @property
161
- def metadata(self) -> dict:
168
+ def metadata(self) -> Metadata:
162
169
  """
163
170
  Evaluation metadata.
164
171
  """
165
- return {
166
- "n_datums": self.n_datums,
167
- "n_groundtruths": self.n_groundtruths,
168
- "n_predictions": self.n_predictions,
169
- "n_labels": self.n_labels,
170
- "ignored_prediction_labels": self.ignored_prediction_labels,
171
- "missing_prediction_labels": self.missing_prediction_labels,
172
- }
172
+ return self._metadata
173
+
174
+ def create_filter(
175
+ self,
176
+ datum_ids: list[str] | None = None,
177
+ groundtruth_ids: list[str] | None = None,
178
+ prediction_ids: list[str] | None = None,
179
+ labels: list[str] | None = None,
180
+ ) -> Filter:
181
+ """
182
+ Creates a filter object.
183
+
184
+ Parameters
185
+ ----------
186
+ datum_uids : list[str], optional
187
+ An optional list of string uids representing datums to keep.
188
+ groundtruth_ids : list[str], optional
189
+ An optional list of string uids representing ground truth annotations to keep.
190
+ prediction_ids : list[str], optional
191
+ An optional list of string uids representing prediction annotations to keep.
192
+ labels : list[str], optional
193
+ An optional list of labels to keep.
194
+ """
195
+ mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
196
+
197
+ # filter datums
198
+ if datum_ids is not None:
199
+ if not datum_ids:
200
+ raise EmptyFilterException("filter removes all datums")
201
+ valid_datum_indices = np.array(
202
+ [self.datum_id_to_index[uid] for uid in datum_ids],
203
+ dtype=np.int32,
204
+ )
205
+ mask_datums = np.isin(
206
+ self._detailed_pairs[:, 0], valid_datum_indices
207
+ )
208
+
209
+ filtered_detailed_pairs = self._detailed_pairs[mask_datums]
210
+ n_pairs = self._detailed_pairs[mask_datums].shape[0]
211
+ mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
212
+ mask_predictions = np.zeros_like(mask_groundtruths)
213
+
214
+ # filter by ground truth annotation ids
215
+ if groundtruth_ids is not None:
216
+ valid_groundtruth_indices = np.array(
217
+ [self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
218
+ dtype=np.int32,
219
+ )
220
+ mask_groundtruths[
221
+ ~np.isin(
222
+ filtered_detailed_pairs[:, 1],
223
+ valid_groundtruth_indices,
224
+ )
225
+ ] = True
226
+
227
+ # filter by prediction annotation ids
228
+ if prediction_ids is not None:
229
+ valid_prediction_indices = np.array(
230
+ [self.prediction_id_to_index[uid] for uid in prediction_ids],
231
+ dtype=np.int32,
232
+ )
233
+ mask_predictions[
234
+ ~np.isin(
235
+ filtered_detailed_pairs[:, 2],
236
+ valid_prediction_indices,
237
+ )
238
+ ] = True
239
+
240
+ # filter by labels
241
+ if labels is not None:
242
+ if not labels:
243
+ raise EmptyFilterException("filter removes all labels")
244
+ valid_label_indices = np.array(
245
+ [self.label_to_index[label] for label in labels] + [-1]
246
+ )
247
+ mask_groundtruths[
248
+ ~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
249
+ ] = True
250
+ mask_predictions[
251
+ ~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
252
+ ] = True
253
+
254
+ filtered_detailed_pairs, _, _ = filter_cache(
255
+ self._detailed_pairs,
256
+ mask_datums=mask_datums,
257
+ mask_ground_truths=mask_groundtruths,
258
+ mask_predictions=mask_predictions,
259
+ n_labels=len(self.index_to_label),
260
+ )
261
+
262
+ number_of_datums = (
263
+ len(datum_ids)
264
+ if datum_ids
265
+ else np.unique(filtered_detailed_pairs[:, 0]).size
266
+ )
267
+
268
+ return Filter(
269
+ mask_datums=mask_datums,
270
+ mask_groundtruths=mask_groundtruths,
271
+ mask_predictions=mask_predictions,
272
+ metadata=Metadata.create(
273
+ detailed_pairs=filtered_detailed_pairs,
274
+ number_of_datums=number_of_datums,
275
+ number_of_labels=len(self.index_to_label),
276
+ ),
277
+ )
278
+
279
+ def filter(
280
+ self, filter_: Filter
281
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
282
+ """
283
+ Performs filtering over the internal cache.
284
+
285
+ Parameters
286
+ ----------
287
+ filter_ : Filter
288
+ The filter parameterization.
289
+
290
+ Returns
291
+ -------
292
+ NDArray[float64]
293
+ Filtered detailed pairs.
294
+ NDArray[float64]
295
+ Filtered ranked pairs.
296
+ NDArray[int32]
297
+ Label metadata.
298
+ """
299
+ return filter_cache(
300
+ detailed_pairs=self._detailed_pairs,
301
+ mask_datums=filter_.mask_datums,
302
+ mask_ground_truths=filter_.mask_groundtruths,
303
+ mask_predictions=filter_.mask_predictions,
304
+ n_labels=len(self.index_to_label),
305
+ )
173
306
 
174
307
  def compute_precision_recall(
175
308
  self,
176
309
  iou_thresholds: list[float],
177
310
  score_thresholds: list[float],
311
+ filter_: Filter | None = None,
178
312
  ) -> dict[MetricType, list[Metric]]:
179
313
  """
180
314
  Computes all metrics except for ConfusionMatrix
@@ -185,6 +319,8 @@ class Evaluator:
185
319
  A list of IOU thresholds to compute metrics over.
186
320
  score_thresholds : list[float]
187
321
  A list of score thresholds to compute metrics over.
322
+ filter_ : Filter, optional
323
+ A collection of filter parameters and masks.
188
324
 
189
325
  Returns
190
326
  -------
@@ -195,15 +331,22 @@ class Evaluator:
195
331
  raise ValueError("At least one IOU threshold must be passed.")
196
332
  elif not score_thresholds:
197
333
  raise ValueError("At least one score threshold must be passed.")
334
+
335
+ if filter_ is not None:
336
+ _, ranked_pairs, label_metadata = self.filter(filter_=filter_)
337
+ else:
338
+ ranked_pairs = self._ranked_pairs
339
+ label_metadata = self._label_metadata
340
+
198
341
  results = compute_precion_recall(
199
- ranked_pairs=self.ranked_pairs,
200
- label_metadata=self.label_metadata,
342
+ ranked_pairs=ranked_pairs,
343
+ label_metadata=label_metadata,
201
344
  iou_thresholds=np.array(iou_thresholds),
202
345
  score_thresholds=np.array(score_thresholds),
203
346
  )
204
347
  return unpack_precision_recall_into_metric_lists(
205
348
  results=results,
206
- label_metadata=self.label_metadata,
349
+ label_metadata=label_metadata,
207
350
  iou_thresholds=iou_thresholds,
208
351
  score_thresholds=score_thresholds,
209
352
  index_to_label=self.index_to_label,
@@ -213,6 +356,7 @@ class Evaluator:
213
356
  self,
214
357
  iou_thresholds: list[float],
215
358
  score_thresholds: list[float],
359
+ filter_: Filter | None = None,
216
360
  ) -> list[Metric]:
217
361
  """
218
362
  Computes confusion matrices at various thresholds.
@@ -223,6 +367,8 @@ class Evaluator:
223
367
  A list of IOU thresholds to compute metrics over.
224
368
  score_thresholds : list[float]
225
369
  A list of score thresholds to compute metrics over.
370
+ filter_ : Filter, optional
371
+ A collection of filter parameters and masks.
226
372
 
227
373
  Returns
228
374
  -------
@@ -233,17 +379,23 @@ class Evaluator:
233
379
  raise ValueError("At least one IOU threshold must be passed.")
234
380
  elif not score_thresholds:
235
381
  raise ValueError("At least one score threshold must be passed.")
236
- elif self.detailed_pairs.size == 0:
237
- warnings.warn("attempted to compute over an empty set")
382
+
383
+ if filter_ is not None:
384
+ detailed_pairs, _, _ = self.filter(filter_=filter_)
385
+ else:
386
+ detailed_pairs = self._detailed_pairs
387
+
388
+ if detailed_pairs.size == 0:
238
389
  return []
390
+
239
391
  results = compute_confusion_matrix(
240
- detailed_pairs=self.detailed_pairs,
392
+ detailed_pairs=detailed_pairs,
241
393
  iou_thresholds=np.array(iou_thresholds),
242
394
  score_thresholds=np.array(score_thresholds),
243
395
  )
244
396
  return unpack_confusion_matrix_into_metric_list(
245
397
  results=results,
246
- detailed_pairs=self.detailed_pairs,
398
+ detailed_pairs=detailed_pairs,
247
399
  iou_thresholds=iou_thresholds,
248
400
  score_thresholds=score_thresholds,
249
401
  index_to_datum_id=self.index_to_datum_id,
@@ -256,6 +408,7 @@ class Evaluator:
256
408
  self,
257
409
  iou_thresholds: list[float] = [0.1, 0.5, 0.75],
258
410
  score_thresholds: list[float] = [0.5],
411
+ filter_: Filter | None = None,
259
412
  ) -> dict[MetricType, list[Metric]]:
260
413
  """
261
414
  Computes all available metrics.
@@ -266,6 +419,8 @@ class Evaluator:
266
419
  A list of IOU thresholds to compute metrics over.
267
420
  score_thresholds : list[float], default=[0.5]
268
421
  A list of score thresholds to compute metrics over.
422
+ filter_ : Filter, optional
423
+ A collection of filter parameters and masks.
269
424
 
270
425
  Returns
271
426
  -------
@@ -275,13 +430,25 @@ class Evaluator:
275
430
  metrics = self.compute_precision_recall(
276
431
  iou_thresholds=iou_thresholds,
277
432
  score_thresholds=score_thresholds,
433
+ filter_=filter_,
278
434
  )
279
435
  metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
280
436
  iou_thresholds=iou_thresholds,
281
437
  score_thresholds=score_thresholds,
438
+ filter_=filter_,
282
439
  )
283
440
  return metrics
284
441
 
442
+
443
+ class DataLoader:
444
+ """
445
+ Object Detection DataLoader
446
+ """
447
+
448
+ def __init__(self):
449
+ self._evaluator = Evaluator()
450
+ self.pairs: list[NDArray[np.float64]] = list()
451
+
285
452
  def _add_datum(self, datum_id: str) -> int:
286
453
  """
287
454
  Helper function for adding a datum to the cache.
@@ -296,13 +463,15 @@ class Evaluator:
296
463
  int
297
464
  The datum index.
298
465
  """
299
- if datum_id not in self.datum_id_to_index:
300
- if len(self.datum_id_to_index) != len(self.index_to_datum_id):
301
- raise RuntimeError("datum cache size mismatch")
302
- idx = len(self.datum_id_to_index)
303
- self.datum_id_to_index[datum_id] = idx
304
- self.index_to_datum_id.append(datum_id)
305
- return self.datum_id_to_index[datum_id]
466
+ if datum_id not in self._evaluator.datum_id_to_index:
467
+ if len(self._evaluator.datum_id_to_index) != len(
468
+ self._evaluator.index_to_datum_id
469
+ ):
470
+ raise InternalCacheException("datum cache size mismatch")
471
+ idx = len(self._evaluator.datum_id_to_index)
472
+ self._evaluator.datum_id_to_index[datum_id] = idx
473
+ self._evaluator.index_to_datum_id.append(datum_id)
474
+ return self._evaluator.datum_id_to_index[datum_id]
306
475
 
307
476
  def _add_groundtruth(self, annotation_id: str) -> int:
308
477
  """
@@ -318,15 +487,17 @@ class Evaluator:
318
487
  int
319
488
  The ground truth annotation index.
320
489
  """
321
- if annotation_id not in self.groundtruth_id_to_index:
322
- if len(self.groundtruth_id_to_index) != len(
323
- self.index_to_groundtruth_id
490
+ if annotation_id not in self._evaluator.groundtruth_id_to_index:
491
+ if len(self._evaluator.groundtruth_id_to_index) != len(
492
+ self._evaluator.index_to_groundtruth_id
324
493
  ):
325
- raise RuntimeError("ground truth cache size mismatch")
326
- idx = len(self.groundtruth_id_to_index)
327
- self.groundtruth_id_to_index[annotation_id] = idx
328
- self.index_to_groundtruth_id.append(annotation_id)
329
- return self.groundtruth_id_to_index[annotation_id]
494
+ raise InternalCacheException(
495
+ "ground truth cache size mismatch"
496
+ )
497
+ idx = len(self._evaluator.groundtruth_id_to_index)
498
+ self._evaluator.groundtruth_id_to_index[annotation_id] = idx
499
+ self._evaluator.index_to_groundtruth_id.append(annotation_id)
500
+ return self._evaluator.groundtruth_id_to_index[annotation_id]
330
501
 
331
502
  def _add_prediction(self, annotation_id: str) -> int:
332
503
  """
@@ -342,15 +513,15 @@ class Evaluator:
342
513
  int
343
514
  The prediction annotation index.
344
515
  """
345
- if annotation_id not in self.prediction_id_to_index:
346
- if len(self.prediction_id_to_index) != len(
347
- self.index_to_prediction_id
516
+ if annotation_id not in self._evaluator.prediction_id_to_index:
517
+ if len(self._evaluator.prediction_id_to_index) != len(
518
+ self._evaluator.index_to_prediction_id
348
519
  ):
349
- raise RuntimeError("prediction cache size mismatch")
350
- idx = len(self.prediction_id_to_index)
351
- self.prediction_id_to_index[annotation_id] = idx
352
- self.index_to_prediction_id.append(annotation_id)
353
- return self.prediction_id_to_index[annotation_id]
520
+ raise InternalCacheException("prediction cache size mismatch")
521
+ idx = len(self._evaluator.prediction_id_to_index)
522
+ self._evaluator.prediction_id_to_index[annotation_id] = idx
523
+ self._evaluator.index_to_prediction_id.append(annotation_id)
524
+ return self._evaluator.prediction_id_to_index[annotation_id]
354
525
 
355
526
  def _add_label(self, label: str) -> int:
356
527
  """
@@ -366,14 +537,16 @@ class Evaluator:
366
537
  int
367
538
  Label index.
368
539
  """
369
- label_id = len(self.index_to_label)
370
- if label not in self.label_to_index:
371
- if len(self.label_to_index) != len(self.index_to_label):
372
- raise RuntimeError("label cache size mismatch")
373
- self.label_to_index[label] = label_id
374
- self.index_to_label.append(label)
540
+ label_id = len(self._evaluator.index_to_label)
541
+ if label not in self._evaluator.label_to_index:
542
+ if len(self._evaluator.label_to_index) != len(
543
+ self._evaluator.index_to_label
544
+ ):
545
+ raise InternalCacheException("label cache size mismatch")
546
+ self._evaluator.label_to_index[label] = label_id
547
+ self._evaluator.index_to_label.append(label)
375
548
  label_id += 1
376
- return self.label_to_index[label]
549
+ return self._evaluator.label_to_index[label]
377
550
 
378
551
  def _add_data(
379
552
  self,
@@ -483,13 +656,7 @@ class Evaluator:
483
656
 
484
657
  data = np.array(pairs)
485
658
  if data.size > 0:
486
- # reset filtered cache if it exists
487
- self.clear_filter()
488
- if self._temp_cache is None:
489
- raise RuntimeError(
490
- "cannot add data as evaluator has already been finalized"
491
- )
492
- self._temp_cache.append(data)
659
+ self.pairs.append(data)
493
660
 
494
661
  def add_bounding_boxes(
495
662
  self,
@@ -591,7 +758,7 @@ class Evaluator:
591
758
  show_progress=show_progress,
592
759
  )
593
760
 
594
- def finalize(self):
761
+ def finalize(self) -> Evaluator:
595
762
  """
596
763
  Performs data finalization and some preprocessing steps.
597
764
 
@@ -600,206 +767,37 @@ class Evaluator:
600
767
  Evaluator
601
768
  A ready-to-use evaluator object.
602
769
  """
603
- if self._temp_cache is None:
604
- warnings.warn("evaluator is already finalized or in a bad state")
605
- return self
606
- elif not self._temp_cache:
607
- self._detailed_pairs = np.array([], dtype=np.float64)
608
- self._ranked_pairs = np.array([], dtype=np.float64)
609
- self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
610
- warnings.warn("no valid pairs")
611
- return self
612
- else:
613
- self._detailed_pairs = np.concatenate(self._temp_cache, axis=0)
614
- self._temp_cache = None
770
+ if not self.pairs:
771
+ raise EmptyEvaluatorException()
772
+
773
+ n_labels = len(self._evaluator.index_to_label)
774
+ n_datums = len(self._evaluator.index_to_datum_id)
775
+
776
+ self._evaluator._detailed_pairs = np.concatenate(self.pairs, axis=0)
777
+ if self._evaluator._detailed_pairs.size == 0:
778
+ raise EmptyEvaluatorException()
615
779
 
616
780
  # order pairs by descending score, iou
617
781
  indices = np.lexsort(
618
782
  (
619
- -self._detailed_pairs[:, 5], # iou
620
- -self._detailed_pairs[:, 6], # score
783
+ -self._evaluator._detailed_pairs[:, 5], # iou
784
+ -self._evaluator._detailed_pairs[:, 6], # score
621
785
  )
622
786
  )
623
- self._detailed_pairs = self._detailed_pairs[indices]
624
- self._label_metadata = compute_label_metadata(
625
- ids=self._detailed_pairs[:, :5].astype(np.int32),
626
- n_labels=self.n_labels,
627
- )
628
- self._ranked_pairs = rank_pairs(
629
- detailed_pairs=self.detailed_pairs,
630
- label_metadata=self._label_metadata,
631
- )
632
- return self
633
-
634
- def apply_filter(
635
- self,
636
- datum_ids: list[str] | None = None,
637
- groundtruth_ids: list[str] | None = None,
638
- prediction_ids: list[str] | None = None,
639
- labels: list[str] | None = None,
640
- ):
641
- """
642
- Apply a filter on the evaluator.
643
-
644
- Can be reset by calling 'clear_filter'.
645
-
646
- Parameters
647
- ----------
648
- datum_uids : list[str], optional
649
- An optional list of string uids representing datums.
650
- groundtruth_ids : list[str], optional
651
- An optional list of string uids representing ground truth annotations.
652
- prediction_ids : list[str], optional
653
- An optional list of string uids representing prediction annotations.
654
- labels : list[str], optional
655
- An optional list of labels.
656
- """
657
- self._filtered_detailed_pairs = self._detailed_pairs.copy()
658
- self._filtered_ranked_pairs = np.array([], dtype=np.float64)
659
- self._filtered_label_metadata = np.zeros(
660
- (self.n_labels, 2), dtype=np.int32
661
- )
662
-
663
- valid_datum_indices = None
664
- if datum_ids is not None:
665
- if not datum_ids:
666
- self._filtered_detailed_pairs = np.array([], dtype=np.float64)
667
- warnings.warn("no valid filtered pairs")
668
- return
669
- valid_datum_indices = np.array(
670
- [self.datum_id_to_index[uid] for uid in datum_ids],
671
- dtype=np.int32,
672
- )
673
-
674
- valid_groundtruth_indices = None
675
- if groundtruth_ids is not None:
676
- valid_groundtruth_indices = np.array(
677
- [self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
678
- dtype=np.int32,
679
- )
680
-
681
- valid_prediction_indices = None
682
- if prediction_ids is not None:
683
- valid_prediction_indices = np.array(
684
- [self.prediction_id_to_index[uid] for uid in prediction_ids],
685
- dtype=np.int32,
686
- )
687
-
688
- valid_label_indices = None
689
- if labels is not None:
690
- if not labels:
691
- self._filtered_detailed_pairs = np.array([], dtype=np.float64)
692
- warnings.warn("no valid filtered pairs")
693
- return
694
- valid_label_indices = np.array(
695
- [self.label_to_index[label] for label in labels] + [-1]
696
- )
697
-
698
- # filter datums
699
- if valid_datum_indices is not None:
700
- mask_valid_datums = np.isin(
701
- self._filtered_detailed_pairs[:, 0], valid_datum_indices
702
- )
703
- self._filtered_detailed_pairs = self._filtered_detailed_pairs[
704
- mask_valid_datums
705
- ]
706
-
707
- n_rows = self._filtered_detailed_pairs.shape[0]
708
- mask_invalid_groundtruths = np.zeros(n_rows, dtype=np.bool_)
709
- mask_invalid_predictions = np.zeros_like(mask_invalid_groundtruths)
710
-
711
- # filter ground truth annotations
712
- if valid_groundtruth_indices is not None:
713
- mask_invalid_groundtruths[
714
- ~np.isin(
715
- self._filtered_detailed_pairs[:, 1],
716
- valid_groundtruth_indices,
717
- )
718
- ] = True
719
-
720
- # filter prediction annotations
721
- if valid_prediction_indices is not None:
722
- mask_invalid_predictions[
723
- ~np.isin(
724
- self._filtered_detailed_pairs[:, 2],
725
- valid_prediction_indices,
726
- )
727
- ] = True
728
-
729
- # filter labels
730
- if valid_label_indices is not None:
731
- mask_invalid_groundtruths[
732
- ~np.isin(
733
- self._filtered_detailed_pairs[:, 3], valid_label_indices
734
- )
735
- ] = True
736
- mask_invalid_predictions[
737
- ~np.isin(
738
- self._filtered_detailed_pairs[:, 4], valid_label_indices
739
- )
740
- ] = True
741
-
742
- # filter cache
743
- if mask_invalid_groundtruths.any():
744
- invalid_groundtruth_indices = np.where(mask_invalid_groundtruths)[
745
- 0
746
- ]
747
- self._filtered_detailed_pairs[
748
- invalid_groundtruth_indices[:, None], (1, 3, 5)
749
- ] = np.array([[-1, -1, 0]])
750
-
751
- if mask_invalid_predictions.any():
752
- invalid_prediction_indices = np.where(mask_invalid_predictions)[0]
753
- self._filtered_detailed_pairs[
754
- invalid_prediction_indices[:, None], (2, 4, 5, 6)
755
- ] = np.array([[-1, -1, 0, -1]])
756
-
757
- # filter null pairs
758
- mask_null_pairs = np.all(
759
- np.isclose(
760
- self._filtered_detailed_pairs[:, 1:5],
761
- np.array([-1.0, -1.0, -1.0, -1.0]),
762
- ),
763
- axis=1,
764
- )
765
- self._filtered_detailed_pairs = self._filtered_detailed_pairs[
766
- ~mask_null_pairs
787
+ self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
788
+ indices
767
789
  ]
768
-
769
- if self._filtered_detailed_pairs.size == 0:
770
- self._ranked_pairs = np.array([], dtype=np.float64)
771
- self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
772
- warnings.warn("no valid filtered pairs")
773
- return
774
-
775
- # sorts by score, iou with ground truth id as a tie-breaker
776
- indices = np.lexsort(
777
- (
778
- self._filtered_detailed_pairs[:, 1], # ground truth id
779
- -self._filtered_detailed_pairs[:, 5], # iou
780
- -self._filtered_detailed_pairs[:, 6], # score
781
- )
790
+ self._evaluator._label_metadata = compute_label_metadata(
791
+ ids=self._evaluator._detailed_pairs[:, :5].astype(np.int32),
792
+ n_labels=n_labels,
782
793
  )
783
- self._filtered_detailed_pairs = self._filtered_detailed_pairs[indices]
784
- self._filtered_label_metadata = compute_label_metadata(
785
- ids=self._filtered_detailed_pairs[:, :5].astype(np.int32),
786
- n_labels=self.n_labels,
794
+ self._evaluator._ranked_pairs = rank_pairs(
795
+ detailed_pairs=self._evaluator._detailed_pairs,
796
+ label_metadata=self._evaluator._label_metadata,
787
797
  )
788
- self._filtered_ranked_pairs = rank_pairs(
789
- detailed_pairs=self._filtered_detailed_pairs,
790
- label_metadata=self._filtered_label_metadata,
798
+ self._evaluator._metadata = Metadata.create(
799
+ detailed_pairs=self._evaluator._detailed_pairs,
800
+ number_of_datums=n_datums,
801
+ number_of_labels=n_labels,
791
802
  )
792
-
793
- def clear_filter(self):
794
- """Removes a filter if one exists."""
795
- self._filtered_detailed_pairs = None
796
- self._filtered_ranked_pairs = None
797
- self._filtered_label_metadata = None
798
-
799
-
800
- class DataLoader(Evaluator):
801
- """
802
- Used for backwards compatibility as the Evaluator now handles ingestion.
803
- """
804
-
805
- pass
803
+ return self._evaluator