valor-lite 0.36.5__py3-none-any.whl → 0.37.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. valor_lite/cache/__init__.py +11 -0
  2. valor_lite/cache/compute.py +211 -0
  3. valor_lite/cache/ephemeral.py +302 -0
  4. valor_lite/cache/persistent.py +536 -0
  5. valor_lite/classification/__init__.py +5 -10
  6. valor_lite/classification/annotation.py +4 -0
  7. valor_lite/classification/computation.py +233 -251
  8. valor_lite/classification/evaluator.py +882 -0
  9. valor_lite/classification/loader.py +97 -0
  10. valor_lite/classification/metric.py +141 -4
  11. valor_lite/classification/shared.py +184 -0
  12. valor_lite/classification/utilities.py +221 -118
  13. valor_lite/exceptions.py +5 -0
  14. valor_lite/object_detection/__init__.py +5 -4
  15. valor_lite/object_detection/annotation.py +13 -1
  16. valor_lite/object_detection/computation.py +367 -304
  17. valor_lite/object_detection/evaluator.py +804 -0
  18. valor_lite/object_detection/loader.py +292 -0
  19. valor_lite/object_detection/metric.py +152 -3
  20. valor_lite/object_detection/shared.py +206 -0
  21. valor_lite/object_detection/utilities.py +182 -109
  22. valor_lite/semantic_segmentation/__init__.py +5 -4
  23. valor_lite/semantic_segmentation/annotation.py +7 -0
  24. valor_lite/semantic_segmentation/computation.py +20 -110
  25. valor_lite/semantic_segmentation/evaluator.py +414 -0
  26. valor_lite/semantic_segmentation/loader.py +205 -0
  27. valor_lite/semantic_segmentation/shared.py +149 -0
  28. valor_lite/semantic_segmentation/utilities.py +6 -23
  29. {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
  30. valor_lite-0.37.5.dist-info/RECORD +49 -0
  31. {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
  32. valor_lite/classification/manager.py +0 -545
  33. valor_lite/object_detection/manager.py +0 -865
  34. valor_lite/profiling.py +0 -374
  35. valor_lite/semantic_segmentation/benchmark.py +0 -237
  36. valor_lite/semantic_segmentation/manager.py +0 -446
  37. valor_lite-0.36.5.dist-info/RECORD +0 -41
  38. {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ import numpy as np
2
+ import pyarrow as pa
3
+ from numpy.typing import NDArray
4
+ from tqdm import tqdm
5
+
6
+ from valor_lite.cache import FileCacheWriter, MemoryCacheWriter
7
+ from valor_lite.object_detection.annotation import (
8
+ Bitmask,
9
+ BoundingBox,
10
+ Detection,
11
+ Polygon,
12
+ )
13
+ from valor_lite.object_detection.computation import (
14
+ EPSILON,
15
+ compute_bbox_iou,
16
+ compute_bitmask_iou,
17
+ compute_polygon_iou,
18
+ )
19
+ from valor_lite.object_detection.evaluator import Builder
20
+
21
+
22
+ class Loader(Builder):
23
+ def __init__(
24
+ self,
25
+ detailed_writer: MemoryCacheWriter | FileCacheWriter,
26
+ ranked_writer: MemoryCacheWriter | FileCacheWriter,
27
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
28
+ ):
29
+ super().__init__(
30
+ detailed_writer=detailed_writer,
31
+ ranked_writer=ranked_writer,
32
+ metadata_fields=metadata_fields,
33
+ )
34
+
35
+ # internal state
36
+ self._labels = {}
37
+ self._datum_count = 0
38
+ self._groundtruth_count = 0
39
+ self._prediction_count = 0
40
+
41
+ def _add_label(self, value: str) -> int:
42
+ """Add a label to the index mapping."""
43
+ idx = self._labels.get(value, None)
44
+ if idx is None:
45
+ idx = len(self._labels)
46
+ self._labels[value] = idx
47
+ return idx
48
+
49
+ def _add_data(
50
+ self,
51
+ detections: list[Detection],
52
+ detection_ious: list[NDArray[np.float64]],
53
+ show_progress: bool = False,
54
+ ):
55
+ """Adds detections to the cache."""
56
+ disable_tqdm = not show_progress
57
+ for detection, ious in tqdm(
58
+ zip(detections, detection_ious), disable=disable_tqdm
59
+ ):
60
+ # cache labels and annotation pairs
61
+ datum_idx = self._datum_count
62
+ datum_metadata = detection.metadata if detection.metadata else {}
63
+ pairs = []
64
+ if detection.groundtruths:
65
+ for gidx, gann in enumerate(detection.groundtruths):
66
+ gt_id = self._groundtruth_count + gidx
67
+ glabel = gann.labels[0]
68
+ glabel_idx = self._add_label(gann.labels[0])
69
+ gann_metadata = gann.metadata if gann.metadata else {}
70
+ if (ious[:, gidx] < EPSILON).all():
71
+ pairs.append(
72
+ {
73
+ # metadata
74
+ **datum_metadata,
75
+ **gann_metadata,
76
+ # datum
77
+ "datum_uid": detection.uid,
78
+ "datum_id": datum_idx,
79
+ # groundtruth
80
+ "gt_uid": gann.uid,
81
+ "gt_id": gt_id,
82
+ "gt_label": glabel,
83
+ "gt_label_id": glabel_idx,
84
+ # prediction
85
+ "pd_uid": None,
86
+ "pd_id": -1,
87
+ "pd_label": None,
88
+ "pd_label_id": -1,
89
+ "pd_score": -1,
90
+ # pair
91
+ "iou": 0.0,
92
+ }
93
+ )
94
+ for pidx, pann in enumerate(detection.predictions):
95
+ pann_id = self._prediction_count + pidx
96
+ pann_metadata = pann.metadata if pann.metadata else {}
97
+ if (ious[pidx, :] < EPSILON).all():
98
+ pairs.extend(
99
+ [
100
+ {
101
+ # metadata
102
+ **datum_metadata,
103
+ **pann_metadata,
104
+ # datum
105
+ "datum_uid": detection.uid,
106
+ "datum_id": datum_idx,
107
+ # groundtruth
108
+ "gt_uid": None,
109
+ "gt_id": -1,
110
+ "gt_label": None,
111
+ "gt_label_id": -1,
112
+ # prediction
113
+ "pd_uid": pann.uid,
114
+ "pd_id": pann_id,
115
+ "pd_label": plabel,
116
+ "pd_label_id": self._add_label(plabel),
117
+ "pd_score": float(pscore),
118
+ # pair
119
+ "iou": 0.0,
120
+ }
121
+ for plabel, pscore in zip(
122
+ pann.labels, pann.scores
123
+ )
124
+ ]
125
+ )
126
+ if ious[pidx, gidx] >= EPSILON:
127
+ pairs.extend(
128
+ [
129
+ {
130
+ # metadata
131
+ **datum_metadata,
132
+ **gann_metadata,
133
+ **pann_metadata,
134
+ # datum
135
+ "datum_uid": detection.uid,
136
+ "datum_id": datum_idx,
137
+ # groundtruth
138
+ "gt_uid": gann.uid,
139
+ "gt_id": gt_id,
140
+ "gt_label": glabel,
141
+ "gt_label_id": self._add_label(glabel),
142
+ # prediction
143
+ "pd_uid": pann.uid,
144
+ "pd_id": pann_id,
145
+ "pd_label": plabel,
146
+ "pd_label_id": self._add_label(plabel),
147
+ "pd_score": float(pscore),
148
+ # pair
149
+ "iou": float(ious[pidx, gidx]),
150
+ }
151
+ for glabel in gann.labels
152
+ for plabel, pscore in zip(
153
+ pann.labels, pann.scores
154
+ )
155
+ ]
156
+ )
157
+ elif detection.predictions:
158
+ for pidx, pann in enumerate(detection.predictions):
159
+ pann_id = self._prediction_count + pidx
160
+ pann_metadata = pann.metadata if pann.metadata else {}
161
+ pairs.extend(
162
+ [
163
+ {
164
+ # metadata
165
+ **datum_metadata,
166
+ **pann_metadata,
167
+ # datum
168
+ "datum_uid": detection.uid,
169
+ "datum_id": datum_idx,
170
+ # groundtruth
171
+ "gt_uid": None,
172
+ "gt_id": -1,
173
+ "gt_label": None,
174
+ "gt_label_id": -1,
175
+ # prediction
176
+ "pd_uid": pann.uid,
177
+ "pd_id": pann_id,
178
+ "pd_label": plabel,
179
+ "pd_label_id": self._add_label(plabel),
180
+ "pd_score": float(pscore),
181
+ # pair
182
+ "iou": 0.0,
183
+ }
184
+ for plabel, pscore in zip(pann.labels, pann.scores)
185
+ ]
186
+ )
187
+
188
+ self._datum_count += 1
189
+ self._groundtruth_count += len(detection.groundtruths)
190
+ self._prediction_count += len(detection.predictions)
191
+
192
+ self._detailed_writer.write_rows(pairs)
193
+
194
+ def add_bounding_boxes(
195
+ self,
196
+ detections: list[Detection[BoundingBox]],
197
+ show_progress: bool = False,
198
+ ):
199
+ """
200
+ Adds bounding box detections to the cache.
201
+
202
+ Parameters
203
+ ----------
204
+ detections : list[Detection]
205
+ A list of Detection objects.
206
+ show_progress : bool, default=False
207
+ Toggle for tqdm progress bar.
208
+ """
209
+ ious = [
210
+ compute_bbox_iou(
211
+ np.array(
212
+ [
213
+ [gt.extrema, pd.extrema]
214
+ for pd in detection.predictions
215
+ for gt in detection.groundtruths
216
+ ],
217
+ dtype=np.float64,
218
+ )
219
+ ).reshape(len(detection.predictions), len(detection.groundtruths))
220
+ for detection in detections
221
+ ]
222
+ return self._add_data(
223
+ detections=detections,
224
+ detection_ious=ious,
225
+ show_progress=show_progress,
226
+ )
227
+
228
+ def add_polygons(
229
+ self,
230
+ detections: list[Detection[Polygon]],
231
+ show_progress: bool = False,
232
+ ):
233
+ """
234
+ Adds polygon detections to the cache.
235
+
236
+ Parameters
237
+ ----------
238
+ detections : list[Detection]
239
+ A list of Detection objects.
240
+ show_progress : bool, default=False
241
+ Toggle for tqdm progress bar.
242
+ """
243
+ ious = [
244
+ compute_polygon_iou(
245
+ np.array(
246
+ [
247
+ [gt.shape, pd.shape]
248
+ for pd in detection.predictions
249
+ for gt in detection.groundtruths
250
+ ]
251
+ )
252
+ ).reshape(len(detection.predictions), len(detection.groundtruths))
253
+ for detection in detections
254
+ ]
255
+ return self._add_data(
256
+ detections=detections,
257
+ detection_ious=ious,
258
+ show_progress=show_progress,
259
+ )
260
+
261
+ def add_bitmasks(
262
+ self,
263
+ detections: list[Detection[Bitmask]],
264
+ show_progress: bool = False,
265
+ ):
266
+ """
267
+ Adds bitmask detections to the cache.
268
+
269
+ Parameters
270
+ ----------
271
+ detections : list[Detection]
272
+ A list of Detection objects.
273
+ show_progress : bool, default=False
274
+ Toggle for tqdm progress bar.
275
+ """
276
+ ious = [
277
+ compute_bitmask_iou(
278
+ np.array(
279
+ [
280
+ [gt.mask, pd.mask]
281
+ for pd in detection.predictions
282
+ for gt in detection.groundtruths
283
+ ]
284
+ )
285
+ ).reshape(len(detection.predictions), len(detection.groundtruths))
286
+ for detection in detections
287
+ ]
288
+ return self._add_data(
289
+ detections=detections,
290
+ detection_ious=ious,
291
+ show_progress=show_progress,
292
+ )
@@ -18,7 +18,9 @@ class MetricType(str, Enum):
18
18
  ARAveragedOverScores = "ARAveragedOverScores"
19
19
  mARAveragedOverScores = "mARAveragedOverScores"
20
20
  PrecisionRecallCurve = "PrecisionRecallCurve"
21
+ ConfusionMatrixWithExamples = "ConfusionMatrixWithExamples"
21
22
  ConfusionMatrix = "ConfusionMatrix"
23
+ Examples = "Examples"
22
24
 
23
25
 
24
26
  @dataclass
@@ -562,6 +564,153 @@ class Metric(BaseMetric):
562
564
 
563
565
  @classmethod
564
566
  def confusion_matrix(
567
+ cls,
568
+ confusion_matrix: dict[str, dict[str, int]],
569
+ unmatched_predictions: dict[str, int],
570
+ unmatched_ground_truths: dict[str, int],
571
+ score_threshold: float,
572
+ iou_threshold: float,
573
+ ):
574
+ """
575
+ Confusion matrix for object detection task.
576
+
577
+ This class encapsulates detailed information about the model's performance, including correct
578
+ predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
579
+ (subset of false negatives).
580
+
581
+ Confusion Matrix Format:
582
+ {
583
+ <ground truth label>: {
584
+ <prediction label>: 129
585
+ ...
586
+ },
587
+ ...
588
+ }
589
+
590
+ Unmatched Predictions Format:
591
+ {
592
+ <prediction label>: 11
593
+ ...
594
+ }
595
+
596
+ Unmatched Ground Truths Format:
597
+ {
598
+ <ground truth label>: 7
599
+ ...
600
+ }
601
+
602
+ Parameters
603
+ ----------
604
+ confusion_matrix : dict
605
+ A nested dictionary containing integer counts of occurences where the first key is the ground truth label value
606
+ and the second key is the prediction label value.
607
+ unmatched_predictions : dict
608
+ A dictionary where each key is a prediction label value with no corresponding ground truth
609
+ (subset of false positives). The value is a dictionary containing counts.
610
+ unmatched_ground_truths : dict
611
+ A dictionary where each key is a ground truth label value for which the model failed to predict
612
+ (subset of false negatives). The value is a dictionary containing counts.
613
+ score_threshold : float
614
+ The confidence score threshold used to filter predictions.
615
+ iou_threshold : float
616
+ The Intersection over Union (IOU) threshold used to determine true positives.
617
+
618
+ Returns
619
+ -------
620
+ Metric
621
+ """
622
+ return cls(
623
+ type=MetricType.ConfusionMatrix.value,
624
+ value={
625
+ "confusion_matrix": confusion_matrix,
626
+ "unmatched_predictions": unmatched_predictions,
627
+ "unmatched_ground_truths": unmatched_ground_truths,
628
+ },
629
+ parameters={
630
+ "score_threshold": score_threshold,
631
+ "iou_threshold": iou_threshold,
632
+ },
633
+ )
634
+
635
+ @classmethod
636
+ def examples(
637
+ cls,
638
+ datum_id: str,
639
+ true_positives: list[tuple[str, str]],
640
+ false_positives: list[str],
641
+ false_negatives: list[str],
642
+ score_threshold: float,
643
+ iou_threshold: float,
644
+ ):
645
+ """
646
+ Per-datum examples for object detection tasks.
647
+
648
+ This metric is per-datum and contains lists of annotation identifiers that categorize them
649
+ as true-positive, false-positive or false-negative. This is intended to be used with an
650
+ external database where the identifiers can be used for retrieval.
651
+
652
+ Examples Format:
653
+ {
654
+ "type": "Examples",
655
+ "value": {
656
+ "datum_id": "some string ID",
657
+ "true_positives": [
658
+ ["groundtruth0", "prediction0"],
659
+ ["groundtruth123", "prediction11"],
660
+ ...
661
+ ],
662
+ "false_positives": [
663
+ "prediction25",
664
+ "prediction92",
665
+ ...
666
+ ]
667
+ "false_negatives": [
668
+ "groundtruth32",
669
+ "groundtruth24",
670
+ ...
671
+ ]
672
+ },
673
+ "parameters": {
674
+ "score_threshold": 0.5,
675
+ "iou_threshold": 0.5,
676
+ }
677
+ }
678
+
679
+ Parameters
680
+ ----------
681
+ datum_id : str
682
+ A string identifier representing a datum.
683
+ true_positives : list[tuple[str, str]]
684
+ A list of string identifier pairs representing true positive ground truth and prediction combinations.
685
+ false_positives : list[str]
686
+ A list of string identifiers representing false positive predictions.
687
+ false_negatives : list[str]
688
+ A list of string identifiers representing false negative ground truths.
689
+ score_threshold : float
690
+ The confidence score threshold used to filter predictions.
691
+ iou_threshold : float
692
+ The Intersection over Union (IOU) threshold used to determine true positives.
693
+
694
+ Returns
695
+ -------
696
+ Metric
697
+ """
698
+ return cls(
699
+ type=MetricType.Examples.value,
700
+ value={
701
+ "datum_id": datum_id,
702
+ "true_positives": true_positives,
703
+ "false_positives": false_positives,
704
+ "false_negatives": false_negatives,
705
+ },
706
+ parameters={
707
+ "score_threshold": score_threshold,
708
+ "iou_threshold": iou_threshold,
709
+ },
710
+ )
711
+
712
+ @classmethod
713
+ def confusion_matrix_with_examples(
565
714
  cls,
566
715
  confusion_matrix: dict[
567
716
  str, # ground truth label value
@@ -609,7 +758,7 @@ class Metric(BaseMetric):
609
758
  iou_threshold: float,
610
759
  ):
611
760
  """
612
- Confusion matrix for object detection tasks.
761
+ Confusion matrix with examples for object detection tasks.
613
762
 
614
763
  This class encapsulates detailed information about the model's performance, including correct
615
764
  predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
@@ -674,7 +823,7 @@ class Metric(BaseMetric):
674
823
  A dictionary where each key is a prediction label value with no corresponding ground truth
675
824
  (subset of false positives). The value is a dictionary containing either a `count` or a list of
676
825
  `examples`. Each example includes annotation and datum identifers.
677
- unmatched_ground_truths : dict
826
+ unmatched_groundtruths : dict
678
827
  A dictionary where each key is a ground truth label value for which the model failed to predict
679
828
  (subset of false negatives). The value is a dictionary containing either a `count` or a list of `examples`.
680
829
  Each example includes annotation and datum identifers.
@@ -688,7 +837,7 @@ class Metric(BaseMetric):
688
837
  Metric
689
838
  """
690
839
  return cls(
691
- type=MetricType.ConfusionMatrix.value,
840
+ type=MetricType.ConfusionMatrixWithExamples.value,
692
841
  value={
693
842
  "confusion_matrix": confusion_matrix,
694
843
  "unmatched_predictions": unmatched_predictions,
@@ -0,0 +1,206 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+ from numpy.typing import NDArray
8
+
9
+ from valor_lite.cache import FileCacheReader, MemoryCacheReader
10
+
11
+
12
+ @dataclass
13
+ class EvaluatorInfo:
14
+ number_of_datums: int = 0
15
+ number_of_groundtruth_annotations: int = 0
16
+ number_of_prediction_annotations: int = 0
17
+ number_of_labels: int = 0
18
+ number_of_rows: int = 0
19
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None = None
20
+
21
+
22
+ def generate_detailed_cache_path(path: str | Path) -> Path:
23
+ return Path(path) / "detailed"
24
+
25
+
26
+ def generate_ranked_cache_path(path: str | Path) -> Path:
27
+ return Path(path) / "ranked"
28
+
29
+
30
+ def generate_temporary_cache_path(path: str | Path) -> Path:
31
+ return Path(path) / "tmp"
32
+
33
+
34
+ def generate_metadata_path(path: str | Path) -> Path:
35
+ return Path(path) / "metadata.json"
36
+
37
+
38
+ def generate_detailed_schema(
39
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None
40
+ ) -> pa.Schema:
41
+ metadata_fields = metadata_fields if metadata_fields else []
42
+ reserved_fields = [
43
+ ("datum_uid", pa.string()),
44
+ ("datum_id", pa.int64()),
45
+ # groundtruth
46
+ ("gt_uid", pa.string()),
47
+ ("gt_id", pa.int64()),
48
+ ("gt_label", pa.string()),
49
+ ("gt_label_id", pa.int64()),
50
+ # prediction
51
+ ("pd_uid", pa.string()),
52
+ ("pd_id", pa.int64()),
53
+ ("pd_label", pa.string()),
54
+ ("pd_label_id", pa.int64()),
55
+ ("pd_score", pa.float64()),
56
+ # pair
57
+ ("iou", pa.float64()),
58
+ ]
59
+
60
+ # validate
61
+ reserved_field_names = {f[0] for f in reserved_fields}
62
+ metadata_field_names = {f[0] for f in metadata_fields}
63
+ if conflicting := reserved_field_names & metadata_field_names:
64
+ raise ValueError(
65
+ f"metadata fields {conflicting} conflict with reserved fields"
66
+ )
67
+
68
+ return pa.schema(reserved_fields + metadata_fields)
69
+
70
+
71
+ def generate_ranked_schema(
72
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None
73
+ ) -> pa.Schema:
74
+ reserved_detailed_fields = [
75
+ ("datum_uid", pa.string()),
76
+ ("datum_id", pa.int64()),
77
+ # groundtruth
78
+ ("gt_id", pa.int64()),
79
+ ("gt_label_id", pa.int64()),
80
+ # prediction
81
+ ("pd_id", pa.int64()),
82
+ ("pd_label_id", pa.int64()),
83
+ ("pd_score", pa.float64()),
84
+ # pair
85
+ ("iou", pa.float64()),
86
+ ]
87
+ reserved_ranking_fields = [
88
+ ("iou_prev", pa.float64()),
89
+ ]
90
+ metadata_fields = metadata_fields if metadata_fields else []
91
+
92
+ # validate
93
+ reserved_field_names = {
94
+ f[0] for f in reserved_detailed_fields + reserved_ranking_fields
95
+ }
96
+ metadata_field_names = {f[0] for f in metadata_fields}
97
+ if conflicting := reserved_field_names & metadata_field_names:
98
+ raise ValueError(
99
+ f"metadata fields {conflicting} conflict with reserved fields"
100
+ )
101
+
102
+ return pa.schema(
103
+ [
104
+ *reserved_detailed_fields,
105
+ *metadata_fields,
106
+ *reserved_ranking_fields,
107
+ ]
108
+ )
109
+
110
+
111
+ def encode_metadata_fields(
112
+ metadata_fields: list[tuple[str, str | pa.DataType]] | None
113
+ ) -> dict[str, str]:
114
+ metadata_fields = metadata_fields if metadata_fields else []
115
+ return {k: str(v) for k, v in metadata_fields}
116
+
117
+
118
+ def decode_metadata_fields(
119
+ encoded_metadata_fields: dict[str, str]
120
+ ) -> list[tuple[str, str]]:
121
+ return [(k, v) for k, v in encoded_metadata_fields.items()]
122
+
123
+
124
+ def extract_labels(
125
+ reader: MemoryCacheReader | FileCacheReader,
126
+ index_to_label_override: dict[int, str] | None = None,
127
+ ) -> dict[int, str]:
128
+ if index_to_label_override is not None:
129
+ return index_to_label_override
130
+
131
+ index_to_label = {}
132
+ for tbl in reader.iterate_tables(
133
+ columns=[
134
+ "gt_label_id",
135
+ "gt_label",
136
+ "pd_label_id",
137
+ "pd_label",
138
+ ]
139
+ ):
140
+
141
+ # get gt labels
142
+ gt_label_ids = tbl["gt_label_id"].to_numpy()
143
+ gt_label_ids, gt_indices = np.unique(gt_label_ids, return_index=True)
144
+ gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
145
+ gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
146
+ gt_labels.pop(-1, None)
147
+ index_to_label.update(gt_labels)
148
+
149
+ # get pd labels
150
+ pd_label_ids = tbl["pd_label_id"].to_numpy()
151
+ pd_label_ids, pd_indices = np.unique(pd_label_ids, return_index=True)
152
+ pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
153
+ pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
154
+ pd_labels.pop(-1, None)
155
+ index_to_label.update(pd_labels)
156
+
157
+ return index_to_label
158
+
159
+
160
+ def extract_counts(
161
+ reader: MemoryCacheReader | FileCacheReader,
162
+ datums: pc.Expression | None = None,
163
+ groundtruths: pc.Expression | None = None,
164
+ predictions: pc.Expression | None = None,
165
+ ):
166
+ n_dts, n_gts, n_pds = 0, 0, 0
167
+ for tbl in reader.iterate_tables(filter=datums):
168
+ # count datums
169
+ n_dts += int(np.unique(tbl["datum_id"].to_numpy()).shape[0])
170
+
171
+ # count groundtruths
172
+ if groundtruths is not None:
173
+ gts = tbl.filter(groundtruths)["gt_id"].to_numpy()
174
+ else:
175
+ gts = tbl["gt_id"].to_numpy()
176
+ n_gts += int(np.unique(gts[gts >= 0]).shape[0])
177
+
178
+ # count predictions
179
+ if predictions is not None:
180
+ pds = tbl.filter(predictions)["pd_id"].to_numpy()
181
+ else:
182
+ pds = tbl["pd_id"].to_numpy()
183
+ n_pds += int(np.unique(pds[pds >= 0]).shape[0])
184
+
185
+ return n_dts, n_gts, n_pds
186
+
187
+
188
+ def extract_groundtruth_count_per_label(
189
+ reader: MemoryCacheReader | FileCacheReader,
190
+ number_of_labels: int,
191
+ datums: pc.Expression | None = None,
192
+ ) -> NDArray[np.uint64]:
193
+ gt_counts_per_lbl = np.zeros(number_of_labels, dtype=np.uint64)
194
+ for gts in reader.iterate_arrays(
195
+ numeric_columns=["gt_id", "gt_label_id"],
196
+ filter=datums,
197
+ ):
198
+ # count gts per label
199
+ unique_ann = np.unique(gts[gts[:, 0] >= 0], axis=0)
200
+ unique_labels, label_counts = np.unique(
201
+ unique_ann[:, 1], return_counts=True
202
+ )
203
+ for label_id, count in zip(unique_labels, label_counts):
204
+ gt_counts_per_lbl[int(label_id)] += int(count)
205
+
206
+ return gt_counts_per_lbl