valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. valor_lite/cache/__init__.py +11 -0
  2. valor_lite/cache/compute.py +211 -0
  3. valor_lite/cache/ephemeral.py +302 -0
  4. valor_lite/cache/persistent.py +536 -0
  5. valor_lite/classification/__init__.py +5 -10
  6. valor_lite/classification/annotation.py +4 -0
  7. valor_lite/classification/computation.py +233 -251
  8. valor_lite/classification/evaluator.py +882 -0
  9. valor_lite/classification/loader.py +97 -0
  10. valor_lite/classification/metric.py +141 -4
  11. valor_lite/classification/shared.py +184 -0
  12. valor_lite/classification/utilities.py +221 -118
  13. valor_lite/exceptions.py +5 -0
  14. valor_lite/object_detection/__init__.py +5 -4
  15. valor_lite/object_detection/annotation.py +13 -1
  16. valor_lite/object_detection/computation.py +368 -299
  17. valor_lite/object_detection/evaluator.py +804 -0
  18. valor_lite/object_detection/loader.py +292 -0
  19. valor_lite/object_detection/metric.py +152 -3
  20. valor_lite/object_detection/shared.py +206 -0
  21. valor_lite/object_detection/utilities.py +182 -100
  22. valor_lite/semantic_segmentation/__init__.py +5 -4
  23. valor_lite/semantic_segmentation/annotation.py +7 -0
  24. valor_lite/semantic_segmentation/computation.py +20 -110
  25. valor_lite/semantic_segmentation/evaluator.py +414 -0
  26. valor_lite/semantic_segmentation/loader.py +205 -0
  27. valor_lite/semantic_segmentation/shared.py +149 -0
  28. valor_lite/semantic_segmentation/utilities.py +6 -23
  29. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
  30. valor_lite-0.37.5.dist-info/RECORD +49 -0
  31. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
  32. valor_lite/classification/manager.py +0 -545
  33. valor_lite/object_detection/manager.py +0 -864
  34. valor_lite/profiling.py +0 -374
  35. valor_lite/semantic_segmentation/benchmark.py +0 -237
  36. valor_lite/semantic_segmentation/manager.py +0 -446
  37. valor_lite-0.36.6.dist-info/RECORD +0 -41
  38. {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
@@ -1,864 +0,0 @@
1
- import warnings
2
- from dataclasses import asdict, dataclass
3
-
4
- import numpy as np
5
- from numpy.typing import NDArray
6
- from tqdm import tqdm
7
-
8
- from valor_lite.exceptions import (
9
- EmptyEvaluatorError,
10
- EmptyFilterError,
11
- InternalCacheError,
12
- )
13
- from valor_lite.object_detection.annotation import (
14
- Bitmask,
15
- BoundingBox,
16
- Detection,
17
- Polygon,
18
- )
19
- from valor_lite.object_detection.computation import (
20
- compute_bbox_iou,
21
- compute_bitmask_iou,
22
- compute_confusion_matrix,
23
- compute_label_metadata,
24
- compute_polygon_iou,
25
- compute_precion_recall,
26
- filter_cache,
27
- rank_pairs,
28
- )
29
- from valor_lite.object_detection.metric import Metric, MetricType
30
- from valor_lite.object_detection.utilities import (
31
- unpack_confusion_matrix_into_metric_list,
32
- unpack_precision_recall_into_metric_lists,
33
- )
34
-
35
- """
36
- Usage
37
- -----
38
-
39
- loader = DataLoader()
40
- loader.add_bounding_boxes(
41
- groundtruths=groundtruths,
42
- predictions=predictions,
43
- )
44
- evaluator = loader.finalize()
45
-
46
- metrics = evaluator.evaluate(iou_thresholds=[0.5])
47
-
48
- ap_metrics = metrics[MetricType.AP]
49
- ar_metrics = metrics[MetricType.AR]
50
-
51
- filter_mask = evaluator.create_filter(datum_uids=["uid1", "uid2"])
52
- filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_mask)
53
- """
54
-
55
-
56
- @dataclass
57
- class Metadata:
58
- number_of_datums: int = 0
59
- number_of_ground_truths: int = 0
60
- number_of_predictions: int = 0
61
- number_of_labels: int = 0
62
-
63
- @classmethod
64
- def create(
65
- cls,
66
- detailed_pairs: NDArray[np.float64],
67
- number_of_datums: int,
68
- number_of_labels: int,
69
- ):
70
- # count number of ground truths
71
- mask_valid_gts = detailed_pairs[:, 1] >= 0
72
- unique_ids = np.unique(
73
- detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
74
- )
75
- number_of_ground_truths = int(unique_ids.shape[0])
76
-
77
- # count number of predictions
78
- mask_valid_pds = detailed_pairs[:, 2] >= 0
79
- unique_ids = np.unique(
80
- detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
81
- )
82
- number_of_predictions = int(unique_ids.shape[0])
83
-
84
- return cls(
85
- number_of_datums=number_of_datums,
86
- number_of_ground_truths=number_of_ground_truths,
87
- number_of_predictions=number_of_predictions,
88
- number_of_labels=number_of_labels,
89
- )
90
-
91
- def to_dict(self) -> dict[str, int | bool]:
92
- return asdict(self)
93
-
94
-
95
- @dataclass
96
- class Filter:
97
- mask_datums: NDArray[np.bool_]
98
- mask_groundtruths: NDArray[np.bool_]
99
- mask_predictions: NDArray[np.bool_]
100
- metadata: Metadata
101
-
102
- def __post_init__(self):
103
- # validate datums mask
104
- if not self.mask_datums.any():
105
- raise EmptyFilterError("filter removes all datums")
106
-
107
- # validate annotation masks
108
- no_gts = self.mask_groundtruths.all()
109
- no_pds = self.mask_predictions.all()
110
- if no_gts and no_pds:
111
- raise EmptyFilterError("filter removes all annotations")
112
- elif no_gts:
113
- warnings.warn("filter removes all ground truths")
114
- elif no_pds:
115
- warnings.warn("filter removes all predictions")
116
-
117
-
118
- class Evaluator:
119
- """
120
- Object Detection Evaluator
121
- """
122
-
123
- def __init__(self):
124
-
125
- # external reference
126
- self.datum_id_to_index: dict[str, int] = {}
127
- self.groundtruth_id_to_index: dict[str, int] = {}
128
- self.prediction_id_to_index: dict[str, int] = {}
129
- self.label_to_index: dict[str, int] = {}
130
-
131
- self.index_to_datum_id: list[str] = []
132
- self.index_to_groundtruth_id: list[str] = []
133
- self.index_to_prediction_id: list[str] = []
134
- self.index_to_label: list[str] = []
135
-
136
- # temporary cache
137
- self._temp_cache: list[NDArray[np.float64]] | None = []
138
-
139
- # internal cache
140
- self._detailed_pairs = np.array([[]], dtype=np.float64)
141
- self._ranked_pairs = np.array([[]], dtype=np.float64)
142
- self._label_metadata: NDArray[np.int32] = np.array([[]])
143
- self._metadata = Metadata()
144
-
145
- @property
146
- def ignored_prediction_labels(self) -> list[str]:
147
- """
148
- Prediction labels that are not present in the ground truth set.
149
- """
150
- glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
151
- plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
152
- return [
153
- self.index_to_label[label_id] for label_id in (plabels - glabels)
154
- ]
155
-
156
- @property
157
- def missing_prediction_labels(self) -> list[str]:
158
- """
159
- Ground truth labels that are not present in the prediction set.
160
- """
161
- glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
162
- plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
163
- return [
164
- self.index_to_label[label_id] for label_id in (glabels - plabels)
165
- ]
166
-
167
- @property
168
- def metadata(self) -> Metadata:
169
- """
170
- Evaluation metadata.
171
- """
172
- return self._metadata
173
-
174
- def create_filter(
175
- self,
176
- datums: list[str] | NDArray[np.int32] | None = None,
177
- groundtruths: list[str] | NDArray[np.int32] | None = None,
178
- predictions: list[str] | NDArray[np.int32] | None = None,
179
- labels: list[str] | NDArray[np.int32] | None = None,
180
- ) -> Filter:
181
- """
182
- Creates a filter object.
183
-
184
- Parameters
185
- ----------
186
- datum : list[str] | NDArray[int32], optional
187
- An optional list of string ids or indices representing datums to keep.
188
- groundtruth : list[str] | NDArray[int32], optional
189
- An optional list of string ids or indices representing ground truth annotations to keep.
190
- prediction : list[str] | NDArray[int32], optional
191
- An optional list of string ids or indices representing prediction annotations to keep.
192
- labels : list[str] | NDArray[int32], optional
193
- An optional list of labels or indices to keep.
194
- """
195
- mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
196
-
197
- # filter datums
198
- if datums is not None:
199
- # convert to indices
200
- if isinstance(datums, list):
201
- datums = np.array(
202
- [self.datum_id_to_index[uid] for uid in datums],
203
- dtype=np.int32,
204
- )
205
-
206
- # validate indices
207
- if datums.size == 0:
208
- raise EmptyFilterError(
209
- "filter removes all datums"
210
- ) # validate indices
211
- elif datums.min() < 0:
212
- raise ValueError(
213
- f"datum index cannot be negative '{datums.min()}'"
214
- )
215
- elif datums.max() >= len(self.index_to_datum_id):
216
- raise ValueError(
217
- f"datum index cannot exceed total number of datums '{datums.max()}'"
218
- )
219
-
220
- # apply to mask
221
- mask_datums = np.isin(self._detailed_pairs[:, 0], datums)
222
-
223
- filtered_detailed_pairs = self._detailed_pairs[mask_datums]
224
- n_pairs = self._detailed_pairs[mask_datums].shape[0]
225
- mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
226
- mask_predictions = np.zeros_like(mask_groundtruths)
227
-
228
- # filter by ground truth annotation ids
229
- if groundtruths is not None:
230
- # convert to indices
231
- if isinstance(groundtruths, list):
232
- groundtruths = np.array(
233
- [
234
- self.groundtruth_id_to_index[uid]
235
- for uid in groundtruths
236
- ],
237
- dtype=np.int32,
238
- )
239
-
240
- # validate indices
241
- if groundtruths.size == 0:
242
- warnings.warn("filter removes all ground truths")
243
- elif groundtruths.min() < 0:
244
- raise ValueError(
245
- f"groundtruth annotation index cannot be negative '{groundtruths.min()}'"
246
- )
247
- elif groundtruths.max() >= len(self.index_to_groundtruth_id):
248
- raise ValueError(
249
- f"groundtruth annotation index cannot exceed total number of groundtruths '{groundtruths.max()}'"
250
- )
251
-
252
- # apply to mask
253
- mask_groundtruths[
254
- ~np.isin(
255
- filtered_detailed_pairs[:, 1],
256
- groundtruths,
257
- )
258
- ] = True
259
-
260
- # filter by prediction annotation ids
261
- if predictions is not None:
262
- # convert to indices
263
- if isinstance(predictions, list):
264
- predictions = np.array(
265
- [self.prediction_id_to_index[uid] for uid in predictions],
266
- dtype=np.int32,
267
- )
268
-
269
- # validate indices
270
- if predictions.size == 0:
271
- warnings.warn("filter removes all predictions")
272
- elif predictions.min() < 0:
273
- raise ValueError(
274
- f"prediction annotation index cannot be negative '{predictions.min()}'"
275
- )
276
- elif predictions.max() >= len(self.index_to_prediction_id):
277
- raise ValueError(
278
- f"prediction annotation index cannot exceed total number of predictions '{predictions.max()}'"
279
- )
280
-
281
- # apply to mask
282
- mask_predictions[
283
- ~np.isin(
284
- filtered_detailed_pairs[:, 2],
285
- predictions,
286
- )
287
- ] = True
288
-
289
- # filter by labels
290
- if labels is not None:
291
- # convert to indices
292
- if isinstance(labels, list):
293
- labels = np.array(
294
- [self.label_to_index[label] for label in labels]
295
- )
296
-
297
- # validate indices
298
- if labels.size == 0:
299
- raise EmptyFilterError("filter removes all labels")
300
- elif labels.min() < 0:
301
- raise ValueError(
302
- f"label index cannot be negative '{labels.min()}'"
303
- )
304
- elif labels.max() >= len(self.index_to_label):
305
- raise ValueError(
306
- f"label index cannot exceed total number of labels '{labels.max()}'"
307
- )
308
-
309
- # apply to mask
310
- labels = np.concatenate([labels, np.array([-1])]) # add null label
311
- mask_groundtruths[
312
- ~np.isin(filtered_detailed_pairs[:, 3], labels)
313
- ] = True
314
- mask_predictions[
315
- ~np.isin(filtered_detailed_pairs[:, 4], labels)
316
- ] = True
317
-
318
- filtered_detailed_pairs, _, _ = filter_cache(
319
- self._detailed_pairs,
320
- mask_datums=mask_datums,
321
- mask_ground_truths=mask_groundtruths,
322
- mask_predictions=mask_predictions,
323
- n_labels=len(self.index_to_label),
324
- )
325
-
326
- number_of_datums = (
327
- datums.size
328
- if datums is not None
329
- else np.unique(filtered_detailed_pairs[:, 0]).size
330
- )
331
-
332
- return Filter(
333
- mask_datums=mask_datums,
334
- mask_groundtruths=mask_groundtruths,
335
- mask_predictions=mask_predictions,
336
- metadata=Metadata.create(
337
- detailed_pairs=filtered_detailed_pairs,
338
- number_of_datums=number_of_datums,
339
- number_of_labels=len(self.index_to_label),
340
- ),
341
- )
342
-
343
- def filter(
344
- self, filter_: Filter
345
- ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
346
- """
347
- Performs filtering over the internal cache.
348
-
349
- Parameters
350
- ----------
351
- filter_ : Filter
352
- The filter parameterization.
353
-
354
- Returns
355
- -------
356
- NDArray[float64]
357
- Filtered detailed pairs.
358
- NDArray[float64]
359
- Filtered ranked pairs.
360
- NDArray[int32]
361
- Label metadata.
362
- """
363
- return filter_cache(
364
- detailed_pairs=self._detailed_pairs,
365
- mask_datums=filter_.mask_datums,
366
- mask_ground_truths=filter_.mask_groundtruths,
367
- mask_predictions=filter_.mask_predictions,
368
- n_labels=len(self.index_to_label),
369
- )
370
-
371
- def compute_precision_recall(
372
- self,
373
- iou_thresholds: list[float],
374
- score_thresholds: list[float],
375
- filter_: Filter | None = None,
376
- ) -> dict[MetricType, list[Metric]]:
377
- """
378
- Computes all metrics except for ConfusionMatrix
379
-
380
- Parameters
381
- ----------
382
- iou_thresholds : list[float]
383
- A list of IOU thresholds to compute metrics over.
384
- score_thresholds : list[float]
385
- A list of score thresholds to compute metrics over.
386
- filter_ : Filter, optional
387
- A collection of filter parameters and masks.
388
-
389
- Returns
390
- -------
391
- dict[MetricType, list]
392
- A dictionary mapping MetricType enumerations to lists of computed metrics.
393
- """
394
- if not iou_thresholds:
395
- raise ValueError("At least one IOU threshold must be passed.")
396
- elif not score_thresholds:
397
- raise ValueError("At least one score threshold must be passed.")
398
-
399
- if filter_ is not None:
400
- _, ranked_pairs, label_metadata = self.filter(filter_=filter_)
401
- else:
402
- ranked_pairs = self._ranked_pairs
403
- label_metadata = self._label_metadata
404
-
405
- results = compute_precion_recall(
406
- ranked_pairs=ranked_pairs,
407
- label_metadata=label_metadata,
408
- iou_thresholds=np.array(iou_thresholds),
409
- score_thresholds=np.array(score_thresholds),
410
- )
411
- return unpack_precision_recall_into_metric_lists(
412
- results=results,
413
- iou_thresholds=iou_thresholds,
414
- score_thresholds=score_thresholds,
415
- index_to_label=self.index_to_label,
416
- )
417
-
418
- def compute_confusion_matrix(
419
- self,
420
- iou_thresholds: list[float],
421
- score_thresholds: list[float],
422
- filter_: Filter | None = None,
423
- ) -> list[Metric]:
424
- """
425
- Computes confusion matrices at various thresholds.
426
-
427
- Parameters
428
- ----------
429
- iou_thresholds : list[float]
430
- A list of IOU thresholds to compute metrics over.
431
- score_thresholds : list[float]
432
- A list of score thresholds to compute metrics over.
433
- filter_ : Filter, optional
434
- A collection of filter parameters and masks.
435
-
436
- Returns
437
- -------
438
- list[Metric]
439
- List of confusion matrices per threshold pair.
440
- """
441
- if not iou_thresholds:
442
- raise ValueError("At least one IOU threshold must be passed.")
443
- elif not score_thresholds:
444
- raise ValueError("At least one score threshold must be passed.")
445
-
446
- if filter_ is not None:
447
- detailed_pairs, _, _ = self.filter(filter_=filter_)
448
- else:
449
- detailed_pairs = self._detailed_pairs
450
-
451
- if detailed_pairs.size == 0:
452
- return []
453
-
454
- results = compute_confusion_matrix(
455
- detailed_pairs=detailed_pairs,
456
- iou_thresholds=np.array(iou_thresholds),
457
- score_thresholds=np.array(score_thresholds),
458
- )
459
- return unpack_confusion_matrix_into_metric_list(
460
- results=results,
461
- detailed_pairs=detailed_pairs,
462
- iou_thresholds=iou_thresholds,
463
- score_thresholds=score_thresholds,
464
- index_to_datum_id=self.index_to_datum_id,
465
- index_to_groundtruth_id=self.index_to_groundtruth_id,
466
- index_to_prediction_id=self.index_to_prediction_id,
467
- index_to_label=self.index_to_label,
468
- )
469
-
470
- def evaluate(
471
- self,
472
- iou_thresholds: list[float] = [0.1, 0.5, 0.75],
473
- score_thresholds: list[float] = [0.5],
474
- filter_: Filter | None = None,
475
- ) -> dict[MetricType, list[Metric]]:
476
- """
477
- Computes all available metrics.
478
-
479
- Parameters
480
- ----------
481
- iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
482
- A list of IOU thresholds to compute metrics over.
483
- score_thresholds : list[float], default=[0.5]
484
- A list of score thresholds to compute metrics over.
485
- filter_ : Filter, optional
486
- A collection of filter parameters and masks.
487
-
488
- Returns
489
- -------
490
- dict[MetricType, list[Metric]]
491
- Lists of metrics organized by metric type.
492
- """
493
- metrics = self.compute_precision_recall(
494
- iou_thresholds=iou_thresholds,
495
- score_thresholds=score_thresholds,
496
- filter_=filter_,
497
- )
498
- metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
499
- iou_thresholds=iou_thresholds,
500
- score_thresholds=score_thresholds,
501
- filter_=filter_,
502
- )
503
- return metrics
504
-
505
-
506
- class DataLoader:
507
- """
508
- Object Detection DataLoader
509
- """
510
-
511
- def __init__(self):
512
- self._evaluator = Evaluator()
513
- self.pairs: list[NDArray[np.float64]] = list()
514
-
515
- def _add_datum(self, datum_id: str) -> int:
516
- """
517
- Helper function for adding a datum to the cache.
518
-
519
- Parameters
520
- ----------
521
- datum_id : str
522
- The datum identifier.
523
-
524
- Returns
525
- -------
526
- int
527
- The datum index.
528
- """
529
- if datum_id not in self._evaluator.datum_id_to_index:
530
- if len(self._evaluator.datum_id_to_index) != len(
531
- self._evaluator.index_to_datum_id
532
- ):
533
- raise InternalCacheError("datum cache size mismatch")
534
- idx = len(self._evaluator.datum_id_to_index)
535
- self._evaluator.datum_id_to_index[datum_id] = idx
536
- self._evaluator.index_to_datum_id.append(datum_id)
537
- return self._evaluator.datum_id_to_index[datum_id]
538
-
539
- def _add_groundtruth(self, annotation_id: str) -> int:
540
- """
541
- Helper function for adding a ground truth annotation identifier to the cache.
542
-
543
- Parameters
544
- ----------
545
- annotation_id : str
546
- The ground truth annotation identifier.
547
-
548
- Returns
549
- -------
550
- int
551
- The ground truth annotation index.
552
- """
553
- if annotation_id not in self._evaluator.groundtruth_id_to_index:
554
- if len(self._evaluator.groundtruth_id_to_index) != len(
555
- self._evaluator.index_to_groundtruth_id
556
- ):
557
- raise InternalCacheError("ground truth cache size mismatch")
558
- idx = len(self._evaluator.groundtruth_id_to_index)
559
- self._evaluator.groundtruth_id_to_index[annotation_id] = idx
560
- self._evaluator.index_to_groundtruth_id.append(annotation_id)
561
- return self._evaluator.groundtruth_id_to_index[annotation_id]
562
-
563
- def _add_prediction(self, annotation_id: str) -> int:
564
- """
565
- Helper function for adding a prediction annotation identifier to the cache.
566
-
567
- Parameters
568
- ----------
569
- annotation_id : str
570
- The prediction annotation identifier.
571
-
572
- Returns
573
- -------
574
- int
575
- The prediction annotation index.
576
- """
577
- if annotation_id not in self._evaluator.prediction_id_to_index:
578
- if len(self._evaluator.prediction_id_to_index) != len(
579
- self._evaluator.index_to_prediction_id
580
- ):
581
- raise InternalCacheError("prediction cache size mismatch")
582
- idx = len(self._evaluator.prediction_id_to_index)
583
- self._evaluator.prediction_id_to_index[annotation_id] = idx
584
- self._evaluator.index_to_prediction_id.append(annotation_id)
585
- return self._evaluator.prediction_id_to_index[annotation_id]
586
-
587
- def _add_label(self, label: str) -> int:
588
- """
589
- Helper function for adding a label to the cache.
590
-
591
- Parameters
592
- ----------
593
- label : str
594
- The label associated with the annotation.
595
-
596
- Returns
597
- -------
598
- int
599
- Label index.
600
- """
601
- label_id = len(self._evaluator.index_to_label)
602
- if label not in self._evaluator.label_to_index:
603
- if len(self._evaluator.label_to_index) != len(
604
- self._evaluator.index_to_label
605
- ):
606
- raise InternalCacheError("label cache size mismatch")
607
- self._evaluator.label_to_index[label] = label_id
608
- self._evaluator.index_to_label.append(label)
609
- label_id += 1
610
- return self._evaluator.label_to_index[label]
611
-
612
- def _add_data(
613
- self,
614
- detections: list[Detection],
615
- detection_ious: list[NDArray[np.float64]],
616
- show_progress: bool = False,
617
- ):
618
- """
619
- Adds detections to the cache.
620
-
621
- Parameters
622
- ----------
623
- detections : list[Detection]
624
- A list of Detection objects.
625
- detection_ious : list[NDArray[np.float64]]
626
- A list of arrays containing IOUs per detection.
627
- show_progress : bool, default=False
628
- Toggle for tqdm progress bar.
629
- """
630
- disable_tqdm = not show_progress
631
- for detection, ious in tqdm(
632
- zip(detections, detection_ious), disable=disable_tqdm
633
- ):
634
- # cache labels and annotation pairs
635
- pairs = []
636
- datum_idx = self._add_datum(detection.uid)
637
- if detection.groundtruths:
638
- for gidx, gann in enumerate(detection.groundtruths):
639
- groundtruth_idx = self._add_groundtruth(gann.uid)
640
- glabel_idx = self._add_label(gann.labels[0])
641
- if (ious[:, gidx] < 1e-9).all():
642
- pairs.extend(
643
- [
644
- np.array(
645
- [
646
- float(datum_idx),
647
- float(groundtruth_idx),
648
- -1.0,
649
- float(glabel_idx),
650
- -1.0,
651
- 0.0,
652
- -1.0,
653
- ]
654
- )
655
- ]
656
- )
657
- for pidx, pann in enumerate(detection.predictions):
658
- prediction_idx = self._add_prediction(pann.uid)
659
- if (ious[pidx, :] < 1e-9).all():
660
- pairs.extend(
661
- [
662
- np.array(
663
- [
664
- float(datum_idx),
665
- -1.0,
666
- float(prediction_idx),
667
- -1.0,
668
- float(self._add_label(plabel)),
669
- 0.0,
670
- float(pscore),
671
- ]
672
- )
673
- for plabel, pscore in zip(
674
- pann.labels, pann.scores
675
- )
676
- ]
677
- )
678
- if ious[pidx, gidx] >= 1e-9:
679
- pairs.extend(
680
- [
681
- np.array(
682
- [
683
- float(datum_idx),
684
- float(groundtruth_idx),
685
- float(prediction_idx),
686
- float(self._add_label(glabel)),
687
- float(self._add_label(plabel)),
688
- ious[pidx, gidx],
689
- float(pscore),
690
- ]
691
- )
692
- for glabel in gann.labels
693
- for plabel, pscore in zip(
694
- pann.labels, pann.scores
695
- )
696
- ]
697
- )
698
- elif detection.predictions:
699
- for pidx, pann in enumerate(detection.predictions):
700
- prediction_idx = self._add_prediction(pann.uid)
701
- pairs.extend(
702
- [
703
- np.array(
704
- [
705
- float(datum_idx),
706
- -1.0,
707
- float(prediction_idx),
708
- -1.0,
709
- float(self._add_label(plabel)),
710
- 0.0,
711
- float(pscore),
712
- ]
713
- )
714
- for plabel, pscore in zip(pann.labels, pann.scores)
715
- ]
716
- )
717
-
718
- data = np.array(pairs)
719
- if data.size > 0:
720
- self.pairs.append(data)
721
-
722
- def add_bounding_boxes(
723
- self,
724
- detections: list[Detection[BoundingBox]],
725
- show_progress: bool = False,
726
- ):
727
- """
728
- Adds bounding box detections to the cache.
729
-
730
- Parameters
731
- ----------
732
- detections : list[Detection]
733
- A list of Detection objects.
734
- show_progress : bool, default=False
735
- Toggle for tqdm progress bar.
736
- """
737
- ious = [
738
- compute_bbox_iou(
739
- np.array(
740
- [
741
- [gt.extrema, pd.extrema]
742
- for pd in detection.predictions
743
- for gt in detection.groundtruths
744
- ],
745
- dtype=np.float64,
746
- )
747
- ).reshape(len(detection.predictions), len(detection.groundtruths))
748
- for detection in detections
749
- ]
750
- return self._add_data(
751
- detections=detections,
752
- detection_ious=ious,
753
- show_progress=show_progress,
754
- )
755
-
756
- def add_polygons(
757
- self,
758
- detections: list[Detection[Polygon]],
759
- show_progress: bool = False,
760
- ):
761
- """
762
- Adds polygon detections to the cache.
763
-
764
- Parameters
765
- ----------
766
- detections : list[Detection]
767
- A list of Detection objects.
768
- show_progress : bool, default=False
769
- Toggle for tqdm progress bar.
770
- """
771
- ious = [
772
- compute_polygon_iou(
773
- np.array(
774
- [
775
- [gt.shape, pd.shape]
776
- for pd in detection.predictions
777
- for gt in detection.groundtruths
778
- ]
779
- )
780
- ).reshape(len(detection.predictions), len(detection.groundtruths))
781
- for detection in detections
782
- ]
783
- return self._add_data(
784
- detections=detections,
785
- detection_ious=ious,
786
- show_progress=show_progress,
787
- )
788
-
789
- def add_bitmasks(
790
- self,
791
- detections: list[Detection[Bitmask]],
792
- show_progress: bool = False,
793
- ):
794
- """
795
- Adds bitmask detections to the cache.
796
-
797
- Parameters
798
- ----------
799
- detections : list[Detection]
800
- A list of Detection objects.
801
- show_progress : bool, default=False
802
- Toggle for tqdm progress bar.
803
- """
804
- ious = [
805
- compute_bitmask_iou(
806
- np.array(
807
- [
808
- [gt.mask, pd.mask]
809
- for pd in detection.predictions
810
- for gt in detection.groundtruths
811
- ]
812
- )
813
- ).reshape(len(detection.predictions), len(detection.groundtruths))
814
- for detection in detections
815
- ]
816
- return self._add_data(
817
- detections=detections,
818
- detection_ious=ious,
819
- show_progress=show_progress,
820
- )
821
-
822
- def finalize(self) -> Evaluator:
823
- """
824
- Performs data finalization and some preprocessing steps.
825
-
826
- Returns
827
- -------
828
- Evaluator
829
- A ready-to-use evaluator object.
830
- """
831
- if not self.pairs:
832
- raise EmptyEvaluatorError()
833
-
834
- n_labels = len(self._evaluator.index_to_label)
835
- n_datums = len(self._evaluator.index_to_datum_id)
836
-
837
- self._evaluator._detailed_pairs = np.concatenate(self.pairs, axis=0)
838
- if self._evaluator._detailed_pairs.size == 0:
839
- raise EmptyEvaluatorError()
840
-
841
- # order pairs by descending score, iou
842
- indices = np.lexsort(
843
- (
844
- -self._evaluator._detailed_pairs[:, 5], # iou
845
- -self._evaluator._detailed_pairs[:, 6], # score
846
- )
847
- )
848
- self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
849
- indices
850
- ]
851
- self._evaluator._label_metadata = compute_label_metadata(
852
- ids=self._evaluator._detailed_pairs[:, :5].astype(np.int32),
853
- n_labels=n_labels,
854
- )
855
- self._evaluator._ranked_pairs = rank_pairs(
856
- detailed_pairs=self._evaluator._detailed_pairs,
857
- label_metadata=self._evaluator._label_metadata,
858
- )
859
- self._evaluator._metadata = Metadata.create(
860
- detailed_pairs=self._evaluator._detailed_pairs,
861
- number_of_datums=n_datums,
862
- number_of_labels=n_labels,
863
- )
864
- return self._evaluator