dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,10 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
3
8
  import torch
4
9
 
5
10
  from ultralytics.data import YOLODataset
@@ -11,48 +16,61 @@ __all__ = ("RTDETRValidator",) # tuple or list
11
16
 
12
17
 
13
18
  class RTDETRDataset(YOLODataset):
14
- """
15
- Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
19
+ """Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
16
20
 
17
21
  This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
18
22
  real-time detection and tracking tasks.
23
+
24
+ Attributes:
25
+ augment (bool): Whether to apply data augmentation.
26
+ rect (bool): Whether to use rectangular training.
27
+ use_segments (bool): Whether to use segmentation masks.
28
+ use_keypoints (bool): Whether to use keypoint annotations.
29
+ imgsz (int): Target image size for training.
30
+
31
+ Methods:
32
+ load_image: Load one image from dataset index.
33
+ build_transforms: Build transformation pipeline for the dataset.
34
+
35
+ Examples:
36
+ Initialize an RT-DETR dataset
37
+ >>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
38
+ >>> image, hw = dataset.load_image(0)
19
39
  """
20
40
 
21
41
  def __init__(self, *args, data=None, **kwargs):
22
- """
23
- Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
42
+ """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
24
43
 
25
44
  This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
26
45
  model, building upon the base YOLODataset functionality.
27
46
 
28
47
  Args:
29
48
  *args (Any): Variable length argument list passed to the parent YOLODataset class.
30
- data (Dict | None): Dictionary containing dataset information. If None, default values will be used.
49
+ data (dict | None): Dictionary containing dataset information. If None, default values will be used.
31
50
  **kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
32
51
  """
33
52
  super().__init__(*args, data=data, **kwargs)
34
53
 
35
54
  def load_image(self, i, rect_mode=False):
36
- """
37
- Load one image from dataset index 'i'.
55
+ """Load one image from dataset index 'i'.
38
56
 
39
57
  Args:
40
58
  i (int): Index of the image to load.
41
59
  rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
42
60
 
43
61
  Returns:
44
- im (numpy.ndarray): The loaded image.
62
+ im (torch.Tensor): The loaded image.
45
63
  resized_hw (tuple): Height and width of the resized image with shape (2,).
46
64
 
47
65
  Examples:
48
- >>> dataset = RTDETRDataset(...)
66
+ Load an image from the dataset
67
+ >>> dataset = RTDETRDataset(img_path="path/to/images")
49
68
  >>> image, hw = dataset.load_image(0)
50
69
  """
51
70
  return super().load_image(i=i, rect_mode=rect_mode)
52
71
 
53
72
  def build_transforms(self, hyp=None):
54
- """
55
- Build transformation pipeline for the dataset.
73
+ """Build transformation pipeline for the dataset.
56
74
 
57
75
  Args:
58
76
  hyp (dict, optional): Hyperparameters for transformations.
@@ -67,7 +85,7 @@ class RTDETRDataset(YOLODataset):
67
85
  transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
68
86
  else:
69
87
  # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
70
- transforms = Compose([])
88
+ transforms = Compose([lambda x: {**x, **{"ratio_pad": [x["ratio_pad"], [0, 0]]}}])
71
89
  transforms.append(
72
90
  Format(
73
91
  bbox_format="xywh",
@@ -83,30 +101,38 @@ class RTDETRDataset(YOLODataset):
83
101
 
84
102
 
85
103
  class RTDETRValidator(DetectionValidator):
86
- """
87
- RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
104
+ """RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
88
105
  the RT-DETR (Real-Time DETR) object detection model.
89
106
 
90
107
  The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
91
108
  post-processing, and updates evaluation metrics accordingly.
92
109
 
110
+ Attributes:
111
+ args (Namespace): Configuration arguments for validation.
112
+ data (dict): Dataset configuration dictionary.
113
+
114
+ Methods:
115
+ build_dataset: Build an RTDETR Dataset for validation.
116
+ postprocess: Apply Non-maximum suppression to prediction outputs.
117
+
93
118
  Examples:
119
+ Initialize and run RT-DETR validation
94
120
  >>> from ultralytics.models.rtdetr import RTDETRValidator
95
121
  >>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
96
122
  >>> validator = RTDETRValidator(args=args)
97
123
  >>> validator()
98
124
 
99
- Note:
125
+ Notes:
100
126
  For further details on the attributes and methods, refer to the parent DetectionValidator class.
101
127
  """
102
128
 
103
129
  def build_dataset(self, img_path, mode="val", batch=None):
104
- """
105
- Build an RTDETR Dataset.
130
+ """Build an RTDETR Dataset.
106
131
 
107
132
  Args:
108
133
  img_path (str): Path to the folder containing images.
109
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
134
+ mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
135
+ each mode.
110
136
  batch (int, optional): Size of batches, this is for `rect`.
111
137
 
112
138
  Returns:
@@ -124,15 +150,21 @@ class RTDETRValidator(DetectionValidator):
124
150
  data=self.data,
125
151
  )
126
152
 
127
- def postprocess(self, preds):
128
- """
129
- Apply Non-maximum suppression to prediction outputs.
153
+ def postprocess(
154
+ self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
155
+ ) -> list[dict[str, torch.Tensor]]:
156
+ """Apply Non-maximum suppression to prediction outputs.
130
157
 
131
158
  Args:
132
- preds (List | Tuple | torch.Tensor): Raw predictions from the model.
159
+ preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
160
+ (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
161
+ class scores.
133
162
 
134
163
  Returns:
135
- (List[torch.Tensor]): List of processed predictions for each image in batch.
164
+ (list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
165
+ - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
166
+ - 'conf': Tensor of shape (N,) with confidence scores
167
+ - 'cls': Tensor of shape (N,) with class indices
136
168
  """
137
169
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
138
170
  preds = [preds, None]
@@ -149,43 +181,31 @@ class RTDETRValidator(DetectionValidator):
149
181
  pred = pred[score.argsort(descending=True)]
150
182
  outputs[i] = pred[score > self.args.conf]
151
183
 
152
- return outputs
153
-
154
- def _prepare_batch(self, si, batch):
155
- """
156
- Prepares a batch for validation by applying necessary transformations.
157
-
158
- Args:
159
- si (int): Batch index.
160
- batch (dict): Batch data containing images and annotations.
184
+ return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
161
185
 
162
- Returns:
163
- (dict): Prepared batch with transformed annotations.
164
- """
165
- idx = batch["batch_idx"] == si
166
- cls = batch["cls"][idx].squeeze(-1)
167
- bbox = batch["bboxes"][idx]
168
- ori_shape = batch["ori_shape"][si]
169
- imgsz = batch["img"].shape[2:]
170
- ratio_pad = batch["ratio_pad"][si]
171
- if len(cls):
172
- bbox = ops.xywh2xyxy(bbox) # target boxes
173
- bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
174
- bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
175
- return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
176
-
177
- def _prepare_pred(self, pred, pbatch):
178
- """
179
- Prepares predictions by scaling bounding boxes to original image dimensions.
186
+ def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
187
+ """Serialize YOLO predictions to COCO json format.
180
188
 
181
189
  Args:
182
- pred (torch.Tensor): Raw predictions.
183
- pbatch (dict): Prepared batch information.
184
-
185
- Returns:
186
- (torch.Tensor): Predictions scaled to original image dimensions.
190
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
191
+ bounding box coordinates, confidence scores, and class predictions.
192
+ pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
187
193
  """
188
- predn = pred.clone()
189
- predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
190
- predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
191
- return predn.float()
194
+ path = Path(pbatch["im_file"])
195
+ stem = path.stem
196
+ image_id = int(stem) if stem.isnumeric() else stem
197
+ box = predn["bboxes"].clone()
198
+ box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
199
+ box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
200
+ box = ops.xyxy2xywh(box) # xywh
201
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
202
+ for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
203
+ self.jdict.append(
204
+ {
205
+ "image_id": image_id,
206
+ "file_name": path.name,
207
+ "category_id": self.class_map[int(c)],
208
+ "bbox": [round(x, 3) for x in b],
209
+ "score": round(s, 5),
210
+ }
211
+ )
@@ -1,6 +1,12 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .model import SAM
4
- from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
4
+ from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
5
5
 
6
- __all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list of exportable items
6
+ __all__ = (
7
+ "SAM",
8
+ "Predictor",
9
+ "SAM2DynamicInteractivePredictor",
10
+ "SAM2Predictor",
11
+ "SAM2VideoPredictor",
12
+ ) # tuple or list of exportable items
@@ -1,17 +1,36 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import math
6
+ from collections.abc import Generator
4
7
  from itertools import product
5
- from typing import Any, Generator, List, Tuple
8
+ from typing import Any
6
9
 
7
10
  import numpy as np
8
11
  import torch
9
12
 
10
13
 
11
14
  def is_box_near_crop_edge(
12
- boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
15
+ boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
13
16
  ) -> torch.Tensor:
14
- """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
17
+ """Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
18
+
19
+ Args:
20
+ boxes (torch.Tensor): Bounding boxes in XYXY format.
21
+ crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
22
+ orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
23
+ atol (float, optional): Absolute tolerance for edge proximity detection.
24
+
25
+ Returns:
26
+ (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
27
+
28
+ Examples:
29
+ >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
30
+ >>> crop_box = [0, 0, 200, 200]
31
+ >>> orig_box = [0, 0, 300, 300]
32
+ >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
33
+ """
15
34
  crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
16
35
  orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
17
36
  boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
@@ -21,9 +40,8 @@ def is_box_near_crop_edge(
21
40
  return torch.any(near_crop_edge, dim=1)
22
41
 
23
42
 
24
- def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
25
- """
26
- Yield batches of data from input arguments with specified batch size for efficient processing.
43
+ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
44
+ """Yield batches of data from input arguments with specified batch size for efficient processing.
27
45
 
28
46
  This function takes a batch size and any number of iterables, then yields batches of elements from those
29
47
  iterables. All input iterables must have the same length.
@@ -33,7 +51,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
33
51
  *args (Any): Variable length input iterables to batch. All iterables must have the same length.
34
52
 
35
53
  Yields:
36
- (List[Any]): A list of batched elements from each input iterable.
54
+ (list[Any]): A list of batched elements from each input iterable.
37
55
 
38
56
  Examples:
39
57
  >>> data = [1, 2, 3, 4, 5]
@@ -51,11 +69,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
51
69
 
52
70
 
53
71
  def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
54
- """
55
- Computes the stability score for a batch of masks.
72
+ """Compute the stability score for a batch of masks.
56
73
 
57
- The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
58
- high and low values.
74
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
75
+ low values.
59
76
 
60
77
  Args:
61
78
  masks (torch.Tensor): Batch of predicted mask logits.
@@ -65,15 +82,15 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
65
82
  Returns:
66
83
  (torch.Tensor): Stability scores for each mask in the batch.
67
84
 
68
- Notes:
69
- - One mask is always contained inside the other.
70
- - Memory is saved by preventing unnecessary cast to torch.int64.
71
-
72
85
  Examples:
73
86
  >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
74
87
  >>> mask_threshold = 0.5
75
88
  >>> threshold_offset = 0.1
76
89
  >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
90
+
91
+ Notes:
92
+ - One mask is always contained inside the other.
93
+ - Memory is saved by preventing unnecessary cast to torch.int64.
77
94
  """
78
95
  intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
79
96
  unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@@ -89,25 +106,24 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
89
106
  return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
90
107
 
91
108
 
92
- def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
93
- """Generates point grids for multiple crop layers with varying scales and densities."""
109
+ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
110
+ """Generate point grids for multiple crop layers with varying scales and densities."""
94
111
  return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
95
112
 
96
113
 
97
114
  def generate_crop_boxes(
98
- im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
99
- ) -> Tuple[List[List[int]], List[int]]:
100
- """
101
- Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
115
+ im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
116
+ ) -> tuple[list[list[int]], list[int]]:
117
+ """Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
102
118
 
103
119
  Args:
104
- im_size (Tuple[int, ...]): Height and width of the input image.
120
+ im_size (tuple[int, ...]): Height and width of the input image.
105
121
  n_layers (int): Number of layers to generate crop boxes for.
106
122
  overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
107
123
 
108
124
  Returns:
109
- (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
110
- (List[int]): List of layer indices corresponding to each crop box.
125
+ crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
126
+ layer_idxs (list[int]): List of layer indices corresponding to each crop box.
111
127
 
112
128
  Examples:
113
129
  >>> im_size = (800, 1200) # Height, width
@@ -124,8 +140,8 @@ def generate_crop_boxes(
124
140
  layer_idxs.append(0)
125
141
 
126
142
  def crop_len(orig_len, n_crops, overlap):
127
- """Calculates the length of each crop given the original length, number of crops, and overlap."""
128
- return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
143
+ """Calculate the length of each crop given the original length, number of crops, and overlap."""
144
+ return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)
129
145
 
130
146
  for i_layer in range(n_layers):
131
147
  n_crops_per_side = 2 ** (i_layer + 1)
@@ -146,7 +162,7 @@ def generate_crop_boxes(
146
162
  return crop_boxes, layer_idxs
147
163
 
148
164
 
149
- def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
165
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
150
166
  """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
151
167
  x0, y0, _, _ = crop_box
152
168
  offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
@@ -156,7 +172,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
156
172
  return boxes + offset
157
173
 
158
174
 
159
- def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
175
+ def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
160
176
  """Uncrop points by adding the crop box offset to their coordinates."""
161
177
  x0, y0, _, _ = crop_box
162
178
  offset = torch.tensor([[x0, y0]], device=points.device)
@@ -166,7 +182,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
166
182
  return points + offset
167
183
 
168
184
 
169
- def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
185
+ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
170
186
  """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
171
187
  x0, y0, x1, y1 = crop_box
172
188
  if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
@@ -177,18 +193,18 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
177
193
  return torch.nn.functional.pad(masks, pad, value=0)
178
194
 
179
195
 
180
- def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
181
- """
182
- Removes small disconnected regions or holes in a mask based on area threshold and mode.
196
+ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
197
+ """Remove small disconnected regions or holes in a mask based on area threshold and mode.
183
198
 
184
199
  Args:
185
200
  mask (np.ndarray): Binary mask to process.
186
201
  area_thresh (float): Area threshold below which regions will be removed.
187
- mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected regions.
202
+ mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
203
+ regions.
188
204
 
189
205
  Returns:
190
- (np.ndarray): Processed binary mask with small regions removed.
191
- (bool): Whether any regions were modified.
206
+ processed_mask (np.ndarray): Processed binary mask with small regions removed.
207
+ modified (bool): Whether any regions were modified.
192
208
 
193
209
  Examples:
194
210
  >>> mask = np.zeros((100, 100), dtype=np.bool_)
@@ -206,7 +222,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
206
222
  small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
207
223
  if not small_regions:
208
224
  return mask, False
209
- fill_labels = [0] + small_regions
225
+ fill_labels = [0, *small_regions]
210
226
  if not correct_holes:
211
227
  # If every region is below threshold, keep largest
212
228
  fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
@@ -215,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
215
231
 
216
232
 
217
233
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
218
- """
219
- Calculates bounding boxes in XYXY format around binary masks.
234
+ """Calculate bounding boxes in XYXY format around binary masks.
220
235
 
221
236
  Args:
222
237
  masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
@@ -11,6 +11,7 @@ from functools import partial
11
11
  import torch
12
12
 
13
13
  from ultralytics.utils.downloads import attempt_download_asset
14
+ from ultralytics.utils.torch_utils import TORCH_1_13
14
15
 
15
16
  from .modules.decoders import MaskDecoder
16
17
  from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
@@ -21,7 +22,7 @@ from .modules.transformer import TwoWayTransformer
21
22
 
22
23
 
23
24
  def build_sam_vit_h(checkpoint=None):
24
- """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
25
+ """Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
25
26
  return _build_sam(
26
27
  encoder_embed_dim=1280,
27
28
  encoder_depth=32,
@@ -32,7 +33,7 @@ def build_sam_vit_h(checkpoint=None):
32
33
 
33
34
 
34
35
  def build_sam_vit_l(checkpoint=None):
35
- """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
36
+ """Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
36
37
  return _build_sam(
37
38
  encoder_embed_dim=1024,
38
39
  encoder_depth=24,
@@ -43,7 +44,7 @@ def build_sam_vit_l(checkpoint=None):
43
44
 
44
45
 
45
46
  def build_sam_vit_b(checkpoint=None):
46
- """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
47
+ """Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
47
48
  return _build_sam(
48
49
  encoder_embed_dim=768,
49
50
  encoder_depth=12,
@@ -54,7 +55,7 @@ def build_sam_vit_b(checkpoint=None):
54
55
 
55
56
 
56
57
  def build_mobile_sam(checkpoint=None):
57
- """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
58
+ """Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
58
59
  return _build_sam(
59
60
  encoder_embed_dim=[64, 128, 160, 320],
60
61
  encoder_depth=[2, 2, 6, 2],
@@ -66,7 +67,7 @@ def build_mobile_sam(checkpoint=None):
66
67
 
67
68
 
68
69
  def build_sam2_t(checkpoint=None):
69
- """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
70
+ """Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
70
71
  return _build_sam2(
71
72
  encoder_embed_dim=96,
72
73
  encoder_stages=[1, 2, 7, 2],
@@ -79,7 +80,7 @@ def build_sam2_t(checkpoint=None):
79
80
 
80
81
 
81
82
  def build_sam2_s(checkpoint=None):
82
- """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
83
+ """Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
83
84
  return _build_sam2(
84
85
  encoder_embed_dim=96,
85
86
  encoder_stages=[1, 2, 11, 2],
@@ -92,7 +93,7 @@ def build_sam2_s(checkpoint=None):
92
93
 
93
94
 
94
95
  def build_sam2_b(checkpoint=None):
95
- """Builds and returns a SAM2 base-size model with specified architecture parameters."""
96
+ """Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
96
97
  return _build_sam2(
97
98
  encoder_embed_dim=112,
98
99
  encoder_stages=[2, 3, 16, 3],
@@ -106,7 +107,7 @@ def build_sam2_b(checkpoint=None):
106
107
 
107
108
 
108
109
  def build_sam2_l(checkpoint=None):
109
- """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
110
+ """Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
110
111
  return _build_sam2(
111
112
  encoder_embed_dim=144,
112
113
  encoder_stages=[2, 6, 36, 4],
@@ -126,16 +127,15 @@ def _build_sam(
126
127
  checkpoint=None,
127
128
  mobile_sam=False,
128
129
  ):
129
- """
130
- Builds a Segment Anything Model (SAM) with specified encoder parameters.
130
+ """Build a Segment Anything Model (SAM) with specified encoder parameters.
131
131
 
132
132
  Args:
133
- encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
134
- encoder_depth (int | List[int]): Depth of the encoder.
135
- encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
136
- encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
137
- checkpoint (str | None): Path to the model checkpoint file.
138
- mobile_sam (bool): Whether to build a Mobile-SAM model.
133
+ encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
134
+ encoder_depth (int | list[int]): Depth of the encoder.
135
+ encoder_num_heads (int | list[int]): Number of attention heads in the encoder.
136
+ encoder_global_attn_indexes (list[int] | None): Indexes for global attention in the encoder.
137
+ checkpoint (str | None, optional): Path to the model checkpoint file.
138
+ mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
139
139
 
140
140
  Returns:
141
141
  (SAMModel): A Segment Anything Model instance with the specified architecture.
@@ -207,7 +207,7 @@ def _build_sam(
207
207
  if checkpoint is not None:
208
208
  checkpoint = attempt_download_asset(checkpoint)
209
209
  with open(checkpoint, "rb") as f:
210
- state_dict = torch.load(f)
210
+ state_dict = torch.load(f, weights_only=False) if TORCH_1_13 else torch.load(f)
211
211
  sam.load_state_dict(state_dict)
212
212
  sam.eval()
213
213
  return sam
@@ -223,18 +223,17 @@ def _build_sam2(
223
223
  encoder_window_spec=[8, 4, 16, 8],
224
224
  checkpoint=None,
225
225
  ):
226
- """
227
- Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
226
+ """Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
228
227
 
229
228
  Args:
230
- encoder_embed_dim (int): Embedding dimension for the encoder.
231
- encoder_stages (List[int]): Number of blocks in each stage of the encoder.
232
- encoder_num_heads (int): Number of attention heads in the encoder.
233
- encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
234
- encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
235
- encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
236
- encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
237
- checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
229
+ encoder_embed_dim (int, optional): Embedding dimension for the encoder.
230
+ encoder_stages (list[int], optional): Number of blocks in each stage of the encoder.
231
+ encoder_num_heads (int, optional): Number of attention heads in the encoder.
232
+ encoder_global_att_blocks (list[int], optional): Indices of global attention blocks in the encoder.
233
+ encoder_backbone_channel_list (list[int], optional): Channel dimensions for each level of the encoder backbone.
234
+ encoder_window_spatial_size (list[int], optional): Spatial size of the window for position embeddings.
235
+ encoder_window_spec (list[int], optional): Window specifications for each stage of the encoder.
236
+ checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
238
237
 
239
238
  Returns:
240
239
  (SAM2Model): A configured and initialized SAM2 model.
@@ -302,7 +301,7 @@ def _build_sam2(
302
301
  if checkpoint is not None:
303
302
  checkpoint = attempt_download_asset(checkpoint)
304
303
  with open(checkpoint, "rb") as f:
305
- state_dict = torch.load(f)["model"]
304
+ state_dict = (torch.load(f, weights_only=False) if TORCH_1_13 else torch.load(f))["model"]
306
305
  sam2.load_state_dict(state_dict)
307
306
  sam2.eval()
308
307
  return sam2
@@ -325,11 +324,10 @@ sam_model_map = {
325
324
 
326
325
 
327
326
  def build_sam(ckpt="sam_b.pt"):
328
- """
329
- Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
327
+ """Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
330
328
 
331
329
  Args:
332
- ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
330
+ ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
333
331
 
334
332
  Returns:
335
333
  (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.