dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,10 @@ from .val import RTDETRValidator
19
19
 
20
20
 
21
21
  class RTDETR(Model):
22
- """
23
- Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
22
+ """Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
24
23
 
25
- This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
26
- query selection, and adaptable inference speed.
24
+ This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
25
+ selection, and adaptable inference speed.
27
26
 
28
27
  Attributes:
29
28
  model (str): Path to the pre-trained model.
@@ -39,8 +38,7 @@ class RTDETR(Model):
39
38
  """
40
39
 
41
40
  def __init__(self, model: str = "rtdetr-l.pt") -> None:
42
- """
43
- Initialize the RT-DETR model with the given pre-trained model file.
41
+ """Initialize the RT-DETR model with the given pre-trained model file.
44
42
 
45
43
  Args:
46
44
  model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
@@ -50,8 +48,7 @@ class RTDETR(Model):
50
48
 
51
49
  @property
52
50
  def task_map(self) -> dict:
53
- """
54
- Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
51
+ """Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
55
52
 
56
53
  Returns:
57
54
  (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
@@ -9,11 +9,10 @@ from ultralytics.utils import ops
9
9
 
10
10
 
11
11
  class RTDETRPredictor(BasePredictor):
12
- """
13
- RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
12
+ """RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
14
13
 
15
- This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
16
- It supports key features like efficient hybrid encoding and IoU-aware query selection.
14
+ This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It
15
+ supports key features like efficient hybrid encoding and IoU-aware query selection.
17
16
 
18
17
  Attributes:
19
18
  imgsz (int): Image size for inference (must be square and scale-filled).
@@ -34,21 +33,20 @@ class RTDETRPredictor(BasePredictor):
34
33
  """
35
34
 
36
35
  def postprocess(self, preds, img, orig_imgs):
37
- """
38
- Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
36
+ """Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
39
37
 
40
- The method filters detections based on confidence and class if specified in `self.args`. It converts
41
- model predictions to Results objects containing properly scaled bounding boxes.
38
+ The method filters detections based on confidence and class if specified in `self.args`. It converts model
39
+ predictions to Results objects containing properly scaled bounding boxes.
42
40
 
43
41
  Args:
44
- preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
45
- bounding boxes and scores.
42
+ preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
43
+ and scores.
46
44
  img (torch.Tensor): Processed input images with shape (N, 3, H, W).
47
45
  orig_imgs (list | torch.Tensor): Original, unprocessed images.
48
46
 
49
47
  Returns:
50
- results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
51
- confidence scores, and class labels.
48
+ results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
49
+ scores, and class labels.
52
50
  """
53
51
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
54
52
  preds = [preds, None]
@@ -57,7 +55,7 @@ class RTDETRPredictor(BasePredictor):
57
55
  bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
58
56
 
59
57
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
60
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
58
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
61
59
 
62
60
  results = []
63
61
  for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
@@ -75,15 +73,13 @@ class RTDETRPredictor(BasePredictor):
75
73
  return results
76
74
 
77
75
  def pre_transform(self, im):
78
- """
79
- Pre-transform input images before feeding them into the model for inference.
76
+ """Pre-transform input images before feeding them into the model for inference.
80
77
 
81
- The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square
82
- (640) and scale_filled.
78
+ The input images are letterboxed to ensure a square aspect ratio and scale-filled.
83
79
 
84
80
  Args:
85
- im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
86
- [(H, W, 3) x N] for list.
81
+ im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
82
+ list.
87
83
 
88
84
  Returns:
89
85
  (list): List of pre-transformed images ready for model inference.
@@ -12,12 +12,11 @@ from .val import RTDETRDataset, RTDETRValidator
12
12
 
13
13
 
14
14
  class RTDETRTrainer(DetectionTrainer):
15
- """
16
- Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
15
+ """Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
17
16
 
18
- This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.
19
- The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference
20
- speed.
17
+ This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
18
+ RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
19
+ inference speed.
21
20
 
22
21
  Attributes:
23
22
  loss_names (tuple): Names of the loss components used for training.
@@ -31,20 +30,19 @@ class RTDETRTrainer(DetectionTrainer):
31
30
  build_dataset: Build and return an RT-DETR dataset for training or validation.
32
31
  get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
33
32
 
34
- Notes:
35
- - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
36
- - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
37
-
38
33
  Examples:
39
34
  >>> from ultralytics.models.rtdetr.train import RTDETRTrainer
40
35
  >>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
41
36
  >>> trainer = RTDETRTrainer(overrides=args)
42
37
  >>> trainer.train()
38
+
39
+ Notes:
40
+ - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
41
+ - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
43
42
  """
44
43
 
45
44
  def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
46
- """
47
- Initialize and return an RT-DETR model for object detection tasks.
45
+ """Initialize and return an RT-DETR model for object detection tasks.
48
46
 
49
47
  Args:
50
48
  cfg (dict, optional): Model configuration.
@@ -60,8 +58,7 @@ class RTDETRTrainer(DetectionTrainer):
60
58
  return model
61
59
 
62
60
  def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
63
- """
64
- Build and return an RT-DETR dataset for training or validation.
61
+ """Build and return an RT-DETR dataset for training or validation.
65
62
 
66
63
  Args:
67
64
  img_path (str): Path to the folder containing images.
@@ -16,8 +16,7 @@ __all__ = ("RTDETRValidator",) # tuple or list
16
16
 
17
17
 
18
18
  class RTDETRDataset(YOLODataset):
19
- """
20
- Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
19
+ """Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
21
20
 
22
21
  This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
23
22
  real-time detection and tracking tasks.
@@ -36,12 +35,11 @@ class RTDETRDataset(YOLODataset):
36
35
  Examples:
37
36
  Initialize an RT-DETR dataset
38
37
  >>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
39
- >>> image, hw = dataset.load_image(0)
38
+ >>> image, hw0, hw = dataset.load_image(0)
40
39
  """
41
40
 
42
41
  def __init__(self, *args, data=None, **kwargs):
43
- """
44
- Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
42
+ """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
45
43
 
46
44
  This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
47
45
  model, building upon the base YOLODataset functionality.
@@ -54,27 +52,26 @@ class RTDETRDataset(YOLODataset):
54
52
  super().__init__(*args, data=data, **kwargs)
55
53
 
56
54
  def load_image(self, i, rect_mode=False):
57
- """
58
- Load one image from dataset index 'i'.
55
+ """Load one image from dataset index 'i'.
59
56
 
60
57
  Args:
61
58
  i (int): Index of the image to load.
62
59
  rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
63
60
 
64
61
  Returns:
65
- im (torch.Tensor): The loaded image.
66
- resized_hw (tuple): Height and width of the resized image with shape (2,).
62
+ im (np.ndarray): Loaded image as a NumPy array.
63
+ hw_original (tuple[int, int]): Original image dimensions in (height, width) format.
64
+ hw_resized (tuple[int, int]): Resized image dimensions in (height, width) format.
67
65
 
68
66
  Examples:
69
67
  Load an image from the dataset
70
68
  >>> dataset = RTDETRDataset(img_path="path/to/images")
71
- >>> image, hw = dataset.load_image(0)
69
+ >>> image, hw0, hw = dataset.load_image(0)
72
70
  """
73
71
  return super().load_image(i=i, rect_mode=rect_mode)
74
72
 
75
73
  def build_transforms(self, hyp=None):
76
- """
77
- Build transformation pipeline for the dataset.
74
+ """Build transformation pipeline for the dataset.
78
75
 
79
76
  Args:
80
77
  hyp (dict, optional): Hyperparameters for transformations.
@@ -105,8 +102,7 @@ class RTDETRDataset(YOLODataset):
105
102
 
106
103
 
107
104
  class RTDETRValidator(DetectionValidator):
108
- """
109
- RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
105
+ """RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
110
106
  the RT-DETR (Real-Time DETR) object detection model.
111
107
 
112
108
  The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
@@ -132,8 +128,7 @@ class RTDETRValidator(DetectionValidator):
132
128
  """
133
129
 
134
130
  def build_dataset(self, img_path, mode="val", batch=None):
135
- """
136
- Build an RTDETR Dataset.
131
+ """Build an RTDETR Dataset.
137
132
 
138
133
  Args:
139
134
  img_path (str): Path to the folder containing images.
@@ -156,15 +151,19 @@ class RTDETRValidator(DetectionValidator):
156
151
  data=self.data,
157
152
  )
158
153
 
154
+ def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
155
+ """Scales predictions to the original image size."""
156
+ return predn
157
+
159
158
  def postprocess(
160
159
  self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
161
160
  ) -> list[dict[str, torch.Tensor]]:
162
- """
163
- Apply Non-maximum suppression to prediction outputs.
161
+ """Apply Non-maximum suppression to prediction outputs.
164
162
 
165
163
  Args:
166
164
  preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
167
- (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
165
+ (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
166
+ class scores.
168
167
 
169
168
  Returns:
170
169
  (list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
@@ -190,12 +189,11 @@ class RTDETRValidator(DetectionValidator):
190
189
  return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
191
190
 
192
191
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
193
- """
194
- Serialize YOLO predictions to COCO json format.
192
+ """Serialize YOLO predictions to COCO json format.
195
193
 
196
194
  Args:
197
- predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
198
- with bounding box coordinates, confidence scores, and class predictions.
195
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
196
+ bounding box coordinates, confidence scores, and class predictions.
199
197
  pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
200
198
  """
201
199
  path = Path(pbatch["im_file"])
@@ -1,12 +1,25 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .model import SAM
4
- from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
4
+ from .predict import (
5
+ Predictor,
6
+ SAM2DynamicInteractivePredictor,
7
+ SAM2Predictor,
8
+ SAM2VideoPredictor,
9
+ SAM3Predictor,
10
+ SAM3SemanticPredictor,
11
+ SAM3VideoPredictor,
12
+ SAM3VideoSemanticPredictor,
13
+ )
5
14
 
6
15
  __all__ = (
7
16
  "SAM",
8
17
  "Predictor",
18
+ "SAM2DynamicInteractivePredictor",
9
19
  "SAM2Predictor",
10
20
  "SAM2VideoPredictor",
11
- "SAM2DynamicInteractivePredictor",
21
+ "SAM3Predictor",
22
+ "SAM3SemanticPredictor",
23
+ "SAM3VideoPredictor",
24
+ "SAM3VideoSemanticPredictor",
12
25
  ) # tuple or list of exportable items
@@ -14,8 +14,7 @@ import torch
14
14
  def is_box_near_crop_edge(
15
15
  boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
16
16
  ) -> torch.Tensor:
17
- """
18
- Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
17
+ """Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
19
18
 
20
19
  Args:
21
20
  boxes (torch.Tensor): Bounding boxes in XYXY format.
@@ -42,8 +41,7 @@ def is_box_near_crop_edge(
42
41
 
43
42
 
44
43
  def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
45
- """
46
- Yield batches of data from input arguments with specified batch size for efficient processing.
44
+ """Yield batches of data from input arguments with specified batch size for efficient processing.
47
45
 
48
46
  This function takes a batch size and any number of iterables, then yields batches of elements from those
49
47
  iterables. All input iterables must have the same length.
@@ -71,11 +69,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
71
69
 
72
70
 
73
71
  def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
74
- """
75
- Compute the stability score for a batch of masks.
72
+ """Compute the stability score for a batch of masks.
76
73
 
77
- The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
78
- high and low values.
74
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
75
+ low values.
79
76
 
80
77
  Args:
81
78
  masks (torch.Tensor): Batch of predicted mask logits.
@@ -85,15 +82,15 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
85
82
  Returns:
86
83
  (torch.Tensor): Stability scores for each mask in the batch.
87
84
 
88
- Notes:
89
- - One mask is always contained inside the other.
90
- - Memory is saved by preventing unnecessary cast to torch.int64.
91
-
92
85
  Examples:
93
86
  >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
94
87
  >>> mask_threshold = 0.5
95
88
  >>> threshold_offset = 0.1
96
89
  >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
90
+
91
+ Notes:
92
+ - One mask is always contained inside the other.
93
+ - Memory is saved by preventing unnecessary cast to torch.int64.
97
94
  """
98
95
  intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
99
96
  unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@@ -117,8 +114,7 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
117
114
  def generate_crop_boxes(
118
115
  im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
119
116
  ) -> tuple[list[list[int]], list[int]]:
120
- """
121
- Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
117
+ """Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
122
118
 
123
119
  Args:
124
120
  im_size (tuple[int, ...]): Height and width of the input image.
@@ -145,7 +141,7 @@ def generate_crop_boxes(
145
141
 
146
142
  def crop_len(orig_len, n_crops, overlap):
147
143
  """Calculate the length of each crop given the original length, number of crops, and overlap."""
148
- return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
144
+ return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)
149
145
 
150
146
  for i_layer in range(n_layers):
151
147
  n_crops_per_side = 2 ** (i_layer + 1)
@@ -198,8 +194,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w:
198
194
 
199
195
 
200
196
  def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
201
- """
202
- Remove small disconnected regions or holes in a mask based on area threshold and mode.
197
+ """Remove small disconnected regions or holes in a mask based on area threshold and mode.
203
198
 
204
199
  Args:
205
200
  mask (np.ndarray): Binary mask to process.
@@ -227,7 +222,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
227
222
  small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
228
223
  if not small_regions:
229
224
  return mask, False
230
- fill_labels = [0] + small_regions
225
+ fill_labels = [0, *small_regions]
231
226
  if not correct_holes:
232
227
  # If every region is below threshold, keep largest
233
228
  fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
@@ -236,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
236
231
 
237
232
 
238
233
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
239
- """
240
- Calculate bounding boxes in XYXY format around binary masks.
234
+ """Calculate bounding boxes in XYXY format around binary masks.
241
235
 
242
236
  Args:
243
237
  masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
@@ -11,6 +11,7 @@ from functools import partial
11
11
  import torch
12
12
 
13
13
  from ultralytics.utils.downloads import attempt_download_asset
14
+ from ultralytics.utils.patches import torch_load
14
15
 
15
16
  from .modules.decoders import MaskDecoder
16
17
  from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
@@ -20,6 +21,21 @@ from .modules.tiny_encoder import TinyViT
20
21
  from .modules.transformer import TwoWayTransformer
21
22
 
22
23
 
24
+ def _load_checkpoint(model, checkpoint):
25
+ """Load checkpoint into model from file path."""
26
+ if checkpoint is None:
27
+ return model
28
+
29
+ checkpoint = attempt_download_asset(checkpoint)
30
+ with open(checkpoint, "rb") as f:
31
+ state_dict = torch_load(f)
32
+ # Handle nested "model" key
33
+ if "model" in state_dict and isinstance(state_dict["model"], dict):
34
+ state_dict = state_dict["model"]
35
+ model.load_state_dict(state_dict)
36
+ return model
37
+
38
+
23
39
  def build_sam_vit_h(checkpoint=None):
24
40
  """Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
25
41
  return _build_sam(
@@ -126,8 +142,7 @@ def _build_sam(
126
142
  checkpoint=None,
127
143
  mobile_sam=False,
128
144
  ):
129
- """
130
- Build a Segment Anything Model (SAM) with specified encoder parameters.
145
+ """Build a Segment Anything Model (SAM) with specified encoder parameters.
131
146
 
132
147
  Args:
133
148
  encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
@@ -205,26 +220,22 @@ def _build_sam(
205
220
  pixel_std=[58.395, 57.12, 57.375],
206
221
  )
207
222
  if checkpoint is not None:
208
- checkpoint = attempt_download_asset(checkpoint)
209
- with open(checkpoint, "rb") as f:
210
- state_dict = torch.load(f)
211
- sam.load_state_dict(state_dict)
223
+ sam = _load_checkpoint(sam, checkpoint)
212
224
  sam.eval()
213
225
  return sam
214
226
 
215
227
 
216
228
  def _build_sam2(
217
229
  encoder_embed_dim=1280,
218
- encoder_stages=[2, 6, 36, 4],
230
+ encoder_stages=(2, 6, 36, 4),
219
231
  encoder_num_heads=2,
220
- encoder_global_att_blocks=[7, 15, 23, 31],
221
- encoder_backbone_channel_list=[1152, 576, 288, 144],
222
- encoder_window_spatial_size=[7, 7],
223
- encoder_window_spec=[8, 4, 16, 8],
232
+ encoder_global_att_blocks=(7, 15, 23, 31),
233
+ encoder_backbone_channel_list=(1152, 576, 288, 144),
234
+ encoder_window_spatial_size=(7, 7),
235
+ encoder_window_spec=(8, 4, 16, 8),
224
236
  checkpoint=None,
225
237
  ):
226
- """
227
- Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
238
+ """Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
228
239
 
229
240
  Args:
230
241
  encoder_embed_dim (int, optional): Embedding dimension for the encoder.
@@ -300,10 +311,7 @@ def _build_sam2(
300
311
  )
301
312
 
302
313
  if checkpoint is not None:
303
- checkpoint = attempt_download_asset(checkpoint)
304
- with open(checkpoint, "rb") as f:
305
- state_dict = torch.load(f)["model"]
306
- sam2.load_state_dict(state_dict)
314
+ sam2 = _load_checkpoint(sam2, checkpoint)
307
315
  sam2.eval()
308
316
  return sam2
309
317
 
@@ -325,8 +333,7 @@ sam_model_map = {
325
333
 
326
334
 
327
335
  def build_sam(ckpt="sam_b.pt"):
328
- """
329
- Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
336
+ """Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
330
337
 
331
338
  Args:
332
339
  ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.