valor-lite 0.37.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/LICENSE +21 -0
- valor_lite/__init__.py +0 -0
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +154 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +529 -0
- valor_lite/classification/__init__.py +14 -0
- valor_lite/classification/annotation.py +45 -0
- valor_lite/classification/computation.py +378 -0
- valor_lite/classification/evaluator.py +879 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +535 -0
- valor_lite/classification/numpy_compatibility.py +13 -0
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +314 -0
- valor_lite/exceptions.py +20 -0
- valor_lite/object_detection/__init__.py +17 -0
- valor_lite/object_detection/annotation.py +238 -0
- valor_lite/object_detection/computation.py +841 -0
- valor_lite/object_detection/evaluator.py +805 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +850 -0
- valor_lite/object_detection/shared.py +185 -0
- valor_lite/object_detection/utilities.py +396 -0
- valor_lite/schemas.py +11 -0
- valor_lite/semantic_segmentation/__init__.py +15 -0
- valor_lite/semantic_segmentation/annotation.py +123 -0
- valor_lite/semantic_segmentation/computation.py +165 -0
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/metric.py +275 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +88 -0
- valor_lite/text_generation/__init__.py +15 -0
- valor_lite/text_generation/annotation.py +56 -0
- valor_lite/text_generation/computation.py +611 -0
- valor_lite/text_generation/llm/__init__.py +0 -0
- valor_lite/text_generation/llm/exceptions.py +14 -0
- valor_lite/text_generation/llm/generation.py +903 -0
- valor_lite/text_generation/llm/instructions.py +814 -0
- valor_lite/text_generation/llm/integrations.py +226 -0
- valor_lite/text_generation/llm/utilities.py +43 -0
- valor_lite/text_generation/llm/validators.py +68 -0
- valor_lite/text_generation/manager.py +697 -0
- valor_lite/text_generation/metric.py +381 -0
- valor_lite-0.37.1.dist-info/METADATA +174 -0
- valor_lite-0.37.1.dist-info/RECORD +49 -0
- valor_lite-0.37.1.dist-info/WHEEL +5 -0
- valor_lite-0.37.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
import pyarrow.compute as pc
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from valor_lite.cache import FileCacheReader, MemoryCacheReader
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class EvaluatorInfo:
|
|
14
|
+
number_of_rows: int = 0
|
|
15
|
+
number_of_datums: int = 0
|
|
16
|
+
number_of_labels: int = 0
|
|
17
|
+
metadata_fields: list[tuple[str, str]] | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def generate_cache_path(path: str | Path) -> Path:
|
|
21
|
+
return Path(path) / "cache"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def generate_intermediate_cache_path(path: str | Path) -> Path:
|
|
25
|
+
return Path(path) / "intermediate"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def generate_roc_curve_cache_path(path: str | Path) -> Path:
|
|
29
|
+
return Path(path) / "roc_curve"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def generate_metadata_path(path: str | Path) -> Path:
|
|
33
|
+
return Path(path) / "metadata.json"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def generate_schema(
|
|
37
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
38
|
+
) -> pa.Schema:
|
|
39
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
40
|
+
reserved_fields = [
|
|
41
|
+
("datum_uid", pa.string()),
|
|
42
|
+
("datum_id", pa.int64()),
|
|
43
|
+
# groundtruth
|
|
44
|
+
("gt_label", pa.string()),
|
|
45
|
+
("gt_label_id", pa.int64()),
|
|
46
|
+
# prediction
|
|
47
|
+
("pd_label", pa.string()),
|
|
48
|
+
("pd_label_id", pa.int64()),
|
|
49
|
+
("pd_score", pa.float64()),
|
|
50
|
+
("pd_winner", pa.bool_()),
|
|
51
|
+
# pair
|
|
52
|
+
("match", pa.bool_()),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
# validate
|
|
56
|
+
reserved_field_names = {f[0] for f in reserved_fields}
|
|
57
|
+
metadata_field_names = {f[0] for f in metadata_fields}
|
|
58
|
+
if conflicting := reserved_field_names & metadata_field_names:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"metadata fields {conflicting} conflict with reserved fields"
|
|
61
|
+
)
|
|
62
|
+
return pa.schema(reserved_fields + metadata_fields)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def generate_intermediate_schema() -> pa.Schema:
|
|
66
|
+
return pa.schema(
|
|
67
|
+
[
|
|
68
|
+
("pd_label_id", pa.int64()),
|
|
69
|
+
("pd_score", pa.float64()),
|
|
70
|
+
("match", pa.bool_()),
|
|
71
|
+
]
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def generate_roc_curve_schema() -> pa.Schema:
|
|
76
|
+
return pa.schema(
|
|
77
|
+
[
|
|
78
|
+
("pd_label_id", pa.int64()),
|
|
79
|
+
("cumulative_fp", pa.uint64()),
|
|
80
|
+
("cumulative_tp", pa.uint64()),
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def encode_metadata_fields(
|
|
86
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
87
|
+
) -> dict[str, str]:
|
|
88
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
89
|
+
return {k: str(v) for k, v in metadata_fields}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def decode_metadata_fields(
|
|
93
|
+
encoded_metadata_fields: dict[str, str]
|
|
94
|
+
) -> list[tuple[str, str]]:
|
|
95
|
+
return [(k, v) for k, v in encoded_metadata_fields.items()]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def extract_labels(
|
|
99
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
100
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
101
|
+
) -> dict[int, str]:
|
|
102
|
+
if index_to_label_override is not None:
|
|
103
|
+
return index_to_label_override
|
|
104
|
+
|
|
105
|
+
index_to_label = {}
|
|
106
|
+
for tbl in reader.iterate_tables(
|
|
107
|
+
columns=[
|
|
108
|
+
"gt_label_id",
|
|
109
|
+
"gt_label",
|
|
110
|
+
"pd_label_id",
|
|
111
|
+
"pd_label",
|
|
112
|
+
]
|
|
113
|
+
):
|
|
114
|
+
|
|
115
|
+
# get gt labels
|
|
116
|
+
gt_label_ids = tbl["gt_label_id"].to_numpy()
|
|
117
|
+
gt_label_ids, gt_indices = np.unique(gt_label_ids, return_index=True)
|
|
118
|
+
gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
|
|
119
|
+
gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
|
|
120
|
+
gt_labels.pop(-1, None)
|
|
121
|
+
index_to_label.update(gt_labels)
|
|
122
|
+
|
|
123
|
+
# get pd labels
|
|
124
|
+
pd_label_ids = tbl["pd_label_id"].to_numpy()
|
|
125
|
+
pd_label_ids, pd_indices = np.unique(pd_label_ids, return_index=True)
|
|
126
|
+
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
|
|
127
|
+
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
|
|
128
|
+
pd_labels.pop(-1, None)
|
|
129
|
+
index_to_label.update(pd_labels)
|
|
130
|
+
|
|
131
|
+
return index_to_label
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def extract_counts(
|
|
135
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
136
|
+
datums: pc.Expression | None = None,
|
|
137
|
+
):
|
|
138
|
+
n_dts = 0
|
|
139
|
+
for tbl in reader.iterate_tables(filter=datums):
|
|
140
|
+
n_dts += int(np.unique(tbl["datum_id"].to_numpy()).shape[0])
|
|
141
|
+
return n_dts
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def extract_groundtruth_count_per_label(
|
|
145
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
146
|
+
number_of_labels: int,
|
|
147
|
+
datums: pc.Expression | None = None,
|
|
148
|
+
groundtruths: pc.Expression | None = None,
|
|
149
|
+
predictions: pc.Expression | None = None,
|
|
150
|
+
) -> NDArray[np.uint64]:
|
|
151
|
+
|
|
152
|
+
# count ground truth and prediction label occurences
|
|
153
|
+
label_counts = np.zeros((number_of_labels, 2), dtype=np.uint64)
|
|
154
|
+
for tbl in reader.iterate_tables(filter=datums):
|
|
155
|
+
|
|
156
|
+
# count unique gt labels
|
|
157
|
+
gt_expr = pc.field("gt_label_id") >= 0
|
|
158
|
+
if groundtruths is not None:
|
|
159
|
+
gt_expr &= groundtruths
|
|
160
|
+
gt_tbl = tbl.filter(gt_expr)
|
|
161
|
+
gt_ids = np.column_stack(
|
|
162
|
+
[gt_tbl[col].to_numpy() for col in ["datum_id", "gt_label_id"]]
|
|
163
|
+
).astype(np.int64)
|
|
164
|
+
unique_gts = np.unique(gt_ids, axis=0)
|
|
165
|
+
unique_gt_labels, gt_label_counts = np.unique(
|
|
166
|
+
unique_gts[:, 1], return_counts=True
|
|
167
|
+
)
|
|
168
|
+
label_counts[unique_gt_labels, 0] += gt_label_counts.astype(np.uint64)
|
|
169
|
+
|
|
170
|
+
# count unique pd labels
|
|
171
|
+
pd_expr = pc.field("pd_label_id") >= 0
|
|
172
|
+
if predictions is not None:
|
|
173
|
+
pd_expr &= predictions
|
|
174
|
+
pd_tbl = tbl.filter(pd_expr)
|
|
175
|
+
pd_ids = np.column_stack(
|
|
176
|
+
[pd_tbl[col].to_numpy() for col in ["datum_id", "pd_label_id"]]
|
|
177
|
+
).astype(np.int64)
|
|
178
|
+
unique_pds = np.unique(pd_ids, axis=0)
|
|
179
|
+
unique_pd_labels, pd_label_counts = np.unique(
|
|
180
|
+
unique_pds[:, 1], return_counts=True
|
|
181
|
+
)
|
|
182
|
+
label_counts[unique_pd_labels, 1] += pd_label_counts.astype(np.uint64)
|
|
183
|
+
|
|
184
|
+
return label_counts
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from valor_lite.classification.metric import Metric, MetricType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def unpack_precision_recall(
|
|
11
|
+
counts: NDArray[np.uint64],
|
|
12
|
+
precision: NDArray[np.float64],
|
|
13
|
+
recall: NDArray[np.float64],
|
|
14
|
+
accuracy: NDArray[np.float64],
|
|
15
|
+
f1_score: NDArray[np.float64],
|
|
16
|
+
score_thresholds: list[float],
|
|
17
|
+
hardmax: bool,
|
|
18
|
+
index_to_label: dict[int, str],
|
|
19
|
+
) -> dict[MetricType, list[Metric]]:
|
|
20
|
+
|
|
21
|
+
metrics = defaultdict(list)
|
|
22
|
+
|
|
23
|
+
metrics[MetricType.Accuracy] = [
|
|
24
|
+
Metric.accuracy(
|
|
25
|
+
value=float(accuracy[score_idx]),
|
|
26
|
+
score_threshold=score_threshold,
|
|
27
|
+
hardmax=hardmax,
|
|
28
|
+
)
|
|
29
|
+
for score_idx, score_threshold in enumerate(score_thresholds)
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
for label_idx, label in index_to_label.items():
|
|
33
|
+
for score_idx, score_threshold in enumerate(score_thresholds):
|
|
34
|
+
kwargs = {
|
|
35
|
+
"label": label,
|
|
36
|
+
"hardmax": hardmax,
|
|
37
|
+
"score_threshold": score_threshold,
|
|
38
|
+
}
|
|
39
|
+
row = counts[:, label_idx]
|
|
40
|
+
metrics[MetricType.Counts].append(
|
|
41
|
+
Metric.counts(
|
|
42
|
+
tp=int(row[score_idx, 0]),
|
|
43
|
+
fp=int(row[score_idx, 1]),
|
|
44
|
+
fn=int(row[score_idx, 2]),
|
|
45
|
+
tn=int(row[score_idx, 3]),
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
metrics[MetricType.Precision].append(
|
|
51
|
+
Metric.precision(
|
|
52
|
+
value=float(precision[score_idx, label_idx]),
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
metrics[MetricType.Recall].append(
|
|
57
|
+
Metric.recall(
|
|
58
|
+
value=float(recall[score_idx, label_idx]),
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
metrics[MetricType.F1].append(
|
|
63
|
+
Metric.f1_score(
|
|
64
|
+
value=float(f1_score[score_idx, label_idx]),
|
|
65
|
+
**kwargs,
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
return metrics
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def unpack_rocauc(
|
|
72
|
+
rocauc: NDArray[np.float64],
|
|
73
|
+
mean_rocauc: float,
|
|
74
|
+
index_to_label: dict[int, str],
|
|
75
|
+
) -> dict[MetricType, list[Metric]]:
|
|
76
|
+
metrics = {}
|
|
77
|
+
metrics[MetricType.ROCAUC] = [
|
|
78
|
+
Metric.roc_auc(
|
|
79
|
+
value=float(rocauc[label_idx]),
|
|
80
|
+
label=label,
|
|
81
|
+
)
|
|
82
|
+
for label_idx, label in index_to_label.items()
|
|
83
|
+
]
|
|
84
|
+
metrics[MetricType.mROCAUC] = [
|
|
85
|
+
Metric.mean_roc_auc(
|
|
86
|
+
value=float(mean_rocauc),
|
|
87
|
+
)
|
|
88
|
+
]
|
|
89
|
+
return metrics
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def unpack_confusion_matrix(
|
|
93
|
+
confusion_matrices: NDArray[np.uint64],
|
|
94
|
+
unmatched_groundtruths: NDArray[np.uint64],
|
|
95
|
+
index_to_label: dict[int, str],
|
|
96
|
+
score_thresholds: list[float],
|
|
97
|
+
hardmax: bool,
|
|
98
|
+
) -> list[Metric]:
|
|
99
|
+
metrics = []
|
|
100
|
+
for score_idx, score_thresh in enumerate(score_thresholds):
|
|
101
|
+
cm_dict = {}
|
|
102
|
+
ugt_dict = {}
|
|
103
|
+
for idx, label in index_to_label.items():
|
|
104
|
+
ugt_dict[label] = int(unmatched_groundtruths[score_idx, idx])
|
|
105
|
+
for pidx, plabel in index_to_label.items():
|
|
106
|
+
if label not in cm_dict:
|
|
107
|
+
cm_dict[label] = {}
|
|
108
|
+
cm_dict[label][plabel] = int(
|
|
109
|
+
confusion_matrices[score_idx, idx, pidx]
|
|
110
|
+
)
|
|
111
|
+
metrics.append(
|
|
112
|
+
Metric.confusion_matrix(
|
|
113
|
+
confusion_matrix=cm_dict,
|
|
114
|
+
unmatched_ground_truths=ugt_dict,
|
|
115
|
+
score_threshold=score_thresh,
|
|
116
|
+
hardmax=hardmax,
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
return metrics
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def create_mapping(
|
|
123
|
+
tbl: pa.Table,
|
|
124
|
+
pairs: NDArray[np.float64],
|
|
125
|
+
index: int,
|
|
126
|
+
id_col: str,
|
|
127
|
+
uid_col: str,
|
|
128
|
+
) -> dict[int, str]:
|
|
129
|
+
col = pairs[:, index].astype(np.int64)
|
|
130
|
+
values, indices = np.unique(col, return_index=True)
|
|
131
|
+
indices = indices[values >= 0]
|
|
132
|
+
return {
|
|
133
|
+
tbl[id_col][idx].as_py(): tbl[uid_col][idx].as_py() for idx in indices
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def unpack_examples(
|
|
138
|
+
ids: NDArray[np.int64],
|
|
139
|
+
mask_tp: NDArray[np.bool_],
|
|
140
|
+
mask_fn: NDArray[np.bool_],
|
|
141
|
+
mask_fp: NDArray[np.bool_],
|
|
142
|
+
score_thresholds: list[float],
|
|
143
|
+
hardmax: bool,
|
|
144
|
+
index_to_datum_id: dict[int, str],
|
|
145
|
+
index_to_label: dict[int, str],
|
|
146
|
+
) -> list[Metric]:
|
|
147
|
+
metrics = []
|
|
148
|
+
unique_datums = np.unique(ids[:, 0])
|
|
149
|
+
for datum_index in unique_datums:
|
|
150
|
+
mask_datum = ids[:, 0] == datum_index
|
|
151
|
+
mask_datum_tp = mask_tp & mask_datum
|
|
152
|
+
mask_datum_fp = mask_fp & mask_datum
|
|
153
|
+
mask_datum_fn = mask_fn & mask_datum
|
|
154
|
+
|
|
155
|
+
datum_id = index_to_datum_id[datum_index]
|
|
156
|
+
for score_idx, score_thresh in enumerate(score_thresholds):
|
|
157
|
+
|
|
158
|
+
unique_tp = np.unique(
|
|
159
|
+
# extract true-positive (datum_id, gt_id, pd_id) pairs
|
|
160
|
+
ids[np.ix_(mask_datum_tp[score_idx], (0, 1, 2))],
|
|
161
|
+
axis=0,
|
|
162
|
+
)
|
|
163
|
+
unique_fp = np.unique(
|
|
164
|
+
# extract false-positive (datum_id, pd_id) pairs
|
|
165
|
+
ids[np.ix_(mask_datum_fp[score_idx], (0, 2))],
|
|
166
|
+
axis=0,
|
|
167
|
+
)
|
|
168
|
+
unique_fn = np.unique(
|
|
169
|
+
# extract false-negative (datum_id, gt_id)
|
|
170
|
+
ids[np.ix_(mask_datum_fn[score_idx], (0, 1))],
|
|
171
|
+
axis=0,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
tp = [index_to_label[row[1]] for row in unique_tp]
|
|
175
|
+
fp = [
|
|
176
|
+
index_to_label[row[1]]
|
|
177
|
+
for row in unique_fp
|
|
178
|
+
if index_to_label[row[1]] not in tp
|
|
179
|
+
]
|
|
180
|
+
fn = [
|
|
181
|
+
index_to_label[row[1]]
|
|
182
|
+
for row in unique_fn
|
|
183
|
+
if index_to_label[row[1]] not in tp
|
|
184
|
+
]
|
|
185
|
+
metrics.append(
|
|
186
|
+
Metric.examples(
|
|
187
|
+
datum_id=datum_id,
|
|
188
|
+
true_positives=tp,
|
|
189
|
+
false_negatives=fn,
|
|
190
|
+
false_positives=fp,
|
|
191
|
+
score_threshold=score_thresh,
|
|
192
|
+
hardmax=hardmax,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
return metrics
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def create_empty_confusion_matrix_with_examples(
|
|
199
|
+
score_threshold: float,
|
|
200
|
+
hardmax: bool,
|
|
201
|
+
index_to_label: dict[int, str],
|
|
202
|
+
) -> Metric:
|
|
203
|
+
unmatched_groundtruths = dict()
|
|
204
|
+
confusion_matrix = dict()
|
|
205
|
+
for label in index_to_label.values():
|
|
206
|
+
unmatched_groundtruths[label] = {"count": 0, "examples": []}
|
|
207
|
+
confusion_matrix[label] = {}
|
|
208
|
+
for plabel in index_to_label.values():
|
|
209
|
+
confusion_matrix[label][plabel] = {"count": 0, "examples": []}
|
|
210
|
+
|
|
211
|
+
return Metric.confusion_matrix_with_examples(
|
|
212
|
+
confusion_matrix=confusion_matrix,
|
|
213
|
+
unmatched_ground_truths=unmatched_groundtruths,
|
|
214
|
+
score_threshold=score_threshold,
|
|
215
|
+
hardmax=hardmax,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _unpack_confusion_matrix_with_examples(
|
|
220
|
+
metric: Metric,
|
|
221
|
+
ids: NDArray[np.int64],
|
|
222
|
+
scores: NDArray[np.float64],
|
|
223
|
+
winners: NDArray[np.bool_],
|
|
224
|
+
mask_matched: NDArray[np.bool_],
|
|
225
|
+
mask_unmatched_fn: NDArray[np.bool_],
|
|
226
|
+
index_to_datum_id: dict[int, str],
|
|
227
|
+
index_to_label: dict[int, str],
|
|
228
|
+
):
|
|
229
|
+
if not isinstance(metric.value, dict):
|
|
230
|
+
raise TypeError("expected metric to contain a dictionary value")
|
|
231
|
+
|
|
232
|
+
mask_valid_gts = ids[:, 1] >= 0
|
|
233
|
+
mask_valid_pds = ids[:, 2] >= 0
|
|
234
|
+
|
|
235
|
+
valid_matches = ids[mask_valid_gts & mask_valid_pds]
|
|
236
|
+
valid_gts = ids[mask_valid_gts]
|
|
237
|
+
|
|
238
|
+
n_matched = 0
|
|
239
|
+
unique_matches = np.empty((1, 3))
|
|
240
|
+
if valid_matches.size > 0:
|
|
241
|
+
unique_matches, unique_match_indices = np.unique(
|
|
242
|
+
# extract matched (datum_id, gt_id, pd_id) pairs
|
|
243
|
+
valid_matches[np.ix_(mask_matched, (0, 1, 2))], # type: ignore[reportArgumentType]
|
|
244
|
+
axis=0,
|
|
245
|
+
return_index=True,
|
|
246
|
+
)
|
|
247
|
+
scores = scores[mask_matched][unique_match_indices]
|
|
248
|
+
n_matched = unique_matches.shape[0]
|
|
249
|
+
|
|
250
|
+
n_unmatched_groundtruths = 0
|
|
251
|
+
unique_unmatched_groundtruths = np.empty((1, 2))
|
|
252
|
+
if valid_gts.size > 0:
|
|
253
|
+
unique_unmatched_groundtruths = np.unique(
|
|
254
|
+
# extract unmatched false-negative (datum_id, gt_id) pairs
|
|
255
|
+
valid_gts[np.ix_(mask_unmatched_fn, (0, 1))], # type: ignore[reportArgumentType]
|
|
256
|
+
axis=0,
|
|
257
|
+
)
|
|
258
|
+
unique_unmatched_groundtruths = unique_unmatched_groundtruths[
|
|
259
|
+
unique_unmatched_groundtruths[:, 1] >= 0
|
|
260
|
+
]
|
|
261
|
+
n_unmatched_groundtruths = unique_unmatched_groundtruths.shape[0]
|
|
262
|
+
|
|
263
|
+
n_max = max(n_matched, n_unmatched_groundtruths)
|
|
264
|
+
for idx in range(n_max):
|
|
265
|
+
if idx < n_matched:
|
|
266
|
+
datum_id = index_to_datum_id[unique_matches[idx, 0]]
|
|
267
|
+
glabel = index_to_label[unique_matches[idx, 1]]
|
|
268
|
+
plabel = index_to_label[unique_matches[idx, 2]]
|
|
269
|
+
score = float(scores[idx])
|
|
270
|
+
|
|
271
|
+
metric.value["confusion_matrix"][glabel][plabel]["count"] += 1
|
|
272
|
+
metric.value["confusion_matrix"][glabel][plabel][
|
|
273
|
+
"examples"
|
|
274
|
+
].append(
|
|
275
|
+
{
|
|
276
|
+
"datum_id": datum_id,
|
|
277
|
+
"score": score,
|
|
278
|
+
}
|
|
279
|
+
)
|
|
280
|
+
if idx < n_unmatched_groundtruths:
|
|
281
|
+
datum_id = index_to_datum_id[unique_unmatched_groundtruths[idx, 0]]
|
|
282
|
+
label = index_to_label[unique_unmatched_groundtruths[idx, 1]]
|
|
283
|
+
|
|
284
|
+
metric.value["unmatched_ground_truths"][label]["count"] += 1
|
|
285
|
+
metric.value["unmatched_ground_truths"][label]["examples"].append(
|
|
286
|
+
{"datum_id": datum_id}
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return metric
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def unpack_confusion_matrix_with_examples(
|
|
293
|
+
metrics: dict[int, Metric],
|
|
294
|
+
ids: NDArray[np.int64],
|
|
295
|
+
scores: NDArray[np.float64],
|
|
296
|
+
winners: NDArray[np.bool_],
|
|
297
|
+
mask_matched: NDArray[np.bool_],
|
|
298
|
+
mask_unmatched_fn: NDArray[np.bool_],
|
|
299
|
+
index_to_datum_id: dict[int, str],
|
|
300
|
+
index_to_label: dict[int, str],
|
|
301
|
+
) -> list[Metric]:
|
|
302
|
+
return [
|
|
303
|
+
_unpack_confusion_matrix_with_examples(
|
|
304
|
+
metric,
|
|
305
|
+
ids=ids,
|
|
306
|
+
scores=scores,
|
|
307
|
+
winners=winners,
|
|
308
|
+
mask_matched=mask_matched[score_idx, :],
|
|
309
|
+
mask_unmatched_fn=mask_unmatched_fn[score_idx, :],
|
|
310
|
+
index_to_datum_id=index_to_datum_id,
|
|
311
|
+
index_to_label=index_to_label,
|
|
312
|
+
)
|
|
313
|
+
for score_idx, metric in metrics.items()
|
|
314
|
+
]
|
valor_lite/exceptions.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
class EmptyEvaluatorError(Exception):
|
|
2
|
+
def __init__(self):
|
|
3
|
+
super().__init__(
|
|
4
|
+
"evaluator cannot be finalized as it contains no data"
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EmptyCacheError(Exception):
|
|
9
|
+
def __init__(self):
|
|
10
|
+
super().__init__("cache contains no data")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmptyFilterError(Exception):
|
|
14
|
+
def __init__(self, message: str):
|
|
15
|
+
super().__init__(message)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InternalCacheError(Exception):
|
|
19
|
+
def __init__(self, message: str):
|
|
20
|
+
super().__init__(message)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .annotation import Bitmask, BoundingBox, Detection, Polygon
|
|
2
|
+
from .evaluator import Evaluator
|
|
3
|
+
from .loader import Loader
|
|
4
|
+
from .metric import Metric, MetricType
|
|
5
|
+
from .shared import EvaluatorInfo
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Bitmask",
|
|
9
|
+
"BoundingBox",
|
|
10
|
+
"Detection",
|
|
11
|
+
"Polygon",
|
|
12
|
+
"Metric",
|
|
13
|
+
"MetricType",
|
|
14
|
+
"Loader",
|
|
15
|
+
"Evaluator",
|
|
16
|
+
"EvaluatorInfo",
|
|
17
|
+
]
|