ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_exports.py +2 -2
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +15 -19
- ultralytics/engine/exporter.py +24 -23
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +42 -24
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +65 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
- ultralytics-8.3.91.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
@@ -25,7 +25,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
25
25
|
box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
|
26
26
|
box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
|
27
27
|
iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
|
28
|
-
eps (float, optional): A small value to avoid division by zero.
|
28
|
+
eps (float, optional): A small value to avoid division by zero.
|
29
29
|
|
30
30
|
Returns:
|
31
31
|
(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
|
@@ -57,7 +57,7 @@ def box_iou(box1, box2, eps=1e-7):
|
|
57
57
|
Args:
|
58
58
|
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
59
59
|
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
|
60
|
-
eps (float, optional): A small value to avoid division by zero.
|
60
|
+
eps (float, optional): A small value to avoid division by zero.
|
61
61
|
|
62
62
|
Returns:
|
63
63
|
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
@@ -73,7 +73,7 @@ def box_iou(box1, box2, eps=1e-7):
|
|
73
73
|
|
74
74
|
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
75
75
|
"""
|
76
|
-
|
76
|
+
Calculate the Intersection over Union (IoU) between bounding boxes.
|
77
77
|
|
78
78
|
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
|
79
79
|
For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
|
@@ -84,11 +84,11 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|
84
84
|
box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
85
85
|
box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
86
86
|
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
|
87
|
-
(x1, y1, x2, y2) format.
|
88
|
-
GIoU (bool, optional): If True, calculate Generalized IoU.
|
89
|
-
DIoU (bool, optional): If True, calculate Distance IoU.
|
90
|
-
CIoU (bool, optional): If True, calculate Complete IoU.
|
91
|
-
eps (float, optional): A small value to avoid division by zero.
|
87
|
+
(x1, y1, x2, y2) format.
|
88
|
+
GIoU (bool, optional): If True, calculate Generalized IoU.
|
89
|
+
DIoU (bool, optional): If True, calculate Distance IoU.
|
90
|
+
CIoU (bool, optional): If True, calculate Complete IoU.
|
91
|
+
eps (float, optional): A small value to avoid division by zero.
|
92
92
|
|
93
93
|
Returns:
|
94
94
|
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
|
@@ -143,7 +143,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
|
|
143
143
|
product of image width and height.
|
144
144
|
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
|
145
145
|
product of image width and height.
|
146
|
-
eps (float, optional): A small value to avoid division by zero.
|
146
|
+
eps (float, optional): A small value to avoid division by zero.
|
147
147
|
|
148
148
|
Returns:
|
149
149
|
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
|
@@ -162,7 +162,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
|
162
162
|
kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
|
163
163
|
area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
|
164
164
|
sigma (list): A list containing 17 values representing keypoint scales.
|
165
|
-
eps (float, optional): A small value to avoid division by zero.
|
165
|
+
eps (float, optional): A small value to avoid division by zero.
|
166
166
|
|
167
167
|
Returns:
|
168
168
|
(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
|
@@ -177,7 +177,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
|
177
177
|
|
178
178
|
def _get_covariance_matrix(boxes):
|
179
179
|
"""
|
180
|
-
|
180
|
+
Generate covariance matrix from oriented bounding boxes.
|
181
181
|
|
182
182
|
Args:
|
183
183
|
boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
|
@@ -199,20 +199,18 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
199
199
|
"""
|
200
200
|
Calculate probabilistic IoU between oriented bounding boxes.
|
201
201
|
|
202
|
-
Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
|
203
|
-
|
204
202
|
Args:
|
205
203
|
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
|
206
204
|
obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
|
207
|
-
CIoU (bool, optional): If True, calculate CIoU.
|
208
|
-
eps (float, optional): Small value to avoid division by zero.
|
205
|
+
CIoU (bool, optional): If True, calculate CIoU.
|
206
|
+
eps (float, optional): Small value to avoid division by zero.
|
209
207
|
|
210
208
|
Returns:
|
211
209
|
(torch.Tensor): OBB similarities, shape (N,).
|
212
210
|
|
213
|
-
|
214
|
-
OBB format: [center_x, center_y, width, height, rotation_angle].
|
215
|
-
|
211
|
+
Notes:
|
212
|
+
- OBB format: [center_x, center_y, width, height, rotation_angle].
|
213
|
+
- Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
|
216
214
|
"""
|
217
215
|
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
218
216
|
x2, y2 = obb2[..., :2].split(1, dim=-1)
|
@@ -243,15 +241,18 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
243
241
|
|
244
242
|
def batch_probiou(obb1, obb2, eps=1e-7):
|
245
243
|
"""
|
246
|
-
Calculate the
|
244
|
+
Calculate the probabilistic IoU between oriented bounding boxes.
|
247
245
|
|
248
246
|
Args:
|
249
247
|
obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
250
248
|
obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
|
251
|
-
eps (float, optional): A small value to avoid division by zero.
|
249
|
+
eps (float, optional): A small value to avoid division by zero.
|
252
250
|
|
253
251
|
Returns:
|
254
252
|
(torch.Tensor): A tensor of shape (N, M) representing obb similarities.
|
253
|
+
|
254
|
+
References:
|
255
|
+
https://arxiv.org/pdf/2106.06072v1.pdf
|
255
256
|
"""
|
256
257
|
obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
|
257
258
|
obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
|
@@ -277,16 +278,16 @@ def batch_probiou(obb1, obb2, eps=1e-7):
|
|
277
278
|
|
278
279
|
def smooth_bce(eps=0.1):
|
279
280
|
"""
|
280
|
-
|
281
|
-
|
282
|
-
This function calculates positive and negative label smoothing BCE targets based on a given epsilon value.
|
283
|
-
For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441.
|
281
|
+
Compute smoothed positive and negative Binary Cross-Entropy targets.
|
284
282
|
|
285
283
|
Args:
|
286
|
-
eps (float, optional): The epsilon value for label smoothing.
|
284
|
+
eps (float, optional): The epsilon value for label smoothing.
|
287
285
|
|
288
286
|
Returns:
|
289
287
|
(tuple): A tuple containing the positive and negative label smoothing BCE targets.
|
288
|
+
|
289
|
+
References:
|
290
|
+
https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
|
290
291
|
"""
|
291
292
|
return 1.0 - 0.5 * eps, 0.5 * eps
|
292
293
|
|
@@ -304,7 +305,15 @@ class ConfusionMatrix:
|
|
304
305
|
"""
|
305
306
|
|
306
307
|
def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
|
307
|
-
"""
|
308
|
+
"""
|
309
|
+
Initialize a ConfusionMatrix instance.
|
310
|
+
|
311
|
+
Args:
|
312
|
+
nc (int): Number of classes.
|
313
|
+
conf (float, optional): Confidence threshold for detections.
|
314
|
+
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
315
|
+
task (str, optional): Type of task, either 'detect' or 'classify'.
|
316
|
+
"""
|
308
317
|
self.task = task
|
309
318
|
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
310
319
|
self.nc = nc # number of classes
|
@@ -382,11 +391,16 @@ class ConfusionMatrix:
|
|
382
391
|
self.matrix[dc, self.nc] += 1 # predicted background
|
383
392
|
|
384
393
|
def matrix(self):
|
385
|
-
"""
|
394
|
+
"""Return the confusion matrix."""
|
386
395
|
return self.matrix
|
387
396
|
|
388
397
|
def tp_fp(self):
|
389
|
-
"""
|
398
|
+
"""
|
399
|
+
Return true positives and false positives.
|
400
|
+
|
401
|
+
Returns:
|
402
|
+
(tuple): True positives and false positives.
|
403
|
+
"""
|
390
404
|
tp = self.matrix.diagonal() # true positives
|
391
405
|
fp = self.matrix.sum(1) - tp # false positives
|
392
406
|
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
@@ -454,7 +468,17 @@ def smooth(y, f=0.05):
|
|
454
468
|
|
455
469
|
@plt_settings()
|
456
470
|
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
|
457
|
-
"""
|
471
|
+
"""
|
472
|
+
Plot precision-recall curve.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
px (np.ndarray): X values for the PR curve.
|
476
|
+
py (np.ndarray): Y values for the PR curve.
|
477
|
+
ap (np.ndarray): Average precision values.
|
478
|
+
save_dir (Path, optional): Path to save the plot.
|
479
|
+
names (dict, optional): Dictionary mapping class indices to class names.
|
480
|
+
on_plot (callable, optional): Function to call after plot is saved.
|
481
|
+
"""
|
458
482
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
459
483
|
py = np.stack(py, axis=1)
|
460
484
|
|
@@ -479,7 +503,18 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
|
|
479
503
|
|
480
504
|
@plt_settings()
|
481
505
|
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
|
482
|
-
"""
|
506
|
+
"""
|
507
|
+
Plot metric-confidence curve.
|
508
|
+
|
509
|
+
Args:
|
510
|
+
px (np.ndarray): X values for the metric-confidence curve.
|
511
|
+
py (np.ndarray): Y values for the metric-confidence curve.
|
512
|
+
save_dir (Path, optional): Path to save the plot.
|
513
|
+
names (dict, optional): Dictionary mapping class indices to class names.
|
514
|
+
xlabel (str, optional): X-axis label.
|
515
|
+
ylabel (str, optional): Y-axis label.
|
516
|
+
on_plot (callable, optional): Function to call after plot is saved.
|
517
|
+
"""
|
483
518
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
484
519
|
|
485
520
|
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
@@ -538,33 +573,33 @@ def ap_per_class(
|
|
538
573
|
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
|
539
574
|
):
|
540
575
|
"""
|
541
|
-
|
576
|
+
Compute the average precision per class for object detection evaluation.
|
542
577
|
|
543
578
|
Args:
|
544
579
|
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
|
545
580
|
conf (np.ndarray): Array of confidence scores of the detections.
|
546
581
|
pred_cls (np.ndarray): Array of predicted classes of the detections.
|
547
582
|
target_cls (np.ndarray): Array of true classes of the detections.
|
548
|
-
plot (bool, optional): Whether to plot PR curves or not.
|
549
|
-
on_plot (func, optional): A callback to pass plots path and data when they are rendered.
|
550
|
-
save_dir (Path, optional): Directory to save the PR curves.
|
551
|
-
names (dict, optional): Dict of class names to plot PR curves.
|
552
|
-
eps (float, optional): A small value to avoid division by zero.
|
553
|
-
prefix (str, optional): A prefix string for saving the plot files.
|
583
|
+
plot (bool, optional): Whether to plot PR curves or not.
|
584
|
+
on_plot (func, optional): A callback to pass plots path and data when they are rendered.
|
585
|
+
save_dir (Path, optional): Directory to save the PR curves.
|
586
|
+
names (dict, optional): Dict of class names to plot PR curves.
|
587
|
+
eps (float, optional): A small value to avoid division by zero.
|
588
|
+
prefix (str, optional): A prefix string for saving the plot files.
|
554
589
|
|
555
590
|
Returns:
|
556
|
-
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.
|
557
|
-
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.
|
558
|
-
p (np.ndarray): Precision values at threshold given by max F1 metric for each class.
|
559
|
-
r (np.ndarray): Recall values at threshold given by max F1 metric for each class.
|
560
|
-
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.
|
561
|
-
ap (np.ndarray): Average precision for each class at different IoU thresholds.
|
562
|
-
unique_classes (np.ndarray): An array of unique classes that have data.
|
563
|
-
p_curve (np.ndarray): Precision curves for each class.
|
564
|
-
r_curve (np.ndarray): Recall curves for each class.
|
565
|
-
f1_curve (np.ndarray): F1-score curves for each class.
|
566
|
-
x (np.ndarray): X-axis values for the curves.
|
567
|
-
prec_values (np.ndarray): Precision values at mAP@0.5 for each class.
|
591
|
+
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.
|
592
|
+
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.
|
593
|
+
p (np.ndarray): Precision values at threshold given by max F1 metric for each class.
|
594
|
+
r (np.ndarray): Recall values at threshold given by max F1 metric for each class.
|
595
|
+
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.
|
596
|
+
ap (np.ndarray): Average precision for each class at different IoU thresholds.
|
597
|
+
unique_classes (np.ndarray): An array of unique classes that have data.
|
598
|
+
p_curve (np.ndarray): Precision curves for each class.
|
599
|
+
r_curve (np.ndarray): Recall curves for each class.
|
600
|
+
f1_curve (np.ndarray): F1-score curves for each class.
|
601
|
+
x (np.ndarray): X-axis values for the curves.
|
602
|
+
prec_values (np.ndarray): Precision values at mAP@0.5 for each class.
|
568
603
|
"""
|
569
604
|
# Sort by objectness
|
570
605
|
i = np.argsort(-conf)
|
@@ -651,7 +686,7 @@ class Metric(SimpleClass):
|
|
651
686
|
"""
|
652
687
|
|
653
688
|
def __init__(self) -> None:
|
654
|
-
"""
|
689
|
+
"""Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model."""
|
655
690
|
self.p = [] # (nc, )
|
656
691
|
self.r = [] # (nc, )
|
657
692
|
self.f1 = [] # (nc, )
|
@@ -662,7 +697,7 @@ class Metric(SimpleClass):
|
|
662
697
|
@property
|
663
698
|
def ap50(self):
|
664
699
|
"""
|
665
|
-
|
700
|
+
Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
666
701
|
|
667
702
|
Returns:
|
668
703
|
(np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
|
@@ -672,7 +707,7 @@ class Metric(SimpleClass):
|
|
672
707
|
@property
|
673
708
|
def ap(self):
|
674
709
|
"""
|
675
|
-
|
710
|
+
Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
676
711
|
|
677
712
|
Returns:
|
678
713
|
(np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
|
@@ -682,7 +717,7 @@ class Metric(SimpleClass):
|
|
682
717
|
@property
|
683
718
|
def mp(self):
|
684
719
|
"""
|
685
|
-
|
720
|
+
Return the Mean Precision of all classes.
|
686
721
|
|
687
722
|
Returns:
|
688
723
|
(float): The mean precision of all classes.
|
@@ -692,7 +727,7 @@ class Metric(SimpleClass):
|
|
692
727
|
@property
|
693
728
|
def mr(self):
|
694
729
|
"""
|
695
|
-
|
730
|
+
Return the Mean Recall of all classes.
|
696
731
|
|
697
732
|
Returns:
|
698
733
|
(float): The mean recall of all classes.
|
@@ -702,7 +737,7 @@ class Metric(SimpleClass):
|
|
702
737
|
@property
|
703
738
|
def map50(self):
|
704
739
|
"""
|
705
|
-
|
740
|
+
Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
706
741
|
|
707
742
|
Returns:
|
708
743
|
(float): The mAP at an IoU threshold of 0.5.
|
@@ -712,7 +747,7 @@ class Metric(SimpleClass):
|
|
712
747
|
@property
|
713
748
|
def map75(self):
|
714
749
|
"""
|
715
|
-
|
750
|
+
Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
716
751
|
|
717
752
|
Returns:
|
718
753
|
(float): The mAP at an IoU threshold of 0.75.
|
@@ -722,7 +757,7 @@ class Metric(SimpleClass):
|
|
722
757
|
@property
|
723
758
|
def map(self):
|
724
759
|
"""
|
725
|
-
|
760
|
+
Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
726
761
|
|
727
762
|
Returns:
|
728
763
|
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
@@ -730,41 +765,42 @@ class Metric(SimpleClass):
|
|
730
765
|
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
731
766
|
|
732
767
|
def mean_results(self):
|
733
|
-
"""
|
768
|
+
"""Return mean of results, mp, mr, map50, map."""
|
734
769
|
return [self.mp, self.mr, self.map50, self.map]
|
735
770
|
|
736
771
|
def class_result(self, i):
|
737
|
-
"""
|
772
|
+
"""Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
|
738
773
|
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
|
739
774
|
|
740
775
|
@property
|
741
776
|
def maps(self):
|
742
|
-
"""
|
777
|
+
"""Return mAP of each class."""
|
743
778
|
maps = np.zeros(self.nc) + self.map
|
744
779
|
for i, c in enumerate(self.ap_class_index):
|
745
780
|
maps[c] = self.ap[i]
|
746
781
|
return maps
|
747
782
|
|
748
783
|
def fitness(self):
|
749
|
-
"""
|
784
|
+
"""Return model fitness as a weighted combination of metrics."""
|
750
785
|
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
751
786
|
return (np.array(self.mean_results()) * w).sum()
|
752
787
|
|
753
788
|
def update(self, results):
|
754
789
|
"""
|
755
|
-
|
790
|
+
Update the evaluation metrics with a new set of results.
|
756
791
|
|
757
792
|
Args:
|
758
|
-
results (tuple): A tuple containing
|
759
|
-
- p (list): Precision for each class.
|
760
|
-
- r (list): Recall for each class.
|
761
|
-
- f1 (list): F1 score for each class.
|
762
|
-
- all_ap (list): AP scores for all classes and all IoU thresholds.
|
763
|
-
- ap_class_index (list): Index of class for each AP score.
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
793
|
+
results (tuple): A tuple containing evaluation metrics:
|
794
|
+
- p (list): Precision for each class.
|
795
|
+
- r (list): Recall for each class.
|
796
|
+
- f1 (list): F1 score for each class.
|
797
|
+
- all_ap (list): AP scores for all classes and all IoU thresholds.
|
798
|
+
- ap_class_index (list): Index of class for each AP score.
|
799
|
+
- p_curve (list): Precision curve for each class.
|
800
|
+
- r_curve (list): Recall curve for each class.
|
801
|
+
- f1_curve (list): F1 curve for each class.
|
802
|
+
- px (list): X values for the curves.
|
803
|
+
- prec_values (list): Precision values for each class.
|
768
804
|
"""
|
769
805
|
(
|
770
806
|
self.p,
|
@@ -781,12 +817,12 @@ class Metric(SimpleClass):
|
|
781
817
|
|
782
818
|
@property
|
783
819
|
def curves(self):
|
784
|
-
"""
|
820
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
785
821
|
return []
|
786
822
|
|
787
823
|
@property
|
788
824
|
def curves_results(self):
|
789
|
-
"""
|
825
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
790
826
|
return [
|
791
827
|
[self.px, self.prec_values, "Recall", "Precision"],
|
792
828
|
[self.px, self.f1_curve, "Confidence", "F1"],
|
@@ -797,36 +833,26 @@ class Metric(SimpleClass):
|
|
797
833
|
|
798
834
|
class DetMetrics(SimpleClass):
|
799
835
|
"""
|
800
|
-
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP)
|
801
|
-
object detection model.
|
802
|
-
|
803
|
-
Args:
|
804
|
-
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
805
|
-
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
|
806
|
-
names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
|
836
|
+
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
807
837
|
|
808
838
|
Attributes:
|
809
839
|
save_dir (Path): A path to the directory where the output plots will be saved.
|
810
|
-
plot (bool): A flag that indicates whether to plot
|
811
|
-
names (dict
|
812
|
-
box (Metric): An instance of the Metric class for storing
|
813
|
-
speed (dict): A dictionary for storing
|
814
|
-
|
815
|
-
Methods:
|
816
|
-
process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
|
817
|
-
keys: Returns a list of keys for accessing the computed detection metrics.
|
818
|
-
mean_results: Returns a list of mean values for the computed detection metrics.
|
819
|
-
class_result(i): Returns a list of values for the computed detection metrics for a specific class.
|
820
|
-
maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
|
821
|
-
fitness: Computes the fitness score based on the computed detection metrics.
|
822
|
-
ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
|
823
|
-
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
|
824
|
-
curves: TODO
|
825
|
-
curves_results: TODO
|
840
|
+
plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
|
841
|
+
names (dict): A dictionary of class names.
|
842
|
+
box (Metric): An instance of the Metric class for storing detection results.
|
843
|
+
speed (dict): A dictionary for storing execution times of different parts of the detection process.
|
844
|
+
task (str): The task type, set to 'detect'.
|
826
845
|
"""
|
827
846
|
|
828
847
|
def __init__(self, save_dir=Path("."), plot=False, names={}) -> None:
|
829
|
-
"""
|
848
|
+
"""
|
849
|
+
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
850
|
+
|
851
|
+
Args:
|
852
|
+
save_dir (Path, optional): Directory to save plots.
|
853
|
+
plot (bool, optional): Whether to plot precision-recall curves.
|
854
|
+
names (dict, optional): Dictionary mapping class indices to names.
|
855
|
+
"""
|
830
856
|
self.save_dir = save_dir
|
831
857
|
self.plot = plot
|
832
858
|
self.names = names
|
@@ -835,7 +861,16 @@ class DetMetrics(SimpleClass):
|
|
835
861
|
self.task = "detect"
|
836
862
|
|
837
863
|
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
|
838
|
-
"""
|
864
|
+
"""
|
865
|
+
Process predicted results for object detection and update metrics.
|
866
|
+
|
867
|
+
Args:
|
868
|
+
tp (np.ndarray): True positive array.
|
869
|
+
conf (np.ndarray): Confidence array.
|
870
|
+
pred_cls (np.ndarray): Predicted class indices array.
|
871
|
+
target_cls (np.ndarray): Target class indices array.
|
872
|
+
on_plot (callable, optional): Function to call after plots are generated.
|
873
|
+
"""
|
839
874
|
results = ap_per_class(
|
840
875
|
tp,
|
841
876
|
conf,
|
@@ -851,7 +886,7 @@ class DetMetrics(SimpleClass):
|
|
851
886
|
|
852
887
|
@property
|
853
888
|
def keys(self):
|
854
|
-
"""
|
889
|
+
"""Return a list of keys for accessing specific metrics."""
|
855
890
|
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
856
891
|
|
857
892
|
def mean_results(self):
|
@@ -864,32 +899,32 @@ class DetMetrics(SimpleClass):
|
|
864
899
|
|
865
900
|
@property
|
866
901
|
def maps(self):
|
867
|
-
"""
|
902
|
+
"""Return mean Average Precision (mAP) scores per class."""
|
868
903
|
return self.box.maps
|
869
904
|
|
870
905
|
@property
|
871
906
|
def fitness(self):
|
872
|
-
"""
|
907
|
+
"""Return the fitness of box object."""
|
873
908
|
return self.box.fitness()
|
874
909
|
|
875
910
|
@property
|
876
911
|
def ap_class_index(self):
|
877
|
-
"""
|
912
|
+
"""Return the average precision index per class."""
|
878
913
|
return self.box.ap_class_index
|
879
914
|
|
880
915
|
@property
|
881
916
|
def results_dict(self):
|
882
|
-
"""
|
917
|
+
"""Return dictionary of computed performance metrics and statistics."""
|
883
918
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
884
919
|
|
885
920
|
@property
|
886
921
|
def curves(self):
|
887
|
-
"""
|
922
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
888
923
|
return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
|
889
924
|
|
890
925
|
@property
|
891
926
|
def curves_results(self):
|
892
|
-
"""
|
927
|
+
"""Return dictionary of computed performance metrics and statistics."""
|
893
928
|
return self.box.curves_results
|
894
929
|
|
895
930
|
|
@@ -897,31 +932,25 @@ class SegmentMetrics(SimpleClass):
|
|
897
932
|
"""
|
898
933
|
Calculates and aggregates detection and segmentation metrics over a given set of classes.
|
899
934
|
|
900
|
-
Args:
|
901
|
-
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
902
|
-
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
903
|
-
names (list): List of class names. Default is an empty list.
|
904
|
-
|
905
935
|
Attributes:
|
906
936
|
save_dir (Path): Path to the directory where the output plots should be saved.
|
907
937
|
plot (bool): Whether to save the detection and segmentation plots.
|
908
|
-
names (
|
938
|
+
names (dict): Dictionary of class names.
|
909
939
|
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
910
940
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
911
941
|
speed (dict): Dictionary to store the time taken in different phases of inference.
|
912
|
-
|
913
|
-
Methods:
|
914
|
-
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
|
915
|
-
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
|
916
|
-
class_result(i): Returns the detection and segmentation metrics of class `i`.
|
917
|
-
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
|
918
|
-
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
|
919
|
-
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
|
920
|
-
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
942
|
+
task (str): The task type, set to 'segment'.
|
921
943
|
"""
|
922
944
|
|
923
945
|
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
924
|
-
"""
|
946
|
+
"""
|
947
|
+
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
948
|
+
|
949
|
+
Args:
|
950
|
+
save_dir (Path, optional): Directory to save plots.
|
951
|
+
plot (bool, optional): Whether to plot precision-recall curves.
|
952
|
+
names (dict, optional): Dictionary mapping class indices to names.
|
953
|
+
"""
|
925
954
|
self.save_dir = save_dir
|
926
955
|
self.plot = plot
|
927
956
|
self.names = names
|
@@ -932,15 +961,15 @@ class SegmentMetrics(SimpleClass):
|
|
932
961
|
|
933
962
|
def process(self, tp, tp_m, conf, pred_cls, target_cls, on_plot=None):
|
934
963
|
"""
|
935
|
-
|
964
|
+
Process the detection and segmentation metrics over the given set of predictions.
|
936
965
|
|
937
966
|
Args:
|
938
|
-
tp (
|
939
|
-
tp_m (
|
940
|
-
conf (
|
941
|
-
pred_cls (
|
942
|
-
target_cls (
|
943
|
-
on_plot (
|
967
|
+
tp (np.ndarray): True positive array for boxes.
|
968
|
+
tp_m (np.ndarray): True positive array for masks.
|
969
|
+
conf (np.ndarray): Confidence array.
|
970
|
+
pred_cls (np.ndarray): Predicted class indices array.
|
971
|
+
target_cls (np.ndarray): Target class indices array.
|
972
|
+
on_plot (callable, optional): Function to call after plots are generated.
|
944
973
|
"""
|
945
974
|
results_mask = ap_per_class(
|
946
975
|
tp_m,
|
@@ -971,7 +1000,7 @@ class SegmentMetrics(SimpleClass):
|
|
971
1000
|
|
972
1001
|
@property
|
973
1002
|
def keys(self):
|
974
|
-
"""
|
1003
|
+
"""Return a list of keys for accessing metrics."""
|
975
1004
|
return [
|
976
1005
|
"metrics/precision(B)",
|
977
1006
|
"metrics/recall(B)",
|
@@ -988,32 +1017,36 @@ class SegmentMetrics(SimpleClass):
|
|
988
1017
|
return self.box.mean_results() + self.seg.mean_results()
|
989
1018
|
|
990
1019
|
def class_result(self, i):
|
991
|
-
"""
|
1020
|
+
"""Return classification results for a specified class index."""
|
992
1021
|
return self.box.class_result(i) + self.seg.class_result(i)
|
993
1022
|
|
994
1023
|
@property
|
995
1024
|
def maps(self):
|
996
|
-
"""
|
1025
|
+
"""Return mAP scores for object detection and semantic segmentation models."""
|
997
1026
|
return self.box.maps + self.seg.maps
|
998
1027
|
|
999
1028
|
@property
|
1000
1029
|
def fitness(self):
|
1001
|
-
"""
|
1030
|
+
"""Return the fitness score for both segmentation and bounding box models."""
|
1002
1031
|
return self.seg.fitness() + self.box.fitness()
|
1003
1032
|
|
1004
1033
|
@property
|
1005
1034
|
def ap_class_index(self):
|
1006
|
-
"""
|
1035
|
+
"""
|
1036
|
+
Return the class indices.
|
1037
|
+
|
1038
|
+
Boxes and masks have the same ap_class_index.
|
1039
|
+
"""
|
1007
1040
|
return self.box.ap_class_index
|
1008
1041
|
|
1009
1042
|
@property
|
1010
1043
|
def results_dict(self):
|
1011
|
-
"""
|
1044
|
+
"""Return results of object detection model for evaluation."""
|
1012
1045
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1013
1046
|
|
1014
1047
|
@property
|
1015
1048
|
def curves(self):
|
1016
|
-
"""
|
1049
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
1017
1050
|
return [
|
1018
1051
|
"Precision-Recall(B)",
|
1019
1052
|
"F1-Confidence(B)",
|
@@ -1027,7 +1060,7 @@ class SegmentMetrics(SimpleClass):
|
|
1027
1060
|
|
1028
1061
|
@property
|
1029
1062
|
def curves_results(self):
|
1030
|
-
"""
|
1063
|
+
"""Return dictionary of computed performance metrics and statistics."""
|
1031
1064
|
return self.box.curves_results + self.seg.curves_results
|
1032
1065
|
|
1033
1066
|
|
@@ -1035,18 +1068,14 @@ class PoseMetrics(SegmentMetrics):
|
|
1035
1068
|
"""
|
1036
1069
|
Calculates and aggregates detection and pose metrics over a given set of classes.
|
1037
1070
|
|
1038
|
-
Args:
|
1039
|
-
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
1040
|
-
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
1041
|
-
names (list): List of class names. Default is an empty list.
|
1042
|
-
|
1043
1071
|
Attributes:
|
1044
1072
|
save_dir (Path): Path to the directory where the output plots should be saved.
|
1045
|
-
plot (bool): Whether to save the detection and
|
1046
|
-
names (
|
1073
|
+
plot (bool): Whether to save the detection and pose plots.
|
1074
|
+
names (dict): Dictionary of class names.
|
1047
1075
|
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
1048
|
-
pose (Metric): An instance of the Metric class to calculate
|
1076
|
+
pose (Metric): An instance of the Metric class to calculate pose metrics.
|
1049
1077
|
speed (dict): Dictionary to store the time taken in different phases of inference.
|
1078
|
+
task (str): The task type, set to 'pose'.
|
1050
1079
|
|
1051
1080
|
Methods:
|
1052
1081
|
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
|
@@ -1059,7 +1088,14 @@ class PoseMetrics(SegmentMetrics):
|
|
1059
1088
|
"""
|
1060
1089
|
|
1061
1090
|
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
1062
|
-
"""
|
1091
|
+
"""
|
1092
|
+
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
1093
|
+
|
1094
|
+
Args:
|
1095
|
+
save_dir (Path, optional): Directory to save plots.
|
1096
|
+
plot (bool, optional): Whether to plot precision-recall curves.
|
1097
|
+
names (dict, optional): Dictionary mapping class indices to names.
|
1098
|
+
"""
|
1063
1099
|
super().__init__(save_dir, plot, names)
|
1064
1100
|
self.save_dir = save_dir
|
1065
1101
|
self.plot = plot
|
@@ -1071,15 +1107,15 @@ class PoseMetrics(SegmentMetrics):
|
|
1071
1107
|
|
1072
1108
|
def process(self, tp, tp_p, conf, pred_cls, target_cls, on_plot=None):
|
1073
1109
|
"""
|
1074
|
-
|
1110
|
+
Process the detection and pose metrics over the given set of predictions.
|
1075
1111
|
|
1076
1112
|
Args:
|
1077
|
-
tp (
|
1078
|
-
tp_p (
|
1079
|
-
conf (
|
1080
|
-
pred_cls (
|
1081
|
-
target_cls (
|
1082
|
-
on_plot (
|
1113
|
+
tp (np.ndarray): True positive array for boxes.
|
1114
|
+
tp_p (np.ndarray): True positive array for keypoints.
|
1115
|
+
conf (np.ndarray): Confidence array.
|
1116
|
+
pred_cls (np.ndarray): Predicted class indices array.
|
1117
|
+
target_cls (np.ndarray): Target class indices array.
|
1118
|
+
on_plot (callable, optional): Function to call after plots are generated.
|
1083
1119
|
"""
|
1084
1120
|
results_pose = ap_per_class(
|
1085
1121
|
tp_p,
|
@@ -1110,7 +1146,7 @@ class PoseMetrics(SegmentMetrics):
|
|
1110
1146
|
|
1111
1147
|
@property
|
1112
1148
|
def keys(self):
|
1113
|
-
"""
|
1149
|
+
"""Return list of evaluation metric keys."""
|
1114
1150
|
return [
|
1115
1151
|
"metrics/precision(B)",
|
1116
1152
|
"metrics/recall(B)",
|
@@ -1132,17 +1168,17 @@ class PoseMetrics(SegmentMetrics):
|
|
1132
1168
|
|
1133
1169
|
@property
|
1134
1170
|
def maps(self):
|
1135
|
-
"""
|
1171
|
+
"""Return the mean average precision (mAP) per class for both box and pose detections."""
|
1136
1172
|
return self.box.maps + self.pose.maps
|
1137
1173
|
|
1138
1174
|
@property
|
1139
1175
|
def fitness(self):
|
1140
|
-
"""
|
1176
|
+
"""Return combined fitness score for pose and box detection."""
|
1141
1177
|
return self.pose.fitness() + self.box.fitness()
|
1142
1178
|
|
1143
1179
|
@property
|
1144
1180
|
def curves(self):
|
1145
|
-
"""
|
1181
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
1146
1182
|
return [
|
1147
1183
|
"Precision-Recall(B)",
|
1148
1184
|
"F1-Confidence(B)",
|
@@ -1156,7 +1192,7 @@ class PoseMetrics(SegmentMetrics):
|
|
1156
1192
|
|
1157
1193
|
@property
|
1158
1194
|
def curves_results(self):
|
1159
|
-
"""
|
1195
|
+
"""Return dictionary of computed performance metrics and statistics."""
|
1160
1196
|
return self.box.curves_results + self.pose.curves_results
|
1161
1197
|
|
1162
1198
|
|
@@ -1167,13 +1203,8 @@ class ClassifyMetrics(SimpleClass):
|
|
1167
1203
|
Attributes:
|
1168
1204
|
top1 (float): The top-1 accuracy.
|
1169
1205
|
top5 (float): The top-5 accuracy.
|
1170
|
-
speed (
|
1171
|
-
|
1172
|
-
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
|
1173
|
-
keys (List[str]): A list of keys for the results_dict.
|
1174
|
-
|
1175
|
-
Methods:
|
1176
|
-
process(targets, pred): Processes the targets and predictions to compute classification metrics.
|
1206
|
+
speed (dict): A dictionary containing the time taken for each step in the pipeline.
|
1207
|
+
task (str): The task type, set to 'classify'.
|
1177
1208
|
"""
|
1178
1209
|
|
1179
1210
|
def __init__(self) -> None:
|
@@ -1184,7 +1215,13 @@ class ClassifyMetrics(SimpleClass):
|
|
1184
1215
|
self.task = "classify"
|
1185
1216
|
|
1186
1217
|
def process(self, targets, pred):
|
1187
|
-
"""
|
1218
|
+
"""
|
1219
|
+
Process target classes and predicted classes to compute metrics.
|
1220
|
+
|
1221
|
+
Args:
|
1222
|
+
targets (torch.Tensor): Target classes.
|
1223
|
+
pred (torch.Tensor): Predicted classes.
|
1224
|
+
"""
|
1188
1225
|
pred, targets = torch.cat(pred), torch.cat(targets)
|
1189
1226
|
correct = (targets[:, None] == pred).float()
|
1190
1227
|
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
|
@@ -1192,35 +1229,54 @@ class ClassifyMetrics(SimpleClass):
|
|
1192
1229
|
|
1193
1230
|
@property
|
1194
1231
|
def fitness(self):
|
1195
|
-
"""
|
1232
|
+
"""Return mean of top-1 and top-5 accuracies as fitness score."""
|
1196
1233
|
return (self.top1 + self.top5) / 2
|
1197
1234
|
|
1198
1235
|
@property
|
1199
1236
|
def results_dict(self):
|
1200
|
-
"""
|
1237
|
+
"""Return a dictionary with model's performance metrics and fitness score."""
|
1201
1238
|
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
|
1202
1239
|
|
1203
1240
|
@property
|
1204
1241
|
def keys(self):
|
1205
|
-
"""
|
1242
|
+
"""Return a list of keys for the results_dict property."""
|
1206
1243
|
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
1207
1244
|
|
1208
1245
|
@property
|
1209
1246
|
def curves(self):
|
1210
|
-
"""
|
1247
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
1211
1248
|
return []
|
1212
1249
|
|
1213
1250
|
@property
|
1214
1251
|
def curves_results(self):
|
1215
|
-
"""
|
1252
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
1216
1253
|
return []
|
1217
1254
|
|
1218
1255
|
|
1219
1256
|
class OBBMetrics(SimpleClass):
|
1220
|
-
"""
|
1257
|
+
"""
|
1258
|
+
Metrics for evaluating oriented bounding box (OBB) detection.
|
1259
|
+
|
1260
|
+
Attributes:
|
1261
|
+
save_dir (Path): Path to the directory where the output plots should be saved.
|
1262
|
+
plot (bool): Whether to save the detection plots.
|
1263
|
+
names (dict): Dictionary of class names.
|
1264
|
+
box (Metric): An instance of the Metric class for storing detection results.
|
1265
|
+
speed (dict): A dictionary for storing execution times of different parts of the detection process.
|
1266
|
+
|
1267
|
+
References:
|
1268
|
+
https://arxiv.org/pdf/2106.06072.pdf
|
1269
|
+
"""
|
1221
1270
|
|
1222
1271
|
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
1223
|
-
"""
|
1272
|
+
"""
|
1273
|
+
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
1274
|
+
|
1275
|
+
Args:
|
1276
|
+
save_dir (Path, optional): Directory to save plots.
|
1277
|
+
plot (bool, optional): Whether to plot precision-recall curves.
|
1278
|
+
names (dict, optional): Dictionary mapping class indices to names.
|
1279
|
+
"""
|
1224
1280
|
self.save_dir = save_dir
|
1225
1281
|
self.plot = plot
|
1226
1282
|
self.names = names
|
@@ -1228,7 +1284,16 @@ class OBBMetrics(SimpleClass):
|
|
1228
1284
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1229
1285
|
|
1230
1286
|
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
|
1231
|
-
"""
|
1287
|
+
"""
|
1288
|
+
Process predicted results for object detection and update metrics.
|
1289
|
+
|
1290
|
+
Args:
|
1291
|
+
tp (np.ndarray): True positive array.
|
1292
|
+
conf (np.ndarray): Confidence array.
|
1293
|
+
pred_cls (np.ndarray): Predicted class indices array.
|
1294
|
+
target_cls (np.ndarray): Target class indices array.
|
1295
|
+
on_plot (callable, optional): Function to call after plots are generated.
|
1296
|
+
"""
|
1232
1297
|
results = ap_per_class(
|
1233
1298
|
tp,
|
1234
1299
|
conf,
|
@@ -1244,7 +1309,7 @@ class OBBMetrics(SimpleClass):
|
|
1244
1309
|
|
1245
1310
|
@property
|
1246
1311
|
def keys(self):
|
1247
|
-
"""
|
1312
|
+
"""Return a list of keys for accessing specific metrics."""
|
1248
1313
|
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
1249
1314
|
|
1250
1315
|
def mean_results(self):
|
@@ -1257,30 +1322,30 @@ class OBBMetrics(SimpleClass):
|
|
1257
1322
|
|
1258
1323
|
@property
|
1259
1324
|
def maps(self):
|
1260
|
-
"""
|
1325
|
+
"""Return mean Average Precision (mAP) scores per class."""
|
1261
1326
|
return self.box.maps
|
1262
1327
|
|
1263
1328
|
@property
|
1264
1329
|
def fitness(self):
|
1265
|
-
"""
|
1330
|
+
"""Return the fitness of box object."""
|
1266
1331
|
return self.box.fitness()
|
1267
1332
|
|
1268
1333
|
@property
|
1269
1334
|
def ap_class_index(self):
|
1270
|
-
"""
|
1335
|
+
"""Return the average precision index per class."""
|
1271
1336
|
return self.box.ap_class_index
|
1272
1337
|
|
1273
1338
|
@property
|
1274
1339
|
def results_dict(self):
|
1275
|
-
"""
|
1340
|
+
"""Return dictionary of computed performance metrics and statistics."""
|
1276
1341
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1277
1342
|
|
1278
1343
|
@property
|
1279
1344
|
def curves(self):
|
1280
|
-
"""
|
1345
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
1281
1346
|
return []
|
1282
1347
|
|
1283
1348
|
@property
|
1284
1349
|
def curves_results(self):
|
1285
|
-
"""
|
1350
|
+
"""Return a list of curves for accessing specific metrics curves."""
|
1286
1351
|
return []
|