valor-lite 0.36.0__tar.gz → 0.36.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- {valor_lite-0.36.0 → valor_lite-0.36.2}/PKG-INFO +1 -1
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/classification/__init__.py +3 -1
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/classification/manager.py +14 -21
- valor_lite-0.36.2/valor_lite/exceptions.py +15 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/object_detection/__init__.py +2 -1
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/object_detection/computation.py +0 -18
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/object_detection/manager.py +407 -418
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/semantic_segmentation/__init__.py +2 -1
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/semantic_segmentation/manager.py +13 -18
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite.egg-info/PKG-INFO +1 -1
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite.egg-info/SOURCES.txt +1 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/README.md +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/pyproject.toml +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/setup.cfg +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/LICENSE +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/__init__.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/classification/annotation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/classification/computation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/classification/metric.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/classification/numpy_compatibility.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/classification/utilities.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/object_detection/annotation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/object_detection/metric.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/object_detection/utilities.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/profiling.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/schemas.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/semantic_segmentation/annotation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/semantic_segmentation/benchmark.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/semantic_segmentation/computation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/semantic_segmentation/metric.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/semantic_segmentation/utilities.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/__init__.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/annotation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/computation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/llm/__init__.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/llm/exceptions.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/llm/generation.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/llm/instructions.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/llm/integrations.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/llm/utilities.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/llm/validators.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/manager.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite/text_generation/metric.py +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite.egg-info/dependency_links.txt +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite.egg-info/requires.txt +0 -0
- {valor_lite-0.36.0 → valor_lite-0.36.2}/valor_lite.egg-info/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ from .computation import (
|
|
|
3
3
|
compute_confusion_matrix,
|
|
4
4
|
compute_precision_recall_rocauc,
|
|
5
5
|
)
|
|
6
|
-
from .manager import DataLoader, Evaluator
|
|
6
|
+
from .manager import DataLoader, Evaluator, Filter, Metadata
|
|
7
7
|
from .metric import Metric, MetricType
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
@@ -14,4 +14,6 @@ __all__ = [
|
|
|
14
14
|
"DataLoader",
|
|
15
15
|
"Evaluator",
|
|
16
16
|
"Metric",
|
|
17
|
+
"Metadata",
|
|
18
|
+
"Filter",
|
|
17
19
|
]
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from dataclasses import asdict, dataclass
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
@@ -17,6 +16,7 @@ from valor_lite.classification.utilities import (
|
|
|
17
16
|
unpack_confusion_matrix_into_metric_list,
|
|
18
17
|
unpack_precision_recall_rocauc_into_metric_lists,
|
|
19
18
|
)
|
|
19
|
+
from valor_lite.exceptions import EmptyEvaluatorException, EmptyFilterException
|
|
20
20
|
|
|
21
21
|
"""
|
|
22
22
|
Usage
|
|
@@ -85,6 +85,18 @@ class Filter:
|
|
|
85
85
|
valid_label_indices: NDArray[np.int32] | None
|
|
86
86
|
metadata: Metadata
|
|
87
87
|
|
|
88
|
+
def __post_init__(self):
|
|
89
|
+
# validate datum mask
|
|
90
|
+
if not self.datum_mask.any():
|
|
91
|
+
raise EmptyFilterException("filter removes all datums")
|
|
92
|
+
|
|
93
|
+
# validate label indices
|
|
94
|
+
if (
|
|
95
|
+
self.valid_label_indices is not None
|
|
96
|
+
and self.valid_label_indices.size == 0
|
|
97
|
+
):
|
|
98
|
+
raise EmptyFilterException("filter removes all labels")
|
|
99
|
+
|
|
88
100
|
|
|
89
101
|
class Evaluator:
|
|
90
102
|
"""
|
|
@@ -155,7 +167,6 @@ class Evaluator:
|
|
|
155
167
|
datum_mask = np.ones(n_pairs, dtype=np.bool_)
|
|
156
168
|
if datum_ids is not None:
|
|
157
169
|
if not datum_ids:
|
|
158
|
-
warnings.warn("no valid filtered pairs")
|
|
159
170
|
return Filter(
|
|
160
171
|
datum_mask=np.zeros_like(datum_mask),
|
|
161
172
|
valid_label_indices=None,
|
|
@@ -173,7 +184,6 @@ class Evaluator:
|
|
|
173
184
|
valid_label_indices = None
|
|
174
185
|
if labels is not None:
|
|
175
186
|
if not labels:
|
|
176
|
-
warnings.warn("no valid filtered pairs")
|
|
177
187
|
return Filter(
|
|
178
188
|
datum_mask=datum_mask,
|
|
179
189
|
valid_label_indices=np.array([], dtype=np.int32),
|
|
@@ -224,21 +234,6 @@ class Evaluator:
|
|
|
224
234
|
NDArray[int32]
|
|
225
235
|
The filtered label metadata.
|
|
226
236
|
"""
|
|
227
|
-
empty_datum_mask = not filter_.datum_mask.any()
|
|
228
|
-
empty_label_mask = (
|
|
229
|
-
filter_.valid_label_indices.size == 0
|
|
230
|
-
if filter_.valid_label_indices is not None
|
|
231
|
-
else False
|
|
232
|
-
)
|
|
233
|
-
if empty_datum_mask or empty_label_mask:
|
|
234
|
-
if empty_datum_mask:
|
|
235
|
-
warnings.warn("filter removes all datums")
|
|
236
|
-
if empty_label_mask:
|
|
237
|
-
warnings.warn("filter removes all labels")
|
|
238
|
-
return (
|
|
239
|
-
np.array([], dtype=np.float64),
|
|
240
|
-
np.zeros((self.metadata.number_of_labels, 2), dtype=np.int32),
|
|
241
|
-
)
|
|
242
237
|
return filter_cache(
|
|
243
238
|
detailed_pairs=self._detailed_pairs,
|
|
244
239
|
datum_mask=filter_.datum_mask,
|
|
@@ -502,9 +497,7 @@ class Evaluator:
|
|
|
502
497
|
A ready-to-use evaluator object.
|
|
503
498
|
"""
|
|
504
499
|
if self._detailed_pairs.size == 0:
|
|
505
|
-
|
|
506
|
-
warnings.warn("evaluator is empty")
|
|
507
|
-
return self
|
|
500
|
+
raise EmptyEvaluatorException()
|
|
508
501
|
|
|
509
502
|
self._label_metadata = compute_label_metadata(
|
|
510
503
|
ids=self._detailed_pairs[:, :3].astype(np.int32),
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
class EmptyEvaluatorException(Exception):
|
|
2
|
+
def __init__(self):
|
|
3
|
+
super().__init__(
|
|
4
|
+
"evaluator cannot be finalized as it contains no data"
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EmptyFilterException(Exception):
|
|
9
|
+
def __init__(self, message: str):
|
|
10
|
+
super().__init__(message)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InternalCacheException(Exception):
|
|
14
|
+
def __init__(self, message: str):
|
|
15
|
+
super().__init__(message)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .annotation import Bitmask, BoundingBox, Detection, Polygon
|
|
2
|
-
from .manager import DataLoader, Evaluator, Filter
|
|
2
|
+
from .manager import DataLoader, Evaluator, Filter, Metadata
|
|
3
3
|
from .metric import Metric, MetricType
|
|
4
4
|
|
|
5
5
|
__all__ = [
|
|
@@ -12,4 +12,5 @@ __all__ = [
|
|
|
12
12
|
"DataLoader",
|
|
13
13
|
"Evaluator",
|
|
14
14
|
"Filter",
|
|
15
|
+
"Metadata",
|
|
15
16
|
]
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from enum import IntFlag, auto
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
@@ -280,14 +279,6 @@ def filter_cache(
|
|
|
280
279
|
)
|
|
281
280
|
detailed_pairs = detailed_pairs[~mask_null_pairs]
|
|
282
281
|
|
|
283
|
-
if detailed_pairs.size == 0:
|
|
284
|
-
warnings.warn("no valid filtered pairs")
|
|
285
|
-
return (
|
|
286
|
-
np.array([], dtype=np.float64),
|
|
287
|
-
np.array([], dtype=np.float64),
|
|
288
|
-
np.zeros((n_labels, 2), dtype=np.int32),
|
|
289
|
-
)
|
|
290
|
-
|
|
291
282
|
# sorts by score, iou with ground truth id as a tie-breaker
|
|
292
283
|
indices = np.lexsort(
|
|
293
284
|
(
|
|
@@ -441,15 +432,6 @@ def compute_precion_recall(
|
|
|
441
432
|
counts = np.zeros((n_ious, n_scores, n_labels, 6), dtype=np.float64)
|
|
442
433
|
pr_curve = np.zeros((n_ious, n_labels, 101, 2))
|
|
443
434
|
|
|
444
|
-
if ranked_pairs.size == 0:
|
|
445
|
-
warnings.warn("no valid ranked pairs")
|
|
446
|
-
return (
|
|
447
|
-
(average_precision, mAP),
|
|
448
|
-
(average_recall, mAR),
|
|
449
|
-
counts,
|
|
450
|
-
pr_curve,
|
|
451
|
-
)
|
|
452
|
-
|
|
453
435
|
# start computation
|
|
454
436
|
ids = ranked_pairs[:, :5].astype(np.int32)
|
|
455
437
|
gt_ids = ids[:, 1]
|
|
@@ -5,6 +5,11 @@ import numpy as np
|
|
|
5
5
|
from numpy.typing import NDArray
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
|
|
8
|
+
from valor_lite.exceptions import (
|
|
9
|
+
EmptyEvaluatorException,
|
|
10
|
+
EmptyFilterException,
|
|
11
|
+
InternalCacheException,
|
|
12
|
+
)
|
|
8
13
|
from valor_lite.object_detection.annotation import (
|
|
9
14
|
Bitmask,
|
|
10
15
|
BoundingBox,
|
|
@@ -94,6 +99,21 @@ class Filter:
|
|
|
94
99
|
mask_predictions: NDArray[np.bool_]
|
|
95
100
|
metadata: Metadata
|
|
96
101
|
|
|
102
|
+
def __post_init__(self):
|
|
103
|
+
# validate datums mask
|
|
104
|
+
if not self.mask_datums.any():
|
|
105
|
+
raise EmptyFilterException("filter removes all datums")
|
|
106
|
+
|
|
107
|
+
# validate annotation masks
|
|
108
|
+
no_gts = self.mask_groundtruths.all()
|
|
109
|
+
no_pds = self.mask_predictions.all()
|
|
110
|
+
if no_gts and no_pds:
|
|
111
|
+
raise EmptyFilterException("filter removes all annotations")
|
|
112
|
+
elif no_gts:
|
|
113
|
+
warnings.warn("filter removes all ground truths")
|
|
114
|
+
elif no_pds:
|
|
115
|
+
warnings.warn("filter removes all predictions")
|
|
116
|
+
|
|
97
117
|
|
|
98
118
|
class Evaluator:
|
|
99
119
|
"""
|
|
@@ -151,123 +171,407 @@ class Evaluator:
|
|
|
151
171
|
"""
|
|
152
172
|
return self._metadata
|
|
153
173
|
|
|
154
|
-
def
|
|
174
|
+
def create_filter(
|
|
175
|
+
self,
|
|
176
|
+
datum_ids: list[str] | None = None,
|
|
177
|
+
groundtruth_ids: list[str] | None = None,
|
|
178
|
+
prediction_ids: list[str] | None = None,
|
|
179
|
+
labels: list[str] | None = None,
|
|
180
|
+
) -> Filter:
|
|
155
181
|
"""
|
|
156
|
-
|
|
182
|
+
Creates a filter object.
|
|
157
183
|
|
|
158
184
|
Parameters
|
|
159
185
|
----------
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
186
|
+
datum_uids : list[str], optional
|
|
187
|
+
An optional list of string uids representing datums to keep.
|
|
188
|
+
groundtruth_ids : list[str], optional
|
|
189
|
+
An optional list of string uids representing ground truth annotations to keep.
|
|
190
|
+
prediction_ids : list[str], optional
|
|
191
|
+
An optional list of string uids representing prediction annotations to keep.
|
|
192
|
+
labels : list[str], optional
|
|
193
|
+
An optional list of labels to keep.
|
|
167
194
|
"""
|
|
168
|
-
|
|
169
|
-
if len(self.datum_id_to_index) != len(self.index_to_datum_id):
|
|
170
|
-
raise RuntimeError("datum cache size mismatch")
|
|
171
|
-
idx = len(self.datum_id_to_index)
|
|
172
|
-
self.datum_id_to_index[datum_id] = idx
|
|
173
|
-
self.index_to_datum_id.append(datum_id)
|
|
174
|
-
return self.datum_id_to_index[datum_id]
|
|
195
|
+
mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
|
|
175
196
|
|
|
176
|
-
|
|
197
|
+
# filter datums
|
|
198
|
+
if datum_ids is not None:
|
|
199
|
+
if not datum_ids:
|
|
200
|
+
raise EmptyFilterException("filter removes all datums")
|
|
201
|
+
valid_datum_indices = np.array(
|
|
202
|
+
[self.datum_id_to_index[uid] for uid in datum_ids],
|
|
203
|
+
dtype=np.int32,
|
|
204
|
+
)
|
|
205
|
+
mask_datums = np.isin(
|
|
206
|
+
self._detailed_pairs[:, 0], valid_datum_indices
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
filtered_detailed_pairs = self._detailed_pairs[mask_datums]
|
|
210
|
+
n_pairs = self._detailed_pairs[mask_datums].shape[0]
|
|
211
|
+
mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
|
|
212
|
+
mask_predictions = np.zeros_like(mask_groundtruths)
|
|
213
|
+
|
|
214
|
+
# filter by ground truth annotation ids
|
|
215
|
+
if groundtruth_ids is not None:
|
|
216
|
+
valid_groundtruth_indices = np.array(
|
|
217
|
+
[self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
|
|
218
|
+
dtype=np.int32,
|
|
219
|
+
)
|
|
220
|
+
mask_groundtruths[
|
|
221
|
+
~np.isin(
|
|
222
|
+
filtered_detailed_pairs[:, 1],
|
|
223
|
+
valid_groundtruth_indices,
|
|
224
|
+
)
|
|
225
|
+
] = True
|
|
226
|
+
|
|
227
|
+
# filter by prediction annotation ids
|
|
228
|
+
if prediction_ids is not None:
|
|
229
|
+
valid_prediction_indices = np.array(
|
|
230
|
+
[self.prediction_id_to_index[uid] for uid in prediction_ids],
|
|
231
|
+
dtype=np.int32,
|
|
232
|
+
)
|
|
233
|
+
mask_predictions[
|
|
234
|
+
~np.isin(
|
|
235
|
+
filtered_detailed_pairs[:, 2],
|
|
236
|
+
valid_prediction_indices,
|
|
237
|
+
)
|
|
238
|
+
] = True
|
|
239
|
+
|
|
240
|
+
# filter by labels
|
|
241
|
+
if labels is not None:
|
|
242
|
+
if not labels:
|
|
243
|
+
raise EmptyFilterException("filter removes all labels")
|
|
244
|
+
valid_label_indices = np.array(
|
|
245
|
+
[self.label_to_index[label] for label in labels] + [-1]
|
|
246
|
+
)
|
|
247
|
+
mask_groundtruths[
|
|
248
|
+
~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
|
|
249
|
+
] = True
|
|
250
|
+
mask_predictions[
|
|
251
|
+
~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
|
|
252
|
+
] = True
|
|
253
|
+
|
|
254
|
+
filtered_detailed_pairs, _, _ = filter_cache(
|
|
255
|
+
self._detailed_pairs,
|
|
256
|
+
mask_datums=mask_datums,
|
|
257
|
+
mask_ground_truths=mask_groundtruths,
|
|
258
|
+
mask_predictions=mask_predictions,
|
|
259
|
+
n_labels=len(self.index_to_label),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
number_of_datums = (
|
|
263
|
+
len(datum_ids)
|
|
264
|
+
if datum_ids
|
|
265
|
+
else np.unique(filtered_detailed_pairs[:, 0]).size
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return Filter(
|
|
269
|
+
mask_datums=mask_datums,
|
|
270
|
+
mask_groundtruths=mask_groundtruths,
|
|
271
|
+
mask_predictions=mask_predictions,
|
|
272
|
+
metadata=Metadata.create(
|
|
273
|
+
detailed_pairs=filtered_detailed_pairs,
|
|
274
|
+
number_of_datums=number_of_datums,
|
|
275
|
+
number_of_labels=len(self.index_to_label),
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def filter(
|
|
280
|
+
self, filter_: Filter
|
|
281
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
|
|
177
282
|
"""
|
|
178
|
-
|
|
283
|
+
Performs filtering over the internal cache.
|
|
179
284
|
|
|
180
285
|
Parameters
|
|
181
286
|
----------
|
|
182
|
-
|
|
183
|
-
The
|
|
287
|
+
filter_ : Filter
|
|
288
|
+
The filter parameterization.
|
|
184
289
|
|
|
185
290
|
Returns
|
|
186
291
|
-------
|
|
187
|
-
|
|
188
|
-
|
|
292
|
+
NDArray[float64]
|
|
293
|
+
Filtered detailed pairs.
|
|
294
|
+
NDArray[float64]
|
|
295
|
+
Filtered ranked pairs.
|
|
296
|
+
NDArray[int32]
|
|
297
|
+
Label metadata.
|
|
189
298
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
self.index_to_groundtruth_id.append(annotation_id)
|
|
198
|
-
return self.groundtruth_id_to_index[annotation_id]
|
|
299
|
+
return filter_cache(
|
|
300
|
+
detailed_pairs=self._detailed_pairs,
|
|
301
|
+
mask_datums=filter_.mask_datums,
|
|
302
|
+
mask_ground_truths=filter_.mask_groundtruths,
|
|
303
|
+
mask_predictions=filter_.mask_predictions,
|
|
304
|
+
n_labels=len(self.index_to_label),
|
|
305
|
+
)
|
|
199
306
|
|
|
200
|
-
def
|
|
307
|
+
def compute_precision_recall(
|
|
308
|
+
self,
|
|
309
|
+
iou_thresholds: list[float],
|
|
310
|
+
score_thresholds: list[float],
|
|
311
|
+
filter_: Filter | None = None,
|
|
312
|
+
) -> dict[MetricType, list[Metric]]:
|
|
201
313
|
"""
|
|
202
|
-
|
|
314
|
+
Computes all metrics except for ConfusionMatrix
|
|
203
315
|
|
|
204
316
|
Parameters
|
|
205
317
|
----------
|
|
206
|
-
|
|
207
|
-
|
|
318
|
+
iou_thresholds : list[float]
|
|
319
|
+
A list of IOU thresholds to compute metrics over.
|
|
320
|
+
score_thresholds : list[float]
|
|
321
|
+
A list of score thresholds to compute metrics over.
|
|
322
|
+
filter_ : Filter, optional
|
|
323
|
+
A collection of filter parameters and masks.
|
|
208
324
|
|
|
209
325
|
Returns
|
|
210
326
|
-------
|
|
211
|
-
|
|
212
|
-
|
|
327
|
+
dict[MetricType, list]
|
|
328
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
213
329
|
"""
|
|
214
|
-
if
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
)
|
|
218
|
-
raise RuntimeError("prediction cache size mismatch")
|
|
219
|
-
idx = len(self.prediction_id_to_index)
|
|
220
|
-
self.prediction_id_to_index[annotation_id] = idx
|
|
221
|
-
self.index_to_prediction_id.append(annotation_id)
|
|
222
|
-
return self.prediction_id_to_index[annotation_id]
|
|
330
|
+
if not iou_thresholds:
|
|
331
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
332
|
+
elif not score_thresholds:
|
|
333
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
223
334
|
|
|
224
|
-
|
|
335
|
+
if filter_ is not None:
|
|
336
|
+
_, ranked_pairs, label_metadata = self.filter(filter_=filter_)
|
|
337
|
+
else:
|
|
338
|
+
ranked_pairs = self._ranked_pairs
|
|
339
|
+
label_metadata = self._label_metadata
|
|
340
|
+
|
|
341
|
+
results = compute_precion_recall(
|
|
342
|
+
ranked_pairs=ranked_pairs,
|
|
343
|
+
label_metadata=label_metadata,
|
|
344
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
345
|
+
score_thresholds=np.array(score_thresholds),
|
|
346
|
+
)
|
|
347
|
+
return unpack_precision_recall_into_metric_lists(
|
|
348
|
+
results=results,
|
|
349
|
+
label_metadata=label_metadata,
|
|
350
|
+
iou_thresholds=iou_thresholds,
|
|
351
|
+
score_thresholds=score_thresholds,
|
|
352
|
+
index_to_label=self.index_to_label,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def compute_confusion_matrix(
|
|
356
|
+
self,
|
|
357
|
+
iou_thresholds: list[float],
|
|
358
|
+
score_thresholds: list[float],
|
|
359
|
+
filter_: Filter | None = None,
|
|
360
|
+
) -> list[Metric]:
|
|
225
361
|
"""
|
|
226
|
-
|
|
362
|
+
Computes confusion matrices at various thresholds.
|
|
227
363
|
|
|
228
364
|
Parameters
|
|
229
365
|
----------
|
|
230
|
-
|
|
231
|
-
|
|
366
|
+
iou_thresholds : list[float]
|
|
367
|
+
A list of IOU thresholds to compute metrics over.
|
|
368
|
+
score_thresholds : list[float]
|
|
369
|
+
A list of score thresholds to compute metrics over.
|
|
370
|
+
filter_ : Filter, optional
|
|
371
|
+
A collection of filter parameters and masks.
|
|
232
372
|
|
|
233
373
|
Returns
|
|
234
374
|
-------
|
|
235
|
-
|
|
236
|
-
|
|
375
|
+
list[Metric]
|
|
376
|
+
List of confusion matrices per threshold pair.
|
|
237
377
|
"""
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
self.label_to_index[label] = label_id
|
|
243
|
-
self.index_to_label.append(label)
|
|
244
|
-
label_id += 1
|
|
245
|
-
return self.label_to_index[label]
|
|
378
|
+
if not iou_thresholds:
|
|
379
|
+
raise ValueError("At least one IOU threshold must be passed.")
|
|
380
|
+
elif not score_thresholds:
|
|
381
|
+
raise ValueError("At least one score threshold must be passed.")
|
|
246
382
|
|
|
247
|
-
|
|
383
|
+
if filter_ is not None:
|
|
384
|
+
detailed_pairs, _, _ = self.filter(filter_=filter_)
|
|
385
|
+
else:
|
|
386
|
+
detailed_pairs = self._detailed_pairs
|
|
387
|
+
|
|
388
|
+
if detailed_pairs.size == 0:
|
|
389
|
+
return []
|
|
390
|
+
|
|
391
|
+
results = compute_confusion_matrix(
|
|
392
|
+
detailed_pairs=detailed_pairs,
|
|
393
|
+
iou_thresholds=np.array(iou_thresholds),
|
|
394
|
+
score_thresholds=np.array(score_thresholds),
|
|
395
|
+
)
|
|
396
|
+
return unpack_confusion_matrix_into_metric_list(
|
|
397
|
+
results=results,
|
|
398
|
+
detailed_pairs=detailed_pairs,
|
|
399
|
+
iou_thresholds=iou_thresholds,
|
|
400
|
+
score_thresholds=score_thresholds,
|
|
401
|
+
index_to_datum_id=self.index_to_datum_id,
|
|
402
|
+
index_to_groundtruth_id=self.index_to_groundtruth_id,
|
|
403
|
+
index_to_prediction_id=self.index_to_prediction_id,
|
|
404
|
+
index_to_label=self.index_to_label,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
def evaluate(
|
|
248
408
|
self,
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
):
|
|
409
|
+
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
|
|
410
|
+
score_thresholds: list[float] = [0.5],
|
|
411
|
+
filter_: Filter | None = None,
|
|
412
|
+
) -> dict[MetricType, list[Metric]]:
|
|
253
413
|
"""
|
|
254
|
-
|
|
414
|
+
Computes all available metrics.
|
|
255
415
|
|
|
256
416
|
Parameters
|
|
257
417
|
----------
|
|
258
|
-
|
|
259
|
-
A list of
|
|
260
|
-
|
|
261
|
-
A list of
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
418
|
+
iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
|
|
419
|
+
A list of IOU thresholds to compute metrics over.
|
|
420
|
+
score_thresholds : list[float], default=[0.5]
|
|
421
|
+
A list of score thresholds to compute metrics over.
|
|
422
|
+
filter_ : Filter, optional
|
|
423
|
+
A collection of filter parameters and masks.
|
|
424
|
+
|
|
425
|
+
Returns
|
|
426
|
+
-------
|
|
427
|
+
dict[MetricType, list[Metric]]
|
|
428
|
+
Lists of metrics organized by metric type.
|
|
429
|
+
"""
|
|
430
|
+
metrics = self.compute_precision_recall(
|
|
431
|
+
iou_thresholds=iou_thresholds,
|
|
432
|
+
score_thresholds=score_thresholds,
|
|
433
|
+
filter_=filter_,
|
|
434
|
+
)
|
|
435
|
+
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
436
|
+
iou_thresholds=iou_thresholds,
|
|
437
|
+
score_thresholds=score_thresholds,
|
|
438
|
+
filter_=filter_,
|
|
439
|
+
)
|
|
440
|
+
return metrics
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class DataLoader:
|
|
444
|
+
"""
|
|
445
|
+
Object Detection DataLoader
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
def __init__(self):
|
|
449
|
+
self._evaluator = Evaluator()
|
|
450
|
+
self.pairs: list[NDArray[np.float64]] = list()
|
|
451
|
+
|
|
452
|
+
def _add_datum(self, datum_id: str) -> int:
|
|
453
|
+
"""
|
|
454
|
+
Helper function for adding a datum to the cache.
|
|
455
|
+
|
|
456
|
+
Parameters
|
|
457
|
+
----------
|
|
458
|
+
datum_id : str
|
|
459
|
+
The datum identifier.
|
|
460
|
+
|
|
461
|
+
Returns
|
|
462
|
+
-------
|
|
463
|
+
int
|
|
464
|
+
The datum index.
|
|
465
|
+
"""
|
|
466
|
+
if datum_id not in self._evaluator.datum_id_to_index:
|
|
467
|
+
if len(self._evaluator.datum_id_to_index) != len(
|
|
468
|
+
self._evaluator.index_to_datum_id
|
|
469
|
+
):
|
|
470
|
+
raise InternalCacheException("datum cache size mismatch")
|
|
471
|
+
idx = len(self._evaluator.datum_id_to_index)
|
|
472
|
+
self._evaluator.datum_id_to_index[datum_id] = idx
|
|
473
|
+
self._evaluator.index_to_datum_id.append(datum_id)
|
|
474
|
+
return self._evaluator.datum_id_to_index[datum_id]
|
|
475
|
+
|
|
476
|
+
def _add_groundtruth(self, annotation_id: str) -> int:
|
|
477
|
+
"""
|
|
478
|
+
Helper function for adding a ground truth annotation identifier to the cache.
|
|
479
|
+
|
|
480
|
+
Parameters
|
|
481
|
+
----------
|
|
482
|
+
annotation_id : str
|
|
483
|
+
The ground truth annotation identifier.
|
|
484
|
+
|
|
485
|
+
Returns
|
|
486
|
+
-------
|
|
487
|
+
int
|
|
488
|
+
The ground truth annotation index.
|
|
489
|
+
"""
|
|
490
|
+
if annotation_id not in self._evaluator.groundtruth_id_to_index:
|
|
491
|
+
if len(self._evaluator.groundtruth_id_to_index) != len(
|
|
492
|
+
self._evaluator.index_to_groundtruth_id
|
|
493
|
+
):
|
|
494
|
+
raise InternalCacheException(
|
|
495
|
+
"ground truth cache size mismatch"
|
|
496
|
+
)
|
|
497
|
+
idx = len(self._evaluator.groundtruth_id_to_index)
|
|
498
|
+
self._evaluator.groundtruth_id_to_index[annotation_id] = idx
|
|
499
|
+
self._evaluator.index_to_groundtruth_id.append(annotation_id)
|
|
500
|
+
return self._evaluator.groundtruth_id_to_index[annotation_id]
|
|
501
|
+
|
|
502
|
+
def _add_prediction(self, annotation_id: str) -> int:
|
|
503
|
+
"""
|
|
504
|
+
Helper function for adding a prediction annotation identifier to the cache.
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
annotation_id : str
|
|
509
|
+
The prediction annotation identifier.
|
|
510
|
+
|
|
511
|
+
Returns
|
|
512
|
+
-------
|
|
513
|
+
int
|
|
514
|
+
The prediction annotation index.
|
|
515
|
+
"""
|
|
516
|
+
if annotation_id not in self._evaluator.prediction_id_to_index:
|
|
517
|
+
if len(self._evaluator.prediction_id_to_index) != len(
|
|
518
|
+
self._evaluator.index_to_prediction_id
|
|
519
|
+
):
|
|
520
|
+
raise InternalCacheException("prediction cache size mismatch")
|
|
521
|
+
idx = len(self._evaluator.prediction_id_to_index)
|
|
522
|
+
self._evaluator.prediction_id_to_index[annotation_id] = idx
|
|
523
|
+
self._evaluator.index_to_prediction_id.append(annotation_id)
|
|
524
|
+
return self._evaluator.prediction_id_to_index[annotation_id]
|
|
525
|
+
|
|
526
|
+
def _add_label(self, label: str) -> int:
|
|
527
|
+
"""
|
|
528
|
+
Helper function for adding a label to the cache.
|
|
529
|
+
|
|
530
|
+
Parameters
|
|
531
|
+
----------
|
|
532
|
+
label : str
|
|
533
|
+
The label associated with the annotation.
|
|
534
|
+
|
|
535
|
+
Returns
|
|
536
|
+
-------
|
|
537
|
+
int
|
|
538
|
+
Label index.
|
|
539
|
+
"""
|
|
540
|
+
label_id = len(self._evaluator.index_to_label)
|
|
541
|
+
if label not in self._evaluator.label_to_index:
|
|
542
|
+
if len(self._evaluator.label_to_index) != len(
|
|
543
|
+
self._evaluator.index_to_label
|
|
544
|
+
):
|
|
545
|
+
raise InternalCacheException("label cache size mismatch")
|
|
546
|
+
self._evaluator.label_to_index[label] = label_id
|
|
547
|
+
self._evaluator.index_to_label.append(label)
|
|
548
|
+
label_id += 1
|
|
549
|
+
return self._evaluator.label_to_index[label]
|
|
550
|
+
|
|
551
|
+
def _add_data(
|
|
552
|
+
self,
|
|
553
|
+
detections: list[Detection],
|
|
554
|
+
detection_ious: list[NDArray[np.float64]],
|
|
555
|
+
show_progress: bool = False,
|
|
556
|
+
):
|
|
557
|
+
"""
|
|
558
|
+
Adds detections to the cache.
|
|
559
|
+
|
|
560
|
+
Parameters
|
|
561
|
+
----------
|
|
562
|
+
detections : list[Detection]
|
|
563
|
+
A list of Detection objects.
|
|
564
|
+
detection_ious : list[NDArray[np.float64]]
|
|
565
|
+
A list of arrays containing IOUs per detection.
|
|
566
|
+
show_progress : bool, default=False
|
|
567
|
+
Toggle for tqdm progress bar.
|
|
568
|
+
"""
|
|
569
|
+
disable_tqdm = not show_progress
|
|
570
|
+
for detection, ious in tqdm(
|
|
571
|
+
zip(detections, detection_ious), disable=disable_tqdm
|
|
572
|
+
):
|
|
573
|
+
# cache labels and annotation pairs
|
|
574
|
+
pairs = []
|
|
271
575
|
datum_idx = self._add_datum(detection.uid)
|
|
272
576
|
if detection.groundtruths:
|
|
273
577
|
for gidx, gann in enumerate(detection.groundtruths):
|
|
@@ -352,12 +656,7 @@ class Evaluator:
|
|
|
352
656
|
|
|
353
657
|
data = np.array(pairs)
|
|
354
658
|
if data.size > 0:
|
|
355
|
-
|
|
356
|
-
if self._temp_cache is None:
|
|
357
|
-
raise RuntimeError(
|
|
358
|
-
"cannot add data as evaluator has already been finalized"
|
|
359
|
-
)
|
|
360
|
-
self._temp_cache.append(data)
|
|
659
|
+
self.pairs.append(data)
|
|
361
660
|
|
|
362
661
|
def add_bounding_boxes(
|
|
363
662
|
self,
|
|
@@ -459,7 +758,7 @@ class Evaluator:
|
|
|
459
758
|
show_progress=show_progress,
|
|
460
759
|
)
|
|
461
760
|
|
|
462
|
-
def finalize(self):
|
|
761
|
+
def finalize(self) -> Evaluator:
|
|
463
762
|
"""
|
|
464
763
|
Performs data finalization and some preprocessing steps.
|
|
465
764
|
|
|
@@ -468,347 +767,37 @@ class Evaluator:
|
|
|
468
767
|
Evaluator
|
|
469
768
|
A ready-to-use evaluator object.
|
|
470
769
|
"""
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
self._metadata = Metadata()
|
|
481
|
-
warnings.warn("no valid pairs")
|
|
482
|
-
return self
|
|
483
|
-
else:
|
|
484
|
-
self._detailed_pairs = np.concatenate(self._temp_cache, axis=0)
|
|
485
|
-
self._temp_cache = None
|
|
770
|
+
if not self.pairs:
|
|
771
|
+
raise EmptyEvaluatorException()
|
|
772
|
+
|
|
773
|
+
n_labels = len(self._evaluator.index_to_label)
|
|
774
|
+
n_datums = len(self._evaluator.index_to_datum_id)
|
|
775
|
+
|
|
776
|
+
self._evaluator._detailed_pairs = np.concatenate(self.pairs, axis=0)
|
|
777
|
+
if self._evaluator._detailed_pairs.size == 0:
|
|
778
|
+
raise EmptyEvaluatorException()
|
|
486
779
|
|
|
487
780
|
# order pairs by descending score, iou
|
|
488
781
|
indices = np.lexsort(
|
|
489
782
|
(
|
|
490
|
-
-self._detailed_pairs[:, 5], # iou
|
|
491
|
-
-self._detailed_pairs[:, 6], # score
|
|
783
|
+
-self._evaluator._detailed_pairs[:, 5], # iou
|
|
784
|
+
-self._evaluator._detailed_pairs[:, 6], # score
|
|
492
785
|
)
|
|
493
786
|
)
|
|
494
|
-
self._detailed_pairs = self._detailed_pairs[
|
|
495
|
-
|
|
496
|
-
|
|
787
|
+
self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
|
|
788
|
+
indices
|
|
789
|
+
]
|
|
790
|
+
self._evaluator._label_metadata = compute_label_metadata(
|
|
791
|
+
ids=self._evaluator._detailed_pairs[:, :5].astype(np.int32),
|
|
497
792
|
n_labels=n_labels,
|
|
498
793
|
)
|
|
499
|
-
self._ranked_pairs = rank_pairs(
|
|
500
|
-
detailed_pairs=self._detailed_pairs,
|
|
501
|
-
label_metadata=self._label_metadata,
|
|
794
|
+
self._evaluator._ranked_pairs = rank_pairs(
|
|
795
|
+
detailed_pairs=self._evaluator._detailed_pairs,
|
|
796
|
+
label_metadata=self._evaluator._label_metadata,
|
|
502
797
|
)
|
|
503
|
-
self._metadata = Metadata.create(
|
|
504
|
-
detailed_pairs=self._detailed_pairs,
|
|
798
|
+
self._evaluator._metadata = Metadata.create(
|
|
799
|
+
detailed_pairs=self._evaluator._detailed_pairs,
|
|
505
800
|
number_of_datums=n_datums,
|
|
506
801
|
number_of_labels=n_labels,
|
|
507
802
|
)
|
|
508
|
-
return self
|
|
509
|
-
|
|
510
|
-
def create_filter(
|
|
511
|
-
self,
|
|
512
|
-
datum_ids: list[str] | None = None,
|
|
513
|
-
groundtruth_ids: list[str] | None = None,
|
|
514
|
-
prediction_ids: list[str] | None = None,
|
|
515
|
-
labels: list[str] | None = None,
|
|
516
|
-
) -> Filter:
|
|
517
|
-
"""
|
|
518
|
-
Creates a filter object.
|
|
519
|
-
|
|
520
|
-
Parameters
|
|
521
|
-
----------
|
|
522
|
-
datum_uids : list[str], optional
|
|
523
|
-
An optional list of string uids representing datums to keep.
|
|
524
|
-
groundtruth_ids : list[str], optional
|
|
525
|
-
An optional list of string uids representing ground truth annotations to keep.
|
|
526
|
-
prediction_ids : list[str], optional
|
|
527
|
-
An optional list of string uids representing prediction annotations to keep.
|
|
528
|
-
labels : list[str], optional
|
|
529
|
-
An optional list of labels to keep.
|
|
530
|
-
"""
|
|
531
|
-
mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
|
|
532
|
-
|
|
533
|
-
# filter datums
|
|
534
|
-
if datum_ids is not None:
|
|
535
|
-
if not datum_ids:
|
|
536
|
-
warnings.warn("creating a filter that removes all datums")
|
|
537
|
-
return Filter(
|
|
538
|
-
mask_datums=np.zeros_like(mask_datums),
|
|
539
|
-
mask_groundtruths=np.array([], dtype=np.bool_),
|
|
540
|
-
mask_predictions=np.array([], dtype=np.bool_),
|
|
541
|
-
metadata=Metadata(),
|
|
542
|
-
)
|
|
543
|
-
valid_datum_indices = np.array(
|
|
544
|
-
[self.datum_id_to_index[uid] for uid in datum_ids],
|
|
545
|
-
dtype=np.int32,
|
|
546
|
-
)
|
|
547
|
-
mask_datums = np.isin(
|
|
548
|
-
self._detailed_pairs[:, 0], valid_datum_indices
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
filtered_detailed_pairs = self._detailed_pairs[mask_datums]
|
|
552
|
-
n_pairs = self._detailed_pairs[mask_datums].shape[0]
|
|
553
|
-
mask_groundtruths = np.zeros(n_pairs, dtype=np.bool_)
|
|
554
|
-
mask_predictions = np.zeros_like(mask_groundtruths)
|
|
555
|
-
|
|
556
|
-
# filter by ground truth annotation ids
|
|
557
|
-
if groundtruth_ids is not None:
|
|
558
|
-
if not groundtruth_ids:
|
|
559
|
-
warnings.warn(
|
|
560
|
-
"creating a filter that removes all ground truths"
|
|
561
|
-
)
|
|
562
|
-
valid_groundtruth_indices = np.array(
|
|
563
|
-
[self.groundtruth_id_to_index[uid] for uid in groundtruth_ids],
|
|
564
|
-
dtype=np.int32,
|
|
565
|
-
)
|
|
566
|
-
mask_groundtruths[
|
|
567
|
-
~np.isin(
|
|
568
|
-
filtered_detailed_pairs[:, 1],
|
|
569
|
-
valid_groundtruth_indices,
|
|
570
|
-
)
|
|
571
|
-
] = True
|
|
572
|
-
|
|
573
|
-
# filter by prediction annotation ids
|
|
574
|
-
if prediction_ids is not None:
|
|
575
|
-
if not prediction_ids:
|
|
576
|
-
warnings.warn("creating a filter that removes all predictions")
|
|
577
|
-
valid_prediction_indices = np.array(
|
|
578
|
-
[self.prediction_id_to_index[uid] for uid in prediction_ids],
|
|
579
|
-
dtype=np.int32,
|
|
580
|
-
)
|
|
581
|
-
mask_predictions[
|
|
582
|
-
~np.isin(
|
|
583
|
-
filtered_detailed_pairs[:, 2],
|
|
584
|
-
valid_prediction_indices,
|
|
585
|
-
)
|
|
586
|
-
] = True
|
|
587
|
-
|
|
588
|
-
# filter by labels
|
|
589
|
-
if labels is not None:
|
|
590
|
-
if not labels:
|
|
591
|
-
warnings.warn("creating a filter that removes all labels")
|
|
592
|
-
return Filter(
|
|
593
|
-
mask_datums=mask_datums,
|
|
594
|
-
mask_groundtruths=np.ones_like(mask_datums),
|
|
595
|
-
mask_predictions=np.ones_like(mask_datums),
|
|
596
|
-
metadata=Metadata(),
|
|
597
|
-
)
|
|
598
|
-
valid_label_indices = np.array(
|
|
599
|
-
[self.label_to_index[label] for label in labels] + [-1]
|
|
600
|
-
)
|
|
601
|
-
mask_groundtruths[
|
|
602
|
-
~np.isin(filtered_detailed_pairs[:, 3], valid_label_indices)
|
|
603
|
-
] = True
|
|
604
|
-
mask_predictions[
|
|
605
|
-
~np.isin(filtered_detailed_pairs[:, 4], valid_label_indices)
|
|
606
|
-
] = True
|
|
607
|
-
|
|
608
|
-
filtered_detailed_pairs, _, _ = filter_cache(
|
|
609
|
-
self._detailed_pairs,
|
|
610
|
-
mask_datums=mask_datums,
|
|
611
|
-
mask_ground_truths=mask_groundtruths,
|
|
612
|
-
mask_predictions=mask_predictions,
|
|
613
|
-
n_labels=len(self.index_to_label),
|
|
614
|
-
)
|
|
615
|
-
|
|
616
|
-
number_of_datums = (
|
|
617
|
-
len(datum_ids)
|
|
618
|
-
if datum_ids
|
|
619
|
-
else np.unique(filtered_detailed_pairs[:, 0]).size
|
|
620
|
-
)
|
|
621
|
-
|
|
622
|
-
return Filter(
|
|
623
|
-
mask_datums=mask_datums,
|
|
624
|
-
mask_groundtruths=mask_groundtruths,
|
|
625
|
-
mask_predictions=mask_predictions,
|
|
626
|
-
metadata=Metadata.create(
|
|
627
|
-
detailed_pairs=filtered_detailed_pairs,
|
|
628
|
-
number_of_datums=number_of_datums,
|
|
629
|
-
number_of_labels=len(self.index_to_label),
|
|
630
|
-
),
|
|
631
|
-
)
|
|
632
|
-
|
|
633
|
-
def filter(
|
|
634
|
-
self, filter_: Filter
|
|
635
|
-
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32],]:
|
|
636
|
-
"""
|
|
637
|
-
Performs filtering over the internal cache.
|
|
638
|
-
|
|
639
|
-
Parameters
|
|
640
|
-
----------
|
|
641
|
-
filter_ : Filter
|
|
642
|
-
The filter parameterization.
|
|
643
|
-
|
|
644
|
-
Returns
|
|
645
|
-
-------
|
|
646
|
-
NDArray[float64]
|
|
647
|
-
Filtered detailed pairs.
|
|
648
|
-
NDArray[float64]
|
|
649
|
-
Filtered ranked pairs.
|
|
650
|
-
NDArray[int32]
|
|
651
|
-
Label metadata.
|
|
652
|
-
"""
|
|
653
|
-
if not filter_.mask_datums.any():
|
|
654
|
-
warnings.warn("filter removed all datums")
|
|
655
|
-
return (
|
|
656
|
-
np.array([], dtype=np.float64),
|
|
657
|
-
np.array([], dtype=np.float64),
|
|
658
|
-
np.zeros((self.metadata.number_of_labels, 2), dtype=np.int32),
|
|
659
|
-
)
|
|
660
|
-
if filter_.mask_groundtruths.all():
|
|
661
|
-
warnings.warn("filter removed all ground truths")
|
|
662
|
-
if filter_.mask_predictions.all():
|
|
663
|
-
warnings.warn("filter removed all predictions")
|
|
664
|
-
return filter_cache(
|
|
665
|
-
detailed_pairs=self._detailed_pairs,
|
|
666
|
-
mask_datums=filter_.mask_datums,
|
|
667
|
-
mask_ground_truths=filter_.mask_groundtruths,
|
|
668
|
-
mask_predictions=filter_.mask_predictions,
|
|
669
|
-
n_labels=len(self.index_to_label),
|
|
670
|
-
)
|
|
671
|
-
|
|
672
|
-
def compute_precision_recall(
|
|
673
|
-
self,
|
|
674
|
-
iou_thresholds: list[float],
|
|
675
|
-
score_thresholds: list[float],
|
|
676
|
-
filter_: Filter | None = None,
|
|
677
|
-
) -> dict[MetricType, list[Metric]]:
|
|
678
|
-
"""
|
|
679
|
-
Computes all metrics except for ConfusionMatrix
|
|
680
|
-
|
|
681
|
-
Parameters
|
|
682
|
-
----------
|
|
683
|
-
iou_thresholds : list[float]
|
|
684
|
-
A list of IOU thresholds to compute metrics over.
|
|
685
|
-
score_thresholds : list[float]
|
|
686
|
-
A list of score thresholds to compute metrics over.
|
|
687
|
-
filter_ : Filter, optional
|
|
688
|
-
A collection of filter parameters and masks.
|
|
689
|
-
|
|
690
|
-
Returns
|
|
691
|
-
-------
|
|
692
|
-
dict[MetricType, list]
|
|
693
|
-
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
694
|
-
"""
|
|
695
|
-
if not iou_thresholds:
|
|
696
|
-
raise ValueError("At least one IOU threshold must be passed.")
|
|
697
|
-
elif not score_thresholds:
|
|
698
|
-
raise ValueError("At least one score threshold must be passed.")
|
|
699
|
-
|
|
700
|
-
if filter_ is not None:
|
|
701
|
-
_, ranked_pairs, label_metadata = self.filter(filter_=filter_)
|
|
702
|
-
else:
|
|
703
|
-
ranked_pairs = self._ranked_pairs
|
|
704
|
-
label_metadata = self._label_metadata
|
|
705
|
-
|
|
706
|
-
results = compute_precion_recall(
|
|
707
|
-
ranked_pairs=ranked_pairs,
|
|
708
|
-
label_metadata=label_metadata,
|
|
709
|
-
iou_thresholds=np.array(iou_thresholds),
|
|
710
|
-
score_thresholds=np.array(score_thresholds),
|
|
711
|
-
)
|
|
712
|
-
return unpack_precision_recall_into_metric_lists(
|
|
713
|
-
results=results,
|
|
714
|
-
label_metadata=label_metadata,
|
|
715
|
-
iou_thresholds=iou_thresholds,
|
|
716
|
-
score_thresholds=score_thresholds,
|
|
717
|
-
index_to_label=self.index_to_label,
|
|
718
|
-
)
|
|
719
|
-
|
|
720
|
-
def compute_confusion_matrix(
|
|
721
|
-
self,
|
|
722
|
-
iou_thresholds: list[float],
|
|
723
|
-
score_thresholds: list[float],
|
|
724
|
-
filter_: Filter | None = None,
|
|
725
|
-
) -> list[Metric]:
|
|
726
|
-
"""
|
|
727
|
-
Computes confusion matrices at various thresholds.
|
|
728
|
-
|
|
729
|
-
Parameters
|
|
730
|
-
----------
|
|
731
|
-
iou_thresholds : list[float]
|
|
732
|
-
A list of IOU thresholds to compute metrics over.
|
|
733
|
-
score_thresholds : list[float]
|
|
734
|
-
A list of score thresholds to compute metrics over.
|
|
735
|
-
filter_ : Filter, optional
|
|
736
|
-
A collection of filter parameters and masks.
|
|
737
|
-
|
|
738
|
-
Returns
|
|
739
|
-
-------
|
|
740
|
-
list[Metric]
|
|
741
|
-
List of confusion matrices per threshold pair.
|
|
742
|
-
"""
|
|
743
|
-
if not iou_thresholds:
|
|
744
|
-
raise ValueError("At least one IOU threshold must be passed.")
|
|
745
|
-
elif not score_thresholds:
|
|
746
|
-
raise ValueError("At least one score threshold must be passed.")
|
|
747
|
-
|
|
748
|
-
if filter_ is not None:
|
|
749
|
-
detailed_pairs, _, _ = self.filter(filter_=filter_)
|
|
750
|
-
else:
|
|
751
|
-
detailed_pairs = self._detailed_pairs
|
|
752
|
-
|
|
753
|
-
if detailed_pairs.size == 0:
|
|
754
|
-
warnings.warn("attempted to compute over an empty set")
|
|
755
|
-
return []
|
|
756
|
-
|
|
757
|
-
results = compute_confusion_matrix(
|
|
758
|
-
detailed_pairs=detailed_pairs,
|
|
759
|
-
iou_thresholds=np.array(iou_thresholds),
|
|
760
|
-
score_thresholds=np.array(score_thresholds),
|
|
761
|
-
)
|
|
762
|
-
return unpack_confusion_matrix_into_metric_list(
|
|
763
|
-
results=results,
|
|
764
|
-
detailed_pairs=detailed_pairs,
|
|
765
|
-
iou_thresholds=iou_thresholds,
|
|
766
|
-
score_thresholds=score_thresholds,
|
|
767
|
-
index_to_datum_id=self.index_to_datum_id,
|
|
768
|
-
index_to_groundtruth_id=self.index_to_groundtruth_id,
|
|
769
|
-
index_to_prediction_id=self.index_to_prediction_id,
|
|
770
|
-
index_to_label=self.index_to_label,
|
|
771
|
-
)
|
|
772
|
-
|
|
773
|
-
def evaluate(
|
|
774
|
-
self,
|
|
775
|
-
iou_thresholds: list[float] = [0.1, 0.5, 0.75],
|
|
776
|
-
score_thresholds: list[float] = [0.5],
|
|
777
|
-
filter_: Filter | None = None,
|
|
778
|
-
) -> dict[MetricType, list[Metric]]:
|
|
779
|
-
"""
|
|
780
|
-
Computes all available metrics.
|
|
781
|
-
|
|
782
|
-
Parameters
|
|
783
|
-
----------
|
|
784
|
-
iou_thresholds : list[float], default=[0.1, 0.5, 0.75]
|
|
785
|
-
A list of IOU thresholds to compute metrics over.
|
|
786
|
-
score_thresholds : list[float], default=[0.5]
|
|
787
|
-
A list of score thresholds to compute metrics over.
|
|
788
|
-
filter_ : Filter, optional
|
|
789
|
-
A collection of filter parameters and masks.
|
|
790
|
-
|
|
791
|
-
Returns
|
|
792
|
-
-------
|
|
793
|
-
dict[MetricType, list[Metric]]
|
|
794
|
-
Lists of metrics organized by metric type.
|
|
795
|
-
"""
|
|
796
|
-
metrics = self.compute_precision_recall(
|
|
797
|
-
iou_thresholds=iou_thresholds,
|
|
798
|
-
score_thresholds=score_thresholds,
|
|
799
|
-
filter_=filter_,
|
|
800
|
-
)
|
|
801
|
-
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
802
|
-
iou_thresholds=iou_thresholds,
|
|
803
|
-
score_thresholds=score_thresholds,
|
|
804
|
-
filter_=filter_,
|
|
805
|
-
)
|
|
806
|
-
return metrics
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
class DataLoader(Evaluator):
|
|
810
|
-
"""
|
|
811
|
-
Used for backwards compatibility as the Evaluator now handles ingestion.
|
|
812
|
-
"""
|
|
813
|
-
|
|
814
|
-
pass
|
|
803
|
+
return self._evaluator
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .annotation import Bitmask, Segmentation
|
|
2
|
-
from .manager import DataLoader, Evaluator, Filter
|
|
2
|
+
from .manager import DataLoader, Evaluator, Filter, Metadata
|
|
3
3
|
from .metric import Metric, MetricType
|
|
4
4
|
|
|
5
5
|
__all__ = [
|
|
@@ -10,4 +10,5 @@ __all__ = [
|
|
|
10
10
|
"Metric",
|
|
11
11
|
"MetricType",
|
|
12
12
|
"Filter",
|
|
13
|
+
"Metadata",
|
|
13
14
|
]
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from dataclasses import asdict, dataclass
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
from numpy.typing import NDArray
|
|
6
5
|
from tqdm import tqdm
|
|
7
6
|
|
|
7
|
+
from valor_lite.exceptions import EmptyEvaluatorException, EmptyFilterException
|
|
8
8
|
from valor_lite.semantic_segmentation.annotation import Segmentation
|
|
9
9
|
from valor_lite.semantic_segmentation.computation import (
|
|
10
10
|
compute_intermediate_confusion_matrices,
|
|
@@ -71,6 +71,15 @@ class Filter:
|
|
|
71
71
|
label_mask: NDArray[np.bool_]
|
|
72
72
|
metadata: Metadata
|
|
73
73
|
|
|
74
|
+
def __post_init__(self):
|
|
75
|
+
# validate datum mask
|
|
76
|
+
if not self.datum_mask.any():
|
|
77
|
+
raise EmptyFilterException("filter removes all datums")
|
|
78
|
+
|
|
79
|
+
# validate label mask
|
|
80
|
+
if self.label_mask.all():
|
|
81
|
+
raise EmptyFilterException("filter removes all labels")
|
|
82
|
+
|
|
74
83
|
|
|
75
84
|
class Evaluator:
|
|
76
85
|
"""
|
|
@@ -140,10 +149,9 @@ class Evaluator:
|
|
|
140
149
|
label_mask = np.zeros(
|
|
141
150
|
self.metadata.number_of_labels + 1, dtype=np.bool_
|
|
142
151
|
)
|
|
152
|
+
|
|
143
153
|
if datum_ids is not None:
|
|
144
154
|
if not datum_ids:
|
|
145
|
-
filtered_confusion_matrices = np.array([], dtype=np.int64)
|
|
146
|
-
warnings.warn("datum filter results in empty data array")
|
|
147
155
|
return Filter(
|
|
148
156
|
datum_mask=np.zeros_like(datum_mask),
|
|
149
157
|
label_mask=label_mask,
|
|
@@ -159,10 +167,9 @@ class Evaluator:
|
|
|
159
167
|
== datum_id_array.reshape(1, -1)
|
|
160
168
|
).any(axis=1)
|
|
161
169
|
datum_mask[~mask_valid_datums] = False
|
|
170
|
+
|
|
162
171
|
if labels is not None:
|
|
163
172
|
if not labels:
|
|
164
|
-
filtered_confusion_matrices = np.array([], dtype=np.int64)
|
|
165
|
-
warnings.warn("label filter results in empty data array")
|
|
166
173
|
return Filter(
|
|
167
174
|
datum_mask=datum_mask,
|
|
168
175
|
label_mask=np.ones_like(label_mask),
|
|
@@ -211,18 +218,6 @@ class Evaluator:
|
|
|
211
218
|
NDArray[int64]
|
|
212
219
|
Filtered label metadata
|
|
213
220
|
"""
|
|
214
|
-
empty_datum_mask = not filter_.datum_mask.any()
|
|
215
|
-
empty_label_mask = filter_.label_mask.all()
|
|
216
|
-
if empty_datum_mask or empty_label_mask:
|
|
217
|
-
if empty_datum_mask:
|
|
218
|
-
warnings.warn("filter does not allow any datum")
|
|
219
|
-
if empty_label_mask:
|
|
220
|
-
warnings.warn("filter removes all labels")
|
|
221
|
-
return (
|
|
222
|
-
np.array([], dtype=np.int64),
|
|
223
|
-
np.zeros((self.metadata.number_of_labels, 2), dtype=np.int64),
|
|
224
|
-
)
|
|
225
|
-
|
|
226
221
|
return filter_cache(
|
|
227
222
|
confusion_matrices=self._confusion_matrices.copy(),
|
|
228
223
|
datum_mask=filter_.datum_mask,
|
|
@@ -408,7 +403,7 @@ class DataLoader:
|
|
|
408
403
|
"""
|
|
409
404
|
|
|
410
405
|
if len(self.matrices) == 0:
|
|
411
|
-
raise
|
|
406
|
+
raise EmptyEvaluatorException()
|
|
412
407
|
|
|
413
408
|
n_labels = len(self._evaluator.index_to_label)
|
|
414
409
|
n_datums = len(self._evaluator.index_to_datum_id)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|