valor-lite 0.33.0__py3-none-any.whl → 0.33.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/detection/__init__.py +4 -6
- valor_lite/detection/computation.py +243 -64
- valor_lite/detection/manager.py +216 -98
- valor_lite/detection/metric.py +77 -77
- {valor_lite-0.33.0.dist-info → valor_lite-0.33.2.dist-info}/METADATA +1 -1
- valor_lite-0.33.2.dist-info/RECORD +12 -0
- valor_lite-0.33.0.dist-info/RECORD +0 -12
- {valor_lite-0.33.0.dist-info → valor_lite-0.33.2.dist-info}/LICENSE +0 -0
- {valor_lite-0.33.0.dist-info → valor_lite-0.33.2.dist-info}/WHEEL +0 -0
- {valor_lite-0.33.0.dist-info → valor_lite-0.33.2.dist-info}/top_level.txt +0 -0
valor_lite/detection/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from .annotation import Bitmask, BoundingBox, Detection
|
|
2
2
|
from .computation import (
|
|
3
|
-
|
|
3
|
+
compute_detailed_counts,
|
|
4
4
|
compute_iou,
|
|
5
5
|
compute_metrics,
|
|
6
6
|
compute_ranked_pairs,
|
|
@@ -14,8 +14,7 @@ from .metric import (
|
|
|
14
14
|
APAveragedOverIOUs,
|
|
15
15
|
ARAveragedOverScores,
|
|
16
16
|
Counts,
|
|
17
|
-
|
|
18
|
-
DetailedPrecisionRecallPoint,
|
|
17
|
+
DetailedCounts,
|
|
19
18
|
MetricType,
|
|
20
19
|
Precision,
|
|
21
20
|
PrecisionRecallCurve,
|
|
@@ -45,12 +44,11 @@ __all__ = [
|
|
|
45
44
|
"ARAveragedOverScores",
|
|
46
45
|
"mARAveragedOverScores",
|
|
47
46
|
"PrecisionRecallCurve",
|
|
48
|
-
"
|
|
49
|
-
"DetailedPrecisionRecallCurve",
|
|
47
|
+
"DetailedCounts",
|
|
50
48
|
"compute_iou",
|
|
51
49
|
"compute_ranked_pairs",
|
|
52
50
|
"compute_metrics",
|
|
53
|
-
"
|
|
51
|
+
"compute_detailed_counts",
|
|
54
52
|
"DataLoader",
|
|
55
53
|
"Evaluator",
|
|
56
54
|
]
|
|
@@ -1,16 +1,38 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from numpy.typing import NDArray
|
|
3
3
|
|
|
4
|
-
# datum id 0
|
|
5
|
-
# gt 1
|
|
6
|
-
# pd 2
|
|
7
|
-
# iou 3
|
|
8
|
-
# gt label 4
|
|
9
|
-
# pd label 5
|
|
10
|
-
# score 6
|
|
11
|
-
|
|
12
4
|
|
|
13
5
|
def compute_iou(data: NDArray[np.floating]) -> NDArray[np.floating]:
|
|
6
|
+
"""
|
|
7
|
+
Computes intersection-over-union (IoU) for axis-aligned bounding boxes.
|
|
8
|
+
|
|
9
|
+
Takes data with shape (N, 8):
|
|
10
|
+
|
|
11
|
+
Index 0 - xmin for Box 1
|
|
12
|
+
Index 1 - xmax for Box 1
|
|
13
|
+
Index 2 - ymin for Box 1
|
|
14
|
+
Index 3 - ymax for Box 1
|
|
15
|
+
Index 4 - xmin for Box 2
|
|
16
|
+
Index 5 - xmax for Box 2
|
|
17
|
+
Index 6 - ymin for Box 2
|
|
18
|
+
Index 7 - ymax for Box 2
|
|
19
|
+
|
|
20
|
+
Returns data with shape (N, 1):
|
|
21
|
+
|
|
22
|
+
Index 0 - IoU
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
data : NDArray[np.floating]
|
|
27
|
+
A sorted array of classification pairs.
|
|
28
|
+
label_metadata : NDArray[np.int32]
|
|
29
|
+
An array containing metadata related to labels.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
NDArray[np.floating]
|
|
34
|
+
Compute IoU's.
|
|
35
|
+
"""
|
|
14
36
|
|
|
15
37
|
xmin1, xmax1, ymin1, ymax1 = (
|
|
16
38
|
data[:, 0],
|
|
@@ -49,7 +71,7 @@ def compute_iou(data: NDArray[np.floating]) -> NDArray[np.floating]:
|
|
|
49
71
|
|
|
50
72
|
def _compute_ranked_pairs_for_datum(
|
|
51
73
|
data: np.ndarray,
|
|
52
|
-
|
|
74
|
+
label_metadata: np.ndarray,
|
|
53
75
|
) -> np.ndarray:
|
|
54
76
|
"""
|
|
55
77
|
Computes ranked pairs for a datum.
|
|
@@ -58,6 +80,12 @@ def _compute_ranked_pairs_for_datum(
|
|
|
58
80
|
# remove null predictions
|
|
59
81
|
data = data[data[:, 2] >= 0.0]
|
|
60
82
|
|
|
83
|
+
# find best fits for prediction
|
|
84
|
+
mask_label_match = data[:, 4] == data[:, 5]
|
|
85
|
+
matched_predicitons = np.unique(data[mask_label_match, 2].astype(int))
|
|
86
|
+
mask_unmatched_predictions = ~np.isin(data[:, 2], matched_predicitons)
|
|
87
|
+
data = data[mask_label_match | mask_unmatched_predictions]
|
|
88
|
+
|
|
61
89
|
# sort by gt_id, iou, score
|
|
62
90
|
indices = np.lexsort(
|
|
63
91
|
(
|
|
@@ -69,7 +97,7 @@ def _compute_ranked_pairs_for_datum(
|
|
|
69
97
|
data = data[indices]
|
|
70
98
|
|
|
71
99
|
# remove ignored predictions
|
|
72
|
-
for label_idx, count in enumerate(
|
|
100
|
+
for label_idx, count in enumerate(label_metadata[:, 0]):
|
|
73
101
|
if count > 0:
|
|
74
102
|
continue
|
|
75
103
|
data = data[data[:, 5] != label_idx]
|
|
@@ -85,13 +113,40 @@ def _compute_ranked_pairs_for_datum(
|
|
|
85
113
|
|
|
86
114
|
def compute_ranked_pairs(
|
|
87
115
|
data: list[NDArray[np.floating]],
|
|
88
|
-
|
|
116
|
+
label_metadata: NDArray[np.integer],
|
|
89
117
|
) -> NDArray[np.floating]:
|
|
118
|
+
"""
|
|
119
|
+
Performs pair ranking on input data.
|
|
120
|
+
|
|
121
|
+
Takes data with shape (N, 7):
|
|
122
|
+
|
|
123
|
+
Index 0 - Datum Index
|
|
124
|
+
Index 1 - GroundTruth Index
|
|
125
|
+
Index 2 - Prediction Index
|
|
126
|
+
Index 3 - IoU
|
|
127
|
+
Index 4 - GroundTruth Label Index
|
|
128
|
+
Index 5 - Prediction Label Index
|
|
129
|
+
Index 6 - Score
|
|
130
|
+
|
|
131
|
+
Returns data with shape (N - M, 7)
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
----------
|
|
135
|
+
data : NDArray[np.floating]
|
|
136
|
+
A sorted array of classification pairs.
|
|
137
|
+
label_metadata : NDArray[np.int32]
|
|
138
|
+
An array containing metadata related to labels.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
NDArray[np.floating]
|
|
143
|
+
A filtered array containing only ranked pairs.
|
|
144
|
+
"""
|
|
90
145
|
pairs = np.concatenate(
|
|
91
146
|
[
|
|
92
147
|
_compute_ranked_pairs_for_datum(
|
|
93
148
|
datum,
|
|
94
|
-
|
|
149
|
+
label_metadata=label_metadata,
|
|
95
150
|
)
|
|
96
151
|
for datum in data
|
|
97
152
|
],
|
|
@@ -108,7 +163,7 @@ def compute_ranked_pairs(
|
|
|
108
163
|
|
|
109
164
|
def compute_metrics(
|
|
110
165
|
data: np.ndarray,
|
|
111
|
-
|
|
166
|
+
label_metadata: np.ndarray,
|
|
112
167
|
iou_thresholds: np.ndarray,
|
|
113
168
|
score_thresholds: np.ndarray,
|
|
114
169
|
) -> tuple[
|
|
@@ -130,6 +185,27 @@ def compute_metrics(
|
|
|
130
185
|
"""
|
|
131
186
|
Computes Object Detection metrics.
|
|
132
187
|
|
|
188
|
+
Takes data with shape (N, 7):
|
|
189
|
+
|
|
190
|
+
Index 0 - Datum Index
|
|
191
|
+
Index 1 - GroundTruth Index
|
|
192
|
+
Index 2 - Prediction Index
|
|
193
|
+
Index 3 - IoU
|
|
194
|
+
Index 4 - GroundTruth Label Index
|
|
195
|
+
Index 5 - Prediction Label Index
|
|
196
|
+
Index 6 - Score
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
data : NDArray[np.floating]
|
|
201
|
+
A sorted array of classification pairs.
|
|
202
|
+
label_metadata : NDArray[np.int32]
|
|
203
|
+
An array containing metadata related to labels.
|
|
204
|
+
iou_thresholds : NDArray[np.floating]
|
|
205
|
+
A 1-D array containing IoU thresholds.
|
|
206
|
+
score_thresholds : NDArray[np.floating]
|
|
207
|
+
A 1-D array containing score thresholds.
|
|
208
|
+
|
|
133
209
|
Returns
|
|
134
210
|
-------
|
|
135
211
|
tuple[NDArray, NDArray, NDArray NDArray]
|
|
@@ -143,17 +219,17 @@ def compute_metrics(
|
|
|
143
219
|
"""
|
|
144
220
|
|
|
145
221
|
n_rows = data.shape[0]
|
|
146
|
-
n_labels =
|
|
222
|
+
n_labels = label_metadata.shape[0]
|
|
147
223
|
n_ious = iou_thresholds.shape[0]
|
|
148
224
|
n_scores = score_thresholds.shape[0]
|
|
149
225
|
|
|
150
226
|
average_precision = np.zeros((n_ious, n_labels))
|
|
151
227
|
average_recall = np.zeros((n_scores, n_labels))
|
|
152
|
-
|
|
228
|
+
counts = np.zeros((n_ious, n_scores, n_labels, 7))
|
|
153
229
|
|
|
154
230
|
pd_labels = data[:, 5].astype(int)
|
|
155
231
|
unique_pd_labels = np.unique(pd_labels)
|
|
156
|
-
gt_count =
|
|
232
|
+
gt_count = label_metadata[:, 0]
|
|
157
233
|
running_total_count = np.zeros(
|
|
158
234
|
(n_ious, n_rows),
|
|
159
235
|
dtype=np.float64,
|
|
@@ -239,7 +315,7 @@ def compute_metrics(
|
|
|
239
315
|
out=accuracy,
|
|
240
316
|
)
|
|
241
317
|
|
|
242
|
-
|
|
318
|
+
counts[iou_idx][score_idx] = np.concatenate(
|
|
243
319
|
(
|
|
244
320
|
tp_count[:, np.newaxis],
|
|
245
321
|
fp_count[:, np.newaxis],
|
|
@@ -313,8 +389,8 @@ def compute_metrics(
|
|
|
313
389
|
average_recall /= n_ious
|
|
314
390
|
|
|
315
391
|
# calculate mAP and mAR
|
|
316
|
-
label_key_mapping =
|
|
317
|
-
label_keys = np.unique(
|
|
392
|
+
label_key_mapping = label_metadata[unique_pd_labels, 2]
|
|
393
|
+
label_keys = np.unique(label_metadata[:, 2])
|
|
318
394
|
mAP = np.ones((n_ious, label_keys.shape[0])) * -1.0
|
|
319
395
|
mAR = np.ones((n_scores, label_keys.shape[0])) * -1.0
|
|
320
396
|
for key in np.unique(label_key_mapping):
|
|
@@ -347,30 +423,64 @@ def compute_metrics(
|
|
|
347
423
|
return (
|
|
348
424
|
ap_results,
|
|
349
425
|
ar_results,
|
|
350
|
-
|
|
426
|
+
counts,
|
|
351
427
|
pr_curve,
|
|
352
428
|
)
|
|
353
429
|
|
|
354
430
|
|
|
355
|
-
def
|
|
431
|
+
def compute_detailed_counts(
|
|
356
432
|
data: np.ndarray,
|
|
357
|
-
|
|
433
|
+
label_metadata: np.ndarray,
|
|
358
434
|
iou_thresholds: np.ndarray,
|
|
359
435
|
score_thresholds: np.ndarray,
|
|
360
436
|
n_samples: int,
|
|
361
437
|
) -> np.ndarray:
|
|
362
|
-
|
|
363
438
|
"""
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
439
|
+
Compute detailed counts.
|
|
440
|
+
|
|
441
|
+
Takes data with shape (N, 7):
|
|
442
|
+
|
|
443
|
+
Index 0 - Datum Index
|
|
444
|
+
Index 1 - GroundTruth Index
|
|
445
|
+
Index 2 - Prediction Index
|
|
446
|
+
Index 3 - IoU
|
|
447
|
+
Index 4 - GroundTruth Label Index
|
|
448
|
+
Index 5 - Prediction Label Index
|
|
449
|
+
Index 6 - Score
|
|
450
|
+
|
|
451
|
+
Outputs an array with shape (N_IoUs, N_Score, N_Labels, 5 * n_samples + 5):
|
|
452
|
+
|
|
453
|
+
Index 0 - True Positive Count
|
|
454
|
+
... Datum ID Examples
|
|
455
|
+
Index n_samples + 1 - False Positive Misclassification Count
|
|
456
|
+
... Datum ID Examples
|
|
457
|
+
Index 2 * n_samples + 2 - False Positive Hallucination Count
|
|
458
|
+
... Datum ID Examples
|
|
459
|
+
Index 3 * n_samples + 3 - False Negative Misclassification Count
|
|
460
|
+
... Datum ID Examples
|
|
461
|
+
Index 4 * n_samples + 4 - False Negative Missing Prediction Count
|
|
462
|
+
... Datum ID Examples
|
|
463
|
+
|
|
464
|
+
Parameters
|
|
465
|
+
----------
|
|
466
|
+
data : NDArray[np.floating]
|
|
467
|
+
A sorted array of classification pairs.
|
|
468
|
+
label_metadata : NDArray[np.int32]
|
|
469
|
+
An array containing metadata related to labels.
|
|
470
|
+
iou_thresholds : NDArray[np.floating]
|
|
471
|
+
A 1-D array containing IoU thresholds.
|
|
472
|
+
score_thresholds : NDArray[np.floating]
|
|
473
|
+
A 1-D array containing score thresholds.
|
|
474
|
+
n_samples : int
|
|
475
|
+
The number of examples to return per count.
|
|
476
|
+
|
|
477
|
+
Returns
|
|
478
|
+
-------
|
|
479
|
+
NDArray[np.floating]
|
|
480
|
+
The detailed counts with optional examples.
|
|
371
481
|
"""
|
|
372
482
|
|
|
373
|
-
n_labels =
|
|
483
|
+
n_labels = label_metadata.shape[0]
|
|
374
484
|
n_ious = iou_thresholds.shape[0]
|
|
375
485
|
n_scores = score_thresholds.shape[0]
|
|
376
486
|
n_metrics = 5 * (n_samples + 1)
|
|
@@ -386,46 +496,115 @@ def compute_detailed_pr_curve(
|
|
|
386
496
|
mask_gt_exists = data[:, 1] > -0.5
|
|
387
497
|
mask_pd_exists = data[:, 2] > -0.5
|
|
388
498
|
mask_label_match = np.isclose(data[:, 4], data[:, 5])
|
|
499
|
+
mask_score_nonzero = data[:, 6] > 1e-9
|
|
500
|
+
mask_iou_nonzero = data[:, 3] > 1e-9
|
|
389
501
|
|
|
390
502
|
mask_gt_pd_exists = mask_gt_exists & mask_pd_exists
|
|
391
503
|
mask_gt_pd_match = mask_gt_pd_exists & mask_label_match
|
|
392
504
|
mask_gt_pd_mismatch = mask_gt_pd_exists & ~mask_label_match
|
|
393
505
|
|
|
506
|
+
groundtruths = data[:, [0, 1]].astype(int)
|
|
507
|
+
predictions = data[:, [0, 2]].astype(int)
|
|
394
508
|
for iou_idx in range(n_ious):
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
509
|
+
mask_iou_threshold = data[:, 3] >= iou_thresholds[iou_idx]
|
|
510
|
+
mask_iou = mask_iou_nonzero & mask_iou_threshold
|
|
511
|
+
|
|
512
|
+
groundtruths_with_pairs = np.unique(groundtruths[mask_iou], axis=0)
|
|
513
|
+
mask_groundtruths_with_passing_ious = (
|
|
514
|
+
groundtruths.reshape(-1, 1, 2)
|
|
515
|
+
== groundtruths_with_pairs.reshape(1, -1, 2)
|
|
516
|
+
).all(axis=2)
|
|
517
|
+
mask_groundtruths_with_passing_ious = (
|
|
518
|
+
mask_groundtruths_with_passing_ious.any(axis=1)
|
|
519
|
+
)
|
|
520
|
+
mask_groundtruths_without_passing_ious = (
|
|
521
|
+
~mask_groundtruths_with_passing_ious & mask_gt_exists
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
predictions_with_passing_ious = np.unique(
|
|
525
|
+
predictions[mask_iou], axis=0
|
|
526
|
+
)
|
|
527
|
+
mask_predictions_with_passing_ious = (
|
|
528
|
+
predictions.reshape(-1, 1, 2)
|
|
529
|
+
== predictions_with_passing_ious.reshape(1, -1, 2)
|
|
530
|
+
).all(axis=2)
|
|
531
|
+
mask_predictions_with_passing_ious = (
|
|
532
|
+
mask_predictions_with_passing_ious.any(axis=1)
|
|
533
|
+
)
|
|
534
|
+
mask_predictions_without_passing_ious = (
|
|
535
|
+
~mask_predictions_with_passing_ious & mask_pd_exists
|
|
536
|
+
)
|
|
537
|
+
|
|
398
538
|
for score_idx in range(n_scores):
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
539
|
+
mask_score_threshold = data[:, 6] >= score_thresholds[score_idx]
|
|
540
|
+
mask_score = mask_score_nonzero & mask_score_threshold
|
|
541
|
+
|
|
542
|
+
groundtruths_with_passing_score = np.unique(
|
|
543
|
+
groundtruths[mask_iou & mask_score], axis=0
|
|
544
|
+
)
|
|
545
|
+
mask_groundtruths_with_passing_score = (
|
|
546
|
+
groundtruths.reshape(-1, 1, 2)
|
|
547
|
+
== groundtruths_with_passing_score.reshape(1, -1, 2)
|
|
548
|
+
).all(axis=2)
|
|
549
|
+
mask_groundtruths_with_passing_score = (
|
|
550
|
+
mask_groundtruths_with_passing_score.any(axis=1)
|
|
551
|
+
)
|
|
552
|
+
mask_groundtruths_without_passing_score = (
|
|
553
|
+
~mask_groundtruths_with_passing_score & mask_gt_exists
|
|
405
554
|
)
|
|
406
|
-
mask_fp_halluc = mask_halluc_missing & mask_pd_exists
|
|
407
|
-
mask_fn_misprd = mask_halluc_missing & mask_gt_exists
|
|
408
555
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
556
|
+
mask_tp = mask_score & mask_iou & mask_gt_pd_match
|
|
557
|
+
mask_fp_misclf = mask_score & mask_iou & mask_gt_pd_mismatch
|
|
558
|
+
mask_fn_misclf = mask_iou & (
|
|
559
|
+
(
|
|
560
|
+
~mask_score
|
|
561
|
+
& mask_gt_pd_match
|
|
562
|
+
& mask_groundtruths_with_passing_score
|
|
563
|
+
)
|
|
564
|
+
| (mask_score & mask_gt_pd_mismatch)
|
|
565
|
+
)
|
|
566
|
+
mask_fp_halluc = mask_score & mask_predictions_without_passing_ious
|
|
567
|
+
mask_fn_misprd = (
|
|
568
|
+
mask_groundtruths_without_passing_ious
|
|
569
|
+
| mask_groundtruths_without_passing_score
|
|
570
|
+
)
|
|
414
571
|
|
|
415
|
-
|
|
416
|
-
|
|
572
|
+
tp_pds = np.unique(data[mask_tp][:, [0, 2, 5]], axis=0)
|
|
573
|
+
tp_gts = np.unique(data[mask_tp][:, [0, 1, 4]], axis=0)
|
|
574
|
+
fp_misclf = np.unique(data[mask_fp_misclf][:, [0, 2, 5]], axis=0)
|
|
575
|
+
fp_halluc = np.unique(data[mask_fp_halluc][:, [0, 2, 5]], axis=0)
|
|
576
|
+
fn_misclf = np.unique(data[mask_fn_misclf][:, [0, 1, 4]], axis=0)
|
|
577
|
+
fn_misprd = np.unique(data[mask_fn_misprd][:, [0, 1, 4]], axis=0)
|
|
578
|
+
|
|
579
|
+
mask_fp_misclf_is_tp = (
|
|
580
|
+
(fp_misclf.reshape(-1, 1, 3) == tp_pds.reshape(1, -1, 3))
|
|
581
|
+
.all(axis=2)
|
|
582
|
+
.any(axis=1)
|
|
583
|
+
)
|
|
584
|
+
mask_fn_misclf_is_tp = (
|
|
585
|
+
(fn_misclf.reshape(-1, 1, 3) == tp_gts.reshape(1, -1, 3))
|
|
586
|
+
.all(axis=2)
|
|
587
|
+
.any(axis=1)
|
|
417
588
|
)
|
|
589
|
+
|
|
590
|
+
tp = tp_pds
|
|
591
|
+
fp_misclf = fp_misclf[~mask_fp_misclf_is_tp]
|
|
592
|
+
fp_halluc = fp_halluc
|
|
593
|
+
fn_misclf = fn_misclf[~mask_fn_misclf_is_tp]
|
|
594
|
+
fn_misprd = fn_misprd
|
|
595
|
+
|
|
596
|
+
tp_count = np.bincount(tp[:, 2].astype(int), minlength=n_labels)
|
|
418
597
|
fp_misclf_count = np.bincount(
|
|
419
|
-
|
|
598
|
+
fp_misclf[:, 2].astype(int), minlength=n_labels
|
|
420
599
|
)
|
|
421
600
|
fp_halluc_count = np.bincount(
|
|
422
|
-
|
|
601
|
+
fp_halluc[:, 2].astype(int), minlength=n_labels
|
|
423
602
|
)
|
|
424
603
|
fn_misclf_count = np.bincount(
|
|
425
|
-
|
|
604
|
+
fn_misclf[:, 2].astype(int), minlength=n_labels
|
|
426
605
|
)
|
|
427
606
|
fn_misprd_count = np.bincount(
|
|
428
|
-
|
|
607
|
+
fn_misprd[:, 2].astype(int), minlength=n_labels
|
|
429
608
|
)
|
|
430
609
|
|
|
431
610
|
detailed_pr_curve[iou_idx, score_idx, :, tp_idx] = tp_count
|
|
@@ -444,20 +623,20 @@ def compute_detailed_pr_curve(
|
|
|
444
623
|
|
|
445
624
|
if n_samples > 0:
|
|
446
625
|
for label_idx in range(n_labels):
|
|
447
|
-
tp_examples =
|
|
448
|
-
|
|
449
|
-
]
|
|
450
|
-
fp_misclf_examples =
|
|
451
|
-
|
|
626
|
+
tp_examples = tp[tp[:, 2].astype(int) == label_idx][
|
|
627
|
+
:n_samples, 0
|
|
628
|
+
]
|
|
629
|
+
fp_misclf_examples = fp_misclf[
|
|
630
|
+
fp_misclf[:, 2].astype(int) == label_idx
|
|
452
631
|
][:n_samples, 0]
|
|
453
|
-
fp_halluc_examples =
|
|
454
|
-
|
|
632
|
+
fp_halluc_examples = fp_halluc[
|
|
633
|
+
fp_halluc[:, 2].astype(int) == label_idx
|
|
455
634
|
][:n_samples, 0]
|
|
456
|
-
fn_misclf_examples =
|
|
457
|
-
|
|
635
|
+
fn_misclf_examples = fn_misclf[
|
|
636
|
+
fn_misclf[:, 2].astype(int) == label_idx
|
|
458
637
|
][:n_samples, 0]
|
|
459
|
-
fn_misprd_examples =
|
|
460
|
-
|
|
638
|
+
fn_misprd_examples = fn_misprd[
|
|
639
|
+
fn_misprd[:, 2].astype(int) == label_idx
|
|
461
640
|
][:n_samples, 0]
|
|
462
641
|
|
|
463
642
|
detailed_pr_curve[
|