dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """Model head modules."""
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import copy
5
7
  import math
6
8
 
@@ -9,19 +11,59 @@ import torch.nn as nn
9
11
  import torch.nn.functional as F
10
12
  from torch.nn.init import constant_, xavier_uniform_
11
13
 
12
- from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
13
- from ultralytics.utils.torch_utils import fuse_conv_and_bn, smart_inference_mode
14
+ from ultralytics.utils import NOT_MACOS14
15
+ from ultralytics.utils.tal import dist2bbox, dist2rbox, make_anchors
16
+ from ultralytics.utils.torch_utils import TORCH_1_11, fuse_conv_and_bn, smart_inference_mode
14
17
 
15
18
  from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN
16
19
  from .conv import Conv, DWConv
17
20
  from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
18
21
  from .utils import bias_init_with_prob, linear_init
19
22
 
20
- __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "YOLOEDetect", "YOLOESegment"
23
+ __all__ = "OBB", "Classify", "Detect", "Pose", "RTDETRDecoder", "Segment", "YOLOEDetect", "YOLOESegment", "v10Detect"
21
24
 
22
25
 
23
26
  class Detect(nn.Module):
24
- """YOLO Detect head for detection models."""
27
+ """YOLO Detect head for object detection models.
28
+
29
+ This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.
30
+ It supports both training and inference modes, with optional end-to-end detection capabilities.
31
+
32
+ Attributes:
33
+ dynamic (bool): Force grid reconstruction.
34
+ export (bool): Export mode flag.
35
+ format (str): Export format.
36
+ end2end (bool): End-to-end detection mode.
37
+ max_det (int): Maximum detections per image.
38
+ shape (tuple): Input shape.
39
+ anchors (torch.Tensor): Anchor points.
40
+ strides (torch.Tensor): Feature map strides.
41
+ legacy (bool): Backward compatibility for v3/v5/v8/v9 models.
42
+ xyxy (bool): Output format, xyxy or xywh.
43
+ nc (int): Number of classes.
44
+ nl (int): Number of detection layers.
45
+ reg_max (int): DFL channels.
46
+ no (int): Number of outputs per anchor.
47
+ stride (torch.Tensor): Strides computed during build.
48
+ cv2 (nn.ModuleList): Convolution layers for box regression.
49
+ cv3 (nn.ModuleList): Convolution layers for classification.
50
+ dfl (nn.Module): Distribution Focal Loss layer.
51
+ one2one_cv2 (nn.ModuleList): One-to-one convolution layers for box regression.
52
+ one2one_cv3 (nn.ModuleList): One-to-one convolution layers for classification.
53
+
54
+ Methods:
55
+ forward: Perform forward pass and return predictions.
56
+ forward_end2end: Perform forward pass for end-to-end detection.
57
+ bias_init: Initialize detection head biases.
58
+ decode_bboxes: Decode bounding boxes from predictions.
59
+ postprocess: Post-process model predictions.
60
+
61
+ Examples:
62
+ Create a detection head for 80 classes
63
+ >>> detect = Detect(nc=80, ch=(256, 512, 1024))
64
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
65
+ >>> outputs = detect(x)
66
+ """
25
67
 
26
68
  dynamic = False # force grid reconstruction
27
69
  export = False # export mode
@@ -34,8 +76,13 @@ class Detect(nn.Module):
34
76
  legacy = False # backward compatibility for v3/v5/v8/v9 models
35
77
  xyxy = False # xyxy or xywh output
36
78
 
37
- def __init__(self, nc=80, ch=()):
38
- """Initialize the YOLO detection layer with specified number of classes and channels."""
79
+ def __init__(self, nc: int = 80, ch: tuple = ()):
80
+ """Initialize the YOLO detection layer with specified number of classes and channels.
81
+
82
+ Args:
83
+ nc (int): Number of classes.
84
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
85
+ """
39
86
  super().__init__()
40
87
  self.nc = nc # number of classes
41
88
  self.nl = len(ch) # number of detection layers
@@ -64,8 +111,8 @@ class Detect(nn.Module):
64
111
  self.one2one_cv2 = copy.deepcopy(self.cv2)
65
112
  self.one2one_cv3 = copy.deepcopy(self.cv3)
66
113
 
67
- def forward(self, x):
68
- """Concatenates and returns predicted bounding boxes and class probabilities."""
114
+ def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor] | tuple:
115
+ """Concatenate and return predicted bounding boxes and class probabilities."""
69
116
  if self.end2end:
70
117
  return self.forward_end2end(x)
71
118
 
@@ -76,18 +123,15 @@ class Detect(nn.Module):
76
123
  y = self._inference(x)
77
124
  return y if self.export else (y, x)
78
125
 
79
- def forward_end2end(self, x):
80
- """
81
- Performs forward pass of the v10Detect module.
126
+ def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
127
+ """Perform forward pass of the v10Detect module.
82
128
 
83
129
  Args:
84
- x (List[torch.Tensor]): Input feature maps from different levels.
130
+ x (list[torch.Tensor]): Input feature maps from different levels.
85
131
 
86
132
  Returns:
87
- (dict | tuple):
88
-
89
- - If in training mode, returns a dictionary containing outputs of both one2many and one2one detections.
90
- - If not in training mode, returns processed detections or a tuple with processed detections and raw outputs.
133
+ outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs. Inference mode returns
134
+ processed detections or tuple with detections and raw outputs.
91
135
  """
92
136
  x_detach = [xi.detach() for xi in x]
93
137
  one2one = [
@@ -102,12 +146,11 @@ class Detect(nn.Module):
102
146
  y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
103
147
  return y if self.export else (y, {"one2many": x, "one2one": one2one})
104
148
 
105
- def _inference(self, x):
106
- """
107
- Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
149
+ def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
150
+ """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
108
151
 
109
152
  Args:
110
- x (List[torch.Tensor]): List of feature maps from different detection layers.
153
+ x (list[torch.Tensor]): List of feature maps from different detection layers.
111
154
 
112
155
  Returns:
113
156
  (torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.
@@ -115,32 +158,12 @@ class Detect(nn.Module):
115
158
  # Inference path
116
159
  shape = x[0].shape # BCHW
117
160
  x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
118
- if self.format != "imx" and (self.dynamic or self.shape != shape):
161
+ if self.dynamic or self.shape != shape:
119
162
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
120
163
  self.shape = shape
121
164
 
122
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
123
- box = x_cat[:, : self.reg_max * 4]
124
- cls = x_cat[:, self.reg_max * 4 :]
125
- else:
126
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
127
-
128
- if self.export and self.format in {"tflite", "edgetpu"}:
129
- # Precompute normalization factor to increase numerical stability
130
- # See https://github.com/ultralytics/ultralytics/issues/7371
131
- grid_h = shape[2]
132
- grid_w = shape[3]
133
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
134
- norm = self.strides / (self.stride[0] * grid_size)
135
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
136
- elif self.export and self.format == "imx":
137
- dbox = self.decode_bboxes(
138
- self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False
139
- )
140
- return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
141
- else:
142
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
143
-
165
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
166
+ dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
144
167
  return torch.cat((dbox, cls.sigmoid()), 1)
145
168
 
146
169
  def bias_init(self):
@@ -156,20 +179,24 @@ class Detect(nn.Module):
156
179
  a[-1].bias.data[:] = 1.0 # box
157
180
  b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
158
181
 
159
- def decode_bboxes(self, bboxes, anchors, xywh=True):
160
- """Decode bounding boxes."""
161
- return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)
182
+ def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
183
+ """Decode bounding boxes from predictions."""
184
+ return dist2bbox(
185
+ bboxes,
186
+ anchors,
187
+ xywh=xywh and not self.end2end and not self.xyxy,
188
+ dim=1,
189
+ )
162
190
 
163
191
  @staticmethod
164
- def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
165
- """
166
- Post-processes YOLO model predictions.
192
+ def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
193
+ """Post-process YOLO model predictions.
167
194
 
168
195
  Args:
169
196
  preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
170
197
  format [x, y, w, h, class_probs].
171
198
  max_det (int): Maximum detections per image.
172
- nc (int, optional): Number of classes. Default: 80.
199
+ nc (int, optional): Number of classes.
173
200
 
174
201
  Returns:
175
202
  (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
@@ -186,10 +213,35 @@ class Detect(nn.Module):
186
213
 
187
214
 
188
215
  class Segment(Detect):
189
- """YOLO Segment head for segmentation models."""
216
+ """YOLO Segment head for segmentation models.
190
217
 
191
- def __init__(self, nc=80, nm=32, npr=256, ch=()):
192
- """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
218
+ This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
219
+
220
+ Attributes:
221
+ nm (int): Number of masks.
222
+ npr (int): Number of protos.
223
+ proto (Proto): Prototype generation module.
224
+ cv4 (nn.ModuleList): Convolution layers for mask coefficients.
225
+
226
+ Methods:
227
+ forward: Return model outputs and mask coefficients.
228
+
229
+ Examples:
230
+ Create a segmentation head
231
+ >>> segment = Segment(nc=80, nm=32, npr=256, ch=(256, 512, 1024))
232
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
233
+ >>> outputs = segment(x)
234
+ """
235
+
236
+ def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
237
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
238
+
239
+ Args:
240
+ nc (int): Number of classes.
241
+ nm (int): Number of masks.
242
+ npr (int): Number of protos.
243
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
244
+ """
193
245
  super().__init__(nc, ch)
194
246
  self.nm = nm # number of masks
195
247
  self.npr = npr # number of protos
@@ -198,7 +250,7 @@ class Segment(Detect):
198
250
  c4 = max(ch[0] // 4, self.nm)
199
251
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
200
252
 
201
- def forward(self, x):
253
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor]:
202
254
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
203
255
  p = self.proto(x[0]) # mask protos
204
256
  bs = p.shape[0] # batch size
@@ -211,18 +263,42 @@ class Segment(Detect):
211
263
 
212
264
 
213
265
  class OBB(Detect):
214
- """YOLO OBB detection head for detection with rotation models."""
266
+ """YOLO OBB detection head for detection with rotation models.
215
267
 
216
- def __init__(self, nc=80, ne=1, ch=()):
217
- """Initialize OBB with number of classes `nc` and layer channels `ch`."""
268
+ This class extends the Detect head to include oriented bounding box prediction with rotation angles.
269
+
270
+ Attributes:
271
+ ne (int): Number of extra parameters.
272
+ cv4 (nn.ModuleList): Convolution layers for angle prediction.
273
+ angle (torch.Tensor): Predicted rotation angles.
274
+
275
+ Methods:
276
+ forward: Concatenate and return predicted bounding boxes and class probabilities.
277
+ decode_bboxes: Decode rotated bounding boxes.
278
+
279
+ Examples:
280
+ Create an OBB detection head
281
+ >>> obb = OBB(nc=80, ne=1, ch=(256, 512, 1024))
282
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
283
+ >>> outputs = obb(x)
284
+ """
285
+
286
+ def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
287
+ """Initialize OBB with number of classes `nc` and layer channels `ch`.
288
+
289
+ Args:
290
+ nc (int): Number of classes.
291
+ ne (int): Number of extra parameters.
292
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
293
+ """
218
294
  super().__init__(nc, ch)
219
295
  self.ne = ne # number of extra parameters
220
296
 
221
297
  c4 = max(ch[0] // 4, self.ne)
222
298
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
223
299
 
224
- def forward(self, x):
225
- """Concatenates and returns predicted bounding boxes and class probabilities."""
300
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
301
+ """Concatenate and return predicted bounding boxes and class probabilities."""
226
302
  bs = x[0].shape[0] # batch size
227
303
  angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
228
304
  # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
@@ -235,16 +311,40 @@ class OBB(Detect):
235
311
  return x, angle
236
312
  return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
237
313
 
238
- def decode_bboxes(self, bboxes, anchors):
314
+ def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
239
315
  """Decode rotated bounding boxes."""
240
316
  return dist2rbox(bboxes, self.angle, anchors, dim=1)
241
317
 
242
318
 
243
319
  class Pose(Detect):
244
- """YOLO Pose head for keypoints models."""
320
+ """YOLO Pose head for keypoints models.
321
+
322
+ This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
323
+
324
+ Attributes:
325
+ kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).
326
+ nk (int): Total number of keypoint values.
327
+ cv4 (nn.ModuleList): Convolution layers for keypoint prediction.
328
+
329
+ Methods:
330
+ forward: Perform forward pass through YOLO model and return predictions.
331
+ kpts_decode: Decode keypoints from predictions.
332
+
333
+ Examples:
334
+ Create a pose detection head
335
+ >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))
336
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
337
+ >>> outputs = pose(x)
338
+ """
339
+
340
+ def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
341
+ """Initialize YOLO network with default parameters and Convolutional Layers.
245
342
 
246
- def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
247
- """Initialize YOLO network with default parameters and Convolutional Layers."""
343
+ Args:
344
+ nc (int): Number of classes.
345
+ kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
346
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
347
+ """
248
348
  super().__init__(nc, ch)
249
349
  self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
250
350
  self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
@@ -252,7 +352,7 @@ class Pose(Detect):
252
352
  c4 = max(ch[0] // 4, self.nk)
253
353
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
254
354
 
255
- def forward(self, x):
355
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
256
356
  """Perform forward pass through YOLO model and return predictions."""
257
357
  bs = x[0].shape[0] # batch size
258
358
  kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
@@ -262,43 +362,63 @@ class Pose(Detect):
262
362
  pred_kpt = self.kpts_decode(bs, kpt)
263
363
  return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
264
364
 
265
- def kpts_decode(self, bs, kpts):
266
- """Decodes keypoints."""
365
+ def kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
366
+ """Decode keypoints from predictions."""
267
367
  ndim = self.kpt_shape[1]
268
368
  if self.export:
269
- if self.format in {
270
- "tflite",
271
- "edgetpu",
272
- }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
273
- # Precompute normalization factor to increase numerical stability
274
- y = kpts.view(bs, *self.kpt_shape, -1)
275
- grid_h, grid_w = self.shape[2], self.shape[3]
276
- grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
277
- norm = self.strides / (self.stride[0] * grid_size)
278
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
279
- else:
280
- # NCNN fix
281
- y = kpts.view(bs, *self.kpt_shape, -1)
282
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
369
+ # NCNN fix
370
+ y = kpts.view(bs, *self.kpt_shape, -1)
371
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
283
372
  if ndim == 3:
284
373
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
285
374
  return a.view(bs, self.nk, -1)
286
375
  else:
287
376
  y = kpts.clone()
288
377
  if ndim == 3:
289
- y[:, 2::ndim] = y[:, 2::ndim].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
378
+ if NOT_MACOS14:
379
+ y[:, 2::ndim].sigmoid_()
380
+ else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
381
+ y[:, 2::ndim] = y[:, 2::ndim].sigmoid()
290
382
  y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
291
383
  y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
292
384
  return y
293
385
 
294
386
 
295
387
  class Classify(nn.Module):
296
- """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
388
+ """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
389
+
390
+ This class implements a classification head that transforms feature maps into class predictions.
391
+
392
+ Attributes:
393
+ export (bool): Export mode flag.
394
+ conv (Conv): Convolutional layer for feature transformation.
395
+ pool (nn.AdaptiveAvgPool2d): Global average pooling layer.
396
+ drop (nn.Dropout): Dropout layer for regularization.
397
+ linear (nn.Linear): Linear layer for final classification.
398
+
399
+ Methods:
400
+ forward: Perform forward pass of the YOLO model on input image data.
401
+
402
+ Examples:
403
+ Create a classification head
404
+ >>> classify = Classify(c1=1024, c2=1000)
405
+ >>> x = torch.randn(1, 1024, 20, 20)
406
+ >>> output = classify(x)
407
+ """
297
408
 
298
409
  export = False # export mode
299
410
 
300
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
301
- """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
411
+ def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
412
+ """Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
413
+
414
+ Args:
415
+ c1 (int): Number of input channels.
416
+ c2 (int): Number of output classes.
417
+ k (int, optional): Kernel size.
418
+ s (int, optional): Stride.
419
+ p (int, optional): Padding.
420
+ g (int, optional): Groups.
421
+ """
302
422
  super().__init__()
303
423
  c_ = 1280 # efficientnet_b0 size
304
424
  self.conv = Conv(c1, c_, k, s, p, g)
@@ -306,8 +426,8 @@ class Classify(nn.Module):
306
426
  self.drop = nn.Dropout(p=0.0, inplace=True)
307
427
  self.linear = nn.Linear(c_, c2) # to x(b,c2)
308
428
 
309
- def forward(self, x):
310
- """Performs a forward pass of the YOLO model on input image data."""
429
+ def forward(self, x: list[torch.Tensor] | torch.Tensor) -> torch.Tensor | tuple:
430
+ """Perform forward pass of the YOLO model on input image data."""
311
431
  if isinstance(x, list):
312
432
  x = torch.cat(x, 1)
313
433
  x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
@@ -318,17 +438,43 @@ class Classify(nn.Module):
318
438
 
319
439
 
320
440
  class WorldDetect(Detect):
321
- """Head for integrating YOLO detection models with semantic understanding from text embeddings."""
441
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
442
+
443
+ This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding in
444
+ object detection tasks.
322
445
 
323
- def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
324
- """Initialize YOLO detection layer with nc classes and layer channels ch."""
446
+ Attributes:
447
+ cv3 (nn.ModuleList): Convolution layers for embedding features.
448
+ cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.
449
+
450
+ Methods:
451
+ forward: Concatenate and return predicted bounding boxes and class probabilities.
452
+ bias_init: Initialize detection head biases.
453
+
454
+ Examples:
455
+ Create a WorldDetect head
456
+ >>> world_detect = WorldDetect(nc=80, embed=512, with_bn=False, ch=(256, 512, 1024))
457
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
458
+ >>> text = torch.randn(1, 80, 512)
459
+ >>> outputs = world_detect(x, text)
460
+ """
461
+
462
+ def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
463
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
464
+
465
+ Args:
466
+ nc (int): Number of classes.
467
+ embed (int): Embedding dimension.
468
+ with_bn (bool): Whether to use batch normalization in contrastive head.
469
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
470
+ """
325
471
  super().__init__(nc, ch)
326
472
  c3 = max(ch[0], min(self.nc, 100))
327
473
  self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
328
474
  self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
329
475
 
330
- def forward(self, x, text):
331
- """Concatenates and returns predicted bounding boxes and class probabilities."""
476
+ def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> list[torch.Tensor] | tuple:
477
+ """Concatenate and return predicted bounding boxes and class probabilities."""
332
478
  for i in range(self.nl):
333
479
  x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
334
480
  if self.training:
@@ -348,17 +494,45 @@ class WorldDetect(Detect):
348
494
 
349
495
 
350
496
  class LRPCHead(nn.Module):
351
- """Lightweight Region Proposal and Classification Head for efficient object detection."""
497
+ """Lightweight Region Proposal and Classification Head for efficient object detection.
352
498
 
353
- def __init__(self, vocab, pf, loc, enabled=True):
354
- """Initialize LRPCHead with vocabulary, proposal filter, and localization components."""
499
+ This head combines region proposal filtering with classification to enable efficient detection with dynamic
500
+ vocabulary support.
501
+
502
+ Attributes:
503
+ vocab (nn.Module): Vocabulary/classification layer.
504
+ pf (nn.Module): Proposal filter module.
505
+ loc (nn.Module): Localization module.
506
+ enabled (bool): Whether the head is enabled.
507
+
508
+ Methods:
509
+ conv2linear: Convert a 1x1 convolutional layer to a linear layer.
510
+ forward: Process classification and localization features to generate detection proposals.
511
+
512
+ Examples:
513
+ Create an LRPC head
514
+ >>> vocab = nn.Conv2d(256, 80, 1)
515
+ >>> pf = nn.Conv2d(256, 1, 1)
516
+ >>> loc = nn.Conv2d(256, 4, 1)
517
+ >>> head = LRPCHead(vocab, pf, loc, enabled=True)
518
+ """
519
+
520
+ def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):
521
+ """Initialize LRPCHead with vocabulary, proposal filter, and localization components.
522
+
523
+ Args:
524
+ vocab (nn.Module): Vocabulary/classification module.
525
+ pf (nn.Module): Proposal filter module.
526
+ loc (nn.Module): Localization module.
527
+ enabled (bool): Whether to enable the head functionality.
528
+ """
355
529
  super().__init__()
356
530
  self.vocab = self.conv2linear(vocab) if enabled else vocab
357
531
  self.pf = pf
358
532
  self.loc = loc
359
533
  self.enabled = enabled
360
534
 
361
- def conv2linear(self, conv):
535
+ def conv2linear(self, conv: nn.Conv2d) -> nn.Linear:
362
536
  """Convert a 1x1 convolutional layer to a linear layer."""
363
537
  assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)
364
538
  linear = nn.Linear(conv.in_channels, conv.out_channels)
@@ -366,7 +540,7 @@ class LRPCHead(nn.Module):
366
540
  linear.bias.data = conv.bias.data
367
541
  return linear
368
542
 
369
- def forward(self, cls_feat, loc_feat, conf):
543
+ def forward(self, cls_feat: torch.Tensor, loc_feat: torch.Tensor, conf: float) -> tuple[tuple, torch.Tensor]:
370
544
  """Process classification and localization features to generate detection proposals."""
371
545
  if self.enabled:
372
546
  pf_score = self.pf(cls_feat)[0, 0].flatten(0)
@@ -383,16 +557,50 @@ class LRPCHead(nn.Module):
383
557
 
384
558
 
385
559
  class YOLOEDetect(Detect):
386
- """Head for integrating YOLO detection models with semantic understanding from text embeddings."""
560
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
561
+
562
+ This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding
563
+ through text embeddings and visual prompt embeddings.
564
+
565
+ Attributes:
566
+ is_fused (bool): Whether the model is fused for inference.
567
+ cv3 (nn.ModuleList): Convolution layers for embedding features.
568
+ cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.
569
+ reprta (Residual): Residual block for text prompt embeddings.
570
+ savpe (SAVPE): Spatial-aware visual prompt embeddings module.
571
+ embed (int): Embedding dimension.
572
+
573
+ Methods:
574
+ fuse: Fuse text features with model weights for efficient inference.
575
+ get_tpe: Get text prompt embeddings with normalization.
576
+ get_vpe: Get visual prompt embeddings with spatial awareness.
577
+ forward_lrpc: Process features with fused text embeddings for prompt-free model.
578
+ forward: Process features with class prompt embeddings to generate detections.
579
+ bias_init: Initialize biases for detection heads.
580
+
581
+ Examples:
582
+ Create a YOLOEDetect head
583
+ >>> yoloe_detect = YOLOEDetect(nc=80, embed=512, with_bn=True, ch=(256, 512, 1024))
584
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
585
+ >>> cls_pe = torch.randn(1, 80, 512)
586
+ >>> outputs = yoloe_detect(x, cls_pe)
587
+ """
387
588
 
388
589
  is_fused = False
389
590
 
390
- def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
391
- """Initialize YOLO detection layer with nc classes and layer channels ch."""
591
+ def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
592
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
593
+
594
+ Args:
595
+ nc (int): Number of classes.
596
+ embed (int): Embedding dimension.
597
+ with_bn (bool): Whether to use batch normalization in contrastive head.
598
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
599
+ """
392
600
  super().__init__(nc, ch)
393
601
  c3 = max(ch[0], min(self.nc, 100))
394
602
  assert c3 <= embed
395
- assert with_bn is True
603
+ assert with_bn
396
604
  self.cv3 = (
397
605
  nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
398
606
  if self.legacy
@@ -413,7 +621,7 @@ class YOLOEDetect(Detect):
413
621
  self.embed = embed
414
622
 
415
623
  @smart_inference_mode()
416
- def fuse(self, txt_feats):
624
+ def fuse(self, txt_feats: torch.Tensor):
417
625
  """Fuse text features with model weights for efficient inference."""
418
626
  if self.is_fused:
419
627
  return
@@ -459,11 +667,11 @@ class YOLOEDetect(Detect):
459
667
  self.reprta = nn.Identity()
460
668
  self.is_fused = True
461
669
 
462
- def get_tpe(self, tpe):
670
+ def get_tpe(self, tpe: torch.Tensor | None) -> torch.Tensor | None:
463
671
  """Get text prompt embeddings with normalization."""
464
672
  return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
465
673
 
466
- def get_vpe(self, x, vpe):
674
+ def get_vpe(self, x: list[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:
467
675
  """Get visual prompt embeddings with spatial awareness."""
468
676
  if vpe.shape[1] == 0: # no visual prompt embeddings
469
677
  return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
@@ -472,7 +680,7 @@ class YOLOEDetect(Detect):
472
680
  assert vpe.ndim == 3 # (B, N, D)
473
681
  return vpe
474
682
 
475
- def forward_lrpc(self, x, return_mask=False):
683
+ def forward_lrpc(self, x: list[torch.Tensor], return_mask: bool = False) -> torch.Tensor | tuple:
476
684
  """Process features with fused text embeddings to generate detections for prompt-free model."""
477
685
  masks = []
478
686
  assert self.is_fused, "Prompt-free inference requires model to be fused!"
@@ -510,7 +718,7 @@ class YOLOEDetect(Detect):
510
718
  else:
511
719
  return y if self.export else (y, x)
512
720
 
513
- def forward(self, x, cls_pe, return_mask=False):
721
+ def forward(self, x: list[torch.Tensor], cls_pe: torch.Tensor, return_mask: bool = False) -> torch.Tensor | tuple:
514
722
  """Process features with class prompt embeddings to generate detections."""
515
723
  if hasattr(self, "lrpc"): # for prompt-free inference
516
724
  return self.forward_lrpc(x, return_mask)
@@ -535,10 +743,41 @@ class YOLOEDetect(Detect):
535
743
 
536
744
 
537
745
  class YOLOESegment(YOLOEDetect):
538
- """YOLO segmentation head with text embedding capabilities."""
746
+ """YOLO segmentation head with text embedding capabilities.
539
747
 
540
- def __init__(self, nc=80, nm=32, npr=256, embed=512, with_bn=False, ch=()):
541
- """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions."""
748
+ This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks with
749
+ text-guided semantic understanding.
750
+
751
+ Attributes:
752
+ nm (int): Number of masks.
753
+ npr (int): Number of protos.
754
+ proto (Proto): Prototype generation module.
755
+ cv5 (nn.ModuleList): Convolution layers for mask coefficients.
756
+
757
+ Methods:
758
+ forward: Return model outputs and mask coefficients.
759
+
760
+ Examples:
761
+ Create a YOLOESegment head
762
+ >>> yoloe_segment = YOLOESegment(nc=80, nm=32, npr=256, embed=512, with_bn=True, ch=(256, 512, 1024))
763
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
764
+ >>> text = torch.randn(1, 80, 512)
765
+ >>> outputs = yoloe_segment(x, text)
766
+ """
767
+
768
+ def __init__(
769
+ self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
770
+ ):
771
+ """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
772
+
773
+ Args:
774
+ nc (int): Number of classes.
775
+ nm (int): Number of masks.
776
+ npr (int): Number of protos.
777
+ embed (int): Embedding dimension.
778
+ with_bn (bool): Whether to use batch normalization in contrastive head.
779
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
780
+ """
542
781
  super().__init__(nc, embed, with_bn, ch)
543
782
  self.nm = nm
544
783
  self.npr = npr
@@ -547,7 +786,7 @@ class YOLOESegment(YOLOEDetect):
547
786
  c5 = max(ch[0] // 4, self.nm)
548
787
  self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
549
788
 
550
- def forward(self, x, text):
789
+ def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> tuple | torch.Tensor:
551
790
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
552
791
  p = self.proto(x[0]) # mask protos
553
792
  bs = p.shape[0] # batch size
@@ -570,54 +809,88 @@ class YOLOESegment(YOLOEDetect):
570
809
 
571
810
 
572
811
  class RTDETRDecoder(nn.Module):
573
- """
574
- Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
812
+ """Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
575
813
 
576
814
  This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
577
815
  and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
578
816
  Transformer decoder layers to output the final predictions.
817
+
818
+ Attributes:
819
+ export (bool): Export mode flag.
820
+ hidden_dim (int): Dimension of hidden layers.
821
+ nhead (int): Number of heads in multi-head attention.
822
+ nl (int): Number of feature levels.
823
+ nc (int): Number of classes.
824
+ num_queries (int): Number of query points.
825
+ num_decoder_layers (int): Number of decoder layers.
826
+ input_proj (nn.ModuleList): Input projection layers for backbone features.
827
+ decoder (DeformableTransformerDecoder): Transformer decoder module.
828
+ denoising_class_embed (nn.Embedding): Class embeddings for denoising.
829
+ num_denoising (int): Number of denoising queries.
830
+ label_noise_ratio (float): Label noise ratio for training.
831
+ box_noise_scale (float): Box noise scale for training.
832
+ learnt_init_query (bool): Whether to learn initial query embeddings.
833
+ tgt_embed (nn.Embedding): Target embeddings for queries.
834
+ query_pos_head (MLP): Query position head.
835
+ enc_output (nn.Sequential): Encoder output layers.
836
+ enc_score_head (nn.Linear): Encoder score prediction head.
837
+ enc_bbox_head (MLP): Encoder bbox prediction head.
838
+ dec_score_head (nn.ModuleList): Decoder score prediction heads.
839
+ dec_bbox_head (nn.ModuleList): Decoder bbox prediction heads.
840
+
841
+ Methods:
842
+ forward: Run forward pass and return bounding box and classification scores.
843
+
844
+ Examples:
845
+ Create an RTDETRDecoder
846
+ >>> decoder = RTDETRDecoder(nc=80, ch=(512, 1024, 2048), hd=256, nq=300)
847
+ >>> x = [torch.randn(1, 512, 64, 64), torch.randn(1, 1024, 32, 32), torch.randn(1, 2048, 16, 16)]
848
+ >>> outputs = decoder(x)
579
849
  """
580
850
 
581
851
  export = False # export mode
852
+ shapes = []
853
+ anchors = torch.empty(0)
854
+ valid_mask = torch.empty(0)
855
+ dynamic = False
582
856
 
583
857
  def __init__(
584
858
  self,
585
- nc=80,
586
- ch=(512, 1024, 2048),
587
- hd=256, # hidden dim
588
- nq=300, # num queries
589
- ndp=4, # num decoder points
590
- nh=8, # num head
591
- ndl=6, # num decoder layers
592
- d_ffn=1024, # dim of feedforward
593
- dropout=0.0,
594
- act=nn.ReLU(),
595
- eval_idx=-1,
859
+ nc: int = 80,
860
+ ch: tuple = (512, 1024, 2048),
861
+ hd: int = 256, # hidden dim
862
+ nq: int = 300, # num queries
863
+ ndp: int = 4, # num decoder points
864
+ nh: int = 8, # num head
865
+ ndl: int = 6, # num decoder layers
866
+ d_ffn: int = 1024, # dim of feedforward
867
+ dropout: float = 0.0,
868
+ act: nn.Module = nn.ReLU(),
869
+ eval_idx: int = -1,
596
870
  # Training args
597
- nd=100, # num denoising
598
- label_noise_ratio=0.5,
599
- box_noise_scale=1.0,
600
- learnt_init_query=False,
871
+ nd: int = 100, # num denoising
872
+ label_noise_ratio: float = 0.5,
873
+ box_noise_scale: float = 1.0,
874
+ learnt_init_query: bool = False,
601
875
  ):
602
- """
603
- Initializes the RTDETRDecoder module with the given parameters.
876
+ """Initialize the RTDETRDecoder module with the given parameters.
604
877
 
605
878
  Args:
606
- nc (int): Number of classes. Default is 80.
607
- ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
608
- hd (int): Dimension of hidden layers. Default is 256.
609
- nq (int): Number of query points. Default is 300.
610
- ndp (int): Number of decoder points. Default is 4.
611
- nh (int): Number of heads in multi-head attention. Default is 8.
612
- ndl (int): Number of decoder layers. Default is 6.
613
- d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
614
- dropout (float): Dropout rate. Default is 0.0.
615
- act (nn.Module): Activation function. Default is nn.ReLU.
616
- eval_idx (int): Evaluation index. Default is -1.
617
- nd (int): Number of denoising. Default is 100.
618
- label_noise_ratio (float): Label noise ratio. Default is 0.5.
619
- box_noise_scale (float): Box noise scale. Default is 1.0.
620
- learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
879
+ nc (int): Number of classes.
880
+ ch (tuple): Channels in the backbone feature maps.
881
+ hd (int): Dimension of hidden layers.
882
+ nq (int): Number of query points.
883
+ ndp (int): Number of decoder points.
884
+ nh (int): Number of heads in multi-head attention.
885
+ ndl (int): Number of decoder layers.
886
+ d_ffn (int): Dimension of the feed-forward networks.
887
+ dropout (float): Dropout rate.
888
+ act (nn.Module): Activation function.
889
+ eval_idx (int): Evaluation index.
890
+ nd (int): Number of denoising.
891
+ label_noise_ratio (float): Label noise ratio.
892
+ box_noise_scale (float): Box noise scale.
893
+ learnt_init_query (bool): Whether to learn initial query embeddings.
621
894
  """
622
895
  super().__init__()
623
896
  self.hidden_dim = hd
@@ -659,17 +932,17 @@ class RTDETRDecoder(nn.Module):
659
932
 
660
933
  self._reset_parameters()
661
934
 
662
- def forward(self, x, batch=None):
663
- """
664
- Runs the forward pass of the module, returning bounding box and classification scores for the input.
935
+ def forward(self, x: list[torch.Tensor], batch: dict | None = None) -> tuple | torch.Tensor:
936
+ """Run the forward pass of the module, returning bounding box and classification scores for the input.
665
937
 
666
938
  Args:
667
- x (List[torch.Tensor]): List of feature maps from the backbone.
939
+ x (list[torch.Tensor]): List of feature maps from the backbone.
668
940
  batch (dict, optional): Batch information for training.
669
941
 
670
942
  Returns:
671
- (tuple | torch.Tensor): During training, returns a tuple of bounding boxes, scores, and other metadata.
672
- During inference, returns a tensor of shape (bs, 300, 4+nc) containing bounding boxes and class scores.
943
+ outputs (tuple | torch.Tensor): During training, returns a tuple of bounding boxes, scores, and other
944
+ metadata. During inference, returns a tensor of shape (bs, 300, 4+nc) containing bounding boxes and
945
+ class scores.
673
946
  """
674
947
  from ultralytics.models.utils.ops import get_cdn_group
675
948
 
@@ -708,25 +981,32 @@ class RTDETRDecoder(nn.Module):
708
981
  y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
709
982
  return y if self.export else (y, x)
710
983
 
711
- def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
712
- """
713
- Generates anchor bounding boxes for given shapes with specific grid size and validates them.
984
+ def _generate_anchors(
985
+ self,
986
+ shapes: list[list[int]],
987
+ grid_size: float = 0.05,
988
+ dtype: torch.dtype = torch.float32,
989
+ device: str = "cpu",
990
+ eps: float = 1e-2,
991
+ ) -> tuple[torch.Tensor, torch.Tensor]:
992
+ """Generate anchor bounding boxes for given shapes with specific grid size and validate them.
714
993
 
715
994
  Args:
716
995
  shapes (list): List of feature map shapes.
717
- grid_size (float, optional): Base size of grid cells. Default is 0.05.
718
- dtype (torch.dtype, optional): Data type for tensors. Default is torch.float32.
719
- device (str, optional): Device to create tensors on. Default is "cpu".
720
- eps (float, optional): Small value for numerical stability. Default is 1e-2.
996
+ grid_size (float, optional): Base size of grid cells.
997
+ dtype (torch.dtype, optional): Data type for tensors.
998
+ device (str, optional): Device to create tensors on.
999
+ eps (float, optional): Small value for numerical stability.
721
1000
 
722
1001
  Returns:
723
- (tuple): Tuple containing anchors and valid mask tensors.
1002
+ anchors (torch.Tensor): Generated anchor boxes.
1003
+ valid_mask (torch.Tensor): Valid mask for anchors.
724
1004
  """
725
1005
  anchors = []
726
1006
  for i, (h, w) in enumerate(shapes):
727
1007
  sy = torch.arange(end=h, dtype=dtype, device=device)
728
1008
  sx = torch.arange(end=w, dtype=dtype, device=device)
729
- grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
1009
+ grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
730
1010
  grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
731
1011
 
732
1012
  valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
@@ -740,15 +1020,15 @@ class RTDETRDecoder(nn.Module):
740
1020
  anchors = anchors.masked_fill(~valid_mask, float("inf"))
741
1021
  return anchors, valid_mask
742
1022
 
743
- def _get_encoder_input(self, x):
744
- """
745
- Processes and returns encoder inputs by getting projection features from input and concatenating them.
1023
+ def _get_encoder_input(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]]]:
1024
+ """Process and return encoder inputs by getting projection features from input and concatenating them.
746
1025
 
747
1026
  Args:
748
- x (List[torch.Tensor]): List of feature maps from the backbone.
1027
+ x (list[torch.Tensor]): List of feature maps from the backbone.
749
1028
 
750
1029
  Returns:
751
- (tuple): Tuple containing processed features and their shapes.
1030
+ feats (torch.Tensor): Processed features.
1031
+ shapes (list): List of feature map shapes.
752
1032
  """
753
1033
  # Get projection features
754
1034
  x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
@@ -766,24 +1046,34 @@ class RTDETRDecoder(nn.Module):
766
1046
  feats = torch.cat(feats, 1)
767
1047
  return feats, shapes
768
1048
 
769
- def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
770
- """
771
- Generates and prepares the input required for the decoder from the provided features and shapes.
1049
+ def _get_decoder_input(
1050
+ self,
1051
+ feats: torch.Tensor,
1052
+ shapes: list[list[int]],
1053
+ dn_embed: torch.Tensor | None = None,
1054
+ dn_bbox: torch.Tensor | None = None,
1055
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1056
+ """Generate and prepare the input required for the decoder from the provided features and shapes.
772
1057
 
773
1058
  Args:
774
1059
  feats (torch.Tensor): Processed features from encoder.
775
1060
  shapes (list): List of feature map shapes.
776
- dn_embed (torch.Tensor, optional): Denoising embeddings. Default is None.
777
- dn_bbox (torch.Tensor, optional): Denoising bounding boxes. Default is None.
1061
+ dn_embed (torch.Tensor, optional): Denoising embeddings.
1062
+ dn_bbox (torch.Tensor, optional): Denoising bounding boxes.
778
1063
 
779
1064
  Returns:
780
- (tuple): Tuple containing embeddings, reference bounding boxes, encoded bounding boxes, and scores.
1065
+ embeddings (torch.Tensor): Query embeddings for decoder.
1066
+ refer_bbox (torch.Tensor): Reference bounding boxes.
1067
+ enc_bboxes (torch.Tensor): Encoded bounding boxes.
1068
+ enc_scores (torch.Tensor): Encoded scores.
781
1069
  """
782
1070
  bs = feats.shape[0]
783
- # Prepare input for decoder
784
- anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
785
- features = self.enc_output(valid_mask * feats) # bs, h*w, 256
1071
+ if self.dynamic or self.shapes != shapes:
1072
+ self.anchors, self.valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
1073
+ self.shapes = shapes
786
1074
 
1075
+ # Prepare input for decoder
1076
+ features = self.enc_output(self.valid_mask * feats) # bs, h*w, 256
787
1077
  enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
788
1078
 
789
1079
  # Query selection
@@ -795,7 +1085,7 @@ class RTDETRDecoder(nn.Module):
795
1085
  # (bs, num_queries, 256)
796
1086
  top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
797
1087
  # (bs, num_queries, 4)
798
- top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
1088
+ top_k_anchors = self.anchors[:, topk_ind].view(bs, self.num_queries, -1)
799
1089
 
800
1090
  # Dynamic anchors + static content
801
1091
  refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
@@ -816,7 +1106,7 @@ class RTDETRDecoder(nn.Module):
816
1106
  return embeddings, refer_bbox, enc_bboxes, enc_scores
817
1107
 
818
1108
  def _reset_parameters(self):
819
- """Initializes or resets the parameters of the model's various components with predefined weights and biases."""
1109
+ """Initialize or reset the parameters of the model's various components with predefined weights and biases."""
820
1110
  # Class and bbox head init
821
1111
  bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
822
1112
  # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.
@@ -841,27 +1131,39 @@ class RTDETRDecoder(nn.Module):
841
1131
 
842
1132
 
843
1133
  class v10Detect(Detect):
844
- """
845
- v10 Detection head from https://arxiv.org/pdf/2405.14458.
1134
+ """v10 Detection head from https://arxiv.org/pdf/2405.14458.
846
1135
 
847
- Args:
848
- nc (int): Number of classes.
849
- ch (tuple): Tuple of channel sizes.
1136
+ This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions for
1137
+ improved efficiency and performance.
850
1138
 
851
1139
  Attributes:
1140
+ end2end (bool): End-to-end detection mode.
852
1141
  max_det (int): Maximum number of detections.
1142
+ cv3 (nn.ModuleList): Light classification head layers.
1143
+ one2one_cv3 (nn.ModuleList): One-to-one classification head layers.
853
1144
 
854
1145
  Methods:
855
- __init__(self, nc=80, ch=()): Initializes the v10Detect object.
856
- forward(self, x): Performs forward pass of the v10Detect module.
857
- bias_init(self): Initializes biases of the Detect module.
858
-
1146
+ __init__: Initialize the v10Detect object with specified number of classes and input channels.
1147
+ forward: Perform forward pass of the v10Detect module.
1148
+ bias_init: Initialize biases of the Detect module.
1149
+ fuse: Remove the one2many head for inference optimization.
1150
+
1151
+ Examples:
1152
+ Create a v10Detect head
1153
+ >>> v10_detect = v10Detect(nc=80, ch=(256, 512, 1024))
1154
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
1155
+ >>> outputs = v10_detect(x)
859
1156
  """
860
1157
 
861
1158
  end2end = True
862
1159
 
863
- def __init__(self, nc=80, ch=()):
864
- """Initializes the v10Detect object with the specified number of classes and input channels."""
1160
+ def __init__(self, nc: int = 80, ch: tuple = ()):
1161
+ """Initialize the v10Detect object with the specified number of classes and input channels.
1162
+
1163
+ Args:
1164
+ nc (int): Number of classes.
1165
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
1166
+ """
865
1167
  super().__init__(nc, ch)
866
1168
  c3 = max(ch[0], min(self.nc, 100)) # channels
867
1169
  # Light cls head
@@ -876,5 +1178,5 @@ class v10Detect(Detect):
876
1178
  self.one2one_cv3 = copy.deepcopy(self.cv3)
877
1179
 
878
1180
  def fuse(self):
879
- """Removes the one2many head."""
1181
+ """Remove the one2many head for inference optimization."""
880
1182
  self.cv2 = self.cv3 = nn.ModuleList([nn.Identity()] * self.nl)