dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -20,12 +20,11 @@ from .conv import Conv, DWConv
20
20
  from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
21
21
  from .utils import bias_init_with_prob, linear_init
22
22
 
23
- __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "YOLOEDetect", "YOLOESegment"
23
+ __all__ = "OBB", "Classify", "Detect", "Pose", "RTDETRDecoder", "Segment", "YOLOEDetect", "YOLOESegment", "v10Detect"
24
24
 
25
25
 
26
26
  class Detect(nn.Module):
27
- """
28
- YOLO Detect head for object detection models.
27
+ """YOLO Detect head for object detection models.
29
28
 
30
29
  This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.
31
30
  It supports both training and inference modes, with optional end-to-end detection capabilities.
@@ -78,8 +77,7 @@ class Detect(nn.Module):
78
77
  xyxy = False # xyxy or xywh output
79
78
 
80
79
  def __init__(self, nc: int = 80, ch: tuple = ()):
81
- """
82
- Initialize the YOLO detection layer with specified number of classes and channels.
80
+ """Initialize the YOLO detection layer with specified number of classes and channels.
83
81
 
84
82
  Args:
85
83
  nc (int): Number of classes.
@@ -126,15 +124,14 @@ class Detect(nn.Module):
126
124
  return y if self.export else (y, x)
127
125
 
128
126
  def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
129
- """
130
- Perform forward pass of the v10Detect module.
127
+ """Perform forward pass of the v10Detect module.
131
128
 
132
129
  Args:
133
130
  x (list[torch.Tensor]): Input feature maps from different levels.
134
131
 
135
132
  Returns:
136
- outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.
137
- Inference mode returns processed detections or tuple with detections and raw outputs.
133
+ outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs. Inference mode returns
134
+ processed detections or tuple with detections and raw outputs.
138
135
  """
139
136
  x_detach = [xi.detach() for xi in x]
140
137
  one2one = [
@@ -150,8 +147,7 @@ class Detect(nn.Module):
150
147
  return y if self.export else (y, {"one2many": x, "one2one": one2one})
151
148
 
152
149
  def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
153
- """
154
- Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
150
+ """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
155
151
 
156
152
  Args:
157
153
  x (list[torch.Tensor]): List of feature maps from different detection layers.
@@ -166,22 +162,8 @@ class Detect(nn.Module):
166
162
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
167
163
  self.shape = shape
168
164
 
169
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
170
- box = x_cat[:, : self.reg_max * 4]
171
- cls = x_cat[:, self.reg_max * 4 :]
172
- else:
173
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
174
-
175
- if self.export and self.format in {"tflite", "edgetpu"}:
176
- # Precompute normalization factor to increase numerical stability
177
- # See https://github.com/ultralytics/ultralytics/issues/7371
178
- grid_h = shape[2]
179
- grid_w = shape[3]
180
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
181
- norm = self.strides / (self.stride[0] * grid_size)
182
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
183
- else:
184
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
165
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
166
+ dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
185
167
  return torch.cat((dbox, cls.sigmoid()), 1)
186
168
 
187
169
  def bias_init(self):
@@ -208,8 +190,7 @@ class Detect(nn.Module):
208
190
 
209
191
  @staticmethod
210
192
  def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
211
- """
212
- Post-process YOLO model predictions.
193
+ """Post-process YOLO model predictions.
213
194
 
214
195
  Args:
215
196
  preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
@@ -232,8 +213,7 @@ class Detect(nn.Module):
232
213
 
233
214
 
234
215
  class Segment(Detect):
235
- """
236
- YOLO Segment head for segmentation models.
216
+ """YOLO Segment head for segmentation models.
237
217
 
238
218
  This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
239
219
 
@@ -254,8 +234,7 @@ class Segment(Detect):
254
234
  """
255
235
 
256
236
  def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
257
- """
258
- Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
237
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
259
238
 
260
239
  Args:
261
240
  nc (int): Number of classes.
@@ -284,8 +263,7 @@ class Segment(Detect):
284
263
 
285
264
 
286
265
  class OBB(Detect):
287
- """
288
- YOLO OBB detection head for detection with rotation models.
266
+ """YOLO OBB detection head for detection with rotation models.
289
267
 
290
268
  This class extends the Detect head to include oriented bounding box prediction with rotation angles.
291
269
 
@@ -306,8 +284,7 @@ class OBB(Detect):
306
284
  """
307
285
 
308
286
  def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
309
- """
310
- Initialize OBB with number of classes `nc` and layer channels `ch`.
287
+ """Initialize OBB with number of classes `nc` and layer channels `ch`.
311
288
 
312
289
  Args:
313
290
  nc (int): Number of classes.
@@ -340,8 +317,7 @@ class OBB(Detect):
340
317
 
341
318
 
342
319
  class Pose(Detect):
343
- """
344
- YOLO Pose head for keypoints models.
320
+ """YOLO Pose head for keypoints models.
345
321
 
346
322
  This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
347
323
 
@@ -362,8 +338,7 @@ class Pose(Detect):
362
338
  """
363
339
 
364
340
  def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
365
- """
366
- Initialize YOLO network with default parameters and Convolutional Layers.
341
+ """Initialize YOLO network with default parameters and Convolutional Layers.
367
342
 
368
343
  Args:
369
344
  nc (int): Number of classes.
@@ -391,20 +366,9 @@ class Pose(Detect):
391
366
  """Decode keypoints from predictions."""
392
367
  ndim = self.kpt_shape[1]
393
368
  if self.export:
394
- if self.format in {
395
- "tflite",
396
- "edgetpu",
397
- }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
398
- # Precompute normalization factor to increase numerical stability
399
- y = kpts.view(bs, *self.kpt_shape, -1)
400
- grid_h, grid_w = self.shape[2], self.shape[3]
401
- grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
402
- norm = self.strides / (self.stride[0] * grid_size)
403
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
404
- else:
405
- # NCNN fix
406
- y = kpts.view(bs, *self.kpt_shape, -1)
407
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
369
+ # NCNN fix
370
+ y = kpts.view(bs, *self.kpt_shape, -1)
371
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
408
372
  if ndim == 3:
409
373
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
410
374
  return a.view(bs, self.nk, -1)
@@ -421,8 +385,7 @@ class Pose(Detect):
421
385
 
422
386
 
423
387
  class Classify(nn.Module):
424
- """
425
- YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
388
+ """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
426
389
 
427
390
  This class implements a classification head that transforms feature maps into class predictions.
428
391
 
@@ -446,8 +409,7 @@ class Classify(nn.Module):
446
409
  export = False # export mode
447
410
 
448
411
  def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
449
- """
450
- Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
412
+ """Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
451
413
 
452
414
  Args:
453
415
  c1 (int): Number of input channels.
@@ -476,11 +438,10 @@ class Classify(nn.Module):
476
438
 
477
439
 
478
440
  class WorldDetect(Detect):
479
- """
480
- Head for integrating YOLO detection models with semantic understanding from text embeddings.
441
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
481
442
 
482
- This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding
483
- in object detection tasks.
443
+ This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding in
444
+ object detection tasks.
484
445
 
485
446
  Attributes:
486
447
  cv3 (nn.ModuleList): Convolution layers for embedding features.
@@ -499,8 +460,7 @@ class WorldDetect(Detect):
499
460
  """
500
461
 
501
462
  def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
502
- """
503
- Initialize YOLO detection layer with nc classes and layer channels ch.
463
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
504
464
 
505
465
  Args:
506
466
  nc (int): Number of classes.
@@ -534,11 +494,10 @@ class WorldDetect(Detect):
534
494
 
535
495
 
536
496
  class LRPCHead(nn.Module):
537
- """
538
- Lightweight Region Proposal and Classification Head for efficient object detection.
497
+ """Lightweight Region Proposal and Classification Head for efficient object detection.
539
498
 
540
- This head combines region proposal filtering with classification to enable efficient detection with
541
- dynamic vocabulary support.
499
+ This head combines region proposal filtering with classification to enable efficient detection with dynamic
500
+ vocabulary support.
542
501
 
543
502
  Attributes:
544
503
  vocab (nn.Module): Vocabulary/classification layer.
@@ -559,8 +518,7 @@ class LRPCHead(nn.Module):
559
518
  """
560
519
 
561
520
  def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):
562
- """
563
- Initialize LRPCHead with vocabulary, proposal filter, and localization components.
521
+ """Initialize LRPCHead with vocabulary, proposal filter, and localization components.
564
522
 
565
523
  Args:
566
524
  vocab (nn.Module): Vocabulary/classification module.
@@ -574,7 +532,8 @@ class LRPCHead(nn.Module):
574
532
  self.loc = loc
575
533
  self.enabled = enabled
576
534
 
577
- def conv2linear(self, conv: nn.Conv2d) -> nn.Linear:
535
+ @staticmethod
536
+ def conv2linear(conv: nn.Conv2d) -> nn.Linear:
578
537
  """Convert a 1x1 convolutional layer to a linear layer."""
579
538
  assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)
580
539
  linear = nn.Linear(conv.in_channels, conv.out_channels)
@@ -599,8 +558,7 @@ class LRPCHead(nn.Module):
599
558
 
600
559
 
601
560
  class YOLOEDetect(Detect):
602
- """
603
- Head for integrating YOLO detection models with semantic understanding from text embeddings.
561
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
604
562
 
605
563
  This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding
606
564
  through text embeddings and visual prompt embeddings.
@@ -632,8 +590,7 @@ class YOLOEDetect(Detect):
632
590
  is_fused = False
633
591
 
634
592
  def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
635
- """
636
- Initialize YOLO detection layer with nc classes and layer channels ch.
593
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
637
594
 
638
595
  Args:
639
596
  nc (int): Number of classes.
@@ -787,11 +744,10 @@ class YOLOEDetect(Detect):
787
744
 
788
745
 
789
746
  class YOLOESegment(YOLOEDetect):
790
- """
791
- YOLO segmentation head with text embedding capabilities.
747
+ """YOLO segmentation head with text embedding capabilities.
792
748
 
793
- This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks
794
- with text-guided semantic understanding.
749
+ This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks with
750
+ text-guided semantic understanding.
795
751
 
796
752
  Attributes:
797
753
  nm (int): Number of masks.
@@ -813,8 +769,7 @@ class YOLOESegment(YOLOEDetect):
813
769
  def __init__(
814
770
  self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
815
771
  ):
816
- """
817
- Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
772
+ """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
818
773
 
819
774
  Args:
820
775
  nc (int): Number of classes.
@@ -855,8 +810,7 @@ class YOLOESegment(YOLOEDetect):
855
810
 
856
811
 
857
812
  class RTDETRDecoder(nn.Module):
858
- """
859
- Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
813
+ """Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
860
814
 
861
815
  This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
862
816
  and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
@@ -920,8 +874,7 @@ class RTDETRDecoder(nn.Module):
920
874
  box_noise_scale: float = 1.0,
921
875
  learnt_init_query: bool = False,
922
876
  ):
923
- """
924
- Initialize the RTDETRDecoder module with the given parameters.
877
+ """Initialize the RTDETRDecoder module with the given parameters.
925
878
 
926
879
  Args:
927
880
  nc (int): Number of classes.
@@ -981,8 +934,7 @@ class RTDETRDecoder(nn.Module):
981
934
  self._reset_parameters()
982
935
 
983
936
  def forward(self, x: list[torch.Tensor], batch: dict | None = None) -> tuple | torch.Tensor:
984
- """
985
- Run the forward pass of the module, returning bounding box and classification scores for the input.
937
+ """Run the forward pass of the module, returning bounding box and classification scores for the input.
986
938
 
987
939
  Args:
988
940
  x (list[torch.Tensor]): List of feature maps from the backbone.
@@ -1030,16 +982,15 @@ class RTDETRDecoder(nn.Module):
1030
982
  y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
1031
983
  return y if self.export else (y, x)
1032
984
 
985
+ @staticmethod
1033
986
  def _generate_anchors(
1034
- self,
1035
987
  shapes: list[list[int]],
1036
988
  grid_size: float = 0.05,
1037
989
  dtype: torch.dtype = torch.float32,
1038
990
  device: str = "cpu",
1039
991
  eps: float = 1e-2,
1040
992
  ) -> tuple[torch.Tensor, torch.Tensor]:
1041
- """
1042
- Generate anchor bounding boxes for given shapes with specific grid size and validate them.
993
+ """Generate anchor bounding boxes for given shapes with specific grid size and validate them.
1043
994
 
1044
995
  Args:
1045
996
  shapes (list): List of feature map shapes.
@@ -1071,8 +1022,7 @@ class RTDETRDecoder(nn.Module):
1071
1022
  return anchors, valid_mask
1072
1023
 
1073
1024
  def _get_encoder_input(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]]]:
1074
- """
1075
- Process and return encoder inputs by getting projection features from input and concatenating them.
1025
+ """Process and return encoder inputs by getting projection features from input and concatenating them.
1076
1026
 
1077
1027
  Args:
1078
1028
  x (list[torch.Tensor]): List of feature maps from the backbone.
@@ -1104,8 +1054,7 @@ class RTDETRDecoder(nn.Module):
1104
1054
  dn_embed: torch.Tensor | None = None,
1105
1055
  dn_bbox: torch.Tensor | None = None,
1106
1056
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1107
- """
1108
- Generate and prepare the input required for the decoder from the provided features and shapes.
1057
+ """Generate and prepare the input required for the decoder from the provided features and shapes.
1109
1058
 
1110
1059
  Args:
1111
1060
  feats (torch.Tensor): Processed features from encoder.
@@ -1129,9 +1078,9 @@ class RTDETRDecoder(nn.Module):
1129
1078
  enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
1130
1079
 
1131
1080
  # Query selection
1132
- # (bs, num_queries)
1081
+ # (bs*num_queries,)
1133
1082
  topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
1134
- # (bs, num_queries)
1083
+ # (bs*num_queries,)
1135
1084
  batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
1136
1085
 
1137
1086
  # (bs, num_queries, 256)
@@ -1183,11 +1132,10 @@ class RTDETRDecoder(nn.Module):
1183
1132
 
1184
1133
 
1185
1134
  class v10Detect(Detect):
1186
- """
1187
- v10 Detection head from https://arxiv.org/pdf/2405.14458.
1135
+ """v10 Detection head from https://arxiv.org/pdf/2405.14458.
1188
1136
 
1189
- This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions
1190
- for improved efficiency and performance.
1137
+ This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions for
1138
+ improved efficiency and performance.
1191
1139
 
1192
1140
  Attributes:
1193
1141
  end2end (bool): End-to-end detection mode.
@@ -1211,8 +1159,7 @@ class v10Detect(Detect):
1211
1159
  end2end = True
1212
1160
 
1213
1161
  def __init__(self, nc: int = 80, ch: tuple = ()):
1214
- """
1215
- Initialize the v10Detect object with the specified number of classes and input channels.
1162
+ """Initialize the v10Detect object with the specified number of classes and input channels.
1216
1163
 
1217
1164
  Args:
1218
1165
  nc (int): Number of classes.