valor-lite 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

@@ -1,4 +1,5 @@
1
1
  import warnings
2
+ from dataclasses import asdict, dataclass
2
3
 
3
4
  import numpy as np
4
5
  from numpy.typing import NDArray
@@ -17,6 +18,7 @@ from valor_lite.object_detection.computation import (
17
18
  compute_label_metadata,
18
19
  compute_polygon_iou,
19
20
  compute_precion_recall,
21
+ filter_cache,
20
22
  rank_pairs,
21
23
  )
22
24
  from valor_lite.object_detection.metric import Metric, MetricType
@@ -46,6 +48,53 @@ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_m
46
48
  """
47
49
 
48
50
 
51
+ @dataclass
52
+ class Metadata:
53
+ number_of_datums: int = 0
54
+ number_of_ground_truths: int = 0
55
+ number_of_predictions: int = 0
56
+ number_of_labels: int = 0
57
+
58
+ @classmethod
59
+ def create(
60
+ cls,
61
+ detailed_pairs: NDArray[np.float64],
62
+ number_of_datums: int,
63
+ number_of_labels: int,
64
+ ):
65
+ # count number of ground truths
66
+ mask_valid_gts = detailed_pairs[:, 1] >= 0
67
+ unique_ids = np.unique(
68
+ detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
69
+ )
70
+ number_of_ground_truths = int(unique_ids.shape[0])
71
+
72
+ # count number of predictions
73
+ mask_valid_pds = detailed_pairs[:, 2] >= 0
74
+ unique_ids = np.unique(
75
+ detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
76
+ )
77
+ number_of_predictions = int(unique_ids.shape[0])
78
+
79
+ return cls(
80
+ number_of_datums=number_of_datums,
81
+ number_of_ground_truths=number_of_ground_truths,
82
+ number_of_predictions=number_of_predictions,
83
+ number_of_labels=number_of_labels,
84
+ )
85
+
86
+ def to_dict(self) -> dict[str, int | bool]:
87
+ return asdict(self)
88
+
89
+
90
+ @dataclass
91
+ class Filter:
92
+ mask_datums: NDArray[np.bool_]
93
+ mask_groundtruths: NDArray[np.bool_]
94
+ mask_predictions: NDArray[np.bool_]
95
+ metadata: Metadata
96
+
97
+
49
98
  class Evaluator:
50
99
  """
51
100
  Object Detection Evaluator
@@ -67,80 +116,19 @@ class Evaluator:
67
116
  # temporary cache
68
117
  self._temp_cache: list[NDArray[np.float64]] | None = []
69
118
 
70
- # cache
119
+ # internal cache
71
120
  self._detailed_pairs = np.array([[]], dtype=np.float64)
72
121
  self._ranked_pairs = np.array([[]], dtype=np.float64)
73
122
  self._label_metadata: NDArray[np.int32] = np.array([[]])
74
-
75
- # filter cache
76
- self._filtered_detailed_pairs: NDArray[np.float64] | None = None
77
- self._filtered_ranked_pairs: NDArray[np.float64] | None = None
78
- self._filtered_label_metadata: NDArray[np.int32] | None = None
79
-
80
- @property
81
- def is_filtered(self) -> bool:
82
- return self._filtered_detailed_pairs is not None
83
-
84
- @property
85
- def label_metadata(self) -> NDArray[np.int32]:
86
- return (
87
- self._filtered_label_metadata
88
- if self._filtered_label_metadata is not None
89
- else self._label_metadata
90
- )
91
-
92
- @property
93
- def detailed_pairs(self) -> NDArray[np.float64]:
94
- return (
95
- self._filtered_detailed_pairs
96
- if self._filtered_detailed_pairs is not None
97
- else self._detailed_pairs
98
- )
99
-
100
- @property
101
- def ranked_pairs(self) -> NDArray[np.float64]:
102
- return (
103
- self._filtered_ranked_pairs
104
- if self._filtered_ranked_pairs is not None
105
- else self._ranked_pairs
106
- )
107
-
108
- @property
109
- def n_labels(self) -> int:
110
- """Returns the total number of unique labels."""
111
- return len(self.index_to_label)
112
-
113
- @property
114
- def n_datums(self) -> int:
115
- """Returns the number of datums."""
116
- return np.unique(self.detailed_pairs[:, 0]).size
117
-
118
- @property
119
- def n_groundtruths(self) -> int:
120
- """Returns the number of ground truth annotations."""
121
- mask_valid_gts = self.detailed_pairs[:, 1] >= 0
122
- unique_ids = np.unique(
123
- self.detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
124
- )
125
- return int(unique_ids.shape[0])
126
-
127
- @property
128
- def n_predictions(self) -> int:
129
- """Returns the number of prediction annotations."""
130
- mask_valid_pds = self.detailed_pairs[:, 2] >= 0
131
- unique_ids = np.unique(
132
- self.detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
133
- )
134
- return int(unique_ids.shape[0])
123
+ self._metadata = Metadata()
135
124
 
136
125
  @property
137
126
  def ignored_prediction_labels(self) -> list[str]:
138
127
  """
139
128
  Prediction labels that are not present in the ground truth set.
140
129
  """
141
- label_metadata = self.label_metadata
142
- glabels = set(np.where(label_metadata[:, 0] > 0)[0])
143
- plabels = set(np.where(label_metadata[:, 1] > 0)[0])
130
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
131
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
144
132
  return [
145
133
  self.index_to_label[label_id] for label_id in (plabels - glabels)
146
134
  ]
@@ -150,137 +138,18 @@ class Evaluator:
150
138
  """
151
139
  Ground truth labels that are not present in the prediction set.
152
140
  """
153
- label_metadata = self.label_metadata
154
- glabels = set(np.where(label_metadata[:, 0] > 0)[0])
155
- plabels = set(np.where(label_metadata[:, 1] > 0)[0])
141
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
142
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
156
143
  return [
157
144
  self.index_to_label[label_id] for label_id in (glabels - plabels)
158
145
  ]
159
146
 
160
147
  @property
161
- def metadata(self) -> dict:
148
+ def metadata(self) -> Metadata:
162
149
  """
163
150
  Evaluation metadata.
164
151
  """
165
- return {
166
- "n_datums": self.n_datums,
167
- "n_groundtruths": self.n_groundtruths,
168
- "n_predictions": self.n_predictions,
169
- "n_labels": self.n_labels,
170
- "ignored_prediction_labels": self.ignored_prediction_labels,
171
- "missing_prediction_labels": self.missing_prediction_labels,
172
- }
173
-
174
- def compute_precision_recall(
175
- self,
176
- iou_thresholds: list[float],
177
- score_thresholds: list[float],
178
- ) -> dict[MetricType, list[Metric]]:
179
- """
180
- Computes all metrics except for ConfusionMatrix
181
-
182
- Parameters
183
- ----------
184
- iou_thresholds : list[float]
185
- A list of IOU thresholds to compute metrics over.
186
- score_thresholds : list[float]
187
- A list of score thresholds to compute metrics over.
188
-
189
- Returns
190
- -------
191
- dict[MetricType, list]
192
- A dictionary mapping MetricType enumerations to lists of computed metrics.
193
- """
194
- if not iou_thresholds:
195
- raise ValueError("At least one IOU threshold must be passed.")
196
- elif not score_thresholds:
197
- raise ValueError("At least one score threshold must be passed.")
198
- results = compute_precion_recall(
199
- ranked_pairs=self.ranked_pairs,
200
- label_metadata=self.label_metadata,
201
- iou_thresholds=np.array(iou_thresholds),
202
- score_thresholds=np.array(score_thresholds),
203
- )
204
- return unpack_precision_recall_into_metric_lists(
205
- results=results,
206
- label_metadata=self.label_metadata,
207
- iou_thresholds=iou_thresholds,
208
- score_thresholds=score_thresholds,
209
- index_to_label=self.index_to_label,
210
- )
211
-
212
- def compute_confusion_matrix(
213
- self,
214
- iou_thresholds: list[float],
215
- score_thresholds: list[float],
216
- ) -> list[Metric]:
217
- """
218
- Computes confusion matrices at various thresholds.
219
-
220
- Parameters
221
- ----------
222
- iou_thresholds : list[float]
223
- A list of IOU thresholds to compute metrics over.
224
- score_thresholds : list[float]
225
- A list of score thresholds to compute metrics over.
226
-
227
- Returns
228
- -------
229
- list[Metric]
230
- List of confusion matrices per threshold pair.
231
- """
232
- if not iou_thresholds:
233
- raise ValueError("At least one IOU threshold must be passed.")
234
- elif not score_thresholds:
235
- raise ValueError("At least one score threshold must be passed.")
236
- elif self.detailed_pairs.size == 0:
237
- warnings.warn("attempted to compute over an empty set")
238
- return []
239
- results = compute_confusion_matrix(
240
- detailed_pairs=self.detailed_pairs,
241
- iou_thresholds=np.array(iou_thresholds),
242
- score_thresholds=np.array(score_thresholds),
243
- )
244
- return unpack_confusion_matrix_into_metric_list(
245
- results=results,
246
- detailed_pairs=self.detailed_pairs,
247
- iou_thresholds=iou_thresholds,
248
- score_thresholds=score_thresholds,
249
- index_to_datum_id=self.index_to_datum_id,
250
- index_to_groundtruth_id=self.index_to_groundtruth_id,
251
- index_to_prediction_id=self.index_to_prediction_id,
252
- index_to_label=self.index_to_label,
253
- )
254
-
255
- def evaluate(
256
- self,
257
- iou_thresholds: list[float] = [0.1, 0.5, 0.75],
258
- score_thresholds: list[float] = [0.5],
259
- ) -> dict[MetricType, list[Metric]]:
260
- """
261
- Computes all available metrics.
262
-
263
- Parameters
264
- ----------
265
- iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
266
- A list of IOU thresholds to compute metrics over.
267
- score_thresholds : list[float], default=[0.5]
268
- A list of score thresholds to compute metrics over.
269
-
270
- Returns
271
- -------
272
- dict[MetricType, list[Metric]]
273
- Lists of metrics organized by metric type.
274
- """
275
- metrics = self.compute_precision_recall(
276
- iou_thresholds=iou_thresholds,
277
- score_thresholds=score_thresholds,
278
- )
279
- metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
280
- iou_thresholds=iou_thresholds,
281
- score_thresholds=score_thresholds,
282
- )
283
- return metrics
152
+ return self._metadata
284
153
 
285
154
  def _add_datum(self, datum_id: str) -> int:
286
155
  """
@@ -484,7 +353,6 @@ class Evaluator:
484
353
  data = np.array(pairs)
485
354
  if data.size > 0:
486
355
  # reset filtered cache if it exists
487
- self.clear_filter()
488
356
  if self._temp_cache is None:
489
357
  raise RuntimeError(
490
358
  "cannot add data as evaluator has already been finalized"
@@ -600,13 +468,16 @@ class Evaluator:
600
468
  Evaluator
601
469
  A ready-to-use evaluator object.
602
470
  """
471
+ n_labels = len(self.index_to_label)
472
+ n_datums = len(self.index_to_datum_id)
603
473
  if self._temp_cache is None:
604
474
  warnings.warn("evaluator is already finalized or in a bad state")
605
475
  return self
606
476
  elif not self._temp_cache:
607
477
  self._detailed_pairs = np.array([], dtype=np.float64)
608
478
  self._ranked_pairs = np.array([], dtype=np.float64)
609
- self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
479
+ self._label_metadata = np.zeros((n_labels, 2), dtype=np.int32)
480
+ self._metadata = Metadata()
610
481
  warnings.warn("no valid pairs")
611
482
  return self
612
483
  else:
@@ -623,178 +494,316 @@ class Evaluator:
623
494
  self._detailed_pairs = self._detailed_pairs[indices]
624
495
  self._label_metadata = compute_label_metadata(
625
496
  ids=self._detailed_pairs[:, :5].astype(np.int32),
626
- n_labels=self.n_labels,
497
+ n_labels=n_labels,
627
498
  )
628
499
  self._ranked_pairs = rank_pairs(
629
- detailed_pairs=self.detailed_pairs,
500
+ detailed_pairs=self._detailed_pairs,
630
501
  label_metadata=self._label_metadata,
631
502
  )
503
+ self._metadata = Metadata.create(
504
+ detailed_pairs=self._detailed_pairs,
505
+ number_of_datums=n_datums,
506
+ number_of_labels=n_labels,
507
+ )
632
508
  return self
633
509
 
634
- def apply_filter(
510
+ def create_filter(
635
511
  self,
636
512
  datum_ids: list[str] | None = None,
637
513
  groundtruth_ids: list[str] | None = None,
638
514
  prediction_ids: list[str] | None = None,
639
515
  labels: list[str] | None = None,
640
- ):
516
+ ) -> Filter:
641
517
  """
642
- Apply a filter on the evaluator.
643
-
644
- Can be reset by calling 'clear_filter'.
518
+ Creates a filter object.
645
519
 
646
520
  Parameters
647
521
  ----------
648
522
  datum_uids : list[str], optional
649
- An optional list of string uids representing datums.
523
+ An optional list of string uids representing datums to keep.
650
524
  groundtruth_ids : list[str], optional
651
- An optional list of string uids representing ground truth annotations.
525
+ An optional list of string uids representing ground truth annotations to keep.
652
526
  prediction_ids : list[str], optional
653
- An optional list of string uids representing prediction annotations.
527
+ An optional list of string uids representing prediction annotations to keep.
654
528
  labels : list[str], optional
655
- An optional list of labels.
529
+ An optional list of labels to keep.
656
530
  """
657
- self._filtered_detailed_pairs = self._detailed_pairs.copy()
658
- self._filtered_ranked_pairs = np.array([], dtype=np.float64)
659
- self._filtered_label_metadata = np.zeros(
660
- (self.n_labels, 2), dtype=np.int32
661
- )
531
+ mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
662
532
 
663
- valid_datum_indices = None
533
+ # filter datums
664
534
  if datum_ids is not None:
665
535
  if not datum_ids:
666
- self._filtered_detailed_pairs = np.array([], dtype=np.float64)
667
- warnings.warn("no valid filtered pairs")
668
- return
536
+ warnings.warn("creating a filter that removes all datums")
537
+ return Filter(
538
+ mask_datums=np.zeros_like(mask_datums),
539
+ mask_groundtruths=np.array([], dtype=np.bool_),
540
+ mask_predictions=np.array([], dtype=np.bool_),
541
+ metadata=Metadata(),
542
+ )
669
543
  valid_datum_indices = np.array(
670
544
  [self.datum_id_to_index[uid] for uid in datum_ids],
671
545
  dtype=np.int32,
672
546
  )
547
+ mask_datums = np.isin(
548
+ self._detailed_pairs[:, 0], valid_datum_indices
549
+ )
550
+
551
+ filtered_detailed_pairs = self._detailed_pairs[mask_datums]
552
+ n_pairs = self._detailed_pairs[mask_datums].shape[0]
553
+ mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
554
+ mask_predictions = np.zeros_like(mask_groundtruths)
673
555
 
674
- valid_groundtruth_indices = None
556
+ # filter by ground truth annotation ids
675
557
  if groundtruth_ids is not None:
558
+ if not groundtruth_ids:
559
+ warnings.warn(
560
+ "creating a filter that removes all ground truths"
561
+ )
676
562
  valid_groundtruth_indices = np.array(
677
563
  [self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
678
564
  dtype=np.int32,
679
565
  )
566
+ mask_groundtruths[
567
+ ~np.isin(
568
+ filtered_detailed_pairs[:, 1],
569
+ valid_groundtruth_indices,
570
+ )
571
+ ] = True
680
572
 
681
- valid_prediction_indices = None
573
+ # filter by prediction annotation ids
682
574
  if prediction_ids is not None:
575
+ if not prediction_ids:
576
+ warnings.warn("creating a filter that removes all predictions")
683
577
  valid_prediction_indices = np.array(
684
578
  [self.prediction_id_to_index[uid] for uid in prediction_ids],
685
579
  dtype=np.int32,
686
580
  )
581
+ mask_predictions[
582
+ ~np.isin(
583
+ filtered_detailed_pairs[:, 2],
584
+ valid_prediction_indices,
585
+ )
586
+ ] = True
687
587
 
688
- valid_label_indices = None
588
+ # filter by labels
689
589
  if labels is not None:
690
590
  if not labels:
691
- self._filtered_detailed_pairs = np.array([], dtype=np.float64)
692
- warnings.warn("no valid filtered pairs")
693
- return
591
+ warnings.warn("creating a filter that removes all labels")
592
+ return Filter(
593
+ mask_datums=mask_datums,
594
+ mask_groundtruths=np.ones_like(mask_datums),
595
+ mask_predictions=np.ones_like(mask_datums),
596
+ metadata=Metadata(),
597
+ )
694
598
  valid_label_indices = np.array(
695
599
  [self.label_to_index[label] for label in labels] + [-1]
696
600
  )
697
-
698
- # filter datums
699
- if valid_datum_indices is not None:
700
- mask_valid_datums = np.isin(
701
- self._filtered_detailed_pairs[:, 0], valid_datum_indices
702
- )
703
- self._filtered_detailed_pairs = self._filtered_detailed_pairs[
704
- mask_valid_datums
705
- ]
706
-
707
- n_rows = self._filtered_detailed_pairs.shape[0]
708
- mask_invalid_groundtruths = np.zeros(n_rows, dtype=np.bool_)
709
- mask_invalid_predictions = np.zeros_like(mask_invalid_groundtruths)
710
-
711
- # filter ground truth annotations
712
- if valid_groundtruth_indices is not None:
713
- mask_invalid_groundtruths[
714
- ~np.isin(
715
- self._filtered_detailed_pairs[:, 1],
716
- valid_groundtruth_indices,
717
- )
601
+ mask_groundtruths[
602
+ ~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
718
603
  ] = True
719
-
720
- # filter prediction annotations
721
- if valid_prediction_indices is not None:
722
- mask_invalid_predictions[
723
- ~np.isin(
724
- self._filtered_detailed_pairs[:, 2],
725
- valid_prediction_indices,
726
- )
604
+ mask_predictions[
605
+ ~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
727
606
  ] = True
728
607
 
729
- # filter labels
730
- if valid_label_indices is not None:
731
- mask_invalid_groundtruths[
732
- ~np.isin(
733
- self._filtered_detailed_pairs[:, 3], valid_label_indices
734
- )
735
- ] = True
736
- mask_invalid_predictions[
737
- ~np.isin(
738
- self._filtered_detailed_pairs[:, 4], valid_label_indices
739
- )
740
- ] = True
608
+ filtered_detailed_pairs, _, _ = filter_cache(
609
+ self._detailed_pairs,
610
+ mask_datums=mask_datums,
611
+ mask_ground_truths=mask_groundtruths,
612
+ mask_predictions=mask_predictions,
613
+ n_labels=len(self.index_to_label),
614
+ )
741
615
 
742
- # filter cache
743
- if mask_invalid_groundtruths.any():
744
- invalid_groundtruth_indices = np.where(mask_invalid_groundtruths)[
745
- 0
746
- ]
747
- self._filtered_detailed_pairs[
748
- invalid_groundtruth_indices[:, None], (1, 3, 5)
749
- ] = np.array([[-1, -1, 0]])
750
-
751
- if mask_invalid_predictions.any():
752
- invalid_prediction_indices = np.where(mask_invalid_predictions)[0]
753
- self._filtered_detailed_pairs[
754
- invalid_prediction_indices[:, None], (2, 4, 5, 6)
755
- ] = np.array([[-1, -1, 0, -1]])
756
-
757
- # filter null pairs
758
- mask_null_pairs = np.all(
759
- np.isclose(
760
- self._filtered_detailed_pairs[:, 1:5],
761
- np.array([-1.0, -1.0, -1.0, -1.0]),
616
+ number_of_datums = (
617
+ len(datum_ids)
618
+ if datum_ids
619
+ else np.unique(filtered_detailed_pairs[:, 0]).size
620
+ )
621
+
622
+ return Filter(
623
+ mask_datums=mask_datums,
624
+ mask_groundtruths=mask_groundtruths,
625
+ mask_predictions=mask_predictions,
626
+ metadata=Metadata.create(
627
+ detailed_pairs=filtered_detailed_pairs,
628
+ number_of_datums=number_of_datums,
629
+ number_of_labels=len(self.index_to_label),
762
630
  ),
763
- axis=1,
764
631
  )
765
- self._filtered_detailed_pairs = self._filtered_detailed_pairs[
766
- ~mask_null_pairs
767
- ]
768
632
 
769
- if self._filtered_detailed_pairs.size == 0:
770
- self._ranked_pairs = np.array([], dtype=np.float64)
771
- self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
772
- warnings.warn("no valid filtered pairs")
773
- return
633
+ def filter(
634
+ self, filter_: Filter
635
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
636
+ """
637
+ Performs filtering over the internal cache.
774
638
 
775
- # sorts by score, iou with ground truth id as a tie-breaker
776
- indices = np.lexsort(
777
- (
778
- self._filtered_detailed_pairs[:, 1], # ground truth id
779
- -self._filtered_detailed_pairs[:, 5], # iou
780
- -self._filtered_detailed_pairs[:, 6], # score
639
+ Parameters
640
+ ----------
641
+ filter_ : Filter
642
+ The filter parameterization.
643
+
644
+ Returns
645
+ -------
646
+ NDArray[float64]
647
+ Filtered detailed pairs.
648
+ NDArray[float64]
649
+ Filtered ranked pairs.
650
+ NDArray[int32]
651
+ Label metadata.
652
+ """
653
+ if not filter_.mask_datums.any():
654
+ warnings.warn("filter removed all datums")
655
+ return (
656
+ np.array([], dtype=np.float64),
657
+ np.array([], dtype=np.float64),
658
+ np.zeros((self.metadata.number_of_labels, 2), dtype=np.int32),
781
659
  )
660
+ if filter_.mask_groundtruths.all():
661
+ warnings.warn("filter removed all ground truths")
662
+ if filter_.mask_predictions.all():
663
+ warnings.warn("filter removed all predictions")
664
+ return filter_cache(
665
+ detailed_pairs=self._detailed_pairs,
666
+ mask_datums=filter_.mask_datums,
667
+ mask_ground_truths=filter_.mask_groundtruths,
668
+ mask_predictions=filter_.mask_predictions,
669
+ n_labels=len(self.index_to_label),
670
+ )
671
+
672
+ def compute_precision_recall(
673
+ self,
674
+ iou_thresholds: list[float],
675
+ score_thresholds: list[float],
676
+ filter_: Filter | None = None,
677
+ ) -> dict[MetricType, list[Metric]]:
678
+ """
679
+ Computes all metrics except for ConfusionMatrix
680
+
681
+ Parameters
682
+ ----------
683
+ iou_thresholds : list[float]
684
+ A list of IOU thresholds to compute metrics over.
685
+ score_thresholds : list[float]
686
+ A list of score thresholds to compute metrics over.
687
+ filter_ : Filter, optional
688
+ A collection of filter parameters and masks.
689
+
690
+ Returns
691
+ -------
692
+ dict[MetricType, list]
693
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
694
+ """
695
+ if not iou_thresholds:
696
+ raise ValueError("At least one IOU threshold must be passed.")
697
+ elif not score_thresholds:
698
+ raise ValueError("At least one score threshold must be passed.")
699
+
700
+ if filter_ is not None:
701
+ _, ranked_pairs, label_metadata = self.filter(filter_=filter_)
702
+ else:
703
+ ranked_pairs = self._ranked_pairs
704
+ label_metadata = self._label_metadata
705
+
706
+ results = compute_precion_recall(
707
+ ranked_pairs=ranked_pairs,
708
+ label_metadata=label_metadata,
709
+ iou_thresholds=np.array(iou_thresholds),
710
+ score_thresholds=np.array(score_thresholds),
711
+ )
712
+ return unpack_precision_recall_into_metric_lists(
713
+ results=results,
714
+ label_metadata=label_metadata,
715
+ iou_thresholds=iou_thresholds,
716
+ score_thresholds=score_thresholds,
717
+ index_to_label=self.index_to_label,
782
718
  )
783
- self._filtered_detailed_pairs = self._filtered_detailed_pairs[indices]
784
- self._filtered_label_metadata = compute_label_metadata(
785
- ids=self._filtered_detailed_pairs[:, :5].astype(np.int32),
786
- n_labels=self.n_labels,
719
+
720
+ def compute_confusion_matrix(
721
+ self,
722
+ iou_thresholds: list[float],
723
+ score_thresholds: list[float],
724
+ filter_: Filter | None = None,
725
+ ) -> list[Metric]:
726
+ """
727
+ Computes confusion matrices at various thresholds.
728
+
729
+ Parameters
730
+ ----------
731
+ iou_thresholds : list[float]
732
+ A list of IOU thresholds to compute metrics over.
733
+ score_thresholds : list[float]
734
+ A list of score thresholds to compute metrics over.
735
+ filter_ : Filter, optional
736
+ A collection of filter parameters and masks.
737
+
738
+ Returns
739
+ -------
740
+ list[Metric]
741
+ List of confusion matrices per threshold pair.
742
+ """
743
+ if not iou_thresholds:
744
+ raise ValueError("At least one IOU threshold must be passed.")
745
+ elif not score_thresholds:
746
+ raise ValueError("At least one score threshold must be passed.")
747
+
748
+ if filter_ is not None:
749
+ detailed_pairs, _, _ = self.filter(filter_=filter_)
750
+ else:
751
+ detailed_pairs = self._detailed_pairs
752
+
753
+ if detailed_pairs.size == 0:
754
+ warnings.warn("attempted to compute over an empty set")
755
+ return []
756
+
757
+ results = compute_confusion_matrix(
758
+ detailed_pairs=detailed_pairs,
759
+ iou_thresholds=np.array(iou_thresholds),
760
+ score_thresholds=np.array(score_thresholds),
787
761
  )
788
- self._filtered_ranked_pairs = rank_pairs(
789
- detailed_pairs=self._filtered_detailed_pairs,
790
- label_metadata=self._filtered_label_metadata,
762
+ return unpack_confusion_matrix_into_metric_list(
763
+ results=results,
764
+ detailed_pairs=detailed_pairs,
765
+ iou_thresholds=iou_thresholds,
766
+ score_thresholds=score_thresholds,
767
+ index_to_datum_id=self.index_to_datum_id,
768
+ index_to_groundtruth_id=self.index_to_groundtruth_id,
769
+ index_to_prediction_id=self.index_to_prediction_id,
770
+ index_to_label=self.index_to_label,
791
771
  )
792
772
 
793
- def clear_filter(self):
794
- """Removes a filter if one exists."""
795
- self._filtered_detailed_pairs = None
796
- self._filtered_ranked_pairs = None
797
- self._filtered_label_metadata = None
773
+ def evaluate(
774
+ self,
775
+ iou_thresholds: list[float] = [0.1, 0.5, 0.75],
776
+ score_thresholds: list[float] = [0.5],
777
+ filter_: Filter | None = None,
778
+ ) -> dict[MetricType, list[Metric]]:
779
+ """
780
+ Computes all available metrics.
781
+
782
+ Parameters
783
+ ----------
784
+ iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
785
+ A list of IOU thresholds to compute metrics over.
786
+ score_thresholds : list[float], default=[0.5]
787
+ A list of score thresholds to compute metrics over.
788
+ filter_ : Filter, optional
789
+ A collection of filter parameters and masks.
790
+
791
+ Returns
792
+ -------
793
+ dict[MetricType, list[Metric]]
794
+ Lists of metrics organized by metric type.
795
+ """
796
+ metrics = self.compute_precision_recall(
797
+ iou_thresholds=iou_thresholds,
798
+ score_thresholds=score_thresholds,
799
+ filter_=filter_,
800
+ )
801
+ metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
802
+ iou_thresholds=iou_thresholds,
803
+ score_thresholds=score_thresholds,
804
+ filter_=filter_,
805
+ )
806
+ return metrics
798
807
 
799
808
 
800
809
  class DataLoader(Evaluator):