dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,7 @@ __all__ = ("RTDETRValidator",) # tuple or list
16
16
 
17
17
 
18
18
  class RTDETRDataset(YOLODataset):
19
- """
20
- Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
19
+ """Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
21
20
 
22
21
  This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
23
22
  real-time detection and tracking tasks.
@@ -36,12 +35,11 @@ class RTDETRDataset(YOLODataset):
36
35
  Examples:
37
36
  Initialize an RT-DETR dataset
38
37
  >>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
39
- >>> image, hw = dataset.load_image(0)
38
+ >>> image, hw0, hw = dataset.load_image(0)
40
39
  """
41
40
 
42
41
  def __init__(self, *args, data=None, **kwargs):
43
- """
44
- Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
42
+ """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
45
43
 
46
44
  This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
47
45
  model, building upon the base YOLODataset functionality.
@@ -54,27 +52,26 @@ class RTDETRDataset(YOLODataset):
54
52
  super().__init__(*args, data=data, **kwargs)
55
53
 
56
54
  def load_image(self, i, rect_mode=False):
57
- """
58
- Load one image from dataset index 'i'.
55
+ """Load one image from dataset index 'i'.
59
56
 
60
57
  Args:
61
58
  i (int): Index of the image to load.
62
59
  rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
63
60
 
64
61
  Returns:
65
- im (torch.Tensor): The loaded image.
66
- resized_hw (tuple): Height and width of the resized image with shape (2,).
62
+ im (np.ndarray): Loaded image as a NumPy array.
63
+ hw_original (tuple[int, int]): Original image dimensions in (height, width) format.
64
+ hw_resized (tuple[int, int]): Resized image dimensions in (height, width) format.
67
65
 
68
66
  Examples:
69
67
  Load an image from the dataset
70
68
  >>> dataset = RTDETRDataset(img_path="path/to/images")
71
- >>> image, hw = dataset.load_image(0)
69
+ >>> image, hw0, hw = dataset.load_image(0)
72
70
  """
73
71
  return super().load_image(i=i, rect_mode=rect_mode)
74
72
 
75
73
  def build_transforms(self, hyp=None):
76
- """
77
- Build transformation pipeline for the dataset.
74
+ """Build transformation pipeline for the dataset.
78
75
 
79
76
  Args:
80
77
  hyp (dict, optional): Hyperparameters for transformations.
@@ -105,8 +102,7 @@ class RTDETRDataset(YOLODataset):
105
102
 
106
103
 
107
104
  class RTDETRValidator(DetectionValidator):
108
- """
109
- RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
105
+ """RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
110
106
  the RT-DETR (Real-Time DETR) object detection model.
111
107
 
112
108
  The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
@@ -132,8 +128,7 @@ class RTDETRValidator(DetectionValidator):
132
128
  """
133
129
 
134
130
  def build_dataset(self, img_path, mode="val", batch=None):
135
- """
136
- Build an RTDETR Dataset.
131
+ """Build an RTDETR Dataset.
137
132
 
138
133
  Args:
139
134
  img_path (str): Path to the folder containing images.
@@ -156,15 +151,19 @@ class RTDETRValidator(DetectionValidator):
156
151
  data=self.data,
157
152
  )
158
153
 
154
+ def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
155
+ """Scales predictions to the original image size."""
156
+ return predn
157
+
159
158
  def postprocess(
160
159
  self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
161
160
  ) -> list[dict[str, torch.Tensor]]:
162
- """
163
- Apply Non-maximum suppression to prediction outputs.
161
+ """Apply Non-maximum suppression to prediction outputs.
164
162
 
165
163
  Args:
166
164
  preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
167
- (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
165
+ (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
166
+ class scores.
168
167
 
169
168
  Returns:
170
169
  (list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
@@ -190,12 +189,11 @@ class RTDETRValidator(DetectionValidator):
190
189
  return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
191
190
 
192
191
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
193
- """
194
- Serialize YOLO predictions to COCO json format.
192
+ """Serialize YOLO predictions to COCO json format.
195
193
 
196
194
  Args:
197
- predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
198
- with bounding box coordinates, confidence scores, and class predictions.
195
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
196
+ bounding box coordinates, confidence scores, and class predictions.
199
197
  pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
200
198
  """
201
199
  path = Path(pbatch["im_file"])
@@ -1,12 +1,25 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .model import SAM
4
- from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
4
+ from .predict import (
5
+ Predictor,
6
+ SAM2DynamicInteractivePredictor,
7
+ SAM2Predictor,
8
+ SAM2VideoPredictor,
9
+ SAM3Predictor,
10
+ SAM3SemanticPredictor,
11
+ SAM3VideoPredictor,
12
+ SAM3VideoSemanticPredictor,
13
+ )
5
14
 
6
15
  __all__ = (
7
16
  "SAM",
8
17
  "Predictor",
18
+ "SAM2DynamicInteractivePredictor",
9
19
  "SAM2Predictor",
10
20
  "SAM2VideoPredictor",
11
- "SAM2DynamicInteractivePredictor",
21
+ "SAM3Predictor",
22
+ "SAM3SemanticPredictor",
23
+ "SAM3VideoPredictor",
24
+ "SAM3VideoSemanticPredictor",
12
25
  ) # tuple or list of exportable items
@@ -14,8 +14,7 @@ import torch
14
14
  def is_box_near_crop_edge(
15
15
  boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
16
16
  ) -> torch.Tensor:
17
- """
18
- Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
17
+ """Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
19
18
 
20
19
  Args:
21
20
  boxes (torch.Tensor): Bounding boxes in XYXY format.
@@ -42,8 +41,7 @@ def is_box_near_crop_edge(
42
41
 
43
42
 
44
43
  def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
45
- """
46
- Yield batches of data from input arguments with specified batch size for efficient processing.
44
+ """Yield batches of data from input arguments with specified batch size for efficient processing.
47
45
 
48
46
  This function takes a batch size and any number of iterables, then yields batches of elements from those
49
47
  iterables. All input iterables must have the same length.
@@ -71,11 +69,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
71
69
 
72
70
 
73
71
  def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
74
- """
75
- Compute the stability score for a batch of masks.
72
+ """Compute the stability score for a batch of masks.
76
73
 
77
- The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
78
- high and low values.
74
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
75
+ low values.
79
76
 
80
77
  Args:
81
78
  masks (torch.Tensor): Batch of predicted mask logits.
@@ -85,15 +82,15 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
85
82
  Returns:
86
83
  (torch.Tensor): Stability scores for each mask in the batch.
87
84
 
88
- Notes:
89
- - One mask is always contained inside the other.
90
- - Memory is saved by preventing unnecessary cast to torch.int64.
91
-
92
85
  Examples:
93
86
  >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
94
87
  >>> mask_threshold = 0.5
95
88
  >>> threshold_offset = 0.1
96
89
  >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
90
+
91
+ Notes:
92
+ - One mask is always contained inside the other.
93
+ - Memory is saved by preventing unnecessary cast to torch.int64.
97
94
  """
98
95
  intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
99
96
  unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@@ -117,8 +114,7 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
117
114
  def generate_crop_boxes(
118
115
  im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
119
116
  ) -> tuple[list[list[int]], list[int]]:
120
- """
121
- Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
117
+ """Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
122
118
 
123
119
  Args:
124
120
  im_size (tuple[int, ...]): Height and width of the input image.
@@ -145,7 +141,7 @@ def generate_crop_boxes(
145
141
 
146
142
  def crop_len(orig_len, n_crops, overlap):
147
143
  """Calculate the length of each crop given the original length, number of crops, and overlap."""
148
- return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
144
+ return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)
149
145
 
150
146
  for i_layer in range(n_layers):
151
147
  n_crops_per_side = 2 ** (i_layer + 1)
@@ -198,8 +194,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w:
198
194
 
199
195
 
200
196
  def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
201
- """
202
- Remove small disconnected regions or holes in a mask based on area threshold and mode.
197
+ """Remove small disconnected regions or holes in a mask based on area threshold and mode.
203
198
 
204
199
  Args:
205
200
  mask (np.ndarray): Binary mask to process.
@@ -227,7 +222,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
227
222
  small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
228
223
  if not small_regions:
229
224
  return mask, False
230
- fill_labels = [0] + small_regions
225
+ fill_labels = [0, *small_regions]
231
226
  if not correct_holes:
232
227
  # If every region is below threshold, keep largest
233
228
  fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
@@ -236,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
236
231
 
237
232
 
238
233
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
239
- """
240
- Calculate bounding boxes in XYXY format around binary masks.
234
+ """Calculate bounding boxes in XYXY format around binary masks.
241
235
 
242
236
  Args:
243
237
  masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
@@ -11,6 +11,7 @@ from functools import partial
11
11
  import torch
12
12
 
13
13
  from ultralytics.utils.downloads import attempt_download_asset
14
+ from ultralytics.utils.patches import torch_load
14
15
 
15
16
  from .modules.decoders import MaskDecoder
16
17
  from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
@@ -20,6 +21,21 @@ from .modules.tiny_encoder import TinyViT
20
21
  from .modules.transformer import TwoWayTransformer
21
22
 
22
23
 
24
+ def _load_checkpoint(model, checkpoint):
25
+ """Load checkpoint into model from file path."""
26
+ if checkpoint is None:
27
+ return model
28
+
29
+ checkpoint = attempt_download_asset(checkpoint)
30
+ with open(checkpoint, "rb") as f:
31
+ state_dict = torch_load(f)
32
+ # Handle nested "model" key
33
+ if "model" in state_dict and isinstance(state_dict["model"], dict):
34
+ state_dict = state_dict["model"]
35
+ model.load_state_dict(state_dict)
36
+ return model
37
+
38
+
23
39
  def build_sam_vit_h(checkpoint=None):
24
40
  """Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
25
41
  return _build_sam(
@@ -126,8 +142,7 @@ def _build_sam(
126
142
  checkpoint=None,
127
143
  mobile_sam=False,
128
144
  ):
129
- """
130
- Build a Segment Anything Model (SAM) with specified encoder parameters.
145
+ """Build a Segment Anything Model (SAM) with specified encoder parameters.
131
146
 
132
147
  Args:
133
148
  encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
@@ -205,26 +220,22 @@ def _build_sam(
205
220
  pixel_std=[58.395, 57.12, 57.375],
206
221
  )
207
222
  if checkpoint is not None:
208
- checkpoint = attempt_download_asset(checkpoint)
209
- with open(checkpoint, "rb") as f:
210
- state_dict = torch.load(f)
211
- sam.load_state_dict(state_dict)
223
+ sam = _load_checkpoint(sam, checkpoint)
212
224
  sam.eval()
213
225
  return sam
214
226
 
215
227
 
216
228
  def _build_sam2(
217
229
  encoder_embed_dim=1280,
218
- encoder_stages=[2, 6, 36, 4],
230
+ encoder_stages=(2, 6, 36, 4),
219
231
  encoder_num_heads=2,
220
- encoder_global_att_blocks=[7, 15, 23, 31],
221
- encoder_backbone_channel_list=[1152, 576, 288, 144],
222
- encoder_window_spatial_size=[7, 7],
223
- encoder_window_spec=[8, 4, 16, 8],
232
+ encoder_global_att_blocks=(7, 15, 23, 31),
233
+ encoder_backbone_channel_list=(1152, 576, 288, 144),
234
+ encoder_window_spatial_size=(7, 7),
235
+ encoder_window_spec=(8, 4, 16, 8),
224
236
  checkpoint=None,
225
237
  ):
226
- """
227
- Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
238
+ """Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
228
239
 
229
240
  Args:
230
241
  encoder_embed_dim (int, optional): Embedding dimension for the encoder.
@@ -300,10 +311,7 @@ def _build_sam2(
300
311
  )
301
312
 
302
313
  if checkpoint is not None:
303
- checkpoint = attempt_download_asset(checkpoint)
304
- with open(checkpoint, "rb") as f:
305
- state_dict = torch.load(f)["model"]
306
- sam2.load_state_dict(state_dict)
314
+ sam2 = _load_checkpoint(sam2, checkpoint)
307
315
  sam2.eval()
308
316
  return sam2
309
317
 
@@ -325,8 +333,7 @@ sam_model_map = {
325
333
 
326
334
 
327
335
  def build_sam(ckpt="sam_b.pt"):
328
- """
329
- Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
336
+ """Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
330
337
 
331
338
  Args:
332
339
  ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.