valor-lite 0.36.2__py3-none-any.whl → 0.36.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/classification/computation.py +25 -132
- valor_lite/classification/manager.py +72 -54
- valor_lite/classification/metric.py +0 -4
- valor_lite/classification/utilities.py +85 -103
- valor_lite/exceptions.py +3 -3
- valor_lite/object_detection/manager.py +118 -56
- valor_lite/semantic_segmentation/manager.py +55 -32
- {valor_lite-0.36.2.dist-info → valor_lite-0.36.4.dist-info}/METADATA +1 -1
- {valor_lite-0.36.2.dist-info → valor_lite-0.36.4.dist-info}/RECORD +11 -11
- {valor_lite-0.36.2.dist-info → valor_lite-0.36.4.dist-info}/WHEEL +0 -0
- {valor_lite-0.36.2.dist-info → valor_lite-0.36.4.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from enum import IntFlag, auto
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
from numpy.typing import NDArray
|
|
3
5
|
|
|
@@ -318,56 +320,20 @@ def compute_precision_recall_rocauc(
|
|
|
318
320
|
)
|
|
319
321
|
|
|
320
322
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
) -> tuple[NDArray[np.float64], NDArray[np.int32], NDArray[np.intp]]:
|
|
326
|
-
"""
|
|
327
|
-
Helper function for counting occurences of unique detailed pairs.
|
|
328
|
-
|
|
329
|
-
Parameters
|
|
330
|
-
----------
|
|
331
|
-
data : NDArray[np.float64]
|
|
332
|
-
A masked portion of a detailed pairs array.
|
|
333
|
-
unique_idx : int | list[int]
|
|
334
|
-
The index or indices upon which uniqueness is constrained.
|
|
335
|
-
label_idx : int | list[int]
|
|
336
|
-
The index or indices within the unique index or indices that encode labels.
|
|
337
|
-
|
|
338
|
-
Returns
|
|
339
|
-
-------
|
|
340
|
-
NDArray[np.float64]
|
|
341
|
-
Examples drawn from the data input.
|
|
342
|
-
NDArray[np.int32]
|
|
343
|
-
Unique label indices.
|
|
344
|
-
NDArray[np.intp]
|
|
345
|
-
Counts for each unique label index.
|
|
346
|
-
"""
|
|
347
|
-
unique_rows, indices = np.unique(
|
|
348
|
-
data.astype(int)[:, unique_idx],
|
|
349
|
-
return_index=True,
|
|
350
|
-
axis=0,
|
|
351
|
-
)
|
|
352
|
-
examples = data[indices]
|
|
353
|
-
labels, counts = np.unique(
|
|
354
|
-
unique_rows[:, label_idx], return_counts=True, axis=0
|
|
355
|
-
)
|
|
356
|
-
return examples, labels, counts
|
|
323
|
+
class PairClassification(IntFlag):
|
|
324
|
+
TP = auto()
|
|
325
|
+
FP_FN_MISCLF = auto()
|
|
326
|
+
FN_UNMATCHED = auto()
|
|
357
327
|
|
|
358
328
|
|
|
359
329
|
def compute_confusion_matrix(
|
|
360
330
|
detailed_pairs: NDArray[np.float64],
|
|
361
|
-
label_metadata: NDArray[np.int32],
|
|
362
331
|
score_thresholds: NDArray[np.float64],
|
|
363
332
|
hardmax: bool,
|
|
364
|
-
|
|
365
|
-
) -> tuple[NDArray[np.float64], NDArray[np.int32]]:
|
|
333
|
+
) -> NDArray[np.uint8]:
|
|
366
334
|
"""
|
|
367
335
|
Compute detailed confusion matrix.
|
|
368
336
|
|
|
369
|
-
Takes data with shape (N, 5):
|
|
370
|
-
|
|
371
337
|
Parameters
|
|
372
338
|
----------
|
|
373
339
|
detailed_pairs : NDArray[np.float64]
|
|
@@ -377,37 +343,22 @@ def compute_confusion_matrix(
|
|
|
377
343
|
Index 2 - Prediction Label Index
|
|
378
344
|
Index 3 - Score
|
|
379
345
|
Index 4 - Hard Max Score
|
|
380
|
-
label_metadata : NDArray[np.int32]
|
|
381
|
-
A 2-D array containing metadata related to labels with shape (n_labels, 2).
|
|
382
|
-
Index 0 - GroundTruth Label Count
|
|
383
|
-
Index 1 - Prediction Label Count
|
|
384
346
|
iou_thresholds : NDArray[np.float64]
|
|
385
347
|
A 1-D array containing IOU thresholds.
|
|
386
348
|
score_thresholds : NDArray[np.float64]
|
|
387
349
|
A 1-D array containing score thresholds.
|
|
388
|
-
n_examples : int
|
|
389
|
-
The maximum number of examples to return per count.
|
|
390
350
|
|
|
391
351
|
Returns
|
|
392
352
|
-------
|
|
393
|
-
NDArray[
|
|
394
|
-
|
|
395
|
-
NDArray[np.int32]
|
|
396
|
-
Unmatched Ground Truths.
|
|
353
|
+
NDArray[uint8]
|
|
354
|
+
Row-wise classification of pairs.
|
|
397
355
|
"""
|
|
398
|
-
|
|
399
|
-
n_labels = label_metadata.shape[0]
|
|
356
|
+
n_pairs = detailed_pairs.shape[0]
|
|
400
357
|
n_scores = score_thresholds.shape[0]
|
|
401
358
|
|
|
402
|
-
|
|
403
|
-
(n_scores,
|
|
404
|
-
|
|
405
|
-
dtype=np.float32,
|
|
406
|
-
)
|
|
407
|
-
unmatched_ground_truths = np.full(
|
|
408
|
-
(n_scores, n_labels, n_examples + 1),
|
|
409
|
-
fill_value=-1,
|
|
410
|
-
dtype=np.int32,
|
|
359
|
+
pair_classifications = np.zeros(
|
|
360
|
+
(n_scores, n_pairs),
|
|
361
|
+
dtype=np.uint8,
|
|
411
362
|
)
|
|
412
363
|
|
|
413
364
|
mask_label_match = np.isclose(detailed_pairs[:, 1], detailed_pairs[:, 2])
|
|
@@ -420,9 +371,9 @@ def compute_confusion_matrix(
|
|
|
420
371
|
if hardmax:
|
|
421
372
|
mask_score &= detailed_pairs[:, 4] > 0.5
|
|
422
373
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
374
|
+
mask_true_positives = mask_label_match & mask_score
|
|
375
|
+
mask_misclassifications = ~mask_label_match & mask_score
|
|
376
|
+
mask_unmatched_groundtruths = ~(
|
|
426
377
|
(
|
|
427
378
|
groundtruths.reshape(-1, 1, 2)
|
|
428
379
|
== groundtruths[mask_score].reshape(1, -1, 2)
|
|
@@ -431,73 +382,15 @@ def compute_confusion_matrix(
|
|
|
431
382
|
.any(axis=1)
|
|
432
383
|
)
|
|
433
384
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
label_idx=1,
|
|
438
|
-
)
|
|
439
|
-
misclf_examples, misclf_labels, misclf_counts = _count_with_examples(
|
|
440
|
-
data=detailed_pairs[mask_misclf],
|
|
441
|
-
unique_idx=[0, 1, 2],
|
|
442
|
-
label_idx=[1, 2],
|
|
385
|
+
# classify pairings
|
|
386
|
+
pair_classifications[score_idx, mask_true_positives] |= np.uint8(
|
|
387
|
+
PairClassification.TP
|
|
443
388
|
)
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
unique_idx=[0, 1],
|
|
447
|
-
label_idx=1,
|
|
389
|
+
pair_classifications[score_idx, mask_misclassifications] |= np.uint8(
|
|
390
|
+
PairClassification.FP_FN_MISCLF
|
|
448
391
|
)
|
|
392
|
+
pair_classifications[
|
|
393
|
+
score_idx, mask_unmatched_groundtruths
|
|
394
|
+
] |= np.uint8(PairClassification.FN_UNMATCHED)
|
|
449
395
|
|
|
450
|
-
|
|
451
|
-
confusion_matrix[
|
|
452
|
-
score_idx, misclf_labels[:, 0], misclf_labels[:, 1], 0
|
|
453
|
-
] = misclf_counts
|
|
454
|
-
|
|
455
|
-
unmatched_ground_truths[score_idx, misprd_labels, 0] = misprd_counts
|
|
456
|
-
|
|
457
|
-
if n_examples > 0:
|
|
458
|
-
for label_idx in range(n_labels):
|
|
459
|
-
# true-positive examples
|
|
460
|
-
mask_tp_label = tp_examples[:, 2] == label_idx
|
|
461
|
-
if mask_tp_label.sum() > 0:
|
|
462
|
-
tp_label_examples = tp_examples[mask_tp_label][:n_examples]
|
|
463
|
-
confusion_matrix[
|
|
464
|
-
score_idx,
|
|
465
|
-
label_idx,
|
|
466
|
-
label_idx,
|
|
467
|
-
1 : 2 * tp_label_examples.shape[0] + 1,
|
|
468
|
-
] = tp_label_examples[:, [0, 3]].flatten()
|
|
469
|
-
|
|
470
|
-
# misclassification examples
|
|
471
|
-
mask_misclf_gt_label = misclf_examples[:, 1] == label_idx
|
|
472
|
-
if mask_misclf_gt_label.sum() > 0:
|
|
473
|
-
for pd_label_idx in range(n_labels):
|
|
474
|
-
mask_misclf_pd_label = (
|
|
475
|
-
misclf_examples[:, 2] == pd_label_idx
|
|
476
|
-
)
|
|
477
|
-
mask_misclf_label_combo = (
|
|
478
|
-
mask_misclf_gt_label & mask_misclf_pd_label
|
|
479
|
-
)
|
|
480
|
-
if mask_misclf_label_combo.sum() > 0:
|
|
481
|
-
misclf_label_examples = misclf_examples[
|
|
482
|
-
mask_misclf_label_combo
|
|
483
|
-
][:n_examples]
|
|
484
|
-
confusion_matrix[
|
|
485
|
-
score_idx,
|
|
486
|
-
label_idx,
|
|
487
|
-
pd_label_idx,
|
|
488
|
-
1 : 2 * misclf_label_examples.shape[0] + 1,
|
|
489
|
-
] = misclf_label_examples[:, [0, 3]].flatten()
|
|
490
|
-
|
|
491
|
-
# unmatched ground truth examples
|
|
492
|
-
mask_misprd_label = misprd_examples[:, 1] == label_idx
|
|
493
|
-
if misprd_examples.size > 0:
|
|
494
|
-
misprd_label_examples = misprd_examples[mask_misprd_label][
|
|
495
|
-
:n_examples
|
|
496
|
-
]
|
|
497
|
-
unmatched_ground_truths[
|
|
498
|
-
score_idx,
|
|
499
|
-
label_idx,
|
|
500
|
-
1 : misprd_label_examples.shape[0] + 1,
|
|
501
|
-
] = misprd_label_examples[:, 0].flatten()
|
|
502
|
-
|
|
503
|
-
return confusion_matrix, unmatched_ground_truths # type: ignore[reportReturnType]
|
|
396
|
+
return pair_classifications
|
|
@@ -16,7 +16,7 @@ from valor_lite.classification.utilities import (
|
|
|
16
16
|
unpack_confusion_matrix_into_metric_list,
|
|
17
17
|
unpack_precision_recall_rocauc_into_metric_lists,
|
|
18
18
|
)
|
|
19
|
-
from valor_lite.exceptions import
|
|
19
|
+
from valor_lite.exceptions import EmptyEvaluatorError, EmptyFilterError
|
|
20
20
|
|
|
21
21
|
"""
|
|
22
22
|
Usage
|
|
@@ -88,14 +88,14 @@ class Filter:
|
|
|
88
88
|
def __post_init__(self):
|
|
89
89
|
# validate datum mask
|
|
90
90
|
if not self.datum_mask.any():
|
|
91
|
-
raise
|
|
91
|
+
raise EmptyFilterError("filter removes all datums")
|
|
92
92
|
|
|
93
93
|
# validate label indices
|
|
94
94
|
if (
|
|
95
95
|
self.valid_label_indices is not None
|
|
96
96
|
and self.valid_label_indices.size == 0
|
|
97
97
|
):
|
|
98
|
-
raise
|
|
98
|
+
raise EmptyFilterError("filter removes all labels")
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
class Evaluator:
|
|
@@ -144,18 +144,18 @@ class Evaluator:
|
|
|
144
144
|
|
|
145
145
|
def create_filter(
|
|
146
146
|
self,
|
|
147
|
-
|
|
148
|
-
labels: list[str] | None = None,
|
|
147
|
+
datums: list[str] | NDArray[np.int32] | None = None,
|
|
148
|
+
labels: list[str] | NDArray[np.int32] | None = None,
|
|
149
149
|
) -> Filter:
|
|
150
150
|
"""
|
|
151
151
|
Creates a filter object.
|
|
152
152
|
|
|
153
153
|
Parameters
|
|
154
154
|
----------
|
|
155
|
-
|
|
156
|
-
An optional list of string uids representing datums.
|
|
157
|
-
labels : list[str], optional
|
|
158
|
-
An optional list of labels.
|
|
155
|
+
datums : list[str] | NDArray[int32], optional
|
|
156
|
+
An optional list of string uids or integer indices representing datums.
|
|
157
|
+
labels : list[str] | NDArray[int32], optional
|
|
158
|
+
An optional list of strings or integer indices representing labels.
|
|
159
159
|
|
|
160
160
|
Returns
|
|
161
161
|
-------
|
|
@@ -165,50 +165,72 @@ class Evaluator:
|
|
|
165
165
|
# create datum mask
|
|
166
166
|
n_pairs = self._detailed_pairs.shape[0]
|
|
167
167
|
datum_mask = np.ones(n_pairs, dtype=np.bool_)
|
|
168
|
-
if
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
168
|
+
if datums is not None:
|
|
169
|
+
# convert to array of valid datum indices
|
|
170
|
+
if isinstance(datums, list):
|
|
171
|
+
datums = np.array(
|
|
172
|
+
[self.datum_id_to_index[uid] for uid in datums],
|
|
173
|
+
dtype=np.int32,
|
|
174
174
|
)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
)
|
|
175
|
+
|
|
176
|
+
# return early if all data removed
|
|
177
|
+
if datums.size == 0:
|
|
178
|
+
raise EmptyFilterError("filter removes all datums")
|
|
179
|
+
|
|
180
|
+
# validate indices
|
|
181
|
+
if datums.max() >= len(self.index_to_datum_id):
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"datum index '{datums.max()}' exceeds total number of datums"
|
|
184
|
+
)
|
|
185
|
+
elif datums.min() < 0:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"datum index '{datums.min()}' is a negative value"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# create datum mask
|
|
191
|
+
datum_mask = np.isin(self._detailed_pairs[:, 0], datums)
|
|
182
192
|
|
|
183
193
|
# collect valid label indices
|
|
184
|
-
valid_label_indices = None
|
|
185
194
|
if labels is not None:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
metadata=Metadata(),
|
|
195
|
+
# convert to array of valid label indices
|
|
196
|
+
if isinstance(labels, list):
|
|
197
|
+
labels = np.array(
|
|
198
|
+
[self.label_to_index[label] for label in labels]
|
|
191
199
|
)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
200
|
+
|
|
201
|
+
# return early if all data removed
|
|
202
|
+
if labels.size == 0:
|
|
203
|
+
raise EmptyFilterError("filter removes all labels")
|
|
204
|
+
|
|
205
|
+
# validate indices
|
|
206
|
+
if labels.max() >= len(self.index_to_label):
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"label index '{labels.max()}' exceeds total number of labels"
|
|
209
|
+
)
|
|
210
|
+
elif labels.min() < 0:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"label index '{labels.min()}' is a negative value"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# add -1 to represent null labels which should not be filtered
|
|
216
|
+
labels = np.concatenate([labels, np.array([-1])])
|
|
195
217
|
|
|
196
218
|
filtered_detailed_pairs, _ = filter_cache(
|
|
197
219
|
detailed_pairs=self._detailed_pairs,
|
|
198
220
|
datum_mask=datum_mask,
|
|
199
|
-
valid_label_indices=
|
|
221
|
+
valid_label_indices=labels,
|
|
200
222
|
n_labels=self.metadata.number_of_labels,
|
|
201
223
|
)
|
|
202
224
|
|
|
203
225
|
number_of_datums = (
|
|
204
|
-
|
|
205
|
-
if
|
|
226
|
+
datums.size
|
|
227
|
+
if datums is not None
|
|
206
228
|
else self.metadata.number_of_datums
|
|
207
229
|
)
|
|
208
230
|
|
|
209
231
|
return Filter(
|
|
210
232
|
datum_mask=datum_mask,
|
|
211
|
-
valid_label_indices=
|
|
233
|
+
valid_label_indices=labels,
|
|
212
234
|
metadata=Metadata.create(
|
|
213
235
|
detailed_pairs=filtered_detailed_pairs,
|
|
214
236
|
number_of_datums=number_of_datums,
|
|
@@ -292,7 +314,6 @@ class Evaluator:
|
|
|
292
314
|
self,
|
|
293
315
|
score_thresholds: list[float] = [0.0],
|
|
294
316
|
hardmax: bool = True,
|
|
295
|
-
number_of_examples: int = 0,
|
|
296
317
|
filter_: Filter | None = None,
|
|
297
318
|
) -> list[Metric]:
|
|
298
319
|
"""
|
|
@@ -304,8 +325,6 @@ class Evaluator:
|
|
|
304
325
|
A list of score thresholds to compute metrics over.
|
|
305
326
|
hardmax : bool
|
|
306
327
|
Toggles whether a hardmax is applied to predictions.
|
|
307
|
-
number_of_examples : int, default=0
|
|
308
|
-
The number of examples to return per count.
|
|
309
328
|
filter_ : Filter, optional
|
|
310
329
|
Applies a filter to the internal cache.
|
|
311
330
|
|
|
@@ -316,25 +335,22 @@ class Evaluator:
|
|
|
316
335
|
"""
|
|
317
336
|
# apply filters
|
|
318
337
|
if filter_ is not None:
|
|
319
|
-
detailed_pairs,
|
|
338
|
+
detailed_pairs, _ = self.filter(filter_=filter_)
|
|
320
339
|
else:
|
|
321
340
|
detailed_pairs = self._detailed_pairs
|
|
322
|
-
label_metadata = self._label_metadata
|
|
323
341
|
|
|
324
342
|
if detailed_pairs.size == 0:
|
|
325
343
|
return list()
|
|
326
344
|
|
|
327
|
-
|
|
345
|
+
result = compute_confusion_matrix(
|
|
328
346
|
detailed_pairs=detailed_pairs,
|
|
329
|
-
label_metadata=label_metadata,
|
|
330
347
|
score_thresholds=np.array(score_thresholds),
|
|
331
348
|
hardmax=hardmax,
|
|
332
|
-
n_examples=number_of_examples,
|
|
333
349
|
)
|
|
334
350
|
return unpack_confusion_matrix_into_metric_list(
|
|
335
|
-
|
|
351
|
+
detailed_pairs=detailed_pairs,
|
|
352
|
+
result=result,
|
|
336
353
|
score_thresholds=score_thresholds,
|
|
337
|
-
number_of_examples=number_of_examples,
|
|
338
354
|
index_to_datum_id=self.index_to_datum_id,
|
|
339
355
|
index_to_label=self.index_to_label,
|
|
340
356
|
)
|
|
@@ -343,7 +359,6 @@ class Evaluator:
|
|
|
343
359
|
self,
|
|
344
360
|
score_thresholds: list[float] = [0.0],
|
|
345
361
|
hardmax: bool = True,
|
|
346
|
-
number_of_examples: int = 0,
|
|
347
362
|
filter_: Filter | None = None,
|
|
348
363
|
) -> dict[MetricType, list[Metric]]:
|
|
349
364
|
"""
|
|
@@ -355,8 +370,6 @@ class Evaluator:
|
|
|
355
370
|
A list of score thresholds to compute metrics over.
|
|
356
371
|
hardmax : bool
|
|
357
372
|
Toggles whether a hardmax is applied to predictions.
|
|
358
|
-
number_of_examples : int, default=0
|
|
359
|
-
The number of examples to return per count.
|
|
360
373
|
filter_ : Filter, optional
|
|
361
374
|
Applies a filter to the internal cache.
|
|
362
375
|
|
|
@@ -373,7 +386,6 @@ class Evaluator:
|
|
|
373
386
|
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
374
387
|
score_thresholds=score_thresholds,
|
|
375
388
|
hardmax=hardmax,
|
|
376
|
-
number_of_examples=number_of_examples,
|
|
377
389
|
filter_=filter_,
|
|
378
390
|
)
|
|
379
391
|
return metrics
|
|
@@ -391,11 +403,17 @@ class Evaluator:
|
|
|
391
403
|
-------
|
|
392
404
|
int
|
|
393
405
|
The datum index.
|
|
406
|
+
|
|
407
|
+
Raises
|
|
408
|
+
------
|
|
409
|
+
ValueError
|
|
410
|
+
If datum id already exists.
|
|
394
411
|
"""
|
|
395
|
-
if uid
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
412
|
+
if uid in self.datum_id_to_index:
|
|
413
|
+
raise ValueError("datum with id '{uid}' already exists")
|
|
414
|
+
index = len(self.datum_id_to_index)
|
|
415
|
+
self.datum_id_to_index[uid] = index
|
|
416
|
+
self.index_to_datum_id.append(uid)
|
|
399
417
|
return self.datum_id_to_index[uid]
|
|
400
418
|
|
|
401
419
|
def _add_label(self, label: str) -> int:
|
|
@@ -497,7 +515,7 @@ class Evaluator:
|
|
|
497
515
|
A ready-to-use evaluator object.
|
|
498
516
|
"""
|
|
499
517
|
if self._detailed_pairs.size == 0:
|
|
500
|
-
raise
|
|
518
|
+
raise EmptyEvaluatorError()
|
|
501
519
|
|
|
502
520
|
self._label_metadata = compute_label_metadata(
|
|
503
521
|
ids=self._detailed_pairs[:, :3].astype(np.int32),
|
|
@@ -329,7 +329,6 @@ class Metric(BaseMetric):
|
|
|
329
329
|
],
|
|
330
330
|
],
|
|
331
331
|
score_threshold: float,
|
|
332
|
-
maximum_number_of_examples: int,
|
|
333
332
|
):
|
|
334
333
|
"""
|
|
335
334
|
The confusion matrix and related metrics for the classification task.
|
|
@@ -382,8 +381,6 @@ class Metric(BaseMetric):
|
|
|
382
381
|
Each example includes the datum UID.
|
|
383
382
|
score_threshold : float
|
|
384
383
|
The confidence score threshold used to filter predictions.
|
|
385
|
-
maximum_number_of_examples : int
|
|
386
|
-
The maximum number of examples per element.
|
|
387
384
|
|
|
388
385
|
Returns
|
|
389
386
|
-------
|
|
@@ -397,6 +394,5 @@ class Metric(BaseMetric):
|
|
|
397
394
|
},
|
|
398
395
|
parameters={
|
|
399
396
|
"score_threshold": score_threshold,
|
|
400
|
-
"maximum_number_of_examples": maximum_number_of_examples,
|
|
401
397
|
},
|
|
402
398
|
)
|
|
@@ -3,6 +3,7 @@ from collections import defaultdict
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from numpy.typing import NDArray
|
|
5
5
|
|
|
6
|
+
from valor_lite.classification.computation import PairClassification
|
|
6
7
|
from valor_lite.classification.metric import Metric, MetricType
|
|
7
8
|
|
|
8
9
|
|
|
@@ -101,129 +102,110 @@ def unpack_precision_recall_rocauc_into_metric_lists(
|
|
|
101
102
|
return metrics
|
|
102
103
|
|
|
103
104
|
|
|
104
|
-
def
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
datum_idx = lambda gt_label_idx, pd_label_idx, example_idx: int( # noqa: E731 - lambda fn
|
|
116
|
-
confusion_matrix[
|
|
117
|
-
gt_label_idx,
|
|
118
|
-
pd_label_idx,
|
|
119
|
-
example_idx * 2 + 1,
|
|
120
|
-
]
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
score_idx = lambda gt_label_idx, pd_label_idx, example_idx: float( # noqa: E731 - lambda fn
|
|
124
|
-
confusion_matrix[
|
|
125
|
-
gt_label_idx,
|
|
126
|
-
pd_label_idx,
|
|
127
|
-
example_idx * 2 + 2,
|
|
128
|
-
]
|
|
105
|
+
def _create_empty_confusion_matrix(index_to_labels: list[str]):
|
|
106
|
+
unmatched_ground_truths = dict()
|
|
107
|
+
confusion_matrix = dict()
|
|
108
|
+
for label in index_to_labels:
|
|
109
|
+
unmatched_ground_truths[label] = {"count": 0, "examples": []}
|
|
110
|
+
confusion_matrix[label] = {}
|
|
111
|
+
for plabel in index_to_labels:
|
|
112
|
+
confusion_matrix[label][plabel] = {"count": 0, "examples": []}
|
|
113
|
+
return (
|
|
114
|
+
confusion_matrix,
|
|
115
|
+
unmatched_ground_truths,
|
|
129
116
|
)
|
|
130
117
|
|
|
131
|
-
return {
|
|
132
|
-
index_to_label[gt_label_idx]: {
|
|
133
|
-
index_to_label[pd_label_idx]: {
|
|
134
|
-
"count": max(
|
|
135
|
-
int(confusion_matrix[gt_label_idx, pd_label_idx, 0]),
|
|
136
|
-
0,
|
|
137
|
-
),
|
|
138
|
-
"examples": [
|
|
139
|
-
{
|
|
140
|
-
"datum_id": index_to_datum_id[
|
|
141
|
-
datum_idx(gt_label_idx, pd_label_idx, example_idx)
|
|
142
|
-
],
|
|
143
|
-
"score": score_idx(
|
|
144
|
-
gt_label_idx, pd_label_idx, example_idx
|
|
145
|
-
),
|
|
146
|
-
}
|
|
147
|
-
for example_idx in range(number_of_examples)
|
|
148
|
-
if datum_idx(gt_label_idx, pd_label_idx, example_idx) >= 0
|
|
149
|
-
],
|
|
150
|
-
}
|
|
151
|
-
for pd_label_idx in range(number_of_labels)
|
|
152
|
-
}
|
|
153
|
-
for gt_label_idx in range(number_of_labels)
|
|
154
|
-
}
|
|
155
|
-
|
|
156
118
|
|
|
157
|
-
def
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
119
|
+
def _unpack_confusion_matrix(
|
|
120
|
+
ids: NDArray[np.int32],
|
|
121
|
+
scores: NDArray[np.float64],
|
|
122
|
+
mask_matched: NDArray[np.bool_],
|
|
123
|
+
mask_fn_unmatched: NDArray[np.bool_],
|
|
161
124
|
index_to_datum_id: list[str],
|
|
162
125
|
index_to_label: list[str],
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
)
|
|
126
|
+
score_threshold: float,
|
|
127
|
+
):
|
|
128
|
+
(
|
|
129
|
+
confusion_matrix,
|
|
130
|
+
unmatched_ground_truths,
|
|
131
|
+
) = _create_empty_confusion_matrix(index_to_label)
|
|
132
|
+
|
|
133
|
+
unique_matches, unique_match_indices = np.unique(
|
|
134
|
+
ids[np.ix_(mask_matched, (0, 1, 2))], # type: ignore - numpy ix_ typing
|
|
135
|
+
axis=0,
|
|
136
|
+
return_index=True,
|
|
175
137
|
)
|
|
138
|
+
(
|
|
139
|
+
unique_unmatched_groundtruths,
|
|
140
|
+
unique_unmatched_groundtruth_indices,
|
|
141
|
+
) = np.unique(
|
|
142
|
+
ids[np.ix_(mask_fn_unmatched, (0, 1))], # type: ignore - numpy ix_ typing
|
|
143
|
+
axis=0,
|
|
144
|
+
return_index=True,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
n_matched = unique_matches.shape[0]
|
|
148
|
+
n_unmatched_groundtruths = unique_unmatched_groundtruths.shape[0]
|
|
149
|
+
n_max = max(n_matched, n_unmatched_groundtruths)
|
|
176
150
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
151
|
+
for idx in range(n_max):
|
|
152
|
+
if idx < n_matched:
|
|
153
|
+
glabel = index_to_label[unique_matches[idx, 1]]
|
|
154
|
+
plabel = index_to_label[unique_matches[idx, 2]]
|
|
155
|
+
confusion_matrix[glabel][plabel]["count"] += 1
|
|
156
|
+
confusion_matrix[glabel][plabel]["examples"].append(
|
|
157
|
+
{
|
|
158
|
+
"datum_id": index_to_datum_id[unique_matches[idx, 0]],
|
|
159
|
+
"score": float(scores[unique_match_indices[idx]]),
|
|
160
|
+
}
|
|
161
|
+
)
|
|
162
|
+
if idx < n_unmatched_groundtruths:
|
|
163
|
+
label = index_to_label[unique_unmatched_groundtruths[idx, 1]]
|
|
164
|
+
unmatched_ground_truths[label]["count"] += 1
|
|
165
|
+
unmatched_ground_truths[label]["examples"].append(
|
|
184
166
|
{
|
|
185
167
|
"datum_id": index_to_datum_id[
|
|
186
|
-
|
|
187
|
-
]
|
|
168
|
+
unique_unmatched_groundtruths[idx, 0]
|
|
169
|
+
],
|
|
188
170
|
}
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return Metric.confusion_matrix(
|
|
174
|
+
confusion_matrix=confusion_matrix,
|
|
175
|
+
unmatched_ground_truths=unmatched_ground_truths,
|
|
176
|
+
score_threshold=score_threshold,
|
|
177
|
+
)
|
|
195
178
|
|
|
196
179
|
|
|
197
180
|
def unpack_confusion_matrix_into_metric_list(
|
|
198
|
-
|
|
181
|
+
result: NDArray[np.uint8],
|
|
182
|
+
detailed_pairs: NDArray[np.float64],
|
|
199
183
|
score_thresholds: list[float],
|
|
200
|
-
number_of_examples: int,
|
|
201
184
|
index_to_datum_id: list[str],
|
|
202
185
|
index_to_label: list[str],
|
|
203
186
|
) -> list[Metric]:
|
|
204
187
|
|
|
205
|
-
|
|
206
|
-
|
|
188
|
+
ids = detailed_pairs[:, :3].astype(np.int32)
|
|
189
|
+
|
|
190
|
+
mask_matched = (
|
|
191
|
+
np.bitwise_and(
|
|
192
|
+
result, PairClassification.TP | PairClassification.FP_FN_MISCLF
|
|
193
|
+
)
|
|
194
|
+
> 0
|
|
195
|
+
)
|
|
196
|
+
mask_fn_unmatched = (
|
|
197
|
+
np.bitwise_and(result, PairClassification.FN_UNMATCHED) > 0
|
|
198
|
+
)
|
|
199
|
+
|
|
207
200
|
return [
|
|
208
|
-
|
|
201
|
+
_unpack_confusion_matrix(
|
|
202
|
+
ids=ids,
|
|
203
|
+
scores=detailed_pairs[:, 3],
|
|
204
|
+
mask_matched=mask_matched[score_idx, :],
|
|
205
|
+
mask_fn_unmatched=mask_fn_unmatched[score_idx, :],
|
|
206
|
+
index_to_datum_id=index_to_datum_id,
|
|
207
|
+
index_to_label=index_to_label,
|
|
209
208
|
score_threshold=score_threshold,
|
|
210
|
-
maximum_number_of_examples=number_of_examples,
|
|
211
|
-
confusion_matrix=_unpack_confusion_matrix_value(
|
|
212
|
-
confusion_matrix=confusion_matrix[score_idx, :, :, :],
|
|
213
|
-
number_of_labels=n_labels,
|
|
214
|
-
number_of_examples=number_of_examples,
|
|
215
|
-
index_to_label=index_to_label,
|
|
216
|
-
index_to_datum_id=index_to_datum_id,
|
|
217
|
-
),
|
|
218
|
-
unmatched_ground_truths=_unpack_unmatched_ground_truths_value(
|
|
219
|
-
unmatched_ground_truths=unmatched_ground_truths[
|
|
220
|
-
score_idx, :, :
|
|
221
|
-
],
|
|
222
|
-
number_of_labels=n_labels,
|
|
223
|
-
number_of_examples=number_of_examples,
|
|
224
|
-
index_to_label=index_to_label,
|
|
225
|
-
index_to_datum_id=index_to_datum_id,
|
|
226
|
-
),
|
|
227
209
|
)
|
|
228
210
|
for score_idx, score_threshold in enumerate(score_thresholds)
|
|
229
211
|
]
|
valor_lite/exceptions.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
class
|
|
1
|
+
class EmptyEvaluatorError(Exception):
|
|
2
2
|
def __init__(self):
|
|
3
3
|
super().__init__(
|
|
4
4
|
"evaluator cannot be finalized as it contains no data"
|
|
5
5
|
)
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class
|
|
8
|
+
class EmptyFilterError(Exception):
|
|
9
9
|
def __init__(self, message: str):
|
|
10
10
|
super().__init__(message)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class
|
|
13
|
+
class InternalCacheError(Exception):
|
|
14
14
|
def __init__(self, message: str):
|
|
15
15
|
super().__init__(message)
|
|
@@ -6,9 +6,9 @@ from numpy.typing import NDArray
|
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
|
|
8
8
|
from valor_lite.exceptions import (
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
EmptyEvaluatorError,
|
|
10
|
+
EmptyFilterError,
|
|
11
|
+
InternalCacheError,
|
|
12
12
|
)
|
|
13
13
|
from valor_lite.object_detection.annotation import (
|
|
14
14
|
Bitmask,
|
|
@@ -102,13 +102,13 @@ class Filter:
|
|
|
102
102
|
def __post_init__(self):
|
|
103
103
|
# validate datums mask
|
|
104
104
|
if not self.mask_datums.any():
|
|
105
|
-
raise
|
|
105
|
+
raise EmptyFilterError("filter removes all datums")
|
|
106
106
|
|
|
107
107
|
# validate annotation masks
|
|
108
108
|
no_gts = self.mask_groundtruths.all()
|
|
109
109
|
no_pds = self.mask_predictions.all()
|
|
110
110
|
if no_gts and no_pds:
|
|
111
|
-
raise
|
|
111
|
+
raise EmptyFilterError("filter removes all annotations")
|
|
112
112
|
elif no_gts:
|
|
113
113
|
warnings.warn("filter removes all ground truths")
|
|
114
114
|
elif no_pds:
|
|
@@ -173,38 +173,52 @@ class Evaluator:
|
|
|
173
173
|
|
|
174
174
|
def create_filter(
|
|
175
175
|
self,
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
labels: list[str] | None = None,
|
|
176
|
+
datums: list[str] | NDArray[np.int32] | None = None,
|
|
177
|
+
groundtruths: list[str] | NDArray[np.int32] | None = None,
|
|
178
|
+
predictions: list[str] | NDArray[np.int32] | None = None,
|
|
179
|
+
labels: list[str] | NDArray[np.int32] | None = None,
|
|
180
180
|
) -> Filter:
|
|
181
181
|
"""
|
|
182
182
|
Creates a filter object.
|
|
183
183
|
|
|
184
184
|
Parameters
|
|
185
185
|
----------
|
|
186
|
-
|
|
187
|
-
An optional list of string
|
|
188
|
-
|
|
189
|
-
An optional list of string
|
|
190
|
-
|
|
191
|
-
An optional list of string
|
|
192
|
-
labels : list[str], optional
|
|
193
|
-
An optional list of labels to keep.
|
|
186
|
+
datum : list[str] | NDArray[int32], optional
|
|
187
|
+
An optional list of string ids or indices representing datums to keep.
|
|
188
|
+
groundtruth : list[str] | NDArray[int32], optional
|
|
189
|
+
An optional list of string ids or indices representing ground truth annotations to keep.
|
|
190
|
+
prediction : list[str] | NDArray[int32], optional
|
|
191
|
+
An optional list of string ids or indices representing prediction annotations to keep.
|
|
192
|
+
labels : list[str] | NDArray[int32], optional
|
|
193
|
+
An optional list of labels or indices to keep.
|
|
194
194
|
"""
|
|
195
195
|
mask_datums = np.ones(self._detailed_pairs.shape[0], dtype=np.bool_)
|
|
196
196
|
|
|
197
197
|
# filter datums
|
|
198
|
-
if
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
198
|
+
if datums is not None:
|
|
199
|
+
# convert to indices
|
|
200
|
+
if isinstance(datums, list):
|
|
201
|
+
datums = np.array(
|
|
202
|
+
[self.datum_id_to_index[uid] for uid in datums],
|
|
203
|
+
dtype=np.int32,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# validate indices
|
|
207
|
+
if datums.size == 0:
|
|
208
|
+
raise EmptyFilterError(
|
|
209
|
+
"filter removes all datums"
|
|
210
|
+
) # validate indices
|
|
211
|
+
elif datums.min() < 0:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"datum index cannot be negative '{datums.min()}'"
|
|
214
|
+
)
|
|
215
|
+
elif datums.max() >= len(self.index_to_datum_id):
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"datum index cannot exceed total number of datums '{datums.max()}'"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# apply to mask
|
|
221
|
+
mask_datums = np.isin(self._detailed_pairs[:, 0], datums)
|
|
208
222
|
|
|
209
223
|
filtered_detailed_pairs = self._detailed_pairs[mask_datums]
|
|
210
224
|
n_pairs = self._detailed_pairs[mask_datums].shape[0]
|
|
@@ -212,43 +226,93 @@ class Evaluator:
|
|
|
212
226
|
mask_predictions = np.zeros_like(mask_groundtruths)
|
|
213
227
|
|
|
214
228
|
# filter by ground truth annotation ids
|
|
215
|
-
if
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
229
|
+
if groundtruths is not None:
|
|
230
|
+
# convert to indices
|
|
231
|
+
if isinstance(groundtruths, list):
|
|
232
|
+
groundtruths = np.array(
|
|
233
|
+
[
|
|
234
|
+
self.groundtruth_id_to_index[uid]
|
|
235
|
+
for uid in groundtruths
|
|
236
|
+
],
|
|
237
|
+
dtype=np.int32,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# validate indices
|
|
241
|
+
if groundtruths.size == 0:
|
|
242
|
+
warnings.warn("filter removes all ground truths")
|
|
243
|
+
elif groundtruths.min() < 0:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"groundtruth annotation index cannot be negative '{groundtruths.min()}'"
|
|
246
|
+
)
|
|
247
|
+
elif groundtruths.max() >= len(self.index_to_groundtruth_id):
|
|
248
|
+
raise ValueError(
|
|
249
|
+
f"groundtruth annotation index cannot exceed total number of groundtruths '{groundtruths.max()}'"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# apply to mask
|
|
220
253
|
mask_groundtruths[
|
|
221
254
|
~np.isin(
|
|
222
255
|
filtered_detailed_pairs[:, 1],
|
|
223
|
-
|
|
256
|
+
groundtruths,
|
|
224
257
|
)
|
|
225
258
|
] = True
|
|
226
259
|
|
|
227
260
|
# filter by prediction annotation ids
|
|
228
|
-
if
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
261
|
+
if predictions is not None:
|
|
262
|
+
# convert to indices
|
|
263
|
+
if isinstance(predictions, list):
|
|
264
|
+
predictions = np.array(
|
|
265
|
+
[self.prediction_id_to_index[uid] for uid in predictions],
|
|
266
|
+
dtype=np.int32,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# validate indices
|
|
270
|
+
if predictions.size == 0:
|
|
271
|
+
warnings.warn("filter removes all predictions")
|
|
272
|
+
elif predictions.min() < 0:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
f"prediction annotation index cannot be negative '{predictions.min()}'"
|
|
275
|
+
)
|
|
276
|
+
elif predictions.max() >= len(self.index_to_prediction_id):
|
|
277
|
+
raise ValueError(
|
|
278
|
+
f"prediction annotation index cannot exceed total number of predictions '{predictions.max()}'"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# apply to mask
|
|
233
282
|
mask_predictions[
|
|
234
283
|
~np.isin(
|
|
235
284
|
filtered_detailed_pairs[:, 2],
|
|
236
|
-
|
|
285
|
+
predictions,
|
|
237
286
|
)
|
|
238
287
|
] = True
|
|
239
288
|
|
|
240
289
|
# filter by labels
|
|
241
290
|
if labels is not None:
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
291
|
+
# convert to indices
|
|
292
|
+
if isinstance(labels, list):
|
|
293
|
+
labels = np.array(
|
|
294
|
+
[self.label_to_index[label] for label in labels]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# validate indices
|
|
298
|
+
if labels.size == 0:
|
|
299
|
+
raise EmptyFilterError("filter removes all labels")
|
|
300
|
+
elif labels.min() < 0:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"label index cannot be negative '{labels.min()}'"
|
|
303
|
+
)
|
|
304
|
+
elif labels.max() >= len(self.index_to_label):
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"label index cannot exceed total number of labels '{labels.max()}'"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# apply to mask
|
|
310
|
+
labels = np.concatenate([labels, np.array([-1])]) # add null label
|
|
247
311
|
mask_groundtruths[
|
|
248
|
-
~np.isin(filtered_detailed_pairs[:, 3],
|
|
312
|
+
~np.isin(filtered_detailed_pairs[:, 3], labels)
|
|
249
313
|
] = True
|
|
250
314
|
mask_predictions[
|
|
251
|
-
~np.isin(filtered_detailed_pairs[:, 4],
|
|
315
|
+
~np.isin(filtered_detailed_pairs[:, 4], labels)
|
|
252
316
|
] = True
|
|
253
317
|
|
|
254
318
|
filtered_detailed_pairs, _, _ = filter_cache(
|
|
@@ -260,8 +324,8 @@ class Evaluator:
|
|
|
260
324
|
)
|
|
261
325
|
|
|
262
326
|
number_of_datums = (
|
|
263
|
-
|
|
264
|
-
if
|
|
327
|
+
datums.size
|
|
328
|
+
if datums is not None
|
|
265
329
|
else np.unique(filtered_detailed_pairs[:, 0]).size
|
|
266
330
|
)
|
|
267
331
|
|
|
@@ -467,7 +531,7 @@ class DataLoader:
|
|
|
467
531
|
if len(self._evaluator.datum_id_to_index) != len(
|
|
468
532
|
self._evaluator.index_to_datum_id
|
|
469
533
|
):
|
|
470
|
-
raise
|
|
534
|
+
raise InternalCacheError("datum cache size mismatch")
|
|
471
535
|
idx = len(self._evaluator.datum_id_to_index)
|
|
472
536
|
self._evaluator.datum_id_to_index[datum_id] = idx
|
|
473
537
|
self._evaluator.index_to_datum_id.append(datum_id)
|
|
@@ -491,9 +555,7 @@ class DataLoader:
|
|
|
491
555
|
if len(self._evaluator.groundtruth_id_to_index) != len(
|
|
492
556
|
self._evaluator.index_to_groundtruth_id
|
|
493
557
|
):
|
|
494
|
-
raise
|
|
495
|
-
"ground truth cache size mismatch"
|
|
496
|
-
)
|
|
558
|
+
raise InternalCacheError("ground truth cache size mismatch")
|
|
497
559
|
idx = len(self._evaluator.groundtruth_id_to_index)
|
|
498
560
|
self._evaluator.groundtruth_id_to_index[annotation_id] = idx
|
|
499
561
|
self._evaluator.index_to_groundtruth_id.append(annotation_id)
|
|
@@ -517,7 +579,7 @@ class DataLoader:
|
|
|
517
579
|
if len(self._evaluator.prediction_id_to_index) != len(
|
|
518
580
|
self._evaluator.index_to_prediction_id
|
|
519
581
|
):
|
|
520
|
-
raise
|
|
582
|
+
raise InternalCacheError("prediction cache size mismatch")
|
|
521
583
|
idx = len(self._evaluator.prediction_id_to_index)
|
|
522
584
|
self._evaluator.prediction_id_to_index[annotation_id] = idx
|
|
523
585
|
self._evaluator.index_to_prediction_id.append(annotation_id)
|
|
@@ -542,7 +604,7 @@ class DataLoader:
|
|
|
542
604
|
if len(self._evaluator.label_to_index) != len(
|
|
543
605
|
self._evaluator.index_to_label
|
|
544
606
|
):
|
|
545
|
-
raise
|
|
607
|
+
raise InternalCacheError("label cache size mismatch")
|
|
546
608
|
self._evaluator.label_to_index[label] = label_id
|
|
547
609
|
self._evaluator.index_to_label.append(label)
|
|
548
610
|
label_id += 1
|
|
@@ -768,14 +830,14 @@ class DataLoader:
|
|
|
768
830
|
A ready-to-use evaluator object.
|
|
769
831
|
"""
|
|
770
832
|
if not self.pairs:
|
|
771
|
-
raise
|
|
833
|
+
raise EmptyEvaluatorError()
|
|
772
834
|
|
|
773
835
|
n_labels = len(self._evaluator.index_to_label)
|
|
774
836
|
n_datums = len(self._evaluator.index_to_datum_id)
|
|
775
837
|
|
|
776
838
|
self._evaluator._detailed_pairs = np.concatenate(self.pairs, axis=0)
|
|
777
839
|
if self._evaluator._detailed_pairs.size == 0:
|
|
778
|
-
raise
|
|
840
|
+
raise EmptyEvaluatorError()
|
|
779
841
|
|
|
780
842
|
# order pairs by descending score, iou
|
|
781
843
|
indices = np.lexsort(
|
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
from numpy.typing import NDArray
|
|
5
5
|
from tqdm import tqdm
|
|
6
6
|
|
|
7
|
-
from valor_lite.exceptions import
|
|
7
|
+
from valor_lite.exceptions import EmptyEvaluatorError, EmptyFilterError
|
|
8
8
|
from valor_lite.semantic_segmentation.annotation import Segmentation
|
|
9
9
|
from valor_lite.semantic_segmentation.computation import (
|
|
10
10
|
compute_intermediate_confusion_matrices,
|
|
@@ -74,11 +74,11 @@ class Filter:
|
|
|
74
74
|
def __post_init__(self):
|
|
75
75
|
# validate datum mask
|
|
76
76
|
if not self.datum_mask.any():
|
|
77
|
-
raise
|
|
77
|
+
raise EmptyFilterError("filter removes all datums")
|
|
78
78
|
|
|
79
79
|
# validate label mask
|
|
80
80
|
if self.label_mask.all():
|
|
81
|
-
raise
|
|
81
|
+
raise EmptyFilterError("filter removes all labels")
|
|
82
82
|
|
|
83
83
|
|
|
84
84
|
class Evaluator:
|
|
@@ -127,18 +127,18 @@ class Evaluator:
|
|
|
127
127
|
|
|
128
128
|
def create_filter(
|
|
129
129
|
self,
|
|
130
|
-
|
|
131
|
-
labels: list[str] | None = None,
|
|
130
|
+
datums: list[str] | NDArray[np.int64] | None = None,
|
|
131
|
+
labels: list[str] | NDArray[np.int64] | None = None,
|
|
132
132
|
) -> Filter:
|
|
133
133
|
"""
|
|
134
134
|
Creates a filter for use with the evaluator.
|
|
135
135
|
|
|
136
136
|
Parameters
|
|
137
137
|
----------
|
|
138
|
-
|
|
139
|
-
An optional list of string
|
|
140
|
-
labels : list[str], optional
|
|
141
|
-
An optional list of labels.
|
|
138
|
+
datums : list[str] | NDArray[int64], optional
|
|
139
|
+
An optional list of string ids or array of indices representing datums.
|
|
140
|
+
labels : list[str] | NDArray[int64], optional
|
|
141
|
+
An optional list of labels or array of indices.
|
|
142
142
|
|
|
143
143
|
Returns
|
|
144
144
|
-------
|
|
@@ -150,38 +150,61 @@ class Evaluator:
|
|
|
150
150
|
self.metadata.number_of_labels + 1, dtype=np.bool_
|
|
151
151
|
)
|
|
152
152
|
|
|
153
|
-
if
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
153
|
+
if datums is not None:
|
|
154
|
+
# convert to indices
|
|
155
|
+
if isinstance(datums, list):
|
|
156
|
+
datums = np.array(
|
|
157
|
+
[self.datum_id_to_index[uid] for uid in datums],
|
|
158
|
+
dtype=np.int64,
|
|
159
159
|
)
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
160
|
+
|
|
161
|
+
# validate indices
|
|
162
|
+
if datums.size == 0:
|
|
163
|
+
raise EmptyFilterError(
|
|
164
|
+
"filter removes all datums"
|
|
165
|
+
) # validate indices
|
|
166
|
+
elif datums.min() < 0:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"datum index cannot be negative '{datums.min()}'"
|
|
169
|
+
)
|
|
170
|
+
elif datums.max() >= len(self.index_to_datum_id):
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"datum index cannot exceed total number of datums '{datums.max()}'"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# apply to mask
|
|
176
|
+
datums.sort()
|
|
165
177
|
mask_valid_datums = (
|
|
166
178
|
np.arange(self._confusion_matrices.shape[0]).reshape(-1, 1)
|
|
167
|
-
==
|
|
179
|
+
== datums.reshape(1, -1)
|
|
168
180
|
).any(axis=1)
|
|
169
181
|
datum_mask[~mask_valid_datums] = False
|
|
170
182
|
|
|
171
183
|
if labels is not None:
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
184
|
+
# convert to indices
|
|
185
|
+
if isinstance(labels, list):
|
|
186
|
+
labels = np.array(
|
|
187
|
+
[self.label_to_index[label] for label in labels],
|
|
188
|
+
dtype=np.int64,
|
|
177
189
|
)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
190
|
+
|
|
191
|
+
# validate indices
|
|
192
|
+
if labels.size == 0:
|
|
193
|
+
raise EmptyFilterError("filter removes all labels")
|
|
194
|
+
elif labels.min() < 0:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"label index cannot be negative '{labels.min()}'"
|
|
197
|
+
)
|
|
198
|
+
elif labels.max() >= len(self.index_to_label):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"label index cannot exceed total number of labels '{labels.max()}'"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# apply to mask
|
|
204
|
+
labels = np.concatenate([labels, np.array([-1])])
|
|
182
205
|
label_range = np.arange(self.metadata.number_of_labels + 1) - 1
|
|
183
206
|
mask_valid_labels = (
|
|
184
|
-
label_range.reshape(-1, 1) ==
|
|
207
|
+
label_range.reshape(-1, 1) == labels.reshape(1, -1)
|
|
185
208
|
).any(axis=1)
|
|
186
209
|
label_mask[~mask_valid_labels] = True
|
|
187
210
|
|
|
@@ -403,7 +426,7 @@ class DataLoader:
|
|
|
403
426
|
"""
|
|
404
427
|
|
|
405
428
|
if len(self.matrices) == 0:
|
|
406
|
-
raise
|
|
429
|
+
raise EmptyEvaluatorError()
|
|
407
430
|
|
|
408
431
|
n_labels = len(self._evaluator.index_to_label)
|
|
409
432
|
n_datums = len(self._evaluator.index_to_datum_id)
|
|
@@ -1,26 +1,26 @@
|
|
|
1
1
|
valor_lite/LICENSE,sha256=M0L53VuwfEEqezhHb7NPeYcO_glw7-k4DMLZQ3eRN64,1068
|
|
2
2
|
valor_lite/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
valor_lite/exceptions.py,sha256=
|
|
3
|
+
valor_lite/exceptions.py,sha256=Q0PLMu0PnCPBx438iEPzpOQyMOcMOA3lOf5xQZP_yYU,385
|
|
4
4
|
valor_lite/profiling.py,sha256=TLIROA1qccFw9NoEkMeQcrvvGGO75c4K5yTIWoCUix8,11746
|
|
5
5
|
valor_lite/schemas.py,sha256=pB0MrPx5qFLbwBWDiOUUm-vmXdWvbJLFCBmKgbcbI5g,198
|
|
6
6
|
valor_lite/classification/__init__.py,sha256=KXaVwyqAbeeeEq7bzNPyt4GTpbxhrABjV7lR58KR6Y4,440
|
|
7
7
|
valor_lite/classification/annotation.py,sha256=0aUOvcwBAZgiNOJuyh-pXyNTG7vP7r8CUfnU3OmpUwQ,1113
|
|
8
|
-
valor_lite/classification/computation.py,sha256=
|
|
9
|
-
valor_lite/classification/manager.py,sha256=
|
|
10
|
-
valor_lite/classification/metric.py,sha256=
|
|
8
|
+
valor_lite/classification/computation.py,sha256=kB5n-RHzDsKG75Guvgg25xAOeLEQCq1TgjwHwfwbQ60,12010
|
|
9
|
+
valor_lite/classification/manager.py,sha256=JZwA9sf-OG7p7uK5qIo-D711kSpBDDeTcXsPr1uuIBI,16884
|
|
10
|
+
valor_lite/classification/metric.py,sha256=nSNWjoxQ1ou7gxTPOYxLNoUYf7avKQzJq3NHR9jzM48,11693
|
|
11
11
|
valor_lite/classification/numpy_compatibility.py,sha256=roqtTetsm1_HxuaejrthQdydjsRIy-FpXpGb86cLh_E,365
|
|
12
|
-
valor_lite/classification/utilities.py,sha256=
|
|
12
|
+
valor_lite/classification/utilities.py,sha256=jAcir7dW-o4I2gk_NEmlRr8j8Iniyyq9QT5j3PMxVHk,6435
|
|
13
13
|
valor_lite/object_detection/__init__.py,sha256=eSrVAOpSykk1CfHXIKy1necplonUGxjyVKyDQ5UZoBQ,343
|
|
14
14
|
valor_lite/object_detection/annotation.py,sha256=LVec-rIk408LuFxcOoIkPk0QZMWSSxbmsady4wapC1s,7007
|
|
15
15
|
valor_lite/object_detection/computation.py,sha256=njLN-1_yql56NSVxY4KGKohxJUIStPYczVTpEpj5geA,24478
|
|
16
|
-
valor_lite/object_detection/manager.py,sha256=
|
|
16
|
+
valor_lite/object_detection/manager.py,sha256=HfSbq4vfKv2Q3kBRIqpBbq7VCrOxCl7_Pd80yUl6TKQ,30053
|
|
17
17
|
valor_lite/object_detection/metric.py,sha256=sUYSZwXYfIyfmXG6_7Tje1_ZL_QwvecPq85jrGmwOWE,22739
|
|
18
18
|
valor_lite/object_detection/utilities.py,sha256=tNdv5dL7JhzOamGQkZ8x3ocZoTwPI6K8rcRAGMhp2nc,11217
|
|
19
19
|
valor_lite/semantic_segmentation/__init__.py,sha256=3YdItCThY_tW23IChCBm-R0zahnbZ06JDVjs-gQLVes,293
|
|
20
20
|
valor_lite/semantic_segmentation/annotation.py,sha256=XRMV32Sx9A1bAVMFQdBGc3tN5Xz2RfmlyKGXCzdee7A,3705
|
|
21
21
|
valor_lite/semantic_segmentation/benchmark.py,sha256=uxd0SiDY3npsgU5pdeT4HvNP_au9GVRWzoqT6br9DtM,5961
|
|
22
22
|
valor_lite/semantic_segmentation/computation.py,sha256=ZO0qAFmq8lN73UjCyiynSv18qQDtn35FNOmvuXY4rOw,7380
|
|
23
|
-
valor_lite/semantic_segmentation/manager.py,sha256=
|
|
23
|
+
valor_lite/semantic_segmentation/manager.py,sha256=h5w8Xl-O9gZxAzqT-ESofVE2th7d3cYahx4hHBic3pw,14256
|
|
24
24
|
valor_lite/semantic_segmentation/metric.py,sha256=T9RfPJf4WgqGQTXYvSy08vJG5bjXXJnyYZeW0mlxMa8,7132
|
|
25
25
|
valor_lite/semantic_segmentation/utilities.py,sha256=zgVmV8nyKWQK-T4Ov8cZFQzOmTKc5EL7errKFvc2H0g,2957
|
|
26
26
|
valor_lite/text_generation/__init__.py,sha256=pGhpWCSZjLM0pPHCtPykAfos55B8ie3mi9EzbNxfj-U,356
|
|
@@ -35,7 +35,7 @@ valor_lite/text_generation/llm/instructions.py,sha256=fz2onBZZWcl5W8iy7zEWkPGU9N
|
|
|
35
35
|
valor_lite/text_generation/llm/integrations.py,sha256=-rTfdAjq1zH-4ixwYuMQEOQ80pIFzMTe0BYfroVx3Pg,6974
|
|
36
36
|
valor_lite/text_generation/llm/utilities.py,sha256=bjqatGgtVTcl1PrMwiDKTYPGJXKrBrx7PDtzIblGSys,1178
|
|
37
37
|
valor_lite/text_generation/llm/validators.py,sha256=Wzr5RlfF58_2wOU-uTw7C8skan_fYdhy4Gfn0jSJ8HM,2700
|
|
38
|
-
valor_lite-0.36.
|
|
39
|
-
valor_lite-0.36.
|
|
40
|
-
valor_lite-0.36.
|
|
41
|
-
valor_lite-0.36.
|
|
38
|
+
valor_lite-0.36.4.dist-info/METADATA,sha256=2UmPknazuM-lpHiYGqEDGl_JBr7dX-HJPCGFJKY5kck,5071
|
|
39
|
+
valor_lite-0.36.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
40
|
+
valor_lite-0.36.4.dist-info/top_level.txt,sha256=9ujykxSwpl2Hu0_R95UQTR_l07k9UUTSdrpiqmq6zc4,11
|
|
41
|
+
valor_lite-0.36.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|