dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py CHANGED
@@ -2,24 +2,24 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import math
5
6
  from typing import Any
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
9
10
  import torch.nn.functional as F
10
11
 
11
- from ultralytics.utils.metrics import OKS_SIGMA
12
+ from ultralytics.utils.metrics import OKS_SIGMA, RLE_WEIGHT
12
13
  from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
13
14
  from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
14
15
  from ultralytics.utils.torch_utils import autocast
15
16
 
16
17
  from .metrics import bbox_iou, probiou
17
- from .tal import bbox2dist
18
+ from .tal import bbox2dist, rbox2dist
18
19
 
19
20
 
20
21
  class VarifocalLoss(nn.Module):
21
- """
22
- Varifocal loss by Zhang et al.
22
+ """Varifocal loss by Zhang et al.
23
23
 
24
24
  Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
25
25
  hard-to-classify examples and balancing positive/negative samples.
@@ -51,11 +51,10 @@ class VarifocalLoss(nn.Module):
51
51
 
52
52
 
53
53
  class FocalLoss(nn.Module):
54
- """
55
- Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
54
+ """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
56
55
 
57
- Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
58
- on hard negatives during training.
56
+ Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing on
57
+ hard negatives during training.
59
58
 
60
59
  Attributes:
61
60
  gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
@@ -124,6 +123,8 @@ class BboxLoss(nn.Module):
124
123
  target_scores: torch.Tensor,
125
124
  target_scores_sum: torch.Tensor,
126
125
  fg_mask: torch.Tensor,
126
+ imgsz: torch.Tensor,
127
+ stride: torch.Tensor,
127
128
  ) -> tuple[torch.Tensor, torch.Tensor]:
128
129
  """Compute IoU and DFL losses for bounding boxes."""
129
130
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -136,11 +137,76 @@ class BboxLoss(nn.Module):
136
137
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
137
138
  loss_dfl = loss_dfl.sum() / target_scores_sum
138
139
  else:
139
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
140
+ target_ltrb = bbox2dist(anchor_points, target_bboxes)
141
+ # normalize ltrb by image size
142
+ target_ltrb = target_ltrb * stride
143
+ target_ltrb[..., 0::2] /= imgsz[1]
144
+ target_ltrb[..., 1::2] /= imgsz[0]
145
+ pred_dist = pred_dist * stride
146
+ pred_dist[..., 0::2] /= imgsz[1]
147
+ pred_dist[..., 1::2] /= imgsz[0]
148
+ loss_dfl = (
149
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
150
+ )
151
+ loss_dfl = loss_dfl.sum() / target_scores_sum
140
152
 
141
153
  return loss_iou, loss_dfl
142
154
 
143
155
 
156
+ class RLELoss(nn.Module):
157
+ """Residual Log-Likelihood Estimation Loss.
158
+
159
+ Args:
160
+ use_target_weight (bool): Option to use weighted loss.
161
+ size_average (bool): Option to average the loss by the batch_size.
162
+ residual (bool): Option to add L1 loss and let the flow learn the residual error distribution.
163
+
164
+ References:
165
+ https://arxiv.org/abs/2107.11291
166
+ https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/losses/regression_loss.py
167
+ """
168
+
169
+ def __init__(self, use_target_weight: bool = True, size_average: bool = True, residual: bool = True):
170
+ """Initialize RLELoss with target weight and residual options.
171
+
172
+ Args:
173
+ use_target_weight (bool): Whether to use target weights for loss calculation.
174
+ size_average (bool): Whether to average the loss over elements.
175
+ residual (bool): Whether to include residual log-likelihood term.
176
+ """
177
+ super().__init__()
178
+ self.size_average = size_average
179
+ self.use_target_weight = use_target_weight
180
+ self.residual = residual
181
+
182
+ def forward(
183
+ self, sigma: torch.Tensor, log_phi: torch.Tensor, error: torch.Tensor, target_weight: torch.Tensor = None
184
+ ) -> torch.Tensor:
185
+ """
186
+ Args:
187
+ sigma (torch.Tensor): Output sigma, shape (N, D).
188
+ log_phi (torch.Tensor): Output log_phi, shape (N).
189
+ error (torch.Tensor): Error, shape (N, D).
190
+ target_weight (torch.Tensor): Weights across different joint types, shape (N).
191
+ """
192
+ log_sigma = torch.log(sigma)
193
+ loss = log_sigma - log_phi.unsqueeze(1)
194
+
195
+ if self.residual:
196
+ loss += torch.log(sigma * 2) + torch.abs(error)
197
+
198
+ if self.use_target_weight:
199
+ assert target_weight is not None, "'target_weight' should not be None when 'use_target_weight' is True."
200
+ if target_weight.dim() == 1:
201
+ target_weight = target_weight.unsqueeze(1)
202
+ loss *= target_weight
203
+
204
+ if self.size_average:
205
+ loss /= len(loss)
206
+
207
+ return loss.sum()
208
+
209
+
144
210
  class RotatedBboxLoss(BboxLoss):
145
211
  """Criterion class for computing training losses for rotated bounding boxes."""
146
212
 
@@ -157,6 +223,8 @@ class RotatedBboxLoss(BboxLoss):
157
223
  target_scores: torch.Tensor,
158
224
  target_scores_sum: torch.Tensor,
159
225
  fg_mask: torch.Tensor,
226
+ imgsz: torch.Tensor,
227
+ stride: torch.Tensor,
160
228
  ) -> tuple[torch.Tensor, torch.Tensor]:
161
229
  """Compute IoU and DFL losses for rotated bounding boxes."""
162
230
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
@@ -165,15 +233,84 @@ class RotatedBboxLoss(BboxLoss):
165
233
 
166
234
  # DFL loss
167
235
  if self.dfl_loss:
168
- target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
236
+ target_ltrb = rbox2dist(
237
+ target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5], reg_max=self.dfl_loss.reg_max - 1
238
+ )
169
239
  loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
170
240
  loss_dfl = loss_dfl.sum() / target_scores_sum
171
241
  else:
172
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
242
+ target_ltrb = rbox2dist(target_bboxes[..., :4], anchor_points, target_bboxes[..., 4:5])
243
+ target_ltrb = target_ltrb * stride
244
+ target_ltrb[..., 0::2] /= imgsz[1]
245
+ target_ltrb[..., 1::2] /= imgsz[0]
246
+ pred_dist = pred_dist * stride
247
+ pred_dist[..., 0::2] /= imgsz[1]
248
+ pred_dist[..., 1::2] /= imgsz[0]
249
+ loss_dfl = (
250
+ F.l1_loss(pred_dist[fg_mask], target_ltrb[fg_mask], reduction="none").mean(-1, keepdim=True) * weight
251
+ )
252
+ loss_dfl = loss_dfl.sum() / target_scores_sum
173
253
 
174
254
  return loss_iou, loss_dfl
175
255
 
176
256
 
257
+ class MultiChannelDiceLoss(nn.Module):
258
+ """Criterion class for computing multi-channel Dice losses."""
259
+
260
+ def __init__(self, smooth: float = 1e-6, reduction: str = "mean"):
261
+ """Initialize MultiChannelDiceLoss with smoothing and reduction options.
262
+
263
+ Args:
264
+ smooth (float): Smoothing factor to avoid division by zero.
265
+ reduction (str): Reduction method ('mean', 'sum', or 'none').
266
+ """
267
+ super().__init__()
268
+ self.smooth = smooth
269
+ self.reduction = reduction
270
+
271
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
272
+ """Calculate multi-channel Dice loss between predictions and targets."""
273
+ assert pred.size() == target.size(), "the size of predict and target must be equal."
274
+
275
+ pred = pred.sigmoid()
276
+ intersection = (pred * target).sum(dim=(2, 3))
277
+ union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
278
+ dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
279
+ dice_loss = 1.0 - dice
280
+ dice_loss = dice_loss.mean(dim=1)
281
+
282
+ if self.reduction == "mean":
283
+ return dice_loss.mean()
284
+ elif self.reduction == "sum":
285
+ return dice_loss.sum()
286
+ else:
287
+ return dice_loss
288
+
289
+
290
+ class BCEDiceLoss(nn.Module):
291
+ """Criterion class for computing combined BCE and Dice losses."""
292
+
293
+ def __init__(self, weight_bce: float = 0.5, weight_dice: float = 0.5):
294
+ """Initialize BCEDiceLoss with BCE and Dice weight factors.
295
+
296
+ Args:
297
+ weight_bce (float): Weight factor for BCE loss component.
298
+ weight_dice (float): Weight factor for Dice loss component.
299
+ """
300
+ super().__init__()
301
+ self.weight_bce = weight_bce
302
+ self.weight_dice = weight_dice
303
+ self.bce = nn.BCEWithLogitsLoss()
304
+ self.dice = MultiChannelDiceLoss(smooth=1)
305
+
306
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
307
+ """Calculate combined BCE and Dice loss between predictions and targets."""
308
+ _, _, mask_h, mask_w = pred.shape
309
+ if tuple(target.shape[-2:]) != (mask_h, mask_w): # downsample to the same size as pred
310
+ target = F.interpolate(target, (mask_h, mask_w), mode="nearest")
311
+ return self.weight_bce * self.bce(pred, target) + self.weight_dice * self.dice(pred, target)
312
+
313
+
177
314
  class KeypointLoss(nn.Module):
178
315
  """Criterion class for computing keypoint losses."""
179
316
 
@@ -196,7 +333,7 @@ class KeypointLoss(nn.Module):
196
333
  class v8DetectionLoss:
197
334
  """Criterion class for computing training losses for YOLOv8 object detection."""
198
335
 
199
- def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
336
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
200
337
  """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
201
338
  device = next(model.parameters()).device # get model device
202
339
  h = model.args # hyperparameters
@@ -212,7 +349,14 @@ class v8DetectionLoss:
212
349
 
213
350
  self.use_dfl = m.reg_max > 1
214
351
 
215
- self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
352
+ self.assigner = TaskAlignedAssigner(
353
+ topk=tal_topk,
354
+ num_classes=self.nc,
355
+ alpha=0.5,
356
+ beta=6.0,
357
+ stride=self.stride.tolist(),
358
+ topk2=tal_topk2,
359
+ )
216
360
  self.bbox_loss = BboxLoss(m.reg_max).to(device)
217
361
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
218
362
 
@@ -242,35 +386,31 @@ class v8DetectionLoss:
242
386
  # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
243
387
  return dist2bbox(pred_dist, anchor_points, xywh=False)
244
388
 
245
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
246
- """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
389
+ def get_assigned_targets_and_loss(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> tuple:
390
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size and return foreground mask and
391
+ target indices.
392
+ """
247
393
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
248
- feats = preds[1] if isinstance(preds, tuple) else preds
249
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
250
- (self.reg_max * 4, self.nc), 1
394
+ pred_distri, pred_scores = (
395
+ preds["boxes"].permute(0, 2, 1).contiguous(),
396
+ preds["scores"].permute(0, 2, 1).contiguous(),
251
397
  )
252
-
253
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
254
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
398
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
255
399
 
256
400
  dtype = pred_scores.dtype
257
401
  batch_size = pred_scores.shape[0]
258
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
259
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
402
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
260
403
 
261
404
  # Targets
262
405
  targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
263
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
406
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
264
407
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
265
408
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
266
409
 
267
410
  # Pboxes
268
411
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
269
- # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
270
- # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
271
412
 
272
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
273
- # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
413
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
274
414
  pred_scores.detach().sigmoid(),
275
415
  (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
276
416
  anchor_points * stride_tensor,
@@ -282,7 +422,6 @@ class v8DetectionLoss:
282
422
  target_scores_sum = max(target_scores.sum(), 1)
283
423
 
284
424
  # Cls loss
285
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
286
425
  loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
287
426
 
288
427
  # Bbox loss
@@ -295,112 +434,114 @@ class v8DetectionLoss:
295
434
  target_scores,
296
435
  target_scores_sum,
297
436
  fg_mask,
437
+ imgsz,
438
+ stride_tensor,
298
439
  )
299
440
 
300
441
  loss[0] *= self.hyp.box # box gain
301
442
  loss[1] *= self.hyp.cls # cls gain
302
443
  loss[2] *= self.hyp.dfl # dfl gain
444
+ return (
445
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor),
446
+ loss,
447
+ loss.detach(),
448
+ ) # loss(box, cls, dfl)
303
449
 
304
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
450
+ def parse_output(
451
+ self, preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]]
452
+ ) -> torch.Tensor:
453
+ """Parse model predictions to extract features."""
454
+ return preds[1] if isinstance(preds, tuple) else preds
455
+
456
+ def __call__(
457
+ self,
458
+ preds: dict[str, torch.Tensor] | tuple[torch.Tensor, dict[str, torch.Tensor]],
459
+ batch: dict[str, torch.Tensor],
460
+ ) -> tuple[torch.Tensor, torch.Tensor]:
461
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
462
+ return self.loss(self.parse_output(preds), batch)
463
+
464
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
465
+ """A wrapper for get_assigned_targets_and_loss and parse_output."""
466
+ batch_size = preds["boxes"].shape[0]
467
+ loss, loss_detach = self.get_assigned_targets_and_loss(preds, batch)[1:]
468
+ return loss * batch_size, loss_detach
305
469
 
306
470
 
307
471
  class v8SegmentationLoss(v8DetectionLoss):
308
472
  """Criterion class for computing training losses for YOLOv8 segmentation."""
309
473
 
310
- def __init__(self, model): # model must be de-paralleled
474
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
311
475
  """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
312
- super().__init__(model)
476
+ super().__init__(model, tal_topk, tal_topk2)
313
477
  self.overlap = model.args.overlap_mask
478
+ self.bcedice_loss = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5)
314
479
 
315
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
480
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
316
481
  """Calculate and return the combined loss for detection and segmentation."""
317
- loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
318
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
319
- batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
320
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
321
- (self.reg_max * 4, self.nc), 1
322
- )
323
-
324
- # B, grids, ..
325
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
326
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
327
- pred_masks = pred_masks.permute(0, 2, 1).contiguous()
328
-
329
- dtype = pred_scores.dtype
330
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
331
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
332
-
333
- # Targets
334
- try:
335
- batch_idx = batch["batch_idx"].view(-1, 1)
336
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
337
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
338
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
339
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
340
- except RuntimeError as e:
341
- raise TypeError(
342
- "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
343
- "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
344
- "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
345
- "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
346
- "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
347
- ) from e
348
-
349
- # Pboxes
350
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
351
-
352
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
353
- pred_scores.detach().sigmoid(),
354
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
355
- anchor_points * stride_tensor,
356
- gt_labels,
357
- gt_bboxes,
358
- mask_gt,
359
- )
360
-
361
- target_scores_sum = max(target_scores.sum(), 1)
362
-
363
- # Cls loss
364
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
365
- loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
482
+ pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
483
+ loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
484
+ if isinstance(proto, tuple) and len(proto) == 2:
485
+ proto, pred_semseg = proto
486
+ else:
487
+ pred_semseg = None
488
+ (fg_mask, target_gt_idx, target_bboxes, _, _), det_loss, _ = self.get_assigned_targets_and_loss(preds, batch)
489
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
490
+ loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
366
491
 
492
+ batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
367
493
  if fg_mask.sum():
368
- # Bbox loss
369
- loss[0], loss[3] = self.bbox_loss(
370
- pred_distri,
371
- pred_bboxes,
372
- anchor_points,
373
- target_bboxes / stride_tensor,
374
- target_scores,
375
- target_scores_sum,
376
- fg_mask,
377
- )
378
494
  # Masks loss
379
495
  masks = batch["masks"].to(self.device).float()
380
496
  if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
381
- masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
497
+ # masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
498
+ proto = F.interpolate(proto, masks.shape[-2:], mode="bilinear", align_corners=False)
382
499
 
500
+ imgsz = (
501
+ torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_masks.dtype) * self.stride[0]
502
+ )
383
503
  loss[1] = self.calculate_segmentation_loss(
384
- fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
504
+ fg_mask,
505
+ masks,
506
+ target_gt_idx,
507
+ target_bboxes,
508
+ batch["batch_idx"].view(-1, 1),
509
+ proto,
510
+ pred_masks,
511
+ imgsz,
385
512
  )
513
+ if pred_semseg is not None:
514
+ sem_masks = batch["sem_masks"].to(self.device) # NxHxW
515
+ sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
516
+
517
+ if self.overlap:
518
+ mask_zero = masks == 0 # NxHxW
519
+ sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
520
+ else:
521
+ batch_idx = batch["batch_idx"].view(-1) # [total_instances]
522
+ for i in range(batch_size):
523
+ instance_mask_i = masks[batch_idx == i] # [num_instances_i, H, W]
524
+ if len(instance_mask_i) == 0:
525
+ continue
526
+ sem_masks[i, :, instance_mask_i.sum(dim=0) == 0] = 0
527
+
528
+ loss[4] = self.bcedice_loss(pred_semseg, sem_masks)
529
+ loss[4] *= self.hyp.box # seg gain
386
530
 
387
531
  # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
388
532
  else:
389
533
  loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
534
+ if pred_semseg is not None:
535
+ loss[4] += (pred_semseg * 0).sum()
390
536
 
391
- loss[0] *= self.hyp.box # box gain
392
537
  loss[1] *= self.hyp.box # seg gain
393
- loss[2] *= self.hyp.cls # cls gain
394
- loss[3] *= self.hyp.dfl # dfl gain
395
-
396
- return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
538
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
397
539
 
398
540
  @staticmethod
399
541
  def single_mask_loss(
400
542
  gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
401
543
  ) -> torch.Tensor:
402
- """
403
- Compute the instance segmentation loss for a single image.
544
+ """Compute the instance segmentation loss for a single image.
404
545
 
405
546
  Args:
406
547
  gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
@@ -430,10 +571,8 @@ class v8SegmentationLoss(v8DetectionLoss):
430
571
  proto: torch.Tensor,
431
572
  pred_masks: torch.Tensor,
432
573
  imgsz: torch.Tensor,
433
- overlap: bool,
434
574
  ) -> torch.Tensor:
435
- """
436
- Calculate the loss for instance segmentation.
575
+ """Calculate the loss for instance segmentation.
437
576
 
438
577
  Args:
439
578
  fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
@@ -444,7 +583,6 @@ class v8SegmentationLoss(v8DetectionLoss):
444
583
  proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
445
584
  pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
446
585
  imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
447
- overlap (bool): Whether the masks in `masks` tensor overlap.
448
586
 
449
587
  Returns:
450
588
  (torch.Tensor): The calculated loss for instance segmentation.
@@ -470,7 +608,7 @@ class v8SegmentationLoss(v8DetectionLoss):
470
608
  fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
471
609
  if fg_mask_i.any():
472
610
  mask_idx = target_gt_idx_i[fg_mask_i]
473
- if overlap:
611
+ if self.overlap:
474
612
  gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
475
613
  gt_mask = gt_mask.float()
476
614
  else:
@@ -490,9 +628,9 @@ class v8SegmentationLoss(v8DetectionLoss):
490
628
  class v8PoseLoss(v8DetectionLoss):
491
629
  """Criterion class for computing training losses for YOLOv8 pose estimation."""
492
630
 
493
- def __init__(self, model): # model must be de-paralleled
631
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int = 10): # model must be de-paralleled
494
632
  """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
495
- super().__init__(model)
633
+ super().__init__(model, tal_topk, tal_topk2)
496
634
  self.kpt_shape = model.model[-1].kpt_shape
497
635
  self.bce_pose = nn.BCEWithLogitsLoss()
498
636
  is_pose = self.kpt_shape == [17, 3]
@@ -500,71 +638,42 @@ class v8PoseLoss(v8DetectionLoss):
500
638
  sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
501
639
  self.keypoint_loss = KeypointLoss(sigmas=sigmas)
502
640
 
503
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
641
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
504
642
  """Calculate the total loss and detach it for pose estimation."""
643
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
505
644
  loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
506
- feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
507
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
508
- (self.reg_max * 4, self.nc), 1
645
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
646
+ self.get_assigned_targets_and_loss(preds, batch)
509
647
  )
648
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
649
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
510
650
 
511
- # B, grids, ..
512
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
513
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
514
- pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
515
-
516
- dtype = pred_scores.dtype
517
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
518
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
519
-
520
- # Targets
521
- batch_size = pred_scores.shape[0]
522
- batch_idx = batch["batch_idx"].view(-1, 1)
523
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
524
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
525
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
526
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
651
+ batch_size = pred_kpts.shape[0]
652
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
527
653
 
528
654
  # Pboxes
529
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
530
655
  pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
531
656
 
532
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
533
- pred_scores.detach().sigmoid(),
534
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
535
- anchor_points * stride_tensor,
536
- gt_labels,
537
- gt_bboxes,
538
- mask_gt,
539
- )
540
-
541
- target_scores_sum = max(target_scores.sum(), 1)
542
-
543
- # Cls loss
544
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
545
- loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
546
-
547
657
  # Bbox loss
548
658
  if fg_mask.sum():
549
- target_bboxes /= stride_tensor
550
- loss[0], loss[4] = self.bbox_loss(
551
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
552
- )
553
659
  keypoints = batch["keypoints"].to(self.device).float().clone()
554
660
  keypoints[..., 0] *= imgsz[1]
555
661
  keypoints[..., 1] *= imgsz[0]
556
662
 
557
663
  loss[1], loss[2] = self.calculate_keypoints_loss(
558
- fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
664
+ fg_mask,
665
+ target_gt_idx,
666
+ keypoints,
667
+ batch["batch_idx"].view(-1, 1),
668
+ stride_tensor,
669
+ target_bboxes,
670
+ pred_kpts,
559
671
  )
560
672
 
561
- loss[0] *= self.hyp.box # box gain
562
673
  loss[1] *= self.hyp.pose # pose gain
563
674
  loss[2] *= self.hyp.kobj # kobj gain
564
- loss[3] *= self.hyp.cls # cls gain
565
- loss[4] *= self.hyp.dfl # dfl gain
566
675
 
567
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
676
+ return loss * batch_size, loss.detach() # loss(box, pose, kobj, cls, dfl)
568
677
 
569
678
  @staticmethod
570
679
  def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
@@ -575,35 +684,23 @@ class v8PoseLoss(v8DetectionLoss):
575
684
  y[..., 1] += anchor_points[:, [1]] - 0.5
576
685
  return y
577
686
 
578
- def calculate_keypoints_loss(
687
+ def _select_target_keypoints(
579
688
  self,
580
- masks: torch.Tensor,
581
- target_gt_idx: torch.Tensor,
582
689
  keypoints: torch.Tensor,
583
690
  batch_idx: torch.Tensor,
584
- stride_tensor: torch.Tensor,
585
- target_bboxes: torch.Tensor,
586
- pred_kpts: torch.Tensor,
587
- ) -> tuple[torch.Tensor, torch.Tensor]:
588
- """
589
- Calculate the keypoints loss for the model.
590
-
591
- This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
592
- based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
593
- a binary classification loss that classifies whether a keypoint is present or not.
691
+ target_gt_idx: torch.Tensor,
692
+ masks: torch.Tensor,
693
+ ) -> torch.Tensor:
694
+ """Select target keypoints for each anchor based on batch index and target ground truth index.
594
695
 
595
696
  Args:
596
- masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
597
- target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
598
697
  keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
599
698
  batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
600
- stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
601
- target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
602
- pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
699
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
700
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
603
701
 
604
702
  Returns:
605
- kpts_loss (torch.Tensor): The keypoints loss.
606
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
703
+ (torch.Tensor): Selected keypoints tensor, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
607
704
  """
608
705
  batch_idx = batch_idx.flatten()
609
706
  batch_size = len(masks)
@@ -630,6 +727,40 @@ class v8PoseLoss(v8DetectionLoss):
630
727
  1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
631
728
  )
632
729
 
730
+ return selected_keypoints
731
+
732
+ def calculate_keypoints_loss(
733
+ self,
734
+ masks: torch.Tensor,
735
+ target_gt_idx: torch.Tensor,
736
+ keypoints: torch.Tensor,
737
+ batch_idx: torch.Tensor,
738
+ stride_tensor: torch.Tensor,
739
+ target_bboxes: torch.Tensor,
740
+ pred_kpts: torch.Tensor,
741
+ ) -> tuple[torch.Tensor, torch.Tensor]:
742
+ """Calculate the keypoints loss for the model.
743
+
744
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
745
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
746
+ a binary classification loss that classifies whether a keypoint is present or not.
747
+
748
+ Args:
749
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
750
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
751
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
752
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
753
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
754
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
755
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
756
+
757
+ Returns:
758
+ kpts_loss (torch.Tensor): The keypoints loss.
759
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
760
+ """
761
+ # Select target keypoints using helper method
762
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
763
+
633
764
  # Divide coordinates by stride
634
765
  selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
635
766
 
@@ -637,6 +768,7 @@ class v8PoseLoss(v8DetectionLoss):
637
768
  kpts_obj_loss = 0
638
769
 
639
770
  if masks.any():
771
+ target_bboxes /= stride_tensor
640
772
  gt_kpt = selected_keypoints[masks]
641
773
  area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
642
774
  pred_kpt = pred_kpts[masks]
@@ -649,6 +781,172 @@ class v8PoseLoss(v8DetectionLoss):
649
781
  return kpts_loss, kpts_obj_loss
650
782
 
651
783
 
784
+ class PoseLoss26(v8PoseLoss):
785
+ """Criterion class for computing training losses for YOLOv8 pose estimation with RLE loss support."""
786
+
787
+ def __init__(self, model, tal_topk: int = 10, tal_topk2: int | None = None): # model must be de-paralleled
788
+ """Initialize PoseLoss26 with model parameters and keypoint-specific loss functions including RLE loss."""
789
+ super().__init__(model, tal_topk, tal_topk2)
790
+ is_pose = self.kpt_shape == [17, 3]
791
+ nkpt = self.kpt_shape[0] # number of keypoints
792
+ self.rle_loss = None
793
+ self.flow_model = model.model[-1].flow_model if hasattr(model.model[-1], "flow_model") else None
794
+ if self.flow_model is not None:
795
+ self.rle_loss = RLELoss(use_target_weight=True).to(self.device)
796
+ self.target_weights = (
797
+ torch.from_numpy(RLE_WEIGHT).to(self.device) if is_pose else torch.ones(nkpt, device=self.device)
798
+ )
799
+
800
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
801
+ """Calculate the total loss and detach it for pose estimation."""
802
+ pred_kpts = preds["kpts"].permute(0, 2, 1).contiguous()
803
+ loss = torch.zeros(6 if self.rle_loss else 5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
804
+ (fg_mask, target_gt_idx, target_bboxes, anchor_points, stride_tensor), det_loss, _ = (
805
+ self.get_assigned_targets_and_loss(preds, batch)
806
+ )
807
+ # NOTE: re-assign index for consistency for now. Need to be removed in the future.
808
+ loss[0], loss[3], loss[4] = det_loss[0], det_loss[1], det_loss[2]
809
+
810
+ batch_size = pred_kpts.shape[0]
811
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=pred_kpts.dtype) * self.stride[0]
812
+
813
+ pred_kpts = pred_kpts.view(batch_size, -1, *self.kpt_shape) # (b, h*w, 17, 3)
814
+
815
+ if self.rle_loss and preds.get("kpts_sigma", None) is not None:
816
+ pred_sigma = preds["kpts_sigma"].permute(0, 2, 1).contiguous()
817
+ pred_sigma = pred_sigma.view(batch_size, -1, self.kpt_shape[0], 2) # (b, h*w, 17, 2)
818
+ pred_kpts = torch.cat([pred_kpts, pred_sigma], dim=-1) # (b, h*w, 17, 5)
819
+
820
+ pred_kpts = self.kpts_decode(anchor_points, pred_kpts)
821
+
822
+ # Bbox loss
823
+ if fg_mask.sum():
824
+ keypoints = batch["keypoints"].to(self.device).float().clone()
825
+ keypoints[..., 0] *= imgsz[1]
826
+ keypoints[..., 1] *= imgsz[0]
827
+
828
+ keypoints_loss = self.calculate_keypoints_loss(
829
+ fg_mask,
830
+ target_gt_idx,
831
+ keypoints,
832
+ batch["batch_idx"].view(-1, 1),
833
+ stride_tensor,
834
+ target_bboxes,
835
+ pred_kpts,
836
+ )
837
+ loss[1] = keypoints_loss[0]
838
+ loss[2] = keypoints_loss[1]
839
+ if self.rle_loss is not None:
840
+ loss[5] = keypoints_loss[2]
841
+
842
+ loss[1] *= self.hyp.pose # pose gain
843
+ loss[2] *= self.hyp.kobj # kobj gain
844
+ if self.rle_loss is not None:
845
+ loss[5] *= self.hyp.rle # rle gain
846
+
847
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl, kpt_location, kpt_visibility)
848
+
849
+ @staticmethod
850
+ def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
851
+ """Decode predicted keypoints to image coordinates."""
852
+ y = pred_kpts.clone()
853
+ y[..., 0] += anchor_points[:, [0]]
854
+ y[..., 1] += anchor_points[:, [1]]
855
+ return y
856
+
857
+ def calculate_rle_loss(self, pred_kpt: torch.Tensor, gt_kpt: torch.Tensor, kpt_mask: torch.Tensor) -> torch.Tensor:
858
+ """Calculate the RLE (Residual Log-likelihood Estimation) loss for keypoints.
859
+
860
+ Args:
861
+ pred_kpt (torch.Tensor): Predicted keypoints with sigma, shape (N, kpts_dim) where kpts_dim >= 4.
862
+ gt_kpt (torch.Tensor): Ground truth keypoints, shape (N, kpts_dim).
863
+ kpt_mask (torch.Tensor): Mask for valid keypoints, shape (N, num_keypoints).
864
+
865
+ Returns:
866
+ (torch.Tensor): The RLE loss.
867
+ """
868
+ pred_kpt_visible = pred_kpt[kpt_mask]
869
+ gt_kpt_visible = gt_kpt[kpt_mask]
870
+ pred_coords = pred_kpt_visible[:, 0:2]
871
+ pred_sigma = pred_kpt_visible[:, -2:]
872
+ gt_coords = gt_kpt_visible[:, 0:2]
873
+
874
+ target_weights = self.target_weights.unsqueeze(0).repeat(kpt_mask.shape[0], 1)
875
+ target_weights = target_weights[kpt_mask]
876
+
877
+ pred_sigma = pred_sigma.sigmoid()
878
+ error = (pred_coords - gt_coords) / (pred_sigma + 1e-9)
879
+
880
+ # Filter out NaN and Inf values to prevent MultivariateNormal validation errors
881
+ valid_mask = ~(torch.isnan(error) | torch.isinf(error)).any(dim=-1)
882
+ if not valid_mask.any():
883
+ return torch.tensor(0.0, device=pred_kpt.device)
884
+
885
+ error = error[valid_mask]
886
+ error = error.clamp(-100, 100) # Prevent numerical instability
887
+ pred_sigma = pred_sigma[valid_mask]
888
+ target_weights = target_weights[valid_mask]
889
+
890
+ log_phi = self.flow_model.log_prob(error)
891
+
892
+ return self.rle_loss(pred_sigma, log_phi, error, target_weights)
893
+
894
+ def calculate_keypoints_loss(
895
+ self,
896
+ masks: torch.Tensor,
897
+ target_gt_idx: torch.Tensor,
898
+ keypoints: torch.Tensor,
899
+ batch_idx: torch.Tensor,
900
+ stride_tensor: torch.Tensor,
901
+ target_bboxes: torch.Tensor,
902
+ pred_kpts: torch.Tensor,
903
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
904
+ """Calculate the keypoints loss for the model.
905
+
906
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
907
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
908
+ a binary classification loss that classifies whether a keypoint is present or not.
909
+
910
+ Args:
911
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
912
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
913
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
914
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
915
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
916
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
917
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
918
+
919
+ Returns:
920
+ kpts_loss (torch.Tensor): The keypoints loss.
921
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
922
+ rle_loss (torch.Tensor): The RLE loss.
923
+ """
924
+ # Select target keypoints using inherited helper method
925
+ selected_keypoints = self._select_target_keypoints(keypoints, batch_idx, target_gt_idx, masks)
926
+
927
+ # Divide coordinates by stride
928
+ selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
929
+
930
+ kpts_loss = 0
931
+ kpts_obj_loss = 0
932
+ rle_loss = 0
933
+
934
+ if masks.any():
935
+ target_bboxes /= stride_tensor
936
+ gt_kpt = selected_keypoints[masks]
937
+ area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
938
+ pred_kpt = pred_kpts[masks]
939
+ kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
940
+ kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
941
+
942
+ if self.rle_loss is not None and (pred_kpt.shape[-1] == 4 or pred_kpt.shape[-1] == 5):
943
+ rle_loss = self.calculate_rle_loss(pred_kpt, gt_kpt, kpt_mask)
944
+ if pred_kpt.shape[-1] == 3 or pred_kpt.shape[-1] == 5:
945
+ kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
946
+
947
+ return kpts_loss, kpts_obj_loss, rle_loss
948
+
949
+
652
950
  class v8ClassificationLoss:
653
951
  """Criterion class for computing training losses for classification."""
654
952
 
@@ -662,10 +960,17 @@ class v8ClassificationLoss:
662
960
  class v8OBBLoss(v8DetectionLoss):
663
961
  """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
664
962
 
665
- def __init__(self, model):
963
+ def __init__(self, model, tal_topk=10, tal_topk2: int | None = None):
666
964
  """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
667
- super().__init__(model)
668
- self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
965
+ super().__init__(model, tal_topk=tal_topk)
966
+ self.assigner = RotatedTaskAlignedAssigner(
967
+ topk=tal_topk,
968
+ num_classes=self.nc,
969
+ alpha=0.5,
970
+ beta=6.0,
971
+ stride=self.stride.tolist(),
972
+ topk2=tal_topk2,
973
+ )
669
974
  self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
670
975
 
671
976
  def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
@@ -685,38 +990,34 @@ class v8OBBLoss(v8DetectionLoss):
685
990
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
686
991
  return out
687
992
 
688
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
993
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
689
994
  """Calculate and return the loss for oriented bounding box detection."""
690
- loss = torch.zeros(3, device=self.device) # box, cls, dfl
691
- feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
692
- batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
693
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
694
- (self.reg_max * 4, self.nc), 1
995
+ loss = torch.zeros(4, device=self.device) # box, cls, dfl, angle
996
+ pred_distri, pred_scores, pred_angle = (
997
+ preds["boxes"].permute(0, 2, 1).contiguous(),
998
+ preds["scores"].permute(0, 2, 1).contiguous(),
999
+ preds["angle"].permute(0, 2, 1).contiguous(),
695
1000
  )
696
-
697
- # b, grids, ..
698
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
699
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
700
- pred_angle = pred_angle.permute(0, 2, 1).contiguous()
1001
+ anchor_points, stride_tensor = make_anchors(preds["feats"], self.stride, 0.5)
1002
+ batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
701
1003
 
702
1004
  dtype = pred_scores.dtype
703
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
704
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
1005
+ imgsz = torch.tensor(preds["feats"][0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
705
1006
 
706
1007
  # targets
707
1008
  try:
708
1009
  batch_idx = batch["batch_idx"].view(-1, 1)
709
1010
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
710
- rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
1011
+ rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
711
1012
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
712
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
1013
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
713
1014
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
714
1015
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
715
1016
  except RuntimeError as e:
716
1017
  raise TypeError(
717
1018
  "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
718
1019
  "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
719
- "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
1020
+ "i.e. 'yolo train model=yolo26n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
720
1021
  "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
721
1022
  "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
722
1023
  ) from e
@@ -746,22 +1047,34 @@ class v8OBBLoss(v8DetectionLoss):
746
1047
  if fg_mask.sum():
747
1048
  target_bboxes[..., :4] /= stride_tensor
748
1049
  loss[0], loss[2] = self.bbox_loss(
749
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
1050
+ pred_distri,
1051
+ pred_bboxes,
1052
+ anchor_points,
1053
+ target_bboxes,
1054
+ target_scores,
1055
+ target_scores_sum,
1056
+ fg_mask,
1057
+ imgsz,
1058
+ stride_tensor,
750
1059
  )
1060
+ weight = target_scores.sum(-1)[fg_mask]
1061
+ loss[3] = self.calculate_angle_loss(
1062
+ pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum
1063
+ ) # angle loss
751
1064
  else:
752
1065
  loss[0] += (pred_angle * 0).sum()
753
1066
 
754
1067
  loss[0] *= self.hyp.box # box gain
755
1068
  loss[1] *= self.hyp.cls # cls gain
756
1069
  loss[2] *= self.hyp.dfl # dfl gain
1070
+ loss[3] *= self.hyp.angle # angle gain
757
1071
 
758
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
1072
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl, angle)
759
1073
 
760
1074
  def bbox_decode(
761
1075
  self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
762
1076
  ) -> torch.Tensor:
763
- """
764
- Decode predicted object bounding box coordinates from anchor points and distribution.
1077
+ """Decode predicted object bounding box coordinates from anchor points and distribution.
765
1078
 
766
1079
  Args:
767
1080
  anchor_points (torch.Tensor): Anchor points, (h*w, 2).
@@ -776,6 +1089,34 @@ class v8OBBLoss(v8DetectionLoss):
776
1089
  pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
777
1090
  return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
778
1091
 
1092
+ def calculate_angle_loss(self, pred_bboxes, target_bboxes, fg_mask, weight, target_scores_sum, lambda_val=3):
1093
+ """Calculate oriented angle loss.
1094
+
1095
+ Args:
1096
+ pred_bboxes: [N, 5] (x, y, w, h, theta).
1097
+ target_bboxes: [N, 5] (x, y, w, h, theta).
1098
+ fg_mask: Foreground mask indicating valid predictions.
1099
+ weight: Loss weights for each prediction.
1100
+ target_scores_sum: Sum of target scores for normalization.
1101
+ lambda_val: control the sensitivity to aspect ratio.
1102
+ """
1103
+ w_gt = target_bboxes[..., 2]
1104
+ h_gt = target_bboxes[..., 3]
1105
+ pred_theta = pred_bboxes[..., 4]
1106
+ target_theta = target_bboxes[..., 4]
1107
+
1108
+ log_ar = torch.log(w_gt / h_gt)
1109
+ scale_weight = torch.exp(-(log_ar**2) / (lambda_val**2))
1110
+
1111
+ delta_theta = pred_theta - target_theta
1112
+ delta_theta_wrapped = delta_theta - torch.round(delta_theta / math.pi) * math.pi
1113
+ ang_loss = torch.sin(2 * delta_theta_wrapped[fg_mask]) ** 2
1114
+
1115
+ ang_loss = scale_weight[fg_mask] * ang_loss
1116
+ ang_loss = ang_loss * weight
1117
+
1118
+ return ang_loss.sum() / target_scores_sum
1119
+
779
1120
 
780
1121
  class E2EDetectLoss:
781
1122
  """Criterion class for computing training losses for end-to-end detection."""
@@ -795,63 +1136,108 @@ class E2EDetectLoss:
795
1136
  return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
796
1137
 
797
1138
 
1139
+ class E2ELoss:
1140
+ """Criterion class for computing training losses for end-to-end detection."""
1141
+
1142
+ def __init__(self, model, loss_fn=v8DetectionLoss):
1143
+ """Initialize E2ELoss with one-to-many and one-to-one detection losses using the provided model."""
1144
+ self.one2many = loss_fn(model, tal_topk=10)
1145
+ self.one2one = loss_fn(model, tal_topk=7, tal_topk2=1)
1146
+ self.updates = 0
1147
+ self.total = 1.0
1148
+ # init gain
1149
+ self.o2m = 0.8
1150
+ self.o2o = self.total - self.o2m
1151
+ self.o2m_copy = self.o2m
1152
+ # final gain
1153
+ self.final_o2m = 0.1
1154
+
1155
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1156
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
1157
+ preds = self.one2many.parse_output(preds)
1158
+ one2many, one2one = preds["one2many"], preds["one2one"]
1159
+ loss_one2many = self.one2many.loss(one2many, batch)
1160
+ loss_one2one = self.one2one.loss(one2one, batch)
1161
+ return loss_one2many[0] * self.o2m + loss_one2one[0] * self.o2o, loss_one2one[1]
1162
+
1163
+ def update(self) -> None:
1164
+ """Update the weights for one-to-many and one-to-one losses based on the decay schedule."""
1165
+ self.updates += 1
1166
+ self.o2m = self.decay(self.updates)
1167
+ self.o2o = max(self.total - self.o2m, 0)
1168
+
1169
+ def decay(self, x) -> float:
1170
+ """Calculate the decayed weight for one-to-many loss based on the current update step."""
1171
+ return max(1 - x / max(self.one2one.hyp.epochs - 1, 1), 0) * (self.o2m_copy - self.final_o2m) + self.final_o2m
1172
+
1173
+
798
1174
  class TVPDetectLoss:
799
1175
  """Criterion class for computing training losses for text-visual prompt detection."""
800
1176
 
801
- def __init__(self, model):
1177
+ def __init__(self, model, tal_topk=10):
802
1178
  """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
803
- self.vp_criterion = v8DetectionLoss(model)
1179
+ self.vp_criterion = v8DetectionLoss(model, tal_topk)
804
1180
  # NOTE: store following info as it's changeable in __call__
1181
+ self.hyp = self.vp_criterion.hyp
805
1182
  self.ori_nc = self.vp_criterion.nc
806
1183
  self.ori_no = self.vp_criterion.no
807
1184
  self.ori_reg_max = self.vp_criterion.reg_max
808
1185
 
1186
+ def parse_output(self, preds) -> dict[str, torch.Tensor]:
1187
+ """Parse model predictions to extract features."""
1188
+ return self.vp_criterion.parse_output(preds)
1189
+
809
1190
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
810
1191
  """Calculate the loss for text-visual prompt detection."""
811
- feats = preds[1] if isinstance(preds, tuple) else preds
1192
+ return self.loss(self.parse_output(preds), batch)
1193
+
1194
+ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1195
+ """Calculate the loss for text-visual prompt detection."""
812
1196
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
813
1197
 
814
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1198
+ if self.ori_nc == preds["scores"].shape[1]:
815
1199
  loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
816
1200
  return loss, loss.detach()
817
1201
 
818
- vp_feats = self._get_vp_features(feats)
819
- vp_loss = self.vp_criterion(vp_feats, batch)
1202
+ preds["scores"] = self._get_vp_features(preds)
1203
+ vp_loss = self.vp_criterion(preds, batch)
820
1204
  box_loss = vp_loss[0][1]
821
1205
  return box_loss, vp_loss[1]
822
1206
 
823
- def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
1207
+ def _get_vp_features(self, preds: dict[str, torch.Tensor]) -> list[torch.Tensor]:
824
1208
  """Extract visual-prompt features from the model output."""
825
- vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
1209
+ # NOTE: remove empty placeholder
1210
+ scores = preds["scores"][:, self.ori_nc :, :]
1211
+ vnc = scores.shape[1]
826
1212
 
827
1213
  self.vp_criterion.nc = vnc
828
1214
  self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
829
1215
  self.vp_criterion.assigner.num_classes = vnc
830
-
831
- return [
832
- torch.cat((box, cls_vp), dim=1)
833
- for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
834
- ]
1216
+ return scores
835
1217
 
836
1218
 
837
1219
  class TVPSegmentLoss(TVPDetectLoss):
838
1220
  """Criterion class for computing training losses for text-visual prompt segmentation."""
839
1221
 
840
- def __init__(self, model):
1222
+ def __init__(self, model, tal_topk=10):
841
1223
  """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
842
1224
  super().__init__(model)
843
- self.vp_criterion = v8SegmentationLoss(model)
1225
+ self.vp_criterion = v8SegmentationLoss(model, tal_topk)
1226
+ self.hyp = self.vp_criterion.hyp
844
1227
 
845
1228
  def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
846
1229
  """Calculate the loss for text-visual prompt segmentation."""
847
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
1230
+ return self.loss(self.parse_output(preds), batch)
1231
+
1232
+ def loss(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
1233
+ """Calculate the loss for text-visual prompt detection."""
848
1234
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
849
1235
 
850
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
1236
+ if self.ori_nc == preds["scores"].shape[1]:
851
1237
  loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
852
1238
  return loss, loss.detach()
853
1239
 
854
- vp_feats = self._get_vp_features(feats)
855
- vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
1240
+ preds["scores"] = self._get_vp_features(preds)
1241
+ vp_loss = self.vp_criterion(preds, batch)
856
1242
  cls_loss = vp_loss[0][2]
857
1243
  return cls_loss, vp_loss[1]