ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +527 -67
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +44 -37
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +84 -56
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.28.dist-info/METADATA +0 -373
- ultralytics-8.1.28.dist-info/RECORD +0 -197
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import torch.nn as nn
|
@@ -7,6 +7,8 @@ import torch.nn.functional as F
|
|
7
7
|
from ultralytics.utils.metrics import OKS_SIGMA
|
8
8
|
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
9
9
|
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
10
|
+
from ultralytics.utils.torch_utils import autocast
|
11
|
+
|
10
12
|
from .metrics import bbox_iou, probiou
|
11
13
|
from .tal import bbox2dist
|
12
14
|
|
@@ -26,7 +28,7 @@ class VarifocalLoss(nn.Module):
|
|
26
28
|
def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
27
29
|
"""Computes varfocal loss."""
|
28
30
|
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
29
|
-
with
|
31
|
+
with autocast(enabled=False):
|
30
32
|
loss = (
|
31
33
|
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
32
34
|
.mean(1)
|
@@ -60,39 +62,22 @@ class FocalLoss(nn.Module):
|
|
60
62
|
return loss.mean(1).sum()
|
61
63
|
|
62
64
|
|
63
|
-
class
|
64
|
-
"""Criterion class for computing
|
65
|
+
class DFLoss(nn.Module):
|
66
|
+
"""Criterion class for computing DFL losses during training."""
|
65
67
|
|
66
|
-
def __init__(self, reg_max
|
67
|
-
"""Initialize the
|
68
|
+
def __init__(self, reg_max=16) -> None:
|
69
|
+
"""Initialize the DFL module."""
|
68
70
|
super().__init__()
|
69
71
|
self.reg_max = reg_max
|
70
|
-
self.use_dfl = use_dfl
|
71
|
-
|
72
|
-
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
73
|
-
"""IoU loss."""
|
74
|
-
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
75
|
-
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
76
|
-
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
77
72
|
|
78
|
-
|
79
|
-
if self.use_dfl:
|
80
|
-
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
|
81
|
-
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
82
|
-
loss_dfl = loss_dfl.sum() / target_scores_sum
|
83
|
-
else:
|
84
|
-
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
85
|
-
|
86
|
-
return loss_iou, loss_dfl
|
87
|
-
|
88
|
-
@staticmethod
|
89
|
-
def _df_loss(pred_dist, target):
|
73
|
+
def __call__(self, pred_dist, target):
|
90
74
|
"""
|
91
75
|
Return sum of left and right DFL losses.
|
92
76
|
|
93
77
|
Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
|
94
78
|
https://ieeexplore.ieee.org/document/9792391
|
95
79
|
"""
|
80
|
+
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
96
81
|
tl = target.long() # target left
|
97
82
|
tr = tl + 1 # target right
|
98
83
|
wl = tr - target # weight left
|
@@ -103,12 +88,37 @@ class BboxLoss(nn.Module):
|
|
103
88
|
).mean(-1, keepdim=True)
|
104
89
|
|
105
90
|
|
91
|
+
class BboxLoss(nn.Module):
|
92
|
+
"""Criterion class for computing training losses during training."""
|
93
|
+
|
94
|
+
def __init__(self, reg_max=16):
|
95
|
+
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
96
|
+
super().__init__()
|
97
|
+
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
98
|
+
|
99
|
+
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
100
|
+
"""IoU loss."""
|
101
|
+
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
102
|
+
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
103
|
+
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
104
|
+
|
105
|
+
# DFL loss
|
106
|
+
if self.dfl_loss:
|
107
|
+
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
|
108
|
+
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
109
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
110
|
+
else:
|
111
|
+
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
112
|
+
|
113
|
+
return loss_iou, loss_dfl
|
114
|
+
|
115
|
+
|
106
116
|
class RotatedBboxLoss(BboxLoss):
|
107
117
|
"""Criterion class for computing training losses during training."""
|
108
118
|
|
109
|
-
def __init__(self, reg_max
|
119
|
+
def __init__(self, reg_max):
|
110
120
|
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
111
|
-
super().__init__(reg_max
|
121
|
+
super().__init__(reg_max)
|
112
122
|
|
113
123
|
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
114
124
|
"""IoU loss."""
|
@@ -117,9 +127,9 @@ class RotatedBboxLoss(BboxLoss):
|
|
117
127
|
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
118
128
|
|
119
129
|
# DFL loss
|
120
|
-
if self.
|
121
|
-
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
|
122
|
-
loss_dfl = self.
|
130
|
+
if self.dfl_loss:
|
131
|
+
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
|
132
|
+
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
123
133
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
124
134
|
else:
|
125
135
|
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
@@ -140,14 +150,14 @@ class KeypointLoss(nn.Module):
|
|
140
150
|
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
|
141
151
|
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
|
142
152
|
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
143
|
-
e = d / (2 * self.sigmas).pow(2)
|
153
|
+
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
|
144
154
|
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
|
145
155
|
|
146
156
|
|
147
157
|
class v8DetectionLoss:
|
148
158
|
"""Criterion class for computing training losses."""
|
149
159
|
|
150
|
-
def __init__(self, model): # model must be de-paralleled
|
160
|
+
def __init__(self, model, tal_topk=10): # model must be de-paralleled
|
151
161
|
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
|
152
162
|
device = next(model.parameters()).device # get model device
|
153
163
|
h = model.args # hyperparameters
|
@@ -157,29 +167,29 @@ class v8DetectionLoss:
|
|
157
167
|
self.hyp = h
|
158
168
|
self.stride = m.stride # model strides
|
159
169
|
self.nc = m.nc # number of classes
|
160
|
-
self.no = m.
|
170
|
+
self.no = m.nc + m.reg_max * 4
|
161
171
|
self.reg_max = m.reg_max
|
162
172
|
self.device = device
|
163
173
|
|
164
174
|
self.use_dfl = m.reg_max > 1
|
165
175
|
|
166
|
-
self.assigner = TaskAlignedAssigner(topk=
|
167
|
-
self.bbox_loss = BboxLoss(m.reg_max
|
176
|
+
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
|
177
|
+
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
168
178
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
169
179
|
|
170
180
|
def preprocess(self, targets, batch_size, scale_tensor):
|
171
181
|
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
172
|
-
|
173
|
-
|
182
|
+
nl, ne = targets.shape
|
183
|
+
if nl == 0:
|
184
|
+
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
|
174
185
|
else:
|
175
186
|
i = targets[:, 0] # image index
|
176
187
|
_, counts = i.unique(return_counts=True)
|
177
188
|
counts = counts.to(dtype=torch.int32)
|
178
|
-
out = torch.zeros(batch_size, counts.max(),
|
189
|
+
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
|
179
190
|
for j in range(batch_size):
|
180
191
|
matches = i == j
|
181
|
-
n
|
182
|
-
if n:
|
192
|
+
if n := matches.sum():
|
183
193
|
out[j, :n] = targets[matches, 1:]
|
184
194
|
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
185
195
|
return out
|
@@ -213,12 +223,15 @@ class v8DetectionLoss:
|
|
213
223
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
214
224
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
215
225
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
216
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
226
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
217
227
|
|
218
228
|
# Pboxes
|
219
229
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
230
|
+
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
|
231
|
+
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
|
220
232
|
|
221
233
|
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
234
|
+
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
|
222
235
|
pred_scores.detach().sigmoid(),
|
223
236
|
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
224
237
|
anchor_points * stride_tensor,
|
@@ -279,7 +292,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
279
292
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
280
293
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
281
294
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
282
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
295
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
283
296
|
except RuntimeError as e:
|
284
297
|
raise TypeError(
|
285
298
|
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
@@ -466,7 +479,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
466
479
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
467
480
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
468
481
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
469
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
482
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
470
483
|
|
471
484
|
# Pboxes
|
472
485
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
@@ -538,9 +551,8 @@ class v8PoseLoss(v8DetectionLoss):
|
|
538
551
|
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
539
552
|
|
540
553
|
Returns:
|
541
|
-
(
|
542
|
-
|
543
|
-
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
554
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
555
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
544
556
|
"""
|
545
557
|
batch_idx = batch_idx.flatten()
|
546
558
|
batch_size = len(masks)
|
@@ -591,21 +603,20 @@ class v8ClassificationLoss:
|
|
591
603
|
|
592
604
|
def __call__(self, preds, batch):
|
593
605
|
"""Compute the classification loss between predictions and true labels."""
|
594
|
-
|
606
|
+
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
607
|
+
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
595
608
|
loss_items = loss.detach()
|
596
609
|
return loss, loss_items
|
597
610
|
|
598
611
|
|
599
612
|
class v8OBBLoss(v8DetectionLoss):
|
600
|
-
|
601
|
-
"""
|
602
|
-
Initializes v8OBBLoss with model, assigner, and rotated bbox loss.
|
613
|
+
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
|
603
614
|
|
604
|
-
|
605
|
-
"""
|
615
|
+
def __init__(self, model):
|
616
|
+
"""Initializes v8OBBLoss with model, assigner, and rotated bbox loss; note model must be de-paralleled."""
|
606
617
|
super().__init__(model)
|
607
618
|
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
608
|
-
self.bbox_loss = RotatedBboxLoss(self.reg_max
|
619
|
+
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
609
620
|
|
610
621
|
def preprocess(self, targets, batch_size, scale_tensor):
|
611
622
|
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
@@ -618,8 +629,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
618
629
|
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
|
619
630
|
for j in range(batch_size):
|
620
631
|
matches = i == j
|
621
|
-
n
|
622
|
-
if n:
|
632
|
+
if n := matches.sum():
|
623
633
|
bboxes = targets[matches, 2:]
|
624
634
|
bboxes[..., :4].mul_(scale_tensor)
|
625
635
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
@@ -651,7 +661,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
651
661
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
652
662
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
653
663
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
654
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
664
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
655
665
|
except RuntimeError as e:
|
656
666
|
raise TypeError(
|
657
667
|
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
@@ -713,3 +723,21 @@ class v8OBBLoss(v8DetectionLoss):
|
|
713
723
|
b, a, c = pred_dist.shape # batch, anchors, channels
|
714
724
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
715
725
|
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
726
|
+
|
727
|
+
|
728
|
+
class E2EDetectLoss:
|
729
|
+
"""Criterion class for computing training losses."""
|
730
|
+
|
731
|
+
def __init__(self, model):
|
732
|
+
"""Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
|
733
|
+
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
734
|
+
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
735
|
+
|
736
|
+
def __call__(self, preds, batch):
|
737
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
738
|
+
preds = preds[1] if isinstance(preds, tuple) else preds
|
739
|
+
one2many = preds["one2many"]
|
740
|
+
loss_one2many = self.one2many(one2many, batch)
|
741
|
+
one2one = preds["one2one"]
|
742
|
+
loss_one2one = self.one2one(one2one, batch)
|
743
|
+
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
|
ultralytics/utils/metrics.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""Model validation metrics."""
|
3
3
|
|
4
4
|
import math
|
@@ -30,7 +30,6 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
30
30
|
Returns:
|
31
31
|
(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
|
32
32
|
"""
|
33
|
-
|
34
33
|
# Get the coordinates of bounding boxes
|
35
34
|
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
36
35
|
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
@@ -53,7 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
53
52
|
def box_iou(box1, box2, eps=1e-7):
|
54
53
|
"""
|
55
54
|
Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
56
|
-
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
55
|
+
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
|
57
56
|
|
58
57
|
Args:
|
59
58
|
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
@@ -63,9 +62,9 @@ def box_iou(box1, box2, eps=1e-7):
|
|
63
62
|
Returns:
|
64
63
|
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
65
64
|
"""
|
66
|
-
|
65
|
+
# NOTE: Need .float() to get accurate iou values
|
67
66
|
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
68
|
-
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
67
|
+
(a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
|
69
68
|
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
|
70
69
|
|
71
70
|
# IoU = inter / (area1 + area2 - inter)
|
@@ -74,11 +73,16 @@ def box_iou(box1, box2, eps=1e-7):
|
|
74
73
|
|
75
74
|
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
76
75
|
"""
|
77
|
-
|
76
|
+
Calculates the Intersection over Union (IoU) between bounding boxes.
|
77
|
+
|
78
|
+
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
|
79
|
+
For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
|
80
|
+
Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
|
81
|
+
or (x1, y1, x2, y2) if `xywh=False`.
|
78
82
|
|
79
83
|
Args:
|
80
|
-
box1 (torch.Tensor): A tensor representing
|
81
|
-
box2 (torch.Tensor): A tensor representing
|
84
|
+
box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
85
|
+
box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
82
86
|
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
|
83
87
|
(x1, y1, x2, y2) format. Defaults to True.
|
84
88
|
GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
|
@@ -89,7 +93,6 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|
89
93
|
Returns:
|
90
94
|
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
|
91
95
|
"""
|
92
|
-
|
93
96
|
# Get the coordinates of bounding boxes
|
94
97
|
if xywh: # transform from xywh to xyxy
|
95
98
|
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
@@ -167,7 +170,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
|
167
170
|
d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)
|
168
171
|
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
|
169
172
|
kpt_mask = kpt1[..., 2] != 0 # (N, 17)
|
170
|
-
e = d / (2 * sigma).pow(2)
|
173
|
+
e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval
|
171
174
|
# e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
|
172
175
|
return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
|
173
176
|
|
@@ -180,7 +183,7 @@ def _get_covariance_matrix(boxes):
|
|
180
183
|
boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
|
181
184
|
|
182
185
|
Returns:
|
183
|
-
(torch.Tensor): Covariance
|
186
|
+
(torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
|
184
187
|
"""
|
185
188
|
# Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
|
186
189
|
gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
|
@@ -194,15 +197,22 @@ def _get_covariance_matrix(boxes):
|
|
194
197
|
|
195
198
|
def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
196
199
|
"""
|
197
|
-
Calculate
|
200
|
+
Calculate probabilistic IoU between oriented bounding boxes.
|
201
|
+
|
202
|
+
Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
|
198
203
|
|
199
204
|
Args:
|
200
|
-
obb1 (torch.Tensor):
|
201
|
-
obb2 (torch.Tensor):
|
202
|
-
|
205
|
+
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
|
206
|
+
obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
|
207
|
+
CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
|
208
|
+
eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
|
203
209
|
|
204
210
|
Returns:
|
205
|
-
(torch.Tensor):
|
211
|
+
(torch.Tensor): OBB similarities, shape (N,).
|
212
|
+
|
213
|
+
Note:
|
214
|
+
OBB format: [center_x, center_y, width, height, rotation_angle].
|
215
|
+
If CIoU is True, returns CIoU instead of IoU.
|
206
216
|
"""
|
207
217
|
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
208
218
|
x2, y2 = obb2[..., :2].split(1, dim=-1)
|
@@ -265,7 +275,7 @@ def batch_probiou(obb1, obb2, eps=1e-7):
|
|
265
275
|
return 1 - hd
|
266
276
|
|
267
277
|
|
268
|
-
def
|
278
|
+
def smooth_bce(eps=0.1):
|
269
279
|
"""
|
270
280
|
Computes smoothed positive and negative Binary Cross-Entropy targets.
|
271
281
|
|
@@ -298,7 +308,7 @@ class ConfusionMatrix:
|
|
298
308
|
self.task = task
|
299
309
|
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
300
310
|
self.nc = nc # number of classes
|
301
|
-
self.conf = 0.25 if conf in
|
311
|
+
self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
302
312
|
self.iou_thres = iou_thres
|
303
313
|
|
304
314
|
def process_cls_preds(self, preds, targets):
|
@@ -367,10 +377,9 @@ class ConfusionMatrix:
|
|
367
377
|
else:
|
368
378
|
self.matrix[self.nc, gc] += 1 # true background
|
369
379
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
self.matrix[dc, self.nc] += 1 # predicted background
|
380
|
+
for i, dc in enumerate(detection_classes):
|
381
|
+
if not any(m1 == i):
|
382
|
+
self.matrix[dc, self.nc] += 1 # predicted background
|
374
383
|
|
375
384
|
def matrix(self):
|
376
385
|
"""Returns the confusion matrix."""
|
@@ -395,19 +404,19 @@ class ConfusionMatrix:
|
|
395
404
|
names (tuple): Names of classes, used as labels on the plot.
|
396
405
|
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
397
406
|
"""
|
398
|
-
import seaborn
|
407
|
+
import seaborn # scope for faster 'import ultralytics'
|
399
408
|
|
400
409
|
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
401
410
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
402
411
|
|
403
412
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
404
413
|
nc, nn = self.nc, len(names) # number of classes, names
|
405
|
-
|
414
|
+
seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
406
415
|
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
407
416
|
ticklabels = (list(names) + ["background"]) if labels else "auto"
|
408
417
|
with warnings.catch_warnings():
|
409
418
|
warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
410
|
-
|
419
|
+
seaborn.heatmap(
|
411
420
|
array,
|
412
421
|
ax=ax,
|
413
422
|
annot=nc < 30,
|
@@ -423,7 +432,7 @@ class ConfusionMatrix:
|
|
423
432
|
ax.set_xlabel("True")
|
424
433
|
ax.set_ylabel("Predicted")
|
425
434
|
ax.set_title(title)
|
426
|
-
plot_fname = Path(save_dir) / f
|
435
|
+
plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png"
|
427
436
|
fig.savefig(plot_fname, dpi=250)
|
428
437
|
plt.close(fig)
|
429
438
|
if on_plot:
|
@@ -444,7 +453,7 @@ def smooth(y, f=0.05):
|
|
444
453
|
|
445
454
|
|
446
455
|
@plt_settings()
|
447
|
-
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=
|
456
|
+
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
|
448
457
|
"""Plots a precision-recall curve."""
|
449
458
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
450
459
|
py = np.stack(py, axis=1)
|
@@ -455,7 +464,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=N
|
|
455
464
|
else:
|
456
465
|
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
457
466
|
|
458
|
-
ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes
|
467
|
+
ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
|
459
468
|
ax.set_xlabel("Recall")
|
460
469
|
ax.set_ylabel("Precision")
|
461
470
|
ax.set_xlim(0, 1)
|
@@ -469,7 +478,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=N
|
|
469
478
|
|
470
479
|
|
471
480
|
@plt_settings()
|
472
|
-
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=
|
481
|
+
def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
|
473
482
|
"""Plots a metric-confidence curve."""
|
474
483
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
475
484
|
|
@@ -506,7 +515,6 @@ def compute_ap(recall, precision):
|
|
506
515
|
(np.ndarray): Precision envelope curve.
|
507
516
|
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
508
517
|
"""
|
509
|
-
|
510
518
|
# Append sentinel values to beginning and end
|
511
519
|
mrec = np.concatenate(([0.0], recall, [1.0]))
|
512
520
|
mpre = np.concatenate(([1.0], precision, [0.0]))
|
@@ -527,7 +535,7 @@ def compute_ap(recall, precision):
|
|
527
535
|
|
528
536
|
|
529
537
|
def ap_per_class(
|
530
|
-
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=
|
538
|
+
tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
|
531
539
|
):
|
532
540
|
"""
|
533
541
|
Computes the average precision per class for object detection evaluation.
|
@@ -540,26 +548,24 @@ def ap_per_class(
|
|
540
548
|
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
|
541
549
|
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
|
542
550
|
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
|
543
|
-
names (
|
551
|
+
names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple.
|
544
552
|
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
|
545
553
|
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
546
554
|
|
547
555
|
Returns:
|
548
|
-
(
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
|
556
|
+
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
|
557
|
+
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
|
558
|
+
p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
|
559
|
+
r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
|
560
|
+
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
|
561
|
+
ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
|
562
|
+
unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
|
563
|
+
p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
|
564
|
+
r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
|
565
|
+
f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
|
566
|
+
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
|
567
|
+
prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
|
561
568
|
"""
|
562
|
-
|
563
569
|
# Sort by objectness
|
564
570
|
i = np.argsort(-conf)
|
565
571
|
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
@@ -595,7 +601,7 @@ def ap_per_class(
|
|
595
601
|
# AP from recall-precision curve
|
596
602
|
for j in range(tp.shape[1]):
|
597
603
|
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
598
|
-
if
|
604
|
+
if j == 0:
|
599
605
|
prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5
|
600
606
|
|
601
607
|
prec_values = np.array(prec_values) # (nc, 1000)
|
@@ -791,20 +797,20 @@ class Metric(SimpleClass):
|
|
791
797
|
|
792
798
|
class DetMetrics(SimpleClass):
|
793
799
|
"""
|
794
|
-
|
795
|
-
|
800
|
+
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
|
801
|
+
object detection model.
|
796
802
|
|
797
803
|
Args:
|
798
804
|
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
799
805
|
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
|
800
806
|
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
801
|
-
names (
|
807
|
+
names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
|
802
808
|
|
803
809
|
Attributes:
|
804
810
|
save_dir (Path): A path to the directory where the output plots will be saved.
|
805
811
|
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
|
806
812
|
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
807
|
-
names (
|
813
|
+
names (dict of str): A dict of strings that represents the names of the classes.
|
808
814
|
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
|
809
815
|
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
|
810
816
|
|
@@ -821,7 +827,7 @@ class DetMetrics(SimpleClass):
|
|
821
827
|
curves_results: TODO
|
822
828
|
"""
|
823
829
|
|
824
|
-
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=
|
830
|
+
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names={}) -> None:
|
825
831
|
"""Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
|
826
832
|
self.save_dir = save_dir
|
827
833
|
self.plot = plot
|
@@ -941,7 +947,6 @@ class SegmentMetrics(SimpleClass):
|
|
941
947
|
pred_cls (list): List of predicted classes.
|
942
948
|
target_cls (list): List of target classes.
|
943
949
|
"""
|
944
|
-
|
945
950
|
results_mask = ap_per_class(
|
946
951
|
tp_m,
|
947
952
|
conf,
|
@@ -1083,7 +1088,6 @@ class PoseMetrics(SegmentMetrics):
|
|
1083
1088
|
pred_cls (list): List of predicted classes.
|
1084
1089
|
target_cls (list): List of target classes.
|
1085
1090
|
"""
|
1086
|
-
|
1087
1091
|
results_pose = ap_per_class(
|
1088
1092
|
tp_p,
|
1089
1093
|
conf,
|
@@ -1171,8 +1175,6 @@ class ClassifyMetrics(SimpleClass):
|
|
1171
1175
|
top1 (float): The top-1 accuracy.
|
1172
1176
|
top5 (float): The top-5 accuracy.
|
1173
1177
|
speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
|
1174
|
-
|
1175
|
-
Properties:
|
1176
1178
|
fitness (float): The fitness of the model, which is equal to top-5 accuracy.
|
1177
1179
|
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
|
1178
1180
|
keys (List[str]): A list of keys for the results_dict.
|
@@ -1222,7 +1224,10 @@ class ClassifyMetrics(SimpleClass):
|
|
1222
1224
|
|
1223
1225
|
|
1224
1226
|
class OBBMetrics(SimpleClass):
|
1227
|
+
"""Metrics for evaluating oriented bounding box (OBB) detection, see https://arxiv.org/pdf/2106.06072.pdf."""
|
1228
|
+
|
1225
1229
|
def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
|
1230
|
+
"""Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
|
1226
1231
|
self.save_dir = save_dir
|
1227
1232
|
self.plot = plot
|
1228
1233
|
self.on_plot = on_plot
|