ultralytics-opencv-headless 8.3.253__py3-none-any.whl → 8.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/__init__.py +2 -2
  2. tests/conftest.py +1 -1
  3. tests/test_cuda.py +8 -2
  4. tests/test_engine.py +6 -6
  5. tests/test_exports.py +10 -3
  6. tests/test_integrations.py +9 -9
  7. tests/test_python.py +14 -14
  8. tests/test_solutions.py +3 -3
  9. ultralytics/__init__.py +1 -1
  10. ultralytics/cfg/__init__.py +6 -6
  11. ultralytics/cfg/default.yaml +3 -1
  12. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  13. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  14. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  15. ultralytics/cfg/models/26/yolo26-p6.yaml +60 -0
  16. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  17. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  18. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  19. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  20. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  21. ultralytics/data/augment.py +7 -0
  22. ultralytics/data/dataset.py +1 -1
  23. ultralytics/engine/exporter.py +10 -3
  24. ultralytics/engine/model.py +1 -1
  25. ultralytics/engine/trainer.py +40 -15
  26. ultralytics/engine/tuner.py +15 -7
  27. ultralytics/models/fastsam/predict.py +1 -1
  28. ultralytics/models/yolo/detect/train.py +3 -2
  29. ultralytics/models/yolo/detect/val.py +6 -0
  30. ultralytics/models/yolo/model.py +1 -1
  31. ultralytics/models/yolo/obb/predict.py +1 -1
  32. ultralytics/models/yolo/obb/train.py +1 -1
  33. ultralytics/models/yolo/pose/train.py +1 -1
  34. ultralytics/models/yolo/segment/predict.py +1 -1
  35. ultralytics/models/yolo/segment/train.py +1 -1
  36. ultralytics/models/yolo/segment/val.py +3 -1
  37. ultralytics/models/yolo/yoloe/train.py +6 -1
  38. ultralytics/models/yolo/yoloe/train_seg.py +6 -1
  39. ultralytics/nn/autobackend.py +7 -3
  40. ultralytics/nn/modules/__init__.py +8 -0
  41. ultralytics/nn/modules/block.py +127 -8
  42. ultralytics/nn/modules/head.py +818 -205
  43. ultralytics/nn/tasks.py +74 -29
  44. ultralytics/nn/text_model.py +5 -2
  45. ultralytics/optim/__init__.py +5 -0
  46. ultralytics/optim/muon.py +338 -0
  47. ultralytics/utils/benchmarks.py +1 -0
  48. ultralytics/utils/callbacks/platform.py +9 -7
  49. ultralytics/utils/downloads.py +3 -1
  50. ultralytics/utils/export/engine.py +19 -10
  51. ultralytics/utils/export/imx.py +22 -11
  52. ultralytics/utils/export/tensorflow.py +1 -41
  53. ultralytics/utils/loss.py +584 -203
  54. ultralytics/utils/metrics.py +1 -0
  55. ultralytics/utils/ops.py +11 -2
  56. ultralytics/utils/tal.py +98 -19
  57. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/METADATA +31 -39
  58. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/RECORD +62 -51
  59. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/WHEEL +0 -0
  60. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/entry_points.txt +0 -0
  61. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/licenses/LICENSE +0 -0
  62. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py CHANGED
@@ -2,19 +2,20 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import math
5
6
  from typing import Any
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
9
10
  import torch.nn.functional as F
10
11
 
11
- from ultralytics.utils.metrics import OKS_SIGMA
12
+ from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
12
13
  from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
13
14
  from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
14
15
  from ultralytics.utils.torch_utils import autocast
15
16
 
16
17
  from .metrics import bbox_iou, probiou
17
- from .tal import bbox2dist
18
+ from .tal import bbox2dist, rbox2dist
18
19
 
19
20
 
20
21
  class VarifocalLoss(nn.Module):
@@ -122,6 +123,8 @@ class BboxLoss(nn.Module):
122
123
  target_scores: torch.Tensor,
123
124
  target_scores_sum: torch.Tensor,
124
125
  fg_mask: torch.Tensor,
126
+ imgsz: torch.Tensor,
127
+ stride: torch.Tensor,
125
128
  ) -> tuple[torch.Tensor, torch.Tensor]:
126
129
  """Compute IoU and DFL losses for bounding boxes."""
127
130
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -134,11 +137,75 @@ class BboxLoss(nn.Module):
134
137
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
135
138
  loss_dfl = loss_dfl.sum() / target_scores_sum
136
139
  else:
137
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
140
+ target_ltrb = bbox2dist(anchor_points, target_bboxes)
141
+ # normalize ltrb by image size
142
+ target_ltrb = target_ltrb * stride
143
+ target_ltrb[..., 0::2] /= imgsz[1]
144
+ target_ltrb[..., 1::2] /= imgsz[0]
145
+ pred_dist = pred_dist * stride
146
+ pred_dist[..., 0::2] /= imgsz[1]
147
+ pred_dist[..., 1::2] /= imgsz[0]
148
+ loss_dfl = (
149
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
150
+ )
151
+ loss_dfl = loss_dfl.sum() / target_scores_sum
138
152
 
139
153
  return loss_iou, loss_dfl
140
154
 
141
155
 
156
+ class RLELoss(nn.Module):
157
+ """Residual Log-Likelihood Estimation Loss.
158
+
159
+ Args:
160
+ use_target_weight (bool): Option to use weighted loss.
161
+ size_average (bool): Option to average the loss by the batch_size.
162
+ residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
163
+
164
+ References:
165
+ https://arxiv.org/abs/2107.11291
166
+ """
167
+
168
+ def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
169
+ """Initialize RLELoss with target weight and residual options.
170
+
171
+ Args:
172
+ use_target_weight (bool): Whether to use target weights for loss calculation.
173
+ size_average (bool): Whether to average the loss over elements.
174
+ residual (bool): Whether to include residual log-likelihood term.
175
+ """
176
+ super().__init__()
177
+ self.size_average = size_average
178
+ self.use_target_weight = use_target_weight
179
+ self.residual = residual
180
+
181
+ def forward(
182
+ self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
183
+ ) -> torch.Tensor:
184
+ """
185
+ Args:
186
+ sigma (torch.Tensor): Output sigma, shape (N, D).
187
+ log_phi (torch.Tensor): Output log_phi, shape (N).
188
+ error (torch.Tensor): Error, shape (N, D).
189
+ target_weight (torch.Tensor): Weights across different joint types, shape (N).
190
+ """
191
+ log_sigma = torch.log(sigma)
192
+ loss = log_sigma - log_phi.unsqueeze(1)
193
+
194
+ if self.residual:
195
+ loss += torch.log(sigma * 2) + torch.abs(error)
196
+
197
+ if self.use_target_weight:
198
+ assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
199
+ if target_weight.dim() == 1:
200
+ target_weight = target_weight.unsqueeze(1)
201
+ loss *= target_weight
202
+
203
+ if self.size_average:
204
+ loss /= len(loss)
205
+
206
+ return loss.sum()
207
+
208
+
142
209
  class RotatedBboxLoss(BboxLoss):
143
210
  """Criterion class for computing training losses for rotated bounding boxes."""
144
211
 
@@ -155,6 +222,8 @@ class RotatedBboxLoss(BboxLoss):
155
222
  target_scores: torch.Tensor,
156
223
  target_scores_sum: torch.Tensor,
157
224
  fg_mask: torch.Tensor,
225
+ imgsz: torch.Tensor,
226
+ stride: torch.Tensor,
158
227
  ) -> tuple[torch.Tensor, torch.Tensor]:
159
228
  """Compute IoU and DFL losses for rotated bounding boxes."""
160
229
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -163,15 +232,84 @@ class RotatedBboxLoss(BboxLoss):
163
232
 
164
233
  # DFL loss
165
234
  if self.dfl_loss:
166
- target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
235
+ target_ltrb = rbox2dist(
236
+ target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
237
+ )
167
238
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
168
239
  loss_dfl = loss_dfl.sum() / target_scores_sum
169
240
  else:
170
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
241
+ target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
242
+ target_ltrb = target_ltrb * stride
243
+ target_ltrb[..., 0::2] /= imgsz[1]
244
+ target_ltrb[..., 1::2] /= imgsz[0]
245
+ pred_dist = pred_dist * stride
246
+ pred_dist[..., 0::2] /= imgsz[1]
247
+ pred_dist[..., 1::2] /= imgsz[0]
248
+ loss_dfl = (
249
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
250
+ )
251
+ loss_dfl = loss_dfl.sum() / target_scores_sum
171
252
 
172
253
  return loss_iou, loss_dfl
173
254
 
174
255
 
256
+ class MultiChannelDiceLoss(nn.Module):
257
+ """Criterion class for computing multi-channel Dice losses."""
258
+
259
+ def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
260
+ """Initialize MultiChannelDiceLoss with smoothing and reduction options.
261
+
262
+ Args:
263
+ smooth (float): Smoothing factor to avoid division by zero.
264
+ reduction (str): Reduction method ('mean', 'sum', or 'none').
265
+ """
266
+ super().__init__()
267
+ self.smooth = smooth
268
+ self.reduction = reduction
269
+
270
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
271
+ """Calculate multi-channel Dice loss between predictions and targets."""
272
+ assert pred.size() == target.size(), "the size of predict and target must be equal."
273
+
274
+ pred = pred.sigmoid()
275
+ intersection = (pred * target).sum(dim=(2, 3))
276
+ union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
277
+ dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
278
+ dice_loss = 1.0 - dice
279
+ dice_loss = dice_loss.mean(dim=1)
280
+
281
+ if self.reduction == "mean":
282
+ return dice_loss.mean()
283
+ elif self.reduction == "sum":
284
+ return dice_loss.sum()
285
+ else:
286
+ return dice_loss
287
+
288
+
289
+ class BCEDiceLoss(nn.Module):
290
+ """Criterion class for computing combined BCE and Dice losses."""
291
+
292
+ def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
293
+ """Initialize BCEDiceLoss with BCE and Dice weight factors.
294
+
295
+ Args:
296
+ weight_bce (float): Weight factor for BCE loss component.
297
+ weight_dice (float): Weight factor for Dice loss component.
298
+ """
299
+ super().__init__()
300
+ self.weight_bce = weight_bce
301
+ self.weight_dice = weight_dice
302
+ self.bce = nn.BCEWithLogitsLoss()
303
+ self.dice = MultiChannelDiceLoss(smooth=1)
304
+
305
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
306
+ """Calculate combined BCE and Dice loss between predictions and targets."""
307
+ _, _, mask_h, mask_w = pred.shape
308
+ if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
309
+ target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
310
+ return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
311
+
312
+
175
313
  class KeypointLoss(nn.Module):
176
314
  """Criterion class for computing keypoint losses."""
177
315
 
@@ -194,7 +332,7 @@ class KeypointLoss(nn.Module):
194
332
  class v8DetectionLoss:
195
333
  """Criterion class for computing training losses for YOLOv8 object detection."""
196
334
 
197
- def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
335
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
198
336
  """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
199
337
  device = next(model.parameters()).device # get model device
200
338
  h = model.args # hyperparameters
@@ -210,7 +348,14 @@ class v8DetectionLoss:
210
348
 
211
349
  self.use_dfl = m.reg_max > 1
212
350
 
213
- self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
351
+ self.assigner = TaskAlignedAssigner(
352
+ topk=tal_topk,
353
+ num_classes=self.nc,
354
+ alpha=0.5,
355
+ beta=6.0,
356
+ stride=self.stride.tolist(),
357
+ topk2=tal_topk2,
358
+ )
214
359
  self.bbox_loss = BboxLoss(m.reg_max).to(device)
215
360
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
216
361
 
@@ -240,35 +385,31 @@ class v8DetectionLoss:
240
385
  # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
241
386
  return dist2bbox(pred_dist, anchor_points, xywh=False)
242
387
 
243
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
244
- """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
388
+ def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
389
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
390
+ target indices.
391
+ """
245
392
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
246
- feats = preds[1] if isinstance(preds, tuple) else preds
247
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
248
- (self.reg_max * 4, self.nc), 1
393
+ pred_distri, pred_scores = (
394
+ preds["boxes"].permute(0, 2, 1).contiguous(),
395
+ preds["scores"].permute(0, 2, 1).contiguous(),
249
396
  )
250
-
251
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
252
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
397
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
253
398
 
254
399
  dtype = pred_scores.dtype
255
400
  batch_size = pred_scores.shape[0]
256
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
257
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
401
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
258
402
 
259
403
  # Targets
260
404
  targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
261
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
405
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
262
406
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
263
407
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
264
408
 
265
409
  # Pboxes
266
410
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
267
- # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
268
- # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
269
411
 
270
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
271
- # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
412
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
272
413
  pred_scores.detach().sigmoid(),
273
414
  (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
274
415
  anchor_points * stride_tensor,
@@ -280,7 +421,6 @@ class v8DetectionLoss:
280
421
  target_scores_sum = max(target_scores.sum(), 1)
281
422
 
282
423
  # Cls loss
283
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
284
424
  loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
285
425
 
286
426
  # Bbox loss
@@ -293,105 +433,97 @@ class v8DetectionLoss:
293
433
  target_scores,
294
434
  target_scores_sum,
295
435
  fg_mask,
436
+ imgsz,
437
+ stride_tensor,
296
438
  )
297
439
 
298
440
  loss[0] *= self.hyp.box # box gain
299
441
  loss[1] *= self.hyp.cls # cls gain
300
442
  loss[2] *= self.hyp.dfl # dfl gain
443
+ return (
444
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
445
+ loss,
446
+ loss.detach(),
447
+ ) # loss(box, cls, dfl)
301
448
 
302
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
449
+ def parse_output(
450
+ self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
451
+ ) -> torch.Tensor:
452
+ """Parse model predictions to extract features."""
453
+ return preds[1] if isinstance(preds, tuple) else preds
454
+
455
+ def __call__(
456
+ self,
457
+ preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
458
+ batch: dict[str, torch.Tensor],
459
+ ) -> tuple[torch.Tensor, torch.Tensor]:
460
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
461
+ return self.loss(self.parse_output(preds), batch)
462
+
463
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
464
+ """A wrapper for get_assigned_targets_and_loss and parse_output."""
465
+ batch_size = preds["boxes"].shape[0]
466
+ loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
467
+ return loss * batch_size, loss_detach
303
468
 
304
469
 
305
470
  class v8SegmentationLoss(v8DetectionLoss):
306
471
  """Criterion class for computing training losses for YOLOv8 segmentation."""
307
472
 
308
- def __init__(self, model): # model must be de-paralleled
473
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
309
474
  """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
310
- super().__init__(model)
475
+ super().__init__(model, tal_topk, tal_topk2)
311
476
  self.overlap = model.args.overlap_mask
477
+ self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
312
478
 
313
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
479
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
314
480
  """Calculate and return the combined loss for detection and segmentation."""
315
- loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
316
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
317
- batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
318
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
319
- (self.reg_max * 4, self.nc), 1
320
- )
321
-
322
- # B, grids, ..
323
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
324
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
325
- pred_masks = pred_masks.permute(0, 2, 1).contiguous()
326
-
327
- dtype = pred_scores.dtype
328
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
329
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
330
-
331
- # Targets
332
- try:
333
- batch_idx = batch["batch_idx"].view(-1, 1)
334
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
335
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
336
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
337
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
338
- except RuntimeError as e:
339
- raise TypeError(
340
- "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
341
- "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
342
- "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
343
- "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
344
- "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
345
- ) from e
346
-
347
- # Pboxes
348
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
349
-
350
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
351
- pred_scores.detach().sigmoid(),
352
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
353
- anchor_points * stride_tensor,
354
- gt_labels,
355
- gt_bboxes,
356
- mask_gt,
357
- )
358
-
359
- target_scores_sum = max(target_scores.sum(), 1)
360
-
361
- # Cls loss
362
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
363
- loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
481
+ pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
482
+ loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
483
+ if len(proto) == 2:
484
+ proto, pred_semseg = proto
485
+ else:
486
+ pred_semseg = None
487
+ (fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
488
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
489
+ loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
364
490
 
491
+ batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
365
492
  if fg_mask.sum():
366
- # Bbox loss
367
- loss[0], loss[3] = self.bbox_loss(
368
- pred_distri,
369
- pred_bboxes,
370
- anchor_points,
371
- target_bboxes / stride_tensor,
372
- target_scores,
373
- target_scores_sum,
374
- fg_mask,
375
- )
376
493
  # Masks loss
377
494
  masks = batch["masks"].to(self.device).float()
378
495
  if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
379
- masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
496
+ # masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
497
+ proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
380
498
 
499
+ imgsz = (
500
+ torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
501
+ )
381
502
  loss[1] = self.calculate_segmentation_loss(
382
- fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
503
+ fg_mask,
504
+ masks,
505
+ target_gt_idx,
506
+ target_bboxes,
507
+ batch["batch_idx"].view(-1, 1),
508
+ proto,
509
+ pred_masks,
510
+ imgsz,
383
511
  )
512
+ if pred_semseg is not None:
513
+ sem_masks = batch["sem_masks"].to(self.device) # NxHxW
514
+ mask_zero = sem_masks == 0 # NxHxW
515
+ sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
516
+ sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
517
+ loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
518
+ loss[4] *= self.hyp.box # seg gain
384
519
 
385
520
  # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
386
521
  else:
387
522
  loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
523
+ loss[4] += (pred_semseg * 0).sum() + (sem_masks * 0).sum()
388
524
 
389
- loss[0] *= self.hyp.box # box gain
390
525
  loss[1] *= self.hyp.box # seg gain
391
- loss[2] *= self.hyp.cls # cls gain
392
- loss[3] *= self.hyp.dfl # dfl gain
393
-
394
- return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
526
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
395
527
 
396
528
  @staticmethod
397
529
  def single_mask_loss(
@@ -427,7 +559,6 @@ class v8SegmentationLoss(v8DetectionLoss):
427
559
  proto: torch.Tensor,
428
560
  pred_masks: torch.Tensor,
429
561
  imgsz: torch.Tensor,
430
- overlap: bool,
431
562
  ) -> torch.Tensor:
432
563
  """Calculate the loss for instance segmentation.
433
564
 
@@ -440,7 +571,6 @@ class v8SegmentationLoss(v8DetectionLoss):
440
571
  proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
441
572
  pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
442
573
  imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
443
- overlap (bool): Whether the masks in `masks` tensor overlap.
444
574
 
445
575
  Returns:
446
576
  (torch.Tensor): The calculated loss for instance segmentation.
@@ -466,7 +596,7 @@ class v8SegmentationLoss(v8DetectionLoss):
466
596
  fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
467
597
  if fg_mask_i.any():
468
598
  mask_idx = target_gt_idx_i[fg_mask_i]
469
- if overlap:
599
+ if self.overlap:
470
600
  gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
471
601
  gt_mask = gt_mask.float()
472
602
  else:
@@ -486,9 +616,9 @@ class v8SegmentationLoss(v8DetectionLoss):
486
616
  class v8PoseLoss(v8DetectionLoss):
487
617
  """Criterion class for computing training losses for YOLOv8 pose estimation."""
488
618
 
489
- def __init__(self, model): # model must be de-paralleled
619
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
490
620
  """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
491
- super().__init__(model)
621
+ super().__init__(model, tal_topk, tal_topk2)
492
622
  self.kpt_shape = model.model[-1].kpt_shape
493
623
  self.bce_pose = nn.BCEWithLogitsLoss()
494
624
  is_pose = self.kpt_shape == [17, 3]
@@ -496,69 +626,40 @@ class v8PoseLoss(v8DetectionLoss):
496
626
  sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
497
627
  self.keypoint_loss = KeypointLoss(sigmas=sigmas)
498
628
 
499
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
629
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
500
630
  """Calculate the total loss and detach it for pose estimation."""
501
- loss = torch.zeros(5, device=self.device) # box, pose, kobj, cls, dfl
502
- feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
503
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
504
- (self.reg_max * 4, self.nc), 1
631
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
632
+ loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
633
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
634
+ self.get_assigned_targets_and_loss(preds, batch)
505
635
  )
636
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
637
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
506
638
 
507
- # B, grids, ..
508
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
509
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
510
- pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
511
-
512
- dtype = pred_scores.dtype
513
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
514
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
515
-
516
- # Targets
517
- batch_size = pred_scores.shape[0]
518
- batch_idx = batch["batch_idx"].view(-1, 1)
519
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
520
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
521
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
522
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
639
+ batch_size = pred_kpts.shape[0]
640
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
523
641
 
524
642
  # Pboxes
525
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
526
643
  pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
527
644
 
528
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
529
- pred_scores.detach().sigmoid(),
530
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
531
- anchor_points * stride_tensor,
532
- gt_labels,
533
- gt_bboxes,
534
- mask_gt,
535
- )
536
-
537
- target_scores_sum = max(target_scores.sum(), 1)
538
-
539
- # Cls loss
540
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
541
- loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
542
-
543
645
  # Bbox loss
544
646
  if fg_mask.sum():
545
- target_bboxes /= stride_tensor
546
- loss[0], loss[4] = self.bbox_loss(
547
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
548
- )
549
647
  keypoints = batch["keypoints"].to(self.device).float().clone()
550
648
  keypoints[..., 0] *= imgsz[1]
551
649
  keypoints[..., 1] *= imgsz[0]
552
650
 
553
651
  loss[1], loss[2] = self.calculate_keypoints_loss(
554
- fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
652
+ fg_mask,
653
+ target_gt_idx,
654
+ keypoints,
655
+ batch["batch_idx"].view(-1, 1),
656
+ stride_tensor,
657
+ target_bboxes,
658
+ pred_kpts,
555
659
  )
556
660
 
557
- loss[0] *= self.hyp.box # box gain
558
661
  loss[1] *= self.hyp.pose # pose gain
559
662
  loss[2] *= self.hyp.kobj # kobj gain
560
- loss[3] *= self.hyp.cls # cls gain
561
- loss[4] *= self.hyp.dfl # dfl gain
562
663
 
563
664
  return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
564
665
 
@@ -571,34 +672,23 @@ class v8PoseLoss(v8DetectionLoss):
571
672
  y[..., 1] += anchor_points[:, [1]] - 0.5
572
673
  return y
573
674
 
574
- def calculate_keypoints_loss(
675
+ def _select_target_keypoints(
575
676
  self,
576
- masks: torch.Tensor,
577
- target_gt_idx: torch.Tensor,
578
677
  keypoints: torch.Tensor,
579
678
  batch_idx: torch.Tensor,
580
- stride_tensor: torch.Tensor,
581
- target_bboxes: torch.Tensor,
582
- pred_kpts: torch.Tensor,
583
- ) -> tuple[torch.Tensor, torch.Tensor]:
584
- """Calculate the keypoints loss for the model.
585
-
586
- This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
587
- based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
588
- a binary classification loss that classifies whether a keypoint is present or not.
679
+ target_gt_idx: torch.Tensor,
680
+ masks: torch.Tensor,
681
+ ) -> torch.Tensor:
682
+ """Select target keypoints for each anchor based on batch index and target ground truth index.
589
683
 
590
684
  Args:
591
- masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
592
- target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
593
685
  keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
594
686
  batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
595
- stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
596
- target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
597
- pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
687
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
688
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
598
689
 
599
690
  Returns:
600
- kpts_loss (torch.Tensor): The keypoints loss.
601
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
691
+ (torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
602
692
  """
603
693
  batch_idx = batch_idx.flatten()
604
694
  batch_size = len(masks)
@@ -625,6 +715,40 @@ class v8PoseLoss(v8DetectionLoss):
625
715
  1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
626
716
  )
627
717
 
718
+ return selected_keypoints
719
+
720
+ def calculate_keypoints_loss(
721
+ self,
722
+ masks: torch.Tensor,
723
+ target_gt_idx: torch.Tensor,
724
+ keypoints: torch.Tensor,
725
+ batch_idx: torch.Tensor,
726
+ stride_tensor: torch.Tensor,
727
+ target_bboxes: torch.Tensor,
728
+ pred_kpts: torch.Tensor,
729
+ ) -> tuple[torch.Tensor, torch.Tensor]:
730
+ """Calculate the keypoints loss for the model.
731
+
732
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
733
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
734
+ a binary classification loss that classifies whether a keypoint is present or not.
735
+
736
+ Args:
737
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
738
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
739
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
740
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
741
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
742
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
743
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
744
+
745
+ Returns:
746
+ kpts_loss (torch.Tensor): The keypoints loss.
747
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
748
+ """
749
+ # Select target keypoints using helper method
750
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
751
+
628
752
  # Divide coordinates by stride
629
753
  selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
630
754
 
@@ -632,6 +756,7 @@ class v8PoseLoss(v8DetectionLoss):
632
756
  kpts_obj_loss = 0
633
757
 
634
758
  if masks.any():
759
+ target_bboxes /= stride_tensor
635
760
  gt_kpt = selected_keypoints[masks]
636
761
  area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
637
762
  pred_kpt = pred_kpts[masks]
@@ -644,6 +769,171 @@ class v8PoseLoss(v8DetectionLoss):
644
769
  return kpts_loss, kpts_obj_loss
645
770
 
646
771
 
772
+ class PoseLoss26(v8PoseLoss):
773
+ """Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
774
+
775
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
776
+ """Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
777
+ super().__init__(model, tal_topk, tal_topk2)
778
+ is_pose = self.kpt_shape == [17, 3]
779
+ nkpt = self.kpt_shape[0] # number of keypoints
780
+ self.rle_loss = None
781
+ self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
782
+ if self.flow_model is not None:
783
+ self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
784
+ self.target_weights = (
785
+ torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
786
+ )
787
+
788
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
789
+ """Calculate the total loss and detach it for pose estimation."""
790
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
791
+ loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
792
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
793
+ self.get_assigned_targets_and_loss(preds, batch)
794
+ )
795
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
796
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
797
+
798
+ batch_size = pred_kpts.shape[0]
799
+ imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=pred_kpts.dtype) # image size (h,w)
800
+
801
+ pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
802
+
803
+ if self.rle_loss and preds.get("kpts_sigma", None) is not None:
804
+ pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
805
+ pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
806
+ pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
807
+
808
+ pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
809
+
810
+ # Bbox loss
811
+ if fg_mask.sum():
812
+ keypoints = batch["keypoints"].to(self.device).float().clone()
813
+ keypoints[..., 0] *= imgsz[1]
814
+ keypoints[..., 1] *= imgsz[0]
815
+
816
+ keypoints_loss = self.calculate_keypoints_loss(
817
+ fg_mask,
818
+ target_gt_idx,
819
+ keypoints,
820
+ batch["batch_idx"].view(-1, 1),
821
+ stride_tensor,
822
+ target_bboxes,
823
+ pred_kpts,
824
+ )
825
+ loss[1] = keypoints_loss[0]
826
+ loss[2] = keypoints_loss[1]
827
+ if self.rle_loss is not None:
828
+ loss[5] = keypoints_loss[2]
829
+
830
+ loss[1] *= self.hyp.pose # pose gain
831
+ loss[2] *= self.hyp.kobj # kobj gain
832
+ if self.rle_loss is not None:
833
+ loss[5] *= self.hyp.rle # rle gain
834
+
835
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
836
+
837
+ @staticmethod
838
+ def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
839
+ """Decode predicted keypoints to image coordinates."""
840
+ y = pred_kpts.clone()
841
+ y[..., 0] += anchor_points[:, [0]]
842
+ y[..., 1] += anchor_points[:, [1]]
843
+ return y
844
+
845
+ def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
846
+ """Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
847
+
848
+ Args:
849
+ pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
850
+ gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
851
+ kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
852
+
853
+ Returns:
854
+ (torch.Tensor): The RLE loss.
855
+ """
856
+ pred_kpt_visible = pred_kpt[kpt_mask]
857
+ gt_kpt_visible = gt_kpt[kpt_mask]
858
+ pred_coords = pred_kpt_visible[:, 0:2]
859
+ pred_sigma = pred_kpt_visible[:, -2:]
860
+ gt_coords = gt_kpt_visible[:, 0:2]
861
+
862
+ target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
863
+ target_weights = target_weights[kpt_mask]
864
+
865
+ pred_sigma = pred_sigma.sigmoid()
866
+ error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
867
+
868
+ # Filter out NaN values to prevent MultivariateNormal validation errors (can occur with small images)
869
+ valid_mask = ~torch.isnan(error).any(dim=-1)
870
+ if not valid_mask.any():
871
+ return torch.tensor(0.0, device=pred_kpt.device)
872
+
873
+ error = error[valid_mask]
874
+ pred_sigma = pred_sigma[valid_mask]
875
+ target_weights = target_weights[valid_mask]
876
+
877
+ log_phi = self.flow_model.log_prob(error)
878
+
879
+ return self.rle_loss(pred_sigma, log_phi, error, target_weights)
880
+
881
+ def calculate_keypoints_loss(
882
+ self,
883
+ masks: torch.Tensor,
884
+ target_gt_idx: torch.Tensor,
885
+ keypoints: torch.Tensor,
886
+ batch_idx: torch.Tensor,
887
+ stride_tensor: torch.Tensor,
888
+ target_bboxes: torch.Tensor,
889
+ pred_kpts: torch.Tensor,
890
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
891
+ """Calculate the keypoints loss for the model.
892
+
893
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
894
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
895
+ a binary classification loss that classifies whether a keypoint is present or not.
896
+
897
+ Args:
898
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
899
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
900
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
901
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
902
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
903
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
904
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
905
+
906
+ Returns:
907
+ kpts_loss (torch.Tensor): The keypoints loss.
908
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
909
+ rle_loss (torch.Tensor): The RLE loss.
910
+ """
911
+ # Select target keypoints using inherited helper method
912
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
913
+
914
+ # Divide coordinates by stride
915
+ selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
916
+
917
+ kpts_loss = 0
918
+ kpts_obj_loss = 0
919
+ rle_loss = 0
920
+
921
+ if masks.any():
922
+ target_bboxes /= stride_tensor
923
+ gt_kpt = selected_keypoints[masks]
924
+ area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
925
+ pred_kpt = pred_kpts[masks]
926
+ kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
927
+ kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
928
+
929
+ if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
930
+ rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
931
+ if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
932
+ kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
933
+
934
+ return kpts_loss, kpts_obj_loss, rle_loss
935
+
936
+
647
937
  class v8ClassificationLoss:
648
938
  """Criterion class for computing training losses for classification."""
649
939
 
@@ -657,10 +947,17 @@ class v8ClassificationLoss:
657
947
  class v8OBBLoss(v8DetectionLoss):
658
948
  """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
659
949
 
660
- def __init__(self, model):
950
+ def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
661
951
  """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
662
- super().__init__(model)
663
- self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
952
+ super().__init__(model, tal_topk=tal_topk)
953
+ self.assigner = RotatedTaskAlignedAssigner(
954
+ topk=tal_topk,
955
+ num_classes=self.nc,
956
+ alpha=0.5,
957
+ beta=6.0,
958
+ stride=self.stride.tolist(),
959
+ topk2=tal_topk2,
960
+ )
664
961
  self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
665
962
 
666
963
  def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
@@ -680,23 +977,19 @@ class v8OBBLoss(v8DetectionLoss):
680
977
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
681
978
  return out
682
979
 
683
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
980
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
684
981
  """Calculate and return the loss for oriented bounding box detection."""
685
- loss = torch.zeros(3, device=self.device) # box, cls, dfl
686
- feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
687
- batch_size = pred_angle.shape[0] # batch size
688
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
689
- (self.reg_max * 4, self.nc), 1
982
+ loss = torch.zeros(4, device=self.device) # box, cls, dfl
983
+ pred_distri, pred_scores, pred_angle = (
984
+ preds["boxes"].permute(0, 2, 1).contiguous(),
985
+ preds["scores"].permute(0, 2, 1).contiguous(),
986
+ preds["angle"].permute(0, 2, 1).contiguous(),
690
987
  )
691
-
692
- # b, grids, ..
693
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
694
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
695
- pred_angle = pred_angle.permute(0, 2, 1).contiguous()
988
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
989
+ batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
696
990
 
697
991
  dtype = pred_scores.dtype
698
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
699
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
992
+ imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=dtype) # image size (h,w)
700
993
 
701
994
  # targets
702
995
  try:
@@ -704,14 +997,14 @@ class v8OBBLoss(v8DetectionLoss):
704
997
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
705
998
  rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
706
999
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
707
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
1000
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
708
1001
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
709
1002
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
710
1003
  except RuntimeError as e:
711
1004
  raise TypeError(
712
1005
  "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
713
1006
  "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
714
- "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
1007
+ "i.e. 'yolo train model=yolo11n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
715
1008
  "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
716
1009
  "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
717
1010
  ) from e
@@ -741,16 +1034,29 @@ class v8OBBLoss(v8DetectionLoss):
741
1034
  if fg_mask.sum():
742
1035
  target_bboxes[..., :4] /= stride_tensor
743
1036
  loss[0], loss[2] = self.bbox_loss(
744
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
1037
+ pred_distri,
1038
+ pred_bboxes,
1039
+ anchor_points,
1040
+ target_bboxes,
1041
+ target_scores,
1042
+ target_scores_sum,
1043
+ fg_mask,
1044
+ imgsz,
1045
+ stride_tensor,
745
1046
  )
1047
+ weight = target_scores.sum(-1)[fg_mask]
1048
+ loss[3] = self.calculate_angle_loss(
1049
+ pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
1050
+ ) # angle loss
746
1051
  else:
747
1052
  loss[0] += (pred_angle * 0).sum()
748
1053
 
749
1054
  loss[0] *= self.hyp.box # box gain
750
1055
  loss[1] *= self.hyp.cls # cls gain
751
1056
  loss[2] *= self.hyp.dfl # dfl gain
1057
+ loss[3] *= self.hyp.angle # angle gain
752
1058
 
753
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
1059
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
754
1060
 
755
1061
  def bbox_decode(
756
1062
  self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
@@ -770,6 +1076,34 @@ class v8OBBLoss(v8DetectionLoss):
770
1076
  pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
771
1077
  return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
772
1078
 
1079
+ def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
1080
+ """Calculate oriented angle loss.
1081
+
1082
+ Args:
1083
+ pred_bboxes: [N, 5] (x, y, w, h, theta).
1084
+ target_bboxes: [N, 5] (x, y, w, h, theta).
1085
+ fg_mask: Foreground mask indicating valid predictions.
1086
+ weight: Loss weights for each prediction.
1087
+ target_scores_sum: Sum of target scores for normalization.
1088
+ lambda_val: control the sensitivity to aspect ratio.
1089
+ """
1090
+ w_gt = target_bboxes[..., 2]
1091
+ h_gt = target_bboxes[..., 3]
1092
+ pred_theta = pred_bboxes[..., 4]
1093
+ target_theta = target_bboxes[..., 4]
1094
+
1095
+ log_ar = torch.log(w_gt / h_gt)
1096
+ scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
1097
+
1098
+ delta_theta = pred_theta - target_theta
1099
+ delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
1100
+ ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
1101
+
1102
+ ang_loss = scale_weight[fg_mask] * ang_loss
1103
+ ang_loss = ang_loss * weight
1104
+
1105
+ return ang_loss.sum() / target_scores_sum
1106
+
773
1107
 
774
1108
  class E2EDetectLoss:
775
1109
  """Criterion class for computing training losses for end-to-end detection."""
@@ -789,61 +1123,108 @@ class E2EDetectLoss:
789
1123
  return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
790
1124
 
791
1125
 
1126
+ class E2ELoss:
1127
+ """Criterion class for computing training losses for end-to-end detection."""
1128
+
1129
+ def __init__(self, model, loss_fn=v8DetectionLoss):
1130
+ """Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
1131
+ self.one2many = loss_fn(model, tal_topk=10)
1132
+ self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
1133
+ self.updates = 0
1134
+ self.total = 1.0
1135
+ # init gain
1136
+ self.o2m = 0.8
1137
+ self.o2o = self.total - self.o2m
1138
+ self.o2m_copy = self.o2m
1139
+ # final gain
1140
+ self.final_o2m = 0.1
1141
+
1142
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1143
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
1144
+ preds = self.one2many.parse_output(preds)
1145
+ one2many, one2one = preds["one2many"], preds["one2one"]
1146
+ loss_one2many = self.one2many.loss(one2many, batch)
1147
+ loss_one2one = self.one2one.loss(one2one, batch)
1148
+ return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
1149
+
1150
+ def update(self) -> None:
1151
+ """Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
1152
+ self.updates += 1
1153
+ self.o2m = self.decay(self.updates)
1154
+ self.o2o = max(self.total - self.o2m, 0)
1155
+
1156
+ def decay(self, x) -> float:
1157
+ """Calculate the decayed weight for one-to-many loss based on the current update step."""
1158
+ return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
1159
+
1160
+
792
1161
  class TVPDetectLoss:
793
1162
  """Criterion class for computing training losses for text-visual prompt detection."""
794
1163
 
795
- def __init__(self, model):
1164
+ def __init__(self, model, tal_topk=10):
796
1165
  """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
797
- self.vp_criterion = v8DetectionLoss(model)
1166
+ self.vp_criterion = v8DetectionLoss(model, tal_topk)
798
1167
  # NOTE: store following info as it's changeable in __call__
1168
+ self.hyp = self.vp_criterion.hyp
799
1169
  self.ori_nc = self.vp_criterion.nc
800
1170
  self.ori_no = self.vp_criterion.no
801
1171
  self.ori_reg_max = self.vp_criterion.reg_max
802
1172
 
1173
+ def parse_output(self, preds) -> dict[str, torch.Tensor]:
1174
+ """Parse model predictions to extract features."""
1175
+ return self.vp_criterion.parse_output(preds)
1176
+
803
1177
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
804
1178
  """Calculate the loss for text-visual prompt detection."""
805
- feats = preds[1] if isinstance(preds, tuple) else preds
1179
+ return self.loss(self.parse_output(preds), batch)
1180
+
1181
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1182
+ """Calculate the loss for text-visual prompt detection."""
1183
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
806
1184
 
807
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1185
+ if self.ori_nc == preds["scores"].shape[1]:
808
1186
  loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
809
1187
  return loss, loss.detach()
810
1188
 
811
- vp_feats = self._get_vp_features(feats)
812
- vp_loss = self.vp_criterion(vp_feats, batch)
813
- cls_loss = vp_loss[0][1]
814
- return cls_loss, vp_loss[1]
1189
+ preds["scores"] = self._get_vp_features(preds)
1190
+ vp_loss = self.vp_criterion(preds, batch)
1191
+ box_loss = vp_loss[0][1]
1192
+ return box_loss, vp_loss[1]
815
1193
 
816
- def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
1194
+ def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
817
1195
  """Extract visual-prompt features from the model output."""
818
- vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
1196
+ # NOTE: remove empty placeholder
1197
+ scores = preds["scores"][:, self.ori_nc :, :]
1198
+ vnc = scores.shape[1]
819
1199
 
820
1200
  self.vp_criterion.nc = vnc
821
1201
  self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
822
1202
  self.vp_criterion.assigner.num_classes = vnc
823
-
824
- return [
825
- torch.cat((box, cls_vp), dim=1)
826
- for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
827
- ]
1203
+ return scores
828
1204
 
829
1205
 
830
1206
  class TVPSegmentLoss(TVPDetectLoss):
831
1207
  """Criterion class for computing training losses for text-visual prompt segmentation."""
832
1208
 
833
- def __init__(self, model):
1209
+ def __init__(self, model, tal_topk=10):
834
1210
  """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
835
1211
  super().__init__(model)
836
- self.vp_criterion = v8SegmentationLoss(model)
1212
+ self.vp_criterion = v8SegmentationLoss(model, tal_topk)
1213
+ self.hyp = self.vp_criterion.hyp
837
1214
 
838
1215
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
839
1216
  """Calculate the loss for text-visual prompt segmentation."""
840
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
1217
+ return self.loss(self.parse_output(preds), batch)
1218
+
1219
+ def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1220
+ """Calculate the loss for text-visual prompt detection."""
1221
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
841
1222
 
842
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1223
+ if self.ori_nc == preds["scores"].shape[1]:
843
1224
  loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
844
1225
  return loss, loss.detach()
845
1226
 
846
- vp_feats = self._get_vp_features(feats)
847
- vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
1227
+ preds["scores"] = self._get_vp_features(preds)
1228
+ vp_loss = self.vp_criterion(preds, batch)
848
1229
  cls_loss = vp_loss[0][2]
849
1230
  return cls_loss, vp_loss[1]