valor-lite 0.37.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/LICENSE +21 -0
- valor_lite/__init__.py +0 -0
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +154 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +529 -0
- valor_lite/classification/__init__.py +14 -0
- valor_lite/classification/annotation.py +45 -0
- valor_lite/classification/computation.py +378 -0
- valor_lite/classification/evaluator.py +879 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +535 -0
- valor_lite/classification/numpy_compatibility.py +13 -0
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +314 -0
- valor_lite/exceptions.py +20 -0
- valor_lite/object_detection/__init__.py +17 -0
- valor_lite/object_detection/annotation.py +238 -0
- valor_lite/object_detection/computation.py +841 -0
- valor_lite/object_detection/evaluator.py +805 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +850 -0
- valor_lite/object_detection/shared.py +185 -0
- valor_lite/object_detection/utilities.py +396 -0
- valor_lite/schemas.py +11 -0
- valor_lite/semantic_segmentation/__init__.py +15 -0
- valor_lite/semantic_segmentation/annotation.py +123 -0
- valor_lite/semantic_segmentation/computation.py +165 -0
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/metric.py +275 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +88 -0
- valor_lite/text_generation/__init__.py +15 -0
- valor_lite/text_generation/annotation.py +56 -0
- valor_lite/text_generation/computation.py +611 -0
- valor_lite/text_generation/llm/__init__.py +0 -0
- valor_lite/text_generation/llm/exceptions.py +14 -0
- valor_lite/text_generation/llm/generation.py +903 -0
- valor_lite/text_generation/llm/instructions.py +814 -0
- valor_lite/text_generation/llm/integrations.py +226 -0
- valor_lite/text_generation/llm/utilities.py +43 -0
- valor_lite/text_generation/llm/validators.py +68 -0
- valor_lite/text_generation/manager.py +697 -0
- valor_lite/text_generation/metric.py +381 -0
- valor_lite-0.37.1.dist-info/METADATA +174 -0
- valor_lite-0.37.1.dist-info/RECORD +49 -0
- valor_lite-0.37.1.dist-info/WHEEL +5 -0
- valor_lite-0.37.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pyarrow as pa
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from valor_lite.cache import FileCacheWriter, MemoryCacheWriter
|
|
6
|
+
from valor_lite.semantic_segmentation.annotation import Segmentation
|
|
7
|
+
from valor_lite.semantic_segmentation.computation import compute_intermediates
|
|
8
|
+
from valor_lite.semantic_segmentation.evaluator import Builder
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Loader(Builder):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
writer: MemoryCacheWriter | FileCacheWriter,
|
|
15
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(
|
|
18
|
+
writer=writer,
|
|
19
|
+
metadata_fields=metadata_fields,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# internal state
|
|
23
|
+
self._labels: dict[str, int] = {}
|
|
24
|
+
self._index_to_label: dict[int, str] = {}
|
|
25
|
+
self._datum_count = 0
|
|
26
|
+
|
|
27
|
+
def _add_label(self, value: str) -> int:
|
|
28
|
+
idx = self._labels.get(value, None)
|
|
29
|
+
if idx is None:
|
|
30
|
+
idx = len(self._labels)
|
|
31
|
+
self._labels[value] = idx
|
|
32
|
+
self._index_to_label[idx] = value
|
|
33
|
+
return idx
|
|
34
|
+
|
|
35
|
+
def add_data(
|
|
36
|
+
self,
|
|
37
|
+
segmentations: list[Segmentation],
|
|
38
|
+
show_progress: bool = False,
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Adds segmentations to the cache.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
segmentations : list[Segmentation]
|
|
46
|
+
A list of Segmentation objects.
|
|
47
|
+
show_progress : bool, default=False
|
|
48
|
+
Toggle for tqdm progress bar.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
disable_tqdm = not show_progress
|
|
52
|
+
for segmentation in tqdm(segmentations, disable=disable_tqdm):
|
|
53
|
+
|
|
54
|
+
groundtruth_labels = -1 * np.ones(
|
|
55
|
+
len(segmentation.groundtruths), dtype=np.int64
|
|
56
|
+
)
|
|
57
|
+
for idx, groundtruth in enumerate(segmentation.groundtruths):
|
|
58
|
+
label_idx = self._add_label(groundtruth.label)
|
|
59
|
+
groundtruth_labels[idx] = label_idx
|
|
60
|
+
|
|
61
|
+
prediction_labels = -1 * np.ones(
|
|
62
|
+
len(segmentation.predictions), dtype=np.int64
|
|
63
|
+
)
|
|
64
|
+
for idx, prediction in enumerate(segmentation.predictions):
|
|
65
|
+
label_idx = self._add_label(prediction.label)
|
|
66
|
+
prediction_labels[idx] = label_idx
|
|
67
|
+
|
|
68
|
+
if segmentation.groundtruths:
|
|
69
|
+
combined_groundtruths = np.stack(
|
|
70
|
+
[
|
|
71
|
+
groundtruth.mask.flatten()
|
|
72
|
+
for groundtruth in segmentation.groundtruths
|
|
73
|
+
],
|
|
74
|
+
axis=0,
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
combined_groundtruths = np.zeros(
|
|
78
|
+
(1, segmentation.shape[0] * segmentation.shape[1]),
|
|
79
|
+
dtype=np.bool_,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if segmentation.predictions:
|
|
83
|
+
combined_predictions = np.stack(
|
|
84
|
+
[
|
|
85
|
+
prediction.mask.flatten()
|
|
86
|
+
for prediction in segmentation.predictions
|
|
87
|
+
],
|
|
88
|
+
axis=0,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
combined_predictions = np.zeros(
|
|
92
|
+
(1, segmentation.shape[0] * segmentation.shape[1]),
|
|
93
|
+
dtype=np.bool_,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
n_labels = len(self._labels)
|
|
97
|
+
counts = compute_intermediates(
|
|
98
|
+
groundtruths=combined_groundtruths,
|
|
99
|
+
predictions=combined_predictions,
|
|
100
|
+
groundtruth_labels=groundtruth_labels,
|
|
101
|
+
prediction_labels=prediction_labels,
|
|
102
|
+
n_labels=n_labels,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# prepare metadata
|
|
106
|
+
datum_metadata = (
|
|
107
|
+
segmentation.metadata if segmentation.metadata else {}
|
|
108
|
+
)
|
|
109
|
+
gt_metadata = {
|
|
110
|
+
self._labels[gt.label]: gt.metadata
|
|
111
|
+
for gt in segmentation.groundtruths
|
|
112
|
+
if gt.metadata
|
|
113
|
+
}
|
|
114
|
+
pd_metadata = {
|
|
115
|
+
self._labels[pd.label]: pd.metadata
|
|
116
|
+
for pd in segmentation.predictions
|
|
117
|
+
if pd.metadata
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
# cache formatting
|
|
121
|
+
rows = []
|
|
122
|
+
for idx in range(n_labels):
|
|
123
|
+
label = self._index_to_label[idx]
|
|
124
|
+
for pidx in range(n_labels):
|
|
125
|
+
# write non-zero intersections to cache
|
|
126
|
+
if counts[idx + 1, pidx + 1] > 0:
|
|
127
|
+
plabel = self._index_to_label[pidx]
|
|
128
|
+
rows.append(
|
|
129
|
+
{
|
|
130
|
+
# metadata
|
|
131
|
+
**datum_metadata,
|
|
132
|
+
**gt_metadata.get(idx, {}),
|
|
133
|
+
**pd_metadata.get(pidx, {}),
|
|
134
|
+
# datum
|
|
135
|
+
"datum_uid": segmentation.uid,
|
|
136
|
+
"datum_id": self._datum_count,
|
|
137
|
+
# groundtruth
|
|
138
|
+
"gt_label": label,
|
|
139
|
+
"gt_label_id": idx,
|
|
140
|
+
# prediction
|
|
141
|
+
"pd_label": plabel,
|
|
142
|
+
"pd_label_id": pidx,
|
|
143
|
+
# pair
|
|
144
|
+
"count": counts[idx + 1, pidx + 1],
|
|
145
|
+
}
|
|
146
|
+
)
|
|
147
|
+
# write all unmatched to preserve labels
|
|
148
|
+
rows.extend(
|
|
149
|
+
[
|
|
150
|
+
{
|
|
151
|
+
# metadata
|
|
152
|
+
**datum_metadata,
|
|
153
|
+
**gt_metadata.get(idx, {}),
|
|
154
|
+
# datum
|
|
155
|
+
"datum_uid": segmentation.uid,
|
|
156
|
+
"datum_id": self._datum_count,
|
|
157
|
+
# groundtruth
|
|
158
|
+
"gt_label": label,
|
|
159
|
+
"gt_label_id": idx,
|
|
160
|
+
# prediction
|
|
161
|
+
"pd_label": None,
|
|
162
|
+
"pd_label_id": -1,
|
|
163
|
+
# pair
|
|
164
|
+
"count": counts[idx + 1, 0],
|
|
165
|
+
},
|
|
166
|
+
{
|
|
167
|
+
# metadata
|
|
168
|
+
**datum_metadata,
|
|
169
|
+
**gt_metadata.get(idx, {}),
|
|
170
|
+
**pd_metadata.get(idx, {}),
|
|
171
|
+
# datum
|
|
172
|
+
"datum_uid": segmentation.uid,
|
|
173
|
+
"datum_id": self._datum_count,
|
|
174
|
+
# groundtruth
|
|
175
|
+
"gt_label": None,
|
|
176
|
+
"gt_label_id": -1,
|
|
177
|
+
# prediction
|
|
178
|
+
"pd_label": label,
|
|
179
|
+
"pd_label_id": idx,
|
|
180
|
+
# pair
|
|
181
|
+
"count": counts[0, idx + 1],
|
|
182
|
+
},
|
|
183
|
+
]
|
|
184
|
+
)
|
|
185
|
+
rows.append(
|
|
186
|
+
{
|
|
187
|
+
# metadata
|
|
188
|
+
**datum_metadata,
|
|
189
|
+
# datum
|
|
190
|
+
"datum_uid": segmentation.uid,
|
|
191
|
+
"datum_id": self._datum_count,
|
|
192
|
+
# groundtruth
|
|
193
|
+
"gt_label": None,
|
|
194
|
+
"gt_label_id": -1,
|
|
195
|
+
# prediction
|
|
196
|
+
"pd_label": None,
|
|
197
|
+
"pd_label_id": -1,
|
|
198
|
+
# pair
|
|
199
|
+
"count": counts[0, 0],
|
|
200
|
+
}
|
|
201
|
+
)
|
|
202
|
+
self._writer.write_rows(rows)
|
|
203
|
+
|
|
204
|
+
# update datum count
|
|
205
|
+
self._datum_count += 1
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from valor_lite.schemas import BaseMetric
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MetricType(Enum):
|
|
8
|
+
Precision = "Precision"
|
|
9
|
+
Recall = "Recall"
|
|
10
|
+
Accuracy = "Accuracy"
|
|
11
|
+
F1 = "F1"
|
|
12
|
+
IOU = "IOU"
|
|
13
|
+
mIOU = "mIOU"
|
|
14
|
+
ConfusionMatrix = "ConfusionMatrix"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class Metric(BaseMetric):
|
|
19
|
+
"""
|
|
20
|
+
Semantic Segmentation Metric.
|
|
21
|
+
|
|
22
|
+
Attributes
|
|
23
|
+
----------
|
|
24
|
+
type : str
|
|
25
|
+
The metric type.
|
|
26
|
+
value : int | float | dict
|
|
27
|
+
The metric value.
|
|
28
|
+
parameters : dict[str, Any]
|
|
29
|
+
A dictionary containing metric parameters.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __post_init__(self):
|
|
33
|
+
if not isinstance(self.type, str):
|
|
34
|
+
raise TypeError(
|
|
35
|
+
f"Metric type should be of type 'str': {self.type}"
|
|
36
|
+
)
|
|
37
|
+
elif not isinstance(self.value, (int, float, dict)):
|
|
38
|
+
raise TypeError(
|
|
39
|
+
f"Metric value must be of type 'int', 'float' or 'dict': {self.value}"
|
|
40
|
+
)
|
|
41
|
+
elif not isinstance(self.parameters, dict):
|
|
42
|
+
raise TypeError(
|
|
43
|
+
f"Metric parameters must be of type 'dict[str, Any]': {self.parameters}"
|
|
44
|
+
)
|
|
45
|
+
elif not all([isinstance(k, str) for k in self.parameters.keys()]):
|
|
46
|
+
raise TypeError(
|
|
47
|
+
f"Metric parameter dictionary should only have keys with type 'str': {self.parameters}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def precision(
|
|
52
|
+
cls,
|
|
53
|
+
value: float,
|
|
54
|
+
label: str,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Precision metric for a specific class label.
|
|
58
|
+
|
|
59
|
+
Precision is calulated using the number of true-positive pixels divided by
|
|
60
|
+
the sum of all true-positive and false-positive pixels.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
value : float
|
|
65
|
+
The computed precision value.
|
|
66
|
+
label : str
|
|
67
|
+
The label for which the precision is calculated.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
Metric
|
|
72
|
+
"""
|
|
73
|
+
return cls(
|
|
74
|
+
type=MetricType.Precision.value,
|
|
75
|
+
value=value,
|
|
76
|
+
parameters={
|
|
77
|
+
"label": label,
|
|
78
|
+
},
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def recall(
|
|
83
|
+
cls,
|
|
84
|
+
value: float,
|
|
85
|
+
label: str,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Recall metric for a specific class label.
|
|
89
|
+
|
|
90
|
+
Recall is calulated using the number of true-positive pixels divided by
|
|
91
|
+
the sum of all true-positive and false-negative pixels.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
value : float
|
|
96
|
+
The computed recall value.
|
|
97
|
+
label : str
|
|
98
|
+
The label for which the recall is calculated.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
Metric
|
|
103
|
+
"""
|
|
104
|
+
return cls(
|
|
105
|
+
type=MetricType.Recall.value,
|
|
106
|
+
value=value,
|
|
107
|
+
parameters={
|
|
108
|
+
"label": label,
|
|
109
|
+
},
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def f1_score(
|
|
114
|
+
cls,
|
|
115
|
+
value: float,
|
|
116
|
+
label: str,
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
F1 score for a specific class label.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
value : float
|
|
124
|
+
The computed F1 score.
|
|
125
|
+
label : str
|
|
126
|
+
The label for which the F1 score is calculated.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
Metric
|
|
131
|
+
"""
|
|
132
|
+
return cls(
|
|
133
|
+
type=MetricType.F1.value,
|
|
134
|
+
value=value,
|
|
135
|
+
parameters={
|
|
136
|
+
"label": label,
|
|
137
|
+
},
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def iou(
|
|
142
|
+
cls,
|
|
143
|
+
value: float,
|
|
144
|
+
label: str,
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
Intersection over Union (IOU) ratio for a specific class label.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
value : float
|
|
152
|
+
The computed IOU ratio.
|
|
153
|
+
label : str
|
|
154
|
+
The label for which the IOU is calculated.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
Metric
|
|
159
|
+
"""
|
|
160
|
+
return cls(
|
|
161
|
+
type=MetricType.IOU.value,
|
|
162
|
+
value=value,
|
|
163
|
+
parameters={
|
|
164
|
+
"label": label,
|
|
165
|
+
},
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def mean_iou(cls, value: float):
|
|
170
|
+
"""
|
|
171
|
+
Mean Intersection over Union (mIOU) ratio.
|
|
172
|
+
|
|
173
|
+
The mIOU value is computed by averaging IOU over all labels.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
value : float
|
|
178
|
+
The mIOU value.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
Metric
|
|
183
|
+
"""
|
|
184
|
+
return cls(type=MetricType.mIOU.value, value=value, parameters={})
|
|
185
|
+
|
|
186
|
+
@classmethod
|
|
187
|
+
def accuracy(cls, value: float):
|
|
188
|
+
"""
|
|
189
|
+
Accuracy metric computed over all labels.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
value : float
|
|
194
|
+
The accuracy value.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
Metric
|
|
199
|
+
"""
|
|
200
|
+
return cls(type=MetricType.Accuracy.value, value=value, parameters={})
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def confusion_matrix(
|
|
204
|
+
cls,
|
|
205
|
+
confusion_matrix: dict[
|
|
206
|
+
str, # ground truth label value
|
|
207
|
+
dict[
|
|
208
|
+
str, # prediction label value
|
|
209
|
+
dict[str, float], # iou
|
|
210
|
+
],
|
|
211
|
+
],
|
|
212
|
+
unmatched_predictions: dict[
|
|
213
|
+
str, # prediction label value
|
|
214
|
+
dict[str, float], # pixel ratio
|
|
215
|
+
],
|
|
216
|
+
unmatched_ground_truths: dict[
|
|
217
|
+
str, # ground truth label value
|
|
218
|
+
dict[str, float], # pixel ratio
|
|
219
|
+
],
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
The confusion matrix and related metrics for semantic segmentation tasks.
|
|
223
|
+
|
|
224
|
+
This class encapsulates detailed information about the model's performance, including correct
|
|
225
|
+
predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
|
|
226
|
+
(subset of false negatives). It provides counts for each category to facilitate in-depth analysis.
|
|
227
|
+
|
|
228
|
+
Confusion Matrix Format:
|
|
229
|
+
{
|
|
230
|
+
<ground truth label>: {
|
|
231
|
+
<prediction label>: {
|
|
232
|
+
'iou': <float>,
|
|
233
|
+
},
|
|
234
|
+
},
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
Unmatched Predictions Format:
|
|
238
|
+
{
|
|
239
|
+
<prediction label>: {
|
|
240
|
+
'iou': <float>,
|
|
241
|
+
},
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
Unmatched Ground Truths Format:
|
|
245
|
+
{
|
|
246
|
+
<ground truth label>: {
|
|
247
|
+
'iou': <float>,
|
|
248
|
+
},
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
confusion_matrix : dict
|
|
254
|
+
Nested dictionaries representing the Intersection over Union (IOU) scores for each
|
|
255
|
+
ground truth label and prediction label pair.
|
|
256
|
+
unmatched_predictions : dict
|
|
257
|
+
Dictionary representing the pixel ratios for predicted labels that do not correspond
|
|
258
|
+
to any ground truth labels (false positives).
|
|
259
|
+
unmatched_ground_truths : dict
|
|
260
|
+
Dictionary representing the pixel ratios for ground truth labels that were not predicted
|
|
261
|
+
(false negatives).
|
|
262
|
+
|
|
263
|
+
Returns
|
|
264
|
+
-------
|
|
265
|
+
Metric
|
|
266
|
+
"""
|
|
267
|
+
return cls(
|
|
268
|
+
type=MetricType.ConfusionMatrix.value,
|
|
269
|
+
value={
|
|
270
|
+
"confusion_matrix": confusion_matrix,
|
|
271
|
+
"unmatched_predictions": unmatched_predictions,
|
|
272
|
+
"unmatched_ground_truths": unmatched_ground_truths,
|
|
273
|
+
},
|
|
274
|
+
parameters={},
|
|
275
|
+
)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
import pyarrow.compute as pc
|
|
7
|
+
|
|
8
|
+
from valor_lite.cache import FileCacheReader, MemoryCacheReader
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class EvaluatorInfo:
|
|
13
|
+
number_of_rows: int = 0
|
|
14
|
+
number_of_datums: int = 0
|
|
15
|
+
number_of_labels: int = 0
|
|
16
|
+
number_of_pixels: int = 0
|
|
17
|
+
number_of_groundtruth_pixels: int = 0
|
|
18
|
+
number_of_prediction_pixels: int = 0
|
|
19
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def generate_cache_path(path: str | Path) -> Path:
|
|
23
|
+
"""Generate cache path from parent directory."""
|
|
24
|
+
return Path(path) / "counts"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def generate_metadata_path(path: str | Path) -> Path:
|
|
28
|
+
"""Generate metadata path from parent directory."""
|
|
29
|
+
return Path(path) / "metadata.json"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def generate_schema(
|
|
33
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
34
|
+
) -> pa.Schema:
|
|
35
|
+
"""Generate PyArrow schema from metadata fields."""
|
|
36
|
+
|
|
37
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
38
|
+
reserved_fields = [
|
|
39
|
+
("datum_uid", pa.string()),
|
|
40
|
+
("datum_id", pa.int64()),
|
|
41
|
+
# groundtruth
|
|
42
|
+
("gt_label", pa.string()),
|
|
43
|
+
("gt_label_id", pa.int64()),
|
|
44
|
+
# prediction
|
|
45
|
+
("pd_label", pa.string()),
|
|
46
|
+
("pd_label_id", pa.int64()),
|
|
47
|
+
# pair
|
|
48
|
+
("count", pa.uint64()),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
# validate
|
|
52
|
+
reserved_field_names = {f[0] for f in reserved_fields}
|
|
53
|
+
metadata_field_names = {f[0] for f in metadata_fields}
|
|
54
|
+
if conflicting := reserved_field_names & metadata_field_names:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"metadata fields {conflicting} conflict with reserved fields"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return pa.schema(
|
|
60
|
+
[
|
|
61
|
+
*reserved_fields,
|
|
62
|
+
*metadata_fields,
|
|
63
|
+
]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def encode_metadata_fields(
|
|
68
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None
|
|
69
|
+
) -> dict[str, str]:
|
|
70
|
+
"""Encode metadata fields into JSON format."""
|
|
71
|
+
metadata_fields = metadata_fields if metadata_fields else []
|
|
72
|
+
return {k: str(v) for k, v in metadata_fields}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def decode_metadata_fields(
|
|
76
|
+
encoded_metadata_fields: dict[str, str]
|
|
77
|
+
) -> list[tuple[str, str | pa.DataType]]:
|
|
78
|
+
"""Decode metadata fields from JSON format."""
|
|
79
|
+
return [(k, v) for k, v in encoded_metadata_fields.items()]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def extract_labels(
|
|
83
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
84
|
+
index_to_label_override: dict[int, str] | None = None,
|
|
85
|
+
) -> dict[int, str]:
|
|
86
|
+
if index_to_label_override is not None:
|
|
87
|
+
return index_to_label_override
|
|
88
|
+
|
|
89
|
+
index_to_label = {}
|
|
90
|
+
for tbl in reader.iterate_tables(
|
|
91
|
+
columns=[
|
|
92
|
+
"gt_label_id",
|
|
93
|
+
"gt_label",
|
|
94
|
+
"pd_label_id",
|
|
95
|
+
"pd_label",
|
|
96
|
+
]
|
|
97
|
+
):
|
|
98
|
+
|
|
99
|
+
# get gt labels
|
|
100
|
+
gt_label_ids = tbl["gt_label_id"].to_numpy()
|
|
101
|
+
gt_label_ids, gt_indices = np.unique(gt_label_ids, return_index=True)
|
|
102
|
+
gt_labels = tbl["gt_label"].take(gt_indices).to_pylist()
|
|
103
|
+
gt_labels = dict(zip(gt_label_ids.astype(int).tolist(), gt_labels))
|
|
104
|
+
gt_labels.pop(-1, None)
|
|
105
|
+
index_to_label.update(gt_labels)
|
|
106
|
+
|
|
107
|
+
# get pd labels
|
|
108
|
+
pd_label_ids = tbl["pd_label_id"].to_numpy()
|
|
109
|
+
pd_label_ids, pd_indices = np.unique(pd_label_ids, return_index=True)
|
|
110
|
+
pd_labels = tbl["pd_label"].take(pd_indices).to_pylist()
|
|
111
|
+
pd_labels = dict(zip(pd_label_ids.astype(int).tolist(), pd_labels))
|
|
112
|
+
pd_labels.pop(-1, None)
|
|
113
|
+
index_to_label.update(pd_labels)
|
|
114
|
+
|
|
115
|
+
return index_to_label
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def extract_counts(
|
|
119
|
+
reader: MemoryCacheReader | FileCacheReader,
|
|
120
|
+
datums: pc.Expression | None = None,
|
|
121
|
+
groundtruths: pc.Expression | None = None,
|
|
122
|
+
predictions: pc.Expression | None = None,
|
|
123
|
+
):
|
|
124
|
+
n_dts, n_total, n_gts, n_pds = 0, 0, 0, 0
|
|
125
|
+
for tbl in reader.iterate_tables(filter=datums):
|
|
126
|
+
|
|
127
|
+
# count datums
|
|
128
|
+
n_dts += int(np.unique(tbl["datum_id"].to_numpy()).shape[0])
|
|
129
|
+
|
|
130
|
+
# count pixels
|
|
131
|
+
n_total += int(tbl["count"].to_numpy().sum())
|
|
132
|
+
|
|
133
|
+
# count groundtruth pixels
|
|
134
|
+
gt_tbl = tbl
|
|
135
|
+
gt_expr = pc.field("gt_label_id") >= 0
|
|
136
|
+
if groundtruths is not None:
|
|
137
|
+
gt_expr &= groundtruths
|
|
138
|
+
gt_tbl = tbl.filter(gt_expr)
|
|
139
|
+
n_gts += int(gt_tbl["count"].to_numpy().sum())
|
|
140
|
+
|
|
141
|
+
# count prediction pixels
|
|
142
|
+
pd_tbl = tbl
|
|
143
|
+
pd_expr = pc.field("pd_label_id") >= 0
|
|
144
|
+
if predictions is not None:
|
|
145
|
+
pd_expr &= predictions
|
|
146
|
+
pd_tbl = tbl.filter(pd_expr)
|
|
147
|
+
n_pds += int(pd_tbl["count"].to_numpy().sum())
|
|
148
|
+
|
|
149
|
+
return n_dts, n_total, n_gts, n_pds
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from valor_lite.semantic_segmentation.metric import Metric, MetricType
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def unpack_precision_recall_iou_into_metric_lists(
|
|
7
|
+
results: tuple,
|
|
8
|
+
index_to_label: dict[int, str],
|
|
9
|
+
) -> dict[MetricType, list[Metric]]:
|
|
10
|
+
|
|
11
|
+
n_labels = len(index_to_label)
|
|
12
|
+
(
|
|
13
|
+
precision,
|
|
14
|
+
recall,
|
|
15
|
+
f1_score,
|
|
16
|
+
accuracy,
|
|
17
|
+
ious,
|
|
18
|
+
unmatched_prediction_ratios,
|
|
19
|
+
unmatched_ground_truth_ratios,
|
|
20
|
+
) = results
|
|
21
|
+
|
|
22
|
+
metrics = defaultdict(list)
|
|
23
|
+
|
|
24
|
+
metrics[MetricType.Accuracy] = [
|
|
25
|
+
Metric.accuracy(
|
|
26
|
+
value=float(accuracy),
|
|
27
|
+
)
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
metrics[MetricType.ConfusionMatrix] = [
|
|
31
|
+
Metric.confusion_matrix(
|
|
32
|
+
confusion_matrix={
|
|
33
|
+
index_to_label[gt_label_idx]: {
|
|
34
|
+
index_to_label[pd_label_idx]: {
|
|
35
|
+
"iou": float(ious[gt_label_idx, pd_label_idx])
|
|
36
|
+
}
|
|
37
|
+
for pd_label_idx in range(n_labels)
|
|
38
|
+
}
|
|
39
|
+
for gt_label_idx in range(n_labels)
|
|
40
|
+
},
|
|
41
|
+
unmatched_predictions={
|
|
42
|
+
index_to_label[pd_label_idx]: {
|
|
43
|
+
"ratio": float(unmatched_prediction_ratios[pd_label_idx])
|
|
44
|
+
}
|
|
45
|
+
for pd_label_idx in range(n_labels)
|
|
46
|
+
},
|
|
47
|
+
unmatched_ground_truths={
|
|
48
|
+
index_to_label[gt_label_idx]: {
|
|
49
|
+
"ratio": float(unmatched_ground_truth_ratios[gt_label_idx])
|
|
50
|
+
}
|
|
51
|
+
for gt_label_idx in range(n_labels)
|
|
52
|
+
},
|
|
53
|
+
)
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
metrics[MetricType.mIOU] = [
|
|
57
|
+
Metric.mean_iou(
|
|
58
|
+
value=float(ious.diagonal().mean()),
|
|
59
|
+
)
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
for label_idx, label in index_to_label.items():
|
|
63
|
+
metrics[MetricType.Precision].append(
|
|
64
|
+
Metric.precision(
|
|
65
|
+
value=float(precision[label_idx]),
|
|
66
|
+
label=label,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
metrics[MetricType.Recall].append(
|
|
70
|
+
Metric.recall(
|
|
71
|
+
value=float(recall[label_idx]),
|
|
72
|
+
label=label,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
metrics[MetricType.F1].append(
|
|
76
|
+
Metric.f1_score(
|
|
77
|
+
value=float(f1_score[label_idx]),
|
|
78
|
+
label=label,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
metrics[MetricType.IOU].append(
|
|
82
|
+
Metric.iou(
|
|
83
|
+
value=float(ious[label_idx, label_idx]),
|
|
84
|
+
label=label,
|
|
85
|
+
)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return metrics
|