valor-lite 0.36.5__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +367 -304
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -109
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -865
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.5.dist-info/RECORD +0 -41
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -1,865 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from dataclasses import asdict, dataclass
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
from numpy.typing import NDArray
|
|
6
|
-
from tqdm import tqdm
|
|
7
|
-
|
|
8
|
-
from valor_lite.exceptions import (
|
|
9
|
-
EmptyEvaluatorError,
|
|
10
|
-
EmptyFilterError,
|
|
11
|
-
InternalCacheError,
|
|
12
|
-
)
|
|
13
|
-
from valor_lite.object_detection.annotation import (
|
|
14
|
-
Bitmask,
|
|
15
|
-
BoundingBox,
|
|
16
|
-
Detection,
|
|
17
|
-
Polygon,
|
|
18
|
-
)
|
|
19
|
-
from valor_lite.object_detection.computation import (
|
|
20
|
-
compute_bbox_iou,
|
|
21
|
-
compute_bitmask_iou,
|
|
22
|
-
compute_confusion_matrix,
|
|
23
|
-
compute_label_metadata,
|
|
24
|
-
compute_polygon_iou,
|
|
25
|
-
compute_precion_recall,
|
|
26
|
-
filter_cache,
|
|
27
|
-
rank_pairs,
|
|
28
|
-
)
|
|
29
|
-
from valor_lite.object_detection.metric import Metric, MetricType
|
|
30
|
-
from valor_lite.object_detection.utilities import (
|
|
31
|
-
unpack_confusion_matrix_into_metric_list,
|
|
32
|
-
unpack_precision_recall_into_metric_lists,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
"""
|
|
36
|
-
Usage
|
|
37
|
-
-----
|
|
38
|
-
|
|
39
|
-
loader = DataLoader()
|
|
40
|
-
loader.add_bounding_boxes(
|
|
41
|
-
groundtruths=groundtruths,
|
|
42
|
-
predictions=predictions,
|
|
43
|
-
)
|
|
44
|
-
evaluator = loader.finalize()
|
|
45
|
-
|
|
46
|
-
metrics = evaluator.evaluate(iou_thresholds=[0.5])
|
|
47
|
-
|
|
48
|
-
ap_metrics = metrics[MetricType.AP]
|
|
49
|
-
ar_metrics = metrics[MetricType.AR]
|
|
50
|
-
|
|
51
|
-
filter_mask = evaluator.create_filter(datum_uids=["uid1", "uid2"])
|
|
52
|
-
filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_mask)
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@dataclass
|
|
57
|
-
class Metadata:
|
|
58
|
-
number_of_datums: int = 0
|
|
59
|
-
number_of_ground_truths: int = 0
|
|
60
|
-
number_of_predictions: int = 0
|
|
61
|
-
number_of_labels: int = 0
|
|
62
|
-
|
|
63
|
-
@classmethod
|
|
64
|
-
def create(
|
|
65
|
-
cls,
|
|
66
|
-
detailed_pairs: NDArray[np.float64],
|
|
67
|
-
number_of_datums: int,
|
|
68
|
-
number_of_labels: int,
|
|
69
|
-
):
|
|
70
|
-
# count number of ground truths
|
|
71
|
-
mask_valid_gts = detailed_pairs[:, 1] >= 0
|
|
72
|
-
unique_ids = np.unique(
|
|
73
|
-
detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
|
|
74
|
-
)
|
|
75
|
-
number_of_ground_truths = int(unique_ids.shape[0])
|
|
76
|
-
|
|
77
|
-
# count number of predictions
|
|
78
|
-
mask_valid_pds = detailed_pairs[:, 2] >= 0
|
|
79
|
-
unique_ids = np.unique(
|
|
80
|
-
detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
|
|
81
|
-
)
|
|
82
|
-
number_of_predictions = int(unique_ids.shape[0])
|
|
83
|
-
|
|
84
|
-
return cls(
|
|
85
|
-
number_of_datums=number_of_datums,
|
|
86
|
-
number_of_ground_truths=number_of_ground_truths,
|
|
87
|
-
number_of_predictions=number_of_predictions,
|
|
88
|
-
number_of_labels=number_of_labels,
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
def to_dict(self) -> dict[str, int | bool]:
|
|
92
|
-
return asdict(self)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
@dataclass
|
|
96
|
-
class Filter:
|
|
97
|
-
mask_datums: NDArray[np.bool_]
|
|
98
|
-
mask_groundtruths: NDArray[np.bool_]
|
|
99
|
-
mask_predictions: NDArray[np.bool_]
|
|
100
|
-
metadata: Metadata
|
|
101
|
-
|
|
102
|
-
def __post_init__(self):
|
|
103
|
-
# validate datums mask
|
|
104
|
-
if not self.mask_datums.any():
|
|
105
|
-
raise EmptyFilterError("filter removes all datums")
|
|
106
|
-
|
|
107
|
-
# validate annotation masks
|
|
108
|
-
no_gts = self.mask_groundtruths.all()
|
|
109
|
-
no_pds = self.mask_predictions.all()
|
|
110
|
-
if no_gts and no_pds:
|
|
111
|
-
raise EmptyFilterError("filter removes all annotations")
|
|
112
|
-
elif no_gts:
|
|
113
|
-
warnings.warn("filter removes all ground truths")
|
|
114
|
-
elif no_pds:
|
|
115
|
-
warnings.warn("filter removes all predictions")
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class Evaluator:
|
|
119
|
-
"""
|
|
120
|
-
Object Detection Evaluator
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
def __init__(self):
|
|
124
|
-
|
|
125
|
-
# external reference
|
|
126
|
-
self.datum_id_to_index: dict[str, int] = {}
|
|
127
|
-
self.groundtruth_id_to_index: dict[str, int] = {}
|
|
128
|
-
self.prediction_id_to_index: dict[str, int] = {}
|
|
129
|
-
self.label_to_index: dict[str, int] = {}
|
|
130
|
-
|
|
131
|
-
self.index_to_datum_id: list[str] = []
|
|
132
|
-
self.index_to_groundtruth_id: list[str] = []
|
|
133
|
-
self.index_to_prediction_id: list[str] = []
|
|
134
|
-
self.index_to_label: list[str] = []
|
|
135
|
-
|
|
136
|
-
# temporary cache
|
|
137
|
-
self._temp_cache: list[NDArray[np.float64]] | None = []
|
|
138
|
-
|
|
139
|
-
# internal cache
|
|
140
|
-
self._detailed_pairs = np.array([[]], dtype=np.float64)
|
|
141
|
-
self._ranked_pairs = np.array([[]], dtype=np.float64)
|
|
142
|
-
self._label_metadata: NDArray[np.int32] = np.array([[]])
|
|
143
|
-
self._metadata = Metadata()
|
|
144
|
-
|
|
145
|
-
@property
|
|
146
|
-
def ignored_prediction_labels(self) -> list[str]:
|
|
147
|
-
"""
|
|
148
|
-
Prediction labels that are not present in the ground truth set.
|
|
149
|
-
"""
|
|
150
|
-
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
151
|
-
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
152
|
-
return [
|
|
153
|
-
self.index_to_label[label_id] for label_id in (plabels - glabels)
|
|
154
|
-
]
|
|
155
|
-
|
|
156
|
-
@property
|
|
157
|
-
def missing_prediction_labels(self) -> list[str]:
|
|
158
|
-
"""
|
|
159
|
-
Ground truth labels that are not present in the prediction set.
|
|
160
|
-
"""
|
|
161
|
-
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
162
|
-
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
163
|
-
return [
|
|
164
|
-
self.index_to_label[label_id] for label_id in (glabels - plabels)
|
|
165
|
-
]
|
|
166
|
-
|
|
167
|
-
@property
|
|
168
|
-
def metadata(self) -> Metadata:
|
|
169
|
-
"""
|
|
170
|
-
Evaluation metadata.
|
|
171
|
-
"""
|
|
172
|
-
return self._metadata
|
|
173
|
-
|
|
174
|
-
def create_filter(
|
|
175
|
-
self,
|
|
176
|
-
datums: list[str] | NDArray[np.int32] | None = None,
|
|
177
|
-
groundtruths: list[str] | NDArray[np.int32] | None = None,
|
|
178
|
-
predictions: list[str] | NDArray[np.int32] | None = None,
|
|
179
|
-
labels: list[str] | NDArray[np.int32] | None = None,
|
|
180
|
-
) -> Filter:
|
|
181
|
-
"""
|
|
182
|
-
Creates a filter object.
|
|
183
|
-
|
|
184
|
-
Parameters
|
|
185
|
-
----------
|
|
186
|
-
datum : list[str] | NDArray[int32], optional
|
|
187
|
-
An optional list of string ids or indices representing datums to keep.
|
|
188
|
-
groundtruth : list[str] | NDArray[int32], optional
|
|
189
|
-
An optional list of string ids or indices representing ground truth annotations to keep.
|
|
190
|
-
prediction : list[str] | NDArray[int32], optional
|
|
191
|
-
An optional list of string ids or indices representing prediction annotations to keep.
|
|
192
|
-
labels : list[str] | NDArray[int32], optional
|
|
193
|
-
An optional list of labels or indices to keep.
|
|
194
|
-
"""
|
|
195
|
-
mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
|
|
196
|
-
|
|
197
|
-
# filter datums
|
|
198
|
-
if datums is not None:
|
|
199
|
-
# convert to indices
|
|
200
|
-
if isinstance(datums, list):
|
|
201
|
-
datums = np.array(
|
|
202
|
-
[self.datum_id_to_index[uid] for uid in datums],
|
|
203
|
-
dtype=np.int32,
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
# validate indices
|
|
207
|
-
if datums.size == 0:
|
|
208
|
-
raise EmptyFilterError(
|
|
209
|
-
"filter removes all datums"
|
|
210
|
-
) # validate indices
|
|
211
|
-
elif datums.min() < 0:
|
|
212
|
-
raise ValueError(
|
|
213
|
-
f"datum index cannot be negative '{datums.min()}'"
|
|
214
|
-
)
|
|
215
|
-
elif datums.max() >= len(self.index_to_datum_id):
|
|
216
|
-
raise ValueError(
|
|
217
|
-
f"datum index cannot exceed total number of datums '{datums.max()}'"
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
# apply to mask
|
|
221
|
-
mask_datums = np.isin(self._detailed_pairs[:, 0], datums)
|
|
222
|
-
|
|
223
|
-
filtered_detailed_pairs = self._detailed_pairs[mask_datums]
|
|
224
|
-
n_pairs = self._detailed_pairs[mask_datums].shape[0]
|
|
225
|
-
mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
|
|
226
|
-
mask_predictions = np.zeros_like(mask_groundtruths)
|
|
227
|
-
|
|
228
|
-
# filter by ground truth annotation ids
|
|
229
|
-
if groundtruths is not None:
|
|
230
|
-
# convert to indices
|
|
231
|
-
if isinstance(groundtruths, list):
|
|
232
|
-
groundtruths = np.array(
|
|
233
|
-
[
|
|
234
|
-
self.groundtruth_id_to_index[uid]
|
|
235
|
-
for uid in groundtruths
|
|
236
|
-
],
|
|
237
|
-
dtype=np.int32,
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
# validate indices
|
|
241
|
-
if groundtruths.size == 0:
|
|
242
|
-
warnings.warn("filter removes all ground truths")
|
|
243
|
-
elif groundtruths.min() < 0:
|
|
244
|
-
raise ValueError(
|
|
245
|
-
f"groundtruth annotation index cannot be negative '{groundtruths.min()}'"
|
|
246
|
-
)
|
|
247
|
-
elif groundtruths.max() >= len(self.index_to_groundtruth_id):
|
|
248
|
-
raise ValueError(
|
|
249
|
-
f"groundtruth annotation index cannot exceed total number of groundtruths '{groundtruths.max()}'"
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
# apply to mask
|
|
253
|
-
mask_groundtruths[
|
|
254
|
-
~np.isin(
|
|
255
|
-
filtered_detailed_pairs[:, 1],
|
|
256
|
-
groundtruths,
|
|
257
|
-
)
|
|
258
|
-
] = True
|
|
259
|
-
|
|
260
|
-
# filter by prediction annotation ids
|
|
261
|
-
if predictions is not None:
|
|
262
|
-
# convert to indices
|
|
263
|
-
if isinstance(predictions, list):
|
|
264
|
-
predictions = np.array(
|
|
265
|
-
[self.prediction_id_to_index[uid] for uid in predictions],
|
|
266
|
-
dtype=np.int32,
|
|
267
|
-
)
|
|
268
|
-
|
|
269
|
-
# validate indices
|
|
270
|
-
if predictions.size == 0:
|
|
271
|
-
warnings.warn("filter removes all predictions")
|
|
272
|
-
elif predictions.min() < 0:
|
|
273
|
-
raise ValueError(
|
|
274
|
-
f"prediction annotation index cannot be negative '{predictions.min()}'"
|
|
275
|
-
)
|
|
276
|
-
elif predictions.max() >= len(self.index_to_prediction_id):
|
|
277
|
-
raise ValueError(
|
|
278
|
-
f"prediction annotation index cannot exceed total number of predictions '{predictions.max()}'"
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
# apply to mask
|
|
282
|
-
mask_predictions[
|
|
283
|
-
~np.isin(
|
|
284
|
-
filtered_detailed_pairs[:, 2],
|
|
285
|
-
predictions,
|
|
286
|
-
)
|
|
287
|
-
] = True
|
|
288
|
-
|
|
289
|
-
# filter by labels
|
|
290
|
-
if labels is not None:
|
|
291
|
-
# convert to indices
|
|
292
|
-
if isinstance(labels, list):
|
|
293
|
-
labels = np.array(
|
|
294
|
-
[self.label_to_index[label] for label in labels]
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
# validate indices
|
|
298
|
-
if labels.size == 0:
|
|
299
|
-
raise EmptyFilterError("filter removes all labels")
|
|
300
|
-
elif labels.min() < 0:
|
|
301
|
-
raise ValueError(
|
|
302
|
-
f"label index cannot be negative '{labels.min()}'"
|
|
303
|
-
)
|
|
304
|
-
elif labels.max() >= len(self.index_to_label):
|
|
305
|
-
raise ValueError(
|
|
306
|
-
f"label index cannot exceed total number of labels '{labels.max()}'"
|
|
307
|
-
)
|
|
308
|
-
|
|
309
|
-
# apply to mask
|
|
310
|
-
labels = np.concatenate([labels, np.array([-1])]) # add null label
|
|
311
|
-
mask_groundtruths[
|
|
312
|
-
~np.isin(filtered_detailed_pairs[:, 3], labels)
|
|
313
|
-
] = True
|
|
314
|
-
mask_predictions[
|
|
315
|
-
~np.isin(filtered_detailed_pairs[:, 4], labels)
|
|
316
|
-
] = True
|
|
317
|
-
|
|
318
|
-
filtered_detailed_pairs, _, _ = filter_cache(
|
|
319
|
-
self._detailed_pairs,
|
|
320
|
-
mask_datums=mask_datums,
|
|
321
|
-
mask_ground_truths=mask_groundtruths,
|
|
322
|
-
mask_predictions=mask_predictions,
|
|
323
|
-
n_labels=len(self.index_to_label),
|
|
324
|
-
)
|
|
325
|
-
|
|
326
|
-
number_of_datums = (
|
|
327
|
-
datums.size
|
|
328
|
-
if datums is not None
|
|
329
|
-
else np.unique(filtered_detailed_pairs[:, 0]).size
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
return Filter(
|
|
333
|
-
mask_datums=mask_datums,
|
|
334
|
-
mask_groundtruths=mask_groundtruths,
|
|
335
|
-
mask_predictions=mask_predictions,
|
|
336
|
-
metadata=Metadata.create(
|
|
337
|
-
detailed_pairs=filtered_detailed_pairs,
|
|
338
|
-
number_of_datums=number_of_datums,
|
|
339
|
-
number_of_labels=len(self.index_to_label),
|
|
340
|
-
),
|
|
341
|
-
)
|
|
342
|
-
|
|
343
|
-
def filter(
|
|
344
|
-
self, filter_: Filter
|
|
345
|
-
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
|
|
346
|
-
"""
|
|
347
|
-
Performs filtering over the internal cache.
|
|
348
|
-
|
|
349
|
-
Parameters
|
|
350
|
-
----------
|
|
351
|
-
filter_ : Filter
|
|
352
|
-
The filter parameterization.
|
|
353
|
-
|
|
354
|
-
Returns
|
|
355
|
-
-------
|
|
356
|
-
NDArray[float64]
|
|
357
|
-
Filtered detailed pairs.
|
|
358
|
-
NDArray[float64]
|
|
359
|
-
Filtered ranked pairs.
|
|
360
|
-
NDArray[int32]
|
|
361
|
-
Label metadata.
|
|
362
|
-
"""
|
|
363
|
-
return filter_cache(
|
|
364
|
-
detailed_pairs=self._detailed_pairs,
|
|
365
|
-
mask_datums=filter_.mask_datums,
|
|
366
|
-
mask_ground_truths=filter_.mask_groundtruths,
|
|
367
|
-
mask_predictions=filter_.mask_predictions,
|
|
368
|
-
n_labels=len(self.index_to_label),
|
|
369
|
-
)
|
|
370
|
-
|
|
371
|
-
def compute_precision_recall(
|
|
372
|
-
self,
|
|
373
|
-
iou_thresholds: list[float],
|
|
374
|
-
score_thresholds: list[float],
|
|
375
|
-
filter_: Filter | None = None,
|
|
376
|
-
) -> dict[MetricType, list[Metric]]:
|
|
377
|
-
"""
|
|
378
|
-
Computes all metrics except for ConfusionMatrix
|
|
379
|
-
|
|
380
|
-
Parameters
|
|
381
|
-
----------
|
|
382
|
-
iou_thresholds : list[float]
|
|
383
|
-
A list of IOU thresholds to compute metrics over.
|
|
384
|
-
score_thresholds : list[float]
|
|
385
|
-
A list of score thresholds to compute metrics over.
|
|
386
|
-
filter_ : Filter, optional
|
|
387
|
-
A collection of filter parameters and masks.
|
|
388
|
-
|
|
389
|
-
Returns
|
|
390
|
-
-------
|
|
391
|
-
dict[MetricType, list]
|
|
392
|
-
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
393
|
-
"""
|
|
394
|
-
if not iou_thresholds:
|
|
395
|
-
raise ValueError("At least one IOU threshold must be passed.")
|
|
396
|
-
elif not score_thresholds:
|
|
397
|
-
raise ValueError("At least one score threshold must be passed.")
|
|
398
|
-
|
|
399
|
-
if filter_ is not None:
|
|
400
|
-
_, ranked_pairs, label_metadata = self.filter(filter_=filter_)
|
|
401
|
-
else:
|
|
402
|
-
ranked_pairs = self._ranked_pairs
|
|
403
|
-
label_metadata = self._label_metadata
|
|
404
|
-
|
|
405
|
-
results = compute_precion_recall(
|
|
406
|
-
ranked_pairs=ranked_pairs,
|
|
407
|
-
label_metadata=label_metadata,
|
|
408
|
-
iou_thresholds=np.array(iou_thresholds),
|
|
409
|
-
score_thresholds=np.array(score_thresholds),
|
|
410
|
-
)
|
|
411
|
-
return unpack_precision_recall_into_metric_lists(
|
|
412
|
-
results=results,
|
|
413
|
-
label_metadata=label_metadata,
|
|
414
|
-
iou_thresholds=iou_thresholds,
|
|
415
|
-
score_thresholds=score_thresholds,
|
|
416
|
-
index_to_label=self.index_to_label,
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
def compute_confusion_matrix(
|
|
420
|
-
self,
|
|
421
|
-
iou_thresholds: list[float],
|
|
422
|
-
score_thresholds: list[float],
|
|
423
|
-
filter_: Filter | None = None,
|
|
424
|
-
) -> list[Metric]:
|
|
425
|
-
"""
|
|
426
|
-
Computes confusion matrices at various thresholds.
|
|
427
|
-
|
|
428
|
-
Parameters
|
|
429
|
-
----------
|
|
430
|
-
iou_thresholds : list[float]
|
|
431
|
-
A list of IOU thresholds to compute metrics over.
|
|
432
|
-
score_thresholds : list[float]
|
|
433
|
-
A list of score thresholds to compute metrics over.
|
|
434
|
-
filter_ : Filter, optional
|
|
435
|
-
A collection of filter parameters and masks.
|
|
436
|
-
|
|
437
|
-
Returns
|
|
438
|
-
-------
|
|
439
|
-
list[Metric]
|
|
440
|
-
List of confusion matrices per threshold pair.
|
|
441
|
-
"""
|
|
442
|
-
if not iou_thresholds:
|
|
443
|
-
raise ValueError("At least one IOU threshold must be passed.")
|
|
444
|
-
elif not score_thresholds:
|
|
445
|
-
raise ValueError("At least one score threshold must be passed.")
|
|
446
|
-
|
|
447
|
-
if filter_ is not None:
|
|
448
|
-
detailed_pairs, _, _ = self.filter(filter_=filter_)
|
|
449
|
-
else:
|
|
450
|
-
detailed_pairs = self._detailed_pairs
|
|
451
|
-
|
|
452
|
-
if detailed_pairs.size == 0:
|
|
453
|
-
return []
|
|
454
|
-
|
|
455
|
-
results = compute_confusion_matrix(
|
|
456
|
-
detailed_pairs=detailed_pairs,
|
|
457
|
-
iou_thresholds=np.array(iou_thresholds),
|
|
458
|
-
score_thresholds=np.array(score_thresholds),
|
|
459
|
-
)
|
|
460
|
-
return unpack_confusion_matrix_into_metric_list(
|
|
461
|
-
results=results,
|
|
462
|
-
detailed_pairs=detailed_pairs,
|
|
463
|
-
iou_thresholds=iou_thresholds,
|
|
464
|
-
score_thresholds=score_thresholds,
|
|
465
|
-
index_to_datum_id=self.index_to_datum_id,
|
|
466
|
-
index_to_groundtruth_id=self.index_to_groundtruth_id,
|
|
467
|
-
index_to_prediction_id=self.index_to_prediction_id,
|
|
468
|
-
index_to_label=self.index_to_label,
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
def evaluate(
|
|
472
|
-
self,
|
|
473
|
-
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
|
|
474
|
-
score_thresholds: list[float] = [0.5],
|
|
475
|
-
filter_: Filter | None = None,
|
|
476
|
-
) -> dict[MetricType, list[Metric]]:
|
|
477
|
-
"""
|
|
478
|
-
Computes all available metrics.
|
|
479
|
-
|
|
480
|
-
Parameters
|
|
481
|
-
----------
|
|
482
|
-
iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
|
|
483
|
-
A list of IOU thresholds to compute metrics over.
|
|
484
|
-
score_thresholds : list[float], default=[0.5]
|
|
485
|
-
A list of score thresholds to compute metrics over.
|
|
486
|
-
filter_ : Filter, optional
|
|
487
|
-
A collection of filter parameters and masks.
|
|
488
|
-
|
|
489
|
-
Returns
|
|
490
|
-
-------
|
|
491
|
-
dict[MetricType, list[Metric]]
|
|
492
|
-
Lists of metrics organized by metric type.
|
|
493
|
-
"""
|
|
494
|
-
metrics = self.compute_precision_recall(
|
|
495
|
-
iou_thresholds=iou_thresholds,
|
|
496
|
-
score_thresholds=score_thresholds,
|
|
497
|
-
filter_=filter_,
|
|
498
|
-
)
|
|
499
|
-
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
500
|
-
iou_thresholds=iou_thresholds,
|
|
501
|
-
score_thresholds=score_thresholds,
|
|
502
|
-
filter_=filter_,
|
|
503
|
-
)
|
|
504
|
-
return metrics
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
class DataLoader:
|
|
508
|
-
"""
|
|
509
|
-
Object Detection DataLoader
|
|
510
|
-
"""
|
|
511
|
-
|
|
512
|
-
def __init__(self):
|
|
513
|
-
self._evaluator = Evaluator()
|
|
514
|
-
self.pairs: list[NDArray[np.float64]] = list()
|
|
515
|
-
|
|
516
|
-
def _add_datum(self, datum_id: str) -> int:
|
|
517
|
-
"""
|
|
518
|
-
Helper function for adding a datum to the cache.
|
|
519
|
-
|
|
520
|
-
Parameters
|
|
521
|
-
----------
|
|
522
|
-
datum_id : str
|
|
523
|
-
The datum identifier.
|
|
524
|
-
|
|
525
|
-
Returns
|
|
526
|
-
-------
|
|
527
|
-
int
|
|
528
|
-
The datum index.
|
|
529
|
-
"""
|
|
530
|
-
if datum_id not in self._evaluator.datum_id_to_index:
|
|
531
|
-
if len(self._evaluator.datum_id_to_index) != len(
|
|
532
|
-
self._evaluator.index_to_datum_id
|
|
533
|
-
):
|
|
534
|
-
raise InternalCacheError("datum cache size mismatch")
|
|
535
|
-
idx = len(self._evaluator.datum_id_to_index)
|
|
536
|
-
self._evaluator.datum_id_to_index[datum_id] = idx
|
|
537
|
-
self._evaluator.index_to_datum_id.append(datum_id)
|
|
538
|
-
return self._evaluator.datum_id_to_index[datum_id]
|
|
539
|
-
|
|
540
|
-
def _add_groundtruth(self, annotation_id: str) -> int:
|
|
541
|
-
"""
|
|
542
|
-
Helper function for adding a ground truth annotation identifier to the cache.
|
|
543
|
-
|
|
544
|
-
Parameters
|
|
545
|
-
----------
|
|
546
|
-
annotation_id : str
|
|
547
|
-
The ground truth annotation identifier.
|
|
548
|
-
|
|
549
|
-
Returns
|
|
550
|
-
-------
|
|
551
|
-
int
|
|
552
|
-
The ground truth annotation index.
|
|
553
|
-
"""
|
|
554
|
-
if annotation_id not in self._evaluator.groundtruth_id_to_index:
|
|
555
|
-
if len(self._evaluator.groundtruth_id_to_index) != len(
|
|
556
|
-
self._evaluator.index_to_groundtruth_id
|
|
557
|
-
):
|
|
558
|
-
raise InternalCacheError("ground truth cache size mismatch")
|
|
559
|
-
idx = len(self._evaluator.groundtruth_id_to_index)
|
|
560
|
-
self._evaluator.groundtruth_id_to_index[annotation_id] = idx
|
|
561
|
-
self._evaluator.index_to_groundtruth_id.append(annotation_id)
|
|
562
|
-
return self._evaluator.groundtruth_id_to_index[annotation_id]
|
|
563
|
-
|
|
564
|
-
def _add_prediction(self, annotation_id: str) -> int:
|
|
565
|
-
"""
|
|
566
|
-
Helper function for adding a prediction annotation identifier to the cache.
|
|
567
|
-
|
|
568
|
-
Parameters
|
|
569
|
-
----------
|
|
570
|
-
annotation_id : str
|
|
571
|
-
The prediction annotation identifier.
|
|
572
|
-
|
|
573
|
-
Returns
|
|
574
|
-
-------
|
|
575
|
-
int
|
|
576
|
-
The prediction annotation index.
|
|
577
|
-
"""
|
|
578
|
-
if annotation_id not in self._evaluator.prediction_id_to_index:
|
|
579
|
-
if len(self._evaluator.prediction_id_to_index) != len(
|
|
580
|
-
self._evaluator.index_to_prediction_id
|
|
581
|
-
):
|
|
582
|
-
raise InternalCacheError("prediction cache size mismatch")
|
|
583
|
-
idx = len(self._evaluator.prediction_id_to_index)
|
|
584
|
-
self._evaluator.prediction_id_to_index[annotation_id] = idx
|
|
585
|
-
self._evaluator.index_to_prediction_id.append(annotation_id)
|
|
586
|
-
return self._evaluator.prediction_id_to_index[annotation_id]
|
|
587
|
-
|
|
588
|
-
def _add_label(self, label: str) -> int:
|
|
589
|
-
"""
|
|
590
|
-
Helper function for adding a label to the cache.
|
|
591
|
-
|
|
592
|
-
Parameters
|
|
593
|
-
----------
|
|
594
|
-
label : str
|
|
595
|
-
The label associated with the annotation.
|
|
596
|
-
|
|
597
|
-
Returns
|
|
598
|
-
-------
|
|
599
|
-
int
|
|
600
|
-
Label index.
|
|
601
|
-
"""
|
|
602
|
-
label_id = len(self._evaluator.index_to_label)
|
|
603
|
-
if label not in self._evaluator.label_to_index:
|
|
604
|
-
if len(self._evaluator.label_to_index) != len(
|
|
605
|
-
self._evaluator.index_to_label
|
|
606
|
-
):
|
|
607
|
-
raise InternalCacheError("label cache size mismatch")
|
|
608
|
-
self._evaluator.label_to_index[label] = label_id
|
|
609
|
-
self._evaluator.index_to_label.append(label)
|
|
610
|
-
label_id += 1
|
|
611
|
-
return self._evaluator.label_to_index[label]
|
|
612
|
-
|
|
613
|
-
def _add_data(
|
|
614
|
-
self,
|
|
615
|
-
detections: list[Detection],
|
|
616
|
-
detection_ious: list[NDArray[np.float64]],
|
|
617
|
-
show_progress: bool = False,
|
|
618
|
-
):
|
|
619
|
-
"""
|
|
620
|
-
Adds detections to the cache.
|
|
621
|
-
|
|
622
|
-
Parameters
|
|
623
|
-
----------
|
|
624
|
-
detections : list[Detection]
|
|
625
|
-
A list of Detection objects.
|
|
626
|
-
detection_ious : list[NDArray[np.float64]]
|
|
627
|
-
A list of arrays containing IOUs per detection.
|
|
628
|
-
show_progress : bool, default=False
|
|
629
|
-
Toggle for tqdm progress bar.
|
|
630
|
-
"""
|
|
631
|
-
disable_tqdm = not show_progress
|
|
632
|
-
for detection, ious in tqdm(
|
|
633
|
-
zip(detections, detection_ious), disable=disable_tqdm
|
|
634
|
-
):
|
|
635
|
-
# cache labels and annotation pairs
|
|
636
|
-
pairs = []
|
|
637
|
-
datum_idx = self._add_datum(detection.uid)
|
|
638
|
-
if detection.groundtruths:
|
|
639
|
-
for gidx, gann in enumerate(detection.groundtruths):
|
|
640
|
-
groundtruth_idx = self._add_groundtruth(gann.uid)
|
|
641
|
-
glabel_idx = self._add_label(gann.labels[0])
|
|
642
|
-
if (ious[:, gidx] < 1e-9).all():
|
|
643
|
-
pairs.extend(
|
|
644
|
-
[
|
|
645
|
-
np.array(
|
|
646
|
-
[
|
|
647
|
-
float(datum_idx),
|
|
648
|
-
float(groundtruth_idx),
|
|
649
|
-
-1.0,
|
|
650
|
-
float(glabel_idx),
|
|
651
|
-
-1.0,
|
|
652
|
-
0.0,
|
|
653
|
-
-1.0,
|
|
654
|
-
]
|
|
655
|
-
)
|
|
656
|
-
]
|
|
657
|
-
)
|
|
658
|
-
for pidx, pann in enumerate(detection.predictions):
|
|
659
|
-
prediction_idx = self._add_prediction(pann.uid)
|
|
660
|
-
if (ious[pidx, :] < 1e-9).all():
|
|
661
|
-
pairs.extend(
|
|
662
|
-
[
|
|
663
|
-
np.array(
|
|
664
|
-
[
|
|
665
|
-
float(datum_idx),
|
|
666
|
-
-1.0,
|
|
667
|
-
float(prediction_idx),
|
|
668
|
-
-1.0,
|
|
669
|
-
float(self._add_label(plabel)),
|
|
670
|
-
0.0,
|
|
671
|
-
float(pscore),
|
|
672
|
-
]
|
|
673
|
-
)
|
|
674
|
-
for plabel, pscore in zip(
|
|
675
|
-
pann.labels, pann.scores
|
|
676
|
-
)
|
|
677
|
-
]
|
|
678
|
-
)
|
|
679
|
-
if ious[pidx, gidx] >= 1e-9:
|
|
680
|
-
pairs.extend(
|
|
681
|
-
[
|
|
682
|
-
np.array(
|
|
683
|
-
[
|
|
684
|
-
float(datum_idx),
|
|
685
|
-
float(groundtruth_idx),
|
|
686
|
-
float(prediction_idx),
|
|
687
|
-
float(self._add_label(glabel)),
|
|
688
|
-
float(self._add_label(plabel)),
|
|
689
|
-
ious[pidx, gidx],
|
|
690
|
-
float(pscore),
|
|
691
|
-
]
|
|
692
|
-
)
|
|
693
|
-
for glabel in gann.labels
|
|
694
|
-
for plabel, pscore in zip(
|
|
695
|
-
pann.labels, pann.scores
|
|
696
|
-
)
|
|
697
|
-
]
|
|
698
|
-
)
|
|
699
|
-
elif detection.predictions:
|
|
700
|
-
for pidx, pann in enumerate(detection.predictions):
|
|
701
|
-
prediction_idx = self._add_prediction(pann.uid)
|
|
702
|
-
pairs.extend(
|
|
703
|
-
[
|
|
704
|
-
np.array(
|
|
705
|
-
[
|
|
706
|
-
float(datum_idx),
|
|
707
|
-
-1.0,
|
|
708
|
-
float(prediction_idx),
|
|
709
|
-
-1.0,
|
|
710
|
-
float(self._add_label(plabel)),
|
|
711
|
-
0.0,
|
|
712
|
-
float(pscore),
|
|
713
|
-
]
|
|
714
|
-
)
|
|
715
|
-
for plabel, pscore in zip(pann.labels, pann.scores)
|
|
716
|
-
]
|
|
717
|
-
)
|
|
718
|
-
|
|
719
|
-
data = np.array(pairs)
|
|
720
|
-
if data.size > 0:
|
|
721
|
-
self.pairs.append(data)
|
|
722
|
-
|
|
723
|
-
def add_bounding_boxes(
|
|
724
|
-
self,
|
|
725
|
-
detections: list[Detection[BoundingBox]],
|
|
726
|
-
show_progress: bool = False,
|
|
727
|
-
):
|
|
728
|
-
"""
|
|
729
|
-
Adds bounding box detections to the cache.
|
|
730
|
-
|
|
731
|
-
Parameters
|
|
732
|
-
----------
|
|
733
|
-
detections : list[Detection]
|
|
734
|
-
A list of Detection objects.
|
|
735
|
-
show_progress : bool, default=False
|
|
736
|
-
Toggle for tqdm progress bar.
|
|
737
|
-
"""
|
|
738
|
-
ious = [
|
|
739
|
-
compute_bbox_iou(
|
|
740
|
-
np.array(
|
|
741
|
-
[
|
|
742
|
-
[gt.extrema, pd.extrema]
|
|
743
|
-
for pd in detection.predictions
|
|
744
|
-
for gt in detection.groundtruths
|
|
745
|
-
],
|
|
746
|
-
dtype=np.float64,
|
|
747
|
-
)
|
|
748
|
-
).reshape(len(detection.predictions), len(detection.groundtruths))
|
|
749
|
-
for detection in detections
|
|
750
|
-
]
|
|
751
|
-
return self._add_data(
|
|
752
|
-
detections=detections,
|
|
753
|
-
detection_ious=ious,
|
|
754
|
-
show_progress=show_progress,
|
|
755
|
-
)
|
|
756
|
-
|
|
757
|
-
def add_polygons(
|
|
758
|
-
self,
|
|
759
|
-
detections: list[Detection[Polygon]],
|
|
760
|
-
show_progress: bool = False,
|
|
761
|
-
):
|
|
762
|
-
"""
|
|
763
|
-
Adds polygon detections to the cache.
|
|
764
|
-
|
|
765
|
-
Parameters
|
|
766
|
-
----------
|
|
767
|
-
detections : list[Detection]
|
|
768
|
-
A list of Detection objects.
|
|
769
|
-
show_progress : bool, default=False
|
|
770
|
-
Toggle for tqdm progress bar.
|
|
771
|
-
"""
|
|
772
|
-
ious = [
|
|
773
|
-
compute_polygon_iou(
|
|
774
|
-
np.array(
|
|
775
|
-
[
|
|
776
|
-
[gt.shape, pd.shape]
|
|
777
|
-
for pd in detection.predictions
|
|
778
|
-
for gt in detection.groundtruths
|
|
779
|
-
]
|
|
780
|
-
)
|
|
781
|
-
).reshape(len(detection.predictions), len(detection.groundtruths))
|
|
782
|
-
for detection in detections
|
|
783
|
-
]
|
|
784
|
-
return self._add_data(
|
|
785
|
-
detections=detections,
|
|
786
|
-
detection_ious=ious,
|
|
787
|
-
show_progress=show_progress,
|
|
788
|
-
)
|
|
789
|
-
|
|
790
|
-
def add_bitmasks(
|
|
791
|
-
self,
|
|
792
|
-
detections: list[Detection[Bitmask]],
|
|
793
|
-
show_progress: bool = False,
|
|
794
|
-
):
|
|
795
|
-
"""
|
|
796
|
-
Adds bitmask detections to the cache.
|
|
797
|
-
|
|
798
|
-
Parameters
|
|
799
|
-
----------
|
|
800
|
-
detections : list[Detection]
|
|
801
|
-
A list of Detection objects.
|
|
802
|
-
show_progress : bool, default=False
|
|
803
|
-
Toggle for tqdm progress bar.
|
|
804
|
-
"""
|
|
805
|
-
ious = [
|
|
806
|
-
compute_bitmask_iou(
|
|
807
|
-
np.array(
|
|
808
|
-
[
|
|
809
|
-
[gt.mask, pd.mask]
|
|
810
|
-
for pd in detection.predictions
|
|
811
|
-
for gt in detection.groundtruths
|
|
812
|
-
]
|
|
813
|
-
)
|
|
814
|
-
).reshape(len(detection.predictions), len(detection.groundtruths))
|
|
815
|
-
for detection in detections
|
|
816
|
-
]
|
|
817
|
-
return self._add_data(
|
|
818
|
-
detections=detections,
|
|
819
|
-
detection_ious=ious,
|
|
820
|
-
show_progress=show_progress,
|
|
821
|
-
)
|
|
822
|
-
|
|
823
|
-
def finalize(self) -> Evaluator:
|
|
824
|
-
"""
|
|
825
|
-
Performs data finalization and some preprocessing steps.
|
|
826
|
-
|
|
827
|
-
Returns
|
|
828
|
-
-------
|
|
829
|
-
Evaluator
|
|
830
|
-
A ready-to-use evaluator object.
|
|
831
|
-
"""
|
|
832
|
-
if not self.pairs:
|
|
833
|
-
raise EmptyEvaluatorError()
|
|
834
|
-
|
|
835
|
-
n_labels = len(self._evaluator.index_to_label)
|
|
836
|
-
n_datums = len(self._evaluator.index_to_datum_id)
|
|
837
|
-
|
|
838
|
-
self._evaluator._detailed_pairs = np.concatenate(self.pairs, axis=0)
|
|
839
|
-
if self._evaluator._detailed_pairs.size == 0:
|
|
840
|
-
raise EmptyEvaluatorError()
|
|
841
|
-
|
|
842
|
-
# order pairs by descending score, iou
|
|
843
|
-
indices = np.lexsort(
|
|
844
|
-
(
|
|
845
|
-
-self._evaluator._detailed_pairs[:, 5], # iou
|
|
846
|
-
-self._evaluator._detailed_pairs[:, 6], # score
|
|
847
|
-
)
|
|
848
|
-
)
|
|
849
|
-
self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
|
|
850
|
-
indices
|
|
851
|
-
]
|
|
852
|
-
self._evaluator._label_metadata = compute_label_metadata(
|
|
853
|
-
ids=self._evaluator._detailed_pairs[:, :5].astype(np.int32),
|
|
854
|
-
n_labels=n_labels,
|
|
855
|
-
)
|
|
856
|
-
self._evaluator._ranked_pairs = rank_pairs(
|
|
857
|
-
detailed_pairs=self._evaluator._detailed_pairs,
|
|
858
|
-
label_metadata=self._evaluator._label_metadata,
|
|
859
|
-
)
|
|
860
|
-
self._evaluator._metadata = Metadata.create(
|
|
861
|
-
detailed_pairs=self._evaluator._detailed_pairs,
|
|
862
|
-
number_of_datums=n_datums,
|
|
863
|
-
number_of_labels=n_labels,
|
|
864
|
-
)
|
|
865
|
-
return self._evaluator
|