valor-lite 0.34.3__py3-none-any.whl → 0.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/classification/computation.py +147 -38
- valor_lite/classification/manager.py +229 -236
- valor_lite/classification/metric.py +5 -8
- valor_lite/classification/utilities.py +18 -14
- valor_lite/object_detection/__init__.py +2 -15
- valor_lite/object_detection/annotation.py +24 -48
- valor_lite/object_detection/computation.py +324 -384
- valor_lite/object_detection/manager.py +549 -456
- valor_lite/object_detection/metric.py +16 -34
- valor_lite/object_detection/utilities.py +134 -305
- valor_lite/semantic_segmentation/__init__.py +3 -3
- valor_lite/semantic_segmentation/annotation.py +32 -103
- valor_lite/semantic_segmentation/benchmark.py +87 -1
- valor_lite/semantic_segmentation/computation.py +96 -14
- valor_lite/semantic_segmentation/manager.py +199 -222
- valor_lite/semantic_segmentation/utilities.py +3 -3
- {valor_lite-0.34.3.dist-info → valor_lite-0.36.0.dist-info}/METADATA +2 -2
- {valor_lite-0.34.3.dist-info → valor_lite-0.36.0.dist-info}/RECORD +20 -20
- {valor_lite-0.34.3.dist-info → valor_lite-0.36.0.dist-info}/WHEEL +1 -1
- {valor_lite-0.34.3.dist-info → valor_lite-0.36.0.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,25 @@
|
|
|
1
|
-
|
|
2
|
-
from dataclasses import dataclass
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
|
|
8
|
-
from valor_lite.object_detection.annotation import
|
|
8
|
+
from valor_lite.object_detection.annotation import (
|
|
9
|
+
Bitmask,
|
|
10
|
+
BoundingBox,
|
|
11
|
+
Detection,
|
|
12
|
+
Polygon,
|
|
13
|
+
)
|
|
9
14
|
from valor_lite.object_detection.computation import (
|
|
10
15
|
compute_bbox_iou,
|
|
11
16
|
compute_bitmask_iou,
|
|
12
17
|
compute_confusion_matrix,
|
|
18
|
+
compute_label_metadata,
|
|
13
19
|
compute_polygon_iou,
|
|
14
20
|
compute_precion_recall,
|
|
15
|
-
|
|
21
|
+
filter_cache,
|
|
22
|
+
rank_pairs,
|
|
16
23
|
)
|
|
17
24
|
from valor_lite.object_detection.metric import Metric, MetricType
|
|
18
25
|
from valor_lite.object_detection.utilities import (
|
|
@@ -41,11 +48,51 @@ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_m
|
|
|
41
48
|
"""
|
|
42
49
|
|
|
43
50
|
|
|
51
|
+
@dataclass
|
|
52
|
+
class Metadata:
|
|
53
|
+
number_of_datums: int = 0
|
|
54
|
+
number_of_ground_truths: int = 0
|
|
55
|
+
number_of_predictions: int = 0
|
|
56
|
+
number_of_labels: int = 0
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def create(
|
|
60
|
+
cls,
|
|
61
|
+
detailed_pairs: NDArray[np.float64],
|
|
62
|
+
number_of_datums: int,
|
|
63
|
+
number_of_labels: int,
|
|
64
|
+
):
|
|
65
|
+
# count number of ground truths
|
|
66
|
+
mask_valid_gts = detailed_pairs[:, 1] >= 0
|
|
67
|
+
unique_ids = np.unique(
|
|
68
|
+
detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
|
|
69
|
+
)
|
|
70
|
+
number_of_ground_truths = int(unique_ids.shape[0])
|
|
71
|
+
|
|
72
|
+
# count number of predictions
|
|
73
|
+
mask_valid_pds = detailed_pairs[:, 2] >= 0
|
|
74
|
+
unique_ids = np.unique(
|
|
75
|
+
detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
|
|
76
|
+
)
|
|
77
|
+
number_of_predictions = int(unique_ids.shape[0])
|
|
78
|
+
|
|
79
|
+
return cls(
|
|
80
|
+
number_of_datums=number_of_datums,
|
|
81
|
+
number_of_ground_truths=number_of_ground_truths,
|
|
82
|
+
number_of_predictions=number_of_predictions,
|
|
83
|
+
number_of_labels=number_of_labels,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def to_dict(self) -> dict[str, int | bool]:
|
|
87
|
+
return asdict(self)
|
|
88
|
+
|
|
89
|
+
|
|
44
90
|
@dataclass
|
|
45
91
|
class Filter:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
92
|
+
mask_datums: NDArray[np.bool_]
|
|
93
|
+
mask_groundtruths: NDArray[np.bool_]
|
|
94
|
+
mask_predictions: NDArray[np.bool_]
|
|
95
|
+
metadata: Metadata
|
|
49
96
|
|
|
50
97
|
|
|
51
98
|
class Evaluator:
|
|
@@ -55,29 +102,25 @@ class Evaluator:
|
|
|
55
102
|
|
|
56
103
|
def __init__(self):
|
|
57
104
|
|
|
58
|
-
#
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self.
|
|
62
|
-
self.
|
|
105
|
+
# external reference
|
|
106
|
+
self.datum_id_to_index: dict[str, int] = {}
|
|
107
|
+
self.groundtruth_id_to_index: dict[str, int] = {}
|
|
108
|
+
self.prediction_id_to_index: dict[str, int] = {}
|
|
109
|
+
self.label_to_index: dict[str, int] = {}
|
|
63
110
|
|
|
64
|
-
|
|
65
|
-
self.
|
|
66
|
-
self.
|
|
111
|
+
self.index_to_datum_id: list[str] = []
|
|
112
|
+
self.index_to_groundtruth_id: list[str] = []
|
|
113
|
+
self.index_to_prediction_id: list[str] = []
|
|
114
|
+
self.index_to_label: list[str] = []
|
|
67
115
|
|
|
68
|
-
#
|
|
69
|
-
self.
|
|
70
|
-
self.prediction_examples: dict[int, NDArray[np.float16]] = dict()
|
|
116
|
+
# temporary cache
|
|
117
|
+
self._temp_cache: list[NDArray[np.float64]] | None = []
|
|
71
118
|
|
|
72
|
-
#
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
self._detailed_pairs: NDArray[np.float64] = np.array([])
|
|
78
|
-
self._ranked_pairs: NDArray[np.float64] = np.array([])
|
|
79
|
-
self._label_metadata: NDArray[np.int32] = np.array([])
|
|
80
|
-
self._label_metadata_per_datum: NDArray[np.int32] = np.array([])
|
|
119
|
+
# internal cache
|
|
120
|
+
self._detailed_pairs = np.array([[]], dtype=np.float64)
|
|
121
|
+
self._ranked_pairs = np.array([[]], dtype=np.float64)
|
|
122
|
+
self._label_metadata: NDArray[np.int32] = np.array([[]])
|
|
123
|
+
self._metadata = Metadata()
|
|
81
124
|
|
|
82
125
|
@property
|
|
83
126
|
def ignored_prediction_labels(self) -> list[str]:
|
|
@@ -102,273 +145,81 @@ class Evaluator:
|
|
|
102
145
|
]
|
|
103
146
|
|
|
104
147
|
@property
|
|
105
|
-
def metadata(self) ->
|
|
148
|
+
def metadata(self) -> Metadata:
|
|
106
149
|
"""
|
|
107
150
|
Evaluation metadata.
|
|
108
151
|
"""
|
|
109
|
-
return
|
|
110
|
-
"n_datums": self.n_datums,
|
|
111
|
-
"n_groundtruths": self.n_groundtruths,
|
|
112
|
-
"n_predictions": self.n_predictions,
|
|
113
|
-
"n_labels": self.n_labels,
|
|
114
|
-
"ignored_prediction_labels": self.ignored_prediction_labels,
|
|
115
|
-
"missing_prediction_labels": self.missing_prediction_labels,
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
def create_filter(
|
|
119
|
-
self,
|
|
120
|
-
datum_uids: list[str] | NDArray[np.int32] | None = None,
|
|
121
|
-
labels: list[str] | NDArray[np.int32] | None = None,
|
|
122
|
-
) -> Filter:
|
|
123
|
-
"""
|
|
124
|
-
Creates a filter that can be passed to an evaluation.
|
|
152
|
+
return self._metadata
|
|
125
153
|
|
|
126
|
-
|
|
127
|
-
----------
|
|
128
|
-
datum_uids : list[str] | NDArray[np.int32], optional
|
|
129
|
-
An optional list of string uids or a numpy array of uid indices.
|
|
130
|
-
labels : list[str] | NDArray[np.int32], optional
|
|
131
|
-
An optional list of labels or a numpy array of label indices.
|
|
132
|
-
|
|
133
|
-
Returns
|
|
134
|
-
-------
|
|
135
|
-
Filter
|
|
136
|
-
A filter object that can be passed to the `evaluate` method.
|
|
154
|
+
def _add_datum(self, datum_id: str) -> int:
|
|
137
155
|
"""
|
|
138
|
-
|
|
139
|
-
n_datums = self._label_metadata_per_datum.shape[1]
|
|
140
|
-
n_labels = self._label_metadata_per_datum.shape[2]
|
|
141
|
-
|
|
142
|
-
mask_ranked = np.ones((self._ranked_pairs.shape[0], 1), dtype=np.bool_)
|
|
143
|
-
mask_detailed = np.ones(
|
|
144
|
-
(self._detailed_pairs.shape[0], 1), dtype=np.bool_
|
|
145
|
-
)
|
|
146
|
-
mask_datums = np.ones(n_datums, dtype=np.bool_)
|
|
147
|
-
mask_labels = np.ones(n_labels, dtype=np.bool_)
|
|
148
|
-
|
|
149
|
-
if datum_uids is not None:
|
|
150
|
-
if isinstance(datum_uids, list):
|
|
151
|
-
datum_uids = np.array(
|
|
152
|
-
[self.uid_to_index[uid] for uid in datum_uids],
|
|
153
|
-
dtype=np.int32,
|
|
154
|
-
)
|
|
155
|
-
mask_ranked[
|
|
156
|
-
~np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
|
|
157
|
-
] = False
|
|
158
|
-
mask_detailed[
|
|
159
|
-
~np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
|
|
160
|
-
] = False
|
|
161
|
-
mask_datums[~np.isin(np.arange(n_datums), datum_uids)] = False
|
|
162
|
-
|
|
163
|
-
if labels is not None:
|
|
164
|
-
if isinstance(labels, list):
|
|
165
|
-
labels = np.array(
|
|
166
|
-
[self.label_to_index[label] for label in labels]
|
|
167
|
-
)
|
|
168
|
-
mask_ranked[
|
|
169
|
-
~np.isin(self._ranked_pairs[:, 4].astype(int), labels)
|
|
170
|
-
] = False
|
|
171
|
-
mask_detailed[
|
|
172
|
-
~np.isin(self._detailed_pairs[:, 4].astype(int), labels)
|
|
173
|
-
] = False
|
|
174
|
-
mask_labels[~np.isin(np.arange(n_labels), labels)] = False
|
|
175
|
-
|
|
176
|
-
mask_label_metadata = (
|
|
177
|
-
mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
|
|
178
|
-
)
|
|
179
|
-
label_metadata_per_datum = self._label_metadata_per_datum.copy()
|
|
180
|
-
label_metadata_per_datum[:, ~mask_label_metadata] = 0
|
|
181
|
-
|
|
182
|
-
label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
|
|
183
|
-
label_metadata = np.transpose(
|
|
184
|
-
np.sum(
|
|
185
|
-
label_metadata_per_datum,
|
|
186
|
-
axis=1,
|
|
187
|
-
)
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
return Filter(
|
|
191
|
-
ranked_indices=np.where(mask_ranked)[0],
|
|
192
|
-
detailed_indices=np.where(mask_detailed)[0],
|
|
193
|
-
label_metadata=label_metadata,
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
def compute_precision_recall(
|
|
197
|
-
self,
|
|
198
|
-
iou_thresholds: list[float],
|
|
199
|
-
score_thresholds: list[float],
|
|
200
|
-
filter_: Filter | None = None,
|
|
201
|
-
) -> dict[MetricType, list[Metric]]:
|
|
202
|
-
"""
|
|
203
|
-
Computes all metrics except for ConfusionMatrix
|
|
156
|
+
Helper function for adding a datum to the cache.
|
|
204
157
|
|
|
205
158
|
Parameters
|
|
206
159
|
----------
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
score_thresholds : list[float]
|
|
210
|
-
A list of score thresholds to compute metrics over.
|
|
211
|
-
filter_ : Filter, optional
|
|
212
|
-
An optional filter object.
|
|
160
|
+
datum_id : str
|
|
161
|
+
The datum identifier.
|
|
213
162
|
|
|
214
163
|
Returns
|
|
215
164
|
-------
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
"""
|
|
219
|
-
|
|
220
|
-
ranked_pairs = self._ranked_pairs
|
|
221
|
-
label_metadata = self._label_metadata
|
|
222
|
-
if filter_ is not None:
|
|
223
|
-
ranked_pairs = ranked_pairs[filter_.ranked_indices]
|
|
224
|
-
label_metadata = filter_.label_metadata
|
|
225
|
-
|
|
226
|
-
results = compute_precion_recall(
|
|
227
|
-
data=ranked_pairs,
|
|
228
|
-
label_metadata=label_metadata,
|
|
229
|
-
iou_thresholds=np.array(iou_thresholds),
|
|
230
|
-
score_thresholds=np.array(score_thresholds),
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
return unpack_precision_recall_into_metric_lists(
|
|
234
|
-
results=results,
|
|
235
|
-
label_metadata=label_metadata,
|
|
236
|
-
iou_thresholds=iou_thresholds,
|
|
237
|
-
score_thresholds=score_thresholds,
|
|
238
|
-
index_to_label=self.index_to_label,
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
def compute_confusion_matrix(
|
|
242
|
-
self,
|
|
243
|
-
iou_thresholds: list[float],
|
|
244
|
-
score_thresholds: list[float],
|
|
245
|
-
number_of_examples: int,
|
|
246
|
-
filter_: Filter | None = None,
|
|
247
|
-
) -> list[Metric]:
|
|
248
|
-
"""
|
|
249
|
-
Computes confusion matrices at various thresholds.
|
|
250
|
-
|
|
251
|
-
Parameters
|
|
252
|
-
----------
|
|
253
|
-
iou_thresholds : list[float]
|
|
254
|
-
A list of IOU thresholds to compute metrics over.
|
|
255
|
-
score_thresholds : list[float]
|
|
256
|
-
A list of score thresholds to compute metrics over.
|
|
257
|
-
number_of_examples : int
|
|
258
|
-
Maximum number of annotation examples to return in ConfusionMatrix.
|
|
259
|
-
filter_ : Filter, optional
|
|
260
|
-
An optional filter object.
|
|
261
|
-
|
|
262
|
-
Returns
|
|
263
|
-
-------
|
|
264
|
-
list[Metric]
|
|
265
|
-
List of confusion matrices per threshold pair.
|
|
165
|
+
int
|
|
166
|
+
The datum index.
|
|
266
167
|
"""
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
results = compute_confusion_matrix(
|
|
278
|
-
data=detailed_pairs,
|
|
279
|
-
label_metadata=label_metadata,
|
|
280
|
-
iou_thresholds=np.array(iou_thresholds),
|
|
281
|
-
score_thresholds=np.array(score_thresholds),
|
|
282
|
-
n_examples=number_of_examples,
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
return unpack_confusion_matrix_into_metric_list(
|
|
286
|
-
results=results,
|
|
287
|
-
iou_thresholds=iou_thresholds,
|
|
288
|
-
score_thresholds=score_thresholds,
|
|
289
|
-
number_of_examples=number_of_examples,
|
|
290
|
-
index_to_uid=self.index_to_uid,
|
|
291
|
-
index_to_label=self.index_to_label,
|
|
292
|
-
groundtruth_examples=self.groundtruth_examples,
|
|
293
|
-
prediction_examples=self.prediction_examples,
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
def evaluate(
|
|
297
|
-
self,
|
|
298
|
-
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
|
|
299
|
-
score_thresholds: list[float] = [0.5],
|
|
300
|
-
number_of_examples: int = 0,
|
|
301
|
-
filter_: Filter | None = None,
|
|
302
|
-
) -> dict[MetricType, list[Metric]]:
|
|
168
|
+
if datum_id not in self.datum_id_to_index:
|
|
169
|
+
if len(self.datum_id_to_index) != len(self.index_to_datum_id):
|
|
170
|
+
raise RuntimeError("datum cache size mismatch")
|
|
171
|
+
idx = len(self.datum_id_to_index)
|
|
172
|
+
self.datum_id_to_index[datum_id] = idx
|
|
173
|
+
self.index_to_datum_id.append(datum_id)
|
|
174
|
+
return self.datum_id_to_index[datum_id]
|
|
175
|
+
|
|
176
|
+
def _add_groundtruth(self, annotation_id: str) -> int:
|
|
303
177
|
"""
|
|
304
|
-
|
|
178
|
+
Helper function for adding a ground truth annotation identifier to the cache.
|
|
305
179
|
|
|
306
180
|
Parameters
|
|
307
181
|
----------
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
score_thresholds : list[float], default=[0.5]
|
|
311
|
-
A list of score thresholds to compute metrics over.
|
|
312
|
-
number_of_examples : int, default=0
|
|
313
|
-
Maximum number of annotation examples to return in ConfusionMatrix.
|
|
314
|
-
filter_ : Filter, optional
|
|
315
|
-
An optional filter object.
|
|
182
|
+
annotation_id : str
|
|
183
|
+
The ground truth annotation identifier.
|
|
316
184
|
|
|
317
185
|
Returns
|
|
318
186
|
-------
|
|
319
|
-
|
|
320
|
-
|
|
187
|
+
int
|
|
188
|
+
The ground truth annotation index.
|
|
321
189
|
"""
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
)
|
|
334
|
-
|
|
335
|
-
return metrics
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
def defaultdict_int():
|
|
339
|
-
return defaultdict(int)
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
class DataLoader:
|
|
343
|
-
"""
|
|
344
|
-
Object Detection DataLoader
|
|
345
|
-
"""
|
|
346
|
-
|
|
347
|
-
def __init__(self):
|
|
348
|
-
self._evaluator = Evaluator()
|
|
349
|
-
self.pairs: list[NDArray[np.float64]] = list()
|
|
350
|
-
self.groundtruth_count = defaultdict(defaultdict_int)
|
|
351
|
-
self.prediction_count = defaultdict(defaultdict_int)
|
|
352
|
-
|
|
353
|
-
def _add_datum(self, uid: str) -> int:
|
|
190
|
+
if annotation_id not in self.groundtruth_id_to_index:
|
|
191
|
+
if len(self.groundtruth_id_to_index) != len(
|
|
192
|
+
self.index_to_groundtruth_id
|
|
193
|
+
):
|
|
194
|
+
raise RuntimeError("ground truth cache size mismatch")
|
|
195
|
+
idx = len(self.groundtruth_id_to_index)
|
|
196
|
+
self.groundtruth_id_to_index[annotation_id] = idx
|
|
197
|
+
self.index_to_groundtruth_id.append(annotation_id)
|
|
198
|
+
return self.groundtruth_id_to_index[annotation_id]
|
|
199
|
+
|
|
200
|
+
def _add_prediction(self, annotation_id: str) -> int:
|
|
354
201
|
"""
|
|
355
|
-
Helper function for adding a
|
|
202
|
+
Helper function for adding a prediction annotation identifier to the cache.
|
|
356
203
|
|
|
357
204
|
Parameters
|
|
358
205
|
----------
|
|
359
|
-
|
|
360
|
-
The
|
|
206
|
+
annotation_id : str
|
|
207
|
+
The prediction annotation identifier.
|
|
361
208
|
|
|
362
209
|
Returns
|
|
363
210
|
-------
|
|
364
211
|
int
|
|
365
|
-
The
|
|
212
|
+
The prediction annotation index.
|
|
366
213
|
"""
|
|
367
|
-
if
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
214
|
+
if annotation_id not in self.prediction_id_to_index:
|
|
215
|
+
if len(self.prediction_id_to_index) != len(
|
|
216
|
+
self.index_to_prediction_id
|
|
217
|
+
):
|
|
218
|
+
raise RuntimeError("prediction cache size mismatch")
|
|
219
|
+
idx = len(self.prediction_id_to_index)
|
|
220
|
+
self.prediction_id_to_index[annotation_id] = idx
|
|
221
|
+
self.index_to_prediction_id.append(annotation_id)
|
|
222
|
+
return self.prediction_id_to_index[annotation_id]
|
|
372
223
|
|
|
373
224
|
def _add_label(self, label: str) -> int:
|
|
374
225
|
"""
|
|
@@ -384,90 +235,14 @@ class DataLoader:
|
|
|
384
235
|
int
|
|
385
236
|
Label index.
|
|
386
237
|
"""
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
self.
|
|
392
|
-
|
|
238
|
+
label_id = len(self.index_to_label)
|
|
239
|
+
if label not in self.label_to_index:
|
|
240
|
+
if len(self.label_to_index) != len(self.index_to_label):
|
|
241
|
+
raise RuntimeError("label cache size mismatch")
|
|
242
|
+
self.label_to_index[label] = label_id
|
|
243
|
+
self.index_to_label.append(label)
|
|
393
244
|
label_id += 1
|
|
394
|
-
|
|
395
|
-
return self._evaluator.label_to_index[label]
|
|
396
|
-
|
|
397
|
-
def _cache_pairs(
|
|
398
|
-
self,
|
|
399
|
-
uid_index: int,
|
|
400
|
-
groundtruths: list,
|
|
401
|
-
predictions: list,
|
|
402
|
-
ious: NDArray[np.float64],
|
|
403
|
-
) -> None:
|
|
404
|
-
"""
|
|
405
|
-
Compute IOUs between groundtruths and preditions before storing as pairs.
|
|
406
|
-
|
|
407
|
-
Parameters
|
|
408
|
-
----------
|
|
409
|
-
uid_index : int
|
|
410
|
-
The index of the detection.
|
|
411
|
-
groundtruths : list
|
|
412
|
-
A list of groundtruths.
|
|
413
|
-
predictions : list
|
|
414
|
-
A list of predictions.
|
|
415
|
-
ious : NDArray[np.float64]
|
|
416
|
-
An array with shape (n_preds, n_gts) containing IOUs.
|
|
417
|
-
"""
|
|
418
|
-
|
|
419
|
-
predictions_with_iou_of_zero = np.where((ious < 1e-9).all(axis=1))[0]
|
|
420
|
-
groundtruths_with_iou_of_zero = np.where((ious < 1e-9).all(axis=0))[0]
|
|
421
|
-
|
|
422
|
-
pairs = [
|
|
423
|
-
np.array(
|
|
424
|
-
[
|
|
425
|
-
float(uid_index),
|
|
426
|
-
float(gidx),
|
|
427
|
-
float(pidx),
|
|
428
|
-
ious[pidx, gidx],
|
|
429
|
-
float(glabel),
|
|
430
|
-
float(plabel),
|
|
431
|
-
float(score),
|
|
432
|
-
]
|
|
433
|
-
)
|
|
434
|
-
for pidx, plabel, score in predictions
|
|
435
|
-
for gidx, glabel in groundtruths
|
|
436
|
-
if ious[pidx, gidx] >= 1e-9
|
|
437
|
-
]
|
|
438
|
-
pairs.extend(
|
|
439
|
-
[
|
|
440
|
-
np.array(
|
|
441
|
-
[
|
|
442
|
-
float(uid_index),
|
|
443
|
-
-1.0,
|
|
444
|
-
float(predictions[index][0]),
|
|
445
|
-
0.0,
|
|
446
|
-
-1.0,
|
|
447
|
-
float(predictions[index][1]),
|
|
448
|
-
float(predictions[index][2]),
|
|
449
|
-
]
|
|
450
|
-
)
|
|
451
|
-
for index in predictions_with_iou_of_zero
|
|
452
|
-
]
|
|
453
|
-
)
|
|
454
|
-
pairs.extend(
|
|
455
|
-
[
|
|
456
|
-
np.array(
|
|
457
|
-
[
|
|
458
|
-
float(uid_index),
|
|
459
|
-
float(groundtruths[index][0]),
|
|
460
|
-
-1.0,
|
|
461
|
-
0.0,
|
|
462
|
-
float(groundtruths[index][1]),
|
|
463
|
-
-1.0,
|
|
464
|
-
-1.0,
|
|
465
|
-
]
|
|
466
|
-
)
|
|
467
|
-
for index in groundtruths_with_iou_of_zero
|
|
468
|
-
]
|
|
469
|
-
)
|
|
470
|
-
self.pairs.append(np.array(pairs))
|
|
245
|
+
return self.label_to_index[label]
|
|
471
246
|
|
|
472
247
|
def _add_data(
|
|
473
248
|
self,
|
|
@@ -491,66 +266,102 @@ class DataLoader:
|
|
|
491
266
|
for detection, ious in tqdm(
|
|
492
267
|
zip(detections, detection_ious), disable=disable_tqdm
|
|
493
268
|
):
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
self.
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
self._evaluator.groundtruth_examples[uid_index][
|
|
517
|
-
gidx
|
|
518
|
-
] = gann.extrema
|
|
519
|
-
for glabel in gann.labels:
|
|
520
|
-
label_idx = self._add_label(glabel)
|
|
521
|
-
self.groundtruth_count[label_idx][uid_index] += 1
|
|
522
|
-
groundtruths.append(
|
|
523
|
-
(
|
|
524
|
-
gidx,
|
|
525
|
-
label_idx,
|
|
269
|
+
# cache labels and annotation pairs
|
|
270
|
+
pairs = []
|
|
271
|
+
datum_idx = self._add_datum(detection.uid)
|
|
272
|
+
if detection.groundtruths:
|
|
273
|
+
for gidx, gann in enumerate(detection.groundtruths):
|
|
274
|
+
groundtruth_idx = self._add_groundtruth(gann.uid)
|
|
275
|
+
glabel_idx = self._add_label(gann.labels[0])
|
|
276
|
+
if (ious[:, gidx] < 1e-9).all():
|
|
277
|
+
pairs.extend(
|
|
278
|
+
[
|
|
279
|
+
np.array(
|
|
280
|
+
[
|
|
281
|
+
float(datum_idx),
|
|
282
|
+
float(groundtruth_idx),
|
|
283
|
+
-1.0,
|
|
284
|
+
float(glabel_idx),
|
|
285
|
+
-1.0,
|
|
286
|
+
0.0,
|
|
287
|
+
-1.0,
|
|
288
|
+
]
|
|
289
|
+
)
|
|
290
|
+
]
|
|
526
291
|
)
|
|
292
|
+
for pidx, pann in enumerate(detection.predictions):
|
|
293
|
+
prediction_idx = self._add_prediction(pann.uid)
|
|
294
|
+
if (ious[pidx, :] < 1e-9).all():
|
|
295
|
+
pairs.extend(
|
|
296
|
+
[
|
|
297
|
+
np.array(
|
|
298
|
+
[
|
|
299
|
+
float(datum_idx),
|
|
300
|
+
-1.0,
|
|
301
|
+
float(prediction_idx),
|
|
302
|
+
-1.0,
|
|
303
|
+
float(self._add_label(plabel)),
|
|
304
|
+
0.0,
|
|
305
|
+
float(pscore),
|
|
306
|
+
]
|
|
307
|
+
)
|
|
308
|
+
for plabel, pscore in zip(
|
|
309
|
+
pann.labels, pann.scores
|
|
310
|
+
)
|
|
311
|
+
]
|
|
312
|
+
)
|
|
313
|
+
if ious[pidx, gidx] >= 1e-9:
|
|
314
|
+
pairs.extend(
|
|
315
|
+
[
|
|
316
|
+
np.array(
|
|
317
|
+
[
|
|
318
|
+
float(datum_idx),
|
|
319
|
+
float(groundtruth_idx),
|
|
320
|
+
float(prediction_idx),
|
|
321
|
+
float(self._add_label(glabel)),
|
|
322
|
+
float(self._add_label(plabel)),
|
|
323
|
+
ious[pidx, gidx],
|
|
324
|
+
float(pscore),
|
|
325
|
+
]
|
|
326
|
+
)
|
|
327
|
+
for glabel in gann.labels
|
|
328
|
+
for plabel, pscore in zip(
|
|
329
|
+
pann.labels, pann.scores
|
|
330
|
+
)
|
|
331
|
+
]
|
|
332
|
+
)
|
|
333
|
+
elif detection.predictions:
|
|
334
|
+
for pidx, pann in enumerate(detection.predictions):
|
|
335
|
+
prediction_idx = self._add_prediction(pann.uid)
|
|
336
|
+
pairs.extend(
|
|
337
|
+
[
|
|
338
|
+
np.array(
|
|
339
|
+
[
|
|
340
|
+
float(datum_idx),
|
|
341
|
+
-1.0,
|
|
342
|
+
float(prediction_idx),
|
|
343
|
+
-1.0,
|
|
344
|
+
float(self._add_label(plabel)),
|
|
345
|
+
0.0,
|
|
346
|
+
float(pscore),
|
|
347
|
+
]
|
|
348
|
+
)
|
|
349
|
+
for plabel, pscore in zip(pann.labels, pann.scores)
|
|
350
|
+
]
|
|
527
351
|
)
|
|
528
352
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
self.prediction_count[label_idx][uid_index] += 1
|
|
536
|
-
predictions.append(
|
|
537
|
-
(
|
|
538
|
-
pidx,
|
|
539
|
-
label_idx,
|
|
540
|
-
pscore,
|
|
541
|
-
)
|
|
353
|
+
data = np.array(pairs)
|
|
354
|
+
if data.size > 0:
|
|
355
|
+
# reset filtered cache if it exists
|
|
356
|
+
if self._temp_cache is None:
|
|
357
|
+
raise RuntimeError(
|
|
358
|
+
"cannot add data as evaluator has already been finalized"
|
|
542
359
|
)
|
|
543
|
-
|
|
544
|
-
self._cache_pairs(
|
|
545
|
-
uid_index=uid_index,
|
|
546
|
-
groundtruths=groundtruths,
|
|
547
|
-
predictions=predictions,
|
|
548
|
-
ious=ious,
|
|
549
|
-
)
|
|
360
|
+
self._temp_cache.append(data)
|
|
550
361
|
|
|
551
362
|
def add_bounding_boxes(
|
|
552
363
|
self,
|
|
553
|
-
detections: list[Detection],
|
|
364
|
+
detections: list[Detection[BoundingBox]],
|
|
554
365
|
show_progress: bool = False,
|
|
555
366
|
):
|
|
556
367
|
"""
|
|
@@ -584,7 +395,7 @@ class DataLoader:
|
|
|
584
395
|
|
|
585
396
|
def add_polygons(
|
|
586
397
|
self,
|
|
587
|
-
detections: list[Detection],
|
|
398
|
+
detections: list[Detection[Polygon]],
|
|
588
399
|
show_progress: bool = False,
|
|
589
400
|
):
|
|
590
401
|
"""
|
|
@@ -601,7 +412,7 @@ class DataLoader:
|
|
|
601
412
|
compute_polygon_iou(
|
|
602
413
|
np.array(
|
|
603
414
|
[
|
|
604
|
-
[gt.shape, pd.shape]
|
|
415
|
+
[gt.shape, pd.shape]
|
|
605
416
|
for pd in detection.predictions
|
|
606
417
|
for gt in detection.groundtruths
|
|
607
418
|
]
|
|
@@ -617,7 +428,7 @@ class DataLoader:
|
|
|
617
428
|
|
|
618
429
|
def add_bitmasks(
|
|
619
430
|
self,
|
|
620
|
-
detections: list[Detection],
|
|
431
|
+
detections: list[Detection[Bitmask]],
|
|
621
432
|
show_progress: bool = False,
|
|
622
433
|
):
|
|
623
434
|
"""
|
|
@@ -634,7 +445,7 @@ class DataLoader:
|
|
|
634
445
|
compute_bitmask_iou(
|
|
635
446
|
np.array(
|
|
636
447
|
[
|
|
637
|
-
[gt.mask, pd.mask]
|
|
448
|
+
[gt.mask, pd.mask]
|
|
638
449
|
for pd in detection.predictions
|
|
639
450
|
for gt in detection.groundtruths
|
|
640
451
|
]
|
|
@@ -648,7 +459,7 @@ class DataLoader:
|
|
|
648
459
|
show_progress=show_progress,
|
|
649
460
|
)
|
|
650
461
|
|
|
651
|
-
def finalize(self)
|
|
462
|
+
def finalize(self):
|
|
652
463
|
"""
|
|
653
464
|
Performs data finalization and some preprocessing steps.
|
|
654
465
|
|
|
@@ -657,65 +468,347 @@ class DataLoader:
|
|
|
657
468
|
Evaluator
|
|
658
469
|
A ready-to-use evaluator object.
|
|
659
470
|
"""
|
|
471
|
+
n_labels = len(self.index_to_label)
|
|
472
|
+
n_datums = len(self.index_to_datum_id)
|
|
473
|
+
if self._temp_cache is None:
|
|
474
|
+
warnings.warn("evaluator is already finalized or in a bad state")
|
|
475
|
+
return self
|
|
476
|
+
elif not self._temp_cache:
|
|
477
|
+
self._detailed_pairs = np.array([], dtype=np.float64)
|
|
478
|
+
self._ranked_pairs = np.array([], dtype=np.float64)
|
|
479
|
+
self._label_metadata = np.zeros((n_labels, 2), dtype=np.int32)
|
|
480
|
+
self._metadata = Metadata()
|
|
481
|
+
warnings.warn("no valid pairs")
|
|
482
|
+
return self
|
|
483
|
+
else:
|
|
484
|
+
self._detailed_pairs = np.concatenate(self._temp_cache, axis=0)
|
|
485
|
+
self._temp_cache = None
|
|
486
|
+
|
|
487
|
+
# order pairs by descending score, iou
|
|
488
|
+
indices = np.lexsort(
|
|
489
|
+
(
|
|
490
|
+
-self._detailed_pairs[:, 5], # iou
|
|
491
|
+
-self._detailed_pairs[:, 6], # score
|
|
492
|
+
)
|
|
493
|
+
)
|
|
494
|
+
self._detailed_pairs = self._detailed_pairs[indices]
|
|
495
|
+
self._label_metadata = compute_label_metadata(
|
|
496
|
+
ids=self._detailed_pairs[:, :5].astype(np.int32),
|
|
497
|
+
n_labels=n_labels,
|
|
498
|
+
)
|
|
499
|
+
self._ranked_pairs = rank_pairs(
|
|
500
|
+
detailed_pairs=self._detailed_pairs,
|
|
501
|
+
label_metadata=self._label_metadata,
|
|
502
|
+
)
|
|
503
|
+
self._metadata = Metadata.create(
|
|
504
|
+
detailed_pairs=self._detailed_pairs,
|
|
505
|
+
number_of_datums=n_datums,
|
|
506
|
+
number_of_labels=n_labels,
|
|
507
|
+
)
|
|
508
|
+
return self
|
|
660
509
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
510
|
+
def create_filter(
|
|
511
|
+
self,
|
|
512
|
+
datum_ids: list[str] | None = None,
|
|
513
|
+
groundtruth_ids: list[str] | None = None,
|
|
514
|
+
prediction_ids: list[str] | None = None,
|
|
515
|
+
labels: list[str] | None = None,
|
|
516
|
+
) -> Filter:
|
|
517
|
+
"""
|
|
518
|
+
Creates a filter object.
|
|
664
519
|
|
|
665
|
-
|
|
666
|
-
|
|
520
|
+
Parameters
|
|
521
|
+
----------
|
|
522
|
+
datum_uids : list[str], optional
|
|
523
|
+
An optional list of string uids representing datums to keep.
|
|
524
|
+
groundtruth_ids : list[str], optional
|
|
525
|
+
An optional list of string uids representing ground truth annotations to keep.
|
|
526
|
+
prediction_ids : list[str], optional
|
|
527
|
+
An optional list of string uids representing prediction annotations to keep.
|
|
528
|
+
labels : list[str], optional
|
|
529
|
+
An optional list of labels to keep.
|
|
530
|
+
"""
|
|
531
|
+
mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
|
|
532
|
+
|
|
533
|
+
# filter datums
|
|
534
|
+
if datum_ids is not None:
|
|
535
|
+
if not datum_ids:
|
|
536
|
+
warnings.warn("creating a filter that removes all datums")
|
|
537
|
+
return Filter(
|
|
538
|
+
mask_datums=np.zeros_like(mask_datums),
|
|
539
|
+
mask_groundtruths=np.array([], dtype=np.bool_),
|
|
540
|
+
mask_predictions=np.array([], dtype=np.bool_),
|
|
541
|
+
metadata=Metadata(),
|
|
542
|
+
)
|
|
543
|
+
valid_datum_indices = np.array(
|
|
544
|
+
[self.datum_id_to_index[uid] for uid in datum_ids],
|
|
545
|
+
dtype=np.int32,
|
|
546
|
+
)
|
|
547
|
+
mask_datums = np.isin(
|
|
548
|
+
self._detailed_pairs[:, 0], valid_datum_indices
|
|
549
|
+
)
|
|
667
550
|
|
|
668
|
-
|
|
551
|
+
filtered_detailed_pairs = self._detailed_pairs[mask_datums]
|
|
552
|
+
n_pairs = self._detailed_pairs[mask_datums].shape[0]
|
|
553
|
+
mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
|
|
554
|
+
mask_predictions = np.zeros_like(mask_groundtruths)
|
|
669
555
|
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
gt_count = (
|
|
676
|
-
self.groundtruth_count[label_idx].get(datum_idx, 0)
|
|
677
|
-
if label_idx in self.groundtruth_count
|
|
678
|
-
else 0
|
|
556
|
+
# filter by ground truth annotation ids
|
|
557
|
+
if groundtruth_ids is not None:
|
|
558
|
+
if not groundtruth_ids:
|
|
559
|
+
warnings.warn(
|
|
560
|
+
"creating a filter that removes all ground truths"
|
|
679
561
|
)
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
562
|
+
valid_groundtruth_indices = np.array(
|
|
563
|
+
[self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
|
|
564
|
+
dtype=np.int32,
|
|
565
|
+
)
|
|
566
|
+
mask_groundtruths[
|
|
567
|
+
~np.isin(
|
|
568
|
+
filtered_detailed_pairs[:, 1],
|
|
569
|
+
valid_groundtruth_indices,
|
|
684
570
|
)
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
571
|
+
] = True
|
|
572
|
+
|
|
573
|
+
# filter by prediction annotation ids
|
|
574
|
+
if prediction_ids is not None:
|
|
575
|
+
if not prediction_ids:
|
|
576
|
+
warnings.warn("creating a filter that removes all predictions")
|
|
577
|
+
valid_prediction_indices = np.array(
|
|
578
|
+
[self.prediction_id_to_index[uid] for uid in prediction_ids],
|
|
579
|
+
dtype=np.int32,
|
|
580
|
+
)
|
|
581
|
+
mask_predictions[
|
|
582
|
+
~np.isin(
|
|
583
|
+
filtered_detailed_pairs[:, 2],
|
|
584
|
+
valid_prediction_indices,
|
|
585
|
+
)
|
|
586
|
+
] = True
|
|
587
|
+
|
|
588
|
+
# filter by labels
|
|
589
|
+
if labels is not None:
|
|
590
|
+
if not labels:
|
|
591
|
+
warnings.warn("creating a filter that removes all labels")
|
|
592
|
+
return Filter(
|
|
593
|
+
mask_datums=mask_datums,
|
|
594
|
+
mask_groundtruths=np.ones_like(mask_datums),
|
|
595
|
+
mask_predictions=np.ones_like(mask_datums),
|
|
596
|
+
metadata=Metadata(),
|
|
597
|
+
)
|
|
598
|
+
valid_label_indices = np.array(
|
|
599
|
+
[self.label_to_index[label] for label in labels] + [-1]
|
|
600
|
+
)
|
|
601
|
+
mask_groundtruths[
|
|
602
|
+
~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
|
|
603
|
+
] = True
|
|
604
|
+
mask_predictions[
|
|
605
|
+
~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
|
|
606
|
+
] = True
|
|
607
|
+
|
|
608
|
+
filtered_detailed_pairs, _, _ = filter_cache(
|
|
609
|
+
self._detailed_pairs,
|
|
610
|
+
mask_datums=mask_datums,
|
|
611
|
+
mask_ground_truths=mask_groundtruths,
|
|
612
|
+
mask_predictions=mask_predictions,
|
|
613
|
+
n_labels=len(self.index_to_label),
|
|
709
614
|
)
|
|
710
615
|
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
616
|
+
number_of_datums = (
|
|
617
|
+
len(datum_ids)
|
|
618
|
+
if datum_ids
|
|
619
|
+
else np.unique(filtered_detailed_pairs[:, 0]).size
|
|
714
620
|
)
|
|
715
621
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
622
|
+
return Filter(
|
|
623
|
+
mask_datums=mask_datums,
|
|
624
|
+
mask_groundtruths=mask_groundtruths,
|
|
625
|
+
mask_predictions=mask_predictions,
|
|
626
|
+
metadata=Metadata.create(
|
|
627
|
+
detailed_pairs=filtered_detailed_pairs,
|
|
628
|
+
number_of_datums=number_of_datums,
|
|
629
|
+
number_of_labels=len(self.index_to_label),
|
|
630
|
+
),
|
|
719
631
|
)
|
|
720
632
|
|
|
721
|
-
|
|
633
|
+
def filter(
|
|
634
|
+
self, filter_: Filter
|
|
635
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
|
|
636
|
+
"""
|
|
637
|
+
Performs filtering over the internal cache.
|
|
638
|
+
|
|
639
|
+
Parameters
|
|
640
|
+
----------
|
|
641
|
+
filter_ : Filter
|
|
642
|
+
The filter parameterization.
|
|
643
|
+
|
|
644
|
+
Returns
|
|
645
|
+
-------
|
|
646
|
+
NDArray[float64]
|
|
647
|
+
Filtered detailed pairs.
|
|
648
|
+
NDArray[float64]
|
|
649
|
+
Filtered ranked pairs.
|
|
650
|
+
NDArray[int32]
|
|
651
|
+
Label metadata.
|
|
652
|
+
"""
|
|
653
|
+
if not filter_.mask_datums.any():
|
|
654
|
+
warnings.warn("filter removed all datums")
|
|
655
|
+
return (
|
|
656
|
+
np.array([], dtype=np.float64),
|
|
657
|
+
np.array([], dtype=np.float64),
|
|
658
|
+
np.zeros((self.metadata.number_of_labels, 2), dtype=np.int32),
|
|
659
|
+
)
|
|
660
|
+
if filter_.mask_groundtruths.all():
|
|
661
|
+
warnings.warn("filter removed all ground truths")
|
|
662
|
+
if filter_.mask_predictions.all():
|
|
663
|
+
warnings.warn("filter removed all predictions")
|
|
664
|
+
return filter_cache(
|
|
665
|
+
detailed_pairs=self._detailed_pairs,
|
|
666
|
+
mask_datums=filter_.mask_datums,
|
|
667
|
+
mask_ground_truths=filter_.mask_groundtruths,
|
|
668
|
+
mask_predictions=filter_.mask_predictions,
|
|
669
|
+
n_labels=len(self.index_to_label),
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
def compute_precision_recall(
|
|
673
|
+
self,
|
|
674
|
+
iou_thresholds: list[float],
|
|
675
|
+
score_thresholds: list[float],
|
|
676
|
+
filter_: Filter | None = None,
|
|
677
|
+
) -> dict[MetricType, list[Metric]]:
|
|
678
|
+
"""
|
|
679
|
+
Computes all metrics except for ConfusionMatrix
|
|
680
|
+
|
|
681
|
+
Parameters
|
|
682
|
+
----------
|
|
683
|
+
iou_thresholds : list[float]
|
|
684
|
+
A list of IOU thresholds to compute metrics over.
|
|
685
|
+
score_thresholds : list[float]
|
|
686
|
+
A list of score thresholds to compute metrics over.
|
|
687
|
+
filter_ : Filter, optional
|
|
688
|
+
A collection of filter parameters and masks.
|
|
689
|
+
|
|
690
|
+
Returns
|
|
691
|
+
-------
|
|
692
|
+
dict[MetricType, list]
|
|
693
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
694
|
+
"""
|
|
695
|
+
if not iou_thresholds:
|
|
696
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
697
|
+
elif not score_thresholds:
|
|
698
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
699
|
+
|
|
700
|
+
if filter_ is not None:
|
|
701
|
+
_, ranked_pairs, label_metadata = self.filter(filter_=filter_)
|
|
702
|
+
else:
|
|
703
|
+
ranked_pairs = self._ranked_pairs
|
|
704
|
+
label_metadata = self._label_metadata
|
|
705
|
+
|
|
706
|
+
results = compute_precion_recall(
|
|
707
|
+
ranked_pairs=ranked_pairs,
|
|
708
|
+
label_metadata=label_metadata,
|
|
709
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
710
|
+
score_thresholds=np.array(score_thresholds),
|
|
711
|
+
)
|
|
712
|
+
return unpack_precision_recall_into_metric_lists(
|
|
713
|
+
results=results,
|
|
714
|
+
label_metadata=label_metadata,
|
|
715
|
+
iou_thresholds=iou_thresholds,
|
|
716
|
+
score_thresholds=score_thresholds,
|
|
717
|
+
index_to_label=self.index_to_label,
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
def compute_confusion_matrix(
|
|
721
|
+
self,
|
|
722
|
+
iou_thresholds: list[float],
|
|
723
|
+
score_thresholds: list[float],
|
|
724
|
+
filter_: Filter | None = None,
|
|
725
|
+
) -> list[Metric]:
|
|
726
|
+
"""
|
|
727
|
+
Computes confusion matrices at various thresholds.
|
|
728
|
+
|
|
729
|
+
Parameters
|
|
730
|
+
----------
|
|
731
|
+
iou_thresholds : list[float]
|
|
732
|
+
A list of IOU thresholds to compute metrics over.
|
|
733
|
+
score_thresholds : list[float]
|
|
734
|
+
A list of score thresholds to compute metrics over.
|
|
735
|
+
filter_ : Filter, optional
|
|
736
|
+
A collection of filter parameters and masks.
|
|
737
|
+
|
|
738
|
+
Returns
|
|
739
|
+
-------
|
|
740
|
+
list[Metric]
|
|
741
|
+
List of confusion matrices per threshold pair.
|
|
742
|
+
"""
|
|
743
|
+
if not iou_thresholds:
|
|
744
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
745
|
+
elif not score_thresholds:
|
|
746
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
747
|
+
|
|
748
|
+
if filter_ is not None:
|
|
749
|
+
detailed_pairs, _, _ = self.filter(filter_=filter_)
|
|
750
|
+
else:
|
|
751
|
+
detailed_pairs = self._detailed_pairs
|
|
752
|
+
|
|
753
|
+
if detailed_pairs.size == 0:
|
|
754
|
+
warnings.warn("attempted to compute over an empty set")
|
|
755
|
+
return []
|
|
756
|
+
|
|
757
|
+
results = compute_confusion_matrix(
|
|
758
|
+
detailed_pairs=detailed_pairs,
|
|
759
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
760
|
+
score_thresholds=np.array(score_thresholds),
|
|
761
|
+
)
|
|
762
|
+
return unpack_confusion_matrix_into_metric_list(
|
|
763
|
+
results=results,
|
|
764
|
+
detailed_pairs=detailed_pairs,
|
|
765
|
+
iou_thresholds=iou_thresholds,
|
|
766
|
+
score_thresholds=score_thresholds,
|
|
767
|
+
index_to_datum_id=self.index_to_datum_id,
|
|
768
|
+
index_to_groundtruth_id=self.index_to_groundtruth_id,
|
|
769
|
+
index_to_prediction_id=self.index_to_prediction_id,
|
|
770
|
+
index_to_label=self.index_to_label,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
def evaluate(
|
|
774
|
+
self,
|
|
775
|
+
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
|
|
776
|
+
score_thresholds: list[float] = [0.5],
|
|
777
|
+
filter_: Filter | None = None,
|
|
778
|
+
) -> dict[MetricType, list[Metric]]:
|
|
779
|
+
"""
|
|
780
|
+
Computes all available metrics.
|
|
781
|
+
|
|
782
|
+
Parameters
|
|
783
|
+
----------
|
|
784
|
+
iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
|
|
785
|
+
A list of IOU thresholds to compute metrics over.
|
|
786
|
+
score_thresholds : list[float], default=[0.5]
|
|
787
|
+
A list of score thresholds to compute metrics over.
|
|
788
|
+
filter_ : Filter, optional
|
|
789
|
+
A collection of filter parameters and masks.
|
|
790
|
+
|
|
791
|
+
Returns
|
|
792
|
+
-------
|
|
793
|
+
dict[MetricType, list[Metric]]
|
|
794
|
+
Lists of metrics organized by metric type.
|
|
795
|
+
"""
|
|
796
|
+
metrics = self.compute_precision_recall(
|
|
797
|
+
iou_thresholds=iou_thresholds,
|
|
798
|
+
score_thresholds=score_thresholds,
|
|
799
|
+
filter_=filter_,
|
|
800
|
+
)
|
|
801
|
+
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
802
|
+
iou_thresholds=iou_thresholds,
|
|
803
|
+
score_thresholds=score_thresholds,
|
|
804
|
+
filter_=filter_,
|
|
805
|
+
)
|
|
806
|
+
return metrics
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
class DataLoader(Evaluator):
|
|
810
|
+
"""
|
|
811
|
+
Used for backwards compatibility as the Evaluator now handles ingestion.
|
|
812
|
+
"""
|
|
813
|
+
|
|
814
|
+
pass
|