valor-lite 0.36.2__tar.gz → 0.36.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {valor_lite-0.36.2 → valor_lite-0.36.4}/PKG-INFO +1 -1
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/classification/computation.py +25 -132
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/classification/manager.py +72 -54
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/classification/metric.py +0 -4
- valor_lite-0.36.4/valor_lite/classification/utilities.py +211 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/exceptions.py +3 -3
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/object_detection/manager.py +118 -56
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/semantic_segmentation/manager.py +55 -32
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite.egg-info/PKG-INFO +1 -1
- valor_lite-0.36.2/valor_lite/classification/utilities.py +0 -229
- {valor_lite-0.36.2 → valor_lite-0.36.4}/README.md +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/pyproject.toml +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/setup.cfg +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/LICENSE +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/__init__.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/classification/__init__.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/classification/annotation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/classification/numpy_compatibility.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/object_detection/__init__.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/object_detection/annotation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/object_detection/computation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/object_detection/metric.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/object_detection/utilities.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/profiling.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/schemas.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/semantic_segmentation/__init__.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/semantic_segmentation/annotation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/semantic_segmentation/benchmark.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/semantic_segmentation/computation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/semantic_segmentation/metric.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/semantic_segmentation/utilities.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/__init__.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/annotation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/computation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/llm/__init__.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/llm/exceptions.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/llm/generation.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/llm/instructions.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/llm/integrations.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/llm/utilities.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/llm/validators.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/manager.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite/text_generation/metric.py +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite.egg-info/SOURCES.txt +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite.egg-info/dependency_links.txt +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite.egg-info/requires.txt +0 -0
- {valor_lite-0.36.2 → valor_lite-0.36.4}/valor_lite.egg-info/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from enum import IntFlag, auto
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
from numpy.typing import NDArray
|
|
3
5
|
|
|
@@ -318,56 +320,20 @@ def compute_precision_recall_rocauc(
|
|
|
318
320
|
)
|
|
319
321
|
|
|
320
322
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
) -> tuple[NDArray[np.float64], NDArray[np.int32], NDArray[np.intp]]:
|
|
326
|
-
"""
|
|
327
|
-
Helper function for counting occurences of unique detailed pairs.
|
|
328
|
-
|
|
329
|
-
Parameters
|
|
330
|
-
----------
|
|
331
|
-
data : NDArray[np.float64]
|
|
332
|
-
A masked portion of a detailed pairs array.
|
|
333
|
-
unique_idx : int | list[int]
|
|
334
|
-
The index or indices upon which uniqueness is constrained.
|
|
335
|
-
label_idx : int | list[int]
|
|
336
|
-
The index or indices within the unique index or indices that encode labels.
|
|
337
|
-
|
|
338
|
-
Returns
|
|
339
|
-
-------
|
|
340
|
-
NDArray[np.float64]
|
|
341
|
-
Examples drawn from the data input.
|
|
342
|
-
NDArray[np.int32]
|
|
343
|
-
Unique label indices.
|
|
344
|
-
NDArray[np.intp]
|
|
345
|
-
Counts for each unique label index.
|
|
346
|
-
"""
|
|
347
|
-
unique_rows, indices = np.unique(
|
|
348
|
-
data.astype(int)[:, unique_idx],
|
|
349
|
-
return_index=True,
|
|
350
|
-
axis=0,
|
|
351
|
-
)
|
|
352
|
-
examples = data[indices]
|
|
353
|
-
labels, counts = np.unique(
|
|
354
|
-
unique_rows[:, label_idx], return_counts=True, axis=0
|
|
355
|
-
)
|
|
356
|
-
return examples, labels, counts
|
|
323
|
+
class PairClassification(IntFlag):
|
|
324
|
+
TP = auto()
|
|
325
|
+
FP_FN_MISCLF = auto()
|
|
326
|
+
FN_UNMATCHED = auto()
|
|
357
327
|
|
|
358
328
|
|
|
359
329
|
def compute_confusion_matrix(
|
|
360
330
|
detailed_pairs: NDArray[np.float64],
|
|
361
|
-
label_metadata: NDArray[np.int32],
|
|
362
331
|
score_thresholds: NDArray[np.float64],
|
|
363
332
|
hardmax: bool,
|
|
364
|
-
|
|
365
|
-
) -> tuple[NDArray[np.float64], NDArray[np.int32]]:
|
|
333
|
+
) -> NDArray[np.uint8]:
|
|
366
334
|
"""
|
|
367
335
|
Compute detailed confusion matrix.
|
|
368
336
|
|
|
369
|
-
Takes data with shape (N, 5):
|
|
370
|
-
|
|
371
337
|
Parameters
|
|
372
338
|
----------
|
|
373
339
|
detailed_pairs : NDArray[np.float64]
|
|
@@ -377,37 +343,22 @@ def compute_confusion_matrix(
|
|
|
377
343
|
Index 2 - Prediction Label Index
|
|
378
344
|
Index 3 - Score
|
|
379
345
|
Index 4 - Hard Max Score
|
|
380
|
-
label_metadata : NDArray[np.int32]
|
|
381
|
-
A 2-D array containing metadata related to labels with shape (n_labels, 2).
|
|
382
|
-
Index 0 - GroundTruth Label Count
|
|
383
|
-
Index 1 - Prediction Label Count
|
|
384
346
|
iou_thresholds : NDArray[np.float64]
|
|
385
347
|
A 1-D array containing IOU thresholds.
|
|
386
348
|
score_thresholds : NDArray[np.float64]
|
|
387
349
|
A 1-D array containing score thresholds.
|
|
388
|
-
n_examples : int
|
|
389
|
-
The maximum number of examples to return per count.
|
|
390
350
|
|
|
391
351
|
Returns
|
|
392
352
|
-------
|
|
393
|
-
NDArray[
|
|
394
|
-
|
|
395
|
-
NDArray[np.int32]
|
|
396
|
-
Unmatched Ground Truths.
|
|
353
|
+
NDArray[uint8]
|
|
354
|
+
Row-wise classification of pairs.
|
|
397
355
|
"""
|
|
398
|
-
|
|
399
|
-
n_labels = label_metadata.shape[0]
|
|
356
|
+
n_pairs = detailed_pairs.shape[0]
|
|
400
357
|
n_scores = score_thresholds.shape[0]
|
|
401
358
|
|
|
402
|
-
|
|
403
|
-
(n_scores,
|
|
404
|
-
|
|
405
|
-
dtype=np.float32,
|
|
406
|
-
)
|
|
407
|
-
unmatched_ground_truths = np.full(
|
|
408
|
-
(n_scores, n_labels, n_examples + 1),
|
|
409
|
-
fill_value=-1,
|
|
410
|
-
dtype=np.int32,
|
|
359
|
+
pair_classifications = np.zeros(
|
|
360
|
+
(n_scores, n_pairs),
|
|
361
|
+
dtype=np.uint8,
|
|
411
362
|
)
|
|
412
363
|
|
|
413
364
|
mask_label_match = np.isclose(detailed_pairs[:, 1], detailed_pairs[:, 2])
|
|
@@ -420,9 +371,9 @@ def compute_confusion_matrix(
|
|
|
420
371
|
if hardmax:
|
|
421
372
|
mask_score &= detailed_pairs[:, 4] > 0.5
|
|
422
373
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
374
|
+
mask_true_positives = mask_label_match & mask_score
|
|
375
|
+
mask_misclassifications = ~mask_label_match & mask_score
|
|
376
|
+
mask_unmatched_groundtruths = ~(
|
|
426
377
|
(
|
|
427
378
|
groundtruths.reshape(-1, 1, 2)
|
|
428
379
|
== groundtruths[mask_score].reshape(1, -1, 2)
|
|
@@ -431,73 +382,15 @@ def compute_confusion_matrix(
|
|
|
431
382
|
.any(axis=1)
|
|
432
383
|
)
|
|
433
384
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
label_idx=1,
|
|
438
|
-
)
|
|
439
|
-
misclf_examples, misclf_labels, misclf_counts = _count_with_examples(
|
|
440
|
-
data=detailed_pairs[mask_misclf],
|
|
441
|
-
unique_idx=[0, 1, 2],
|
|
442
|
-
label_idx=[1, 2],
|
|
385
|
+
# classify pairings
|
|
386
|
+
pair_classifications[score_idx, mask_true_positives] |= np.uint8(
|
|
387
|
+
PairClassification.TP
|
|
443
388
|
)
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
unique_idx=[0, 1],
|
|
447
|
-
label_idx=1,
|
|
389
|
+
pair_classifications[score_idx, mask_misclassifications] |= np.uint8(
|
|
390
|
+
PairClassification.FP_FN_MISCLF
|
|
448
391
|
)
|
|
392
|
+
pair_classifications[
|
|
393
|
+
score_idx, mask_unmatched_groundtruths
|
|
394
|
+
] |= np.uint8(PairClassification.FN_UNMATCHED)
|
|
449
395
|
|
|
450
|
-
|
|
451
|
-
confusion_matrix[
|
|
452
|
-
score_idx, misclf_labels[:, 0], misclf_labels[:, 1], 0
|
|
453
|
-
] = misclf_counts
|
|
454
|
-
|
|
455
|
-
unmatched_ground_truths[score_idx, misprd_labels, 0] = misprd_counts
|
|
456
|
-
|
|
457
|
-
if n_examples > 0:
|
|
458
|
-
for label_idx in range(n_labels):
|
|
459
|
-
# true-positive examples
|
|
460
|
-
mask_tp_label = tp_examples[:, 2] == label_idx
|
|
461
|
-
if mask_tp_label.sum() > 0:
|
|
462
|
-
tp_label_examples = tp_examples[mask_tp_label][:n_examples]
|
|
463
|
-
confusion_matrix[
|
|
464
|
-
score_idx,
|
|
465
|
-
label_idx,
|
|
466
|
-
label_idx,
|
|
467
|
-
1 : 2 * tp_label_examples.shape[0] + 1,
|
|
468
|
-
] = tp_label_examples[:, [0, 3]].flatten()
|
|
469
|
-
|
|
470
|
-
# misclassification examples
|
|
471
|
-
mask_misclf_gt_label = misclf_examples[:, 1] == label_idx
|
|
472
|
-
if mask_misclf_gt_label.sum() > 0:
|
|
473
|
-
for pd_label_idx in range(n_labels):
|
|
474
|
-
mask_misclf_pd_label = (
|
|
475
|
-
misclf_examples[:, 2] == pd_label_idx
|
|
476
|
-
)
|
|
477
|
-
mask_misclf_label_combo = (
|
|
478
|
-
mask_misclf_gt_label & mask_misclf_pd_label
|
|
479
|
-
)
|
|
480
|
-
if mask_misclf_label_combo.sum() > 0:
|
|
481
|
-
misclf_label_examples = misclf_examples[
|
|
482
|
-
mask_misclf_label_combo
|
|
483
|
-
][:n_examples]
|
|
484
|
-
confusion_matrix[
|
|
485
|
-
score_idx,
|
|
486
|
-
label_idx,
|
|
487
|
-
pd_label_idx,
|
|
488
|
-
1 : 2 * misclf_label_examples.shape[0] + 1,
|
|
489
|
-
] = misclf_label_examples[:, [0, 3]].flatten()
|
|
490
|
-
|
|
491
|
-
# unmatched ground truth examples
|
|
492
|
-
mask_misprd_label = misprd_examples[:, 1] == label_idx
|
|
493
|
-
if misprd_examples.size > 0:
|
|
494
|
-
misprd_label_examples = misprd_examples[mask_misprd_label][
|
|
495
|
-
:n_examples
|
|
496
|
-
]
|
|
497
|
-
unmatched_ground_truths[
|
|
498
|
-
score_idx,
|
|
499
|
-
label_idx,
|
|
500
|
-
1 : misprd_label_examples.shape[0] + 1,
|
|
501
|
-
] = misprd_label_examples[:, 0].flatten()
|
|
502
|
-
|
|
503
|
-
return confusion_matrix, unmatched_ground_truths # type: ignore[reportReturnType]
|
|
396
|
+
return pair_classifications
|
|
@@ -16,7 +16,7 @@ from valor_lite.classification.utilities import (
|
|
|
16
16
|
unpack_confusion_matrix_into_metric_list,
|
|
17
17
|
unpack_precision_recall_rocauc_into_metric_lists,
|
|
18
18
|
)
|
|
19
|
-
from valor_lite.exceptions import
|
|
19
|
+
from valor_lite.exceptions import EmptyEvaluatorError, EmptyFilterError
|
|
20
20
|
|
|
21
21
|
"""
|
|
22
22
|
Usage
|
|
@@ -88,14 +88,14 @@ class Filter:
|
|
|
88
88
|
def __post_init__(self):
|
|
89
89
|
# validate datum mask
|
|
90
90
|
if not self.datum_mask.any():
|
|
91
|
-
raise
|
|
91
|
+
raise EmptyFilterError("filter removes all datums")
|
|
92
92
|
|
|
93
93
|
# validate label indices
|
|
94
94
|
if (
|
|
95
95
|
self.valid_label_indices is not None
|
|
96
96
|
and self.valid_label_indices.size == 0
|
|
97
97
|
):
|
|
98
|
-
raise
|
|
98
|
+
raise EmptyFilterError("filter removes all labels")
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
class Evaluator:
|
|
@@ -144,18 +144,18 @@ class Evaluator:
|
|
|
144
144
|
|
|
145
145
|
def create_filter(
|
|
146
146
|
self,
|
|
147
|
-
|
|
148
|
-
labels: list[str] | None = None,
|
|
147
|
+
datums: list[str] | NDArray[np.int32] | None = None,
|
|
148
|
+
labels: list[str] | NDArray[np.int32] | None = None,
|
|
149
149
|
) -> Filter:
|
|
150
150
|
"""
|
|
151
151
|
Creates a filter object.
|
|
152
152
|
|
|
153
153
|
Parameters
|
|
154
154
|
----------
|
|
155
|
-
|
|
156
|
-
An optional list of string uids representing datums.
|
|
157
|
-
labels : list[str], optional
|
|
158
|
-
An optional list of labels.
|
|
155
|
+
datums : list[str] | NDArray[int32], optional
|
|
156
|
+
An optional list of string uids or integer indices representing datums.
|
|
157
|
+
labels : list[str] | NDArray[int32], optional
|
|
158
|
+
An optional list of strings or integer indices representing labels.
|
|
159
159
|
|
|
160
160
|
Returns
|
|
161
161
|
-------
|
|
@@ -165,50 +165,72 @@ class Evaluator:
|
|
|
165
165
|
# create datum mask
|
|
166
166
|
n_pairs = self._detailed_pairs.shape[0]
|
|
167
167
|
datum_mask = np.ones(n_pairs, dtype=np.bool_)
|
|
168
|
-
if
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
168
|
+
if datums is not None:
|
|
169
|
+
# convert to array of valid datum indices
|
|
170
|
+
if isinstance(datums, list):
|
|
171
|
+
datums = np.array(
|
|
172
|
+
[self.datum_id_to_index[uid] for uid in datums],
|
|
173
|
+
dtype=np.int32,
|
|
174
174
|
)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
)
|
|
175
|
+
|
|
176
|
+
# return early if all data removed
|
|
177
|
+
if datums.size == 0:
|
|
178
|
+
raise EmptyFilterError("filter removes all datums")
|
|
179
|
+
|
|
180
|
+
# validate indices
|
|
181
|
+
if datums.max() >= len(self.index_to_datum_id):
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"datum index '{datums.max()}' exceeds total number of datums"
|
|
184
|
+
)
|
|
185
|
+
elif datums.min() < 0:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"datum index '{datums.min()}' is a negative value"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# create datum mask
|
|
191
|
+
datum_mask = np.isin(self._detailed_pairs[:, 0], datums)
|
|
182
192
|
|
|
183
193
|
# collect valid label indices
|
|
184
|
-
valid_label_indices = None
|
|
185
194
|
if labels is not None:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
metadata=Metadata(),
|
|
195
|
+
# convert to array of valid label indices
|
|
196
|
+
if isinstance(labels, list):
|
|
197
|
+
labels = np.array(
|
|
198
|
+
[self.label_to_index[label] for label in labels]
|
|
191
199
|
)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
200
|
+
|
|
201
|
+
# return early if all data removed
|
|
202
|
+
if labels.size == 0:
|
|
203
|
+
raise EmptyFilterError("filter removes all labels")
|
|
204
|
+
|
|
205
|
+
# validate indices
|
|
206
|
+
if labels.max() >= len(self.index_to_label):
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"label index '{labels.max()}' exceeds total number of labels"
|
|
209
|
+
)
|
|
210
|
+
elif labels.min() < 0:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"label index '{labels.min()}' is a negative value"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# add -1 to represent null labels which should not be filtered
|
|
216
|
+
labels = np.concatenate([labels, np.array([-1])])
|
|
195
217
|
|
|
196
218
|
filtered_detailed_pairs, _ = filter_cache(
|
|
197
219
|
detailed_pairs=self._detailed_pairs,
|
|
198
220
|
datum_mask=datum_mask,
|
|
199
|
-
valid_label_indices=
|
|
221
|
+
valid_label_indices=labels,
|
|
200
222
|
n_labels=self.metadata.number_of_labels,
|
|
201
223
|
)
|
|
202
224
|
|
|
203
225
|
number_of_datums = (
|
|
204
|
-
|
|
205
|
-
if
|
|
226
|
+
datums.size
|
|
227
|
+
if datums is not None
|
|
206
228
|
else self.metadata.number_of_datums
|
|
207
229
|
)
|
|
208
230
|
|
|
209
231
|
return Filter(
|
|
210
232
|
datum_mask=datum_mask,
|
|
211
|
-
valid_label_indices=
|
|
233
|
+
valid_label_indices=labels,
|
|
212
234
|
metadata=Metadata.create(
|
|
213
235
|
detailed_pairs=filtered_detailed_pairs,
|
|
214
236
|
number_of_datums=number_of_datums,
|
|
@@ -292,7 +314,6 @@ class Evaluator:
|
|
|
292
314
|
self,
|
|
293
315
|
score_thresholds: list[float] = [0.0],
|
|
294
316
|
hardmax: bool = True,
|
|
295
|
-
number_of_examples: int = 0,
|
|
296
317
|
filter_: Filter | None = None,
|
|
297
318
|
) -> list[Metric]:
|
|
298
319
|
"""
|
|
@@ -304,8 +325,6 @@ class Evaluator:
|
|
|
304
325
|
A list of score thresholds to compute metrics over.
|
|
305
326
|
hardmax : bool
|
|
306
327
|
Toggles whether a hardmax is applied to predictions.
|
|
307
|
-
number_of_examples : int, default=0
|
|
308
|
-
The number of examples to return per count.
|
|
309
328
|
filter_ : Filter, optional
|
|
310
329
|
Applies a filter to the internal cache.
|
|
311
330
|
|
|
@@ -316,25 +335,22 @@ class Evaluator:
|
|
|
316
335
|
"""
|
|
317
336
|
# apply filters
|
|
318
337
|
if filter_ is not None:
|
|
319
|
-
detailed_pairs,
|
|
338
|
+
detailed_pairs, _ = self.filter(filter_=filter_)
|
|
320
339
|
else:
|
|
321
340
|
detailed_pairs = self._detailed_pairs
|
|
322
|
-
label_metadata = self._label_metadata
|
|
323
341
|
|
|
324
342
|
if detailed_pairs.size == 0:
|
|
325
343
|
return list()
|
|
326
344
|
|
|
327
|
-
|
|
345
|
+
result = compute_confusion_matrix(
|
|
328
346
|
detailed_pairs=detailed_pairs,
|
|
329
|
-
label_metadata=label_metadata,
|
|
330
347
|
score_thresholds=np.array(score_thresholds),
|
|
331
348
|
hardmax=hardmax,
|
|
332
|
-
n_examples=number_of_examples,
|
|
333
349
|
)
|
|
334
350
|
return unpack_confusion_matrix_into_metric_list(
|
|
335
|
-
|
|
351
|
+
detailed_pairs=detailed_pairs,
|
|
352
|
+
result=result,
|
|
336
353
|
score_thresholds=score_thresholds,
|
|
337
|
-
number_of_examples=number_of_examples,
|
|
338
354
|
index_to_datum_id=self.index_to_datum_id,
|
|
339
355
|
index_to_label=self.index_to_label,
|
|
340
356
|
)
|
|
@@ -343,7 +359,6 @@ class Evaluator:
|
|
|
343
359
|
self,
|
|
344
360
|
score_thresholds: list[float] = [0.0],
|
|
345
361
|
hardmax: bool = True,
|
|
346
|
-
number_of_examples: int = 0,
|
|
347
362
|
filter_: Filter | None = None,
|
|
348
363
|
) -> dict[MetricType, list[Metric]]:
|
|
349
364
|
"""
|
|
@@ -355,8 +370,6 @@ class Evaluator:
|
|
|
355
370
|
A list of score thresholds to compute metrics over.
|
|
356
371
|
hardmax : bool
|
|
357
372
|
Toggles whether a hardmax is applied to predictions.
|
|
358
|
-
number_of_examples : int, default=0
|
|
359
|
-
The number of examples to return per count.
|
|
360
373
|
filter_ : Filter, optional
|
|
361
374
|
Applies a filter to the internal cache.
|
|
362
375
|
|
|
@@ -373,7 +386,6 @@ class Evaluator:
|
|
|
373
386
|
metrics[MetricType.ConfusionMatrix] = self.compute_confusion_matrix(
|
|
374
387
|
score_thresholds=score_thresholds,
|
|
375
388
|
hardmax=hardmax,
|
|
376
|
-
number_of_examples=number_of_examples,
|
|
377
389
|
filter_=filter_,
|
|
378
390
|
)
|
|
379
391
|
return metrics
|
|
@@ -391,11 +403,17 @@ class Evaluator:
|
|
|
391
403
|
-------
|
|
392
404
|
int
|
|
393
405
|
The datum index.
|
|
406
|
+
|
|
407
|
+
Raises
|
|
408
|
+
------
|
|
409
|
+
ValueError
|
|
410
|
+
If datum id already exists.
|
|
394
411
|
"""
|
|
395
|
-
if uid
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
412
|
+
if uid in self.datum_id_to_index:
|
|
413
|
+
raise ValueError("datum with id '{uid}' already exists")
|
|
414
|
+
index = len(self.datum_id_to_index)
|
|
415
|
+
self.datum_id_to_index[uid] = index
|
|
416
|
+
self.index_to_datum_id.append(uid)
|
|
399
417
|
return self.datum_id_to_index[uid]
|
|
400
418
|
|
|
401
419
|
def _add_label(self, label: str) -> int:
|
|
@@ -497,7 +515,7 @@ class Evaluator:
|
|
|
497
515
|
A ready-to-use evaluator object.
|
|
498
516
|
"""
|
|
499
517
|
if self._detailed_pairs.size == 0:
|
|
500
|
-
raise
|
|
518
|
+
raise EmptyEvaluatorError()
|
|
501
519
|
|
|
502
520
|
self._label_metadata = compute_label_metadata(
|
|
503
521
|
ids=self._detailed_pairs[:, :3].astype(np.int32),
|
|
@@ -329,7 +329,6 @@ class Metric(BaseMetric):
|
|
|
329
329
|
],
|
|
330
330
|
],
|
|
331
331
|
score_threshold: float,
|
|
332
|
-
maximum_number_of_examples: int,
|
|
333
332
|
):
|
|
334
333
|
"""
|
|
335
334
|
The confusion matrix and related metrics for the classification task.
|
|
@@ -382,8 +381,6 @@ class Metric(BaseMetric):
|
|
|
382
381
|
Each example includes the datum UID.
|
|
383
382
|
score_threshold : float
|
|
384
383
|
The confidence score threshold used to filter predictions.
|
|
385
|
-
maximum_number_of_examples : int
|
|
386
|
-
The maximum number of examples per element.
|
|
387
384
|
|
|
388
385
|
Returns
|
|
389
386
|
-------
|
|
@@ -397,6 +394,5 @@ class Metric(BaseMetric):
|
|
|
397
394
|
},
|
|
398
395
|
parameters={
|
|
399
396
|
"score_threshold": score_threshold,
|
|
400
|
-
"maximum_number_of_examples": maximum_number_of_examples,
|
|
401
397
|
},
|
|
402
398
|
)
|