valor-lite 0.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

Files changed (49) hide show
  1. valor_lite/LICENSE +21 -0
  2. valor_lite/__init__.py +0 -0
  3. valor_lite/cache/__init__.py +11 -0
  4. valor_lite/cache/compute.py +154 -0
  5. valor_lite/cache/ephemeral.py +302 -0
  6. valor_lite/cache/persistent.py +529 -0
  7. valor_lite/classification/__init__.py +14 -0
  8. valor_lite/classification/annotation.py +45 -0
  9. valor_lite/classification/computation.py +378 -0
  10. valor_lite/classification/evaluator.py +879 -0
  11. valor_lite/classification/loader.py +97 -0
  12. valor_lite/classification/metric.py +535 -0
  13. valor_lite/classification/numpy_compatibility.py +13 -0
  14. valor_lite/classification/shared.py +184 -0
  15. valor_lite/classification/utilities.py +314 -0
  16. valor_lite/exceptions.py +20 -0
  17. valor_lite/object_detection/__init__.py +17 -0
  18. valor_lite/object_detection/annotation.py +238 -0
  19. valor_lite/object_detection/computation.py +841 -0
  20. valor_lite/object_detection/evaluator.py +805 -0
  21. valor_lite/object_detection/loader.py +292 -0
  22. valor_lite/object_detection/metric.py +850 -0
  23. valor_lite/object_detection/shared.py +185 -0
  24. valor_lite/object_detection/utilities.py +396 -0
  25. valor_lite/schemas.py +11 -0
  26. valor_lite/semantic_segmentation/__init__.py +15 -0
  27. valor_lite/semantic_segmentation/annotation.py +123 -0
  28. valor_lite/semantic_segmentation/computation.py +165 -0
  29. valor_lite/semantic_segmentation/evaluator.py +414 -0
  30. valor_lite/semantic_segmentation/loader.py +205 -0
  31. valor_lite/semantic_segmentation/metric.py +275 -0
  32. valor_lite/semantic_segmentation/shared.py +149 -0
  33. valor_lite/semantic_segmentation/utilities.py +88 -0
  34. valor_lite/text_generation/__init__.py +15 -0
  35. valor_lite/text_generation/annotation.py +56 -0
  36. valor_lite/text_generation/computation.py +611 -0
  37. valor_lite/text_generation/llm/__init__.py +0 -0
  38. valor_lite/text_generation/llm/exceptions.py +14 -0
  39. valor_lite/text_generation/llm/generation.py +903 -0
  40. valor_lite/text_generation/llm/instructions.py +814 -0
  41. valor_lite/text_generation/llm/integrations.py +226 -0
  42. valor_lite/text_generation/llm/utilities.py +43 -0
  43. valor_lite/text_generation/llm/validators.py +68 -0
  44. valor_lite/text_generation/manager.py +697 -0
  45. valor_lite/text_generation/metric.py +381 -0
  46. valor_lite-0.37.1.dist-info/METADATA +174 -0
  47. valor_lite-0.37.1.dist-info/RECORD +49 -0
  48. valor_lite-0.37.1.dist-info/WHEEL +5 -0
  49. valor_lite-0.37.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,850 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ from valor_lite.schemas import BaseMetric
5
+
6
+
7
+ class MetricType(str, Enum):
8
+ Counts = "Counts"
9
+ Precision = "Precision"
10
+ Recall = "Recall"
11
+ F1 = "F1"
12
+ AP = "AP"
13
+ AR = "AR"
14
+ mAP = "mAP"
15
+ mAR = "mAR"
16
+ APAveragedOverIOUs = "APAveragedOverIOUs"
17
+ mAPAveragedOverIOUs = "mAPAveragedOverIOUs"
18
+ ARAveragedOverScores = "ARAveragedOverScores"
19
+ mARAveragedOverScores = "mARAveragedOverScores"
20
+ PrecisionRecallCurve = "PrecisionRecallCurve"
21
+ ConfusionMatrixWithExamples = "ConfusionMatrixWithExamples"
22
+ ConfusionMatrix = "ConfusionMatrix"
23
+ Examples = "Examples"
24
+
25
+
26
+ @dataclass
27
+ class Metric(BaseMetric):
28
+ """
29
+ Object Detection Metric.
30
+
31
+ Attributes
32
+ ----------
33
+ type : str
34
+ The metric type.
35
+ value : int | float | dict
36
+ The metric value.
37
+ parameters : dict[str, Any]
38
+ A dictionary containing metric parameters.
39
+ """
40
+
41
+ def __post_init__(self):
42
+ if not isinstance(self.type, str):
43
+ raise TypeError(
44
+ f"Metric type should be of type 'str': {self.type}"
45
+ )
46
+ elif not isinstance(self.value, (int, float, dict)):
47
+ raise TypeError(
48
+ f"Metric value must be of type 'int', 'float' or 'dict': {self.value}"
49
+ )
50
+ elif not isinstance(self.parameters, dict):
51
+ raise TypeError(
52
+ f"Metric parameters must be of type 'dict[str, Any]': {self.parameters}"
53
+ )
54
+ elif not all([isinstance(k, str) for k in self.parameters.keys()]):
55
+ raise TypeError(
56
+ f"Metric parameter dictionary should only have keys with type 'str': {self.parameters}"
57
+ )
58
+
59
+ @classmethod
60
+ def precision(
61
+ cls,
62
+ value: float,
63
+ label: str,
64
+ iou_threshold: float,
65
+ score_threshold: float,
66
+ ):
67
+ """
68
+ Precision metric for a specific class label in object detection.
69
+
70
+ This class encapsulates a metric value for a particular class label,
71
+ along with the associated Intersection over Union (IOU) threshold and
72
+ confidence score threshold.
73
+
74
+ Parameters
75
+ ----------
76
+ value : float
77
+ The metric value.
78
+ label : str
79
+ The class label for which the metric is calculated.
80
+ iou_threshold : float
81
+ The IOU threshold used to determine matches between predicted and ground truth boxes.
82
+ score_threshold : float
83
+ The confidence score threshold above which predictions are considered.
84
+
85
+ Returns
86
+ -------
87
+ Metric
88
+ """
89
+ return cls(
90
+ type=MetricType.Precision.value,
91
+ value=value,
92
+ parameters={
93
+ "label": label,
94
+ "iou_threshold": iou_threshold,
95
+ "score_threshold": score_threshold,
96
+ },
97
+ )
98
+
99
+ @classmethod
100
+ def recall(
101
+ cls,
102
+ value: float,
103
+ label: str,
104
+ iou_threshold: float,
105
+ score_threshold: float,
106
+ ):
107
+ """
108
+ Recall metric for a specific class label in object detection.
109
+
110
+ This class encapsulates a metric value for a particular class label,
111
+ along with the associated Intersection over Union (IOU) threshold and
112
+ confidence score threshold.
113
+
114
+ Parameters
115
+ ----------
116
+ value : float
117
+ The metric value.
118
+ label : str
119
+ The class label for which the metric is calculated.
120
+ iou_threshold : float
121
+ The IOU threshold used to determine matches between predicted and ground truth boxes.
122
+ score_threshold : float
123
+ The confidence score threshold above which predictions are considered.
124
+
125
+ Returns
126
+ -------
127
+ Metric
128
+ """
129
+ return cls(
130
+ type=MetricType.Recall.value,
131
+ value=value,
132
+ parameters={
133
+ "label": label,
134
+ "iou_threshold": iou_threshold,
135
+ "score_threshold": score_threshold,
136
+ },
137
+ )
138
+
139
+ @classmethod
140
+ def f1_score(
141
+ cls,
142
+ value: float,
143
+ label: str,
144
+ iou_threshold: float,
145
+ score_threshold: float,
146
+ ):
147
+ """
148
+ F1 score for a specific class label in object detection.
149
+
150
+ This class encapsulates a metric value for a particular class label,
151
+ along with the associated Intersection over Union (IOU) threshold and
152
+ confidence score threshold.
153
+
154
+ Parameters
155
+ ----------
156
+ value : float
157
+ The metric value.
158
+ label : str
159
+ The class label for which the metric is calculated.
160
+ iou_threshold : float
161
+ The IOU threshold used to determine matches between predicted and ground truth boxes.
162
+ score_threshold : float
163
+ The confidence score threshold above which predictions are considered.
164
+
165
+ Returns
166
+ -------
167
+ Metric
168
+ """
169
+ return cls(
170
+ type=MetricType.F1.value,
171
+ value=value,
172
+ parameters={
173
+ "label": label,
174
+ "iou_threshold": iou_threshold,
175
+ "score_threshold": score_threshold,
176
+ },
177
+ )
178
+
179
+ @classmethod
180
+ def average_precision(
181
+ cls,
182
+ value: float,
183
+ iou_threshold: float,
184
+ label: str,
185
+ ):
186
+ """
187
+ Average Precision (AP) metric for object detection tasks.
188
+
189
+ The AP computation uses 101-point interpolation, which calculates the average
190
+ precision by interpolating the precision-recall curve at 101 evenly spaced recall
191
+ levels from 0 to 1.
192
+
193
+ Parameters
194
+ ----------
195
+ value : float
196
+ The average precision value.
197
+ iou_threshold : float
198
+ The IOU threshold used to compute the AP.
199
+ label : str
200
+ The class label for which the AP is computed.
201
+
202
+ Returns
203
+ -------
204
+ Metric
205
+ """
206
+ return cls(
207
+ type=MetricType.AP.value,
208
+ value=value,
209
+ parameters={
210
+ "iou_threshold": iou_threshold,
211
+ "label": label,
212
+ },
213
+ )
214
+
215
+ @classmethod
216
+ def mean_average_precision(
217
+ cls,
218
+ value: float,
219
+ iou_threshold: float,
220
+ ):
221
+ """
222
+ Mean Average Precision (mAP) metric for object detection tasks.
223
+
224
+ The AP computation uses 101-point interpolation, which calculates the average
225
+ precision for each class by interpolating the precision-recall curve at 101 evenly
226
+ spaced recall levels from 0 to 1. The mAP is then calculated by averaging these
227
+ values across all class labels.
228
+
229
+ Parameters
230
+ ----------
231
+ value : float
232
+ The mean average precision value.
233
+ iou_threshold : float
234
+ The IOU threshold used to compute the mAP.
235
+
236
+ Returns
237
+ -------
238
+ Metric
239
+ """
240
+ return cls(
241
+ type=MetricType.mAP.value,
242
+ value=value,
243
+ parameters={
244
+ "iou_threshold": iou_threshold,
245
+ },
246
+ )
247
+
248
+ @classmethod
249
+ def average_precision_averaged_over_IOUs(
250
+ cls,
251
+ value: float,
252
+ iou_thresholds: list[float],
253
+ label: str,
254
+ ):
255
+ """
256
+ Average Precision (AP) metric averaged over multiple IOU thresholds.
257
+
258
+ The AP computation uses 101-point interpolation, which calculates the average precision
259
+ by interpolating the precision-recall curve at 101 evenly spaced recall levels from 0 to 1
260
+ for each IOU threshold specified in `iou_thresholds`. The final APAveragedOverIOUs value is
261
+ obtained by averaging these AP values across all specified IOU thresholds.
262
+
263
+ Parameters
264
+ ----------
265
+ value : float
266
+ The average precision value averaged over the specified IOU thresholds.
267
+ iou_thresholds : list[float]
268
+ The list of IOU thresholds used to compute the AP values.
269
+ label : str
270
+ The class label for which the AP is computed.
271
+
272
+ Returns
273
+ -------
274
+ Metric
275
+ """
276
+ return cls(
277
+ type=MetricType.APAveragedOverIOUs.value,
278
+ value=value,
279
+ parameters={
280
+ "iou_thresholds": iou_thresholds,
281
+ "label": label,
282
+ },
283
+ )
284
+
285
+ @classmethod
286
+ def mean_average_precision_averaged_over_IOUs(
287
+ cls,
288
+ value: float,
289
+ iou_thresholds: list[float],
290
+ ):
291
+ """
292
+ Mean Average Precision (mAP) metric averaged over multiple IOU thresholds.
293
+
294
+ The AP computation uses 101-point interpolation, which calculates the average precision
295
+ by interpolating the precision-recall curve at 101 evenly spaced recall levels from 0 to 1
296
+ for each IOU threshold specified in `iou_thresholds`. The final mAPAveragedOverIOUs value is
297
+ obtained by averaging these AP values across all specified IOU thresholds and all class labels.
298
+
299
+ Parameters
300
+ ----------
301
+ value : float
302
+ The average precision value averaged over the specified IOU thresholds.
303
+ iou_thresholds : list[float]
304
+ The list of IOU thresholds used to compute the AP values.
305
+
306
+ Returns
307
+ -------
308
+ Metric
309
+ """
310
+ return cls(
311
+ type=MetricType.mAPAveragedOverIOUs.value,
312
+ value=value,
313
+ parameters={
314
+ "iou_thresholds": iou_thresholds,
315
+ },
316
+ )
317
+
318
+ @classmethod
319
+ def average_recall(
320
+ cls,
321
+ value: float,
322
+ score_threshold: float,
323
+ iou_thresholds: list[float],
324
+ label: str,
325
+ ):
326
+ """
327
+ Average Recall (AR) metric for object detection tasks.
328
+
329
+ The AR computation considers detections with confidence scores above the specified
330
+ `score_threshold` and calculates the recall at each IOU threshold in `iou_thresholds`.
331
+ The final AR value is the average of these recall values across all specified IOU
332
+ thresholds.
333
+
334
+ Parameters
335
+ ----------
336
+ value : float
337
+ The average recall value averaged over the specified IOU thresholds.
338
+ score_threshold : float
339
+ The detection score threshold; only detections with confidence scores above this
340
+ threshold are considered.
341
+ iou_thresholds : list[float]
342
+ The list of IOU thresholds used to compute the recall values.
343
+ label : str
344
+ The class label for which the AR is computed.
345
+
346
+ Returns
347
+ -------
348
+ Metric
349
+ """
350
+ return cls(
351
+ type=MetricType.AR.value,
352
+ value=value,
353
+ parameters={
354
+ "iou_thresholds": iou_thresholds,
355
+ "score_threshold": score_threshold,
356
+ "label": label,
357
+ },
358
+ )
359
+
360
+ @classmethod
361
+ def mean_average_recall(
362
+ cls,
363
+ value: float,
364
+ score_threshold: float,
365
+ iou_thresholds: list[float],
366
+ ):
367
+ """
368
+ Mean Average Recall (mAR) metric for object detection tasks.
369
+
370
+ The mAR computation considers detections with confidence scores above the specified
371
+ `score_threshold` and calculates recall at each IOU threshold in `iou_thresholds` for
372
+ each label. The final mAR value is obtained by averaging these recall values over the
373
+ specified IOU thresholds and then averaging across all labels.
374
+
375
+ Parameters
376
+ ----------
377
+ value : float
378
+ The mean average recall value averaged over the specified IOU thresholds.
379
+ score_threshold : float
380
+ The detection score threshold; only detections with confidence scores above this
381
+ threshold are considered.
382
+ iou_thresholds : list[float]
383
+ The list of IOU thresholds used to compute the recall values.
384
+
385
+ Returns
386
+ -------
387
+ Metric
388
+ """
389
+ return cls(
390
+ type=MetricType.mAR.value,
391
+ value=value,
392
+ parameters={
393
+ "iou_thresholds": iou_thresholds,
394
+ "score_threshold": score_threshold,
395
+ },
396
+ )
397
+
398
+ @classmethod
399
+ def average_recall_averaged_over_scores(
400
+ cls,
401
+ value: float,
402
+ score_thresholds: list[float],
403
+ iou_thresholds: list[float],
404
+ label: str,
405
+ ):
406
+ """
407
+ Average Recall (AR) metric averaged over multiple score thresholds for a specific object class label.
408
+
409
+ The AR computation considers detections across multiple `score_thresholds` and calculates
410
+ recall at each IOU threshold in `iou_thresholds`. The final AR value is obtained by averaging
411
+ the recall values over all specified score thresholds and IOU thresholds.
412
+
413
+ Parameters
414
+ ----------
415
+ value : float
416
+ The average recall value averaged over the specified score thresholds and IOU thresholds.
417
+ score_thresholds : list[float]
418
+ The list of detection score thresholds; detections with confidence scores above each threshold are considered.
419
+ iou_thresholds : list[float]
420
+ The list of IOU thresholds used to compute the recall values.
421
+ label : str
422
+ The class label for which the AR is computed.
423
+
424
+ Returns
425
+ -------
426
+ Metric
427
+ """
428
+ return cls(
429
+ type=MetricType.ARAveragedOverScores.value,
430
+ value=value,
431
+ parameters={
432
+ "iou_thresholds": iou_thresholds,
433
+ "score_thresholds": score_thresholds,
434
+ "label": label,
435
+ },
436
+ )
437
+
438
+ @classmethod
439
+ def mean_average_recall_averaged_over_scores(
440
+ cls,
441
+ value: float,
442
+ score_thresholds: list[float],
443
+ iou_thresholds: list[float],
444
+ ):
445
+ """
446
+ Mean Average Recall (mAR) metric averaged over multiple score thresholds and IOU thresholds.
447
+
448
+ The mAR computation considers detections across multiple `score_thresholds`, calculates recall
449
+ at each IOU threshold in `iou_thresholds` for each label, averages these recall values over all
450
+ specified score thresholds and IOU thresholds, and then computes the mean across all labels to
451
+ obtain the final mAR value.
452
+
453
+ Parameters
454
+ ----------
455
+ value : float
456
+ The mean average recall value averaged over the specified score thresholds and IOU thresholds.
457
+ score_thresholds : list[float]
458
+ The list of detection score thresholds; detections with confidence scores above each threshold are considered.
459
+ iou_thresholds : list[float]
460
+ The list of IOU thresholds used to compute the recall values.
461
+
462
+ Returns
463
+ -------
464
+ Metric
465
+ """
466
+ return cls(
467
+ type=MetricType.mARAveragedOverScores.value,
468
+ value=value,
469
+ parameters={
470
+ "iou_thresholds": iou_thresholds,
471
+ "score_thresholds": score_thresholds,
472
+ },
473
+ )
474
+
475
+ @classmethod
476
+ def precision_recall_curve(
477
+ cls,
478
+ precisions: list[float],
479
+ scores: list[float],
480
+ iou_threshold: float,
481
+ label: str,
482
+ ):
483
+ """
484
+ Interpolated precision-recall curve over 101 recall points.
485
+
486
+ The precision values are interpolated over recalls ranging from 0.0 to 1.0 in steps of 0.01,
487
+ resulting in 101 points. This is a byproduct of the 101-point interpolation used in calculating
488
+ the Average Precision (AP) metric in object detection tasks.
489
+
490
+ Parameters
491
+ ----------
492
+ precisions : list[float]
493
+ Interpolated precision values corresponding to recalls at 0.0, 0.01, ..., 1.0.
494
+ scores : list[float]
495
+ Maximum prediction score for each point on the interpolated curve.
496
+ iou_threshold : float
497
+ The Intersection over Union (IOU) threshold used to determine true positives.
498
+ label : str
499
+ The class label associated with this precision-recall curve.
500
+
501
+ Returns
502
+ -------
503
+ Metric
504
+ """
505
+ return cls(
506
+ type=MetricType.PrecisionRecallCurve.value,
507
+ value={
508
+ "precisions": precisions,
509
+ "scores": scores,
510
+ },
511
+ parameters={
512
+ "iou_threshold": iou_threshold,
513
+ "label": label,
514
+ },
515
+ )
516
+
517
+ @classmethod
518
+ def counts(
519
+ cls,
520
+ tp: int,
521
+ fp: int,
522
+ fn: int,
523
+ label: str,
524
+ iou_threshold: float,
525
+ score_threshold: float,
526
+ ):
527
+ """
528
+ `Counts` encapsulates the counts of true positives (`tp`), false positives (`fp`),
529
+ and false negatives (`fn`) for object detection evaluation, along with the associated
530
+ class label, Intersection over Union (IOU) threshold, and confidence score threshold.
531
+
532
+ Parameters
533
+ ----------
534
+ tp : int
535
+ Number of true positives.
536
+ fp : int
537
+ Number of false positives.
538
+ fn : int
539
+ Number of false negatives.
540
+ label : str
541
+ The class label for which the counts are calculated.
542
+ iou_threshold : float
543
+ The IOU threshold used to determine a match between predicted and ground truth boxes.
544
+ score_threshold : float
545
+ The confidence score threshold above which predictions are considered.
546
+
547
+ Returns
548
+ -------
549
+ Metric
550
+ """
551
+ return cls(
552
+ type=MetricType.Counts.value,
553
+ value={
554
+ "tp": tp,
555
+ "fp": fp,
556
+ "fn": fn,
557
+ },
558
+ parameters={
559
+ "iou_threshold": iou_threshold,
560
+ "score_threshold": score_threshold,
561
+ "label": label,
562
+ },
563
+ )
564
+
565
+ @classmethod
566
+ def confusion_matrix(
567
+ cls,
568
+ confusion_matrix: dict[str, dict[str, int]],
569
+ unmatched_predictions: dict[str, int],
570
+ unmatched_ground_truths: dict[str, int],
571
+ score_threshold: float,
572
+ iou_threshold: float,
573
+ ):
574
+ """
575
+ Confusion matrix for object detection task.
576
+
577
+ This class encapsulates detailed information about the model's performance, including correct
578
+ predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
579
+ (subset of false negatives).
580
+
581
+ Confusion Matrix Format:
582
+ {
583
+ <ground truth label>: {
584
+ <prediction label>: 129
585
+ ...
586
+ },
587
+ ...
588
+ }
589
+
590
+ Unmatched Predictions Format:
591
+ {
592
+ <prediction label>: 11
593
+ ...
594
+ }
595
+
596
+ Unmatched Ground Truths Format:
597
+ {
598
+ <ground truth label>: 7
599
+ ...
600
+ }
601
+
602
+ Parameters
603
+ ----------
604
+ confusion_matrix : dict
605
+ A nested dictionary containing integer counts of occurences where the first key is the ground truth label value
606
+ and the second key is the prediction label value.
607
+ unmatched_predictions : dict
608
+ A dictionary where each key is a prediction label value with no corresponding ground truth
609
+ (subset of false positives). The value is a dictionary containing counts.
610
+ unmatched_ground_truths : dict
611
+ A dictionary where each key is a ground truth label value for which the model failed to predict
612
+ (subset of false negatives). The value is a dictionary containing counts.
613
+ score_threshold : float
614
+ The confidence score threshold used to filter predictions.
615
+ iou_threshold : float
616
+ The Intersection over Union (IOU) threshold used to determine true positives.
617
+
618
+ Returns
619
+ -------
620
+ Metric
621
+ """
622
+ return cls(
623
+ type=MetricType.ConfusionMatrix.value,
624
+ value={
625
+ "confusion_matrix": confusion_matrix,
626
+ "unmatched_predictions": unmatched_predictions,
627
+ "unmatched_ground_truths": unmatched_ground_truths,
628
+ },
629
+ parameters={
630
+ "score_threshold": score_threshold,
631
+ "iou_threshold": iou_threshold,
632
+ },
633
+ )
634
+
635
+ @classmethod
636
+ def examples(
637
+ cls,
638
+ datum_id: str,
639
+ true_positives: list[tuple[str, str]],
640
+ false_positives: list[str],
641
+ false_negatives: list[str],
642
+ score_threshold: float,
643
+ iou_threshold: float,
644
+ ):
645
+ """
646
+ Per-datum examples for object detection tasks.
647
+
648
+ This metric is per-datum and contains lists of annotation identifiers that categorize them
649
+ as true-positive, false-positive or false-negative. This is intended to be used with an
650
+ external database where the identifiers can be used for retrieval.
651
+
652
+ Examples Format:
653
+ {
654
+ "type": "Examples",
655
+ "value": {
656
+ "datum_id": "some string ID",
657
+ "true_positives": [
658
+ ["groundtruth0", "prediction0"],
659
+ ["groundtruth123", "prediction11"],
660
+ ...
661
+ ],
662
+ "false_positives": [
663
+ "prediction25",
664
+ "prediction92",
665
+ ...
666
+ ]
667
+ "false_negatives": [
668
+ "groundtruth32",
669
+ "groundtruth24",
670
+ ...
671
+ ]
672
+ },
673
+ "parameters": {
674
+ "score_threshold": 0.5,
675
+ "iou_threshold": 0.5,
676
+ }
677
+ }
678
+
679
+ Parameters
680
+ ----------
681
+ datum_id : str
682
+ A string identifier representing a datum.
683
+ true_positives : list[tuple[str, str]]
684
+ A list of string identifier pairs representing true positive ground truth and prediction combinations.
685
+ false_positives : list[str]
686
+ A list of string identifiers representing false positive predictions.
687
+ false_negatives : list[str]
688
+ A list of string identifiers representing false negative ground truths.
689
+ score_threshold : float
690
+ The confidence score threshold used to filter predictions.
691
+ iou_threshold : float
692
+ The Intersection over Union (IOU) threshold used to determine true positives.
693
+
694
+ Returns
695
+ -------
696
+ Metric
697
+ """
698
+ return cls(
699
+ type=MetricType.Examples.value,
700
+ value={
701
+ "datum_id": datum_id,
702
+ "true_positives": true_positives,
703
+ "false_positives": false_positives,
704
+ "false_negatives": false_negatives,
705
+ },
706
+ parameters={
707
+ "score_threshold": score_threshold,
708
+ "iou_threshold": iou_threshold,
709
+ },
710
+ )
711
+
712
+ @classmethod
713
+ def confusion_matrix_with_examples(
714
+ cls,
715
+ confusion_matrix: dict[
716
+ str, # ground truth label value
717
+ dict[
718
+ str, # prediction label value
719
+ dict[
720
+ str, # either `count` or `examples`
721
+ int
722
+ | list[
723
+ dict[
724
+ str, # either `datum_id`, `ground_truth_id`, `prediction_id`
725
+ str, # string identifier
726
+ ]
727
+ ],
728
+ ],
729
+ ],
730
+ ],
731
+ unmatched_predictions: dict[
732
+ str, # prediction label value
733
+ dict[
734
+ str, # either `count` or `examples`
735
+ int
736
+ | list[
737
+ dict[
738
+ str, # either `datum_id` or `prediction_id``
739
+ str, # string identifier
740
+ ]
741
+ ],
742
+ ],
743
+ ],
744
+ unmatched_ground_truths: dict[
745
+ str, # ground truth label value
746
+ dict[
747
+ str, # either `count` or `examples`
748
+ int
749
+ | list[
750
+ dict[
751
+ str, # either `datum_id` or `ground_truth_id`
752
+ str, # string identifier
753
+ ]
754
+ ],
755
+ ],
756
+ ],
757
+ score_threshold: float,
758
+ iou_threshold: float,
759
+ ):
760
+ """
761
+ Confusion matrix with examples for object detection tasks.
762
+
763
+ This class encapsulates detailed information about the model's performance, including correct
764
+ predictions, misclassifications, unmatched_predictions (subset of false positives), and unmatched ground truths
765
+ (subset of false negatives). It provides counts and examples for each category to facilitate in-depth analysis.
766
+
767
+ Confusion Matrix Format:
768
+ {
769
+ <ground truth label>: {
770
+ <prediction label>: {
771
+ 'count': int,
772
+ 'examples': [
773
+ {
774
+ 'datum_id': str,
775
+ 'groundtruth_id': str,
776
+ 'prediction_id': str
777
+ },
778
+ ...
779
+ ],
780
+ },
781
+ ...
782
+ },
783
+ ...
784
+ }
785
+
786
+ Unmatched Predictions Format:
787
+ {
788
+ <prediction label>: {
789
+ 'count': int,
790
+ 'examples': [
791
+ {
792
+ 'datum_id': str,
793
+ 'prediction_id': str
794
+ },
795
+ ...
796
+ ],
797
+ },
798
+ ...
799
+ }
800
+
801
+ Unmatched Ground Truths Format:
802
+ {
803
+ <ground truth label>: {
804
+ 'count': int,
805
+ 'examples': [
806
+ {
807
+ 'datum_id': str,
808
+ 'groundtruth_id': str
809
+ },
810
+ ...
811
+ ],
812
+ },
813
+ ...
814
+ }
815
+
816
+ Parameters
817
+ ----------
818
+ confusion_matrix : dict
819
+ A nested dictionary where the first key is the ground truth label value, the second key
820
+ is the prediction label value, and the innermost dictionary contains either a `count`
821
+ or a list of `examples`. Each example includes annotation and datum identifers.
822
+ unmatched_predictions : dict
823
+ A dictionary where each key is a prediction label value with no corresponding ground truth
824
+ (subset of false positives). The value is a dictionary containing either a `count` or a list of
825
+ `examples`. Each example includes annotation and datum identifers.
826
+ unmatched_groundtruths : dict
827
+ A dictionary where each key is a ground truth label value for which the model failed to predict
828
+ (subset of false negatives). The value is a dictionary containing either a `count` or a list of `examples`.
829
+ Each example includes annotation and datum identifers.
830
+ score_threshold : float
831
+ The confidence score threshold used to filter predictions.
832
+ iou_threshold : float
833
+ The Intersection over Union (IOU) threshold used to determine true positives.
834
+
835
+ Returns
836
+ -------
837
+ Metric
838
+ """
839
+ return cls(
840
+ type=MetricType.ConfusionMatrixWithExamples.value,
841
+ value={
842
+ "confusion_matrix": confusion_matrix,
843
+ "unmatched_predictions": unmatched_predictions,
844
+ "unmatched_ground_truths": unmatched_ground_truths,
845
+ },
846
+ parameters={
847
+ "score_threshold": score_threshold,
848
+ "iou_threshold": iou_threshold,
849
+ },
850
+ )