valor-lite 0.37.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/LICENSE +21 -0
- valor_lite/__init__.py +0 -0
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +154 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +529 -0
- valor_lite/classification/__init__.py +14 -0
- valor_lite/classification/annotation.py +45 -0
- valor_lite/classification/computation.py +378 -0
- valor_lite/classification/evaluator.py +879 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +535 -0
- valor_lite/classification/numpy_compatibility.py +13 -0
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +314 -0
- valor_lite/exceptions.py +20 -0
- valor_lite/object_detection/__init__.py +17 -0
- valor_lite/object_detection/annotation.py +238 -0
- valor_lite/object_detection/computation.py +841 -0
- valor_lite/object_detection/evaluator.py +805 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +850 -0
- valor_lite/object_detection/shared.py +185 -0
- valor_lite/object_detection/utilities.py +396 -0
- valor_lite/schemas.py +11 -0
- valor_lite/semantic_segmentation/__init__.py +15 -0
- valor_lite/semantic_segmentation/annotation.py +123 -0
- valor_lite/semantic_segmentation/computation.py +165 -0
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/metric.py +275 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +88 -0
- valor_lite/text_generation/__init__.py +15 -0
- valor_lite/text_generation/annotation.py +56 -0
- valor_lite/text_generation/computation.py +611 -0
- valor_lite/text_generation/llm/__init__.py +0 -0
- valor_lite/text_generation/llm/exceptions.py +14 -0
- valor_lite/text_generation/llm/generation.py +903 -0
- valor_lite/text_generation/llm/instructions.py +814 -0
- valor_lite/text_generation/llm/integrations.py +226 -0
- valor_lite/text_generation/llm/utilities.py +43 -0
- valor_lite/text_generation/llm/validators.py +68 -0
- valor_lite/text_generation/manager.py +697 -0
- valor_lite/text_generation/metric.py +381 -0
- valor_lite-0.37.1.dist-info/METADATA +174 -0
- valor_lite-0.37.1.dist-info/RECORD +49 -0
- valor_lite-0.37.1.dist-info/WHEEL +5 -0
- valor_lite-0.37.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pyarrow as pa
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from valor_lite.cache.ephemeral import MemoryCacheWriter
|
|
6
|
+
from valor_lite.cache.persistent import FileCacheWriter
|
|
7
|
+
from valor_lite.classification.annotation import Classification
|
|
8
|
+
from valor_lite.classification.evaluator import Builder
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Loader(Builder):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
writer: MemoryCacheWriter | FileCacheWriter,
|
|
15
|
+
roc_curve_writer: MemoryCacheWriter | FileCacheWriter,
|
|
16
|
+
intermediate_writer: MemoryCacheWriter | FileCacheWriter,
|
|
17
|
+
metadata_fields: list[tuple[str, str | pa.DataType]] | None = None,
|
|
18
|
+
):
|
|
19
|
+
super().__init__(
|
|
20
|
+
writer=writer,
|
|
21
|
+
roc_curve_writer=roc_curve_writer,
|
|
22
|
+
intermediate_writer=intermediate_writer,
|
|
23
|
+
metadata_fields=metadata_fields,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# internal state
|
|
27
|
+
self._labels: dict[str, int] = {}
|
|
28
|
+
self._index_to_label: dict[int, str] = {}
|
|
29
|
+
self._datum_count = 0
|
|
30
|
+
|
|
31
|
+
def _add_label(self, value: str) -> int:
|
|
32
|
+
idx = self._labels.get(value, None)
|
|
33
|
+
if idx is None:
|
|
34
|
+
idx = len(self._labels)
|
|
35
|
+
self._labels[value] = idx
|
|
36
|
+
self._index_to_label[idx] = value
|
|
37
|
+
return idx
|
|
38
|
+
|
|
39
|
+
def add_data(
|
|
40
|
+
self,
|
|
41
|
+
classifications: list[Classification],
|
|
42
|
+
show_progress: bool = False,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Adds classifications to the cache.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
classifications : list[Classification]
|
|
50
|
+
A list of Classification objects.
|
|
51
|
+
show_progress : bool, default=False
|
|
52
|
+
Toggle for tqdm progress bar.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
disable_tqdm = not show_progress
|
|
56
|
+
for classification in tqdm(classifications, disable=disable_tqdm):
|
|
57
|
+
if len(classification.predictions) == 0:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"Classifications must contain at least one prediction."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# prepare metadata
|
|
63
|
+
datum_metadata = (
|
|
64
|
+
classification.metadata if classification.metadata else {}
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# write to cache
|
|
68
|
+
rows = []
|
|
69
|
+
gidx = self._add_label(classification.groundtruth)
|
|
70
|
+
max_score_idx = np.argmax(np.array(classification.scores))
|
|
71
|
+
for idx, (plabel, score) in enumerate(
|
|
72
|
+
zip(classification.predictions, classification.scores)
|
|
73
|
+
):
|
|
74
|
+
pidx = self._add_label(plabel)
|
|
75
|
+
rows.append(
|
|
76
|
+
{
|
|
77
|
+
# metadata
|
|
78
|
+
**datum_metadata,
|
|
79
|
+
# datum
|
|
80
|
+
"datum_uid": classification.uid,
|
|
81
|
+
"datum_id": self._datum_count,
|
|
82
|
+
# groundtruth
|
|
83
|
+
"gt_label": classification.groundtruth,
|
|
84
|
+
"gt_label_id": gidx,
|
|
85
|
+
# prediction
|
|
86
|
+
"pd_label": plabel,
|
|
87
|
+
"pd_label_id": pidx,
|
|
88
|
+
"pd_score": float(score),
|
|
89
|
+
"pd_winner": max_score_idx == idx,
|
|
90
|
+
# pair
|
|
91
|
+
"match": (gidx == pidx) and pidx >= 0,
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
self._writer.write_rows(rows)
|
|
95
|
+
|
|
96
|
+
# update datum count
|
|
97
|
+
self._datum_count += 1
|
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from valor_lite.schemas import BaseMetric
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MetricType(Enum):
|
|
8
|
+
Counts = "Counts"
|
|
9
|
+
ROCAUC = "ROCAUC"
|
|
10
|
+
mROCAUC = "mROCAUC"
|
|
11
|
+
Precision = "Precision"
|
|
12
|
+
Recall = "Recall"
|
|
13
|
+
Accuracy = "Accuracy"
|
|
14
|
+
F1 = "F1"
|
|
15
|
+
ConfusionMatrix = "ConfusionMatrix"
|
|
16
|
+
Examples = "Examples"
|
|
17
|
+
ConfusionMatrixWithExamples = "ConfusionMatrixWithExamples"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class Metric(BaseMetric):
|
|
22
|
+
"""
|
|
23
|
+
Classification Metric.
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
type : str
|
|
28
|
+
The metric type.
|
|
29
|
+
value : int | float | dict
|
|
30
|
+
The metric value.
|
|
31
|
+
parameters : dict[str, Any]
|
|
32
|
+
A dictionary containing metric parameters.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
if not isinstance(self.type, str):
|
|
37
|
+
raise TypeError(
|
|
38
|
+
f"Metric type should be of type 'str': {self.type}"
|
|
39
|
+
)
|
|
40
|
+
elif not isinstance(self.value, (int, float, dict)):
|
|
41
|
+
raise TypeError(
|
|
42
|
+
f"Metric value must be of type 'int', 'float' or 'dict': {self.value}"
|
|
43
|
+
)
|
|
44
|
+
elif not isinstance(self.parameters, dict):
|
|
45
|
+
raise TypeError(
|
|
46
|
+
f"Metric parameters must be of type 'dict[str, Any]': {self.parameters}"
|
|
47
|
+
)
|
|
48
|
+
elif not all([isinstance(k, str) for k in self.parameters.keys()]):
|
|
49
|
+
raise TypeError(
|
|
50
|
+
f"Metric parameter dictionary should only have keys with type 'str': {self.parameters}"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def precision(
|
|
55
|
+
cls,
|
|
56
|
+
value: float,
|
|
57
|
+
score_threshold: float,
|
|
58
|
+
hardmax: bool,
|
|
59
|
+
label: str,
|
|
60
|
+
):
|
|
61
|
+
"""
|
|
62
|
+
Precision metric for a specific class label.
|
|
63
|
+
|
|
64
|
+
This class calculates the precision at a specific score threshold.
|
|
65
|
+
Precision is defined as the ratio of true positives to the sum of
|
|
66
|
+
true positives and false positives.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
value : float
|
|
71
|
+
Precision value computed at a specific score threshold.
|
|
72
|
+
score_threshold : float
|
|
73
|
+
Score threshold at which the precision value is computed.
|
|
74
|
+
hardmax : bool
|
|
75
|
+
Indicates whether hardmax thresholding was used.
|
|
76
|
+
label : str
|
|
77
|
+
The class label for which the precision is computed.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
Metric
|
|
82
|
+
"""
|
|
83
|
+
return cls(
|
|
84
|
+
type=MetricType.Precision.value,
|
|
85
|
+
value=value,
|
|
86
|
+
parameters={
|
|
87
|
+
"score_threshold": score_threshold,
|
|
88
|
+
"hardmax": hardmax,
|
|
89
|
+
"label": label,
|
|
90
|
+
},
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def recall(
|
|
95
|
+
cls,
|
|
96
|
+
value: float,
|
|
97
|
+
score_threshold: float,
|
|
98
|
+
hardmax: bool,
|
|
99
|
+
label: str,
|
|
100
|
+
):
|
|
101
|
+
"""
|
|
102
|
+
Recall metric for a specific class label.
|
|
103
|
+
|
|
104
|
+
This class calculates the recall at a specific score threshold.
|
|
105
|
+
Recall is defined as the ratio of true positives to the sum of
|
|
106
|
+
true positives and false negatives.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
value : float
|
|
111
|
+
Recall value computed at a specific score threshold.
|
|
112
|
+
score_threshold : float
|
|
113
|
+
Score threshold at which the recall value is computed.
|
|
114
|
+
hardmax : bool
|
|
115
|
+
Indicates whether hardmax thresholding was used.
|
|
116
|
+
label : str
|
|
117
|
+
The class label for which the recall is computed.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
Metric
|
|
122
|
+
"""
|
|
123
|
+
return cls(
|
|
124
|
+
type=MetricType.Recall.value,
|
|
125
|
+
value=value,
|
|
126
|
+
parameters={
|
|
127
|
+
"score_threshold": score_threshold,
|
|
128
|
+
"hardmax": hardmax,
|
|
129
|
+
"label": label,
|
|
130
|
+
},
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def f1_score(
|
|
135
|
+
cls,
|
|
136
|
+
value: float,
|
|
137
|
+
score_threshold: float,
|
|
138
|
+
hardmax: bool,
|
|
139
|
+
label: str,
|
|
140
|
+
):
|
|
141
|
+
"""
|
|
142
|
+
F1 score for a specific class label and confidence score threshold.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
value : float
|
|
147
|
+
F1 score computed at a specific score threshold.
|
|
148
|
+
score_threshold : float
|
|
149
|
+
Score threshold at which the F1 score is computed.
|
|
150
|
+
hardmax : bool
|
|
151
|
+
Indicates whether hardmax thresholding was used.
|
|
152
|
+
label : str
|
|
153
|
+
The class label for which the F1 score is computed.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
Metric
|
|
158
|
+
"""
|
|
159
|
+
return cls(
|
|
160
|
+
type=MetricType.F1.value,
|
|
161
|
+
value=value,
|
|
162
|
+
parameters={
|
|
163
|
+
"score_threshold": score_threshold,
|
|
164
|
+
"hardmax": hardmax,
|
|
165
|
+
"label": label,
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def accuracy(
|
|
171
|
+
cls,
|
|
172
|
+
value: float,
|
|
173
|
+
score_threshold: float,
|
|
174
|
+
hardmax: bool,
|
|
175
|
+
):
|
|
176
|
+
"""
|
|
177
|
+
Multiclass accuracy metric.
|
|
178
|
+
|
|
179
|
+
This class calculates the accuracy at various score thresholds.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
value : float
|
|
184
|
+
Accuracy value computed at a specific score threshold.
|
|
185
|
+
score_threshold : float
|
|
186
|
+
Score threshold at which the accuracy value is computed.
|
|
187
|
+
hardmax : bool
|
|
188
|
+
Indicates whether hardmax thresholding was used.
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
Metric
|
|
193
|
+
"""
|
|
194
|
+
return cls(
|
|
195
|
+
type=MetricType.Accuracy.value,
|
|
196
|
+
value=value,
|
|
197
|
+
parameters={
|
|
198
|
+
"score_threshold": score_threshold,
|
|
199
|
+
"hardmax": hardmax,
|
|
200
|
+
},
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def roc_auc(
|
|
205
|
+
cls,
|
|
206
|
+
value: float,
|
|
207
|
+
label: str,
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
Receiver Operating Characteristic Area Under the Curve (ROC AUC).
|
|
211
|
+
|
|
212
|
+
This class calculates the ROC AUC score for a specific class label in a multiclass classification task.
|
|
213
|
+
ROC AUC is a performance measurement for classification problems at various threshold settings.
|
|
214
|
+
It reflects the ability of the classifier to distinguish between the positive and negative classes.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
value : float
|
|
219
|
+
The computed ROC AUC score.
|
|
220
|
+
label : str
|
|
221
|
+
The class label for which the ROC AUC is computed.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
Metric
|
|
226
|
+
"""
|
|
227
|
+
return cls(
|
|
228
|
+
type=MetricType.ROCAUC.value,
|
|
229
|
+
value=value,
|
|
230
|
+
parameters={
|
|
231
|
+
"label": label,
|
|
232
|
+
},
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
@classmethod
|
|
236
|
+
def mean_roc_auc(cls, value: float):
|
|
237
|
+
"""
|
|
238
|
+
Mean Receiver Operating Characteristic Area Under the Curve (mROC AUC).
|
|
239
|
+
|
|
240
|
+
This class calculates the mean ROC AUC score over all classes in a multiclass classification task.
|
|
241
|
+
It provides an aggregate measure of the model's ability to distinguish between classes.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
value : float
|
|
246
|
+
The computed mean ROC AUC score.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
Metric
|
|
251
|
+
"""
|
|
252
|
+
return cls(type=MetricType.mROCAUC.value, value=value, parameters={})
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def counts(
|
|
256
|
+
cls,
|
|
257
|
+
tp: int,
|
|
258
|
+
fp: int,
|
|
259
|
+
fn: int,
|
|
260
|
+
tn: int,
|
|
261
|
+
score_threshold: float,
|
|
262
|
+
hardmax: bool,
|
|
263
|
+
label: str,
|
|
264
|
+
):
|
|
265
|
+
"""
|
|
266
|
+
Confusion matrix counts at specified score thresholds for binary classification.
|
|
267
|
+
|
|
268
|
+
This class stores the true positive (`tp`), false positive (`fp`), false negative (`fn`), and true
|
|
269
|
+
negative (`tn`) counts computed at various score thresholds for a binary classification task.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
tp : int
|
|
274
|
+
True positive counts at each score threshold.
|
|
275
|
+
fp : int
|
|
276
|
+
False positive counts at each score threshold.
|
|
277
|
+
fn : int
|
|
278
|
+
False negative counts at each score threshold.
|
|
279
|
+
tn : int
|
|
280
|
+
True negative counts at each score threshold.
|
|
281
|
+
score_threshold : float
|
|
282
|
+
Score thresholds at which the counts are computed.
|
|
283
|
+
hardmax : bool
|
|
284
|
+
Indicates whether hardmax thresholding was used.
|
|
285
|
+
label : str
|
|
286
|
+
The class label for which the counts are computed.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
Metric
|
|
291
|
+
"""
|
|
292
|
+
return cls(
|
|
293
|
+
type=MetricType.Counts.value,
|
|
294
|
+
value={
|
|
295
|
+
"tp": tp,
|
|
296
|
+
"fp": fp,
|
|
297
|
+
"fn": fn,
|
|
298
|
+
"tn": tn,
|
|
299
|
+
},
|
|
300
|
+
parameters={
|
|
301
|
+
"score_threshold": score_threshold,
|
|
302
|
+
"hardmax": hardmax,
|
|
303
|
+
"label": label,
|
|
304
|
+
},
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
@classmethod
|
|
308
|
+
def confusion_matrix(
|
|
309
|
+
cls,
|
|
310
|
+
confusion_matrix: dict[str, dict[str, int]],
|
|
311
|
+
unmatched_ground_truths: dict[str, int],
|
|
312
|
+
score_threshold: float,
|
|
313
|
+
hardmax: bool,
|
|
314
|
+
):
|
|
315
|
+
"""
|
|
316
|
+
Confusion matrix for object detection task.
|
|
317
|
+
|
|
318
|
+
This class encapsulates detailed information about the model's performance, including correct
|
|
319
|
+
predictions, misclassifications and unmatched ground truths (subset of false negatives).
|
|
320
|
+
|
|
321
|
+
Confusion Matrix Format:
|
|
322
|
+
{
|
|
323
|
+
<ground truth label>: {
|
|
324
|
+
<prediction label>: 129
|
|
325
|
+
...
|
|
326
|
+
},
|
|
327
|
+
...
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
Unmatched Ground Truths Format:
|
|
331
|
+
{
|
|
332
|
+
<ground truth label>: 7
|
|
333
|
+
...
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
confusion_matrix : dict
|
|
339
|
+
A nested dictionary containing integer counts of occurences where the first key is the ground truth label value
|
|
340
|
+
and the second key is the prediction label value.
|
|
341
|
+
unmatched_ground_truths : dict
|
|
342
|
+
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
343
|
+
(subset of false negatives). The value is a dictionary containing counts.
|
|
344
|
+
score_threshold : float
|
|
345
|
+
The confidence score threshold used to filter predictions.
|
|
346
|
+
hardmax : bool
|
|
347
|
+
Indicates whether hardmax thresholding was used.
|
|
348
|
+
|
|
349
|
+
Returns
|
|
350
|
+
-------
|
|
351
|
+
Metric
|
|
352
|
+
"""
|
|
353
|
+
return cls(
|
|
354
|
+
type=MetricType.ConfusionMatrix.value,
|
|
355
|
+
value={
|
|
356
|
+
"confusion_matrix": confusion_matrix,
|
|
357
|
+
"unmatched_ground_truths": unmatched_ground_truths,
|
|
358
|
+
},
|
|
359
|
+
parameters={
|
|
360
|
+
"score_threshold": score_threshold,
|
|
361
|
+
"hardmax": hardmax,
|
|
362
|
+
},
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
@classmethod
|
|
366
|
+
def examples(
|
|
367
|
+
cls,
|
|
368
|
+
datum_id: str,
|
|
369
|
+
true_positives: list[str],
|
|
370
|
+
false_positives: list[str],
|
|
371
|
+
false_negatives: list[str],
|
|
372
|
+
score_threshold: float,
|
|
373
|
+
hardmax: bool,
|
|
374
|
+
):
|
|
375
|
+
"""
|
|
376
|
+
Per-datum examples for object detection tasks.
|
|
377
|
+
|
|
378
|
+
This metric is per-datum and contains lists of annotation identifiers that categorize them
|
|
379
|
+
as true-positive, false-positive or false-negative. This is intended to be used with an
|
|
380
|
+
external database where the identifiers can be used for retrieval.
|
|
381
|
+
|
|
382
|
+
Examples Format:
|
|
383
|
+
{
|
|
384
|
+
"type": "Examples",
|
|
385
|
+
"value": {
|
|
386
|
+
"datum_id": "some string ID",
|
|
387
|
+
"true_positives": [
|
|
388
|
+
"label A",
|
|
389
|
+
],
|
|
390
|
+
"false_positives": [
|
|
391
|
+
"label 25",
|
|
392
|
+
"label 92",
|
|
393
|
+
...
|
|
394
|
+
]
|
|
395
|
+
"false_negatives": [
|
|
396
|
+
"groundtruth32",
|
|
397
|
+
"groundtruth24",
|
|
398
|
+
...
|
|
399
|
+
]
|
|
400
|
+
},
|
|
401
|
+
"parameters": {
|
|
402
|
+
"score_threshold": 0.5,
|
|
403
|
+
"hardmax": False,
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
Parameters
|
|
408
|
+
----------
|
|
409
|
+
datum_id : str
|
|
410
|
+
A string identifier representing a datum.
|
|
411
|
+
true_positives : list[str]
|
|
412
|
+
A list of string identifier representing true positive labels.
|
|
413
|
+
false_positives : list[str]
|
|
414
|
+
A list of string identifiers representing false positive predictions.
|
|
415
|
+
false_negatives : list[str]
|
|
416
|
+
A list of string identifiers representing false negative ground truths.
|
|
417
|
+
score_threshold : float
|
|
418
|
+
The confidence score threshold used to filter predictions.
|
|
419
|
+
hardmax : bool
|
|
420
|
+
Indicates whether hardmax thresholding was used.
|
|
421
|
+
|
|
422
|
+
Returns
|
|
423
|
+
-------
|
|
424
|
+
Metric
|
|
425
|
+
"""
|
|
426
|
+
return cls(
|
|
427
|
+
type=MetricType.Examples.value,
|
|
428
|
+
value={
|
|
429
|
+
"datum_id": datum_id,
|
|
430
|
+
"true_positives": true_positives,
|
|
431
|
+
"false_positives": false_positives,
|
|
432
|
+
"false_negatives": false_negatives,
|
|
433
|
+
},
|
|
434
|
+
parameters={
|
|
435
|
+
"score_threshold": score_threshold,
|
|
436
|
+
"hardmax": hardmax,
|
|
437
|
+
},
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
@classmethod
|
|
441
|
+
def confusion_matrix_with_examples(
|
|
442
|
+
cls,
|
|
443
|
+
confusion_matrix: dict[
|
|
444
|
+
str, # ground truth label value
|
|
445
|
+
dict[
|
|
446
|
+
str, # prediction label value
|
|
447
|
+
dict[
|
|
448
|
+
str, # either `count` or `examples`
|
|
449
|
+
int
|
|
450
|
+
| list[
|
|
451
|
+
dict[
|
|
452
|
+
str, # either `datum` or `score`
|
|
453
|
+
str | float, # datum uid # prediction score
|
|
454
|
+
]
|
|
455
|
+
],
|
|
456
|
+
],
|
|
457
|
+
],
|
|
458
|
+
],
|
|
459
|
+
unmatched_ground_truths: dict[
|
|
460
|
+
str, # ground truth label value
|
|
461
|
+
dict[
|
|
462
|
+
str, # either `count` or `examples`
|
|
463
|
+
int | list[dict[str, str]], # count or datum examples
|
|
464
|
+
],
|
|
465
|
+
],
|
|
466
|
+
score_threshold: float,
|
|
467
|
+
hardmax: bool,
|
|
468
|
+
):
|
|
469
|
+
"""
|
|
470
|
+
The confusion matrix with examples for the classification task.
|
|
471
|
+
|
|
472
|
+
This class encapsulates detailed information about the model's performance, including correct
|
|
473
|
+
predictions, misclassifications and unmatched ground truths (subset of false negatives).
|
|
474
|
+
It provides counts and examples for each category to facilitate in-depth analysis.
|
|
475
|
+
|
|
476
|
+
Confusion Matrix Structure:
|
|
477
|
+
{
|
|
478
|
+
ground_truth_label: {
|
|
479
|
+
predicted_label: {
|
|
480
|
+
'count': int,
|
|
481
|
+
'examples': [
|
|
482
|
+
{
|
|
483
|
+
"datum_id": str,
|
|
484
|
+
"score": float
|
|
485
|
+
},
|
|
486
|
+
...
|
|
487
|
+
],
|
|
488
|
+
},
|
|
489
|
+
...
|
|
490
|
+
},
|
|
491
|
+
...
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
Unmatched Ground Truths Structure:
|
|
495
|
+
{
|
|
496
|
+
ground_truth_label: {
|
|
497
|
+
'count': int,
|
|
498
|
+
'examples': [
|
|
499
|
+
{
|
|
500
|
+
"datum_id": str
|
|
501
|
+
},
|
|
502
|
+
...
|
|
503
|
+
],
|
|
504
|
+
},
|
|
505
|
+
...
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
Parameters
|
|
509
|
+
----------
|
|
510
|
+
confusion_matrix : dict
|
|
511
|
+
A nested dictionary where the first key is the ground truth label value, the second key
|
|
512
|
+
is the prediction label value, and the innermost dictionary contains either a `count`
|
|
513
|
+
or a list of `examples`. Each example includes the datum UID and prediction score.
|
|
514
|
+
unmatched_ground_truths : dict
|
|
515
|
+
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
516
|
+
(false negatives). The value is a dictionary containing either a `count` or a list of `examples`.
|
|
517
|
+
Each example includes the datum UID.
|
|
518
|
+
hardmax : bool
|
|
519
|
+
Indicates whether hardmax thresholding was used.
|
|
520
|
+
|
|
521
|
+
Returns
|
|
522
|
+
-------
|
|
523
|
+
Metric
|
|
524
|
+
"""
|
|
525
|
+
return cls(
|
|
526
|
+
type=MetricType.ConfusionMatrixWithExamples.value,
|
|
527
|
+
value={
|
|
528
|
+
"confusion_matrix": confusion_matrix,
|
|
529
|
+
"unmatched_ground_truths": unmatched_ground_truths,
|
|
530
|
+
},
|
|
531
|
+
parameters={
|
|
532
|
+
"score_threshold": score_threshold,
|
|
533
|
+
"hardmax": hardmax,
|
|
534
|
+
},
|
|
535
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numpy.typing import NDArray
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
_numpy_trapezoid = np.trapezoid # numpy v2
|
|
6
|
+
except AttributeError:
|
|
7
|
+
_numpy_trapezoid = np.trapz # numpy v1
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def trapezoid(
|
|
11
|
+
x: NDArray[np.float64], y: NDArray[np.float64], axis: int
|
|
12
|
+
) -> NDArray[np.float64]:
|
|
13
|
+
return _numpy_trapezoid(x=x, y=y, axis=axis) # type: ignore - NumPy compatibility
|