eye-cv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eye/__init__.py +115 -0
  2. eye/__init___supervision_original.py +120 -0
  3. eye/annotators/__init__.py +0 -0
  4. eye/annotators/base.py +22 -0
  5. eye/annotators/core.py +2699 -0
  6. eye/annotators/line.py +107 -0
  7. eye/annotators/modern.py +529 -0
  8. eye/annotators/trace.py +142 -0
  9. eye/annotators/utils.py +177 -0
  10. eye/assets/__init__.py +2 -0
  11. eye/assets/downloader.py +95 -0
  12. eye/assets/list.py +83 -0
  13. eye/classification/__init__.py +0 -0
  14. eye/classification/core.py +188 -0
  15. eye/config.py +2 -0
  16. eye/core/__init__.py +0 -0
  17. eye/core/trackers/__init__.py +1 -0
  18. eye/core/trackers/botsort_tracker.py +336 -0
  19. eye/core/trackers/bytetrack_tracker.py +284 -0
  20. eye/core/trackers/sort_tracker.py +200 -0
  21. eye/core/tracking.py +146 -0
  22. eye/dataset/__init__.py +0 -0
  23. eye/dataset/core.py +919 -0
  24. eye/dataset/formats/__init__.py +0 -0
  25. eye/dataset/formats/coco.py +258 -0
  26. eye/dataset/formats/pascal_voc.py +279 -0
  27. eye/dataset/formats/yolo.py +272 -0
  28. eye/dataset/utils.py +259 -0
  29. eye/detection/__init__.py +0 -0
  30. eye/detection/auto_convert.py +155 -0
  31. eye/detection/core.py +1529 -0
  32. eye/detection/detections_enhanced.py +392 -0
  33. eye/detection/line_zone.py +859 -0
  34. eye/detection/lmm.py +184 -0
  35. eye/detection/overlap_filter.py +270 -0
  36. eye/detection/tools/__init__.py +0 -0
  37. eye/detection/tools/csv_sink.py +181 -0
  38. eye/detection/tools/inference_slicer.py +288 -0
  39. eye/detection/tools/json_sink.py +142 -0
  40. eye/detection/tools/polygon_zone.py +202 -0
  41. eye/detection/tools/smoother.py +123 -0
  42. eye/detection/tools/smoothing.py +179 -0
  43. eye/detection/tools/smoothing_config.py +202 -0
  44. eye/detection/tools/transformers.py +247 -0
  45. eye/detection/utils.py +1175 -0
  46. eye/draw/__init__.py +0 -0
  47. eye/draw/color.py +154 -0
  48. eye/draw/utils.py +374 -0
  49. eye/filters.py +112 -0
  50. eye/geometry/__init__.py +0 -0
  51. eye/geometry/core.py +128 -0
  52. eye/geometry/utils.py +47 -0
  53. eye/keypoint/__init__.py +0 -0
  54. eye/keypoint/annotators.py +442 -0
  55. eye/keypoint/core.py +687 -0
  56. eye/keypoint/skeletons.py +2647 -0
  57. eye/metrics/__init__.py +21 -0
  58. eye/metrics/core.py +72 -0
  59. eye/metrics/detection.py +843 -0
  60. eye/metrics/f1_score.py +648 -0
  61. eye/metrics/mean_average_precision.py +628 -0
  62. eye/metrics/mean_average_recall.py +697 -0
  63. eye/metrics/precision.py +653 -0
  64. eye/metrics/recall.py +652 -0
  65. eye/metrics/utils/__init__.py +0 -0
  66. eye/metrics/utils/object_size.py +158 -0
  67. eye/metrics/utils/utils.py +9 -0
  68. eye/py.typed +0 -0
  69. eye/quick.py +104 -0
  70. eye/tracker/__init__.py +0 -0
  71. eye/tracker/byte_tracker/__init__.py +0 -0
  72. eye/tracker/byte_tracker/core.py +386 -0
  73. eye/tracker/byte_tracker/kalman_filter.py +205 -0
  74. eye/tracker/byte_tracker/matching.py +69 -0
  75. eye/tracker/byte_tracker/single_object_track.py +178 -0
  76. eye/tracker/byte_tracker/utils.py +18 -0
  77. eye/utils/__init__.py +0 -0
  78. eye/utils/conversion.py +132 -0
  79. eye/utils/file.py +159 -0
  80. eye/utils/image.py +794 -0
  81. eye/utils/internal.py +200 -0
  82. eye/utils/iterables.py +84 -0
  83. eye/utils/notebook.py +114 -0
  84. eye/utils/video.py +307 -0
  85. eye/utils_eye/__init__.py +1 -0
  86. eye/utils_eye/geometry.py +71 -0
  87. eye/utils_eye/nms.py +55 -0
  88. eye/validators/__init__.py +140 -0
  89. eye/web.py +271 -0
  90. eye_cv-1.0.0.dist-info/METADATA +319 -0
  91. eye_cv-1.0.0.dist-info/RECORD +94 -0
  92. eye_cv-1.0.0.dist-info/WHEEL +5 -0
  93. eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
  94. eye_cv-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,648 @@
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ from matplotlib import pyplot as plt
9
+
10
+ from eye.config import ORIENTED_BOX_COORDINATES
11
+ from eye.detection.core import Detections
12
+ from eye.detection.utils import (
13
+ box_iou_batch,
14
+ mask_iou_batch,
15
+ oriented_box_iou_batch,
16
+ )
17
+ from eye.draw.color import LEGACY_COLOR_PALETTE
18
+ from eye.metrics.core import AveragingMethod, Metric, MetricTarget
19
+ from eye.metrics.utils.object_size import (
20
+ ObjectSizeCategory,
21
+ get_detection_size_category,
22
+ )
23
+ from eye.metrics.utils.utils import ensure_pandas_installed
24
+
25
+ if TYPE_CHECKING:
26
+ import pandas as pd
27
+
28
+
29
+ class F1Score(Metric):
30
+ """
31
+ F1 Score is a metric used to evaluate object detection models. It is the harmonic
32
+ mean of precision and recall, calculated at different IoU thresholds.
33
+
34
+ In simple terms, F1 Score is a measure of a model's balance between precision and
35
+ recall (accuracy and completeness), calculated as:
36
+
37
+ `F1 = 2 * (precision * recall) / (precision + recall)`
38
+
39
+ Example:
40
+ ```python
41
+ import eye as sv
42
+ from eye.metrics import F1Score
43
+
44
+ predictions = sv.Detections(...)
45
+ targets = sv.Detections(...)
46
+
47
+ f1_metric = F1Score()
48
+ f1_result = f1_metric.update(predictions, targets).compute()
49
+
50
+ print(f1_result.f1_50)
51
+ # 0.7618
52
+
53
+ print(f1_result)
54
+ # F1ScoreResult:
55
+ # Metric target: MetricTarget.BOXES
56
+ # Averaging method: AveragingMethod.WEIGHTED
57
+ # F1 @ 50: 0.7618
58
+ # F1 @ 75: 0.7487
59
+ # F1 @ thresh: [0.76175 0.76068 0.76068]
60
+ # IoU thresh: [0.5 0.55 0.6 ...]
61
+ # F1 per class:
62
+ # 0: [0.70968 0.70968 0.70968 ...]
63
+ # ...
64
+ # Small objects: ...
65
+ # Medium objects: ...
66
+ # Large objects: ...
67
+
68
+ f1_result.plot()
69
+ ```
70
+
71
+ ![example_plot](\
72
+ https://media.roboflow.com/eye-docs/metrics/f1_plot_example.png\
73
+ ){ align=center width="800" }
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ metric_target: MetricTarget = MetricTarget.BOXES,
79
+ averaging_method: AveragingMethod = AveragingMethod.WEIGHTED,
80
+ ):
81
+ """
82
+ Initialize the F1Score metric.
83
+
84
+ Args:
85
+ metric_target (MetricTarget): The type of detection data to use.
86
+ averaging_method (AveragingMethod): The averaging method used to compute the
87
+ F1 scores. Determines how the F1 scores are aggregated across classes.
88
+ """
89
+ self._metric_target = metric_target
90
+ self.averaging_method = averaging_method
91
+
92
+ self._predictions_list: List[Detections] = []
93
+ self._targets_list: List[Detections] = []
94
+
95
+ def reset(self) -> None:
96
+ """
97
+ Reset the metric to its initial state, clearing all stored data.
98
+ """
99
+ self._predictions_list = []
100
+ self._targets_list = []
101
+
102
+ def update(
103
+ self,
104
+ predictions: Union[Detections, List[Detections]],
105
+ targets: Union[Detections, List[Detections]],
106
+ ) -> F1Score:
107
+ """
108
+ Add new predictions and targets to the metric, but do not compute the result.
109
+
110
+ Args:
111
+ predictions (Union[Detections, List[Detections]]): The predicted detections.
112
+ targets (Union[Detections, List[Detections]]): The target detections.
113
+
114
+ Returns:
115
+ (F1Score): The updated metric instance.
116
+ """
117
+ if not isinstance(predictions, list):
118
+ predictions = [predictions]
119
+ if not isinstance(targets, list):
120
+ targets = [targets]
121
+
122
+ if len(predictions) != len(targets):
123
+ raise ValueError(
124
+ f"The number of predictions ({len(predictions)}) and"
125
+ f" targets ({len(targets)}) during the update must be the same."
126
+ )
127
+
128
+ self._predictions_list.extend(predictions)
129
+ self._targets_list.extend(targets)
130
+
131
+ return self
132
+
133
+ def compute(self) -> F1ScoreResult:
134
+ """
135
+ Calculate the F1 score metric based on the stored predictions and ground-truth
136
+ data, at different IoU thresholds.
137
+
138
+ Returns:
139
+ (F1ScoreResult): The F1 score metric result.
140
+ """
141
+ result = self._compute(self._predictions_list, self._targets_list)
142
+
143
+ small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
144
+ self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL
145
+ )
146
+ result.small_objects = self._compute(small_predictions, small_targets)
147
+
148
+ medium_predictions, medium_targets = (
149
+ self._filter_predictions_and_targets_by_size(
150
+ self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM
151
+ )
152
+ )
153
+ result.medium_objects = self._compute(medium_predictions, medium_targets)
154
+
155
+ large_predictions, large_targets = self._filter_predictions_and_targets_by_size(
156
+ self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE
157
+ )
158
+ result.large_objects = self._compute(large_predictions, large_targets)
159
+
160
+ return result
161
+
162
+ def _compute(
163
+ self, predictions_list: List[Detections], targets_list: List[Detections]
164
+ ) -> F1ScoreResult:
165
+ iou_thresholds = np.linspace(0.5, 0.95, 10)
166
+ stats = []
167
+
168
+ for predictions, targets in zip(predictions_list, targets_list):
169
+ prediction_contents = self._detections_content(predictions)
170
+ target_contents = self._detections_content(targets)
171
+
172
+ if len(targets) > 0:
173
+ if len(predictions) == 0:
174
+ stats.append(
175
+ (
176
+ np.zeros((0, iou_thresholds.size), dtype=bool),
177
+ np.zeros((0,), dtype=np.float32),
178
+ np.zeros((0,), dtype=int),
179
+ targets.class_id,
180
+ )
181
+ )
182
+
183
+ else:
184
+ if self._metric_target == MetricTarget.BOXES:
185
+ iou = box_iou_batch(target_contents, prediction_contents)
186
+ elif self._metric_target == MetricTarget.MASKS:
187
+ iou = mask_iou_batch(target_contents, prediction_contents)
188
+ elif self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
189
+ iou = oriented_box_iou_batch(
190
+ target_contents, prediction_contents
191
+ )
192
+ else:
193
+ raise ValueError(
194
+ "Unsupported metric target for IoU calculation"
195
+ )
196
+
197
+ matches = self._match_detection_batch(
198
+ predictions.class_id, targets.class_id, iou, iou_thresholds
199
+ )
200
+ stats.append(
201
+ (
202
+ matches,
203
+ predictions.confidence,
204
+ predictions.class_id,
205
+ targets.class_id,
206
+ )
207
+ )
208
+
209
+ if not stats:
210
+ return F1ScoreResult(
211
+ metric_target=self._metric_target,
212
+ averaging_method=self.averaging_method,
213
+ f1_scores=np.zeros(iou_thresholds.shape[0]),
214
+ f1_per_class=np.zeros((0, iou_thresholds.shape[0])),
215
+ iou_thresholds=iou_thresholds,
216
+ matched_classes=np.array([], dtype=int),
217
+ small_objects=None,
218
+ medium_objects=None,
219
+ large_objects=None,
220
+ )
221
+
222
+ concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
223
+ f1_scores, f1_per_class, unique_classes = self._compute_f1_for_classes(
224
+ *concatenated_stats
225
+ )
226
+
227
+ return F1ScoreResult(
228
+ metric_target=self._metric_target,
229
+ averaging_method=self.averaging_method,
230
+ f1_scores=f1_scores,
231
+ f1_per_class=f1_per_class,
232
+ iou_thresholds=iou_thresholds,
233
+ matched_classes=unique_classes,
234
+ small_objects=None,
235
+ medium_objects=None,
236
+ large_objects=None,
237
+ )
238
+
239
+ def _compute_f1_for_classes(
240
+ self,
241
+ matches: np.ndarray,
242
+ prediction_confidence: np.ndarray,
243
+ prediction_class_ids: np.ndarray,
244
+ true_class_ids: np.ndarray,
245
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
246
+ sorted_indices = np.argsort(-prediction_confidence)
247
+ matches = matches[sorted_indices]
248
+ prediction_class_ids = prediction_class_ids[sorted_indices]
249
+ unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
250
+
251
+ # Shape: PxTh,P,C,C -> CxThx3
252
+ confusion_matrix = self._compute_confusion_matrix(
253
+ matches, prediction_class_ids, unique_classes, class_counts
254
+ )
255
+
256
+ # Shape: CxThx3 -> CxTh
257
+ f1_per_class = self._compute_f1(confusion_matrix)
258
+
259
+ # Shape: CxTh -> Th
260
+ if self.averaging_method == AveragingMethod.MACRO:
261
+ f1_scores = np.mean(f1_per_class, axis=0)
262
+ elif self.averaging_method == AveragingMethod.MICRO:
263
+ confusion_matrix_merged = confusion_matrix.sum(0)
264
+ f1_scores = self._compute_f1(confusion_matrix_merged)
265
+ elif self.averaging_method == AveragingMethod.WEIGHTED:
266
+ class_counts = class_counts.astype(np.float32)
267
+ f1_scores = np.average(f1_per_class, axis=0, weights=class_counts)
268
+
269
+ return f1_scores, f1_per_class, unique_classes
270
+
271
+ @staticmethod
272
+ def _match_detection_batch(
273
+ predictions_classes: np.ndarray,
274
+ target_classes: np.ndarray,
275
+ iou: np.ndarray,
276
+ iou_thresholds: np.ndarray,
277
+ ) -> np.ndarray:
278
+ num_predictions, num_iou_levels = (
279
+ predictions_classes.shape[0],
280
+ iou_thresholds.shape[0],
281
+ )
282
+ correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)
283
+ correct_class = target_classes[:, None] == predictions_classes
284
+
285
+ for i, iou_level in enumerate(iou_thresholds):
286
+ matched_indices = np.where((iou >= iou_level) & correct_class)
287
+
288
+ if matched_indices[0].shape[0]:
289
+ combined_indices = np.stack(matched_indices, axis=1)
290
+ iou_values = iou[matched_indices][:, None]
291
+ matches = np.hstack([combined_indices, iou_values])
292
+
293
+ if matched_indices[0].shape[0] > 1:
294
+ matches = matches[matches[:, 2].argsort()[::-1]]
295
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
296
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
297
+
298
+ correct[matches[:, 1].astype(int), i] = True
299
+
300
+ return correct
301
+
302
+ @staticmethod
303
+ def _compute_confusion_matrix(
304
+ sorted_matches: np.ndarray,
305
+ sorted_prediction_class_ids: np.ndarray,
306
+ unique_classes: np.ndarray,
307
+ class_counts: np.ndarray,
308
+ ) -> np.ndarray:
309
+ """
310
+ Compute the confusion matrix for each class and IoU threshold.
311
+
312
+ Assumes the matches and prediction_class_ids are sorted by confidence
313
+ in descending order.
314
+
315
+ Arguments:
316
+ sorted_matches: np.ndarray, bool, shape (P, Th), that is True
317
+ if the prediction is a true positive at the given IoU threshold.
318
+ sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing
319
+ the class id for each prediction.
320
+ unique_classes: np.ndarray, int, shape (C,), containing the unique
321
+ class ids.
322
+ class_counts: np.ndarray, int, shape (C,), containing the number
323
+ of true instances for each class.
324
+
325
+ Returns:
326
+ np.ndarray, shape (C, Th, 3), containing the true positives, false
327
+ positives, and false negatives for each class and IoU threshold.
328
+ """
329
+
330
+ num_thresholds = sorted_matches.shape[1]
331
+ num_classes = unique_classes.shape[0]
332
+
333
+ confusion_matrix = np.zeros((num_classes, num_thresholds, 3))
334
+ for class_idx, class_id in enumerate(unique_classes):
335
+ is_class = sorted_prediction_class_ids == class_id
336
+ num_true = class_counts[class_idx]
337
+ num_predictions = is_class.sum()
338
+
339
+ if num_predictions == 0:
340
+ true_positives = np.zeros(num_thresholds)
341
+ false_positives = np.zeros(num_thresholds)
342
+ false_negatives = np.full(num_thresholds, num_true)
343
+ elif num_true == 0:
344
+ true_positives = np.zeros(num_thresholds)
345
+ false_positives = np.full(num_thresholds, num_predictions)
346
+ false_negatives = np.zeros(num_thresholds)
347
+ else:
348
+ true_positives = sorted_matches[is_class].sum(0)
349
+ false_positives = (1 - sorted_matches[is_class]).sum(0)
350
+ false_negatives = num_true - true_positives
351
+ confusion_matrix[class_idx] = np.stack(
352
+ [true_positives, false_positives, false_negatives], axis=1
353
+ )
354
+
355
+ return confusion_matrix
356
+
357
+ @staticmethod
358
+ def _compute_f1(confusion_matrix: np.ndarray) -> np.ndarray:
359
+ """
360
+ Broadcastable function, computing the F1 score from the confusion matrix.
361
+
362
+ Arguments:
363
+ confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension
364
+ contains the true positives, false positives, and false negatives.
365
+
366
+ Returns:
367
+ np.ndarray, shape (N, ...), containing the F1 score for each element.
368
+ """
369
+ if not confusion_matrix.shape[-1] == 3:
370
+ raise ValueError(
371
+ f"Confusion matrix must have shape (..., 3), got "
372
+ f"{confusion_matrix.shape}"
373
+ )
374
+ true_positives = confusion_matrix[..., 0]
375
+ false_positives = confusion_matrix[..., 1]
376
+ false_negatives = confusion_matrix[..., 2]
377
+
378
+ # Alternate formula, avoids multiple zero division checks
379
+ denominator = 2 * true_positives + false_positives + false_negatives
380
+ f1_score = np.where(denominator == 0, 0, 2 * true_positives / denominator)
381
+
382
+ return f1_score
383
+
384
+ def _detections_content(self, detections: Detections) -> np.ndarray:
385
+ """Return boxes, masks or oriented bounding boxes from detections."""
386
+ if self._metric_target == MetricTarget.BOXES:
387
+ return detections.xyxy
388
+ if self._metric_target == MetricTarget.MASKS:
389
+ return (
390
+ detections.mask
391
+ if detections.mask is not None
392
+ else self._make_empty_content()
393
+ )
394
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
395
+ obb = detections.data.get(ORIENTED_BOX_COORDINATES)
396
+ if obb is not None and len(obb) > 0:
397
+ return np.array(obb, dtype=np.float32)
398
+ return self._make_empty_content()
399
+ raise ValueError(f"Invalid metric target: {self._metric_target}")
400
+
401
+ def _make_empty_content(self) -> np.ndarray:
402
+ if self._metric_target == MetricTarget.BOXES:
403
+ return np.empty((0, 4), dtype=np.float32)
404
+ if self._metric_target == MetricTarget.MASKS:
405
+ return np.empty((0, 0, 0), dtype=bool)
406
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
407
+ return np.empty((0, 4, 2), dtype=np.float32)
408
+ raise ValueError(f"Invalid metric target: {self._metric_target}")
409
+
410
+ def _filter_detections_by_size(
411
+ self, detections: Detections, size_category: ObjectSizeCategory
412
+ ) -> Detections:
413
+ """Return a copy of detections with contents filtered by object size."""
414
+ new_detections = deepcopy(detections)
415
+ if detections.is_empty() or size_category == ObjectSizeCategory.ANY:
416
+ return new_detections
417
+
418
+ sizes = get_detection_size_category(new_detections, self._metric_target)
419
+ size_mask = sizes == size_category.value
420
+
421
+ new_detections.xyxy = new_detections.xyxy[size_mask]
422
+ if new_detections.mask is not None:
423
+ new_detections.mask = new_detections.mask[size_mask]
424
+ if new_detections.class_id is not None:
425
+ new_detections.class_id = new_detections.class_id[size_mask]
426
+ if new_detections.confidence is not None:
427
+ new_detections.confidence = new_detections.confidence[size_mask]
428
+ if new_detections.tracker_id is not None:
429
+ new_detections.tracker_id = new_detections.tracker_id[size_mask]
430
+ if new_detections.data is not None:
431
+ for key, value in new_detections.data.items():
432
+ new_detections.data[key] = np.array(value)[size_mask]
433
+
434
+ return new_detections
435
+
436
+ def _filter_predictions_and_targets_by_size(
437
+ self,
438
+ predictions_list: List[Detections],
439
+ targets_list: List[Detections],
440
+ size_category: ObjectSizeCategory,
441
+ ) -> Tuple[List[Detections], List[Detections]]:
442
+ """
443
+ Filter predictions and targets by object size category.
444
+ """
445
+ new_predictions_list = []
446
+ new_targets_list = []
447
+ for predictions, targets in zip(predictions_list, targets_list):
448
+ new_predictions_list.append(
449
+ self._filter_detections_by_size(predictions, size_category)
450
+ )
451
+ new_targets_list.append(
452
+ self._filter_detections_by_size(targets, size_category)
453
+ )
454
+ return new_predictions_list, new_targets_list
455
+
456
+
457
+ @dataclass
458
+ class F1ScoreResult:
459
+ """
460
+ The results of the F1 score metric calculation.
461
+
462
+ Defaults to `0` if no detections or targets were provided.
463
+
464
+ Attributes:
465
+ metric_target (MetricTarget): the type of data used for the metric -
466
+ boxes, masks or oriented bounding boxes.
467
+ averaging_method (AveragingMethod): the averaging method used to compute the
468
+ F1 scores. Determines how the F1 scores are aggregated across classes.
469
+ f1_50 (float): the F1 score at IoU threshold of `0.5`.
470
+ f1_75 (float): the F1 score at IoU threshold of `0.75`.
471
+ f1_scores (np.ndarray): the F1 scores at each IoU threshold.
472
+ Shape: `(num_iou_thresholds,)`
473
+ f1_per_class (np.ndarray): the F1 scores per class and IoU threshold.
474
+ Shape: `(num_target_classes, num_iou_thresholds)`
475
+ iou_thresholds (np.ndarray): the IoU thresholds used in the calculations.
476
+ matched_classes (np.ndarray): the class IDs of all matched classes.
477
+ Corresponds to the rows of `f1_per_class`.
478
+ small_objects (Optional[F1ScoreResult]): the F1 metric results
479
+ for small objects (area < 32²).
480
+ medium_objects (Optional[F1ScoreResult]): the F1 metric results
481
+ for medium objects (32² ≤ area < 96²).
482
+ large_objects (Optional[F1ScoreResult]): the F1 metric results
483
+ for large objects (area ≥ 96²).
484
+ """
485
+
486
+ metric_target: MetricTarget
487
+ averaging_method: AveragingMethod
488
+
489
+ @property
490
+ def f1_50(self) -> float:
491
+ return self.f1_scores[0]
492
+
493
+ @property
494
+ def f1_75(self) -> float:
495
+ return self.f1_scores[5]
496
+
497
+ f1_scores: np.ndarray
498
+ f1_per_class: np.ndarray
499
+ iou_thresholds: np.ndarray
500
+ matched_classes: np.ndarray
501
+
502
+ small_objects: Optional[F1ScoreResult]
503
+ medium_objects: Optional[F1ScoreResult]
504
+ large_objects: Optional[F1ScoreResult]
505
+
506
+ def __str__(self) -> str:
507
+ """
508
+ Format as a pretty string.
509
+
510
+ Example:
511
+ ```python
512
+ print(f1_result)
513
+ # F1ScoreResult:
514
+ # Metric target: MetricTarget.BOXES
515
+ # Averaging method: AveragingMethod.WEIGHTED
516
+ # F1 @ 50: 0.7618
517
+ # F1 @ 75: 0.7487
518
+ # F1 @ thresh: [0.76175 0.76068 0.76068]
519
+ # IoU thresh: [0.5 0.55 0.6 ...]
520
+ # F1 per class:
521
+ # 0: [0.70968 0.70968 0.70968 ...]
522
+ # ...
523
+ # Small objects: ...
524
+ # Medium objects: ...
525
+ # Large objects: ...
526
+ ```
527
+ """
528
+ out_str = (
529
+ f"{self.__class__.__name__}:\n"
530
+ f"Metric target: {self.metric_target}\n"
531
+ f"Averaging method: {self.averaging_method}\n"
532
+ f"F1 @ 50: {self.f1_50:.4f}\n"
533
+ f"F1 @ 75: {self.f1_75:.4f}\n"
534
+ f"F1 @ thresh: {self.f1_scores}\n"
535
+ f"IoU thresh: {self.iou_thresholds}\n"
536
+ f"F1 per class:\n"
537
+ )
538
+ if self.f1_per_class.size == 0:
539
+ out_str += " No results\n"
540
+ for class_id, f1_of_class in zip(self.matched_classes, self.f1_per_class):
541
+ out_str += f" {class_id}: {f1_of_class}\n"
542
+
543
+ indent = " "
544
+ if self.small_objects is not None:
545
+ indented = indent + str(self.small_objects).replace("\n", f"\n{indent}")
546
+ out_str += f"\nSmall objects:\n{indented}"
547
+ if self.medium_objects is not None:
548
+ indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}")
549
+ out_str += f"\nMedium objects:\n{indented}"
550
+ if self.large_objects is not None:
551
+ indented = indent + str(self.large_objects).replace("\n", f"\n{indent}")
552
+ out_str += f"\nLarge objects:\n{indented}"
553
+
554
+ return out_str
555
+
556
+ def to_pandas(self) -> "pd.DataFrame":
557
+ """
558
+ Convert the result to a pandas DataFrame.
559
+
560
+ Returns:
561
+ (pd.DataFrame): The result as a DataFrame.
562
+ """
563
+ ensure_pandas_installed()
564
+ import pandas as pd
565
+
566
+ pandas_data = {
567
+ "F1@50": self.f1_50,
568
+ "F1@75": self.f1_75,
569
+ }
570
+
571
+ if self.small_objects is not None:
572
+ small_objects_df = self.small_objects.to_pandas()
573
+ for key, value in small_objects_df.items():
574
+ pandas_data[f"small_objects_{key}"] = value
575
+ if self.medium_objects is not None:
576
+ medium_objects_df = self.medium_objects.to_pandas()
577
+ for key, value in medium_objects_df.items():
578
+ pandas_data[f"medium_objects_{key}"] = value
579
+ if self.large_objects is not None:
580
+ large_objects_df = self.large_objects.to_pandas()
581
+ for key, value in large_objects_df.items():
582
+ pandas_data[f"large_objects_{key}"] = value
583
+
584
+ return pd.DataFrame(pandas_data, index=[0])
585
+
586
+ def plot(self):
587
+ """
588
+ Plot the F1 results.
589
+
590
+ ![example_plot](\
591
+ https://media.roboflow.com/eye-docs/metrics/f1_plot_example.png\
592
+ ){ align=center width="800" }
593
+ """
594
+
595
+ labels = ["F1@50", "F1@75"]
596
+ values = [self.f1_50, self.f1_75]
597
+ colors = [LEGACY_COLOR_PALETTE[0]] * 2
598
+
599
+ if self.small_objects is not None:
600
+ small_objects = self.small_objects
601
+ labels += ["Small: F1@50", "Small: F1@75"]
602
+ values += [small_objects.f1_50, small_objects.f1_75]
603
+ colors += [LEGACY_COLOR_PALETTE[3]] * 2
604
+
605
+ if self.medium_objects is not None:
606
+ medium_objects = self.medium_objects
607
+ labels += ["Medium: F1@50", "Medium: F1@75"]
608
+ values += [medium_objects.f1_50, medium_objects.f1_75]
609
+ colors += [LEGACY_COLOR_PALETTE[2]] * 2
610
+
611
+ if self.large_objects is not None:
612
+ large_objects = self.large_objects
613
+ labels += ["Large: F1@50", "Large: F1@75"]
614
+ values += [large_objects.f1_50, large_objects.f1_75]
615
+ colors += [LEGACY_COLOR_PALETTE[4]] * 2
616
+
617
+ plt.rcParams["font.family"] = "monospace"
618
+
619
+ _, ax = plt.subplots(figsize=(10, 6))
620
+ ax.set_ylim(0, 1)
621
+ ax.set_ylabel("Value", fontweight="bold")
622
+ title = (
623
+ f"F1 Score, by Object Size"
624
+ f"\n(target: {self.metric_target.value},"
625
+ f" averaging: {self.averaging_method.value})"
626
+ )
627
+ ax.set_title(title, fontweight="bold")
628
+
629
+ x_positions = range(len(labels))
630
+ bars = ax.bar(x_positions, values, color=colors, align="center")
631
+
632
+ ax.set_xticks(x_positions)
633
+ ax.set_xticklabels(labels, rotation=45, ha="right")
634
+
635
+ for bar in bars:
636
+ y_value = bar.get_height()
637
+ ax.text(
638
+ bar.get_x() + bar.get_width() / 2,
639
+ y_value + 0.02,
640
+ f"{y_value:.2f}",
641
+ ha="center",
642
+ va="bottom",
643
+ )
644
+
645
+ plt.rcParams["font.family"] = "sans-serif"
646
+
647
+ plt.tight_layout()
648
+ plt.show()