dgenerate-ultralytics-headless 8.3.248__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +52 -61
- {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/RECORD +97 -84
- {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +2 -2
- tests/conftest.py +1 -1
- tests/test_cuda.py +8 -2
- tests/test_engine.py +8 -8
- tests/test_exports.py +11 -4
- tests/test_integrations.py +9 -9
- tests/test_python.py +41 -16
- tests/test_solutions.py +3 -3
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +31 -31
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/default.yaml +3 -1
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/data/annotator.py +2 -2
- ultralytics/data/augment.py +15 -0
- ultralytics/data/converter.py +76 -45
- ultralytics/data/dataset.py +1 -1
- ultralytics/data/utils.py +2 -2
- ultralytics/engine/exporter.py +34 -28
- ultralytics/engine/model.py +38 -37
- ultralytics/engine/predictor.py +17 -17
- ultralytics/engine/results.py +22 -15
- ultralytics/engine/trainer.py +83 -48
- ultralytics/engine/tuner.py +20 -11
- ultralytics/engine/validator.py +16 -16
- ultralytics/models/fastsam/predict.py +1 -1
- ultralytics/models/yolo/classify/predict.py +1 -1
- ultralytics/models/yolo/classify/train.py +1 -1
- ultralytics/models/yolo/classify/val.py +1 -1
- ultralytics/models/yolo/detect/predict.py +2 -2
- ultralytics/models/yolo/detect/train.py +6 -3
- ultralytics/models/yolo/detect/val.py +7 -1
- ultralytics/models/yolo/model.py +8 -8
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +1 -1
- ultralytics/models/yolo/pose/predict.py +1 -1
- ultralytics/models/yolo/pose/train.py +3 -1
- ultralytics/models/yolo/pose/val.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -3
- ultralytics/models/yolo/segment/train.py +4 -4
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/yoloe/train.py +6 -1
- ultralytics/models/yolo/yoloe/train_seg.py +6 -1
- ultralytics/nn/autobackend.py +14 -8
- ultralytics/nn/modules/__init__.py +8 -0
- ultralytics/nn/modules/block.py +128 -8
- ultralytics/nn/modules/head.py +788 -203
- ultralytics/nn/tasks.py +86 -41
- ultralytics/nn/text_model.py +5 -2
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/ai_gym.py +3 -3
- ultralytics/solutions/config.py +1 -1
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +2 -2
- ultralytics/solutions/object_counter.py +1 -1
- ultralytics/solutions/parking_management.py +1 -1
- ultralytics/solutions/solutions.py +2 -2
- ultralytics/trackers/byte_tracker.py +7 -7
- ultralytics/trackers/track.py +1 -1
- ultralytics/utils/__init__.py +8 -8
- ultralytics/utils/benchmarks.py +26 -26
- ultralytics/utils/callbacks/platform.py +173 -64
- ultralytics/utils/callbacks/tensorboard.py +2 -0
- ultralytics/utils/callbacks/wb.py +6 -1
- ultralytics/utils/checks.py +28 -9
- ultralytics/utils/dist.py +1 -0
- ultralytics/utils/downloads.py +5 -3
- ultralytics/utils/export/engine.py +19 -10
- ultralytics/utils/export/imx.py +38 -20
- ultralytics/utils/export/tensorflow.py +21 -21
- ultralytics/utils/files.py +2 -2
- ultralytics/utils/loss.py +597 -203
- ultralytics/utils/metrics.py +2 -1
- ultralytics/utils/ops.py +11 -2
- ultralytics/utils/patches.py +42 -0
- ultralytics/utils/plotting.py +3 -0
- ultralytics/utils/tal.py +100 -20
- ultralytics/utils/torch_utils.py +1 -1
- ultralytics/utils/tqdm.py +4 -1
- ultralytics/utils/tuner.py +2 -5
- {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py
CHANGED
|
@@ -2,19 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
import torch.nn as nn
|
|
9
10
|
import torch.nn.functional as F
|
|
10
11
|
|
|
11
|
-
from ultralytics.utils.metrics import OKS_SIGMA
|
|
12
|
+
from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
|
|
12
13
|
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
|
13
14
|
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
|
14
15
|
from ultralytics.utils.torch_utils import autocast
|
|
15
16
|
|
|
16
17
|
from .metrics import bbox_iou, probiou
|
|
17
|
-
from .tal import bbox2dist
|
|
18
|
+
from .tal import bbox2dist, rbox2dist
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class VarifocalLoss(nn.Module):
|
|
@@ -122,6 +123,8 @@ class BboxLoss(nn.Module):
|
|
|
122
123
|
target_scores: torch.Tensor,
|
|
123
124
|
target_scores_sum: torch.Tensor,
|
|
124
125
|
fg_mask: torch.Tensor,
|
|
126
|
+
imgsz: torch.Tensor,
|
|
127
|
+
stride: torch.Tensor,
|
|
125
128
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
126
129
|
"""Compute IoU and DFL losses for bounding boxes."""
|
|
127
130
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -134,11 +137,76 @@ class BboxLoss(nn.Module):
|
|
|
134
137
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
135
138
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
136
139
|
else:
|
|
137
|
-
|
|
140
|
+
target_ltrb = bbox2dist(anchor_points, target_bboxes)
|
|
141
|
+
# normalize ltrb by image size
|
|
142
|
+
target_ltrb = target_ltrb * stride
|
|
143
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
144
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
145
|
+
pred_dist = pred_dist * stride
|
|
146
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
147
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
148
|
+
loss_dfl = (
|
|
149
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
150
|
+
)
|
|
151
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
138
152
|
|
|
139
153
|
return loss_iou, loss_dfl
|
|
140
154
|
|
|
141
155
|
|
|
156
|
+
class RLELoss(nn.Module):
|
|
157
|
+
"""Residual Log-Likelihood Estimation Loss.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
use_target_weight (bool): Option to use weighted loss.
|
|
161
|
+
size_average (bool): Option to average the loss by the batch_size.
|
|
162
|
+
residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
|
|
163
|
+
|
|
164
|
+
References:
|
|
165
|
+
https://arxiv.org/abs/2107.11291
|
|
166
|
+
https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/losses/regression_loss.py
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
|
|
170
|
+
"""Initialize RLELoss with target weight and residual options.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
use_target_weight (bool): Whether to use target weights for loss calculation.
|
|
174
|
+
size_average (bool): Whether to average the loss over elements.
|
|
175
|
+
residual (bool): Whether to include residual log-likelihood term.
|
|
176
|
+
"""
|
|
177
|
+
super().__init__()
|
|
178
|
+
self.size_average = size_average
|
|
179
|
+
self.use_target_weight = use_target_weight
|
|
180
|
+
self.residual = residual
|
|
181
|
+
|
|
182
|
+
def forward(
|
|
183
|
+
self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
|
|
184
|
+
) -> torch.Tensor:
|
|
185
|
+
"""
|
|
186
|
+
Args:
|
|
187
|
+
sigma (torch.Tensor): Output sigma, shape (N, D).
|
|
188
|
+
log_phi (torch.Tensor): Output log_phi, shape (N).
|
|
189
|
+
error (torch.Tensor): Error, shape (N, D).
|
|
190
|
+
target_weight (torch.Tensor): Weights across different joint types, shape (N).
|
|
191
|
+
"""
|
|
192
|
+
log_sigma = torch.log(sigma)
|
|
193
|
+
loss = log_sigma - log_phi.unsqueeze(1)
|
|
194
|
+
|
|
195
|
+
if self.residual:
|
|
196
|
+
loss += torch.log(sigma * 2) + torch.abs(error)
|
|
197
|
+
|
|
198
|
+
if self.use_target_weight:
|
|
199
|
+
assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
|
|
200
|
+
if target_weight.dim() == 1:
|
|
201
|
+
target_weight = target_weight.unsqueeze(1)
|
|
202
|
+
loss *= target_weight
|
|
203
|
+
|
|
204
|
+
if self.size_average:
|
|
205
|
+
loss /= len(loss)
|
|
206
|
+
|
|
207
|
+
return loss.sum()
|
|
208
|
+
|
|
209
|
+
|
|
142
210
|
class RotatedBboxLoss(BboxLoss):
|
|
143
211
|
"""Criterion class for computing training losses for rotated bounding boxes."""
|
|
144
212
|
|
|
@@ -155,6 +223,8 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
155
223
|
target_scores: torch.Tensor,
|
|
156
224
|
target_scores_sum: torch.Tensor,
|
|
157
225
|
fg_mask: torch.Tensor,
|
|
226
|
+
imgsz: torch.Tensor,
|
|
227
|
+
stride: torch.Tensor,
|
|
158
228
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
159
229
|
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
|
160
230
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -163,15 +233,84 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
163
233
|
|
|
164
234
|
# DFL loss
|
|
165
235
|
if self.dfl_loss:
|
|
166
|
-
target_ltrb =
|
|
236
|
+
target_ltrb = rbox2dist(
|
|
237
|
+
target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
|
|
238
|
+
)
|
|
167
239
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
168
240
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
169
241
|
else:
|
|
170
|
-
|
|
242
|
+
target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
|
|
243
|
+
target_ltrb = target_ltrb * stride
|
|
244
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
245
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
246
|
+
pred_dist = pred_dist * stride
|
|
247
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
248
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
249
|
+
loss_dfl = (
|
|
250
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
251
|
+
)
|
|
252
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
171
253
|
|
|
172
254
|
return loss_iou, loss_dfl
|
|
173
255
|
|
|
174
256
|
|
|
257
|
+
class MultiChannelDiceLoss(nn.Module):
|
|
258
|
+
"""Criterion class for computing multi-channel Dice losses."""
|
|
259
|
+
|
|
260
|
+
def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
|
|
261
|
+
"""Initialize MultiChannelDiceLoss with smoothing and reduction options.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
265
|
+
reduction (str): Reduction method ('mean', 'sum', or 'none').
|
|
266
|
+
"""
|
|
267
|
+
super().__init__()
|
|
268
|
+
self.smooth = smooth
|
|
269
|
+
self.reduction = reduction
|
|
270
|
+
|
|
271
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
272
|
+
"""Calculate multi-channel Dice loss between predictions and targets."""
|
|
273
|
+
assert pred.size() == target.size(), "the size of predict and target must be equal."
|
|
274
|
+
|
|
275
|
+
pred = pred.sigmoid()
|
|
276
|
+
intersection = (pred * target).sum(dim=(2, 3))
|
|
277
|
+
union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
|
|
278
|
+
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
|
279
|
+
dice_loss = 1.0 - dice
|
|
280
|
+
dice_loss = dice_loss.mean(dim=1)
|
|
281
|
+
|
|
282
|
+
if self.reduction == "mean":
|
|
283
|
+
return dice_loss.mean()
|
|
284
|
+
elif self.reduction == "sum":
|
|
285
|
+
return dice_loss.sum()
|
|
286
|
+
else:
|
|
287
|
+
return dice_loss
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class BCEDiceLoss(nn.Module):
|
|
291
|
+
"""Criterion class for computing combined BCE and Dice losses."""
|
|
292
|
+
|
|
293
|
+
def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
|
|
294
|
+
"""Initialize BCEDiceLoss with BCE and Dice weight factors.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
weight_bce (float): Weight factor for BCE loss component.
|
|
298
|
+
weight_dice (float): Weight factor for Dice loss component.
|
|
299
|
+
"""
|
|
300
|
+
super().__init__()
|
|
301
|
+
self.weight_bce = weight_bce
|
|
302
|
+
self.weight_dice = weight_dice
|
|
303
|
+
self.bce = nn.BCEWithLogitsLoss()
|
|
304
|
+
self.dice = MultiChannelDiceLoss(smooth=1)
|
|
305
|
+
|
|
306
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
307
|
+
"""Calculate combined BCE and Dice loss between predictions and targets."""
|
|
308
|
+
_, _, mask_h, mask_w = pred.shape
|
|
309
|
+
if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
|
|
310
|
+
target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
|
|
311
|
+
return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
|
|
312
|
+
|
|
313
|
+
|
|
175
314
|
class KeypointLoss(nn.Module):
|
|
176
315
|
"""Criterion class for computing keypoint losses."""
|
|
177
316
|
|
|
@@ -194,7 +333,7 @@ class KeypointLoss(nn.Module):
|
|
|
194
333
|
class v8DetectionLoss:
|
|
195
334
|
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
|
196
335
|
|
|
197
|
-
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
|
|
336
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
198
337
|
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
|
199
338
|
device = next(model.parameters()).device # get model device
|
|
200
339
|
h = model.args # hyperparameters
|
|
@@ -210,7 +349,14 @@ class v8DetectionLoss:
|
|
|
210
349
|
|
|
211
350
|
self.use_dfl = m.reg_max > 1
|
|
212
351
|
|
|
213
|
-
self.assigner = TaskAlignedAssigner(
|
|
352
|
+
self.assigner = TaskAlignedAssigner(
|
|
353
|
+
topk=tal_topk,
|
|
354
|
+
num_classes=self.nc,
|
|
355
|
+
alpha=0.5,
|
|
356
|
+
beta=6.0,
|
|
357
|
+
stride=self.stride.tolist(),
|
|
358
|
+
topk2=tal_topk2,
|
|
359
|
+
)
|
|
214
360
|
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
|
215
361
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
|
216
362
|
|
|
@@ -240,35 +386,31 @@ class v8DetectionLoss:
|
|
|
240
386
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
|
241
387
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
|
242
388
|
|
|
243
|
-
def
|
|
244
|
-
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size
|
|
389
|
+
def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
|
|
390
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
|
|
391
|
+
target indices.
|
|
392
|
+
"""
|
|
245
393
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
(
|
|
394
|
+
pred_distri, pred_scores = (
|
|
395
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
396
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
249
397
|
)
|
|
250
|
-
|
|
251
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
252
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
398
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
253
399
|
|
|
254
400
|
dtype = pred_scores.dtype
|
|
255
401
|
batch_size = pred_scores.shape[0]
|
|
256
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
257
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
402
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
258
403
|
|
|
259
404
|
# Targets
|
|
260
405
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
261
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
406
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
262
407
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
263
408
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
264
409
|
|
|
265
410
|
# Pboxes
|
|
266
411
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
267
|
-
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
|
|
268
|
-
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
|
|
269
412
|
|
|
270
|
-
_, target_bboxes, target_scores, fg_mask,
|
|
271
|
-
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
|
|
413
|
+
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
272
414
|
pred_scores.detach().sigmoid(),
|
|
273
415
|
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
274
416
|
anchor_points * stride_tensor,
|
|
@@ -280,7 +422,6 @@ class v8DetectionLoss:
|
|
|
280
422
|
target_scores_sum = max(target_scores.sum(), 1)
|
|
281
423
|
|
|
282
424
|
# Cls loss
|
|
283
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
284
425
|
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
285
426
|
|
|
286
427
|
# Bbox loss
|
|
@@ -293,105 +434,108 @@ class v8DetectionLoss:
|
|
|
293
434
|
target_scores,
|
|
294
435
|
target_scores_sum,
|
|
295
436
|
fg_mask,
|
|
437
|
+
imgsz,
|
|
438
|
+
stride_tensor,
|
|
296
439
|
)
|
|
297
440
|
|
|
298
441
|
loss[0] *= self.hyp.box # box gain
|
|
299
442
|
loss[1] *= self.hyp.cls # cls gain
|
|
300
443
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
444
|
+
return (
|
|
445
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
|
|
446
|
+
loss,
|
|
447
|
+
loss.detach(),
|
|
448
|
+
) # loss(box, cls, dfl)
|
|
301
449
|
|
|
302
|
-
|
|
450
|
+
def parse_output(
|
|
451
|
+
self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
|
|
452
|
+
) -> torch.Tensor:
|
|
453
|
+
"""Parse model predictions to extract features."""
|
|
454
|
+
return preds[1] if isinstance(preds, tuple) else preds
|
|
455
|
+
|
|
456
|
+
def __call__(
|
|
457
|
+
self,
|
|
458
|
+
preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
|
|
459
|
+
batch: dict[str, torch.Tensor],
|
|
460
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
461
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
462
|
+
return self.loss(self.parse_output(preds), batch)
|
|
463
|
+
|
|
464
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
465
|
+
"""A wrapper for get_assigned_targets_and_loss and parse_output."""
|
|
466
|
+
batch_size = preds["boxes"].shape[0]
|
|
467
|
+
loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
|
|
468
|
+
return loss * batch_size, loss_detach
|
|
303
469
|
|
|
304
470
|
|
|
305
471
|
class v8SegmentationLoss(v8DetectionLoss):
|
|
306
472
|
"""Criterion class for computing training losses for YOLOv8 segmentation."""
|
|
307
473
|
|
|
308
|
-
def __init__(self, model): # model must be de-paralleled
|
|
474
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
309
475
|
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
|
|
310
|
-
super().__init__(model)
|
|
476
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
311
477
|
self.overlap = model.args.overlap_mask
|
|
478
|
+
self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
|
|
312
479
|
|
|
313
|
-
def
|
|
480
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
314
481
|
"""Calculate and return the combined loss for detection and segmentation."""
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
#
|
|
323
|
-
|
|
324
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
325
|
-
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
|
326
|
-
|
|
327
|
-
dtype = pred_scores.dtype
|
|
328
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
329
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
330
|
-
|
|
331
|
-
# Targets
|
|
332
|
-
try:
|
|
333
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
334
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
335
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
336
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
337
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
338
|
-
except RuntimeError as e:
|
|
339
|
-
raise TypeError(
|
|
340
|
-
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
|
341
|
-
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
|
342
|
-
"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
|
343
|
-
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
|
344
|
-
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
|
345
|
-
) from e
|
|
346
|
-
|
|
347
|
-
# Pboxes
|
|
348
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
349
|
-
|
|
350
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
351
|
-
pred_scores.detach().sigmoid(),
|
|
352
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
353
|
-
anchor_points * stride_tensor,
|
|
354
|
-
gt_labels,
|
|
355
|
-
gt_bboxes,
|
|
356
|
-
mask_gt,
|
|
357
|
-
)
|
|
358
|
-
|
|
359
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
360
|
-
|
|
361
|
-
# Cls loss
|
|
362
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
363
|
-
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
482
|
+
pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
|
|
483
|
+
loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
|
|
484
|
+
if isinstance(proto, tuple) and len(proto) == 2:
|
|
485
|
+
proto, pred_semseg = proto
|
|
486
|
+
else:
|
|
487
|
+
pred_semseg = None
|
|
488
|
+
(fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
|
|
489
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
490
|
+
loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
|
|
364
491
|
|
|
492
|
+
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
|
365
493
|
if fg_mask.sum():
|
|
366
|
-
# Bbox loss
|
|
367
|
-
loss[0], loss[3] = self.bbox_loss(
|
|
368
|
-
pred_distri,
|
|
369
|
-
pred_bboxes,
|
|
370
|
-
anchor_points,
|
|
371
|
-
target_bboxes / stride_tensor,
|
|
372
|
-
target_scores,
|
|
373
|
-
target_scores_sum,
|
|
374
|
-
fg_mask,
|
|
375
|
-
)
|
|
376
494
|
# Masks loss
|
|
377
495
|
masks = batch["masks"].to(self.device).float()
|
|
378
496
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
|
379
|
-
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
497
|
+
# masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
498
|
+
proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
|
|
380
499
|
|
|
500
|
+
imgsz = (
|
|
501
|
+
torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
|
|
502
|
+
)
|
|
381
503
|
loss[1] = self.calculate_segmentation_loss(
|
|
382
|
-
fg_mask,
|
|
504
|
+
fg_mask,
|
|
505
|
+
masks,
|
|
506
|
+
target_gt_idx,
|
|
507
|
+
target_bboxes,
|
|
508
|
+
batch["batch_idx"].view(-1, 1),
|
|
509
|
+
proto,
|
|
510
|
+
pred_masks,
|
|
511
|
+
imgsz,
|
|
383
512
|
)
|
|
513
|
+
if pred_semseg is not None:
|
|
514
|
+
sem_masks = batch["sem_masks"].to(self.device) # NxHxW
|
|
515
|
+
sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
|
|
516
|
+
|
|
517
|
+
if self.overlap:
|
|
518
|
+
mask_zero = masks == 0 # NxHxW
|
|
519
|
+
sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
|
|
520
|
+
else:
|
|
521
|
+
batch_idx = batch["batch_idx"].view(-1) # [total_instances]
|
|
522
|
+
for i in range(batch_size):
|
|
523
|
+
instance_mask_i = masks[batch_idx == i] # [num_instances_i, H, W]
|
|
524
|
+
if len(instance_mask_i) == 0:
|
|
525
|
+
continue
|
|
526
|
+
sem_masks[i, :, instance_mask_i.sum(dim=0) == 0] = 0
|
|
527
|
+
|
|
528
|
+
loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
|
|
529
|
+
loss[4] *= self.hyp.box # seg gain
|
|
384
530
|
|
|
385
531
|
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
|
386
532
|
else:
|
|
387
533
|
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
|
534
|
+
if pred_semseg is not None:
|
|
535
|
+
loss[4] += (pred_semseg * 0).sum()
|
|
388
536
|
|
|
389
|
-
loss[0] *= self.hyp.box # box gain
|
|
390
537
|
loss[1] *= self.hyp.box # seg gain
|
|
391
|
-
loss
|
|
392
|
-
loss[3] *= self.hyp.dfl # dfl gain
|
|
393
|
-
|
|
394
|
-
return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
|
|
538
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
395
539
|
|
|
396
540
|
@staticmethod
|
|
397
541
|
def single_mask_loss(
|
|
@@ -427,7 +571,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
427
571
|
proto: torch.Tensor,
|
|
428
572
|
pred_masks: torch.Tensor,
|
|
429
573
|
imgsz: torch.Tensor,
|
|
430
|
-
overlap: bool,
|
|
431
574
|
) -> torch.Tensor:
|
|
432
575
|
"""Calculate the loss for instance segmentation.
|
|
433
576
|
|
|
@@ -440,7 +583,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
440
583
|
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
|
441
584
|
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
|
442
585
|
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
|
443
|
-
overlap (bool): Whether the masks in `masks` tensor overlap.
|
|
444
586
|
|
|
445
587
|
Returns:
|
|
446
588
|
(torch.Tensor): The calculated loss for instance segmentation.
|
|
@@ -466,7 +608,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
466
608
|
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
|
467
609
|
if fg_mask_i.any():
|
|
468
610
|
mask_idx = target_gt_idx_i[fg_mask_i]
|
|
469
|
-
if overlap:
|
|
611
|
+
if self.overlap:
|
|
470
612
|
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
|
471
613
|
gt_mask = gt_mask.float()
|
|
472
614
|
else:
|
|
@@ -486,9 +628,9 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
486
628
|
class v8PoseLoss(v8DetectionLoss):
|
|
487
629
|
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
|
|
488
630
|
|
|
489
|
-
def __init__(self, model): # model must be de-paralleled
|
|
631
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
|
|
490
632
|
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
|
|
491
|
-
super().__init__(model)
|
|
633
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
492
634
|
self.kpt_shape = model.model[-1].kpt_shape
|
|
493
635
|
self.bce_pose = nn.BCEWithLogitsLoss()
|
|
494
636
|
is_pose = self.kpt_shape == [17, 3]
|
|
@@ -496,69 +638,40 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
496
638
|
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
|
497
639
|
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
|
498
640
|
|
|
499
|
-
def
|
|
641
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
500
642
|
"""Calculate the total loss and detach it for pose estimation."""
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
643
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
644
|
+
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
645
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
646
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
505
647
|
)
|
|
648
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
649
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
506
650
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
510
|
-
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
|
511
|
-
|
|
512
|
-
dtype = pred_scores.dtype
|
|
513
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
514
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
515
|
-
|
|
516
|
-
# Targets
|
|
517
|
-
batch_size = pred_scores.shape[0]
|
|
518
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
519
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
520
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
521
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
522
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
651
|
+
batch_size = pred_kpts.shape[0]
|
|
652
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
|
|
523
653
|
|
|
524
654
|
# Pboxes
|
|
525
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
526
655
|
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
|
527
656
|
|
|
528
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
529
|
-
pred_scores.detach().sigmoid(),
|
|
530
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
531
|
-
anchor_points * stride_tensor,
|
|
532
|
-
gt_labels,
|
|
533
|
-
gt_bboxes,
|
|
534
|
-
mask_gt,
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
538
|
-
|
|
539
|
-
# Cls loss
|
|
540
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
541
|
-
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
542
|
-
|
|
543
657
|
# Bbox loss
|
|
544
658
|
if fg_mask.sum():
|
|
545
|
-
target_bboxes /= stride_tensor
|
|
546
|
-
loss[0], loss[4] = self.bbox_loss(
|
|
547
|
-
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
|
548
|
-
)
|
|
549
659
|
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
550
660
|
keypoints[..., 0] *= imgsz[1]
|
|
551
661
|
keypoints[..., 1] *= imgsz[0]
|
|
552
662
|
|
|
553
663
|
loss[1], loss[2] = self.calculate_keypoints_loss(
|
|
554
|
-
fg_mask,
|
|
664
|
+
fg_mask,
|
|
665
|
+
target_gt_idx,
|
|
666
|
+
keypoints,
|
|
667
|
+
batch["batch_idx"].view(-1, 1),
|
|
668
|
+
stride_tensor,
|
|
669
|
+
target_bboxes,
|
|
670
|
+
pred_kpts,
|
|
555
671
|
)
|
|
556
672
|
|
|
557
|
-
loss[0] *= self.hyp.box # box gain
|
|
558
673
|
loss[1] *= self.hyp.pose # pose gain
|
|
559
674
|
loss[2] *= self.hyp.kobj # kobj gain
|
|
560
|
-
loss[3] *= self.hyp.cls # cls gain
|
|
561
|
-
loss[4] *= self.hyp.dfl # dfl gain
|
|
562
675
|
|
|
563
676
|
return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
|
|
564
677
|
|
|
@@ -571,34 +684,23 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
571
684
|
y[..., 1] += anchor_points[:, [1]] - 0.5
|
|
572
685
|
return y
|
|
573
686
|
|
|
574
|
-
def
|
|
687
|
+
def _select_target_keypoints(
|
|
575
688
|
self,
|
|
576
|
-
masks: torch.Tensor,
|
|
577
|
-
target_gt_idx: torch.Tensor,
|
|
578
689
|
keypoints: torch.Tensor,
|
|
579
690
|
batch_idx: torch.Tensor,
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
"""Calculate the keypoints loss for the model.
|
|
585
|
-
|
|
586
|
-
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
587
|
-
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
588
|
-
a binary classification loss that classifies whether a keypoint is present or not.
|
|
691
|
+
target_gt_idx: torch.Tensor,
|
|
692
|
+
masks: torch.Tensor,
|
|
693
|
+
) -> torch.Tensor:
|
|
694
|
+
"""Select target keypoints for each anchor based on batch index and target ground truth index.
|
|
589
695
|
|
|
590
696
|
Args:
|
|
591
|
-
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
592
|
-
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
593
697
|
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
594
698
|
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
699
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
700
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
598
701
|
|
|
599
702
|
Returns:
|
|
600
|
-
|
|
601
|
-
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
703
|
+
(torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
602
704
|
"""
|
|
603
705
|
batch_idx = batch_idx.flatten()
|
|
604
706
|
batch_size = len(masks)
|
|
@@ -625,6 +727,40 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
625
727
|
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
|
626
728
|
)
|
|
627
729
|
|
|
730
|
+
return selected_keypoints
|
|
731
|
+
|
|
732
|
+
def calculate_keypoints_loss(
|
|
733
|
+
self,
|
|
734
|
+
masks: torch.Tensor,
|
|
735
|
+
target_gt_idx: torch.Tensor,
|
|
736
|
+
keypoints: torch.Tensor,
|
|
737
|
+
batch_idx: torch.Tensor,
|
|
738
|
+
stride_tensor: torch.Tensor,
|
|
739
|
+
target_bboxes: torch.Tensor,
|
|
740
|
+
pred_kpts: torch.Tensor,
|
|
741
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
742
|
+
"""Calculate the keypoints loss for the model.
|
|
743
|
+
|
|
744
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
745
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
746
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
750
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
751
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
752
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
753
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
754
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
755
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
759
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
760
|
+
"""
|
|
761
|
+
# Select target keypoints using helper method
|
|
762
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
763
|
+
|
|
628
764
|
# Divide coordinates by stride
|
|
629
765
|
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
630
766
|
|
|
@@ -632,6 +768,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
632
768
|
kpts_obj_loss = 0
|
|
633
769
|
|
|
634
770
|
if masks.any():
|
|
771
|
+
target_bboxes /= stride_tensor
|
|
635
772
|
gt_kpt = selected_keypoints[masks]
|
|
636
773
|
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
637
774
|
pred_kpt = pred_kpts[masks]
|
|
@@ -644,6 +781,172 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
644
781
|
return kpts_loss, kpts_obj_loss
|
|
645
782
|
|
|
646
783
|
|
|
784
|
+
class PoseLoss26(v8PoseLoss):
|
|
785
|
+
"""Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
|
|
786
|
+
|
|
787
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
788
|
+
"""Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
|
|
789
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
790
|
+
is_pose = self.kpt_shape == [17, 3]
|
|
791
|
+
nkpt = self.kpt_shape[0] # number of keypoints
|
|
792
|
+
self.rle_loss = None
|
|
793
|
+
self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
|
|
794
|
+
if self.flow_model is not None:
|
|
795
|
+
self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
|
|
796
|
+
self.target_weights = (
|
|
797
|
+
torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
801
|
+
"""Calculate the total loss and detach it for pose estimation."""
|
|
802
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
803
|
+
loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
804
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
805
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
806
|
+
)
|
|
807
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
808
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
809
|
+
|
|
810
|
+
batch_size = pred_kpts.shape[0]
|
|
811
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
|
|
812
|
+
|
|
813
|
+
pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
|
|
814
|
+
|
|
815
|
+
if self.rle_loss and preds.get("kpts_sigma", None) is not None:
|
|
816
|
+
pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
|
|
817
|
+
pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
|
|
818
|
+
pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
|
|
819
|
+
|
|
820
|
+
pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
|
|
821
|
+
|
|
822
|
+
# Bbox loss
|
|
823
|
+
if fg_mask.sum():
|
|
824
|
+
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
825
|
+
keypoints[..., 0] *= imgsz[1]
|
|
826
|
+
keypoints[..., 1] *= imgsz[0]
|
|
827
|
+
|
|
828
|
+
keypoints_loss = self.calculate_keypoints_loss(
|
|
829
|
+
fg_mask,
|
|
830
|
+
target_gt_idx,
|
|
831
|
+
keypoints,
|
|
832
|
+
batch["batch_idx"].view(-1, 1),
|
|
833
|
+
stride_tensor,
|
|
834
|
+
target_bboxes,
|
|
835
|
+
pred_kpts,
|
|
836
|
+
)
|
|
837
|
+
loss[1] = keypoints_loss[0]
|
|
838
|
+
loss[2] = keypoints_loss[1]
|
|
839
|
+
if self.rle_loss is not None:
|
|
840
|
+
loss[5] = keypoints_loss[2]
|
|
841
|
+
|
|
842
|
+
loss[1] *= self.hyp.pose # pose gain
|
|
843
|
+
loss[2] *= self.hyp.kobj # kobj gain
|
|
844
|
+
if self.rle_loss is not None:
|
|
845
|
+
loss[5] *= self.hyp.rle # rle gain
|
|
846
|
+
|
|
847
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl, kpt_location, kpt_visibility)
|
|
848
|
+
|
|
849
|
+
@staticmethod
|
|
850
|
+
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
|
851
|
+
"""Decode predicted keypoints to image coordinates."""
|
|
852
|
+
y = pred_kpts.clone()
|
|
853
|
+
y[..., 0] += anchor_points[:, [0]]
|
|
854
|
+
y[..., 1] += anchor_points[:, [1]]
|
|
855
|
+
return y
|
|
856
|
+
|
|
857
|
+
def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
|
|
858
|
+
"""Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
|
|
859
|
+
|
|
860
|
+
Args:
|
|
861
|
+
pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
|
|
862
|
+
gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
|
|
863
|
+
kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
|
|
864
|
+
|
|
865
|
+
Returns:
|
|
866
|
+
(torch.Tensor): The RLE loss.
|
|
867
|
+
"""
|
|
868
|
+
pred_kpt_visible = pred_kpt[kpt_mask]
|
|
869
|
+
gt_kpt_visible = gt_kpt[kpt_mask]
|
|
870
|
+
pred_coords = pred_kpt_visible[:, 0:2]
|
|
871
|
+
pred_sigma = pred_kpt_visible[:, -2:]
|
|
872
|
+
gt_coords = gt_kpt_visible[:, 0:2]
|
|
873
|
+
|
|
874
|
+
target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
|
|
875
|
+
target_weights = target_weights[kpt_mask]
|
|
876
|
+
|
|
877
|
+
pred_sigma = pred_sigma.sigmoid()
|
|
878
|
+
error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
|
|
879
|
+
|
|
880
|
+
# Filter out NaN and Inf values to prevent MultivariateNormal validation errors
|
|
881
|
+
valid_mask = ~(torch.isnan(error) | torch.isinf(error)).any(dim=-1)
|
|
882
|
+
if not valid_mask.any():
|
|
883
|
+
return torch.tensor(0.0, device=pred_kpt.device)
|
|
884
|
+
|
|
885
|
+
error = error[valid_mask]
|
|
886
|
+
error = error.clamp(-100, 100) # Prevent numerical instability
|
|
887
|
+
pred_sigma = pred_sigma[valid_mask]
|
|
888
|
+
target_weights = target_weights[valid_mask]
|
|
889
|
+
|
|
890
|
+
log_phi = self.flow_model.log_prob(error)
|
|
891
|
+
|
|
892
|
+
return self.rle_loss(pred_sigma, log_phi, error, target_weights)
|
|
893
|
+
|
|
894
|
+
def calculate_keypoints_loss(
|
|
895
|
+
self,
|
|
896
|
+
masks: torch.Tensor,
|
|
897
|
+
target_gt_idx: torch.Tensor,
|
|
898
|
+
keypoints: torch.Tensor,
|
|
899
|
+
batch_idx: torch.Tensor,
|
|
900
|
+
stride_tensor: torch.Tensor,
|
|
901
|
+
target_bboxes: torch.Tensor,
|
|
902
|
+
pred_kpts: torch.Tensor,
|
|
903
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
904
|
+
"""Calculate the keypoints loss for the model.
|
|
905
|
+
|
|
906
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
907
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
908
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
912
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
913
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
914
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
915
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
916
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
917
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
918
|
+
|
|
919
|
+
Returns:
|
|
920
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
921
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
922
|
+
rle_loss (torch.Tensor): The RLE loss.
|
|
923
|
+
"""
|
|
924
|
+
# Select target keypoints using inherited helper method
|
|
925
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
926
|
+
|
|
927
|
+
# Divide coordinates by stride
|
|
928
|
+
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
929
|
+
|
|
930
|
+
kpts_loss = 0
|
|
931
|
+
kpts_obj_loss = 0
|
|
932
|
+
rle_loss = 0
|
|
933
|
+
|
|
934
|
+
if masks.any():
|
|
935
|
+
target_bboxes /= stride_tensor
|
|
936
|
+
gt_kpt = selected_keypoints[masks]
|
|
937
|
+
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
938
|
+
pred_kpt = pred_kpts[masks]
|
|
939
|
+
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
|
940
|
+
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
|
941
|
+
|
|
942
|
+
if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
|
|
943
|
+
rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
|
|
944
|
+
if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
|
|
945
|
+
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
|
946
|
+
|
|
947
|
+
return kpts_loss, kpts_obj_loss, rle_loss
|
|
948
|
+
|
|
949
|
+
|
|
647
950
|
class v8ClassificationLoss:
|
|
648
951
|
"""Criterion class for computing training losses for classification."""
|
|
649
952
|
|
|
@@ -657,10 +960,17 @@ class v8ClassificationLoss:
|
|
|
657
960
|
class v8OBBLoss(v8DetectionLoss):
|
|
658
961
|
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
|
|
659
962
|
|
|
660
|
-
def __init__(self, model):
|
|
963
|
+
def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
|
|
661
964
|
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
|
|
662
|
-
super().__init__(model)
|
|
663
|
-
self.assigner = RotatedTaskAlignedAssigner(
|
|
965
|
+
super().__init__(model, tal_topk=tal_topk)
|
|
966
|
+
self.assigner = RotatedTaskAlignedAssigner(
|
|
967
|
+
topk=tal_topk,
|
|
968
|
+
num_classes=self.nc,
|
|
969
|
+
alpha=0.5,
|
|
970
|
+
beta=6.0,
|
|
971
|
+
stride=self.stride.tolist(),
|
|
972
|
+
topk2=tal_topk2,
|
|
973
|
+
)
|
|
664
974
|
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
|
665
975
|
|
|
666
976
|
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
|
@@ -680,23 +990,19 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
680
990
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
|
681
991
|
return out
|
|
682
992
|
|
|
683
|
-
def
|
|
993
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
684
994
|
"""Calculate and return the loss for oriented bounding box detection."""
|
|
685
|
-
loss = torch.zeros(
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
(
|
|
995
|
+
loss = torch.zeros(4, device=self.device) # box, cls, dfl, angle
|
|
996
|
+
pred_distri, pred_scores, pred_angle = (
|
|
997
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
998
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
999
|
+
preds["angle"].permute(0, 2, 1).contiguous(),
|
|
690
1000
|
)
|
|
691
|
-
|
|
692
|
-
#
|
|
693
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
694
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
695
|
-
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
|
1001
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
1002
|
+
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
|
696
1003
|
|
|
697
1004
|
dtype = pred_scores.dtype
|
|
698
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
699
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
1005
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
700
1006
|
|
|
701
1007
|
# targets
|
|
702
1008
|
try:
|
|
@@ -704,14 +1010,14 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
704
1010
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
|
705
1011
|
rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
|
|
706
1012
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
|
707
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
1013
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
708
1014
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
|
709
1015
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
710
1016
|
except RuntimeError as e:
|
|
711
1017
|
raise TypeError(
|
|
712
1018
|
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
|
713
1019
|
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
|
714
|
-
"i.e. 'yolo train model=
|
|
1020
|
+
"i.e. 'yolo train model=yolo26n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
|
|
715
1021
|
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
|
716
1022
|
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
|
717
1023
|
) from e
|
|
@@ -741,16 +1047,29 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
741
1047
|
if fg_mask.sum():
|
|
742
1048
|
target_bboxes[..., :4] /= stride_tensor
|
|
743
1049
|
loss[0], loss[2] = self.bbox_loss(
|
|
744
|
-
pred_distri,
|
|
1050
|
+
pred_distri,
|
|
1051
|
+
pred_bboxes,
|
|
1052
|
+
anchor_points,
|
|
1053
|
+
target_bboxes,
|
|
1054
|
+
target_scores,
|
|
1055
|
+
target_scores_sum,
|
|
1056
|
+
fg_mask,
|
|
1057
|
+
imgsz,
|
|
1058
|
+
stride_tensor,
|
|
745
1059
|
)
|
|
1060
|
+
weight = target_scores.sum(-1)[fg_mask]
|
|
1061
|
+
loss[3] = self.calculate_angle_loss(
|
|
1062
|
+
pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
|
|
1063
|
+
) # angle loss
|
|
746
1064
|
else:
|
|
747
1065
|
loss[0] += (pred_angle * 0).sum()
|
|
748
1066
|
|
|
749
1067
|
loss[0] *= self.hyp.box # box gain
|
|
750
1068
|
loss[1] *= self.hyp.cls # cls gain
|
|
751
1069
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
1070
|
+
loss[3] *= self.hyp.angle # angle gain
|
|
752
1071
|
|
|
753
|
-
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
1072
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
|
|
754
1073
|
|
|
755
1074
|
def bbox_decode(
|
|
756
1075
|
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
|
@@ -770,6 +1089,34 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
770
1089
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
|
771
1090
|
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
|
772
1091
|
|
|
1092
|
+
def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
|
|
1093
|
+
"""Calculate oriented angle loss.
|
|
1094
|
+
|
|
1095
|
+
Args:
|
|
1096
|
+
pred_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1097
|
+
target_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1098
|
+
fg_mask: Foreground mask indicating valid predictions.
|
|
1099
|
+
weight: Loss weights for each prediction.
|
|
1100
|
+
target_scores_sum: Sum of target scores for normalization.
|
|
1101
|
+
lambda_val: control the sensitivity to aspect ratio.
|
|
1102
|
+
"""
|
|
1103
|
+
w_gt = target_bboxes[..., 2]
|
|
1104
|
+
h_gt = target_bboxes[..., 3]
|
|
1105
|
+
pred_theta = pred_bboxes[..., 4]
|
|
1106
|
+
target_theta = target_bboxes[..., 4]
|
|
1107
|
+
|
|
1108
|
+
log_ar = torch.log(w_gt / h_gt)
|
|
1109
|
+
scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
|
|
1110
|
+
|
|
1111
|
+
delta_theta = pred_theta - target_theta
|
|
1112
|
+
delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
|
|
1113
|
+
ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
|
|
1114
|
+
|
|
1115
|
+
ang_loss = scale_weight[fg_mask] * ang_loss
|
|
1116
|
+
ang_loss = ang_loss * weight
|
|
1117
|
+
|
|
1118
|
+
return ang_loss.sum() / target_scores_sum
|
|
1119
|
+
|
|
773
1120
|
|
|
774
1121
|
class E2EDetectLoss:
|
|
775
1122
|
"""Criterion class for computing training losses for end-to-end detection."""
|
|
@@ -789,61 +1136,108 @@ class E2EDetectLoss:
|
|
|
789
1136
|
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
|
|
790
1137
|
|
|
791
1138
|
|
|
1139
|
+
class E2ELoss:
|
|
1140
|
+
"""Criterion class for computing training losses for end-to-end detection."""
|
|
1141
|
+
|
|
1142
|
+
def __init__(self, model, loss_fn=v8DetectionLoss):
|
|
1143
|
+
"""Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
|
|
1144
|
+
self.one2many = loss_fn(model, tal_topk=10)
|
|
1145
|
+
self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
|
|
1146
|
+
self.updates = 0
|
|
1147
|
+
self.total = 1.0
|
|
1148
|
+
# init gain
|
|
1149
|
+
self.o2m = 0.8
|
|
1150
|
+
self.o2o = self.total - self.o2m
|
|
1151
|
+
self.o2m_copy = self.o2m
|
|
1152
|
+
# final gain
|
|
1153
|
+
self.final_o2m = 0.1
|
|
1154
|
+
|
|
1155
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1156
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
1157
|
+
preds = self.one2many.parse_output(preds)
|
|
1158
|
+
one2many, one2one = preds["one2many"], preds["one2one"]
|
|
1159
|
+
loss_one2many = self.one2many.loss(one2many, batch)
|
|
1160
|
+
loss_one2one = self.one2one.loss(one2one, batch)
|
|
1161
|
+
return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
|
|
1162
|
+
|
|
1163
|
+
def update(self) -> None:
|
|
1164
|
+
"""Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
|
|
1165
|
+
self.updates += 1
|
|
1166
|
+
self.o2m = self.decay(self.updates)
|
|
1167
|
+
self.o2o = max(self.total - self.o2m, 0)
|
|
1168
|
+
|
|
1169
|
+
def decay(self, x) -> float:
|
|
1170
|
+
"""Calculate the decayed weight for one-to-many loss based on the current update step."""
|
|
1171
|
+
return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
|
|
1172
|
+
|
|
1173
|
+
|
|
792
1174
|
class TVPDetectLoss:
|
|
793
1175
|
"""Criterion class for computing training losses for text-visual prompt detection."""
|
|
794
1176
|
|
|
795
|
-
def __init__(self, model):
|
|
1177
|
+
def __init__(self, model, tal_topk=10):
|
|
796
1178
|
"""Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
797
|
-
self.vp_criterion = v8DetectionLoss(model)
|
|
1179
|
+
self.vp_criterion = v8DetectionLoss(model, tal_topk)
|
|
798
1180
|
# NOTE: store following info as it's changeable in __call__
|
|
1181
|
+
self.hyp = self.vp_criterion.hyp
|
|
799
1182
|
self.ori_nc = self.vp_criterion.nc
|
|
800
1183
|
self.ori_no = self.vp_criterion.no
|
|
801
1184
|
self.ori_reg_max = self.vp_criterion.reg_max
|
|
802
1185
|
|
|
1186
|
+
def parse_output(self, preds) -> dict[str, torch.Tensor]:
|
|
1187
|
+
"""Parse model predictions to extract features."""
|
|
1188
|
+
return self.vp_criterion.parse_output(preds)
|
|
1189
|
+
|
|
803
1190
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
804
1191
|
"""Calculate the loss for text-visual prompt detection."""
|
|
805
|
-
|
|
1192
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1193
|
+
|
|
1194
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1195
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
1196
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
806
1197
|
|
|
807
|
-
if self.
|
|
1198
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
808
1199
|
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
|
|
809
1200
|
return loss, loss.detach()
|
|
810
1201
|
|
|
811
|
-
|
|
812
|
-
vp_loss = self.vp_criterion(
|
|
813
|
-
|
|
814
|
-
return
|
|
1202
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1203
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
1204
|
+
box_loss = vp_loss[0][1]
|
|
1205
|
+
return box_loss, vp_loss[1]
|
|
815
1206
|
|
|
816
|
-
def _get_vp_features(self,
|
|
1207
|
+
def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
|
|
817
1208
|
"""Extract visual-prompt features from the model output."""
|
|
818
|
-
|
|
1209
|
+
# NOTE: remove empty placeholder
|
|
1210
|
+
scores = preds["scores"][:, self.ori_nc :, :]
|
|
1211
|
+
vnc = scores.shape[1]
|
|
819
1212
|
|
|
820
1213
|
self.vp_criterion.nc = vnc
|
|
821
1214
|
self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
|
|
822
1215
|
self.vp_criterion.assigner.num_classes = vnc
|
|
823
|
-
|
|
824
|
-
return [
|
|
825
|
-
torch.cat((box, cls_vp), dim=1)
|
|
826
|
-
for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
|
|
827
|
-
]
|
|
1216
|
+
return scores
|
|
828
1217
|
|
|
829
1218
|
|
|
830
1219
|
class TVPSegmentLoss(TVPDetectLoss):
|
|
831
1220
|
"""Criterion class for computing training losses for text-visual prompt segmentation."""
|
|
832
1221
|
|
|
833
|
-
def __init__(self, model):
|
|
1222
|
+
def __init__(self, model, tal_topk=10):
|
|
834
1223
|
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
835
1224
|
super().__init__(model)
|
|
836
|
-
self.vp_criterion = v8SegmentationLoss(model)
|
|
1225
|
+
self.vp_criterion = v8SegmentationLoss(model, tal_topk)
|
|
1226
|
+
self.hyp = self.vp_criterion.hyp
|
|
837
1227
|
|
|
838
1228
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
839
1229
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
840
|
-
|
|
1230
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1231
|
+
|
|
1232
|
+
def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1233
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
1234
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
841
1235
|
|
|
842
|
-
if self.
|
|
1236
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
843
1237
|
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
|
844
1238
|
return loss, loss.detach()
|
|
845
1239
|
|
|
846
|
-
|
|
847
|
-
vp_loss = self.vp_criterion(
|
|
1240
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1241
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
848
1242
|
cls_loss = vp_loss[0][2]
|
|
849
1243
|
return cls_loss, vp_loss[1]
|