valor-lite 0.35.0__py3-none-any.whl → 0.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/classification/computation.py +147 -38
- valor_lite/classification/manager.py +221 -235
- valor_lite/classification/metric.py +5 -8
- valor_lite/classification/utilities.py +18 -14
- valor_lite/exceptions.py +15 -0
- valor_lite/object_detection/__init__.py +2 -1
- valor_lite/object_detection/computation.py +83 -10
- valor_lite/object_detection/manager.py +313 -315
- valor_lite/semantic_segmentation/__init__.py +3 -3
- valor_lite/semantic_segmentation/annotation.py +32 -103
- valor_lite/semantic_segmentation/benchmark.py +87 -1
- valor_lite/semantic_segmentation/computation.py +96 -14
- valor_lite/semantic_segmentation/manager.py +193 -221
- valor_lite/semantic_segmentation/utilities.py +3 -3
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/METADATA +2 -2
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/RECORD +18 -17
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/WHEEL +1 -1
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
import warnings
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
from numpy.typing import NDArray
|
|
5
6
|
from tqdm import tqdm
|
|
6
7
|
|
|
8
|
+
from valor_lite.exceptions import (
|
|
9
|
+
EmptyEvaluatorException,
|
|
10
|
+
EmptyFilterException,
|
|
11
|
+
InternalCacheException,
|
|
12
|
+
)
|
|
7
13
|
from valor_lite.object_detection.annotation import (
|
|
8
14
|
Bitmask,
|
|
9
15
|
BoundingBox,
|
|
@@ -17,6 +23,7 @@ from valor_lite.object_detection.computation import (
|
|
|
17
23
|
compute_label_metadata,
|
|
18
24
|
compute_polygon_iou,
|
|
19
25
|
compute_precion_recall,
|
|
26
|
+
filter_cache,
|
|
20
27
|
rank_pairs,
|
|
21
28
|
)
|
|
22
29
|
from valor_lite.object_detection.metric import Metric, MetricType
|
|
@@ -46,6 +53,68 @@ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_m
|
|
|
46
53
|
"""
|
|
47
54
|
|
|
48
55
|
|
|
56
|
+
@dataclass
|
|
57
|
+
class Metadata:
|
|
58
|
+
number_of_datums: int = 0
|
|
59
|
+
number_of_ground_truths: int = 0
|
|
60
|
+
number_of_predictions: int = 0
|
|
61
|
+
number_of_labels: int = 0
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def create(
|
|
65
|
+
cls,
|
|
66
|
+
detailed_pairs: NDArray[np.float64],
|
|
67
|
+
number_of_datums: int,
|
|
68
|
+
number_of_labels: int,
|
|
69
|
+
):
|
|
70
|
+
# count number of ground truths
|
|
71
|
+
mask_valid_gts = detailed_pairs[:, 1] >= 0
|
|
72
|
+
unique_ids = np.unique(
|
|
73
|
+
detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
|
|
74
|
+
)
|
|
75
|
+
number_of_ground_truths = int(unique_ids.shape[0])
|
|
76
|
+
|
|
77
|
+
# count number of predictions
|
|
78
|
+
mask_valid_pds = detailed_pairs[:, 2] >= 0
|
|
79
|
+
unique_ids = np.unique(
|
|
80
|
+
detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
|
|
81
|
+
)
|
|
82
|
+
number_of_predictions = int(unique_ids.shape[0])
|
|
83
|
+
|
|
84
|
+
return cls(
|
|
85
|
+
number_of_datums=number_of_datums,
|
|
86
|
+
number_of_ground_truths=number_of_ground_truths,
|
|
87
|
+
number_of_predictions=number_of_predictions,
|
|
88
|
+
number_of_labels=number_of_labels,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def to_dict(self) -> dict[str, int | bool]:
|
|
92
|
+
return asdict(self)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class Filter:
|
|
97
|
+
mask_datums: NDArray[np.bool_]
|
|
98
|
+
mask_groundtruths: NDArray[np.bool_]
|
|
99
|
+
mask_predictions: NDArray[np.bool_]
|
|
100
|
+
metadata: Metadata
|
|
101
|
+
|
|
102
|
+
def __post_init__(self):
|
|
103
|
+
# validate datums mask
|
|
104
|
+
if not self.mask_datums.any():
|
|
105
|
+
raise EmptyFilterException("filter removes all datums")
|
|
106
|
+
|
|
107
|
+
# validate annotation masks
|
|
108
|
+
no_gts = self.mask_groundtruths.all()
|
|
109
|
+
no_pds = self.mask_predictions.all()
|
|
110
|
+
if no_gts and no_pds:
|
|
111
|
+
raise EmptyFilterException("filter removes all annotations")
|
|
112
|
+
elif no_gts:
|
|
113
|
+
warnings.warn("filter removes all ground truths")
|
|
114
|
+
elif no_pds:
|
|
115
|
+
warnings.warn("filter removes all predictions")
|
|
116
|
+
|
|
117
|
+
|
|
49
118
|
class Evaluator:
|
|
50
119
|
"""
|
|
51
120
|
Object Detection Evaluator
|
|
@@ -67,80 +136,19 @@ class Evaluator:
|
|
|
67
136
|
# temporary cache
|
|
68
137
|
self._temp_cache: list[NDArray[np.float64]] | None = []
|
|
69
138
|
|
|
70
|
-
# cache
|
|
139
|
+
# internal cache
|
|
71
140
|
self._detailed_pairs = np.array([[]], dtype=np.float64)
|
|
72
141
|
self._ranked_pairs = np.array([[]], dtype=np.float64)
|
|
73
142
|
self._label_metadata: NDArray[np.int32] = np.array([[]])
|
|
74
|
-
|
|
75
|
-
# filter cache
|
|
76
|
-
self._filtered_detailed_pairs: NDArray[np.float64] | None = None
|
|
77
|
-
self._filtered_ranked_pairs: NDArray[np.float64] | None = None
|
|
78
|
-
self._filtered_label_metadata: NDArray[np.int32] | None = None
|
|
79
|
-
|
|
80
|
-
@property
|
|
81
|
-
def is_filtered(self) -> bool:
|
|
82
|
-
return self._filtered_detailed_pairs is not None
|
|
83
|
-
|
|
84
|
-
@property
|
|
85
|
-
def label_metadata(self) -> NDArray[np.int32]:
|
|
86
|
-
return (
|
|
87
|
-
self._filtered_label_metadata
|
|
88
|
-
if self._filtered_label_metadata is not None
|
|
89
|
-
else self._label_metadata
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
@property
|
|
93
|
-
def detailed_pairs(self) -> NDArray[np.float64]:
|
|
94
|
-
return (
|
|
95
|
-
self._filtered_detailed_pairs
|
|
96
|
-
if self._filtered_detailed_pairs is not None
|
|
97
|
-
else self._detailed_pairs
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
@property
|
|
101
|
-
def ranked_pairs(self) -> NDArray[np.float64]:
|
|
102
|
-
return (
|
|
103
|
-
self._filtered_ranked_pairs
|
|
104
|
-
if self._filtered_ranked_pairs is not None
|
|
105
|
-
else self._ranked_pairs
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
@property
|
|
109
|
-
def n_labels(self) -> int:
|
|
110
|
-
"""Returns the total number of unique labels."""
|
|
111
|
-
return len(self.index_to_label)
|
|
112
|
-
|
|
113
|
-
@property
|
|
114
|
-
def n_datums(self) -> int:
|
|
115
|
-
"""Returns the number of datums."""
|
|
116
|
-
return np.unique(self.detailed_pairs[:, 0]).size
|
|
117
|
-
|
|
118
|
-
@property
|
|
119
|
-
def n_groundtruths(self) -> int:
|
|
120
|
-
"""Returns the number of ground truth annotations."""
|
|
121
|
-
mask_valid_gts = self.detailed_pairs[:, 1] >= 0
|
|
122
|
-
unique_ids = np.unique(
|
|
123
|
-
self.detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
|
|
124
|
-
)
|
|
125
|
-
return int(unique_ids.shape[0])
|
|
126
|
-
|
|
127
|
-
@property
|
|
128
|
-
def n_predictions(self) -> int:
|
|
129
|
-
"""Returns the number of prediction annotations."""
|
|
130
|
-
mask_valid_pds = self.detailed_pairs[:, 2] >= 0
|
|
131
|
-
unique_ids = np.unique(
|
|
132
|
-
self.detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
|
|
133
|
-
)
|
|
134
|
-
return int(unique_ids.shape[0])
|
|
143
|
+
self._metadata = Metadata()
|
|
135
144
|
|
|
136
145
|
@property
|
|
137
146
|
def ignored_prediction_labels(self) -> list[str]:
|
|
138
147
|
"""
|
|
139
148
|
Prediction labels that are not present in the ground truth set.
|
|
140
149
|
"""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
plabels = set(np.where(label_metadata[:, 1] > 0)[0])
|
|
150
|
+
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
151
|
+
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
144
152
|
return [
|
|
145
153
|
self.index_to_label[label_id] for label_id in (plabels - glabels)
|
|
146
154
|
]
|
|
@@ -150,31 +158,157 @@ class Evaluator:
|
|
|
150
158
|
"""
|
|
151
159
|
Ground truth labels that are not present in the prediction set.
|
|
152
160
|
"""
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
plabels = set(np.where(label_metadata[:, 1] > 0)[0])
|
|
161
|
+
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
162
|
+
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
156
163
|
return [
|
|
157
164
|
self.index_to_label[label_id] for label_id in (glabels - plabels)
|
|
158
165
|
]
|
|
159
166
|
|
|
160
167
|
@property
|
|
161
|
-
def metadata(self) ->
|
|
168
|
+
def metadata(self) -> Metadata:
|
|
162
169
|
"""
|
|
163
170
|
Evaluation metadata.
|
|
164
171
|
"""
|
|
165
|
-
return
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
172
|
+
return self._metadata
|
|
173
|
+
|
|
174
|
+
def create_filter(
|
|
175
|
+
self,
|
|
176
|
+
datum_ids: list[str] | None = None,
|
|
177
|
+
groundtruth_ids: list[str] | None = None,
|
|
178
|
+
prediction_ids: list[str] | None = None,
|
|
179
|
+
labels: list[str] | None = None,
|
|
180
|
+
) -> Filter:
|
|
181
|
+
"""
|
|
182
|
+
Creates a filter object.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
datum_uids : list[str], optional
|
|
187
|
+
An optional list of string uids representing datums to keep.
|
|
188
|
+
groundtruth_ids : list[str], optional
|
|
189
|
+
An optional list of string uids representing ground truth annotations to keep.
|
|
190
|
+
prediction_ids : list[str], optional
|
|
191
|
+
An optional list of string uids representing prediction annotations to keep.
|
|
192
|
+
labels : list[str], optional
|
|
193
|
+
An optional list of labels to keep.
|
|
194
|
+
"""
|
|
195
|
+
mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
|
|
196
|
+
|
|
197
|
+
# filter datums
|
|
198
|
+
if datum_ids is not None:
|
|
199
|
+
if not datum_ids:
|
|
200
|
+
raise EmptyFilterException("filter removes all datums")
|
|
201
|
+
valid_datum_indices = np.array(
|
|
202
|
+
[self.datum_id_to_index[uid] for uid in datum_ids],
|
|
203
|
+
dtype=np.int32,
|
|
204
|
+
)
|
|
205
|
+
mask_datums = np.isin(
|
|
206
|
+
self._detailed_pairs[:, 0], valid_datum_indices
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
filtered_detailed_pairs = self._detailed_pairs[mask_datums]
|
|
210
|
+
n_pairs = self._detailed_pairs[mask_datums].shape[0]
|
|
211
|
+
mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
|
|
212
|
+
mask_predictions = np.zeros_like(mask_groundtruths)
|
|
213
|
+
|
|
214
|
+
# filter by ground truth annotation ids
|
|
215
|
+
if groundtruth_ids is not None:
|
|
216
|
+
valid_groundtruth_indices = np.array(
|
|
217
|
+
[self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
|
|
218
|
+
dtype=np.int32,
|
|
219
|
+
)
|
|
220
|
+
mask_groundtruths[
|
|
221
|
+
~np.isin(
|
|
222
|
+
filtered_detailed_pairs[:, 1],
|
|
223
|
+
valid_groundtruth_indices,
|
|
224
|
+
)
|
|
225
|
+
] = True
|
|
226
|
+
|
|
227
|
+
# filter by prediction annotation ids
|
|
228
|
+
if prediction_ids is not None:
|
|
229
|
+
valid_prediction_indices = np.array(
|
|
230
|
+
[self.prediction_id_to_index[uid] for uid in prediction_ids],
|
|
231
|
+
dtype=np.int32,
|
|
232
|
+
)
|
|
233
|
+
mask_predictions[
|
|
234
|
+
~np.isin(
|
|
235
|
+
filtered_detailed_pairs[:, 2],
|
|
236
|
+
valid_prediction_indices,
|
|
237
|
+
)
|
|
238
|
+
] = True
|
|
239
|
+
|
|
240
|
+
# filter by labels
|
|
241
|
+
if labels is not None:
|
|
242
|
+
if not labels:
|
|
243
|
+
raise EmptyFilterException("filter removes all labels")
|
|
244
|
+
valid_label_indices = np.array(
|
|
245
|
+
[self.label_to_index[label] for label in labels] + [-1]
|
|
246
|
+
)
|
|
247
|
+
mask_groundtruths[
|
|
248
|
+
~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
|
|
249
|
+
] = True
|
|
250
|
+
mask_predictions[
|
|
251
|
+
~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
|
|
252
|
+
] = True
|
|
253
|
+
|
|
254
|
+
filtered_detailed_pairs, _, _ = filter_cache(
|
|
255
|
+
self._detailed_pairs,
|
|
256
|
+
mask_datums=mask_datums,
|
|
257
|
+
mask_ground_truths=mask_groundtruths,
|
|
258
|
+
mask_predictions=mask_predictions,
|
|
259
|
+
n_labels=len(self.index_to_label),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
number_of_datums = (
|
|
263
|
+
len(datum_ids)
|
|
264
|
+
if datum_ids
|
|
265
|
+
else np.unique(filtered_detailed_pairs[:, 0]).size
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return Filter(
|
|
269
|
+
mask_datums=mask_datums,
|
|
270
|
+
mask_groundtruths=mask_groundtruths,
|
|
271
|
+
mask_predictions=mask_predictions,
|
|
272
|
+
metadata=Metadata.create(
|
|
273
|
+
detailed_pairs=filtered_detailed_pairs,
|
|
274
|
+
number_of_datums=number_of_datums,
|
|
275
|
+
number_of_labels=len(self.index_to_label),
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def filter(
|
|
280
|
+
self, filter_: Filter
|
|
281
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
|
|
282
|
+
"""
|
|
283
|
+
Performs filtering over the internal cache.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
filter_ : Filter
|
|
288
|
+
The filter parameterization.
|
|
289
|
+
|
|
290
|
+
Returns
|
|
291
|
+
-------
|
|
292
|
+
NDArray[float64]
|
|
293
|
+
Filtered detailed pairs.
|
|
294
|
+
NDArray[float64]
|
|
295
|
+
Filtered ranked pairs.
|
|
296
|
+
NDArray[int32]
|
|
297
|
+
Label metadata.
|
|
298
|
+
"""
|
|
299
|
+
return filter_cache(
|
|
300
|
+
detailed_pairs=self._detailed_pairs,
|
|
301
|
+
mask_datums=filter_.mask_datums,
|
|
302
|
+
mask_ground_truths=filter_.mask_groundtruths,
|
|
303
|
+
mask_predictions=filter_.mask_predictions,
|
|
304
|
+
n_labels=len(self.index_to_label),
|
|
305
|
+
)
|
|
173
306
|
|
|
174
307
|
def compute_precision_recall(
|
|
175
308
|
self,
|
|
176
309
|
iou_thresholds: list[float],
|
|
177
310
|
score_thresholds: list[float],
|
|
311
|
+
filter_: Filter | None = None,
|
|
178
312
|
) -> dict[MetricType, list[Metric]]:
|
|
179
313
|
"""
|
|
180
314
|
Computes all metrics except for ConfusionMatrix
|
|
@@ -185,6 +319,8 @@ class Evaluator:
|
|
|
185
319
|
A list of IOU thresholds to compute metrics over.
|
|
186
320
|
score_thresholds : list[float]
|
|
187
321
|
A list of score thresholds to compute metrics over.
|
|
322
|
+
filter_ : Filter, optional
|
|
323
|
+
A collection of filter parameters and masks.
|
|
188
324
|
|
|
189
325
|
Returns
|
|
190
326
|
-------
|
|
@@ -195,15 +331,22 @@ class Evaluator:
|
|
|
195
331
|
raise ValueError("At least one IOU threshold must be passed.")
|
|
196
332
|
elif not score_thresholds:
|
|
197
333
|
raise ValueError("At least one score threshold must be passed.")
|
|
334
|
+
|
|
335
|
+
if filter_ is not None:
|
|
336
|
+
_, ranked_pairs, label_metadata = self.filter(filter_=filter_)
|
|
337
|
+
else:
|
|
338
|
+
ranked_pairs = self._ranked_pairs
|
|
339
|
+
label_metadata = self._label_metadata
|
|
340
|
+
|
|
198
341
|
results = compute_precion_recall(
|
|
199
|
-
ranked_pairs=
|
|
200
|
-
label_metadata=
|
|
342
|
+
ranked_pairs=ranked_pairs,
|
|
343
|
+
label_metadata=label_metadata,
|
|
201
344
|
iou_thresholds=np.array(iou_thresholds),
|
|
202
345
|
score_thresholds=np.array(score_thresholds),
|
|
203
346
|
)
|
|
204
347
|
return unpack_precision_recall_into_metric_lists(
|
|
205
348
|
results=results,
|
|
206
|
-
label_metadata=
|
|
349
|
+
label_metadata=label_metadata,
|
|
207
350
|
iou_thresholds=iou_thresholds,
|
|
208
351
|
score_thresholds=score_thresholds,
|
|
209
352
|
index_to_label=self.index_to_label,
|
|
@@ -213,6 +356,7 @@ class Evaluator:
|
|
|
213
356
|
self,
|
|
214
357
|
iou_thresholds: list[float],
|
|
215
358
|
score_thresholds: list[float],
|
|
359
|
+
filter_: Filter | None = None,
|
|
216
360
|
) -> list[Metric]:
|
|
217
361
|
"""
|
|
218
362
|
Computes confusion matrices at various thresholds.
|
|
@@ -223,6 +367,8 @@ class Evaluator:
|
|
|
223
367
|
A list of IOU thresholds to compute metrics over.
|
|
224
368
|
score_thresholds : list[float]
|
|
225
369
|
A list of score thresholds to compute metrics over.
|
|
370
|
+
filter_ : Filter, optional
|
|
371
|
+
A collection of filter parameters and masks.
|
|
226
372
|
|
|
227
373
|
Returns
|
|
228
374
|
-------
|
|
@@ -233,17 +379,23 @@ class Evaluator:
|
|
|
233
379
|
raise ValueError("At least one IOU threshold must be passed.")
|
|
234
380
|
elif not score_thresholds:
|
|
235
381
|
raise ValueError("At least one score threshold must be passed.")
|
|
236
|
-
|
|
237
|
-
|
|
382
|
+
|
|
383
|
+
if filter_ is not None:
|
|
384
|
+
detailed_pairs, _, _ = self.filter(filter_=filter_)
|
|
385
|
+
else:
|
|
386
|
+
detailed_pairs = self._detailed_pairs
|
|
387
|
+
|
|
388
|
+
if detailed_pairs.size == 0:
|
|
238
389
|
return []
|
|
390
|
+
|
|
239
391
|
results = compute_confusion_matrix(
|
|
240
|
-
detailed_pairs=
|
|
392
|
+
detailed_pairs=detailed_pairs,
|
|
241
393
|
iou_thresholds=np.array(iou_thresholds),
|
|
242
394
|
score_thresholds=np.array(score_thresholds),
|
|
243
395
|
)
|
|
244
396
|
return unpack_confusion_matrix_into_metric_list(
|
|
245
397
|
results=results,
|
|
246
|
-
detailed_pairs=
|
|
398
|
+
detailed_pairs=detailed_pairs,
|
|
247
399
|
iou_thresholds=iou_thresholds,
|
|
248
400
|
score_thresholds=score_thresholds,
|
|
249
401
|
index_to_datum_id=self.index_to_datum_id,
|
|
@@ -256,6 +408,7 @@ class Evaluator:
|
|
|
256
408
|
self,
|
|
257
409
|
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
|
|
258
410
|
score_thresholds: list[float] = [0.5],
|
|
411
|
+
filter_: Filter | None = None,
|
|
259
412
|
) -> dict[MetricType, list[Metric]]:
|
|
260
413
|
"""
|
|
261
414
|
Computes all available metrics.
|
|
@@ -266,6 +419,8 @@ class Evaluator:
|
|
|
266
419
|
A list of IOU thresholds to compute metrics over.
|
|
267
420
|
score_thresholds : list[float], default=[0.5]
|
|
268
421
|
A list of score thresholds to compute metrics over.
|
|
422
|
+
filter_ : Filter, optional
|
|
423
|
+
A collection of filter parameters and masks.
|
|
269
424
|
|
|
270
425
|
Returns
|
|
271
426
|
-------
|
|
@@ -275,13 +430,25 @@ class Evaluator:
|
|
|
275
430
|
metrics = self.compute_precision_recall(
|
|
276
431
|
iou_thresholds=iou_thresholds,
|
|
277
432
|
score_thresholds=score_thresholds,
|
|
433
|
+
filter_=filter_,
|
|
278
434
|
)
|
|
279
435
|
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
280
436
|
iou_thresholds=iou_thresholds,
|
|
281
437
|
score_thresholds=score_thresholds,
|
|
438
|
+
filter_=filter_,
|
|
282
439
|
)
|
|
283
440
|
return metrics
|
|
284
441
|
|
|
442
|
+
|
|
443
|
+
class DataLoader:
|
|
444
|
+
"""
|
|
445
|
+
Object Detection DataLoader
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
def __init__(self):
|
|
449
|
+
self._evaluator = Evaluator()
|
|
450
|
+
self.pairs: list[NDArray[np.float64]] = list()
|
|
451
|
+
|
|
285
452
|
def _add_datum(self, datum_id: str) -> int:
|
|
286
453
|
"""
|
|
287
454
|
Helper function for adding a datum to the cache.
|
|
@@ -296,13 +463,15 @@ class Evaluator:
|
|
|
296
463
|
int
|
|
297
464
|
The datum index.
|
|
298
465
|
"""
|
|
299
|
-
if datum_id not in self.datum_id_to_index:
|
|
300
|
-
if len(self.datum_id_to_index) != len(
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
self.
|
|
305
|
-
|
|
466
|
+
if datum_id not in self._evaluator.datum_id_to_index:
|
|
467
|
+
if len(self._evaluator.datum_id_to_index) != len(
|
|
468
|
+
self._evaluator.index_to_datum_id
|
|
469
|
+
):
|
|
470
|
+
raise InternalCacheException("datum cache size mismatch")
|
|
471
|
+
idx = len(self._evaluator.datum_id_to_index)
|
|
472
|
+
self._evaluator.datum_id_to_index[datum_id] = idx
|
|
473
|
+
self._evaluator.index_to_datum_id.append(datum_id)
|
|
474
|
+
return self._evaluator.datum_id_to_index[datum_id]
|
|
306
475
|
|
|
307
476
|
def _add_groundtruth(self, annotation_id: str) -> int:
|
|
308
477
|
"""
|
|
@@ -318,15 +487,17 @@ class Evaluator:
|
|
|
318
487
|
int
|
|
319
488
|
The ground truth annotation index.
|
|
320
489
|
"""
|
|
321
|
-
if annotation_id not in self.groundtruth_id_to_index:
|
|
322
|
-
if len(self.groundtruth_id_to_index) != len(
|
|
323
|
-
self.index_to_groundtruth_id
|
|
490
|
+
if annotation_id not in self._evaluator.groundtruth_id_to_index:
|
|
491
|
+
if len(self._evaluator.groundtruth_id_to_index) != len(
|
|
492
|
+
self._evaluator.index_to_groundtruth_id
|
|
324
493
|
):
|
|
325
|
-
raise
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
self.
|
|
329
|
-
|
|
494
|
+
raise InternalCacheException(
|
|
495
|
+
"ground truth cache size mismatch"
|
|
496
|
+
)
|
|
497
|
+
idx = len(self._evaluator.groundtruth_id_to_index)
|
|
498
|
+
self._evaluator.groundtruth_id_to_index[annotation_id] = idx
|
|
499
|
+
self._evaluator.index_to_groundtruth_id.append(annotation_id)
|
|
500
|
+
return self._evaluator.groundtruth_id_to_index[annotation_id]
|
|
330
501
|
|
|
331
502
|
def _add_prediction(self, annotation_id: str) -> int:
|
|
332
503
|
"""
|
|
@@ -342,15 +513,15 @@ class Evaluator:
|
|
|
342
513
|
int
|
|
343
514
|
The prediction annotation index.
|
|
344
515
|
"""
|
|
345
|
-
if annotation_id not in self.prediction_id_to_index:
|
|
346
|
-
if len(self.prediction_id_to_index) != len(
|
|
347
|
-
self.index_to_prediction_id
|
|
516
|
+
if annotation_id not in self._evaluator.prediction_id_to_index:
|
|
517
|
+
if len(self._evaluator.prediction_id_to_index) != len(
|
|
518
|
+
self._evaluator.index_to_prediction_id
|
|
348
519
|
):
|
|
349
|
-
raise
|
|
350
|
-
idx = len(self.prediction_id_to_index)
|
|
351
|
-
self.prediction_id_to_index[annotation_id] = idx
|
|
352
|
-
self.index_to_prediction_id.append(annotation_id)
|
|
353
|
-
return self.prediction_id_to_index[annotation_id]
|
|
520
|
+
raise InternalCacheException("prediction cache size mismatch")
|
|
521
|
+
idx = len(self._evaluator.prediction_id_to_index)
|
|
522
|
+
self._evaluator.prediction_id_to_index[annotation_id] = idx
|
|
523
|
+
self._evaluator.index_to_prediction_id.append(annotation_id)
|
|
524
|
+
return self._evaluator.prediction_id_to_index[annotation_id]
|
|
354
525
|
|
|
355
526
|
def _add_label(self, label: str) -> int:
|
|
356
527
|
"""
|
|
@@ -366,14 +537,16 @@ class Evaluator:
|
|
|
366
537
|
int
|
|
367
538
|
Label index.
|
|
368
539
|
"""
|
|
369
|
-
label_id = len(self.index_to_label)
|
|
370
|
-
if label not in self.label_to_index:
|
|
371
|
-
if len(self.label_to_index) != len(
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
540
|
+
label_id = len(self._evaluator.index_to_label)
|
|
541
|
+
if label not in self._evaluator.label_to_index:
|
|
542
|
+
if len(self._evaluator.label_to_index) != len(
|
|
543
|
+
self._evaluator.index_to_label
|
|
544
|
+
):
|
|
545
|
+
raise InternalCacheException("label cache size mismatch")
|
|
546
|
+
self._evaluator.label_to_index[label] = label_id
|
|
547
|
+
self._evaluator.index_to_label.append(label)
|
|
375
548
|
label_id += 1
|
|
376
|
-
return self.label_to_index[label]
|
|
549
|
+
return self._evaluator.label_to_index[label]
|
|
377
550
|
|
|
378
551
|
def _add_data(
|
|
379
552
|
self,
|
|
@@ -483,13 +656,7 @@ class Evaluator:
|
|
|
483
656
|
|
|
484
657
|
data = np.array(pairs)
|
|
485
658
|
if data.size > 0:
|
|
486
|
-
|
|
487
|
-
self.clear_filter()
|
|
488
|
-
if self._temp_cache is None:
|
|
489
|
-
raise RuntimeError(
|
|
490
|
-
"cannot add data as evaluator has already been finalized"
|
|
491
|
-
)
|
|
492
|
-
self._temp_cache.append(data)
|
|
659
|
+
self.pairs.append(data)
|
|
493
660
|
|
|
494
661
|
def add_bounding_boxes(
|
|
495
662
|
self,
|
|
@@ -591,7 +758,7 @@ class Evaluator:
|
|
|
591
758
|
show_progress=show_progress,
|
|
592
759
|
)
|
|
593
760
|
|
|
594
|
-
def finalize(self):
|
|
761
|
+
def finalize(self) -> Evaluator:
|
|
595
762
|
"""
|
|
596
763
|
Performs data finalization and some preprocessing steps.
|
|
597
764
|
|
|
@@ -600,206 +767,37 @@ class Evaluator:
|
|
|
600
767
|
Evaluator
|
|
601
768
|
A ready-to-use evaluator object.
|
|
602
769
|
"""
|
|
603
|
-
if self.
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
else:
|
|
613
|
-
self._detailed_pairs = np.concatenate(self._temp_cache, axis=0)
|
|
614
|
-
self._temp_cache = None
|
|
770
|
+
if not self.pairs:
|
|
771
|
+
raise EmptyEvaluatorException()
|
|
772
|
+
|
|
773
|
+
n_labels = len(self._evaluator.index_to_label)
|
|
774
|
+
n_datums = len(self._evaluator.index_to_datum_id)
|
|
775
|
+
|
|
776
|
+
self._evaluator._detailed_pairs = np.concatenate(self.pairs, axis=0)
|
|
777
|
+
if self._evaluator._detailed_pairs.size == 0:
|
|
778
|
+
raise EmptyEvaluatorException()
|
|
615
779
|
|
|
616
780
|
# order pairs by descending score, iou
|
|
617
781
|
indices = np.lexsort(
|
|
618
782
|
(
|
|
619
|
-
-self._detailed_pairs[:, 5], # iou
|
|
620
|
-
-self._detailed_pairs[:, 6], # score
|
|
783
|
+
-self._evaluator._detailed_pairs[:, 5], # iou
|
|
784
|
+
-self._evaluator._detailed_pairs[:, 6], # score
|
|
621
785
|
)
|
|
622
786
|
)
|
|
623
|
-
self._detailed_pairs = self._detailed_pairs[
|
|
624
|
-
|
|
625
|
-
ids=self._detailed_pairs[:, :5].astype(np.int32),
|
|
626
|
-
n_labels=self.n_labels,
|
|
627
|
-
)
|
|
628
|
-
self._ranked_pairs = rank_pairs(
|
|
629
|
-
detailed_pairs=self.detailed_pairs,
|
|
630
|
-
label_metadata=self._label_metadata,
|
|
631
|
-
)
|
|
632
|
-
return self
|
|
633
|
-
|
|
634
|
-
def apply_filter(
|
|
635
|
-
self,
|
|
636
|
-
datum_ids: list[str] | None = None,
|
|
637
|
-
groundtruth_ids: list[str] | None = None,
|
|
638
|
-
prediction_ids: list[str] | None = None,
|
|
639
|
-
labels: list[str] | None = None,
|
|
640
|
-
):
|
|
641
|
-
"""
|
|
642
|
-
Apply a filter on the evaluator.
|
|
643
|
-
|
|
644
|
-
Can be reset by calling 'clear_filter'.
|
|
645
|
-
|
|
646
|
-
Parameters
|
|
647
|
-
----------
|
|
648
|
-
datum_uids : list[str], optional
|
|
649
|
-
An optional list of string uids representing datums.
|
|
650
|
-
groundtruth_ids : list[str], optional
|
|
651
|
-
An optional list of string uids representing ground truth annotations.
|
|
652
|
-
prediction_ids : list[str], optional
|
|
653
|
-
An optional list of string uids representing prediction annotations.
|
|
654
|
-
labels : list[str], optional
|
|
655
|
-
An optional list of labels.
|
|
656
|
-
"""
|
|
657
|
-
self._filtered_detailed_pairs = self._detailed_pairs.copy()
|
|
658
|
-
self._filtered_ranked_pairs = np.array([], dtype=np.float64)
|
|
659
|
-
self._filtered_label_metadata = np.zeros(
|
|
660
|
-
(self.n_labels, 2), dtype=np.int32
|
|
661
|
-
)
|
|
662
|
-
|
|
663
|
-
valid_datum_indices = None
|
|
664
|
-
if datum_ids is not None:
|
|
665
|
-
if not datum_ids:
|
|
666
|
-
self._filtered_detailed_pairs = np.array([], dtype=np.float64)
|
|
667
|
-
warnings.warn("no valid filtered pairs")
|
|
668
|
-
return
|
|
669
|
-
valid_datum_indices = np.array(
|
|
670
|
-
[self.datum_id_to_index[uid] for uid in datum_ids],
|
|
671
|
-
dtype=np.int32,
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
valid_groundtruth_indices = None
|
|
675
|
-
if groundtruth_ids is not None:
|
|
676
|
-
valid_groundtruth_indices = np.array(
|
|
677
|
-
[self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
|
|
678
|
-
dtype=np.int32,
|
|
679
|
-
)
|
|
680
|
-
|
|
681
|
-
valid_prediction_indices = None
|
|
682
|
-
if prediction_ids is not None:
|
|
683
|
-
valid_prediction_indices = np.array(
|
|
684
|
-
[self.prediction_id_to_index[uid] for uid in prediction_ids],
|
|
685
|
-
dtype=np.int32,
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
valid_label_indices = None
|
|
689
|
-
if labels is not None:
|
|
690
|
-
if not labels:
|
|
691
|
-
self._filtered_detailed_pairs = np.array([], dtype=np.float64)
|
|
692
|
-
warnings.warn("no valid filtered pairs")
|
|
693
|
-
return
|
|
694
|
-
valid_label_indices = np.array(
|
|
695
|
-
[self.label_to_index[label] for label in labels] + [-1]
|
|
696
|
-
)
|
|
697
|
-
|
|
698
|
-
# filter datums
|
|
699
|
-
if valid_datum_indices is not None:
|
|
700
|
-
mask_valid_datums = np.isin(
|
|
701
|
-
self._filtered_detailed_pairs[:, 0], valid_datum_indices
|
|
702
|
-
)
|
|
703
|
-
self._filtered_detailed_pairs = self._filtered_detailed_pairs[
|
|
704
|
-
mask_valid_datums
|
|
705
|
-
]
|
|
706
|
-
|
|
707
|
-
n_rows = self._filtered_detailed_pairs.shape[0]
|
|
708
|
-
mask_invalid_groundtruths = np.zeros(n_rows, dtype=np.bool_)
|
|
709
|
-
mask_invalid_predictions = np.zeros_like(mask_invalid_groundtruths)
|
|
710
|
-
|
|
711
|
-
# filter ground truth annotations
|
|
712
|
-
if valid_groundtruth_indices is not None:
|
|
713
|
-
mask_invalid_groundtruths[
|
|
714
|
-
~np.isin(
|
|
715
|
-
self._filtered_detailed_pairs[:, 1],
|
|
716
|
-
valid_groundtruth_indices,
|
|
717
|
-
)
|
|
718
|
-
] = True
|
|
719
|
-
|
|
720
|
-
# filter prediction annotations
|
|
721
|
-
if valid_prediction_indices is not None:
|
|
722
|
-
mask_invalid_predictions[
|
|
723
|
-
~np.isin(
|
|
724
|
-
self._filtered_detailed_pairs[:, 2],
|
|
725
|
-
valid_prediction_indices,
|
|
726
|
-
)
|
|
727
|
-
] = True
|
|
728
|
-
|
|
729
|
-
# filter labels
|
|
730
|
-
if valid_label_indices is not None:
|
|
731
|
-
mask_invalid_groundtruths[
|
|
732
|
-
~np.isin(
|
|
733
|
-
self._filtered_detailed_pairs[:, 3], valid_label_indices
|
|
734
|
-
)
|
|
735
|
-
] = True
|
|
736
|
-
mask_invalid_predictions[
|
|
737
|
-
~np.isin(
|
|
738
|
-
self._filtered_detailed_pairs[:, 4], valid_label_indices
|
|
739
|
-
)
|
|
740
|
-
] = True
|
|
741
|
-
|
|
742
|
-
# filter cache
|
|
743
|
-
if mask_invalid_groundtruths.any():
|
|
744
|
-
invalid_groundtruth_indices = np.where(mask_invalid_groundtruths)[
|
|
745
|
-
0
|
|
746
|
-
]
|
|
747
|
-
self._filtered_detailed_pairs[
|
|
748
|
-
invalid_groundtruth_indices[:, None], (1, 3, 5)
|
|
749
|
-
] = np.array([[-1, -1, 0]])
|
|
750
|
-
|
|
751
|
-
if mask_invalid_predictions.any():
|
|
752
|
-
invalid_prediction_indices = np.where(mask_invalid_predictions)[0]
|
|
753
|
-
self._filtered_detailed_pairs[
|
|
754
|
-
invalid_prediction_indices[:, None], (2, 4, 5, 6)
|
|
755
|
-
] = np.array([[-1, -1, 0, -1]])
|
|
756
|
-
|
|
757
|
-
# filter null pairs
|
|
758
|
-
mask_null_pairs = np.all(
|
|
759
|
-
np.isclose(
|
|
760
|
-
self._filtered_detailed_pairs[:, 1:5],
|
|
761
|
-
np.array([-1.0, -1.0, -1.0, -1.0]),
|
|
762
|
-
),
|
|
763
|
-
axis=1,
|
|
764
|
-
)
|
|
765
|
-
self._filtered_detailed_pairs = self._filtered_detailed_pairs[
|
|
766
|
-
~mask_null_pairs
|
|
787
|
+
self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
|
|
788
|
+
indices
|
|
767
789
|
]
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
|
|
772
|
-
warnings.warn("no valid filtered pairs")
|
|
773
|
-
return
|
|
774
|
-
|
|
775
|
-
# sorts by score, iou with ground truth id as a tie-breaker
|
|
776
|
-
indices = np.lexsort(
|
|
777
|
-
(
|
|
778
|
-
self._filtered_detailed_pairs[:, 1], # ground truth id
|
|
779
|
-
-self._filtered_detailed_pairs[:, 5], # iou
|
|
780
|
-
-self._filtered_detailed_pairs[:, 6], # score
|
|
781
|
-
)
|
|
790
|
+
self._evaluator._label_metadata = compute_label_metadata(
|
|
791
|
+
ids=self._evaluator._detailed_pairs[:, :5].astype(np.int32),
|
|
792
|
+
n_labels=n_labels,
|
|
782
793
|
)
|
|
783
|
-
self.
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
n_labels=self.n_labels,
|
|
794
|
+
self._evaluator._ranked_pairs = rank_pairs(
|
|
795
|
+
detailed_pairs=self._evaluator._detailed_pairs,
|
|
796
|
+
label_metadata=self._evaluator._label_metadata,
|
|
787
797
|
)
|
|
788
|
-
self.
|
|
789
|
-
detailed_pairs=self.
|
|
790
|
-
|
|
798
|
+
self._evaluator._metadata = Metadata.create(
|
|
799
|
+
detailed_pairs=self._evaluator._detailed_pairs,
|
|
800
|
+
number_of_datums=n_datums,
|
|
801
|
+
number_of_labels=n_labels,
|
|
791
802
|
)
|
|
792
|
-
|
|
793
|
-
def clear_filter(self):
|
|
794
|
-
"""Removes a filter if one exists."""
|
|
795
|
-
self._filtered_detailed_pairs = None
|
|
796
|
-
self._filtered_ranked_pairs = None
|
|
797
|
-
self._filtered_label_metadata = None
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
class DataLoader(Evaluator):
|
|
801
|
-
"""
|
|
802
|
-
Used for backwards compatibility as the Evaluator now handles ingestion.
|
|
803
|
-
"""
|
|
804
|
-
|
|
805
|
-
pass
|
|
803
|
+
return self._evaluator
|