ultralytics-opencv-headless 8.3.253__py3-none-any.whl → 8.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +2 -2
- tests/conftest.py +1 -1
- tests/test_cuda.py +8 -2
- tests/test_engine.py +6 -6
- tests/test_exports.py +10 -3
- tests/test_integrations.py +9 -9
- tests/test_python.py +14 -14
- tests/test_solutions.py +3 -3
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +6 -6
- ultralytics/cfg/default.yaml +3 -1
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/data/augment.py +7 -0
- ultralytics/data/dataset.py +1 -1
- ultralytics/engine/exporter.py +10 -3
- ultralytics/engine/model.py +1 -1
- ultralytics/engine/trainer.py +40 -15
- ultralytics/engine/tuner.py +15 -7
- ultralytics/models/fastsam/predict.py +1 -1
- ultralytics/models/yolo/detect/train.py +3 -2
- ultralytics/models/yolo/detect/val.py +6 -0
- ultralytics/models/yolo/model.py +1 -1
- ultralytics/models/yolo/obb/predict.py +1 -1
- ultralytics/models/yolo/obb/train.py +1 -1
- ultralytics/models/yolo/pose/train.py +1 -1
- ultralytics/models/yolo/segment/predict.py +1 -1
- ultralytics/models/yolo/segment/train.py +1 -1
- ultralytics/models/yolo/segment/val.py +3 -1
- ultralytics/models/yolo/yoloe/train.py +6 -1
- ultralytics/models/yolo/yoloe/train_seg.py +6 -1
- ultralytics/nn/autobackend.py +7 -3
- ultralytics/nn/modules/__init__.py +8 -0
- ultralytics/nn/modules/block.py +127 -8
- ultralytics/nn/modules/head.py +818 -205
- ultralytics/nn/tasks.py +74 -29
- ultralytics/nn/text_model.py +5 -2
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/utils/benchmarks.py +1 -0
- ultralytics/utils/callbacks/platform.py +9 -7
- ultralytics/utils/downloads.py +3 -1
- ultralytics/utils/export/engine.py +19 -10
- ultralytics/utils/export/imx.py +22 -11
- ultralytics/utils/export/tensorflow.py +1 -41
- ultralytics/utils/loss.py +584 -203
- ultralytics/utils/metrics.py +1 -0
- ultralytics/utils/ops.py +11 -2
- ultralytics/utils/tal.py +98 -19
- {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/METADATA +31 -39
- {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/RECORD +62 -51
- {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/WHEEL +0 -0
- {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/entry_points.txt +0 -0
- {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/licenses/LICENSE +0 -0
- {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py
CHANGED
|
@@ -2,19 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
import torch.nn as nn
|
|
9
10
|
import torch.nn.functional as F
|
|
10
11
|
|
|
11
|
-
from ultralytics.utils.metrics import OKS_SIGMA
|
|
12
|
+
from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
|
|
12
13
|
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
|
13
14
|
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
|
14
15
|
from ultralytics.utils.torch_utils import autocast
|
|
15
16
|
|
|
16
17
|
from .metrics import bbox_iou, probiou
|
|
17
|
-
from .tal import bbox2dist
|
|
18
|
+
from .tal import bbox2dist, rbox2dist
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class VarifocalLoss(nn.Module):
|
|
@@ -122,6 +123,8 @@ class BboxLoss(nn.Module):
|
|
|
122
123
|
target_scores: torch.Tensor,
|
|
123
124
|
target_scores_sum: torch.Tensor,
|
|
124
125
|
fg_mask: torch.Tensor,
|
|
126
|
+
imgsz: torch.Tensor,
|
|
127
|
+
stride: torch.Tensor,
|
|
125
128
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
126
129
|
"""Compute IoU and DFL losses for bounding boxes."""
|
|
127
130
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -134,11 +137,75 @@ class BboxLoss(nn.Module):
|
|
|
134
137
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
135
138
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
136
139
|
else:
|
|
137
|
-
|
|
140
|
+
target_ltrb = bbox2dist(anchor_points, target_bboxes)
|
|
141
|
+
# normalize ltrb by image size
|
|
142
|
+
target_ltrb = target_ltrb * stride
|
|
143
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
144
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
145
|
+
pred_dist = pred_dist * stride
|
|
146
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
147
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
148
|
+
loss_dfl = (
|
|
149
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
150
|
+
)
|
|
151
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
138
152
|
|
|
139
153
|
return loss_iou, loss_dfl
|
|
140
154
|
|
|
141
155
|
|
|
156
|
+
class RLELoss(nn.Module):
|
|
157
|
+
"""Residual Log-Likelihood Estimation Loss.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
use_target_weight (bool): Option to use weighted loss.
|
|
161
|
+
size_average (bool): Option to average the loss by the batch_size.
|
|
162
|
+
residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
|
|
163
|
+
|
|
164
|
+
References:
|
|
165
|
+
https://arxiv.org/abs/2107.11291
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
|
|
169
|
+
"""Initialize RLELoss with target weight and residual options.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
use_target_weight (bool): Whether to use target weights for loss calculation.
|
|
173
|
+
size_average (bool): Whether to average the loss over elements.
|
|
174
|
+
residual (bool): Whether to include residual log-likelihood term.
|
|
175
|
+
"""
|
|
176
|
+
super().__init__()
|
|
177
|
+
self.size_average = size_average
|
|
178
|
+
self.use_target_weight = use_target_weight
|
|
179
|
+
self.residual = residual
|
|
180
|
+
|
|
181
|
+
def forward(
|
|
182
|
+
self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
|
|
183
|
+
) -> torch.Tensor:
|
|
184
|
+
"""
|
|
185
|
+
Args:
|
|
186
|
+
sigma (torch.Tensor): Output sigma, shape (N, D).
|
|
187
|
+
log_phi (torch.Tensor): Output log_phi, shape (N).
|
|
188
|
+
error (torch.Tensor): Error, shape (N, D).
|
|
189
|
+
target_weight (torch.Tensor): Weights across different joint types, shape (N).
|
|
190
|
+
"""
|
|
191
|
+
log_sigma = torch.log(sigma)
|
|
192
|
+
loss = log_sigma - log_phi.unsqueeze(1)
|
|
193
|
+
|
|
194
|
+
if self.residual:
|
|
195
|
+
loss += torch.log(sigma * 2) + torch.abs(error)
|
|
196
|
+
|
|
197
|
+
if self.use_target_weight:
|
|
198
|
+
assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
|
|
199
|
+
if target_weight.dim() == 1:
|
|
200
|
+
target_weight = target_weight.unsqueeze(1)
|
|
201
|
+
loss *= target_weight
|
|
202
|
+
|
|
203
|
+
if self.size_average:
|
|
204
|
+
loss /= len(loss)
|
|
205
|
+
|
|
206
|
+
return loss.sum()
|
|
207
|
+
|
|
208
|
+
|
|
142
209
|
class RotatedBboxLoss(BboxLoss):
|
|
143
210
|
"""Criterion class for computing training losses for rotated bounding boxes."""
|
|
144
211
|
|
|
@@ -155,6 +222,8 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
155
222
|
target_scores: torch.Tensor,
|
|
156
223
|
target_scores_sum: torch.Tensor,
|
|
157
224
|
fg_mask: torch.Tensor,
|
|
225
|
+
imgsz: torch.Tensor,
|
|
226
|
+
stride: torch.Tensor,
|
|
158
227
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
159
228
|
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
|
160
229
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
@@ -163,15 +232,84 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
163
232
|
|
|
164
233
|
# DFL loss
|
|
165
234
|
if self.dfl_loss:
|
|
166
|
-
target_ltrb =
|
|
235
|
+
target_ltrb = rbox2dist(
|
|
236
|
+
target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
|
|
237
|
+
)
|
|
167
238
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
168
239
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
169
240
|
else:
|
|
170
|
-
|
|
241
|
+
target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
|
|
242
|
+
target_ltrb = target_ltrb * stride
|
|
243
|
+
target_ltrb[..., 0::2] /= imgsz[1]
|
|
244
|
+
target_ltrb[..., 1::2] /= imgsz[0]
|
|
245
|
+
pred_dist = pred_dist * stride
|
|
246
|
+
pred_dist[..., 0::2] /= imgsz[1]
|
|
247
|
+
pred_dist[..., 1::2] /= imgsz[0]
|
|
248
|
+
loss_dfl = (
|
|
249
|
+
F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
|
|
250
|
+
)
|
|
251
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
171
252
|
|
|
172
253
|
return loss_iou, loss_dfl
|
|
173
254
|
|
|
174
255
|
|
|
256
|
+
class MultiChannelDiceLoss(nn.Module):
|
|
257
|
+
"""Criterion class for computing multi-channel Dice losses."""
|
|
258
|
+
|
|
259
|
+
def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
|
|
260
|
+
"""Initialize MultiChannelDiceLoss with smoothing and reduction options.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
264
|
+
reduction (str): Reduction method ('mean', 'sum', or 'none').
|
|
265
|
+
"""
|
|
266
|
+
super().__init__()
|
|
267
|
+
self.smooth = smooth
|
|
268
|
+
self.reduction = reduction
|
|
269
|
+
|
|
270
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
271
|
+
"""Calculate multi-channel Dice loss between predictions and targets."""
|
|
272
|
+
assert pred.size() == target.size(), "the size of predict and target must be equal."
|
|
273
|
+
|
|
274
|
+
pred = pred.sigmoid()
|
|
275
|
+
intersection = (pred * target).sum(dim=(2, 3))
|
|
276
|
+
union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
|
|
277
|
+
dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
|
278
|
+
dice_loss = 1.0 - dice
|
|
279
|
+
dice_loss = dice_loss.mean(dim=1)
|
|
280
|
+
|
|
281
|
+
if self.reduction == "mean":
|
|
282
|
+
return dice_loss.mean()
|
|
283
|
+
elif self.reduction == "sum":
|
|
284
|
+
return dice_loss.sum()
|
|
285
|
+
else:
|
|
286
|
+
return dice_loss
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class BCEDiceLoss(nn.Module):
|
|
290
|
+
"""Criterion class for computing combined BCE and Dice losses."""
|
|
291
|
+
|
|
292
|
+
def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
|
|
293
|
+
"""Initialize BCEDiceLoss with BCE and Dice weight factors.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
weight_bce (float): Weight factor for BCE loss component.
|
|
297
|
+
weight_dice (float): Weight factor for Dice loss component.
|
|
298
|
+
"""
|
|
299
|
+
super().__init__()
|
|
300
|
+
self.weight_bce = weight_bce
|
|
301
|
+
self.weight_dice = weight_dice
|
|
302
|
+
self.bce = nn.BCEWithLogitsLoss()
|
|
303
|
+
self.dice = MultiChannelDiceLoss(smooth=1)
|
|
304
|
+
|
|
305
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
306
|
+
"""Calculate combined BCE and Dice loss between predictions and targets."""
|
|
307
|
+
_, _, mask_h, mask_w = pred.shape
|
|
308
|
+
if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
|
|
309
|
+
target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
|
|
310
|
+
return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
|
|
311
|
+
|
|
312
|
+
|
|
175
313
|
class KeypointLoss(nn.Module):
|
|
176
314
|
"""Criterion class for computing keypoint losses."""
|
|
177
315
|
|
|
@@ -194,7 +332,7 @@ class KeypointLoss(nn.Module):
|
|
|
194
332
|
class v8DetectionLoss:
|
|
195
333
|
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
|
196
334
|
|
|
197
|
-
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
|
|
335
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
198
336
|
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
|
199
337
|
device = next(model.parameters()).device # get model device
|
|
200
338
|
h = model.args # hyperparameters
|
|
@@ -210,7 +348,14 @@ class v8DetectionLoss:
|
|
|
210
348
|
|
|
211
349
|
self.use_dfl = m.reg_max > 1
|
|
212
350
|
|
|
213
|
-
self.assigner = TaskAlignedAssigner(
|
|
351
|
+
self.assigner = TaskAlignedAssigner(
|
|
352
|
+
topk=tal_topk,
|
|
353
|
+
num_classes=self.nc,
|
|
354
|
+
alpha=0.5,
|
|
355
|
+
beta=6.0,
|
|
356
|
+
stride=self.stride.tolist(),
|
|
357
|
+
topk2=tal_topk2,
|
|
358
|
+
)
|
|
214
359
|
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
|
215
360
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
|
216
361
|
|
|
@@ -240,35 +385,31 @@ class v8DetectionLoss:
|
|
|
240
385
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
|
241
386
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
|
242
387
|
|
|
243
|
-
def
|
|
244
|
-
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size
|
|
388
|
+
def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
|
|
389
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
|
|
390
|
+
target indices.
|
|
391
|
+
"""
|
|
245
392
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
(
|
|
393
|
+
pred_distri, pred_scores = (
|
|
394
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
395
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
249
396
|
)
|
|
250
|
-
|
|
251
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
252
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
397
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
253
398
|
|
|
254
399
|
dtype = pred_scores.dtype
|
|
255
400
|
batch_size = pred_scores.shape[0]
|
|
256
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
257
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
401
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
258
402
|
|
|
259
403
|
# Targets
|
|
260
404
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
261
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
405
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
262
406
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
263
407
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
264
408
|
|
|
265
409
|
# Pboxes
|
|
266
410
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
267
|
-
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
|
|
268
|
-
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
|
|
269
411
|
|
|
270
|
-
_, target_bboxes, target_scores, fg_mask,
|
|
271
|
-
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
|
|
412
|
+
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
272
413
|
pred_scores.detach().sigmoid(),
|
|
273
414
|
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
274
415
|
anchor_points * stride_tensor,
|
|
@@ -280,7 +421,6 @@ class v8DetectionLoss:
|
|
|
280
421
|
target_scores_sum = max(target_scores.sum(), 1)
|
|
281
422
|
|
|
282
423
|
# Cls loss
|
|
283
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
284
424
|
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
285
425
|
|
|
286
426
|
# Bbox loss
|
|
@@ -293,105 +433,97 @@ class v8DetectionLoss:
|
|
|
293
433
|
target_scores,
|
|
294
434
|
target_scores_sum,
|
|
295
435
|
fg_mask,
|
|
436
|
+
imgsz,
|
|
437
|
+
stride_tensor,
|
|
296
438
|
)
|
|
297
439
|
|
|
298
440
|
loss[0] *= self.hyp.box # box gain
|
|
299
441
|
loss[1] *= self.hyp.cls # cls gain
|
|
300
442
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
443
|
+
return (
|
|
444
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
|
|
445
|
+
loss,
|
|
446
|
+
loss.detach(),
|
|
447
|
+
) # loss(box, cls, dfl)
|
|
301
448
|
|
|
302
|
-
|
|
449
|
+
def parse_output(
|
|
450
|
+
self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
|
|
451
|
+
) -> torch.Tensor:
|
|
452
|
+
"""Parse model predictions to extract features."""
|
|
453
|
+
return preds[1] if isinstance(preds, tuple) else preds
|
|
454
|
+
|
|
455
|
+
def __call__(
|
|
456
|
+
self,
|
|
457
|
+
preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
|
|
458
|
+
batch: dict[str, torch.Tensor],
|
|
459
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
460
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
461
|
+
return self.loss(self.parse_output(preds), batch)
|
|
462
|
+
|
|
463
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
464
|
+
"""A wrapper for get_assigned_targets_and_loss and parse_output."""
|
|
465
|
+
batch_size = preds["boxes"].shape[0]
|
|
466
|
+
loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
|
|
467
|
+
return loss * batch_size, loss_detach
|
|
303
468
|
|
|
304
469
|
|
|
305
470
|
class v8SegmentationLoss(v8DetectionLoss):
|
|
306
471
|
"""Criterion class for computing training losses for YOLOv8 segmentation."""
|
|
307
472
|
|
|
308
|
-
def __init__(self, model): # model must be de-paralleled
|
|
473
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
309
474
|
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
|
|
310
|
-
super().__init__(model)
|
|
475
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
311
476
|
self.overlap = model.args.overlap_mask
|
|
477
|
+
self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
|
|
312
478
|
|
|
313
|
-
def
|
|
479
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
314
480
|
"""Calculate and return the combined loss for detection and segmentation."""
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
#
|
|
323
|
-
|
|
324
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
325
|
-
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
|
326
|
-
|
|
327
|
-
dtype = pred_scores.dtype
|
|
328
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
329
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
330
|
-
|
|
331
|
-
# Targets
|
|
332
|
-
try:
|
|
333
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
334
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
335
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
336
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
337
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
338
|
-
except RuntimeError as e:
|
|
339
|
-
raise TypeError(
|
|
340
|
-
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
|
341
|
-
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
|
342
|
-
"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
|
343
|
-
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
|
344
|
-
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
|
345
|
-
) from e
|
|
346
|
-
|
|
347
|
-
# Pboxes
|
|
348
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
349
|
-
|
|
350
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
351
|
-
pred_scores.detach().sigmoid(),
|
|
352
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
353
|
-
anchor_points * stride_tensor,
|
|
354
|
-
gt_labels,
|
|
355
|
-
gt_bboxes,
|
|
356
|
-
mask_gt,
|
|
357
|
-
)
|
|
358
|
-
|
|
359
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
360
|
-
|
|
361
|
-
# Cls loss
|
|
362
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
363
|
-
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
481
|
+
pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
|
|
482
|
+
loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
|
|
483
|
+
if len(proto) == 2:
|
|
484
|
+
proto, pred_semseg = proto
|
|
485
|
+
else:
|
|
486
|
+
pred_semseg = None
|
|
487
|
+
(fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
|
|
488
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
489
|
+
loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
|
|
364
490
|
|
|
491
|
+
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
|
365
492
|
if fg_mask.sum():
|
|
366
|
-
# Bbox loss
|
|
367
|
-
loss[0], loss[3] = self.bbox_loss(
|
|
368
|
-
pred_distri,
|
|
369
|
-
pred_bboxes,
|
|
370
|
-
anchor_points,
|
|
371
|
-
target_bboxes / stride_tensor,
|
|
372
|
-
target_scores,
|
|
373
|
-
target_scores_sum,
|
|
374
|
-
fg_mask,
|
|
375
|
-
)
|
|
376
493
|
# Masks loss
|
|
377
494
|
masks = batch["masks"].to(self.device).float()
|
|
378
495
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
|
379
|
-
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
496
|
+
# masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
|
497
|
+
proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
|
|
380
498
|
|
|
499
|
+
imgsz = (
|
|
500
|
+
torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
|
|
501
|
+
)
|
|
381
502
|
loss[1] = self.calculate_segmentation_loss(
|
|
382
|
-
fg_mask,
|
|
503
|
+
fg_mask,
|
|
504
|
+
masks,
|
|
505
|
+
target_gt_idx,
|
|
506
|
+
target_bboxes,
|
|
507
|
+
batch["batch_idx"].view(-1, 1),
|
|
508
|
+
proto,
|
|
509
|
+
pred_masks,
|
|
510
|
+
imgsz,
|
|
383
511
|
)
|
|
512
|
+
if pred_semseg is not None:
|
|
513
|
+
sem_masks = batch["sem_masks"].to(self.device) # NxHxW
|
|
514
|
+
mask_zero = sem_masks == 0 # NxHxW
|
|
515
|
+
sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
|
|
516
|
+
sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
|
|
517
|
+
loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
|
|
518
|
+
loss[4] *= self.hyp.box # seg gain
|
|
384
519
|
|
|
385
520
|
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
|
386
521
|
else:
|
|
387
522
|
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
|
523
|
+
loss[4] += (pred_semseg * 0).sum() + (sem_masks * 0).sum()
|
|
388
524
|
|
|
389
|
-
loss[0] *= self.hyp.box # box gain
|
|
390
525
|
loss[1] *= self.hyp.box # seg gain
|
|
391
|
-
loss
|
|
392
|
-
loss[3] *= self.hyp.dfl # dfl gain
|
|
393
|
-
|
|
394
|
-
return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
|
|
526
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
395
527
|
|
|
396
528
|
@staticmethod
|
|
397
529
|
def single_mask_loss(
|
|
@@ -427,7 +559,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
427
559
|
proto: torch.Tensor,
|
|
428
560
|
pred_masks: torch.Tensor,
|
|
429
561
|
imgsz: torch.Tensor,
|
|
430
|
-
overlap: bool,
|
|
431
562
|
) -> torch.Tensor:
|
|
432
563
|
"""Calculate the loss for instance segmentation.
|
|
433
564
|
|
|
@@ -440,7 +571,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
440
571
|
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
|
441
572
|
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
|
442
573
|
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
|
443
|
-
overlap (bool): Whether the masks in `masks` tensor overlap.
|
|
444
574
|
|
|
445
575
|
Returns:
|
|
446
576
|
(torch.Tensor): The calculated loss for instance segmentation.
|
|
@@ -466,7 +596,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
466
596
|
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
|
467
597
|
if fg_mask_i.any():
|
|
468
598
|
mask_idx = target_gt_idx_i[fg_mask_i]
|
|
469
|
-
if overlap:
|
|
599
|
+
if self.overlap:
|
|
470
600
|
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
|
471
601
|
gt_mask = gt_mask.float()
|
|
472
602
|
else:
|
|
@@ -486,9 +616,9 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
486
616
|
class v8PoseLoss(v8DetectionLoss):
|
|
487
617
|
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
|
|
488
618
|
|
|
489
|
-
def __init__(self, model): # model must be de-paralleled
|
|
619
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
|
|
490
620
|
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
|
|
491
|
-
super().__init__(model)
|
|
621
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
492
622
|
self.kpt_shape = model.model[-1].kpt_shape
|
|
493
623
|
self.bce_pose = nn.BCEWithLogitsLoss()
|
|
494
624
|
is_pose = self.kpt_shape == [17, 3]
|
|
@@ -496,69 +626,40 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
496
626
|
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
|
497
627
|
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
|
498
628
|
|
|
499
|
-
def
|
|
629
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
500
630
|
"""Calculate the total loss and detach it for pose estimation."""
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
631
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
632
|
+
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
633
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
634
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
505
635
|
)
|
|
636
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
637
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
506
638
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
510
|
-
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
|
511
|
-
|
|
512
|
-
dtype = pred_scores.dtype
|
|
513
|
-
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
|
514
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
515
|
-
|
|
516
|
-
# Targets
|
|
517
|
-
batch_size = pred_scores.shape[0]
|
|
518
|
-
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
519
|
-
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
520
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
521
|
-
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
522
|
-
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
639
|
+
batch_size = pred_kpts.shape[0]
|
|
640
|
+
imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
|
|
523
641
|
|
|
524
642
|
# Pboxes
|
|
525
|
-
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
526
643
|
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
|
527
644
|
|
|
528
|
-
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
|
529
|
-
pred_scores.detach().sigmoid(),
|
|
530
|
-
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
531
|
-
anchor_points * stride_tensor,
|
|
532
|
-
gt_labels,
|
|
533
|
-
gt_bboxes,
|
|
534
|
-
mask_gt,
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
target_scores_sum = max(target_scores.sum(), 1)
|
|
538
|
-
|
|
539
|
-
# Cls loss
|
|
540
|
-
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
|
541
|
-
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
|
542
|
-
|
|
543
645
|
# Bbox loss
|
|
544
646
|
if fg_mask.sum():
|
|
545
|
-
target_bboxes /= stride_tensor
|
|
546
|
-
loss[0], loss[4] = self.bbox_loss(
|
|
547
|
-
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
|
548
|
-
)
|
|
549
647
|
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
550
648
|
keypoints[..., 0] *= imgsz[1]
|
|
551
649
|
keypoints[..., 1] *= imgsz[0]
|
|
552
650
|
|
|
553
651
|
loss[1], loss[2] = self.calculate_keypoints_loss(
|
|
554
|
-
fg_mask,
|
|
652
|
+
fg_mask,
|
|
653
|
+
target_gt_idx,
|
|
654
|
+
keypoints,
|
|
655
|
+
batch["batch_idx"].view(-1, 1),
|
|
656
|
+
stride_tensor,
|
|
657
|
+
target_bboxes,
|
|
658
|
+
pred_kpts,
|
|
555
659
|
)
|
|
556
660
|
|
|
557
|
-
loss[0] *= self.hyp.box # box gain
|
|
558
661
|
loss[1] *= self.hyp.pose # pose gain
|
|
559
662
|
loss[2] *= self.hyp.kobj # kobj gain
|
|
560
|
-
loss[3] *= self.hyp.cls # cls gain
|
|
561
|
-
loss[4] *= self.hyp.dfl # dfl gain
|
|
562
663
|
|
|
563
664
|
return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
|
|
564
665
|
|
|
@@ -571,34 +672,23 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
571
672
|
y[..., 1] += anchor_points[:, [1]] - 0.5
|
|
572
673
|
return y
|
|
573
674
|
|
|
574
|
-
def
|
|
675
|
+
def _select_target_keypoints(
|
|
575
676
|
self,
|
|
576
|
-
masks: torch.Tensor,
|
|
577
|
-
target_gt_idx: torch.Tensor,
|
|
578
677
|
keypoints: torch.Tensor,
|
|
579
678
|
batch_idx: torch.Tensor,
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
"""Calculate the keypoints loss for the model.
|
|
585
|
-
|
|
586
|
-
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
587
|
-
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
588
|
-
a binary classification loss that classifies whether a keypoint is present or not.
|
|
679
|
+
target_gt_idx: torch.Tensor,
|
|
680
|
+
masks: torch.Tensor,
|
|
681
|
+
) -> torch.Tensor:
|
|
682
|
+
"""Select target keypoints for each anchor based on batch index and target ground truth index.
|
|
589
683
|
|
|
590
684
|
Args:
|
|
591
|
-
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
592
|
-
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
593
685
|
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
594
686
|
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
687
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
688
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
598
689
|
|
|
599
690
|
Returns:
|
|
600
|
-
|
|
601
|
-
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
691
|
+
(torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
602
692
|
"""
|
|
603
693
|
batch_idx = batch_idx.flatten()
|
|
604
694
|
batch_size = len(masks)
|
|
@@ -625,6 +715,40 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
625
715
|
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
|
626
716
|
)
|
|
627
717
|
|
|
718
|
+
return selected_keypoints
|
|
719
|
+
|
|
720
|
+
def calculate_keypoints_loss(
|
|
721
|
+
self,
|
|
722
|
+
masks: torch.Tensor,
|
|
723
|
+
target_gt_idx: torch.Tensor,
|
|
724
|
+
keypoints: torch.Tensor,
|
|
725
|
+
batch_idx: torch.Tensor,
|
|
726
|
+
stride_tensor: torch.Tensor,
|
|
727
|
+
target_bboxes: torch.Tensor,
|
|
728
|
+
pred_kpts: torch.Tensor,
|
|
729
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
730
|
+
"""Calculate the keypoints loss for the model.
|
|
731
|
+
|
|
732
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
733
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
734
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
738
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
739
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
740
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
741
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
742
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
743
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
747
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
748
|
+
"""
|
|
749
|
+
# Select target keypoints using helper method
|
|
750
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
751
|
+
|
|
628
752
|
# Divide coordinates by stride
|
|
629
753
|
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
630
754
|
|
|
@@ -632,6 +756,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
632
756
|
kpts_obj_loss = 0
|
|
633
757
|
|
|
634
758
|
if masks.any():
|
|
759
|
+
target_bboxes /= stride_tensor
|
|
635
760
|
gt_kpt = selected_keypoints[masks]
|
|
636
761
|
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
637
762
|
pred_kpt = pred_kpts[masks]
|
|
@@ -644,6 +769,171 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
644
769
|
return kpts_loss, kpts_obj_loss
|
|
645
770
|
|
|
646
771
|
|
|
772
|
+
class PoseLoss26(v8PoseLoss):
|
|
773
|
+
"""Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
|
|
774
|
+
|
|
775
|
+
def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
|
|
776
|
+
"""Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
|
|
777
|
+
super().__init__(model, tal_topk, tal_topk2)
|
|
778
|
+
is_pose = self.kpt_shape == [17, 3]
|
|
779
|
+
nkpt = self.kpt_shape[0] # number of keypoints
|
|
780
|
+
self.rle_loss = None
|
|
781
|
+
self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
|
|
782
|
+
if self.flow_model is not None:
|
|
783
|
+
self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
|
|
784
|
+
self.target_weights = (
|
|
785
|
+
torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
789
|
+
"""Calculate the total loss and detach it for pose estimation."""
|
|
790
|
+
pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
|
|
791
|
+
loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
792
|
+
(fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
|
|
793
|
+
self.get_assigned_targets_and_loss(preds, batch)
|
|
794
|
+
)
|
|
795
|
+
# NOTE: re-assign index for consistency for now. Need to be removed in the future.
|
|
796
|
+
loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
|
|
797
|
+
|
|
798
|
+
batch_size = pred_kpts.shape[0]
|
|
799
|
+
imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=pred_kpts.dtype) # image size (h,w)
|
|
800
|
+
|
|
801
|
+
pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
|
|
802
|
+
|
|
803
|
+
if self.rle_loss and preds.get("kpts_sigma", None) is not None:
|
|
804
|
+
pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
|
|
805
|
+
pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
|
|
806
|
+
pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
|
|
807
|
+
|
|
808
|
+
pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
|
|
809
|
+
|
|
810
|
+
# Bbox loss
|
|
811
|
+
if fg_mask.sum():
|
|
812
|
+
keypoints = batch["keypoints"].to(self.device).float().clone()
|
|
813
|
+
keypoints[..., 0] *= imgsz[1]
|
|
814
|
+
keypoints[..., 1] *= imgsz[0]
|
|
815
|
+
|
|
816
|
+
keypoints_loss = self.calculate_keypoints_loss(
|
|
817
|
+
fg_mask,
|
|
818
|
+
target_gt_idx,
|
|
819
|
+
keypoints,
|
|
820
|
+
batch["batch_idx"].view(-1, 1),
|
|
821
|
+
stride_tensor,
|
|
822
|
+
target_bboxes,
|
|
823
|
+
pred_kpts,
|
|
824
|
+
)
|
|
825
|
+
loss[1] = keypoints_loss[0]
|
|
826
|
+
loss[2] = keypoints_loss[1]
|
|
827
|
+
if self.rle_loss is not None:
|
|
828
|
+
loss[5] = keypoints_loss[2]
|
|
829
|
+
|
|
830
|
+
loss[1] *= self.hyp.pose # pose gain
|
|
831
|
+
loss[2] *= self.hyp.kobj # kobj gain
|
|
832
|
+
if self.rle_loss is not None:
|
|
833
|
+
loss[5] *= self.hyp.rle # rle gain
|
|
834
|
+
|
|
835
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
836
|
+
|
|
837
|
+
@staticmethod
|
|
838
|
+
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
|
839
|
+
"""Decode predicted keypoints to image coordinates."""
|
|
840
|
+
y = pred_kpts.clone()
|
|
841
|
+
y[..., 0] += anchor_points[:, [0]]
|
|
842
|
+
y[..., 1] += anchor_points[:, [1]]
|
|
843
|
+
return y
|
|
844
|
+
|
|
845
|
+
def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
|
|
846
|
+
"""Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
|
|
850
|
+
gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
|
|
851
|
+
kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
(torch.Tensor): The RLE loss.
|
|
855
|
+
"""
|
|
856
|
+
pred_kpt_visible = pred_kpt[kpt_mask]
|
|
857
|
+
gt_kpt_visible = gt_kpt[kpt_mask]
|
|
858
|
+
pred_coords = pred_kpt_visible[:, 0:2]
|
|
859
|
+
pred_sigma = pred_kpt_visible[:, -2:]
|
|
860
|
+
gt_coords = gt_kpt_visible[:, 0:2]
|
|
861
|
+
|
|
862
|
+
target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
|
|
863
|
+
target_weights = target_weights[kpt_mask]
|
|
864
|
+
|
|
865
|
+
pred_sigma = pred_sigma.sigmoid()
|
|
866
|
+
error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
|
|
867
|
+
|
|
868
|
+
# Filter out NaN values to prevent MultivariateNormal validation errors (can occur with small images)
|
|
869
|
+
valid_mask = ~torch.isnan(error).any(dim=-1)
|
|
870
|
+
if not valid_mask.any():
|
|
871
|
+
return torch.tensor(0.0, device=pred_kpt.device)
|
|
872
|
+
|
|
873
|
+
error = error[valid_mask]
|
|
874
|
+
pred_sigma = pred_sigma[valid_mask]
|
|
875
|
+
target_weights = target_weights[valid_mask]
|
|
876
|
+
|
|
877
|
+
log_phi = self.flow_model.log_prob(error)
|
|
878
|
+
|
|
879
|
+
return self.rle_loss(pred_sigma, log_phi, error, target_weights)
|
|
880
|
+
|
|
881
|
+
def calculate_keypoints_loss(
|
|
882
|
+
self,
|
|
883
|
+
masks: torch.Tensor,
|
|
884
|
+
target_gt_idx: torch.Tensor,
|
|
885
|
+
keypoints: torch.Tensor,
|
|
886
|
+
batch_idx: torch.Tensor,
|
|
887
|
+
stride_tensor: torch.Tensor,
|
|
888
|
+
target_bboxes: torch.Tensor,
|
|
889
|
+
pred_kpts: torch.Tensor,
|
|
890
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
891
|
+
"""Calculate the keypoints loss for the model.
|
|
892
|
+
|
|
893
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
894
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
895
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
|
896
|
+
|
|
897
|
+
Args:
|
|
898
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
|
899
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
|
900
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
|
901
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
|
902
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
|
903
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
|
904
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
|
908
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
909
|
+
rle_loss (torch.Tensor): The RLE loss.
|
|
910
|
+
"""
|
|
911
|
+
# Select target keypoints using inherited helper method
|
|
912
|
+
selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
|
|
913
|
+
|
|
914
|
+
# Divide coordinates by stride
|
|
915
|
+
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
|
916
|
+
|
|
917
|
+
kpts_loss = 0
|
|
918
|
+
kpts_obj_loss = 0
|
|
919
|
+
rle_loss = 0
|
|
920
|
+
|
|
921
|
+
if masks.any():
|
|
922
|
+
target_bboxes /= stride_tensor
|
|
923
|
+
gt_kpt = selected_keypoints[masks]
|
|
924
|
+
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
|
925
|
+
pred_kpt = pred_kpts[masks]
|
|
926
|
+
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
|
927
|
+
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
|
928
|
+
|
|
929
|
+
if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
|
|
930
|
+
rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
|
|
931
|
+
if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
|
|
932
|
+
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
|
933
|
+
|
|
934
|
+
return kpts_loss, kpts_obj_loss, rle_loss
|
|
935
|
+
|
|
936
|
+
|
|
647
937
|
class v8ClassificationLoss:
|
|
648
938
|
"""Criterion class for computing training losses for classification."""
|
|
649
939
|
|
|
@@ -657,10 +947,17 @@ class v8ClassificationLoss:
|
|
|
657
947
|
class v8OBBLoss(v8DetectionLoss):
|
|
658
948
|
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
|
|
659
949
|
|
|
660
|
-
def __init__(self, model):
|
|
950
|
+
def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
|
|
661
951
|
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
|
|
662
|
-
super().__init__(model)
|
|
663
|
-
self.assigner = RotatedTaskAlignedAssigner(
|
|
952
|
+
super().__init__(model, tal_topk=tal_topk)
|
|
953
|
+
self.assigner = RotatedTaskAlignedAssigner(
|
|
954
|
+
topk=tal_topk,
|
|
955
|
+
num_classes=self.nc,
|
|
956
|
+
alpha=0.5,
|
|
957
|
+
beta=6.0,
|
|
958
|
+
stride=self.stride.tolist(),
|
|
959
|
+
topk2=tal_topk2,
|
|
960
|
+
)
|
|
664
961
|
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
|
665
962
|
|
|
666
963
|
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
|
@@ -680,23 +977,19 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
680
977
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
|
681
978
|
return out
|
|
682
979
|
|
|
683
|
-
def
|
|
980
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
684
981
|
"""Calculate and return the loss for oriented bounding box detection."""
|
|
685
|
-
loss = torch.zeros(
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
(
|
|
982
|
+
loss = torch.zeros(4, device=self.device) # box, cls, dfl
|
|
983
|
+
pred_distri, pred_scores, pred_angle = (
|
|
984
|
+
preds["boxes"].permute(0, 2, 1).contiguous(),
|
|
985
|
+
preds["scores"].permute(0, 2, 1).contiguous(),
|
|
986
|
+
preds["angle"].permute(0, 2, 1).contiguous(),
|
|
690
987
|
)
|
|
691
|
-
|
|
692
|
-
#
|
|
693
|
-
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
|
694
|
-
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
|
695
|
-
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
|
988
|
+
anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
|
|
989
|
+
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
|
696
990
|
|
|
697
991
|
dtype = pred_scores.dtype
|
|
698
|
-
imgsz = torch.tensor(
|
|
699
|
-
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
992
|
+
imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=dtype) # image size (h,w)
|
|
700
993
|
|
|
701
994
|
# targets
|
|
702
995
|
try:
|
|
@@ -704,14 +997,14 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
704
997
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
|
705
998
|
rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
|
|
706
999
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
|
707
|
-
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
1000
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
708
1001
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
|
709
1002
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
710
1003
|
except RuntimeError as e:
|
|
711
1004
|
raise TypeError(
|
|
712
1005
|
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
|
713
1006
|
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
|
714
|
-
"i.e. 'yolo train model=yolo11n-obb.pt data=
|
|
1007
|
+
"i.e. 'yolo train model=yolo11n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
|
|
715
1008
|
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
|
716
1009
|
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
|
717
1010
|
) from e
|
|
@@ -741,16 +1034,29 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
741
1034
|
if fg_mask.sum():
|
|
742
1035
|
target_bboxes[..., :4] /= stride_tensor
|
|
743
1036
|
loss[0], loss[2] = self.bbox_loss(
|
|
744
|
-
pred_distri,
|
|
1037
|
+
pred_distri,
|
|
1038
|
+
pred_bboxes,
|
|
1039
|
+
anchor_points,
|
|
1040
|
+
target_bboxes,
|
|
1041
|
+
target_scores,
|
|
1042
|
+
target_scores_sum,
|
|
1043
|
+
fg_mask,
|
|
1044
|
+
imgsz,
|
|
1045
|
+
stride_tensor,
|
|
745
1046
|
)
|
|
1047
|
+
weight = target_scores.sum(-1)[fg_mask]
|
|
1048
|
+
loss[3] = self.calculate_angle_loss(
|
|
1049
|
+
pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
|
|
1050
|
+
) # angle loss
|
|
746
1051
|
else:
|
|
747
1052
|
loss[0] += (pred_angle * 0).sum()
|
|
748
1053
|
|
|
749
1054
|
loss[0] *= self.hyp.box # box gain
|
|
750
1055
|
loss[1] *= self.hyp.cls # cls gain
|
|
751
1056
|
loss[2] *= self.hyp.dfl # dfl gain
|
|
1057
|
+
loss[3] *= self.hyp.angle # angle gain
|
|
752
1058
|
|
|
753
|
-
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
1059
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
|
|
754
1060
|
|
|
755
1061
|
def bbox_decode(
|
|
756
1062
|
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
|
@@ -770,6 +1076,34 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
770
1076
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
|
771
1077
|
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
|
772
1078
|
|
|
1079
|
+
def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
|
|
1080
|
+
"""Calculate oriented angle loss.
|
|
1081
|
+
|
|
1082
|
+
Args:
|
|
1083
|
+
pred_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1084
|
+
target_bboxes: [N, 5] (x, y, w, h, theta).
|
|
1085
|
+
fg_mask: Foreground mask indicating valid predictions.
|
|
1086
|
+
weight: Loss weights for each prediction.
|
|
1087
|
+
target_scores_sum: Sum of target scores for normalization.
|
|
1088
|
+
lambda_val: control the sensitivity to aspect ratio.
|
|
1089
|
+
"""
|
|
1090
|
+
w_gt = target_bboxes[..., 2]
|
|
1091
|
+
h_gt = target_bboxes[..., 3]
|
|
1092
|
+
pred_theta = pred_bboxes[..., 4]
|
|
1093
|
+
target_theta = target_bboxes[..., 4]
|
|
1094
|
+
|
|
1095
|
+
log_ar = torch.log(w_gt / h_gt)
|
|
1096
|
+
scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
|
|
1097
|
+
|
|
1098
|
+
delta_theta = pred_theta - target_theta
|
|
1099
|
+
delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
|
|
1100
|
+
ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
|
|
1101
|
+
|
|
1102
|
+
ang_loss = scale_weight[fg_mask] * ang_loss
|
|
1103
|
+
ang_loss = ang_loss * weight
|
|
1104
|
+
|
|
1105
|
+
return ang_loss.sum() / target_scores_sum
|
|
1106
|
+
|
|
773
1107
|
|
|
774
1108
|
class E2EDetectLoss:
|
|
775
1109
|
"""Criterion class for computing training losses for end-to-end detection."""
|
|
@@ -789,61 +1123,108 @@ class E2EDetectLoss:
|
|
|
789
1123
|
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
|
|
790
1124
|
|
|
791
1125
|
|
|
1126
|
+
class E2ELoss:
|
|
1127
|
+
"""Criterion class for computing training losses for end-to-end detection."""
|
|
1128
|
+
|
|
1129
|
+
def __init__(self, model, loss_fn=v8DetectionLoss):
|
|
1130
|
+
"""Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
|
|
1131
|
+
self.one2many = loss_fn(model, tal_topk=10)
|
|
1132
|
+
self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
|
|
1133
|
+
self.updates = 0
|
|
1134
|
+
self.total = 1.0
|
|
1135
|
+
# init gain
|
|
1136
|
+
self.o2m = 0.8
|
|
1137
|
+
self.o2o = self.total - self.o2m
|
|
1138
|
+
self.o2m_copy = self.o2m
|
|
1139
|
+
# final gain
|
|
1140
|
+
self.final_o2m = 0.1
|
|
1141
|
+
|
|
1142
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1143
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
1144
|
+
preds = self.one2many.parse_output(preds)
|
|
1145
|
+
one2many, one2one = preds["one2many"], preds["one2one"]
|
|
1146
|
+
loss_one2many = self.one2many.loss(one2many, batch)
|
|
1147
|
+
loss_one2one = self.one2one.loss(one2one, batch)
|
|
1148
|
+
return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
|
|
1149
|
+
|
|
1150
|
+
def update(self) -> None:
|
|
1151
|
+
"""Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
|
|
1152
|
+
self.updates += 1
|
|
1153
|
+
self.o2m = self.decay(self.updates)
|
|
1154
|
+
self.o2o = max(self.total - self.o2m, 0)
|
|
1155
|
+
|
|
1156
|
+
def decay(self, x) -> float:
|
|
1157
|
+
"""Calculate the decayed weight for one-to-many loss based on the current update step."""
|
|
1158
|
+
return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
|
|
1159
|
+
|
|
1160
|
+
|
|
792
1161
|
class TVPDetectLoss:
|
|
793
1162
|
"""Criterion class for computing training losses for text-visual prompt detection."""
|
|
794
1163
|
|
|
795
|
-
def __init__(self, model):
|
|
1164
|
+
def __init__(self, model, tal_topk=10):
|
|
796
1165
|
"""Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
797
|
-
self.vp_criterion = v8DetectionLoss(model)
|
|
1166
|
+
self.vp_criterion = v8DetectionLoss(model, tal_topk)
|
|
798
1167
|
# NOTE: store following info as it's changeable in __call__
|
|
1168
|
+
self.hyp = self.vp_criterion.hyp
|
|
799
1169
|
self.ori_nc = self.vp_criterion.nc
|
|
800
1170
|
self.ori_no = self.vp_criterion.no
|
|
801
1171
|
self.ori_reg_max = self.vp_criterion.reg_max
|
|
802
1172
|
|
|
1173
|
+
def parse_output(self, preds) -> dict[str, torch.Tensor]:
|
|
1174
|
+
"""Parse model predictions to extract features."""
|
|
1175
|
+
return self.vp_criterion.parse_output(preds)
|
|
1176
|
+
|
|
803
1177
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
804
1178
|
"""Calculate the loss for text-visual prompt detection."""
|
|
805
|
-
|
|
1179
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1180
|
+
|
|
1181
|
+
def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1182
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
1183
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
806
1184
|
|
|
807
|
-
if self.
|
|
1185
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
808
1186
|
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
|
|
809
1187
|
return loss, loss.detach()
|
|
810
1188
|
|
|
811
|
-
|
|
812
|
-
vp_loss = self.vp_criterion(
|
|
813
|
-
|
|
814
|
-
return
|
|
1189
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1190
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
1191
|
+
box_loss = vp_loss[0][1]
|
|
1192
|
+
return box_loss, vp_loss[1]
|
|
815
1193
|
|
|
816
|
-
def _get_vp_features(self,
|
|
1194
|
+
def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
|
|
817
1195
|
"""Extract visual-prompt features from the model output."""
|
|
818
|
-
|
|
1196
|
+
# NOTE: remove empty placeholder
|
|
1197
|
+
scores = preds["scores"][:, self.ori_nc :, :]
|
|
1198
|
+
vnc = scores.shape[1]
|
|
819
1199
|
|
|
820
1200
|
self.vp_criterion.nc = vnc
|
|
821
1201
|
self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
|
|
822
1202
|
self.vp_criterion.assigner.num_classes = vnc
|
|
823
|
-
|
|
824
|
-
return [
|
|
825
|
-
torch.cat((box, cls_vp), dim=1)
|
|
826
|
-
for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
|
|
827
|
-
]
|
|
1203
|
+
return scores
|
|
828
1204
|
|
|
829
1205
|
|
|
830
1206
|
class TVPSegmentLoss(TVPDetectLoss):
|
|
831
1207
|
"""Criterion class for computing training losses for text-visual prompt segmentation."""
|
|
832
1208
|
|
|
833
|
-
def __init__(self, model):
|
|
1209
|
+
def __init__(self, model, tal_topk=10):
|
|
834
1210
|
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
835
1211
|
super().__init__(model)
|
|
836
|
-
self.vp_criterion = v8SegmentationLoss(model)
|
|
1212
|
+
self.vp_criterion = v8SegmentationLoss(model, tal_topk)
|
|
1213
|
+
self.hyp = self.vp_criterion.hyp
|
|
837
1214
|
|
|
838
1215
|
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
839
1216
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
840
|
-
|
|
1217
|
+
return self.loss(self.parse_output(preds), batch)
|
|
1218
|
+
|
|
1219
|
+
def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1220
|
+
"""Calculate the loss for text-visual prompt detection."""
|
|
1221
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
841
1222
|
|
|
842
|
-
if self.
|
|
1223
|
+
if self.ori_nc == preds["scores"].shape[1]:
|
|
843
1224
|
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
|
844
1225
|
return loss, loss.detach()
|
|
845
1226
|
|
|
846
|
-
|
|
847
|
-
vp_loss = self.vp_criterion(
|
|
1227
|
+
preds["scores"] = self._get_vp_features(preds)
|
|
1228
|
+
vp_loss = self.vp_criterion(preds, batch)
|
|
848
1229
|
cls_loss = vp_loss[0][2]
|
|
849
1230
|
return cls_loss, vp_loss[1]
|