ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,160 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """Provides utility to combine a vision backbone with a language backbone."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from copy import copy
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn.attention import SDPBackend, sdpa_kernel
14
+
15
+ from .necks import Sam3DualViTDetNeck
16
+
17
+
18
+ class SAM3VLBackbone(nn.Module):
19
+ """This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a
20
+ convenience wrapper to handle the two backbones together.
21
+
22
+ It adds support for activation checkpointing and compilation.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ visual: Sam3DualViTDetNeck,
28
+ text,
29
+ compile_visual: bool = False,
30
+ act_ckpt_whole_vision_backbone: bool = False,
31
+ act_ckpt_whole_language_backbone: bool = False,
32
+ scalp=0,
33
+ ):
34
+ """Initialize the backbone combiner.
35
+
36
+ :param visual: The vision backbone to use
37
+ :param text: The text encoder to use
38
+ """
39
+ super().__init__()
40
+ self.vision_backbone: Sam3DualViTDetNeck = torch.compile(visual) if compile_visual else visual
41
+ self.language_backbone = text
42
+ self.scalp = scalp
43
+ # allow running activation checkpointing on the entire vision and language backbones
44
+ self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
45
+ self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
46
+
47
+ def forward(
48
+ self,
49
+ samples: torch.Tensor,
50
+ captions: list[str],
51
+ input_boxes: torch.Tensor = None,
52
+ additional_text: list[str] | None = None,
53
+ ):
54
+ """Forward pass of the backbone combiner.
55
+
56
+ :param samples: The input images
57
+ :param captions: The input captions
58
+ :param input_boxes: If the text contains place-holders for boxes, this
59
+ parameter contains the tensor containing their spatial features
60
+ :param additional_text: This can be used to encode some additional text
61
+ (different from the captions) in the same forward of the backbone
62
+ :return: Output dictionary with the following keys:
63
+ - vision_features: The output of the vision backbone
64
+ - language_features: The output of the language backbone
65
+ - language_mask: The attention mask of the language backbone
66
+ - vision_pos_enc: The positional encoding of the vision backbone
67
+ - (optional) additional_text_features: The output of the language
68
+ backbone for the additional text
69
+ - (optional) additional_text_mask: The attention mask of the
70
+ language backbone for the additional text
71
+ """
72
+ output = self.forward_image(samples)
73
+ output.update(self.forward_text(captions, input_boxes, additional_text))
74
+ return output
75
+
76
+ def forward_image(self, samples: torch.Tensor):
77
+ """Forward pass of the vision backbone and get both SAM3 and SAM2 features."""
78
+ # Forward through backbone
79
+ sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(samples)
80
+ if self.scalp > 0:
81
+ # Discard the lowest resolution features
82
+ sam3_features, sam3_pos = (
83
+ sam3_features[: -self.scalp],
84
+ sam3_pos[: -self.scalp],
85
+ )
86
+ if sam2_features is not None and sam2_pos is not None:
87
+ sam2_features, sam2_pos = (
88
+ sam2_features[: -self.scalp],
89
+ sam2_pos[: -self.scalp],
90
+ )
91
+
92
+ sam2_output = None
93
+
94
+ if sam2_features is not None and sam2_pos is not None:
95
+ sam2_src = sam2_features[-1]
96
+ sam2_output = {
97
+ "vision_features": sam2_src,
98
+ "vision_pos_enc": sam2_pos,
99
+ "backbone_fpn": sam2_features,
100
+ }
101
+
102
+ sam3_src = sam3_features[-1]
103
+ return {
104
+ "vision_features": sam3_src,
105
+ "vision_pos_enc": sam3_pos,
106
+ "backbone_fpn": sam3_features,
107
+ "sam2_backbone_out": sam2_output,
108
+ }
109
+
110
+ def forward_image_sam2(self, samples: torch.Tensor):
111
+ """Forward pass of the vision backbone to get SAM2 features only."""
112
+ xs = self.vision_backbone.trunk(samples)
113
+ x = xs[-1] # simpleFPN
114
+
115
+ assert self.vision_backbone.sam2_convs is not None, "SAM2 neck is not available."
116
+ sam2_features, sam2_pos = self.vision_backbone.sam_forward_feature_levels(x, self.vision_backbone.sam2_convs)
117
+
118
+ if self.scalp > 0:
119
+ # Discard the lowest resolution features
120
+ sam2_features, sam2_pos = (
121
+ sam2_features[: -self.scalp],
122
+ sam2_pos[: -self.scalp],
123
+ )
124
+
125
+ return {
126
+ "vision_features": sam2_features[-1],
127
+ "vision_pos_enc": sam2_pos,
128
+ "backbone_fpn": sam2_features,
129
+ }
130
+
131
+ def forward_text(self, captions, input_boxes=None, additional_text=None):
132
+ """Forward pass of the text encoder."""
133
+ output = {}
134
+
135
+ # Forward through text_encoder
136
+ text_to_encode = copy(captions)
137
+ if additional_text is not None:
138
+ # if there are additional_text, we piggy-back them into this forward.
139
+ # They'll be used later for output alignment
140
+ text_to_encode += additional_text
141
+
142
+ with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]):
143
+ text_attention_mask, text_memory, text_embeds = self.language_backbone(text_to_encode, input_boxes)
144
+
145
+ if additional_text is not None:
146
+ output["additional_text_features"] = text_memory[:, -len(additional_text) :]
147
+ output["additional_text_mask"] = text_attention_mask[-len(additional_text) :]
148
+
149
+ text_memory = text_memory[:, : len(captions)]
150
+ text_attention_mask = text_attention_mask[: len(captions)]
151
+ text_embeds = text_embeds[:, : len(captions)]
152
+ output["language_features"] = text_memory
153
+ output["language_mask"] = text_attention_mask
154
+ output["language_embeds"] = text_embeds # Text embeddings before forward to the encoder
155
+
156
+ return output
157
+
158
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
159
+ """Set the image size for the vision backbone."""
160
+ self.vision_backbone.set_imgsz(imgsz)
@@ -0,0 +1 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
@@ -0,0 +1,466 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ultralytics.utils.loss import FocalLoss, VarifocalLoss
12
+ from ultralytics.utils.metrics import bbox_iou
13
+
14
+ from .ops import HungarianMatcher
15
+
16
+
17
+ class DETRLoss(nn.Module):
18
+ """DETR (DEtection TRansformer) Loss class for calculating various loss components.
19
+
20
+ This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
21
+ object detection model.
22
+
23
+ Attributes:
24
+ nc (int): Number of classes.
25
+ loss_gain (dict[str, float]): Coefficients for different loss components.
26
+ aux_loss (bool): Whether to compute auxiliary losses.
27
+ use_fl (bool): Whether to use FocalLoss.
28
+ use_vfl (bool): Whether to use VarifocalLoss.
29
+ use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.
30
+ uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.
31
+ matcher (HungarianMatcher): Object to compute matching cost and indices.
32
+ fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.
33
+ vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.
34
+ device (torch.device): Device on which tensors are stored.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ nc: int = 80,
40
+ loss_gain: dict[str, float] | None = None,
41
+ aux_loss: bool = True,
42
+ use_fl: bool = True,
43
+ use_vfl: bool = False,
44
+ use_uni_match: bool = False,
45
+ uni_match_ind: int = 0,
46
+ gamma: float = 1.5,
47
+ alpha: float = 0.25,
48
+ ):
49
+ """Initialize DETR loss function with customizable components and gains.
50
+
51
+ Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
52
+ losses and various loss types.
53
+
54
+ Args:
55
+ nc (int): Number of classes.
56
+ loss_gain (dict[str, float], optional): Coefficients for different loss components.
57
+ aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
58
+ use_fl (bool): Whether to use FocalLoss.
59
+ use_vfl (bool): Whether to use VarifocalLoss.
60
+ use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
61
+ uni_match_ind (int): Index of fixed layer for uni_match.
62
+ gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
63
+ alpha (float): The balancing factor used to address class imbalance.
64
+ """
65
+ super().__init__()
66
+
67
+ if loss_gain is None:
68
+ loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
69
+ self.nc = nc
70
+ self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
71
+ self.loss_gain = loss_gain
72
+ self.aux_loss = aux_loss
73
+ self.fl = FocalLoss(gamma, alpha) if use_fl else None
74
+ self.vfl = VarifocalLoss(gamma, alpha) if use_vfl else None
75
+
76
+ self.use_uni_match = use_uni_match
77
+ self.uni_match_ind = uni_match_ind
78
+ self.device = None
79
+
80
+ def _get_loss_class(
81
+ self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
82
+ ) -> dict[str, torch.Tensor]:
83
+ """Compute classification loss based on predictions, target values, and ground truth scores.
84
+
85
+ Args:
86
+ pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
87
+ targets (torch.Tensor): Target class indices with shape (B, N).
88
+ gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
89
+ num_gts (int): Number of ground truth objects.
90
+ postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
91
+
92
+ Returns:
93
+ (dict[str, torch.Tensor]): Dictionary containing classification loss value.
94
+
95
+ Notes:
96
+ The function supports different classification loss types:
97
+ - Varifocal Loss (if self.vfl is True and num_gts > 0)
98
+ - Focal Loss (if self.fl is True)
99
+ - BCE Loss (default fallback)
100
+ """
101
+ # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
102
+ name_class = f"loss_class{postfix}"
103
+ bs, nq = pred_scores.shape[:2]
104
+ # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
105
+ one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
106
+ one_hot.scatter_(2, targets.unsqueeze(-1), 1)
107
+ one_hot = one_hot[..., :-1]
108
+ gt_scores = gt_scores.view(bs, nq, 1) * one_hot
109
+
110
+ if self.fl:
111
+ if num_gts and self.vfl:
112
+ loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
113
+ else:
114
+ loss_cls = self.fl(pred_scores, one_hot.float())
115
+ loss_cls /= max(num_gts, 1) / nq
116
+ else:
117
+ loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
118
+
119
+ return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
120
+
121
+ def _get_loss_bbox(
122
+ self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
123
+ ) -> dict[str, torch.Tensor]:
124
+ """Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
125
+
126
+ Args:
127
+ pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
128
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
129
+ postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
130
+
131
+ Returns:
132
+ (dict[str, torch.Tensor]): Dictionary containing:
133
+ - loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
134
+ - loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
135
+
136
+ Notes:
137
+ If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
138
+ """
139
+ # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
140
+ name_bbox = f"loss_bbox{postfix}"
141
+ name_giou = f"loss_giou{postfix}"
142
+
143
+ loss = {}
144
+ if len(gt_bboxes) == 0:
145
+ loss[name_bbox] = torch.tensor(0.0, device=self.device)
146
+ loss[name_giou] = torch.tensor(0.0, device=self.device)
147
+ return loss
148
+
149
+ loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
150
+ loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
151
+ loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
152
+ loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
153
+ return {k: v.squeeze() for k, v in loss.items()}
154
+
155
+ # This function is for future RT-DETR Segment models
156
+ # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
157
+ # # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
158
+ # name_mask = f'loss_mask{postfix}'
159
+ # name_dice = f'loss_dice{postfix}'
160
+ #
161
+ # loss = {}
162
+ # if sum(len(a) for a in gt_mask) == 0:
163
+ # loss[name_mask] = torch.tensor(0., device=self.device)
164
+ # loss[name_dice] = torch.tensor(0., device=self.device)
165
+ # return loss
166
+ #
167
+ # num_gts = len(gt_mask)
168
+ # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
169
+ # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
170
+ # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
171
+ # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
172
+ # torch.tensor([num_gts], dtype=torch.float32))
173
+ # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
174
+ # return loss
175
+
176
+ # This function is for future RT-DETR Segment models
177
+ # @staticmethod
178
+ # def _dice_loss(inputs, targets, num_gts):
179
+ # inputs = F.sigmoid(inputs).flatten(1)
180
+ # targets = targets.flatten(1)
181
+ # numerator = 2 * (inputs * targets).sum(1)
182
+ # denominator = inputs.sum(-1) + targets.sum(-1)
183
+ # loss = 1 - (numerator + 1) / (denominator + 1)
184
+ # return loss.sum() / num_gts
185
+
186
+ def _get_loss_aux(
187
+ self,
188
+ pred_bboxes: torch.Tensor,
189
+ pred_scores: torch.Tensor,
190
+ gt_bboxes: torch.Tensor,
191
+ gt_cls: torch.Tensor,
192
+ gt_groups: list[int],
193
+ match_indices: list[tuple] | None = None,
194
+ postfix: str = "",
195
+ masks: torch.Tensor | None = None,
196
+ gt_mask: torch.Tensor | None = None,
197
+ ) -> dict[str, torch.Tensor]:
198
+ """Get auxiliary losses for intermediate decoder layers.
199
+
200
+ Args:
201
+ pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
202
+ pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
203
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes.
204
+ gt_cls (torch.Tensor): Ground truth classes.
205
+ gt_groups (list[int]): Number of ground truths per image.
206
+ match_indices (list[tuple], optional): Pre-computed matching indices.
207
+ postfix (str, optional): String to append to loss names.
208
+ masks (torch.Tensor, optional): Predicted masks if using segmentation.
209
+ gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
210
+
211
+ Returns:
212
+ (dict[str, torch.Tensor]): Dictionary of auxiliary losses.
213
+ """
214
+ # NOTE: loss class, bbox, giou, mask, dice
215
+ loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
216
+ if match_indices is None and self.use_uni_match:
217
+ match_indices = self.matcher(
218
+ pred_bboxes[self.uni_match_ind],
219
+ pred_scores[self.uni_match_ind],
220
+ gt_bboxes,
221
+ gt_cls,
222
+ gt_groups,
223
+ masks=masks[self.uni_match_ind] if masks is not None else None,
224
+ gt_mask=gt_mask,
225
+ )
226
+ for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
227
+ aux_masks = masks[i] if masks is not None else None
228
+ loss_ = self._get_loss(
229
+ aux_bboxes,
230
+ aux_scores,
231
+ gt_bboxes,
232
+ gt_cls,
233
+ gt_groups,
234
+ masks=aux_masks,
235
+ gt_mask=gt_mask,
236
+ postfix=postfix,
237
+ match_indices=match_indices,
238
+ )
239
+ loss[0] += loss_[f"loss_class{postfix}"]
240
+ loss[1] += loss_[f"loss_bbox{postfix}"]
241
+ loss[2] += loss_[f"loss_giou{postfix}"]
242
+ # if masks is not None and gt_mask is not None:
243
+ # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
244
+ # loss[3] += loss_[f'loss_mask{postfix}']
245
+ # loss[4] += loss_[f'loss_dice{postfix}']
246
+
247
+ loss = {
248
+ f"loss_class_aux{postfix}": loss[0],
249
+ f"loss_bbox_aux{postfix}": loss[1],
250
+ f"loss_giou_aux{postfix}": loss[2],
251
+ }
252
+ # if masks is not None and gt_mask is not None:
253
+ # loss[f'loss_mask_aux{postfix}'] = loss[3]
254
+ # loss[f'loss_dice_aux{postfix}'] = loss[4]
255
+ return loss
256
+
257
+ @staticmethod
258
+ def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
259
+ """Extract batch indices, source indices, and destination indices from match indices.
260
+
261
+ Args:
262
+ match_indices (list[tuple]): List of tuples containing matched indices.
263
+
264
+ Returns:
265
+ batch_idx (tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
266
+ dst_idx (torch.Tensor): Destination indices.
267
+ """
268
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
269
+ src_idx = torch.cat([src for (src, _) in match_indices])
270
+ dst_idx = torch.cat([dst for (_, dst) in match_indices])
271
+ return (batch_idx, src_idx), dst_idx
272
+
273
+ def _get_assigned_bboxes(
274
+ self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
275
+ ) -> tuple[torch.Tensor, torch.Tensor]:
276
+ """Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
277
+
278
+ Args:
279
+ pred_bboxes (torch.Tensor): Predicted bounding boxes.
280
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes.
281
+ match_indices (list[tuple]): List of tuples containing matched indices.
282
+
283
+ Returns:
284
+ pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
285
+ gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
286
+ """
287
+ pred_assigned = torch.cat(
288
+ [
289
+ t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
290
+ for t, (i, _) in zip(pred_bboxes, match_indices)
291
+ ]
292
+ )
293
+ gt_assigned = torch.cat(
294
+ [
295
+ t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
296
+ for t, (_, j) in zip(gt_bboxes, match_indices)
297
+ ]
298
+ )
299
+ return pred_assigned, gt_assigned
300
+
301
+ def _get_loss(
302
+ self,
303
+ pred_bboxes: torch.Tensor,
304
+ pred_scores: torch.Tensor,
305
+ gt_bboxes: torch.Tensor,
306
+ gt_cls: torch.Tensor,
307
+ gt_groups: list[int],
308
+ masks: torch.Tensor | None = None,
309
+ gt_mask: torch.Tensor | None = None,
310
+ postfix: str = "",
311
+ match_indices: list[tuple] | None = None,
312
+ ) -> dict[str, torch.Tensor]:
313
+ """Calculate losses for a single prediction layer.
314
+
315
+ Args:
316
+ pred_bboxes (torch.Tensor): Predicted bounding boxes.
317
+ pred_scores (torch.Tensor): Predicted class scores.
318
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes.
319
+ gt_cls (torch.Tensor): Ground truth classes.
320
+ gt_groups (list[int]): Number of ground truths per image.
321
+ masks (torch.Tensor, optional): Predicted masks if using segmentation.
322
+ gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
323
+ postfix (str, optional): String to append to loss names.
324
+ match_indices (list[tuple], optional): Pre-computed matching indices.
325
+
326
+ Returns:
327
+ (dict[str, torch.Tensor]): Dictionary of losses.
328
+ """
329
+ if match_indices is None:
330
+ match_indices = self.matcher(
331
+ pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
332
+ )
333
+
334
+ idx, gt_idx = self._get_index(match_indices)
335
+ pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
336
+
337
+ bs, nq = pred_scores.shape[:2]
338
+ targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
339
+ targets[idx] = gt_cls[gt_idx]
340
+
341
+ gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
342
+ if len(gt_bboxes):
343
+ gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
344
+
345
+ return {
346
+ **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
347
+ **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
348
+ # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
349
+ }
350
+
351
+ def forward(
352
+ self,
353
+ pred_bboxes: torch.Tensor,
354
+ pred_scores: torch.Tensor,
355
+ batch: dict[str, Any],
356
+ postfix: str = "",
357
+ **kwargs: Any,
358
+ ) -> dict[str, torch.Tensor]:
359
+ """Calculate loss for predicted bounding boxes and scores.
360
+
361
+ Args:
362
+ pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
363
+ pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
364
+ batch (dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
365
+ postfix (str, optional): Postfix for loss names.
366
+ **kwargs (Any): Additional arguments, may include 'match_indices'.
367
+
368
+ Returns:
369
+ (dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
370
+
371
+ Notes:
372
+ Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
373
+ self.aux_loss is True.
374
+ """
375
+ self.device = pred_bboxes.device
376
+ match_indices = kwargs.get("match_indices", None)
377
+ gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
378
+
379
+ total_loss = self._get_loss(
380
+ pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
381
+ )
382
+
383
+ if self.aux_loss:
384
+ total_loss.update(
385
+ self._get_loss_aux(
386
+ pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
387
+ )
388
+ )
389
+
390
+ return total_loss
391
+
392
+
393
+ class RTDETRDetectionLoss(DETRLoss):
394
+ """Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
395
+
396
+ This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
397
+ an additional denoising training loss when provided with denoising metadata.
398
+ """
399
+
400
+ def forward(
401
+ self,
402
+ preds: tuple[torch.Tensor, torch.Tensor],
403
+ batch: dict[str, Any],
404
+ dn_bboxes: torch.Tensor | None = None,
405
+ dn_scores: torch.Tensor | None = None,
406
+ dn_meta: dict[str, Any] | None = None,
407
+ ) -> dict[str, torch.Tensor]:
408
+ """Forward pass to compute detection loss with optional denoising loss.
409
+
410
+ Args:
411
+ preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
412
+ batch (dict[str, Any]): Batch data containing ground truth information.
413
+ dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
414
+ dn_scores (torch.Tensor, optional): Denoising scores.
415
+ dn_meta (dict[str, Any], optional): Metadata for denoising.
416
+
417
+ Returns:
418
+ (dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
419
+ """
420
+ pred_bboxes, pred_scores = preds
421
+ total_loss = super().forward(pred_bboxes, pred_scores, batch)
422
+
423
+ # Check for denoising metadata to compute denoising training loss
424
+ if dn_meta is not None:
425
+ dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
426
+ assert len(batch["gt_groups"]) == len(dn_pos_idx)
427
+
428
+ # Get the match indices for denoising
429
+ match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
430
+
431
+ # Compute the denoising training loss
432
+ dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
433
+ total_loss.update(dn_loss)
434
+ else:
435
+ # If no denoising metadata is provided, set denoising loss to zero
436
+ total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
437
+
438
+ return total_loss
439
+
440
+ @staticmethod
441
+ def get_dn_match_indices(
442
+ dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
443
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
444
+ """Get match indices for denoising.
445
+
446
+ Args:
447
+ dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
448
+ dn_num_group (int): Number of denoising groups.
449
+ gt_groups (list[int]): List of integers representing number of ground truths per image.
450
+
451
+ Returns:
452
+ (list[tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
453
+ """
454
+ dn_match_indices = []
455
+ idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
456
+ for i, num_gt in enumerate(gt_groups):
457
+ if num_gt > 0:
458
+ gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
459
+ gt_idx = gt_idx.repeat(dn_num_group)
460
+ assert len(dn_pos_idx[i]) == len(gt_idx), (
461
+ f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
462
+ )
463
+ dn_match_indices.append((dn_pos_idx[i], gt_idx))
464
+ else:
465
+ dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
466
+ return dn_match_indices