valor-lite 0.33.5__py3-none-any.whl → 0.33.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

@@ -0,0 +1,842 @@
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+ from tqdm import tqdm
7
+ from valor_lite.classification.annotation import Classification
8
+ from valor_lite.classification.computation import (
9
+ compute_confusion_matrix,
10
+ compute_metrics,
11
+ )
12
+ from valor_lite.classification.metric import (
13
+ F1,
14
+ ROCAUC,
15
+ Accuracy,
16
+ ConfusionMatrix,
17
+ Counts,
18
+ MetricType,
19
+ Precision,
20
+ Recall,
21
+ mROCAUC,
22
+ )
23
+
24
+ """
25
+ Usage
26
+ -----
27
+
28
+ manager = DataLoader()
29
+ manager.add_data(
30
+ groundtruths=groundtruths,
31
+ predictions=predictions,
32
+ )
33
+ evaluator = manager.finalize()
34
+
35
+ metrics = evaluator.evaluate()
36
+
37
+ f1_metrics = metrics[MetricType.F1]
38
+ accuracy_metrics = metrics[MetricType.Accuracy]
39
+
40
+ filter_mask = evaluator.create_filter(datum_uids=["uid1", "uid2"])
41
+ filtered_metrics = evaluator.evaluate(filter_mask=filter_mask)
42
+ """
43
+
44
+
45
+ @dataclass
46
+ class Filter:
47
+ indices: NDArray[np.int32]
48
+ label_metadata: NDArray[np.int32]
49
+ n_datums: int
50
+
51
+
52
+ class Evaluator:
53
+ """
54
+ Classification Evaluator
55
+ """
56
+
57
+ def __init__(self):
58
+
59
+ # metadata
60
+ self.n_datums = 0
61
+ self.n_groundtruths = 0
62
+ self.n_predictions = 0
63
+ self.n_labels = 0
64
+
65
+ # datum reference
66
+ self.uid_to_index: dict[str, int] = dict()
67
+ self.index_to_uid: dict[int, str] = dict()
68
+
69
+ # label reference
70
+ self.label_to_index: dict[tuple[str, str], int] = dict()
71
+ self.index_to_label: dict[int, tuple[str, str]] = dict()
72
+
73
+ # label key reference
74
+ self.index_to_label_key: dict[int, str] = dict()
75
+ self.label_key_to_index: dict[str, int] = dict()
76
+ self.label_index_to_label_key_index: dict[int, int] = dict()
77
+
78
+ # computation caches
79
+ self._detailed_pairs = np.array([])
80
+ self._label_metadata = np.array([], dtype=np.int32)
81
+ self._label_metadata_per_datum = np.array([], dtype=np.int32)
82
+
83
+ @property
84
+ def ignored_prediction_labels(self) -> list[tuple[str, str]]:
85
+ """
86
+ Prediction labels that are not present in the ground truth set.
87
+ """
88
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
89
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
90
+ return [
91
+ self.index_to_label[label_id] for label_id in (plabels - glabels)
92
+ ]
93
+
94
+ @property
95
+ def missing_prediction_labels(self) -> list[tuple[str, str]]:
96
+ """
97
+ Ground truth labels that are not present in the prediction set.
98
+ """
99
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
100
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
101
+ return [
102
+ self.index_to_label[label_id] for label_id in (glabels - plabels)
103
+ ]
104
+
105
+ @property
106
+ def metadata(self) -> dict:
107
+ """
108
+ Evaluation metadata.
109
+ """
110
+ return {
111
+ "n_datums": self.n_datums,
112
+ "n_groundtruths": self.n_groundtruths,
113
+ "n_predictions": self.n_predictions,
114
+ "n_labels": self.n_labels,
115
+ "ignored_prediction_labels": self.ignored_prediction_labels,
116
+ "missing_prediction_labels": self.missing_prediction_labels,
117
+ }
118
+
119
+ def create_filter(
120
+ self,
121
+ datum_uids: list[str] | NDArray[np.int32] | None = None,
122
+ labels: list[tuple[str, str]] | NDArray[np.int32] | None = None,
123
+ label_keys: list[str] | NDArray[np.int32] | None = None,
124
+ ) -> Filter:
125
+ """
126
+ Creates a boolean mask that can be passed to an evaluation.
127
+
128
+ Parameters
129
+ ----------
130
+ datum_uids : list[str] | NDArray[np.int32], optional
131
+ An optional list of string uids or a numpy array of uid indices.
132
+ labels : list[tuple[str, str]] | NDArray[np.int32], optional
133
+ An optional list of labels or a numpy array of label indices.
134
+ label_keys : list[str] | NDArray[np.int32], optional
135
+ An optional list of label keys or a numpy array of label key indices.
136
+
137
+ Returns
138
+ -------
139
+ Filter
140
+ A filter object that can be passed to the `evaluate` method.
141
+ """
142
+ n_rows = self._detailed_pairs.shape[0]
143
+
144
+ n_datums = self._label_metadata_per_datum.shape[1]
145
+ n_labels = self._label_metadata_per_datum.shape[2]
146
+
147
+ mask_pairs = np.ones((n_rows, 1), dtype=np.bool_)
148
+ mask_datums = np.ones(n_datums, dtype=np.bool_)
149
+ mask_labels = np.ones(n_labels, dtype=np.bool_)
150
+
151
+ if datum_uids is not None:
152
+ if isinstance(datum_uids, list):
153
+ datum_uids = np.array(
154
+ [self.uid_to_index[uid] for uid in datum_uids],
155
+ dtype=np.int32,
156
+ )
157
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
158
+ mask[
159
+ np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
160
+ ] = True
161
+ mask_pairs &= mask
162
+
163
+ mask = np.zeros_like(mask_datums, dtype=np.bool_)
164
+ mask[datum_uids] = True
165
+ mask_datums &= mask
166
+
167
+ if labels is not None:
168
+ if isinstance(labels, list):
169
+ labels = np.array(
170
+ [self.label_to_index[label] for label in labels]
171
+ )
172
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
173
+ mask[
174
+ np.isin(self._detailed_pairs[:, 1].astype(int), labels)
175
+ ] = True
176
+ mask_pairs &= mask
177
+
178
+ mask = np.zeros_like(mask_labels, dtype=np.bool_)
179
+ mask[labels] = True
180
+ mask_labels &= mask
181
+
182
+ if label_keys is not None:
183
+ if isinstance(label_keys, list):
184
+ label_keys = np.array(
185
+ [self.label_key_to_index[key] for key in label_keys]
186
+ )
187
+ label_indices = np.where(
188
+ np.isclose(self._label_metadata[:, 2], label_keys)
189
+ )[0]
190
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
191
+ mask[
192
+ np.isin(self._detailed_pairs[:, 1].astype(int), label_indices)
193
+ ] = True
194
+ mask_pairs &= mask
195
+
196
+ mask = np.zeros_like(mask_labels, dtype=np.bool_)
197
+ mask[label_indices] = True
198
+ mask_labels &= mask
199
+
200
+ mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
201
+ label_metadata_per_datum = self._label_metadata_per_datum.copy()
202
+ label_metadata_per_datum[:, ~mask] = 0
203
+
204
+ label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
205
+ label_metadata[:, :2] = np.transpose(
206
+ np.sum(
207
+ label_metadata_per_datum,
208
+ axis=1,
209
+ )
210
+ )
211
+ label_metadata[:, 2] = self._label_metadata[:, 2]
212
+ n_datums = int(np.sum(label_metadata[:, 0]))
213
+
214
+ return Filter(
215
+ indices=np.where(mask_pairs)[0],
216
+ label_metadata=label_metadata,
217
+ n_datums=n_datums,
218
+ )
219
+
220
+ def evaluate(
221
+ self,
222
+ metrics_to_return: list[MetricType] = MetricType.base(),
223
+ score_thresholds: list[float] = [0.0],
224
+ hardmax: bool = True,
225
+ number_of_examples: int = 0,
226
+ filter_: Filter | None = None,
227
+ as_dict: bool = False,
228
+ ) -> dict[MetricType, list]:
229
+ """
230
+ Performs an evaluation and returns metrics.
231
+
232
+ Parameters
233
+ ----------
234
+ metrics_to_return : list[MetricType]
235
+ A list of metrics to return in the results.
236
+ score_thresholds : list[float]
237
+ A list of score thresholds to compute metrics over.
238
+ hardmax : bool
239
+ Toggles whether a hardmax is applied to predictions.
240
+ number_of_examples : int, default=0
241
+ Maximum number of annotation examples to return in ConfusionMatrix.
242
+ filter_ : Filter, optional
243
+ An optional filter object.
244
+
245
+ Returns
246
+ -------
247
+ dict[MetricType, list]
248
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
249
+ """
250
+
251
+ # apply filters
252
+ data = self._detailed_pairs
253
+ label_metadata = self._label_metadata
254
+ n_datums = self.n_datums
255
+ if filter_ is not None:
256
+ data = data[filter_.indices]
257
+ label_metadata = filter_.label_metadata
258
+ n_datums = filter_.n_datums
259
+
260
+ (
261
+ counts,
262
+ precision,
263
+ recall,
264
+ accuracy,
265
+ f1_score,
266
+ rocauc,
267
+ mean_rocauc,
268
+ ) = compute_metrics(
269
+ data=data,
270
+ label_metadata=label_metadata,
271
+ score_thresholds=np.array(score_thresholds),
272
+ hardmax=hardmax,
273
+ n_datums=n_datums,
274
+ )
275
+
276
+ metrics = defaultdict(list)
277
+
278
+ metrics[MetricType.ROCAUC] = [
279
+ ROCAUC(
280
+ value=rocauc[label_idx],
281
+ label=self.index_to_label[label_idx],
282
+ )
283
+ for label_idx in range(label_metadata.shape[0])
284
+ if label_metadata[label_idx, 0] > 0
285
+ ]
286
+
287
+ metrics[MetricType.mROCAUC] = [
288
+ mROCAUC(
289
+ value=mean_rocauc[label_key_idx],
290
+ label_key=self.index_to_label_key[label_key_idx],
291
+ )
292
+ for label_key_idx in range(len(self.label_key_to_index))
293
+ ]
294
+
295
+ for label_idx, label in self.index_to_label.items():
296
+
297
+ kwargs = {
298
+ "label": label,
299
+ "score_thresholds": score_thresholds,
300
+ "hardmax": hardmax,
301
+ }
302
+ row = counts[:, label_idx]
303
+ metrics[MetricType.Counts].append(
304
+ Counts(
305
+ tp=row[:, 0].tolist(),
306
+ fp=row[:, 1].tolist(),
307
+ fn=row[:, 2].tolist(),
308
+ tn=row[:, 3].tolist(),
309
+ **kwargs,
310
+ )
311
+ )
312
+
313
+ # if no groundtruths exists for a label, skip it.
314
+ if label_metadata[label_idx, 0] == 0:
315
+ continue
316
+
317
+ metrics[MetricType.Precision].append(
318
+ Precision(
319
+ value=precision[:, label_idx].tolist(),
320
+ **kwargs,
321
+ )
322
+ )
323
+ metrics[MetricType.Recall].append(
324
+ Recall(
325
+ value=recall[:, label_idx].tolist(),
326
+ **kwargs,
327
+ )
328
+ )
329
+ metrics[MetricType.Accuracy].append(
330
+ Accuracy(
331
+ value=accuracy[:, label_idx].tolist(),
332
+ **kwargs,
333
+ )
334
+ )
335
+ metrics[MetricType.F1].append(
336
+ F1(
337
+ value=f1_score[:, label_idx].tolist(),
338
+ **kwargs,
339
+ )
340
+ )
341
+
342
+ if MetricType.ConfusionMatrix in metrics_to_return:
343
+ metrics[
344
+ MetricType.ConfusionMatrix
345
+ ] = self._compute_confusion_matrix(
346
+ data=data,
347
+ label_metadata=label_metadata,
348
+ score_thresholds=score_thresholds,
349
+ hardmax=hardmax,
350
+ number_of_examples=number_of_examples,
351
+ )
352
+
353
+ for metric in set(metrics.keys()):
354
+ if metric not in metrics_to_return:
355
+ del metrics[metric]
356
+
357
+ if as_dict:
358
+ return {
359
+ mtype: [metric.to_dict() for metric in mvalues]
360
+ for mtype, mvalues in metrics.items()
361
+ }
362
+
363
+ return metrics
364
+
365
+ def _unpack_confusion_matrix(
366
+ self,
367
+ confusion_matrix: NDArray[np.floating],
368
+ label_key_idx: int,
369
+ number_of_labels: int,
370
+ number_of_examples: int,
371
+ ) -> dict[
372
+ str,
373
+ dict[
374
+ str,
375
+ dict[
376
+ str,
377
+ int
378
+ | list[
379
+ dict[
380
+ str,
381
+ str | float,
382
+ ]
383
+ ],
384
+ ],
385
+ ],
386
+ ]:
387
+ """
388
+ Unpacks a numpy array of confusion matrix counts and examples.
389
+ """
390
+
391
+ datum_idx = lambda gt_label_idx, pd_label_idx, example_idx: int( # noqa: E731 - lambda fn
392
+ confusion_matrix[
393
+ gt_label_idx,
394
+ pd_label_idx,
395
+ example_idx * 2 + 1,
396
+ ]
397
+ )
398
+
399
+ score_idx = lambda gt_label_idx, pd_label_idx, example_idx: float( # noqa: E731 - lambda fn
400
+ confusion_matrix[
401
+ gt_label_idx,
402
+ pd_label_idx,
403
+ example_idx * 2 + 2,
404
+ ]
405
+ )
406
+
407
+ return {
408
+ self.index_to_label[gt_label_idx][1]: {
409
+ self.index_to_label[pd_label_idx][1]: {
410
+ "count": max(
411
+ int(confusion_matrix[gt_label_idx, pd_label_idx, 0]),
412
+ 0,
413
+ ),
414
+ "examples": [
415
+ {
416
+ "datum": self.index_to_uid[
417
+ datum_idx(
418
+ gt_label_idx, pd_label_idx, example_idx
419
+ )
420
+ ],
421
+ "score": score_idx(
422
+ gt_label_idx, pd_label_idx, example_idx
423
+ ),
424
+ }
425
+ for example_idx in range(number_of_examples)
426
+ if datum_idx(gt_label_idx, pd_label_idx, example_idx)
427
+ >= 0
428
+ ],
429
+ }
430
+ for pd_label_idx in range(number_of_labels)
431
+ if (
432
+ self.label_index_to_label_key_index[pd_label_idx]
433
+ == label_key_idx
434
+ )
435
+ }
436
+ for gt_label_idx in range(number_of_labels)
437
+ if (
438
+ self.label_index_to_label_key_index[gt_label_idx]
439
+ == label_key_idx
440
+ )
441
+ }
442
+
443
+ def _unpack_missing_predictions(
444
+ self,
445
+ missing_predictions: NDArray[np.int32],
446
+ label_key_idx: int,
447
+ number_of_labels: int,
448
+ number_of_examples: int,
449
+ ) -> dict[str, dict[str, int | list[dict[str, str]]]]:
450
+ """
451
+ Unpacks a numpy array of missing prediction counts and examples.
452
+ """
453
+
454
+ datum_idx = (
455
+ lambda gt_label_idx, example_idx: int( # noqa: E731 - lambda fn
456
+ missing_predictions[
457
+ gt_label_idx,
458
+ example_idx + 1,
459
+ ]
460
+ )
461
+ )
462
+
463
+ return {
464
+ self.index_to_label[gt_label_idx][1]: {
465
+ "count": max(
466
+ int(missing_predictions[gt_label_idx, 0]),
467
+ 0,
468
+ ),
469
+ "examples": [
470
+ {
471
+ "datum": self.index_to_uid[
472
+ datum_idx(gt_label_idx, example_idx)
473
+ ]
474
+ }
475
+ for example_idx in range(number_of_examples)
476
+ if datum_idx(gt_label_idx, example_idx) >= 0
477
+ ],
478
+ }
479
+ for gt_label_idx in range(number_of_labels)
480
+ if (
481
+ self.label_index_to_label_key_index[gt_label_idx]
482
+ == label_key_idx
483
+ )
484
+ }
485
+
486
+ def _compute_confusion_matrix(
487
+ self,
488
+ data: NDArray[np.floating],
489
+ label_metadata: NDArray[np.int32],
490
+ score_thresholds: list[float],
491
+ hardmax: bool,
492
+ number_of_examples: int,
493
+ ) -> list[ConfusionMatrix]:
494
+ """
495
+ Computes a detailed confusion matrix..
496
+
497
+ Parameters
498
+ ----------
499
+ data : NDArray[np.floating]
500
+ A data array containing classification pairs.
501
+ label_metadata : NDArray[np.int32]
502
+ An integer array containing label metadata.
503
+ score_thresholds : list[float]
504
+ A list of score thresholds to compute metrics over.
505
+ hardmax : bool
506
+ Toggles whether a hardmax is applied to predictions.
507
+ number_of_examples : int, default=0
508
+ The number of examples to return per count.
509
+
510
+ Returns
511
+ -------
512
+ list[ConfusionMatrix]
513
+ A list of ConfusionMatrix per label key.
514
+ """
515
+
516
+ if data.size == 0:
517
+ return list()
518
+
519
+ confusion_matrix, missing_predictions = compute_confusion_matrix(
520
+ data=data,
521
+ label_metadata=label_metadata,
522
+ score_thresholds=np.array(score_thresholds),
523
+ hardmax=hardmax,
524
+ n_examples=number_of_examples,
525
+ )
526
+
527
+ n_scores, n_labels, _, _ = confusion_matrix.shape
528
+ return [
529
+ ConfusionMatrix(
530
+ score_threshold=score_thresholds[score_idx],
531
+ label_key=label_key,
532
+ number_of_examples=number_of_examples,
533
+ confusion_matrix=self._unpack_confusion_matrix(
534
+ confusion_matrix=confusion_matrix[score_idx, :, :, :],
535
+ label_key_idx=label_key_idx,
536
+ number_of_labels=n_labels,
537
+ number_of_examples=number_of_examples,
538
+ ),
539
+ missing_predictions=self._unpack_missing_predictions(
540
+ missing_predictions=missing_predictions[score_idx, :, :],
541
+ label_key_idx=label_key_idx,
542
+ number_of_labels=n_labels,
543
+ number_of_examples=number_of_examples,
544
+ ),
545
+ )
546
+ for label_key_idx, label_key in self.index_to_label_key.items()
547
+ for score_idx in range(n_scores)
548
+ ]
549
+
550
+
551
+ class DataLoader:
552
+ """
553
+ Classification DataLoader.
554
+ """
555
+
556
+ def __init__(self):
557
+ self._evaluator = Evaluator()
558
+ self.groundtruth_count = defaultdict(lambda: defaultdict(int))
559
+ self.prediction_count = defaultdict(lambda: defaultdict(int))
560
+
561
+ def _add_datum(self, uid: str) -> int:
562
+ """
563
+ Helper function for adding a datum to the cache.
564
+
565
+ Parameters
566
+ ----------
567
+ uid : str
568
+ The datum uid.
569
+
570
+ Returns
571
+ -------
572
+ int
573
+ The datum index.
574
+ """
575
+ if uid not in self._evaluator.uid_to_index:
576
+ index = len(self._evaluator.uid_to_index)
577
+ self._evaluator.uid_to_index[uid] = index
578
+ self._evaluator.index_to_uid[index] = uid
579
+ return self._evaluator.uid_to_index[uid]
580
+
581
+ def _add_label(self, label: tuple[str, str]) -> tuple[int, int]:
582
+ """
583
+ Helper function for adding a label to the cache.
584
+
585
+ Parameters
586
+ ----------
587
+ label : tuple[str, str]
588
+ The label as a tuple in format (key, value).
589
+
590
+ Returns
591
+ -------
592
+ int
593
+ Label index.
594
+ int
595
+ Label key index.
596
+ """
597
+ label_id = len(self._evaluator.index_to_label)
598
+ label_key_id = len(self._evaluator.index_to_label_key)
599
+ if label not in self._evaluator.label_to_index:
600
+ self._evaluator.label_to_index[label] = label_id
601
+ self._evaluator.index_to_label[label_id] = label
602
+
603
+ # update label key index
604
+ if label[0] not in self._evaluator.label_key_to_index:
605
+ self._evaluator.label_key_to_index[label[0]] = label_key_id
606
+ self._evaluator.index_to_label_key[label_key_id] = label[0]
607
+ label_key_id += 1
608
+
609
+ self._evaluator.label_index_to_label_key_index[
610
+ label_id
611
+ ] = self._evaluator.label_key_to_index[label[0]]
612
+ label_id += 1
613
+
614
+ return (
615
+ self._evaluator.label_to_index[label],
616
+ self._evaluator.label_key_to_index[label[0]],
617
+ )
618
+
619
+ def _add_data(
620
+ self,
621
+ uid_index: int,
622
+ keyed_groundtruths: dict[int, int],
623
+ keyed_predictions: dict[int, list[tuple[int, float]]],
624
+ ):
625
+ gt_keys = set(keyed_groundtruths.keys())
626
+ pd_keys = set(keyed_predictions.keys())
627
+ joint_keys = gt_keys.intersection(pd_keys)
628
+
629
+ gt_unique_keys = gt_keys - pd_keys
630
+ pd_unique_keys = pd_keys - gt_keys
631
+ if gt_unique_keys or pd_unique_keys:
632
+ raise ValueError(
633
+ "Label keys must match between ground truths and predictions."
634
+ )
635
+
636
+ pairs = list()
637
+ for key in joint_keys:
638
+ scores = np.array([score for _, score in keyed_predictions[key]])
639
+ max_score_idx = np.argmax(scores)
640
+
641
+ glabel = keyed_groundtruths[key]
642
+ for idx, (plabel, score) in enumerate(keyed_predictions[key]):
643
+ pairs.append(
644
+ (
645
+ float(uid_index),
646
+ float(glabel),
647
+ float(plabel),
648
+ float(score),
649
+ float(max_score_idx == idx),
650
+ )
651
+ )
652
+
653
+ if self._evaluator._detailed_pairs.size == 0:
654
+ self._evaluator._detailed_pairs = np.array(pairs)
655
+ else:
656
+ self._evaluator._detailed_pairs = np.concatenate(
657
+ [
658
+ self._evaluator._detailed_pairs,
659
+ np.array(pairs),
660
+ ],
661
+ axis=0,
662
+ )
663
+
664
+ def add_data(
665
+ self,
666
+ classifications: list[Classification],
667
+ show_progress: bool = False,
668
+ ):
669
+ """
670
+ Adds classifications to the cache.
671
+
672
+ Parameters
673
+ ----------
674
+ classifications : list[Classification]
675
+ A list of Classification objects.
676
+ show_progress : bool, default=False
677
+ Toggle for tqdm progress bar.
678
+ """
679
+
680
+ disable_tqdm = not show_progress
681
+ for classification in tqdm(classifications, disable=disable_tqdm):
682
+
683
+ # update metadata
684
+ self._evaluator.n_datums += 1
685
+ self._evaluator.n_groundtruths += len(classification.groundtruths)
686
+ self._evaluator.n_predictions += len(classification.predictions)
687
+
688
+ # update datum uid index
689
+ uid_index = self._add_datum(uid=classification.uid)
690
+
691
+ # cache labels and annotations
692
+ keyed_groundtruths = defaultdict(int)
693
+ keyed_predictions = defaultdict(list)
694
+ for glabel in classification.groundtruths:
695
+ label_idx, label_key_idx = self._add_label(glabel)
696
+ self.groundtruth_count[label_idx][uid_index] += 1
697
+ keyed_groundtruths[label_key_idx] = label_idx
698
+ for idx, (plabel, pscore) in enumerate(
699
+ zip(classification.predictions, classification.scores)
700
+ ):
701
+ label_idx, label_key_idx = self._add_label(plabel)
702
+ self.prediction_count[label_idx][uid_index] += 1
703
+ keyed_predictions[label_key_idx].append(
704
+ (
705
+ label_idx,
706
+ pscore,
707
+ )
708
+ )
709
+
710
+ self._add_data(
711
+ uid_index=uid_index,
712
+ keyed_groundtruths=keyed_groundtruths,
713
+ keyed_predictions=keyed_predictions,
714
+ )
715
+
716
+ def add_data_from_valor_dict(
717
+ self,
718
+ classifications: list[tuple[dict, dict]],
719
+ show_progress: bool = False,
720
+ ):
721
+ """
722
+ Adds Valor-format classifications to the cache.
723
+
724
+ Parameters
725
+ ----------
726
+ classifications : list[tuple[dict, dict]]
727
+ A list of groundtruth, prediction pairs in Valor-format dictionaries.
728
+ show_progress : bool, default=False
729
+ Toggle for tqdm progress bar.
730
+ """
731
+
732
+ disable_tqdm = not show_progress
733
+ for groundtruth, prediction in tqdm(
734
+ classifications, disable=disable_tqdm
735
+ ):
736
+
737
+ # update metadata
738
+ self._evaluator.n_datums += 1
739
+ self._evaluator.n_groundtruths += len(groundtruth["annotations"])
740
+ self._evaluator.n_predictions += len(prediction["annotations"])
741
+
742
+ # update datum uid index
743
+ uid_index = self._add_datum(uid=groundtruth["datum"]["uid"])
744
+
745
+ # cache labels and annotations
746
+ keyed_groundtruths = defaultdict(int)
747
+ keyed_predictions = defaultdict(list)
748
+ for gann in groundtruth["annotations"]:
749
+ for valor_label in gann["labels"]:
750
+ glabel = (valor_label["key"], valor_label["value"])
751
+ label_idx, label_key_idx = self._add_label(glabel)
752
+ self.groundtruth_count[label_idx][uid_index] += 1
753
+ keyed_groundtruths[label_key_idx] = label_idx
754
+ for pann in prediction["annotations"]:
755
+ for valor_label in pann["labels"]:
756
+ plabel = (valor_label["key"], valor_label["value"])
757
+ pscore = valor_label["score"]
758
+ label_idx, label_key_idx = self._add_label(plabel)
759
+ self.prediction_count[label_idx][uid_index] += 1
760
+ keyed_predictions[label_key_idx].append(
761
+ (
762
+ label_idx,
763
+ pscore,
764
+ )
765
+ )
766
+
767
+ self._add_data(
768
+ uid_index=uid_index,
769
+ keyed_groundtruths=keyed_groundtruths,
770
+ keyed_predictions=keyed_predictions,
771
+ )
772
+
773
+ def finalize(self) -> Evaluator:
774
+ """
775
+ Performs data finalization and some preprocessing steps.
776
+
777
+ Returns
778
+ -------
779
+ Evaluator
780
+ A ready-to-use evaluator object.
781
+ """
782
+
783
+ if self._evaluator._detailed_pairs.size == 0:
784
+ raise ValueError("No data available to create evaluator.")
785
+
786
+ n_datums = self._evaluator.n_datums
787
+ n_labels = len(self._evaluator.index_to_label)
788
+
789
+ self._evaluator.n_labels = n_labels
790
+
791
+ self._evaluator._label_metadata_per_datum = np.zeros(
792
+ (2, n_datums, n_labels), dtype=np.int32
793
+ )
794
+ for datum_idx in range(n_datums):
795
+ for label_idx in range(n_labels):
796
+ gt_count = (
797
+ self.groundtruth_count[label_idx].get(datum_idx, 0)
798
+ if label_idx in self.groundtruth_count
799
+ else 0
800
+ )
801
+ pd_count = (
802
+ self.prediction_count[label_idx].get(datum_idx, 0)
803
+ if label_idx in self.prediction_count
804
+ else 0
805
+ )
806
+ self._evaluator._label_metadata_per_datum[
807
+ :, datum_idx, label_idx
808
+ ] = np.array([gt_count, pd_count])
809
+
810
+ self._evaluator._label_metadata = np.array(
811
+ [
812
+ [
813
+ np.sum(
814
+ self._evaluator._label_metadata_per_datum[
815
+ 0, :, label_idx
816
+ ]
817
+ ),
818
+ np.sum(
819
+ self._evaluator._label_metadata_per_datum[
820
+ 1, :, label_idx
821
+ ]
822
+ ),
823
+ self._evaluator.label_index_to_label_key_index[label_idx],
824
+ ]
825
+ for label_idx in range(n_labels)
826
+ ],
827
+ dtype=np.int32,
828
+ )
829
+
830
+ # sort pairs by groundtruth, prediction, score
831
+ indices = np.lexsort(
832
+ (
833
+ self._evaluator._detailed_pairs[:, 1],
834
+ self._evaluator._detailed_pairs[:, 2],
835
+ -self._evaluator._detailed_pairs[:, 3],
836
+ )
837
+ )
838
+ self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
839
+ indices
840
+ ]
841
+
842
+ return self._evaluator