valor-lite 0.36.5__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +367 -304
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -109
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -865
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.5.dist-info/RECORD +0 -41
- {valor_lite-0.36.5.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -1,545 +0,0 @@
|
|
|
1
|
-
from dataclasses import asdict, dataclass
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
from numpy.typing import NDArray
|
|
5
|
-
from tqdm import tqdm
|
|
6
|
-
|
|
7
|
-
from valor_lite.classification.annotation import Classification
|
|
8
|
-
from valor_lite.classification.computation import (
|
|
9
|
-
compute_confusion_matrix,
|
|
10
|
-
compute_label_metadata,
|
|
11
|
-
compute_precision_recall_rocauc,
|
|
12
|
-
filter_cache,
|
|
13
|
-
)
|
|
14
|
-
from valor_lite.classification.metric import Metric, MetricType
|
|
15
|
-
from valor_lite.classification.utilities import (
|
|
16
|
-
unpack_confusion_matrix_into_metric_list,
|
|
17
|
-
unpack_precision_recall_rocauc_into_metric_lists,
|
|
18
|
-
)
|
|
19
|
-
from valor_lite.exceptions import EmptyEvaluatorError, EmptyFilterError
|
|
20
|
-
|
|
21
|
-
"""
|
|
22
|
-
Usage
|
|
23
|
-
-----
|
|
24
|
-
|
|
25
|
-
manager = DataLoader()
|
|
26
|
-
manager.add_data(
|
|
27
|
-
groundtruths=groundtruths,
|
|
28
|
-
predictions=predictions,
|
|
29
|
-
)
|
|
30
|
-
evaluator = manager.finalize()
|
|
31
|
-
|
|
32
|
-
metrics = evaluator.evaluate()
|
|
33
|
-
|
|
34
|
-
f1_metrics = metrics[MetricType.F1]
|
|
35
|
-
accuracy_metrics = metrics[MetricType.Accuracy]
|
|
36
|
-
|
|
37
|
-
filter_mask = evaluator.create_filter(datum_uids=["uid1", "uid2"])
|
|
38
|
-
filtered_metrics = evaluator.evaluate(filter_mask=filter_mask)
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@dataclass
|
|
43
|
-
class Metadata:
|
|
44
|
-
number_of_datums: int = 0
|
|
45
|
-
number_of_ground_truths: int = 0
|
|
46
|
-
number_of_predictions: int = 0
|
|
47
|
-
number_of_labels: int = 0
|
|
48
|
-
|
|
49
|
-
@classmethod
|
|
50
|
-
def create(
|
|
51
|
-
cls,
|
|
52
|
-
detailed_pairs: NDArray[np.float64],
|
|
53
|
-
number_of_datums: int,
|
|
54
|
-
number_of_labels: int,
|
|
55
|
-
):
|
|
56
|
-
# count number of unique ground truths
|
|
57
|
-
mask_valid_gts = detailed_pairs[:, 1] >= 0
|
|
58
|
-
unique_ids = np.unique(
|
|
59
|
-
detailed_pairs[np.ix_(mask_valid_gts, (0, 1))], # type: ignore - np.ix_ typing
|
|
60
|
-
axis=0,
|
|
61
|
-
)
|
|
62
|
-
number_of_ground_truths = int(unique_ids.shape[0])
|
|
63
|
-
|
|
64
|
-
# count number of unqiue predictions
|
|
65
|
-
mask_valid_pds = detailed_pairs[:, 2] >= 0
|
|
66
|
-
unique_ids = np.unique(
|
|
67
|
-
detailed_pairs[np.ix_(mask_valid_pds, (0, 2))], axis=0 # type: ignore - np.ix_ typing
|
|
68
|
-
)
|
|
69
|
-
number_of_predictions = int(unique_ids.shape[0])
|
|
70
|
-
|
|
71
|
-
return cls(
|
|
72
|
-
number_of_datums=number_of_datums,
|
|
73
|
-
number_of_ground_truths=number_of_ground_truths,
|
|
74
|
-
number_of_predictions=number_of_predictions,
|
|
75
|
-
number_of_labels=number_of_labels,
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
def to_dict(self) -> dict[str, int | bool]:
|
|
79
|
-
return asdict(self)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
@dataclass
|
|
83
|
-
class Filter:
|
|
84
|
-
datum_mask: NDArray[np.bool_]
|
|
85
|
-
valid_label_indices: NDArray[np.int32] | None
|
|
86
|
-
metadata: Metadata
|
|
87
|
-
|
|
88
|
-
def __post_init__(self):
|
|
89
|
-
# validate datum mask
|
|
90
|
-
if not self.datum_mask.any():
|
|
91
|
-
raise EmptyFilterError("filter removes all datums")
|
|
92
|
-
|
|
93
|
-
# validate label indices
|
|
94
|
-
if (
|
|
95
|
-
self.valid_label_indices is not None
|
|
96
|
-
and self.valid_label_indices.size == 0
|
|
97
|
-
):
|
|
98
|
-
raise EmptyFilterError("filter removes all labels")
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
class Evaluator:
|
|
102
|
-
"""
|
|
103
|
-
Classification Evaluator
|
|
104
|
-
"""
|
|
105
|
-
|
|
106
|
-
def __init__(self):
|
|
107
|
-
# external references
|
|
108
|
-
self.datum_id_to_index: dict[str, int] = {}
|
|
109
|
-
self.label_to_index: dict[str, int] = {}
|
|
110
|
-
|
|
111
|
-
self.index_to_datum_id: list[str] = []
|
|
112
|
-
self.index_to_label: list[str] = []
|
|
113
|
-
|
|
114
|
-
# internal caches
|
|
115
|
-
self._detailed_pairs = np.array([])
|
|
116
|
-
self._label_metadata = np.array([], dtype=np.int32)
|
|
117
|
-
self._metadata = Metadata()
|
|
118
|
-
|
|
119
|
-
@property
|
|
120
|
-
def metadata(self) -> Metadata:
|
|
121
|
-
return self._metadata
|
|
122
|
-
|
|
123
|
-
@property
|
|
124
|
-
def ignored_prediction_labels(self) -> list[str]:
|
|
125
|
-
"""
|
|
126
|
-
Prediction labels that are not present in the ground truth set.
|
|
127
|
-
"""
|
|
128
|
-
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
129
|
-
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
130
|
-
return [
|
|
131
|
-
self.index_to_label[label_id] for label_id in (plabels - glabels)
|
|
132
|
-
]
|
|
133
|
-
|
|
134
|
-
@property
|
|
135
|
-
def missing_prediction_labels(self) -> list[str]:
|
|
136
|
-
"""
|
|
137
|
-
Ground truth labels that are not present in the prediction set.
|
|
138
|
-
"""
|
|
139
|
-
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
140
|
-
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
141
|
-
return [
|
|
142
|
-
self.index_to_label[label_id] for label_id in (glabels - plabels)
|
|
143
|
-
]
|
|
144
|
-
|
|
145
|
-
def create_filter(
|
|
146
|
-
self,
|
|
147
|
-
datums: list[str] | NDArray[np.int32] | None = None,
|
|
148
|
-
labels: list[str] | NDArray[np.int32] | None = None,
|
|
149
|
-
) -> Filter:
|
|
150
|
-
"""
|
|
151
|
-
Creates a filter object.
|
|
152
|
-
|
|
153
|
-
Parameters
|
|
154
|
-
----------
|
|
155
|
-
datums : list[str] | NDArray[int32], optional
|
|
156
|
-
An optional list of string uids or integer indices representing datums.
|
|
157
|
-
labels : list[str] | NDArray[int32], optional
|
|
158
|
-
An optional list of strings or integer indices representing labels.
|
|
159
|
-
|
|
160
|
-
Returns
|
|
161
|
-
-------
|
|
162
|
-
Filter
|
|
163
|
-
The filter object representing the input parameters.
|
|
164
|
-
"""
|
|
165
|
-
# create datum mask
|
|
166
|
-
n_pairs = self._detailed_pairs.shape[0]
|
|
167
|
-
datum_mask = np.ones(n_pairs, dtype=np.bool_)
|
|
168
|
-
if datums is not None:
|
|
169
|
-
# convert to array of valid datum indices
|
|
170
|
-
if isinstance(datums, list):
|
|
171
|
-
datums = np.array(
|
|
172
|
-
[self.datum_id_to_index[uid] for uid in datums],
|
|
173
|
-
dtype=np.int32,
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
# return early if all data removed
|
|
177
|
-
if datums.size == 0:
|
|
178
|
-
raise EmptyFilterError("filter removes all datums")
|
|
179
|
-
|
|
180
|
-
# validate indices
|
|
181
|
-
if datums.max() >= len(self.index_to_datum_id):
|
|
182
|
-
raise ValueError(
|
|
183
|
-
f"datum index '{datums.max()}' exceeds total number of datums"
|
|
184
|
-
)
|
|
185
|
-
elif datums.min() < 0:
|
|
186
|
-
raise ValueError(
|
|
187
|
-
f"datum index '{datums.min()}' is a negative value"
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
# create datum mask
|
|
191
|
-
datum_mask = np.isin(self._detailed_pairs[:, 0], datums)
|
|
192
|
-
|
|
193
|
-
# collect valid label indices
|
|
194
|
-
if labels is not None:
|
|
195
|
-
# convert to array of valid label indices
|
|
196
|
-
if isinstance(labels, list):
|
|
197
|
-
labels = np.array(
|
|
198
|
-
[self.label_to_index[label] for label in labels]
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
# return early if all data removed
|
|
202
|
-
if labels.size == 0:
|
|
203
|
-
raise EmptyFilterError("filter removes all labels")
|
|
204
|
-
|
|
205
|
-
# validate indices
|
|
206
|
-
if labels.max() >= len(self.index_to_label):
|
|
207
|
-
raise ValueError(
|
|
208
|
-
f"label index '{labels.max()}' exceeds total number of labels"
|
|
209
|
-
)
|
|
210
|
-
elif labels.min() < 0:
|
|
211
|
-
raise ValueError(
|
|
212
|
-
f"label index '{labels.min()}' is a negative value"
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
# add -1 to represent null labels which should not be filtered
|
|
216
|
-
labels = np.concatenate([labels, np.array([-1])])
|
|
217
|
-
|
|
218
|
-
filtered_detailed_pairs, _ = filter_cache(
|
|
219
|
-
detailed_pairs=self._detailed_pairs,
|
|
220
|
-
datum_mask=datum_mask,
|
|
221
|
-
valid_label_indices=labels,
|
|
222
|
-
n_labels=self.metadata.number_of_labels,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
number_of_datums = (
|
|
226
|
-
datums.size
|
|
227
|
-
if datums is not None
|
|
228
|
-
else self.metadata.number_of_datums
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
return Filter(
|
|
232
|
-
datum_mask=datum_mask,
|
|
233
|
-
valid_label_indices=labels,
|
|
234
|
-
metadata=Metadata.create(
|
|
235
|
-
detailed_pairs=filtered_detailed_pairs,
|
|
236
|
-
number_of_datums=number_of_datums,
|
|
237
|
-
number_of_labels=self.metadata.number_of_labels,
|
|
238
|
-
),
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
def filter(
|
|
242
|
-
self, filter_: Filter
|
|
243
|
-
) -> tuple[NDArray[np.float64], NDArray[np.int32]]:
|
|
244
|
-
"""
|
|
245
|
-
Performs filtering over the internal cache.
|
|
246
|
-
|
|
247
|
-
Parameters
|
|
248
|
-
----------
|
|
249
|
-
filter_ : Filter
|
|
250
|
-
The filter object representation.
|
|
251
|
-
|
|
252
|
-
Returns
|
|
253
|
-
-------
|
|
254
|
-
NDArray[float64]
|
|
255
|
-
The filtered detailed pairs.
|
|
256
|
-
NDArray[int32]
|
|
257
|
-
The filtered label metadata.
|
|
258
|
-
"""
|
|
259
|
-
return filter_cache(
|
|
260
|
-
detailed_pairs=self._detailed_pairs,
|
|
261
|
-
datum_mask=filter_.datum_mask,
|
|
262
|
-
valid_label_indices=filter_.valid_label_indices,
|
|
263
|
-
n_labels=self.metadata.number_of_labels,
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
def compute_precision_recall_rocauc(
|
|
267
|
-
self,
|
|
268
|
-
score_thresholds: list[float] = [0.0],
|
|
269
|
-
hardmax: bool = True,
|
|
270
|
-
filter_: Filter | None = None,
|
|
271
|
-
) -> dict[MetricType, list]:
|
|
272
|
-
"""
|
|
273
|
-
Performs an evaluation and returns metrics.
|
|
274
|
-
|
|
275
|
-
Parameters
|
|
276
|
-
----------
|
|
277
|
-
score_thresholds : list[float]
|
|
278
|
-
A list of score thresholds to compute metrics over.
|
|
279
|
-
hardmax : bool
|
|
280
|
-
Toggles whether a hardmax is applied to predictions.
|
|
281
|
-
filter_ : Filter, optional
|
|
282
|
-
Applies a filter to the internal cache.
|
|
283
|
-
|
|
284
|
-
Returns
|
|
285
|
-
-------
|
|
286
|
-
dict[MetricType, list]
|
|
287
|
-
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
288
|
-
"""
|
|
289
|
-
# apply filters
|
|
290
|
-
if filter_ is not None:
|
|
291
|
-
detailed_pairs, label_metadata = self.filter(filter_=filter_)
|
|
292
|
-
n_datums = filter_.metadata.number_of_datums
|
|
293
|
-
else:
|
|
294
|
-
detailed_pairs = self._detailed_pairs
|
|
295
|
-
label_metadata = self._label_metadata
|
|
296
|
-
n_datums = self.metadata.number_of_datums
|
|
297
|
-
|
|
298
|
-
results = compute_precision_recall_rocauc(
|
|
299
|
-
detailed_pairs=detailed_pairs,
|
|
300
|
-
label_metadata=label_metadata,
|
|
301
|
-
score_thresholds=np.array(score_thresholds),
|
|
302
|
-
hardmax=hardmax,
|
|
303
|
-
n_datums=n_datums,
|
|
304
|
-
)
|
|
305
|
-
return unpack_precision_recall_rocauc_into_metric_lists(
|
|
306
|
-
results=results,
|
|
307
|
-
score_thresholds=score_thresholds,
|
|
308
|
-
hardmax=hardmax,
|
|
309
|
-
label_metadata=label_metadata,
|
|
310
|
-
index_to_label=self.index_to_label,
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
def compute_confusion_matrix(
|
|
314
|
-
self,
|
|
315
|
-
score_thresholds: list[float] = [0.0],
|
|
316
|
-
hardmax: bool = True,
|
|
317
|
-
filter_: Filter | None = None,
|
|
318
|
-
) -> list[Metric]:
|
|
319
|
-
"""
|
|
320
|
-
Computes a detailed confusion matrix..
|
|
321
|
-
|
|
322
|
-
Parameters
|
|
323
|
-
----------
|
|
324
|
-
score_thresholds : list[float]
|
|
325
|
-
A list of score thresholds to compute metrics over.
|
|
326
|
-
hardmax : bool
|
|
327
|
-
Toggles whether a hardmax is applied to predictions.
|
|
328
|
-
filter_ : Filter, optional
|
|
329
|
-
Applies a filter to the internal cache.
|
|
330
|
-
|
|
331
|
-
Returns
|
|
332
|
-
-------
|
|
333
|
-
list[Metric]
|
|
334
|
-
A list of confusion matrices.
|
|
335
|
-
"""
|
|
336
|
-
# apply filters
|
|
337
|
-
if filter_ is not None:
|
|
338
|
-
detailed_pairs, _ = self.filter(filter_=filter_)
|
|
339
|
-
else:
|
|
340
|
-
detailed_pairs = self._detailed_pairs
|
|
341
|
-
|
|
342
|
-
if detailed_pairs.size == 0:
|
|
343
|
-
return list()
|
|
344
|
-
|
|
345
|
-
result = compute_confusion_matrix(
|
|
346
|
-
detailed_pairs=detailed_pairs,
|
|
347
|
-
score_thresholds=np.array(score_thresholds),
|
|
348
|
-
hardmax=hardmax,
|
|
349
|
-
)
|
|
350
|
-
return unpack_confusion_matrix_into_metric_list(
|
|
351
|
-
detailed_pairs=detailed_pairs,
|
|
352
|
-
result=result,
|
|
353
|
-
score_thresholds=score_thresholds,
|
|
354
|
-
index_to_datum_id=self.index_to_datum_id,
|
|
355
|
-
index_to_label=self.index_to_label,
|
|
356
|
-
)
|
|
357
|
-
|
|
358
|
-
def evaluate(
|
|
359
|
-
self,
|
|
360
|
-
score_thresholds: list[float] = [0.0],
|
|
361
|
-
hardmax: bool = True,
|
|
362
|
-
filter_: Filter | None = None,
|
|
363
|
-
) -> dict[MetricType, list[Metric]]:
|
|
364
|
-
"""
|
|
365
|
-
Computes a detailed confusion matrix..
|
|
366
|
-
|
|
367
|
-
Parameters
|
|
368
|
-
----------
|
|
369
|
-
score_thresholds : list[float]
|
|
370
|
-
A list of score thresholds to compute metrics over.
|
|
371
|
-
hardmax : bool
|
|
372
|
-
Toggles whether a hardmax is applied to predictions.
|
|
373
|
-
filter_ : Filter, optional
|
|
374
|
-
Applies a filter to the internal cache.
|
|
375
|
-
|
|
376
|
-
Returns
|
|
377
|
-
-------
|
|
378
|
-
dict[MetricType, list[Metric]]
|
|
379
|
-
Lists of metrics organized by metric type.
|
|
380
|
-
"""
|
|
381
|
-
metrics = self.compute_precision_recall_rocauc(
|
|
382
|
-
score_thresholds=score_thresholds,
|
|
383
|
-
hardmax=hardmax,
|
|
384
|
-
filter_=filter_,
|
|
385
|
-
)
|
|
386
|
-
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
387
|
-
score_thresholds=score_thresholds,
|
|
388
|
-
hardmax=hardmax,
|
|
389
|
-
filter_=filter_,
|
|
390
|
-
)
|
|
391
|
-
return metrics
|
|
392
|
-
|
|
393
|
-
def _add_datum(self, uid: str) -> int:
|
|
394
|
-
"""
|
|
395
|
-
Helper function for adding a datum to the cache.
|
|
396
|
-
|
|
397
|
-
Parameters
|
|
398
|
-
----------
|
|
399
|
-
uid : str
|
|
400
|
-
The datum uid.
|
|
401
|
-
|
|
402
|
-
Returns
|
|
403
|
-
-------
|
|
404
|
-
int
|
|
405
|
-
The datum index.
|
|
406
|
-
|
|
407
|
-
Raises
|
|
408
|
-
------
|
|
409
|
-
ValueError
|
|
410
|
-
If datum id already exists.
|
|
411
|
-
"""
|
|
412
|
-
if uid in self.datum_id_to_index:
|
|
413
|
-
raise ValueError("datum with id '{uid}' already exists")
|
|
414
|
-
index = len(self.datum_id_to_index)
|
|
415
|
-
self.datum_id_to_index[uid] = index
|
|
416
|
-
self.index_to_datum_id.append(uid)
|
|
417
|
-
return self.datum_id_to_index[uid]
|
|
418
|
-
|
|
419
|
-
def _add_label(self, label: str) -> int:
|
|
420
|
-
"""
|
|
421
|
-
Helper function for adding a label to the cache.
|
|
422
|
-
|
|
423
|
-
Parameters
|
|
424
|
-
----------
|
|
425
|
-
label : str
|
|
426
|
-
A string representing a label.
|
|
427
|
-
|
|
428
|
-
Returns
|
|
429
|
-
-------
|
|
430
|
-
int
|
|
431
|
-
Label index.
|
|
432
|
-
"""
|
|
433
|
-
label_id = len(self.index_to_label)
|
|
434
|
-
if label not in self.label_to_index:
|
|
435
|
-
self.label_to_index[label] = label_id
|
|
436
|
-
self.index_to_label.append(label)
|
|
437
|
-
label_id += 1
|
|
438
|
-
return self.label_to_index[label]
|
|
439
|
-
|
|
440
|
-
def add_data(
|
|
441
|
-
self,
|
|
442
|
-
classifications: list[Classification],
|
|
443
|
-
show_progress: bool = False,
|
|
444
|
-
):
|
|
445
|
-
"""
|
|
446
|
-
Adds classifications to the cache.
|
|
447
|
-
|
|
448
|
-
Parameters
|
|
449
|
-
----------
|
|
450
|
-
classifications : list[Classification]
|
|
451
|
-
A list of Classification objects.
|
|
452
|
-
show_progress : bool, default=False
|
|
453
|
-
Toggle for tqdm progress bar.
|
|
454
|
-
"""
|
|
455
|
-
|
|
456
|
-
disable_tqdm = not show_progress
|
|
457
|
-
for classification in tqdm(classifications, disable=disable_tqdm):
|
|
458
|
-
|
|
459
|
-
if len(classification.predictions) == 0:
|
|
460
|
-
raise ValueError(
|
|
461
|
-
"Classifications must contain at least one prediction."
|
|
462
|
-
)
|
|
463
|
-
|
|
464
|
-
# update datum uid index
|
|
465
|
-
uid_index = self._add_datum(uid=classification.uid)
|
|
466
|
-
|
|
467
|
-
# cache labels and annotations
|
|
468
|
-
groundtruth = self._add_label(classification.groundtruth)
|
|
469
|
-
|
|
470
|
-
predictions = list()
|
|
471
|
-
for plabel, pscore in zip(
|
|
472
|
-
classification.predictions, classification.scores
|
|
473
|
-
):
|
|
474
|
-
label_idx = self._add_label(plabel)
|
|
475
|
-
predictions.append(
|
|
476
|
-
(
|
|
477
|
-
label_idx,
|
|
478
|
-
pscore,
|
|
479
|
-
)
|
|
480
|
-
)
|
|
481
|
-
|
|
482
|
-
pairs = list()
|
|
483
|
-
scores = np.array([score for _, score in predictions])
|
|
484
|
-
max_score_idx = np.argmax(scores)
|
|
485
|
-
|
|
486
|
-
for idx, (plabel, score) in enumerate(predictions):
|
|
487
|
-
pairs.append(
|
|
488
|
-
(
|
|
489
|
-
float(uid_index),
|
|
490
|
-
float(groundtruth),
|
|
491
|
-
float(plabel),
|
|
492
|
-
float(score),
|
|
493
|
-
float(max_score_idx == idx),
|
|
494
|
-
)
|
|
495
|
-
)
|
|
496
|
-
|
|
497
|
-
if self._detailed_pairs.size == 0:
|
|
498
|
-
self._detailed_pairs = np.array(pairs)
|
|
499
|
-
else:
|
|
500
|
-
self._detailed_pairs = np.concatenate(
|
|
501
|
-
[
|
|
502
|
-
self._detailed_pairs,
|
|
503
|
-
np.array(pairs),
|
|
504
|
-
],
|
|
505
|
-
axis=0,
|
|
506
|
-
)
|
|
507
|
-
|
|
508
|
-
def finalize(self):
|
|
509
|
-
"""
|
|
510
|
-
Performs data finalization and some preprocessing steps.
|
|
511
|
-
|
|
512
|
-
Returns
|
|
513
|
-
-------
|
|
514
|
-
Evaluator
|
|
515
|
-
A ready-to-use evaluator object.
|
|
516
|
-
"""
|
|
517
|
-
if self._detailed_pairs.size == 0:
|
|
518
|
-
raise EmptyEvaluatorError()
|
|
519
|
-
|
|
520
|
-
self._label_metadata = compute_label_metadata(
|
|
521
|
-
ids=self._detailed_pairs[:, :3].astype(np.int32),
|
|
522
|
-
n_labels=len(self.index_to_label),
|
|
523
|
-
)
|
|
524
|
-
indices = np.lexsort(
|
|
525
|
-
(
|
|
526
|
-
self._detailed_pairs[:, 1], # ground truth
|
|
527
|
-
self._detailed_pairs[:, 2], # prediction
|
|
528
|
-
-self._detailed_pairs[:, 3], # score
|
|
529
|
-
)
|
|
530
|
-
)
|
|
531
|
-
self._detailed_pairs = self._detailed_pairs[indices]
|
|
532
|
-
self._metadata = Metadata.create(
|
|
533
|
-
detailed_pairs=self._detailed_pairs,
|
|
534
|
-
number_of_datums=len(self.index_to_datum_id),
|
|
535
|
-
number_of_labels=len(self.index_to_label),
|
|
536
|
-
)
|
|
537
|
-
return self
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
class DataLoader(Evaluator):
|
|
541
|
-
"""
|
|
542
|
-
Used for backwards compatibility as the Evaluator now handles ingestion.
|
|
543
|
-
"""
|
|
544
|
-
|
|
545
|
-
pass
|