ultralytics 8.2.72__py3-none-any.whl → 8.2.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (34) hide show
  1. ultralytics/__init__.py +2 -3
  2. ultralytics/cfg/trackers/botsort.yaml +1 -1
  3. ultralytics/cfg/trackers/bytetrack.yaml +1 -1
  4. ultralytics/models/__init__.py +1 -2
  5. ultralytics/models/sam/__init__.py +2 -2
  6. ultralytics/models/sam/amg.py +27 -21
  7. ultralytics/models/sam/build.py +200 -9
  8. ultralytics/models/sam/model.py +86 -34
  9. ultralytics/models/sam/modules/blocks.py +1131 -0
  10. ultralytics/models/sam/modules/decoders.py +390 -23
  11. ultralytics/models/sam/modules/encoders.py +508 -323
  12. ultralytics/models/{sam2 → sam}/modules/memory_attention.py +73 -6
  13. ultralytics/models/sam/modules/sam.py +887 -16
  14. ultralytics/models/sam/modules/tiny_encoder.py +376 -126
  15. ultralytics/models/sam/modules/transformer.py +155 -54
  16. ultralytics/models/{sam2 → sam}/modules/utils.py +105 -3
  17. ultralytics/models/sam/predict.py +382 -92
  18. ultralytics/trackers/bot_sort.py +2 -3
  19. ultralytics/trackers/byte_tracker.py +2 -3
  20. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/METADATA +44 -44
  21. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/RECORD +25 -33
  22. ultralytics/models/sam2/__init__.py +0 -6
  23. ultralytics/models/sam2/build.py +0 -156
  24. ultralytics/models/sam2/model.py +0 -97
  25. ultralytics/models/sam2/modules/__init__.py +0 -1
  26. ultralytics/models/sam2/modules/decoders.py +0 -305
  27. ultralytics/models/sam2/modules/encoders.py +0 -332
  28. ultralytics/models/sam2/modules/sam2.py +0 -804
  29. ultralytics/models/sam2/modules/sam2_blocks.py +0 -715
  30. ultralytics/models/sam2/predict.py +0 -177
  31. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/LICENSE +0 -0
  32. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/WHEEL +0 -0
  33. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/entry_points.txt +0 -0
  34. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/top_level.txt +0 -0
@@ -34,35 +34,64 @@ from .build import build_sam
34
34
 
35
35
  class Predictor(BasePredictor):
36
36
  """
37
- Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
37
+ Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
38
38
 
39
- The class provides an interface for model inference tailored to image segmentation tasks.
40
- With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
41
- mask generation. The class is capable of working with various types of prompts such as bounding boxes,
42
- points, and low-resolution masks.
39
+ This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
40
+ segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
41
+ fine-grained control over segmentation results.
43
42
 
44
43
  Attributes:
45
- cfg (dict): Configuration dictionary specifying model and task-related parameters.
46
- overrides (dict): Dictionary containing values that override the default configuration.
47
- _callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
48
- args (namespace): Namespace to hold command-line arguments or other operational variables.
49
- im (torch.Tensor): Preprocessed input image tensor.
50
- features (torch.Tensor): Extracted image features used for inference.
51
- prompts (dict): Collection of various prompt types, such as bounding boxes and points.
52
- segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
44
+ args (SimpleNamespace): Configuration arguments for the predictor.
45
+ model (torch.nn.Module): The loaded SAM model.
46
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
47
+ im (torch.Tensor): The preprocessed input image.
48
+ features (torch.Tensor): Extracted image features.
49
+ prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
50
+ segment_all (bool): Flag to indicate if full image segmentation should be performed.
51
+ mean (torch.Tensor): Mean values for image normalization.
52
+ std (torch.Tensor): Standard deviation values for image normalization.
53
+
54
+ Methods:
55
+ preprocess: Prepares input images for model inference.
56
+ pre_transform: Performs initial transformations on the input image.
57
+ inference: Performs segmentation inference based on input prompts.
58
+ prompt_inference: Internal function for prompt-based segmentation inference.
59
+ generate: Generates segmentation masks for an entire image.
60
+ setup_model: Initializes the SAM model for inference.
61
+ get_model: Builds and returns a SAM model.
62
+ postprocess: Post-processes model outputs to generate final results.
63
+ setup_source: Sets up the data source for inference.
64
+ set_image: Sets and preprocesses a single image for inference.
65
+ get_im_features: Extracts image features using the SAM image encoder.
66
+ set_prompts: Sets prompts for subsequent inference.
67
+ reset_image: Resets the current image and its features.
68
+ remove_small_regions: Removes small disconnected regions and holes from masks.
69
+
70
+ Examples:
71
+ >>> predictor = Predictor()
72
+ >>> predictor.setup_model(model_path='sam_model.pt')
73
+ >>> predictor.set_image('image.jpg')
74
+ >>> masks, scores, boxes = predictor.generate()
75
+ >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
53
76
  """
54
77
 
55
78
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
56
79
  """
57
80
  Initialize the Predictor with configuration, overrides, and callbacks.
58
81
 
59
- The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
60
- initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
82
+ Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
83
+ callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
84
+ for optimal results.
61
85
 
62
86
  Args:
63
- cfg (dict): Configuration dictionary.
64
- overrides (dict, optional): Dictionary of values to override default configuration.
65
- _callbacks (dict, optional): Dictionary of callback functions to customize behavior.
87
+ cfg (Dict): Configuration dictionary containing default settings.
88
+ overrides (Dict | None): Dictionary of values to override default configuration.
89
+ _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
90
+
91
+ Examples:
92
+ >>> predictor = Predictor(cfg=DEFAULT_CFG)
93
+ >>> predictor = Predictor(overrides={'imgsz': 640})
94
+ >>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
66
95
  """
67
96
  if overrides is None:
68
97
  overrides = {}
@@ -78,14 +107,19 @@ class Predictor(BasePredictor):
78
107
  """
79
108
  Preprocess the input image for model inference.
80
109
 
81
- The method prepares the input image by applying transformations and normalization.
82
- It supports both torch.Tensor and list of np.ndarray as input formats.
110
+ This method prepares the input image by applying transformations and normalization. It supports both
111
+ torch.Tensor and list of np.ndarray as input formats.
83
112
 
84
113
  Args:
85
- im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
114
+ im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
86
115
 
87
116
  Returns:
88
- (torch.Tensor): The preprocessed image tensor.
117
+ (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
118
+
119
+ Examples:
120
+ >>> predictor = Predictor()
121
+ >>> image = torch.rand(1, 3, 640, 640)
122
+ >>> preprocessed_image = predictor.preprocess(image)
89
123
  """
90
124
  if self.im is not None:
91
125
  return self.im
@@ -106,14 +140,24 @@ class Predictor(BasePredictor):
106
140
  """
107
141
  Perform initial transformations on the input image for preprocessing.
108
142
 
109
- The method applies transformations such as resizing to prepare the image for further preprocessing.
143
+ This method applies transformations such as resizing to prepare the image for further preprocessing.
110
144
  Currently, batched inference is not supported; hence the list length should be 1.
111
145
 
112
146
  Args:
113
- im (List[np.ndarray]): List containing images in HWC numpy array format.
147
+ im (List[np.ndarray]): List containing a single image in HWC numpy array format.
114
148
 
115
149
  Returns:
116
- (List[np.ndarray]): List of transformed images.
150
+ (List[np.ndarray]): List containing the transformed image.
151
+
152
+ Raises:
153
+ AssertionError: If the input list contains more than one image.
154
+
155
+ Examples:
156
+ >>> predictor = Predictor()
157
+ >>> image = np.random.rand(480, 640, 3) # Single HWC image
158
+ >>> transformed = predictor.pre_transform([image])
159
+ >>> print(len(transformed))
160
+ 1
117
161
  """
118
162
  assert len(im) == 1, "SAM model does not currently support batched inference"
119
163
  letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
@@ -121,23 +165,32 @@ class Predictor(BasePredictor):
121
165
 
122
166
  def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
123
167
  """
124
- Perform image segmentation inference based on the given input cues, using the currently loaded image. This
125
- method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
126
- mask decoder for real-time and promptable segmentation tasks.
168
+ Perform image segmentation inference based on the given input cues, using the currently loaded image.
169
+
170
+ This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
171
+ encoder, and mask decoder for real-time and promptable segmentation tasks.
127
172
 
128
173
  Args:
129
174
  im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
130
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
131
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
132
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
133
- masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
134
- multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
175
+ bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
176
+ points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
177
+ labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
178
+ masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
179
+ multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
180
+ *args (Any): Additional positional arguments.
181
+ **kwargs (Any): Additional keyword arguments.
135
182
 
136
183
  Returns:
137
- (tuple): Contains the following three elements.
138
- - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
184
+ (tuple): Contains the following three elements:
185
+ - np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
139
186
  - np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
140
- - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
187
+ - np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
188
+
189
+ Examples:
190
+ >>> predictor = Predictor()
191
+ >>> predictor.setup_model(model_path='sam_model.pt')
192
+ >>> predictor.set_image('image.jpg')
193
+ >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
141
194
  """
142
195
  # Override prompts if any stored in self.prompts
143
196
  bboxes = self.prompts.pop("bboxes", bboxes)
@@ -151,22 +204,30 @@ class Predictor(BasePredictor):
151
204
 
152
205
  def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
153
206
  """
154
- Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
155
- Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
207
+ Performs image segmentation inference based on input cues using SAM's specialized architecture.
208
+
209
+ This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
210
+ It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
156
211
 
157
212
  Args:
158
- im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
159
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
160
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
161
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
162
- masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
163
- multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
213
+ im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
214
+ bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
215
+ points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
216
+ labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
217
+ masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
218
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
164
219
 
165
220
  Returns:
166
- (tuple): Contains the following three elements.
167
- - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
168
- - np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
169
- - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
221
+ (tuple): Tuple containing:
222
+ - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
223
+ - np.ndarray: Quality scores predicted by the model for each mask, with length C.
224
+ - np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
225
+
226
+ Examples:
227
+ >>> predictor = Predictor()
228
+ >>> im = torch.rand(1, 3, 1024, 1024)
229
+ >>> bboxes = [[100, 100, 200, 200]]
230
+ >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
170
231
  """
171
232
  features = self.get_im_features(im) if self.features is None else self.features
172
233
 
@@ -224,27 +285,32 @@ class Predictor(BasePredictor):
224
285
  """
225
286
  Perform image segmentation using the Segment Anything Model (SAM).
226
287
 
227
- This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
288
+ This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
228
289
  and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
229
290
 
230
291
  Args:
231
- im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
232
- crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
233
- Each layer produces 2**i_layer number of image crops.
234
- crop_overlap_ratio (float): Determines the overlap between crops. Scaled down in subsequent layers.
235
- crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
236
- point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
237
- Used in the nth crop layer.
238
- points_stride (int, optional): Number of points to sample along each side of the image.
239
- Exclusive with 'point_grids'.
292
+ im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
293
+ crop_n_layers (int): Number of layers for additional mask predictions on image crops.
294
+ crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
295
+ crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
296
+ point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
297
+ points_stride (int): Number of points to sample along each side of the image.
240
298
  points_batch_size (int): Batch size for the number of points processed simultaneously.
241
- conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
242
- stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
299
+ conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
300
+ stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
243
301
  stability_score_offset (float): Offset value for calculating stability score.
244
302
  crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
245
303
 
246
304
  Returns:
247
- (tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
305
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
306
+ - pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
307
+ - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
308
+ - pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
309
+
310
+ Examples:
311
+ >>> predictor = Predictor()
312
+ >>> im = torch.rand(1, 3, 1024, 1024) # Example input image
313
+ >>> masks, scores, boxes = predictor.generate(im)
248
314
  """
249
315
  import torchvision # scope for faster 'import ultralytics'
250
316
 
@@ -326,11 +392,9 @@ class Predictor(BasePredictor):
326
392
  model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
327
393
  verbose (bool): If True, prints selected device information.
328
394
 
329
- Attributes:
330
- model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
331
- device (torch.device): The device to which the model and tensors are allocated.
332
- mean (torch.Tensor): The mean values for image normalization.
333
- std (torch.Tensor): The standard deviation values for image normalization.
395
+ Examples:
396
+ >>> predictor = Predictor()
397
+ >>> predictor.setup_model(model=sam_model, verbose=True)
334
398
  """
335
399
  device = select_device(self.args.device, verbose=verbose)
336
400
  if model is None:
@@ -349,23 +413,32 @@ class Predictor(BasePredictor):
349
413
  self.done_warmup = True
350
414
 
351
415
  def get_model(self):
352
- """Built Segment Anything Model (SAM) model."""
416
+ """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
353
417
  return build_sam(self.args.model)
354
418
 
355
419
  def postprocess(self, preds, img, orig_imgs):
356
420
  """
357
421
  Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
358
422
 
359
- The method scales masks and boxes to the original image size and applies a threshold to the mask predictions.
360
- The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
423
+ This method scales masks and boxes to the original image size and applies a threshold to the mask
424
+ predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
361
425
 
362
426
  Args:
363
- preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
364
- img (torch.Tensor): The processed input image tensor.
365
- orig_imgs (list | torch.Tensor): The original, unprocessed images.
427
+ preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
428
+ - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
429
+ - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
430
+ - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
431
+ img (torch.Tensor): The processed input image tensor with shape (C, H, W).
432
+ orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
366
433
 
367
434
  Returns:
368
- (list): List of Results objects containing detection masks, bounding boxes, and other metadata.
435
+ (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
436
+ metadata for each processed image.
437
+
438
+ Examples:
439
+ >>> predictor = Predictor()
440
+ >>> preds = predictor.inference(img)
441
+ >>> results = predictor.postprocess(preds, img, orig_imgs)
369
442
  """
370
443
  # (N, 1, H, W), (N, 1)
371
444
  pred_masks, pred_scores = preds[:2]
@@ -393,11 +466,23 @@ class Predictor(BasePredictor):
393
466
  """
394
467
  Sets up the data source for inference.
395
468
 
396
- This method configures the data source from which images will be fetched for inference. The source could be a
397
- directory, a video file, or other types of image data sources.
469
+ This method configures the data source from which images will be fetched for inference. It supports
470
+ various input types such as image files, directories, video files, and other compatible data sources.
398
471
 
399
472
  Args:
400
- source (str | Path): The path to the image data source for inference.
473
+ source (str | Path | None): The path or identifier for the image data source. Can be a file path,
474
+ directory path, URL, or other supported source types.
475
+
476
+ Examples:
477
+ >>> predictor = Predictor()
478
+ >>> predictor.setup_source('path/to/images')
479
+ >>> predictor.setup_source('video.mp4')
480
+ >>> predictor.setup_source(None) # Uses default source if available
481
+
482
+ Notes:
483
+ - If source is None, the method may use a default source if configured.
484
+ - The method adapts to different source types and prepares them for subsequent inference steps.
485
+ - Supported source types may include local files, directories, URLs, and video streams.
401
486
  """
402
487
  if source is not None:
403
488
  super().setup_source(source)
@@ -406,14 +491,25 @@ class Predictor(BasePredictor):
406
491
  """
407
492
  Preprocesses and sets a single image for inference.
408
493
 
409
- This function sets up the model if not already initialized, configures the data source to the specified image,
410
- and preprocesses the image for feature extraction. Only one image can be set at a time.
494
+ This method prepares the model for inference on a single image by setting up the model if not already
495
+ initialized, configuring the data source, and preprocessing the image for feature extraction. It
496
+ ensures that only one image is set at a time and extracts image features for subsequent use.
411
497
 
412
498
  Args:
413
- image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
499
+ image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
500
+ an image read by cv2.
414
501
 
415
502
  Raises:
416
- AssertionError: If more than one image is set.
503
+ AssertionError: If more than one image is attempted to be set.
504
+
505
+ Examples:
506
+ >>> predictor = Predictor()
507
+ >>> predictor.set_image('path/to/image.jpg')
508
+ >>> predictor.set_image(cv2.imread('path/to/image.jpg'))
509
+
510
+ Notes:
511
+ - This method should be called before performing inference on a new image.
512
+ - The extracted features are stored in the `self.features` attribute for later use.
417
513
  """
418
514
  if self.model is None:
419
515
  self.setup_model(model=None)
@@ -425,35 +521,44 @@ class Predictor(BasePredictor):
425
521
  break
426
522
 
427
523
  def get_im_features(self, im):
428
- """Get image features from the SAM image encoder."""
524
+ """Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
429
525
  return self.model.image_encoder(im)
430
526
 
431
527
  def set_prompts(self, prompts):
432
- """Set prompts in advance."""
528
+ """Sets prompts for subsequent inference operations."""
433
529
  self.prompts = prompts
434
530
 
435
531
  def reset_image(self):
436
- """Resets the image and its features to None."""
532
+ """Resets the current image and its features, clearing them for subsequent inference."""
437
533
  self.im = None
438
534
  self.features = None
439
535
 
440
536
  @staticmethod
441
537
  def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
442
538
  """
443
- Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
444
- function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
539
+ Remove small disconnected regions and holes from segmentation masks.
540
+
541
+ This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
542
+ It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
445
543
  Suppression (NMS) to eliminate any newly created duplicate boxes.
446
544
 
447
545
  Args:
448
- masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
449
- the number of masks, H is height, and W is width.
450
- min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
451
- nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
546
+ masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
547
+ masks, H is height, and W is width.
548
+ min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
549
+ this will be removed.
550
+ nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
452
551
 
453
552
  Returns:
454
- (tuple([torch.Tensor, List[int]])):
455
- - new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
456
- - keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
553
+ (tuple):
554
+ - new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
555
+ - keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
556
+
557
+ Examples:
558
+ >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
559
+ >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
560
+ >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
561
+ >>> print(f"Indices of kept masks: {keep}")
457
562
  """
458
563
  import torchvision # scope for faster 'import ultralytics'
459
564
 
@@ -480,3 +585,188 @@ class Predictor(BasePredictor):
480
585
  keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
481
586
 
482
587
  return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
588
+
589
+
590
+ class SAM2Predictor(Predictor):
591
+ """
592
+ SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
593
+
594
+ This class extends the base Predictor class to implement SAM2-specific functionality for image
595
+ segmentation tasks. It provides methods for model initialization, feature extraction, and
596
+ prompt-based inference.
597
+
598
+ Attributes:
599
+ _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
600
+ model (torch.nn.Module): The loaded SAM2 model.
601
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
602
+ features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
603
+ segment_all (bool): Flag to indicate if all segments should be predicted.
604
+ prompts (Dict): Dictionary to store various types of prompts for inference.
605
+
606
+ Methods:
607
+ get_model: Retrieves and initializes the SAM2 model.
608
+ prompt_inference: Performs image segmentation inference based on various prompts.
609
+ set_image: Preprocesses and sets a single image for inference.
610
+ get_im_features: Extracts and processes image features using SAM2's image encoder.
611
+
612
+ Examples:
613
+ >>> predictor = SAM2Predictor(cfg)
614
+ >>> predictor.set_image("path/to/image.jpg")
615
+ >>> bboxes = [[100, 100, 200, 200]]
616
+ >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes)
617
+ >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}")
618
+ """
619
+
620
+ _bb_feat_sizes = [
621
+ (256, 256),
622
+ (128, 128),
623
+ (64, 64),
624
+ ]
625
+
626
+ def get_model(self):
627
+ """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
628
+ return build_sam(self.args.model)
629
+
630
+ def prompt_inference(
631
+ self,
632
+ im,
633
+ bboxes=None,
634
+ points=None,
635
+ labels=None,
636
+ masks=None,
637
+ multimask_output=False,
638
+ img_idx=-1,
639
+ ):
640
+ """
641
+ Performs image segmentation inference based on various prompts using SAM2 architecture.
642
+
643
+ This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
644
+ based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
645
+ multi-object prediction scenarios.
646
+
647
+ Args:
648
+ im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
649
+ bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
650
+ points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
651
+ labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
652
+ masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
653
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
654
+ img_idx (int): Index of the image in the batch to process.
655
+
656
+ Returns:
657
+ (tuple): Tuple containing:
658
+ - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
659
+ - np.ndarray: Quality scores for each mask, with length C.
660
+ - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
661
+
662
+ Examples:
663
+ >>> predictor = SAM2Predictor(cfg)
664
+ >>> image = torch.rand(1, 3, 640, 640)
665
+ >>> bboxes = [[100, 100, 200, 200]]
666
+ >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
667
+ >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}")
668
+
669
+ Notes:
670
+ - The method supports batched inference for multiple objects when points or bboxes are provided.
671
+ - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
672
+ - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
673
+
674
+ References:
675
+ - SAM2 Paper: [Add link to SAM2 paper when available]
676
+ """
677
+ features = self.get_im_features(im) if self.features is None else self.features
678
+
679
+ src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
680
+ r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
681
+ # Transform input prompts
682
+ if points is not None:
683
+ points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
684
+ points = points[None] if points.ndim == 1 else points
685
+ # Assuming labels are all positive if users don't pass labels.
686
+ if labels is None:
687
+ labels = torch.ones(points.shape[0])
688
+ labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
689
+ points *= r
690
+ # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
691
+ points, labels = points[:, None], labels[:, None]
692
+ if bboxes is not None:
693
+ bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
694
+ bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
695
+ bboxes = bboxes.view(-1, 2, 2) * r
696
+ bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
697
+ # NOTE: merge "boxes" and "points" into a single "points" input
698
+ # (where boxes are added at the beginning) to model.sam_prompt_encoder
699
+ if points is not None:
700
+ points = torch.cat([bboxes, points], dim=1)
701
+ labels = torch.cat([bbox_labels, labels], dim=1)
702
+ else:
703
+ points, labels = bboxes, bbox_labels
704
+ if masks is not None:
705
+ masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
706
+
707
+ points = (points, labels) if points is not None else None
708
+
709
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
710
+ points=points,
711
+ boxes=None,
712
+ masks=masks,
713
+ )
714
+ # Predict masks
715
+ batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
716
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
717
+ pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
718
+ image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
719
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
720
+ sparse_prompt_embeddings=sparse_embeddings,
721
+ dense_prompt_embeddings=dense_embeddings,
722
+ multimask_output=multimask_output,
723
+ repeat_image=batched_mode,
724
+ high_res_features=high_res_features,
725
+ )
726
+ # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
727
+ # `d` could be 1 or 3 depends on `multimask_output`.
728
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
729
+
730
+ def set_image(self, image):
731
+ """
732
+ Preprocesses and sets a single image for inference using the SAM2 model.
733
+
734
+ This method initializes the model if not already done, configures the data source to the specified image,
735
+ and preprocesses the image for feature extraction. It supports setting only one image at a time.
736
+
737
+ Args:
738
+ image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
739
+
740
+ Raises:
741
+ AssertionError: If more than one image is attempted to be set.
742
+
743
+ Examples:
744
+ >>> predictor = SAM2Predictor()
745
+ >>> predictor.set_image("path/to/image.jpg")
746
+ >>> predictor.set_image(np.array([...])) # Using a numpy array
747
+
748
+ Notes:
749
+ - This method must be called before performing any inference on a new image.
750
+ - The method caches the extracted features for efficient subsequent inferences on the same image.
751
+ - Only one image can be set at a time. To process multiple images, call this method for each new image.
752
+ """
753
+ if self.model is None:
754
+ self.setup_model(model=None)
755
+ self.setup_source(image)
756
+ assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
757
+ for batch in self.dataset:
758
+ im = self.preprocess(batch[1])
759
+ self.features = self.get_im_features(im)
760
+ break
761
+
762
+ def get_im_features(self, im):
763
+ """Extracts image features from the SAM image encoder for subsequent processing."""
764
+ backbone_out = self.model.forward_image(im)
765
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
766
+ if self.model.directly_add_no_mem_embed:
767
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
768
+ feats = [
769
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
770
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
771
+ ][::-1]
772
+ return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
@@ -179,9 +179,8 @@ class BOTSORT(BYTETracker):
179
179
  dists = matching.iou_distance(tracks, detections)
180
180
  dists_mask = dists > self.proximity_thresh
181
181
 
182
- # TODO: mot20
183
- # if not self.args.mot20:
184
- dists = matching.fuse_score(dists, detections)
182
+ if self.args.fuse_score:
183
+ dists = matching.fuse_score(dists, detections)
185
184
 
186
185
  if self.args.with_reid and self.encoder is not None:
187
186
  emb_dists = matching.embedding_distance(tracks, detections) / 2.0