ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,849 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ultralytics.utils.metrics import OKS_SIGMA
12
+ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
13
+ from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
14
+ from ultralytics.utils.torch_utils import autocast
15
+
16
+ from .metrics import bbox_iou, probiou
17
+ from .tal import bbox2dist
18
+
19
+
20
+ class VarifocalLoss(nn.Module):
21
+ """Varifocal loss by Zhang et al.
22
+
23
+ Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
24
+ hard-to-classify examples and balancing positive/negative samples.
25
+
26
+ Attributes:
27
+ gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
28
+ alpha (float): The balancing factor used to address class imbalance.
29
+
30
+ References:
31
+ https://arxiv.org/abs/2008.13367
32
+ """
33
+
34
+ def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
35
+ """Initialize the VarifocalLoss class with focusing and balancing parameters."""
36
+ super().__init__()
37
+ self.gamma = gamma
38
+ self.alpha = alpha
39
+
40
+ def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
41
+ """Compute varifocal loss between predictions and ground truth."""
42
+ weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
43
+ with autocast(enabled=False):
44
+ loss = (
45
+ (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
46
+ .mean(1)
47
+ .sum()
48
+ )
49
+ return loss
50
+
51
+
52
+ class FocalLoss(nn.Module):
53
+ """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
54
+
55
+ Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing on
56
+ hard negatives during training.
57
+
58
+ Attributes:
59
+ gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
60
+ alpha (torch.Tensor): The balancing factor used to address class imbalance.
61
+ """
62
+
63
+ def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
64
+ """Initialize FocalLoss class with focusing and balancing parameters."""
65
+ super().__init__()
66
+ self.gamma = gamma
67
+ self.alpha = torch.tensor(alpha)
68
+
69
+ def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
70
+ """Calculate focal loss with modulating factors for class imbalance."""
71
+ loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
72
+ # p_t = torch.exp(-loss)
73
+ # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
74
+
75
+ # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
76
+ pred_prob = pred.sigmoid() # prob from logits
77
+ p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
78
+ modulating_factor = (1.0 - p_t) ** self.gamma
79
+ loss *= modulating_factor
80
+ if (self.alpha > 0).any():
81
+ self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)
82
+ alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
83
+ loss *= alpha_factor
84
+ return loss.mean(1).sum()
85
+
86
+
87
+ class DFLoss(nn.Module):
88
+ """Criterion class for computing Distribution Focal Loss (DFL)."""
89
+
90
+ def __init__(self, reg_max: int = 16) -> None:
91
+ """Initialize the DFL module with regularization maximum."""
92
+ super().__init__()
93
+ self.reg_max = reg_max
94
+
95
+ def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
96
+ """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
97
+ target = target.clamp_(0, self.reg_max - 1 - 0.01)
98
+ tl = target.long() # target left
99
+ tr = tl + 1 # target right
100
+ wl = tr - target # weight left
101
+ wr = 1 - wl # weight right
102
+ return (
103
+ F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
104
+ + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
105
+ ).mean(-1, keepdim=True)
106
+
107
+
108
+ class BboxLoss(nn.Module):
109
+ """Criterion class for computing training losses for bounding boxes."""
110
+
111
+ def __init__(self, reg_max: int = 16):
112
+ """Initialize the BboxLoss module with regularization maximum and DFL settings."""
113
+ super().__init__()
114
+ self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
115
+
116
+ def forward(
117
+ self,
118
+ pred_dist: torch.Tensor,
119
+ pred_bboxes: torch.Tensor,
120
+ anchor_points: torch.Tensor,
121
+ target_bboxes: torch.Tensor,
122
+ target_scores: torch.Tensor,
123
+ target_scores_sum: torch.Tensor,
124
+ fg_mask: torch.Tensor,
125
+ ) -> tuple[torch.Tensor, torch.Tensor]:
126
+ """Compute IoU and DFL losses for bounding boxes."""
127
+ weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
128
+ iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
129
+ loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
130
+
131
+ # DFL loss
132
+ if self.dfl_loss:
133
+ target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
134
+ loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
135
+ loss_dfl = loss_dfl.sum() / target_scores_sum
136
+ else:
137
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
138
+
139
+ return loss_iou, loss_dfl
140
+
141
+
142
+ class RotatedBboxLoss(BboxLoss):
143
+ """Criterion class for computing training losses for rotated bounding boxes."""
144
+
145
+ def __init__(self, reg_max: int):
146
+ """Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
147
+ super().__init__(reg_max)
148
+
149
+ def forward(
150
+ self,
151
+ pred_dist: torch.Tensor,
152
+ pred_bboxes: torch.Tensor,
153
+ anchor_points: torch.Tensor,
154
+ target_bboxes: torch.Tensor,
155
+ target_scores: torch.Tensor,
156
+ target_scores_sum: torch.Tensor,
157
+ fg_mask: torch.Tensor,
158
+ ) -> tuple[torch.Tensor, torch.Tensor]:
159
+ """Compute IoU and DFL losses for rotated bounding boxes."""
160
+ weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
161
+ iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
162
+ loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
163
+
164
+ # DFL loss
165
+ if self.dfl_loss:
166
+ target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
167
+ loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
168
+ loss_dfl = loss_dfl.sum() / target_scores_sum
169
+ else:
170
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
171
+
172
+ return loss_iou, loss_dfl
173
+
174
+
175
+ class KeypointLoss(nn.Module):
176
+ """Criterion class for computing keypoint losses."""
177
+
178
+ def __init__(self, sigmas: torch.Tensor) -> None:
179
+ """Initialize the KeypointLoss class with keypoint sigmas."""
180
+ super().__init__()
181
+ self.sigmas = sigmas
182
+
183
+ def forward(
184
+ self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
185
+ ) -> torch.Tensor:
186
+ """Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
187
+ d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
188
+ kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
189
+ # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
190
+ e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
191
+ return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
192
+
193
+
194
+ class v8DetectionLoss:
195
+ """Criterion class for computing training losses for YOLOv8 object detection."""
196
+
197
+ def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
198
+ """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
199
+ device = next(model.parameters()).device # get model device
200
+ h = model.args # hyperparameters
201
+
202
+ m = model.model[-1] # Detect() module
203
+ self.bce = nn.BCEWithLogitsLoss(reduction="none")
204
+ self.hyp = h
205
+ self.stride = m.stride # model strides
206
+ self.nc = m.nc # number of classes
207
+ self.no = m.nc + m.reg_max * 4
208
+ self.reg_max = m.reg_max
209
+ self.device = device
210
+
211
+ self.use_dfl = m.reg_max > 1
212
+
213
+ self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
214
+ self.bbox_loss = BboxLoss(m.reg_max).to(device)
215
+ self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
216
+
217
+ def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
218
+ """Preprocess targets by converting to tensor format and scaling coordinates."""
219
+ nl, ne = targets.shape
220
+ if nl == 0:
221
+ out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
222
+ else:
223
+ i = targets[:, 0] # image index
224
+ _, counts = i.unique(return_counts=True)
225
+ counts = counts.to(dtype=torch.int32)
226
+ out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
227
+ for j in range(batch_size):
228
+ matches = i == j
229
+ if n := matches.sum():
230
+ out[j, :n] = targets[matches, 1:]
231
+ out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
232
+ return out
233
+
234
+ def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
235
+ """Decode predicted object bounding box coordinates from anchor points and distribution."""
236
+ if self.use_dfl:
237
+ b, a, c = pred_dist.shape # batch, anchors, channels
238
+ pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
239
+ # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
240
+ # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
241
+ return dist2bbox(pred_dist, anchor_points, xywh=False)
242
+
243
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
244
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
245
+ loss = torch.zeros(3, device=self.device) # box, cls, dfl
246
+ feats = preds[1] if isinstance(preds, tuple) else preds
247
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
248
+ (self.reg_max * 4, self.nc), 1
249
+ )
250
+
251
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
252
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
253
+
254
+ dtype = pred_scores.dtype
255
+ batch_size = pred_scores.shape[0]
256
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
257
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
258
+
259
+ # Targets
260
+ targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
261
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
262
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
263
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
264
+
265
+ # Pboxes
266
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
267
+ # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
268
+ # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
269
+
270
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
271
+ # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
272
+ pred_scores.detach().sigmoid(),
273
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
274
+ anchor_points * stride_tensor,
275
+ gt_labels,
276
+ gt_bboxes,
277
+ mask_gt,
278
+ )
279
+
280
+ target_scores_sum = max(target_scores.sum(), 1)
281
+
282
+ # Cls loss
283
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
284
+ loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
285
+
286
+ # Bbox loss
287
+ if fg_mask.sum():
288
+ loss[0], loss[2] = self.bbox_loss(
289
+ pred_distri,
290
+ pred_bboxes,
291
+ anchor_points,
292
+ target_bboxes / stride_tensor,
293
+ target_scores,
294
+ target_scores_sum,
295
+ fg_mask,
296
+ )
297
+
298
+ loss[0] *= self.hyp.box # box gain
299
+ loss[1] *= self.hyp.cls # cls gain
300
+ loss[2] *= self.hyp.dfl # dfl gain
301
+
302
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
303
+
304
+
305
+ class v8SegmentationLoss(v8DetectionLoss):
306
+ """Criterion class for computing training losses for YOLOv8 segmentation."""
307
+
308
+ def __init__(self, model): # model must be de-paralleled
309
+ """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
310
+ super().__init__(model)
311
+ self.overlap = model.args.overlap_mask
312
+
313
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
314
+ """Calculate and return the combined loss for detection and segmentation."""
315
+ loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
316
+ feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
317
+ batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
318
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
319
+ (self.reg_max * 4, self.nc), 1
320
+ )
321
+
322
+ # B, grids, ..
323
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
324
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
325
+ pred_masks = pred_masks.permute(0, 2, 1).contiguous()
326
+
327
+ dtype = pred_scores.dtype
328
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
329
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
330
+
331
+ # Targets
332
+ try:
333
+ batch_idx = batch["batch_idx"].view(-1, 1)
334
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
335
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
336
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
337
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
338
+ except RuntimeError as e:
339
+ raise TypeError(
340
+ "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
341
+ "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
342
+ "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
343
+ "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
344
+ "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
345
+ ) from e
346
+
347
+ # Pboxes
348
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
349
+
350
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
351
+ pred_scores.detach().sigmoid(),
352
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
353
+ anchor_points * stride_tensor,
354
+ gt_labels,
355
+ gt_bboxes,
356
+ mask_gt,
357
+ )
358
+
359
+ target_scores_sum = max(target_scores.sum(), 1)
360
+
361
+ # Cls loss
362
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
363
+ loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
364
+
365
+ if fg_mask.sum():
366
+ # Bbox loss
367
+ loss[0], loss[3] = self.bbox_loss(
368
+ pred_distri,
369
+ pred_bboxes,
370
+ anchor_points,
371
+ target_bboxes / stride_tensor,
372
+ target_scores,
373
+ target_scores_sum,
374
+ fg_mask,
375
+ )
376
+ # Masks loss
377
+ masks = batch["masks"].to(self.device).float()
378
+ if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
379
+ masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
380
+
381
+ loss[1] = self.calculate_segmentation_loss(
382
+ fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
383
+ )
384
+
385
+ # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
386
+ else:
387
+ loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
388
+
389
+ loss[0] *= self.hyp.box # box gain
390
+ loss[1] *= self.hyp.box # seg gain
391
+ loss[2] *= self.hyp.cls # cls gain
392
+ loss[3] *= self.hyp.dfl # dfl gain
393
+
394
+ return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
395
+
396
+ @staticmethod
397
+ def single_mask_loss(
398
+ gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
399
+ ) -> torch.Tensor:
400
+ """Compute the instance segmentation loss for a single image.
401
+
402
+ Args:
403
+ gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
404
+ pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
405
+ proto (torch.Tensor): Prototype masks of shape (32, H, W).
406
+ xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
407
+ area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
408
+
409
+ Returns:
410
+ (torch.Tensor): The calculated mask loss for a single image.
411
+
412
+ Notes:
413
+ The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
414
+ predicted masks from the prototype masks and predicted mask coefficients.
415
+ """
416
+ pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
417
+ loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
418
+ return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
419
+
420
+ def calculate_segmentation_loss(
421
+ self,
422
+ fg_mask: torch.Tensor,
423
+ masks: torch.Tensor,
424
+ target_gt_idx: torch.Tensor,
425
+ target_bboxes: torch.Tensor,
426
+ batch_idx: torch.Tensor,
427
+ proto: torch.Tensor,
428
+ pred_masks: torch.Tensor,
429
+ imgsz: torch.Tensor,
430
+ overlap: bool,
431
+ ) -> torch.Tensor:
432
+ """Calculate the loss for instance segmentation.
433
+
434
+ Args:
435
+ fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
436
+ masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
437
+ target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
438
+ target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
439
+ batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
440
+ proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
441
+ pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
442
+ imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
443
+ overlap (bool): Whether the masks in `masks` tensor overlap.
444
+
445
+ Returns:
446
+ (torch.Tensor): The calculated loss for instance segmentation.
447
+
448
+ Notes:
449
+ The batch loss can be computed for improved speed at higher memory usage.
450
+ For example, pred_mask can be computed as follows:
451
+ pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
452
+ """
453
+ _, _, mask_h, mask_w = proto.shape
454
+ loss = 0
455
+
456
+ # Normalize to 0-1
457
+ target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
458
+
459
+ # Areas of target bboxes
460
+ marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
461
+
462
+ # Normalize to mask size
463
+ mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
464
+
465
+ for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
466
+ fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
467
+ if fg_mask_i.any():
468
+ mask_idx = target_gt_idx_i[fg_mask_i]
469
+ if overlap:
470
+ gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
471
+ gt_mask = gt_mask.float()
472
+ else:
473
+ gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
474
+
475
+ loss += self.single_mask_loss(
476
+ gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
477
+ )
478
+
479
+ # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
480
+ else:
481
+ loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
482
+
483
+ return loss / fg_mask.sum()
484
+
485
+
486
+ class v8PoseLoss(v8DetectionLoss):
487
+ """Criterion class for computing training losses for YOLOv8 pose estimation."""
488
+
489
+ def __init__(self, model): # model must be de-paralleled
490
+ """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
491
+ super().__init__(model)
492
+ self.kpt_shape = model.model[-1].kpt_shape
493
+ self.bce_pose = nn.BCEWithLogitsLoss()
494
+ is_pose = self.kpt_shape == [17, 3]
495
+ nkpt = self.kpt_shape[0] # number of keypoints
496
+ sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
497
+ self.keypoint_loss = KeypointLoss(sigmas=sigmas)
498
+
499
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
500
+ """Calculate the total loss and detach it for pose estimation."""
501
+ loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
502
+ feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
503
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
504
+ (self.reg_max * 4, self.nc), 1
505
+ )
506
+
507
+ # B, grids, ..
508
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
509
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
510
+ pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
511
+
512
+ dtype = pred_scores.dtype
513
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
514
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
515
+
516
+ # Targets
517
+ batch_size = pred_scores.shape[0]
518
+ batch_idx = batch["batch_idx"].view(-1, 1)
519
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
520
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
521
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
522
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
523
+
524
+ # Pboxes
525
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
526
+ pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
527
+
528
+ _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
529
+ pred_scores.detach().sigmoid(),
530
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
531
+ anchor_points * stride_tensor,
532
+ gt_labels,
533
+ gt_bboxes,
534
+ mask_gt,
535
+ )
536
+
537
+ target_scores_sum = max(target_scores.sum(), 1)
538
+
539
+ # Cls loss
540
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
541
+ loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
542
+
543
+ # Bbox loss
544
+ if fg_mask.sum():
545
+ target_bboxes /= stride_tensor
546
+ loss[0], loss[4] = self.bbox_loss(
547
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
548
+ )
549
+ keypoints = batch["keypoints"].to(self.device).float().clone()
550
+ keypoints[..., 0] *= imgsz[1]
551
+ keypoints[..., 1] *= imgsz[0]
552
+
553
+ loss[1], loss[2] = self.calculate_keypoints_loss(
554
+ fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
555
+ )
556
+
557
+ loss[0] *= self.hyp.box # box gain
558
+ loss[1] *= self.hyp.pose # pose gain
559
+ loss[2] *= self.hyp.kobj # kobj gain
560
+ loss[3] *= self.hyp.cls # cls gain
561
+ loss[4] *= self.hyp.dfl # dfl gain
562
+
563
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
564
+
565
+ @staticmethod
566
+ def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
567
+ """Decode predicted keypoints to image coordinates."""
568
+ y = pred_kpts.clone()
569
+ y[..., :2] *= 2.0
570
+ y[..., 0] += anchor_points[:, [0]] - 0.5
571
+ y[..., 1] += anchor_points[:, [1]] - 0.5
572
+ return y
573
+
574
+ def calculate_keypoints_loss(
575
+ self,
576
+ masks: torch.Tensor,
577
+ target_gt_idx: torch.Tensor,
578
+ keypoints: torch.Tensor,
579
+ batch_idx: torch.Tensor,
580
+ stride_tensor: torch.Tensor,
581
+ target_bboxes: torch.Tensor,
582
+ pred_kpts: torch.Tensor,
583
+ ) -> tuple[torch.Tensor, torch.Tensor]:
584
+ """Calculate the keypoints loss for the model.
585
+
586
+ This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
587
+ based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
588
+ a binary classification loss that classifies whether a keypoint is present or not.
589
+
590
+ Args:
591
+ masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
592
+ target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
593
+ keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
594
+ batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
595
+ stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
596
+ target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
597
+ pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
598
+
599
+ Returns:
600
+ kpts_loss (torch.Tensor): The keypoints loss.
601
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
602
+ """
603
+ batch_idx = batch_idx.flatten()
604
+ batch_size = len(masks)
605
+
606
+ # Find the maximum number of keypoints in a single image
607
+ max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
608
+
609
+ # Create a tensor to hold batched keypoints
610
+ batched_keypoints = torch.zeros(
611
+ (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
612
+ )
613
+
614
+ # TODO: any idea how to vectorize this?
615
+ # Fill batched_keypoints with keypoints based on batch_idx
616
+ for i in range(batch_size):
617
+ keypoints_i = keypoints[batch_idx == i]
618
+ batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
619
+
620
+ # Expand dimensions of target_gt_idx to match the shape of batched_keypoints
621
+ target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
622
+
623
+ # Use target_gt_idx_expanded to select keypoints from batched_keypoints
624
+ selected_keypoints = batched_keypoints.gather(
625
+ 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
626
+ )
627
+
628
+ # Divide coordinates by stride
629
+ selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
630
+
631
+ kpts_loss = 0
632
+ kpts_obj_loss = 0
633
+
634
+ if masks.any():
635
+ gt_kpt = selected_keypoints[masks]
636
+ area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
637
+ pred_kpt = pred_kpts[masks]
638
+ kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
639
+ kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
640
+
641
+ if pred_kpt.shape[-1] == 3:
642
+ kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
643
+
644
+ return kpts_loss, kpts_obj_loss
645
+
646
+
647
+ class v8ClassificationLoss:
648
+ """Criterion class for computing training losses for classification."""
649
+
650
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
651
+ """Compute the classification loss between predictions and true labels."""
652
+ preds = preds[1] if isinstance(preds, (list, tuple)) else preds
653
+ loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
654
+ return loss, loss.detach()
655
+
656
+
657
+ class v8OBBLoss(v8DetectionLoss):
658
+ """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
659
+
660
+ def __init__(self, model):
661
+ """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
662
+ super().__init__(model)
663
+ self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
664
+ self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
665
+
666
+ def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
667
+ """Preprocess targets for oriented bounding box detection."""
668
+ if targets.shape[0] == 0:
669
+ out = torch.zeros(batch_size, 0, 6, device=self.device)
670
+ else:
671
+ i = targets[:, 0] # image index
672
+ _, counts = i.unique(return_counts=True)
673
+ counts = counts.to(dtype=torch.int32)
674
+ out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
675
+ for j in range(batch_size):
676
+ matches = i == j
677
+ if n := matches.sum():
678
+ bboxes = targets[matches, 2:]
679
+ bboxes[..., :4].mul_(scale_tensor)
680
+ out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
681
+ return out
682
+
683
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
684
+ """Calculate and return the loss for oriented bounding box detection."""
685
+ loss = torch.zeros(3, device=self.device) # box, cls, dfl
686
+ feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
687
+ batch_size = pred_angle.shape[0] # batch size
688
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
689
+ (self.reg_max * 4, self.nc), 1
690
+ )
691
+
692
+ # b, grids, ..
693
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
694
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
695
+ pred_angle = pred_angle.permute(0, 2, 1).contiguous()
696
+
697
+ dtype = pred_scores.dtype
698
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
699
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
700
+
701
+ # targets
702
+ try:
703
+ batch_idx = batch["batch_idx"].view(-1, 1)
704
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
705
+ rw, rh = targets[:, 4] * float(imgsz[1]), targets[:, 5] * float(imgsz[0])
706
+ targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
707
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
708
+ gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
709
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
710
+ except RuntimeError as e:
711
+ raise TypeError(
712
+ "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
713
+ "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
714
+ "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
715
+ "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
716
+ "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
717
+ ) from e
718
+
719
+ # Pboxes
720
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
721
+
722
+ bboxes_for_assigner = pred_bboxes.clone().detach()
723
+ # Only the first four elements need to be scaled
724
+ bboxes_for_assigner[..., :4] *= stride_tensor
725
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
726
+ pred_scores.detach().sigmoid(),
727
+ bboxes_for_assigner.type(gt_bboxes.dtype),
728
+ anchor_points * stride_tensor,
729
+ gt_labels,
730
+ gt_bboxes,
731
+ mask_gt,
732
+ )
733
+
734
+ target_scores_sum = max(target_scores.sum(), 1)
735
+
736
+ # Cls loss
737
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
738
+ loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
739
+
740
+ # Bbox loss
741
+ if fg_mask.sum():
742
+ target_bboxes[..., :4] /= stride_tensor
743
+ loss[0], loss[2] = self.bbox_loss(
744
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
745
+ )
746
+ else:
747
+ loss[0] += (pred_angle * 0).sum()
748
+
749
+ loss[0] *= self.hyp.box # box gain
750
+ loss[1] *= self.hyp.cls # cls gain
751
+ loss[2] *= self.hyp.dfl # dfl gain
752
+
753
+ return loss * batch_size, loss.detach() # loss(box, cls, dfl)
754
+
755
+ def bbox_decode(
756
+ self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
757
+ ) -> torch.Tensor:
758
+ """Decode predicted object bounding box coordinates from anchor points and distribution.
759
+
760
+ Args:
761
+ anchor_points (torch.Tensor): Anchor points, (h*w, 2).
762
+ pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
763
+ pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
764
+
765
+ Returns:
766
+ (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
767
+ """
768
+ if self.use_dfl:
769
+ b, a, c = pred_dist.shape # batch, anchors, channels
770
+ pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
771
+ return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
772
+
773
+
774
+ class E2EDetectLoss:
775
+ """Criterion class for computing training losses for end-to-end detection."""
776
+
777
+ def __init__(self, model):
778
+ """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
779
+ self.one2many = v8DetectionLoss(model, tal_topk=10)
780
+ self.one2one = v8DetectionLoss(model, tal_topk=1)
781
+
782
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
783
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
784
+ preds = preds[1] if isinstance(preds, tuple) else preds
785
+ one2many = preds["one2many"]
786
+ loss_one2many = self.one2many(one2many, batch)
787
+ one2one = preds["one2one"]
788
+ loss_one2one = self.one2one(one2one, batch)
789
+ return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
790
+
791
+
792
+ class TVPDetectLoss:
793
+ """Criterion class for computing training losses for text-visual prompt detection."""
794
+
795
+ def __init__(self, model):
796
+ """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
797
+ self.vp_criterion = v8DetectionLoss(model)
798
+ # NOTE: store following info as it's changeable in __call__
799
+ self.ori_nc = self.vp_criterion.nc
800
+ self.ori_no = self.vp_criterion.no
801
+ self.ori_reg_max = self.vp_criterion.reg_max
802
+
803
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
804
+ """Calculate the loss for text-visual prompt detection."""
805
+ feats = preds[1] if isinstance(preds, tuple) else preds
806
+
807
+ if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
808
+ loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
809
+ return loss, loss.detach()
810
+
811
+ vp_feats = self._get_vp_features(feats)
812
+ vp_loss = self.vp_criterion(vp_feats, batch)
813
+ cls_loss = vp_loss[0][1]
814
+ return cls_loss, vp_loss[1]
815
+
816
+ def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
817
+ """Extract visual-prompt features from the model output."""
818
+ vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
819
+
820
+ self.vp_criterion.nc = vnc
821
+ self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
822
+ self.vp_criterion.assigner.num_classes = vnc
823
+
824
+ return [
825
+ torch.cat((box, cls_vp), dim=1)
826
+ for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
827
+ ]
828
+
829
+
830
+ class TVPSegmentLoss(TVPDetectLoss):
831
+ """Criterion class for computing training losses for text-visual prompt segmentation."""
832
+
833
+ def __init__(self, model):
834
+ """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
835
+ super().__init__(model)
836
+ self.vp_criterion = v8SegmentationLoss(model)
837
+
838
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
839
+ """Calculate the loss for text-visual prompt segmentation."""
840
+ feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
841
+
842
+ if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
843
+ loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
844
+ return loss, loss.detach()
845
+
846
+ vp_feats = self._get_vp_features(feats)
847
+ vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
848
+ cls_loss = vp_loss[0][2]
849
+ return cls_loss, vp_loss[1]