dgenerate-ultralytics-headless 8.3.248__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +52 -61
  2. {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/RECORD +97 -84
  3. {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +2 -2
  5. tests/conftest.py +1 -1
  6. tests/test_cuda.py +8 -2
  7. tests/test_engine.py +8 -8
  8. tests/test_exports.py +11 -4
  9. tests/test_integrations.py +9 -9
  10. tests/test_python.py +41 -16
  11. tests/test_solutions.py +3 -3
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +31 -31
  14. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  15. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  16. ultralytics/cfg/default.yaml +3 -1
  17. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  18. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  19. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  20. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  21. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  22. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  23. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  24. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  25. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  26. ultralytics/data/annotator.py +2 -2
  27. ultralytics/data/augment.py +15 -0
  28. ultralytics/data/converter.py +76 -45
  29. ultralytics/data/dataset.py +1 -1
  30. ultralytics/data/utils.py +2 -2
  31. ultralytics/engine/exporter.py +34 -28
  32. ultralytics/engine/model.py +38 -37
  33. ultralytics/engine/predictor.py +17 -17
  34. ultralytics/engine/results.py +22 -15
  35. ultralytics/engine/trainer.py +83 -48
  36. ultralytics/engine/tuner.py +20 -11
  37. ultralytics/engine/validator.py +16 -16
  38. ultralytics/models/fastsam/predict.py +1 -1
  39. ultralytics/models/yolo/classify/predict.py +1 -1
  40. ultralytics/models/yolo/classify/train.py +1 -1
  41. ultralytics/models/yolo/classify/val.py +1 -1
  42. ultralytics/models/yolo/detect/predict.py +2 -2
  43. ultralytics/models/yolo/detect/train.py +6 -3
  44. ultralytics/models/yolo/detect/val.py +7 -1
  45. ultralytics/models/yolo/model.py +8 -8
  46. ultralytics/models/yolo/obb/predict.py +2 -2
  47. ultralytics/models/yolo/obb/train.py +3 -3
  48. ultralytics/models/yolo/obb/val.py +1 -1
  49. ultralytics/models/yolo/pose/predict.py +1 -1
  50. ultralytics/models/yolo/pose/train.py +3 -1
  51. ultralytics/models/yolo/pose/val.py +1 -1
  52. ultralytics/models/yolo/segment/predict.py +3 -3
  53. ultralytics/models/yolo/segment/train.py +4 -4
  54. ultralytics/models/yolo/segment/val.py +2 -2
  55. ultralytics/models/yolo/yoloe/train.py +6 -1
  56. ultralytics/models/yolo/yoloe/train_seg.py +6 -1
  57. ultralytics/nn/autobackend.py +14 -8
  58. ultralytics/nn/modules/__init__.py +8 -0
  59. ultralytics/nn/modules/block.py +128 -8
  60. ultralytics/nn/modules/head.py +788 -203
  61. ultralytics/nn/tasks.py +86 -41
  62. ultralytics/nn/text_model.py +5 -2
  63. ultralytics/optim/__init__.py +5 -0
  64. ultralytics/optim/muon.py +338 -0
  65. ultralytics/solutions/ai_gym.py +3 -3
  66. ultralytics/solutions/config.py +1 -1
  67. ultralytics/solutions/heatmap.py +1 -1
  68. ultralytics/solutions/instance_segmentation.py +2 -2
  69. ultralytics/solutions/object_counter.py +1 -1
  70. ultralytics/solutions/parking_management.py +1 -1
  71. ultralytics/solutions/solutions.py +2 -2
  72. ultralytics/trackers/byte_tracker.py +7 -7
  73. ultralytics/trackers/track.py +1 -1
  74. ultralytics/utils/__init__.py +8 -8
  75. ultralytics/utils/benchmarks.py +26 -26
  76. ultralytics/utils/callbacks/platform.py +173 -64
  77. ultralytics/utils/callbacks/tensorboard.py +2 -0
  78. ultralytics/utils/callbacks/wb.py +6 -1
  79. ultralytics/utils/checks.py +28 -9
  80. ultralytics/utils/dist.py +1 -0
  81. ultralytics/utils/downloads.py +5 -3
  82. ultralytics/utils/export/engine.py +19 -10
  83. ultralytics/utils/export/imx.py +38 -20
  84. ultralytics/utils/export/tensorflow.py +21 -21
  85. ultralytics/utils/files.py +2 -2
  86. ultralytics/utils/loss.py +597 -203
  87. ultralytics/utils/metrics.py +2 -1
  88. ultralytics/utils/ops.py +11 -2
  89. ultralytics/utils/patches.py +42 -0
  90. ultralytics/utils/plotting.py +3 -0
  91. ultralytics/utils/tal.py +100 -20
  92. ultralytics/utils/torch_utils.py +1 -1
  93. ultralytics/utils/tqdm.py +4 -1
  94. ultralytics/utils/tuner.py +2 -5
  95. {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  96. {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  97. {dgenerate_ultralytics_headless-8.3.248.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py CHANGED
@@ -2,19 +2,20 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import math
5
6
  from typing import Any
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
9
10
  import torch.nn.functional as F
10
11
 
11
- from ultralytics.utils.metrics import OKS_SIGMA
12
+ from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
12
13
  from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
13
14
  from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
14
15
  from ultralytics.utils.torch_utils import autocast
15
16
 
16
17
  from .metrics import bbox_iou, probiou
17
- from .tal import bbox2dist
18
+ from .tal import bbox2dist, rbox2dist
18
19
 
19
20
 
20
21
  class VarifocalLoss(nn.Module):
@@ -122,6 +123,8 @@ class BboxLoss(nn.Module):
122
123
  target_scores: torch.Tensor,
123
124
  target_scores_sum: torch.Tensor,
124
125
  fg_mask: torch.Tensor,
126
+ imgsz: torch.Tensor,
127
+ stride: torch.Tensor,
125
128
  ) -> tuple[torch.Tensor, torch.Tensor]:
126
129
  """Compute IoU and DFL losses for bounding boxes."""
127
130
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -134,11 +137,76 @@ class BboxLoss(nn.Module):
134
137
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
135
138
  loss_dfl = loss_dfl.sum() / target_scores_sum
136
139
  else:
137
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
140
+ target_ltrb = bbox2dist(anchor_points, target_bboxes)
141
+ # normalize ltrb by image size
142
+ target_ltrb = target_ltrb * stride
143
+ target_ltrb[..., 0::2] /= imgsz[1]
144
+ target_ltrb[..., 1::2] /= imgsz[0]
145
+ pred_dist = pred_dist * stride
146
+ pred_dist[..., 0::2] /= imgsz[1]
147
+ pred_dist[..., 1::2] /= imgsz[0]
148
+ loss_dfl = (
149
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
150
+ )
151
+ loss_dfl = loss_dfl.sum() / target_scores_sum
138
152
 
139
153
  return loss_iou, loss_dfl
140
154
 
141
155
 
156
+ class RLELoss(nn.Module):
157
+ """Residual Log-Likelihood Estimation Loss.
158
+
159
+ Args:
160
+ use_target_weight (bool): Option to use weighted loss.
161
+ size_average (bool): Option to average the loss by the batch_size.
162
+ residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
163
+
164
+ References:
165
+ https://arxiv.org/abs/2107.11291
166
+ https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/losses/regression_loss.py
167
+ """
168
+
169
+ def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
170
+ """Initialize RLELoss with target weight and residual options.
171
+
172
+ Args:
173
+ use_target_weight (bool): Whether to use target weights for loss calculation.
174
+ size_average (bool): Whether to average the loss over elements.
175
+ residual (bool): Whether to include residual log-likelihood term.
176
+ """
177
+ super().__init__()
178
+ self.size_average = size_average
179
+ self.use_target_weight = use_target_weight
180
+ self.residual = residual
181
+
182
+ def forward(
183
+ self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
184
+ ) -> torch.Tensor:
185
+ """
186
+ Args:
187
+ sigma (torch.Tensor): Output sigma, shape (N, D).
188
+ log_phi (torch.Tensor): Output log_phi, shape (N).
189
+ error (torch.Tensor): Error, shape (N, D).
190
+ target_weight (torch.Tensor): Weights across different joint types, shape (N).
191
+ """
192
+ log_sigma = torch.log(sigma)
193
+ loss = log_sigma - log_phi.unsqueeze(1)
194
+
195
+ if self.residual:
196
+ loss += torch.log(sigma * 2) + torch.abs(error)
197
+
198
+ if self.use_target_weight:
199
+ assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
200
+ if target_weight.dim() == 1:
201
+ target_weight = target_weight.unsqueeze(1)
202
+ loss *= target_weight
203
+
204
+ if self.size_average:
205
+ loss /= len(loss)
206
+
207
+ return loss.sum()
208
+
209
+
142
210
  class RotatedBboxLoss(BboxLoss):
143
211
  """Criterion class for computing training losses for rotated bounding boxes."""
144
212
 
@@ -155,6 +223,8 @@ class RotatedBboxLoss(BboxLoss):
155
223
  target_scores: torch.Tensor,
156
224
  target_scores_sum: torch.Tensor,
157
225
  fg_mask: torch.Tensor,
226
+ imgsz: torch.Tensor,
227
+ stride: torch.Tensor,
158
228
  ) -> tuple[torch.Tensor, torch.Tensor]:
159
229
  """Compute IoU and DFL losses for rotated bounding boxes."""
160
230
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -163,15 +233,84 @@ class RotatedBboxLoss(BboxLoss):
163
233
 
164
234
  # DFL loss
165
235
  if self.dfl_loss:
166
- target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
236
+ target_ltrb = rbox2dist(
237
+ target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
238
+ )
167
239
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
168
240
  loss_dfl = loss_dfl.sum() / target_scores_sum
169
241
  else:
170
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
242
+ target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
243
+ target_ltrb = target_ltrb * stride
244
+ target_ltrb[..., 0::2] /= imgsz[1]
245
+ target_ltrb[..., 1::2] /= imgsz[0]
246
+ pred_dist = pred_dist * stride
247
+ pred_dist[..., 0::2] /= imgsz[1]
248
+ pred_dist[..., 1::2] /= imgsz[0]
249
+ loss_dfl = (
250
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
251
+ )
252
+ loss_dfl = loss_dfl.sum() / target_scores_sum
171
253
 
172
254
  return loss_iou, loss_dfl
173
255
 
174
256
 
257
+ class MultiChannelDiceLoss(nn.Module):
258
+ """Criterion class for computing multi-channel Dice losses."""
259
+
260
+ def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
261
+ """Initialize MultiChannelDiceLoss with smoothing and reduction options.
262
+
263
+ Args:
264
+ smooth (float): Smoothing factor to avoid division by zero.
265
+ reduction (str): Reduction method ('mean', 'sum', or 'none').
266
+ """
267
+ super().__init__()
268
+ self.smooth = smooth
269
+ self.reduction = reduction
270
+
271
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
272
+ """Calculate multi-channel Dice loss between predictions and targets."""
273
+ assert pred.size() == target.size(), "the size of predict and target must be equal."
274
+
275
+ pred = pred.sigmoid()
276
+ intersection = (pred * target).sum(dim=(2, 3))
277
+ union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
278
+ dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
279
+ dice_loss = 1.0 - dice
280
+ dice_loss = dice_loss.mean(dim=1)
281
+
282
+ if self.reduction == "mean":
283
+ return dice_loss.mean()
284
+ elif self.reduction == "sum":
285
+ return dice_loss.sum()
286
+ else:
287
+ return dice_loss
288
+
289
+
290
+ class BCEDiceLoss(nn.Module):
291
+ """Criterion class for computing combined BCE and Dice losses."""
292
+
293
+ def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
294
+ """Initialize BCEDiceLoss with BCE and Dice weight factors.
295
+
296
+ Args:
297
+ weight_bce (float): Weight factor for BCE loss component.
298
+ weight_dice (float): Weight factor for Dice loss component.
299
+ """
300
+ super().__init__()
301
+ self.weight_bce = weight_bce
302
+ self.weight_dice = weight_dice
303
+ self.bce = nn.BCEWithLogitsLoss()
304
+ self.dice = MultiChannelDiceLoss(smooth=1)
305
+
306
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
307
+ """Calculate combined BCE and Dice loss between predictions and targets."""
308
+ _, _, mask_h, mask_w = pred.shape
309
+ if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
310
+ target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
311
+ return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
312
+
313
+
175
314
  class KeypointLoss(nn.Module):
176
315
  """Criterion class for computing keypoint losses."""
177
316
 
@@ -194,7 +333,7 @@ class KeypointLoss(nn.Module):
194
333
  class v8DetectionLoss:
195
334
  """Criterion class for computing training losses for YOLOv8 object detection."""
196
335
 
197
- def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
336
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
198
337
  """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
199
338
  device = next(model.parameters()).device # get model device
200
339
  h = model.args # hyperparameters
@@ -210,7 +349,14 @@ class v8DetectionLoss:
210
349
 
211
350
  self.use_dfl = m.reg_max > 1
212
351
 
213
- self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
352
+ self.assigner = TaskAlignedAssigner(
353
+ topk=tal_topk,
354
+ num_classes=self.nc,
355
+ alpha=0.5,
356
+ beta=6.0,
357
+ stride=self.stride.tolist(),
358
+ topk2=tal_topk2,
359
+ )
214
360
  self.bbox_loss = BboxLoss(m.reg_max).to(device)
215
361
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
216
362
 
@@ -240,35 +386,31 @@ class v8DetectionLoss:
240
386
  # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
241
387
  return dist2bbox(pred_dist, anchor_points, xywh=False)
242
388
 
243
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
244
- """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
389
+ def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
390
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
391
+ target indices.
392
+ """
245
393
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
246
- feats = preds[1] if isinstance(preds, tuple) else preds
247
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
248
- (self.reg_max * 4, self.nc), 1
394
+ pred_distri, pred_scores = (
395
+ preds["boxes"].permute(0, 2, 1).contiguous(),
396
+ preds["scores"].permute(0, 2, 1).contiguous(),
249
397
  )
250
-
251
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
252
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
398
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
253
399
 
254
400
  dtype = pred_scores.dtype
255
401
  batch_size = pred_scores.shape[0]
256
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
257
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
402
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
258
403
 
259
404
  # Targets
260
405
  targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
261
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
406
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
262
407
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
263
408
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
264
409
 
265
410
  # Pboxes
266
411
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
267
- # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
268
- # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
269
412
 
270
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
271
- # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
413
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
272
414
  pred_scores.detach().sigmoid(),
273
415
  (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
274
416
  anchor_points * stride_tensor,
@@ -280,7 +422,6 @@ class v8DetectionLoss:
280
422
  target_scores_sum = max(target_scores.sum(), 1)
281
423
 
282
424
  # Cls loss
283
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
284
425
  loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
285
426
 
286
427
  # Bbox loss
@@ -293,105 +434,108 @@ class v8DetectionLoss:
293
434
  target_scores,
294
435
  target_scores_sum,
295
436
  fg_mask,
437
+ imgsz,
438
+ stride_tensor,
296
439
  )
297
440
 
298
441
  loss[0] *= self.hyp.box # box gain
299
442
  loss[1] *= self.hyp.cls # cls gain
300
443
  loss[2] *= self.hyp.dfl # dfl gain
444
+ return (
445
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
446
+ loss,
447
+ loss.detach(),
448
+ ) # loss(box, cls, dfl)
301
449
 
302
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
450
+ def parse_output(
451
+ self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
452
+ ) -> torch.Tensor:
453
+ """Parse model predictions to extract features."""
454
+ return preds[1] if isinstance(preds, tuple) else preds
455
+
456
+ def __call__(
457
+ self,
458
+ preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
459
+ batch: dict[str, torch.Tensor],
460
+ ) -> tuple[torch.Tensor, torch.Tensor]:
461
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
462
+ return self.loss(self.parse_output(preds), batch)
463
+
464
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
465
+ """A wrapper for get_assigned_targets_and_loss and parse_output."""
466
+ batch_size = preds["boxes"].shape[0]
467
+ loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
468
+ return loss * batch_size, loss_detach
303
469
 
304
470
 
305
471
  class v8SegmentationLoss(v8DetectionLoss):
306
472
  """Criterion class for computing training losses for YOLOv8 segmentation."""
307
473
 
308
- def __init__(self, model): # model must be de-paralleled
474
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
309
475
  """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
310
- super().__init__(model)
476
+ super().__init__(model, tal_topk, tal_topk2)
311
477
  self.overlap = model.args.overlap_mask
478
+ self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
312
479
 
313
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
480
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
314
481
  """Calculate and return the combined loss for detection and segmentation."""
315
- loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
316
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
317
- batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
318
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
319
- (self.reg_max * 4, self.nc), 1
320
- )
321
-
322
- # B, grids, ..
323
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
324
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
325
- pred_masks = pred_masks.permute(0, 2, 1).contiguous()
326
-
327
- dtype = pred_scores.dtype
328
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
329
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
330
-
331
- # Targets
332
- try:
333
- batch_idx = batch["batch_idx"].view(-1, 1)
334
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
335
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
336
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
337
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
338
- except RuntimeError as e:
339
- raise TypeError(
340
- "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
341
- "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
342
- "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
343
- "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
344
- "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
345
- ) from e
346
-
347
- # Pboxes
348
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
349
-
350
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
351
- pred_scores.detach().sigmoid(),
352
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
353
- anchor_points * stride_tensor,
354
- gt_labels,
355
- gt_bboxes,
356
- mask_gt,
357
- )
358
-
359
- target_scores_sum = max(target_scores.sum(), 1)
360
-
361
- # Cls loss
362
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
363
- loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
482
+ pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
483
+ loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
484
+ if isinstance(proto, tuple) and len(proto) == 2:
485
+ proto, pred_semseg = proto
486
+ else:
487
+ pred_semseg = None
488
+ (fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
489
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
490
+ loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
364
491
 
492
+ batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
365
493
  if fg_mask.sum():
366
- # Bbox loss
367
- loss[0], loss[3] = self.bbox_loss(
368
- pred_distri,
369
- pred_bboxes,
370
- anchor_points,
371
- target_bboxes / stride_tensor,
372
- target_scores,
373
- target_scores_sum,
374
- fg_mask,
375
- )
376
494
  # Masks loss
377
495
  masks = batch["masks"].to(self.device).float()
378
496
  if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
379
- masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
497
+ # masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
498
+ proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
380
499
 
500
+ imgsz = (
501
+ torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
502
+ )
381
503
  loss[1] = self.calculate_segmentation_loss(
382
- fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
504
+ fg_mask,
505
+ masks,
506
+ target_gt_idx,
507
+ target_bboxes,
508
+ batch["batch_idx"].view(-1, 1),
509
+ proto,
510
+ pred_masks,
511
+ imgsz,
383
512
  )
513
+ if pred_semseg is not None:
514
+ sem_masks = batch["sem_masks"].to(self.device) # NxHxW
515
+ sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
516
+
517
+ if self.overlap:
518
+ mask_zero = masks == 0 # NxHxW
519
+ sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
520
+ else:
521
+ batch_idx = batch["batch_idx"].view(-1) # [total_instances]
522
+ for i in range(batch_size):
523
+ instance_mask_i = masks[batch_idx == i] # [num_instances_i, H, W]
524
+ if len(instance_mask_i) == 0:
525
+ continue
526
+ sem_masks[i, :, instance_mask_i.sum(dim=0) == 0] = 0
527
+
528
+ loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
529
+ loss[4] *= self.hyp.box # seg gain
384
530
 
385
531
  # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
386
532
  else:
387
533
  loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
534
+ if pred_semseg is not None:
535
+ loss[4] += (pred_semseg * 0).sum()
388
536
 
389
- loss[0] *= self.hyp.box # box gain
390
537
  loss[1] *= self.hyp.box # seg gain
391
- loss[2] *= self.hyp.cls # cls gain
392
- loss[3] *= self.hyp.dfl # dfl gain
393
-
394
- return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
538
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
395
539
 
396
540
  @staticmethod
397
541
  def single_mask_loss(
@@ -427,7 +571,6 @@ class v8SegmentationLoss(v8DetectionLoss):
427
571
  proto: torch.Tensor,
428
572
  pred_masks: torch.Tensor,
429
573
  imgsz: torch.Tensor,
430
- overlap: bool,
431
574
  ) -> torch.Tensor:
432
575
  """Calculate the loss for instance segmentation.
433
576
 
@@ -440,7 +583,6 @@ class v8SegmentationLoss(v8DetectionLoss):
440
583
  proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
441
584
  pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
442
585
  imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
443
- overlap (bool): Whether the masks in `masks` tensor overlap.
444
586
 
445
587
  Returns:
446
588
  (torch.Tensor): The calculated loss for instance segmentation.
@@ -466,7 +608,7 @@ class v8SegmentationLoss(v8DetectionLoss):
466
608
  fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
467
609
  if fg_mask_i.any():
468
610
  mask_idx = target_gt_idx_i[fg_mask_i]
469
- if overlap:
611
+ if self.overlap:
470
612
  gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
471
613
  gt_mask = gt_mask.float()
472
614
  else:
@@ -486,9 +628,9 @@ class v8SegmentationLoss(v8DetectionLoss):
486
628
  class v8PoseLoss(v8DetectionLoss):
487
629
  """Criterion class for computing training losses for YOLOv8 pose estimation."""
488
630
 
489
- def __init__(self, model): # model must be de-paralleled
631
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
490
632
  """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
491
- super().__init__(model)
633
+ super().__init__(model, tal_topk, tal_topk2)
492
634
  self.kpt_shape = model.model[-1].kpt_shape
493
635
  self.bce_pose = nn.BCEWithLogitsLoss()
494
636
  is_pose = self.kpt_shape == [17, 3]
@@ -496,69 +638,40 @@ class v8PoseLoss(v8DetectionLoss):
496
638
  sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
497
639
  self.keypoint_loss = KeypointLoss(sigmas=sigmas)
498
640
 
499
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
641
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
500
642
  """Calculate the total loss and detach it for pose estimation."""
501
- loss = torch.zeros(5, device=self.device) # box, pose, kobj, cls, dfl
502
- feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
503
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
504
- (self.reg_max * 4, self.nc), 1
643
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
644
+ loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
645
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
646
+ self.get_assigned_targets_and_loss(preds, batch)
505
647
  )
648
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
649
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
506
650
 
507
- # B, grids, ..
508
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
509
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
510
- pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
511
-
512
- dtype = pred_scores.dtype
513
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
514
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
515
-
516
- # Targets
517
- batch_size = pred_scores.shape[0]
518
- batch_idx = batch["batch_idx"].view(-1, 1)
519
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
520
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
521
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
522
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
651
+ batch_size = pred_kpts.shape[0]
652
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
523
653
 
524
654
  # Pboxes
525
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
526
655
  pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
527
656
 
528
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
529
- pred_scores.detach().sigmoid(),
530
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
531
- anchor_points * stride_tensor,
532
- gt_labels,
533
- gt_bboxes,
534
- mask_gt,
535
- )
536
-
537
- target_scores_sum = max(target_scores.sum(), 1)
538
-
539
- # Cls loss
540
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
541
- loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
542
-
543
657
  # Bbox loss
544
658
  if fg_mask.sum():
545
- target_bboxes /= stride_tensor
546
- loss[0], loss[4] = self.bbox_loss(
547
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
548
- )
549
659
  keypoints = batch["keypoints"].to(self.device).float().clone()
550
660
  keypoints[..., 0] *= imgsz[1]
551
661
  keypoints[..., 1] *= imgsz[0]
552
662
 
553
663
  loss[1], loss[2] = self.calculate_keypoints_loss(
554
- fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
664
+ fg_mask,
665
+ target_gt_idx,
666
+ keypoints,
667
+ batch["batch_idx"].view(-1, 1),
668
+ stride_tensor,
669
+ target_bboxes,
670
+ pred_kpts,
555
671
  )
556
672
 
557
- loss[0] *= self.hyp.box # box gain
558
673
  loss[1] *= self.hyp.pose # pose gain
559
674
  loss[2] *= self.hyp.kobj # kobj gain
560
- loss[3] *= self.hyp.cls # cls gain
561
- loss[4] *= self.hyp.dfl # dfl gain
562
675
 
563
676
  return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
564
677
 
@@ -571,34 +684,23 @@ class v8PoseLoss(v8DetectionLoss):
571
684
  y[..., 1] += anchor_points[:, [1]] - 0.5
572
685
  return y
573
686
 
574
- def calculate_keypoints_loss(
687
+ def _select_target_keypoints(
575
688
  self,
576
- masks: torch.Tensor,
577
- target_gt_idx: torch.Tensor,
578
689
  keypoints: torch.Tensor,
579
690
  batch_idx: torch.Tensor,
580
- stride_tensor: torch.Tensor,
581
- target_bboxes: torch.Tensor,
582
- pred_kpts: torch.Tensor,
583
- ) -> tuple[torch.Tensor, torch.Tensor]:
584
- """Calculate the keypoints loss for the model.
585
-
586
- This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
587
- based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
588
- a binary classification loss that classifies whether a keypoint is present or not.
691
+ target_gt_idx: torch.Tensor,
692
+ masks: torch.Tensor,
693
+ ) -> torch.Tensor:
694
+ """Select target keypoints for each anchor based on batch index and target ground truth index.
589
695
 
590
696
  Args:
591
- masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
592
- target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
593
697
  keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
594
698
  batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
595
- stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
596
- target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
597
- pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
699
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
700
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
598
701
 
599
702
  Returns:
600
- kpts_loss (torch.Tensor): The keypoints loss.
601
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
703
+ (torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
602
704
  """
603
705
  batch_idx = batch_idx.flatten()
604
706
  batch_size = len(masks)
@@ -625,6 +727,40 @@ class v8PoseLoss(v8DetectionLoss):
625
727
  1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
626
728
  )
627
729
 
730
+ return selected_keypoints
731
+
732
+ def calculate_keypoints_loss(
733
+ self,
734
+ masks: torch.Tensor,
735
+ target_gt_idx: torch.Tensor,
736
+ keypoints: torch.Tensor,
737
+ batch_idx: torch.Tensor,
738
+ stride_tensor: torch.Tensor,
739
+ target_bboxes: torch.Tensor,
740
+ pred_kpts: torch.Tensor,
741
+ ) -> tuple[torch.Tensor, torch.Tensor]:
742
+ """Calculate the keypoints loss for the model.
743
+
744
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
745
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
746
+ a binary classification loss that classifies whether a keypoint is present or not.
747
+
748
+ Args:
749
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
750
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
751
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
752
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
753
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
754
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
755
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
756
+
757
+ Returns:
758
+ kpts_loss (torch.Tensor): The keypoints loss.
759
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
760
+ """
761
+ # Select target keypoints using helper method
762
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
763
+
628
764
  # Divide coordinates by stride
629
765
  selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
630
766
 
@@ -632,6 +768,7 @@ class v8PoseLoss(v8DetectionLoss):
632
768
  kpts_obj_loss = 0
633
769
 
634
770
  if masks.any():
771
+ target_bboxes /= stride_tensor
635
772
  gt_kpt = selected_keypoints[masks]
636
773
  area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
637
774
  pred_kpt = pred_kpts[masks]
@@ -644,6 +781,172 @@ class v8PoseLoss(v8DetectionLoss):
644
781
  return kpts_loss, kpts_obj_loss
645
782
 
646
783
 
784
+ class PoseLoss26(v8PoseLoss):
785
+ """Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
786
+
787
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
788
+ """Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
789
+ super().__init__(model, tal_topk, tal_topk2)
790
+ is_pose = self.kpt_shape == [17, 3]
791
+ nkpt = self.kpt_shape[0] # number of keypoints
792
+ self.rle_loss = None
793
+ self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
794
+ if self.flow_model is not None:
795
+ self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
796
+ self.target_weights = (
797
+ torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
798
+ )
799
+
800
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
801
+ """Calculate the total loss and detach it for pose estimation."""
802
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
803
+ loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
804
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
805
+ self.get_assigned_targets_and_loss(preds, batch)
806
+ )
807
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
808
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
809
+
810
+ batch_size = pred_kpts.shape[0]
811
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
812
+
813
+ pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
814
+
815
+ if self.rle_loss and preds.get("kpts_sigma", None) is not None:
816
+ pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
817
+ pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
818
+ pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
819
+
820
+ pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
821
+
822
+ # Bbox loss
823
+ if fg_mask.sum():
824
+ keypoints = batch["keypoints"].to(self.device).float().clone()
825
+ keypoints[..., 0] *= imgsz[1]
826
+ keypoints[..., 1] *= imgsz[0]
827
+
828
+ keypoints_loss = self.calculate_keypoints_loss(
829
+ fg_mask,
830
+ target_gt_idx,
831
+ keypoints,
832
+ batch["batch_idx"].view(-1, 1),
833
+ stride_tensor,
834
+ target_bboxes,
835
+ pred_kpts,
836
+ )
837
+ loss[1] = keypoints_loss[0]
838
+ loss[2] = keypoints_loss[1]
839
+ if self.rle_loss is not None:
840
+ loss[5] = keypoints_loss[2]
841
+
842
+ loss[1] *= self.hyp.pose # pose gain
843
+ loss[2] *= self.hyp.kobj # kobj gain
844
+ if self.rle_loss is not None:
845
+ loss[5] *= self.hyp.rle # rle gain
846
+
847
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl, kpt_location, kpt_visibility)
848
+
849
+ @staticmethod
850
+ def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
851
+ """Decode predicted keypoints to image coordinates."""
852
+ y = pred_kpts.clone()
853
+ y[..., 0] += anchor_points[:, [0]]
854
+ y[..., 1] += anchor_points[:, [1]]
855
+ return y
856
+
857
+ def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
858
+ """Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
859
+
860
+ Args:
861
+ pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
862
+ gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
863
+ kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
864
+
865
+ Returns:
866
+ (torch.Tensor): The RLE loss.
867
+ """
868
+ pred_kpt_visible = pred_kpt[kpt_mask]
869
+ gt_kpt_visible = gt_kpt[kpt_mask]
870
+ pred_coords = pred_kpt_visible[:, 0:2]
871
+ pred_sigma = pred_kpt_visible[:, -2:]
872
+ gt_coords = gt_kpt_visible[:, 0:2]
873
+
874
+ target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
875
+ target_weights = target_weights[kpt_mask]
876
+
877
+ pred_sigma = pred_sigma.sigmoid()
878
+ error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
879
+
880
+ # Filter out NaN and Inf values to prevent MultivariateNormal validation errors
881
+ valid_mask = ~(torch.isnan(error) | torch.isinf(error)).any(dim=-1)
882
+ if not valid_mask.any():
883
+ return torch.tensor(0.0, device=pred_kpt.device)
884
+
885
+ error = error[valid_mask]
886
+ error = error.clamp(-100, 100) # Prevent numerical instability
887
+ pred_sigma = pred_sigma[valid_mask]
888
+ target_weights = target_weights[valid_mask]
889
+
890
+ log_phi = self.flow_model.log_prob(error)
891
+
892
+ return self.rle_loss(pred_sigma, log_phi, error, target_weights)
893
+
894
+ def calculate_keypoints_loss(
895
+ self,
896
+ masks: torch.Tensor,
897
+ target_gt_idx: torch.Tensor,
898
+ keypoints: torch.Tensor,
899
+ batch_idx: torch.Tensor,
900
+ stride_tensor: torch.Tensor,
901
+ target_bboxes: torch.Tensor,
902
+ pred_kpts: torch.Tensor,
903
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
904
+ """Calculate the keypoints loss for the model.
905
+
906
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
907
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
908
+ a binary classification loss that classifies whether a keypoint is present or not.
909
+
910
+ Args:
911
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
912
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
913
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
914
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
915
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
916
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
917
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
918
+
919
+ Returns:
920
+ kpts_loss (torch.Tensor): The keypoints loss.
921
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
922
+ rle_loss (torch.Tensor): The RLE loss.
923
+ """
924
+ # Select target keypoints using inherited helper method
925
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
926
+
927
+ # Divide coordinates by stride
928
+ selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
929
+
930
+ kpts_loss = 0
931
+ kpts_obj_loss = 0
932
+ rle_loss = 0
933
+
934
+ if masks.any():
935
+ target_bboxes /= stride_tensor
936
+ gt_kpt = selected_keypoints[masks]
937
+ area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
938
+ pred_kpt = pred_kpts[masks]
939
+ kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
940
+ kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
941
+
942
+ if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
943
+ rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
944
+ if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
945
+ kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
946
+
947
+ return kpts_loss, kpts_obj_loss, rle_loss
948
+
949
+
647
950
  class v8ClassificationLoss:
648
951
  """Criterion class for computing training losses for classification."""
649
952
 
@@ -657,10 +960,17 @@ class v8ClassificationLoss:
657
960
  class v8OBBLoss(v8DetectionLoss):
658
961
  """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
659
962
 
660
- def __init__(self, model):
963
+ def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
661
964
  """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
662
- super().__init__(model)
663
- self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
965
+ super().__init__(model, tal_topk=tal_topk)
966
+ self.assigner = RotatedTaskAlignedAssigner(
967
+ topk=tal_topk,
968
+ num_classes=self.nc,
969
+ alpha=0.5,
970
+ beta=6.0,
971
+ stride=self.stride.tolist(),
972
+ topk2=tal_topk2,
973
+ )
664
974
  self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
665
975
 
666
976
  def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
@@ -680,23 +990,19 @@ class v8OBBLoss(v8DetectionLoss):
680
990
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
681
991
  return out
682
992
 
683
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
993
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
684
994
  """Calculate and return the loss for oriented bounding box detection."""
685
- loss = torch.zeros(3, device=self.device) # box, cls, dfl
686
- feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
687
- batch_size = pred_angle.shape[0] # batch size
688
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
689
- (self.reg_max * 4, self.nc), 1
995
+ loss = torch.zeros(4, device=self.device) # box, cls, dfl, angle
996
+ pred_distri, pred_scores, pred_angle = (
997
+ preds["boxes"].permute(0, 2, 1).contiguous(),
998
+ preds["scores"].permute(0, 2, 1).contiguous(),
999
+ preds["angle"].permute(0, 2, 1).contiguous(),
690
1000
  )
691
-
692
- # b, grids, ..
693
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
694
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
695
- pred_angle = pred_angle.permute(0, 2, 1).contiguous()
1001
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
1002
+ batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
696
1003
 
697
1004
  dtype = pred_scores.dtype
698
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
699
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
1005
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
700
1006
 
701
1007
  # targets
702
1008
  try:
@@ -704,14 +1010,14 @@ class v8OBBLoss(v8DetectionLoss):
704
1010
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
705
1011
  rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
706
1012
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
707
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
1013
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
708
1014
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
709
1015
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
710
1016
  except RuntimeError as e:
711
1017
  raise TypeError(
712
1018
  "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
713
1019
  "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
714
- "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
1020
+ "i.e. 'yolo train model=yolo26n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
715
1021
  "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
716
1022
  "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
717
1023
  ) from e
@@ -741,16 +1047,29 @@ class v8OBBLoss(v8DetectionLoss):
741
1047
  if fg_mask.sum():
742
1048
  target_bboxes[..., :4] /= stride_tensor
743
1049
  loss[0], loss[2] = self.bbox_loss(
744
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
1050
+ pred_distri,
1051
+ pred_bboxes,
1052
+ anchor_points,
1053
+ target_bboxes,
1054
+ target_scores,
1055
+ target_scores_sum,
1056
+ fg_mask,
1057
+ imgsz,
1058
+ stride_tensor,
745
1059
  )
1060
+ weight = target_scores.sum(-1)[fg_mask]
1061
+ loss[3] = self.calculate_angle_loss(
1062
+ pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
1063
+ ) # angle loss
746
1064
  else:
747
1065
  loss[0] += (pred_angle * 0).sum()
748
1066
 
749
1067
  loss[0] *= self.hyp.box # box gain
750
1068
  loss[1] *= self.hyp.cls # cls gain
751
1069
  loss[2] *= self.hyp.dfl # dfl gain
1070
+ loss[3] *= self.hyp.angle # angle gain
752
1071
 
753
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
1072
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
754
1073
 
755
1074
  def bbox_decode(
756
1075
  self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
@@ -770,6 +1089,34 @@ class v8OBBLoss(v8DetectionLoss):
770
1089
  pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
771
1090
  return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
772
1091
 
1092
+ def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
1093
+ """Calculate oriented angle loss.
1094
+
1095
+ Args:
1096
+ pred_bboxes: [N, 5] (x, y, w, h, theta).
1097
+ target_bboxes: [N, 5] (x, y, w, h, theta).
1098
+ fg_mask: Foreground mask indicating valid predictions.
1099
+ weight: Loss weights for each prediction.
1100
+ target_scores_sum: Sum of target scores for normalization.
1101
+ lambda_val: control the sensitivity to aspect ratio.
1102
+ """
1103
+ w_gt = target_bboxes[..., 2]
1104
+ h_gt = target_bboxes[..., 3]
1105
+ pred_theta = pred_bboxes[..., 4]
1106
+ target_theta = target_bboxes[..., 4]
1107
+
1108
+ log_ar = torch.log(w_gt / h_gt)
1109
+ scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
1110
+
1111
+ delta_theta = pred_theta - target_theta
1112
+ delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
1113
+ ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
1114
+
1115
+ ang_loss = scale_weight[fg_mask] * ang_loss
1116
+ ang_loss = ang_loss * weight
1117
+
1118
+ return ang_loss.sum() / target_scores_sum
1119
+
773
1120
 
774
1121
  class E2EDetectLoss:
775
1122
  """Criterion class for computing training losses for end-to-end detection."""
@@ -789,61 +1136,108 @@ class E2EDetectLoss:
789
1136
  return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
790
1137
 
791
1138
 
1139
+ class E2ELoss:
1140
+ """Criterion class for computing training losses for end-to-end detection."""
1141
+
1142
+ def __init__(self, model, loss_fn=v8DetectionLoss):
1143
+ """Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
1144
+ self.one2many = loss_fn(model, tal_topk=10)
1145
+ self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
1146
+ self.updates = 0
1147
+ self.total = 1.0
1148
+ # init gain
1149
+ self.o2m = 0.8
1150
+ self.o2o = self.total - self.o2m
1151
+ self.o2m_copy = self.o2m
1152
+ # final gain
1153
+ self.final_o2m = 0.1
1154
+
1155
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1156
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
1157
+ preds = self.one2many.parse_output(preds)
1158
+ one2many, one2one = preds["one2many"], preds["one2one"]
1159
+ loss_one2many = self.one2many.loss(one2many, batch)
1160
+ loss_one2one = self.one2one.loss(one2one, batch)
1161
+ return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
1162
+
1163
+ def update(self) -> None:
1164
+ """Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
1165
+ self.updates += 1
1166
+ self.o2m = self.decay(self.updates)
1167
+ self.o2o = max(self.total - self.o2m, 0)
1168
+
1169
+ def decay(self, x) -> float:
1170
+ """Calculate the decayed weight for one-to-many loss based on the current update step."""
1171
+ return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
1172
+
1173
+
792
1174
  class TVPDetectLoss:
793
1175
  """Criterion class for computing training losses for text-visual prompt detection."""
794
1176
 
795
- def __init__(self, model):
1177
+ def __init__(self, model, tal_topk=10):
796
1178
  """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
797
- self.vp_criterion = v8DetectionLoss(model)
1179
+ self.vp_criterion = v8DetectionLoss(model, tal_topk)
798
1180
  # NOTE: store following info as it's changeable in __call__
1181
+ self.hyp = self.vp_criterion.hyp
799
1182
  self.ori_nc = self.vp_criterion.nc
800
1183
  self.ori_no = self.vp_criterion.no
801
1184
  self.ori_reg_max = self.vp_criterion.reg_max
802
1185
 
1186
+ def parse_output(self, preds) -> dict[str, torch.Tensor]:
1187
+ """Parse model predictions to extract features."""
1188
+ return self.vp_criterion.parse_output(preds)
1189
+
803
1190
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
804
1191
  """Calculate the loss for text-visual prompt detection."""
805
- feats = preds[1] if isinstance(preds, tuple) else preds
1192
+ return self.loss(self.parse_output(preds), batch)
1193
+
1194
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1195
+ """Calculate the loss for text-visual prompt detection."""
1196
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
806
1197
 
807
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1198
+ if self.ori_nc == preds["scores"].shape[1]:
808
1199
  loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
809
1200
  return loss, loss.detach()
810
1201
 
811
- vp_feats = self._get_vp_features(feats)
812
- vp_loss = self.vp_criterion(vp_feats, batch)
813
- cls_loss = vp_loss[0][1]
814
- return cls_loss, vp_loss[1]
1202
+ preds["scores"] = self._get_vp_features(preds)
1203
+ vp_loss = self.vp_criterion(preds, batch)
1204
+ box_loss = vp_loss[0][1]
1205
+ return box_loss, vp_loss[1]
815
1206
 
816
- def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
1207
+ def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
817
1208
  """Extract visual-prompt features from the model output."""
818
- vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
1209
+ # NOTE: remove empty placeholder
1210
+ scores = preds["scores"][:, self.ori_nc :, :]
1211
+ vnc = scores.shape[1]
819
1212
 
820
1213
  self.vp_criterion.nc = vnc
821
1214
  self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
822
1215
  self.vp_criterion.assigner.num_classes = vnc
823
-
824
- return [
825
- torch.cat((box, cls_vp), dim=1)
826
- for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
827
- ]
1216
+ return scores
828
1217
 
829
1218
 
830
1219
  class TVPSegmentLoss(TVPDetectLoss):
831
1220
  """Criterion class for computing training losses for text-visual prompt segmentation."""
832
1221
 
833
- def __init__(self, model):
1222
+ def __init__(self, model, tal_topk=10):
834
1223
  """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
835
1224
  super().__init__(model)
836
- self.vp_criterion = v8SegmentationLoss(model)
1225
+ self.vp_criterion = v8SegmentationLoss(model, tal_topk)
1226
+ self.hyp = self.vp_criterion.hyp
837
1227
 
838
1228
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
839
1229
  """Calculate the loss for text-visual prompt segmentation."""
840
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
1230
+ return self.loss(self.parse_output(preds), batch)
1231
+
1232
+ def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1233
+ """Calculate the loss for text-visual prompt detection."""
1234
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
841
1235
 
842
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1236
+ if self.ori_nc == preds["scores"].shape[1]:
843
1237
  loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
844
1238
  return loss, loss.detach()
845
1239
 
846
- vp_feats = self._get_vp_features(feats)
847
- vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
1240
+ preds["scores"] = self._get_vp_features(preds)
1241
+ vp_loss = self.vp_criterion(preds, batch)
848
1242
  cls_loss = vp_loss[0][2]
849
1243
  return cls_loss, vp_loss[1]