valor-lite 0.36.6__py3-none-any.whl → 0.37.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +211 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +536 -0
- valor_lite/classification/__init__.py +5 -10
- valor_lite/classification/annotation.py +4 -0
- valor_lite/classification/computation.py +233 -251
- valor_lite/classification/evaluator.py +882 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +141 -4
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +221 -118
- valor_lite/exceptions.py +5 -0
- valor_lite/object_detection/__init__.py +5 -4
- valor_lite/object_detection/annotation.py +13 -1
- valor_lite/object_detection/computation.py +368 -299
- valor_lite/object_detection/evaluator.py +804 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +152 -3
- valor_lite/object_detection/shared.py +206 -0
- valor_lite/object_detection/utilities.py +182 -100
- valor_lite/semantic_segmentation/__init__.py +5 -4
- valor_lite/semantic_segmentation/annotation.py +7 -0
- valor_lite/semantic_segmentation/computation.py +20 -110
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +6 -23
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/METADATA +3 -1
- valor_lite-0.37.5.dist-info/RECORD +49 -0
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/WHEEL +1 -1
- valor_lite/classification/manager.py +0 -545
- valor_lite/object_detection/manager.py +0 -864
- valor_lite/profiling.py +0 -374
- valor_lite/semantic_segmentation/benchmark.py +0 -237
- valor_lite/semantic_segmentation/manager.py +0 -446
- valor_lite-0.36.6.dist-info/RECORD +0 -41
- {valor_lite-0.36.6.dist-info → valor_lite-0.37.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pyarrow as pa
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from valor_lite.cache.ephemeral import MemoryCacheWriter
|
|
6
|
+
from valor_lite.cache.persistent import FileCacheWriter
|
|
7
|
+
from valor_lite.classification.annotation import Classification
|
|
8
|
+
from valor_lite.classification.evaluator import Builder
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Loader(Builder):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
writer: MemoryCacheWriter | FileCacheWriter,
|
|
15
|
+
roc_curve_writer: MemoryCacheWriter | FileCacheWriter,
|
|
16
|
+
intermediate_writer: MemoryCacheWriter | FileCacheWriter,
|
|
17
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
18
|
+
):
|
|
19
|
+
super().__init__(
|
|
20
|
+
writer=writer,
|
|
21
|
+
roc_curve_writer=roc_curve_writer,
|
|
22
|
+
intermediate_writer=intermediate_writer,
|
|
23
|
+
metadata_fields=metadata_fields,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# internal state
|
|
27
|
+
self._labels: dict[str, int] = {}
|
|
28
|
+
self._index_to_label: dict[int, str] = {}
|
|
29
|
+
self._datum_count = 0
|
|
30
|
+
|
|
31
|
+
def _add_label(self, value: str) -> int:
|
|
32
|
+
idx = self._labels.get(value, None)
|
|
33
|
+
if idx is None:
|
|
34
|
+
idx = len(self._labels)
|
|
35
|
+
self._labels[value] = idx
|
|
36
|
+
self._index_to_label[idx] = value
|
|
37
|
+
return idx
|
|
38
|
+
|
|
39
|
+
def add_data(
|
|
40
|
+
self,
|
|
41
|
+
classifications: list[Classification],
|
|
42
|
+
show_progress: bool = False,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Adds classifications to the cache.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
classifications : list[Classification]
|
|
50
|
+
A list of Classification objects.
|
|
51
|
+
show_progress : bool, default=False
|
|
52
|
+
Toggle for tqdm progress bar.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
disable_tqdm = not show_progress
|
|
56
|
+
for classification in tqdm(classifications, disable=disable_tqdm):
|
|
57
|
+
if len(classification.predictions) == 0:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"Classifications must contain at least one prediction."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# prepare metadata
|
|
63
|
+
datum_metadata = (
|
|
64
|
+
classification.metadata if classification.metadata else {}
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# write to cache
|
|
68
|
+
rows = []
|
|
69
|
+
gidx = self._add_label(classification.groundtruth)
|
|
70
|
+
max_score_idx = np.argmax(np.array(classification.scores))
|
|
71
|
+
for idx, (plabel, score) in enumerate(
|
|
72
|
+
zip(classification.predictions, classification.scores)
|
|
73
|
+
):
|
|
74
|
+
pidx = self._add_label(plabel)
|
|
75
|
+
rows.append(
|
|
76
|
+
{
|
|
77
|
+
# metadata
|
|
78
|
+
**datum_metadata,
|
|
79
|
+
# datum
|
|
80
|
+
"datum_uid": classification.uid,
|
|
81
|
+
"datum_id": self._datum_count,
|
|
82
|
+
# groundtruth
|
|
83
|
+
"gt_label": classification.groundtruth,
|
|
84
|
+
"gt_label_id": gidx,
|
|
85
|
+
# prediction
|
|
86
|
+
"pd_label": plabel,
|
|
87
|
+
"pd_label_id": pidx,
|
|
88
|
+
"pd_score": float(score),
|
|
89
|
+
"pd_winner": max_score_idx == idx,
|
|
90
|
+
# pair
|
|
91
|
+
"match": (gidx == pidx) and pidx >= 0,
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
self._writer.write_rows(rows)
|
|
95
|
+
|
|
96
|
+
# update datum count
|
|
97
|
+
self._datum_count += 1
|
|
@@ -13,6 +13,8 @@ class MetricType(Enum):
|
|
|
13
13
|
Accuracy = "Accuracy"
|
|
14
14
|
F1 = "F1"
|
|
15
15
|
ConfusionMatrix = "ConfusionMatrix"
|
|
16
|
+
Examples = "Examples"
|
|
17
|
+
ConfusionMatrixWithExamples = "ConfusionMatrixWithExamples"
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
@dataclass
|
|
@@ -304,6 +306,139 @@ class Metric(BaseMetric):
|
|
|
304
306
|
|
|
305
307
|
@classmethod
|
|
306
308
|
def confusion_matrix(
|
|
309
|
+
cls,
|
|
310
|
+
confusion_matrix: dict[str, dict[str, int]],
|
|
311
|
+
unmatched_ground_truths: dict[str, int],
|
|
312
|
+
score_threshold: float,
|
|
313
|
+
hardmax: bool,
|
|
314
|
+
):
|
|
315
|
+
"""
|
|
316
|
+
Confusion matrix for object detection task.
|
|
317
|
+
|
|
318
|
+
This class encapsulates detailed information about the model's performance, including correct
|
|
319
|
+
predictions, misclassifications and unmatched ground truths (subset of false negatives).
|
|
320
|
+
|
|
321
|
+
Confusion Matrix Format:
|
|
322
|
+
{
|
|
323
|
+
<ground truth label>: {
|
|
324
|
+
<prediction label>: 129
|
|
325
|
+
...
|
|
326
|
+
},
|
|
327
|
+
...
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
Unmatched Ground Truths Format:
|
|
331
|
+
{
|
|
332
|
+
<ground truth label>: 7
|
|
333
|
+
...
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
confusion_matrix : dict
|
|
339
|
+
A nested dictionary containing integer counts of occurences where the first key is the ground truth label value
|
|
340
|
+
and the second key is the prediction label value.
|
|
341
|
+
unmatched_ground_truths : dict
|
|
342
|
+
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
343
|
+
(subset of false negatives). The value is a dictionary containing counts.
|
|
344
|
+
score_threshold : float
|
|
345
|
+
The confidence score threshold used to filter predictions.
|
|
346
|
+
hardmax : bool
|
|
347
|
+
Indicates whether hardmax thresholding was used.
|
|
348
|
+
|
|
349
|
+
Returns
|
|
350
|
+
-------
|
|
351
|
+
Metric
|
|
352
|
+
"""
|
|
353
|
+
return cls(
|
|
354
|
+
type=MetricType.ConfusionMatrix.value,
|
|
355
|
+
value={
|
|
356
|
+
"confusion_matrix": confusion_matrix,
|
|
357
|
+
"unmatched_ground_truths": unmatched_ground_truths,
|
|
358
|
+
},
|
|
359
|
+
parameters={
|
|
360
|
+
"score_threshold": score_threshold,
|
|
361
|
+
"hardmax": hardmax,
|
|
362
|
+
},
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
@classmethod
|
|
366
|
+
def examples(
|
|
367
|
+
cls,
|
|
368
|
+
datum_id: str,
|
|
369
|
+
true_positives: list[str],
|
|
370
|
+
false_positives: list[str],
|
|
371
|
+
false_negatives: list[str],
|
|
372
|
+
score_threshold: float,
|
|
373
|
+
hardmax: bool,
|
|
374
|
+
):
|
|
375
|
+
"""
|
|
376
|
+
Per-datum examples for object detection tasks.
|
|
377
|
+
|
|
378
|
+
This metric is per-datum and contains lists of annotation identifiers that categorize them
|
|
379
|
+
as true-positive, false-positive or false-negative. This is intended to be used with an
|
|
380
|
+
external database where the identifiers can be used for retrieval.
|
|
381
|
+
|
|
382
|
+
Examples Format:
|
|
383
|
+
{
|
|
384
|
+
"type": "Examples",
|
|
385
|
+
"value": {
|
|
386
|
+
"datum_id": "some string ID",
|
|
387
|
+
"true_positives": [
|
|
388
|
+
"label A",
|
|
389
|
+
],
|
|
390
|
+
"false_positives": [
|
|
391
|
+
"label 25",
|
|
392
|
+
"label 92",
|
|
393
|
+
...
|
|
394
|
+
]
|
|
395
|
+
"false_negatives": [
|
|
396
|
+
"groundtruth32",
|
|
397
|
+
"groundtruth24",
|
|
398
|
+
...
|
|
399
|
+
]
|
|
400
|
+
},
|
|
401
|
+
"parameters": {
|
|
402
|
+
"score_threshold": 0.5,
|
|
403
|
+
"hardmax": False,
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
Parameters
|
|
408
|
+
----------
|
|
409
|
+
datum_id : str
|
|
410
|
+
A string identifier representing a datum.
|
|
411
|
+
true_positives : list[str]
|
|
412
|
+
A list of string identifier representing true positive labels.
|
|
413
|
+
false_positives : list[str]
|
|
414
|
+
A list of string identifiers representing false positive predictions.
|
|
415
|
+
false_negatives : list[str]
|
|
416
|
+
A list of string identifiers representing false negative ground truths.
|
|
417
|
+
score_threshold : float
|
|
418
|
+
The confidence score threshold used to filter predictions.
|
|
419
|
+
hardmax : bool
|
|
420
|
+
Indicates whether hardmax thresholding was used.
|
|
421
|
+
|
|
422
|
+
Returns
|
|
423
|
+
-------
|
|
424
|
+
Metric
|
|
425
|
+
"""
|
|
426
|
+
return cls(
|
|
427
|
+
type=MetricType.Examples.value,
|
|
428
|
+
value={
|
|
429
|
+
"datum_id": datum_id,
|
|
430
|
+
"true_positives": true_positives,
|
|
431
|
+
"false_positives": false_positives,
|
|
432
|
+
"false_negatives": false_negatives,
|
|
433
|
+
},
|
|
434
|
+
parameters={
|
|
435
|
+
"score_threshold": score_threshold,
|
|
436
|
+
"hardmax": hardmax,
|
|
437
|
+
},
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
@classmethod
|
|
441
|
+
def confusion_matrix_with_examples(
|
|
307
442
|
cls,
|
|
308
443
|
confusion_matrix: dict[
|
|
309
444
|
str, # ground truth label value
|
|
@@ -329,9 +464,10 @@ class Metric(BaseMetric):
|
|
|
329
464
|
],
|
|
330
465
|
],
|
|
331
466
|
score_threshold: float,
|
|
467
|
+
hardmax: bool,
|
|
332
468
|
):
|
|
333
469
|
"""
|
|
334
|
-
The confusion matrix
|
|
470
|
+
The confusion matrix with examples for the classification task.
|
|
335
471
|
|
|
336
472
|
This class encapsulates detailed information about the model's performance, including correct
|
|
337
473
|
predictions, misclassifications and unmatched ground truths (subset of false negatives).
|
|
@@ -379,20 +515,21 @@ class Metric(BaseMetric):
|
|
|
379
515
|
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
380
516
|
(false negatives). The value is a dictionary containing either a `count` or a list of `examples`.
|
|
381
517
|
Each example includes the datum UID.
|
|
382
|
-
|
|
383
|
-
|
|
518
|
+
hardmax : bool
|
|
519
|
+
Indicates whether hardmax thresholding was used.
|
|
384
520
|
|
|
385
521
|
Returns
|
|
386
522
|
-------
|
|
387
523
|
Metric
|
|
388
524
|
"""
|
|
389
525
|
return cls(
|
|
390
|
-
type=MetricType.
|
|
526
|
+
type=MetricType.ConfusionMatrixWithExamples.value,
|
|
391
527
|
value={
|
|
392
528
|
"confusion_matrix": confusion_matrix,
|
|
393
529
|
"unmatched_ground_truths": unmatched_ground_truths,
|
|
394
530
|
},
|
|
395
531
|
parameters={
|
|
396
532
|
"score_threshold": score_threshold,
|
|
533
|
+
"hardmax": hardmax,
|
|
397
534
|
},
|
|
398
535
|
)
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
import pyarrow.compute as pc
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from valor_lite.cache import FileCacheReader, MemoryCacheReader
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class EvaluatorInfo:
|
|
14
|
+
number_of_rows: int = 0
|
|
15
|
+
number_of_datums: int = 0
|
|
16
|
+
number_of_labels: int = 0
|
|
17
|
+
metadata_fields: list[tuple[str, str]] | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def generate_cache_path(path: str | Path) -> Path:
|
|
21
|
+
return Path(path) / "cache"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def generate_intermediate_cache_path(path: str | Path) -> Path:
|
|
25
|
+
return Path(path) / "intermediate"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def generate_roc_curve_cache_path(path: str | Path) -> Path:
|
|
29
|
+
return Path(path) / "roc_curve"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def generate_metadata_path(path: str | Path) -> Path:
|
|
33
|
+
return Path(path) / "metadata.json"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def generate_schema(
|
|
37
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
38
|
+
) -> pa.Schema:
|
|
39
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
40
|
+
reserved_fields = [
|
|
41
|
+
("datum_uid", pa.string()),
|
|
42
|
+
("datum_id", pa.int64()),
|
|
43
|
+
# groundtruth
|
|
44
|
+
("gt_label", pa.string()),
|
|
45
|
+
("gt_label_id", pa.int64()),
|
|
46
|
+
# prediction
|
|
47
|
+
("pd_label", pa.string()),
|
|
48
|
+
("pd_label_id", pa.int64()),
|
|
49
|
+
("pd_score", pa.float64()),
|
|
50
|
+
("pd_winner", pa.bool_()),
|
|
51
|
+
# pair
|
|
52
|
+
("match", pa.bool_()),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
# validate
|
|
56
|
+
reserved_field_names = {f[0] for f in reserved_fields}
|
|
57
|
+
metadata_field_names = {f[0] for f in metadata_fields}
|
|
58
|
+
if conflicting := reserved_field_names & metadata_field_names:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"metadata fields {conflicting} conflict with reserved fields"
|
|
61
|
+
)
|
|
62
|
+
return pa.schema(reserved_fields + metadata_fields)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def generate_intermediate_schema() -> pa.Schema:
|
|
66
|
+
return pa.schema(
|
|
67
|
+
[
|
|
68
|
+
("pd_label_id", pa.int64()),
|
|
69
|
+
("pd_score", pa.float64()),
|
|
70
|
+
("match", pa.bool_()),
|
|
71
|
+
]
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def generate_roc_curve_schema() -> pa.Schema:
|
|
76
|
+
return pa.schema(
|
|
77
|
+
[
|
|
78
|
+
("pd_label_id", pa.int64()),
|
|
79
|
+
("cumulative_fp", pa.uint64()),
|
|
80
|
+
("cumulative_tp", pa.uint64()),
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def encode_metadata_fields(
|
|
86
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
87
|
+
) -> dict[str, str]:
|
|
88
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
89
|
+
return {k: str(v) for k, v in metadata_fields}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def decode_metadata_fields(
|
|
93
|
+
encoded_metadata_fields: dict[str, str]
|
|
94
|
+
) -> list[tuple[str, str]]:
|
|
95
|
+
return [(k, v) for k, v in encoded_metadata_fields.items()]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def extract_labels(
|
|
99
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
100
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
101
|
+
) -> dict[int, str]:
|
|
102
|
+
if index_to_label_override is not None:
|
|
103
|
+
return index_to_label_override
|
|
104
|
+
|
|
105
|
+
index_to_label = {}
|
|
106
|
+
for tbl in reader.iterate_tables(
|
|
107
|
+
columns=[
|
|
108
|
+
"gt_label_id",
|
|
109
|
+
"gt_label",
|
|
110
|
+
"pd_label_id",
|
|
111
|
+
"pd_label",
|
|
112
|
+
]
|
|
113
|
+
):
|
|
114
|
+
|
|
115
|
+
# get gt labels
|
|
116
|
+
gt_label_ids = tbl["gt_label_id"].to_numpy()
|
|
117
|
+
gt_label_ids, gt_indices = np.unique(gt_label_ids, return_index=True)
|
|
118
|
+
gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
|
|
119
|
+
gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
|
|
120
|
+
gt_labels.pop(-1, None)
|
|
121
|
+
index_to_label.update(gt_labels)
|
|
122
|
+
|
|
123
|
+
# get pd labels
|
|
124
|
+
pd_label_ids = tbl["pd_label_id"].to_numpy()
|
|
125
|
+
pd_label_ids, pd_indices = np.unique(pd_label_ids, return_index=True)
|
|
126
|
+
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
|
|
127
|
+
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
|
|
128
|
+
pd_labels.pop(-1, None)
|
|
129
|
+
index_to_label.update(pd_labels)
|
|
130
|
+
|
|
131
|
+
return index_to_label
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def extract_counts(
|
|
135
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
136
|
+
datums: pc.Expression | None = None,
|
|
137
|
+
):
|
|
138
|
+
n_dts = 0
|
|
139
|
+
for tbl in reader.iterate_tables(filter=datums):
|
|
140
|
+
n_dts += int(np.unique(tbl["datum_id"].to_numpy()).shape[0])
|
|
141
|
+
return n_dts
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def extract_groundtruth_count_per_label(
|
|
145
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
146
|
+
number_of_labels: int,
|
|
147
|
+
datums: pc.Expression | None = None,
|
|
148
|
+
groundtruths: pc.Expression | None = None,
|
|
149
|
+
predictions: pc.Expression | None = None,
|
|
150
|
+
) -> NDArray[np.uint64]:
|
|
151
|
+
|
|
152
|
+
# count ground truth and prediction label occurences
|
|
153
|
+
label_counts = np.zeros((number_of_labels, 2), dtype=np.uint64)
|
|
154
|
+
for tbl in reader.iterate_tables(filter=datums):
|
|
155
|
+
|
|
156
|
+
# count unique gt labels
|
|
157
|
+
gt_expr = pc.field("gt_label_id") >= 0
|
|
158
|
+
if groundtruths is not None:
|
|
159
|
+
gt_expr &= groundtruths
|
|
160
|
+
gt_tbl = tbl.filter(gt_expr)
|
|
161
|
+
gt_ids = np.column_stack(
|
|
162
|
+
[gt_tbl[col].to_numpy() for col in ["datum_id", "gt_label_id"]]
|
|
163
|
+
).astype(np.int64)
|
|
164
|
+
unique_gts = np.unique(gt_ids, axis=0)
|
|
165
|
+
unique_gt_labels, gt_label_counts = np.unique(
|
|
166
|
+
unique_gts[:, 1], return_counts=True
|
|
167
|
+
)
|
|
168
|
+
label_counts[unique_gt_labels, 0] += gt_label_counts.astype(np.uint64)
|
|
169
|
+
|
|
170
|
+
# count unique pd labels
|
|
171
|
+
pd_expr = pc.field("pd_label_id") >= 0
|
|
172
|
+
if predictions is not None:
|
|
173
|
+
pd_expr &= predictions
|
|
174
|
+
pd_tbl = tbl.filter(pd_expr)
|
|
175
|
+
pd_ids = np.column_stack(
|
|
176
|
+
[pd_tbl[col].to_numpy() for col in ["datum_id", "pd_label_id"]]
|
|
177
|
+
).astype(np.int64)
|
|
178
|
+
unique_pds = np.unique(pd_ids, axis=0)
|
|
179
|
+
unique_pd_labels, pd_label_counts = np.unique(
|
|
180
|
+
unique_pds[:, 1], return_counts=True
|
|
181
|
+
)
|
|
182
|
+
label_counts[unique_pd_labels, 1] += pd_label_counts.astype(np.uint64)
|
|
183
|
+
|
|
184
|
+
return label_counts
|