dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,813 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from ultralytics.utils.metrics import OKS_SIGMA
8
+ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
9
+ from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
10
+ from ultralytics.utils.torch_utils import autocast
11
+
12
+ from .metrics import bbox_iou, probiou
13
+ from .tal import bbox2dist
14
+
15
+
16
+ class VarifocalLoss(nn.Module):
17
+ """
18
+ Varifocal loss by Zhang et al.
19
+
20
+ https://arxiv.org/abs/2008.13367.
21
+
22
+ Args:
23
+ gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
24
+ alpha (float): The balancing factor used to address class imbalance.
25
+ """
26
+
27
+ def __init__(self, gamma=2.0, alpha=0.75):
28
+ """Initialize the VarifocalLoss class."""
29
+ super().__init__()
30
+ self.gamma = gamma
31
+ self.alpha = alpha
32
+
33
+ def forward(self, pred_score, gt_score, label):
34
+ """Compute varifocal loss between predictions and ground truth."""
35
+ weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
36
+ with autocast(enabled=False):
37
+ loss = (
38
+ (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
39
+ .mean(1)
40
+ .sum()
41
+ )
42
+ return loss
43
+
44
+
45
+ class FocalLoss(nn.Module):
46
+ """
47
+ Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
48
+
49
+ Args:
50
+ gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
51
+ alpha (float | list): The balancing factor used to address class imbalance.
52
+ """
53
+
54
+ def __init__(self, gamma=1.5, alpha=0.25):
55
+ """Initialize FocalLoss class with no parameters."""
56
+ super().__init__()
57
+ self.gamma = gamma
58
+ self.alpha = torch.tensor(alpha)
59
+
60
+ def forward(self, pred, label):
61
+ """Calculate focal loss with modulating factors for class imbalance."""
62
+ loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
63
+ # p_t = torch.exp(-loss)
64
+ # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
65
+
66
+ # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
67
+ pred_prob = pred.sigmoid() # prob from logits
68
+ p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
69
+ modulating_factor = (1.0 - p_t) ** self.gamma
70
+ loss *= modulating_factor
71
+ if (self.alpha > 0).any():
72
+ self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)
73
+ alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
74
+ loss *= alpha_factor
75
+ return loss.mean(1).sum()
76
+
77
+
78
+ class DFLoss(nn.Module):
79
+ """Criterion class for computing Distribution Focal Loss (DFL)."""
80
+
81
+ def __init__(self, reg_max=16) -> None:
82
+ """Initialize the DFL module with regularization maximum."""
83
+ super().__init__()
84
+ self.reg_max = reg_max
85
+
86
+ def __call__(self, pred_dist, target):
87
+ """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
88
+ target = target.clamp_(0, self.reg_max - 1 - 0.01)
89
+ tl = target.long() # target left
90
+ tr = tl + 1 # target right
91
+ wl = tr - target # weight left
92
+ wr = 1 - wl # weight right
93
+ return (
94
+ F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
95
+ + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
96
+ ).mean(-1, keepdim=True)
97
+
98
+
99
+ class BboxLoss(nn.Module):
100
+ """Criterion class for computing training losses for bounding boxes."""
101
+
102
+ def __init__(self, reg_max=16):
103
+ """Initialize the BboxLoss module with regularization maximum and DFL settings."""
104
+ super().__init__()
105
+ self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
106
+
107
+ def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
108
+ """Compute IoU and DFL losses for bounding boxes."""
109
+ weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
110
+ iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
111
+ loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
112
+
113
+ # DFL loss
114
+ if self.dfl_loss:
115
+ target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
116
+ loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
117
+ loss_dfl = loss_dfl.sum() / target_scores_sum
118
+ else:
119
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
120
+
121
+ return loss_iou, loss_dfl
122
+
123
+
124
+ class RotatedBboxLoss(BboxLoss):
125
+ """Criterion class for computing training losses for rotated bounding boxes."""
126
+
127
+ def __init__(self, reg_max):
128
+ """Initialize the BboxLoss module with regularization maximum and DFL settings."""
129
+ super().__init__(reg_max)
130
+
131
+ def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
132
+ """Compute IoU and DFL losses for rotated bounding boxes."""
133
+ weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
134
+ iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
135
+ loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
136
+
137
+ # DFL loss
138
+ if self.dfl_loss:
139
+ target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
140
+ loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
141
+ loss_dfl = loss_dfl.sum() / target_scores_sum
142
+ else:
143
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
144
+
145
+ return loss_iou, loss_dfl
146
+
147
+
148
+ class KeypointLoss(nn.Module):
149
+ """Criterion class for computing keypoint losses."""
150
+
151
+ def __init__(self, sigmas) -> None:
152
+ """Initialize the KeypointLoss class with keypoint sigmas."""
153
+ super().__init__()
154
+ self.sigmas = sigmas
155
+
156
+ def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
157
+ """Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
158
+ d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
159
+ kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
160
+ # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
161
+ e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
162
+ return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
163
+
164
+
165
+ class v8DetectionLoss:
166
+ """Criterion class for computing training losses for YOLOv8 object detection."""
167
+
168
+ def __init__(self, model, tal_topk=10): # model must be de-paralleled
169
+ """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
170
+ device = next(model.parameters()).device # get model device
171
+ h = model.args # hyperparameters
172
+
173
+ m = model.model[-1] # Detect() module
174
+ self.bce = nn.BCEWithLogitsLoss(reduction="none")
175
+ self.hyp = h
176
+ self.stride = m.stride # model strides
177
+ self.nc = m.nc # number of classes
178
+ self.no = m.nc + m.reg_max * 4
179
+ self.reg_max = m.reg_max
180
+ self.device = device
181
+
182
+ self.use_dfl = m.reg_max > 1
183
+
184
+ self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
185
+ self.bbox_loss = BboxLoss(m.reg_max).to(device)
186
+ self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
187
+
188
+ def preprocess(self, targets, batch_size, scale_tensor):
189
+ """Preprocess targets by converting to tensor format and scaling coordinates."""
190
+ nl, ne = targets.shape
191
+ if nl == 0:
192
+ out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
193
+ else:
194
+ i = targets[:, 0] # image index
195
+ _, counts = i.unique(return_counts=True)
196
+ counts = counts.to(dtype=torch.int32)
197
+ out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
198
+ for j in range(batch_size):
199
+ matches = i == j
200
+ if n := matches.sum():
201
+ out[j, :n] = targets[matches, 1:]
202
+ out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
203
+ return out
204
+
205
+ def bbox_decode(self, anchor_points, pred_dist):
206
+ """Decode predicted object bounding box coordinates from anchor points and distribution."""
207
+ if self.use_dfl:
208
+ b, a, c = pred_dist.shape # batch, anchors, channels
209
+ pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
210
+ # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
211
+ # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
212
+ return dist2bbox(pred_dist, anchor_points, xywh=False)
213
+
214
+ def __call__(self, preds, batch):
215
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
216
+ loss = torch.zeros(3, device=self.device) # box, cls, dfl
217
+ feats = preds[1] if isinstance(preds, tuple) else preds
218
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
219
+ (self.reg_max * 4, self.nc), 1
220
+ )
221
+
222
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
223
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
224
+
225
+ dtype = pred_scores.dtype
226
+ batch_size = pred_scores.shape[0]
227
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
228
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
229
+
230
+ # Targets
231
+ targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
232
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
233
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
234
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
235
+
236
+ # Pboxes
237
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
238
+ # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
239
+ # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
240
+
241
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
242
+ # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
243
+ pred_scores.detach().sigmoid(),
244
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
245
+ anchor_points * stride_tensor,
246
+ gt_labels,
247
+ gt_bboxes,
248
+ mask_gt,
249
+ )
250
+
251
+ target_scores_sum = max(target_scores.sum(), 1)
252
+
253
+ # Cls loss
254
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
255
+ loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
256
+
257
+ # Bbox loss
258
+ if fg_mask.sum():
259
+ target_bboxes /= stride_tensor
260
+ loss[0], loss[2] = self.bbox_loss(
261
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
262
+ )
263
+
264
+ loss[0] *= self.hyp.box # box gain
265
+ loss[1] *= self.hyp.cls # cls gain
266
+ loss[2] *= self.hyp.dfl # dfl gain
267
+
268
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
269
+
270
+
271
+ class v8SegmentationLoss(v8DetectionLoss):
272
+ """Criterion class for computing training losses for YOLOv8 segmentation."""
273
+
274
+ def __init__(self, model): # model must be de-paralleled
275
+ """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
276
+ super().__init__(model)
277
+ self.overlap = model.args.overlap_mask
278
+
279
+ def __call__(self, preds, batch):
280
+ """Calculate and return the combined loss for detection and segmentation."""
281
+ loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
282
+ feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
283
+ batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
284
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
285
+ (self.reg_max * 4, self.nc), 1
286
+ )
287
+
288
+ # B, grids, ..
289
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
290
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
291
+ pred_masks = pred_masks.permute(0, 2, 1).contiguous()
292
+
293
+ dtype = pred_scores.dtype
294
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
295
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
296
+
297
+ # Targets
298
+ try:
299
+ batch_idx = batch["batch_idx"].view(-1, 1)
300
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
301
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
302
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
303
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
304
+ except RuntimeError as e:
305
+ raise TypeError(
306
+ "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
307
+ "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
308
+ "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
309
+ "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
310
+ "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
311
+ ) from e
312
+
313
+ # Pboxes
314
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
315
+
316
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
317
+ pred_scores.detach().sigmoid(),
318
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
319
+ anchor_points * stride_tensor,
320
+ gt_labels,
321
+ gt_bboxes,
322
+ mask_gt,
323
+ )
324
+
325
+ target_scores_sum = max(target_scores.sum(), 1)
326
+
327
+ # Cls loss
328
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
329
+ loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
330
+
331
+ if fg_mask.sum():
332
+ # Bbox loss
333
+ loss[0], loss[3] = self.bbox_loss(
334
+ pred_distri,
335
+ pred_bboxes,
336
+ anchor_points,
337
+ target_bboxes / stride_tensor,
338
+ target_scores,
339
+ target_scores_sum,
340
+ fg_mask,
341
+ )
342
+ # Masks loss
343
+ masks = batch["masks"].to(self.device).float()
344
+ if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
345
+ masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
346
+
347
+ loss[1] = self.calculate_segmentation_loss(
348
+ fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
349
+ )
350
+
351
+ # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
352
+ else:
353
+ loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
354
+
355
+ loss[0] *= self.hyp.box # box gain
356
+ loss[1] *= self.hyp.box # seg gain
357
+ loss[2] *= self.hyp.cls # cls gain
358
+ loss[3] *= self.hyp.dfl # dfl gain
359
+
360
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
361
+
362
+ @staticmethod
363
+ def single_mask_loss(
364
+ gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
365
+ ) -> torch.Tensor:
366
+ """
367
+ Compute the instance segmentation loss for a single image.
368
+
369
+ Args:
370
+ gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
371
+ pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
372
+ proto (torch.Tensor): Prototype masks of shape (32, H, W).
373
+ xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
374
+ area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
375
+
376
+ Returns:
377
+ (torch.Tensor): The calculated mask loss for a single image.
378
+
379
+ Notes:
380
+ The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
381
+ predicted masks from the prototype masks and predicted mask coefficients.
382
+ """
383
+ pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
384
+ loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
385
+ return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
386
+
387
+ def calculate_segmentation_loss(
388
+ self,
389
+ fg_mask: torch.Tensor,
390
+ masks: torch.Tensor,
391
+ target_gt_idx: torch.Tensor,
392
+ target_bboxes: torch.Tensor,
393
+ batch_idx: torch.Tensor,
394
+ proto: torch.Tensor,
395
+ pred_masks: torch.Tensor,
396
+ imgsz: torch.Tensor,
397
+ overlap: bool,
398
+ ) -> torch.Tensor:
399
+ """
400
+ Calculate the loss for instance segmentation.
401
+
402
+ Args:
403
+ fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
404
+ masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
405
+ target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
406
+ target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
407
+ batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
408
+ proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
409
+ pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
410
+ imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
411
+ overlap (bool): Whether the masks in `masks` tensor overlap.
412
+
413
+ Returns:
414
+ (torch.Tensor): The calculated loss for instance segmentation.
415
+
416
+ Notes:
417
+ The batch loss can be computed for improved speed at higher memory usage.
418
+ For example, pred_mask can be computed as follows:
419
+ pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
420
+ """
421
+ _, _, mask_h, mask_w = proto.shape
422
+ loss = 0
423
+
424
+ # Normalize to 0-1
425
+ target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
426
+
427
+ # Areas of target bboxes
428
+ marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
429
+
430
+ # Normalize to mask size
431
+ mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
432
+
433
+ for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
434
+ fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
435
+ if fg_mask_i.any():
436
+ mask_idx = target_gt_idx_i[fg_mask_i]
437
+ if overlap:
438
+ gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
439
+ gt_mask = gt_mask.float()
440
+ else:
441
+ gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
442
+
443
+ loss += self.single_mask_loss(
444
+ gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
445
+ )
446
+
447
+ # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
448
+ else:
449
+ loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
450
+
451
+ return loss / fg_mask.sum()
452
+
453
+
454
+ class v8PoseLoss(v8DetectionLoss):
455
+ """Criterion class for computing training losses for YOLOv8 pose estimation."""
456
+
457
+ def __init__(self, model): # model must be de-paralleled
458
+ """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
459
+ super().__init__(model)
460
+ self.kpt_shape = model.model[-1].kpt_shape
461
+ self.bce_pose = nn.BCEWithLogitsLoss()
462
+ is_pose = self.kpt_shape == [17, 3]
463
+ nkpt = self.kpt_shape[0] # number of keypoints
464
+ sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
465
+ self.keypoint_loss = KeypointLoss(sigmas=sigmas)
466
+
467
+ def __call__(self, preds, batch):
468
+ """Calculate the total loss and detach it for pose estimation."""
469
+ loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
470
+ feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
471
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
472
+ (self.reg_max * 4, self.nc), 1
473
+ )
474
+
475
+ # B, grids, ..
476
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
477
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
478
+ pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
479
+
480
+ dtype = pred_scores.dtype
481
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
482
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
483
+
484
+ # Targets
485
+ batch_size = pred_scores.shape[0]
486
+ batch_idx = batch["batch_idx"].view(-1, 1)
487
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
488
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
489
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
490
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
491
+
492
+ # Pboxes
493
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
494
+ pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
495
+
496
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
497
+ pred_scores.detach().sigmoid(),
498
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
499
+ anchor_points * stride_tensor,
500
+ gt_labels,
501
+ gt_bboxes,
502
+ mask_gt,
503
+ )
504
+
505
+ target_scores_sum = max(target_scores.sum(), 1)
506
+
507
+ # Cls loss
508
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
509
+ loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
510
+
511
+ # Bbox loss
512
+ if fg_mask.sum():
513
+ target_bboxes /= stride_tensor
514
+ loss[0], loss[4] = self.bbox_loss(
515
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
516
+ )
517
+ keypoints = batch["keypoints"].to(self.device).float().clone()
518
+ keypoints[..., 0] *= imgsz[1]
519
+ keypoints[..., 1] *= imgsz[0]
520
+
521
+ loss[1], loss[2] = self.calculate_keypoints_loss(
522
+ fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
523
+ )
524
+
525
+ loss[0] *= self.hyp.box # box gain
526
+ loss[1] *= self.hyp.pose # pose gain
527
+ loss[2] *= self.hyp.kobj # kobj gain
528
+ loss[3] *= self.hyp.cls # cls gain
529
+ loss[4] *= self.hyp.dfl # dfl gain
530
+
531
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
532
+
533
+ @staticmethod
534
+ def kpts_decode(anchor_points, pred_kpts):
535
+ """Decode predicted keypoints to image coordinates."""
536
+ y = pred_kpts.clone()
537
+ y[..., :2] *= 2.0
538
+ y[..., 0] += anchor_points[:, [0]] - 0.5
539
+ y[..., 1] += anchor_points[:, [1]] - 0.5
540
+ return y
541
+
542
+ def calculate_keypoints_loss(
543
+ self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
544
+ ):
545
+ """
546
+ Calculate the keypoints loss for the model.
547
+
548
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
549
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
550
+ a binary classification loss that classifies whether a keypoint is present or not.
551
+
552
+ Args:
553
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
554
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
555
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
556
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
557
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
558
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
559
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
560
+
561
+ Returns:
562
+ kpts_loss (torch.Tensor): The keypoints loss.
563
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
564
+ """
565
+ batch_idx = batch_idx.flatten()
566
+ batch_size = len(masks)
567
+
568
+ # Find the maximum number of keypoints in a single image
569
+ max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
570
+
571
+ # Create a tensor to hold batched keypoints
572
+ batched_keypoints = torch.zeros(
573
+ (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
574
+ )
575
+
576
+ # TODO: any idea how to vectorize this?
577
+ # Fill batched_keypoints with keypoints based on batch_idx
578
+ for i in range(batch_size):
579
+ keypoints_i = keypoints[batch_idx == i]
580
+ batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
581
+
582
+ # Expand dimensions of target_gt_idx to match the shape of batched_keypoints
583
+ target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
584
+
585
+ # Use target_gt_idx_expanded to select keypoints from batched_keypoints
586
+ selected_keypoints = batched_keypoints.gather(
587
+ 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
588
+ )
589
+
590
+ # Divide coordinates by stride
591
+ selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
592
+
593
+ kpts_loss = 0
594
+ kpts_obj_loss = 0
595
+
596
+ if masks.any():
597
+ gt_kpt = selected_keypoints[masks]
598
+ area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
599
+ pred_kpt = pred_kpts[masks]
600
+ kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
601
+ kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
602
+
603
+ if pred_kpt.shape[-1] == 3:
604
+ kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
605
+
606
+ return kpts_loss, kpts_obj_loss
607
+
608
+
609
+ class v8ClassificationLoss:
610
+ """Criterion class for computing training losses for classification."""
611
+
612
+ def __call__(self, preds, batch):
613
+ """Compute the classification loss between predictions and true labels."""
614
+ preds = preds[1] if isinstance(preds, (list, tuple)) else preds
615
+ loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
616
+ loss_items = loss.detach()
617
+ return loss, loss_items
618
+
619
+
620
+ class v8OBBLoss(v8DetectionLoss):
621
+ """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
622
+
623
+ def __init__(self, model):
624
+ """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
625
+ super().__init__(model)
626
+ self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
627
+ self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
628
+
629
+ def preprocess(self, targets, batch_size, scale_tensor):
630
+ """Preprocess targets for oriented bounding box detection."""
631
+ if targets.shape[0] == 0:
632
+ out = torch.zeros(batch_size, 0, 6, device=self.device)
633
+ else:
634
+ i = targets[:, 0] # image index
635
+ _, counts = i.unique(return_counts=True)
636
+ counts = counts.to(dtype=torch.int32)
637
+ out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
638
+ for j in range(batch_size):
639
+ matches = i == j
640
+ if n := matches.sum():
641
+ bboxes = targets[matches, 2:]
642
+ bboxes[..., :4].mul_(scale_tensor)
643
+ out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
644
+ return out
645
+
646
+ def __call__(self, preds, batch):
647
+ """Calculate and return the loss for oriented bounding box detection."""
648
+ loss = torch.zeros(3, device=self.device) # box, cls, dfl
649
+ feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
650
+ batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
651
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
652
+ (self.reg_max * 4, self.nc), 1
653
+ )
654
+
655
+ # b, grids, ..
656
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
657
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
658
+ pred_angle = pred_angle.permute(0, 2, 1).contiguous()
659
+
660
+ dtype = pred_scores.dtype
661
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
662
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
663
+
664
+ # targets
665
+ try:
666
+ batch_idx = batch["batch_idx"].view(-1, 1)
667
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
668
+ rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
669
+ targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
670
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
671
+ gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
672
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
673
+ except RuntimeError as e:
674
+ raise TypeError(
675
+ "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
676
+ "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
677
+ "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
678
+ "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
679
+ "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
680
+ ) from e
681
+
682
+ # Pboxes
683
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
684
+
685
+ bboxes_for_assigner = pred_bboxes.clone().detach()
686
+ # Only the first four elements need to be scaled
687
+ bboxes_for_assigner[..., :4] *= stride_tensor
688
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
689
+ pred_scores.detach().sigmoid(),
690
+ bboxes_for_assigner.type(gt_bboxes.dtype),
691
+ anchor_points * stride_tensor,
692
+ gt_labels,
693
+ gt_bboxes,
694
+ mask_gt,
695
+ )
696
+
697
+ target_scores_sum = max(target_scores.sum(), 1)
698
+
699
+ # Cls loss
700
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
701
+ loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
702
+
703
+ # Bbox loss
704
+ if fg_mask.sum():
705
+ target_bboxes[..., :4] /= stride_tensor
706
+ loss[0], loss[2] = self.bbox_loss(
707
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
708
+ )
709
+ else:
710
+ loss[0] += (pred_angle * 0).sum()
711
+
712
+ loss[0] *= self.hyp.box # box gain
713
+ loss[1] *= self.hyp.cls # cls gain
714
+ loss[2] *= self.hyp.dfl # dfl gain
715
+
716
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
717
+
718
+ def bbox_decode(self, anchor_points, pred_dist, pred_angle):
719
+ """
720
+ Decode predicted object bounding box coordinates from anchor points and distribution.
721
+
722
+ Args:
723
+ anchor_points (torch.Tensor): Anchor points, (h*w, 2).
724
+ pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
725
+ pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
726
+
727
+ Returns:
728
+ (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
729
+ """
730
+ if self.use_dfl:
731
+ b, a, c = pred_dist.shape # batch, anchors, channels
732
+ pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
733
+ return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
734
+
735
+
736
+ class E2EDetectLoss:
737
+ """Criterion class for computing training losses for end-to-end detection."""
738
+
739
+ def __init__(self, model):
740
+ """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
741
+ self.one2many = v8DetectionLoss(model, tal_topk=10)
742
+ self.one2one = v8DetectionLoss(model, tal_topk=1)
743
+
744
+ def __call__(self, preds, batch):
745
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
746
+ preds = preds[1] if isinstance(preds, tuple) else preds
747
+ one2many = preds["one2many"]
748
+ loss_one2many = self.one2many(one2many, batch)
749
+ one2one = preds["one2one"]
750
+ loss_one2one = self.one2one(one2one, batch)
751
+ return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
752
+
753
+
754
+ class TVPDetectLoss:
755
+ """Criterion class for computing training losses for text-visual prompt detection."""
756
+
757
+ def __init__(self, model):
758
+ """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
759
+ self.vp_criterion = v8DetectionLoss(model)
760
+ # NOTE: store following info as it's changeable in __call__
761
+ self.ori_nc = self.vp_criterion.nc
762
+ self.ori_no = self.vp_criterion.no
763
+ self.ori_reg_max = self.vp_criterion.reg_max
764
+
765
+ def __call__(self, preds, batch):
766
+ """Calculate the loss for text-visual prompt detection."""
767
+ feats = preds[1] if isinstance(preds, tuple) else preds
768
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
769
+
770
+ if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
771
+ loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
772
+ return loss, loss.detach()
773
+
774
+ vp_feats = self._get_vp_features(feats)
775
+ vp_loss = self.vp_criterion(vp_feats, batch)
776
+ box_loss = vp_loss[0][1]
777
+ return box_loss, vp_loss[1]
778
+
779
+ def _get_vp_features(self, feats):
780
+ """Extract visual-prompt features from the model output."""
781
+ vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
782
+
783
+ self.vp_criterion.nc = vnc
784
+ self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
785
+ self.vp_criterion.assigner.num_classes = vnc
786
+
787
+ return [
788
+ torch.cat((box, cls_vp), dim=1)
789
+ for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
790
+ ]
791
+
792
+
793
+ class TVPSegmentLoss(TVPDetectLoss):
794
+ """Criterion class for computing training losses for text-visual prompt segmentation."""
795
+
796
+ def __init__(self, model):
797
+ """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
798
+ super().__init__(model)
799
+ self.vp_criterion = v8SegmentationLoss(model)
800
+
801
+ def __call__(self, preds, batch):
802
+ """Calculate the loss for text-visual prompt segmentation."""
803
+ feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
804
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
805
+
806
+ if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
807
+ loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
808
+ return loss, loss.detach()
809
+
810
+ vp_feats = self._get_vp_features(feats)
811
+ vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
812
+ cls_loss = vp_loss[0][2]
813
+ return cls_loss, vp_loss[1]