valor-lite 0.36.5__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +367 -304
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -109
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -865
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.5.dist-info/RECORD +0 -41
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pyarrow as pa
|
|
3
|
+
from numpy.typing import NDArray
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
from valor_lite.cache import FileCacheWriter, MemoryCacheWriter
|
|
7
|
+
from valor_lite.object_detection.annotation import (
|
|
8
|
+
Bitmask,
|
|
9
|
+
BoundingBox,
|
|
10
|
+
Detection,
|
|
11
|
+
Polygon,
|
|
12
|
+
)
|
|
13
|
+
from valor_lite.object_detection.computation import (
|
|
14
|
+
EPSILON,
|
|
15
|
+
compute_bbox_iou,
|
|
16
|
+
compute_bitmask_iou,
|
|
17
|
+
compute_polygon_iou,
|
|
18
|
+
)
|
|
19
|
+
from valor_lite.object_detection.evaluator import Builder
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Loader(Builder):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
detailed_writer: MemoryCacheWriter | FileCacheWriter,
|
|
26
|
+
ranked_writer: MemoryCacheWriter | FileCacheWriter,
|
|
27
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(
|
|
30
|
+
detailed_writer=detailed_writer,
|
|
31
|
+
ranked_writer=ranked_writer,
|
|
32
|
+
metadata_fields=metadata_fields,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# internal state
|
|
36
|
+
self._labels = {}
|
|
37
|
+
self._datum_count = 0
|
|
38
|
+
self._groundtruth_count = 0
|
|
39
|
+
self._prediction_count = 0
|
|
40
|
+
|
|
41
|
+
def _add_label(self, value: str) -> int:
|
|
42
|
+
"""Add a label to the index mapping."""
|
|
43
|
+
idx = self._labels.get(value, None)
|
|
44
|
+
if idx is None:
|
|
45
|
+
idx = len(self._labels)
|
|
46
|
+
self._labels[value] = idx
|
|
47
|
+
return idx
|
|
48
|
+
|
|
49
|
+
def _add_data(
|
|
50
|
+
self,
|
|
51
|
+
detections: list[Detection],
|
|
52
|
+
detection_ious: list[NDArray[np.float64]],
|
|
53
|
+
show_progress: bool = False,
|
|
54
|
+
):
|
|
55
|
+
"""Adds detections to the cache."""
|
|
56
|
+
disable_tqdm = not show_progress
|
|
57
|
+
for detection, ious in tqdm(
|
|
58
|
+
zip(detections, detection_ious), disable=disable_tqdm
|
|
59
|
+
):
|
|
60
|
+
# cache labels and annotation pairs
|
|
61
|
+
datum_idx = self._datum_count
|
|
62
|
+
datum_metadata = detection.metadata if detection.metadata else {}
|
|
63
|
+
pairs = []
|
|
64
|
+
if detection.groundtruths:
|
|
65
|
+
for gidx, gann in enumerate(detection.groundtruths):
|
|
66
|
+
gt_id = self._groundtruth_count + gidx
|
|
67
|
+
glabel = gann.labels[0]
|
|
68
|
+
glabel_idx = self._add_label(gann.labels[0])
|
|
69
|
+
gann_metadata = gann.metadata if gann.metadata else {}
|
|
70
|
+
if (ious[:, gidx] < EPSILON).all():
|
|
71
|
+
pairs.append(
|
|
72
|
+
{
|
|
73
|
+
# metadata
|
|
74
|
+
**datum_metadata,
|
|
75
|
+
**gann_metadata,
|
|
76
|
+
# datum
|
|
77
|
+
"datum_uid": detection.uid,
|
|
78
|
+
"datum_id": datum_idx,
|
|
79
|
+
# groundtruth
|
|
80
|
+
"gt_uid": gann.uid,
|
|
81
|
+
"gt_id": gt_id,
|
|
82
|
+
"gt_label": glabel,
|
|
83
|
+
"gt_label_id": glabel_idx,
|
|
84
|
+
# prediction
|
|
85
|
+
"pd_uid": None,
|
|
86
|
+
"pd_id": -1,
|
|
87
|
+
"pd_label": None,
|
|
88
|
+
"pd_label_id": -1,
|
|
89
|
+
"pd_score": -1,
|
|
90
|
+
# pair
|
|
91
|
+
"iou": 0.0,
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
for pidx, pann in enumerate(detection.predictions):
|
|
95
|
+
pann_id = self._prediction_count + pidx
|
|
96
|
+
pann_metadata = pann.metadata if pann.metadata else {}
|
|
97
|
+
if (ious[pidx, :] < EPSILON).all():
|
|
98
|
+
pairs.extend(
|
|
99
|
+
[
|
|
100
|
+
{
|
|
101
|
+
# metadata
|
|
102
|
+
**datum_metadata,
|
|
103
|
+
**pann_metadata,
|
|
104
|
+
# datum
|
|
105
|
+
"datum_uid": detection.uid,
|
|
106
|
+
"datum_id": datum_idx,
|
|
107
|
+
# groundtruth
|
|
108
|
+
"gt_uid": None,
|
|
109
|
+
"gt_id": -1,
|
|
110
|
+
"gt_label": None,
|
|
111
|
+
"gt_label_id": -1,
|
|
112
|
+
# prediction
|
|
113
|
+
"pd_uid": pann.uid,
|
|
114
|
+
"pd_id": pann_id,
|
|
115
|
+
"pd_label": plabel,
|
|
116
|
+
"pd_label_id": self._add_label(plabel),
|
|
117
|
+
"pd_score": float(pscore),
|
|
118
|
+
# pair
|
|
119
|
+
"iou": 0.0,
|
|
120
|
+
}
|
|
121
|
+
for plabel, pscore in zip(
|
|
122
|
+
pann.labels, pann.scores
|
|
123
|
+
)
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
if ious[pidx, gidx] >= EPSILON:
|
|
127
|
+
pairs.extend(
|
|
128
|
+
[
|
|
129
|
+
{
|
|
130
|
+
# metadata
|
|
131
|
+
**datum_metadata,
|
|
132
|
+
**gann_metadata,
|
|
133
|
+
**pann_metadata,
|
|
134
|
+
# datum
|
|
135
|
+
"datum_uid": detection.uid,
|
|
136
|
+
"datum_id": datum_idx,
|
|
137
|
+
# groundtruth
|
|
138
|
+
"gt_uid": gann.uid,
|
|
139
|
+
"gt_id": gt_id,
|
|
140
|
+
"gt_label": glabel,
|
|
141
|
+
"gt_label_id": self._add_label(glabel),
|
|
142
|
+
# prediction
|
|
143
|
+
"pd_uid": pann.uid,
|
|
144
|
+
"pd_id": pann_id,
|
|
145
|
+
"pd_label": plabel,
|
|
146
|
+
"pd_label_id": self._add_label(plabel),
|
|
147
|
+
"pd_score": float(pscore),
|
|
148
|
+
# pair
|
|
149
|
+
"iou": float(ious[pidx, gidx]),
|
|
150
|
+
}
|
|
151
|
+
for glabel in gann.labels
|
|
152
|
+
for plabel, pscore in zip(
|
|
153
|
+
pann.labels, pann.scores
|
|
154
|
+
)
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
elif detection.predictions:
|
|
158
|
+
for pidx, pann in enumerate(detection.predictions):
|
|
159
|
+
pann_id = self._prediction_count + pidx
|
|
160
|
+
pann_metadata = pann.metadata if pann.metadata else {}
|
|
161
|
+
pairs.extend(
|
|
162
|
+
[
|
|
163
|
+
{
|
|
164
|
+
# metadata
|
|
165
|
+
**datum_metadata,
|
|
166
|
+
**pann_metadata,
|
|
167
|
+
# datum
|
|
168
|
+
"datum_uid": detection.uid,
|
|
169
|
+
"datum_id": datum_idx,
|
|
170
|
+
# groundtruth
|
|
171
|
+
"gt_uid": None,
|
|
172
|
+
"gt_id": -1,
|
|
173
|
+
"gt_label": None,
|
|
174
|
+
"gt_label_id": -1,
|
|
175
|
+
# prediction
|
|
176
|
+
"pd_uid": pann.uid,
|
|
177
|
+
"pd_id": pann_id,
|
|
178
|
+
"pd_label": plabel,
|
|
179
|
+
"pd_label_id": self._add_label(plabel),
|
|
180
|
+
"pd_score": float(pscore),
|
|
181
|
+
# pair
|
|
182
|
+
"iou": 0.0,
|
|
183
|
+
}
|
|
184
|
+
for plabel, pscore in zip(pann.labels, pann.scores)
|
|
185
|
+
]
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
self._datum_count += 1
|
|
189
|
+
self._groundtruth_count += len(detection.groundtruths)
|
|
190
|
+
self._prediction_count += len(detection.predictions)
|
|
191
|
+
|
|
192
|
+
self._detailed_writer.write_rows(pairs)
|
|
193
|
+
|
|
194
|
+
def add_bounding_boxes(
|
|
195
|
+
self,
|
|
196
|
+
detections: list[Detection[BoundingBox]],
|
|
197
|
+
show_progress: bool = False,
|
|
198
|
+
):
|
|
199
|
+
"""
|
|
200
|
+
Adds bounding box detections to the cache.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
detections : list[Detection]
|
|
205
|
+
A list of Detection objects.
|
|
206
|
+
show_progress : bool, default=False
|
|
207
|
+
Toggle for tqdm progress bar.
|
|
208
|
+
"""
|
|
209
|
+
ious = [
|
|
210
|
+
compute_bbox_iou(
|
|
211
|
+
np.array(
|
|
212
|
+
[
|
|
213
|
+
[gt.extrema, pd.extrema]
|
|
214
|
+
for pd in detection.predictions
|
|
215
|
+
for gt in detection.groundtruths
|
|
216
|
+
],
|
|
217
|
+
dtype=np.float64,
|
|
218
|
+
)
|
|
219
|
+
).reshape(len(detection.predictions), len(detection.groundtruths))
|
|
220
|
+
for detection in detections
|
|
221
|
+
]
|
|
222
|
+
return self._add_data(
|
|
223
|
+
detections=detections,
|
|
224
|
+
detection_ious=ious,
|
|
225
|
+
show_progress=show_progress,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def add_polygons(
|
|
229
|
+
self,
|
|
230
|
+
detections: list[Detection[Polygon]],
|
|
231
|
+
show_progress: bool = False,
|
|
232
|
+
):
|
|
233
|
+
"""
|
|
234
|
+
Adds polygon detections to the cache.
|
|
235
|
+
|
|
236
|
+
Parameters
|
|
237
|
+
----------
|
|
238
|
+
detections : list[Detection]
|
|
239
|
+
A list of Detection objects.
|
|
240
|
+
show_progress : bool, default=False
|
|
241
|
+
Toggle for tqdm progress bar.
|
|
242
|
+
"""
|
|
243
|
+
ious = [
|
|
244
|
+
compute_polygon_iou(
|
|
245
|
+
np.array(
|
|
246
|
+
[
|
|
247
|
+
[gt.shape, pd.shape]
|
|
248
|
+
for pd in detection.predictions
|
|
249
|
+
for gt in detection.groundtruths
|
|
250
|
+
]
|
|
251
|
+
)
|
|
252
|
+
).reshape(len(detection.predictions), len(detection.groundtruths))
|
|
253
|
+
for detection in detections
|
|
254
|
+
]
|
|
255
|
+
return self._add_data(
|
|
256
|
+
detections=detections,
|
|
257
|
+
detection_ious=ious,
|
|
258
|
+
show_progress=show_progress,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def add_bitmasks(
|
|
262
|
+
self,
|
|
263
|
+
detections: list[Detection[Bitmask]],
|
|
264
|
+
show_progress: bool = False,
|
|
265
|
+
):
|
|
266
|
+
"""
|
|
267
|
+
Adds bitmask detections to the cache.
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
detections : list[Detection]
|
|
272
|
+
A list of Detection objects.
|
|
273
|
+
show_progress : bool, default=False
|
|
274
|
+
Toggle for tqdm progress bar.
|
|
275
|
+
"""
|
|
276
|
+
ious = [
|
|
277
|
+
compute_bitmask_iou(
|
|
278
|
+
np.array(
|
|
279
|
+
[
|
|
280
|
+
[gt.mask, pd.mask]
|
|
281
|
+
for pd in detection.predictions
|
|
282
|
+
for gt in detection.groundtruths
|
|
283
|
+
]
|
|
284
|
+
)
|
|
285
|
+
).reshape(len(detection.predictions), len(detection.groundtruths))
|
|
286
|
+
for detection in detections
|
|
287
|
+
]
|
|
288
|
+
return self._add_data(
|
|
289
|
+
detections=detections,
|
|
290
|
+
detection_ious=ious,
|
|
291
|
+
show_progress=show_progress,
|
|
292
|
+
)
|
|
@@ -18,7 +18,9 @@ class MetricType(str, Enum):
|
|
|
18
18
|
ARAveragedOverScores = "ARAveragedOverScores"
|
|
19
19
|
mARAveragedOverScores = "mARAveragedOverScores"
|
|
20
20
|
PrecisionRecallCurve = "PrecisionRecallCurve"
|
|
21
|
+
ConfusionMatrixWithExamples = "ConfusionMatrixWithExamples"
|
|
21
22
|
ConfusionMatrix = "ConfusionMatrix"
|
|
23
|
+
Examples = "Examples"
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
@dataclass
|
|
@@ -562,6 +564,153 @@ class Metric(BaseMetric):
|
|
|
562
564
|
|
|
563
565
|
@classmethod
|
|
564
566
|
def confusion_matrix(
|
|
567
|
+
cls,
|
|
568
|
+
confusion_matrix: dict[str, dict[str, int]],
|
|
569
|
+
unmatched_predictions: dict[str, int],
|
|
570
|
+
unmatched_ground_truths: dict[str, int],
|
|
571
|
+
score_threshold: float,
|
|
572
|
+
iou_threshold: float,
|
|
573
|
+
):
|
|
574
|
+
"""
|
|
575
|
+
Confusion matrix for object detection task.
|
|
576
|
+
|
|
577
|
+
This class encapsulates detailed information about the model's performance, including correct
|
|
578
|
+
predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
|
|
579
|
+
(subset of false negatives).
|
|
580
|
+
|
|
581
|
+
Confusion Matrix Format:
|
|
582
|
+
{
|
|
583
|
+
<ground truth label>: {
|
|
584
|
+
<prediction label>: 129
|
|
585
|
+
...
|
|
586
|
+
},
|
|
587
|
+
...
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
Unmatched Predictions Format:
|
|
591
|
+
{
|
|
592
|
+
<prediction label>: 11
|
|
593
|
+
...
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
Unmatched Ground Truths Format:
|
|
597
|
+
{
|
|
598
|
+
<ground truth label>: 7
|
|
599
|
+
...
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
Parameters
|
|
603
|
+
----------
|
|
604
|
+
confusion_matrix : dict
|
|
605
|
+
A nested dictionary containing integer counts of occurences where the first key is the ground truth label value
|
|
606
|
+
and the second key is the prediction label value.
|
|
607
|
+
unmatched_predictions : dict
|
|
608
|
+
A dictionary where each key is a prediction label value with no corresponding ground truth
|
|
609
|
+
(subset of false positives). The value is a dictionary containing counts.
|
|
610
|
+
unmatched_ground_truths : dict
|
|
611
|
+
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
612
|
+
(subset of false negatives). The value is a dictionary containing counts.
|
|
613
|
+
score_threshold : float
|
|
614
|
+
The confidence score threshold used to filter predictions.
|
|
615
|
+
iou_threshold : float
|
|
616
|
+
The Intersection over Union (IOU) threshold used to determine true positives.
|
|
617
|
+
|
|
618
|
+
Returns
|
|
619
|
+
-------
|
|
620
|
+
Metric
|
|
621
|
+
"""
|
|
622
|
+
return cls(
|
|
623
|
+
type=MetricType.ConfusionMatrix.value,
|
|
624
|
+
value={
|
|
625
|
+
"confusion_matrix": confusion_matrix,
|
|
626
|
+
"unmatched_predictions": unmatched_predictions,
|
|
627
|
+
"unmatched_ground_truths": unmatched_ground_truths,
|
|
628
|
+
},
|
|
629
|
+
parameters={
|
|
630
|
+
"score_threshold": score_threshold,
|
|
631
|
+
"iou_threshold": iou_threshold,
|
|
632
|
+
},
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
@classmethod
|
|
636
|
+
def examples(
|
|
637
|
+
cls,
|
|
638
|
+
datum_id: str,
|
|
639
|
+
true_positives: list[tuple[str, str]],
|
|
640
|
+
false_positives: list[str],
|
|
641
|
+
false_negatives: list[str],
|
|
642
|
+
score_threshold: float,
|
|
643
|
+
iou_threshold: float,
|
|
644
|
+
):
|
|
645
|
+
"""
|
|
646
|
+
Per-datum examples for object detection tasks.
|
|
647
|
+
|
|
648
|
+
This metric is per-datum and contains lists of annotation identifiers that categorize them
|
|
649
|
+
as true-positive, false-positive or false-negative. This is intended to be used with an
|
|
650
|
+
external database where the identifiers can be used for retrieval.
|
|
651
|
+
|
|
652
|
+
Examples Format:
|
|
653
|
+
{
|
|
654
|
+
"type": "Examples",
|
|
655
|
+
"value": {
|
|
656
|
+
"datum_id": "some string ID",
|
|
657
|
+
"true_positives": [
|
|
658
|
+
["groundtruth0", "prediction0"],
|
|
659
|
+
["groundtruth123", "prediction11"],
|
|
660
|
+
...
|
|
661
|
+
],
|
|
662
|
+
"false_positives": [
|
|
663
|
+
"prediction25",
|
|
664
|
+
"prediction92",
|
|
665
|
+
...
|
|
666
|
+
]
|
|
667
|
+
"false_negatives": [
|
|
668
|
+
"groundtruth32",
|
|
669
|
+
"groundtruth24",
|
|
670
|
+
...
|
|
671
|
+
]
|
|
672
|
+
},
|
|
673
|
+
"parameters": {
|
|
674
|
+
"score_threshold": 0.5,
|
|
675
|
+
"iou_threshold": 0.5,
|
|
676
|
+
}
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
Parameters
|
|
680
|
+
----------
|
|
681
|
+
datum_id : str
|
|
682
|
+
A string identifier representing a datum.
|
|
683
|
+
true_positives : list[tuple[str, str]]
|
|
684
|
+
A list of string identifier pairs representing true positive ground truth and prediction combinations.
|
|
685
|
+
false_positives : list[str]
|
|
686
|
+
A list of string identifiers representing false positive predictions.
|
|
687
|
+
false_negatives : list[str]
|
|
688
|
+
A list of string identifiers representing false negative ground truths.
|
|
689
|
+
score_threshold : float
|
|
690
|
+
The confidence score threshold used to filter predictions.
|
|
691
|
+
iou_threshold : float
|
|
692
|
+
The Intersection over Union (IOU) threshold used to determine true positives.
|
|
693
|
+
|
|
694
|
+
Returns
|
|
695
|
+
-------
|
|
696
|
+
Metric
|
|
697
|
+
"""
|
|
698
|
+
return cls(
|
|
699
|
+
type=MetricType.Examples.value,
|
|
700
|
+
value={
|
|
701
|
+
"datum_id": datum_id,
|
|
702
|
+
"true_positives": true_positives,
|
|
703
|
+
"false_positives": false_positives,
|
|
704
|
+
"false_negatives": false_negatives,
|
|
705
|
+
},
|
|
706
|
+
parameters={
|
|
707
|
+
"score_threshold": score_threshold,
|
|
708
|
+
"iou_threshold": iou_threshold,
|
|
709
|
+
},
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
@classmethod
|
|
713
|
+
def confusion_matrix_with_examples(
|
|
565
714
|
cls,
|
|
566
715
|
confusion_matrix: dict[
|
|
567
716
|
str, # ground truth label value
|
|
@@ -609,7 +758,7 @@ class Metric(BaseMetric):
|
|
|
609
758
|
iou_threshold: float,
|
|
610
759
|
):
|
|
611
760
|
"""
|
|
612
|
-
Confusion matrix for object detection tasks.
|
|
761
|
+
Confusion matrix with examples for object detection tasks.
|
|
613
762
|
|
|
614
763
|
This class encapsulates detailed information about the model's performance, including correct
|
|
615
764
|
predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
|
|
@@ -674,7 +823,7 @@ class Metric(BaseMetric):
|
|
|
674
823
|
A dictionary where each key is a prediction label value with no corresponding ground truth
|
|
675
824
|
(subset of false positives). The value is a dictionary containing either a `count` or a list of
|
|
676
825
|
`examples`. Each example includes annotation and datum identifers.
|
|
677
|
-
|
|
826
|
+
unmatched_groundtruths : dict
|
|
678
827
|
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
679
828
|
(subset of false negatives). The value is a dictionary containing either a `count` or a list of `examples`.
|
|
680
829
|
Each example includes annotation and datum identifers.
|
|
@@ -688,7 +837,7 @@ class Metric(BaseMetric):
|
|
|
688
837
|
Metric
|
|
689
838
|
"""
|
|
690
839
|
return cls(
|
|
691
|
-
type=MetricType.
|
|
840
|
+
type=MetricType.ConfusionMatrixWithExamples.value,
|
|
692
841
|
value={
|
|
693
842
|
"confusion_matrix": confusion_matrix,
|
|
694
843
|
"unmatched_predictions": unmatched_predictions,
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
import pyarrow.compute as pc
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from valor_lite.cache import FileCacheReader, MemoryCacheReader
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class EvaluatorInfo:
|
|
14
|
+
number_of_datums: int = 0
|
|
15
|
+
number_of_groundtruth_annotations: int = 0
|
|
16
|
+
number_of_prediction_annotations: int = 0
|
|
17
|
+
number_of_labels: int = 0
|
|
18
|
+
number_of_rows: int = 0
|
|
19
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def generate_detailed_cache_path(path: str | Path) -> Path:
|
|
23
|
+
return Path(path) / "detailed"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def generate_ranked_cache_path(path: str | Path) -> Path:
|
|
27
|
+
return Path(path) / "ranked"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def generate_temporary_cache_path(path: str | Path) -> Path:
|
|
31
|
+
return Path(path) / "tmp"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def generate_metadata_path(path: str | Path) -> Path:
|
|
35
|
+
return Path(path) / "metadata.json"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def generate_detailed_schema(
|
|
39
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
40
|
+
) -> pa.Schema:
|
|
41
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
42
|
+
reserved_fields = [
|
|
43
|
+
("datum_uid", pa.string()),
|
|
44
|
+
("datum_id", pa.int64()),
|
|
45
|
+
# groundtruth
|
|
46
|
+
("gt_uid", pa.string()),
|
|
47
|
+
("gt_id", pa.int64()),
|
|
48
|
+
("gt_label", pa.string()),
|
|
49
|
+
("gt_label_id", pa.int64()),
|
|
50
|
+
# prediction
|
|
51
|
+
("pd_uid", pa.string()),
|
|
52
|
+
("pd_id", pa.int64()),
|
|
53
|
+
("pd_label", pa.string()),
|
|
54
|
+
("pd_label_id", pa.int64()),
|
|
55
|
+
("pd_score", pa.float64()),
|
|
56
|
+
# pair
|
|
57
|
+
("iou", pa.float64()),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
# validate
|
|
61
|
+
reserved_field_names = {f[0] for f in reserved_fields}
|
|
62
|
+
metadata_field_names = {f[0] for f in metadata_fields}
|
|
63
|
+
if conflicting := reserved_field_names & metadata_field_names:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"metadata fields {conflicting} conflict with reserved fields"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return pa.schema(reserved_fields + metadata_fields)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def generate_ranked_schema(
|
|
72
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
73
|
+
) -> pa.Schema:
|
|
74
|
+
reserved_detailed_fields = [
|
|
75
|
+
("datum_uid", pa.string()),
|
|
76
|
+
("datum_id", pa.int64()),
|
|
77
|
+
# groundtruth
|
|
78
|
+
("gt_id", pa.int64()),
|
|
79
|
+
("gt_label_id", pa.int64()),
|
|
80
|
+
# prediction
|
|
81
|
+
("pd_id", pa.int64()),
|
|
82
|
+
("pd_label_id", pa.int64()),
|
|
83
|
+
("pd_score", pa.float64()),
|
|
84
|
+
# pair
|
|
85
|
+
("iou", pa.float64()),
|
|
86
|
+
]
|
|
87
|
+
reserved_ranking_fields = [
|
|
88
|
+
("iou_prev", pa.float64()),
|
|
89
|
+
]
|
|
90
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
91
|
+
|
|
92
|
+
# validate
|
|
93
|
+
reserved_field_names = {
|
|
94
|
+
f[0] for f in reserved_detailed_fields + reserved_ranking_fields
|
|
95
|
+
}
|
|
96
|
+
metadata_field_names = {f[0] for f in metadata_fields}
|
|
97
|
+
if conflicting := reserved_field_names & metadata_field_names:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"metadata fields {conflicting} conflict with reserved fields"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return pa.schema(
|
|
103
|
+
[
|
|
104
|
+
*reserved_detailed_fields,
|
|
105
|
+
*metadata_fields,
|
|
106
|
+
*reserved_ranking_fields,
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def encode_metadata_fields(
|
|
112
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
113
|
+
) -> dict[str, str]:
|
|
114
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
115
|
+
return {k: str(v) for k, v in metadata_fields}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def decode_metadata_fields(
|
|
119
|
+
encoded_metadata_fields: dict[str, str]
|
|
120
|
+
) -> list[tuple[str, str]]:
|
|
121
|
+
return [(k, v) for k, v in encoded_metadata_fields.items()]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def extract_labels(
|
|
125
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
126
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
127
|
+
) -> dict[int, str]:
|
|
128
|
+
if index_to_label_override is not None:
|
|
129
|
+
return index_to_label_override
|
|
130
|
+
|
|
131
|
+
index_to_label = {}
|
|
132
|
+
for tbl in reader.iterate_tables(
|
|
133
|
+
columns=[
|
|
134
|
+
"gt_label_id",
|
|
135
|
+
"gt_label",
|
|
136
|
+
"pd_label_id",
|
|
137
|
+
"pd_label",
|
|
138
|
+
]
|
|
139
|
+
):
|
|
140
|
+
|
|
141
|
+
# get gt labels
|
|
142
|
+
gt_label_ids = tbl["gt_label_id"].to_numpy()
|
|
143
|
+
gt_label_ids, gt_indices = np.unique(gt_label_ids, return_index=True)
|
|
144
|
+
gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
|
|
145
|
+
gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
|
|
146
|
+
gt_labels.pop(-1, None)
|
|
147
|
+
index_to_label.update(gt_labels)
|
|
148
|
+
|
|
149
|
+
# get pd labels
|
|
150
|
+
pd_label_ids = tbl["pd_label_id"].to_numpy()
|
|
151
|
+
pd_label_ids, pd_indices = np.unique(pd_label_ids, return_index=True)
|
|
152
|
+
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
|
|
153
|
+
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
|
|
154
|
+
pd_labels.pop(-1, None)
|
|
155
|
+
index_to_label.update(pd_labels)
|
|
156
|
+
|
|
157
|
+
return index_to_label
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def extract_counts(
|
|
161
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
162
|
+
datums: pc.Expression | None = None,
|
|
163
|
+
groundtruths: pc.Expression | None = None,
|
|
164
|
+
predictions: pc.Expression | None = None,
|
|
165
|
+
):
|
|
166
|
+
n_dts, n_gts, n_pds = 0, 0, 0
|
|
167
|
+
for tbl in reader.iterate_tables(filter=datums):
|
|
168
|
+
# count datums
|
|
169
|
+
n_dts += int(np.unique(tbl["datum_id"].to_numpy()).shape[0])
|
|
170
|
+
|
|
171
|
+
# count groundtruths
|
|
172
|
+
if groundtruths is not None:
|
|
173
|
+
gts = tbl.filter(groundtruths)["gt_id"].to_numpy()
|
|
174
|
+
else:
|
|
175
|
+
gts = tbl["gt_id"].to_numpy()
|
|
176
|
+
n_gts += int(np.unique(gts[gts >= 0]).shape[0])
|
|
177
|
+
|
|
178
|
+
# count predictions
|
|
179
|
+
if predictions is not None:
|
|
180
|
+
pds = tbl.filter(predictions)["pd_id"].to_numpy()
|
|
181
|
+
else:
|
|
182
|
+
pds = tbl["pd_id"].to_numpy()
|
|
183
|
+
n_pds += int(np.unique(pds[pds >= 0]).shape[0])
|
|
184
|
+
|
|
185
|
+
return n_dts, n_gts, n_pds
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def extract_groundtruth_count_per_label(
|
|
189
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
190
|
+
number_of_labels: int,
|
|
191
|
+
datums: pc.Expression | None = None,
|
|
192
|
+
) -> NDArray[np.uint64]:
|
|
193
|
+
gt_counts_per_lbl = np.zeros(number_of_labels, dtype=np.uint64)
|
|
194
|
+
for gts in reader.iterate_arrays(
|
|
195
|
+
numeric_columns=["gt_id", "gt_label_id"],
|
|
196
|
+
filter=datums,
|
|
197
|
+
):
|
|
198
|
+
# count gts per label
|
|
199
|
+
unique_ann = np.unique(gts[gts[:, 0] >= 0], axis=0)
|
|
200
|
+
unique_labels, label_counts = np.unique(
|
|
201
|
+
unique_ann[:, 1], return_counts=True
|
|
202
|
+
)
|
|
203
|
+
for label_id, count in zip(unique_labels, label_counts):
|
|
204
|
+
gt_counts_per_lbl[int(label_id)] += int(count)
|
|
205
|
+
|
|
206
|
+
return gt_counts_per_lbl
|