valor-lite 0.33.5__py3-none-any.whl → 0.33.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,844 @@
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+ from tqdm import tqdm
7
+ from valor_lite.classification.annotation import Classification
8
+ from valor_lite.classification.computation import (
9
+ compute_confusion_matrix,
10
+ compute_metrics,
11
+ )
12
+ from valor_lite.classification.metric import (
13
+ F1,
14
+ ROCAUC,
15
+ Accuracy,
16
+ ConfusionMatrix,
17
+ Counts,
18
+ MetricType,
19
+ Precision,
20
+ Recall,
21
+ mROCAUC,
22
+ )
23
+
24
+ """
25
+ Usage
26
+ -----
27
+
28
+ manager = DataLoader()
29
+ manager.add_data(
30
+ groundtruths=groundtruths,
31
+ predictions=predictions,
32
+ )
33
+ evaluator = manager.finalize()
34
+
35
+ metrics = evaluator.evaluate()
36
+
37
+ f1_metrics = metrics[MetricType.F1]
38
+ accuracy_metrics = metrics[MetricType.Accuracy]
39
+
40
+ filter_mask = evaluator.create_filter(datum_uids=["uid1", "uid2"])
41
+ filtered_metrics = evaluator.evaluate(filter_mask=filter_mask)
42
+ """
43
+
44
+
45
+ @dataclass
46
+ class Filter:
47
+ indices: NDArray[np.int32]
48
+ label_metadata: NDArray[np.int32]
49
+ n_datums: int
50
+
51
+
52
+ class Evaluator:
53
+ """
54
+ Classification Evaluator
55
+ """
56
+
57
+ def __init__(self):
58
+
59
+ # metadata
60
+ self.n_datums = 0
61
+ self.n_groundtruths = 0
62
+ self.n_predictions = 0
63
+ self.n_labels = 0
64
+
65
+ # datum reference
66
+ self.uid_to_index: dict[str, int] = dict()
67
+ self.index_to_uid: dict[int, str] = dict()
68
+
69
+ # label reference
70
+ self.label_to_index: dict[tuple[str, str], int] = dict()
71
+ self.index_to_label: dict[int, tuple[str, str]] = dict()
72
+
73
+ # label key reference
74
+ self.index_to_label_key: dict[int, str] = dict()
75
+ self.label_key_to_index: dict[str, int] = dict()
76
+ self.label_index_to_label_key_index: dict[int, int] = dict()
77
+
78
+ # computation caches
79
+ self._detailed_pairs = np.array([])
80
+ self._label_metadata = np.array([], dtype=np.int32)
81
+ self._label_metadata_per_datum = np.array([], dtype=np.int32)
82
+
83
+ @property
84
+ def ignored_prediction_labels(self) -> list[tuple[str, str]]:
85
+ """
86
+ Prediction labels that are not present in the ground truth set.
87
+ """
88
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
89
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
90
+ return [
91
+ self.index_to_label[label_id] for label_id in (plabels - glabels)
92
+ ]
93
+
94
+ @property
95
+ def missing_prediction_labels(self) -> list[tuple[str, str]]:
96
+ """
97
+ Ground truth labels that are not present in the prediction set.
98
+ """
99
+ glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
100
+ plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
101
+ return [
102
+ self.index_to_label[label_id] for label_id in (glabels - plabels)
103
+ ]
104
+
105
+ @property
106
+ def metadata(self) -> dict:
107
+ """
108
+ Evaluation metadata.
109
+ """
110
+ return {
111
+ "n_datums": self.n_datums,
112
+ "n_groundtruths": self.n_groundtruths,
113
+ "n_predictions": self.n_predictions,
114
+ "n_labels": self.n_labels,
115
+ "ignored_prediction_labels": self.ignored_prediction_labels,
116
+ "missing_prediction_labels": self.missing_prediction_labels,
117
+ }
118
+
119
+ def create_filter(
120
+ self,
121
+ datum_uids: list[str] | NDArray[np.int32] | None = None,
122
+ labels: list[tuple[str, str]] | NDArray[np.int32] | None = None,
123
+ label_keys: list[str] | NDArray[np.int32] | None = None,
124
+ ) -> Filter:
125
+ """
126
+ Creates a boolean mask that can be passed to an evaluation.
127
+
128
+ Parameters
129
+ ----------
130
+ datum_uids : list[str] | NDArray[np.int32], optional
131
+ An optional list of string uids or a numpy array of uid indices.
132
+ labels : list[tuple[str, str]] | NDArray[np.int32], optional
133
+ An optional list of labels or a numpy array of label indices.
134
+ label_keys : list[str] | NDArray[np.int32], optional
135
+ An optional list of label keys or a numpy array of label key indices.
136
+
137
+ Returns
138
+ -------
139
+ Filter
140
+ A filter object that can be passed to the `evaluate` method.
141
+ """
142
+ n_rows = self._detailed_pairs.shape[0]
143
+
144
+ n_datums = self._label_metadata_per_datum.shape[1]
145
+ n_labels = self._label_metadata_per_datum.shape[2]
146
+
147
+ mask_pairs = np.ones((n_rows, 1), dtype=np.bool_)
148
+ mask_datums = np.ones(n_datums, dtype=np.bool_)
149
+ mask_labels = np.ones(n_labels, dtype=np.bool_)
150
+
151
+ if datum_uids is not None:
152
+ if isinstance(datum_uids, list):
153
+ datum_uids = np.array(
154
+ [self.uid_to_index[uid] for uid in datum_uids],
155
+ dtype=np.int32,
156
+ )
157
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
158
+ mask[
159
+ np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
160
+ ] = True
161
+ mask_pairs &= mask
162
+
163
+ mask = np.zeros_like(mask_datums, dtype=np.bool_)
164
+ mask[datum_uids] = True
165
+ mask_datums &= mask
166
+
167
+ if labels is not None:
168
+ if isinstance(labels, list):
169
+ labels = np.array(
170
+ [self.label_to_index[label] for label in labels]
171
+ )
172
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
173
+ mask[
174
+ np.isin(self._detailed_pairs[:, 1].astype(int), labels)
175
+ ] = True
176
+ mask_pairs &= mask
177
+
178
+ mask = np.zeros_like(mask_labels, dtype=np.bool_)
179
+ mask[labels] = True
180
+ mask_labels &= mask
181
+
182
+ if label_keys is not None:
183
+ if isinstance(label_keys, list):
184
+ label_keys = np.array(
185
+ [self.label_key_to_index[key] for key in label_keys]
186
+ )
187
+ label_indices = np.where(
188
+ np.isclose(self._label_metadata[:, 2], label_keys)
189
+ )[0]
190
+ mask = np.zeros_like(mask_pairs, dtype=np.bool_)
191
+ mask[
192
+ np.isin(self._detailed_pairs[:, 1].astype(int), label_indices)
193
+ ] = True
194
+ mask_pairs &= mask
195
+
196
+ mask = np.zeros_like(mask_labels, dtype=np.bool_)
197
+ mask[label_indices] = True
198
+ mask_labels &= mask
199
+
200
+ mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
201
+ label_metadata_per_datum = self._label_metadata_per_datum.copy()
202
+ label_metadata_per_datum[:, ~mask] = 0
203
+
204
+ label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
205
+ label_metadata[:, :2] = np.transpose(
206
+ np.sum(
207
+ label_metadata_per_datum,
208
+ axis=1,
209
+ )
210
+ )
211
+ label_metadata[:, 2] = self._label_metadata[:, 2]
212
+ n_datums = int(np.sum(label_metadata[:, 0]))
213
+
214
+ return Filter(
215
+ indices=np.where(mask_pairs)[0],
216
+ label_metadata=label_metadata,
217
+ n_datums=n_datums,
218
+ )
219
+
220
+ def evaluate(
221
+ self,
222
+ metrics_to_return: list[MetricType] = MetricType.base(),
223
+ score_thresholds: list[float] = [0.0],
224
+ hardmax: bool = True,
225
+ number_of_examples: int = 0,
226
+ filter_: Filter | None = None,
227
+ as_dict: bool = False,
228
+ ) -> dict[MetricType, list]:
229
+ """
230
+ Performs an evaluation and returns metrics.
231
+
232
+ Parameters
233
+ ----------
234
+ metrics_to_return : list[MetricType]
235
+ A list of metrics to return in the results.
236
+ score_thresholds : list[float]
237
+ A list of score thresholds to compute metrics over.
238
+ hardmax : bool
239
+ Toggles whether a hardmax is applied to predictions.
240
+ number_of_examples : int, default=0
241
+ Maximum number of annotation examples to return in ConfusionMatrix.
242
+ filter_ : Filter, optional
243
+ An optional filter object.
244
+ as_dict : bool, default=False
245
+ An option to return metrics as dictionaries.
246
+
247
+ Returns
248
+ -------
249
+ dict[MetricType, list]
250
+ A dictionary mapping MetricType enumerations to lists of computed metrics.
251
+ """
252
+
253
+ # apply filters
254
+ data = self._detailed_pairs
255
+ label_metadata = self._label_metadata
256
+ n_datums = self.n_datums
257
+ if filter_ is not None:
258
+ data = data[filter_.indices]
259
+ label_metadata = filter_.label_metadata
260
+ n_datums = filter_.n_datums
261
+
262
+ (
263
+ counts,
264
+ precision,
265
+ recall,
266
+ accuracy,
267
+ f1_score,
268
+ rocauc,
269
+ mean_rocauc,
270
+ ) = compute_metrics(
271
+ data=data,
272
+ label_metadata=label_metadata,
273
+ score_thresholds=np.array(score_thresholds),
274
+ hardmax=hardmax,
275
+ n_datums=n_datums,
276
+ )
277
+
278
+ metrics = defaultdict(list)
279
+
280
+ metrics[MetricType.ROCAUC] = [
281
+ ROCAUC(
282
+ value=rocauc[label_idx],
283
+ label=self.index_to_label[label_idx],
284
+ )
285
+ for label_idx in range(label_metadata.shape[0])
286
+ if label_metadata[label_idx, 0] > 0
287
+ ]
288
+
289
+ metrics[MetricType.mROCAUC] = [
290
+ mROCAUC(
291
+ value=mean_rocauc[label_key_idx],
292
+ label_key=self.index_to_label_key[label_key_idx],
293
+ )
294
+ for label_key_idx in range(len(self.label_key_to_index))
295
+ ]
296
+
297
+ for label_idx, label in self.index_to_label.items():
298
+
299
+ kwargs = {
300
+ "label": label,
301
+ "score_thresholds": score_thresholds,
302
+ "hardmax": hardmax,
303
+ }
304
+ row = counts[:, label_idx]
305
+ metrics[MetricType.Counts].append(
306
+ Counts(
307
+ tp=row[:, 0].tolist(),
308
+ fp=row[:, 1].tolist(),
309
+ fn=row[:, 2].tolist(),
310
+ tn=row[:, 3].tolist(),
311
+ **kwargs,
312
+ )
313
+ )
314
+
315
+ # if no groundtruths exists for a label, skip it.
316
+ if label_metadata[label_idx, 0] == 0:
317
+ continue
318
+
319
+ metrics[MetricType.Precision].append(
320
+ Precision(
321
+ value=precision[:, label_idx].tolist(),
322
+ **kwargs,
323
+ )
324
+ )
325
+ metrics[MetricType.Recall].append(
326
+ Recall(
327
+ value=recall[:, label_idx].tolist(),
328
+ **kwargs,
329
+ )
330
+ )
331
+ metrics[MetricType.Accuracy].append(
332
+ Accuracy(
333
+ value=accuracy[:, label_idx].tolist(),
334
+ **kwargs,
335
+ )
336
+ )
337
+ metrics[MetricType.F1].append(
338
+ F1(
339
+ value=f1_score[:, label_idx].tolist(),
340
+ **kwargs,
341
+ )
342
+ )
343
+
344
+ if MetricType.ConfusionMatrix in metrics_to_return:
345
+ metrics[
346
+ MetricType.ConfusionMatrix
347
+ ] = self._compute_confusion_matrix(
348
+ data=data,
349
+ label_metadata=label_metadata,
350
+ score_thresholds=score_thresholds,
351
+ hardmax=hardmax,
352
+ number_of_examples=number_of_examples,
353
+ )
354
+
355
+ for metric in set(metrics.keys()):
356
+ if metric not in metrics_to_return:
357
+ del metrics[metric]
358
+
359
+ if as_dict:
360
+ return {
361
+ mtype: [metric.to_dict() for metric in mvalues]
362
+ for mtype, mvalues in metrics.items()
363
+ }
364
+
365
+ return metrics
366
+
367
+ def _unpack_confusion_matrix(
368
+ self,
369
+ confusion_matrix: NDArray[np.floating],
370
+ label_key_idx: int,
371
+ number_of_labels: int,
372
+ number_of_examples: int,
373
+ ) -> dict[
374
+ str,
375
+ dict[
376
+ str,
377
+ dict[
378
+ str,
379
+ int
380
+ | list[
381
+ dict[
382
+ str,
383
+ str | float,
384
+ ]
385
+ ],
386
+ ],
387
+ ],
388
+ ]:
389
+ """
390
+ Unpacks a numpy array of confusion matrix counts and examples.
391
+ """
392
+
393
+ datum_idx = lambda gt_label_idx, pd_label_idx, example_idx: int( # noqa: E731 - lambda fn
394
+ confusion_matrix[
395
+ gt_label_idx,
396
+ pd_label_idx,
397
+ example_idx * 2 + 1,
398
+ ]
399
+ )
400
+
401
+ score_idx = lambda gt_label_idx, pd_label_idx, example_idx: float( # noqa: E731 - lambda fn
402
+ confusion_matrix[
403
+ gt_label_idx,
404
+ pd_label_idx,
405
+ example_idx * 2 + 2,
406
+ ]
407
+ )
408
+
409
+ return {
410
+ self.index_to_label[gt_label_idx][1]: {
411
+ self.index_to_label[pd_label_idx][1]: {
412
+ "count": max(
413
+ int(confusion_matrix[gt_label_idx, pd_label_idx, 0]),
414
+ 0,
415
+ ),
416
+ "examples": [
417
+ {
418
+ "datum": self.index_to_uid[
419
+ datum_idx(
420
+ gt_label_idx, pd_label_idx, example_idx
421
+ )
422
+ ],
423
+ "score": score_idx(
424
+ gt_label_idx, pd_label_idx, example_idx
425
+ ),
426
+ }
427
+ for example_idx in range(number_of_examples)
428
+ if datum_idx(gt_label_idx, pd_label_idx, example_idx)
429
+ >= 0
430
+ ],
431
+ }
432
+ for pd_label_idx in range(number_of_labels)
433
+ if (
434
+ self.label_index_to_label_key_index[pd_label_idx]
435
+ == label_key_idx
436
+ )
437
+ }
438
+ for gt_label_idx in range(number_of_labels)
439
+ if (
440
+ self.label_index_to_label_key_index[gt_label_idx]
441
+ == label_key_idx
442
+ )
443
+ }
444
+
445
+ def _unpack_missing_predictions(
446
+ self,
447
+ missing_predictions: NDArray[np.int32],
448
+ label_key_idx: int,
449
+ number_of_labels: int,
450
+ number_of_examples: int,
451
+ ) -> dict[str, dict[str, int | list[dict[str, str]]]]:
452
+ """
453
+ Unpacks a numpy array of missing prediction counts and examples.
454
+ """
455
+
456
+ datum_idx = (
457
+ lambda gt_label_idx, example_idx: int( # noqa: E731 - lambda fn
458
+ missing_predictions[
459
+ gt_label_idx,
460
+ example_idx + 1,
461
+ ]
462
+ )
463
+ )
464
+
465
+ return {
466
+ self.index_to_label[gt_label_idx][1]: {
467
+ "count": max(
468
+ int(missing_predictions[gt_label_idx, 0]),
469
+ 0,
470
+ ),
471
+ "examples": [
472
+ {
473
+ "datum": self.index_to_uid[
474
+ datum_idx(gt_label_idx, example_idx)
475
+ ]
476
+ }
477
+ for example_idx in range(number_of_examples)
478
+ if datum_idx(gt_label_idx, example_idx) >= 0
479
+ ],
480
+ }
481
+ for gt_label_idx in range(number_of_labels)
482
+ if (
483
+ self.label_index_to_label_key_index[gt_label_idx]
484
+ == label_key_idx
485
+ )
486
+ }
487
+
488
+ def _compute_confusion_matrix(
489
+ self,
490
+ data: NDArray[np.floating],
491
+ label_metadata: NDArray[np.int32],
492
+ score_thresholds: list[float],
493
+ hardmax: bool,
494
+ number_of_examples: int,
495
+ ) -> list[ConfusionMatrix]:
496
+ """
497
+ Computes a detailed confusion matrix..
498
+
499
+ Parameters
500
+ ----------
501
+ data : NDArray[np.floating]
502
+ A data array containing classification pairs.
503
+ label_metadata : NDArray[np.int32]
504
+ An integer array containing label metadata.
505
+ score_thresholds : list[float]
506
+ A list of score thresholds to compute metrics over.
507
+ hardmax : bool
508
+ Toggles whether a hardmax is applied to predictions.
509
+ number_of_examples : int, default=0
510
+ The number of examples to return per count.
511
+
512
+ Returns
513
+ -------
514
+ list[ConfusionMatrix]
515
+ A list of ConfusionMatrix per label key.
516
+ """
517
+
518
+ if data.size == 0:
519
+ return list()
520
+
521
+ confusion_matrix, missing_predictions = compute_confusion_matrix(
522
+ data=data,
523
+ label_metadata=label_metadata,
524
+ score_thresholds=np.array(score_thresholds),
525
+ hardmax=hardmax,
526
+ n_examples=number_of_examples,
527
+ )
528
+
529
+ n_scores, n_labels, _, _ = confusion_matrix.shape
530
+ return [
531
+ ConfusionMatrix(
532
+ score_threshold=score_thresholds[score_idx],
533
+ label_key=label_key,
534
+ number_of_examples=number_of_examples,
535
+ confusion_matrix=self._unpack_confusion_matrix(
536
+ confusion_matrix=confusion_matrix[score_idx, :, :, :],
537
+ label_key_idx=label_key_idx,
538
+ number_of_labels=n_labels,
539
+ number_of_examples=number_of_examples,
540
+ ),
541
+ missing_predictions=self._unpack_missing_predictions(
542
+ missing_predictions=missing_predictions[score_idx, :, :],
543
+ label_key_idx=label_key_idx,
544
+ number_of_labels=n_labels,
545
+ number_of_examples=number_of_examples,
546
+ ),
547
+ )
548
+ for label_key_idx, label_key in self.index_to_label_key.items()
549
+ for score_idx in range(n_scores)
550
+ ]
551
+
552
+
553
+ class DataLoader:
554
+ """
555
+ Classification DataLoader.
556
+ """
557
+
558
+ def __init__(self):
559
+ self._evaluator = Evaluator()
560
+ self.groundtruth_count = defaultdict(lambda: defaultdict(int))
561
+ self.prediction_count = defaultdict(lambda: defaultdict(int))
562
+
563
+ def _add_datum(self, uid: str) -> int:
564
+ """
565
+ Helper function for adding a datum to the cache.
566
+
567
+ Parameters
568
+ ----------
569
+ uid : str
570
+ The datum uid.
571
+
572
+ Returns
573
+ -------
574
+ int
575
+ The datum index.
576
+ """
577
+ if uid not in self._evaluator.uid_to_index:
578
+ index = len(self._evaluator.uid_to_index)
579
+ self._evaluator.uid_to_index[uid] = index
580
+ self._evaluator.index_to_uid[index] = uid
581
+ return self._evaluator.uid_to_index[uid]
582
+
583
+ def _add_label(self, label: tuple[str, str]) -> tuple[int, int]:
584
+ """
585
+ Helper function for adding a label to the cache.
586
+
587
+ Parameters
588
+ ----------
589
+ label : tuple[str, str]
590
+ The label as a tuple in format (key, value).
591
+
592
+ Returns
593
+ -------
594
+ int
595
+ Label index.
596
+ int
597
+ Label key index.
598
+ """
599
+ label_id = len(self._evaluator.index_to_label)
600
+ label_key_id = len(self._evaluator.index_to_label_key)
601
+ if label not in self._evaluator.label_to_index:
602
+ self._evaluator.label_to_index[label] = label_id
603
+ self._evaluator.index_to_label[label_id] = label
604
+
605
+ # update label key index
606
+ if label[0] not in self._evaluator.label_key_to_index:
607
+ self._evaluator.label_key_to_index[label[0]] = label_key_id
608
+ self._evaluator.index_to_label_key[label_key_id] = label[0]
609
+ label_key_id += 1
610
+
611
+ self._evaluator.label_index_to_label_key_index[
612
+ label_id
613
+ ] = self._evaluator.label_key_to_index[label[0]]
614
+ label_id += 1
615
+
616
+ return (
617
+ self._evaluator.label_to_index[label],
618
+ self._evaluator.label_key_to_index[label[0]],
619
+ )
620
+
621
+ def _add_data(
622
+ self,
623
+ uid_index: int,
624
+ keyed_groundtruths: dict[int, int],
625
+ keyed_predictions: dict[int, list[tuple[int, float]]],
626
+ ):
627
+ gt_keys = set(keyed_groundtruths.keys())
628
+ pd_keys = set(keyed_predictions.keys())
629
+ joint_keys = gt_keys.intersection(pd_keys)
630
+
631
+ gt_unique_keys = gt_keys - pd_keys
632
+ pd_unique_keys = pd_keys - gt_keys
633
+ if gt_unique_keys or pd_unique_keys:
634
+ raise ValueError(
635
+ "Label keys must match between ground truths and predictions."
636
+ )
637
+
638
+ pairs = list()
639
+ for key in joint_keys:
640
+ scores = np.array([score for _, score in keyed_predictions[key]])
641
+ max_score_idx = np.argmax(scores)
642
+
643
+ glabel = keyed_groundtruths[key]
644
+ for idx, (plabel, score) in enumerate(keyed_predictions[key]):
645
+ pairs.append(
646
+ (
647
+ float(uid_index),
648
+ float(glabel),
649
+ float(plabel),
650
+ float(score),
651
+ float(max_score_idx == idx),
652
+ )
653
+ )
654
+
655
+ if self._evaluator._detailed_pairs.size == 0:
656
+ self._evaluator._detailed_pairs = np.array(pairs)
657
+ else:
658
+ self._evaluator._detailed_pairs = np.concatenate(
659
+ [
660
+ self._evaluator._detailed_pairs,
661
+ np.array(pairs),
662
+ ],
663
+ axis=0,
664
+ )
665
+
666
+ def add_data(
667
+ self,
668
+ classifications: list[Classification],
669
+ show_progress: bool = False,
670
+ ):
671
+ """
672
+ Adds classifications to the cache.
673
+
674
+ Parameters
675
+ ----------
676
+ classifications : list[Classification]
677
+ A list of Classification objects.
678
+ show_progress : bool, default=False
679
+ Toggle for tqdm progress bar.
680
+ """
681
+
682
+ disable_tqdm = not show_progress
683
+ for classification in tqdm(classifications, disable=disable_tqdm):
684
+
685
+ # update metadata
686
+ self._evaluator.n_datums += 1
687
+ self._evaluator.n_groundtruths += len(classification.groundtruths)
688
+ self._evaluator.n_predictions += len(classification.predictions)
689
+
690
+ # update datum uid index
691
+ uid_index = self._add_datum(uid=classification.uid)
692
+
693
+ # cache labels and annotations
694
+ keyed_groundtruths = defaultdict(int)
695
+ keyed_predictions = defaultdict(list)
696
+ for glabel in classification.groundtruths:
697
+ label_idx, label_key_idx = self._add_label(glabel)
698
+ self.groundtruth_count[label_idx][uid_index] += 1
699
+ keyed_groundtruths[label_key_idx] = label_idx
700
+ for idx, (plabel, pscore) in enumerate(
701
+ zip(classification.predictions, classification.scores)
702
+ ):
703
+ label_idx, label_key_idx = self._add_label(plabel)
704
+ self.prediction_count[label_idx][uid_index] += 1
705
+ keyed_predictions[label_key_idx].append(
706
+ (
707
+ label_idx,
708
+ pscore,
709
+ )
710
+ )
711
+
712
+ self._add_data(
713
+ uid_index=uid_index,
714
+ keyed_groundtruths=keyed_groundtruths,
715
+ keyed_predictions=keyed_predictions,
716
+ )
717
+
718
+ def add_data_from_valor_dict(
719
+ self,
720
+ classifications: list[tuple[dict, dict]],
721
+ show_progress: bool = False,
722
+ ):
723
+ """
724
+ Adds Valor-format classifications to the cache.
725
+
726
+ Parameters
727
+ ----------
728
+ classifications : list[tuple[dict, dict]]
729
+ A list of groundtruth, prediction pairs in Valor-format dictionaries.
730
+ show_progress : bool, default=False
731
+ Toggle for tqdm progress bar.
732
+ """
733
+
734
+ disable_tqdm = not show_progress
735
+ for groundtruth, prediction in tqdm(
736
+ classifications, disable=disable_tqdm
737
+ ):
738
+
739
+ # update metadata
740
+ self._evaluator.n_datums += 1
741
+ self._evaluator.n_groundtruths += len(groundtruth["annotations"])
742
+ self._evaluator.n_predictions += len(prediction["annotations"])
743
+
744
+ # update datum uid index
745
+ uid_index = self._add_datum(uid=groundtruth["datum"]["uid"])
746
+
747
+ # cache labels and annotations
748
+ keyed_groundtruths = defaultdict(int)
749
+ keyed_predictions = defaultdict(list)
750
+ for gann in groundtruth["annotations"]:
751
+ for valor_label in gann["labels"]:
752
+ glabel = (valor_label["key"], valor_label["value"])
753
+ label_idx, label_key_idx = self._add_label(glabel)
754
+ self.groundtruth_count[label_idx][uid_index] += 1
755
+ keyed_groundtruths[label_key_idx] = label_idx
756
+ for pann in prediction["annotations"]:
757
+ for valor_label in pann["labels"]:
758
+ plabel = (valor_label["key"], valor_label["value"])
759
+ pscore = valor_label["score"]
760
+ label_idx, label_key_idx = self._add_label(plabel)
761
+ self.prediction_count[label_idx][uid_index] += 1
762
+ keyed_predictions[label_key_idx].append(
763
+ (
764
+ label_idx,
765
+ pscore,
766
+ )
767
+ )
768
+
769
+ self._add_data(
770
+ uid_index=uid_index,
771
+ keyed_groundtruths=keyed_groundtruths,
772
+ keyed_predictions=keyed_predictions,
773
+ )
774
+
775
+ def finalize(self) -> Evaluator:
776
+ """
777
+ Performs data finalization and some preprocessing steps.
778
+
779
+ Returns
780
+ -------
781
+ Evaluator
782
+ A ready-to-use evaluator object.
783
+ """
784
+
785
+ if self._evaluator._detailed_pairs.size == 0:
786
+ raise ValueError("No data available to create evaluator.")
787
+
788
+ n_datums = self._evaluator.n_datums
789
+ n_labels = len(self._evaluator.index_to_label)
790
+
791
+ self._evaluator.n_labels = n_labels
792
+
793
+ self._evaluator._label_metadata_per_datum = np.zeros(
794
+ (2, n_datums, n_labels), dtype=np.int32
795
+ )
796
+ for datum_idx in range(n_datums):
797
+ for label_idx in range(n_labels):
798
+ gt_count = (
799
+ self.groundtruth_count[label_idx].get(datum_idx, 0)
800
+ if label_idx in self.groundtruth_count
801
+ else 0
802
+ )
803
+ pd_count = (
804
+ self.prediction_count[label_idx].get(datum_idx, 0)
805
+ if label_idx in self.prediction_count
806
+ else 0
807
+ )
808
+ self._evaluator._label_metadata_per_datum[
809
+ :, datum_idx, label_idx
810
+ ] = np.array([gt_count, pd_count])
811
+
812
+ self._evaluator._label_metadata = np.array(
813
+ [
814
+ [
815
+ np.sum(
816
+ self._evaluator._label_metadata_per_datum[
817
+ 0, :, label_idx
818
+ ]
819
+ ),
820
+ np.sum(
821
+ self._evaluator._label_metadata_per_datum[
822
+ 1, :, label_idx
823
+ ]
824
+ ),
825
+ self._evaluator.label_index_to_label_key_index[label_idx],
826
+ ]
827
+ for label_idx in range(n_labels)
828
+ ],
829
+ dtype=np.int32,
830
+ )
831
+
832
+ # sort pairs by groundtruth, prediction, score
833
+ indices = np.lexsort(
834
+ (
835
+ self._evaluator._detailed_pairs[:, 1],
836
+ self._evaluator._detailed_pairs[:, 2],
837
+ -self._evaluator._detailed_pairs[:, 3],
838
+ )
839
+ )
840
+ self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
841
+ indices
842
+ ]
843
+
844
+ return self._evaluator