dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py
CHANGED
|
@@ -2,24 +2,24 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
import torch.nn as nn
|
|
9
10
|
import torch.nn.functional as F
|
|
10
11
|
|
|
11
|
-
from ultralytics.utils.metrics import OKS_SIGMA
|
|
12
|
+
from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
|
|
12
13
|
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
|
13
14
|
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
|
14
15
|
from ultralytics.utils.torch_utils import autocast
|
|
15
16
|
|
|
16
17
|
from .metrics import bbox_iou, probiou
|
|
17
|
-
from .tal import bbox2dist
|
|
18
|
+
from .tal import bbox2dist, rbox2dist
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class VarifocalLoss(nn.Module):
|
|
21
|
-
"""
|
|
22
|
-
Varifocal loss by Zhang et al.
|
|
22
|
+
"""Varifocal loss by Zhang et al.
|
|
23
23
|
|
|
24
24
|
Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
|
|
25
25
|
hard-to-classify examples and balancing positive/negative samples.
|
|
@@ -51,11 +51,10 @@ class VarifocalLoss(nn.Module):
|
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
class FocalLoss(nn.Module):
|
|
54
|
-
"""
|
|
55
|
-
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
54
|
+
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
56
55
|
|
|
57
|
-
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
|
|
58
|
-
|
|
56
|
+
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing on
|
|
57
|
+
hard negatives during training.
|
|
59
58
|
|
|
60
59
|
Attributes:
|
|
61
60
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
|
@@ -124,6 +123,8 @@ class BboxLoss(nn.Module):
|
|
|
124
123
|
target_scores: torch.Tensor,
|
|
125
124
|
target_scores_sum: torch.Tensor,
|
|
126
125
|
fg_mask: torch.Tensor,
|
|
126
|
+
imgsz: torch.Tensor,
|
|
127
|
+
stride: torch.Tensor,
|
|
127
128
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
128
129
|
"""Compute IoU and DFL losses for bounding boxes."""
|
|
129
130
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -136,11 +137,76 @@ class BboxLoss(nn.Module):
|
|
|
136
137
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
137
138
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
138
139
|
else:
|
|
139
|
-
|
|
140
|
+
target_ltrb = bbox2dist(anchor_points, target_bboxes)
|
|
141
|
+
# normalize ltrb by image size
|
|
142
|
+
target_ltrb = target_ltrb * stride
|
|
143
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
144
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
145
|
+
pred_dist = pred_dist * stride
|
|
146
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
147
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
148
|
+
loss_dfl = (
|
|
149
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
150
|
+
)
|
|
151
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
140
152
|
|
|
141
153
|
return loss_iou, loss_dfl
|
|
142
154
|
|
|
143
155
|
|
|
156
|
+
class RLELoss(nn.Module):
|
|
157
|
+
"""Residual Log-Likelihood Estimation Loss.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
use_target_weight (bool): Option to use weighted loss.
|
|
161
|
+
size_average (bool): Option to average the loss by the batch_size.
|
|
162
|
+
residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
|
|
163
|
+
|
|
164
|
+
References:
|
|
165
|
+
https://arxiv.org/abs/2107.11291
|
|
166
|
+
https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/losses/regression_loss.py
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
|
|
170
|
+
"""Initialize RLELoss with target weight and residual options.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
use_target_weight (bool): Whether to use target weights for loss calculation.
|
|
174
|
+
size_average (bool): Whether to average the loss over elements.
|
|
175
|
+
residual (bool): Whether to include residual log-likelihood term.
|
|
176
|
+
"""
|
|
177
|
+
super().__init__()
|
|
178
|
+
self.size_average = size_average
|
|
179
|
+
self.use_target_weight = use_target_weight
|
|
180
|
+
self.residual = residual
|
|
181
|
+
|
|
182
|
+
def forward(
|
|
183
|
+
self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
|
|
184
|
+
) -> torch.Tensor:
|
|
185
|
+
"""
|
|
186
|
+
Args:
|
|
187
|
+
sigma (torch.Tensor): Output sigma, shape (N, D).
|
|
188
|
+
log_phi (torch.Tensor): Output log_phi, shape (N).
|
|
189
|
+
error (torch.Tensor): Error, shape (N, D).
|
|
190
|
+
target_weight (torch.Tensor): Weights across different joint types, shape (N).
|
|
191
|
+
"""
|
|
192
|
+
log_sigma = torch.log(sigma)
|
|
193
|
+
loss = log_sigma - log_phi.unsqueeze(1)
|
|
194
|
+
|
|
195
|
+
if self.residual:
|
|
196
|
+
loss += torch.log(sigma * 2) + torch.abs(error)
|
|
197
|
+
|
|
198
|
+
if self.use_target_weight:
|
|
199
|
+
assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
|
|
200
|
+
if target_weight.dim() == 1:
|
|
201
|
+
target_weight = target_weight.unsqueeze(1)
|
|
202
|
+
loss *= target_weight
|
|
203
|
+
|
|
204
|
+
if self.size_average:
|
|
205
|
+
loss /= len(loss)
|
|
206
|
+
|
|
207
|
+
return loss.sum()
|
|
208
|
+
|
|
209
|
+
|
|
144
210
|
class RotatedBboxLoss(BboxLoss):
|
|
145
211
|
"""Criterion class for computing training losses for rotated bounding boxes."""
|
|
146
212
|
|
|
@@ -157,6 +223,8 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
157
223
|
target_scores: torch.Tensor,
|
|
158
224
|
target_scores_sum: torch.Tensor,
|
|
159
225
|
fg_mask: torch.Tensor,
|
|
226
|
+
imgsz: torch.Tensor,
|
|
227
|
+
stride: torch.Tensor,
|
|
160
228
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
161
229
|
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
|
162
230
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -165,15 +233,84 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
165
233
|
|
|
166
234
|
# DFL loss
|
|
167
235
|
if self.dfl_loss:
|
|
168
|
-
target_ltrb =
|
|
236
|
+
target_ltrb = rbox2dist(
|
|
237
|
+
target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
|
|
238
|
+
)
|
|
169
239
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
170
240
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
171
241
|
else:
|
|
172
|
-
|
|
242
|
+
target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
|
|
243
|
+
target_ltrb = target_ltrb * stride
|
|
244
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
245
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
246
|
+
pred_dist = pred_dist * stride
|
|
247
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
248
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
249
|
+
loss_dfl = (
|
|
250
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
251
|
+
)
|
|
252
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
173
253
|
|
|
174
254
|
return loss_iou, loss_dfl
|
|
175
255
|
|
|
176
256
|
|
|
257
|
+
class MultiChannelDiceLoss(nn.Module):
|
|
258
|
+
"""Criterion class for computing multi-channel Dice losses."""
|
|
259
|
+
|
|
260
|
+
def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
|
|
261
|
+
"""Initialize MultiChannelDiceLoss with smoothing and reduction options.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
265
|
+
reduction (str): Reduction method ('mean', 'sum', or 'none').
|
|
266
|
+
"""
|
|
267
|
+
super().__init__()
|
|
268
|
+
self.smooth = smooth
|
|
269
|
+
self.reduction = reduction
|
|
270
|
+
|
|
271
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
272
|
+
"""Calculate multi-channel Dice loss between predictions and targets."""
|
|
273
|
+
assert pred.size() == target.size(), "the size of predict and target must be equal."
|
|
274
|
+
|
|
275
|
+
pred = pred.sigmoid()
|
|
276
|
+
intersection = (pred * target).sum(dim=(2, 3))
|
|
277
|
+
union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
|
|
278
|
+
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
|
279
|
+
dice_loss = 1.0 - dice
|
|
280
|
+
dice_loss = dice_loss.mean(dim=1)
|
|
281
|
+
|
|
282
|
+
if self.reduction == "mean":
|
|
283
|
+
return dice_loss.mean()
|
|
284
|
+
elif self.reduction == "sum":
|
|
285
|
+
return dice_loss.sum()
|
|
286
|
+
else:
|
|
287
|
+
return dice_loss
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class BCEDiceLoss(nn.Module):
|
|
291
|
+
"""Criterion class for computing combined BCE and Dice losses."""
|
|
292
|
+
|
|
293
|
+
def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
|
|
294
|
+
"""Initialize BCEDiceLoss with BCE and Dice weight factors.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
weight_bce (float): Weight factor for BCE loss component.
|
|
298
|
+
weight_dice (float): Weight factor for Dice loss component.
|
|
299
|
+
"""
|
|
300
|
+
super().__init__()
|
|
301
|
+
self.weight_bce = weight_bce
|
|
302
|
+
self.weight_dice = weight_dice
|
|
303
|
+
self.bce = nn.BCEWithLogitsLoss()
|
|
304
|
+
self.dice = MultiChannelDiceLoss(smooth=1)
|
|
305
|
+
|
|
306
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
307
|
+
"""Calculate combined BCE and Dice loss between predictions and targets."""
|
|
308
|
+
_, _, mask_h, mask_w = pred.shape
|
|
309
|
+
if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
|
|
310
|
+
target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
|
|
311
|
+
return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
|
|
312
|
+
|
|
313
|
+
|
|
177
314
|
class KeypointLoss(nn.Module):
|
|
178
315
|
"""Criterion class for computing keypoint losses."""
|
|
179
316
|
|
|
@@ -196,7 +333,7 @@ class KeypointLoss(nn.Module):
|
|
|
196
333
|
class v8DetectionLoss:
|
|
197
334
|
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
|
198
335
|
|
|
199
|
-
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
|
|
336
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
200
337
|
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
|
201
338
|
device = next(model.parameters()).device # get model device
|
|
202
339
|
h = model.args # hyperparameters
|
|
@@ -212,7 +349,14 @@ class v8DetectionLoss:
|
|
|
212
349
|
|
|
213
350
|
self.use_dfl = m.reg_max > 1
|
|
214
351
|
|
|
215
|
-
self.assigner = TaskAlignedAssigner(
|
|
352
|
+
self.assigner = TaskAlignedAssigner(
|
|
353
|
+
topk=tal_topk,
|
|
354
|
+
num_classes=self.nc,
|
|
355
|
+
alpha=0.5,
|
|
356
|
+
beta=6.0,
|
|
357
|
+
stride=self.stride.tolist(),
|
|
358
|
+
topk2=tal_topk2,
|
|
359
|
+
)
|
|
216
360
|
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
|
217
361
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
|
218
362
|
|
|
@@ -242,35 +386,31 @@ class v8DetectionLoss:
|
|
|
242
386
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
|
243
387
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
|
244
388
|
|
|
245
|
-
def
|
|
246
|
-
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size
|
|
389
|
+
def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
|
|
390
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
|
|
391
|
+
target indices.
|
|
392
|
+
"""
|
|
247
393
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
(
|
|
394
|
+
pred_distri, pred_scores = (
|
|
395
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
396
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
251
397
|
)
|
|
252
|
-
|
|
253
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
254
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
398
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
255
399
|
|
|
256
400
|
dtype = pred_scores.dtype
|
|
257
401
|
batch_size = pred_scores.shape[0]
|
|
258
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
259
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
402
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
260
403
|
|
|
261
404
|
# Targets
|
|
262
405
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
263
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
406
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
264
407
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
265
408
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
266
409
|
|
|
267
410
|
# Pboxes
|
|
268
411
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
269
|
-
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
|
|
270
|
-
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
|
|
271
412
|
|
|
272
|
-
_, target_bboxes, target_scores, fg_mask,
|
|
273
|
-
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
|
|
413
|
+
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
274
414
|
pred_scores.detach().sigmoid(),
|
|
275
415
|
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
276
416
|
anchor_points * stride_tensor,
|
|
@@ -282,7 +422,6 @@ class v8DetectionLoss:
|
|
|
282
422
|
target_scores_sum = max(target_scores.sum(), 1)
|
|
283
423
|
|
|
284
424
|
# Cls loss
|
|
285
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
286
425
|
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
287
426
|
|
|
288
427
|
# Bbox loss
|
|
@@ -295,112 +434,114 @@ class v8DetectionLoss:
|
|
|
295
434
|
target_scores,
|
|
296
435
|
target_scores_sum,
|
|
297
436
|
fg_mask,
|
|
437
|
+
imgsz,
|
|
438
|
+
stride_tensor,
|
|
298
439
|
)
|
|
299
440
|
|
|
300
441
|
loss[0] *= self.hyp.box # box gain
|
|
301
442
|
loss[1] *= self.hyp.cls # cls gain
|
|
302
443
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
444
|
+
return (
|
|
445
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
|
|
446
|
+
loss,
|
|
447
|
+
loss.detach(),
|
|
448
|
+
) # loss(box, cls, dfl)
|
|
303
449
|
|
|
304
|
-
|
|
450
|
+
def parse_output(
|
|
451
|
+
self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
|
|
452
|
+
) -> torch.Tensor:
|
|
453
|
+
"""Parse model predictions to extract features."""
|
|
454
|
+
return preds[1] if isinstance(preds, tuple) else preds
|
|
455
|
+
|
|
456
|
+
def __call__(
|
|
457
|
+
self,
|
|
458
|
+
preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
|
|
459
|
+
batch: dict[str, torch.Tensor],
|
|
460
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
461
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
462
|
+
return self.loss(self.parse_output(preds), batch)
|
|
463
|
+
|
|
464
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
465
|
+
"""A wrapper for get_assigned_targets_and_loss and parse_output."""
|
|
466
|
+
batch_size = preds["boxes"].shape[0]
|
|
467
|
+
loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
|
|
468
|
+
return loss * batch_size, loss_detach
|
|
305
469
|
|
|
306
470
|
|
|
307
471
|
class v8SegmentationLoss(v8DetectionLoss):
|
|
308
472
|
"""Criterion class for computing training losses for YOLOv8 segmentation."""
|
|
309
473
|
|
|
310
|
-
def __init__(self, model): # model must be de-paralleled
|
|
474
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
311
475
|
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
|
|
312
|
-
super().__init__(model)
|
|
476
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
313
477
|
self.overlap = model.args.overlap_mask
|
|
478
|
+
self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
|
|
314
479
|
|
|
315
|
-
def
|
|
480
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
316
481
|
"""Calculate and return the combined loss for detection and segmentation."""
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
#
|
|
325
|
-
|
|
326
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
327
|
-
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
|
328
|
-
|
|
329
|
-
dtype = pred_scores.dtype
|
|
330
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
331
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
332
|
-
|
|
333
|
-
# Targets
|
|
334
|
-
try:
|
|
335
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
336
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
337
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
338
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
339
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
340
|
-
except RuntimeError as e:
|
|
341
|
-
raise TypeError(
|
|
342
|
-
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
|
343
|
-
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
|
344
|
-
"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
|
345
|
-
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
|
346
|
-
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
|
347
|
-
) from e
|
|
348
|
-
|
|
349
|
-
# Pboxes
|
|
350
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
351
|
-
|
|
352
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
353
|
-
pred_scores.detach().sigmoid(),
|
|
354
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
355
|
-
anchor_points * stride_tensor,
|
|
356
|
-
gt_labels,
|
|
357
|
-
gt_bboxes,
|
|
358
|
-
mask_gt,
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
362
|
-
|
|
363
|
-
# Cls loss
|
|
364
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
365
|
-
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
482
|
+
pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
|
|
483
|
+
loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
|
|
484
|
+
if isinstance(proto, tuple) and len(proto) == 2:
|
|
485
|
+
proto, pred_semseg = proto
|
|
486
|
+
else:
|
|
487
|
+
pred_semseg = None
|
|
488
|
+
(fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
|
|
489
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
490
|
+
loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
|
|
366
491
|
|
|
492
|
+
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
|
367
493
|
if fg_mask.sum():
|
|
368
|
-
# Bbox loss
|
|
369
|
-
loss[0], loss[3] = self.bbox_loss(
|
|
370
|
-
pred_distri,
|
|
371
|
-
pred_bboxes,
|
|
372
|
-
anchor_points,
|
|
373
|
-
target_bboxes / stride_tensor,
|
|
374
|
-
target_scores,
|
|
375
|
-
target_scores_sum,
|
|
376
|
-
fg_mask,
|
|
377
|
-
)
|
|
378
494
|
# Masks loss
|
|
379
495
|
masks = batch["masks"].to(self.device).float()
|
|
380
496
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
|
381
|
-
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
497
|
+
# masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
498
|
+
proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
|
|
382
499
|
|
|
500
|
+
imgsz = (
|
|
501
|
+
torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
|
|
502
|
+
)
|
|
383
503
|
loss[1] = self.calculate_segmentation_loss(
|
|
384
|
-
fg_mask,
|
|
504
|
+
fg_mask,
|
|
505
|
+
masks,
|
|
506
|
+
target_gt_idx,
|
|
507
|
+
target_bboxes,
|
|
508
|
+
batch["batch_idx"].view(-1, 1),
|
|
509
|
+
proto,
|
|
510
|
+
pred_masks,
|
|
511
|
+
imgsz,
|
|
385
512
|
)
|
|
513
|
+
if pred_semseg is not None:
|
|
514
|
+
sem_masks = batch["sem_masks"].to(self.device) # NxHxW
|
|
515
|
+
sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
|
|
516
|
+
|
|
517
|
+
if self.overlap:
|
|
518
|
+
mask_zero = masks == 0 # NxHxW
|
|
519
|
+
sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
|
|
520
|
+
else:
|
|
521
|
+
batch_idx = batch["batch_idx"].view(-1) # [total_instances]
|
|
522
|
+
for i in range(batch_size):
|
|
523
|
+
instance_mask_i = masks[batch_idx == i] # [num_instances_i, H, W]
|
|
524
|
+
if len(instance_mask_i) == 0:
|
|
525
|
+
continue
|
|
526
|
+
sem_masks[i, :, instance_mask_i.sum(dim=0) == 0] = 0
|
|
527
|
+
|
|
528
|
+
loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
|
|
529
|
+
loss[4] *= self.hyp.box # seg gain
|
|
386
530
|
|
|
387
531
|
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
|
388
532
|
else:
|
|
389
533
|
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
|
534
|
+
if pred_semseg is not None:
|
|
535
|
+
loss[4] += (pred_semseg * 0).sum()
|
|
390
536
|
|
|
391
|
-
loss[0] *= self.hyp.box # box gain
|
|
392
537
|
loss[1] *= self.hyp.box # seg gain
|
|
393
|
-
loss
|
|
394
|
-
loss[3] *= self.hyp.dfl # dfl gain
|
|
395
|
-
|
|
396
|
-
return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
|
|
538
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
397
539
|
|
|
398
540
|
@staticmethod
|
|
399
541
|
def single_mask_loss(
|
|
400
542
|
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
|
401
543
|
) -> torch.Tensor:
|
|
402
|
-
"""
|
|
403
|
-
Compute the instance segmentation loss for a single image.
|
|
544
|
+
"""Compute the instance segmentation loss for a single image.
|
|
404
545
|
|
|
405
546
|
Args:
|
|
406
547
|
gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
|
|
@@ -430,10 +571,8 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
430
571
|
proto: torch.Tensor,
|
|
431
572
|
pred_masks: torch.Tensor,
|
|
432
573
|
imgsz: torch.Tensor,
|
|
433
|
-
overlap: bool,
|
|
434
574
|
) -> torch.Tensor:
|
|
435
|
-
"""
|
|
436
|
-
Calculate the loss for instance segmentation.
|
|
575
|
+
"""Calculate the loss for instance segmentation.
|
|
437
576
|
|
|
438
577
|
Args:
|
|
439
578
|
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
|
@@ -444,7 +583,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
444
583
|
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
|
445
584
|
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
|
446
585
|
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
|
447
|
-
overlap (bool): Whether the masks in `masks` tensor overlap.
|
|
448
586
|
|
|
449
587
|
Returns:
|
|
450
588
|
(torch.Tensor): The calculated loss for instance segmentation.
|
|
@@ -470,7 +608,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
470
608
|
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
|
471
609
|
if fg_mask_i.any():
|
|
472
610
|
mask_idx = target_gt_idx_i[fg_mask_i]
|
|
473
|
-
if overlap:
|
|
611
|
+
if self.overlap:
|
|
474
612
|
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
|
475
613
|
gt_mask = gt_mask.float()
|
|
476
614
|
else:
|
|
@@ -490,9 +628,9 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
490
628
|
class v8PoseLoss(v8DetectionLoss):
|
|
491
629
|
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
|
|
492
630
|
|
|
493
|
-
def __init__(self, model): # model must be de-paralleled
|
|
631
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
|
|
494
632
|
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
|
|
495
|
-
super().__init__(model)
|
|
633
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
496
634
|
self.kpt_shape = model.model[-1].kpt_shape
|
|
497
635
|
self.bce_pose = nn.BCEWithLogitsLoss()
|
|
498
636
|
is_pose = self.kpt_shape == [17, 3]
|
|
@@ -500,71 +638,42 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
500
638
|
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
|
501
639
|
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
|
502
640
|
|
|
503
|
-
def
|
|
641
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
504
642
|
"""Calculate the total loss and detach it for pose estimation."""
|
|
643
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
505
644
|
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
(self.reg_max * 4, self.nc), 1
|
|
645
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
646
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
509
647
|
)
|
|
648
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
649
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
510
650
|
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
514
|
-
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
|
515
|
-
|
|
516
|
-
dtype = pred_scores.dtype
|
|
517
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
518
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
519
|
-
|
|
520
|
-
# Targets
|
|
521
|
-
batch_size = pred_scores.shape[0]
|
|
522
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
523
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
524
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
525
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
526
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
651
|
+
batch_size = pred_kpts.shape[0]
|
|
652
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
|
|
527
653
|
|
|
528
654
|
# Pboxes
|
|
529
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
530
655
|
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
|
531
656
|
|
|
532
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
533
|
-
pred_scores.detach().sigmoid(),
|
|
534
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
535
|
-
anchor_points * stride_tensor,
|
|
536
|
-
gt_labels,
|
|
537
|
-
gt_bboxes,
|
|
538
|
-
mask_gt,
|
|
539
|
-
)
|
|
540
|
-
|
|
541
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
542
|
-
|
|
543
|
-
# Cls loss
|
|
544
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
545
|
-
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
546
|
-
|
|
547
657
|
# Bbox loss
|
|
548
658
|
if fg_mask.sum():
|
|
549
|
-
target_bboxes /= stride_tensor
|
|
550
|
-
loss[0], loss[4] = self.bbox_loss(
|
|
551
|
-
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
|
552
|
-
)
|
|
553
659
|
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
554
660
|
keypoints[..., 0] *= imgsz[1]
|
|
555
661
|
keypoints[..., 1] *= imgsz[0]
|
|
556
662
|
|
|
557
663
|
loss[1], loss[2] = self.calculate_keypoints_loss(
|
|
558
|
-
fg_mask,
|
|
664
|
+
fg_mask,
|
|
665
|
+
target_gt_idx,
|
|
666
|
+
keypoints,
|
|
667
|
+
batch["batch_idx"].view(-1, 1),
|
|
668
|
+
stride_tensor,
|
|
669
|
+
target_bboxes,
|
|
670
|
+
pred_kpts,
|
|
559
671
|
)
|
|
560
672
|
|
|
561
|
-
loss[0] *= self.hyp.box # box gain
|
|
562
673
|
loss[1] *= self.hyp.pose # pose gain
|
|
563
674
|
loss[2] *= self.hyp.kobj # kobj gain
|
|
564
|
-
loss[3] *= self.hyp.cls # cls gain
|
|
565
|
-
loss[4] *= self.hyp.dfl # dfl gain
|
|
566
675
|
|
|
567
|
-
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
676
|
+
return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
|
|
568
677
|
|
|
569
678
|
@staticmethod
|
|
570
679
|
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
|
@@ -575,35 +684,23 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
575
684
|
y[..., 1] += anchor_points[:, [1]] - 0.5
|
|
576
685
|
return y
|
|
577
686
|
|
|
578
|
-
def
|
|
687
|
+
def _select_target_keypoints(
|
|
579
688
|
self,
|
|
580
|
-
masks: torch.Tensor,
|
|
581
|
-
target_gt_idx: torch.Tensor,
|
|
582
689
|
keypoints: torch.Tensor,
|
|
583
690
|
batch_idx: torch.Tensor,
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
"""
|
|
589
|
-
Calculate the keypoints loss for the model.
|
|
590
|
-
|
|
591
|
-
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
592
|
-
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
593
|
-
a binary classification loss that classifies whether a keypoint is present or not.
|
|
691
|
+
target_gt_idx: torch.Tensor,
|
|
692
|
+
masks: torch.Tensor,
|
|
693
|
+
) -> torch.Tensor:
|
|
694
|
+
"""Select target keypoints for each anchor based on batch index and target ground truth index.
|
|
594
695
|
|
|
595
696
|
Args:
|
|
596
|
-
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
597
|
-
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
598
697
|
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
599
698
|
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
699
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
700
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
603
701
|
|
|
604
702
|
Returns:
|
|
605
|
-
|
|
606
|
-
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
703
|
+
(torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
607
704
|
"""
|
|
608
705
|
batch_idx = batch_idx.flatten()
|
|
609
706
|
batch_size = len(masks)
|
|
@@ -630,6 +727,40 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
630
727
|
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
|
631
728
|
)
|
|
632
729
|
|
|
730
|
+
return selected_keypoints
|
|
731
|
+
|
|
732
|
+
def calculate_keypoints_loss(
|
|
733
|
+
self,
|
|
734
|
+
masks: torch.Tensor,
|
|
735
|
+
target_gt_idx: torch.Tensor,
|
|
736
|
+
keypoints: torch.Tensor,
|
|
737
|
+
batch_idx: torch.Tensor,
|
|
738
|
+
stride_tensor: torch.Tensor,
|
|
739
|
+
target_bboxes: torch.Tensor,
|
|
740
|
+
pred_kpts: torch.Tensor,
|
|
741
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
742
|
+
"""Calculate the keypoints loss for the model.
|
|
743
|
+
|
|
744
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
745
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
746
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
750
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
751
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
752
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
753
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
754
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
755
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
759
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
760
|
+
"""
|
|
761
|
+
# Select target keypoints using helper method
|
|
762
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
763
|
+
|
|
633
764
|
# Divide coordinates by stride
|
|
634
765
|
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
635
766
|
|
|
@@ -637,6 +768,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
637
768
|
kpts_obj_loss = 0
|
|
638
769
|
|
|
639
770
|
if masks.any():
|
|
771
|
+
target_bboxes /= stride_tensor
|
|
640
772
|
gt_kpt = selected_keypoints[masks]
|
|
641
773
|
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
642
774
|
pred_kpt = pred_kpts[masks]
|
|
@@ -649,6 +781,172 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
649
781
|
return kpts_loss, kpts_obj_loss
|
|
650
782
|
|
|
651
783
|
|
|
784
|
+
class PoseLoss26(v8PoseLoss):
|
|
785
|
+
"""Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
|
|
786
|
+
|
|
787
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
788
|
+
"""Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
|
|
789
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
790
|
+
is_pose = self.kpt_shape == [17, 3]
|
|
791
|
+
nkpt = self.kpt_shape[0] # number of keypoints
|
|
792
|
+
self.rle_loss = None
|
|
793
|
+
self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
|
|
794
|
+
if self.flow_model is not None:
|
|
795
|
+
self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
|
|
796
|
+
self.target_weights = (
|
|
797
|
+
torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
801
|
+
"""Calculate the total loss and detach it for pose estimation."""
|
|
802
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
803
|
+
loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
804
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
805
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
806
|
+
)
|
|
807
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
808
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
809
|
+
|
|
810
|
+
batch_size = pred_kpts.shape[0]
|
|
811
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
|
|
812
|
+
|
|
813
|
+
pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
|
|
814
|
+
|
|
815
|
+
if self.rle_loss and preds.get("kpts_sigma", None) is not None:
|
|
816
|
+
pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
|
|
817
|
+
pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
|
|
818
|
+
pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
|
|
819
|
+
|
|
820
|
+
pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
|
|
821
|
+
|
|
822
|
+
# Bbox loss
|
|
823
|
+
if fg_mask.sum():
|
|
824
|
+
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
825
|
+
keypoints[..., 0] *= imgsz[1]
|
|
826
|
+
keypoints[..., 1] *= imgsz[0]
|
|
827
|
+
|
|
828
|
+
keypoints_loss = self.calculate_keypoints_loss(
|
|
829
|
+
fg_mask,
|
|
830
|
+
target_gt_idx,
|
|
831
|
+
keypoints,
|
|
832
|
+
batch["batch_idx"].view(-1, 1),
|
|
833
|
+
stride_tensor,
|
|
834
|
+
target_bboxes,
|
|
835
|
+
pred_kpts,
|
|
836
|
+
)
|
|
837
|
+
loss[1] = keypoints_loss[0]
|
|
838
|
+
loss[2] = keypoints_loss[1]
|
|
839
|
+
if self.rle_loss is not None:
|
|
840
|
+
loss[5] = keypoints_loss[2]
|
|
841
|
+
|
|
842
|
+
loss[1] *= self.hyp.pose # pose gain
|
|
843
|
+
loss[2] *= self.hyp.kobj # kobj gain
|
|
844
|
+
if self.rle_loss is not None:
|
|
845
|
+
loss[5] *= self.hyp.rle # rle gain
|
|
846
|
+
|
|
847
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl, kpt_location, kpt_visibility)
|
|
848
|
+
|
|
849
|
+
@staticmethod
|
|
850
|
+
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
|
851
|
+
"""Decode predicted keypoints to image coordinates."""
|
|
852
|
+
y = pred_kpts.clone()
|
|
853
|
+
y[..., 0] += anchor_points[:, [0]]
|
|
854
|
+
y[..., 1] += anchor_points[:, [1]]
|
|
855
|
+
return y
|
|
856
|
+
|
|
857
|
+
def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
|
|
858
|
+
"""Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
|
|
859
|
+
|
|
860
|
+
Args:
|
|
861
|
+
pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
|
|
862
|
+
gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
|
|
863
|
+
kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
|
|
864
|
+
|
|
865
|
+
Returns:
|
|
866
|
+
(torch.Tensor): The RLE loss.
|
|
867
|
+
"""
|
|
868
|
+
pred_kpt_visible = pred_kpt[kpt_mask]
|
|
869
|
+
gt_kpt_visible = gt_kpt[kpt_mask]
|
|
870
|
+
pred_coords = pred_kpt_visible[:, 0:2]
|
|
871
|
+
pred_sigma = pred_kpt_visible[:, -2:]
|
|
872
|
+
gt_coords = gt_kpt_visible[:, 0:2]
|
|
873
|
+
|
|
874
|
+
target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
|
|
875
|
+
target_weights = target_weights[kpt_mask]
|
|
876
|
+
|
|
877
|
+
pred_sigma = pred_sigma.sigmoid()
|
|
878
|
+
error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
|
|
879
|
+
|
|
880
|
+
# Filter out NaN and Inf values to prevent MultivariateNormal validation errors
|
|
881
|
+
valid_mask = ~(torch.isnan(error) | torch.isinf(error)).any(dim=-1)
|
|
882
|
+
if not valid_mask.any():
|
|
883
|
+
return torch.tensor(0.0, device=pred_kpt.device)
|
|
884
|
+
|
|
885
|
+
error = error[valid_mask]
|
|
886
|
+
error = error.clamp(-100, 100) # Prevent numerical instability
|
|
887
|
+
pred_sigma = pred_sigma[valid_mask]
|
|
888
|
+
target_weights = target_weights[valid_mask]
|
|
889
|
+
|
|
890
|
+
log_phi = self.flow_model.log_prob(error)
|
|
891
|
+
|
|
892
|
+
return self.rle_loss(pred_sigma, log_phi, error, target_weights)
|
|
893
|
+
|
|
894
|
+
def calculate_keypoints_loss(
|
|
895
|
+
self,
|
|
896
|
+
masks: torch.Tensor,
|
|
897
|
+
target_gt_idx: torch.Tensor,
|
|
898
|
+
keypoints: torch.Tensor,
|
|
899
|
+
batch_idx: torch.Tensor,
|
|
900
|
+
stride_tensor: torch.Tensor,
|
|
901
|
+
target_bboxes: torch.Tensor,
|
|
902
|
+
pred_kpts: torch.Tensor,
|
|
903
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
904
|
+
"""Calculate the keypoints loss for the model.
|
|
905
|
+
|
|
906
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
907
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
908
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
912
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
913
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
914
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
915
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
916
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
917
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
918
|
+
|
|
919
|
+
Returns:
|
|
920
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
921
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
922
|
+
rle_loss (torch.Tensor): The RLE loss.
|
|
923
|
+
"""
|
|
924
|
+
# Select target keypoints using inherited helper method
|
|
925
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
926
|
+
|
|
927
|
+
# Divide coordinates by stride
|
|
928
|
+
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
929
|
+
|
|
930
|
+
kpts_loss = 0
|
|
931
|
+
kpts_obj_loss = 0
|
|
932
|
+
rle_loss = 0
|
|
933
|
+
|
|
934
|
+
if masks.any():
|
|
935
|
+
target_bboxes /= stride_tensor
|
|
936
|
+
gt_kpt = selected_keypoints[masks]
|
|
937
|
+
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
938
|
+
pred_kpt = pred_kpts[masks]
|
|
939
|
+
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
|
940
|
+
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
|
941
|
+
|
|
942
|
+
if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
|
|
943
|
+
rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
|
|
944
|
+
if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
|
|
945
|
+
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
|
946
|
+
|
|
947
|
+
return kpts_loss, kpts_obj_loss, rle_loss
|
|
948
|
+
|
|
949
|
+
|
|
652
950
|
class v8ClassificationLoss:
|
|
653
951
|
"""Criterion class for computing training losses for classification."""
|
|
654
952
|
|
|
@@ -662,10 +960,17 @@ class v8ClassificationLoss:
|
|
|
662
960
|
class v8OBBLoss(v8DetectionLoss):
|
|
663
961
|
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
|
|
664
962
|
|
|
665
|
-
def __init__(self, model):
|
|
963
|
+
def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
|
|
666
964
|
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
|
|
667
|
-
super().__init__(model)
|
|
668
|
-
self.assigner = RotatedTaskAlignedAssigner(
|
|
965
|
+
super().__init__(model, tal_topk=tal_topk)
|
|
966
|
+
self.assigner = RotatedTaskAlignedAssigner(
|
|
967
|
+
topk=tal_topk,
|
|
968
|
+
num_classes=self.nc,
|
|
969
|
+
alpha=0.5,
|
|
970
|
+
beta=6.0,
|
|
971
|
+
stride=self.stride.tolist(),
|
|
972
|
+
topk2=tal_topk2,
|
|
973
|
+
)
|
|
669
974
|
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
|
670
975
|
|
|
671
976
|
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
|
@@ -685,38 +990,34 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
685
990
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
|
686
991
|
return out
|
|
687
992
|
|
|
688
|
-
def
|
|
993
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
689
994
|
"""Calculate and return the loss for oriented bounding box detection."""
|
|
690
|
-
loss = torch.zeros(
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
(
|
|
995
|
+
loss = torch.zeros(4, device=self.device) # box, cls, dfl, angle
|
|
996
|
+
pred_distri, pred_scores, pred_angle = (
|
|
997
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
998
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
999
|
+
preds["angle"].permute(0, 2, 1).contiguous(),
|
|
695
1000
|
)
|
|
696
|
-
|
|
697
|
-
#
|
|
698
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
699
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
700
|
-
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
|
1001
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
1002
|
+
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
|
701
1003
|
|
|
702
1004
|
dtype = pred_scores.dtype
|
|
703
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
704
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
1005
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
705
1006
|
|
|
706
1007
|
# targets
|
|
707
1008
|
try:
|
|
708
1009
|
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
709
1010
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
|
710
|
-
rw, rh = targets[:, 4] * imgsz[
|
|
1011
|
+
rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
|
|
711
1012
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
|
712
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
1013
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
713
1014
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
|
714
1015
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
715
1016
|
except RuntimeError as e:
|
|
716
1017
|
raise TypeError(
|
|
717
1018
|
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
|
718
1019
|
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
|
719
|
-
"i.e. 'yolo train model=
|
|
1020
|
+
"i.e. 'yolo train model=yolo26n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
|
|
720
1021
|
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
|
721
1022
|
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
|
722
1023
|
) from e
|
|
@@ -746,22 +1047,34 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
746
1047
|
if fg_mask.sum():
|
|
747
1048
|
target_bboxes[..., :4] /= stride_tensor
|
|
748
1049
|
loss[0], loss[2] = self.bbox_loss(
|
|
749
|
-
pred_distri,
|
|
1050
|
+
pred_distri,
|
|
1051
|
+
pred_bboxes,
|
|
1052
|
+
anchor_points,
|
|
1053
|
+
target_bboxes,
|
|
1054
|
+
target_scores,
|
|
1055
|
+
target_scores_sum,
|
|
1056
|
+
fg_mask,
|
|
1057
|
+
imgsz,
|
|
1058
|
+
stride_tensor,
|
|
750
1059
|
)
|
|
1060
|
+
weight = target_scores.sum(-1)[fg_mask]
|
|
1061
|
+
loss[3] = self.calculate_angle_loss(
|
|
1062
|
+
pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
|
|
1063
|
+
) # angle loss
|
|
751
1064
|
else:
|
|
752
1065
|
loss[0] += (pred_angle * 0).sum()
|
|
753
1066
|
|
|
754
1067
|
loss[0] *= self.hyp.box # box gain
|
|
755
1068
|
loss[1] *= self.hyp.cls # cls gain
|
|
756
1069
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
1070
|
+
loss[3] *= self.hyp.angle # angle gain
|
|
757
1071
|
|
|
758
|
-
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
1072
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
|
|
759
1073
|
|
|
760
1074
|
def bbox_decode(
|
|
761
1075
|
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
|
762
1076
|
) -> torch.Tensor:
|
|
763
|
-
"""
|
|
764
|
-
Decode predicted object bounding box coordinates from anchor points and distribution.
|
|
1077
|
+
"""Decode predicted object bounding box coordinates from anchor points and distribution.
|
|
765
1078
|
|
|
766
1079
|
Args:
|
|
767
1080
|
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
|
@@ -776,6 +1089,34 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
776
1089
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
|
777
1090
|
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
|
778
1091
|
|
|
1092
|
+
def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
|
|
1093
|
+
"""Calculate oriented angle loss.
|
|
1094
|
+
|
|
1095
|
+
Args:
|
|
1096
|
+
pred_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1097
|
+
target_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1098
|
+
fg_mask: Foreground mask indicating valid predictions.
|
|
1099
|
+
weight: Loss weights for each prediction.
|
|
1100
|
+
target_scores_sum: Sum of target scores for normalization.
|
|
1101
|
+
lambda_val: control the sensitivity to aspect ratio.
|
|
1102
|
+
"""
|
|
1103
|
+
w_gt = target_bboxes[..., 2]
|
|
1104
|
+
h_gt = target_bboxes[..., 3]
|
|
1105
|
+
pred_theta = pred_bboxes[..., 4]
|
|
1106
|
+
target_theta = target_bboxes[..., 4]
|
|
1107
|
+
|
|
1108
|
+
log_ar = torch.log(w_gt / h_gt)
|
|
1109
|
+
scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
|
|
1110
|
+
|
|
1111
|
+
delta_theta = pred_theta - target_theta
|
|
1112
|
+
delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
|
|
1113
|
+
ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
|
|
1114
|
+
|
|
1115
|
+
ang_loss = scale_weight[fg_mask] * ang_loss
|
|
1116
|
+
ang_loss = ang_loss * weight
|
|
1117
|
+
|
|
1118
|
+
return ang_loss.sum() / target_scores_sum
|
|
1119
|
+
|
|
779
1120
|
|
|
780
1121
|
class E2EDetectLoss:
|
|
781
1122
|
"""Criterion class for computing training losses for end-to-end detection."""
|
|
@@ -795,63 +1136,108 @@ class E2EDetectLoss:
|
|
|
795
1136
|
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
|
|
796
1137
|
|
|
797
1138
|
|
|
1139
|
+
class E2ELoss:
|
|
1140
|
+
"""Criterion class for computing training losses for end-to-end detection."""
|
|
1141
|
+
|
|
1142
|
+
def __init__(self, model, loss_fn=v8DetectionLoss):
|
|
1143
|
+
"""Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
|
|
1144
|
+
self.one2many = loss_fn(model, tal_topk=10)
|
|
1145
|
+
self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
|
|
1146
|
+
self.updates = 0
|
|
1147
|
+
self.total = 1.0
|
|
1148
|
+
# init gain
|
|
1149
|
+
self.o2m = 0.8
|
|
1150
|
+
self.o2o = self.total - self.o2m
|
|
1151
|
+
self.o2m_copy = self.o2m
|
|
1152
|
+
# final gain
|
|
1153
|
+
self.final_o2m = 0.1
|
|
1154
|
+
|
|
1155
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1156
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
1157
|
+
preds = self.one2many.parse_output(preds)
|
|
1158
|
+
one2many, one2one = preds["one2many"], preds["one2one"]
|
|
1159
|
+
loss_one2many = self.one2many.loss(one2many, batch)
|
|
1160
|
+
loss_one2one = self.one2one.loss(one2one, batch)
|
|
1161
|
+
return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
|
|
1162
|
+
|
|
1163
|
+
def update(self) -> None:
|
|
1164
|
+
"""Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
|
|
1165
|
+
self.updates += 1
|
|
1166
|
+
self.o2m = self.decay(self.updates)
|
|
1167
|
+
self.o2o = max(self.total - self.o2m, 0)
|
|
1168
|
+
|
|
1169
|
+
def decay(self, x) -> float:
|
|
1170
|
+
"""Calculate the decayed weight for one-to-many loss based on the current update step."""
|
|
1171
|
+
return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
|
|
1172
|
+
|
|
1173
|
+
|
|
798
1174
|
class TVPDetectLoss:
|
|
799
1175
|
"""Criterion class for computing training losses for text-visual prompt detection."""
|
|
800
1176
|
|
|
801
|
-
def __init__(self, model):
|
|
1177
|
+
def __init__(self, model, tal_topk=10):
|
|
802
1178
|
"""Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
803
|
-
self.vp_criterion = v8DetectionLoss(model)
|
|
1179
|
+
self.vp_criterion = v8DetectionLoss(model, tal_topk)
|
|
804
1180
|
# NOTE: store following info as it's changeable in __call__
|
|
1181
|
+
self.hyp = self.vp_criterion.hyp
|
|
805
1182
|
self.ori_nc = self.vp_criterion.nc
|
|
806
1183
|
self.ori_no = self.vp_criterion.no
|
|
807
1184
|
self.ori_reg_max = self.vp_criterion.reg_max
|
|
808
1185
|
|
|
1186
|
+
def parse_output(self, preds) -> dict[str, torch.Tensor]:
|
|
1187
|
+
"""Parse model predictions to extract features."""
|
|
1188
|
+
return self.vp_criterion.parse_output(preds)
|
|
1189
|
+
|
|
809
1190
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
810
1191
|
"""Calculate the loss for text-visual prompt detection."""
|
|
811
|
-
|
|
1192
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1193
|
+
|
|
1194
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1195
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
812
1196
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
813
1197
|
|
|
814
|
-
if self.
|
|
1198
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
815
1199
|
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
|
|
816
1200
|
return loss, loss.detach()
|
|
817
1201
|
|
|
818
|
-
|
|
819
|
-
vp_loss = self.vp_criterion(
|
|
1202
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1203
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
820
1204
|
box_loss = vp_loss[0][1]
|
|
821
1205
|
return box_loss, vp_loss[1]
|
|
822
1206
|
|
|
823
|
-
def _get_vp_features(self,
|
|
1207
|
+
def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
|
|
824
1208
|
"""Extract visual-prompt features from the model output."""
|
|
825
|
-
|
|
1209
|
+
# NOTE: remove empty placeholder
|
|
1210
|
+
scores = preds["scores"][:, self.ori_nc :, :]
|
|
1211
|
+
vnc = scores.shape[1]
|
|
826
1212
|
|
|
827
1213
|
self.vp_criterion.nc = vnc
|
|
828
1214
|
self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
|
|
829
1215
|
self.vp_criterion.assigner.num_classes = vnc
|
|
830
|
-
|
|
831
|
-
return [
|
|
832
|
-
torch.cat((box, cls_vp), dim=1)
|
|
833
|
-
for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
|
|
834
|
-
]
|
|
1216
|
+
return scores
|
|
835
1217
|
|
|
836
1218
|
|
|
837
1219
|
class TVPSegmentLoss(TVPDetectLoss):
|
|
838
1220
|
"""Criterion class for computing training losses for text-visual prompt segmentation."""
|
|
839
1221
|
|
|
840
|
-
def __init__(self, model):
|
|
1222
|
+
def __init__(self, model, tal_topk=10):
|
|
841
1223
|
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
842
1224
|
super().__init__(model)
|
|
843
|
-
self.vp_criterion = v8SegmentationLoss(model)
|
|
1225
|
+
self.vp_criterion = v8SegmentationLoss(model, tal_topk)
|
|
1226
|
+
self.hyp = self.vp_criterion.hyp
|
|
844
1227
|
|
|
845
1228
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
846
1229
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
847
|
-
|
|
1230
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1231
|
+
|
|
1232
|
+
def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1233
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
848
1234
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
849
1235
|
|
|
850
|
-
if self.
|
|
1236
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
851
1237
|
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
|
852
1238
|
return loss, loss.detach()
|
|
853
1239
|
|
|
854
|
-
|
|
855
|
-
vp_loss = self.vp_criterion(
|
|
1240
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1241
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
856
1242
|
cls_loss = vp_loss[0][2]
|
|
857
1243
|
return cls_loss, vp_loss[1]
|