dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl → 8.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/METADATA +41 -49
- {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/RECORD +85 -74
- tests/__init__.py +2 -2
- tests/conftest.py +1 -1
- tests/test_cuda.py +8 -2
- tests/test_engine.py +8 -8
- tests/test_exports.py +11 -4
- tests/test_integrations.py +9 -9
- tests/test_python.py +14 -14
- tests/test_solutions.py +3 -3
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +25 -27
- ultralytics/cfg/default.yaml +3 -1
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/data/annotator.py +2 -2
- ultralytics/data/augment.py +7 -0
- ultralytics/data/converter.py +57 -38
- ultralytics/data/dataset.py +1 -1
- ultralytics/engine/exporter.py +31 -26
- ultralytics/engine/model.py +34 -34
- ultralytics/engine/predictor.py +17 -17
- ultralytics/engine/results.py +14 -12
- ultralytics/engine/trainer.py +59 -29
- ultralytics/engine/tuner.py +19 -11
- ultralytics/engine/validator.py +16 -16
- ultralytics/models/fastsam/predict.py +1 -1
- ultralytics/models/yolo/classify/predict.py +1 -1
- ultralytics/models/yolo/classify/train.py +1 -1
- ultralytics/models/yolo/classify/val.py +1 -1
- ultralytics/models/yolo/detect/predict.py +2 -2
- ultralytics/models/yolo/detect/train.py +4 -3
- ultralytics/models/yolo/detect/val.py +7 -1
- ultralytics/models/yolo/model.py +8 -8
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +1 -1
- ultralytics/models/yolo/pose/predict.py +1 -1
- ultralytics/models/yolo/pose/train.py +3 -1
- ultralytics/models/yolo/pose/val.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -3
- ultralytics/models/yolo/segment/train.py +4 -4
- ultralytics/models/yolo/segment/val.py +4 -2
- ultralytics/models/yolo/yoloe/train.py +6 -1
- ultralytics/models/yolo/yoloe/train_seg.py +6 -1
- ultralytics/nn/autobackend.py +5 -5
- ultralytics/nn/modules/__init__.py +8 -0
- ultralytics/nn/modules/block.py +128 -8
- ultralytics/nn/modules/head.py +788 -203
- ultralytics/nn/tasks.py +86 -41
- ultralytics/nn/text_model.py +5 -2
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/ai_gym.py +3 -3
- ultralytics/solutions/config.py +1 -1
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +2 -2
- ultralytics/solutions/parking_management.py +1 -1
- ultralytics/solutions/solutions.py +2 -2
- ultralytics/trackers/track.py +1 -1
- ultralytics/utils/__init__.py +8 -8
- ultralytics/utils/benchmarks.py +23 -23
- ultralytics/utils/callbacks/platform.py +11 -7
- ultralytics/utils/checks.py +6 -6
- ultralytics/utils/downloads.py +5 -3
- ultralytics/utils/export/engine.py +19 -10
- ultralytics/utils/export/imx.py +19 -13
- ultralytics/utils/export/tensorflow.py +21 -21
- ultralytics/utils/files.py +2 -2
- ultralytics/utils/loss.py +587 -203
- ultralytics/utils/metrics.py +1 -0
- ultralytics/utils/ops.py +11 -2
- ultralytics/utils/tal.py +98 -19
- ultralytics/utils/tuner.py +2 -2
- {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py
CHANGED
|
@@ -2,19 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
import torch.nn as nn
|
|
9
10
|
import torch.nn.functional as F
|
|
10
11
|
|
|
11
|
-
from ultralytics.utils.metrics import OKS_SIGMA
|
|
12
|
+
from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
|
|
12
13
|
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
|
13
14
|
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
|
14
15
|
from ultralytics.utils.torch_utils import autocast
|
|
15
16
|
|
|
16
17
|
from .metrics import bbox_iou, probiou
|
|
17
|
-
from .tal import bbox2dist
|
|
18
|
+
from .tal import bbox2dist, rbox2dist
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class VarifocalLoss(nn.Module):
|
|
@@ -122,6 +123,8 @@ class BboxLoss(nn.Module):
|
|
|
122
123
|
target_scores: torch.Tensor,
|
|
123
124
|
target_scores_sum: torch.Tensor,
|
|
124
125
|
fg_mask: torch.Tensor,
|
|
126
|
+
imgsz: torch.Tensor,
|
|
127
|
+
stride: torch.Tensor,
|
|
125
128
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
126
129
|
"""Compute IoU and DFL losses for bounding boxes."""
|
|
127
130
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -134,11 +137,76 @@ class BboxLoss(nn.Module):
|
|
|
134
137
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
135
138
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
136
139
|
else:
|
|
137
|
-
|
|
140
|
+
target_ltrb = bbox2dist(anchor_points, target_bboxes)
|
|
141
|
+
# normalize ltrb by image size
|
|
142
|
+
target_ltrb = target_ltrb * stride
|
|
143
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
144
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
145
|
+
pred_dist = pred_dist * stride
|
|
146
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
147
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
148
|
+
loss_dfl = (
|
|
149
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
150
|
+
)
|
|
151
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
138
152
|
|
|
139
153
|
return loss_iou, loss_dfl
|
|
140
154
|
|
|
141
155
|
|
|
156
|
+
class RLELoss(nn.Module):
|
|
157
|
+
"""Residual Log-Likelihood Estimation Loss.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
use_target_weight (bool): Option to use weighted loss.
|
|
161
|
+
size_average (bool): Option to average the loss by the batch_size.
|
|
162
|
+
residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
|
|
163
|
+
|
|
164
|
+
References:
|
|
165
|
+
https://arxiv.org/abs/2107.11291
|
|
166
|
+
https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/losses/regression_loss.py
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
|
|
170
|
+
"""Initialize RLELoss with target weight and residual options.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
use_target_weight (bool): Whether to use target weights for loss calculation.
|
|
174
|
+
size_average (bool): Whether to average the loss over elements.
|
|
175
|
+
residual (bool): Whether to include residual log-likelihood term.
|
|
176
|
+
"""
|
|
177
|
+
super().__init__()
|
|
178
|
+
self.size_average = size_average
|
|
179
|
+
self.use_target_weight = use_target_weight
|
|
180
|
+
self.residual = residual
|
|
181
|
+
|
|
182
|
+
def forward(
|
|
183
|
+
self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
|
|
184
|
+
) -> torch.Tensor:
|
|
185
|
+
"""
|
|
186
|
+
Args:
|
|
187
|
+
sigma (torch.Tensor): Output sigma, shape (N, D).
|
|
188
|
+
log_phi (torch.Tensor): Output log_phi, shape (N).
|
|
189
|
+
error (torch.Tensor): Error, shape (N, D).
|
|
190
|
+
target_weight (torch.Tensor): Weights across different joint types, shape (N).
|
|
191
|
+
"""
|
|
192
|
+
log_sigma = torch.log(sigma)
|
|
193
|
+
loss = log_sigma - log_phi.unsqueeze(1)
|
|
194
|
+
|
|
195
|
+
if self.residual:
|
|
196
|
+
loss += torch.log(sigma * 2) + torch.abs(error)
|
|
197
|
+
|
|
198
|
+
if self.use_target_weight:
|
|
199
|
+
assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
|
|
200
|
+
if target_weight.dim() == 1:
|
|
201
|
+
target_weight = target_weight.unsqueeze(1)
|
|
202
|
+
loss *= target_weight
|
|
203
|
+
|
|
204
|
+
if self.size_average:
|
|
205
|
+
loss /= len(loss)
|
|
206
|
+
|
|
207
|
+
return loss.sum()
|
|
208
|
+
|
|
209
|
+
|
|
142
210
|
class RotatedBboxLoss(BboxLoss):
|
|
143
211
|
"""Criterion class for computing training losses for rotated bounding boxes."""
|
|
144
212
|
|
|
@@ -155,6 +223,8 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
155
223
|
target_scores: torch.Tensor,
|
|
156
224
|
target_scores_sum: torch.Tensor,
|
|
157
225
|
fg_mask: torch.Tensor,
|
|
226
|
+
imgsz: torch.Tensor,
|
|
227
|
+
stride: torch.Tensor,
|
|
158
228
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
159
229
|
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
|
160
230
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -163,15 +233,84 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
163
233
|
|
|
164
234
|
# DFL loss
|
|
165
235
|
if self.dfl_loss:
|
|
166
|
-
target_ltrb =
|
|
236
|
+
target_ltrb = rbox2dist(
|
|
237
|
+
target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
|
|
238
|
+
)
|
|
167
239
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
168
240
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
169
241
|
else:
|
|
170
|
-
|
|
242
|
+
target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
|
|
243
|
+
target_ltrb = target_ltrb * stride
|
|
244
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
245
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
246
|
+
pred_dist = pred_dist * stride
|
|
247
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
248
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
249
|
+
loss_dfl = (
|
|
250
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
251
|
+
)
|
|
252
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
171
253
|
|
|
172
254
|
return loss_iou, loss_dfl
|
|
173
255
|
|
|
174
256
|
|
|
257
|
+
class MultiChannelDiceLoss(nn.Module):
|
|
258
|
+
"""Criterion class for computing multi-channel Dice losses."""
|
|
259
|
+
|
|
260
|
+
def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
|
|
261
|
+
"""Initialize MultiChannelDiceLoss with smoothing and reduction options.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
265
|
+
reduction (str): Reduction method ('mean', 'sum', or 'none').
|
|
266
|
+
"""
|
|
267
|
+
super().__init__()
|
|
268
|
+
self.smooth = smooth
|
|
269
|
+
self.reduction = reduction
|
|
270
|
+
|
|
271
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
272
|
+
"""Calculate multi-channel Dice loss between predictions and targets."""
|
|
273
|
+
assert pred.size() == target.size(), "the size of predict and target must be equal."
|
|
274
|
+
|
|
275
|
+
pred = pred.sigmoid()
|
|
276
|
+
intersection = (pred * target).sum(dim=(2, 3))
|
|
277
|
+
union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
|
|
278
|
+
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
|
279
|
+
dice_loss = 1.0 - dice
|
|
280
|
+
dice_loss = dice_loss.mean(dim=1)
|
|
281
|
+
|
|
282
|
+
if self.reduction == "mean":
|
|
283
|
+
return dice_loss.mean()
|
|
284
|
+
elif self.reduction == "sum":
|
|
285
|
+
return dice_loss.sum()
|
|
286
|
+
else:
|
|
287
|
+
return dice_loss
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class BCEDiceLoss(nn.Module):
|
|
291
|
+
"""Criterion class for computing combined BCE and Dice losses."""
|
|
292
|
+
|
|
293
|
+
def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
|
|
294
|
+
"""Initialize BCEDiceLoss with BCE and Dice weight factors.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
weight_bce (float): Weight factor for BCE loss component.
|
|
298
|
+
weight_dice (float): Weight factor for Dice loss component.
|
|
299
|
+
"""
|
|
300
|
+
super().__init__()
|
|
301
|
+
self.weight_bce = weight_bce
|
|
302
|
+
self.weight_dice = weight_dice
|
|
303
|
+
self.bce = nn.BCEWithLogitsLoss()
|
|
304
|
+
self.dice = MultiChannelDiceLoss(smooth=1)
|
|
305
|
+
|
|
306
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
307
|
+
"""Calculate combined BCE and Dice loss between predictions and targets."""
|
|
308
|
+
_, _, mask_h, mask_w = pred.shape
|
|
309
|
+
if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
|
|
310
|
+
target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
|
|
311
|
+
return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
|
|
312
|
+
|
|
313
|
+
|
|
175
314
|
class KeypointLoss(nn.Module):
|
|
176
315
|
"""Criterion class for computing keypoint losses."""
|
|
177
316
|
|
|
@@ -194,7 +333,7 @@ class KeypointLoss(nn.Module):
|
|
|
194
333
|
class v8DetectionLoss:
|
|
195
334
|
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
|
196
335
|
|
|
197
|
-
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
|
|
336
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
198
337
|
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
|
199
338
|
device = next(model.parameters()).device # get model device
|
|
200
339
|
h = model.args # hyperparameters
|
|
@@ -210,7 +349,14 @@ class v8DetectionLoss:
|
|
|
210
349
|
|
|
211
350
|
self.use_dfl = m.reg_max > 1
|
|
212
351
|
|
|
213
|
-
self.assigner = TaskAlignedAssigner(
|
|
352
|
+
self.assigner = TaskAlignedAssigner(
|
|
353
|
+
topk=tal_topk,
|
|
354
|
+
num_classes=self.nc,
|
|
355
|
+
alpha=0.5,
|
|
356
|
+
beta=6.0,
|
|
357
|
+
stride=self.stride.tolist(),
|
|
358
|
+
topk2=tal_topk2,
|
|
359
|
+
)
|
|
214
360
|
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
|
215
361
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
|
216
362
|
|
|
@@ -240,35 +386,31 @@ class v8DetectionLoss:
|
|
|
240
386
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
|
241
387
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
|
242
388
|
|
|
243
|
-
def
|
|
244
|
-
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size
|
|
389
|
+
def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
|
|
390
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
|
|
391
|
+
target indices.
|
|
392
|
+
"""
|
|
245
393
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
(
|
|
394
|
+
pred_distri, pred_scores = (
|
|
395
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
396
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
249
397
|
)
|
|
250
|
-
|
|
251
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
252
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
398
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
253
399
|
|
|
254
400
|
dtype = pred_scores.dtype
|
|
255
401
|
batch_size = pred_scores.shape[0]
|
|
256
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
257
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
402
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
258
403
|
|
|
259
404
|
# Targets
|
|
260
405
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
261
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
406
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
262
407
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
263
408
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
264
409
|
|
|
265
410
|
# Pboxes
|
|
266
411
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
267
|
-
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
|
|
268
|
-
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
|
|
269
412
|
|
|
270
|
-
_, target_bboxes, target_scores, fg_mask,
|
|
271
|
-
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
|
|
413
|
+
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
272
414
|
pred_scores.detach().sigmoid(),
|
|
273
415
|
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
274
416
|
anchor_points * stride_tensor,
|
|
@@ -280,7 +422,6 @@ class v8DetectionLoss:
|
|
|
280
422
|
target_scores_sum = max(target_scores.sum(), 1)
|
|
281
423
|
|
|
282
424
|
# Cls loss
|
|
283
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
284
425
|
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
285
426
|
|
|
286
427
|
# Bbox loss
|
|
@@ -293,105 +434,98 @@ class v8DetectionLoss:
|
|
|
293
434
|
target_scores,
|
|
294
435
|
target_scores_sum,
|
|
295
436
|
fg_mask,
|
|
437
|
+
imgsz,
|
|
438
|
+
stride_tensor,
|
|
296
439
|
)
|
|
297
440
|
|
|
298
441
|
loss[0] *= self.hyp.box # box gain
|
|
299
442
|
loss[1] *= self.hyp.cls # cls gain
|
|
300
443
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
444
|
+
return (
|
|
445
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
|
|
446
|
+
loss,
|
|
447
|
+
loss.detach(),
|
|
448
|
+
) # loss(box, cls, dfl)
|
|
301
449
|
|
|
302
|
-
|
|
450
|
+
def parse_output(
|
|
451
|
+
self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
|
|
452
|
+
) -> torch.Tensor:
|
|
453
|
+
"""Parse model predictions to extract features."""
|
|
454
|
+
return preds[1] if isinstance(preds, tuple) else preds
|
|
455
|
+
|
|
456
|
+
def __call__(
|
|
457
|
+
self,
|
|
458
|
+
preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
|
|
459
|
+
batch: dict[str, torch.Tensor],
|
|
460
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
461
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
462
|
+
return self.loss(self.parse_output(preds), batch)
|
|
463
|
+
|
|
464
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
465
|
+
"""A wrapper for get_assigned_targets_and_loss and parse_output."""
|
|
466
|
+
batch_size = preds["boxes"].shape[0]
|
|
467
|
+
loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
|
|
468
|
+
return loss * batch_size, loss_detach
|
|
303
469
|
|
|
304
470
|
|
|
305
471
|
class v8SegmentationLoss(v8DetectionLoss):
|
|
306
472
|
"""Criterion class for computing training losses for YOLOv8 segmentation."""
|
|
307
473
|
|
|
308
|
-
def __init__(self, model): # model must be de-paralleled
|
|
474
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
309
475
|
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
|
|
310
|
-
super().__init__(model)
|
|
476
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
311
477
|
self.overlap = model.args.overlap_mask
|
|
478
|
+
self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
|
|
312
479
|
|
|
313
|
-
def
|
|
480
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
314
481
|
"""Calculate and return the combined loss for detection and segmentation."""
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
#
|
|
323
|
-
|
|
324
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
325
|
-
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
|
326
|
-
|
|
327
|
-
dtype = pred_scores.dtype
|
|
328
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
329
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
330
|
-
|
|
331
|
-
# Targets
|
|
332
|
-
try:
|
|
333
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
334
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
335
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
336
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
337
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
338
|
-
except RuntimeError as e:
|
|
339
|
-
raise TypeError(
|
|
340
|
-
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
|
341
|
-
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
|
342
|
-
"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
|
343
|
-
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
|
344
|
-
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
|
345
|
-
) from e
|
|
346
|
-
|
|
347
|
-
# Pboxes
|
|
348
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
349
|
-
|
|
350
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
351
|
-
pred_scores.detach().sigmoid(),
|
|
352
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
353
|
-
anchor_points * stride_tensor,
|
|
354
|
-
gt_labels,
|
|
355
|
-
gt_bboxes,
|
|
356
|
-
mask_gt,
|
|
357
|
-
)
|
|
358
|
-
|
|
359
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
360
|
-
|
|
361
|
-
# Cls loss
|
|
362
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
363
|
-
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
482
|
+
pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
|
|
483
|
+
loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
|
|
484
|
+
if isinstance(proto, tuple) and len(proto) == 2:
|
|
485
|
+
proto, pred_semseg = proto
|
|
486
|
+
else:
|
|
487
|
+
pred_semseg = None
|
|
488
|
+
(fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
|
|
489
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
490
|
+
loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
|
|
364
491
|
|
|
492
|
+
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
|
365
493
|
if fg_mask.sum():
|
|
366
|
-
# Bbox loss
|
|
367
|
-
loss[0], loss[3] = self.bbox_loss(
|
|
368
|
-
pred_distri,
|
|
369
|
-
pred_bboxes,
|
|
370
|
-
anchor_points,
|
|
371
|
-
target_bboxes / stride_tensor,
|
|
372
|
-
target_scores,
|
|
373
|
-
target_scores_sum,
|
|
374
|
-
fg_mask,
|
|
375
|
-
)
|
|
376
494
|
# Masks loss
|
|
377
495
|
masks = batch["masks"].to(self.device).float()
|
|
378
496
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
|
379
|
-
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
497
|
+
# masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
498
|
+
proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
|
|
380
499
|
|
|
500
|
+
imgsz = (
|
|
501
|
+
torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
|
|
502
|
+
)
|
|
381
503
|
loss[1] = self.calculate_segmentation_loss(
|
|
382
|
-
fg_mask,
|
|
504
|
+
fg_mask,
|
|
505
|
+
masks,
|
|
506
|
+
target_gt_idx,
|
|
507
|
+
target_bboxes,
|
|
508
|
+
batch["batch_idx"].view(-1, 1),
|
|
509
|
+
proto,
|
|
510
|
+
pred_masks,
|
|
511
|
+
imgsz,
|
|
383
512
|
)
|
|
513
|
+
if pred_semseg is not None:
|
|
514
|
+
sem_masks = batch["sem_masks"].to(self.device) # NxHxW
|
|
515
|
+
mask_zero = sem_masks == 0 # NxHxW
|
|
516
|
+
sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
|
|
517
|
+
sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
|
|
518
|
+
loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
|
|
519
|
+
loss[4] *= self.hyp.box # seg gain
|
|
384
520
|
|
|
385
521
|
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
|
386
522
|
else:
|
|
387
523
|
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
|
524
|
+
if pred_semseg is not None:
|
|
525
|
+
loss[4] += (pred_semseg * 0).sum()
|
|
388
526
|
|
|
389
|
-
loss[0] *= self.hyp.box # box gain
|
|
390
527
|
loss[1] *= self.hyp.box # seg gain
|
|
391
|
-
loss
|
|
392
|
-
loss[3] *= self.hyp.dfl # dfl gain
|
|
393
|
-
|
|
394
|
-
return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
|
|
528
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
395
529
|
|
|
396
530
|
@staticmethod
|
|
397
531
|
def single_mask_loss(
|
|
@@ -427,7 +561,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
427
561
|
proto: torch.Tensor,
|
|
428
562
|
pred_masks: torch.Tensor,
|
|
429
563
|
imgsz: torch.Tensor,
|
|
430
|
-
overlap: bool,
|
|
431
564
|
) -> torch.Tensor:
|
|
432
565
|
"""Calculate the loss for instance segmentation.
|
|
433
566
|
|
|
@@ -440,7 +573,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
440
573
|
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
|
441
574
|
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
|
442
575
|
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
|
443
|
-
overlap (bool): Whether the masks in `masks` tensor overlap.
|
|
444
576
|
|
|
445
577
|
Returns:
|
|
446
578
|
(torch.Tensor): The calculated loss for instance segmentation.
|
|
@@ -466,7 +598,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
466
598
|
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
|
467
599
|
if fg_mask_i.any():
|
|
468
600
|
mask_idx = target_gt_idx_i[fg_mask_i]
|
|
469
|
-
if overlap:
|
|
601
|
+
if self.overlap:
|
|
470
602
|
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
|
471
603
|
gt_mask = gt_mask.float()
|
|
472
604
|
else:
|
|
@@ -486,9 +618,9 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
486
618
|
class v8PoseLoss(v8DetectionLoss):
|
|
487
619
|
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
|
|
488
620
|
|
|
489
|
-
def __init__(self, model): # model must be de-paralleled
|
|
621
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
|
|
490
622
|
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
|
|
491
|
-
super().__init__(model)
|
|
623
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
492
624
|
self.kpt_shape = model.model[-1].kpt_shape
|
|
493
625
|
self.bce_pose = nn.BCEWithLogitsLoss()
|
|
494
626
|
is_pose = self.kpt_shape == [17, 3]
|
|
@@ -496,69 +628,40 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
496
628
|
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
|
497
629
|
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
|
498
630
|
|
|
499
|
-
def
|
|
631
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
500
632
|
"""Calculate the total loss and detach it for pose estimation."""
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
633
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
634
|
+
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
635
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
636
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
505
637
|
)
|
|
638
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
639
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
506
640
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
510
|
-
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
|
511
|
-
|
|
512
|
-
dtype = pred_scores.dtype
|
|
513
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
514
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
515
|
-
|
|
516
|
-
# Targets
|
|
517
|
-
batch_size = pred_scores.shape[0]
|
|
518
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
519
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
520
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
521
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
522
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
641
|
+
batch_size = pred_kpts.shape[0]
|
|
642
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
|
|
523
643
|
|
|
524
644
|
# Pboxes
|
|
525
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
526
645
|
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
|
527
646
|
|
|
528
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
529
|
-
pred_scores.detach().sigmoid(),
|
|
530
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
531
|
-
anchor_points * stride_tensor,
|
|
532
|
-
gt_labels,
|
|
533
|
-
gt_bboxes,
|
|
534
|
-
mask_gt,
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
538
|
-
|
|
539
|
-
# Cls loss
|
|
540
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
541
|
-
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
542
|
-
|
|
543
647
|
# Bbox loss
|
|
544
648
|
if fg_mask.sum():
|
|
545
|
-
target_bboxes /= stride_tensor
|
|
546
|
-
loss[0], loss[4] = self.bbox_loss(
|
|
547
|
-
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
|
548
|
-
)
|
|
549
649
|
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
550
650
|
keypoints[..., 0] *= imgsz[1]
|
|
551
651
|
keypoints[..., 1] *= imgsz[0]
|
|
552
652
|
|
|
553
653
|
loss[1], loss[2] = self.calculate_keypoints_loss(
|
|
554
|
-
fg_mask,
|
|
654
|
+
fg_mask,
|
|
655
|
+
target_gt_idx,
|
|
656
|
+
keypoints,
|
|
657
|
+
batch["batch_idx"].view(-1, 1),
|
|
658
|
+
stride_tensor,
|
|
659
|
+
target_bboxes,
|
|
660
|
+
pred_kpts,
|
|
555
661
|
)
|
|
556
662
|
|
|
557
|
-
loss[0] *= self.hyp.box # box gain
|
|
558
663
|
loss[1] *= self.hyp.pose # pose gain
|
|
559
664
|
loss[2] *= self.hyp.kobj # kobj gain
|
|
560
|
-
loss[3] *= self.hyp.cls # cls gain
|
|
561
|
-
loss[4] *= self.hyp.dfl # dfl gain
|
|
562
665
|
|
|
563
666
|
return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
|
|
564
667
|
|
|
@@ -571,34 +674,23 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
571
674
|
y[..., 1] += anchor_points[:, [1]] - 0.5
|
|
572
675
|
return y
|
|
573
676
|
|
|
574
|
-
def
|
|
677
|
+
def _select_target_keypoints(
|
|
575
678
|
self,
|
|
576
|
-
masks: torch.Tensor,
|
|
577
|
-
target_gt_idx: torch.Tensor,
|
|
578
679
|
keypoints: torch.Tensor,
|
|
579
680
|
batch_idx: torch.Tensor,
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
"""Calculate the keypoints loss for the model.
|
|
585
|
-
|
|
586
|
-
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
587
|
-
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
588
|
-
a binary classification loss that classifies whether a keypoint is present or not.
|
|
681
|
+
target_gt_idx: torch.Tensor,
|
|
682
|
+
masks: torch.Tensor,
|
|
683
|
+
) -> torch.Tensor:
|
|
684
|
+
"""Select target keypoints for each anchor based on batch index and target ground truth index.
|
|
589
685
|
|
|
590
686
|
Args:
|
|
591
|
-
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
592
|
-
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
593
687
|
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
594
688
|
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
689
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
690
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
598
691
|
|
|
599
692
|
Returns:
|
|
600
|
-
|
|
601
|
-
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
693
|
+
(torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
602
694
|
"""
|
|
603
695
|
batch_idx = batch_idx.flatten()
|
|
604
696
|
batch_size = len(masks)
|
|
@@ -625,6 +717,40 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
625
717
|
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
|
626
718
|
)
|
|
627
719
|
|
|
720
|
+
return selected_keypoints
|
|
721
|
+
|
|
722
|
+
def calculate_keypoints_loss(
|
|
723
|
+
self,
|
|
724
|
+
masks: torch.Tensor,
|
|
725
|
+
target_gt_idx: torch.Tensor,
|
|
726
|
+
keypoints: torch.Tensor,
|
|
727
|
+
batch_idx: torch.Tensor,
|
|
728
|
+
stride_tensor: torch.Tensor,
|
|
729
|
+
target_bboxes: torch.Tensor,
|
|
730
|
+
pred_kpts: torch.Tensor,
|
|
731
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
732
|
+
"""Calculate the keypoints loss for the model.
|
|
733
|
+
|
|
734
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
735
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
736
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
740
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
741
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
742
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
743
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
744
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
745
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
746
|
+
|
|
747
|
+
Returns:
|
|
748
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
749
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
750
|
+
"""
|
|
751
|
+
# Select target keypoints using helper method
|
|
752
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
753
|
+
|
|
628
754
|
# Divide coordinates by stride
|
|
629
755
|
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
630
756
|
|
|
@@ -632,6 +758,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
632
758
|
kpts_obj_loss = 0
|
|
633
759
|
|
|
634
760
|
if masks.any():
|
|
761
|
+
target_bboxes /= stride_tensor
|
|
635
762
|
gt_kpt = selected_keypoints[masks]
|
|
636
763
|
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
637
764
|
pred_kpt = pred_kpts[masks]
|
|
@@ -644,6 +771,172 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
644
771
|
return kpts_loss, kpts_obj_loss
|
|
645
772
|
|
|
646
773
|
|
|
774
|
+
class PoseLoss26(v8PoseLoss):
|
|
775
|
+
"""Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
|
|
776
|
+
|
|
777
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
778
|
+
"""Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
|
|
779
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
780
|
+
is_pose = self.kpt_shape == [17, 3]
|
|
781
|
+
nkpt = self.kpt_shape[0] # number of keypoints
|
|
782
|
+
self.rle_loss = None
|
|
783
|
+
self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
|
|
784
|
+
if self.flow_model is not None:
|
|
785
|
+
self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
|
|
786
|
+
self.target_weights = (
|
|
787
|
+
torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
791
|
+
"""Calculate the total loss and detach it for pose estimation."""
|
|
792
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
793
|
+
loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
794
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
795
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
796
|
+
)
|
|
797
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
798
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
799
|
+
|
|
800
|
+
batch_size = pred_kpts.shape[0]
|
|
801
|
+
imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=pred_kpts.dtype) # image size (h,w)
|
|
802
|
+
|
|
803
|
+
pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
|
|
804
|
+
|
|
805
|
+
if self.rle_loss and preds.get("kpts_sigma", None) is not None:
|
|
806
|
+
pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
|
|
807
|
+
pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
|
|
808
|
+
pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
|
|
809
|
+
|
|
810
|
+
pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
|
|
811
|
+
|
|
812
|
+
# Bbox loss
|
|
813
|
+
if fg_mask.sum():
|
|
814
|
+
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
815
|
+
keypoints[..., 0] *= imgsz[1]
|
|
816
|
+
keypoints[..., 1] *= imgsz[0]
|
|
817
|
+
|
|
818
|
+
keypoints_loss = self.calculate_keypoints_loss(
|
|
819
|
+
fg_mask,
|
|
820
|
+
target_gt_idx,
|
|
821
|
+
keypoints,
|
|
822
|
+
batch["batch_idx"].view(-1, 1),
|
|
823
|
+
stride_tensor,
|
|
824
|
+
target_bboxes,
|
|
825
|
+
pred_kpts,
|
|
826
|
+
)
|
|
827
|
+
loss[1] = keypoints_loss[0]
|
|
828
|
+
loss[2] = keypoints_loss[1]
|
|
829
|
+
if self.rle_loss is not None:
|
|
830
|
+
loss[5] = keypoints_loss[2]
|
|
831
|
+
|
|
832
|
+
loss[1] *= self.hyp.pose # pose gain
|
|
833
|
+
loss[2] *= self.hyp.kobj # kobj gain
|
|
834
|
+
if self.rle_loss is not None:
|
|
835
|
+
loss[5] *= self.hyp.rle # rle gain
|
|
836
|
+
|
|
837
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl, kpt_location, kpt_visibility)
|
|
838
|
+
|
|
839
|
+
@staticmethod
|
|
840
|
+
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
|
841
|
+
"""Decode predicted keypoints to image coordinates."""
|
|
842
|
+
y = pred_kpts.clone()
|
|
843
|
+
y[..., 0] += anchor_points[:, [0]]
|
|
844
|
+
y[..., 1] += anchor_points[:, [1]]
|
|
845
|
+
return y
|
|
846
|
+
|
|
847
|
+
def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
|
|
848
|
+
"""Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
|
|
852
|
+
gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
|
|
853
|
+
kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
|
|
854
|
+
|
|
855
|
+
Returns:
|
|
856
|
+
(torch.Tensor): The RLE loss.
|
|
857
|
+
"""
|
|
858
|
+
pred_kpt_visible = pred_kpt[kpt_mask]
|
|
859
|
+
gt_kpt_visible = gt_kpt[kpt_mask]
|
|
860
|
+
pred_coords = pred_kpt_visible[:, 0:2]
|
|
861
|
+
pred_sigma = pred_kpt_visible[:, -2:]
|
|
862
|
+
gt_coords = gt_kpt_visible[:, 0:2]
|
|
863
|
+
|
|
864
|
+
target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
|
|
865
|
+
target_weights = target_weights[kpt_mask]
|
|
866
|
+
|
|
867
|
+
pred_sigma = pred_sigma.sigmoid()
|
|
868
|
+
error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
|
|
869
|
+
|
|
870
|
+
# Filter out NaN and Inf values to prevent MultivariateNormal validation errors
|
|
871
|
+
valid_mask = ~(torch.isnan(error) | torch.isinf(error)).any(dim=-1)
|
|
872
|
+
if not valid_mask.any():
|
|
873
|
+
return torch.tensor(0.0, device=pred_kpt.device)
|
|
874
|
+
|
|
875
|
+
error = error[valid_mask]
|
|
876
|
+
error = error.clamp(-100, 100) # Prevent numerical instability
|
|
877
|
+
pred_sigma = pred_sigma[valid_mask]
|
|
878
|
+
target_weights = target_weights[valid_mask]
|
|
879
|
+
|
|
880
|
+
log_phi = self.flow_model.log_prob(error)
|
|
881
|
+
|
|
882
|
+
return self.rle_loss(pred_sigma, log_phi, error, target_weights)
|
|
883
|
+
|
|
884
|
+
def calculate_keypoints_loss(
|
|
885
|
+
self,
|
|
886
|
+
masks: torch.Tensor,
|
|
887
|
+
target_gt_idx: torch.Tensor,
|
|
888
|
+
keypoints: torch.Tensor,
|
|
889
|
+
batch_idx: torch.Tensor,
|
|
890
|
+
stride_tensor: torch.Tensor,
|
|
891
|
+
target_bboxes: torch.Tensor,
|
|
892
|
+
pred_kpts: torch.Tensor,
|
|
893
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
894
|
+
"""Calculate the keypoints loss for the model.
|
|
895
|
+
|
|
896
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
897
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
898
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
902
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
903
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
904
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
905
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
906
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
907
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
908
|
+
|
|
909
|
+
Returns:
|
|
910
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
911
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
912
|
+
rle_loss (torch.Tensor): The RLE loss.
|
|
913
|
+
"""
|
|
914
|
+
# Select target keypoints using inherited helper method
|
|
915
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
916
|
+
|
|
917
|
+
# Divide coordinates by stride
|
|
918
|
+
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
919
|
+
|
|
920
|
+
kpts_loss = 0
|
|
921
|
+
kpts_obj_loss = 0
|
|
922
|
+
rle_loss = 0
|
|
923
|
+
|
|
924
|
+
if masks.any():
|
|
925
|
+
target_bboxes /= stride_tensor
|
|
926
|
+
gt_kpt = selected_keypoints[masks]
|
|
927
|
+
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
928
|
+
pred_kpt = pred_kpts[masks]
|
|
929
|
+
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
|
930
|
+
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
|
931
|
+
|
|
932
|
+
if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
|
|
933
|
+
rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
|
|
934
|
+
if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
|
|
935
|
+
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
|
936
|
+
|
|
937
|
+
return kpts_loss, kpts_obj_loss, rle_loss
|
|
938
|
+
|
|
939
|
+
|
|
647
940
|
class v8ClassificationLoss:
|
|
648
941
|
"""Criterion class for computing training losses for classification."""
|
|
649
942
|
|
|
@@ -657,10 +950,17 @@ class v8ClassificationLoss:
|
|
|
657
950
|
class v8OBBLoss(v8DetectionLoss):
|
|
658
951
|
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
|
|
659
952
|
|
|
660
|
-
def __init__(self, model):
|
|
953
|
+
def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
|
|
661
954
|
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
|
|
662
|
-
super().__init__(model)
|
|
663
|
-
self.assigner = RotatedTaskAlignedAssigner(
|
|
955
|
+
super().__init__(model, tal_topk=tal_topk)
|
|
956
|
+
self.assigner = RotatedTaskAlignedAssigner(
|
|
957
|
+
topk=tal_topk,
|
|
958
|
+
num_classes=self.nc,
|
|
959
|
+
alpha=0.5,
|
|
960
|
+
beta=6.0,
|
|
961
|
+
stride=self.stride.tolist(),
|
|
962
|
+
topk2=tal_topk2,
|
|
963
|
+
)
|
|
664
964
|
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
|
665
965
|
|
|
666
966
|
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
|
@@ -680,23 +980,19 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
680
980
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
|
681
981
|
return out
|
|
682
982
|
|
|
683
|
-
def
|
|
983
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
684
984
|
"""Calculate and return the loss for oriented bounding box detection."""
|
|
685
|
-
loss = torch.zeros(
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
(
|
|
985
|
+
loss = torch.zeros(4, device=self.device) # box, cls, dfl, angle
|
|
986
|
+
pred_distri, pred_scores, pred_angle = (
|
|
987
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
988
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
989
|
+
preds["angle"].permute(0, 2, 1).contiguous(),
|
|
690
990
|
)
|
|
691
|
-
|
|
692
|
-
#
|
|
693
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
694
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
695
|
-
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
|
991
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
992
|
+
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
|
696
993
|
|
|
697
994
|
dtype = pred_scores.dtype
|
|
698
|
-
imgsz = torch.tensor(
|
|
699
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
995
|
+
imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=dtype) # image size (h,w)
|
|
700
996
|
|
|
701
997
|
# targets
|
|
702
998
|
try:
|
|
@@ -704,14 +1000,14 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
704
1000
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
|
705
1001
|
rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
|
|
706
1002
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
|
707
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
1003
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
708
1004
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
|
709
1005
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
710
1006
|
except RuntimeError as e:
|
|
711
1007
|
raise TypeError(
|
|
712
1008
|
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
|
713
1009
|
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
|
714
|
-
"i.e. 'yolo train model=
|
|
1010
|
+
"i.e. 'yolo train model=yolo26n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
|
|
715
1011
|
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
|
716
1012
|
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
|
717
1013
|
) from e
|
|
@@ -741,16 +1037,29 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
741
1037
|
if fg_mask.sum():
|
|
742
1038
|
target_bboxes[..., :4] /= stride_tensor
|
|
743
1039
|
loss[0], loss[2] = self.bbox_loss(
|
|
744
|
-
pred_distri,
|
|
1040
|
+
pred_distri,
|
|
1041
|
+
pred_bboxes,
|
|
1042
|
+
anchor_points,
|
|
1043
|
+
target_bboxes,
|
|
1044
|
+
target_scores,
|
|
1045
|
+
target_scores_sum,
|
|
1046
|
+
fg_mask,
|
|
1047
|
+
imgsz,
|
|
1048
|
+
stride_tensor,
|
|
745
1049
|
)
|
|
1050
|
+
weight = target_scores.sum(-1)[fg_mask]
|
|
1051
|
+
loss[3] = self.calculate_angle_loss(
|
|
1052
|
+
pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
|
|
1053
|
+
) # angle loss
|
|
746
1054
|
else:
|
|
747
1055
|
loss[0] += (pred_angle * 0).sum()
|
|
748
1056
|
|
|
749
1057
|
loss[0] *= self.hyp.box # box gain
|
|
750
1058
|
loss[1] *= self.hyp.cls # cls gain
|
|
751
1059
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
1060
|
+
loss[3] *= self.hyp.angle # angle gain
|
|
752
1061
|
|
|
753
|
-
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
1062
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
|
|
754
1063
|
|
|
755
1064
|
def bbox_decode(
|
|
756
1065
|
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
|
@@ -770,6 +1079,34 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
770
1079
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
|
771
1080
|
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
|
772
1081
|
|
|
1082
|
+
def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
|
|
1083
|
+
"""Calculate oriented angle loss.
|
|
1084
|
+
|
|
1085
|
+
Args:
|
|
1086
|
+
pred_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1087
|
+
target_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1088
|
+
fg_mask: Foreground mask indicating valid predictions.
|
|
1089
|
+
weight: Loss weights for each prediction.
|
|
1090
|
+
target_scores_sum: Sum of target scores for normalization.
|
|
1091
|
+
lambda_val: control the sensitivity to aspect ratio.
|
|
1092
|
+
"""
|
|
1093
|
+
w_gt = target_bboxes[..., 2]
|
|
1094
|
+
h_gt = target_bboxes[..., 3]
|
|
1095
|
+
pred_theta = pred_bboxes[..., 4]
|
|
1096
|
+
target_theta = target_bboxes[..., 4]
|
|
1097
|
+
|
|
1098
|
+
log_ar = torch.log(w_gt / h_gt)
|
|
1099
|
+
scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
|
|
1100
|
+
|
|
1101
|
+
delta_theta = pred_theta - target_theta
|
|
1102
|
+
delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
|
|
1103
|
+
ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
|
|
1104
|
+
|
|
1105
|
+
ang_loss = scale_weight[fg_mask] * ang_loss
|
|
1106
|
+
ang_loss = ang_loss * weight
|
|
1107
|
+
|
|
1108
|
+
return ang_loss.sum() / target_scores_sum
|
|
1109
|
+
|
|
773
1110
|
|
|
774
1111
|
class E2EDetectLoss:
|
|
775
1112
|
"""Criterion class for computing training losses for end-to-end detection."""
|
|
@@ -789,61 +1126,108 @@ class E2EDetectLoss:
|
|
|
789
1126
|
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
|
|
790
1127
|
|
|
791
1128
|
|
|
1129
|
+
class E2ELoss:
|
|
1130
|
+
"""Criterion class for computing training losses for end-to-end detection."""
|
|
1131
|
+
|
|
1132
|
+
def __init__(self, model, loss_fn=v8DetectionLoss):
|
|
1133
|
+
"""Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
|
|
1134
|
+
self.one2many = loss_fn(model, tal_topk=10)
|
|
1135
|
+
self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
|
|
1136
|
+
self.updates = 0
|
|
1137
|
+
self.total = 1.0
|
|
1138
|
+
# init gain
|
|
1139
|
+
self.o2m = 0.8
|
|
1140
|
+
self.o2o = self.total - self.o2m
|
|
1141
|
+
self.o2m_copy = self.o2m
|
|
1142
|
+
# final gain
|
|
1143
|
+
self.final_o2m = 0.1
|
|
1144
|
+
|
|
1145
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1146
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
1147
|
+
preds = self.one2many.parse_output(preds)
|
|
1148
|
+
one2many, one2one = preds["one2many"], preds["one2one"]
|
|
1149
|
+
loss_one2many = self.one2many.loss(one2many, batch)
|
|
1150
|
+
loss_one2one = self.one2one.loss(one2one, batch)
|
|
1151
|
+
return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
|
|
1152
|
+
|
|
1153
|
+
def update(self) -> None:
|
|
1154
|
+
"""Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
|
|
1155
|
+
self.updates += 1
|
|
1156
|
+
self.o2m = self.decay(self.updates)
|
|
1157
|
+
self.o2o = max(self.total - self.o2m, 0)
|
|
1158
|
+
|
|
1159
|
+
def decay(self, x) -> float:
|
|
1160
|
+
"""Calculate the decayed weight for one-to-many loss based on the current update step."""
|
|
1161
|
+
return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
|
|
1162
|
+
|
|
1163
|
+
|
|
792
1164
|
class TVPDetectLoss:
|
|
793
1165
|
"""Criterion class for computing training losses for text-visual prompt detection."""
|
|
794
1166
|
|
|
795
|
-
def __init__(self, model):
|
|
1167
|
+
def __init__(self, model, tal_topk=10):
|
|
796
1168
|
"""Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
797
|
-
self.vp_criterion = v8DetectionLoss(model)
|
|
1169
|
+
self.vp_criterion = v8DetectionLoss(model, tal_topk)
|
|
798
1170
|
# NOTE: store following info as it's changeable in __call__
|
|
1171
|
+
self.hyp = self.vp_criterion.hyp
|
|
799
1172
|
self.ori_nc = self.vp_criterion.nc
|
|
800
1173
|
self.ori_no = self.vp_criterion.no
|
|
801
1174
|
self.ori_reg_max = self.vp_criterion.reg_max
|
|
802
1175
|
|
|
1176
|
+
def parse_output(self, preds) -> dict[str, torch.Tensor]:
|
|
1177
|
+
"""Parse model predictions to extract features."""
|
|
1178
|
+
return self.vp_criterion.parse_output(preds)
|
|
1179
|
+
|
|
803
1180
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
804
1181
|
"""Calculate the loss for text-visual prompt detection."""
|
|
805
|
-
|
|
1182
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1183
|
+
|
|
1184
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1185
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
1186
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
806
1187
|
|
|
807
|
-
if self.
|
|
1188
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
808
1189
|
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
|
|
809
1190
|
return loss, loss.detach()
|
|
810
1191
|
|
|
811
|
-
|
|
812
|
-
vp_loss = self.vp_criterion(
|
|
813
|
-
|
|
814
|
-
return
|
|
1192
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1193
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
1194
|
+
box_loss = vp_loss[0][1]
|
|
1195
|
+
return box_loss, vp_loss[1]
|
|
815
1196
|
|
|
816
|
-
def _get_vp_features(self,
|
|
1197
|
+
def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
|
|
817
1198
|
"""Extract visual-prompt features from the model output."""
|
|
818
|
-
|
|
1199
|
+
# NOTE: remove empty placeholder
|
|
1200
|
+
scores = preds["scores"][:, self.ori_nc :, :]
|
|
1201
|
+
vnc = scores.shape[1]
|
|
819
1202
|
|
|
820
1203
|
self.vp_criterion.nc = vnc
|
|
821
1204
|
self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
|
|
822
1205
|
self.vp_criterion.assigner.num_classes = vnc
|
|
823
|
-
|
|
824
|
-
return [
|
|
825
|
-
torch.cat((box, cls_vp), dim=1)
|
|
826
|
-
for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
|
|
827
|
-
]
|
|
1206
|
+
return scores
|
|
828
1207
|
|
|
829
1208
|
|
|
830
1209
|
class TVPSegmentLoss(TVPDetectLoss):
|
|
831
1210
|
"""Criterion class for computing training losses for text-visual prompt segmentation."""
|
|
832
1211
|
|
|
833
|
-
def __init__(self, model):
|
|
1212
|
+
def __init__(self, model, tal_topk=10):
|
|
834
1213
|
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
835
1214
|
super().__init__(model)
|
|
836
|
-
self.vp_criterion = v8SegmentationLoss(model)
|
|
1215
|
+
self.vp_criterion = v8SegmentationLoss(model, tal_topk)
|
|
1216
|
+
self.hyp = self.vp_criterion.hyp
|
|
837
1217
|
|
|
838
1218
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
839
1219
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
840
|
-
|
|
1220
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1221
|
+
|
|
1222
|
+
def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1223
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
1224
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
841
1225
|
|
|
842
|
-
if self.
|
|
1226
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
843
1227
|
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
|
844
1228
|
return loss, loss.detach()
|
|
845
1229
|
|
|
846
|
-
|
|
847
|
-
vp_loss = self.vp_criterion(
|
|
1230
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1231
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
848
1232
|
cls_loss = vp_loss[0][2]
|
|
849
1233
|
return cls_loss, vp_loss[1]
|