eye-cv 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eye/__init__.py +115 -0
- eye/__init___supervision_original.py +120 -0
- eye/annotators/__init__.py +0 -0
- eye/annotators/base.py +22 -0
- eye/annotators/core.py +2699 -0
- eye/annotators/line.py +107 -0
- eye/annotators/modern.py +529 -0
- eye/annotators/trace.py +142 -0
- eye/annotators/utils.py +177 -0
- eye/assets/__init__.py +2 -0
- eye/assets/downloader.py +95 -0
- eye/assets/list.py +83 -0
- eye/classification/__init__.py +0 -0
- eye/classification/core.py +188 -0
- eye/config.py +2 -0
- eye/core/__init__.py +0 -0
- eye/core/trackers/__init__.py +1 -0
- eye/core/trackers/botsort_tracker.py +336 -0
- eye/core/trackers/bytetrack_tracker.py +284 -0
- eye/core/trackers/sort_tracker.py +200 -0
- eye/core/tracking.py +146 -0
- eye/dataset/__init__.py +0 -0
- eye/dataset/core.py +919 -0
- eye/dataset/formats/__init__.py +0 -0
- eye/dataset/formats/coco.py +258 -0
- eye/dataset/formats/pascal_voc.py +279 -0
- eye/dataset/formats/yolo.py +272 -0
- eye/dataset/utils.py +259 -0
- eye/detection/__init__.py +0 -0
- eye/detection/auto_convert.py +155 -0
- eye/detection/core.py +1529 -0
- eye/detection/detections_enhanced.py +392 -0
- eye/detection/line_zone.py +859 -0
- eye/detection/lmm.py +184 -0
- eye/detection/overlap_filter.py +270 -0
- eye/detection/tools/__init__.py +0 -0
- eye/detection/tools/csv_sink.py +181 -0
- eye/detection/tools/inference_slicer.py +288 -0
- eye/detection/tools/json_sink.py +142 -0
- eye/detection/tools/polygon_zone.py +202 -0
- eye/detection/tools/smoother.py +123 -0
- eye/detection/tools/smoothing.py +179 -0
- eye/detection/tools/smoothing_config.py +202 -0
- eye/detection/tools/transformers.py +247 -0
- eye/detection/utils.py +1175 -0
- eye/draw/__init__.py +0 -0
- eye/draw/color.py +154 -0
- eye/draw/utils.py +374 -0
- eye/filters.py +112 -0
- eye/geometry/__init__.py +0 -0
- eye/geometry/core.py +128 -0
- eye/geometry/utils.py +47 -0
- eye/keypoint/__init__.py +0 -0
- eye/keypoint/annotators.py +442 -0
- eye/keypoint/core.py +687 -0
- eye/keypoint/skeletons.py +2647 -0
- eye/metrics/__init__.py +21 -0
- eye/metrics/core.py +72 -0
- eye/metrics/detection.py +843 -0
- eye/metrics/f1_score.py +648 -0
- eye/metrics/mean_average_precision.py +628 -0
- eye/metrics/mean_average_recall.py +697 -0
- eye/metrics/precision.py +653 -0
- eye/metrics/recall.py +652 -0
- eye/metrics/utils/__init__.py +0 -0
- eye/metrics/utils/object_size.py +158 -0
- eye/metrics/utils/utils.py +9 -0
- eye/py.typed +0 -0
- eye/quick.py +104 -0
- eye/tracker/__init__.py +0 -0
- eye/tracker/byte_tracker/__init__.py +0 -0
- eye/tracker/byte_tracker/core.py +386 -0
- eye/tracker/byte_tracker/kalman_filter.py +205 -0
- eye/tracker/byte_tracker/matching.py +69 -0
- eye/tracker/byte_tracker/single_object_track.py +178 -0
- eye/tracker/byte_tracker/utils.py +18 -0
- eye/utils/__init__.py +0 -0
- eye/utils/conversion.py +132 -0
- eye/utils/file.py +159 -0
- eye/utils/image.py +794 -0
- eye/utils/internal.py +200 -0
- eye/utils/iterables.py +84 -0
- eye/utils/notebook.py +114 -0
- eye/utils/video.py +307 -0
- eye/utils_eye/__init__.py +1 -0
- eye/utils_eye/geometry.py +71 -0
- eye/utils_eye/nms.py +55 -0
- eye/validators/__init__.py +140 -0
- eye/web.py +271 -0
- eye_cv-1.0.0.dist-info/METADATA +319 -0
- eye_cv-1.0.0.dist-info/RECORD +94 -0
- eye_cv-1.0.0.dist-info/WHEEL +5 -0
- eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
- eye_cv-1.0.0.dist-info/top_level.txt +1 -0
eye/metrics/recall.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from matplotlib import pyplot as plt
|
|
9
|
+
|
|
10
|
+
from eye.config import ORIENTED_BOX_COORDINATES
|
|
11
|
+
from eye.detection.core import Detections
|
|
12
|
+
from eye.detection.utils import (
|
|
13
|
+
box_iou_batch,
|
|
14
|
+
mask_iou_batch,
|
|
15
|
+
oriented_box_iou_batch,
|
|
16
|
+
)
|
|
17
|
+
from eye.draw.color import LEGACY_COLOR_PALETTE
|
|
18
|
+
from eye.metrics.core import AveragingMethod, Metric, MetricTarget
|
|
19
|
+
from eye.metrics.utils.object_size import (
|
|
20
|
+
ObjectSizeCategory,
|
|
21
|
+
get_detection_size_category,
|
|
22
|
+
)
|
|
23
|
+
from eye.metrics.utils.utils import ensure_pandas_installed
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
import pandas as pd
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Recall(Metric):
|
|
30
|
+
"""
|
|
31
|
+
Recall is a metric used to evaluate object detection models. It is the ratio of
|
|
32
|
+
true positive detections to the total number of ground truth instances. We calculate
|
|
33
|
+
it at different IoU thresholds.
|
|
34
|
+
|
|
35
|
+
In simple terms, Recall is a measure of a model's completeness, calculated as:
|
|
36
|
+
|
|
37
|
+
`Recall = TP / (TP + FN)`
|
|
38
|
+
|
|
39
|
+
Here, `TP` is the number of true positives (correct detections), and `FN` is the
|
|
40
|
+
number of false negatives (missed detections).
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
```python
|
|
44
|
+
import eye as sv
|
|
45
|
+
from eye.metrics import Recall
|
|
46
|
+
|
|
47
|
+
predictions = sv.Detections(...)
|
|
48
|
+
targets = sv.Detections(...)
|
|
49
|
+
|
|
50
|
+
recall_metric = Recall()
|
|
51
|
+
recall_result = recall_metric.update(predictions, targets).compute()
|
|
52
|
+
|
|
53
|
+
print(recall_result.recall_at_50)
|
|
54
|
+
# 0.7615
|
|
55
|
+
|
|
56
|
+
print(recall_result)
|
|
57
|
+
# RecallResult:
|
|
58
|
+
# Metric target: MetricTarget.BOXES
|
|
59
|
+
# Averaging method: AveragingMethod.WEIGHTED
|
|
60
|
+
# R @ 50: 0.7615
|
|
61
|
+
# R @ 75: 0.7462
|
|
62
|
+
# R @ thresh: [0.76151 0.76011 0.76011 0.75732 ...]
|
|
63
|
+
# IoU thresh: [0.5 0.55 0.6 ...]
|
|
64
|
+
# Recall per class:
|
|
65
|
+
# 0: [0.78571 0.78571 0.78571 ...]
|
|
66
|
+
# ...
|
|
67
|
+
# Small objects: ...
|
|
68
|
+
# Medium objects: ...
|
|
69
|
+
# Large objects: ...
|
|
70
|
+
|
|
71
|
+
recall_result.plot()
|
|
72
|
+
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
{ align=center width="800" }
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
metric_target: MetricTarget = MetricTarget.BOXES,
|
|
83
|
+
averaging_method: AveragingMethod = AveragingMethod.WEIGHTED,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Initialize the Recall metric.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
metric_target (MetricTarget): The type of detection data to use.
|
|
90
|
+
averaging_method (AveragingMethod): The averaging method used to compute the
|
|
91
|
+
recall. Determines how the recall is aggregated across classes.
|
|
92
|
+
"""
|
|
93
|
+
self._metric_target = metric_target
|
|
94
|
+
self.averaging_method = averaging_method
|
|
95
|
+
|
|
96
|
+
self._predictions_list: List[Detections] = []
|
|
97
|
+
self._targets_list: List[Detections] = []
|
|
98
|
+
|
|
99
|
+
def reset(self) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Reset the metric to its initial state, clearing all stored data.
|
|
102
|
+
"""
|
|
103
|
+
self._predictions_list = []
|
|
104
|
+
self._targets_list = []
|
|
105
|
+
|
|
106
|
+
def update(
|
|
107
|
+
self,
|
|
108
|
+
predictions: Union[Detections, List[Detections]],
|
|
109
|
+
targets: Union[Detections, List[Detections]],
|
|
110
|
+
) -> Recall:
|
|
111
|
+
"""
|
|
112
|
+
Add new predictions and targets to the metric, but do not compute the result.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
predictions (Union[Detections, List[Detections]]): The predicted detections.
|
|
116
|
+
targets (Union[Detections, List[Detections]]): The target detections.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
(Recall): The updated metric instance.
|
|
120
|
+
"""
|
|
121
|
+
if not isinstance(predictions, list):
|
|
122
|
+
predictions = [predictions]
|
|
123
|
+
if not isinstance(targets, list):
|
|
124
|
+
targets = [targets]
|
|
125
|
+
|
|
126
|
+
if len(predictions) != len(targets):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"The number of predictions ({len(predictions)}) and"
|
|
129
|
+
f" targets ({len(targets)}) during the update must be the same."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self._predictions_list.extend(predictions)
|
|
133
|
+
self._targets_list.extend(targets)
|
|
134
|
+
|
|
135
|
+
return self
|
|
136
|
+
|
|
137
|
+
def compute(self) -> RecallResult:
|
|
138
|
+
"""
|
|
139
|
+
Calculate the recall metric based on the stored predictions and ground-truth
|
|
140
|
+
data, at different IoU thresholds.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
(RecallResult): The recall metric result.
|
|
144
|
+
"""
|
|
145
|
+
result = self._compute(self._predictions_list, self._targets_list)
|
|
146
|
+
|
|
147
|
+
small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
|
|
148
|
+
self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL
|
|
149
|
+
)
|
|
150
|
+
result.small_objects = self._compute(small_predictions, small_targets)
|
|
151
|
+
|
|
152
|
+
medium_predictions, medium_targets = (
|
|
153
|
+
self._filter_predictions_and_targets_by_size(
|
|
154
|
+
self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
result.medium_objects = self._compute(medium_predictions, medium_targets)
|
|
158
|
+
|
|
159
|
+
large_predictions, large_targets = self._filter_predictions_and_targets_by_size(
|
|
160
|
+
self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE
|
|
161
|
+
)
|
|
162
|
+
result.large_objects = self._compute(large_predictions, large_targets)
|
|
163
|
+
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def _compute(
|
|
167
|
+
self, predictions_list: List[Detections], targets_list: List[Detections]
|
|
168
|
+
) -> RecallResult:
|
|
169
|
+
iou_thresholds = np.linspace(0.5, 0.95, 10)
|
|
170
|
+
stats = []
|
|
171
|
+
|
|
172
|
+
for predictions, targets in zip(predictions_list, targets_list):
|
|
173
|
+
prediction_contents = self._detections_content(predictions)
|
|
174
|
+
target_contents = self._detections_content(targets)
|
|
175
|
+
|
|
176
|
+
if len(targets) > 0:
|
|
177
|
+
if len(predictions) == 0:
|
|
178
|
+
stats.append(
|
|
179
|
+
(
|
|
180
|
+
np.zeros((0, iou_thresholds.size), dtype=bool),
|
|
181
|
+
np.zeros((0,), dtype=np.float32),
|
|
182
|
+
np.zeros((0,), dtype=int),
|
|
183
|
+
targets.class_id,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
else:
|
|
188
|
+
if self._metric_target == MetricTarget.BOXES:
|
|
189
|
+
iou = box_iou_batch(target_contents, prediction_contents)
|
|
190
|
+
elif self._metric_target == MetricTarget.MASKS:
|
|
191
|
+
iou = mask_iou_batch(target_contents, prediction_contents)
|
|
192
|
+
elif self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
|
|
193
|
+
iou = oriented_box_iou_batch(
|
|
194
|
+
target_contents, prediction_contents
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
"Unsupported metric target for IoU calculation"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
matches = self._match_detection_batch(
|
|
202
|
+
predictions.class_id, targets.class_id, iou, iou_thresholds
|
|
203
|
+
)
|
|
204
|
+
stats.append(
|
|
205
|
+
(
|
|
206
|
+
matches,
|
|
207
|
+
predictions.confidence,
|
|
208
|
+
predictions.class_id,
|
|
209
|
+
targets.class_id,
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if not stats:
|
|
214
|
+
return RecallResult(
|
|
215
|
+
metric_target=self._metric_target,
|
|
216
|
+
averaging_method=self.averaging_method,
|
|
217
|
+
recall_scores=np.zeros(iou_thresholds.shape[0]),
|
|
218
|
+
recall_per_class=np.zeros((0, iou_thresholds.shape[0])),
|
|
219
|
+
iou_thresholds=iou_thresholds,
|
|
220
|
+
matched_classes=np.array([], dtype=int),
|
|
221
|
+
small_objects=None,
|
|
222
|
+
medium_objects=None,
|
|
223
|
+
large_objects=None,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
|
|
227
|
+
recall_scores, recall_per_class, unique_classes = (
|
|
228
|
+
self._compute_recall_for_classes(*concatenated_stats)
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return RecallResult(
|
|
232
|
+
metric_target=self._metric_target,
|
|
233
|
+
averaging_method=self.averaging_method,
|
|
234
|
+
recall_scores=recall_scores,
|
|
235
|
+
recall_per_class=recall_per_class,
|
|
236
|
+
iou_thresholds=iou_thresholds,
|
|
237
|
+
matched_classes=unique_classes,
|
|
238
|
+
small_objects=None,
|
|
239
|
+
medium_objects=None,
|
|
240
|
+
large_objects=None,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def _compute_recall_for_classes(
|
|
244
|
+
self,
|
|
245
|
+
matches: np.ndarray,
|
|
246
|
+
prediction_confidence: np.ndarray,
|
|
247
|
+
prediction_class_ids: np.ndarray,
|
|
248
|
+
true_class_ids: np.ndarray,
|
|
249
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
250
|
+
sorted_indices = np.argsort(-prediction_confidence)
|
|
251
|
+
matches = matches[sorted_indices]
|
|
252
|
+
prediction_class_ids = prediction_class_ids[sorted_indices]
|
|
253
|
+
unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
|
|
254
|
+
|
|
255
|
+
# Shape: PxTh,P,C,C -> CxThx3
|
|
256
|
+
confusion_matrix = self._compute_confusion_matrix(
|
|
257
|
+
matches, prediction_class_ids, unique_classes, class_counts
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Shape: CxThx3 -> CxTh
|
|
261
|
+
recall_per_class = self._compute_recall(confusion_matrix)
|
|
262
|
+
|
|
263
|
+
# Shape: CxTh -> Th
|
|
264
|
+
if self.averaging_method == AveragingMethod.MACRO:
|
|
265
|
+
recall_scores = np.mean(recall_per_class, axis=0)
|
|
266
|
+
elif self.averaging_method == AveragingMethod.MICRO:
|
|
267
|
+
confusion_matrix_merged = confusion_matrix.sum(0)
|
|
268
|
+
recall_scores = self._compute_recall(confusion_matrix_merged)
|
|
269
|
+
elif self.averaging_method == AveragingMethod.WEIGHTED:
|
|
270
|
+
class_counts = class_counts.astype(np.float32)
|
|
271
|
+
recall_scores = np.average(recall_per_class, axis=0, weights=class_counts)
|
|
272
|
+
|
|
273
|
+
return recall_scores, recall_per_class, unique_classes
|
|
274
|
+
|
|
275
|
+
@staticmethod
|
|
276
|
+
def _match_detection_batch(
|
|
277
|
+
predictions_classes: np.ndarray,
|
|
278
|
+
target_classes: np.ndarray,
|
|
279
|
+
iou: np.ndarray,
|
|
280
|
+
iou_thresholds: np.ndarray,
|
|
281
|
+
) -> np.ndarray:
|
|
282
|
+
num_predictions, num_iou_levels = (
|
|
283
|
+
predictions_classes.shape[0],
|
|
284
|
+
iou_thresholds.shape[0],
|
|
285
|
+
)
|
|
286
|
+
correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)
|
|
287
|
+
correct_class = target_classes[:, None] == predictions_classes
|
|
288
|
+
|
|
289
|
+
for i, iou_level in enumerate(iou_thresholds):
|
|
290
|
+
matched_indices = np.where((iou >= iou_level) & correct_class)
|
|
291
|
+
|
|
292
|
+
if matched_indices[0].shape[0]:
|
|
293
|
+
combined_indices = np.stack(matched_indices, axis=1)
|
|
294
|
+
iou_values = iou[matched_indices][:, None]
|
|
295
|
+
matches = np.hstack([combined_indices, iou_values])
|
|
296
|
+
|
|
297
|
+
if matched_indices[0].shape[0] > 1:
|
|
298
|
+
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
299
|
+
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
300
|
+
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
301
|
+
|
|
302
|
+
correct[matches[:, 1].astype(int), i] = True
|
|
303
|
+
|
|
304
|
+
return correct
|
|
305
|
+
|
|
306
|
+
@staticmethod
|
|
307
|
+
def _compute_confusion_matrix(
|
|
308
|
+
sorted_matches: np.ndarray,
|
|
309
|
+
sorted_prediction_class_ids: np.ndarray,
|
|
310
|
+
unique_classes: np.ndarray,
|
|
311
|
+
class_counts: np.ndarray,
|
|
312
|
+
) -> np.ndarray:
|
|
313
|
+
"""
|
|
314
|
+
Compute the confusion matrix for each class and IoU threshold.
|
|
315
|
+
|
|
316
|
+
Assumes the matches and prediction_class_ids are sorted by confidence
|
|
317
|
+
in descending order.
|
|
318
|
+
|
|
319
|
+
Arguments:
|
|
320
|
+
sorted_matches: np.ndarray, bool, shape (P, Th), that is True
|
|
321
|
+
if the prediction is a true positive at the given IoU threshold.
|
|
322
|
+
sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing
|
|
323
|
+
the class id for each prediction.
|
|
324
|
+
unique_classes: np.ndarray, int, shape (C,), containing the unique
|
|
325
|
+
class ids.
|
|
326
|
+
class_counts: np.ndarray, int, shape (C,), containing the number
|
|
327
|
+
of true instances for each class.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
np.ndarray, shape (C, Th, 3), containing the true positives, false
|
|
331
|
+
positives, and false negatives for each class and IoU threshold.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
num_thresholds = sorted_matches.shape[1]
|
|
335
|
+
num_classes = unique_classes.shape[0]
|
|
336
|
+
|
|
337
|
+
confusion_matrix = np.zeros((num_classes, num_thresholds, 3))
|
|
338
|
+
for class_idx, class_id in enumerate(unique_classes):
|
|
339
|
+
is_class = sorted_prediction_class_ids == class_id
|
|
340
|
+
num_true = class_counts[class_idx]
|
|
341
|
+
num_predictions = is_class.sum()
|
|
342
|
+
|
|
343
|
+
if num_predictions == 0:
|
|
344
|
+
true_positives = np.zeros(num_thresholds)
|
|
345
|
+
false_positives = np.zeros(num_thresholds)
|
|
346
|
+
false_negatives = np.full(num_thresholds, num_true)
|
|
347
|
+
elif num_true == 0:
|
|
348
|
+
true_positives = np.zeros(num_thresholds)
|
|
349
|
+
false_positives = np.full(num_thresholds, num_predictions)
|
|
350
|
+
false_negatives = np.zeros(num_thresholds)
|
|
351
|
+
else:
|
|
352
|
+
true_positives = sorted_matches[is_class].sum(0)
|
|
353
|
+
false_positives = (1 - sorted_matches[is_class]).sum(0)
|
|
354
|
+
false_negatives = num_true - true_positives
|
|
355
|
+
confusion_matrix[class_idx] = np.stack(
|
|
356
|
+
[true_positives, false_positives, false_negatives], axis=1
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return confusion_matrix
|
|
360
|
+
|
|
361
|
+
@staticmethod
|
|
362
|
+
def _compute_recall(confusion_matrix: np.ndarray) -> np.ndarray:
|
|
363
|
+
"""
|
|
364
|
+
Broadcastable function, computing the recall from the confusion matrix.
|
|
365
|
+
|
|
366
|
+
Arguments:
|
|
367
|
+
confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension
|
|
368
|
+
contains the true positives, false positives, and false negatives.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
np.ndarray, shape (N, ...), containing the recall for each element.
|
|
372
|
+
"""
|
|
373
|
+
if not confusion_matrix.shape[-1] == 3:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"Confusion matrix must have shape (..., 3), got "
|
|
376
|
+
f"{confusion_matrix.shape}"
|
|
377
|
+
)
|
|
378
|
+
true_positives = confusion_matrix[..., 0]
|
|
379
|
+
false_negatives = confusion_matrix[..., 2]
|
|
380
|
+
|
|
381
|
+
denominator = true_positives + false_negatives
|
|
382
|
+
recall = np.where(denominator == 0, 0, true_positives / denominator)
|
|
383
|
+
|
|
384
|
+
return recall
|
|
385
|
+
|
|
386
|
+
def _detections_content(self, detections: Detections) -> np.ndarray:
|
|
387
|
+
"""Return boxes, masks or oriented bounding boxes from detections."""
|
|
388
|
+
if self._metric_target == MetricTarget.BOXES:
|
|
389
|
+
return detections.xyxy
|
|
390
|
+
if self._metric_target == MetricTarget.MASKS:
|
|
391
|
+
return (
|
|
392
|
+
detections.mask
|
|
393
|
+
if detections.mask is not None
|
|
394
|
+
else self._make_empty_content()
|
|
395
|
+
)
|
|
396
|
+
if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
|
|
397
|
+
obb = detections.data.get(ORIENTED_BOX_COORDINATES)
|
|
398
|
+
if obb is not None and len(obb) > 0:
|
|
399
|
+
return np.array(obb, dtype=np.float32)
|
|
400
|
+
return self._make_empty_content()
|
|
401
|
+
raise ValueError(f"Invalid metric target: {self._metric_target}")
|
|
402
|
+
|
|
403
|
+
def _make_empty_content(self) -> np.ndarray:
|
|
404
|
+
if self._metric_target == MetricTarget.BOXES:
|
|
405
|
+
return np.empty((0, 4), dtype=np.float32)
|
|
406
|
+
if self._metric_target == MetricTarget.MASKS:
|
|
407
|
+
return np.empty((0, 0, 0), dtype=bool)
|
|
408
|
+
if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
|
|
409
|
+
return np.empty((0, 4, 2), dtype=np.float32)
|
|
410
|
+
raise ValueError(f"Invalid metric target: {self._metric_target}")
|
|
411
|
+
|
|
412
|
+
def _filter_detections_by_size(
|
|
413
|
+
self, detections: Detections, size_category: ObjectSizeCategory
|
|
414
|
+
) -> Detections:
|
|
415
|
+
"""Return a copy of detections with contents filtered by object size."""
|
|
416
|
+
new_detections = deepcopy(detections)
|
|
417
|
+
if detections.is_empty() or size_category == ObjectSizeCategory.ANY:
|
|
418
|
+
return new_detections
|
|
419
|
+
|
|
420
|
+
sizes = get_detection_size_category(new_detections, self._metric_target)
|
|
421
|
+
size_mask = sizes == size_category.value
|
|
422
|
+
|
|
423
|
+
new_detections.xyxy = new_detections.xyxy[size_mask]
|
|
424
|
+
if new_detections.mask is not None:
|
|
425
|
+
new_detections.mask = new_detections.mask[size_mask]
|
|
426
|
+
if new_detections.class_id is not None:
|
|
427
|
+
new_detections.class_id = new_detections.class_id[size_mask]
|
|
428
|
+
if new_detections.confidence is not None:
|
|
429
|
+
new_detections.confidence = new_detections.confidence[size_mask]
|
|
430
|
+
if new_detections.tracker_id is not None:
|
|
431
|
+
new_detections.tracker_id = new_detections.tracker_id[size_mask]
|
|
432
|
+
if new_detections.data is not None:
|
|
433
|
+
for key, value in new_detections.data.items():
|
|
434
|
+
new_detections.data[key] = np.array(value)[size_mask]
|
|
435
|
+
|
|
436
|
+
return new_detections
|
|
437
|
+
|
|
438
|
+
def _filter_predictions_and_targets_by_size(
|
|
439
|
+
self,
|
|
440
|
+
predictions_list: List[Detections],
|
|
441
|
+
targets_list: List[Detections],
|
|
442
|
+
size_category: ObjectSizeCategory,
|
|
443
|
+
) -> Tuple[List[Detections], List[Detections]]:
|
|
444
|
+
"""
|
|
445
|
+
Filter predictions and targets by object size category.
|
|
446
|
+
"""
|
|
447
|
+
new_predictions_list = []
|
|
448
|
+
new_targets_list = []
|
|
449
|
+
for predictions, targets in zip(predictions_list, targets_list):
|
|
450
|
+
new_predictions_list.append(
|
|
451
|
+
self._filter_detections_by_size(predictions, size_category)
|
|
452
|
+
)
|
|
453
|
+
new_targets_list.append(
|
|
454
|
+
self._filter_detections_by_size(targets, size_category)
|
|
455
|
+
)
|
|
456
|
+
return new_predictions_list, new_targets_list
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@dataclass
|
|
460
|
+
class RecallResult:
|
|
461
|
+
"""
|
|
462
|
+
The results of the recall metric calculation.
|
|
463
|
+
|
|
464
|
+
Defaults to `0` if no detections or targets were provided.
|
|
465
|
+
|
|
466
|
+
Attributes:
|
|
467
|
+
metric_target (MetricTarget): the type of data used for the metric -
|
|
468
|
+
boxes, masks or oriented bounding boxes.
|
|
469
|
+
averaging_method (AveragingMethod): the averaging method used to compute the
|
|
470
|
+
recall. Determines how the recall is aggregated across classes.
|
|
471
|
+
recall_at_50 (float): the recall at IoU threshold of `0.5`.
|
|
472
|
+
recall_at_75 (float): the recall at IoU threshold of `0.75`.
|
|
473
|
+
recall_scores (np.ndarray): the recall scores at each IoU threshold.
|
|
474
|
+
Shape: `(num_iou_thresholds,)`
|
|
475
|
+
recall_per_class (np.ndarray): the recall scores per class and IoU threshold.
|
|
476
|
+
Shape: `(num_target_classes, num_iou_thresholds)`
|
|
477
|
+
iou_thresholds (np.ndarray): the IoU thresholds used in the calculations.
|
|
478
|
+
matched_classes (np.ndarray): the class IDs of all matched classes.
|
|
479
|
+
Corresponds to the rows of `recall_per_class`.
|
|
480
|
+
small_objects (Optional[RecallResult]): the Recall metric results
|
|
481
|
+
for small objects (area < 32²).
|
|
482
|
+
medium_objects (Optional[RecallResult]): the Recall metric results
|
|
483
|
+
for medium objects (32² ≤ area < 96²).
|
|
484
|
+
large_objects (Optional[RecallResult]): the Recall metric results
|
|
485
|
+
for large objects (area ≥ 96²).
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
metric_target: MetricTarget
|
|
489
|
+
averaging_method: AveragingMethod
|
|
490
|
+
|
|
491
|
+
@property
|
|
492
|
+
def recall_at_50(self) -> float:
|
|
493
|
+
return self.recall_scores[0]
|
|
494
|
+
|
|
495
|
+
@property
|
|
496
|
+
def recall_at_75(self) -> float:
|
|
497
|
+
return self.recall_scores[5]
|
|
498
|
+
|
|
499
|
+
recall_scores: np.ndarray
|
|
500
|
+
recall_per_class: np.ndarray
|
|
501
|
+
iou_thresholds: np.ndarray
|
|
502
|
+
matched_classes: np.ndarray
|
|
503
|
+
|
|
504
|
+
small_objects: Optional[RecallResult]
|
|
505
|
+
medium_objects: Optional[RecallResult]
|
|
506
|
+
large_objects: Optional[RecallResult]
|
|
507
|
+
|
|
508
|
+
def __str__(self) -> str:
|
|
509
|
+
"""
|
|
510
|
+
Format as a pretty string.
|
|
511
|
+
|
|
512
|
+
Example:
|
|
513
|
+
```python
|
|
514
|
+
print(recall_result)
|
|
515
|
+
# RecallResult:
|
|
516
|
+
# Metric target: MetricTarget.BOXES
|
|
517
|
+
# Averaging method: AveragingMethod.WEIGHTED
|
|
518
|
+
# R @ 50: 0.7615
|
|
519
|
+
# R @ 75: 0.7462
|
|
520
|
+
# R @ thresh: [0.76151 0.76011 0.76011 0.75732 ...]
|
|
521
|
+
# IoU thresh: [0.5 0.55 0.6 ...]
|
|
522
|
+
# Recall per class:
|
|
523
|
+
# 0: [0.78571 0.78571 0.78571 ...]
|
|
524
|
+
# ...
|
|
525
|
+
# Small objects: ...
|
|
526
|
+
# Medium objects: ...
|
|
527
|
+
# Large objects: ...
|
|
528
|
+
```
|
|
529
|
+
"""
|
|
530
|
+
out_str = (
|
|
531
|
+
f"{self.__class__.__name__}:\n"
|
|
532
|
+
f"Metric target: {self.metric_target}\n"
|
|
533
|
+
f"Averaging method: {self.averaging_method}\n"
|
|
534
|
+
f"R @ 50: {self.recall_at_50:.4f}\n"
|
|
535
|
+
f"R @ 75: {self.recall_at_75:.4f}\n"
|
|
536
|
+
f"R @ thresh: {self.recall_scores}\n"
|
|
537
|
+
f"IoU thresh: {self.iou_thresholds}\n"
|
|
538
|
+
f"Recall per class:\n"
|
|
539
|
+
)
|
|
540
|
+
if self.recall_per_class.size == 0:
|
|
541
|
+
out_str += " No results\n"
|
|
542
|
+
for class_id, recall_of_class in zip(
|
|
543
|
+
self.matched_classes, self.recall_per_class
|
|
544
|
+
):
|
|
545
|
+
out_str += f" {class_id}: {recall_of_class}\n"
|
|
546
|
+
|
|
547
|
+
indent = " "
|
|
548
|
+
if self.small_objects is not None:
|
|
549
|
+
indented = indent + str(self.small_objects).replace("\n", f"\n{indent}")
|
|
550
|
+
out_str += f"\nSmall objects:\n{indented}"
|
|
551
|
+
if self.medium_objects is not None:
|
|
552
|
+
indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}")
|
|
553
|
+
out_str += f"\nMedium objects:\n{indented}"
|
|
554
|
+
if self.large_objects is not None:
|
|
555
|
+
indented = indent + str(self.large_objects).replace("\n", f"\n{indent}")
|
|
556
|
+
out_str += f"\nLarge objects:\n{indented}"
|
|
557
|
+
|
|
558
|
+
return out_str
|
|
559
|
+
|
|
560
|
+
def to_pandas(self) -> "pd.DataFrame":
|
|
561
|
+
"""
|
|
562
|
+
Convert the result to a pandas DataFrame.
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
(pd.DataFrame): The result as a DataFrame.
|
|
566
|
+
"""
|
|
567
|
+
ensure_pandas_installed()
|
|
568
|
+
import pandas as pd
|
|
569
|
+
|
|
570
|
+
pandas_data = {
|
|
571
|
+
"R@50": self.recall_at_50,
|
|
572
|
+
"R@75": self.recall_at_75,
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
if self.small_objects is not None:
|
|
576
|
+
small_objects_df = self.small_objects.to_pandas()
|
|
577
|
+
for key, value in small_objects_df.items():
|
|
578
|
+
pandas_data[f"small_objects_{key}"] = value
|
|
579
|
+
if self.medium_objects is not None:
|
|
580
|
+
medium_objects_df = self.medium_objects.to_pandas()
|
|
581
|
+
for key, value in medium_objects_df.items():
|
|
582
|
+
pandas_data[f"medium_objects_{key}"] = value
|
|
583
|
+
if self.large_objects is not None:
|
|
584
|
+
large_objects_df = self.large_objects.to_pandas()
|
|
585
|
+
for key, value in large_objects_df.items():
|
|
586
|
+
pandas_data[f"large_objects_{key}"] = value
|
|
587
|
+
|
|
588
|
+
return pd.DataFrame(pandas_data, index=[0])
|
|
589
|
+
|
|
590
|
+
def plot(self):
|
|
591
|
+
"""
|
|
592
|
+
Plot the recall results.
|
|
593
|
+
|
|
594
|
+
{ align=center width="800" }
|
|
597
|
+
"""
|
|
598
|
+
|
|
599
|
+
labels = ["Recall@50", "Recall@75"]
|
|
600
|
+
values = [self.recall_at_50, self.recall_at_75]
|
|
601
|
+
colors = [LEGACY_COLOR_PALETTE[0]] * 2
|
|
602
|
+
|
|
603
|
+
if self.small_objects is not None:
|
|
604
|
+
small_objects = self.small_objects
|
|
605
|
+
labels += ["Small: R@50", "Small: R@75"]
|
|
606
|
+
values += [small_objects.recall_at_50, small_objects.recall_at_75]
|
|
607
|
+
colors += [LEGACY_COLOR_PALETTE[3]] * 2
|
|
608
|
+
|
|
609
|
+
if self.medium_objects is not None:
|
|
610
|
+
medium_objects = self.medium_objects
|
|
611
|
+
labels += ["Medium: R@50", "Medium: R@75"]
|
|
612
|
+
values += [medium_objects.recall_at_50, medium_objects.recall_at_75]
|
|
613
|
+
colors += [LEGACY_COLOR_PALETTE[2]] * 2
|
|
614
|
+
|
|
615
|
+
if self.large_objects is not None:
|
|
616
|
+
large_objects = self.large_objects
|
|
617
|
+
labels += ["Large: R@50", "Large: R@75"]
|
|
618
|
+
values += [large_objects.recall_at_50, large_objects.recall_at_75]
|
|
619
|
+
colors += [LEGACY_COLOR_PALETTE[4]] * 2
|
|
620
|
+
|
|
621
|
+
plt.rcParams["font.family"] = "monospace"
|
|
622
|
+
|
|
623
|
+
_, ax = plt.subplots(figsize=(10, 6))
|
|
624
|
+
ax.set_ylim(0, 1)
|
|
625
|
+
ax.set_ylabel("Value", fontweight="bold")
|
|
626
|
+
title = (
|
|
627
|
+
f"Recall, by Object Size"
|
|
628
|
+
f"\n(target: {self.metric_target.value},"
|
|
629
|
+
f" averaging: {self.averaging_method.value})"
|
|
630
|
+
)
|
|
631
|
+
ax.set_title(title, fontweight="bold")
|
|
632
|
+
|
|
633
|
+
x_positions = range(len(labels))
|
|
634
|
+
bars = ax.bar(x_positions, values, color=colors, align="center")
|
|
635
|
+
|
|
636
|
+
ax.set_xticks(x_positions)
|
|
637
|
+
ax.set_xticklabels(labels, rotation=45, ha="right")
|
|
638
|
+
|
|
639
|
+
for bar in bars:
|
|
640
|
+
y_value = bar.get_height()
|
|
641
|
+
ax.text(
|
|
642
|
+
bar.get_x() + bar.get_width() / 2,
|
|
643
|
+
y_value + 0.02,
|
|
644
|
+
f"{y_value:.2f}",
|
|
645
|
+
ha="center",
|
|
646
|
+
va="bottom",
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
plt.rcParams["font.family"] = "sans-serif"
|
|
650
|
+
|
|
651
|
+
plt.tight_layout()
|
|
652
|
+
plt.show()
|
|
File without changes
|