ultralytics 8.2.72__py3-none-any.whl → 8.2.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (34) hide show
  1. ultralytics/__init__.py +2 -3
  2. ultralytics/cfg/trackers/botsort.yaml +1 -1
  3. ultralytics/cfg/trackers/bytetrack.yaml +1 -1
  4. ultralytics/models/__init__.py +1 -2
  5. ultralytics/models/sam/__init__.py +2 -2
  6. ultralytics/models/sam/amg.py +27 -21
  7. ultralytics/models/sam/build.py +200 -9
  8. ultralytics/models/sam/model.py +86 -34
  9. ultralytics/models/sam/modules/blocks.py +1131 -0
  10. ultralytics/models/sam/modules/decoders.py +390 -23
  11. ultralytics/models/sam/modules/encoders.py +508 -323
  12. ultralytics/models/{sam2 → sam}/modules/memory_attention.py +73 -6
  13. ultralytics/models/sam/modules/sam.py +887 -16
  14. ultralytics/models/sam/modules/tiny_encoder.py +376 -126
  15. ultralytics/models/sam/modules/transformer.py +155 -54
  16. ultralytics/models/{sam2 → sam}/modules/utils.py +105 -3
  17. ultralytics/models/sam/predict.py +382 -92
  18. ultralytics/trackers/bot_sort.py +2 -3
  19. ultralytics/trackers/byte_tracker.py +2 -3
  20. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/METADATA +44 -44
  21. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/RECORD +25 -33
  22. ultralytics/models/sam2/__init__.py +0 -6
  23. ultralytics/models/sam2/build.py +0 -156
  24. ultralytics/models/sam2/model.py +0 -97
  25. ultralytics/models/sam2/modules/__init__.py +0 -1
  26. ultralytics/models/sam2/modules/decoders.py +0 -305
  27. ultralytics/models/sam2/modules/encoders.py +0 -332
  28. ultralytics/models/sam2/modules/sam2.py +0 -804
  29. ultralytics/models/sam2/modules/sam2_blocks.py +0 -715
  30. ultralytics/models/sam2/predict.py +0 -177
  31. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/LICENSE +0 -0
  32. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/WHEEL +0 -0
  33. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/entry_points.txt +0 -0
  34. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.2.72"
3
+ __version__ = "8.2.74"
4
4
 
5
5
  import os
6
6
 
@@ -8,7 +8,7 @@ import os
8
8
  os.environ["OMP_NUM_THREADS"] = "1" # reduce CPU utilization during training
9
9
 
10
10
  from ultralytics.data.explorer.explorer import Explorer
11
- from ultralytics.models import NAS, RTDETR, SAM, SAM2, YOLO, FastSAM, YOLOWorld
11
+ from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
12
12
  from ultralytics.utils import ASSETS, SETTINGS
13
13
  from ultralytics.utils.checks import check_yolo as checks
14
14
  from ultralytics.utils.downloads import download
@@ -21,7 +21,6 @@ __all__ = (
21
21
  "YOLOWorld",
22
22
  "NAS",
23
23
  "SAM",
24
- "SAM2",
25
24
  "FastSAM",
26
25
  "RTDETR",
27
26
  "checks",
@@ -7,8 +7,8 @@ track_low_thresh: 0.1 # threshold for the second association
7
7
  new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8
8
  track_buffer: 30 # buffer to calculate the time when to remove tracks
9
9
  match_thresh: 0.8 # threshold for matching tracks
10
+ fuse_score: True # Whether to fuse confidence scores with the iou distances before matching
10
11
  # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11
- # mot20: False # for tracker evaluation(not used for now)
12
12
 
13
13
  # BoT-SORT settings
14
14
  gmc_method: sparseOptFlow # method of global motion compensation
@@ -7,5 +7,5 @@ track_low_thresh: 0.1 # threshold for the second association
7
7
  new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
8
8
  track_buffer: 30 # buffer to calculate the time when to remove tracks
9
9
  match_thresh: 0.8 # threshold for matching tracks
10
+ fuse_score: True # Whether to fuse confidence scores with the iou distances before matching
10
11
  # min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
11
- # mot20: False # for tracker evaluation(not used for now)
@@ -4,7 +4,6 @@ from .fastsam import FastSAM
4
4
  from .nas import NAS
5
5
  from .rtdetr import RTDETR
6
6
  from .sam import SAM
7
- from .sam2 import SAM2
8
7
  from .yolo import YOLO, YOLOWorld
9
8
 
10
- __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import
9
+ __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  from .model import SAM
4
- from .predict import Predictor
4
+ from .predict import Predictor, SAM2Predictor
5
5
 
6
- __all__ = "SAM", "Predictor" # tuple or list
6
+ __all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list
@@ -11,7 +11,7 @@ import torch
11
11
  def is_box_near_crop_edge(
12
12
  boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
13
13
  ) -> torch.Tensor:
14
- """Return a boolean tensor indicating if boxes are near the crop edge."""
14
+ """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
15
15
  crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
16
16
  orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
17
17
  boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
@@ -22,7 +22,7 @@ def is_box_near_crop_edge(
22
22
 
23
23
 
24
24
  def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
25
- """Yield batches of data from the input arguments."""
25
+ """Yields batches of data from input arguments with specified batch size for efficient processing."""
26
26
  assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
27
27
  n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
28
28
  for b in range(n_batches):
@@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
33
33
  """
34
34
  Computes the stability score for a batch of masks.
35
35
 
36
- The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
37
- and low values.
36
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
37
+ high and low values.
38
+
39
+ Args:
40
+ masks (torch.Tensor): Batch of predicted mask logits.
41
+ mask_threshold (float): Threshold value for creating binary masks.
42
+ threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
43
+
44
+ Returns:
45
+ (torch.Tensor): Stability scores for each mask in the batch.
38
46
 
39
47
  Notes:
40
48
  - One mask is always contained inside the other.
41
- - Save memory by preventing unnecessary cast to torch.int64
49
+ - Memory is saved by preventing unnecessary cast to torch.int64.
50
+
51
+ Examples:
52
+ >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
53
+ >>> mask_threshold = 0.5
54
+ >>> threshold_offset = 0.1
55
+ >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
42
56
  """
43
57
  intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
44
58
  unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
46
60
 
47
61
 
48
62
  def build_point_grid(n_per_side: int) -> np.ndarray:
49
- """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
63
+ """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
50
64
  offset = 1 / (2 * n_per_side)
51
65
  points_one_side = np.linspace(offset, 1 - offset, n_per_side)
52
66
  points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
@@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
55
69
 
56
70
 
57
71
  def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
58
- """Generate point grids for all crop layers."""
72
+ """Generates point grids for multiple crop layers with varying scales and densities."""
59
73
  return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
60
74
 
61
75
 
62
76
  def generate_crop_boxes(
63
77
  im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
64
78
  ) -> Tuple[List[List[int]], List[int]]:
65
- """
66
- Generates a list of crop boxes of different sizes.
67
-
68
- Each layer has (2**i)**2 boxes for the ith layer.
69
- """
79
+ """Generates crop boxes of varying sizes for multi-scale image processing, with layered overlapping regions."""
70
80
  crop_boxes, layer_idxs = [], []
71
81
  im_h, im_w = im_size
72
82
  short_side = min(im_h, im_w)
@@ -99,7 +109,7 @@ def generate_crop_boxes(
99
109
 
100
110
 
101
111
  def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
102
- """Uncrop bounding boxes by adding the crop box offset."""
112
+ """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
103
113
  x0, y0, _, _ = crop_box
104
114
  offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
105
115
  # Check if boxes has a channel dimension
@@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
109
119
 
110
120
 
111
121
  def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
112
- """Uncrop points by adding the crop box offset."""
122
+ """Uncrop points by adding the crop box offset to their coordinates."""
113
123
  x0, y0, _, _ = crop_box
114
124
  offset = torch.tensor([[x0, y0]], device=points.device)
115
125
  # Check if points has a channel dimension
@@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
119
129
 
120
130
 
121
131
  def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
122
- """Uncrop masks by padding them to the original image size."""
132
+ """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
123
133
  x0, y0, x1, y1 = crop_box
124
134
  if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
125
135
  return masks
@@ -130,7 +140,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
130
140
 
131
141
 
132
142
  def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
133
- """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
143
+ """Removes small disconnected regions or holes in a mask based on area threshold and mode."""
134
144
  import cv2 # type: ignore
135
145
 
136
146
  assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
@@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
150
160
 
151
161
 
152
162
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
153
- """
154
- Calculates boxes in XYXY format around masks.
155
-
156
- Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
157
- """
163
+ """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
158
164
  # torch.max below raises an error on empty inputs, just skip in this case
159
165
  if torch.numel(masks) == 0:
160
166
  return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
@@ -13,14 +13,15 @@ import torch
13
13
  from ultralytics.utils.downloads import attempt_download_asset
14
14
 
15
15
  from .modules.decoders import MaskDecoder
16
- from .modules.encoders import ImageEncoderViT, PromptEncoder
17
- from .modules.sam import SAMModel
16
+ from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
17
+ from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
18
+ from .modules.sam import SAM2Model, SAMModel
18
19
  from .modules.tiny_encoder import TinyViT
19
20
  from .modules.transformer import TwoWayTransformer
20
21
 
21
22
 
22
23
  def build_sam_vit_h(checkpoint=None):
23
- """Build and return a Segment Anything Model (SAM) h-size model."""
24
+ """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
24
25
  return _build_sam(
25
26
  encoder_embed_dim=1280,
26
27
  encoder_depth=32,
@@ -31,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
31
32
 
32
33
 
33
34
  def build_sam_vit_l(checkpoint=None):
34
- """Build and return a Segment Anything Model (SAM) l-size model."""
35
+ """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
35
36
  return _build_sam(
36
37
  encoder_embed_dim=1024,
37
38
  encoder_depth=24,
@@ -42,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
42
43
 
43
44
 
44
45
  def build_sam_vit_b(checkpoint=None):
45
- """Build and return a Segment Anything Model (SAM) b-size model."""
46
+ """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
46
47
  return _build_sam(
47
48
  encoder_embed_dim=768,
48
49
  encoder_depth=12,
@@ -53,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
53
54
 
54
55
 
55
56
  def build_mobile_sam(checkpoint=None):
56
- """Build and return Mobile Segment Anything Model (Mobile-SAM)."""
57
+ """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
57
58
  return _build_sam(
58
59
  encoder_embed_dim=[64, 128, 160, 320],
59
60
  encoder_depth=[2, 2, 6, 2],
@@ -64,10 +65,85 @@ def build_mobile_sam(checkpoint=None):
64
65
  )
65
66
 
66
67
 
68
+ def build_sam2_t(checkpoint=None):
69
+ """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
70
+ return _build_sam2(
71
+ encoder_embed_dim=96,
72
+ encoder_stages=[1, 2, 7, 2],
73
+ encoder_num_heads=1,
74
+ encoder_global_att_blocks=[5, 7, 9],
75
+ encoder_window_spec=[8, 4, 14, 7],
76
+ encoder_backbone_channel_list=[768, 384, 192, 96],
77
+ checkpoint=checkpoint,
78
+ )
79
+
80
+
81
+ def build_sam2_s(checkpoint=None):
82
+ """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
83
+ return _build_sam2(
84
+ encoder_embed_dim=96,
85
+ encoder_stages=[1, 2, 11, 2],
86
+ encoder_num_heads=1,
87
+ encoder_global_att_blocks=[7, 10, 13],
88
+ encoder_window_spec=[8, 4, 14, 7],
89
+ encoder_backbone_channel_list=[768, 384, 192, 96],
90
+ checkpoint=checkpoint,
91
+ )
92
+
93
+
94
+ def build_sam2_b(checkpoint=None):
95
+ """Builds and returns a SAM2 base-size model with specified architecture parameters."""
96
+ return _build_sam2(
97
+ encoder_embed_dim=112,
98
+ encoder_stages=[2, 3, 16, 3],
99
+ encoder_num_heads=2,
100
+ encoder_global_att_blocks=[12, 16, 20],
101
+ encoder_window_spec=[8, 4, 14, 7],
102
+ encoder_window_spatial_size=[14, 14],
103
+ encoder_backbone_channel_list=[896, 448, 224, 112],
104
+ checkpoint=checkpoint,
105
+ )
106
+
107
+
108
+ def build_sam2_l(checkpoint=None):
109
+ """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
110
+ return _build_sam2(
111
+ encoder_embed_dim=144,
112
+ encoder_stages=[2, 6, 36, 4],
113
+ encoder_num_heads=2,
114
+ encoder_global_att_blocks=[23, 33, 43],
115
+ encoder_window_spec=[8, 4, 16, 8],
116
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
117
+ checkpoint=checkpoint,
118
+ )
119
+
120
+
67
121
  def _build_sam(
68
- encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
122
+ encoder_embed_dim,
123
+ encoder_depth,
124
+ encoder_num_heads,
125
+ encoder_global_attn_indexes,
126
+ checkpoint=None,
127
+ mobile_sam=False,
69
128
  ):
70
- """Builds the selected SAM model architecture."""
129
+ """
130
+ Builds a Segment Anything Model (SAM) with specified encoder parameters.
131
+
132
+ Args:
133
+ encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
134
+ encoder_depth (int | List[int]): Depth of the encoder.
135
+ encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
136
+ encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
137
+ checkpoint (str | None): Path to the model checkpoint file.
138
+ mobile_sam (bool): Whether to build a Mobile-SAM model.
139
+
140
+ Returns:
141
+ (SAMModel): A Segment Anything Model instance with the specified architecture.
142
+
143
+ Examples:
144
+ >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
145
+ >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
146
+ """
71
147
  prompt_embed_dim = 256
72
148
  image_size = 1024
73
149
  vit_patch_size = 16
@@ -139,16 +215,131 @@ def _build_sam(
139
215
  return sam
140
216
 
141
217
 
218
+ def _build_sam2(
219
+ encoder_embed_dim=1280,
220
+ encoder_stages=[2, 6, 36, 4],
221
+ encoder_num_heads=2,
222
+ encoder_global_att_blocks=[7, 15, 23, 31],
223
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
224
+ encoder_window_spatial_size=[7, 7],
225
+ encoder_window_spec=[8, 4, 16, 8],
226
+ checkpoint=None,
227
+ ):
228
+ """
229
+ Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
230
+
231
+ Args:
232
+ encoder_embed_dim (int): Embedding dimension for the encoder.
233
+ encoder_stages (List[int]): Number of blocks in each stage of the encoder.
234
+ encoder_num_heads (int): Number of attention heads in the encoder.
235
+ encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
236
+ encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
237
+ encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
238
+ encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
239
+ checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
240
+
241
+ Returns:
242
+ (SAM2Model): A configured and initialized SAM2 model.
243
+
244
+ Examples:
245
+ >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
246
+ >>> sam2_model.eval()
247
+ """
248
+ image_encoder = ImageEncoder(
249
+ trunk=Hiera(
250
+ embed_dim=encoder_embed_dim,
251
+ num_heads=encoder_num_heads,
252
+ stages=encoder_stages,
253
+ global_att_blocks=encoder_global_att_blocks,
254
+ window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
255
+ window_spec=encoder_window_spec,
256
+ ),
257
+ neck=FpnNeck(
258
+ d_model=256,
259
+ backbone_channel_list=encoder_backbone_channel_list,
260
+ fpn_top_down_levels=[2, 3],
261
+ fpn_interp_model="nearest",
262
+ ),
263
+ scalp=1,
264
+ )
265
+ memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
266
+ memory_encoder = MemoryEncoder(out_dim=64)
267
+
268
+ sam2 = SAM2Model(
269
+ image_encoder=image_encoder,
270
+ memory_attention=memory_attention,
271
+ memory_encoder=memory_encoder,
272
+ num_maskmem=7,
273
+ image_size=1024,
274
+ sigmoid_scale_for_mem_enc=20.0,
275
+ sigmoid_bias_for_mem_enc=-10.0,
276
+ use_mask_input_as_output_without_sam=True,
277
+ directly_add_no_mem_embed=True,
278
+ use_high_res_features_in_sam=True,
279
+ multimask_output_in_sam=True,
280
+ iou_prediction_use_sigmoid=True,
281
+ use_obj_ptrs_in_encoder=True,
282
+ add_tpos_enc_to_obj_ptrs=True,
283
+ only_obj_ptrs_in_the_past_for_eval=True,
284
+ pred_obj_scores=True,
285
+ pred_obj_scores_mlp=True,
286
+ fixed_no_obj_ptr=True,
287
+ multimask_output_for_tracking=True,
288
+ use_multimask_token_for_obj_ptr=True,
289
+ multimask_min_pt_num=0,
290
+ multimask_max_pt_num=1,
291
+ use_mlp_for_obj_ptr_proj=True,
292
+ compile_image_encoder=False,
293
+ sam_mask_decoder_extra_args=dict(
294
+ dynamic_multimask_via_stability=True,
295
+ dynamic_multimask_stability_delta=0.05,
296
+ dynamic_multimask_stability_thresh=0.98,
297
+ ),
298
+ )
299
+
300
+ if checkpoint is not None:
301
+ checkpoint = attempt_download_asset(checkpoint)
302
+ with open(checkpoint, "rb") as f:
303
+ state_dict = torch.load(f)["model"]
304
+ sam2.load_state_dict(state_dict)
305
+ sam2.eval()
306
+ return sam2
307
+
308
+
142
309
  sam_model_map = {
143
310
  "sam_h.pt": build_sam_vit_h,
144
311
  "sam_l.pt": build_sam_vit_l,
145
312
  "sam_b.pt": build_sam_vit_b,
146
313
  "mobile_sam.pt": build_mobile_sam,
314
+ "sam2_t.pt": build_sam2_t,
315
+ "sam2_s.pt": build_sam2_s,
316
+ "sam2_b.pt": build_sam2_b,
317
+ "sam2_l.pt": build_sam2_l,
147
318
  }
148
319
 
149
320
 
150
321
  def build_sam(ckpt="sam_b.pt"):
151
- """Build a SAM model specified by ckpt."""
322
+ """
323
+ Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
324
+
325
+ Args:
326
+ ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
327
+
328
+ Returns:
329
+ (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
330
+
331
+ Raises:
332
+ FileNotFoundError: If the provided checkpoint is not a supported SAM model.
333
+
334
+ Examples:
335
+ >>> sam_model = build_sam("sam_b.pt")
336
+ >>> sam_model = build_sam("path/to/custom_checkpoint.pt")
337
+
338
+ Notes:
339
+ Supported pre-defined models include:
340
+ - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
341
+ - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
342
+ """
152
343
  model_builder = None
153
344
  ckpt = str(ckpt) # to allow Path ckpt types
154
345
  for k in sam_model_map.keys():
@@ -20,27 +20,46 @@ from ultralytics.engine.model import Model
20
20
  from ultralytics.utils.torch_utils import model_info
21
21
 
22
22
  from .build import build_sam
23
- from .predict import Predictor
23
+ from .predict import Predictor, SAM2Predictor
24
24
 
25
25
 
26
26
  class SAM(Model):
27
27
  """
28
- SAM (Segment Anything Model) interface class.
29
-
30
- SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
31
- bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
32
- dataset.
28
+ SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
29
+
30
+ This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
31
+ promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
32
+ boxes, points, or labels, and features zero-shot performance capabilities.
33
+
34
+ Attributes:
35
+ model (torch.nn.Module): The loaded SAM model.
36
+ is_sam2 (bool): Indicates whether the model is SAM2 variant.
37
+ task (str): The task type, set to "segment" for SAM models.
38
+
39
+ Methods:
40
+ predict: Performs segmentation prediction on the given image or video source.
41
+ info: Logs information about the SAM model.
42
+
43
+ Examples:
44
+ >>> sam = SAM('sam_b.pt')
45
+ >>> results = sam.predict('image.jpg', points=[[500, 375]])
46
+ >>> for r in results:
47
+ >>> print(f"Detected {len(r.masks)} masks")
33
48
  """
34
49
 
35
50
  def __init__(self, model="sam_b.pt") -> None:
36
51
  """
37
- Initializes the SAM model with a pre-trained model file.
52
+ Initializes the SAM (Segment Anything Model) instance.
38
53
 
39
54
  Args:
40
55
  model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
41
56
 
42
57
  Raises:
43
58
  NotImplementedError: If the model file extension is not .pt or .pth.
59
+
60
+ Examples:
61
+ >>> sam = SAM('sam_b.pt')
62
+ >>> print(sam.is_sam2)
44
63
  """
45
64
  if model and Path(model).suffix not in {".pt", ".pth"}:
46
65
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
@@ -51,30 +70,40 @@ class SAM(Model):
51
70
  """
52
71
  Loads the specified weights into the SAM model.
53
72
 
73
+ This method initializes the SAM model with the provided weights file, setting up the model architecture
74
+ and loading the pre-trained parameters.
75
+
54
76
  Args:
55
- weights (str): Path to the weights file.
56
- task (str, optional): Task name. Defaults to None.
57
- """
58
- if self.is_sam2:
59
- from ..sam2.build import build_sam2
77
+ weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
78
+ task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
60
79
 
61
- self.model = build_sam2(weights)
62
- else:
63
- self.model = build_sam(weights)
80
+ Examples:
81
+ >>> sam = SAM('sam_b.pt')
82
+ >>> sam._load('path/to/custom_weights.pt')
83
+ """
84
+ self.model = build_sam(weights)
64
85
 
65
86
  def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
66
87
  """
67
88
  Performs segmentation prediction on the given image or video source.
68
89
 
69
90
  Args:
70
- source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
71
- stream (bool, optional): If True, enables real-time streaming. Defaults to False.
72
- bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
73
- points (list, optional): List of points for prompted segmentation. Defaults to None.
74
- labels (list, optional): List of labels for prompted segmentation. Defaults to None.
91
+ source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
92
+ a numpy.ndarray object.
93
+ stream (bool): If True, enables real-time streaming.
94
+ bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
95
+ points (List[List[float]] | None): List of points for prompted segmentation.
96
+ labels (List[int] | None): List of labels for prompted segmentation.
97
+ **kwargs (Any): Additional keyword arguments for prediction.
75
98
 
76
99
  Returns:
77
- (list): The model predictions.
100
+ (List): The model predictions.
101
+
102
+ Examples:
103
+ >>> sam = SAM('sam_b.pt')
104
+ >>> results = sam.predict('image.jpg', points=[[500, 375]])
105
+ >>> for r in results:
106
+ ... print(f"Detected {len(r.masks)} masks")
78
107
  """
79
108
  overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
80
109
  kwargs.update(overrides)
@@ -83,17 +112,27 @@ class SAM(Model):
83
112
 
84
113
  def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
85
114
  """
86
- Alias for the 'predict' method.
115
+ Performs segmentation prediction on the given image or video source.
116
+
117
+ This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
118
+ for segmentation tasks.
87
119
 
88
120
  Args:
89
- source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
90
- stream (bool, optional): If True, enables real-time streaming. Defaults to False.
91
- bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
92
- points (list, optional): List of points for prompted segmentation. Defaults to None.
93
- labels (list, optional): List of labels for prompted segmentation. Defaults to None.
121
+ source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
122
+ object, or a numpy.ndarray object.
123
+ stream (bool): If True, enables real-time streaming.
124
+ bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
125
+ points (List[List[float]] | None): List of points for prompted segmentation.
126
+ labels (List[int] | None): List of labels for prompted segmentation.
127
+ **kwargs (Any): Additional keyword arguments to be passed to the predict method.
94
128
 
95
129
  Returns:
96
- (list): The model predictions.
130
+ (List): The model predictions, typically containing segmentation masks and other relevant information.
131
+
132
+ Examples:
133
+ >>> sam = SAM('sam_b.pt')
134
+ >>> results = sam('image.jpg', points=[[500, 375]])
135
+ >>> print(f"Detected {len(results[0].masks)} masks")
97
136
  """
98
137
  return self.predict(source, stream, bboxes, points, labels, **kwargs)
99
138
 
@@ -101,12 +140,20 @@ class SAM(Model):
101
140
  """
102
141
  Logs information about the SAM model.
103
142
 
143
+ This method provides details about the Segment Anything Model (SAM), including its architecture,
144
+ parameters, and computational requirements.
145
+
104
146
  Args:
105
- detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
106
- verbose (bool, optional): If True, displays information on the console. Defaults to True.
147
+ detailed (bool): If True, displays detailed information about the model layers and operations.
148
+ verbose (bool): If True, prints the information to the console.
107
149
 
108
150
  Returns:
109
- (tuple): A tuple containing the model's information.
151
+ (Tuple): A tuple containing the model's information (string representations of the model).
152
+
153
+ Examples:
154
+ >>> sam = SAM('sam_b.pt')
155
+ >>> info = sam.info()
156
+ >>> print(info[0]) # Print summary information
110
157
  """
111
158
  return model_info(self.model, detailed=detailed, verbose=verbose)
112
159
 
@@ -116,8 +163,13 @@ class SAM(Model):
116
163
  Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
117
164
 
118
165
  Returns:
119
- (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
166
+ (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
167
+ class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
168
+
169
+ Examples:
170
+ >>> sam = SAM('sam_b.pt')
171
+ >>> task_map = sam.task_map
172
+ >>> print(task_map)
173
+ {'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
120
174
  """
121
- from ..sam2.predict import SAM2Predictor
122
-
123
175
  return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}