ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import math
4
4
  from itertools import product
@@ -11,7 +11,7 @@ import torch
11
11
  def is_box_near_crop_edge(
12
12
  boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
13
13
  ) -> torch.Tensor:
14
- """Return a boolean tensor indicating if boxes are near the crop edge."""
14
+ """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
15
15
  crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
16
16
  orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
17
17
  boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
@@ -22,7 +22,7 @@ def is_box_near_crop_edge(
22
22
 
23
23
 
24
24
  def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
25
- """Yield batches of data from the input arguments."""
25
+ """Yields batches of data from input arguments with specified batch size for efficient processing."""
26
26
  assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
27
27
  n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
28
28
  for b in range(n_batches):
@@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
33
33
  """
34
34
  Computes the stability score for a batch of masks.
35
35
 
36
- The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
37
- and low values.
36
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
37
+ high and low values.
38
+
39
+ Args:
40
+ masks (torch.Tensor): Batch of predicted mask logits.
41
+ mask_threshold (float): Threshold value for creating binary masks.
42
+ threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
43
+
44
+ Returns:
45
+ (torch.Tensor): Stability scores for each mask in the batch.
38
46
 
39
47
  Notes:
40
48
  - One mask is always contained inside the other.
41
- - Save memory by preventing unnecessary cast to torch.int64
49
+ - Memory is saved by preventing unnecessary cast to torch.int64.
50
+
51
+ Examples:
52
+ >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
53
+ >>> mask_threshold = 0.5
54
+ >>> threshold_offset = 0.1
55
+ >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
42
56
  """
43
57
  intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
44
58
  unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
46
60
 
47
61
 
48
62
  def build_point_grid(n_per_side: int) -> np.ndarray:
49
- """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
63
+ """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
50
64
  offset = 1 / (2 * n_per_side)
51
65
  points_one_side = np.linspace(offset, 1 - offset, n_per_side)
52
66
  points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
@@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
55
69
 
56
70
 
57
71
  def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
58
- """Generate point grids for all crop layers."""
72
+ """Generates point grids for multiple crop layers with varying scales and densities."""
59
73
  return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
60
74
 
61
75
 
62
76
  def generate_crop_boxes(
63
77
  im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
64
78
  ) -> Tuple[List[List[int]], List[int]]:
65
- """
66
- Generates a list of crop boxes of different sizes.
67
-
68
- Each layer has (2**i)**2 boxes for the ith layer.
69
- """
79
+ """Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions."""
70
80
  crop_boxes, layer_idxs = [], []
71
81
  im_h, im_w = im_size
72
82
  short_side = min(im_h, im_w)
@@ -99,7 +109,7 @@ def generate_crop_boxes(
99
109
 
100
110
 
101
111
  def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
102
- """Uncrop bounding boxes by adding the crop box offset."""
112
+ """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
103
113
  x0, y0, _, _ = crop_box
104
114
  offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
105
115
  # Check if boxes has a channel dimension
@@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
109
119
 
110
120
 
111
121
  def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
112
- """Uncrop points by adding the crop box offset."""
122
+ """Uncrop points by adding the crop box offset to their coordinates."""
113
123
  x0, y0, _, _ = crop_box
114
124
  offset = torch.tensor([[x0, y0]], device=points.device)
115
125
  # Check if points has a channel dimension
@@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
119
129
 
120
130
 
121
131
  def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
122
- """Uncrop masks by padding them to the original image size."""
132
+ """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
123
133
  x0, y0, x1, y1 = crop_box
124
134
  if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
125
135
  return masks
@@ -130,10 +140,10 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
130
140
 
131
141
 
132
142
  def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
133
- """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
143
+ """Removes small disconnected regions or holes in a mask based on area threshold and mode."""
134
144
  import cv2 # type: ignore
135
145
 
136
- assert mode in {"holes", "islands"}
146
+ assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
137
147
  correct_holes = mode == "holes"
138
148
  working_mask = (correct_holes ^ mask).astype(np.uint8)
139
149
  n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
@@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
150
160
 
151
161
 
152
162
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
153
- """
154
- Calculates boxes in XYXY format around masks.
155
-
156
- Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
157
- """
163
+ """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
158
164
  # torch.max below raises an error on empty inputs, just skip in this case
159
165
  if torch.numel(masks) == 0:
160
166
  return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  # Copyright (c) Meta Platforms, Inc. and affiliates.
4
4
  # All rights reserved.
@@ -11,15 +11,17 @@ from functools import partial
11
11
  import torch
12
12
 
13
13
  from ultralytics.utils.downloads import attempt_download_asset
14
+
14
15
  from .modules.decoders import MaskDecoder
15
- from .modules.encoders import ImageEncoderViT, PromptEncoder
16
- from .modules.sam import Sam
16
+ from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
17
+ from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
18
+ from .modules.sam import SAM2Model, SAMModel
17
19
  from .modules.tiny_encoder import TinyViT
18
20
  from .modules.transformer import TwoWayTransformer
19
21
 
20
22
 
21
23
  def build_sam_vit_h(checkpoint=None):
22
- """Build and return a Segment Anything Model (SAM) h-size model."""
24
+ """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
23
25
  return _build_sam(
24
26
  encoder_embed_dim=1280,
25
27
  encoder_depth=32,
@@ -30,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
30
32
 
31
33
 
32
34
  def build_sam_vit_l(checkpoint=None):
33
- """Build and return a Segment Anything Model (SAM) l-size model."""
35
+ """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
34
36
  return _build_sam(
35
37
  encoder_embed_dim=1024,
36
38
  encoder_depth=24,
@@ -41,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
41
43
 
42
44
 
43
45
  def build_sam_vit_b(checkpoint=None):
44
- """Build and return a Segment Anything Model (SAM) b-size model."""
46
+ """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
45
47
  return _build_sam(
46
48
  encoder_embed_dim=768,
47
49
  encoder_depth=12,
@@ -52,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
52
54
 
53
55
 
54
56
  def build_mobile_sam(checkpoint=None):
55
- """Build and return Mobile Segment Anything Model (Mobile-SAM)."""
57
+ """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
56
58
  return _build_sam(
57
59
  encoder_embed_dim=[64, 128, 160, 320],
58
60
  encoder_depth=[2, 2, 6, 2],
@@ -63,10 +65,85 @@ def build_mobile_sam(checkpoint=None):
63
65
  )
64
66
 
65
67
 
68
+ def build_sam2_t(checkpoint=None):
69
+ """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
70
+ return _build_sam2(
71
+ encoder_embed_dim=96,
72
+ encoder_stages=[1, 2, 7, 2],
73
+ encoder_num_heads=1,
74
+ encoder_global_att_blocks=[5, 7, 9],
75
+ encoder_window_spec=[8, 4, 14, 7],
76
+ encoder_backbone_channel_list=[768, 384, 192, 96],
77
+ checkpoint=checkpoint,
78
+ )
79
+
80
+
81
+ def build_sam2_s(checkpoint=None):
82
+ """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
83
+ return _build_sam2(
84
+ encoder_embed_dim=96,
85
+ encoder_stages=[1, 2, 11, 2],
86
+ encoder_num_heads=1,
87
+ encoder_global_att_blocks=[7, 10, 13],
88
+ encoder_window_spec=[8, 4, 14, 7],
89
+ encoder_backbone_channel_list=[768, 384, 192, 96],
90
+ checkpoint=checkpoint,
91
+ )
92
+
93
+
94
+ def build_sam2_b(checkpoint=None):
95
+ """Builds and returns a SAM2 base-size model with specified architecture parameters."""
96
+ return _build_sam2(
97
+ encoder_embed_dim=112,
98
+ encoder_stages=[2, 3, 16, 3],
99
+ encoder_num_heads=2,
100
+ encoder_global_att_blocks=[12, 16, 20],
101
+ encoder_window_spec=[8, 4, 14, 7],
102
+ encoder_window_spatial_size=[14, 14],
103
+ encoder_backbone_channel_list=[896, 448, 224, 112],
104
+ checkpoint=checkpoint,
105
+ )
106
+
107
+
108
+ def build_sam2_l(checkpoint=None):
109
+ """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
110
+ return _build_sam2(
111
+ encoder_embed_dim=144,
112
+ encoder_stages=[2, 6, 36, 4],
113
+ encoder_num_heads=2,
114
+ encoder_global_att_blocks=[23, 33, 43],
115
+ encoder_window_spec=[8, 4, 16, 8],
116
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
117
+ checkpoint=checkpoint,
118
+ )
119
+
120
+
66
121
  def _build_sam(
67
- encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
122
+ encoder_embed_dim,
123
+ encoder_depth,
124
+ encoder_num_heads,
125
+ encoder_global_attn_indexes,
126
+ checkpoint=None,
127
+ mobile_sam=False,
68
128
  ):
69
- """Builds the selected SAM model architecture."""
129
+ """
130
+ Builds a Segment Anything Model (SAM) with specified encoder parameters.
131
+
132
+ Args:
133
+ encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
134
+ encoder_depth (int | List[int]): Depth of the encoder.
135
+ encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
136
+ encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
137
+ checkpoint (str | None): Path to the model checkpoint file.
138
+ mobile_sam (bool): Whether to build a Mobile-SAM model.
139
+
140
+ Returns:
141
+ (SAMModel): A Segment Anything Model instance with the specified architecture.
142
+
143
+ Examples:
144
+ >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
145
+ >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
146
+ """
70
147
  prompt_embed_dim = 256
71
148
  image_size = 1024
72
149
  vit_patch_size = 16
@@ -104,7 +181,7 @@ def _build_sam(
104
181
  out_chans=prompt_embed_dim,
105
182
  )
106
183
  )
107
- sam = Sam(
184
+ sam = SAMModel(
108
185
  image_encoder=image_encoder,
109
186
  prompt_encoder=PromptEncoder(
110
187
  embed_dim=prompt_embed_dim,
@@ -133,21 +210,142 @@ def _build_sam(
133
210
  state_dict = torch.load(f)
134
211
  sam.load_state_dict(state_dict)
135
212
  sam.eval()
136
- # sam.load_state_dict(torch.load(checkpoint), strict=True)
137
- # sam.eval()
138
213
  return sam
139
214
 
140
215
 
216
+ def _build_sam2(
217
+ encoder_embed_dim=1280,
218
+ encoder_stages=[2, 6, 36, 4],
219
+ encoder_num_heads=2,
220
+ encoder_global_att_blocks=[7, 15, 23, 31],
221
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
222
+ encoder_window_spatial_size=[7, 7],
223
+ encoder_window_spec=[8, 4, 16, 8],
224
+ checkpoint=None,
225
+ ):
226
+ """
227
+ Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
228
+
229
+ Args:
230
+ encoder_embed_dim (int): Embedding dimension for the encoder.
231
+ encoder_stages (List[int]): Number of blocks in each stage of the encoder.
232
+ encoder_num_heads (int): Number of attention heads in the encoder.
233
+ encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
234
+ encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
235
+ encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
236
+ encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
237
+ checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
238
+
239
+ Returns:
240
+ (SAM2Model): A configured and initialized SAM2 model.
241
+
242
+ Examples:
243
+ >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
244
+ >>> sam2_model.eval()
245
+ """
246
+ image_encoder = ImageEncoder(
247
+ trunk=Hiera(
248
+ embed_dim=encoder_embed_dim,
249
+ num_heads=encoder_num_heads,
250
+ stages=encoder_stages,
251
+ global_att_blocks=encoder_global_att_blocks,
252
+ window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
253
+ window_spec=encoder_window_spec,
254
+ ),
255
+ neck=FpnNeck(
256
+ d_model=256,
257
+ backbone_channel_list=encoder_backbone_channel_list,
258
+ fpn_top_down_levels=[2, 3],
259
+ fpn_interp_model="nearest",
260
+ ),
261
+ scalp=1,
262
+ )
263
+ memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
264
+ memory_encoder = MemoryEncoder(out_dim=64)
265
+
266
+ is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
267
+ sam2 = SAM2Model(
268
+ image_encoder=image_encoder,
269
+ memory_attention=memory_attention,
270
+ memory_encoder=memory_encoder,
271
+ num_maskmem=7,
272
+ image_size=1024,
273
+ sigmoid_scale_for_mem_enc=20.0,
274
+ sigmoid_bias_for_mem_enc=-10.0,
275
+ use_mask_input_as_output_without_sam=True,
276
+ directly_add_no_mem_embed=True,
277
+ use_high_res_features_in_sam=True,
278
+ multimask_output_in_sam=True,
279
+ iou_prediction_use_sigmoid=True,
280
+ use_obj_ptrs_in_encoder=True,
281
+ add_tpos_enc_to_obj_ptrs=True,
282
+ only_obj_ptrs_in_the_past_for_eval=True,
283
+ pred_obj_scores=True,
284
+ pred_obj_scores_mlp=True,
285
+ fixed_no_obj_ptr=True,
286
+ multimask_output_for_tracking=True,
287
+ use_multimask_token_for_obj_ptr=True,
288
+ multimask_min_pt_num=0,
289
+ multimask_max_pt_num=1,
290
+ use_mlp_for_obj_ptr_proj=True,
291
+ compile_image_encoder=False,
292
+ no_obj_embed_spatial=is_sam2_1,
293
+ proj_tpos_enc_in_obj_ptrs=is_sam2_1,
294
+ use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
295
+ sam_mask_decoder_extra_args=dict(
296
+ dynamic_multimask_via_stability=True,
297
+ dynamic_multimask_stability_delta=0.05,
298
+ dynamic_multimask_stability_thresh=0.98,
299
+ ),
300
+ )
301
+
302
+ if checkpoint is not None:
303
+ checkpoint = attempt_download_asset(checkpoint)
304
+ with open(checkpoint, "rb") as f:
305
+ state_dict = torch.load(f)["model"]
306
+ sam2.load_state_dict(state_dict)
307
+ sam2.eval()
308
+ return sam2
309
+
310
+
141
311
  sam_model_map = {
142
312
  "sam_h.pt": build_sam_vit_h,
143
313
  "sam_l.pt": build_sam_vit_l,
144
314
  "sam_b.pt": build_sam_vit_b,
145
315
  "mobile_sam.pt": build_mobile_sam,
316
+ "sam2_t.pt": build_sam2_t,
317
+ "sam2_s.pt": build_sam2_s,
318
+ "sam2_b.pt": build_sam2_b,
319
+ "sam2_l.pt": build_sam2_l,
320
+ "sam2.1_t.pt": build_sam2_t,
321
+ "sam2.1_s.pt": build_sam2_s,
322
+ "sam2.1_b.pt": build_sam2_b,
323
+ "sam2.1_l.pt": build_sam2_l,
146
324
  }
147
325
 
148
326
 
149
327
  def build_sam(ckpt="sam_b.pt"):
150
- """Build a SAM model specified by ckpt."""
328
+ """
329
+ Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
330
+
331
+ Args:
332
+ ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
333
+
334
+ Returns:
335
+ (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
336
+
337
+ Raises:
338
+ FileNotFoundError: If the provided checkpoint is not a supported SAM model.
339
+
340
+ Examples:
341
+ >>> sam_model = build_sam("sam_b.pt")
342
+ >>> sam_model = build_sam("path/to/custom_checkpoint.pt")
343
+
344
+ Notes:
345
+ Supported pre-defined models include:
346
+ - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
347
+ - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
348
+ """
151
349
  model_builder = None
152
350
  ckpt = str(ckpt) # to allow Path ckpt types
153
351
  for k in sam_model_map.keys():
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """
3
3
  SAM model interface.
4
4
 
@@ -18,40 +18,68 @@ from pathlib import Path
18
18
 
19
19
  from ultralytics.engine.model import Model
20
20
  from ultralytics.utils.torch_utils import model_info
21
+
21
22
  from .build import build_sam
22
- from .predict import Predictor
23
+ from .predict import Predictor, SAM2Predictor
23
24
 
24
25
 
25
26
  class SAM(Model):
26
27
  """
27
- SAM (Segment Anything Model) interface class.
28
-
29
- SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
30
- bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
31
- dataset.
28
+ SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
29
+
30
+ This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
31
+ promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
32
+ boxes, points, or labels, and features zero-shot performance capabilities.
33
+
34
+ Attributes:
35
+ model (torch.nn.Module): The loaded SAM model.
36
+ is_sam2 (bool): Indicates whether the model is SAM2 variant.
37
+ task (str): The task type, set to "segment" for SAM models.
38
+
39
+ Methods:
40
+ predict: Performs segmentation prediction on the given image or video source.
41
+ info: Logs information about the SAM model.
42
+
43
+ Examples:
44
+ >>> sam = SAM("sam_b.pt")
45
+ >>> results = sam.predict("image.jpg", points=[[500, 375]])
46
+ >>> for r in results:
47
+ >>> print(f"Detected {len(r.masks)} masks")
32
48
  """
33
49
 
34
50
  def __init__(self, model="sam_b.pt") -> None:
35
51
  """
36
- Initializes the SAM model with a pre-trained model file.
52
+ Initializes the SAM (Segment Anything Model) instance.
37
53
 
38
54
  Args:
39
55
  model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
40
56
 
41
57
  Raises:
42
58
  NotImplementedError: If the model file extension is not .pt or .pth.
59
+
60
+ Examples:
61
+ >>> sam = SAM("sam_b.pt")
62
+ >>> print(sam.is_sam2)
43
63
  """
44
- if model and Path(model).suffix not in (".pt", ".pth"):
64
+ if model and Path(model).suffix not in {".pt", ".pth"}:
45
65
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
66
+ self.is_sam2 = "sam2" in Path(model).stem
46
67
  super().__init__(model=model, task="segment")
47
68
 
48
69
  def _load(self, weights: str, task=None):
49
70
  """
50
71
  Loads the specified weights into the SAM model.
51
72
 
73
+ This method initializes the SAM model with the provided weights file, setting up the model architecture
74
+ and loading the pre-trained parameters.
75
+
52
76
  Args:
53
- weights (str): Path to the weights file.
54
- task (str, optional): Task name. Defaults to None.
77
+ weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
78
+ task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
79
+
80
+ Examples:
81
+ >>> sam = SAM("sam_b.pt")
82
+ >>> sam._load("path/to/custom_weights.pt")
55
83
  """
56
84
  self.model = build_sam(weights)
57
85
 
@@ -60,33 +88,51 @@ class SAM(Model):
60
88
  Performs segmentation prediction on the given image or video source.
61
89
 
62
90
  Args:
63
- source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
64
- stream (bool, optional): If True, enables real-time streaming. Defaults to False.
65
- bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
66
- points (list, optional): List of points for prompted segmentation. Defaults to None.
67
- labels (list, optional): List of labels for prompted segmentation. Defaults to None.
91
+ source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
92
+ a numpy.ndarray object.
93
+ stream (bool): If True, enables real-time streaming.
94
+ bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
95
+ points (List[List[float]] | None): List of points for prompted segmentation.
96
+ labels (List[int] | None): List of labels for prompted segmentation.
97
+ **kwargs (Any): Additional keyword arguments for prediction.
68
98
 
69
99
  Returns:
70
- (list): The model predictions.
100
+ (List): The model predictions.
101
+
102
+ Examples:
103
+ >>> sam = SAM("sam_b.pt")
104
+ >>> results = sam.predict("image.jpg", points=[[500, 375]])
105
+ >>> for r in results:
106
+ ... print(f"Detected {len(r.masks)} masks")
71
107
  """
72
108
  overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
73
- kwargs.update(overrides)
109
+ kwargs = {**overrides, **kwargs}
74
110
  prompts = dict(bboxes=bboxes, points=points, labels=labels)
75
111
  return super().predict(source, stream, prompts=prompts, **kwargs)
76
112
 
77
113
  def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
78
114
  """
79
- Alias for the 'predict' method.
115
+ Performs segmentation prediction on the given image or video source.
116
+
117
+ This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
118
+ for segmentation tasks.
80
119
 
81
120
  Args:
82
- source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
83
- stream (bool, optional): If True, enables real-time streaming. Defaults to False.
84
- bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
85
- points (list, optional): List of points for prompted segmentation. Defaults to None.
86
- labels (list, optional): List of labels for prompted segmentation. Defaults to None.
121
+ source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
122
+ object, or a numpy.ndarray object.
123
+ stream (bool): If True, enables real-time streaming.
124
+ bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
125
+ points (List[List[float]] | None): List of points for prompted segmentation.
126
+ labels (List[int] | None): List of labels for prompted segmentation.
127
+ **kwargs (Any): Additional keyword arguments to be passed to the predict method.
87
128
 
88
129
  Returns:
89
- (list): The model predictions.
130
+ (List): The model predictions, typically containing segmentation masks and other relevant information.
131
+
132
+ Examples:
133
+ >>> sam = SAM("sam_b.pt")
134
+ >>> results = sam("image.jpg", points=[[500, 375]])
135
+ >>> print(f"Detected {len(results[0].masks)} masks")
90
136
  """
91
137
  return self.predict(source, stream, bboxes, points, labels, **kwargs)
92
138
 
@@ -94,12 +140,20 @@ class SAM(Model):
94
140
  """
95
141
  Logs information about the SAM model.
96
142
 
143
+ This method provides details about the Segment Anything Model (SAM), including its architecture,
144
+ parameters, and computational requirements.
145
+
97
146
  Args:
98
- detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
99
- verbose (bool, optional): If True, displays information on the console. Defaults to True.
147
+ detailed (bool): If True, displays detailed information about the model layers and operations.
148
+ verbose (bool): If True, prints the information to the console.
100
149
 
101
150
  Returns:
102
- (tuple): A tuple containing the model's information.
151
+ (tuple): A tuple containing the model's information (string representations of the model).
152
+
153
+ Examples:
154
+ >>> sam = SAM("sam_b.pt")
155
+ >>> info = sam.info()
156
+ >>> print(info[0]) # Print summary information
103
157
  """
104
158
  return model_info(self.model, detailed=detailed, verbose=verbose)
105
159
 
@@ -109,6 +163,13 @@ class SAM(Model):
109
163
  Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
110
164
 
111
165
  Returns:
112
- (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
166
+ (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
167
+ class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
168
+
169
+ Examples:
170
+ >>> sam = SAM("sam_b.pt")
171
+ >>> task_map = sam.task_map
172
+ >>> print(task_map)
173
+ {'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
113
174
  """
114
- return {"segment": {"predictor": Predictor}}
175
+ return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
@@ -1 +1 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license