valor-lite 0.32.2a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

@@ -0,0 +1,845 @@
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+ from tqdm import tqdm
7
+ from valor_lite.detection.annotation import Detection
8
+ from valor_lite.detection.computation import (
9
+ compute_detailed_pr_curve,
10
+ compute_iou,
11
+ compute_metrics,
12
+ compute_ranked_pairs,
13
+ )
14
+ from valor_lite.detection.metric import (
15
+ AP,
16
+ AR,
17
+ F1,
18
+ Accuracy,
19
+ APAveragedOverIOUs,
20
+ ARAveragedOverScores,
21
+ Counts,
22
+ DetailedPrecisionRecallCurve,
23
+ DetailedPrecisionRecallPoint,
24
+ MetricType,
25
+ Precision,
26
+ PrecisionRecallCurve,
27
+ Recall,
28
+ mAP,
29
+ mAPAveragedOverIOUs,
30
+ mAR,
31
+ mARAveragedOverScores,
32
+ )
33
+
34
+ """
35
+ Usage
36
+ -----
37
+
38
+ manager = DataLoader()
39
+ manager.add_data(
40
+ groundtruths=groundtruths,
41
+ predictions=predictions,
42
+ )
43
+ evaluator = manager.finalize()
44
+
45
+ metrics = evaluator.evaluate(iou_thresholds=[0.5])
46
+
47
+ ap_metrics = metrics[MetricType.AP]
48
+ ar_metrics = metrics[MetricType.AR]
49
+
50
+ filter_mask = evaluator.create_filter(datum_uids=["uid1", "uid2"])
51
+ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_mask)
52
+ """
53
+
54
+
55
+ @dataclass
56
+ class Filter:
57
+ indices: NDArray[np.int32]
58
+ label_metadata: NDArray[np.int32]
59
+
60
+
61
+ class Evaluator:
62
+ def __init__(self):
63
+
64
+ # metadata
65
+ self.n_datums = 0
66
+ self.n_groundtruths = 0
67
+ self.n_predictions = 0
68
+ self.n_labels = 0
69
+
70
+ # datum reference
71
+ self.uid_to_index: dict[str, int] = dict()
72
+ self.index_to_uid: dict[int, str] = dict()
73
+
74
+ # label reference
75
+ self.label_to_index: dict[tuple[str, str], int] = dict()
76
+ self.index_to_label: dict[int, tuple[str, str]] = dict()
77
+
78
+ # label key reference
79
+ self.index_to_label_key: dict[int, str] = dict()
80
+ self.label_key_to_index: dict[str, int] = dict()
81
+ self.label_index_to_label_key_index: dict[int, int] = dict()
82
+
83
+ # computation caches
84
+ self._detailed_pairs = np.array([])
85
+ self._ranked_pairs = np.array([])
86
+ self._label_metadata = np.array([])
87
+ self._label_metadata_per_datum = np.array([])
88
+
89
+ @property
90
+ def ignored_prediction_labels(self) -> list[tuple[str, str]]:
91
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
92
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
93
+ return [
94
+ self.index_to_label[label_id] for label_id in (plabels - glabels)
95
+ ]
96
+
97
+ @property
98
+ def missing_prediction_labels(self) -> list[tuple[str, str]]:
99
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
100
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
101
+ return [
102
+ self.index_to_label[label_id] for label_id in (glabels - plabels)
103
+ ]
104
+
105
+ @property
106
+ def metadata(self) -> dict:
107
+ return {
108
+ "n_datums": self.n_datums,
109
+ "n_groundtruths": self.n_groundtruths,
110
+ "n_predictions": self.n_predictions,
111
+ "n_labels": self.n_labels,
112
+ "ignored_prediction_labels": self.ignored_prediction_labels,
113
+ "missing_prediction_labels": self.missing_prediction_labels,
114
+ }
115
+
116
+ def create_filter(
117
+ self,
118
+ datum_uids: list[str] | NDArray[np.int32] | None = None,
119
+ labels: list[tuple[str, str]] | NDArray[np.int32] | None = None,
120
+ label_keys: list[str] | NDArray[np.int32] | None = None,
121
+ ) -> Filter:
122
+ """
123
+ Creates a boolean mask that can be passed to an evaluation.
124
+
125
+ Parameters
126
+ ----------
127
+ datum_uids : list[str] | NDArray[np.int32], optional
128
+ An optional list of string uids or a numpy array of uid indices.
129
+ labels : list[tuple[str, str]] | NDArray[np.int32], optional
130
+ An optional list of labels or a numpy array of label indices.
131
+ label_keys : list[str] | NDArray[np.int32], optional
132
+ An optional list of label keys or a numpy array of label key indices.
133
+
134
+ Returns
135
+ -------
136
+ Filter
137
+ A filter object that can be passed to the `evaluate` method.
138
+ """
139
+ n_rows = self._ranked_pairs.shape[0]
140
+
141
+ n_datums = self._label_metadata_per_datum.shape[1]
142
+ n_labels = self._label_metadata_per_datum.shape[2]
143
+
144
+ mask_pairs = np.ones((n_rows, 1), dtype=np.bool_)
145
+ mask_datums = np.ones(n_datums, dtype=np.bool_)
146
+ mask_labels = np.ones(n_labels, dtype=np.bool_)
147
+
148
+ if datum_uids is not None:
149
+ if isinstance(datum_uids, list):
150
+ datum_uids = np.array(
151
+ [self.uid_to_index[uid] for uid in datum_uids],
152
+ dtype=np.int32,
153
+ )
154
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
155
+ mask[
156
+ np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
157
+ ] = True
158
+ mask_pairs &= mask
159
+
160
+ mask = np.zeros_like(mask_datums, dtype=np.bool_)
161
+ mask[datum_uids] = True
162
+ mask_datums &= mask
163
+
164
+ if labels is not None:
165
+ if isinstance(labels, list):
166
+ labels = np.array(
167
+ [self.label_to_index[label] for label in labels]
168
+ )
169
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
170
+ mask[np.isin(self._ranked_pairs[:, 4].astype(int), labels)] = True
171
+ mask_pairs &= mask
172
+
173
+ mask = np.zeros_like(mask_labels, dtype=np.bool_)
174
+ mask[labels] = True
175
+ mask_labels &= mask
176
+
177
+ if label_keys is not None:
178
+ if isinstance(label_keys, list):
179
+ label_keys = np.array(
180
+ [self.label_key_to_index[key] for key in label_keys]
181
+ )
182
+ label_indices = np.where(
183
+ np.isclose(self._label_metadata[:, 2], label_keys)
184
+ )[0]
185
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
186
+ mask[
187
+ np.isin(self._ranked_pairs[:, 4].astype(int), label_indices)
188
+ ] = True
189
+ mask_pairs &= mask
190
+
191
+ mask = np.zeros_like(mask_labels, dtype=np.bool_)
192
+ mask[label_indices] = True
193
+ mask_labels &= mask
194
+
195
+ mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
196
+ label_metadata_per_datum = self._label_metadata_per_datum.copy()
197
+ label_metadata_per_datum[:, ~mask] = 0
198
+
199
+ label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
200
+ label_metadata[:, :2] = np.transpose(
201
+ np.sum(
202
+ label_metadata_per_datum,
203
+ axis=1,
204
+ )
205
+ )
206
+ label_metadata[:, 2] = self._label_metadata[:, 2]
207
+
208
+ return Filter(
209
+ indices=np.where(mask_pairs)[0],
210
+ label_metadata=label_metadata,
211
+ # uids=datum_uids,
212
+ # labels=labels,
213
+ # label_keys=label_keys,
214
+ )
215
+
216
+ def evaluate(
217
+ self,
218
+ iou_thresholds: list[float] = [0.5, 0.75, 0.9],
219
+ score_thresholds: list[float] = [0.5],
220
+ filter_: Filter | None = None,
221
+ ) -> dict[MetricType, list]:
222
+ """
223
+ Runs evaluation over cached data.
224
+
225
+ Parameters
226
+ ----------
227
+ iou_thresholds : list[float]
228
+ A list of iou thresholds to compute over.
229
+ score_thresholds : list[float]
230
+ A list of score thresholds to compute over.
231
+ filter_mask : NDArray[bool], optional
232
+ A boolean mask that filters the cached data.
233
+ """
234
+
235
+ data = self._ranked_pairs
236
+ label_metadata = self._label_metadata
237
+ if filter_ is not None:
238
+ data = data[filter_.indices]
239
+ label_metadata = filter_.label_metadata
240
+
241
+ (
242
+ (
243
+ average_precision,
244
+ mean_average_precision,
245
+ average_precision_average_over_ious,
246
+ mean_average_precision_average_over_ious,
247
+ ),
248
+ (
249
+ average_recall,
250
+ mean_average_recall,
251
+ average_recall_averaged_over_scores,
252
+ mean_average_recall_averaged_over_scores,
253
+ ),
254
+ precision_recall,
255
+ pr_curves,
256
+ ) = compute_metrics(
257
+ data=data,
258
+ label_counts=label_metadata,
259
+ iou_thresholds=np.array(iou_thresholds),
260
+ score_thresholds=np.array(score_thresholds),
261
+ )
262
+
263
+ metrics = defaultdict(list)
264
+
265
+ metrics[MetricType.AP] = [
266
+ AP(
267
+ value=average_precision[iou_idx][label_idx],
268
+ iou=iou_thresholds[iou_idx],
269
+ label=self.index_to_label[label_idx],
270
+ )
271
+ for iou_idx in range(average_precision.shape[0])
272
+ for label_idx in range(average_precision.shape[1])
273
+ if int(label_metadata[label_idx][0]) > 0
274
+ ]
275
+
276
+ metrics[MetricType.mAP] = [
277
+ mAP(
278
+ value=mean_average_precision[iou_idx][label_key_idx],
279
+ iou=iou_thresholds[iou_idx],
280
+ label_key=self.index_to_label_key[label_key_idx],
281
+ )
282
+ for iou_idx in range(mean_average_precision.shape[0])
283
+ for label_key_idx in range(mean_average_precision.shape[1])
284
+ ]
285
+
286
+ metrics[MetricType.APAveragedOverIOUs] = [
287
+ APAveragedOverIOUs(
288
+ value=average_precision_average_over_ious[label_idx],
289
+ ious=iou_thresholds,
290
+ label=self.index_to_label[label_idx],
291
+ )
292
+ for label_idx in range(self.n_labels)
293
+ if int(label_metadata[label_idx][0]) > 0
294
+ ]
295
+
296
+ metrics[MetricType.mAPAveragedOverIOUs] = [
297
+ mAPAveragedOverIOUs(
298
+ value=mean_average_precision_average_over_ious[label_key_idx],
299
+ ious=iou_thresholds,
300
+ label_key=self.index_to_label_key[label_key_idx],
301
+ )
302
+ for label_key_idx in range(
303
+ mean_average_precision_average_over_ious.shape[0]
304
+ )
305
+ ]
306
+
307
+ metrics[MetricType.AR] = [
308
+ AR(
309
+ value=average_recall[score_idx][label_idx],
310
+ ious=iou_thresholds,
311
+ score=score_thresholds[score_idx],
312
+ label=self.index_to_label[label_idx],
313
+ )
314
+ for score_idx in range(average_recall.shape[0])
315
+ for label_idx in range(average_recall.shape[1])
316
+ if int(label_metadata[label_idx][0]) > 0
317
+ ]
318
+
319
+ metrics[MetricType.mAR] = [
320
+ mAR(
321
+ value=mean_average_recall[score_idx][label_key_idx],
322
+ ious=iou_thresholds,
323
+ score=score_thresholds[score_idx],
324
+ label_key=self.index_to_label_key[label_key_idx],
325
+ )
326
+ for score_idx in range(mean_average_recall.shape[0])
327
+ for label_key_idx in range(mean_average_recall.shape[1])
328
+ ]
329
+
330
+ metrics[MetricType.ARAveragedOverScores] = [
331
+ ARAveragedOverScores(
332
+ value=average_recall_averaged_over_scores[label_idx],
333
+ scores=score_thresholds,
334
+ ious=iou_thresholds,
335
+ label=self.index_to_label[label_idx],
336
+ )
337
+ for label_idx in range(self.n_labels)
338
+ if int(label_metadata[label_idx][0]) > 0
339
+ ]
340
+
341
+ metrics[MetricType.mARAveragedOverScores] = [
342
+ mARAveragedOverScores(
343
+ value=mean_average_recall_averaged_over_scores[label_key_idx],
344
+ scores=score_thresholds,
345
+ ious=iou_thresholds,
346
+ label_key=self.index_to_label_key[label_key_idx],
347
+ )
348
+ for label_key_idx in range(
349
+ mean_average_recall_averaged_over_scores.shape[0]
350
+ )
351
+ ]
352
+
353
+ metrics[MetricType.PrecisionRecallCurve] = [
354
+ PrecisionRecallCurve(
355
+ precision=list(pr_curves[iou_idx][label_idx]),
356
+ iou=iou_threshold,
357
+ label=label,
358
+ )
359
+ for iou_idx, iou_threshold in enumerate(iou_thresholds)
360
+ for label_idx, label in self.index_to_label.items()
361
+ if int(label_metadata[label_idx][0]) > 0
362
+ ]
363
+
364
+ for iou_idx, iou_threshold in enumerate(iou_thresholds):
365
+ for score_idx, score_threshold in enumerate(score_thresholds):
366
+ for label_idx, label in self.index_to_label.items():
367
+ row = precision_recall[iou_idx][score_idx][label_idx]
368
+ kwargs = {
369
+ "label": label,
370
+ "iou": iou_threshold,
371
+ "score": score_threshold,
372
+ }
373
+ metrics[MetricType.Counts].append(
374
+ Counts(
375
+ tp=int(row[0]),
376
+ fp=int(row[1]),
377
+ fn=int(row[2]),
378
+ **kwargs,
379
+ )
380
+ )
381
+ metrics[MetricType.Precision].append(
382
+ Precision(
383
+ value=row[3],
384
+ **kwargs,
385
+ )
386
+ )
387
+ metrics[MetricType.Recall].append(
388
+ Recall(
389
+ value=row[4],
390
+ **kwargs,
391
+ )
392
+ )
393
+ metrics[MetricType.F1].append(
394
+ F1(
395
+ value=row[5],
396
+ **kwargs,
397
+ )
398
+ )
399
+ metrics[MetricType.Accuracy].append(
400
+ Accuracy(
401
+ value=row[6],
402
+ **kwargs,
403
+ )
404
+ )
405
+
406
+ return metrics
407
+
408
+ def compute_detailed_pr_curve(
409
+ self,
410
+ iou_thresholds: list[float] = [0.5],
411
+ score_thresholds: list[float] = [
412
+ score / 10.0 for score in range(1, 11)
413
+ ],
414
+ n_samples: int = 0,
415
+ ) -> list[DetailedPrecisionRecallCurve]:
416
+
417
+ if self._detailed_pairs.size == 0:
418
+ return list()
419
+
420
+ metrics = compute_detailed_pr_curve(
421
+ self._detailed_pairs,
422
+ label_counts=self._label_metadata,
423
+ iou_thresholds=np.array(iou_thresholds),
424
+ score_thresholds=np.array(score_thresholds),
425
+ n_samples=n_samples,
426
+ )
427
+
428
+ tp_idx = 0
429
+ fp_misclf_idx = tp_idx + n_samples + 1
430
+ fp_halluc_idx = fp_misclf_idx + n_samples + 1
431
+ fn_misclf_idx = fp_halluc_idx + n_samples + 1
432
+ fn_misprd_idx = fn_misclf_idx + n_samples + 1
433
+
434
+ results = list()
435
+ for label_idx in range(len(metrics)):
436
+ n_ious, n_scores, _, _ = metrics.shape
437
+ for iou_idx in range(n_ious):
438
+ curve = DetailedPrecisionRecallCurve(
439
+ iou=iou_thresholds[iou_idx],
440
+ value=list(),
441
+ label=self.index_to_label[label_idx],
442
+ )
443
+ for score_idx in range(n_scores):
444
+ curve.value.append(
445
+ DetailedPrecisionRecallPoint(
446
+ score=score_thresholds[score_idx],
447
+ tp=metrics[iou_idx][score_idx][label_idx][tp_idx],
448
+ tp_examples=[
449
+ self.index_to_uid[int(datum_idx)]
450
+ for datum_idx in metrics[iou_idx][score_idx][
451
+ label_idx
452
+ ][tp_idx + 1 : fp_misclf_idx]
453
+ if int(datum_idx) >= 0
454
+ ],
455
+ fp_misclassification=metrics[iou_idx][score_idx][
456
+ label_idx
457
+ ][fp_misclf_idx],
458
+ fp_misclassification_examples=[
459
+ self.index_to_uid[int(datum_idx)]
460
+ for datum_idx in metrics[iou_idx][score_idx][
461
+ label_idx
462
+ ][fp_misclf_idx + 1 : fp_halluc_idx]
463
+ if int(datum_idx) >= 0
464
+ ],
465
+ fp_hallucination=metrics[iou_idx][score_idx][
466
+ label_idx
467
+ ][fp_halluc_idx],
468
+ fp_hallucination_examples=[
469
+ self.index_to_uid[int(datum_idx)]
470
+ for datum_idx in metrics[iou_idx][score_idx][
471
+ label_idx
472
+ ][fp_halluc_idx + 1 : fn_misclf_idx]
473
+ if int(datum_idx) >= 0
474
+ ],
475
+ fn_misclassification=metrics[iou_idx][score_idx][
476
+ label_idx
477
+ ][fn_misclf_idx],
478
+ fn_misclassification_examples=[
479
+ self.index_to_uid[int(datum_idx)]
480
+ for datum_idx in metrics[iou_idx][score_idx][
481
+ label_idx
482
+ ][fn_misclf_idx + 1 : fn_misprd_idx]
483
+ if int(datum_idx) >= 0
484
+ ],
485
+ fn_missing_prediction=metrics[iou_idx][score_idx][
486
+ label_idx
487
+ ][fn_misprd_idx],
488
+ fn_missing_prediction_examples=[
489
+ self.index_to_uid[int(datum_idx)]
490
+ for datum_idx in metrics[iou_idx][score_idx][
491
+ label_idx
492
+ ][fn_misprd_idx + 1 :]
493
+ if int(datum_idx) >= 0
494
+ ],
495
+ )
496
+ )
497
+ results.append(curve)
498
+ return results
499
+
500
+
501
+ class DataLoader:
502
+ def __init__(self):
503
+ self._evaluator = Evaluator()
504
+ self.pairs = list()
505
+ self.groundtruth_count = defaultdict(lambda: defaultdict(int))
506
+ self.prediction_count = defaultdict(lambda: defaultdict(int))
507
+
508
+ def _add_datum(self, uid: str) -> int:
509
+ if uid not in self._evaluator.uid_to_index:
510
+ index = len(self._evaluator.uid_to_index)
511
+ self._evaluator.uid_to_index[uid] = index
512
+ self._evaluator.index_to_uid[index] = uid
513
+ return self._evaluator.uid_to_index[uid]
514
+
515
+ def _add_label(self, label: tuple[str, str]) -> tuple[int, int]:
516
+ label_id = len(self._evaluator.index_to_label)
517
+ label_key_id = len(self._evaluator.index_to_label_key)
518
+ if label not in self._evaluator.label_to_index:
519
+ self._evaluator.label_to_index[label] = label_id
520
+ self._evaluator.index_to_label[label_id] = label
521
+
522
+ # update label key index
523
+ if label[0] not in self._evaluator.label_key_to_index:
524
+ self._evaluator.label_key_to_index[label[0]] = label_key_id
525
+ self._evaluator.index_to_label_key[label_key_id] = label[0]
526
+ label_key_id += 1
527
+
528
+ self._evaluator.label_index_to_label_key_index[
529
+ label_id
530
+ ] = self._evaluator.label_key_to_index[label[0]]
531
+ label_id += 1
532
+
533
+ return (
534
+ self._evaluator.label_to_index[label],
535
+ self._evaluator.label_key_to_index[label[0]],
536
+ )
537
+
538
+ def add_data(
539
+ self,
540
+ detections: list[Detection],
541
+ show_progress: bool = False,
542
+ ):
543
+ disable_tqdm = not show_progress
544
+ for detection in tqdm(detections, disable=disable_tqdm):
545
+
546
+ # update metadata
547
+ self._evaluator.n_datums += 1
548
+ self._evaluator.n_groundtruths += len(detection.groundtruths)
549
+ self._evaluator.n_predictions += len(detection.predictions)
550
+
551
+ # update datum uid index
552
+ uid_index = self._add_datum(uid=detection.uid)
553
+
554
+ # cache labels and annotations
555
+ keyed_groundtruths = defaultdict(list)
556
+ keyed_predictions = defaultdict(list)
557
+ for gidx, gann in enumerate(detection.groundtruths):
558
+ for glabel in gann.labels:
559
+ label_idx, label_key_idx = self._add_label(glabel)
560
+ self.groundtruth_count[label_idx][uid_index] += 1
561
+ keyed_groundtruths[label_key_idx].append(
562
+ (
563
+ gidx,
564
+ label_idx,
565
+ gann.extrema,
566
+ )
567
+ )
568
+ for pidx, pann in enumerate(detection.predictions):
569
+ for plabel, pscore in zip(pann.labels, pann.scores):
570
+ label_idx, label_key_idx = self._add_label(plabel)
571
+ self.prediction_count[label_idx][uid_index] += 1
572
+ keyed_predictions[label_key_idx].append(
573
+ (
574
+ pidx,
575
+ label_idx,
576
+ pscore,
577
+ pann.extrema,
578
+ )
579
+ )
580
+
581
+ gt_keys = set(keyed_groundtruths.keys())
582
+ pd_keys = set(keyed_predictions.keys())
583
+ joint_keys = gt_keys.intersection(pd_keys)
584
+ gt_unique_keys = gt_keys - pd_keys
585
+ pd_unique_keys = pd_keys - gt_keys
586
+
587
+ pairs = list()
588
+ for key in joint_keys:
589
+ boxes = np.array(
590
+ [
591
+ np.array([*gextrema, *pextrema])
592
+ for _, _, _, pextrema in keyed_predictions[key]
593
+ for _, _, gextrema in keyed_groundtruths[key]
594
+ ]
595
+ )
596
+ ious = compute_iou(boxes)
597
+ pairs.extend(
598
+ [
599
+ np.array(
600
+ [
601
+ float(uid_index),
602
+ float(gidx),
603
+ float(pidx),
604
+ ious[
605
+ pidx * len(keyed_groundtruths[key]) + gidx
606
+ ],
607
+ float(glabel),
608
+ float(plabel),
609
+ float(score),
610
+ ]
611
+ )
612
+ for pidx, plabel, score, _ in keyed_predictions[key]
613
+ for gidx, glabel, _ in keyed_groundtruths[key]
614
+ ]
615
+ )
616
+ for key in gt_unique_keys:
617
+ pairs.extend(
618
+ [
619
+ np.array(
620
+ [
621
+ float(uid_index),
622
+ float(gidx),
623
+ -1.0,
624
+ 0.0,
625
+ float(glabel),
626
+ -1.0,
627
+ -1.0,
628
+ ]
629
+ )
630
+ for gidx, glabel, _ in keyed_groundtruths[key]
631
+ ]
632
+ )
633
+ for key in pd_unique_keys:
634
+ pairs.extend(
635
+ [
636
+ np.array(
637
+ [
638
+ float(uid_index),
639
+ -1.0,
640
+ float(pidx),
641
+ 0.0,
642
+ -1.0,
643
+ float(plabel),
644
+ float(score),
645
+ ]
646
+ )
647
+ for pidx, plabel, score, _ in keyed_predictions[key]
648
+ ]
649
+ )
650
+
651
+ self.pairs.append(np.array(pairs))
652
+
653
+ def add_data_from_valor_dict(
654
+ self,
655
+ detections: list[tuple[dict, dict]],
656
+ show_progress: bool = False,
657
+ ):
658
+ def _get_bbox_extrema(
659
+ data: list[list[list[float]]],
660
+ ) -> tuple[float, float, float, float]:
661
+ x = [point[0] for shape in data for point in shape]
662
+ y = [point[1] for shape in data for point in shape]
663
+ return (min(x), max(x), min(y), max(y))
664
+
665
+ disable_tqdm = not show_progress
666
+ for groundtruth, prediction in tqdm(detections, disable=disable_tqdm):
667
+
668
+ # update metadata
669
+ self._evaluator.n_datums += 1
670
+ self._evaluator.n_groundtruths += len(groundtruth["annotations"])
671
+ self._evaluator.n_predictions += len(prediction["annotations"])
672
+
673
+ # update datum uid index
674
+ uid_index = self._add_datum(uid=groundtruth["datum"]["uid"])
675
+
676
+ # cache labels and annotations
677
+ keyed_groundtruths = defaultdict(list)
678
+ keyed_predictions = defaultdict(list)
679
+ for gidx, gann in enumerate(groundtruth["annotations"]):
680
+ for valor_label in gann["labels"]:
681
+ glabel = (valor_label["key"], valor_label["value"])
682
+ label_idx, label_key_idx = self._add_label(glabel)
683
+ self.groundtruth_count[label_idx][uid_index] += 1
684
+ keyed_groundtruths[label_key_idx].append(
685
+ (
686
+ gidx,
687
+ label_idx,
688
+ _get_bbox_extrema(gann["bounding_box"]),
689
+ )
690
+ )
691
+ for pidx, pann in enumerate(prediction["annotations"]):
692
+ for valor_label in pann["labels"]:
693
+ plabel = (valor_label["key"], valor_label["value"])
694
+ pscore = valor_label["score"]
695
+ label_idx, label_key_idx = self._add_label(plabel)
696
+ self.prediction_count[label_idx][uid_index] += 1
697
+ keyed_predictions[label_key_idx].append(
698
+ (
699
+ pidx,
700
+ label_idx,
701
+ pscore,
702
+ _get_bbox_extrema(pann["bounding_box"]),
703
+ )
704
+ )
705
+
706
+ gt_keys = set(keyed_groundtruths.keys())
707
+ pd_keys = set(keyed_predictions.keys())
708
+ joint_keys = gt_keys.intersection(pd_keys)
709
+ gt_unique_keys = gt_keys - pd_keys
710
+ pd_unique_keys = pd_keys - gt_keys
711
+
712
+ pairs = list()
713
+ for key in joint_keys:
714
+ boxes = np.array(
715
+ [
716
+ np.array([*gextrema, *pextrema])
717
+ for _, _, _, pextrema in keyed_predictions[key]
718
+ for _, _, gextrema in keyed_groundtruths[key]
719
+ ]
720
+ )
721
+ ious = compute_iou(boxes)
722
+ pairs.extend(
723
+ [
724
+ np.array(
725
+ [
726
+ float(uid_index),
727
+ float(gidx),
728
+ float(pidx),
729
+ ious[
730
+ pidx * len(keyed_groundtruths[key]) + gidx
731
+ ],
732
+ float(glabel),
733
+ float(plabel),
734
+ float(score),
735
+ ]
736
+ )
737
+ for pidx, plabel, score, _ in keyed_predictions[key]
738
+ for gidx, glabel, _ in keyed_groundtruths[key]
739
+ ]
740
+ )
741
+ for key in gt_unique_keys:
742
+ pairs.extend(
743
+ [
744
+ np.array(
745
+ [
746
+ float(uid_index),
747
+ float(gidx),
748
+ -1.0,
749
+ 0.0,
750
+ float(glabel),
751
+ -1.0,
752
+ -1.0,
753
+ ]
754
+ )
755
+ for gidx, glabel, _ in keyed_groundtruths[key]
756
+ ]
757
+ )
758
+ for key in pd_unique_keys:
759
+ pairs.extend(
760
+ [
761
+ np.array(
762
+ [
763
+ float(uid_index),
764
+ -1.0,
765
+ float(pidx),
766
+ 0.0,
767
+ -1.0,
768
+ float(plabel),
769
+ float(score),
770
+ ]
771
+ )
772
+ for pidx, plabel, score, _ in keyed_predictions[key]
773
+ ]
774
+ )
775
+
776
+ self.pairs.append(np.array(pairs))
777
+
778
+ def finalize(self) -> Evaluator:
779
+
780
+ self.pairs = [pair for pair in self.pairs if pair.size > 0]
781
+ if len(self.pairs) == 0:
782
+ raise ValueError("No data available to create evaluator.")
783
+
784
+ n_datums = self._evaluator.n_datums
785
+ n_labels = len(self._evaluator.index_to_label)
786
+
787
+ self._evaluator.n_labels = n_labels
788
+
789
+ self._evaluator._label_metadata_per_datum = np.zeros(
790
+ (2, n_datums, n_labels), dtype=np.int32
791
+ )
792
+ for datum_idx in range(n_datums):
793
+ for label_idx in range(n_labels):
794
+ gt_count = (
795
+ self.groundtruth_count[label_idx].get(datum_idx, 0)
796
+ if label_idx in self.groundtruth_count
797
+ else 0
798
+ )
799
+ pd_count = (
800
+ self.prediction_count[label_idx].get(datum_idx, 0)
801
+ if label_idx in self.prediction_count
802
+ else 0
803
+ )
804
+ self._evaluator._label_metadata_per_datum[
805
+ :, datum_idx, label_idx
806
+ ] = np.array([gt_count, pd_count])
807
+
808
+ self._evaluator._label_metadata = np.array(
809
+ [
810
+ [
811
+ float(
812
+ np.sum(
813
+ self._evaluator._label_metadata_per_datum[
814
+ 0, :, label_idx
815
+ ]
816
+ )
817
+ ),
818
+ float(
819
+ np.sum(
820
+ self._evaluator._label_metadata_per_datum[
821
+ 1, :, label_idx
822
+ ]
823
+ )
824
+ ),
825
+ float(
826
+ self._evaluator.label_index_to_label_key_index[
827
+ label_idx
828
+ ]
829
+ ),
830
+ ]
831
+ for label_idx in range(n_labels)
832
+ ]
833
+ )
834
+
835
+ self._evaluator._detailed_pairs = np.concatenate(
836
+ self.pairs,
837
+ axis=0,
838
+ )
839
+
840
+ self._evaluator._ranked_pairs = compute_ranked_pairs(
841
+ self.pairs,
842
+ label_counts=self._evaluator._label_metadata,
843
+ )
844
+
845
+ return self._evaluator