dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl → 8.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/METADATA +41 -49
  2. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/RECORD +85 -74
  3. tests/__init__.py +2 -2
  4. tests/conftest.py +1 -1
  5. tests/test_cuda.py +8 -2
  6. tests/test_engine.py +8 -8
  7. tests/test_exports.py +11 -4
  8. tests/test_integrations.py +9 -9
  9. tests/test_python.py +14 -14
  10. tests/test_solutions.py +3 -3
  11. ultralytics/__init__.py +1 -1
  12. ultralytics/cfg/__init__.py +25 -27
  13. ultralytics/cfg/default.yaml +3 -1
  14. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  15. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  16. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  17. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  18. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  19. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  20. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  21. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  22. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  23. ultralytics/data/annotator.py +2 -2
  24. ultralytics/data/augment.py +7 -0
  25. ultralytics/data/converter.py +57 -38
  26. ultralytics/data/dataset.py +1 -1
  27. ultralytics/engine/exporter.py +31 -26
  28. ultralytics/engine/model.py +34 -34
  29. ultralytics/engine/predictor.py +17 -17
  30. ultralytics/engine/results.py +14 -12
  31. ultralytics/engine/trainer.py +59 -29
  32. ultralytics/engine/tuner.py +19 -11
  33. ultralytics/engine/validator.py +16 -16
  34. ultralytics/models/fastsam/predict.py +1 -1
  35. ultralytics/models/yolo/classify/predict.py +1 -1
  36. ultralytics/models/yolo/classify/train.py +1 -1
  37. ultralytics/models/yolo/classify/val.py +1 -1
  38. ultralytics/models/yolo/detect/predict.py +2 -2
  39. ultralytics/models/yolo/detect/train.py +4 -3
  40. ultralytics/models/yolo/detect/val.py +7 -1
  41. ultralytics/models/yolo/model.py +8 -8
  42. ultralytics/models/yolo/obb/predict.py +2 -2
  43. ultralytics/models/yolo/obb/train.py +3 -3
  44. ultralytics/models/yolo/obb/val.py +1 -1
  45. ultralytics/models/yolo/pose/predict.py +1 -1
  46. ultralytics/models/yolo/pose/train.py +3 -1
  47. ultralytics/models/yolo/pose/val.py +1 -1
  48. ultralytics/models/yolo/segment/predict.py +3 -3
  49. ultralytics/models/yolo/segment/train.py +4 -4
  50. ultralytics/models/yolo/segment/val.py +4 -2
  51. ultralytics/models/yolo/yoloe/train.py +6 -1
  52. ultralytics/models/yolo/yoloe/train_seg.py +6 -1
  53. ultralytics/nn/autobackend.py +5 -5
  54. ultralytics/nn/modules/__init__.py +8 -0
  55. ultralytics/nn/modules/block.py +128 -8
  56. ultralytics/nn/modules/head.py +788 -203
  57. ultralytics/nn/tasks.py +86 -41
  58. ultralytics/nn/text_model.py +5 -2
  59. ultralytics/optim/__init__.py +5 -0
  60. ultralytics/optim/muon.py +338 -0
  61. ultralytics/solutions/ai_gym.py +3 -3
  62. ultralytics/solutions/config.py +1 -1
  63. ultralytics/solutions/heatmap.py +1 -1
  64. ultralytics/solutions/instance_segmentation.py +2 -2
  65. ultralytics/solutions/parking_management.py +1 -1
  66. ultralytics/solutions/solutions.py +2 -2
  67. ultralytics/trackers/track.py +1 -1
  68. ultralytics/utils/__init__.py +8 -8
  69. ultralytics/utils/benchmarks.py +23 -23
  70. ultralytics/utils/callbacks/platform.py +11 -7
  71. ultralytics/utils/checks.py +6 -6
  72. ultralytics/utils/downloads.py +5 -3
  73. ultralytics/utils/export/engine.py +19 -10
  74. ultralytics/utils/export/imx.py +19 -13
  75. ultralytics/utils/export/tensorflow.py +21 -21
  76. ultralytics/utils/files.py +2 -2
  77. ultralytics/utils/loss.py +587 -203
  78. ultralytics/utils/metrics.py +1 -0
  79. ultralytics/utils/ops.py +11 -2
  80. ultralytics/utils/tal.py +98 -19
  81. ultralytics/utils/tuner.py +2 -2
  82. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/WHEEL +0 -0
  83. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/entry_points.txt +0 -0
  84. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/licenses/LICENSE +0 -0
  85. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py CHANGED
@@ -2,19 +2,20 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import math
5
6
  from typing import Any
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
9
10
  import torch.nn.functional as F
10
11
 
11
- from ultralytics.utils.metrics import OKS_SIGMA
12
+ from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
12
13
  from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
13
14
  from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
14
15
  from ultralytics.utils.torch_utils import autocast
15
16
 
16
17
  from .metrics import bbox_iou, probiou
17
- from .tal import bbox2dist
18
+ from .tal import bbox2dist, rbox2dist
18
19
 
19
20
 
20
21
  class VarifocalLoss(nn.Module):
@@ -122,6 +123,8 @@ class BboxLoss(nn.Module):
122
123
  target_scores: torch.Tensor,
123
124
  target_scores_sum: torch.Tensor,
124
125
  fg_mask: torch.Tensor,
126
+ imgsz: torch.Tensor,
127
+ stride: torch.Tensor,
125
128
  ) -> tuple[torch.Tensor, torch.Tensor]:
126
129
  """Compute IoU and DFL losses for bounding boxes."""
127
130
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -134,11 +137,76 @@ class BboxLoss(nn.Module):
134
137
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
135
138
  loss_dfl = loss_dfl.sum() / target_scores_sum
136
139
  else:
137
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
140
+ target_ltrb = bbox2dist(anchor_points, target_bboxes)
141
+ # normalize ltrb by image size
142
+ target_ltrb = target_ltrb * stride
143
+ target_ltrb[..., 0::2] /= imgsz[1]
144
+ target_ltrb[..., 1::2] /= imgsz[0]
145
+ pred_dist = pred_dist * stride
146
+ pred_dist[..., 0::2] /= imgsz[1]
147
+ pred_dist[..., 1::2] /= imgsz[0]
148
+ loss_dfl = (
149
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
150
+ )
151
+ loss_dfl = loss_dfl.sum() / target_scores_sum
138
152
 
139
153
  return loss_iou, loss_dfl
140
154
 
141
155
 
156
+ class RLELoss(nn.Module):
157
+ """Residual Log-Likelihood Estimation Loss.
158
+
159
+ Args:
160
+ use_target_weight (bool): Option to use weighted loss.
161
+ size_average (bool): Option to average the loss by the batch_size.
162
+ residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
163
+
164
+ References:
165
+ https://arxiv.org/abs/2107.11291
166
+ https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/losses/regression_loss.py
167
+ """
168
+
169
+ def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
170
+ """Initialize RLELoss with target weight and residual options.
171
+
172
+ Args:
173
+ use_target_weight (bool): Whether to use target weights for loss calculation.
174
+ size_average (bool): Whether to average the loss over elements.
175
+ residual (bool): Whether to include residual log-likelihood term.
176
+ """
177
+ super().__init__()
178
+ self.size_average = size_average
179
+ self.use_target_weight = use_target_weight
180
+ self.residual = residual
181
+
182
+ def forward(
183
+ self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
184
+ ) -> torch.Tensor:
185
+ """
186
+ Args:
187
+ sigma (torch.Tensor): Output sigma, shape (N, D).
188
+ log_phi (torch.Tensor): Output log_phi, shape (N).
189
+ error (torch.Tensor): Error, shape (N, D).
190
+ target_weight (torch.Tensor): Weights across different joint types, shape (N).
191
+ """
192
+ log_sigma = torch.log(sigma)
193
+ loss = log_sigma - log_phi.unsqueeze(1)
194
+
195
+ if self.residual:
196
+ loss += torch.log(sigma * 2) + torch.abs(error)
197
+
198
+ if self.use_target_weight:
199
+ assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
200
+ if target_weight.dim() == 1:
201
+ target_weight = target_weight.unsqueeze(1)
202
+ loss *= target_weight
203
+
204
+ if self.size_average:
205
+ loss /= len(loss)
206
+
207
+ return loss.sum()
208
+
209
+
142
210
  class RotatedBboxLoss(BboxLoss):
143
211
  """Criterion class for computing training losses for rotated bounding boxes."""
144
212
 
@@ -155,6 +223,8 @@ class RotatedBboxLoss(BboxLoss):
155
223
  target_scores: torch.Tensor,
156
224
  target_scores_sum: torch.Tensor,
157
225
  fg_mask: torch.Tensor,
226
+ imgsz: torch.Tensor,
227
+ stride: torch.Tensor,
158
228
  ) -> tuple[torch.Tensor, torch.Tensor]:
159
229
  """Compute IoU and DFL losses for rotated bounding boxes."""
160
230
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -163,15 +233,84 @@ class RotatedBboxLoss(BboxLoss):
163
233
 
164
234
  # DFL loss
165
235
  if self.dfl_loss:
166
- target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
236
+ target_ltrb = rbox2dist(
237
+ target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
238
+ )
167
239
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
168
240
  loss_dfl = loss_dfl.sum() / target_scores_sum
169
241
  else:
170
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
242
+ target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
243
+ target_ltrb = target_ltrb * stride
244
+ target_ltrb[..., 0::2] /= imgsz[1]
245
+ target_ltrb[..., 1::2] /= imgsz[0]
246
+ pred_dist = pred_dist * stride
247
+ pred_dist[..., 0::2] /= imgsz[1]
248
+ pred_dist[..., 1::2] /= imgsz[0]
249
+ loss_dfl = (
250
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
251
+ )
252
+ loss_dfl = loss_dfl.sum() / target_scores_sum
171
253
 
172
254
  return loss_iou, loss_dfl
173
255
 
174
256
 
257
+ class MultiChannelDiceLoss(nn.Module):
258
+ """Criterion class for computing multi-channel Dice losses."""
259
+
260
+ def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
261
+ """Initialize MultiChannelDiceLoss with smoothing and reduction options.
262
+
263
+ Args:
264
+ smooth (float): Smoothing factor to avoid division by zero.
265
+ reduction (str): Reduction method ('mean', 'sum', or 'none').
266
+ """
267
+ super().__init__()
268
+ self.smooth = smooth
269
+ self.reduction = reduction
270
+
271
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
272
+ """Calculate multi-channel Dice loss between predictions and targets."""
273
+ assert pred.size() == target.size(), "the size of predict and target must be equal."
274
+
275
+ pred = pred.sigmoid()
276
+ intersection = (pred * target).sum(dim=(2, 3))
277
+ union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
278
+ dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
279
+ dice_loss = 1.0 - dice
280
+ dice_loss = dice_loss.mean(dim=1)
281
+
282
+ if self.reduction == "mean":
283
+ return dice_loss.mean()
284
+ elif self.reduction == "sum":
285
+ return dice_loss.sum()
286
+ else:
287
+ return dice_loss
288
+
289
+
290
+ class BCEDiceLoss(nn.Module):
291
+ """Criterion class for computing combined BCE and Dice losses."""
292
+
293
+ def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
294
+ """Initialize BCEDiceLoss with BCE and Dice weight factors.
295
+
296
+ Args:
297
+ weight_bce (float): Weight factor for BCE loss component.
298
+ weight_dice (float): Weight factor for Dice loss component.
299
+ """
300
+ super().__init__()
301
+ self.weight_bce = weight_bce
302
+ self.weight_dice = weight_dice
303
+ self.bce = nn.BCEWithLogitsLoss()
304
+ self.dice = MultiChannelDiceLoss(smooth=1)
305
+
306
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
307
+ """Calculate combined BCE and Dice loss between predictions and targets."""
308
+ _, _, mask_h, mask_w = pred.shape
309
+ if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
310
+ target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
311
+ return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
312
+
313
+
175
314
  class KeypointLoss(nn.Module):
176
315
  """Criterion class for computing keypoint losses."""
177
316
 
@@ -194,7 +333,7 @@ class KeypointLoss(nn.Module):
194
333
  class v8DetectionLoss:
195
334
  """Criterion class for computing training losses for YOLOv8 object detection."""
196
335
 
197
- def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
336
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
198
337
  """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
199
338
  device = next(model.parameters()).device # get model device
200
339
  h = model.args # hyperparameters
@@ -210,7 +349,14 @@ class v8DetectionLoss:
210
349
 
211
350
  self.use_dfl = m.reg_max > 1
212
351
 
213
- self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
352
+ self.assigner = TaskAlignedAssigner(
353
+ topk=tal_topk,
354
+ num_classes=self.nc,
355
+ alpha=0.5,
356
+ beta=6.0,
357
+ stride=self.stride.tolist(),
358
+ topk2=tal_topk2,
359
+ )
214
360
  self.bbox_loss = BboxLoss(m.reg_max).to(device)
215
361
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
216
362
 
@@ -240,35 +386,31 @@ class v8DetectionLoss:
240
386
  # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
241
387
  return dist2bbox(pred_dist, anchor_points, xywh=False)
242
388
 
243
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
244
- """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
389
+ def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
390
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
391
+ target indices.
392
+ """
245
393
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
246
- feats = preds[1] if isinstance(preds, tuple) else preds
247
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
248
- (self.reg_max * 4, self.nc), 1
394
+ pred_distri, pred_scores = (
395
+ preds["boxes"].permute(0, 2, 1).contiguous(),
396
+ preds["scores"].permute(0, 2, 1).contiguous(),
249
397
  )
250
-
251
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
252
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
398
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
253
399
 
254
400
  dtype = pred_scores.dtype
255
401
  batch_size = pred_scores.shape[0]
256
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
257
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
402
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
258
403
 
259
404
  # Targets
260
405
  targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
261
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
406
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
262
407
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
263
408
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
264
409
 
265
410
  # Pboxes
266
411
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
267
- # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
268
- # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
269
412
 
270
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
271
- # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
413
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
272
414
  pred_scores.detach().sigmoid(),
273
415
  (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
274
416
  anchor_points * stride_tensor,
@@ -280,7 +422,6 @@ class v8DetectionLoss:
280
422
  target_scores_sum = max(target_scores.sum(), 1)
281
423
 
282
424
  # Cls loss
283
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
284
425
  loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
285
426
 
286
427
  # Bbox loss
@@ -293,105 +434,98 @@ class v8DetectionLoss:
293
434
  target_scores,
294
435
  target_scores_sum,
295
436
  fg_mask,
437
+ imgsz,
438
+ stride_tensor,
296
439
  )
297
440
 
298
441
  loss[0] *= self.hyp.box # box gain
299
442
  loss[1] *= self.hyp.cls # cls gain
300
443
  loss[2] *= self.hyp.dfl # dfl gain
444
+ return (
445
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
446
+ loss,
447
+ loss.detach(),
448
+ ) # loss(box, cls, dfl)
301
449
 
302
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
450
+ def parse_output(
451
+ self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
452
+ ) -> torch.Tensor:
453
+ """Parse model predictions to extract features."""
454
+ return preds[1] if isinstance(preds, tuple) else preds
455
+
456
+ def __call__(
457
+ self,
458
+ preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
459
+ batch: dict[str, torch.Tensor],
460
+ ) -> tuple[torch.Tensor, torch.Tensor]:
461
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
462
+ return self.loss(self.parse_output(preds), batch)
463
+
464
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
465
+ """A wrapper for get_assigned_targets_and_loss and parse_output."""
466
+ batch_size = preds["boxes"].shape[0]
467
+ loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
468
+ return loss * batch_size, loss_detach
303
469
 
304
470
 
305
471
  class v8SegmentationLoss(v8DetectionLoss):
306
472
  """Criterion class for computing training losses for YOLOv8 segmentation."""
307
473
 
308
- def __init__(self, model): # model must be de-paralleled
474
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
309
475
  """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
310
- super().__init__(model)
476
+ super().__init__(model, tal_topk, tal_topk2)
311
477
  self.overlap = model.args.overlap_mask
478
+ self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
312
479
 
313
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
480
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
314
481
  """Calculate and return the combined loss for detection and segmentation."""
315
- loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
316
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
317
- batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
318
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
319
- (self.reg_max * 4, self.nc), 1
320
- )
321
-
322
- # B, grids, ..
323
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
324
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
325
- pred_masks = pred_masks.permute(0, 2, 1).contiguous()
326
-
327
- dtype = pred_scores.dtype
328
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
329
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
330
-
331
- # Targets
332
- try:
333
- batch_idx = batch["batch_idx"].view(-1, 1)
334
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
335
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
336
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
337
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
338
- except RuntimeError as e:
339
- raise TypeError(
340
- "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
341
- "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
342
- "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
343
- "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
344
- "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
345
- ) from e
346
-
347
- # Pboxes
348
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
349
-
350
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
351
- pred_scores.detach().sigmoid(),
352
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
353
- anchor_points * stride_tensor,
354
- gt_labels,
355
- gt_bboxes,
356
- mask_gt,
357
- )
358
-
359
- target_scores_sum = max(target_scores.sum(), 1)
360
-
361
- # Cls loss
362
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
363
- loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
482
+ pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
483
+ loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
484
+ if isinstance(proto, tuple) and len(proto) == 2:
485
+ proto, pred_semseg = proto
486
+ else:
487
+ pred_semseg = None
488
+ (fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
489
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
490
+ loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
364
491
 
492
+ batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
365
493
  if fg_mask.sum():
366
- # Bbox loss
367
- loss[0], loss[3] = self.bbox_loss(
368
- pred_distri,
369
- pred_bboxes,
370
- anchor_points,
371
- target_bboxes / stride_tensor,
372
- target_scores,
373
- target_scores_sum,
374
- fg_mask,
375
- )
376
494
  # Masks loss
377
495
  masks = batch["masks"].to(self.device).float()
378
496
  if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
379
- masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
497
+ # masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
498
+ proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
380
499
 
500
+ imgsz = (
501
+ torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
502
+ )
381
503
  loss[1] = self.calculate_segmentation_loss(
382
- fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
504
+ fg_mask,
505
+ masks,
506
+ target_gt_idx,
507
+ target_bboxes,
508
+ batch["batch_idx"].view(-1, 1),
509
+ proto,
510
+ pred_masks,
511
+ imgsz,
383
512
  )
513
+ if pred_semseg is not None:
514
+ sem_masks = batch["sem_masks"].to(self.device) # NxHxW
515
+ mask_zero = sem_masks == 0 # NxHxW
516
+ sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
517
+ sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
518
+ loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
519
+ loss[4] *= self.hyp.box # seg gain
384
520
 
385
521
  # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
386
522
  else:
387
523
  loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
524
+ if pred_semseg is not None:
525
+ loss[4] += (pred_semseg * 0).sum()
388
526
 
389
- loss[0] *= self.hyp.box # box gain
390
527
  loss[1] *= self.hyp.box # seg gain
391
- loss[2] *= self.hyp.cls # cls gain
392
- loss[3] *= self.hyp.dfl # dfl gain
393
-
394
- return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
528
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
395
529
 
396
530
  @staticmethod
397
531
  def single_mask_loss(
@@ -427,7 +561,6 @@ class v8SegmentationLoss(v8DetectionLoss):
427
561
  proto: torch.Tensor,
428
562
  pred_masks: torch.Tensor,
429
563
  imgsz: torch.Tensor,
430
- overlap: bool,
431
564
  ) -> torch.Tensor:
432
565
  """Calculate the loss for instance segmentation.
433
566
 
@@ -440,7 +573,6 @@ class v8SegmentationLoss(v8DetectionLoss):
440
573
  proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
441
574
  pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
442
575
  imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
443
- overlap (bool): Whether the masks in `masks` tensor overlap.
444
576
 
445
577
  Returns:
446
578
  (torch.Tensor): The calculated loss for instance segmentation.
@@ -466,7 +598,7 @@ class v8SegmentationLoss(v8DetectionLoss):
466
598
  fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
467
599
  if fg_mask_i.any():
468
600
  mask_idx = target_gt_idx_i[fg_mask_i]
469
- if overlap:
601
+ if self.overlap:
470
602
  gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
471
603
  gt_mask = gt_mask.float()
472
604
  else:
@@ -486,9 +618,9 @@ class v8SegmentationLoss(v8DetectionLoss):
486
618
  class v8PoseLoss(v8DetectionLoss):
487
619
  """Criterion class for computing training losses for YOLOv8 pose estimation."""
488
620
 
489
- def __init__(self, model): # model must be de-paralleled
621
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
490
622
  """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
491
- super().__init__(model)
623
+ super().__init__(model, tal_topk, tal_topk2)
492
624
  self.kpt_shape = model.model[-1].kpt_shape
493
625
  self.bce_pose = nn.BCEWithLogitsLoss()
494
626
  is_pose = self.kpt_shape == [17, 3]
@@ -496,69 +628,40 @@ class v8PoseLoss(v8DetectionLoss):
496
628
  sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
497
629
  self.keypoint_loss = KeypointLoss(sigmas=sigmas)
498
630
 
499
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
631
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
500
632
  """Calculate the total loss and detach it for pose estimation."""
501
- loss = torch.zeros(5, device=self.device) # box, pose, kobj, cls, dfl
502
- feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
503
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
504
- (self.reg_max * 4, self.nc), 1
633
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
634
+ loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
635
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
636
+ self.get_assigned_targets_and_loss(preds, batch)
505
637
  )
638
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
639
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
506
640
 
507
- # B, grids, ..
508
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
509
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
510
- pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
511
-
512
- dtype = pred_scores.dtype
513
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
514
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
515
-
516
- # Targets
517
- batch_size = pred_scores.shape[0]
518
- batch_idx = batch["batch_idx"].view(-1, 1)
519
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
520
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
521
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
522
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
641
+ batch_size = pred_kpts.shape[0]
642
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
523
643
 
524
644
  # Pboxes
525
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
526
645
  pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
527
646
 
528
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
529
- pred_scores.detach().sigmoid(),
530
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
531
- anchor_points * stride_tensor,
532
- gt_labels,
533
- gt_bboxes,
534
- mask_gt,
535
- )
536
-
537
- target_scores_sum = max(target_scores.sum(), 1)
538
-
539
- # Cls loss
540
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
541
- loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
542
-
543
647
  # Bbox loss
544
648
  if fg_mask.sum():
545
- target_bboxes /= stride_tensor
546
- loss[0], loss[4] = self.bbox_loss(
547
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
548
- )
549
649
  keypoints = batch["keypoints"].to(self.device).float().clone()
550
650
  keypoints[..., 0] *= imgsz[1]
551
651
  keypoints[..., 1] *= imgsz[0]
552
652
 
553
653
  loss[1], loss[2] = self.calculate_keypoints_loss(
554
- fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
654
+ fg_mask,
655
+ target_gt_idx,
656
+ keypoints,
657
+ batch["batch_idx"].view(-1, 1),
658
+ stride_tensor,
659
+ target_bboxes,
660
+ pred_kpts,
555
661
  )
556
662
 
557
- loss[0] *= self.hyp.box # box gain
558
663
  loss[1] *= self.hyp.pose # pose gain
559
664
  loss[2] *= self.hyp.kobj # kobj gain
560
- loss[3] *= self.hyp.cls # cls gain
561
- loss[4] *= self.hyp.dfl # dfl gain
562
665
 
563
666
  return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
564
667
 
@@ -571,34 +674,23 @@ class v8PoseLoss(v8DetectionLoss):
571
674
  y[..., 1] += anchor_points[:, [1]] - 0.5
572
675
  return y
573
676
 
574
- def calculate_keypoints_loss(
677
+ def _select_target_keypoints(
575
678
  self,
576
- masks: torch.Tensor,
577
- target_gt_idx: torch.Tensor,
578
679
  keypoints: torch.Tensor,
579
680
  batch_idx: torch.Tensor,
580
- stride_tensor: torch.Tensor,
581
- target_bboxes: torch.Tensor,
582
- pred_kpts: torch.Tensor,
583
- ) -> tuple[torch.Tensor, torch.Tensor]:
584
- """Calculate the keypoints loss for the model.
585
-
586
- This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
587
- based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
588
- a binary classification loss that classifies whether a keypoint is present or not.
681
+ target_gt_idx: torch.Tensor,
682
+ masks: torch.Tensor,
683
+ ) -> torch.Tensor:
684
+ """Select target keypoints for each anchor based on batch index and target ground truth index.
589
685
 
590
686
  Args:
591
- masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
592
- target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
593
687
  keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
594
688
  batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
595
- stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
596
- target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
597
- pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
689
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
690
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
598
691
 
599
692
  Returns:
600
- kpts_loss (torch.Tensor): The keypoints loss.
601
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
693
+ (torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
602
694
  """
603
695
  batch_idx = batch_idx.flatten()
604
696
  batch_size = len(masks)
@@ -625,6 +717,40 @@ class v8PoseLoss(v8DetectionLoss):
625
717
  1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
626
718
  )
627
719
 
720
+ return selected_keypoints
721
+
722
+ def calculate_keypoints_loss(
723
+ self,
724
+ masks: torch.Tensor,
725
+ target_gt_idx: torch.Tensor,
726
+ keypoints: torch.Tensor,
727
+ batch_idx: torch.Tensor,
728
+ stride_tensor: torch.Tensor,
729
+ target_bboxes: torch.Tensor,
730
+ pred_kpts: torch.Tensor,
731
+ ) -> tuple[torch.Tensor, torch.Tensor]:
732
+ """Calculate the keypoints loss for the model.
733
+
734
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
735
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
736
+ a binary classification loss that classifies whether a keypoint is present or not.
737
+
738
+ Args:
739
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
740
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
741
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
742
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
743
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
744
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
745
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
746
+
747
+ Returns:
748
+ kpts_loss (torch.Tensor): The keypoints loss.
749
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
750
+ """
751
+ # Select target keypoints using helper method
752
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
753
+
628
754
  # Divide coordinates by stride
629
755
  selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
630
756
 
@@ -632,6 +758,7 @@ class v8PoseLoss(v8DetectionLoss):
632
758
  kpts_obj_loss = 0
633
759
 
634
760
  if masks.any():
761
+ target_bboxes /= stride_tensor
635
762
  gt_kpt = selected_keypoints[masks]
636
763
  area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
637
764
  pred_kpt = pred_kpts[masks]
@@ -644,6 +771,172 @@ class v8PoseLoss(v8DetectionLoss):
644
771
  return kpts_loss, kpts_obj_loss
645
772
 
646
773
 
774
+ class PoseLoss26(v8PoseLoss):
775
+ """Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
776
+
777
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
778
+ """Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
779
+ super().__init__(model, tal_topk, tal_topk2)
780
+ is_pose = self.kpt_shape == [17, 3]
781
+ nkpt = self.kpt_shape[0] # number of keypoints
782
+ self.rle_loss = None
783
+ self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
784
+ if self.flow_model is not None:
785
+ self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
786
+ self.target_weights = (
787
+ torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
788
+ )
789
+
790
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
791
+ """Calculate the total loss and detach it for pose estimation."""
792
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
793
+ loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
794
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
795
+ self.get_assigned_targets_and_loss(preds, batch)
796
+ )
797
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
798
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
799
+
800
+ batch_size = pred_kpts.shape[0]
801
+ imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=pred_kpts.dtype) # image size (h,w)
802
+
803
+ pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
804
+
805
+ if self.rle_loss and preds.get("kpts_sigma", None) is not None:
806
+ pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
807
+ pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
808
+ pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
809
+
810
+ pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
811
+
812
+ # Bbox loss
813
+ if fg_mask.sum():
814
+ keypoints = batch["keypoints"].to(self.device).float().clone()
815
+ keypoints[..., 0] *= imgsz[1]
816
+ keypoints[..., 1] *= imgsz[0]
817
+
818
+ keypoints_loss = self.calculate_keypoints_loss(
819
+ fg_mask,
820
+ target_gt_idx,
821
+ keypoints,
822
+ batch["batch_idx"].view(-1, 1),
823
+ stride_tensor,
824
+ target_bboxes,
825
+ pred_kpts,
826
+ )
827
+ loss[1] = keypoints_loss[0]
828
+ loss[2] = keypoints_loss[1]
829
+ if self.rle_loss is not None:
830
+ loss[5] = keypoints_loss[2]
831
+
832
+ loss[1] *= self.hyp.pose # pose gain
833
+ loss[2] *= self.hyp.kobj # kobj gain
834
+ if self.rle_loss is not None:
835
+ loss[5] *= self.hyp.rle # rle gain
836
+
837
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl, kpt_location, kpt_visibility)
838
+
839
+ @staticmethod
840
+ def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
841
+ """Decode predicted keypoints to image coordinates."""
842
+ y = pred_kpts.clone()
843
+ y[..., 0] += anchor_points[:, [0]]
844
+ y[..., 1] += anchor_points[:, [1]]
845
+ return y
846
+
847
+ def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
848
+ """Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
849
+
850
+ Args:
851
+ pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
852
+ gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
853
+ kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
854
+
855
+ Returns:
856
+ (torch.Tensor): The RLE loss.
857
+ """
858
+ pred_kpt_visible = pred_kpt[kpt_mask]
859
+ gt_kpt_visible = gt_kpt[kpt_mask]
860
+ pred_coords = pred_kpt_visible[:, 0:2]
861
+ pred_sigma = pred_kpt_visible[:, -2:]
862
+ gt_coords = gt_kpt_visible[:, 0:2]
863
+
864
+ target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
865
+ target_weights = target_weights[kpt_mask]
866
+
867
+ pred_sigma = pred_sigma.sigmoid()
868
+ error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
869
+
870
+ # Filter out NaN and Inf values to prevent MultivariateNormal validation errors
871
+ valid_mask = ~(torch.isnan(error) | torch.isinf(error)).any(dim=-1)
872
+ if not valid_mask.any():
873
+ return torch.tensor(0.0, device=pred_kpt.device)
874
+
875
+ error = error[valid_mask]
876
+ error = error.clamp(-100, 100) # Prevent numerical instability
877
+ pred_sigma = pred_sigma[valid_mask]
878
+ target_weights = target_weights[valid_mask]
879
+
880
+ log_phi = self.flow_model.log_prob(error)
881
+
882
+ return self.rle_loss(pred_sigma, log_phi, error, target_weights)
883
+
884
+ def calculate_keypoints_loss(
885
+ self,
886
+ masks: torch.Tensor,
887
+ target_gt_idx: torch.Tensor,
888
+ keypoints: torch.Tensor,
889
+ batch_idx: torch.Tensor,
890
+ stride_tensor: torch.Tensor,
891
+ target_bboxes: torch.Tensor,
892
+ pred_kpts: torch.Tensor,
893
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
894
+ """Calculate the keypoints loss for the model.
895
+
896
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
897
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
898
+ a binary classification loss that classifies whether a keypoint is present or not.
899
+
900
+ Args:
901
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
902
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
903
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
904
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
905
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
906
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
907
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
908
+
909
+ Returns:
910
+ kpts_loss (torch.Tensor): The keypoints loss.
911
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
912
+ rle_loss (torch.Tensor): The RLE loss.
913
+ """
914
+ # Select target keypoints using inherited helper method
915
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
916
+
917
+ # Divide coordinates by stride
918
+ selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
919
+
920
+ kpts_loss = 0
921
+ kpts_obj_loss = 0
922
+ rle_loss = 0
923
+
924
+ if masks.any():
925
+ target_bboxes /= stride_tensor
926
+ gt_kpt = selected_keypoints[masks]
927
+ area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
928
+ pred_kpt = pred_kpts[masks]
929
+ kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
930
+ kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
931
+
932
+ if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
933
+ rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
934
+ if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
935
+ kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
936
+
937
+ return kpts_loss, kpts_obj_loss, rle_loss
938
+
939
+
647
940
  class v8ClassificationLoss:
648
941
  """Criterion class for computing training losses for classification."""
649
942
 
@@ -657,10 +950,17 @@ class v8ClassificationLoss:
657
950
  class v8OBBLoss(v8DetectionLoss):
658
951
  """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
659
952
 
660
- def __init__(self, model):
953
+ def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
661
954
  """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
662
- super().__init__(model)
663
- self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
955
+ super().__init__(model, tal_topk=tal_topk)
956
+ self.assigner = RotatedTaskAlignedAssigner(
957
+ topk=tal_topk,
958
+ num_classes=self.nc,
959
+ alpha=0.5,
960
+ beta=6.0,
961
+ stride=self.stride.tolist(),
962
+ topk2=tal_topk2,
963
+ )
664
964
  self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
665
965
 
666
966
  def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
@@ -680,23 +980,19 @@ class v8OBBLoss(v8DetectionLoss):
680
980
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
681
981
  return out
682
982
 
683
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
983
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
684
984
  """Calculate and return the loss for oriented bounding box detection."""
685
- loss = torch.zeros(3, device=self.device) # box, cls, dfl
686
- feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
687
- batch_size = pred_angle.shape[0] # batch size
688
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
689
- (self.reg_max * 4, self.nc), 1
985
+ loss = torch.zeros(4, device=self.device) # box, cls, dfl, angle
986
+ pred_distri, pred_scores, pred_angle = (
987
+ preds["boxes"].permute(0, 2, 1).contiguous(),
988
+ preds["scores"].permute(0, 2, 1).contiguous(),
989
+ preds["angle"].permute(0, 2, 1).contiguous(),
690
990
  )
691
-
692
- # b, grids, ..
693
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
694
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
695
- pred_angle = pred_angle.permute(0, 2, 1).contiguous()
991
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
992
+ batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
696
993
 
697
994
  dtype = pred_scores.dtype
698
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
699
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
995
+ imgsz = torch.tensor(batch["resized_shape"][0], device=self.device, dtype=dtype) # image size (h,w)
700
996
 
701
997
  # targets
702
998
  try:
@@ -704,14 +1000,14 @@ class v8OBBLoss(v8DetectionLoss):
704
1000
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
705
1001
  rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
706
1002
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
707
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
1003
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
708
1004
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
709
1005
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
710
1006
  except RuntimeError as e:
711
1007
  raise TypeError(
712
1008
  "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
713
1009
  "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
714
- "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
1010
+ "i.e. 'yolo train model=yolo26n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
715
1011
  "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
716
1012
  "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
717
1013
  ) from e
@@ -741,16 +1037,29 @@ class v8OBBLoss(v8DetectionLoss):
741
1037
  if fg_mask.sum():
742
1038
  target_bboxes[..., :4] /= stride_tensor
743
1039
  loss[0], loss[2] = self.bbox_loss(
744
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
1040
+ pred_distri,
1041
+ pred_bboxes,
1042
+ anchor_points,
1043
+ target_bboxes,
1044
+ target_scores,
1045
+ target_scores_sum,
1046
+ fg_mask,
1047
+ imgsz,
1048
+ stride_tensor,
745
1049
  )
1050
+ weight = target_scores.sum(-1)[fg_mask]
1051
+ loss[3] = self.calculate_angle_loss(
1052
+ pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
1053
+ ) # angle loss
746
1054
  else:
747
1055
  loss[0] += (pred_angle * 0).sum()
748
1056
 
749
1057
  loss[0] *= self.hyp.box # box gain
750
1058
  loss[1] *= self.hyp.cls # cls gain
751
1059
  loss[2] *= self.hyp.dfl # dfl gain
1060
+ loss[3] *= self.hyp.angle # angle gain
752
1061
 
753
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
1062
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
754
1063
 
755
1064
  def bbox_decode(
756
1065
  self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
@@ -770,6 +1079,34 @@ class v8OBBLoss(v8DetectionLoss):
770
1079
  pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
771
1080
  return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
772
1081
 
1082
+ def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
1083
+ """Calculate oriented angle loss.
1084
+
1085
+ Args:
1086
+ pred_bboxes: [N, 5] (x, y, w, h, theta).
1087
+ target_bboxes: [N, 5] (x, y, w, h, theta).
1088
+ fg_mask: Foreground mask indicating valid predictions.
1089
+ weight: Loss weights for each prediction.
1090
+ target_scores_sum: Sum of target scores for normalization.
1091
+ lambda_val: control the sensitivity to aspect ratio.
1092
+ """
1093
+ w_gt = target_bboxes[..., 2]
1094
+ h_gt = target_bboxes[..., 3]
1095
+ pred_theta = pred_bboxes[..., 4]
1096
+ target_theta = target_bboxes[..., 4]
1097
+
1098
+ log_ar = torch.log(w_gt / h_gt)
1099
+ scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
1100
+
1101
+ delta_theta = pred_theta - target_theta
1102
+ delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
1103
+ ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
1104
+
1105
+ ang_loss = scale_weight[fg_mask] * ang_loss
1106
+ ang_loss = ang_loss * weight
1107
+
1108
+ return ang_loss.sum() / target_scores_sum
1109
+
773
1110
 
774
1111
  class E2EDetectLoss:
775
1112
  """Criterion class for computing training losses for end-to-end detection."""
@@ -789,61 +1126,108 @@ class E2EDetectLoss:
789
1126
  return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
790
1127
 
791
1128
 
1129
+ class E2ELoss:
1130
+ """Criterion class for computing training losses for end-to-end detection."""
1131
+
1132
+ def __init__(self, model, loss_fn=v8DetectionLoss):
1133
+ """Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
1134
+ self.one2many = loss_fn(model, tal_topk=10)
1135
+ self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
1136
+ self.updates = 0
1137
+ self.total = 1.0
1138
+ # init gain
1139
+ self.o2m = 0.8
1140
+ self.o2o = self.total - self.o2m
1141
+ self.o2m_copy = self.o2m
1142
+ # final gain
1143
+ self.final_o2m = 0.1
1144
+
1145
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1146
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
1147
+ preds = self.one2many.parse_output(preds)
1148
+ one2many, one2one = preds["one2many"], preds["one2one"]
1149
+ loss_one2many = self.one2many.loss(one2many, batch)
1150
+ loss_one2one = self.one2one.loss(one2one, batch)
1151
+ return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
1152
+
1153
+ def update(self) -> None:
1154
+ """Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
1155
+ self.updates += 1
1156
+ self.o2m = self.decay(self.updates)
1157
+ self.o2o = max(self.total - self.o2m, 0)
1158
+
1159
+ def decay(self, x) -> float:
1160
+ """Calculate the decayed weight for one-to-many loss based on the current update step."""
1161
+ return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
1162
+
1163
+
792
1164
  class TVPDetectLoss:
793
1165
  """Criterion class for computing training losses for text-visual prompt detection."""
794
1166
 
795
- def __init__(self, model):
1167
+ def __init__(self, model, tal_topk=10):
796
1168
  """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
797
- self.vp_criterion = v8DetectionLoss(model)
1169
+ self.vp_criterion = v8DetectionLoss(model, tal_topk)
798
1170
  # NOTE: store following info as it's changeable in __call__
1171
+ self.hyp = self.vp_criterion.hyp
799
1172
  self.ori_nc = self.vp_criterion.nc
800
1173
  self.ori_no = self.vp_criterion.no
801
1174
  self.ori_reg_max = self.vp_criterion.reg_max
802
1175
 
1176
+ def parse_output(self, preds) -> dict[str, torch.Tensor]:
1177
+ """Parse model predictions to extract features."""
1178
+ return self.vp_criterion.parse_output(preds)
1179
+
803
1180
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
804
1181
  """Calculate the loss for text-visual prompt detection."""
805
- feats = preds[1] if isinstance(preds, tuple) else preds
1182
+ return self.loss(self.parse_output(preds), batch)
1183
+
1184
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1185
+ """Calculate the loss for text-visual prompt detection."""
1186
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
806
1187
 
807
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1188
+ if self.ori_nc == preds["scores"].shape[1]:
808
1189
  loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
809
1190
  return loss, loss.detach()
810
1191
 
811
- vp_feats = self._get_vp_features(feats)
812
- vp_loss = self.vp_criterion(vp_feats, batch)
813
- cls_loss = vp_loss[0][1]
814
- return cls_loss, vp_loss[1]
1192
+ preds["scores"] = self._get_vp_features(preds)
1193
+ vp_loss = self.vp_criterion(preds, batch)
1194
+ box_loss = vp_loss[0][1]
1195
+ return box_loss, vp_loss[1]
815
1196
 
816
- def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
1197
+ def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
817
1198
  """Extract visual-prompt features from the model output."""
818
- vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
1199
+ # NOTE: remove empty placeholder
1200
+ scores = preds["scores"][:, self.ori_nc :, :]
1201
+ vnc = scores.shape[1]
819
1202
 
820
1203
  self.vp_criterion.nc = vnc
821
1204
  self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
822
1205
  self.vp_criterion.assigner.num_classes = vnc
823
-
824
- return [
825
- torch.cat((box, cls_vp), dim=1)
826
- for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
827
- ]
1206
+ return scores
828
1207
 
829
1208
 
830
1209
  class TVPSegmentLoss(TVPDetectLoss):
831
1210
  """Criterion class for computing training losses for text-visual prompt segmentation."""
832
1211
 
833
- def __init__(self, model):
1212
+ def __init__(self, model, tal_topk=10):
834
1213
  """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
835
1214
  super().__init__(model)
836
- self.vp_criterion = v8SegmentationLoss(model)
1215
+ self.vp_criterion = v8SegmentationLoss(model, tal_topk)
1216
+ self.hyp = self.vp_criterion.hyp
837
1217
 
838
1218
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
839
1219
  """Calculate the loss for text-visual prompt segmentation."""
840
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
1220
+ return self.loss(self.parse_output(preds), batch)
1221
+
1222
+ def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1223
+ """Calculate the loss for text-visual prompt detection."""
1224
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
841
1225
 
842
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1226
+ if self.ori_nc == preds["scores"].shape[1]:
843
1227
  loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
844
1228
  return loss, loss.detach()
845
1229
 
846
- vp_feats = self._get_vp_features(feats)
847
- vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
1230
+ preds["scores"] = self._get_vp_features(preds)
1231
+ vp_loss = self.vp_criterion(preds, batch)
848
1232
  cls_loss = vp_loss[0][2]
849
1233
  return cls_loss, vp_loss[1]