valor-lite 0.34.2__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/object_detection/__init__.py +0 -14
- valor_lite/object_detection/annotation.py +24 -48
- valor_lite/object_detection/computation.py +244 -407
- valor_lite/object_detection/manager.py +458 -374
- valor_lite/object_detection/metric.py +16 -70
- valor_lite/object_detection/utilities.py +134 -317
- {valor_lite-0.34.2.dist-info → valor_lite-0.35.0.dist-info}/METADATA +1 -1
- {valor_lite-0.34.2.dist-info → valor_lite-0.35.0.dist-info}/RECORD +10 -10
- {valor_lite-0.34.2.dist-info → valor_lite-0.35.0.dist-info}/WHEEL +1 -1
- {valor_lite-0.34.2.dist-info → valor_lite-0.35.0.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
from dataclasses import dataclass
|
|
1
|
+
import warnings
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
from numpy.typing import NDArray
|
|
6
5
|
from tqdm import tqdm
|
|
7
6
|
|
|
8
|
-
from valor_lite.object_detection.annotation import
|
|
7
|
+
from valor_lite.object_detection.annotation import (
|
|
8
|
+
Bitmask,
|
|
9
|
+
BoundingBox,
|
|
10
|
+
Detection,
|
|
11
|
+
Polygon,
|
|
12
|
+
)
|
|
9
13
|
from valor_lite.object_detection.computation import (
|
|
10
14
|
compute_bbox_iou,
|
|
11
15
|
compute_bitmask_iou,
|
|
12
16
|
compute_confusion_matrix,
|
|
17
|
+
compute_label_metadata,
|
|
13
18
|
compute_polygon_iou,
|
|
14
19
|
compute_precion_recall,
|
|
15
|
-
|
|
20
|
+
rank_pairs,
|
|
16
21
|
)
|
|
17
22
|
from valor_lite.object_detection.metric import Metric, MetricType
|
|
18
23
|
from valor_lite.object_detection.utilities import (
|
|
@@ -41,13 +46,6 @@ filtered_metrics = evaluator.evaluate(iou_thresholds=[0.5], filter_mask=filter_m
|
|
|
41
46
|
"""
|
|
42
47
|
|
|
43
48
|
|
|
44
|
-
@dataclass
|
|
45
|
-
class Filter:
|
|
46
|
-
ranked_indices: NDArray[np.intp]
|
|
47
|
-
detailed_indices: NDArray[np.intp]
|
|
48
|
-
label_metadata: NDArray[np.int32]
|
|
49
|
-
|
|
50
|
-
|
|
51
49
|
class Evaluator:
|
|
52
50
|
"""
|
|
53
51
|
Object Detection Evaluator
|
|
@@ -55,37 +53,94 @@ class Evaluator:
|
|
|
55
53
|
|
|
56
54
|
def __init__(self):
|
|
57
55
|
|
|
58
|
-
#
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self.
|
|
62
|
-
self.
|
|
56
|
+
# external reference
|
|
57
|
+
self.datum_id_to_index: dict[str, int] = {}
|
|
58
|
+
self.groundtruth_id_to_index: dict[str, int] = {}
|
|
59
|
+
self.prediction_id_to_index: dict[str, int] = {}
|
|
60
|
+
self.label_to_index: dict[str, int] = {}
|
|
61
|
+
|
|
62
|
+
self.index_to_datum_id: list[str] = []
|
|
63
|
+
self.index_to_groundtruth_id: list[str] = []
|
|
64
|
+
self.index_to_prediction_id: list[str] = []
|
|
65
|
+
self.index_to_label: list[str] = []
|
|
66
|
+
|
|
67
|
+
# temporary cache
|
|
68
|
+
self._temp_cache: list[NDArray[np.float64]] | None = []
|
|
69
|
+
|
|
70
|
+
# cache
|
|
71
|
+
self._detailed_pairs = np.array([[]], dtype=np.float64)
|
|
72
|
+
self._ranked_pairs = np.array([[]], dtype=np.float64)
|
|
73
|
+
self._label_metadata: NDArray[np.int32] = np.array([[]])
|
|
74
|
+
|
|
75
|
+
# filter cache
|
|
76
|
+
self._filtered_detailed_pairs: NDArray[np.float64] | None = None
|
|
77
|
+
self._filtered_ranked_pairs: NDArray[np.float64] | None = None
|
|
78
|
+
self._filtered_label_metadata: NDArray[np.int32] | None = None
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def is_filtered(self) -> bool:
|
|
82
|
+
return self._filtered_detailed_pairs is not None
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def label_metadata(self) -> NDArray[np.int32]:
|
|
86
|
+
return (
|
|
87
|
+
self._filtered_label_metadata
|
|
88
|
+
if self._filtered_label_metadata is not None
|
|
89
|
+
else self._label_metadata
|
|
90
|
+
)
|
|
63
91
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
92
|
+
@property
|
|
93
|
+
def detailed_pairs(self) -> NDArray[np.float64]:
|
|
94
|
+
return (
|
|
95
|
+
self._filtered_detailed_pairs
|
|
96
|
+
if self._filtered_detailed_pairs is not None
|
|
97
|
+
else self._detailed_pairs
|
|
98
|
+
)
|
|
67
99
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
100
|
+
@property
|
|
101
|
+
def ranked_pairs(self) -> NDArray[np.float64]:
|
|
102
|
+
return (
|
|
103
|
+
self._filtered_ranked_pairs
|
|
104
|
+
if self._filtered_ranked_pairs is not None
|
|
105
|
+
else self._ranked_pairs
|
|
106
|
+
)
|
|
71
107
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
108
|
+
@property
|
|
109
|
+
def n_labels(self) -> int:
|
|
110
|
+
"""Returns the total number of unique labels."""
|
|
111
|
+
return len(self.index_to_label)
|
|
75
112
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
113
|
+
@property
|
|
114
|
+
def n_datums(self) -> int:
|
|
115
|
+
"""Returns the number of datums."""
|
|
116
|
+
return np.unique(self.detailed_pairs[:, 0]).size
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def n_groundtruths(self) -> int:
|
|
120
|
+
"""Returns the number of ground truth annotations."""
|
|
121
|
+
mask_valid_gts = self.detailed_pairs[:, 1] >= 0
|
|
122
|
+
unique_ids = np.unique(
|
|
123
|
+
self.detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], axis=0 # type: ignore - np.ix_ typing
|
|
124
|
+
)
|
|
125
|
+
return int(unique_ids.shape[0])
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def n_predictions(self) -> int:
|
|
129
|
+
"""Returns the number of prediction annotations."""
|
|
130
|
+
mask_valid_pds = self.detailed_pairs[:, 2] >= 0
|
|
131
|
+
unique_ids = np.unique(
|
|
132
|
+
self.detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
|
|
133
|
+
)
|
|
134
|
+
return int(unique_ids.shape[0])
|
|
81
135
|
|
|
82
136
|
@property
|
|
83
137
|
def ignored_prediction_labels(self) -> list[str]:
|
|
84
138
|
"""
|
|
85
139
|
Prediction labels that are not present in the ground truth set.
|
|
86
140
|
"""
|
|
87
|
-
|
|
88
|
-
|
|
141
|
+
label_metadata = self.label_metadata
|
|
142
|
+
glabels = set(np.where(label_metadata[:, 0] > 0)[0])
|
|
143
|
+
plabels = set(np.where(label_metadata[:, 1] > 0)[0])
|
|
89
144
|
return [
|
|
90
145
|
self.index_to_label[label_id] for label_id in (plabels - glabels)
|
|
91
146
|
]
|
|
@@ -95,8 +150,9 @@ class Evaluator:
|
|
|
95
150
|
"""
|
|
96
151
|
Ground truth labels that are not present in the prediction set.
|
|
97
152
|
"""
|
|
98
|
-
|
|
99
|
-
|
|
153
|
+
label_metadata = self.label_metadata
|
|
154
|
+
glabels = set(np.where(label_metadata[:, 0] > 0)[0])
|
|
155
|
+
plabels = set(np.where(label_metadata[:, 1] > 0)[0])
|
|
100
156
|
return [
|
|
101
157
|
self.index_to_label[label_id] for label_id in (glabels - plabels)
|
|
102
158
|
]
|
|
@@ -115,89 +171,10 @@ class Evaluator:
|
|
|
115
171
|
"missing_prediction_labels": self.missing_prediction_labels,
|
|
116
172
|
}
|
|
117
173
|
|
|
118
|
-
def create_filter(
|
|
119
|
-
self,
|
|
120
|
-
datum_uids: list[str] | NDArray[np.int32] | None = None,
|
|
121
|
-
labels: list[str] | NDArray[np.int32] | None = None,
|
|
122
|
-
) -> Filter:
|
|
123
|
-
"""
|
|
124
|
-
Creates a filter that can be passed to an evaluation.
|
|
125
|
-
|
|
126
|
-
Parameters
|
|
127
|
-
----------
|
|
128
|
-
datum_uids : list[str] | NDArray[np.int32], optional
|
|
129
|
-
An optional list of string uids or a numpy array of uid indices.
|
|
130
|
-
labels : list[str] | NDArray[np.int32], optional
|
|
131
|
-
An optional list of labels or a numpy array of label indices.
|
|
132
|
-
|
|
133
|
-
Returns
|
|
134
|
-
-------
|
|
135
|
-
Filter
|
|
136
|
-
A filter object that can be passed to the `evaluate` method.
|
|
137
|
-
"""
|
|
138
|
-
|
|
139
|
-
n_datums = self._label_metadata_per_datum.shape[1]
|
|
140
|
-
n_labels = self._label_metadata_per_datum.shape[2]
|
|
141
|
-
|
|
142
|
-
mask_ranked = np.ones((self._ranked_pairs.shape[0], 1), dtype=np.bool_)
|
|
143
|
-
mask_detailed = np.ones(
|
|
144
|
-
(self._detailed_pairs.shape[0], 1), dtype=np.bool_
|
|
145
|
-
)
|
|
146
|
-
mask_datums = np.ones(n_datums, dtype=np.bool_)
|
|
147
|
-
mask_labels = np.ones(n_labels, dtype=np.bool_)
|
|
148
|
-
|
|
149
|
-
if datum_uids is not None:
|
|
150
|
-
if isinstance(datum_uids, list):
|
|
151
|
-
datum_uids = np.array(
|
|
152
|
-
[self.uid_to_index[uid] for uid in datum_uids],
|
|
153
|
-
dtype=np.int32,
|
|
154
|
-
)
|
|
155
|
-
mask_ranked[
|
|
156
|
-
~np.isin(self._ranked_pairs[:, 0].astype(int), datum_uids)
|
|
157
|
-
] = False
|
|
158
|
-
mask_detailed[
|
|
159
|
-
~np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
|
|
160
|
-
] = False
|
|
161
|
-
mask_datums[~np.isin(np.arange(n_datums), datum_uids)] = False
|
|
162
|
-
|
|
163
|
-
if labels is not None:
|
|
164
|
-
if isinstance(labels, list):
|
|
165
|
-
labels = np.array(
|
|
166
|
-
[self.label_to_index[label] for label in labels]
|
|
167
|
-
)
|
|
168
|
-
mask_ranked[
|
|
169
|
-
~np.isin(self._ranked_pairs[:, 4].astype(int), labels)
|
|
170
|
-
] = False
|
|
171
|
-
mask_detailed[
|
|
172
|
-
~np.isin(self._detailed_pairs[:, 4].astype(int), labels)
|
|
173
|
-
] = False
|
|
174
|
-
mask_labels[~np.isin(np.arange(n_labels), labels)] = False
|
|
175
|
-
|
|
176
|
-
mask_label_metadata = (
|
|
177
|
-
mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
|
|
178
|
-
)
|
|
179
|
-
label_metadata_per_datum = self._label_metadata_per_datum.copy()
|
|
180
|
-
label_metadata_per_datum[:, ~mask_label_metadata] = 0
|
|
181
|
-
|
|
182
|
-
label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
|
|
183
|
-
label_metadata = np.transpose(
|
|
184
|
-
np.sum(
|
|
185
|
-
label_metadata_per_datum,
|
|
186
|
-
axis=1,
|
|
187
|
-
)
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
return Filter(
|
|
191
|
-
ranked_indices=np.where(mask_ranked)[0],
|
|
192
|
-
detailed_indices=np.where(mask_detailed)[0],
|
|
193
|
-
label_metadata=label_metadata,
|
|
194
|
-
)
|
|
195
|
-
|
|
196
174
|
def compute_precision_recall(
|
|
197
175
|
self,
|
|
198
176
|
iou_thresholds: list[float],
|
|
199
177
|
score_thresholds: list[float],
|
|
200
|
-
filter_: Filter | None = None,
|
|
201
178
|
) -> dict[MetricType, list[Metric]]:
|
|
202
179
|
"""
|
|
203
180
|
Computes all metrics except for ConfusionMatrix
|
|
@@ -208,31 +185,25 @@ class Evaluator:
|
|
|
208
185
|
A list of IOU thresholds to compute metrics over.
|
|
209
186
|
score_thresholds : list[float]
|
|
210
187
|
A list of score thresholds to compute metrics over.
|
|
211
|
-
filter_ : Filter, optional
|
|
212
|
-
An optional filter object.
|
|
213
188
|
|
|
214
189
|
Returns
|
|
215
190
|
-------
|
|
216
191
|
dict[MetricType, list]
|
|
217
192
|
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
218
193
|
"""
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
ranked_pairs = ranked_pairs[filter_.ranked_indices]
|
|
224
|
-
label_metadata = filter_.label_metadata
|
|
225
|
-
|
|
194
|
+
if not iou_thresholds:
|
|
195
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
196
|
+
elif not score_thresholds:
|
|
197
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
226
198
|
results = compute_precion_recall(
|
|
227
|
-
|
|
228
|
-
label_metadata=label_metadata,
|
|
199
|
+
ranked_pairs=self.ranked_pairs,
|
|
200
|
+
label_metadata=self.label_metadata,
|
|
229
201
|
iou_thresholds=np.array(iou_thresholds),
|
|
230
202
|
score_thresholds=np.array(score_thresholds),
|
|
231
203
|
)
|
|
232
|
-
|
|
233
204
|
return unpack_precision_recall_into_metric_lists(
|
|
234
205
|
results=results,
|
|
235
|
-
label_metadata=label_metadata,
|
|
206
|
+
label_metadata=self.label_metadata,
|
|
236
207
|
iou_thresholds=iou_thresholds,
|
|
237
208
|
score_thresholds=score_thresholds,
|
|
238
209
|
index_to_label=self.index_to_label,
|
|
@@ -242,8 +213,6 @@ class Evaluator:
|
|
|
242
213
|
self,
|
|
243
214
|
iou_thresholds: list[float],
|
|
244
215
|
score_thresholds: list[float],
|
|
245
|
-
number_of_examples: int,
|
|
246
|
-
filter_: Filter | None = None,
|
|
247
216
|
) -> list[Metric]:
|
|
248
217
|
"""
|
|
249
218
|
Computes confusion matrices at various thresholds.
|
|
@@ -254,51 +223,39 @@ class Evaluator:
|
|
|
254
223
|
A list of IOU thresholds to compute metrics over.
|
|
255
224
|
score_thresholds : list[float]
|
|
256
225
|
A list of score thresholds to compute metrics over.
|
|
257
|
-
number_of_examples : int
|
|
258
|
-
Maximum number of annotation examples to return in ConfusionMatrix.
|
|
259
|
-
filter_ : Filter, optional
|
|
260
|
-
An optional filter object.
|
|
261
226
|
|
|
262
227
|
Returns
|
|
263
228
|
-------
|
|
264
229
|
list[Metric]
|
|
265
230
|
List of confusion matrices per threshold pair.
|
|
266
231
|
"""
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
if detailed_pairs.size == 0:
|
|
275
|
-
return list()
|
|
276
|
-
|
|
232
|
+
if not iou_thresholds:
|
|
233
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
234
|
+
elif not score_thresholds:
|
|
235
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
236
|
+
elif self.detailed_pairs.size == 0:
|
|
237
|
+
warnings.warn("attempted to compute over an empty set")
|
|
238
|
+
return []
|
|
277
239
|
results = compute_confusion_matrix(
|
|
278
|
-
|
|
279
|
-
label_metadata=label_metadata,
|
|
240
|
+
detailed_pairs=self.detailed_pairs,
|
|
280
241
|
iou_thresholds=np.array(iou_thresholds),
|
|
281
242
|
score_thresholds=np.array(score_thresholds),
|
|
282
|
-
n_examples=number_of_examples,
|
|
283
243
|
)
|
|
284
|
-
|
|
285
244
|
return unpack_confusion_matrix_into_metric_list(
|
|
286
245
|
results=results,
|
|
246
|
+
detailed_pairs=self.detailed_pairs,
|
|
287
247
|
iou_thresholds=iou_thresholds,
|
|
288
248
|
score_thresholds=score_thresholds,
|
|
289
|
-
|
|
290
|
-
|
|
249
|
+
index_to_datum_id=self.index_to_datum_id,
|
|
250
|
+
index_to_groundtruth_id=self.index_to_groundtruth_id,
|
|
251
|
+
index_to_prediction_id=self.index_to_prediction_id,
|
|
291
252
|
index_to_label=self.index_to_label,
|
|
292
|
-
groundtruth_examples=self.groundtruth_examples,
|
|
293
|
-
prediction_examples=self.prediction_examples,
|
|
294
253
|
)
|
|
295
254
|
|
|
296
255
|
def evaluate(
|
|
297
256
|
self,
|
|
298
257
|
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
|
|
299
258
|
score_thresholds: list[float] = [0.5],
|
|
300
|
-
number_of_examples: int = 0,
|
|
301
|
-
filter_: Filter | None = None,
|
|
302
259
|
) -> dict[MetricType, list[Metric]]:
|
|
303
260
|
"""
|
|
304
261
|
Computes all available metrics.
|
|
@@ -309,10 +266,6 @@ class Evaluator:
|
|
|
309
266
|
A list of IOU thresholds to compute metrics over.
|
|
310
267
|
score_thresholds : list[float], default=[0.5]
|
|
311
268
|
A list of score thresholds to compute metrics over.
|
|
312
|
-
number_of_examples : int, default=0
|
|
313
|
-
Maximum number of annotation examples to return in ConfusionMatrix.
|
|
314
|
-
filter_ : Filter, optional
|
|
315
|
-
An optional filter object.
|
|
316
269
|
|
|
317
270
|
Returns
|
|
318
271
|
-------
|
|
@@ -322,53 +275,82 @@ class Evaluator:
|
|
|
322
275
|
metrics = self.compute_precision_recall(
|
|
323
276
|
iou_thresholds=iou_thresholds,
|
|
324
277
|
score_thresholds=score_thresholds,
|
|
325
|
-
filter_=filter_,
|
|
326
278
|
)
|
|
327
|
-
|
|
328
279
|
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
329
280
|
iou_thresholds=iou_thresholds,
|
|
330
281
|
score_thresholds=score_thresholds,
|
|
331
|
-
number_of_examples=number_of_examples,
|
|
332
|
-
filter_=filter_,
|
|
333
282
|
)
|
|
334
|
-
|
|
335
283
|
return metrics
|
|
336
284
|
|
|
285
|
+
def _add_datum(self, datum_id: str) -> int:
|
|
286
|
+
"""
|
|
287
|
+
Helper function for adding a datum to the cache.
|
|
337
288
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
289
|
+
Parameters
|
|
290
|
+
----------
|
|
291
|
+
datum_id : str
|
|
292
|
+
The datum identifier.
|
|
341
293
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
294
|
+
Returns
|
|
295
|
+
-------
|
|
296
|
+
int
|
|
297
|
+
The datum index.
|
|
298
|
+
"""
|
|
299
|
+
if datum_id not in self.datum_id_to_index:
|
|
300
|
+
if len(self.datum_id_to_index) != len(self.index_to_datum_id):
|
|
301
|
+
raise RuntimeError("datum cache size mismatch")
|
|
302
|
+
idx = len(self.datum_id_to_index)
|
|
303
|
+
self.datum_id_to_index[datum_id] = idx
|
|
304
|
+
self.index_to_datum_id.append(datum_id)
|
|
305
|
+
return self.datum_id_to_index[datum_id]
|
|
306
|
+
|
|
307
|
+
def _add_groundtruth(self, annotation_id: str) -> int:
|
|
308
|
+
"""
|
|
309
|
+
Helper function for adding a ground truth annotation identifier to the cache.
|
|
346
310
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
self.prediction_count = defaultdict(defaultdict_int)
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
annotation_id : str
|
|
314
|
+
The ground truth annotation identifier.
|
|
352
315
|
|
|
353
|
-
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
int
|
|
319
|
+
The ground truth annotation index.
|
|
354
320
|
"""
|
|
355
|
-
|
|
321
|
+
if annotation_id not in self.groundtruth_id_to_index:
|
|
322
|
+
if len(self.groundtruth_id_to_index) != len(
|
|
323
|
+
self.index_to_groundtruth_id
|
|
324
|
+
):
|
|
325
|
+
raise RuntimeError("ground truth cache size mismatch")
|
|
326
|
+
idx = len(self.groundtruth_id_to_index)
|
|
327
|
+
self.groundtruth_id_to_index[annotation_id] = idx
|
|
328
|
+
self.index_to_groundtruth_id.append(annotation_id)
|
|
329
|
+
return self.groundtruth_id_to_index[annotation_id]
|
|
330
|
+
|
|
331
|
+
def _add_prediction(self, annotation_id: str) -> int:
|
|
332
|
+
"""
|
|
333
|
+
Helper function for adding a prediction annotation identifier to the cache.
|
|
356
334
|
|
|
357
335
|
Parameters
|
|
358
336
|
----------
|
|
359
|
-
|
|
360
|
-
The
|
|
337
|
+
annotation_id : str
|
|
338
|
+
The prediction annotation identifier.
|
|
361
339
|
|
|
362
340
|
Returns
|
|
363
341
|
-------
|
|
364
342
|
int
|
|
365
|
-
The
|
|
343
|
+
The prediction annotation index.
|
|
366
344
|
"""
|
|
367
|
-
if
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
345
|
+
if annotation_id not in self.prediction_id_to_index:
|
|
346
|
+
if len(self.prediction_id_to_index) != len(
|
|
347
|
+
self.index_to_prediction_id
|
|
348
|
+
):
|
|
349
|
+
raise RuntimeError("prediction cache size mismatch")
|
|
350
|
+
idx = len(self.prediction_id_to_index)
|
|
351
|
+
self.prediction_id_to_index[annotation_id] = idx
|
|
352
|
+
self.index_to_prediction_id.append(annotation_id)
|
|
353
|
+
return self.prediction_id_to_index[annotation_id]
|
|
372
354
|
|
|
373
355
|
def _add_label(self, label: str) -> int:
|
|
374
356
|
"""
|
|
@@ -384,90 +366,14 @@ class DataLoader:
|
|
|
384
366
|
int
|
|
385
367
|
Label index.
|
|
386
368
|
"""
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
self.
|
|
392
|
-
|
|
369
|
+
label_id = len(self.index_to_label)
|
|
370
|
+
if label not in self.label_to_index:
|
|
371
|
+
if len(self.label_to_index) != len(self.index_to_label):
|
|
372
|
+
raise RuntimeError("label cache size mismatch")
|
|
373
|
+
self.label_to_index[label] = label_id
|
|
374
|
+
self.index_to_label.append(label)
|
|
393
375
|
label_id += 1
|
|
394
|
-
|
|
395
|
-
return self._evaluator.label_to_index[label]
|
|
396
|
-
|
|
397
|
-
def _cache_pairs(
|
|
398
|
-
self,
|
|
399
|
-
uid_index: int,
|
|
400
|
-
groundtruths: list,
|
|
401
|
-
predictions: list,
|
|
402
|
-
ious: NDArray[np.float64],
|
|
403
|
-
) -> None:
|
|
404
|
-
"""
|
|
405
|
-
Compute IOUs between groundtruths and preditions before storing as pairs.
|
|
406
|
-
|
|
407
|
-
Parameters
|
|
408
|
-
----------
|
|
409
|
-
uid_index : int
|
|
410
|
-
The index of the detection.
|
|
411
|
-
groundtruths : list
|
|
412
|
-
A list of groundtruths.
|
|
413
|
-
predictions : list
|
|
414
|
-
A list of predictions.
|
|
415
|
-
ious : NDArray[np.float64]
|
|
416
|
-
An array with shape (n_preds, n_gts) containing IOUs.
|
|
417
|
-
"""
|
|
418
|
-
|
|
419
|
-
predictions_with_iou_of_zero = np.where((ious < 1e-9).all(axis=1))[0]
|
|
420
|
-
groundtruths_with_iou_of_zero = np.where((ious < 1e-9).all(axis=0))[0]
|
|
421
|
-
|
|
422
|
-
pairs = [
|
|
423
|
-
np.array(
|
|
424
|
-
[
|
|
425
|
-
float(uid_index),
|
|
426
|
-
float(gidx),
|
|
427
|
-
float(pidx),
|
|
428
|
-
ious[pidx, gidx],
|
|
429
|
-
float(glabel),
|
|
430
|
-
float(plabel),
|
|
431
|
-
float(score),
|
|
432
|
-
]
|
|
433
|
-
)
|
|
434
|
-
for pidx, plabel, score in predictions
|
|
435
|
-
for gidx, glabel in groundtruths
|
|
436
|
-
if ious[pidx, gidx] >= 1e-9
|
|
437
|
-
]
|
|
438
|
-
pairs.extend(
|
|
439
|
-
[
|
|
440
|
-
np.array(
|
|
441
|
-
[
|
|
442
|
-
float(uid_index),
|
|
443
|
-
-1.0,
|
|
444
|
-
float(predictions[index][0]),
|
|
445
|
-
0.0,
|
|
446
|
-
-1.0,
|
|
447
|
-
float(predictions[index][1]),
|
|
448
|
-
float(predictions[index][2]),
|
|
449
|
-
]
|
|
450
|
-
)
|
|
451
|
-
for index in predictions_with_iou_of_zero
|
|
452
|
-
]
|
|
453
|
-
)
|
|
454
|
-
pairs.extend(
|
|
455
|
-
[
|
|
456
|
-
np.array(
|
|
457
|
-
[
|
|
458
|
-
float(uid_index),
|
|
459
|
-
float(groundtruths[index][0]),
|
|
460
|
-
-1.0,
|
|
461
|
-
0.0,
|
|
462
|
-
float(groundtruths[index][1]),
|
|
463
|
-
-1.0,
|
|
464
|
-
-1.0,
|
|
465
|
-
]
|
|
466
|
-
)
|
|
467
|
-
for index in groundtruths_with_iou_of_zero
|
|
468
|
-
]
|
|
469
|
-
)
|
|
470
|
-
self.pairs.append(np.array(pairs))
|
|
376
|
+
return self.label_to_index[label]
|
|
471
377
|
|
|
472
378
|
def _add_data(
|
|
473
379
|
self,
|
|
@@ -491,66 +397,103 @@ class DataLoader:
|
|
|
491
397
|
for detection, ious in tqdm(
|
|
492
398
|
zip(detections, detection_ious), disable=disable_tqdm
|
|
493
399
|
):
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
self.
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
self._evaluator.groundtruth_examples[uid_index][
|
|
517
|
-
gidx
|
|
518
|
-
] = gann.extrema
|
|
519
|
-
for glabel in gann.labels:
|
|
520
|
-
label_idx = self._add_label(glabel)
|
|
521
|
-
self.groundtruth_count[label_idx][uid_index] += 1
|
|
522
|
-
groundtruths.append(
|
|
523
|
-
(
|
|
524
|
-
gidx,
|
|
525
|
-
label_idx,
|
|
400
|
+
# cache labels and annotation pairs
|
|
401
|
+
pairs = []
|
|
402
|
+
datum_idx = self._add_datum(detection.uid)
|
|
403
|
+
if detection.groundtruths:
|
|
404
|
+
for gidx, gann in enumerate(detection.groundtruths):
|
|
405
|
+
groundtruth_idx = self._add_groundtruth(gann.uid)
|
|
406
|
+
glabel_idx = self._add_label(gann.labels[0])
|
|
407
|
+
if (ious[:, gidx] < 1e-9).all():
|
|
408
|
+
pairs.extend(
|
|
409
|
+
[
|
|
410
|
+
np.array(
|
|
411
|
+
[
|
|
412
|
+
float(datum_idx),
|
|
413
|
+
float(groundtruth_idx),
|
|
414
|
+
-1.0,
|
|
415
|
+
float(glabel_idx),
|
|
416
|
+
-1.0,
|
|
417
|
+
0.0,
|
|
418
|
+
-1.0,
|
|
419
|
+
]
|
|
420
|
+
)
|
|
421
|
+
]
|
|
526
422
|
)
|
|
423
|
+
for pidx, pann in enumerate(detection.predictions):
|
|
424
|
+
prediction_idx = self._add_prediction(pann.uid)
|
|
425
|
+
if (ious[pidx, :] < 1e-9).all():
|
|
426
|
+
pairs.extend(
|
|
427
|
+
[
|
|
428
|
+
np.array(
|
|
429
|
+
[
|
|
430
|
+
float(datum_idx),
|
|
431
|
+
-1.0,
|
|
432
|
+
float(prediction_idx),
|
|
433
|
+
-1.0,
|
|
434
|
+
float(self._add_label(plabel)),
|
|
435
|
+
0.0,
|
|
436
|
+
float(pscore),
|
|
437
|
+
]
|
|
438
|
+
)
|
|
439
|
+
for plabel, pscore in zip(
|
|
440
|
+
pann.labels, pann.scores
|
|
441
|
+
)
|
|
442
|
+
]
|
|
443
|
+
)
|
|
444
|
+
if ious[pidx, gidx] >= 1e-9:
|
|
445
|
+
pairs.extend(
|
|
446
|
+
[
|
|
447
|
+
np.array(
|
|
448
|
+
[
|
|
449
|
+
float(datum_idx),
|
|
450
|
+
float(groundtruth_idx),
|
|
451
|
+
float(prediction_idx),
|
|
452
|
+
float(self._add_label(glabel)),
|
|
453
|
+
float(self._add_label(plabel)),
|
|
454
|
+
ious[pidx, gidx],
|
|
455
|
+
float(pscore),
|
|
456
|
+
]
|
|
457
|
+
)
|
|
458
|
+
for glabel in gann.labels
|
|
459
|
+
for plabel, pscore in zip(
|
|
460
|
+
pann.labels, pann.scores
|
|
461
|
+
)
|
|
462
|
+
]
|
|
463
|
+
)
|
|
464
|
+
elif detection.predictions:
|
|
465
|
+
for pidx, pann in enumerate(detection.predictions):
|
|
466
|
+
prediction_idx = self._add_prediction(pann.uid)
|
|
467
|
+
pairs.extend(
|
|
468
|
+
[
|
|
469
|
+
np.array(
|
|
470
|
+
[
|
|
471
|
+
float(datum_idx),
|
|
472
|
+
-1.0,
|
|
473
|
+
float(prediction_idx),
|
|
474
|
+
-1.0,
|
|
475
|
+
float(self._add_label(plabel)),
|
|
476
|
+
0.0,
|
|
477
|
+
float(pscore),
|
|
478
|
+
]
|
|
479
|
+
)
|
|
480
|
+
for plabel, pscore in zip(pann.labels, pann.scores)
|
|
481
|
+
]
|
|
527
482
|
)
|
|
528
483
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
predictions.append(
|
|
537
|
-
(
|
|
538
|
-
pidx,
|
|
539
|
-
label_idx,
|
|
540
|
-
pscore,
|
|
541
|
-
)
|
|
484
|
+
data = np.array(pairs)
|
|
485
|
+
if data.size > 0:
|
|
486
|
+
# reset filtered cache if it exists
|
|
487
|
+
self.clear_filter()
|
|
488
|
+
if self._temp_cache is None:
|
|
489
|
+
raise RuntimeError(
|
|
490
|
+
"cannot add data as evaluator has already been finalized"
|
|
542
491
|
)
|
|
543
|
-
|
|
544
|
-
self._cache_pairs(
|
|
545
|
-
uid_index=uid_index,
|
|
546
|
-
groundtruths=groundtruths,
|
|
547
|
-
predictions=predictions,
|
|
548
|
-
ious=ious,
|
|
549
|
-
)
|
|
492
|
+
self._temp_cache.append(data)
|
|
550
493
|
|
|
551
494
|
def add_bounding_boxes(
|
|
552
495
|
self,
|
|
553
|
-
detections: list[Detection],
|
|
496
|
+
detections: list[Detection[BoundingBox]],
|
|
554
497
|
show_progress: bool = False,
|
|
555
498
|
):
|
|
556
499
|
"""
|
|
@@ -584,7 +527,7 @@ class DataLoader:
|
|
|
584
527
|
|
|
585
528
|
def add_polygons(
|
|
586
529
|
self,
|
|
587
|
-
detections: list[Detection],
|
|
530
|
+
detections: list[Detection[Polygon]],
|
|
588
531
|
show_progress: bool = False,
|
|
589
532
|
):
|
|
590
533
|
"""
|
|
@@ -601,7 +544,7 @@ class DataLoader:
|
|
|
601
544
|
compute_polygon_iou(
|
|
602
545
|
np.array(
|
|
603
546
|
[
|
|
604
|
-
[gt.shape, pd.shape]
|
|
547
|
+
[gt.shape, pd.shape]
|
|
605
548
|
for pd in detection.predictions
|
|
606
549
|
for gt in detection.groundtruths
|
|
607
550
|
]
|
|
@@ -617,7 +560,7 @@ class DataLoader:
|
|
|
617
560
|
|
|
618
561
|
def add_bitmasks(
|
|
619
562
|
self,
|
|
620
|
-
detections: list[Detection],
|
|
563
|
+
detections: list[Detection[Bitmask]],
|
|
621
564
|
show_progress: bool = False,
|
|
622
565
|
):
|
|
623
566
|
"""
|
|
@@ -634,7 +577,7 @@ class DataLoader:
|
|
|
634
577
|
compute_bitmask_iou(
|
|
635
578
|
np.array(
|
|
636
579
|
[
|
|
637
|
-
[gt.mask, pd.mask]
|
|
580
|
+
[gt.mask, pd.mask]
|
|
638
581
|
for pd in detection.predictions
|
|
639
582
|
for gt in detection.groundtruths
|
|
640
583
|
]
|
|
@@ -648,7 +591,7 @@ class DataLoader:
|
|
|
648
591
|
show_progress=show_progress,
|
|
649
592
|
)
|
|
650
593
|
|
|
651
|
-
def finalize(self)
|
|
594
|
+
def finalize(self):
|
|
652
595
|
"""
|
|
653
596
|
Performs data finalization and some preprocessing steps.
|
|
654
597
|
|
|
@@ -657,65 +600,206 @@ class DataLoader:
|
|
|
657
600
|
Evaluator
|
|
658
601
|
A ready-to-use evaluator object.
|
|
659
602
|
"""
|
|
603
|
+
if self._temp_cache is None:
|
|
604
|
+
warnings.warn("evaluator is already finalized or in a bad state")
|
|
605
|
+
return self
|
|
606
|
+
elif not self._temp_cache:
|
|
607
|
+
self._detailed_pairs = np.array([], dtype=np.float64)
|
|
608
|
+
self._ranked_pairs = np.array([], dtype=np.float64)
|
|
609
|
+
self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
|
|
610
|
+
warnings.warn("no valid pairs")
|
|
611
|
+
return self
|
|
612
|
+
else:
|
|
613
|
+
self._detailed_pairs = np.concatenate(self._temp_cache, axis=0)
|
|
614
|
+
self._temp_cache = None
|
|
615
|
+
|
|
616
|
+
# order pairs by descending score, iou
|
|
617
|
+
indices = np.lexsort(
|
|
618
|
+
(
|
|
619
|
+
-self._detailed_pairs[:, 5], # iou
|
|
620
|
+
-self._detailed_pairs[:, 6], # score
|
|
621
|
+
)
|
|
622
|
+
)
|
|
623
|
+
self._detailed_pairs = self._detailed_pairs[indices]
|
|
624
|
+
self._label_metadata = compute_label_metadata(
|
|
625
|
+
ids=self._detailed_pairs[:, :5].astype(np.int32),
|
|
626
|
+
n_labels=self.n_labels,
|
|
627
|
+
)
|
|
628
|
+
self._ranked_pairs = rank_pairs(
|
|
629
|
+
detailed_pairs=self.detailed_pairs,
|
|
630
|
+
label_metadata=self._label_metadata,
|
|
631
|
+
)
|
|
632
|
+
return self
|
|
660
633
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
634
|
+
def apply_filter(
|
|
635
|
+
self,
|
|
636
|
+
datum_ids: list[str] | None = None,
|
|
637
|
+
groundtruth_ids: list[str] | None = None,
|
|
638
|
+
prediction_ids: list[str] | None = None,
|
|
639
|
+
labels: list[str] | None = None,
|
|
640
|
+
):
|
|
641
|
+
"""
|
|
642
|
+
Apply a filter on the evaluator.
|
|
667
643
|
|
|
668
|
-
|
|
644
|
+
Can be reset by calling 'clear_filter'.
|
|
669
645
|
|
|
670
|
-
|
|
671
|
-
|
|
646
|
+
Parameters
|
|
647
|
+
----------
|
|
648
|
+
datum_uids : list[str], optional
|
|
649
|
+
An optional list of string uids representing datums.
|
|
650
|
+
groundtruth_ids : list[str], optional
|
|
651
|
+
An optional list of string uids representing ground truth annotations.
|
|
652
|
+
prediction_ids : list[str], optional
|
|
653
|
+
An optional list of string uids representing prediction annotations.
|
|
654
|
+
labels : list[str], optional
|
|
655
|
+
An optional list of labels.
|
|
656
|
+
"""
|
|
657
|
+
self._filtered_detailed_pairs = self._detailed_pairs.copy()
|
|
658
|
+
self._filtered_ranked_pairs = np.array([], dtype=np.float64)
|
|
659
|
+
self._filtered_label_metadata = np.zeros(
|
|
660
|
+
(self.n_labels, 2), dtype=np.int32
|
|
672
661
|
)
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
662
|
+
|
|
663
|
+
valid_datum_indices = None
|
|
664
|
+
if datum_ids is not None:
|
|
665
|
+
if not datum_ids:
|
|
666
|
+
self._filtered_detailed_pairs = np.array([], dtype=np.float64)
|
|
667
|
+
warnings.warn("no valid filtered pairs")
|
|
668
|
+
return
|
|
669
|
+
valid_datum_indices = np.array(
|
|
670
|
+
[self.datum_id_to_index[uid] for uid in datum_ids],
|
|
671
|
+
dtype=np.int32,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
valid_groundtruth_indices = None
|
|
675
|
+
if groundtruth_ids is not None:
|
|
676
|
+
valid_groundtruth_indices = np.array(
|
|
677
|
+
[self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
|
|
678
|
+
dtype=np.int32,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
valid_prediction_indices = None
|
|
682
|
+
if prediction_ids is not None:
|
|
683
|
+
valid_prediction_indices = np.array(
|
|
684
|
+
[self.prediction_id_to_index[uid] for uid in prediction_ids],
|
|
685
|
+
dtype=np.int32,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
valid_label_indices = None
|
|
689
|
+
if labels is not None:
|
|
690
|
+
if not labels:
|
|
691
|
+
self._filtered_detailed_pairs = np.array([], dtype=np.float64)
|
|
692
|
+
warnings.warn("no valid filtered pairs")
|
|
693
|
+
return
|
|
694
|
+
valid_label_indices = np.array(
|
|
695
|
+
[self.label_to_index[label] for label in labels] + [-1]
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# filter datums
|
|
699
|
+
if valid_datum_indices is not None:
|
|
700
|
+
mask_valid_datums = np.isin(
|
|
701
|
+
self._filtered_detailed_pairs[:, 0], valid_datum_indices
|
|
702
|
+
)
|
|
703
|
+
self._filtered_detailed_pairs = self._filtered_detailed_pairs[
|
|
704
|
+
mask_valid_datums
|
|
705
|
+
]
|
|
706
|
+
|
|
707
|
+
n_rows = self._filtered_detailed_pairs.shape[0]
|
|
708
|
+
mask_invalid_groundtruths = np.zeros(n_rows, dtype=np.bool_)
|
|
709
|
+
mask_invalid_predictions = np.zeros_like(mask_invalid_groundtruths)
|
|
710
|
+
|
|
711
|
+
# filter ground truth annotations
|
|
712
|
+
if valid_groundtruth_indices is not None:
|
|
713
|
+
mask_invalid_groundtruths[
|
|
714
|
+
~np.isin(
|
|
715
|
+
self._filtered_detailed_pairs[:, 1],
|
|
716
|
+
valid_groundtruth_indices,
|
|
679
717
|
)
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
718
|
+
] = True
|
|
719
|
+
|
|
720
|
+
# filter prediction annotations
|
|
721
|
+
if valid_prediction_indices is not None:
|
|
722
|
+
mask_invalid_predictions[
|
|
723
|
+
~np.isin(
|
|
724
|
+
self._filtered_detailed_pairs[:, 2],
|
|
725
|
+
valid_prediction_indices,
|
|
684
726
|
)
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
)
|
|
705
|
-
),
|
|
706
|
-
]
|
|
707
|
-
for label_idx in range(n_labels)
|
|
727
|
+
] = True
|
|
728
|
+
|
|
729
|
+
# filter labels
|
|
730
|
+
if valid_label_indices is not None:
|
|
731
|
+
mask_invalid_groundtruths[
|
|
732
|
+
~np.isin(
|
|
733
|
+
self._filtered_detailed_pairs[:, 3], valid_label_indices
|
|
734
|
+
)
|
|
735
|
+
] = True
|
|
736
|
+
mask_invalid_predictions[
|
|
737
|
+
~np.isin(
|
|
738
|
+
self._filtered_detailed_pairs[:, 4], valid_label_indices
|
|
739
|
+
)
|
|
740
|
+
] = True
|
|
741
|
+
|
|
742
|
+
# filter cache
|
|
743
|
+
if mask_invalid_groundtruths.any():
|
|
744
|
+
invalid_groundtruth_indices = np.where(mask_invalid_groundtruths)[
|
|
745
|
+
0
|
|
708
746
|
]
|
|
747
|
+
self._filtered_detailed_pairs[
|
|
748
|
+
invalid_groundtruth_indices[:, None], (1, 3, 5)
|
|
749
|
+
] = np.array([[-1, -1, 0]])
|
|
750
|
+
|
|
751
|
+
if mask_invalid_predictions.any():
|
|
752
|
+
invalid_prediction_indices = np.where(mask_invalid_predictions)[0]
|
|
753
|
+
self._filtered_detailed_pairs[
|
|
754
|
+
invalid_prediction_indices[:, None], (2, 4, 5, 6)
|
|
755
|
+
] = np.array([[-1, -1, 0, -1]])
|
|
756
|
+
|
|
757
|
+
# filter null pairs
|
|
758
|
+
mask_null_pairs = np.all(
|
|
759
|
+
np.isclose(
|
|
760
|
+
self._filtered_detailed_pairs[:, 1:5],
|
|
761
|
+
np.array([-1.0, -1.0, -1.0, -1.0]),
|
|
762
|
+
),
|
|
763
|
+
axis=1,
|
|
709
764
|
)
|
|
765
|
+
self._filtered_detailed_pairs = self._filtered_detailed_pairs[
|
|
766
|
+
~mask_null_pairs
|
|
767
|
+
]
|
|
710
768
|
|
|
711
|
-
self.
|
|
712
|
-
self.
|
|
713
|
-
|
|
769
|
+
if self._filtered_detailed_pairs.size == 0:
|
|
770
|
+
self._ranked_pairs = np.array([], dtype=np.float64)
|
|
771
|
+
self._label_metadata = np.zeros((self.n_labels, 2), dtype=np.int32)
|
|
772
|
+
warnings.warn("no valid filtered pairs")
|
|
773
|
+
return
|
|
774
|
+
|
|
775
|
+
# sorts by score, iou with ground truth id as a tie-breaker
|
|
776
|
+
indices = np.lexsort(
|
|
777
|
+
(
|
|
778
|
+
self._filtered_detailed_pairs[:, 1], # ground truth id
|
|
779
|
+
-self._filtered_detailed_pairs[:, 5], # iou
|
|
780
|
+
-self._filtered_detailed_pairs[:, 6], # score
|
|
781
|
+
)
|
|
714
782
|
)
|
|
715
|
-
|
|
716
|
-
self.
|
|
717
|
-
self.
|
|
718
|
-
|
|
783
|
+
self._filtered_detailed_pairs = self._filtered_detailed_pairs[indices]
|
|
784
|
+
self._filtered_label_metadata = compute_label_metadata(
|
|
785
|
+
ids=self._filtered_detailed_pairs[:, :5].astype(np.int32),
|
|
786
|
+
n_labels=self.n_labels,
|
|
719
787
|
)
|
|
788
|
+
self._filtered_ranked_pairs = rank_pairs(
|
|
789
|
+
detailed_pairs=self._filtered_detailed_pairs,
|
|
790
|
+
label_metadata=self._filtered_label_metadata,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
def clear_filter(self):
|
|
794
|
+
"""Removes a filter if one exists."""
|
|
795
|
+
self._filtered_detailed_pairs = None
|
|
796
|
+
self._filtered_ranked_pairs = None
|
|
797
|
+
self._filtered_label_metadata = None
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
class DataLoader(Evaluator):
|
|
801
|
+
"""
|
|
802
|
+
Used for backwards compatibility as the Evaluator now handles ingestion.
|
|
803
|
+
"""
|
|
720
804
|
|
|
721
|
-
|
|
805
|
+
pass
|