valor-lite 0.33.5__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/classification/__init__.py +30 -0
- valor_lite/classification/annotation.py +13 -0
- valor_lite/classification/computation.py +411 -0
- valor_lite/classification/manager.py +844 -0
- valor_lite/classification/metric.py +191 -0
- valor_lite/detection/manager.py +19 -8
- {valor_lite-0.33.5.dist-info → valor_lite-0.33.7.dist-info}/METADATA +1 -1
- valor_lite-0.33.7.dist-info/RECORD +17 -0
- valor_lite-0.33.5.dist-info/RECORD +0 -12
- {valor_lite-0.33.5.dist-info → valor_lite-0.33.7.dist-info}/LICENSE +0 -0
- {valor_lite-0.33.5.dist-info → valor_lite-0.33.7.dist-info}/WHEEL +0 -0
- {valor_lite-0.33.5.dist-info → valor_lite-0.33.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,844 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from valor_lite.classification.annotation import Classification
|
|
8
|
+
from valor_lite.classification.computation import (
|
|
9
|
+
compute_confusion_matrix,
|
|
10
|
+
compute_metrics,
|
|
11
|
+
)
|
|
12
|
+
from valor_lite.classification.metric import (
|
|
13
|
+
F1,
|
|
14
|
+
ROCAUC,
|
|
15
|
+
Accuracy,
|
|
16
|
+
ConfusionMatrix,
|
|
17
|
+
Counts,
|
|
18
|
+
MetricType,
|
|
19
|
+
Precision,
|
|
20
|
+
Recall,
|
|
21
|
+
mROCAUC,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
Usage
|
|
26
|
+
-----
|
|
27
|
+
|
|
28
|
+
manager = DataLoader()
|
|
29
|
+
manager.add_data(
|
|
30
|
+
groundtruths=groundtruths,
|
|
31
|
+
predictions=predictions,
|
|
32
|
+
)
|
|
33
|
+
evaluator = manager.finalize()
|
|
34
|
+
|
|
35
|
+
metrics = evaluator.evaluate()
|
|
36
|
+
|
|
37
|
+
f1_metrics = metrics[MetricType.F1]
|
|
38
|
+
accuracy_metrics = metrics[MetricType.Accuracy]
|
|
39
|
+
|
|
40
|
+
filter_mask = evaluator.create_filter(datum_uids=["uid1", "uid2"])
|
|
41
|
+
filtered_metrics = evaluator.evaluate(filter_mask=filter_mask)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class Filter:
|
|
47
|
+
indices: NDArray[np.int32]
|
|
48
|
+
label_metadata: NDArray[np.int32]
|
|
49
|
+
n_datums: int
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Evaluator:
|
|
53
|
+
"""
|
|
54
|
+
Classification Evaluator
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self):
|
|
58
|
+
|
|
59
|
+
# metadata
|
|
60
|
+
self.n_datums = 0
|
|
61
|
+
self.n_groundtruths = 0
|
|
62
|
+
self.n_predictions = 0
|
|
63
|
+
self.n_labels = 0
|
|
64
|
+
|
|
65
|
+
# datum reference
|
|
66
|
+
self.uid_to_index: dict[str, int] = dict()
|
|
67
|
+
self.index_to_uid: dict[int, str] = dict()
|
|
68
|
+
|
|
69
|
+
# label reference
|
|
70
|
+
self.label_to_index: dict[tuple[str, str], int] = dict()
|
|
71
|
+
self.index_to_label: dict[int, tuple[str, str]] = dict()
|
|
72
|
+
|
|
73
|
+
# label key reference
|
|
74
|
+
self.index_to_label_key: dict[int, str] = dict()
|
|
75
|
+
self.label_key_to_index: dict[str, int] = dict()
|
|
76
|
+
self.label_index_to_label_key_index: dict[int, int] = dict()
|
|
77
|
+
|
|
78
|
+
# computation caches
|
|
79
|
+
self._detailed_pairs = np.array([])
|
|
80
|
+
self._label_metadata = np.array([], dtype=np.int32)
|
|
81
|
+
self._label_metadata_per_datum = np.array([], dtype=np.int32)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def ignored_prediction_labels(self) -> list[tuple[str, str]]:
|
|
85
|
+
"""
|
|
86
|
+
Prediction labels that are not present in the ground truth set.
|
|
87
|
+
"""
|
|
88
|
+
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
89
|
+
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
90
|
+
return [
|
|
91
|
+
self.index_to_label[label_id] for label_id in (plabels - glabels)
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def missing_prediction_labels(self) -> list[tuple[str, str]]:
|
|
96
|
+
"""
|
|
97
|
+
Ground truth labels that are not present in the prediction set.
|
|
98
|
+
"""
|
|
99
|
+
glabels = set(np.where(self._label_metadata[:, 0] > 0)[0])
|
|
100
|
+
plabels = set(np.where(self._label_metadata[:, 1] > 0)[0])
|
|
101
|
+
return [
|
|
102
|
+
self.index_to_label[label_id] for label_id in (glabels - plabels)
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def metadata(self) -> dict:
|
|
107
|
+
"""
|
|
108
|
+
Evaluation metadata.
|
|
109
|
+
"""
|
|
110
|
+
return {
|
|
111
|
+
"n_datums": self.n_datums,
|
|
112
|
+
"n_groundtruths": self.n_groundtruths,
|
|
113
|
+
"n_predictions": self.n_predictions,
|
|
114
|
+
"n_labels": self.n_labels,
|
|
115
|
+
"ignored_prediction_labels": self.ignored_prediction_labels,
|
|
116
|
+
"missing_prediction_labels": self.missing_prediction_labels,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
def create_filter(
|
|
120
|
+
self,
|
|
121
|
+
datum_uids: list[str] | NDArray[np.int32] | None = None,
|
|
122
|
+
labels: list[tuple[str, str]] | NDArray[np.int32] | None = None,
|
|
123
|
+
label_keys: list[str] | NDArray[np.int32] | None = None,
|
|
124
|
+
) -> Filter:
|
|
125
|
+
"""
|
|
126
|
+
Creates a boolean mask that can be passed to an evaluation.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
datum_uids : list[str] | NDArray[np.int32], optional
|
|
131
|
+
An optional list of string uids or a numpy array of uid indices.
|
|
132
|
+
labels : list[tuple[str, str]] | NDArray[np.int32], optional
|
|
133
|
+
An optional list of labels or a numpy array of label indices.
|
|
134
|
+
label_keys : list[str] | NDArray[np.int32], optional
|
|
135
|
+
An optional list of label keys or a numpy array of label key indices.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
Filter
|
|
140
|
+
A filter object that can be passed to the `evaluate` method.
|
|
141
|
+
"""
|
|
142
|
+
n_rows = self._detailed_pairs.shape[0]
|
|
143
|
+
|
|
144
|
+
n_datums = self._label_metadata_per_datum.shape[1]
|
|
145
|
+
n_labels = self._label_metadata_per_datum.shape[2]
|
|
146
|
+
|
|
147
|
+
mask_pairs = np.ones((n_rows, 1), dtype=np.bool_)
|
|
148
|
+
mask_datums = np.ones(n_datums, dtype=np.bool_)
|
|
149
|
+
mask_labels = np.ones(n_labels, dtype=np.bool_)
|
|
150
|
+
|
|
151
|
+
if datum_uids is not None:
|
|
152
|
+
if isinstance(datum_uids, list):
|
|
153
|
+
datum_uids = np.array(
|
|
154
|
+
[self.uid_to_index[uid] for uid in datum_uids],
|
|
155
|
+
dtype=np.int32,
|
|
156
|
+
)
|
|
157
|
+
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
|
|
158
|
+
mask[
|
|
159
|
+
np.isin(self._detailed_pairs[:, 0].astype(int), datum_uids)
|
|
160
|
+
] = True
|
|
161
|
+
mask_pairs &= mask
|
|
162
|
+
|
|
163
|
+
mask = np.zeros_like(mask_datums, dtype=np.bool_)
|
|
164
|
+
mask[datum_uids] = True
|
|
165
|
+
mask_datums &= mask
|
|
166
|
+
|
|
167
|
+
if labels is not None:
|
|
168
|
+
if isinstance(labels, list):
|
|
169
|
+
labels = np.array(
|
|
170
|
+
[self.label_to_index[label] for label in labels]
|
|
171
|
+
)
|
|
172
|
+
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
|
|
173
|
+
mask[
|
|
174
|
+
np.isin(self._detailed_pairs[:, 1].astype(int), labels)
|
|
175
|
+
] = True
|
|
176
|
+
mask_pairs &= mask
|
|
177
|
+
|
|
178
|
+
mask = np.zeros_like(mask_labels, dtype=np.bool_)
|
|
179
|
+
mask[labels] = True
|
|
180
|
+
mask_labels &= mask
|
|
181
|
+
|
|
182
|
+
if label_keys is not None:
|
|
183
|
+
if isinstance(label_keys, list):
|
|
184
|
+
label_keys = np.array(
|
|
185
|
+
[self.label_key_to_index[key] for key in label_keys]
|
|
186
|
+
)
|
|
187
|
+
label_indices = np.where(
|
|
188
|
+
np.isclose(self._label_metadata[:, 2], label_keys)
|
|
189
|
+
)[0]
|
|
190
|
+
mask = np.zeros_like(mask_pairs, dtype=np.bool_)
|
|
191
|
+
mask[
|
|
192
|
+
np.isin(self._detailed_pairs[:, 1].astype(int), label_indices)
|
|
193
|
+
] = True
|
|
194
|
+
mask_pairs &= mask
|
|
195
|
+
|
|
196
|
+
mask = np.zeros_like(mask_labels, dtype=np.bool_)
|
|
197
|
+
mask[label_indices] = True
|
|
198
|
+
mask_labels &= mask
|
|
199
|
+
|
|
200
|
+
mask = mask_datums[:, np.newaxis] & mask_labels[np.newaxis, :]
|
|
201
|
+
label_metadata_per_datum = self._label_metadata_per_datum.copy()
|
|
202
|
+
label_metadata_per_datum[:, ~mask] = 0
|
|
203
|
+
|
|
204
|
+
label_metadata = np.zeros_like(self._label_metadata, dtype=np.int32)
|
|
205
|
+
label_metadata[:, :2] = np.transpose(
|
|
206
|
+
np.sum(
|
|
207
|
+
label_metadata_per_datum,
|
|
208
|
+
axis=1,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
label_metadata[:, 2] = self._label_metadata[:, 2]
|
|
212
|
+
n_datums = int(np.sum(label_metadata[:, 0]))
|
|
213
|
+
|
|
214
|
+
return Filter(
|
|
215
|
+
indices=np.where(mask_pairs)[0],
|
|
216
|
+
label_metadata=label_metadata,
|
|
217
|
+
n_datums=n_datums,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def evaluate(
|
|
221
|
+
self,
|
|
222
|
+
metrics_to_return: list[MetricType] = MetricType.base(),
|
|
223
|
+
score_thresholds: list[float] = [0.0],
|
|
224
|
+
hardmax: bool = True,
|
|
225
|
+
number_of_examples: int = 0,
|
|
226
|
+
filter_: Filter | None = None,
|
|
227
|
+
as_dict: bool = False,
|
|
228
|
+
) -> dict[MetricType, list]:
|
|
229
|
+
"""
|
|
230
|
+
Performs an evaluation and returns metrics.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
metrics_to_return : list[MetricType]
|
|
235
|
+
A list of metrics to return in the results.
|
|
236
|
+
score_thresholds : list[float]
|
|
237
|
+
A list of score thresholds to compute metrics over.
|
|
238
|
+
hardmax : bool
|
|
239
|
+
Toggles whether a hardmax is applied to predictions.
|
|
240
|
+
number_of_examples : int, default=0
|
|
241
|
+
Maximum number of annotation examples to return in ConfusionMatrix.
|
|
242
|
+
filter_ : Filter, optional
|
|
243
|
+
An optional filter object.
|
|
244
|
+
as_dict : bool, default=False
|
|
245
|
+
An option to return metrics as dictionaries.
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
dict[MetricType, list]
|
|
250
|
+
A dictionary mapping MetricType enumerations to lists of computed metrics.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
# apply filters
|
|
254
|
+
data = self._detailed_pairs
|
|
255
|
+
label_metadata = self._label_metadata
|
|
256
|
+
n_datums = self.n_datums
|
|
257
|
+
if filter_ is not None:
|
|
258
|
+
data = data[filter_.indices]
|
|
259
|
+
label_metadata = filter_.label_metadata
|
|
260
|
+
n_datums = filter_.n_datums
|
|
261
|
+
|
|
262
|
+
(
|
|
263
|
+
counts,
|
|
264
|
+
precision,
|
|
265
|
+
recall,
|
|
266
|
+
accuracy,
|
|
267
|
+
f1_score,
|
|
268
|
+
rocauc,
|
|
269
|
+
mean_rocauc,
|
|
270
|
+
) = compute_metrics(
|
|
271
|
+
data=data,
|
|
272
|
+
label_metadata=label_metadata,
|
|
273
|
+
score_thresholds=np.array(score_thresholds),
|
|
274
|
+
hardmax=hardmax,
|
|
275
|
+
n_datums=n_datums,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
metrics = defaultdict(list)
|
|
279
|
+
|
|
280
|
+
metrics[MetricType.ROCAUC] = [
|
|
281
|
+
ROCAUC(
|
|
282
|
+
value=rocauc[label_idx],
|
|
283
|
+
label=self.index_to_label[label_idx],
|
|
284
|
+
)
|
|
285
|
+
for label_idx in range(label_metadata.shape[0])
|
|
286
|
+
if label_metadata[label_idx, 0] > 0
|
|
287
|
+
]
|
|
288
|
+
|
|
289
|
+
metrics[MetricType.mROCAUC] = [
|
|
290
|
+
mROCAUC(
|
|
291
|
+
value=mean_rocauc[label_key_idx],
|
|
292
|
+
label_key=self.index_to_label_key[label_key_idx],
|
|
293
|
+
)
|
|
294
|
+
for label_key_idx in range(len(self.label_key_to_index))
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
for label_idx, label in self.index_to_label.items():
|
|
298
|
+
|
|
299
|
+
kwargs = {
|
|
300
|
+
"label": label,
|
|
301
|
+
"score_thresholds": score_thresholds,
|
|
302
|
+
"hardmax": hardmax,
|
|
303
|
+
}
|
|
304
|
+
row = counts[:, label_idx]
|
|
305
|
+
metrics[MetricType.Counts].append(
|
|
306
|
+
Counts(
|
|
307
|
+
tp=row[:, 0].tolist(),
|
|
308
|
+
fp=row[:, 1].tolist(),
|
|
309
|
+
fn=row[:, 2].tolist(),
|
|
310
|
+
tn=row[:, 3].tolist(),
|
|
311
|
+
**kwargs,
|
|
312
|
+
)
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# if no groundtruths exists for a label, skip it.
|
|
316
|
+
if label_metadata[label_idx, 0] == 0:
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
metrics[MetricType.Precision].append(
|
|
320
|
+
Precision(
|
|
321
|
+
value=precision[:, label_idx].tolist(),
|
|
322
|
+
**kwargs,
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
metrics[MetricType.Recall].append(
|
|
326
|
+
Recall(
|
|
327
|
+
value=recall[:, label_idx].tolist(),
|
|
328
|
+
**kwargs,
|
|
329
|
+
)
|
|
330
|
+
)
|
|
331
|
+
metrics[MetricType.Accuracy].append(
|
|
332
|
+
Accuracy(
|
|
333
|
+
value=accuracy[:, label_idx].tolist(),
|
|
334
|
+
**kwargs,
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
metrics[MetricType.F1].append(
|
|
338
|
+
F1(
|
|
339
|
+
value=f1_score[:, label_idx].tolist(),
|
|
340
|
+
**kwargs,
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if MetricType.ConfusionMatrix in metrics_to_return:
|
|
345
|
+
metrics[
|
|
346
|
+
MetricType.ConfusionMatrix
|
|
347
|
+
] = self._compute_confusion_matrix(
|
|
348
|
+
data=data,
|
|
349
|
+
label_metadata=label_metadata,
|
|
350
|
+
score_thresholds=score_thresholds,
|
|
351
|
+
hardmax=hardmax,
|
|
352
|
+
number_of_examples=number_of_examples,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
for metric in set(metrics.keys()):
|
|
356
|
+
if metric not in metrics_to_return:
|
|
357
|
+
del metrics[metric]
|
|
358
|
+
|
|
359
|
+
if as_dict:
|
|
360
|
+
return {
|
|
361
|
+
mtype: [metric.to_dict() for metric in mvalues]
|
|
362
|
+
for mtype, mvalues in metrics.items()
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
return metrics
|
|
366
|
+
|
|
367
|
+
def _unpack_confusion_matrix(
|
|
368
|
+
self,
|
|
369
|
+
confusion_matrix: NDArray[np.floating],
|
|
370
|
+
label_key_idx: int,
|
|
371
|
+
number_of_labels: int,
|
|
372
|
+
number_of_examples: int,
|
|
373
|
+
) -> dict[
|
|
374
|
+
str,
|
|
375
|
+
dict[
|
|
376
|
+
str,
|
|
377
|
+
dict[
|
|
378
|
+
str,
|
|
379
|
+
int
|
|
380
|
+
| list[
|
|
381
|
+
dict[
|
|
382
|
+
str,
|
|
383
|
+
str | float,
|
|
384
|
+
]
|
|
385
|
+
],
|
|
386
|
+
],
|
|
387
|
+
],
|
|
388
|
+
]:
|
|
389
|
+
"""
|
|
390
|
+
Unpacks a numpy array of confusion matrix counts and examples.
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
datum_idx = lambda gt_label_idx, pd_label_idx, example_idx: int( # noqa: E731 - lambda fn
|
|
394
|
+
confusion_matrix[
|
|
395
|
+
gt_label_idx,
|
|
396
|
+
pd_label_idx,
|
|
397
|
+
example_idx * 2 + 1,
|
|
398
|
+
]
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
score_idx = lambda gt_label_idx, pd_label_idx, example_idx: float( # noqa: E731 - lambda fn
|
|
402
|
+
confusion_matrix[
|
|
403
|
+
gt_label_idx,
|
|
404
|
+
pd_label_idx,
|
|
405
|
+
example_idx * 2 + 2,
|
|
406
|
+
]
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
return {
|
|
410
|
+
self.index_to_label[gt_label_idx][1]: {
|
|
411
|
+
self.index_to_label[pd_label_idx][1]: {
|
|
412
|
+
"count": max(
|
|
413
|
+
int(confusion_matrix[gt_label_idx, pd_label_idx, 0]),
|
|
414
|
+
0,
|
|
415
|
+
),
|
|
416
|
+
"examples": [
|
|
417
|
+
{
|
|
418
|
+
"datum": self.index_to_uid[
|
|
419
|
+
datum_idx(
|
|
420
|
+
gt_label_idx, pd_label_idx, example_idx
|
|
421
|
+
)
|
|
422
|
+
],
|
|
423
|
+
"score": score_idx(
|
|
424
|
+
gt_label_idx, pd_label_idx, example_idx
|
|
425
|
+
),
|
|
426
|
+
}
|
|
427
|
+
for example_idx in range(number_of_examples)
|
|
428
|
+
if datum_idx(gt_label_idx, pd_label_idx, example_idx)
|
|
429
|
+
>= 0
|
|
430
|
+
],
|
|
431
|
+
}
|
|
432
|
+
for pd_label_idx in range(number_of_labels)
|
|
433
|
+
if (
|
|
434
|
+
self.label_index_to_label_key_index[pd_label_idx]
|
|
435
|
+
== label_key_idx
|
|
436
|
+
)
|
|
437
|
+
}
|
|
438
|
+
for gt_label_idx in range(number_of_labels)
|
|
439
|
+
if (
|
|
440
|
+
self.label_index_to_label_key_index[gt_label_idx]
|
|
441
|
+
== label_key_idx
|
|
442
|
+
)
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
def _unpack_missing_predictions(
|
|
446
|
+
self,
|
|
447
|
+
missing_predictions: NDArray[np.int32],
|
|
448
|
+
label_key_idx: int,
|
|
449
|
+
number_of_labels: int,
|
|
450
|
+
number_of_examples: int,
|
|
451
|
+
) -> dict[str, dict[str, int | list[dict[str, str]]]]:
|
|
452
|
+
"""
|
|
453
|
+
Unpacks a numpy array of missing prediction counts and examples.
|
|
454
|
+
"""
|
|
455
|
+
|
|
456
|
+
datum_idx = (
|
|
457
|
+
lambda gt_label_idx, example_idx: int( # noqa: E731 - lambda fn
|
|
458
|
+
missing_predictions[
|
|
459
|
+
gt_label_idx,
|
|
460
|
+
example_idx + 1,
|
|
461
|
+
]
|
|
462
|
+
)
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
return {
|
|
466
|
+
self.index_to_label[gt_label_idx][1]: {
|
|
467
|
+
"count": max(
|
|
468
|
+
int(missing_predictions[gt_label_idx, 0]),
|
|
469
|
+
0,
|
|
470
|
+
),
|
|
471
|
+
"examples": [
|
|
472
|
+
{
|
|
473
|
+
"datum": self.index_to_uid[
|
|
474
|
+
datum_idx(gt_label_idx, example_idx)
|
|
475
|
+
]
|
|
476
|
+
}
|
|
477
|
+
for example_idx in range(number_of_examples)
|
|
478
|
+
if datum_idx(gt_label_idx, example_idx) >= 0
|
|
479
|
+
],
|
|
480
|
+
}
|
|
481
|
+
for gt_label_idx in range(number_of_labels)
|
|
482
|
+
if (
|
|
483
|
+
self.label_index_to_label_key_index[gt_label_idx]
|
|
484
|
+
== label_key_idx
|
|
485
|
+
)
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
def _compute_confusion_matrix(
|
|
489
|
+
self,
|
|
490
|
+
data: NDArray[np.floating],
|
|
491
|
+
label_metadata: NDArray[np.int32],
|
|
492
|
+
score_thresholds: list[float],
|
|
493
|
+
hardmax: bool,
|
|
494
|
+
number_of_examples: int,
|
|
495
|
+
) -> list[ConfusionMatrix]:
|
|
496
|
+
"""
|
|
497
|
+
Computes a detailed confusion matrix..
|
|
498
|
+
|
|
499
|
+
Parameters
|
|
500
|
+
----------
|
|
501
|
+
data : NDArray[np.floating]
|
|
502
|
+
A data array containing classification pairs.
|
|
503
|
+
label_metadata : NDArray[np.int32]
|
|
504
|
+
An integer array containing label metadata.
|
|
505
|
+
score_thresholds : list[float]
|
|
506
|
+
A list of score thresholds to compute metrics over.
|
|
507
|
+
hardmax : bool
|
|
508
|
+
Toggles whether a hardmax is applied to predictions.
|
|
509
|
+
number_of_examples : int, default=0
|
|
510
|
+
The number of examples to return per count.
|
|
511
|
+
|
|
512
|
+
Returns
|
|
513
|
+
-------
|
|
514
|
+
list[ConfusionMatrix]
|
|
515
|
+
A list of ConfusionMatrix per label key.
|
|
516
|
+
"""
|
|
517
|
+
|
|
518
|
+
if data.size == 0:
|
|
519
|
+
return list()
|
|
520
|
+
|
|
521
|
+
confusion_matrix, missing_predictions = compute_confusion_matrix(
|
|
522
|
+
data=data,
|
|
523
|
+
label_metadata=label_metadata,
|
|
524
|
+
score_thresholds=np.array(score_thresholds),
|
|
525
|
+
hardmax=hardmax,
|
|
526
|
+
n_examples=number_of_examples,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
n_scores, n_labels, _, _ = confusion_matrix.shape
|
|
530
|
+
return [
|
|
531
|
+
ConfusionMatrix(
|
|
532
|
+
score_threshold=score_thresholds[score_idx],
|
|
533
|
+
label_key=label_key,
|
|
534
|
+
number_of_examples=number_of_examples,
|
|
535
|
+
confusion_matrix=self._unpack_confusion_matrix(
|
|
536
|
+
confusion_matrix=confusion_matrix[score_idx, :, :, :],
|
|
537
|
+
label_key_idx=label_key_idx,
|
|
538
|
+
number_of_labels=n_labels,
|
|
539
|
+
number_of_examples=number_of_examples,
|
|
540
|
+
),
|
|
541
|
+
missing_predictions=self._unpack_missing_predictions(
|
|
542
|
+
missing_predictions=missing_predictions[score_idx, :, :],
|
|
543
|
+
label_key_idx=label_key_idx,
|
|
544
|
+
number_of_labels=n_labels,
|
|
545
|
+
number_of_examples=number_of_examples,
|
|
546
|
+
),
|
|
547
|
+
)
|
|
548
|
+
for label_key_idx, label_key in self.index_to_label_key.items()
|
|
549
|
+
for score_idx in range(n_scores)
|
|
550
|
+
]
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
class DataLoader:
|
|
554
|
+
"""
|
|
555
|
+
Classification DataLoader.
|
|
556
|
+
"""
|
|
557
|
+
|
|
558
|
+
def __init__(self):
|
|
559
|
+
self._evaluator = Evaluator()
|
|
560
|
+
self.groundtruth_count = defaultdict(lambda: defaultdict(int))
|
|
561
|
+
self.prediction_count = defaultdict(lambda: defaultdict(int))
|
|
562
|
+
|
|
563
|
+
def _add_datum(self, uid: str) -> int:
|
|
564
|
+
"""
|
|
565
|
+
Helper function for adding a datum to the cache.
|
|
566
|
+
|
|
567
|
+
Parameters
|
|
568
|
+
----------
|
|
569
|
+
uid : str
|
|
570
|
+
The datum uid.
|
|
571
|
+
|
|
572
|
+
Returns
|
|
573
|
+
-------
|
|
574
|
+
int
|
|
575
|
+
The datum index.
|
|
576
|
+
"""
|
|
577
|
+
if uid not in self._evaluator.uid_to_index:
|
|
578
|
+
index = len(self._evaluator.uid_to_index)
|
|
579
|
+
self._evaluator.uid_to_index[uid] = index
|
|
580
|
+
self._evaluator.index_to_uid[index] = uid
|
|
581
|
+
return self._evaluator.uid_to_index[uid]
|
|
582
|
+
|
|
583
|
+
def _add_label(self, label: tuple[str, str]) -> tuple[int, int]:
|
|
584
|
+
"""
|
|
585
|
+
Helper function for adding a label to the cache.
|
|
586
|
+
|
|
587
|
+
Parameters
|
|
588
|
+
----------
|
|
589
|
+
label : tuple[str, str]
|
|
590
|
+
The label as a tuple in format (key, value).
|
|
591
|
+
|
|
592
|
+
Returns
|
|
593
|
+
-------
|
|
594
|
+
int
|
|
595
|
+
Label index.
|
|
596
|
+
int
|
|
597
|
+
Label key index.
|
|
598
|
+
"""
|
|
599
|
+
label_id = len(self._evaluator.index_to_label)
|
|
600
|
+
label_key_id = len(self._evaluator.index_to_label_key)
|
|
601
|
+
if label not in self._evaluator.label_to_index:
|
|
602
|
+
self._evaluator.label_to_index[label] = label_id
|
|
603
|
+
self._evaluator.index_to_label[label_id] = label
|
|
604
|
+
|
|
605
|
+
# update label key index
|
|
606
|
+
if label[0] not in self._evaluator.label_key_to_index:
|
|
607
|
+
self._evaluator.label_key_to_index[label[0]] = label_key_id
|
|
608
|
+
self._evaluator.index_to_label_key[label_key_id] = label[0]
|
|
609
|
+
label_key_id += 1
|
|
610
|
+
|
|
611
|
+
self._evaluator.label_index_to_label_key_index[
|
|
612
|
+
label_id
|
|
613
|
+
] = self._evaluator.label_key_to_index[label[0]]
|
|
614
|
+
label_id += 1
|
|
615
|
+
|
|
616
|
+
return (
|
|
617
|
+
self._evaluator.label_to_index[label],
|
|
618
|
+
self._evaluator.label_key_to_index[label[0]],
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
def _add_data(
|
|
622
|
+
self,
|
|
623
|
+
uid_index: int,
|
|
624
|
+
keyed_groundtruths: dict[int, int],
|
|
625
|
+
keyed_predictions: dict[int, list[tuple[int, float]]],
|
|
626
|
+
):
|
|
627
|
+
gt_keys = set(keyed_groundtruths.keys())
|
|
628
|
+
pd_keys = set(keyed_predictions.keys())
|
|
629
|
+
joint_keys = gt_keys.intersection(pd_keys)
|
|
630
|
+
|
|
631
|
+
gt_unique_keys = gt_keys - pd_keys
|
|
632
|
+
pd_unique_keys = pd_keys - gt_keys
|
|
633
|
+
if gt_unique_keys or pd_unique_keys:
|
|
634
|
+
raise ValueError(
|
|
635
|
+
"Label keys must match between ground truths and predictions."
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
pairs = list()
|
|
639
|
+
for key in joint_keys:
|
|
640
|
+
scores = np.array([score for _, score in keyed_predictions[key]])
|
|
641
|
+
max_score_idx = np.argmax(scores)
|
|
642
|
+
|
|
643
|
+
glabel = keyed_groundtruths[key]
|
|
644
|
+
for idx, (plabel, score) in enumerate(keyed_predictions[key]):
|
|
645
|
+
pairs.append(
|
|
646
|
+
(
|
|
647
|
+
float(uid_index),
|
|
648
|
+
float(glabel),
|
|
649
|
+
float(plabel),
|
|
650
|
+
float(score),
|
|
651
|
+
float(max_score_idx == idx),
|
|
652
|
+
)
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
if self._evaluator._detailed_pairs.size == 0:
|
|
656
|
+
self._evaluator._detailed_pairs = np.array(pairs)
|
|
657
|
+
else:
|
|
658
|
+
self._evaluator._detailed_pairs = np.concatenate(
|
|
659
|
+
[
|
|
660
|
+
self._evaluator._detailed_pairs,
|
|
661
|
+
np.array(pairs),
|
|
662
|
+
],
|
|
663
|
+
axis=0,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
def add_data(
|
|
667
|
+
self,
|
|
668
|
+
classifications: list[Classification],
|
|
669
|
+
show_progress: bool = False,
|
|
670
|
+
):
|
|
671
|
+
"""
|
|
672
|
+
Adds classifications to the cache.
|
|
673
|
+
|
|
674
|
+
Parameters
|
|
675
|
+
----------
|
|
676
|
+
classifications : list[Classification]
|
|
677
|
+
A list of Classification objects.
|
|
678
|
+
show_progress : bool, default=False
|
|
679
|
+
Toggle for tqdm progress bar.
|
|
680
|
+
"""
|
|
681
|
+
|
|
682
|
+
disable_tqdm = not show_progress
|
|
683
|
+
for classification in tqdm(classifications, disable=disable_tqdm):
|
|
684
|
+
|
|
685
|
+
# update metadata
|
|
686
|
+
self._evaluator.n_datums += 1
|
|
687
|
+
self._evaluator.n_groundtruths += len(classification.groundtruths)
|
|
688
|
+
self._evaluator.n_predictions += len(classification.predictions)
|
|
689
|
+
|
|
690
|
+
# update datum uid index
|
|
691
|
+
uid_index = self._add_datum(uid=classification.uid)
|
|
692
|
+
|
|
693
|
+
# cache labels and annotations
|
|
694
|
+
keyed_groundtruths = defaultdict(int)
|
|
695
|
+
keyed_predictions = defaultdict(list)
|
|
696
|
+
for glabel in classification.groundtruths:
|
|
697
|
+
label_idx, label_key_idx = self._add_label(glabel)
|
|
698
|
+
self.groundtruth_count[label_idx][uid_index] += 1
|
|
699
|
+
keyed_groundtruths[label_key_idx] = label_idx
|
|
700
|
+
for idx, (plabel, pscore) in enumerate(
|
|
701
|
+
zip(classification.predictions, classification.scores)
|
|
702
|
+
):
|
|
703
|
+
label_idx, label_key_idx = self._add_label(plabel)
|
|
704
|
+
self.prediction_count[label_idx][uid_index] += 1
|
|
705
|
+
keyed_predictions[label_key_idx].append(
|
|
706
|
+
(
|
|
707
|
+
label_idx,
|
|
708
|
+
pscore,
|
|
709
|
+
)
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
self._add_data(
|
|
713
|
+
uid_index=uid_index,
|
|
714
|
+
keyed_groundtruths=keyed_groundtruths,
|
|
715
|
+
keyed_predictions=keyed_predictions,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
def add_data_from_valor_dict(
|
|
719
|
+
self,
|
|
720
|
+
classifications: list[tuple[dict, dict]],
|
|
721
|
+
show_progress: bool = False,
|
|
722
|
+
):
|
|
723
|
+
"""
|
|
724
|
+
Adds Valor-format classifications to the cache.
|
|
725
|
+
|
|
726
|
+
Parameters
|
|
727
|
+
----------
|
|
728
|
+
classifications : list[tuple[dict, dict]]
|
|
729
|
+
A list of groundtruth, prediction pairs in Valor-format dictionaries.
|
|
730
|
+
show_progress : bool, default=False
|
|
731
|
+
Toggle for tqdm progress bar.
|
|
732
|
+
"""
|
|
733
|
+
|
|
734
|
+
disable_tqdm = not show_progress
|
|
735
|
+
for groundtruth, prediction in tqdm(
|
|
736
|
+
classifications, disable=disable_tqdm
|
|
737
|
+
):
|
|
738
|
+
|
|
739
|
+
# update metadata
|
|
740
|
+
self._evaluator.n_datums += 1
|
|
741
|
+
self._evaluator.n_groundtruths += len(groundtruth["annotations"])
|
|
742
|
+
self._evaluator.n_predictions += len(prediction["annotations"])
|
|
743
|
+
|
|
744
|
+
# update datum uid index
|
|
745
|
+
uid_index = self._add_datum(uid=groundtruth["datum"]["uid"])
|
|
746
|
+
|
|
747
|
+
# cache labels and annotations
|
|
748
|
+
keyed_groundtruths = defaultdict(int)
|
|
749
|
+
keyed_predictions = defaultdict(list)
|
|
750
|
+
for gann in groundtruth["annotations"]:
|
|
751
|
+
for valor_label in gann["labels"]:
|
|
752
|
+
glabel = (valor_label["key"], valor_label["value"])
|
|
753
|
+
label_idx, label_key_idx = self._add_label(glabel)
|
|
754
|
+
self.groundtruth_count[label_idx][uid_index] += 1
|
|
755
|
+
keyed_groundtruths[label_key_idx] = label_idx
|
|
756
|
+
for pann in prediction["annotations"]:
|
|
757
|
+
for valor_label in pann["labels"]:
|
|
758
|
+
plabel = (valor_label["key"], valor_label["value"])
|
|
759
|
+
pscore = valor_label["score"]
|
|
760
|
+
label_idx, label_key_idx = self._add_label(plabel)
|
|
761
|
+
self.prediction_count[label_idx][uid_index] += 1
|
|
762
|
+
keyed_predictions[label_key_idx].append(
|
|
763
|
+
(
|
|
764
|
+
label_idx,
|
|
765
|
+
pscore,
|
|
766
|
+
)
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
self._add_data(
|
|
770
|
+
uid_index=uid_index,
|
|
771
|
+
keyed_groundtruths=keyed_groundtruths,
|
|
772
|
+
keyed_predictions=keyed_predictions,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
def finalize(self) -> Evaluator:
|
|
776
|
+
"""
|
|
777
|
+
Performs data finalization and some preprocessing steps.
|
|
778
|
+
|
|
779
|
+
Returns
|
|
780
|
+
-------
|
|
781
|
+
Evaluator
|
|
782
|
+
A ready-to-use evaluator object.
|
|
783
|
+
"""
|
|
784
|
+
|
|
785
|
+
if self._evaluator._detailed_pairs.size == 0:
|
|
786
|
+
raise ValueError("No data available to create evaluator.")
|
|
787
|
+
|
|
788
|
+
n_datums = self._evaluator.n_datums
|
|
789
|
+
n_labels = len(self._evaluator.index_to_label)
|
|
790
|
+
|
|
791
|
+
self._evaluator.n_labels = n_labels
|
|
792
|
+
|
|
793
|
+
self._evaluator._label_metadata_per_datum = np.zeros(
|
|
794
|
+
(2, n_datums, n_labels), dtype=np.int32
|
|
795
|
+
)
|
|
796
|
+
for datum_idx in range(n_datums):
|
|
797
|
+
for label_idx in range(n_labels):
|
|
798
|
+
gt_count = (
|
|
799
|
+
self.groundtruth_count[label_idx].get(datum_idx, 0)
|
|
800
|
+
if label_idx in self.groundtruth_count
|
|
801
|
+
else 0
|
|
802
|
+
)
|
|
803
|
+
pd_count = (
|
|
804
|
+
self.prediction_count[label_idx].get(datum_idx, 0)
|
|
805
|
+
if label_idx in self.prediction_count
|
|
806
|
+
else 0
|
|
807
|
+
)
|
|
808
|
+
self._evaluator._label_metadata_per_datum[
|
|
809
|
+
:, datum_idx, label_idx
|
|
810
|
+
] = np.array([gt_count, pd_count])
|
|
811
|
+
|
|
812
|
+
self._evaluator._label_metadata = np.array(
|
|
813
|
+
[
|
|
814
|
+
[
|
|
815
|
+
np.sum(
|
|
816
|
+
self._evaluator._label_metadata_per_datum[
|
|
817
|
+
0, :, label_idx
|
|
818
|
+
]
|
|
819
|
+
),
|
|
820
|
+
np.sum(
|
|
821
|
+
self._evaluator._label_metadata_per_datum[
|
|
822
|
+
1, :, label_idx
|
|
823
|
+
]
|
|
824
|
+
),
|
|
825
|
+
self._evaluator.label_index_to_label_key_index[label_idx],
|
|
826
|
+
]
|
|
827
|
+
for label_idx in range(n_labels)
|
|
828
|
+
],
|
|
829
|
+
dtype=np.int32,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
# sort pairs by groundtruth, prediction, score
|
|
833
|
+
indices = np.lexsort(
|
|
834
|
+
(
|
|
835
|
+
self._evaluator._detailed_pairs[:, 1],
|
|
836
|
+
self._evaluator._detailed_pairs[:, 2],
|
|
837
|
+
-self._evaluator._detailed_pairs[:, 3],
|
|
838
|
+
)
|
|
839
|
+
)
|
|
840
|
+
self._evaluator._detailed_pairs = self._evaluator._detailed_pairs[
|
|
841
|
+
indices
|
|
842
|
+
]
|
|
843
|
+
|
|
844
|
+
return self._evaluator
|