dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -15,17 +15,16 @@ from ultralytics.utils import NOT_MACOS14
15
15
  from ultralytics.utils.tal import dist2bbox, dist2rbox, make_anchors
16
16
  from ultralytics.utils.torch_utils import TORCH_1_11, fuse_conv_and_bn, smart_inference_mode
17
17
 
18
- from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN
18
+ from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Proto26, RealNVP, Residual, SwiGLUFFN
19
19
  from .conv import Conv, DWConv
20
20
  from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
21
21
  from .utils import bias_init_with_prob, linear_init
22
22
 
23
- __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "YOLOEDetect", "YOLOESegment"
23
+ __all__ = "OBB", "Classify", "Detect", "Pose", "RTDETRDecoder", "Segment", "YOLOEDetect", "YOLOESegment", "v10Detect"
24
24
 
25
25
 
26
26
  class Detect(nn.Module):
27
- """
28
- YOLO Detect head for object detection models.
27
+ """YOLO Detect head for object detection models.
29
28
 
30
29
  This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.
31
30
  It supports both training and inference modes, with optional end-to-end detection capabilities.
@@ -69,7 +68,6 @@ class Detect(nn.Module):
69
68
  dynamic = False # force grid reconstruction
70
69
  export = False # export mode
71
70
  format = None # export format
72
- end2end = False # end2end
73
71
  max_det = 300 # max_det
74
72
  shape = None
75
73
  anchors = torch.empty(0) # init
@@ -77,18 +75,19 @@ class Detect(nn.Module):
77
75
  legacy = False # backward compatibility for v3/v5/v8/v9 models
78
76
  xyxy = False # xyxy or xywh output
79
77
 
80
- def __init__(self, nc: int = 80, ch: tuple = ()):
81
- """
82
- Initialize the YOLO detection layer with specified number of classes and channels.
78
+ def __init__(self, nc: int = 80, reg_max=16, end2end=False, ch: tuple = ()):
79
+ """Initialize the YOLO detection layer with specified number of classes and channels.
83
80
 
84
81
  Args:
85
82
  nc (int): Number of classes.
83
+ reg_max (int): Maximum number of DFL channels.
84
+ end2end (bool): Whether to use end-to-end NMS-free detection.
86
85
  ch (tuple): Tuple of channel sizes from backbone feature maps.
87
86
  """
88
87
  super().__init__()
89
88
  self.nc = nc # number of classes
90
89
  self.nl = len(ch) # number of detection layers
91
- self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
90
+ self.reg_max = reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
92
91
  self.no = nc + self.reg_max * 4 # number of outputs per anchor
93
92
  self.stride = torch.zeros(self.nl) # strides computed during build
94
93
  c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
@@ -109,93 +108,88 @@ class Detect(nn.Module):
109
108
  )
110
109
  self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
111
110
 
112
- if self.end2end:
111
+ if end2end:
113
112
  self.one2one_cv2 = copy.deepcopy(self.cv2)
114
113
  self.one2one_cv3 = copy.deepcopy(self.cv3)
115
114
 
116
- def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor] | tuple:
117
- """Concatenate and return predicted bounding boxes and class probabilities."""
115
+ @property
116
+ def one2many(self):
117
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
118
+ return dict(box_head=self.cv2, cls_head=self.cv3)
119
+
120
+ @property
121
+ def one2one(self):
122
+ """Returns the one-to-one head components."""
123
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3)
124
+
125
+ @property
126
+ def end2end(self):
127
+ """Checks if the model has one2one for v5/v5/v8/v9/11 backward compatibility."""
128
+ return hasattr(self, "one2one")
129
+
130
+ def forward_head(
131
+ self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None
132
+ ) -> dict[str, torch.Tensor]:
133
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
134
+ if box_head is None or cls_head is None: # for fused inference
135
+ return dict()
136
+ bs = x[0].shape[0] # batch size
137
+ boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
138
+ scores = torch.cat([cls_head[i](x[i]).view(bs, self.nc, -1) for i in range(self.nl)], dim=-1)
139
+ return dict(boxes=boxes, scores=scores, feats=x)
140
+
141
+ def forward(
142
+ self, x: list[torch.Tensor]
143
+ ) -> dict[str, torch.Tensor] | torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
144
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
145
+ preds = self.forward_head(x, **self.one2many)
118
146
  if self.end2end:
119
- return self.forward_end2end(x)
120
-
121
- for i in range(self.nl):
122
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
123
- if self.training: # Training path
124
- return x
125
- y = self._inference(x)
126
- return y if self.export else (y, x)
127
-
128
- def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
129
- """
130
- Perform forward pass of the v10Detect module.
131
-
132
- Args:
133
- x (list[torch.Tensor]): Input feature maps from different levels.
134
-
135
- Returns:
136
- outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.
137
- Inference mode returns processed detections or tuple with detections and raw outputs.
138
- """
139
- x_detach = [xi.detach() for xi in x]
140
- one2one = [
141
- torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
142
- ]
143
- for i in range(self.nl):
144
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
145
- if self.training: # Training path
146
- return {"one2many": x, "one2one": one2one}
147
+ x_detach = [xi.detach() for xi in x]
148
+ one2one = self.forward_head(x_detach, **self.one2one)
149
+ preds = {"one2many": preds, "one2one": one2one}
150
+ if self.training:
151
+ return preds
152
+ y = self._inference(preds["one2one"] if self.end2end else preds)
153
+ if self.end2end:
154
+ y = self.postprocess(y.permute(0, 2, 1))
155
+ return y if self.export else (y, preds)
147
156
 
148
- y = self._inference(one2one)
149
- y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
150
- return y if self.export else (y, {"one2many": x, "one2one": one2one})
151
-
152
- def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
153
- """
154
- Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
157
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
158
+ """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
155
159
 
156
160
  Args:
157
- x (list[torch.Tensor]): List of feature maps from different detection layers.
161
+ x (dict[str, torch.Tensor]): List of feature maps from different detection layers.
158
162
 
159
163
  Returns:
160
164
  (torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.
161
165
  """
162
166
  # Inference path
163
- shape = x[0].shape # BCHW
164
- x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
167
+ dbox = self._get_decode_boxes(x)
168
+ return torch.cat((dbox, x["scores"].sigmoid()), 1)
169
+
170
+ def _get_decode_boxes(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
171
+ """Get decoded boxes based on anchors and strides."""
172
+ shape = x["feats"][0].shape # BCHW
165
173
  if self.dynamic or self.shape != shape:
166
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
174
+ self.anchors, self.strides = (a.transpose(0, 1) for a in make_anchors(x["feats"], self.stride, 0.5))
167
175
  self.shape = shape
168
176
 
169
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
170
- box = x_cat[:, : self.reg_max * 4]
171
- cls = x_cat[:, self.reg_max * 4 :]
172
- else:
173
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
174
-
175
- if self.export and self.format in {"tflite", "edgetpu"}:
176
- # Precompute normalization factor to increase numerical stability
177
- # See https://github.com/ultralytics/ultralytics/issues/7371
178
- grid_h = shape[2]
179
- grid_w = shape[3]
180
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
181
- norm = self.strides / (self.stride[0] * grid_size)
182
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
183
- else:
184
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
185
- return torch.cat((dbox, cls.sigmoid()), 1)
177
+ dbox = self.decode_bboxes(self.dfl(x["boxes"]), self.anchors.unsqueeze(0)) * self.strides
178
+ return dbox
186
179
 
187
180
  def bias_init(self):
188
181
  """Initialize Detect() biases, WARNING: requires stride availability."""
189
- m = self # self.model[-1] # Detect() module
190
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
191
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
192
- for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
193
- a[-1].bias.data[:] = 1.0 # box
194
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
182
+ for i, (a, b) in enumerate(zip(self.one2many["box_head"], self.one2many["cls_head"])): # from
183
+ a[-1].bias.data[:] = 2.0 # box
184
+ b[-1].bias.data[: self.nc] = math.log(
185
+ 5 / self.nc / (640 / self.stride[i]) ** 2
186
+ ) # cls (.01 objects, 80 classes, 640 img)
195
187
  if self.end2end:
196
- for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
197
- a[-1].bias.data[:] = 1.0 # box
198
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
188
+ for i, (a, b) in enumerate(zip(self.one2one["box_head"], self.one2one["cls_head"])): # from
189
+ a[-1].bias.data[:] = 2.0 # box
190
+ b[-1].bias.data[: self.nc] = math.log(
191
+ 5 / self.nc / (640 / self.stride[i]) ** 2
192
+ ) # cls (.01 objects, 80 classes, 640 img)
199
193
 
200
194
  def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
201
195
  """Decode bounding boxes from predictions."""
@@ -206,34 +200,49 @@ class Detect(nn.Module):
206
200
  dim=1,
207
201
  )
208
202
 
209
- @staticmethod
210
- def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
211
- """
212
- Post-process YOLO model predictions.
203
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
204
+ """Post-processes YOLO model predictions.
213
205
 
214
206
  Args:
215
207
  preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
216
208
  format [x, y, w, h, class_probs].
217
- max_det (int): Maximum detections per image.
218
- nc (int, optional): Number of classes.
219
209
 
220
210
  Returns:
221
211
  (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
222
212
  dimension format [x, y, w, h, max_class_prob, class_index].
223
213
  """
224
- batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
225
- boxes, scores = preds.split([4, nc], dim=-1)
226
- index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
227
- boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
228
- scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
229
- scores, index = scores.flatten(1).topk(min(max_det, anchors))
230
- i = torch.arange(batch_size)[..., None] # batch indices
231
- return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
214
+ boxes, scores = preds.split([4, self.nc], dim=-1)
215
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
216
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
217
+ return torch.cat([boxes, scores, conf], dim=-1)
218
+
219
+ def get_topk_index(self, scores: torch.Tensor, max_det: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
220
+ """Get top-k indices from scores.
221
+
222
+ Args:
223
+ scores (torch.Tensor): Scores tensor with shape (batch_size, num_anchors, num_classes).
224
+ max_det (int): Maximum detections per image.
225
+
226
+ Returns:
227
+ (torch.Tensor, torch.Tensor, torch.Tensor): Top scores, class indices, and filtered indices.
228
+ """
229
+ batch_size, anchors, nc = scores.shape # i.e. shape(16,8400,84)
230
+ # Use max_det directly during export for TensorRT compatibility (requires k to be constant),
231
+ # otherwise use min(max_det, anchors) for safety with small inputs during Python inference
232
+ k = max_det if self.export else min(max_det, anchors)
233
+ ori_index = scores.max(dim=-1)[0].topk(k)[1].unsqueeze(-1)
234
+ scores = scores.gather(dim=1, index=ori_index.repeat(1, 1, nc))
235
+ scores, index = scores.flatten(1).topk(k)
236
+ idx = ori_index[torch.arange(batch_size)[..., None], index // nc] # original index
237
+ return scores[..., None], (index % nc)[..., None].float(), idx
238
+
239
+ def fuse(self) -> None:
240
+ """Remove the one2many head for inference optimization."""
241
+ self.cv2 = self.cv3 = None
232
242
 
233
243
 
234
244
  class Segment(Detect):
235
- """
236
- YOLO Segment head for segmentation models.
245
+ """YOLO Segment head for segmentation models.
237
246
 
238
247
  This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
239
248
 
@@ -253,39 +262,150 @@ class Segment(Detect):
253
262
  >>> outputs = segment(x)
254
263
  """
255
264
 
256
- def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
257
- """
258
- Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
265
+ def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
266
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
259
267
 
260
268
  Args:
261
269
  nc (int): Number of classes.
262
270
  nm (int): Number of masks.
263
271
  npr (int): Number of protos.
272
+ reg_max (int): Maximum number of DFL channels.
273
+ end2end (bool): Whether to use end-to-end NMS-free detection.
264
274
  ch (tuple): Tuple of channel sizes from backbone feature maps.
265
275
  """
266
- super().__init__(nc, ch)
276
+ super().__init__(nc, reg_max, end2end, ch)
267
277
  self.nm = nm # number of masks
268
278
  self.npr = npr # number of protos
269
279
  self.proto = Proto(ch[0], self.npr, self.nm) # protos
270
280
 
271
281
  c4 = max(ch[0] // 4, self.nm)
272
282
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
283
+ if end2end:
284
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
285
+
286
+ @property
287
+ def one2many(self):
288
+ """Returns the one-to-many head components, here for backward compatibility."""
289
+ return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv4)
290
+
291
+ @property
292
+ def one2one(self):
293
+ """Returns the one-to-one head components."""
294
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, mask_head=self.one2one_cv4)
273
295
 
274
- def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor]:
296
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
275
297
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
276
- p = self.proto(x[0]) # mask protos
277
- bs = p.shape[0] # batch size
298
+ outputs = super().forward(x)
299
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
300
+ proto = self.proto(x[0]) # mask protos
301
+ if isinstance(preds, dict): # training and validating during training
302
+ if self.end2end:
303
+ preds["one2many"]["proto"] = proto
304
+ preds["one2one"]["proto"] = proto.detach()
305
+ else:
306
+ preds["proto"] = proto
307
+ if self.training:
308
+ return preds
309
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
310
+
311
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
312
+ """Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
313
+ preds = super()._inference(x)
314
+ return torch.cat([preds, x["mask_coefficient"]], dim=1)
315
+
316
+ def forward_head(
317
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, mask_head: torch.nn.Module
318
+ ) -> torch.Tensor:
319
+ """Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
320
+ preds = super().forward_head(x, box_head, cls_head)
321
+ if mask_head is not None:
322
+ bs = x[0].shape[0] # batch size
323
+ preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
324
+ return preds
325
+
326
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
327
+ """Post-process YOLO model predictions.
328
+
329
+ Args:
330
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
331
+ format [x, y, w, h, class_probs, mask_coefficient].
332
+
333
+ Returns:
334
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
335
+ dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
336
+ """
337
+ boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
338
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
339
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
340
+ mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
341
+ return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
342
+
343
+ def fuse(self) -> None:
344
+ """Remove the one2many head for inference optimization."""
345
+ self.cv2 = self.cv3 = self.cv4 = None
278
346
 
279
- mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
280
- x = Detect.forward(self, x)
347
+
348
+ class Segment26(Segment):
349
+ """YOLO26 Segment head for segmentation models.
350
+
351
+ This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
352
+
353
+ Attributes:
354
+ nm (int): Number of masks.
355
+ npr (int): Number of protos.
356
+ proto (Proto): Prototype generation module.
357
+ cv4 (nn.ModuleList): Convolution layers for mask coefficients.
358
+
359
+ Methods:
360
+ forward: Return model outputs and mask coefficients.
361
+
362
+ Examples:
363
+ Create a segmentation head
364
+ >>> segment = Segment26(nc=80, nm=32, npr=256, ch=(256, 512, 1024))
365
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
366
+ >>> outputs = segment(x)
367
+ """
368
+
369
+ def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
370
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
371
+
372
+ Args:
373
+ nc (int): Number of classes.
374
+ nm (int): Number of masks.
375
+ npr (int): Number of protos.
376
+ reg_max (int): Maximum number of DFL channels.
377
+ end2end (bool): Whether to use end-to-end NMS-free detection.
378
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
379
+ """
380
+ super().__init__(nc, nm, npr, reg_max, end2end, ch)
381
+ self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
382
+
383
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
384
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
385
+ outputs = Detect.forward(self, x)
386
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
387
+ proto = self.proto(x) # mask protos
388
+ if isinstance(preds, dict): # training and validating during training
389
+ if self.end2end:
390
+ preds["one2many"]["proto"] = proto
391
+ preds["one2one"]["proto"] = (
392
+ tuple(p.detach() for p in proto) if isinstance(proto, tuple) else proto.detach()
393
+ )
394
+ else:
395
+ preds["proto"] = proto
281
396
  if self.training:
282
- return x, mc, p
283
- return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
397
+ return preds
398
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
399
+
400
+ def fuse(self) -> None:
401
+ """Remove the one2many head and extra part of proto module for inference optimization."""
402
+ super().fuse()
403
+ if hasattr(self.proto, "fuse"):
404
+ self.proto.fuse()
284
405
 
285
406
 
286
407
  class OBB(Detect):
287
- """
288
- YOLO OBB detection head for detection with rotation models.
408
+ """YOLO OBB detection head for detection with rotation models.
289
409
 
290
410
  This class extends the Detect head to include oriented bounding box prediction with rotation angles.
291
411
 
@@ -305,43 +425,117 @@ class OBB(Detect):
305
425
  >>> outputs = obb(x)
306
426
  """
307
427
 
308
- def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
309
- """
310
- Initialize OBB with number of classes `nc` and layer channels `ch`.
428
+ def __init__(self, nc: int = 80, ne: int = 1, reg_max=16, end2end=False, ch: tuple = ()):
429
+ """Initialize OBB with number of classes `nc` and layer channels `ch`.
311
430
 
312
431
  Args:
313
432
  nc (int): Number of classes.
314
433
  ne (int): Number of extra parameters.
434
+ reg_max (int): Maximum number of DFL channels.
435
+ end2end (bool): Whether to use end-to-end NMS-free detection.
315
436
  ch (tuple): Tuple of channel sizes from backbone feature maps.
316
437
  """
317
- super().__init__(nc, ch)
438
+ super().__init__(nc, reg_max, end2end, ch)
318
439
  self.ne = ne # number of extra parameters
319
440
 
320
441
  c4 = max(ch[0] // 4, self.ne)
321
442
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
322
-
323
- def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
324
- """Concatenate and return predicted bounding boxes and class probabilities."""
325
- bs = x[0].shape[0] # batch size
326
- angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
327
- # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
328
- angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
329
- # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
330
- if not self.training:
331
- self.angle = angle
332
- x = Detect.forward(self, x)
333
- if self.training:
334
- return x, angle
335
- return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
443
+ if end2end:
444
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
445
+
446
+ @property
447
+ def one2many(self):
448
+ """Returns the one-to-many head components, here for backward compatibility."""
449
+ return dict(box_head=self.cv2, cls_head=self.cv3, angle_head=self.cv4)
450
+
451
+ @property
452
+ def one2one(self):
453
+ """Returns the one-to-one head components."""
454
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, angle_head=self.one2one_cv4)
455
+
456
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
457
+ """Decode predicted bounding boxes and class probabilities, concatenated with rotation angles."""
458
+ # For decode_bboxes convenience
459
+ self.angle = x["angle"] # TODO: need to test obb
460
+ preds = super()._inference(x)
461
+ return torch.cat([preds, x["angle"]], dim=1)
462
+
463
+ def forward_head(
464
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
465
+ ) -> torch.Tensor:
466
+ """Concatenates and returns predicted bounding boxes, class probabilities, and angles."""
467
+ preds = super().forward_head(x, box_head, cls_head)
468
+ if angle_head is not None:
469
+ bs = x[0].shape[0] # batch size
470
+ angle = torch.cat(
471
+ [angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
472
+ ) # OBB theta logits
473
+ angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
474
+ preds["angle"] = angle
475
+ return preds
336
476
 
337
477
  def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
338
478
  """Decode rotated bounding boxes."""
339
479
  return dist2rbox(bboxes, self.angle, anchors, dim=1)
340
480
 
481
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
482
+ """Post-process YOLO model predictions.
341
483
 
342
- class Pose(Detect):
484
+ Args:
485
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + ne) with last dimension
486
+ format [x, y, w, h, class_probs, angle].
487
+
488
+ Returns:
489
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 7) and last
490
+ dimension format [x, y, w, h, max_class_prob, class_index, angle].
491
+ """
492
+ boxes, scores, angle = preds.split([4, self.nc, self.ne], dim=-1)
493
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
494
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
495
+ angle = angle.gather(dim=1, index=idx.repeat(1, 1, self.ne))
496
+ return torch.cat([boxes, scores, conf, angle], dim=-1)
497
+
498
+ def fuse(self) -> None:
499
+ """Remove the one2many head for inference optimization."""
500
+ self.cv2 = self.cv3 = self.cv4 = None
501
+
502
+
503
+ class OBB26(OBB):
504
+ """YOLO26 OBB detection head for detection with rotation models. This class extends the OBB head with modified angle
505
+ processing that outputs raw angle predictions without sigmoid transformation, compared to the original
506
+ OBB class.
507
+
508
+ Attributes:
509
+ ne (int): Number of extra parameters.
510
+ cv4 (nn.ModuleList): Convolution layers for angle prediction.
511
+ angle (torch.Tensor): Predicted rotation angles.
512
+
513
+ Methods:
514
+ forward_head: Concatenate and return predicted bounding boxes, class probabilities, and raw angles.
515
+
516
+ Examples:
517
+ Create an OBB26 detection head
518
+ >>> obb26 = OBB26(nc=80, ne=1, ch=(256, 512, 1024))
519
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
520
+ >>> outputs = obb26(x).
343
521
  """
344
- YOLO Pose head for keypoints models.
522
+
523
+ def forward_head(
524
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
525
+ ) -> torch.Tensor:
526
+ """Concatenates and returns predicted bounding boxes, class probabilities, and raw angles."""
527
+ preds = Detect.forward_head(self, x, box_head, cls_head)
528
+ if angle_head is not None:
529
+ bs = x[0].shape[0] # batch size
530
+ angle = torch.cat(
531
+ [angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
532
+ ) # OBB theta logits (raw output without sigmoid transformation)
533
+ preds["angle"] = angle
534
+ return preds
535
+
536
+
537
+ class Pose(Detect):
538
+ """YOLO Pose head for keypoints models.
345
539
 
346
540
  This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
347
541
 
@@ -361,50 +555,78 @@ class Pose(Detect):
361
555
  >>> outputs = pose(x)
362
556
  """
363
557
 
364
- def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
365
- """
366
- Initialize YOLO network with default parameters and Convolutional Layers.
558
+ def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
559
+ """Initialize YOLO network with default parameters and Convolutional Layers.
367
560
 
368
561
  Args:
369
562
  nc (int): Number of classes.
370
563
  kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
564
+ reg_max (int): Maximum number of DFL channels.
565
+ end2end (bool): Whether to use end-to-end NMS-free detection.
371
566
  ch (tuple): Tuple of channel sizes from backbone feature maps.
372
567
  """
373
- super().__init__(nc, ch)
568
+ super().__init__(nc, reg_max, end2end, ch)
374
569
  self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
375
570
  self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
376
571
 
377
572
  c4 = max(ch[0] // 4, self.nk)
378
573
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
574
+ if end2end:
575
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
576
+
577
+ @property
578
+ def one2many(self):
579
+ """Returns the one-to-many head components, here for backward compatibility."""
580
+ return dict(box_head=self.cv2, cls_head=self.cv3, pose_head=self.cv4)
581
+
582
+ @property
583
+ def one2one(self):
584
+ """Returns the one-to-one head components."""
585
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, pose_head=self.one2one_cv4)
586
+
587
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
588
+ """Decode predicted bounding boxes and class probabilities, concatenated with keypoints."""
589
+ preds = super()._inference(x)
590
+ return torch.cat([preds, self.kpts_decode(x["kpts"])], dim=1)
591
+
592
+ def forward_head(
593
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, pose_head: torch.nn.Module
594
+ ) -> torch.Tensor:
595
+ """Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
596
+ preds = super().forward_head(x, box_head, cls_head)
597
+ if pose_head is not None:
598
+ bs = x[0].shape[0] # batch size
599
+ preds["kpts"] = torch.cat([pose_head[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
600
+ return preds
601
+
602
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
603
+ """Post-process YOLO model predictions.
379
604
 
380
- def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
381
- """Perform forward pass through YOLO model and return predictions."""
382
- bs = x[0].shape[0] # batch size
383
- kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
384
- x = Detect.forward(self, x)
385
- if self.training:
386
- return x, kpt
387
- pred_kpt = self.kpts_decode(bs, kpt)
388
- return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
605
+ Args:
606
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nk) with last dimension
607
+ format [x, y, w, h, class_probs, keypoints].
389
608
 
390
- def kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
609
+ Returns:
610
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + self.nk) and
611
+ last dimension format [x, y, w, h, max_class_prob, class_index, keypoints].
612
+ """
613
+ boxes, scores, kpts = preds.split([4, self.nc, self.nk], dim=-1)
614
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
615
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
616
+ kpts = kpts.gather(dim=1, index=idx.repeat(1, 1, self.nk))
617
+ return torch.cat([boxes, scores, conf, kpts], dim=-1)
618
+
619
+ def fuse(self) -> None:
620
+ """Remove the one2many head for inference optimization."""
621
+ self.cv2 = self.cv3 = self.cv4 = None
622
+
623
+ def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
391
624
  """Decode keypoints from predictions."""
392
625
  ndim = self.kpt_shape[1]
626
+ bs = kpts.shape[0]
393
627
  if self.export:
394
- if self.format in {
395
- "tflite",
396
- "edgetpu",
397
- }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
398
- # Precompute normalization factor to increase numerical stability
399
- y = kpts.view(bs, *self.kpt_shape, -1)
400
- grid_h, grid_w = self.shape[2], self.shape[3]
401
- grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
402
- norm = self.strides / (self.stride[0] * grid_size)
403
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
404
- else:
405
- # NCNN fix
406
- y = kpts.view(bs, *self.kpt_shape, -1)
407
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
628
+ y = kpts.view(bs, *self.kpt_shape, -1)
629
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
408
630
  if ndim == 3:
409
631
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
410
632
  return a.view(bs, self.nk, -1)
@@ -420,9 +642,125 @@ class Pose(Detect):
420
642
  return y
421
643
 
422
644
 
423
- class Classify(nn.Module):
645
+ class Pose26(Pose):
646
+ """YOLO26 Pose head for keypoints models.
647
+
648
+ This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
649
+
650
+ Attributes:
651
+ kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).
652
+ nk (int): Total number of keypoint values.
653
+ cv4 (nn.ModuleList): Convolution layers for keypoint prediction.
654
+
655
+ Methods:
656
+ forward: Perform forward pass through YOLO model and return predictions.
657
+ kpts_decode: Decode keypoints from predictions.
658
+
659
+ Examples:
660
+ Create a pose detection head
661
+ >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))
662
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
663
+ >>> outputs = pose(x)
424
664
  """
425
- YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
665
+
666
+ def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
667
+ """Initialize YOLO network with default parameters and Convolutional Layers.
668
+
669
+ Args:
670
+ nc (int): Number of classes.
671
+ kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
672
+ reg_max (int): Maximum number of DFL channels.
673
+ end2end (bool): Whether to use end-to-end NMS-free detection.
674
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
675
+ """
676
+ super().__init__(nc, kpt_shape, reg_max, end2end, ch)
677
+ self.flow_model = RealNVP()
678
+
679
+ c4 = max(ch[0] // 4, kpt_shape[0] * (kpt_shape[1] + 2))
680
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3)) for x in ch)
681
+
682
+ self.cv4_kpts = nn.ModuleList(nn.Conv2d(c4, self.nk, 1) for _ in ch)
683
+ self.nk_sigma = kpt_shape[0] * 2 # sigma_x, sigma_y for each keypoint
684
+ self.cv4_sigma = nn.ModuleList(nn.Conv2d(c4, self.nk_sigma, 1) for _ in ch)
685
+
686
+ if end2end:
687
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
688
+ self.one2one_cv4_kpts = copy.deepcopy(self.cv4_kpts)
689
+ self.one2one_cv4_sigma = copy.deepcopy(self.cv4_sigma)
690
+
691
+ @property
692
+ def one2many(self):
693
+ """Returns the one-to-many head components, here for backward compatibility."""
694
+ return dict(
695
+ box_head=self.cv2,
696
+ cls_head=self.cv3,
697
+ pose_head=self.cv4,
698
+ kpts_head=self.cv4_kpts,
699
+ kpts_sigma_head=self.cv4_sigma,
700
+ )
701
+
702
+ @property
703
+ def one2one(self):
704
+ """Returns the one-to-one head components."""
705
+ return dict(
706
+ box_head=self.one2one_cv2,
707
+ cls_head=self.one2one_cv3,
708
+ pose_head=self.one2one_cv4,
709
+ kpts_head=self.one2one_cv4_kpts,
710
+ kpts_sigma_head=self.one2one_cv4_sigma,
711
+ )
712
+
713
+ def forward_head(
714
+ self,
715
+ x: list[torch.Tensor],
716
+ box_head: torch.nn.Module,
717
+ cls_head: torch.nn.Module,
718
+ pose_head: torch.nn.Module,
719
+ kpts_head: torch.nn.Module,
720
+ kpts_sigma_head: torch.nn.Module,
721
+ ) -> torch.Tensor:
722
+ """Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
723
+ preds = Detect.forward_head(self, x, box_head, cls_head)
724
+ if pose_head is not None:
725
+ bs = x[0].shape[0] # batch size
726
+ features = [pose_head[i](x[i]) for i in range(self.nl)]
727
+ preds["kpts"] = torch.cat([kpts_head[i](features[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
728
+ if self.training:
729
+ preds["kpts_sigma"] = torch.cat(
730
+ [kpts_sigma_head[i](features[i]).view(bs, self.nk_sigma, -1) for i in range(self.nl)], 2
731
+ )
732
+ return preds
733
+
734
+ def fuse(self) -> None:
735
+ """Remove the one2many head for inference optimization."""
736
+ super().fuse()
737
+ self.cv4_kpts = self.cv4_sigma = self.flow_model = self.one2one_cv4_sigma = None
738
+
739
+ def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
740
+ """Decode keypoints from predictions."""
741
+ ndim = self.kpt_shape[1]
742
+ bs = kpts.shape[0]
743
+ if self.export:
744
+ y = kpts.view(bs, *self.kpt_shape, -1)
745
+ # NCNN fix
746
+ a = (y[:, :, :2] + self.anchors) * self.strides
747
+ if ndim == 3:
748
+ a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
749
+ return a.view(bs, self.nk, -1)
750
+ else:
751
+ y = kpts.clone()
752
+ if ndim == 3:
753
+ if NOT_MACOS14:
754
+ y[:, 2::ndim].sigmoid_()
755
+ else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
756
+ y[:, 2::ndim] = y[:, 2::ndim].sigmoid()
757
+ y[:, 0::ndim] = (y[:, 0::ndim] + self.anchors[0]) * self.strides
758
+ y[:, 1::ndim] = (y[:, 1::ndim] + self.anchors[1]) * self.strides
759
+ return y
760
+
761
+
762
+ class Classify(nn.Module):
763
+ """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
426
764
 
427
765
  This class implements a classification head that transforms feature maps into class predictions.
428
766
 
@@ -446,8 +784,7 @@ class Classify(nn.Module):
446
784
  export = False # export mode
447
785
 
448
786
  def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
449
- """
450
- Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
787
+ """Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
451
788
 
452
789
  Args:
453
790
  c1 (int): Number of input channels.
@@ -476,11 +813,10 @@ class Classify(nn.Module):
476
813
 
477
814
 
478
815
  class WorldDetect(Detect):
479
- """
480
- Head for integrating YOLO detection models with semantic understanding from text embeddings.
816
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
481
817
 
482
- This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding
483
- in object detection tasks.
818
+ This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding in
819
+ object detection tasks.
484
820
 
485
821
  Attributes:
486
822
  cv3 (nn.ModuleList): Convolution layers for embedding features.
@@ -498,30 +834,44 @@ class WorldDetect(Detect):
498
834
  >>> outputs = world_detect(x, text)
499
835
  """
500
836
 
501
- def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
502
- """
503
- Initialize YOLO detection layer with nc classes and layer channels ch.
837
+ def __init__(
838
+ self,
839
+ nc: int = 80,
840
+ embed: int = 512,
841
+ with_bn: bool = False,
842
+ reg_max: int = 16,
843
+ end2end: bool = False,
844
+ ch: tuple = (),
845
+ ):
846
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
504
847
 
505
848
  Args:
506
849
  nc (int): Number of classes.
507
850
  embed (int): Embedding dimension.
508
851
  with_bn (bool): Whether to use batch normalization in contrastive head.
852
+ reg_max (int): Maximum number of DFL channels.
853
+ end2end (bool): Whether to use end-to-end NMS-free detection.
509
854
  ch (tuple): Tuple of channel sizes from backbone feature maps.
510
855
  """
511
- super().__init__(nc, ch)
856
+ super().__init__(nc, reg_max=reg_max, end2end=end2end, ch=ch)
512
857
  c3 = max(ch[0], min(self.nc, 100))
513
858
  self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
514
859
  self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
515
860
 
516
- def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> list[torch.Tensor] | tuple:
861
+ def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> dict[str, torch.Tensor] | tuple:
517
862
  """Concatenate and return predicted bounding boxes and class probabilities."""
863
+ feats = [xi.clone() for xi in x] # save original features for anchor generation
518
864
  for i in range(self.nl):
519
865
  x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
520
- if self.training:
521
- return x
522
866
  self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
523
- y = self._inference(x)
524
- return y if self.export else (y, x)
867
+ bs = x[0].shape[0]
868
+ x_cat = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2)
869
+ boxes, scores = x_cat.split((self.reg_max * 4, self.nc), 1)
870
+ preds = dict(boxes=boxes, scores=scores, feats=feats)
871
+ if self.training:
872
+ return preds
873
+ y = self._inference(preds)
874
+ return y if self.export else (y, preds)
525
875
 
526
876
  def bias_init(self):
527
877
  """Initialize Detect() biases, WARNING: requires stride availability."""
@@ -534,11 +884,10 @@ class WorldDetect(Detect):
534
884
 
535
885
 
536
886
  class LRPCHead(nn.Module):
537
- """
538
- Lightweight Region Proposal and Classification Head for efficient object detection.
887
+ """Lightweight Region Proposal and Classification Head for efficient object detection.
539
888
 
540
- This head combines region proposal filtering with classification to enable efficient detection with
541
- dynamic vocabulary support.
889
+ This head combines region proposal filtering with classification to enable efficient detection with dynamic
890
+ vocabulary support.
542
891
 
543
892
  Attributes:
544
893
  vocab (nn.Module): Vocabulary/classification layer.
@@ -559,8 +908,7 @@ class LRPCHead(nn.Module):
559
908
  """
560
909
 
561
910
  def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):
562
- """
563
- Initialize LRPCHead with vocabulary, proposal filter, and localization components.
911
+ """Initialize LRPCHead with vocabulary, proposal filter, and localization components.
564
912
 
565
913
  Args:
566
914
  vocab (nn.Module): Vocabulary/classification module.
@@ -574,7 +922,8 @@ class LRPCHead(nn.Module):
574
922
  self.loc = loc
575
923
  self.enabled = enabled
576
924
 
577
- def conv2linear(self, conv: nn.Conv2d) -> nn.Linear:
925
+ @staticmethod
926
+ def conv2linear(conv: nn.Conv2d) -> nn.Linear:
578
927
  """Convert a 1x1 convolutional layer to a linear layer."""
579
928
  assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)
580
929
  linear = nn.Linear(conv.in_channels, conv.out_channels)
@@ -589,18 +938,19 @@ class LRPCHead(nn.Module):
589
938
  mask = pf_score.sigmoid() > conf
590
939
  cls_feat = cls_feat.flatten(2).transpose(-1, -2)
591
940
  cls_feat = self.vocab(cls_feat[:, mask] if conf else cls_feat * mask.unsqueeze(-1).int())
592
- return (self.loc(loc_feat), cls_feat.transpose(-1, -2)), mask
941
+ return self.loc(loc_feat), cls_feat.transpose(-1, -2), mask
593
942
  else:
594
943
  cls_feat = self.vocab(cls_feat)
595
944
  loc_feat = self.loc(loc_feat)
596
- return (loc_feat, cls_feat.flatten(2)), torch.ones(
597
- cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool
945
+ return (
946
+ loc_feat,
947
+ cls_feat.flatten(2),
948
+ torch.ones(cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool),
598
949
  )
599
950
 
600
951
 
601
952
  class YOLOEDetect(Detect):
602
- """
603
- Head for integrating YOLO detection models with semantic understanding from text embeddings.
953
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
604
954
 
605
955
  This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding
606
956
  through text embeddings and visual prompt embeddings.
@@ -631,17 +981,20 @@ class YOLOEDetect(Detect):
631
981
 
632
982
  is_fused = False
633
983
 
634
- def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
635
- """
636
- Initialize YOLO detection layer with nc classes and layer channels ch.
984
+ def __init__(
985
+ self, nc: int = 80, embed: int = 512, with_bn: bool = False, reg_max=16, end2end=False, ch: tuple = ()
986
+ ):
987
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
637
988
 
638
989
  Args:
639
990
  nc (int): Number of classes.
640
991
  embed (int): Embedding dimension.
641
992
  with_bn (bool): Whether to use batch normalization in contrastive head.
993
+ reg_max (int): Maximum number of DFL channels.
994
+ end2end (bool): Whether to use end-to-end NMS-free detection.
642
995
  ch (tuple): Tuple of channel sizes from backbone feature maps.
643
996
  """
644
- super().__init__(nc, ch)
997
+ super().__init__(nc, reg_max, end2end, ch)
645
998
  c3 = max(ch[0], min(self.nc, 100))
646
999
  assert c3 <= embed
647
1000
  assert with_bn
@@ -657,29 +1010,43 @@ class YOLOEDetect(Detect):
657
1010
  for x in ch
658
1011
  )
659
1012
  )
660
-
661
1013
  self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
1014
+ if end2end:
1015
+ self.one2one_cv3 = copy.deepcopy(self.cv3) # overwrite with new cv3
1016
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
662
1017
 
663
1018
  self.reprta = Residual(SwiGLUFFN(embed, embed))
664
1019
  self.savpe = SAVPE(ch, c3, embed)
665
1020
  self.embed = embed
666
1021
 
667
1022
  @smart_inference_mode()
668
- def fuse(self, txt_feats: torch.Tensor):
1023
+ def fuse(self, txt_feats: torch.Tensor = None):
669
1024
  """Fuse text features with model weights for efficient inference."""
1025
+ if txt_feats is None: # means eliminate one2many branch
1026
+ self.cv2 = self.cv3 = self.cv4 = None
1027
+ return
670
1028
  if self.is_fused:
671
1029
  return
672
1030
 
673
1031
  assert not self.training
674
1032
  txt_feats = txt_feats.to(torch.float32).squeeze(0)
675
- for cls_head, bn_head in zip(self.cv3, self.cv4):
676
- assert isinstance(cls_head, nn.Sequential)
677
- assert isinstance(bn_head, BNContrastiveHead)
678
- conv = cls_head[-1]
1033
+ self._fuse_tp(txt_feats, self.cv3, self.cv4)
1034
+ if self.end2end:
1035
+ self._fuse_tp(txt_feats, self.one2one_cv3, self.one2one_cv4)
1036
+ del self.reprta
1037
+ self.reprta = nn.Identity()
1038
+ self.is_fused = True
1039
+
1040
+ def _fuse_tp(self, txt_feats: torch.Tensor, cls_head: torch.nn.Module, bn_head: torch.nn.Module) -> None:
1041
+ """Fuse text prompt embeddings with model weights for efficient inference."""
1042
+ for cls_h, bn_h in zip(cls_head, bn_head):
1043
+ assert isinstance(cls_h, nn.Sequential)
1044
+ assert isinstance(bn_h, BNContrastiveHead)
1045
+ conv = cls_h[-1]
679
1046
  assert isinstance(conv, nn.Conv2d)
680
- logit_scale = bn_head.logit_scale
681
- bias = bn_head.bias
682
- norm = bn_head.norm
1047
+ logit_scale = bn_h.logit_scale
1048
+ bias = bn_h.bias
1049
+ norm = bn_h.norm
683
1050
 
684
1051
  t = txt_feats * logit_scale.exp()
685
1052
  conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)
@@ -703,13 +1070,9 @@ class YOLOEDetect(Detect):
703
1070
 
704
1071
  conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))
705
1072
  conv.bias.data.copy_(b1 + b2)
706
- cls_head[-1] = conv
1073
+ cls_h[-1] = conv
707
1074
 
708
- bn_head.fuse()
709
-
710
- del self.reprta
711
- self.reprta = nn.Identity()
712
- self.is_fused = True
1075
+ bn_h.fuse()
713
1076
 
714
1077
  def get_tpe(self, tpe: torch.Tensor | None) -> torch.Tensor | None:
715
1078
  """Get text prompt embeddings with normalization."""
@@ -724,74 +1087,89 @@ class YOLOEDetect(Detect):
724
1087
  assert vpe.ndim == 3 # (B, N, D)
725
1088
  return vpe
726
1089
 
727
- def forward_lrpc(self, x: list[torch.Tensor], return_mask: bool = False) -> torch.Tensor | tuple:
1090
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
1091
+ """Process features with class prompt embeddings to generate detections."""
1092
+ if hasattr(self, "lrpc"): # for prompt-free inference
1093
+ return self.forward_lrpc(x[:3])
1094
+ return super().forward(x)
1095
+
1096
+ def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
728
1097
  """Process features with fused text embeddings to generate detections for prompt-free model."""
729
- masks = []
730
- assert self.is_fused, "Prompt-free inference requires model to be fused!"
1098
+ boxes, scores, index = [], [], []
1099
+ bs = x[0].shape[0]
1100
+ cv2 = self.cv2 if not self.end2end else self.one2one_cv2
1101
+ cv3 = self.cv3 if not self.end2end else self.one2one_cv2
731
1102
  for i in range(self.nl):
732
- cls_feat = self.cv3[i](x[i])
733
- loc_feat = self.cv2[i](x[i])
1103
+ cls_feat = cv3[i](x[i])
1104
+ loc_feat = cv2[i](x[i])
734
1105
  assert isinstance(self.lrpc[i], LRPCHead)
735
- x[i], mask = self.lrpc[i](
736
- cls_feat, loc_feat, 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001)
1106
+ box, score, idx = self.lrpc[i](
1107
+ cls_feat,
1108
+ loc_feat,
1109
+ 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
737
1110
  )
738
- masks.append(mask)
739
- shape = x[0][0].shape
740
- if self.dynamic or self.shape != shape:
741
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors([b[0] for b in x], self.stride, 0.5))
742
- self.shape = shape
743
- box = torch.cat([xi[0].view(shape[0], self.reg_max * 4, -1) for xi in x], 2)
744
- cls = torch.cat([xi[1] for xi in x], 2)
745
-
746
- if self.export and self.format in {"tflite", "edgetpu"}:
747
- # Precompute normalization factor to increase numerical stability
748
- # See https://github.com/ultralytics/ultralytics/issues/7371
749
- grid_h = shape[2]
750
- grid_w = shape[3]
751
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
752
- norm = self.strides / (self.stride[0] * grid_size)
753
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
754
- else:
755
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
756
-
757
- mask = torch.cat(masks)
758
- y = torch.cat((dbox if self.export and not self.dynamic else dbox[..., mask], cls.sigmoid()), 1)
759
-
760
- if return_mask:
761
- return (y, mask) if self.export else ((y, x), mask)
762
- else:
763
- return y if self.export else (y, x)
764
-
765
- def forward(self, x: list[torch.Tensor], cls_pe: torch.Tensor, return_mask: bool = False) -> torch.Tensor | tuple:
766
- """Process features with class prompt embeddings to generate detections."""
767
- if hasattr(self, "lrpc"): # for prompt-free inference
768
- return self.forward_lrpc(x, return_mask)
769
- for i in range(self.nl):
770
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
771
- if self.training:
772
- return x
1111
+ boxes.append(box.view(bs, self.reg_max * 4, -1))
1112
+ scores.append(score)
1113
+ index.append(idx)
1114
+ preds = dict(boxes=torch.cat(boxes, 2), scores=torch.cat(scores, 2), feats=x, index=torch.cat(index))
1115
+ y = self._inference(preds)
1116
+ if self.end2end:
1117
+ y = self.postprocess(y.permute(0, 2, 1))
1118
+ return y if self.export else (y, preds)
1119
+
1120
+ def _get_decode_boxes(self, x):
1121
+ """Decode predicted bounding boxes for inference."""
1122
+ dbox = super()._get_decode_boxes(x)
1123
+ if hasattr(self, "lrpc"):
1124
+ dbox = dbox if self.export and not self.dynamic else dbox[..., x["index"]]
1125
+ return dbox
1126
+
1127
+ @property
1128
+ def one2many(self):
1129
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
1130
+ return dict(box_head=self.cv2, cls_head=self.cv3, contrastive_head=self.cv4)
1131
+
1132
+ @property
1133
+ def one2one(self):
1134
+ """Returns the one-to-one head components."""
1135
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, contrastive_head=self.one2one_cv4)
1136
+
1137
+ def forward_head(self, x, box_head, cls_head, contrastive_head):
1138
+ """Concatenates and returns predicted bounding boxes, class probabilities, and text embeddings."""
1139
+ assert len(x) == 4, f"Expected 4 features including 3 feature maps and 1 text embeddings, but got {len(x)}."
1140
+ if box_head is None or cls_head is None: # for fused inference
1141
+ return dict()
1142
+ bs = x[0].shape[0] # batch size
1143
+ boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
1144
+ self.nc = x[-1].shape[1]
1145
+ scores = torch.cat(
1146
+ [contrastive_head[i](cls_head[i](x[i]), x[-1]).reshape(bs, self.nc, -1) for i in range(self.nl)], dim=-1
1147
+ )
773
1148
  self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
774
- y = self._inference(x)
775
- return y if self.export else (y, x)
1149
+ return dict(boxes=boxes, scores=scores, feats=x[:3])
776
1150
 
777
1151
  def bias_init(self):
778
- """Initialize biases for detection heads."""
779
- m = self # self.model[-1] # Detect() module
780
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
781
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
782
- for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from
783
- a[-1].bias.data[:] = 1.0 # box
784
- # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
1152
+ """Initialize Detect() biases, WARNING: requires stride availability."""
1153
+ for i, (a, b, c) in enumerate(
1154
+ zip(self.one2many["box_head"], self.one2many["cls_head"], self.one2many["contrastive_head"])
1155
+ ):
1156
+ a[-1].bias.data[:] = 2.0 # box
785
1157
  b[-1].bias.data[:] = 0.0
786
- c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
1158
+ c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
1159
+ if self.end2end:
1160
+ for i, (a, b, c) in enumerate(
1161
+ zip(self.one2one["box_head"], self.one2one["cls_head"], self.one2one["contrastive_head"])
1162
+ ):
1163
+ a[-1].bias.data[:] = 2.0 # box
1164
+ b[-1].bias.data[:] = 0.0
1165
+ c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
787
1166
 
788
1167
 
789
1168
  class YOLOESegment(YOLOEDetect):
790
- """
791
- YOLO segmentation head with text embedding capabilities.
1169
+ """YOLO segmentation head with text embedding capabilities.
792
1170
 
793
- This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks
794
- with text-guided semantic understanding.
1171
+ This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks with
1172
+ text-guided semantic understanding.
795
1173
 
796
1174
  Attributes:
797
1175
  nm (int): Number of masks.
@@ -811,10 +1189,17 @@ class YOLOESegment(YOLOEDetect):
811
1189
  """
812
1190
 
813
1191
  def __init__(
814
- self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
1192
+ self,
1193
+ nc: int = 80,
1194
+ nm: int = 32,
1195
+ npr: int = 256,
1196
+ embed: int = 512,
1197
+ with_bn: bool = False,
1198
+ reg_max=16,
1199
+ end2end=False,
1200
+ ch: tuple = (),
815
1201
  ):
816
- """
817
- Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
1202
+ """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
818
1203
 
819
1204
  Args:
820
1205
  nc (int): Number of classes.
@@ -822,41 +1207,195 @@ class YOLOESegment(YOLOEDetect):
822
1207
  npr (int): Number of protos.
823
1208
  embed (int): Embedding dimension.
824
1209
  with_bn (bool): Whether to use batch normalization in contrastive head.
1210
+ reg_max (int): Maximum number of DFL channels.
1211
+ end2end (bool): Whether to use end-to-end NMS-free detection.
825
1212
  ch (tuple): Tuple of channel sizes from backbone feature maps.
826
1213
  """
827
- super().__init__(nc, embed, with_bn, ch)
1214
+ super().__init__(nc, embed, with_bn, reg_max, end2end, ch)
828
1215
  self.nm = nm
829
1216
  self.npr = npr
830
1217
  self.proto = Proto(ch[0], self.npr, self.nm)
831
1218
 
832
1219
  c5 = max(ch[0] // 4, self.nm)
833
1220
  self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
1221
+ if end2end:
1222
+ self.one2one_cv5 = copy.deepcopy(self.cv5)
1223
+
1224
+ @property
1225
+ def one2many(self):
1226
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
1227
+ return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv5, contrastive_head=self.cv4)
1228
+
1229
+ @property
1230
+ def one2one(self):
1231
+ """Returns the one-to-one head components."""
1232
+ return dict(
1233
+ box_head=self.one2one_cv2,
1234
+ cls_head=self.one2one_cv3,
1235
+ mask_head=self.one2one_cv5,
1236
+ contrastive_head=self.one2one_cv4,
1237
+ )
1238
+
1239
+ def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
1240
+ """Process features with fused text embeddings to generate detections for prompt-free model."""
1241
+ boxes, scores, index = [], [], []
1242
+ bs = x[0].shape[0]
1243
+ cv2 = self.cv2 if not self.end2end else self.one2one_cv2
1244
+ cv3 = self.cv3 if not self.end2end else self.one2one_cv3
1245
+ cv5 = self.cv5 if not self.end2end else self.one2one_cv5
1246
+ for i in range(self.nl):
1247
+ cls_feat = cv3[i](x[i])
1248
+ loc_feat = cv2[i](x[i])
1249
+ assert isinstance(self.lrpc[i], LRPCHead)
1250
+ box, score, idx = self.lrpc[i](
1251
+ cls_feat,
1252
+ loc_feat,
1253
+ 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
1254
+ )
1255
+ boxes.append(box.view(bs, self.reg_max * 4, -1))
1256
+ scores.append(score)
1257
+ index.append(idx)
1258
+ mc = torch.cat([cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
1259
+ index = torch.cat(index)
1260
+ preds = dict(
1261
+ boxes=torch.cat(boxes, 2),
1262
+ scores=torch.cat(scores, 2),
1263
+ feats=x,
1264
+ index=index,
1265
+ mask_coefficient=mc * index.int() if self.export and not self.dynamic else mc[..., index],
1266
+ )
1267
+ y = self._inference(preds)
1268
+ if self.end2end:
1269
+ y = self.postprocess(y.permute(0, 2, 1))
1270
+ return y if self.export else (y, preds)
834
1271
 
835
- def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> tuple | torch.Tensor:
1272
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
836
1273
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
837
- p = self.proto(x[0]) # mask protos
838
- bs = p.shape[0] # batch size
1274
+ outputs = super().forward(x)
1275
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
1276
+ proto = self.proto(x[0]) # mask protos
1277
+ if isinstance(preds, dict): # training and validating during training
1278
+ if self.end2end:
1279
+ preds["one2many"]["proto"] = proto
1280
+ preds["one2one"]["proto"] = proto.detach()
1281
+ else:
1282
+ preds["proto"] = proto
1283
+ if self.training:
1284
+ return preds
1285
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
839
1286
 
840
- mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
841
- has_lrpc = hasattr(self, "lrpc")
1287
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
1288
+ """Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
1289
+ preds = super()._inference(x)
1290
+ return torch.cat([preds, x["mask_coefficient"]], dim=1)
842
1291
 
843
- if not has_lrpc:
844
- x = YOLOEDetect.forward(self, x, text)
845
- else:
846
- x, mask = YOLOEDetect.forward(self, x, text, return_mask=True)
1292
+ def forward_head(
1293
+ self,
1294
+ x: list[torch.Tensor],
1295
+ box_head: torch.nn.Module,
1296
+ cls_head: torch.nn.Module,
1297
+ mask_head: torch.nn.Module,
1298
+ contrastive_head: torch.nn.Module,
1299
+ ) -> torch.Tensor:
1300
+ """Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
1301
+ preds = super().forward_head(x, box_head, cls_head, contrastive_head)
1302
+ if mask_head is not None:
1303
+ bs = x[0].shape[0] # batch size
1304
+ preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
1305
+ return preds
1306
+
1307
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
1308
+ """Post-process YOLO model predictions.
847
1309
 
848
- if self.training:
849
- return x, mc, p
1310
+ Args:
1311
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
1312
+ format [x, y, w, h, class_probs, mask_coefficient].
850
1313
 
851
- if has_lrpc:
852
- mc = (mc * mask.int()) if self.export and not self.dynamic else mc[..., mask]
1314
+ Returns:
1315
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
1316
+ dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
1317
+ """
1318
+ boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
1319
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
1320
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
1321
+ mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
1322
+ return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
853
1323
 
854
- return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
1324
+ def fuse(self, txt_feats: torch.Tensor = None):
1325
+ """Fuse text features with model weights for efficient inference."""
1326
+ super().fuse(txt_feats)
1327
+ if txt_feats is None: # means eliminate one2many branch
1328
+ self.cv5 = None
1329
+ if hasattr(self.proto, "fuse"):
1330
+ self.proto.fuse()
1331
+ return
855
1332
 
856
1333
 
857
- class RTDETRDecoder(nn.Module):
1334
+ class YOLOESegment26(YOLOESegment):
1335
+ """YOLOE-style segmentation head module using Proto26 for mask generation.
1336
+
1337
+ This class extends the YOLOEDetect functionality to include segmentation capabilities by integrating a prototype
1338
+ generation module and convolutional layers to predict mask coefficients.
1339
+
1340
+ Args:
1341
+ nc (int): Number of classes. Defaults to 80.
1342
+ nm (int): Number of masks. Defaults to 32.
1343
+ npr (int): Number of prototype channels. Defaults to 256.
1344
+ embed (int): Embedding dimensionality. Defaults to 512.
1345
+ with_bn (bool): Whether to use Batch Normalization. Defaults to False.
1346
+ reg_max (int): Maximum regression value for bounding boxes. Defaults to 16.
1347
+ end2end (bool): Whether to use end-to-end detection mode. Defaults to False.
1348
+ ch (tuple[int, ...]): Input channels for each scale.
1349
+
1350
+ Attributes:
1351
+ nm (int): Number of segmentation masks.
1352
+ npr (int): Number of prototype channels.
1353
+ proto (Proto26): Prototype generation module for segmentation.
1354
+ cv5 (nn.ModuleList): Convolutional layers for generating mask coefficients from features.
1355
+ one2one_cv5 (nn.ModuleList, optional): Deep copy of cv5 for end-to-end detection branches.
858
1356
  """
859
- Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
1357
+
1358
+ def __init__(
1359
+ self,
1360
+ nc: int = 80,
1361
+ nm: int = 32,
1362
+ npr: int = 256,
1363
+ embed: int = 512,
1364
+ with_bn: bool = False,
1365
+ reg_max=16,
1366
+ end2end=False,
1367
+ ch: tuple = (),
1368
+ ):
1369
+ """Initialize YOLOESegment26 with class count, mask parameters, and embedding dimensions."""
1370
+ YOLOEDetect.__init__(self, nc, embed, with_bn, reg_max, end2end, ch)
1371
+ self.nm = nm
1372
+ self.npr = npr
1373
+ self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
1374
+
1375
+ c5 = max(ch[0] // 4, self.nm)
1376
+ self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
1377
+ if end2end:
1378
+ self.one2one_cv5 = copy.deepcopy(self.cv5)
1379
+
1380
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
1381
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
1382
+ outputs = YOLOEDetect.forward(self, x)
1383
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
1384
+ proto = self.proto([xi.detach() for xi in x], return_semseg=False) # mask protos
1385
+
1386
+ if isinstance(preds, dict): # training and validating during training
1387
+ if self.end2end and not hasattr(self, "lrpc"): # not prompt-free
1388
+ preds["one2many"]["proto"] = proto
1389
+ preds["one2one"]["proto"] = proto.detach()
1390
+ else:
1391
+ preds["proto"] = proto
1392
+ if self.training:
1393
+ return preds
1394
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
1395
+
1396
+
1397
+ class RTDETRDecoder(nn.Module):
1398
+ """Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
860
1399
 
861
1400
  This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
862
1401
  and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
@@ -920,8 +1459,7 @@ class RTDETRDecoder(nn.Module):
920
1459
  box_noise_scale: float = 1.0,
921
1460
  learnt_init_query: bool = False,
922
1461
  ):
923
- """
924
- Initialize the RTDETRDecoder module with the given parameters.
1462
+ """Initialize the RTDETRDecoder module with the given parameters.
925
1463
 
926
1464
  Args:
927
1465
  nc (int): Number of classes.
@@ -981,8 +1519,7 @@ class RTDETRDecoder(nn.Module):
981
1519
  self._reset_parameters()
982
1520
 
983
1521
  def forward(self, x: list[torch.Tensor], batch: dict | None = None) -> tuple | torch.Tensor:
984
- """
985
- Run the forward pass of the module, returning bounding box and classification scores for the input.
1522
+ """Run the forward pass of the module, returning bounding box and classification scores for the input.
986
1523
 
987
1524
  Args:
988
1525
  x (list[torch.Tensor]): List of feature maps from the backbone.
@@ -1030,16 +1567,15 @@ class RTDETRDecoder(nn.Module):
1030
1567
  y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
1031
1568
  return y if self.export else (y, x)
1032
1569
 
1570
+ @staticmethod
1033
1571
  def _generate_anchors(
1034
- self,
1035
1572
  shapes: list[list[int]],
1036
1573
  grid_size: float = 0.05,
1037
1574
  dtype: torch.dtype = torch.float32,
1038
1575
  device: str = "cpu",
1039
1576
  eps: float = 1e-2,
1040
1577
  ) -> tuple[torch.Tensor, torch.Tensor]:
1041
- """
1042
- Generate anchor bounding boxes for given shapes with specific grid size and validate them.
1578
+ """Generate anchor bounding boxes for given shapes with specific grid size and validate them.
1043
1579
 
1044
1580
  Args:
1045
1581
  shapes (list): List of feature map shapes.
@@ -1071,8 +1607,7 @@ class RTDETRDecoder(nn.Module):
1071
1607
  return anchors, valid_mask
1072
1608
 
1073
1609
  def _get_encoder_input(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]]]:
1074
- """
1075
- Process and return encoder inputs by getting projection features from input and concatenating them.
1610
+ """Process and return encoder inputs by getting projection features from input and concatenating them.
1076
1611
 
1077
1612
  Args:
1078
1613
  x (list[torch.Tensor]): List of feature maps from the backbone.
@@ -1104,8 +1639,7 @@ class RTDETRDecoder(nn.Module):
1104
1639
  dn_embed: torch.Tensor | None = None,
1105
1640
  dn_bbox: torch.Tensor | None = None,
1106
1641
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1107
- """
1108
- Generate and prepare the input required for the decoder from the provided features and shapes.
1642
+ """Generate and prepare the input required for the decoder from the provided features and shapes.
1109
1643
 
1110
1644
  Args:
1111
1645
  feats (torch.Tensor): Processed features from encoder.
@@ -1129,9 +1663,9 @@ class RTDETRDecoder(nn.Module):
1129
1663
  enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
1130
1664
 
1131
1665
  # Query selection
1132
- # (bs, num_queries)
1666
+ # (bs*num_queries,)
1133
1667
  topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
1134
- # (bs, num_queries)
1668
+ # (bs*num_queries,)
1135
1669
  batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
1136
1670
 
1137
1671
  # (bs, num_queries, 256)
@@ -1183,11 +1717,10 @@ class RTDETRDecoder(nn.Module):
1183
1717
 
1184
1718
 
1185
1719
  class v10Detect(Detect):
1186
- """
1187
- v10 Detection head from https://arxiv.org/pdf/2405.14458.
1720
+ """v10 Detection head from https://arxiv.org/pdf/2405.14458.
1188
1721
 
1189
- This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions
1190
- for improved efficiency and performance.
1722
+ This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions for
1723
+ improved efficiency and performance.
1191
1724
 
1192
1725
  Attributes:
1193
1726
  end2end (bool): End-to-end detection mode.
@@ -1211,14 +1744,13 @@ class v10Detect(Detect):
1211
1744
  end2end = True
1212
1745
 
1213
1746
  def __init__(self, nc: int = 80, ch: tuple = ()):
1214
- """
1215
- Initialize the v10Detect object with the specified number of classes and input channels.
1747
+ """Initialize the v10Detect object with the specified number of classes and input channels.
1216
1748
 
1217
1749
  Args:
1218
1750
  nc (int): Number of classes.
1219
1751
  ch (tuple): Tuple of channel sizes from backbone feature maps.
1220
1752
  """
1221
- super().__init__(nc, ch)
1753
+ super().__init__(nc, end2end=True, ch=ch)
1222
1754
  c3 = max(ch[0], min(self.nc, 100)) # channels
1223
1755
  # Light cls head
1224
1756
  self.cv3 = nn.ModuleList(
@@ -1233,4 +1765,4 @@ class v10Detect(Detect):
1233
1765
 
1234
1766
  def fuse(self):
1235
1767
  """Remove the one2many head for inference optimization."""
1236
- self.cv2 = self.cv3 = nn.ModuleList([nn.Identity()] * self.nl)
1768
+ self.cv2 = self.cv3 = None