valor-lite 0.35.0__py3-none-any.whl → 0.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/classification/computation.py +147 -38
- valor_lite/classification/manager.py +221 -235
- valor_lite/classification/metric.py +5 -8
- valor_lite/classification/utilities.py +18 -14
- valor_lite/exceptions.py +15 -0
- valor_lite/object_detection/__init__.py +2 -1
- valor_lite/object_detection/computation.py +83 -10
- valor_lite/object_detection/manager.py +313 -315
- valor_lite/semantic_segmentation/__init__.py +3 -3
- valor_lite/semantic_segmentation/annotation.py +32 -103
- valor_lite/semantic_segmentation/benchmark.py +87 -1
- valor_lite/semantic_segmentation/computation.py +96 -14
- valor_lite/semantic_segmentation/manager.py +193 -221
- valor_lite/semantic_segmentation/utilities.py +3 -3
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/METADATA +2 -2
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/RECORD +18 -17
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/WHEEL +1 -1
- {valor_lite-0.35.0.dist-info → valor_lite-0.36.1.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from
|
|
2
|
-
from dataclasses import dataclass
|
|
1
|
+
from dataclasses import asdict, dataclass
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
from numpy.typing import NDArray
|
|
@@ -8,13 +7,16 @@ from tqdm import tqdm
|
|
|
8
7
|
from valor_lite.classification.annotation import Classification
|
|
9
8
|
from valor_lite.classification.computation import (
|
|
10
9
|
compute_confusion_matrix,
|
|
10
|
+
compute_label_metadata,
|
|
11
11
|
compute_precision_recall_rocauc,
|
|
12
|
+
filter_cache,
|
|
12
13
|
)
|
|
13
14
|
from valor_lite.classification.metric import Metric, MetricType
|
|
14
15
|
from valor_lite.classification.utilities import (
|
|
15
16
|
unpack_confusion_matrix_into_metric_list,
|
|
16
17
|
unpack_precision_recall_rocauc_into_metric_lists,
|
|
17
18
|
)
|
|
19
|
+
from valor_lite.exceptions import EmptyEvaluatorException, EmptyFilterException
|
|
18
20
|
|
|
19
21
|
"""
|
|
20
22
|
Usage
|
|
@@ -37,11 +39,63 @@ filtered_metrics = evaluator.evaluate(filter_mask=filter_mask)
|
|
|
37
39
|
"""
|
|
38
40
|
|
|
39
41
|
|
|
42
|
+
@dataclass
|
|
43
|
+
class Metadata:
|
|
44
|
+
number_of_datums: int = 0
|
|
45
|
+
number_of_ground_truths: int = 0
|
|
46
|
+
number_of_predictions: int = 0
|
|
47
|
+
number_of_labels: int = 0
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def create(
|
|
51
|
+
cls,
|
|
52
|
+
detailed_pairs: NDArray[np.float64],
|
|
53
|
+
number_of_datums: int,
|
|
54
|
+
number_of_labels: int,
|
|
55
|
+
):
|
|
56
|
+
# count number of unique ground truths
|
|
57
|
+
mask_valid_gts = detailed_pairs[:, 1] >= 0
|
|
58
|
+
unique_ids = np.unique(
|
|
59
|
+
detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], # type: ignore - np.ix_ typing
|
|
60
|
+
axis=0,
|
|
61
|
+
)
|
|
62
|
+
number_of_ground_truths = int(unique_ids.shape[0])
|
|
63
|
+
|
|
64
|
+
# count number of unqiue predictions
|
|
65
|
+
mask_valid_pds = detailed_pairs[:, 2] >= 0
|
|
66
|
+
unique_ids = np.unique(
|
|
67
|
+
detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
|
|
68
|
+
)
|
|
69
|
+
number_of_predictions = int(unique_ids.shape[0])
|
|
70
|
+
|
|
71
|
+
return cls(
|
|
72
|
+
number_of_datums=number_of_datums,
|
|
73
|
+
number_of_ground_truths=number_of_ground_truths,
|
|
74
|
+
number_of_predictions=number_of_predictions,
|
|
75
|
+
number_of_labels=number_of_labels,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def to_dict(self) -> dict[str, int | bool]:
|
|
79
|
+
return asdict(self)
|
|
80
|
+
|
|
81
|
+
|
|
40
82
|
@dataclass
|
|
41
83
|
class Filter:
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
84
|
+
datum_mask: NDArray[np.bool_]
|
|
85
|
+
valid_label_indices: NDArray[np.int32] | None
|
|
86
|
+
metadata: Metadata
|
|
87
|
+
|
|
88
|
+
def __post_init__(self):
|
|
89
|
+
# validate datum mask
|
|
90
|
+
if not self.datum_mask.any():
|
|
91
|
+
raise EmptyFilterException("filter removes all datums")
|
|
92
|
+
|
|
93
|
+
# validate label indices
|
|
94
|
+
if (
|
|
95
|
+
self.valid_label_indices is not None
|
|
96
|
+
and self.valid_label_indices.size == 0
|
|
97
|
+
):
|
|
98
|
+
raise EmptyFilterException("filter removes all labels")
|
|
45
99
|
|
|
46
100
|
|
|
47
101
|
class Evaluator:
|
|
@@ -50,25 +104,21 @@ class Evaluator:
|
|
|
50
104
|
"""
|
|
51
105
|
|
|
52
106
|
def __init__(self):
|
|
107
|
+
# external references
|
|
108
|
+
self.datum_id_to_index: dict[str, int] = {}
|
|
109
|
+
self.label_to_index: dict[str, int] = {}
|
|
53
110
|
|
|
54
|
-
|
|
55
|
-
self.
|
|
56
|
-
self.n_groundtruths = 0
|
|
57
|
-
self.n_predictions = 0
|
|
58
|
-
self.n_labels = 0
|
|
59
|
-
|
|
60
|
-
# datum reference
|
|
61
|
-
self.uid_to_index: dict[str, int] = dict()
|
|
62
|
-
self.index_to_uid: dict[int, str] = dict()
|
|
63
|
-
|
|
64
|
-
# label reference
|
|
65
|
-
self.label_to_index: dict[str, int] = dict()
|
|
66
|
-
self.index_to_label: dict[int, str] = dict()
|
|
111
|
+
self.index_to_datum_id: list[str] = []
|
|
112
|
+
self.index_to_label: list[str] = []
|
|
67
113
|
|
|
68
|
-
#
|
|
114
|
+
# internal caches
|
|
69
115
|
self._detailed_pairs = np.array([])
|
|
70
116
|
self._label_metadata = np.array([], dtype=np.int32)
|
|
71
|
-
self.
|
|
117
|
+
self._metadata = Metadata()
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def metadata(self) -> Metadata:
|
|
121
|
+
return self._metadata
|
|
72
122
|
|
|
73
123
|
@property
|
|
74
124
|
def ignored_prediction_labels(self) -> list[str]:
|
|
@@ -92,97 +142,103 @@ class Evaluator:
|
|
|
92
142
|
self.index_to_label[label_id] for label_id in (glabels - plabels)
|
|
93
143
|
]
|
|
94
144
|
|
|
95
|
-
@property
|
|
96
|
-
def metadata(self) -> dict:
|
|
97
|
-
"""
|
|
98
|
-
Evaluation metadata.
|
|
99
|
-
"""
|
|
100
|
-
return {
|
|
101
|
-
"n_datums": self.n_datums,
|
|
102
|
-
"n_groundtruths": self.n_groundtruths,
|
|
103
|
-
"n_predictions": self.n_predictions,
|
|
104
|
-
"n_labels": self.n_labels,
|
|
105
|
-
"ignored_prediction_labels": self.ignored_prediction_labels,
|
|
106
|
-
"missing_prediction_labels": self.missing_prediction_labels,
|
|
107
|
-
}
|
|
108
|
-
|
|
109
145
|
def create_filter(
|
|
110
146
|
self,
|
|
111
|
-
|
|
112
|
-
labels: list[str] |
|
|
147
|
+
datum_ids: list[str] | None = None,
|
|
148
|
+
labels: list[str] | None = None,
|
|
113
149
|
) -> Filter:
|
|
114
150
|
"""
|
|
115
|
-
Creates a
|
|
151
|
+
Creates a filter object.
|
|
116
152
|
|
|
117
153
|
Parameters
|
|
118
154
|
----------
|
|
119
|
-
datum_uids : list[str]
|
|
120
|
-
An optional list of string uids
|
|
121
|
-
labels : list[str]
|
|
122
|
-
An optional list of labels
|
|
155
|
+
datum_uids : list[str], optional
|
|
156
|
+
An optional list of string uids representing datums.
|
|
157
|
+
labels : list[str], optional
|
|
158
|
+
An optional list of labels.
|
|
123
159
|
|
|
124
160
|
Returns
|
|
125
161
|
-------
|
|
126
162
|
Filter
|
|
127
|
-
|
|
163
|
+
The filter object representing the input parameters.
|
|
128
164
|
"""
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
if datum_uids is not None:
|
|
139
|
-
if isinstance(datum_uids, list):
|
|
140
|
-
datum_uids = np.array(
|
|
141
|
-
[self.uid_to_index[uid] for uid in datum_uids],
|
|
142
|
-
dtype=np.int32,
|
|
165
|
+
# create datum mask
|
|
166
|
+
n_pairs = self._detailed_pairs.shape[0]
|
|
167
|
+
datum_mask = np.ones(n_pairs, dtype=np.bool_)
|
|
168
|
+
if datum_ids is not None:
|
|
169
|
+
if not datum_ids:
|
|
170
|
+
return Filter(
|
|
171
|
+
datum_mask=np.zeros_like(datum_mask),
|
|
172
|
+
valid_label_indices=None,
|
|
173
|
+
metadata=Metadata(),
|
|
143
174
|
)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
np.
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
mask[datum_uids] = True
|
|
152
|
-
mask_datums &= mask
|
|
175
|
+
valid_datum_indices = np.array(
|
|
176
|
+
[self.datum_id_to_index[uid] for uid in datum_ids],
|
|
177
|
+
dtype=np.int32,
|
|
178
|
+
)
|
|
179
|
+
datum_mask = np.isin(
|
|
180
|
+
self._detailed_pairs[:, 0], valid_datum_indices
|
|
181
|
+
)
|
|
153
182
|
|
|
183
|
+
# collect valid label indices
|
|
184
|
+
valid_label_indices = None
|
|
154
185
|
if labels is not None:
|
|
155
|
-
if
|
|
156
|
-
|
|
157
|
-
|
|
186
|
+
if not labels:
|
|
187
|
+
return Filter(
|
|
188
|
+
datum_mask=datum_mask,
|
|
189
|
+
valid_label_indices=np.array([], dtype=np.int32),
|
|
190
|
+
metadata=Metadata(),
|
|
158
191
|
)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
np.isin(self._detailed_pairs[:, 1].astype(int), labels)
|
|
162
|
-
] = True
|
|
163
|
-
mask_pairs &= mask
|
|
164
|
-
|
|
165
|
-
mask = np.zeros_like(mask_labels, dtype=np.bool_)
|
|
166
|
-
mask[labels] = True
|
|
167
|
-
mask_labels &= mask
|
|
168
|
-
|
|
169
|
-
mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
|
|
170
|
-
label_metadata_per_datum = self._label_metadata_per_datum.copy()
|
|
171
|
-
label_metadata_per_datum[:, ~mask] = 0
|
|
172
|
-
|
|
173
|
-
label_metadata: NDArray[np.int32] = np.transpose(
|
|
174
|
-
np.sum(
|
|
175
|
-
label_metadata_per_datum,
|
|
176
|
-
axis=1,
|
|
192
|
+
valid_label_indices = np.array(
|
|
193
|
+
[self.label_to_index[label] for label in labels] + [-1]
|
|
177
194
|
)
|
|
195
|
+
|
|
196
|
+
filtered_detailed_pairs, _ = filter_cache(
|
|
197
|
+
detailed_pairs=self._detailed_pairs,
|
|
198
|
+
datum_mask=datum_mask,
|
|
199
|
+
valid_label_indices=valid_label_indices,
|
|
200
|
+
n_labels=self.metadata.number_of_labels,
|
|
178
201
|
)
|
|
179
202
|
|
|
180
|
-
|
|
203
|
+
number_of_datums = (
|
|
204
|
+
len(datum_ids)
|
|
205
|
+
if datum_ids is not None
|
|
206
|
+
else self.metadata.number_of_datums
|
|
207
|
+
)
|
|
181
208
|
|
|
182
209
|
return Filter(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
210
|
+
datum_mask=datum_mask,
|
|
211
|
+
valid_label_indices=valid_label_indices,
|
|
212
|
+
metadata=Metadata.create(
|
|
213
|
+
detailed_pairs=filtered_detailed_pairs,
|
|
214
|
+
number_of_datums=number_of_datums,
|
|
215
|
+
number_of_labels=self.metadata.number_of_labels,
|
|
216
|
+
),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def filter(
|
|
220
|
+
self, filter_: Filter
|
|
221
|
+
) -> tuple[NDArray[np.float64], NDArray[np.int32]]:
|
|
222
|
+
"""
|
|
223
|
+
Performs filtering over the internal cache.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
filter_ : Filter
|
|
228
|
+
The filter object representation.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
NDArray[float64]
|
|
233
|
+
The filtered detailed pairs.
|
|
234
|
+
NDArray[int32]
|
|
235
|
+
The filtered label metadata.
|
|
236
|
+
"""
|
|
237
|
+
return filter_cache(
|
|
238
|
+
detailed_pairs=self._detailed_pairs,
|
|
239
|
+
datum_mask=filter_.datum_mask,
|
|
240
|
+
valid_label_indices=filter_.valid_label_indices,
|
|
241
|
+
n_labels=self.metadata.number_of_labels,
|
|
186
242
|
)
|
|
187
243
|
|
|
188
244
|
def compute_precision_recall_rocauc(
|
|
@@ -201,31 +257,29 @@ class Evaluator:
|
|
|
201
257
|
hardmax : bool
|
|
202
258
|
Toggles whether a hardmax is applied to predictions.
|
|
203
259
|
filter_ : Filter, optional
|
|
204
|
-
|
|
260
|
+
Applies a filter to the internal cache.
|
|
205
261
|
|
|
206
262
|
Returns
|
|
207
263
|
-------
|
|
208
264
|
dict[MetricType, list]
|
|
209
265
|
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
210
266
|
"""
|
|
211
|
-
|
|
212
267
|
# apply filters
|
|
213
|
-
data = self._detailed_pairs
|
|
214
|
-
label_metadata = self._label_metadata
|
|
215
|
-
n_datums = self.n_datums
|
|
216
268
|
if filter_ is not None:
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
269
|
+
detailed_pairs, label_metadata = self.filter(filter_=filter_)
|
|
270
|
+
n_datums = filter_.metadata.number_of_datums
|
|
271
|
+
else:
|
|
272
|
+
detailed_pairs = self._detailed_pairs
|
|
273
|
+
label_metadata = self._label_metadata
|
|
274
|
+
n_datums = self.metadata.number_of_datums
|
|
220
275
|
|
|
221
276
|
results = compute_precision_recall_rocauc(
|
|
222
|
-
|
|
277
|
+
detailed_pairs=detailed_pairs,
|
|
223
278
|
label_metadata=label_metadata,
|
|
224
279
|
score_thresholds=np.array(score_thresholds),
|
|
225
280
|
hardmax=hardmax,
|
|
226
281
|
n_datums=n_datums,
|
|
227
282
|
)
|
|
228
|
-
|
|
229
283
|
return unpack_precision_recall_rocauc_into_metric_lists(
|
|
230
284
|
results=results,
|
|
231
285
|
score_thresholds=score_thresholds,
|
|
@@ -253,37 +307,35 @@ class Evaluator:
|
|
|
253
307
|
number_of_examples : int, default=0
|
|
254
308
|
The number of examples to return per count.
|
|
255
309
|
filter_ : Filter, optional
|
|
256
|
-
|
|
310
|
+
Applies a filter to the internal cache.
|
|
257
311
|
|
|
258
312
|
Returns
|
|
259
313
|
-------
|
|
260
314
|
list[Metric]
|
|
261
315
|
A list of confusion matrices.
|
|
262
316
|
"""
|
|
263
|
-
|
|
264
317
|
# apply filters
|
|
265
|
-
data = self._detailed_pairs
|
|
266
|
-
label_metadata = self._label_metadata
|
|
267
318
|
if filter_ is not None:
|
|
268
|
-
|
|
269
|
-
|
|
319
|
+
detailed_pairs, label_metadata = self.filter(filter_=filter_)
|
|
320
|
+
else:
|
|
321
|
+
detailed_pairs = self._detailed_pairs
|
|
322
|
+
label_metadata = self._label_metadata
|
|
270
323
|
|
|
271
|
-
if
|
|
324
|
+
if detailed_pairs.size == 0:
|
|
272
325
|
return list()
|
|
273
326
|
|
|
274
327
|
results = compute_confusion_matrix(
|
|
275
|
-
|
|
328
|
+
detailed_pairs=detailed_pairs,
|
|
276
329
|
label_metadata=label_metadata,
|
|
277
330
|
score_thresholds=np.array(score_thresholds),
|
|
278
331
|
hardmax=hardmax,
|
|
279
332
|
n_examples=number_of_examples,
|
|
280
333
|
)
|
|
281
|
-
|
|
282
334
|
return unpack_confusion_matrix_into_metric_list(
|
|
283
335
|
results=results,
|
|
284
336
|
score_thresholds=score_thresholds,
|
|
285
337
|
number_of_examples=number_of_examples,
|
|
286
|
-
|
|
338
|
+
index_to_datum_id=self.index_to_datum_id,
|
|
287
339
|
index_to_label=self.index_to_label,
|
|
288
340
|
)
|
|
289
341
|
|
|
@@ -306,40 +358,26 @@ class Evaluator:
|
|
|
306
358
|
number_of_examples : int, default=0
|
|
307
359
|
The number of examples to return per count.
|
|
308
360
|
filter_ : Filter, optional
|
|
309
|
-
|
|
361
|
+
Applies a filter to the internal cache.
|
|
310
362
|
|
|
311
363
|
Returns
|
|
312
364
|
-------
|
|
313
365
|
dict[MetricType, list[Metric]]
|
|
314
366
|
Lists of metrics organized by metric type.
|
|
315
367
|
"""
|
|
316
|
-
|
|
317
368
|
metrics = self.compute_precision_recall_rocauc(
|
|
318
369
|
score_thresholds=score_thresholds,
|
|
319
370
|
hardmax=hardmax,
|
|
320
371
|
filter_=filter_,
|
|
321
372
|
)
|
|
322
|
-
|
|
323
373
|
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
324
374
|
score_thresholds=score_thresholds,
|
|
325
375
|
hardmax=hardmax,
|
|
326
376
|
number_of_examples=number_of_examples,
|
|
327
377
|
filter_=filter_,
|
|
328
378
|
)
|
|
329
|
-
|
|
330
379
|
return metrics
|
|
331
380
|
|
|
332
|
-
|
|
333
|
-
class DataLoader:
|
|
334
|
-
"""
|
|
335
|
-
Classification DataLoader.
|
|
336
|
-
"""
|
|
337
|
-
|
|
338
|
-
def __init__(self):
|
|
339
|
-
self._evaluator = Evaluator()
|
|
340
|
-
self.groundtruth_count = defaultdict(lambda: defaultdict(int))
|
|
341
|
-
self.prediction_count = defaultdict(lambda: defaultdict(int))
|
|
342
|
-
|
|
343
381
|
def _add_datum(self, uid: str) -> int:
|
|
344
382
|
"""
|
|
345
383
|
Helper function for adding a datum to the cache.
|
|
@@ -354,11 +392,11 @@ class DataLoader:
|
|
|
354
392
|
int
|
|
355
393
|
The datum index.
|
|
356
394
|
"""
|
|
357
|
-
if uid not in self.
|
|
358
|
-
index = len(self.
|
|
359
|
-
self.
|
|
360
|
-
self.
|
|
361
|
-
return self.
|
|
395
|
+
if uid not in self.datum_id_to_index:
|
|
396
|
+
index = len(self.datum_id_to_index)
|
|
397
|
+
self.datum_id_to_index[uid] = index
|
|
398
|
+
self.index_to_datum_id.append(uid)
|
|
399
|
+
return self.datum_id_to_index[uid]
|
|
362
400
|
|
|
363
401
|
def _add_label(self, label: str) -> int:
|
|
364
402
|
"""
|
|
@@ -374,47 +412,12 @@ class DataLoader:
|
|
|
374
412
|
int
|
|
375
413
|
Label index.
|
|
376
414
|
"""
|
|
377
|
-
label_id = len(self.
|
|
378
|
-
if label not in self.
|
|
379
|
-
self.
|
|
380
|
-
self.
|
|
381
|
-
|
|
415
|
+
label_id = len(self.index_to_label)
|
|
416
|
+
if label not in self.label_to_index:
|
|
417
|
+
self.label_to_index[label] = label_id
|
|
418
|
+
self.index_to_label.append(label)
|
|
382
419
|
label_id += 1
|
|
383
|
-
|
|
384
|
-
return self._evaluator.label_to_index[label]
|
|
385
|
-
|
|
386
|
-
def _add_data(
|
|
387
|
-
self,
|
|
388
|
-
uid_index: int,
|
|
389
|
-
groundtruth: int,
|
|
390
|
-
predictions: list[tuple[int, float]],
|
|
391
|
-
):
|
|
392
|
-
|
|
393
|
-
pairs = list()
|
|
394
|
-
scores = np.array([score for _, score in predictions])
|
|
395
|
-
max_score_idx = np.argmax(scores)
|
|
396
|
-
|
|
397
|
-
for idx, (plabel, score) in enumerate(predictions):
|
|
398
|
-
pairs.append(
|
|
399
|
-
(
|
|
400
|
-
float(uid_index),
|
|
401
|
-
float(groundtruth),
|
|
402
|
-
float(plabel),
|
|
403
|
-
float(score),
|
|
404
|
-
float(max_score_idx == idx),
|
|
405
|
-
)
|
|
406
|
-
)
|
|
407
|
-
|
|
408
|
-
if self._evaluator._detailed_pairs.size == 0:
|
|
409
|
-
self._evaluator._detailed_pairs = np.array(pairs)
|
|
410
|
-
else:
|
|
411
|
-
self._evaluator._detailed_pairs = np.concatenate(
|
|
412
|
-
[
|
|
413
|
-
self._evaluator._detailed_pairs,
|
|
414
|
-
np.array(pairs),
|
|
415
|
-
],
|
|
416
|
-
axis=0,
|
|
417
|
-
)
|
|
420
|
+
return self.label_to_index[label]
|
|
418
421
|
|
|
419
422
|
def add_data(
|
|
420
423
|
self,
|
|
@@ -439,24 +442,18 @@ class DataLoader:
|
|
|
439
442
|
raise ValueError(
|
|
440
443
|
"Classifications must contain at least one prediction."
|
|
441
444
|
)
|
|
442
|
-
# update metadata
|
|
443
|
-
self._evaluator.n_datums += 1
|
|
444
|
-
self._evaluator.n_groundtruths += 1
|
|
445
|
-
self._evaluator.n_predictions += len(classification.predictions)
|
|
446
445
|
|
|
447
446
|
# update datum uid index
|
|
448
447
|
uid_index = self._add_datum(uid=classification.uid)
|
|
449
448
|
|
|
450
449
|
# cache labels and annotations
|
|
451
450
|
groundtruth = self._add_label(classification.groundtruth)
|
|
452
|
-
self.groundtruth_count[groundtruth][uid_index] += 1
|
|
453
451
|
|
|
454
452
|
predictions = list()
|
|
455
453
|
for plabel, pscore in zip(
|
|
456
454
|
classification.predictions, classification.scores
|
|
457
455
|
):
|
|
458
456
|
label_idx = self._add_label(plabel)
|
|
459
|
-
self.prediction_count[label_idx][uid_index] += 1
|
|
460
457
|
predictions.append(
|
|
461
458
|
(
|
|
462
459
|
label_idx,
|
|
@@ -464,13 +461,33 @@ class DataLoader:
|
|
|
464
461
|
)
|
|
465
462
|
)
|
|
466
463
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
)
|
|
464
|
+
pairs = list()
|
|
465
|
+
scores = np.array([score for _, score in predictions])
|
|
466
|
+
max_score_idx = np.argmax(scores)
|
|
467
|
+
|
|
468
|
+
for idx, (plabel, score) in enumerate(predictions):
|
|
469
|
+
pairs.append(
|
|
470
|
+
(
|
|
471
|
+
float(uid_index),
|
|
472
|
+
float(groundtruth),
|
|
473
|
+
float(plabel),
|
|
474
|
+
float(score),
|
|
475
|
+
float(max_score_idx == idx),
|
|
476
|
+
)
|
|
477
|
+
)
|
|
472
478
|
|
|
473
|
-
|
|
479
|
+
if self._detailed_pairs.size == 0:
|
|
480
|
+
self._detailed_pairs = np.array(pairs)
|
|
481
|
+
else:
|
|
482
|
+
self._detailed_pairs = np.concatenate(
|
|
483
|
+
[
|
|
484
|
+
self._detailed_pairs,
|
|
485
|
+
np.array(pairs),
|
|
486
|
+
],
|
|
487
|
+
axis=0,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
def finalize(self):
|
|
474
491
|
"""
|
|
475
492
|
Performs data finalization and some preprocessing steps.
|
|
476
493
|
|
|
@@ -479,63 +496,32 @@ class DataLoader:
|
|
|
479
496
|
Evaluator
|
|
480
497
|
A ready-to-use evaluator object.
|
|
481
498
|
"""
|
|
499
|
+
if self._detailed_pairs.size == 0:
|
|
500
|
+
raise EmptyEvaluatorException()
|
|
482
501
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
n_datums = self._evaluator.n_datums
|
|
487
|
-
n_labels = len(self._evaluator.index_to_label)
|
|
488
|
-
|
|
489
|
-
self._evaluator.n_labels = n_labels
|
|
490
|
-
|
|
491
|
-
self._evaluator._label_metadata_per_datum = np.zeros(
|
|
492
|
-
(2, n_datums, n_labels), dtype=np.int32
|
|
502
|
+
self._label_metadata = compute_label_metadata(
|
|
503
|
+
ids=self._detailed_pairs[:, :3].astype(np.int32),
|
|
504
|
+
n_labels=len(self.index_to_label),
|
|
493
505
|
)
|
|
494
|
-
for datum_idx in range(n_datums):
|
|
495
|
-
for label_idx in range(n_labels):
|
|
496
|
-
gt_count = (
|
|
497
|
-
self.groundtruth_count[label_idx].get(datum_idx, 0)
|
|
498
|
-
if label_idx in self.groundtruth_count
|
|
499
|
-
else 0
|
|
500
|
-
)
|
|
501
|
-
pd_count = (
|
|
502
|
-
self.prediction_count[label_idx].get(datum_idx, 0)
|
|
503
|
-
if label_idx in self.prediction_count
|
|
504
|
-
else 0
|
|
505
|
-
)
|
|
506
|
-
self._evaluator._label_metadata_per_datum[
|
|
507
|
-
:, datum_idx, label_idx
|
|
508
|
-
] = np.array([gt_count, pd_count])
|
|
509
|
-
|
|
510
|
-
self._evaluator._label_metadata = np.array(
|
|
511
|
-
[
|
|
512
|
-
[
|
|
513
|
-
np.sum(
|
|
514
|
-
self._evaluator._label_metadata_per_datum[
|
|
515
|
-
0, :, label_idx
|
|
516
|
-
]
|
|
517
|
-
),
|
|
518
|
-
np.sum(
|
|
519
|
-
self._evaluator._label_metadata_per_datum[
|
|
520
|
-
1, :, label_idx
|
|
521
|
-
]
|
|
522
|
-
),
|
|
523
|
-
]
|
|
524
|
-
for label_idx in range(n_labels)
|
|
525
|
-
],
|
|
526
|
-
dtype=np.int32,
|
|
527
|
-
)
|
|
528
|
-
|
|
529
|
-
# sort pairs by groundtruth, prediction, score
|
|
530
506
|
indices = np.lexsort(
|
|
531
507
|
(
|
|
532
|
-
self.
|
|
533
|
-
self.
|
|
534
|
-
-self.
|
|
508
|
+
self._detailed_pairs[:, 1], # ground truth
|
|
509
|
+
self._detailed_pairs[:, 2], # prediction
|
|
510
|
+
-self._detailed_pairs[:, 3], # score
|
|
535
511
|
)
|
|
536
512
|
)
|
|
537
|
-
self.
|
|
538
|
-
|
|
539
|
-
|
|
513
|
+
self._detailed_pairs = self._detailed_pairs[indices]
|
|
514
|
+
self._metadata = Metadata.create(
|
|
515
|
+
detailed_pairs=self._detailed_pairs,
|
|
516
|
+
number_of_datums=len(self.index_to_datum_id),
|
|
517
|
+
number_of_labels=len(self.index_to_label),
|
|
518
|
+
)
|
|
519
|
+
return self
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class DataLoader(Evaluator):
|
|
523
|
+
"""
|
|
524
|
+
Used for backwards compatibility as the Evaluator now handles ingestion.
|
|
525
|
+
"""
|
|
540
526
|
|
|
541
|
-
|
|
527
|
+
pass
|
|
@@ -335,8 +335,8 @@ class Metric(BaseMetric):
|
|
|
335
335
|
The confusion matrix and related metrics for the classification task.
|
|
336
336
|
|
|
337
337
|
This class encapsulates detailed information about the model's performance, including correct
|
|
338
|
-
predictions, misclassifications
|
|
339
|
-
|
|
338
|
+
predictions, misclassifications and unmatched ground truths (subset of false negatives).
|
|
339
|
+
It provides counts and examples for each category to facilitate in-depth analysis.
|
|
340
340
|
|
|
341
341
|
Confusion Matrix Structure:
|
|
342
342
|
{
|
|
@@ -345,10 +345,8 @@ class Metric(BaseMetric):
|
|
|
345
345
|
'count': int,
|
|
346
346
|
'examples': [
|
|
347
347
|
{
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
'prediction': dict, # {'xmin': float, 'xmax': float, 'ymin': float, 'ymax': float}
|
|
351
|
-
'score': float,
|
|
348
|
+
"datum_id": str,
|
|
349
|
+
"score": float
|
|
352
350
|
},
|
|
353
351
|
...
|
|
354
352
|
],
|
|
@@ -364,8 +362,7 @@ class Metric(BaseMetric):
|
|
|
364
362
|
'count': int,
|
|
365
363
|
'examples': [
|
|
366
364
|
{
|
|
367
|
-
|
|
368
|
-
'groundtruth': dict, # {'xmin': float, 'xmax': float, 'ymin': float, 'ymax': float}
|
|
365
|
+
"datum_id": str
|
|
369
366
|
},
|
|
370
367
|
...
|
|
371
368
|
],
|