eye-cv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eye/__init__.py +115 -0
  2. eye/__init___supervision_original.py +120 -0
  3. eye/annotators/__init__.py +0 -0
  4. eye/annotators/base.py +22 -0
  5. eye/annotators/core.py +2699 -0
  6. eye/annotators/line.py +107 -0
  7. eye/annotators/modern.py +529 -0
  8. eye/annotators/trace.py +142 -0
  9. eye/annotators/utils.py +177 -0
  10. eye/assets/__init__.py +2 -0
  11. eye/assets/downloader.py +95 -0
  12. eye/assets/list.py +83 -0
  13. eye/classification/__init__.py +0 -0
  14. eye/classification/core.py +188 -0
  15. eye/config.py +2 -0
  16. eye/core/__init__.py +0 -0
  17. eye/core/trackers/__init__.py +1 -0
  18. eye/core/trackers/botsort_tracker.py +336 -0
  19. eye/core/trackers/bytetrack_tracker.py +284 -0
  20. eye/core/trackers/sort_tracker.py +200 -0
  21. eye/core/tracking.py +146 -0
  22. eye/dataset/__init__.py +0 -0
  23. eye/dataset/core.py +919 -0
  24. eye/dataset/formats/__init__.py +0 -0
  25. eye/dataset/formats/coco.py +258 -0
  26. eye/dataset/formats/pascal_voc.py +279 -0
  27. eye/dataset/formats/yolo.py +272 -0
  28. eye/dataset/utils.py +259 -0
  29. eye/detection/__init__.py +0 -0
  30. eye/detection/auto_convert.py +155 -0
  31. eye/detection/core.py +1529 -0
  32. eye/detection/detections_enhanced.py +392 -0
  33. eye/detection/line_zone.py +859 -0
  34. eye/detection/lmm.py +184 -0
  35. eye/detection/overlap_filter.py +270 -0
  36. eye/detection/tools/__init__.py +0 -0
  37. eye/detection/tools/csv_sink.py +181 -0
  38. eye/detection/tools/inference_slicer.py +288 -0
  39. eye/detection/tools/json_sink.py +142 -0
  40. eye/detection/tools/polygon_zone.py +202 -0
  41. eye/detection/tools/smoother.py +123 -0
  42. eye/detection/tools/smoothing.py +179 -0
  43. eye/detection/tools/smoothing_config.py +202 -0
  44. eye/detection/tools/transformers.py +247 -0
  45. eye/detection/utils.py +1175 -0
  46. eye/draw/__init__.py +0 -0
  47. eye/draw/color.py +154 -0
  48. eye/draw/utils.py +374 -0
  49. eye/filters.py +112 -0
  50. eye/geometry/__init__.py +0 -0
  51. eye/geometry/core.py +128 -0
  52. eye/geometry/utils.py +47 -0
  53. eye/keypoint/__init__.py +0 -0
  54. eye/keypoint/annotators.py +442 -0
  55. eye/keypoint/core.py +687 -0
  56. eye/keypoint/skeletons.py +2647 -0
  57. eye/metrics/__init__.py +21 -0
  58. eye/metrics/core.py +72 -0
  59. eye/metrics/detection.py +843 -0
  60. eye/metrics/f1_score.py +648 -0
  61. eye/metrics/mean_average_precision.py +628 -0
  62. eye/metrics/mean_average_recall.py +697 -0
  63. eye/metrics/precision.py +653 -0
  64. eye/metrics/recall.py +652 -0
  65. eye/metrics/utils/__init__.py +0 -0
  66. eye/metrics/utils/object_size.py +158 -0
  67. eye/metrics/utils/utils.py +9 -0
  68. eye/py.typed +0 -0
  69. eye/quick.py +104 -0
  70. eye/tracker/__init__.py +0 -0
  71. eye/tracker/byte_tracker/__init__.py +0 -0
  72. eye/tracker/byte_tracker/core.py +386 -0
  73. eye/tracker/byte_tracker/kalman_filter.py +205 -0
  74. eye/tracker/byte_tracker/matching.py +69 -0
  75. eye/tracker/byte_tracker/single_object_track.py +178 -0
  76. eye/tracker/byte_tracker/utils.py +18 -0
  77. eye/utils/__init__.py +0 -0
  78. eye/utils/conversion.py +132 -0
  79. eye/utils/file.py +159 -0
  80. eye/utils/image.py +794 -0
  81. eye/utils/internal.py +200 -0
  82. eye/utils/iterables.py +84 -0
  83. eye/utils/notebook.py +114 -0
  84. eye/utils/video.py +307 -0
  85. eye/utils_eye/__init__.py +1 -0
  86. eye/utils_eye/geometry.py +71 -0
  87. eye/utils_eye/nms.py +55 -0
  88. eye/validators/__init__.py +140 -0
  89. eye/web.py +271 -0
  90. eye_cv-1.0.0.dist-info/METADATA +319 -0
  91. eye_cv-1.0.0.dist-info/RECORD +94 -0
  92. eye_cv-1.0.0.dist-info/WHEEL +5 -0
  93. eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
  94. eye_cv-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,653 @@
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ from matplotlib import pyplot as plt
9
+
10
+ from eye.config import ORIENTED_BOX_COORDINATES
11
+ from eye.detection.core import Detections
12
+ from eye.detection.utils import (
13
+ box_iou_batch,
14
+ mask_iou_batch,
15
+ oriented_box_iou_batch,
16
+ )
17
+ from eye.draw.color import LEGACY_COLOR_PALETTE
18
+ from eye.metrics.core import AveragingMethod, Metric, MetricTarget
19
+ from eye.metrics.utils.object_size import (
20
+ ObjectSizeCategory,
21
+ get_detection_size_category,
22
+ )
23
+ from eye.metrics.utils.utils import ensure_pandas_installed
24
+
25
+ if TYPE_CHECKING:
26
+ import pandas as pd
27
+
28
+
29
+ class Precision(Metric):
30
+ """
31
+ Precision is a metric used to evaluate object detection models. It is the ratio of
32
+ true positive detections to the total number of predicted detections. We calculate
33
+ it at different IoU thresholds.
34
+
35
+ In simple terms, Precision is a measure of a model's accuracy, calculated as:
36
+
37
+ `Precision = TP / (TP + FP)`
38
+
39
+ Here, `TP` is the number of true positives (correct detections), and `FP` is the
40
+ number of false positive detections (detected, but incorrectly).
41
+
42
+ Example:
43
+ ```python
44
+ import eye as sv
45
+ from eye.metrics import Precision
46
+
47
+ predictions = sv.Detections(...)
48
+ targets = sv.Detections(...)
49
+
50
+ precision_metric = Precision()
51
+ precision_result = precision_metric.update(predictions, targets).compute()
52
+
53
+ print(precision_result.precision_at_50)
54
+ # 0.8099
55
+
56
+ print(precision_result)
57
+ # PrecisionResult:
58
+ # Metric target: MetricTarget.BOXES
59
+ # Averaging method: AveragingMethod.WEIGHTED
60
+ # P @ 50: 0.8099
61
+ # P @ 75: 0.7969
62
+ # P @ thresh: [0.80992 0.80905 0.80905 ...]
63
+ # IoU thresh: [0.5 0.55 0.6 ...]
64
+ # Precision per class:
65
+ # 0: [0.64706 0.64706 0.64706 ...]
66
+ # ...
67
+ # Small objects: ...
68
+ # Medium objects: ...
69
+ # Large objects: ...
70
+
71
+ print(precision_result.small_objects.precision_at_50)
72
+ ```
73
+
74
+ ![example_plot](\
75
+ https://media.roboflow.com/eye-docs/metrics/precision_plot_example.png\
76
+ ){ align=center width="800" }
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ metric_target: MetricTarget = MetricTarget.BOXES,
82
+ averaging_method: AveragingMethod = AveragingMethod.WEIGHTED,
83
+ ):
84
+ """
85
+ Initialize the Precision metric.
86
+
87
+ Args:
88
+ metric_target (MetricTarget): The type of detection data to use.
89
+ averaging_method (AveragingMethod): The averaging method used to compute the
90
+ precision. Determines how the precision is aggregated across classes.
91
+ """
92
+ self._metric_target = metric_target
93
+ self.averaging_method = averaging_method
94
+
95
+ self._predictions_list: List[Detections] = []
96
+ self._targets_list: List[Detections] = []
97
+
98
+ def reset(self) -> None:
99
+ """
100
+ Reset the metric to its initial state, clearing all stored data.
101
+ """
102
+ self._predictions_list = []
103
+ self._targets_list = []
104
+
105
+ def update(
106
+ self,
107
+ predictions: Union[Detections, List[Detections]],
108
+ targets: Union[Detections, List[Detections]],
109
+ ) -> Precision:
110
+ """
111
+ Add new predictions and targets to the metric, but do not compute the result.
112
+
113
+ Args:
114
+ predictions (Union[Detections, List[Detections]]): The predicted detections.
115
+ targets (Union[Detections, List[Detections]]): The target detections.
116
+
117
+ Returns:
118
+ (Precision): The updated metric instance.
119
+ """
120
+ if not isinstance(predictions, list):
121
+ predictions = [predictions]
122
+ if not isinstance(targets, list):
123
+ targets = [targets]
124
+
125
+ if len(predictions) != len(targets):
126
+ raise ValueError(
127
+ f"The number of predictions ({len(predictions)}) and"
128
+ f" targets ({len(targets)}) during the update must be the same."
129
+ )
130
+
131
+ self._predictions_list.extend(predictions)
132
+ self._targets_list.extend(targets)
133
+
134
+ return self
135
+
136
+ def compute(self) -> PrecisionResult:
137
+ """
138
+ Calculate the precision metric based on the stored predictions and ground-truth
139
+ data, at different IoU thresholds.
140
+
141
+ Returns:
142
+ (PrecisionResult): The precision metric result.
143
+ """
144
+ result = self._compute(self._predictions_list, self._targets_list)
145
+
146
+ small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
147
+ self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL
148
+ )
149
+ result.small_objects = self._compute(small_predictions, small_targets)
150
+
151
+ medium_predictions, medium_targets = (
152
+ self._filter_predictions_and_targets_by_size(
153
+ self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM
154
+ )
155
+ )
156
+ result.medium_objects = self._compute(medium_predictions, medium_targets)
157
+
158
+ large_predictions, large_targets = self._filter_predictions_and_targets_by_size(
159
+ self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE
160
+ )
161
+ result.large_objects = self._compute(large_predictions, large_targets)
162
+
163
+ return result
164
+
165
+ def _compute(
166
+ self, predictions_list: List[Detections], targets_list: List[Detections]
167
+ ) -> PrecisionResult:
168
+ iou_thresholds = np.linspace(0.5, 0.95, 10)
169
+ stats = []
170
+
171
+ for predictions, targets in zip(predictions_list, targets_list):
172
+ prediction_contents = self._detections_content(predictions)
173
+ target_contents = self._detections_content(targets)
174
+
175
+ if len(targets) > 0:
176
+ if len(predictions) == 0:
177
+ stats.append(
178
+ (
179
+ np.zeros((0, iou_thresholds.size), dtype=bool),
180
+ np.zeros((0,), dtype=np.float32),
181
+ np.zeros((0,), dtype=int),
182
+ targets.class_id,
183
+ )
184
+ )
185
+
186
+ else:
187
+ if self._metric_target == MetricTarget.BOXES:
188
+ iou = box_iou_batch(target_contents, prediction_contents)
189
+ elif self._metric_target == MetricTarget.MASKS:
190
+ iou = mask_iou_batch(target_contents, prediction_contents)
191
+ elif self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
192
+ iou = oriented_box_iou_batch(
193
+ target_contents, prediction_contents
194
+ )
195
+ else:
196
+ raise ValueError(
197
+ "Unsupported metric target for IoU calculation"
198
+ )
199
+
200
+ matches = self._match_detection_batch(
201
+ predictions.class_id, targets.class_id, iou, iou_thresholds
202
+ )
203
+ stats.append(
204
+ (
205
+ matches,
206
+ predictions.confidence,
207
+ predictions.class_id,
208
+ targets.class_id,
209
+ )
210
+ )
211
+
212
+ if not stats:
213
+ return PrecisionResult(
214
+ metric_target=self._metric_target,
215
+ averaging_method=self.averaging_method,
216
+ precision_scores=np.zeros(iou_thresholds.shape[0]),
217
+ precision_per_class=np.zeros((0, iou_thresholds.shape[0])),
218
+ iou_thresholds=iou_thresholds,
219
+ matched_classes=np.array([], dtype=int),
220
+ small_objects=None,
221
+ medium_objects=None,
222
+ large_objects=None,
223
+ )
224
+
225
+ concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
226
+ precision_scores, precision_per_class, unique_classes = (
227
+ self._compute_precision_for_classes(*concatenated_stats)
228
+ )
229
+
230
+ return PrecisionResult(
231
+ metric_target=self._metric_target,
232
+ averaging_method=self.averaging_method,
233
+ precision_scores=precision_scores,
234
+ precision_per_class=precision_per_class,
235
+ iou_thresholds=iou_thresholds,
236
+ matched_classes=unique_classes,
237
+ small_objects=None,
238
+ medium_objects=None,
239
+ large_objects=None,
240
+ )
241
+
242
+ def _compute_precision_for_classes(
243
+ self,
244
+ matches: np.ndarray,
245
+ prediction_confidence: np.ndarray,
246
+ prediction_class_ids: np.ndarray,
247
+ true_class_ids: np.ndarray,
248
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
249
+ sorted_indices = np.argsort(-prediction_confidence)
250
+ matches = matches[sorted_indices]
251
+ prediction_class_ids = prediction_class_ids[sorted_indices]
252
+ unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
253
+
254
+ # Shape: PxTh,P,C,C -> CxThx3
255
+ confusion_matrix = self._compute_confusion_matrix(
256
+ matches, prediction_class_ids, unique_classes, class_counts
257
+ )
258
+
259
+ # Shape: CxThx3 -> CxTh
260
+ precision_per_class = self._compute_precision(confusion_matrix)
261
+
262
+ # Shape: CxTh -> Th
263
+ if self.averaging_method == AveragingMethod.MACRO:
264
+ precision_scores = np.mean(precision_per_class, axis=0)
265
+ elif self.averaging_method == AveragingMethod.MICRO:
266
+ confusion_matrix_merged = confusion_matrix.sum(0)
267
+ precision_scores = self._compute_precision(confusion_matrix_merged)
268
+ elif self.averaging_method == AveragingMethod.WEIGHTED:
269
+ class_counts = class_counts.astype(np.float32)
270
+ precision_scores = np.average(
271
+ precision_per_class, axis=0, weights=class_counts
272
+ )
273
+
274
+ return precision_scores, precision_per_class, unique_classes
275
+
276
+ @staticmethod
277
+ def _match_detection_batch(
278
+ predictions_classes: np.ndarray,
279
+ target_classes: np.ndarray,
280
+ iou: np.ndarray,
281
+ iou_thresholds: np.ndarray,
282
+ ) -> np.ndarray:
283
+ num_predictions, num_iou_levels = (
284
+ predictions_classes.shape[0],
285
+ iou_thresholds.shape[0],
286
+ )
287
+ correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)
288
+ correct_class = target_classes[:, None] == predictions_classes
289
+
290
+ for i, iou_level in enumerate(iou_thresholds):
291
+ matched_indices = np.where((iou >= iou_level) & correct_class)
292
+
293
+ if matched_indices[0].shape[0]:
294
+ combined_indices = np.stack(matched_indices, axis=1)
295
+ iou_values = iou[matched_indices][:, None]
296
+ matches = np.hstack([combined_indices, iou_values])
297
+
298
+ if matched_indices[0].shape[0] > 1:
299
+ matches = matches[matches[:, 2].argsort()[::-1]]
300
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
301
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
302
+
303
+ correct[matches[:, 1].astype(int), i] = True
304
+
305
+ return correct
306
+
307
+ @staticmethod
308
+ def _compute_confusion_matrix(
309
+ sorted_matches: np.ndarray,
310
+ sorted_prediction_class_ids: np.ndarray,
311
+ unique_classes: np.ndarray,
312
+ class_counts: np.ndarray,
313
+ ) -> np.ndarray:
314
+ """
315
+ Compute the confusion matrix for each class and IoU threshold.
316
+
317
+ Assumes the matches and prediction_class_ids are sorted by confidence
318
+ in descending order.
319
+
320
+ Arguments:
321
+ sorted_matches: np.ndarray, bool, shape (P, Th), that is True
322
+ if the prediction is a true positive at the given IoU threshold.
323
+ sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing
324
+ the class id for each prediction.
325
+ unique_classes: np.ndarray, int, shape (C,), containing the unique
326
+ class ids.
327
+ class_counts: np.ndarray, int, shape (C,), containing the number
328
+ of true instances for each class.
329
+
330
+ Returns:
331
+ np.ndarray, shape (C, Th, 3), containing the true positives, false
332
+ positives, and false negatives for each class and IoU threshold.
333
+ """
334
+
335
+ num_thresholds = sorted_matches.shape[1]
336
+ num_classes = unique_classes.shape[0]
337
+
338
+ confusion_matrix = np.zeros((num_classes, num_thresholds, 3))
339
+ for class_idx, class_id in enumerate(unique_classes):
340
+ is_class = sorted_prediction_class_ids == class_id
341
+ num_true = class_counts[class_idx]
342
+ num_predictions = is_class.sum()
343
+
344
+ if num_predictions == 0:
345
+ true_positives = np.zeros(num_thresholds)
346
+ false_positives = np.zeros(num_thresholds)
347
+ false_negatives = np.full(num_thresholds, num_true)
348
+ elif num_true == 0:
349
+ true_positives = np.zeros(num_thresholds)
350
+ false_positives = np.full(num_thresholds, num_predictions)
351
+ false_negatives = np.zeros(num_thresholds)
352
+ else:
353
+ true_positives = sorted_matches[is_class].sum(0)
354
+ false_positives = (1 - sorted_matches[is_class]).sum(0)
355
+ false_negatives = num_true - true_positives
356
+ confusion_matrix[class_idx] = np.stack(
357
+ [true_positives, false_positives, false_negatives], axis=1
358
+ )
359
+
360
+ return confusion_matrix
361
+
362
+ @staticmethod
363
+ def _compute_precision(confusion_matrix: np.ndarray) -> np.ndarray:
364
+ """
365
+ Broadcastable function, computing the precision from the confusion matrix.
366
+
367
+ Arguments:
368
+ confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension
369
+ contains the true positives, false positives, and false negatives.
370
+
371
+ Returns:
372
+ np.ndarray, shape (N, ...), containing the precision for each element.
373
+ """
374
+ if not confusion_matrix.shape[-1] == 3:
375
+ raise ValueError(
376
+ f"Confusion matrix must have shape (..., 3), got "
377
+ f"{confusion_matrix.shape}"
378
+ )
379
+ true_positives = confusion_matrix[..., 0]
380
+ false_positives = confusion_matrix[..., 1]
381
+
382
+ denominator = true_positives + false_positives
383
+ precision = np.where(denominator == 0, 0, true_positives / denominator)
384
+
385
+ return precision
386
+
387
+ def _detections_content(self, detections: Detections) -> np.ndarray:
388
+ """Return boxes, masks or oriented bounding boxes from detections."""
389
+ if self._metric_target == MetricTarget.BOXES:
390
+ return detections.xyxy
391
+ if self._metric_target == MetricTarget.MASKS:
392
+ return (
393
+ detections.mask
394
+ if detections.mask is not None
395
+ else self._make_empty_content()
396
+ )
397
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
398
+ obb = detections.data.get(ORIENTED_BOX_COORDINATES)
399
+ if obb is not None and len(obb) > 0:
400
+ return np.array(obb, dtype=np.float32)
401
+ return self._make_empty_content()
402
+ raise ValueError(f"Invalid metric target: {self._metric_target}")
403
+
404
+ def _make_empty_content(self) -> np.ndarray:
405
+ if self._metric_target == MetricTarget.BOXES:
406
+ return np.empty((0, 4), dtype=np.float32)
407
+ if self._metric_target == MetricTarget.MASKS:
408
+ return np.empty((0, 0, 0), dtype=bool)
409
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
410
+ return np.empty((0, 4, 2), dtype=np.float32)
411
+ raise ValueError(f"Invalid metric target: {self._metric_target}")
412
+
413
+ def _filter_detections_by_size(
414
+ self, detections: Detections, size_category: ObjectSizeCategory
415
+ ) -> Detections:
416
+ """Return a copy of detections with contents filtered by object size."""
417
+ new_detections = deepcopy(detections)
418
+ if detections.is_empty() or size_category == ObjectSizeCategory.ANY:
419
+ return new_detections
420
+
421
+ sizes = get_detection_size_category(new_detections, self._metric_target)
422
+ size_mask = sizes == size_category.value
423
+
424
+ new_detections.xyxy = new_detections.xyxy[size_mask]
425
+ if new_detections.mask is not None:
426
+ new_detections.mask = new_detections.mask[size_mask]
427
+ if new_detections.class_id is not None:
428
+ new_detections.class_id = new_detections.class_id[size_mask]
429
+ if new_detections.confidence is not None:
430
+ new_detections.confidence = new_detections.confidence[size_mask]
431
+ if new_detections.tracker_id is not None:
432
+ new_detections.tracker_id = new_detections.tracker_id[size_mask]
433
+ if new_detections.data is not None:
434
+ for key, value in new_detections.data.items():
435
+ new_detections.data[key] = np.array(value)[size_mask]
436
+
437
+ return new_detections
438
+
439
+ def _filter_predictions_and_targets_by_size(
440
+ self,
441
+ predictions_list: List[Detections],
442
+ targets_list: List[Detections],
443
+ size_category: ObjectSizeCategory,
444
+ ) -> Tuple[List[Detections], List[Detections]]:
445
+ """
446
+ Filter predictions and targets by object size category.
447
+ """
448
+ new_predictions_list = []
449
+ new_targets_list = []
450
+ for predictions, targets in zip(predictions_list, targets_list):
451
+ new_predictions_list.append(
452
+ self._filter_detections_by_size(predictions, size_category)
453
+ )
454
+ new_targets_list.append(
455
+ self._filter_detections_by_size(targets, size_category)
456
+ )
457
+ return new_predictions_list, new_targets_list
458
+
459
+
460
+ @dataclass
461
+ class PrecisionResult:
462
+ """
463
+ The results of the precision metric calculation.
464
+
465
+ Defaults to `0` if no detections or targets were provided.
466
+
467
+ Attributes:
468
+ metric_target (MetricTarget): the type of data used for the metric -
469
+ boxes, masks or oriented bounding boxes.
470
+ averaging_method (AveragingMethod): the averaging method used to compute the
471
+ precision. Determines how the precision is aggregated across classes.
472
+ precision_at_50 (float): the precision at IoU threshold of `0.5`.
473
+ precision_at_75 (float): the precision at IoU threshold of `0.75`.
474
+ precision_scores (np.ndarray): the precision scores at each IoU threshold.
475
+ Shape: `(num_iou_thresholds,)`
476
+ precision_per_class (np.ndarray): the precision scores per class and
477
+ IoU threshold. Shape: `(num_target_classes, num_iou_thresholds)`
478
+ iou_thresholds (np.ndarray): the IoU thresholds used in the calculations.
479
+ matched_classes (np.ndarray): the class IDs of all matched classes.
480
+ Corresponds to the rows of `precision_per_class`.
481
+ small_objects (Optional[PrecisionResult]): the Precision metric results
482
+ for small objects (area < 32²).
483
+ medium_objects (Optional[PrecisionResult]): the Precision metric results
484
+ for medium objects (32² ≤ area < 96²).
485
+ large_objects (Optional[PrecisionResult]): the Precision metric results
486
+ for large objects (area ≥ 96²).
487
+ """
488
+
489
+ metric_target: MetricTarget
490
+ averaging_method: AveragingMethod
491
+
492
+ @property
493
+ def precision_at_50(self) -> float:
494
+ return self.precision_scores[0]
495
+
496
+ @property
497
+ def precision_at_75(self) -> float:
498
+ return self.precision_scores[5]
499
+
500
+ precision_scores: np.ndarray
501
+ precision_per_class: np.ndarray
502
+ iou_thresholds: np.ndarray
503
+ matched_classes: np.ndarray
504
+
505
+ small_objects: Optional[PrecisionResult]
506
+ medium_objects: Optional[PrecisionResult]
507
+ large_objects: Optional[PrecisionResult]
508
+
509
+ def __str__(self) -> str:
510
+ """
511
+ Format as a pretty string.
512
+
513
+ Example:
514
+ ```python
515
+ print(precision_result)
516
+ # PrecisionResult:
517
+ # Metric target: MetricTarget.BOXES
518
+ # Averaging method: AveragingMethod.WEIGHTED
519
+ # P @ 50: 0.8099
520
+ # P @ 75: 0.7969
521
+ # P @ thresh: [0.80992 0.80905 0.80905 ...]
522
+ # IoU thresh: [0.5 0.55 0.6 ...]
523
+ # Precision per class:
524
+ # 0: [0.64706 0.64706 0.64706 ...]
525
+ # ...
526
+ # Small objects: ...
527
+ # Medium objects: ...
528
+ # Large objects: ...
529
+ ```
530
+ """
531
+ out_str = (
532
+ f"{self.__class__.__name__}:\n"
533
+ f"Metric target: {self.metric_target}\n"
534
+ f"Averaging method: {self.averaging_method}\n"
535
+ f"P @ 50: {self.precision_at_50:.4f}\n"
536
+ f"P @ 75: {self.precision_at_75:.4f}\n"
537
+ f"P @ thresh: {self.precision_scores}\n"
538
+ f"IoU thresh: {self.iou_thresholds}\n"
539
+ f"Precision per class:\n"
540
+ )
541
+ if self.precision_per_class.size == 0:
542
+ out_str += " No results\n"
543
+ for class_id, precision_of_class in zip(
544
+ self.matched_classes, self.precision_per_class
545
+ ):
546
+ out_str += f" {class_id}: {precision_of_class}\n"
547
+
548
+ indent = " "
549
+ if self.small_objects is not None:
550
+ indented = indent + str(self.small_objects).replace("\n", f"\n{indent}")
551
+ out_str += f"\nSmall objects:\n{indented}"
552
+ if self.medium_objects is not None:
553
+ indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}")
554
+ out_str += f"\nMedium objects:\n{indented}"
555
+ if self.large_objects is not None:
556
+ indented = indent + str(self.large_objects).replace("\n", f"\n{indent}")
557
+ out_str += f"\nLarge objects:\n{indented}"
558
+
559
+ return out_str
560
+
561
+ def to_pandas(self) -> "pd.DataFrame":
562
+ """
563
+ Convert the result to a pandas DataFrame.
564
+
565
+ Returns:
566
+ (pd.DataFrame): The result as a DataFrame.
567
+ """
568
+ ensure_pandas_installed()
569
+ import pandas as pd
570
+
571
+ pandas_data = {
572
+ "P@50": self.precision_at_50,
573
+ "P@75": self.precision_at_75,
574
+ }
575
+
576
+ if self.small_objects is not None:
577
+ small_objects_df = self.small_objects.to_pandas()
578
+ for key, value in small_objects_df.items():
579
+ pandas_data[f"small_objects_{key}"] = value
580
+ if self.medium_objects is not None:
581
+ medium_objects_df = self.medium_objects.to_pandas()
582
+ for key, value in medium_objects_df.items():
583
+ pandas_data[f"medium_objects_{key}"] = value
584
+ if self.large_objects is not None:
585
+ large_objects_df = self.large_objects.to_pandas()
586
+ for key, value in large_objects_df.items():
587
+ pandas_data[f"large_objects_{key}"] = value
588
+
589
+ return pd.DataFrame(pandas_data, index=[0])
590
+
591
+ def plot(self):
592
+ """
593
+ Plot the precision results.
594
+
595
+ ![example_plot](\
596
+ https://media.roboflow.com/eye-docs/metrics/precision_plot_example.png\
597
+ ){ align=center width="800" }
598
+ """
599
+
600
+ labels = ["Precision@50", "Precision@75"]
601
+ values = [self.precision_at_50, self.precision_at_75]
602
+ colors = [LEGACY_COLOR_PALETTE[0]] * 2
603
+
604
+ if self.small_objects is not None:
605
+ small_objects = self.small_objects
606
+ labels += ["Small: P@50", "Small: P@75"]
607
+ values += [small_objects.precision_at_50, small_objects.precision_at_75]
608
+ colors += [LEGACY_COLOR_PALETTE[3]] * 2
609
+
610
+ if self.medium_objects is not None:
611
+ medium_objects = self.medium_objects
612
+ labels += ["Medium: P@50", "Medium: P@75"]
613
+ values += [medium_objects.precision_at_50, medium_objects.precision_at_75]
614
+ colors += [LEGACY_COLOR_PALETTE[2]] * 2
615
+
616
+ if self.large_objects is not None:
617
+ large_objects = self.large_objects
618
+ labels += ["Large: P@50", "Large: P@75"]
619
+ values += [large_objects.precision_at_50, large_objects.precision_at_75]
620
+ colors += [LEGACY_COLOR_PALETTE[4]] * 2
621
+
622
+ plt.rcParams["font.family"] = "monospace"
623
+
624
+ _, ax = plt.subplots(figsize=(10, 6))
625
+ ax.set_ylim(0, 1)
626
+ ax.set_ylabel("Value", fontweight="bold")
627
+ title = (
628
+ f"Precision, by Object Size"
629
+ f"\n(target: {self.metric_target.value},"
630
+ f" averaging: {self.averaging_method.value})"
631
+ )
632
+ ax.set_title(title, fontweight="bold")
633
+
634
+ x_positions = range(len(labels))
635
+ bars = ax.bar(x_positions, values, color=colors, align="center")
636
+
637
+ ax.set_xticks(x_positions)
638
+ ax.set_xticklabels(labels, rotation=45, ha="right")
639
+
640
+ for bar in bars:
641
+ y_value = bar.get_height()
642
+ ax.text(
643
+ bar.get_x() + bar.get_width() / 2,
644
+ y_value + 0.02,
645
+ f"{y_value:.2f}",
646
+ ha="center",
647
+ va="bottom",
648
+ )
649
+
650
+ plt.rcParams["font.family"] = "sans-serif"
651
+
652
+ plt.tight_layout()
653
+ plt.show()