ultralytics 8.2.69__py3-none-any.whl → 8.2.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (35) hide show
  1. ultralytics/__init__.py +3 -2
  2. ultralytics/cfg/__init__.py +4 -0
  3. ultralytics/data/converter.py +81 -0
  4. ultralytics/engine/trainer.py +3 -2
  5. ultralytics/engine/validator.py +2 -2
  6. ultralytics/models/__init__.py +2 -1
  7. ultralytics/models/fastsam/predict.py +1 -0
  8. ultralytics/models/sam/build.py +2 -2
  9. ultralytics/models/sam/model.py +10 -2
  10. ultralytics/models/sam/modules/decoders.py +1 -42
  11. ultralytics/models/sam/modules/encoders.py +3 -1
  12. ultralytics/models/sam/modules/sam.py +5 -7
  13. ultralytics/models/sam/modules/transformer.py +4 -3
  14. ultralytics/models/sam/predict.py +12 -6
  15. ultralytics/models/sam2/__init__.py +6 -0
  16. ultralytics/models/sam2/build.py +156 -0
  17. ultralytics/models/sam2/model.py +97 -0
  18. ultralytics/models/sam2/modules/__init__.py +1 -0
  19. ultralytics/models/sam2/modules/decoders.py +305 -0
  20. ultralytics/models/sam2/modules/encoders.py +332 -0
  21. ultralytics/models/sam2/modules/memory_attention.py +170 -0
  22. ultralytics/models/sam2/modules/sam2.py +804 -0
  23. ultralytics/models/sam2/modules/sam2_blocks.py +715 -0
  24. ultralytics/models/sam2/modules/utils.py +191 -0
  25. ultralytics/models/sam2/predict.py +182 -0
  26. ultralytics/nn/modules/transformer.py +5 -3
  27. ultralytics/utils/__init__.py +9 -9
  28. ultralytics/utils/plotting.py +1 -1
  29. ultralytics/utils/torch_utils.py +11 -7
  30. {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/METADATA +1 -1
  31. {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/RECORD +35 -24
  32. {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/LICENSE +0 -0
  33. {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/WHEEL +0 -0
  34. {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/entry_points.txt +0 -0
  35. {ultralytics-8.2.69.dist-info → ultralytics-8.2.71.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.2.69"
3
+ __version__ = "8.2.71"
4
4
 
5
5
  import os
6
6
 
@@ -8,7 +8,7 @@ import os
8
8
  os.environ["OMP_NUM_THREADS"] = "1" # reduce CPU utilization during training
9
9
 
10
10
  from ultralytics.data.explorer.explorer import Explorer
11
- from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
11
+ from ultralytics.models import NAS, RTDETR, SAM, SAM2, YOLO, FastSAM, YOLOWorld
12
12
  from ultralytics.utils import ASSETS, SETTINGS
13
13
  from ultralytics.utils.checks import check_yolo as checks
14
14
  from ultralytics.utils.downloads import download
@@ -21,6 +21,7 @@ __all__ = (
21
21
  "YOLOWorld",
22
22
  "NAS",
23
23
  "SAM",
24
+ "SAM2",
24
25
  "FastSAM",
25
26
  "RTDETR",
26
27
  "checks",
@@ -793,6 +793,10 @@ def entrypoint(debug=""):
793
793
  from ultralytics import FastSAM
794
794
 
795
795
  model = FastSAM(model)
796
+ elif "sam2" in stem:
797
+ from ultralytics import SAM2
798
+
799
+ model = SAM2(model)
796
800
  elif "sam" in stem:
797
801
  from ultralytics import SAM
798
802
 
@@ -334,6 +334,87 @@ def convert_coco(
334
334
  LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
335
335
 
336
336
 
337
+ def convert_segment_masks_to_yolo_seg(masks_dir, output_dir, classes):
338
+ """
339
+ Converts a dataset of segmentation mask images to the YOLO segmentation format.
340
+
341
+ This function takes the directory containing the binary format mask images and converts them into YOLO segmentation format.
342
+ The converted masks are saved in the specified output directory.
343
+
344
+ Args:
345
+ masks_dir (str): The path to the directory where all mask images (png, jpg) are stored.
346
+ output_dir (str): The path to the directory where the converted YOLO segmentation masks will be stored.
347
+ classes (int): Total classes in the dataset i.e for COCO classes=80
348
+
349
+ Example:
350
+ ```python
351
+ from ultralytics.data.converter import convert_segment_masks_to_yolo_seg
352
+
353
+ # for coco dataset, we have 80 classes
354
+ convert_segment_masks_to_yolo_seg('path/to/masks_directory', 'path/to/output/directory', classes=80)
355
+ ```
356
+
357
+ Notes:
358
+ The expected directory structure for the masks is:
359
+
360
+ - masks
361
+ ├─ mask_image_01.png or mask_image_01.jpg
362
+ ├─ mask_image_02.png or mask_image_02.jpg
363
+ ├─ mask_image_03.png or mask_image_03.jpg
364
+ └─ mask_image_04.png or mask_image_04.jpg
365
+
366
+ After execution, the labels will be organized in the following structure:
367
+
368
+ - output_dir
369
+ ├─ mask_yolo_01.txt
370
+ ├─ mask_yolo_02.txt
371
+ ├─ mask_yolo_03.txt
372
+ └─ mask_yolo_04.txt
373
+ """
374
+ import os
375
+
376
+ pixel_to_class_mapping = {i + 1: i for i in range(80)}
377
+ for mask_filename in os.listdir(masks_dir):
378
+ if mask_filename.endswith(".png"):
379
+ mask_path = os.path.join(masks_dir, mask_filename)
380
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Read the mask image in grayscale
381
+ img_height, img_width = mask.shape # Get image dimensions
382
+ LOGGER.info(f"Processing {mask_path} imgsz = {img_height} x {img_width}")
383
+
384
+ unique_values = np.unique(mask) # Get unique pixel values representing different classes
385
+ yolo_format_data = []
386
+
387
+ for value in unique_values:
388
+ if value == 0:
389
+ continue # Skip background
390
+ class_index = pixel_to_class_mapping.get(value, -1)
391
+ if class_index == -1:
392
+ LOGGER.warning(f"Unknown class for pixel value {value} in file {mask_filename}, skipping.")
393
+ continue
394
+
395
+ # Create a binary mask for the current class and find contours
396
+ contours, _ = cv2.findContours(
397
+ (mask == value).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
398
+ ) # Find contours
399
+
400
+ for contour in contours:
401
+ if len(contour) >= 3: # YOLO requires at least 3 points for a valid segmentation
402
+ contour = contour.squeeze() # Remove single-dimensional entries
403
+ yolo_format = [class_index]
404
+ for point in contour:
405
+ # Normalize the coordinates
406
+ yolo_format.append(round(point[0] / img_width, 6)) # Rounding to 6 decimal places
407
+ yolo_format.append(round(point[1] / img_height, 6))
408
+ yolo_format_data.append(yolo_format)
409
+ # Save Ultralytics YOLO format data to file
410
+ output_path = os.path.join(output_dir, os.path.splitext(mask_filename)[0] + ".txt")
411
+ with open(output_path, "w") as file:
412
+ for item in yolo_format_data:
413
+ line = " ".join(map(str, item))
414
+ file.write(line + "\n")
415
+ LOGGER.info(f"Processed and stored at {output_path} imgsz = {img_height} x {img_width}")
416
+
417
+
337
418
  def convert_dota_to_yolo_obb(dota_root_path: str):
338
419
  """
339
420
  Converts DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.
@@ -26,6 +26,7 @@ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
26
26
  from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
27
27
  from ultralytics.utils import (
28
28
  DEFAULT_CFG,
29
+ LOCAL_RANK,
29
30
  LOGGER,
30
31
  RANK,
31
32
  TQDM,
@@ -129,7 +130,7 @@ class BaseTrainer:
129
130
 
130
131
  # Model and Dataset
131
132
  self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
132
- with torch_distributed_zero_first(RANK): # avoid auto-downloading dataset multiple times
133
+ with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
133
134
  self.trainset, self.testset = self.get_dataset()
134
135
  self.ema = None
135
136
 
@@ -285,7 +286,7 @@ class BaseTrainer:
285
286
 
286
287
  # Dataloaders
287
288
  batch_size = self.batch_size // max(world_size, 1)
288
- self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
289
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
289
290
  if RANK in {-1, 0}:
290
291
  # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
291
292
  self.test_loader = self.get_dataloader(
@@ -136,8 +136,8 @@ class BaseValidator:
136
136
  if engine:
137
137
  self.args.batch = model.batch_size
138
138
  elif not pt and not jit:
139
- self.args.batch = 1 # export.py models default to batch-size 1
140
- LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
139
+ self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
140
+ LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
141
141
 
142
142
  if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
143
143
  self.data = check_det_dataset(self.args.data)
@@ -4,6 +4,7 @@ from .fastsam import FastSAM
4
4
  from .nas import NAS
5
5
  from .rtdetr import RTDETR
6
6
  from .sam import SAM
7
+ from .sam2 import SAM2
7
8
  from .yolo import YOLO, YOLOWorld
8
9
 
9
- __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
10
+ __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import
@@ -21,6 +21,7 @@ class FastSAMPredictor(SegmentationPredictor):
21
21
  """
22
22
 
23
23
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
24
+ """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
24
25
  super().__init__(cfg, overrides, _callbacks)
25
26
  self.prompts = {}
26
27
 
@@ -14,7 +14,7 @@ from ultralytics.utils.downloads import attempt_download_asset
14
14
 
15
15
  from .modules.decoders import MaskDecoder
16
16
  from .modules.encoders import ImageEncoderViT, PromptEncoder
17
- from .modules.sam import Sam
17
+ from .modules.sam import SAMModel
18
18
  from .modules.tiny_encoder import TinyViT
19
19
  from .modules.transformer import TwoWayTransformer
20
20
 
@@ -105,7 +105,7 @@ def _build_sam(
105
105
  out_chans=prompt_embed_dim,
106
106
  )
107
107
  )
108
- sam = Sam(
108
+ sam = SAMModel(
109
109
  image_encoder=image_encoder,
110
110
  prompt_encoder=PromptEncoder(
111
111
  embed_dim=prompt_embed_dim,
@@ -44,6 +44,7 @@ class SAM(Model):
44
44
  """
45
45
  if model and Path(model).suffix not in {".pt", ".pth"}:
46
46
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
47
+ self.is_sam2 = "sam2" in Path(model).stem
47
48
  super().__init__(model=model, task="segment")
48
49
 
49
50
  def _load(self, weights: str, task=None):
@@ -54,7 +55,12 @@ class SAM(Model):
54
55
  weights (str): Path to the weights file.
55
56
  task (str, optional): Task name. Defaults to None.
56
57
  """
57
- self.model = build_sam(weights)
58
+ if self.is_sam2:
59
+ from ..sam2.build import build_sam2
60
+
61
+ self.model = build_sam2(weights)
62
+ else:
63
+ self.model = build_sam(weights)
58
64
 
59
65
  def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
60
66
  """
@@ -112,4 +118,6 @@ class SAM(Model):
112
118
  Returns:
113
119
  (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
114
120
  """
115
- return {"segment": {"predictor": Predictor}}
121
+ from ..sam2.predict import SAM2Predictor
122
+
123
+ return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
@@ -4,9 +4,8 @@ from typing import List, Tuple, Type
4
4
 
5
5
  import torch
6
6
  from torch import nn
7
- from torch.nn import functional as F
8
7
 
9
- from ultralytics.nn.modules import LayerNorm2d
8
+ from ultralytics.nn.modules import MLP, LayerNorm2d
10
9
 
11
10
 
12
11
  class MaskDecoder(nn.Module):
@@ -28,7 +27,6 @@ class MaskDecoder(nn.Module):
28
27
 
29
28
  def __init__(
30
29
  self,
31
- *,
32
30
  transformer_dim: int,
33
31
  transformer: nn.Module,
34
32
  num_multimask_outputs: int = 3,
@@ -149,42 +147,3 @@ class MaskDecoder(nn.Module):
149
147
  iou_pred = self.iou_prediction_head(iou_token_out)
150
148
 
151
149
  return masks, iou_pred
152
-
153
-
154
- class MLP(nn.Module):
155
- """
156
- MLP (Multi-Layer Perceptron) model lightly adapted from
157
- https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
158
- """
159
-
160
- def __init__(
161
- self,
162
- input_dim: int,
163
- hidden_dim: int,
164
- output_dim: int,
165
- num_layers: int,
166
- sigmoid_output: bool = False,
167
- ) -> None:
168
- """
169
- Initializes the MLP (Multi-Layer Perceptron) model.
170
-
171
- Args:
172
- input_dim (int): The dimensionality of the input features.
173
- hidden_dim (int): The dimensionality of the hidden layers.
174
- output_dim (int): The dimensionality of the output layer.
175
- num_layers (int): The number of hidden layers.
176
- sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False.
177
- """
178
- super().__init__()
179
- self.num_layers = num_layers
180
- h = [hidden_dim] * (num_layers - 1)
181
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
182
- self.sigmoid_output = sigmoid_output
183
-
184
- def forward(self, x):
185
- """Executes feedforward within the neural network module and applies activation."""
186
- for i, layer in enumerate(self.layers):
187
- x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
188
- if self.sigmoid_output:
189
- x = torch.sigmoid(x)
190
- return x
@@ -211,6 +211,8 @@ class PromptEncoder(nn.Module):
211
211
  point_embedding[labels == -1] += self.not_a_point_embed.weight
212
212
  point_embedding[labels == 0] += self.point_embeddings[0].weight
213
213
  point_embedding[labels == 1] += self.point_embeddings[1].weight
214
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
215
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
214
216
  return point_embedding
215
217
 
216
218
  def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
@@ -226,8 +228,8 @@ class PromptEncoder(nn.Module):
226
228
  """Embeds mask inputs."""
227
229
  return self.mask_downscaling(masks)
228
230
 
231
+ @staticmethod
229
232
  def _get_batch_size(
230
- self,
231
233
  points: Optional[Tuple[torch.Tensor, torch.Tensor]],
232
234
  boxes: Optional[torch.Tensor],
233
235
  masks: Optional[torch.Tensor],
@@ -15,15 +15,14 @@ from .decoders import MaskDecoder
15
15
  from .encoders import ImageEncoderViT, PromptEncoder
16
16
 
17
17
 
18
- class Sam(nn.Module):
18
+ class SAMModel(nn.Module):
19
19
  """
20
- Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
21
- embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
22
- decoder to predict object masks.
20
+ SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate
21
+ image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by
22
+ the mask decoder to predict object masks.
23
23
 
24
24
  Attributes:
25
25
  mask_threshold (float): Threshold value for mask prediction.
26
- image_format (str): Format of the input image, default is 'RGB'.
27
26
  image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
28
27
  prompt_encoder (PromptEncoder): Encodes various types of input prompts.
29
28
  mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
@@ -32,7 +31,6 @@ class Sam(nn.Module):
32
31
  """
33
32
 
34
33
  mask_threshold: float = 0.0
35
- image_format: str = "RGB"
36
34
 
37
35
  def __init__(
38
36
  self,
@@ -43,7 +41,7 @@ class Sam(nn.Module):
43
41
  pixel_std: List[float] = (58.395, 57.12, 57.375),
44
42
  ) -> None:
45
43
  """
46
- Initialize the Sam class to predict object masks from an image and input prompts.
44
+ Initialize the SAMModel class to predict object masks from an image and input prompts.
47
45
 
48
46
  Note:
49
47
  All forward() operations moved to SAMPredictor.
@@ -86,7 +86,6 @@ class TwoWayTransformer(nn.Module):
86
86
  (torch.Tensor): the processed image_embedding
87
87
  """
88
88
  # BxCxHxW -> BxHWxC == B x N_image_tokens x C
89
- bs, c, h, w = image_embedding.shape
90
89
  image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
91
90
  image_pe = image_pe.flatten(2).permute(0, 2, 1)
92
91
 
@@ -212,6 +211,7 @@ class Attention(nn.Module):
212
211
  embedding_dim: int,
213
212
  num_heads: int,
214
213
  downsample_rate: int = 1,
214
+ kv_in_dim: int = None,
215
215
  ) -> None:
216
216
  """
217
217
  Initializes the Attention model with the given dimensions and settings.
@@ -226,13 +226,14 @@ class Attention(nn.Module):
226
226
  """
227
227
  super().__init__()
228
228
  self.embedding_dim = embedding_dim
229
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
229
230
  self.internal_dim = embedding_dim // downsample_rate
230
231
  self.num_heads = num_heads
231
232
  assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
232
233
 
233
234
  self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
234
- self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
235
- self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
235
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
236
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
236
237
  self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
237
238
 
238
239
  @staticmethod
@@ -168,7 +168,7 @@ class Predictor(BasePredictor):
168
168
  - np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
169
169
  - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
170
170
  """
171
- features = self.model.image_encoder(im) if self.features is None else self.features
171
+ features = self.get_im_features(im) if self.features is None else self.features
172
172
 
173
173
  src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
174
174
  r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
@@ -334,7 +334,7 @@ class Predictor(BasePredictor):
334
334
  """
335
335
  device = select_device(self.args.device, verbose=verbose)
336
336
  if model is None:
337
- model = build_sam(self.args.model)
337
+ model = self.get_model()
338
338
  model.eval()
339
339
  self.model = model.to(device)
340
340
  self.device = device
@@ -348,6 +348,10 @@ class Predictor(BasePredictor):
348
348
  self.model.fp16 = False
349
349
  self.done_warmup = True
350
350
 
351
+ def get_model(self):
352
+ """Built Segment Anything Model (SAM) model."""
353
+ return build_sam(self.args.model)
354
+
351
355
  def postprocess(self, preds, img, orig_imgs):
352
356
  """
353
357
  Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
@@ -412,16 +416,18 @@ class Predictor(BasePredictor):
412
416
  AssertionError: If more than one image is set.
413
417
  """
414
418
  if self.model is None:
415
- model = build_sam(self.args.model)
416
- self.setup_model(model)
419
+ self.setup_model(model=None)
417
420
  self.setup_source(image)
418
421
  assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
419
422
  for batch in self.dataset:
420
423
  im = self.preprocess(batch[1])
421
- self.features = self.model.image_encoder(im)
422
- self.im = im
424
+ self.features = self.get_im_features(im)
423
425
  break
424
426
 
427
+ def get_im_features(self, im):
428
+ """Get image features from the SAM image encoder."""
429
+ return self.model.image_encoder(im)
430
+
425
431
  def set_prompts(self, prompts):
426
432
  """Set prompts in advance."""
427
433
  self.prompts = prompts
@@ -0,0 +1,6 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .model import SAM2
4
+ from .predict import SAM2Predictor
5
+
6
+ __all__ = "SAM2", "SAM2Predictor" # tuple or list
@@ -0,0 +1,156 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import torch
4
+
5
+ from ultralytics.utils.downloads import attempt_download_asset
6
+
7
+ from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder
8
+ from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
9
+ from .modules.sam2 import SAM2Model
10
+
11
+
12
+ def build_sam2_t(checkpoint=None):
13
+ """Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters."""
14
+ return _build_sam2(
15
+ encoder_embed_dim=96,
16
+ encoder_stages=[1, 2, 7, 2],
17
+ encoder_num_heads=1,
18
+ encoder_global_att_blocks=[5, 7, 9],
19
+ encoder_window_spec=[8, 4, 14, 7],
20
+ encoder_backbone_channel_list=[768, 384, 192, 96],
21
+ checkpoint=checkpoint,
22
+ )
23
+
24
+
25
+ def build_sam2_s(checkpoint=None):
26
+ """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
27
+ return _build_sam2(
28
+ encoder_embed_dim=96,
29
+ encoder_stages=[1, 2, 11, 2],
30
+ encoder_num_heads=1,
31
+ encoder_global_att_blocks=[7, 10, 13],
32
+ encoder_window_spec=[8, 4, 14, 7],
33
+ encoder_backbone_channel_list=[768, 384, 192, 96],
34
+ checkpoint=checkpoint,
35
+ )
36
+
37
+
38
+ def build_sam2_b(checkpoint=None):
39
+ """Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters."""
40
+ return _build_sam2(
41
+ encoder_embed_dim=112,
42
+ encoder_stages=[2, 3, 16, 3],
43
+ encoder_num_heads=2,
44
+ encoder_global_att_blocks=[12, 16, 20],
45
+ encoder_window_spec=[8, 4, 14, 7],
46
+ encoder_window_spatial_size=[14, 14],
47
+ encoder_backbone_channel_list=[896, 448, 224, 112],
48
+ checkpoint=checkpoint,
49
+ )
50
+
51
+
52
+ def build_sam2_l(checkpoint=None):
53
+ """Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters."""
54
+ return _build_sam2(
55
+ encoder_embed_dim=144,
56
+ encoder_stages=[2, 6, 36, 4],
57
+ encoder_num_heads=2,
58
+ encoder_global_att_blocks=[23, 33, 43],
59
+ encoder_window_spec=[8, 4, 16, 8],
60
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
61
+ checkpoint=checkpoint,
62
+ )
63
+
64
+
65
+ def _build_sam2(
66
+ encoder_embed_dim=1280,
67
+ encoder_stages=[2, 6, 36, 4],
68
+ encoder_num_heads=2,
69
+ encoder_global_att_blocks=[7, 15, 23, 31],
70
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
71
+ encoder_window_spatial_size=[7, 7],
72
+ encoder_window_spec=[8, 4, 16, 8],
73
+ checkpoint=None,
74
+ ):
75
+ """Builds a SAM2 model with specified architecture parameters and optional checkpoint loading."""
76
+ image_encoder = ImageEncoder(
77
+ trunk=Hiera(
78
+ embed_dim=encoder_embed_dim,
79
+ num_heads=encoder_num_heads,
80
+ stages=encoder_stages,
81
+ global_att_blocks=encoder_global_att_blocks,
82
+ window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
83
+ window_spec=encoder_window_spec,
84
+ ),
85
+ neck=FpnNeck(
86
+ d_model=256,
87
+ backbone_channel_list=encoder_backbone_channel_list,
88
+ fpn_top_down_levels=[2, 3],
89
+ fpn_interp_model="nearest",
90
+ ),
91
+ scalp=1,
92
+ )
93
+ memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
94
+ memory_encoder = MemoryEncoder(out_dim=64)
95
+
96
+ sam2 = SAM2Model(
97
+ image_encoder=image_encoder,
98
+ memory_attention=memory_attention,
99
+ memory_encoder=memory_encoder,
100
+ num_maskmem=7,
101
+ image_size=1024,
102
+ sigmoid_scale_for_mem_enc=20.0,
103
+ sigmoid_bias_for_mem_enc=-10.0,
104
+ use_mask_input_as_output_without_sam=True,
105
+ directly_add_no_mem_embed=True,
106
+ use_high_res_features_in_sam=True,
107
+ multimask_output_in_sam=True,
108
+ iou_prediction_use_sigmoid=True,
109
+ use_obj_ptrs_in_encoder=True,
110
+ add_tpos_enc_to_obj_ptrs=True,
111
+ only_obj_ptrs_in_the_past_for_eval=True,
112
+ pred_obj_scores=True,
113
+ pred_obj_scores_mlp=True,
114
+ fixed_no_obj_ptr=True,
115
+ multimask_output_for_tracking=True,
116
+ use_multimask_token_for_obj_ptr=True,
117
+ multimask_min_pt_num=0,
118
+ multimask_max_pt_num=1,
119
+ use_mlp_for_obj_ptr_proj=True,
120
+ compile_image_encoder=False,
121
+ sam_mask_decoder_extra_args=dict(
122
+ dynamic_multimask_via_stability=True,
123
+ dynamic_multimask_stability_delta=0.05,
124
+ dynamic_multimask_stability_thresh=0.98,
125
+ ),
126
+ )
127
+
128
+ if checkpoint is not None:
129
+ checkpoint = attempt_download_asset(checkpoint)
130
+ with open(checkpoint, "rb") as f:
131
+ state_dict = torch.load(f)["model"]
132
+ sam2.load_state_dict(state_dict)
133
+ sam2.eval()
134
+ return sam2
135
+
136
+
137
+ sam_model_map = {
138
+ "sam2_t.pt": build_sam2_t,
139
+ "sam2_s.pt": build_sam2_s,
140
+ "sam2_b.pt": build_sam2_b,
141
+ "sam2_l.pt": build_sam2_l,
142
+ }
143
+
144
+
145
+ def build_sam2(ckpt="sam_b.pt"):
146
+ """Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options."""
147
+ model_builder = None
148
+ ckpt = str(ckpt) # to allow Path ckpt types
149
+ for k in sam_model_map.keys():
150
+ if ckpt.endswith(k):
151
+ model_builder = sam_model_map.get(k)
152
+
153
+ if not model_builder:
154
+ raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
155
+
156
+ return model_builder(ckpt)