valor-lite 0.37.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/LICENSE +21 -0
- valor_lite/__init__.py +0 -0
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +154 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +529 -0
- valor_lite/classification/__init__.py +14 -0
- valor_lite/classification/annotation.py +45 -0
- valor_lite/classification/computation.py +378 -0
- valor_lite/classification/evaluator.py +879 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +535 -0
- valor_lite/classification/numpy_compatibility.py +13 -0
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +314 -0
- valor_lite/exceptions.py +20 -0
- valor_lite/object_detection/__init__.py +17 -0
- valor_lite/object_detection/annotation.py +238 -0
- valor_lite/object_detection/computation.py +841 -0
- valor_lite/object_detection/evaluator.py +805 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +850 -0
- valor_lite/object_detection/shared.py +185 -0
- valor_lite/object_detection/utilities.py +396 -0
- valor_lite/schemas.py +11 -0
- valor_lite/semantic_segmentation/__init__.py +15 -0
- valor_lite/semantic_segmentation/annotation.py +123 -0
- valor_lite/semantic_segmentation/computation.py +165 -0
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/metric.py +275 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +88 -0
- valor_lite/text_generation/__init__.py +15 -0
- valor_lite/text_generation/annotation.py +56 -0
- valor_lite/text_generation/computation.py +611 -0
- valor_lite/text_generation/llm/__init__.py +0 -0
- valor_lite/text_generation/llm/exceptions.py +14 -0
- valor_lite/text_generation/llm/generation.py +903 -0
- valor_lite/text_generation/llm/instructions.py +814 -0
- valor_lite/text_generation/llm/integrations.py +226 -0
- valor_lite/text_generation/llm/utilities.py +43 -0
- valor_lite/text_generation/llm/validators.py +68 -0
- valor_lite/text_generation/manager.py +697 -0
- valor_lite/text_generation/metric.py +381 -0
- valor_lite-0.37.1.dist-info/METADATA +174 -0
- valor_lite-0.37.1.dist-info/RECORD +49 -0
- valor_lite-0.37.1.dist-info/WHEEL +5 -0
- valor_lite-0.37.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,850 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from valor_lite.schemas import BaseMetric
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MetricType(str, Enum):
|
|
8
|
+
Counts = "Counts"
|
|
9
|
+
Precision = "Precision"
|
|
10
|
+
Recall = "Recall"
|
|
11
|
+
F1 = "F1"
|
|
12
|
+
AP = "AP"
|
|
13
|
+
AR = "AR"
|
|
14
|
+
mAP = "mAP"
|
|
15
|
+
mAR = "mAR"
|
|
16
|
+
APAveragedOverIOUs = "APAveragedOverIOUs"
|
|
17
|
+
mAPAveragedOverIOUs = "mAPAveragedOverIOUs"
|
|
18
|
+
ARAveragedOverScores = "ARAveragedOverScores"
|
|
19
|
+
mARAveragedOverScores = "mARAveragedOverScores"
|
|
20
|
+
PrecisionRecallCurve = "PrecisionRecallCurve"
|
|
21
|
+
ConfusionMatrixWithExamples = "ConfusionMatrixWithExamples"
|
|
22
|
+
ConfusionMatrix = "ConfusionMatrix"
|
|
23
|
+
Examples = "Examples"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class Metric(BaseMetric):
|
|
28
|
+
"""
|
|
29
|
+
Object Detection Metric.
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
type : str
|
|
34
|
+
The metric type.
|
|
35
|
+
value : int | float | dict
|
|
36
|
+
The metric value.
|
|
37
|
+
parameters : dict[str, Any]
|
|
38
|
+
A dictionary containing metric parameters.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __post_init__(self):
|
|
42
|
+
if not isinstance(self.type, str):
|
|
43
|
+
raise TypeError(
|
|
44
|
+
f"Metric type should be of type 'str': {self.type}"
|
|
45
|
+
)
|
|
46
|
+
elif not isinstance(self.value, (int, float, dict)):
|
|
47
|
+
raise TypeError(
|
|
48
|
+
f"Metric value must be of type 'int', 'float' or 'dict': {self.value}"
|
|
49
|
+
)
|
|
50
|
+
elif not isinstance(self.parameters, dict):
|
|
51
|
+
raise TypeError(
|
|
52
|
+
f"Metric parameters must be of type 'dict[str, Any]': {self.parameters}"
|
|
53
|
+
)
|
|
54
|
+
elif not all([isinstance(k, str) for k in self.parameters.keys()]):
|
|
55
|
+
raise TypeError(
|
|
56
|
+
f"Metric parameter dictionary should only have keys with type 'str': {self.parameters}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def precision(
|
|
61
|
+
cls,
|
|
62
|
+
value: float,
|
|
63
|
+
label: str,
|
|
64
|
+
iou_threshold: float,
|
|
65
|
+
score_threshold: float,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Precision metric for a specific class label in object detection.
|
|
69
|
+
|
|
70
|
+
This class encapsulates a metric value for a particular class label,
|
|
71
|
+
along with the associated Intersection over Union (IOU) threshold and
|
|
72
|
+
confidence score threshold.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
value : float
|
|
77
|
+
The metric value.
|
|
78
|
+
label : str
|
|
79
|
+
The class label for which the metric is calculated.
|
|
80
|
+
iou_threshold : float
|
|
81
|
+
The IOU threshold used to determine matches between predicted and ground truth boxes.
|
|
82
|
+
score_threshold : float
|
|
83
|
+
The confidence score threshold above which predictions are considered.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
Metric
|
|
88
|
+
"""
|
|
89
|
+
return cls(
|
|
90
|
+
type=MetricType.Precision.value,
|
|
91
|
+
value=value,
|
|
92
|
+
parameters={
|
|
93
|
+
"label": label,
|
|
94
|
+
"iou_threshold": iou_threshold,
|
|
95
|
+
"score_threshold": score_threshold,
|
|
96
|
+
},
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def recall(
|
|
101
|
+
cls,
|
|
102
|
+
value: float,
|
|
103
|
+
label: str,
|
|
104
|
+
iou_threshold: float,
|
|
105
|
+
score_threshold: float,
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Recall metric for a specific class label in object detection.
|
|
109
|
+
|
|
110
|
+
This class encapsulates a metric value for a particular class label,
|
|
111
|
+
along with the associated Intersection over Union (IOU) threshold and
|
|
112
|
+
confidence score threshold.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
value : float
|
|
117
|
+
The metric value.
|
|
118
|
+
label : str
|
|
119
|
+
The class label for which the metric is calculated.
|
|
120
|
+
iou_threshold : float
|
|
121
|
+
The IOU threshold used to determine matches between predicted and ground truth boxes.
|
|
122
|
+
score_threshold : float
|
|
123
|
+
The confidence score threshold above which predictions are considered.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
Metric
|
|
128
|
+
"""
|
|
129
|
+
return cls(
|
|
130
|
+
type=MetricType.Recall.value,
|
|
131
|
+
value=value,
|
|
132
|
+
parameters={
|
|
133
|
+
"label": label,
|
|
134
|
+
"iou_threshold": iou_threshold,
|
|
135
|
+
"score_threshold": score_threshold,
|
|
136
|
+
},
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def f1_score(
|
|
141
|
+
cls,
|
|
142
|
+
value: float,
|
|
143
|
+
label: str,
|
|
144
|
+
iou_threshold: float,
|
|
145
|
+
score_threshold: float,
|
|
146
|
+
):
|
|
147
|
+
"""
|
|
148
|
+
F1 score for a specific class label in object detection.
|
|
149
|
+
|
|
150
|
+
This class encapsulates a metric value for a particular class label,
|
|
151
|
+
along with the associated Intersection over Union (IOU) threshold and
|
|
152
|
+
confidence score threshold.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
value : float
|
|
157
|
+
The metric value.
|
|
158
|
+
label : str
|
|
159
|
+
The class label for which the metric is calculated.
|
|
160
|
+
iou_threshold : float
|
|
161
|
+
The IOU threshold used to determine matches between predicted and ground truth boxes.
|
|
162
|
+
score_threshold : float
|
|
163
|
+
The confidence score threshold above which predictions are considered.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
Metric
|
|
168
|
+
"""
|
|
169
|
+
return cls(
|
|
170
|
+
type=MetricType.F1.value,
|
|
171
|
+
value=value,
|
|
172
|
+
parameters={
|
|
173
|
+
"label": label,
|
|
174
|
+
"iou_threshold": iou_threshold,
|
|
175
|
+
"score_threshold": score_threshold,
|
|
176
|
+
},
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def average_precision(
|
|
181
|
+
cls,
|
|
182
|
+
value: float,
|
|
183
|
+
iou_threshold: float,
|
|
184
|
+
label: str,
|
|
185
|
+
):
|
|
186
|
+
"""
|
|
187
|
+
Average Precision (AP) metric for object detection tasks.
|
|
188
|
+
|
|
189
|
+
The AP computation uses 101-point interpolation, which calculates the average
|
|
190
|
+
precision by interpolating the precision-recall curve at 101 evenly spaced recall
|
|
191
|
+
levels from 0 to 1.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
value : float
|
|
196
|
+
The average precision value.
|
|
197
|
+
iou_threshold : float
|
|
198
|
+
The IOU threshold used to compute the AP.
|
|
199
|
+
label : str
|
|
200
|
+
The class label for which the AP is computed.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
Metric
|
|
205
|
+
"""
|
|
206
|
+
return cls(
|
|
207
|
+
type=MetricType.AP.value,
|
|
208
|
+
value=value,
|
|
209
|
+
parameters={
|
|
210
|
+
"iou_threshold": iou_threshold,
|
|
211
|
+
"label": label,
|
|
212
|
+
},
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
@classmethod
|
|
216
|
+
def mean_average_precision(
|
|
217
|
+
cls,
|
|
218
|
+
value: float,
|
|
219
|
+
iou_threshold: float,
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
Mean Average Precision (mAP) metric for object detection tasks.
|
|
223
|
+
|
|
224
|
+
The AP computation uses 101-point interpolation, which calculates the average
|
|
225
|
+
precision for each class by interpolating the precision-recall curve at 101 evenly
|
|
226
|
+
spaced recall levels from 0 to 1. The mAP is then calculated by averaging these
|
|
227
|
+
values across all class labels.
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
value : float
|
|
232
|
+
The mean average precision value.
|
|
233
|
+
iou_threshold : float
|
|
234
|
+
The IOU threshold used to compute the mAP.
|
|
235
|
+
|
|
236
|
+
Returns
|
|
237
|
+
-------
|
|
238
|
+
Metric
|
|
239
|
+
"""
|
|
240
|
+
return cls(
|
|
241
|
+
type=MetricType.mAP.value,
|
|
242
|
+
value=value,
|
|
243
|
+
parameters={
|
|
244
|
+
"iou_threshold": iou_threshold,
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
@classmethod
|
|
249
|
+
def average_precision_averaged_over_IOUs(
|
|
250
|
+
cls,
|
|
251
|
+
value: float,
|
|
252
|
+
iou_thresholds: list[float],
|
|
253
|
+
label: str,
|
|
254
|
+
):
|
|
255
|
+
"""
|
|
256
|
+
Average Precision (AP) metric averaged over multiple IOU thresholds.
|
|
257
|
+
|
|
258
|
+
The AP computation uses 101-point interpolation, which calculates the average precision
|
|
259
|
+
by interpolating the precision-recall curve at 101 evenly spaced recall levels from 0 to 1
|
|
260
|
+
for each IOU threshold specified in `iou_thresholds`. The final APAveragedOverIOUs value is
|
|
261
|
+
obtained by averaging these AP values across all specified IOU thresholds.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
value : float
|
|
266
|
+
The average precision value averaged over the specified IOU thresholds.
|
|
267
|
+
iou_thresholds : list[float]
|
|
268
|
+
The list of IOU thresholds used to compute the AP values.
|
|
269
|
+
label : str
|
|
270
|
+
The class label for which the AP is computed.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
Metric
|
|
275
|
+
"""
|
|
276
|
+
return cls(
|
|
277
|
+
type=MetricType.APAveragedOverIOUs.value,
|
|
278
|
+
value=value,
|
|
279
|
+
parameters={
|
|
280
|
+
"iou_thresholds": iou_thresholds,
|
|
281
|
+
"label": label,
|
|
282
|
+
},
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
@classmethod
|
|
286
|
+
def mean_average_precision_averaged_over_IOUs(
|
|
287
|
+
cls,
|
|
288
|
+
value: float,
|
|
289
|
+
iou_thresholds: list[float],
|
|
290
|
+
):
|
|
291
|
+
"""
|
|
292
|
+
Mean Average Precision (mAP) metric averaged over multiple IOU thresholds.
|
|
293
|
+
|
|
294
|
+
The AP computation uses 101-point interpolation, which calculates the average precision
|
|
295
|
+
by interpolating the precision-recall curve at 101 evenly spaced recall levels from 0 to 1
|
|
296
|
+
for each IOU threshold specified in `iou_thresholds`. The final mAPAveragedOverIOUs value is
|
|
297
|
+
obtained by averaging these AP values across all specified IOU thresholds and all class labels.
|
|
298
|
+
|
|
299
|
+
Parameters
|
|
300
|
+
----------
|
|
301
|
+
value : float
|
|
302
|
+
The average precision value averaged over the specified IOU thresholds.
|
|
303
|
+
iou_thresholds : list[float]
|
|
304
|
+
The list of IOU thresholds used to compute the AP values.
|
|
305
|
+
|
|
306
|
+
Returns
|
|
307
|
+
-------
|
|
308
|
+
Metric
|
|
309
|
+
"""
|
|
310
|
+
return cls(
|
|
311
|
+
type=MetricType.mAPAveragedOverIOUs.value,
|
|
312
|
+
value=value,
|
|
313
|
+
parameters={
|
|
314
|
+
"iou_thresholds": iou_thresholds,
|
|
315
|
+
},
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
@classmethod
|
|
319
|
+
def average_recall(
|
|
320
|
+
cls,
|
|
321
|
+
value: float,
|
|
322
|
+
score_threshold: float,
|
|
323
|
+
iou_thresholds: list[float],
|
|
324
|
+
label: str,
|
|
325
|
+
):
|
|
326
|
+
"""
|
|
327
|
+
Average Recall (AR) metric for object detection tasks.
|
|
328
|
+
|
|
329
|
+
The AR computation considers detections with confidence scores above the specified
|
|
330
|
+
`score_threshold` and calculates the recall at each IOU threshold in `iou_thresholds`.
|
|
331
|
+
The final AR value is the average of these recall values across all specified IOU
|
|
332
|
+
thresholds.
|
|
333
|
+
|
|
334
|
+
Parameters
|
|
335
|
+
----------
|
|
336
|
+
value : float
|
|
337
|
+
The average recall value averaged over the specified IOU thresholds.
|
|
338
|
+
score_threshold : float
|
|
339
|
+
The detection score threshold; only detections with confidence scores above this
|
|
340
|
+
threshold are considered.
|
|
341
|
+
iou_thresholds : list[float]
|
|
342
|
+
The list of IOU thresholds used to compute the recall values.
|
|
343
|
+
label : str
|
|
344
|
+
The class label for which the AR is computed.
|
|
345
|
+
|
|
346
|
+
Returns
|
|
347
|
+
-------
|
|
348
|
+
Metric
|
|
349
|
+
"""
|
|
350
|
+
return cls(
|
|
351
|
+
type=MetricType.AR.value,
|
|
352
|
+
value=value,
|
|
353
|
+
parameters={
|
|
354
|
+
"iou_thresholds": iou_thresholds,
|
|
355
|
+
"score_threshold": score_threshold,
|
|
356
|
+
"label": label,
|
|
357
|
+
},
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
@classmethod
|
|
361
|
+
def mean_average_recall(
|
|
362
|
+
cls,
|
|
363
|
+
value: float,
|
|
364
|
+
score_threshold: float,
|
|
365
|
+
iou_thresholds: list[float],
|
|
366
|
+
):
|
|
367
|
+
"""
|
|
368
|
+
Mean Average Recall (mAR) metric for object detection tasks.
|
|
369
|
+
|
|
370
|
+
The mAR computation considers detections with confidence scores above the specified
|
|
371
|
+
`score_threshold` and calculates recall at each IOU threshold in `iou_thresholds` for
|
|
372
|
+
each label. The final mAR value is obtained by averaging these recall values over the
|
|
373
|
+
specified IOU thresholds and then averaging across all labels.
|
|
374
|
+
|
|
375
|
+
Parameters
|
|
376
|
+
----------
|
|
377
|
+
value : float
|
|
378
|
+
The mean average recall value averaged over the specified IOU thresholds.
|
|
379
|
+
score_threshold : float
|
|
380
|
+
The detection score threshold; only detections with confidence scores above this
|
|
381
|
+
threshold are considered.
|
|
382
|
+
iou_thresholds : list[float]
|
|
383
|
+
The list of IOU thresholds used to compute the recall values.
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
Metric
|
|
388
|
+
"""
|
|
389
|
+
return cls(
|
|
390
|
+
type=MetricType.mAR.value,
|
|
391
|
+
value=value,
|
|
392
|
+
parameters={
|
|
393
|
+
"iou_thresholds": iou_thresholds,
|
|
394
|
+
"score_threshold": score_threshold,
|
|
395
|
+
},
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
@classmethod
|
|
399
|
+
def average_recall_averaged_over_scores(
|
|
400
|
+
cls,
|
|
401
|
+
value: float,
|
|
402
|
+
score_thresholds: list[float],
|
|
403
|
+
iou_thresholds: list[float],
|
|
404
|
+
label: str,
|
|
405
|
+
):
|
|
406
|
+
"""
|
|
407
|
+
Average Recall (AR) metric averaged over multiple score thresholds for a specific object class label.
|
|
408
|
+
|
|
409
|
+
The AR computation considers detections across multiple `score_thresholds` and calculates
|
|
410
|
+
recall at each IOU threshold in `iou_thresholds`. The final AR value is obtained by averaging
|
|
411
|
+
the recall values over all specified score thresholds and IOU thresholds.
|
|
412
|
+
|
|
413
|
+
Parameters
|
|
414
|
+
----------
|
|
415
|
+
value : float
|
|
416
|
+
The average recall value averaged over the specified score thresholds and IOU thresholds.
|
|
417
|
+
score_thresholds : list[float]
|
|
418
|
+
The list of detection score thresholds; detections with confidence scores above each threshold are considered.
|
|
419
|
+
iou_thresholds : list[float]
|
|
420
|
+
The list of IOU thresholds used to compute the recall values.
|
|
421
|
+
label : str
|
|
422
|
+
The class label for which the AR is computed.
|
|
423
|
+
|
|
424
|
+
Returns
|
|
425
|
+
-------
|
|
426
|
+
Metric
|
|
427
|
+
"""
|
|
428
|
+
return cls(
|
|
429
|
+
type=MetricType.ARAveragedOverScores.value,
|
|
430
|
+
value=value,
|
|
431
|
+
parameters={
|
|
432
|
+
"iou_thresholds": iou_thresholds,
|
|
433
|
+
"score_thresholds": score_thresholds,
|
|
434
|
+
"label": label,
|
|
435
|
+
},
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
@classmethod
|
|
439
|
+
def mean_average_recall_averaged_over_scores(
|
|
440
|
+
cls,
|
|
441
|
+
value: float,
|
|
442
|
+
score_thresholds: list[float],
|
|
443
|
+
iou_thresholds: list[float],
|
|
444
|
+
):
|
|
445
|
+
"""
|
|
446
|
+
Mean Average Recall (mAR) metric averaged over multiple score thresholds and IOU thresholds.
|
|
447
|
+
|
|
448
|
+
The mAR computation considers detections across multiple `score_thresholds`, calculates recall
|
|
449
|
+
at each IOU threshold in `iou_thresholds` for each label, averages these recall values over all
|
|
450
|
+
specified score thresholds and IOU thresholds, and then computes the mean across all labels to
|
|
451
|
+
obtain the final mAR value.
|
|
452
|
+
|
|
453
|
+
Parameters
|
|
454
|
+
----------
|
|
455
|
+
value : float
|
|
456
|
+
The mean average recall value averaged over the specified score thresholds and IOU thresholds.
|
|
457
|
+
score_thresholds : list[float]
|
|
458
|
+
The list of detection score thresholds; detections with confidence scores above each threshold are considered.
|
|
459
|
+
iou_thresholds : list[float]
|
|
460
|
+
The list of IOU thresholds used to compute the recall values.
|
|
461
|
+
|
|
462
|
+
Returns
|
|
463
|
+
-------
|
|
464
|
+
Metric
|
|
465
|
+
"""
|
|
466
|
+
return cls(
|
|
467
|
+
type=MetricType.mARAveragedOverScores.value,
|
|
468
|
+
value=value,
|
|
469
|
+
parameters={
|
|
470
|
+
"iou_thresholds": iou_thresholds,
|
|
471
|
+
"score_thresholds": score_thresholds,
|
|
472
|
+
},
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
@classmethod
|
|
476
|
+
def precision_recall_curve(
|
|
477
|
+
cls,
|
|
478
|
+
precisions: list[float],
|
|
479
|
+
scores: list[float],
|
|
480
|
+
iou_threshold: float,
|
|
481
|
+
label: str,
|
|
482
|
+
):
|
|
483
|
+
"""
|
|
484
|
+
Interpolated precision-recall curve over 101 recall points.
|
|
485
|
+
|
|
486
|
+
The precision values are interpolated over recalls ranging from 0.0 to 1.0 in steps of 0.01,
|
|
487
|
+
resulting in 101 points. This is a byproduct of the 101-point interpolation used in calculating
|
|
488
|
+
the Average Precision (AP) metric in object detection tasks.
|
|
489
|
+
|
|
490
|
+
Parameters
|
|
491
|
+
----------
|
|
492
|
+
precisions : list[float]
|
|
493
|
+
Interpolated precision values corresponding to recalls at 0.0, 0.01, ..., 1.0.
|
|
494
|
+
scores : list[float]
|
|
495
|
+
Maximum prediction score for each point on the interpolated curve.
|
|
496
|
+
iou_threshold : float
|
|
497
|
+
The Intersection over Union (IOU) threshold used to determine true positives.
|
|
498
|
+
label : str
|
|
499
|
+
The class label associated with this precision-recall curve.
|
|
500
|
+
|
|
501
|
+
Returns
|
|
502
|
+
-------
|
|
503
|
+
Metric
|
|
504
|
+
"""
|
|
505
|
+
return cls(
|
|
506
|
+
type=MetricType.PrecisionRecallCurve.value,
|
|
507
|
+
value={
|
|
508
|
+
"precisions": precisions,
|
|
509
|
+
"scores": scores,
|
|
510
|
+
},
|
|
511
|
+
parameters={
|
|
512
|
+
"iou_threshold": iou_threshold,
|
|
513
|
+
"label": label,
|
|
514
|
+
},
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
@classmethod
|
|
518
|
+
def counts(
|
|
519
|
+
cls,
|
|
520
|
+
tp: int,
|
|
521
|
+
fp: int,
|
|
522
|
+
fn: int,
|
|
523
|
+
label: str,
|
|
524
|
+
iou_threshold: float,
|
|
525
|
+
score_threshold: float,
|
|
526
|
+
):
|
|
527
|
+
"""
|
|
528
|
+
`Counts` encapsulates the counts of true positives (`tp`), false positives (`fp`),
|
|
529
|
+
and false negatives (`fn`) for object detection evaluation, along with the associated
|
|
530
|
+
class label, Intersection over Union (IOU) threshold, and confidence score threshold.
|
|
531
|
+
|
|
532
|
+
Parameters
|
|
533
|
+
----------
|
|
534
|
+
tp : int
|
|
535
|
+
Number of true positives.
|
|
536
|
+
fp : int
|
|
537
|
+
Number of false positives.
|
|
538
|
+
fn : int
|
|
539
|
+
Number of false negatives.
|
|
540
|
+
label : str
|
|
541
|
+
The class label for which the counts are calculated.
|
|
542
|
+
iou_threshold : float
|
|
543
|
+
The IOU threshold used to determine a match between predicted and ground truth boxes.
|
|
544
|
+
score_threshold : float
|
|
545
|
+
The confidence score threshold above which predictions are considered.
|
|
546
|
+
|
|
547
|
+
Returns
|
|
548
|
+
-------
|
|
549
|
+
Metric
|
|
550
|
+
"""
|
|
551
|
+
return cls(
|
|
552
|
+
type=MetricType.Counts.value,
|
|
553
|
+
value={
|
|
554
|
+
"tp": tp,
|
|
555
|
+
"fp": fp,
|
|
556
|
+
"fn": fn,
|
|
557
|
+
},
|
|
558
|
+
parameters={
|
|
559
|
+
"iou_threshold": iou_threshold,
|
|
560
|
+
"score_threshold": score_threshold,
|
|
561
|
+
"label": label,
|
|
562
|
+
},
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
@classmethod
|
|
566
|
+
def confusion_matrix(
|
|
567
|
+
cls,
|
|
568
|
+
confusion_matrix: dict[str, dict[str, int]],
|
|
569
|
+
unmatched_predictions: dict[str, int],
|
|
570
|
+
unmatched_ground_truths: dict[str, int],
|
|
571
|
+
score_threshold: float,
|
|
572
|
+
iou_threshold: float,
|
|
573
|
+
):
|
|
574
|
+
"""
|
|
575
|
+
Confusion matrix for object detection task.
|
|
576
|
+
|
|
577
|
+
This class encapsulates detailed information about the model's performance, including correct
|
|
578
|
+
predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
|
|
579
|
+
(subset of false negatives).
|
|
580
|
+
|
|
581
|
+
Confusion Matrix Format:
|
|
582
|
+
{
|
|
583
|
+
<ground truth label>: {
|
|
584
|
+
<prediction label>: 129
|
|
585
|
+
...
|
|
586
|
+
},
|
|
587
|
+
...
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
Unmatched Predictions Format:
|
|
591
|
+
{
|
|
592
|
+
<prediction label>: 11
|
|
593
|
+
...
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
Unmatched Ground Truths Format:
|
|
597
|
+
{
|
|
598
|
+
<ground truth label>: 7
|
|
599
|
+
...
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
Parameters
|
|
603
|
+
----------
|
|
604
|
+
confusion_matrix : dict
|
|
605
|
+
A nested dictionary containing integer counts of occurences where the first key is the ground truth label value
|
|
606
|
+
and the second key is the prediction label value.
|
|
607
|
+
unmatched_predictions : dict
|
|
608
|
+
A dictionary where each key is a prediction label value with no corresponding ground truth
|
|
609
|
+
(subset of false positives). The value is a dictionary containing counts.
|
|
610
|
+
unmatched_ground_truths : dict
|
|
611
|
+
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
612
|
+
(subset of false negatives). The value is a dictionary containing counts.
|
|
613
|
+
score_threshold : float
|
|
614
|
+
The confidence score threshold used to filter predictions.
|
|
615
|
+
iou_threshold : float
|
|
616
|
+
The Intersection over Union (IOU) threshold used to determine true positives.
|
|
617
|
+
|
|
618
|
+
Returns
|
|
619
|
+
-------
|
|
620
|
+
Metric
|
|
621
|
+
"""
|
|
622
|
+
return cls(
|
|
623
|
+
type=MetricType.ConfusionMatrix.value,
|
|
624
|
+
value={
|
|
625
|
+
"confusion_matrix": confusion_matrix,
|
|
626
|
+
"unmatched_predictions": unmatched_predictions,
|
|
627
|
+
"unmatched_ground_truths": unmatched_ground_truths,
|
|
628
|
+
},
|
|
629
|
+
parameters={
|
|
630
|
+
"score_threshold": score_threshold,
|
|
631
|
+
"iou_threshold": iou_threshold,
|
|
632
|
+
},
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
@classmethod
|
|
636
|
+
def examples(
|
|
637
|
+
cls,
|
|
638
|
+
datum_id: str,
|
|
639
|
+
true_positives: list[tuple[str, str]],
|
|
640
|
+
false_positives: list[str],
|
|
641
|
+
false_negatives: list[str],
|
|
642
|
+
score_threshold: float,
|
|
643
|
+
iou_threshold: float,
|
|
644
|
+
):
|
|
645
|
+
"""
|
|
646
|
+
Per-datum examples for object detection tasks.
|
|
647
|
+
|
|
648
|
+
This metric is per-datum and contains lists of annotation identifiers that categorize them
|
|
649
|
+
as true-positive, false-positive or false-negative. This is intended to be used with an
|
|
650
|
+
external database where the identifiers can be used for retrieval.
|
|
651
|
+
|
|
652
|
+
Examples Format:
|
|
653
|
+
{
|
|
654
|
+
"type": "Examples",
|
|
655
|
+
"value": {
|
|
656
|
+
"datum_id": "some string ID",
|
|
657
|
+
"true_positives": [
|
|
658
|
+
["groundtruth0", "prediction0"],
|
|
659
|
+
["groundtruth123", "prediction11"],
|
|
660
|
+
...
|
|
661
|
+
],
|
|
662
|
+
"false_positives": [
|
|
663
|
+
"prediction25",
|
|
664
|
+
"prediction92",
|
|
665
|
+
...
|
|
666
|
+
]
|
|
667
|
+
"false_negatives": [
|
|
668
|
+
"groundtruth32",
|
|
669
|
+
"groundtruth24",
|
|
670
|
+
...
|
|
671
|
+
]
|
|
672
|
+
},
|
|
673
|
+
"parameters": {
|
|
674
|
+
"score_threshold": 0.5,
|
|
675
|
+
"iou_threshold": 0.5,
|
|
676
|
+
}
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
Parameters
|
|
680
|
+
----------
|
|
681
|
+
datum_id : str
|
|
682
|
+
A string identifier representing a datum.
|
|
683
|
+
true_positives : list[tuple[str, str]]
|
|
684
|
+
A list of string identifier pairs representing true positive ground truth and prediction combinations.
|
|
685
|
+
false_positives : list[str]
|
|
686
|
+
A list of string identifiers representing false positive predictions.
|
|
687
|
+
false_negatives : list[str]
|
|
688
|
+
A list of string identifiers representing false negative ground truths.
|
|
689
|
+
score_threshold : float
|
|
690
|
+
The confidence score threshold used to filter predictions.
|
|
691
|
+
iou_threshold : float
|
|
692
|
+
The Intersection over Union (IOU) threshold used to determine true positives.
|
|
693
|
+
|
|
694
|
+
Returns
|
|
695
|
+
-------
|
|
696
|
+
Metric
|
|
697
|
+
"""
|
|
698
|
+
return cls(
|
|
699
|
+
type=MetricType.Examples.value,
|
|
700
|
+
value={
|
|
701
|
+
"datum_id": datum_id,
|
|
702
|
+
"true_positives": true_positives,
|
|
703
|
+
"false_positives": false_positives,
|
|
704
|
+
"false_negatives": false_negatives,
|
|
705
|
+
},
|
|
706
|
+
parameters={
|
|
707
|
+
"score_threshold": score_threshold,
|
|
708
|
+
"iou_threshold": iou_threshold,
|
|
709
|
+
},
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
@classmethod
|
|
713
|
+
def confusion_matrix_with_examples(
|
|
714
|
+
cls,
|
|
715
|
+
confusion_matrix: dict[
|
|
716
|
+
str, # ground truth label value
|
|
717
|
+
dict[
|
|
718
|
+
str, # prediction label value
|
|
719
|
+
dict[
|
|
720
|
+
str, # either `count` or `examples`
|
|
721
|
+
int
|
|
722
|
+
| list[
|
|
723
|
+
dict[
|
|
724
|
+
str, # either `datum_id`, `ground_truth_id`, `prediction_id`
|
|
725
|
+
str, # string identifier
|
|
726
|
+
]
|
|
727
|
+
],
|
|
728
|
+
],
|
|
729
|
+
],
|
|
730
|
+
],
|
|
731
|
+
unmatched_predictions: dict[
|
|
732
|
+
str, # prediction label value
|
|
733
|
+
dict[
|
|
734
|
+
str, # either `count` or `examples`
|
|
735
|
+
int
|
|
736
|
+
| list[
|
|
737
|
+
dict[
|
|
738
|
+
str, # either `datum_id` or `prediction_id``
|
|
739
|
+
str, # string identifier
|
|
740
|
+
]
|
|
741
|
+
],
|
|
742
|
+
],
|
|
743
|
+
],
|
|
744
|
+
unmatched_ground_truths: dict[
|
|
745
|
+
str, # ground truth label value
|
|
746
|
+
dict[
|
|
747
|
+
str, # either `count` or `examples`
|
|
748
|
+
int
|
|
749
|
+
| list[
|
|
750
|
+
dict[
|
|
751
|
+
str, # either `datum_id` or `ground_truth_id`
|
|
752
|
+
str, # string identifier
|
|
753
|
+
]
|
|
754
|
+
],
|
|
755
|
+
],
|
|
756
|
+
],
|
|
757
|
+
score_threshold: float,
|
|
758
|
+
iou_threshold: float,
|
|
759
|
+
):
|
|
760
|
+
"""
|
|
761
|
+
Confusion matrix with examples for object detection tasks.
|
|
762
|
+
|
|
763
|
+
This class encapsulates detailed information about the model's performance, including correct
|
|
764
|
+
predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
|
|
765
|
+
(subset of false negatives). It provides counts and examples for each category to facilitate in-depth analysis.
|
|
766
|
+
|
|
767
|
+
Confusion Matrix Format:
|
|
768
|
+
{
|
|
769
|
+
<ground truth label>: {
|
|
770
|
+
<prediction label>: {
|
|
771
|
+
'count': int,
|
|
772
|
+
'examples': [
|
|
773
|
+
{
|
|
774
|
+
'datum_id': str,
|
|
775
|
+
'groundtruth_id': str,
|
|
776
|
+
'prediction_id': str
|
|
777
|
+
},
|
|
778
|
+
...
|
|
779
|
+
],
|
|
780
|
+
},
|
|
781
|
+
...
|
|
782
|
+
},
|
|
783
|
+
...
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
Unmatched Predictions Format:
|
|
787
|
+
{
|
|
788
|
+
<prediction label>: {
|
|
789
|
+
'count': int,
|
|
790
|
+
'examples': [
|
|
791
|
+
{
|
|
792
|
+
'datum_id': str,
|
|
793
|
+
'prediction_id': str
|
|
794
|
+
},
|
|
795
|
+
...
|
|
796
|
+
],
|
|
797
|
+
},
|
|
798
|
+
...
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
Unmatched Ground Truths Format:
|
|
802
|
+
{
|
|
803
|
+
<ground truth label>: {
|
|
804
|
+
'count': int,
|
|
805
|
+
'examples': [
|
|
806
|
+
{
|
|
807
|
+
'datum_id': str,
|
|
808
|
+
'groundtruth_id': str
|
|
809
|
+
},
|
|
810
|
+
...
|
|
811
|
+
],
|
|
812
|
+
},
|
|
813
|
+
...
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
Parameters
|
|
817
|
+
----------
|
|
818
|
+
confusion_matrix : dict
|
|
819
|
+
A nested dictionary where the first key is the ground truth label value, the second key
|
|
820
|
+
is the prediction label value, and the innermost dictionary contains either a `count`
|
|
821
|
+
or a list of `examples`. Each example includes annotation and datum identifers.
|
|
822
|
+
unmatched_predictions : dict
|
|
823
|
+
A dictionary where each key is a prediction label value with no corresponding ground truth
|
|
824
|
+
(subset of false positives). The value is a dictionary containing either a `count` or a list of
|
|
825
|
+
`examples`. Each example includes annotation and datum identifers.
|
|
826
|
+
unmatched_groundtruths : dict
|
|
827
|
+
A dictionary where each key is a ground truth label value for which the model failed to predict
|
|
828
|
+
(subset of false negatives). The value is a dictionary containing either a `count` or a list of `examples`.
|
|
829
|
+
Each example includes annotation and datum identifers.
|
|
830
|
+
score_threshold : float
|
|
831
|
+
The confidence score threshold used to filter predictions.
|
|
832
|
+
iou_threshold : float
|
|
833
|
+
The Intersection over Union (IOU) threshold used to determine true positives.
|
|
834
|
+
|
|
835
|
+
Returns
|
|
836
|
+
-------
|
|
837
|
+
Metric
|
|
838
|
+
"""
|
|
839
|
+
return cls(
|
|
840
|
+
type=MetricType.ConfusionMatrixWithExamples.value,
|
|
841
|
+
value={
|
|
842
|
+
"confusion_matrix": confusion_matrix,
|
|
843
|
+
"unmatched_predictions": unmatched_predictions,
|
|
844
|
+
"unmatched_ground_truths": unmatched_ground_truths,
|
|
845
|
+
},
|
|
846
|
+
parameters={
|
|
847
|
+
"score_threshold": score_threshold,
|
|
848
|
+
"iou_threshold": iou_threshold,
|
|
849
|
+
},
|
|
850
|
+
)
|