dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -24,8 +24,7 @@ __all__ = "OBB", "Classify", "Detect", "Pose", "RTDETRDecoder", "Segment", "YOLO
24
24
 
25
25
 
26
26
  class Detect(nn.Module):
27
- """
28
- YOLO Detect head for object detection models.
27
+ """YOLO Detect head for object detection models.
29
28
 
30
29
  This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.
31
30
  It supports both training and inference modes, with optional end-to-end detection capabilities.
@@ -78,8 +77,7 @@ class Detect(nn.Module):
78
77
  xyxy = False # xyxy or xywh output
79
78
 
80
79
  def __init__(self, nc: int = 80, ch: tuple = ()):
81
- """
82
- Initialize the YOLO detection layer with specified number of classes and channels.
80
+ """Initialize the YOLO detection layer with specified number of classes and channels.
83
81
 
84
82
  Args:
85
83
  nc (int): Number of classes.
@@ -126,15 +124,14 @@ class Detect(nn.Module):
126
124
  return y if self.export else (y, x)
127
125
 
128
126
  def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
129
- """
130
- Perform forward pass of the v10Detect module.
127
+ """Perform forward pass of the v10Detect module.
131
128
 
132
129
  Args:
133
130
  x (list[torch.Tensor]): Input feature maps from different levels.
134
131
 
135
132
  Returns:
136
- outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.
137
- Inference mode returns processed detections or tuple with detections and raw outputs.
133
+ outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs. Inference mode returns
134
+ processed detections or tuple with detections and raw outputs.
138
135
  """
139
136
  x_detach = [xi.detach() for xi in x]
140
137
  one2one = [
@@ -150,8 +147,7 @@ class Detect(nn.Module):
150
147
  return y if self.export else (y, {"one2many": x, "one2one": one2one})
151
148
 
152
149
  def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
153
- """
154
- Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
150
+ """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
155
151
 
156
152
  Args:
157
153
  x (list[torch.Tensor]): List of feature maps from different detection layers.
@@ -166,22 +162,8 @@ class Detect(nn.Module):
166
162
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
167
163
  self.shape = shape
168
164
 
169
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
170
- box = x_cat[:, : self.reg_max * 4]
171
- cls = x_cat[:, self.reg_max * 4 :]
172
- else:
173
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
174
-
175
- if self.export and self.format in {"tflite", "edgetpu"}:
176
- # Precompute normalization factor to increase numerical stability
177
- # See https://github.com/ultralytics/ultralytics/issues/7371
178
- grid_h = shape[2]
179
- grid_w = shape[3]
180
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
181
- norm = self.strides / (self.stride[0] * grid_size)
182
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
183
- else:
184
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
165
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
166
+ dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
185
167
  return torch.cat((dbox, cls.sigmoid()), 1)
186
168
 
187
169
  def bias_init(self):
@@ -208,8 +190,7 @@ class Detect(nn.Module):
208
190
 
209
191
  @staticmethod
210
192
  def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
211
- """
212
- Post-process YOLO model predictions.
193
+ """Post-process YOLO model predictions.
213
194
 
214
195
  Args:
215
196
  preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
@@ -232,8 +213,7 @@ class Detect(nn.Module):
232
213
 
233
214
 
234
215
  class Segment(Detect):
235
- """
236
- YOLO Segment head for segmentation models.
216
+ """YOLO Segment head for segmentation models.
237
217
 
238
218
  This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
239
219
 
@@ -254,8 +234,7 @@ class Segment(Detect):
254
234
  """
255
235
 
256
236
  def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
257
- """
258
- Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
237
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
259
238
 
260
239
  Args:
261
240
  nc (int): Number of classes.
@@ -284,8 +263,7 @@ class Segment(Detect):
284
263
 
285
264
 
286
265
  class OBB(Detect):
287
- """
288
- YOLO OBB detection head for detection with rotation models.
266
+ """YOLO OBB detection head for detection with rotation models.
289
267
 
290
268
  This class extends the Detect head to include oriented bounding box prediction with rotation angles.
291
269
 
@@ -306,8 +284,7 @@ class OBB(Detect):
306
284
  """
307
285
 
308
286
  def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
309
- """
310
- Initialize OBB with number of classes `nc` and layer channels `ch`.
287
+ """Initialize OBB with number of classes `nc` and layer channels `ch`.
311
288
 
312
289
  Args:
313
290
  nc (int): Number of classes.
@@ -340,8 +317,7 @@ class OBB(Detect):
340
317
 
341
318
 
342
319
  class Pose(Detect):
343
- """
344
- YOLO Pose head for keypoints models.
320
+ """YOLO Pose head for keypoints models.
345
321
 
346
322
  This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
347
323
 
@@ -362,8 +338,7 @@ class Pose(Detect):
362
338
  """
363
339
 
364
340
  def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
365
- """
366
- Initialize YOLO network with default parameters and Convolutional Layers.
341
+ """Initialize YOLO network with default parameters and Convolutional Layers.
367
342
 
368
343
  Args:
369
344
  nc (int): Number of classes.
@@ -391,20 +366,9 @@ class Pose(Detect):
391
366
  """Decode keypoints from predictions."""
392
367
  ndim = self.kpt_shape[1]
393
368
  if self.export:
394
- if self.format in {
395
- "tflite",
396
- "edgetpu",
397
- }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
398
- # Precompute normalization factor to increase numerical stability
399
- y = kpts.view(bs, *self.kpt_shape, -1)
400
- grid_h, grid_w = self.shape[2], self.shape[3]
401
- grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
402
- norm = self.strides / (self.stride[0] * grid_size)
403
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
404
- else:
405
- # NCNN fix
406
- y = kpts.view(bs, *self.kpt_shape, -1)
407
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
369
+ # NCNN fix
370
+ y = kpts.view(bs, *self.kpt_shape, -1)
371
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
408
372
  if ndim == 3:
409
373
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
410
374
  return a.view(bs, self.nk, -1)
@@ -421,8 +385,7 @@ class Pose(Detect):
421
385
 
422
386
 
423
387
  class Classify(nn.Module):
424
- """
425
- YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
388
+ """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
426
389
 
427
390
  This class implements a classification head that transforms feature maps into class predictions.
428
391
 
@@ -446,8 +409,7 @@ class Classify(nn.Module):
446
409
  export = False # export mode
447
410
 
448
411
  def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
449
- """
450
- Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
412
+ """Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
451
413
 
452
414
  Args:
453
415
  c1 (int): Number of input channels.
@@ -476,11 +438,10 @@ class Classify(nn.Module):
476
438
 
477
439
 
478
440
  class WorldDetect(Detect):
479
- """
480
- Head for integrating YOLO detection models with semantic understanding from text embeddings.
441
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
481
442
 
482
- This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding
483
- in object detection tasks.
443
+ This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding in
444
+ object detection tasks.
484
445
 
485
446
  Attributes:
486
447
  cv3 (nn.ModuleList): Convolution layers for embedding features.
@@ -499,8 +460,7 @@ class WorldDetect(Detect):
499
460
  """
500
461
 
501
462
  def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
502
- """
503
- Initialize YOLO detection layer with nc classes and layer channels ch.
463
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
504
464
 
505
465
  Args:
506
466
  nc (int): Number of classes.
@@ -534,11 +494,10 @@ class WorldDetect(Detect):
534
494
 
535
495
 
536
496
  class LRPCHead(nn.Module):
537
- """
538
- Lightweight Region Proposal and Classification Head for efficient object detection.
497
+ """Lightweight Region Proposal and Classification Head for efficient object detection.
539
498
 
540
- This head combines region proposal filtering with classification to enable efficient detection with
541
- dynamic vocabulary support.
499
+ This head combines region proposal filtering with classification to enable efficient detection with dynamic
500
+ vocabulary support.
542
501
 
543
502
  Attributes:
544
503
  vocab (nn.Module): Vocabulary/classification layer.
@@ -559,8 +518,7 @@ class LRPCHead(nn.Module):
559
518
  """
560
519
 
561
520
  def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):
562
- """
563
- Initialize LRPCHead with vocabulary, proposal filter, and localization components.
521
+ """Initialize LRPCHead with vocabulary, proposal filter, and localization components.
564
522
 
565
523
  Args:
566
524
  vocab (nn.Module): Vocabulary/classification module.
@@ -599,8 +557,7 @@ class LRPCHead(nn.Module):
599
557
 
600
558
 
601
559
  class YOLOEDetect(Detect):
602
- """
603
- Head for integrating YOLO detection models with semantic understanding from text embeddings.
560
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings.
604
561
 
605
562
  This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding
606
563
  through text embeddings and visual prompt embeddings.
@@ -632,8 +589,7 @@ class YOLOEDetect(Detect):
632
589
  is_fused = False
633
590
 
634
591
  def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
635
- """
636
- Initialize YOLO detection layer with nc classes and layer channels ch.
592
+ """Initialize YOLO detection layer with nc classes and layer channels ch.
637
593
 
638
594
  Args:
639
595
  nc (int): Number of classes.
@@ -787,11 +743,10 @@ class YOLOEDetect(Detect):
787
743
 
788
744
 
789
745
  class YOLOESegment(YOLOEDetect):
790
- """
791
- YOLO segmentation head with text embedding capabilities.
746
+ """YOLO segmentation head with text embedding capabilities.
792
747
 
793
- This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks
794
- with text-guided semantic understanding.
748
+ This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks with
749
+ text-guided semantic understanding.
795
750
 
796
751
  Attributes:
797
752
  nm (int): Number of masks.
@@ -813,8 +768,7 @@ class YOLOESegment(YOLOEDetect):
813
768
  def __init__(
814
769
  self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
815
770
  ):
816
- """
817
- Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
771
+ """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
818
772
 
819
773
  Args:
820
774
  nc (int): Number of classes.
@@ -855,8 +809,7 @@ class YOLOESegment(YOLOEDetect):
855
809
 
856
810
 
857
811
  class RTDETRDecoder(nn.Module):
858
- """
859
- Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
812
+ """Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
860
813
 
861
814
  This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
862
815
  and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
@@ -920,8 +873,7 @@ class RTDETRDecoder(nn.Module):
920
873
  box_noise_scale: float = 1.0,
921
874
  learnt_init_query: bool = False,
922
875
  ):
923
- """
924
- Initialize the RTDETRDecoder module with the given parameters.
876
+ """Initialize the RTDETRDecoder module with the given parameters.
925
877
 
926
878
  Args:
927
879
  nc (int): Number of classes.
@@ -981,8 +933,7 @@ class RTDETRDecoder(nn.Module):
981
933
  self._reset_parameters()
982
934
 
983
935
  def forward(self, x: list[torch.Tensor], batch: dict | None = None) -> tuple | torch.Tensor:
984
- """
985
- Run the forward pass of the module, returning bounding box and classification scores for the input.
936
+ """Run the forward pass of the module, returning bounding box and classification scores for the input.
986
937
 
987
938
  Args:
988
939
  x (list[torch.Tensor]): List of feature maps from the backbone.
@@ -1038,8 +989,7 @@ class RTDETRDecoder(nn.Module):
1038
989
  device: str = "cpu",
1039
990
  eps: float = 1e-2,
1040
991
  ) -> tuple[torch.Tensor, torch.Tensor]:
1041
- """
1042
- Generate anchor bounding boxes for given shapes with specific grid size and validate them.
992
+ """Generate anchor bounding boxes for given shapes with specific grid size and validate them.
1043
993
 
1044
994
  Args:
1045
995
  shapes (list): List of feature map shapes.
@@ -1071,8 +1021,7 @@ class RTDETRDecoder(nn.Module):
1071
1021
  return anchors, valid_mask
1072
1022
 
1073
1023
  def _get_encoder_input(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]]]:
1074
- """
1075
- Process and return encoder inputs by getting projection features from input and concatenating them.
1024
+ """Process and return encoder inputs by getting projection features from input and concatenating them.
1076
1025
 
1077
1026
  Args:
1078
1027
  x (list[torch.Tensor]): List of feature maps from the backbone.
@@ -1104,8 +1053,7 @@ class RTDETRDecoder(nn.Module):
1104
1053
  dn_embed: torch.Tensor | None = None,
1105
1054
  dn_bbox: torch.Tensor | None = None,
1106
1055
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1107
- """
1108
- Generate and prepare the input required for the decoder from the provided features and shapes.
1056
+ """Generate and prepare the input required for the decoder from the provided features and shapes.
1109
1057
 
1110
1058
  Args:
1111
1059
  feats (torch.Tensor): Processed features from encoder.
@@ -1183,11 +1131,10 @@ class RTDETRDecoder(nn.Module):
1183
1131
 
1184
1132
 
1185
1133
  class v10Detect(Detect):
1186
- """
1187
- v10 Detection head from https://arxiv.org/pdf/2405.14458.
1134
+ """v10 Detection head from https://arxiv.org/pdf/2405.14458.
1188
1135
 
1189
- This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions
1190
- for improved efficiency and performance.
1136
+ This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions for
1137
+ improved efficiency and performance.
1191
1138
 
1192
1139
  Attributes:
1193
1140
  end2end (bool): End-to-end detection mode.
@@ -1211,8 +1158,7 @@ class v10Detect(Detect):
1211
1158
  end2end = True
1212
1159
 
1213
1160
  def __init__(self, nc: int = 80, ch: tuple = ()):
1214
- """
1215
- Initialize the v10Detect object with the specified number of classes and input channels.
1161
+ """Initialize the v10Detect object with the specified number of classes and input channels.
1216
1162
 
1217
1163
  Args:
1218
1164
  nc (int): Number of classes.