dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -8,8 +8,12 @@ using SAM. It forms an integral part of the Ultralytics framework and is designe
8
8
  segmentation tasks.
9
9
  """
10
10
 
11
+ from __future__ import annotations
12
+
11
13
  from collections import OrderedDict
14
+ from typing import Any
12
15
 
16
+ import cv2
13
17
  import numpy as np
14
18
  import torch
15
19
  import torch.nn.functional as F
@@ -34,12 +38,11 @@ from .amg import (
34
38
 
35
39
 
36
40
  class Predictor(BasePredictor):
37
- """
38
- Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
41
+ """Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
39
42
 
40
- This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
41
- segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
42
- fine-grained control over segmentation results.
43
+ This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image segmentation
44
+ tasks. It supports various input prompts like points, bounding boxes, and masks for fine-grained control over
45
+ segmentation results.
43
46
 
44
47
  Attributes:
45
48
  args (SimpleNamespace): Configuration arguments for the predictor.
@@ -47,26 +50,26 @@ class Predictor(BasePredictor):
47
50
  device (torch.device): The device (CPU or GPU) on which the model is loaded.
48
51
  im (torch.Tensor): The preprocessed input image.
49
52
  features (torch.Tensor): Extracted image features.
50
- prompts (dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
53
+ prompts (dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
51
54
  segment_all (bool): Flag to indicate if full image segmentation should be performed.
52
55
  mean (torch.Tensor): Mean values for image normalization.
53
56
  std (torch.Tensor): Standard deviation values for image normalization.
54
57
 
55
58
  Methods:
56
- preprocess: Prepares input images for model inference.
57
- pre_transform: Performs initial transformations on the input image.
58
- inference: Performs segmentation inference based on input prompts.
59
+ preprocess: Prepare input images for model inference.
60
+ pre_transform: Perform initial transformations on the input image.
61
+ inference: Perform segmentation inference based on input prompts.
59
62
  prompt_inference: Internal function for prompt-based segmentation inference.
60
- generate: Generates segmentation masks for an entire image.
61
- setup_model: Initializes the SAM model for inference.
62
- get_model: Builds and returns a SAM model.
63
- postprocess: Post-processes model outputs to generate final results.
64
- setup_source: Sets up the data source for inference.
65
- set_image: Sets and preprocesses a single image for inference.
66
- get_im_features: Extracts image features using the SAM image encoder.
67
- set_prompts: Sets prompts for subsequent inference.
68
- reset_image: Resets the current image and its features.
69
- remove_small_regions: Removes small disconnected regions and holes from masks.
63
+ generate: Generate segmentation masks for an entire image.
64
+ setup_model: Initialize the SAM model for inference.
65
+ get_model: Build and return a SAM model.
66
+ postprocess: Post-process model outputs to generate final results.
67
+ setup_source: Set up the data source for inference.
68
+ set_image: Set and preprocess a single image for inference.
69
+ get_im_features: Extract image features using the SAM image encoder.
70
+ set_prompts: Set prompts for subsequent inference.
71
+ reset_image: Reset the current image and its features.
72
+ remove_small_regions: Remove small disconnected regions and holes from masks.
70
73
 
71
74
  Examples:
72
75
  >>> predictor = Predictor()
@@ -77,17 +80,16 @@ class Predictor(BasePredictor):
77
80
  """
78
81
 
79
82
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
80
- """
81
- Initialize the Predictor with configuration, overrides, and callbacks.
83
+ """Initialize the Predictor with configuration, overrides, and callbacks.
82
84
 
83
85
  Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
84
- callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
85
- for optimal results.
86
+ callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True for
87
+ optimal results.
86
88
 
87
89
  Args:
88
90
  cfg (dict): Configuration dictionary containing default settings.
89
- overrides (Dict | None): Dictionary of values to override default configuration.
90
- _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
91
+ overrides (dict | None): Dictionary of values to override default configuration.
92
+ _callbacks (dict | None): Dictionary of callback functions to customize behavior.
91
93
 
92
94
  Examples:
93
95
  >>> predictor_example = Predictor(cfg=DEFAULT_CFG)
@@ -105,17 +107,16 @@ class Predictor(BasePredictor):
105
107
  self.segment_all = False
106
108
 
107
109
  def preprocess(self, im):
108
- """
109
- Preprocess the input image for model inference.
110
+ """Preprocess the input image for model inference.
110
111
 
111
112
  This method prepares the input image by applying transformations and normalization. It supports both
112
113
  torch.Tensor and list of np.ndarray as input formats.
113
114
 
114
115
  Args:
115
- im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
116
+ im (torch.Tensor | list[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
116
117
 
117
118
  Returns:
118
- im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
119
+ (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
119
120
 
120
121
  Examples:
121
122
  >>> predictor = Predictor()
@@ -132,23 +133,22 @@ class Predictor(BasePredictor):
132
133
  im = torch.from_numpy(im)
133
134
 
134
135
  im = im.to(self.device)
135
- im = im.half() if self.model.fp16 else im.float()
136
136
  if not_tensor:
137
137
  im = (im - self.mean) / self.std
138
+ im = im.half() if self.model.fp16 else im.float()
138
139
  return im
139
140
 
140
141
  def pre_transform(self, im):
141
- """
142
- Perform initial transformations on the input image for preprocessing.
142
+ """Perform initial transformations on the input image for preprocessing.
143
143
 
144
- This method applies transformations such as resizing to prepare the image for further preprocessing.
145
- Currently, batched inference is not supported; hence the list length should be 1.
144
+ This method applies transformations such as resizing to prepare the image for further preprocessing. Currently,
145
+ batched inference is not supported; hence the list length should be 1.
146
146
 
147
147
  Args:
148
- im (List[np.ndarray]): List containing a single image in HWC numpy array format.
148
+ im (list[np.ndarray]): List containing a single image in HWC numpy array format.
149
149
 
150
150
  Returns:
151
- (List[np.ndarray]): List containing the transformed image.
151
+ (list[np.ndarray]): List containing the transformed image.
152
152
 
153
153
  Raises:
154
154
  AssertionError: If the input list contains more than one image.
@@ -165,26 +165,25 @@ class Predictor(BasePredictor):
165
165
  return [letterbox(image=x) for x in im]
166
166
 
167
167
  def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
168
- """
169
- Perform image segmentation inference based on the given input cues, using the currently loaded image.
168
+ """Perform image segmentation inference based on the given input cues, using the currently loaded image.
170
169
 
171
- This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
172
- encoder, and mask decoder for real-time and promptable segmentation tasks.
170
+ This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder,
171
+ and mask decoder for real-time and promptable segmentation tasks.
173
172
 
174
173
  Args:
175
174
  im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
176
- bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
177
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
178
- labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
175
+ bboxes (np.ndarray | list | None): Bounding boxes with shape (N, 4), in XYXY format.
176
+ points (np.ndarray | list | None): Points indicating object locations with shape (N, 2), in pixels.
177
+ labels (np.ndarray | list | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
179
178
  masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
180
179
  multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
181
180
  *args (Any): Additional positional arguments.
182
181
  **kwargs (Any): Additional keyword arguments.
183
182
 
184
183
  Returns:
185
- (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
186
- (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
187
- (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
184
+ pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks.
185
+ pred_scores (torch.Tensor): An array of length C containing quality scores predicted by the model for each
186
+ mask.
188
187
 
189
188
  Examples:
190
189
  >>> predictor = Predictor()
@@ -204,26 +203,24 @@ class Predictor(BasePredictor):
204
203
  return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
205
204
 
206
205
  def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
207
- """
208
- Performs image segmentation inference based on input cues using SAM's specialized architecture.
206
+ """Perform image segmentation inference based on input cues using SAM's specialized architecture.
209
207
 
210
- This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
211
- It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
208
+ This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. It
209
+ processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
212
210
 
213
211
  Args:
214
212
  im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
215
- bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
216
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
217
- labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
213
+ bboxes (np.ndarray | list | None): Bounding boxes in XYXY format with shape (N, 4).
214
+ points (np.ndarray | list | None): Points indicating object locations with shape (N, 2) or (N, num_points,
215
+ 2), in pixels.
216
+ labels (np.ndarray | list | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground,
217
+ 0 for background.
218
218
  masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
219
219
  multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
220
220
 
221
- Raises:
222
- AssertionError: If the number of points don't match the number of labels, in case labels were passed.
223
-
224
221
  Returns:
225
- (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
226
- (np.ndarray): Quality scores predicted by the model for each mask, with length C.
222
+ pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
223
+ pred_scores (torch.Tensor): Quality scores predicted by the model for each mask, with length C.
227
224
 
228
225
  Examples:
229
226
  >>> predictor = Predictor()
@@ -233,7 +230,32 @@ class Predictor(BasePredictor):
233
230
  """
234
231
  features = self.get_im_features(im) if self.features is None else self.features
235
232
 
236
- bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
233
+ prompts = self._prepare_prompts(im.shape[2:], self.batch[1][0].shape[:2], bboxes, points, labels, masks)
234
+ return self._inference_features(features, *prompts, multimask_output)
235
+
236
+ def _inference_features(
237
+ self,
238
+ features,
239
+ bboxes=None,
240
+ points=None,
241
+ labels=None,
242
+ masks=None,
243
+ multimask_output=False,
244
+ ):
245
+ """Perform inference on image features using the SAM model.
246
+
247
+ Args:
248
+ features (torch.Tensor): Extracted image features with shape (B, C, H, W) from the SAM model image encoder.
249
+ bboxes (np.ndarray | list[list[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
250
+ points (np.ndarray | list[list[float]] | None): Object location points with shape (N, 2), in pixels.
251
+ labels (np.ndarray | list[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
252
+ masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
253
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
254
+
255
+ Returns:
256
+ pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
257
+ pred_scores (torch.Tensor): Quality scores for each mask, with length C.
258
+ """
237
259
  points = (points, labels) if points is not None else None
238
260
  # Embed prompts
239
261
  sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
@@ -251,28 +273,33 @@ class Predictor(BasePredictor):
251
273
  # `d` could be 1 or 3 depends on `multimask_output`.
252
274
  return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
253
275
 
254
- def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
255
- """
256
- Prepares and transforms the input prompts for processing based on the destination shape.
276
+ def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
277
+ """Prepare and transform the input prompts for processing based on the destination shape.
257
278
 
258
279
  Args:
259
- dst_shape (tuple): The target shape (height, width) for the prompts.
260
- bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
261
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
262
- labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
263
- masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
280
+ dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
281
+ src_shape (tuple[int, int]): The source shape (height, width) of the input image.
282
+ bboxes (np.ndarray | list | None): Bounding boxes in XYXY format with shape (N, 4).
283
+ points (np.ndarray | list | None): Points indicating object locations with shape (N, 2) or (N, num_points,
284
+ 2), in pixels.
285
+ labels (np.ndarray | list | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground,
286
+ 0 for background.
287
+ masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array with
288
+ shape (H, W).
289
+
290
+ Returns:
291
+ bboxes (torch.Tensor | None): Transformed bounding boxes.
292
+ points (torch.Tensor | None): Transformed points.
293
+ labels (torch.Tensor | None): Transformed labels.
294
+ masks (torch.Tensor | None): Transformed masks.
264
295
 
265
296
  Raises:
266
297
  AssertionError: If the number of points don't match the number of labels, in case labels were passed.
267
-
268
- Returns:
269
- (tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
270
298
  """
271
- src_shape = self.batch[1][0].shape[:2]
272
299
  r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
273
300
  # Transform input prompts
274
301
  if points is not None:
275
- points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
302
+ points = torch.as_tensor(points, dtype=self.torch_dtype, device=self.device)
276
303
  points = points[None] if points.ndim == 1 else points
277
304
  # Assuming labels are all positive if users don't pass labels.
278
305
  if labels is None:
@@ -286,11 +313,15 @@ class Predictor(BasePredictor):
286
313
  # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
287
314
  points, labels = points[:, None, :], labels[:, None]
288
315
  if bboxes is not None:
289
- bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
316
+ bboxes = torch.as_tensor(bboxes, dtype=self.torch_dtype, device=self.device)
290
317
  bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
291
318
  bboxes *= r
292
319
  if masks is not None:
293
- masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
320
+ masks = np.asarray(masks, dtype=np.uint8)
321
+ masks = masks[None] if masks.ndim == 2 else masks
322
+ letterbox = LetterBox(dst_shape, auto=False, center=False, padding_value=0, interpolation=cv2.INTER_NEAREST)
323
+ masks = np.stack([letterbox(image=x).squeeze() for x in masks], axis=0)
324
+ masks = torch.tensor(masks, dtype=self.torch_dtype, device=self.device)
294
325
  return bboxes, points, labels, masks
295
326
 
296
327
  def generate(
@@ -307,18 +338,17 @@ class Predictor(BasePredictor):
307
338
  stability_score_offset=0.95,
308
339
  crop_nms_thresh=0.7,
309
340
  ):
310
- """
311
- Perform image segmentation using the Segment Anything Model (SAM).
341
+ """Perform image segmentation using the Segment Anything Model (SAM).
312
342
 
313
- This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
314
- and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
343
+ This method segments an entire image into constituent parts by leveraging SAM's advanced architecture and
344
+ real-time performance capabilities. It can optionally work on image crops for finer segmentation.
315
345
 
316
346
  Args:
317
347
  im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
318
348
  crop_n_layers (int): Number of layers for additional mask predictions on image crops.
319
349
  crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
320
350
  crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
321
- point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
351
+ point_grids (list[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
322
352
  points_stride (int): Number of points to sample along each side of the image.
323
353
  points_batch_size (int): Batch size for the number of points processed simultaneously.
324
354
  conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
@@ -390,7 +420,7 @@ class Predictor(BasePredictor):
390
420
  pred_masks.append(crop_masks)
391
421
  pred_bboxes.append(crop_bboxes)
392
422
  pred_scores.append(crop_scores)
393
- region_areas.append(area.expand(len(crop_masks)))
423
+ region_areas.append(area.expand(crop_masks.shape[0]))
394
424
 
395
425
  pred_masks = torch.cat(pred_masks)
396
426
  pred_bboxes = torch.cat(pred_bboxes)
@@ -406,8 +436,7 @@ class Predictor(BasePredictor):
406
436
  return pred_masks, pred_scores, pred_bboxes
407
437
 
408
438
  def setup_model(self, model=None, verbose=True):
409
- """
410
- Initializes the Segment Anything Model (SAM) for inference.
439
+ """Initialize the Segment Anything Model (SAM) for inference.
411
440
 
412
441
  This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
413
442
  parameters for image normalization and other Ultralytics compatibility settings.
@@ -424,7 +453,8 @@ class Predictor(BasePredictor):
424
453
  if model is None:
425
454
  model = self.get_model()
426
455
  model.eval()
427
- self.model = model.to(device)
456
+ model = model.to(device)
457
+ self.model = model.half() if self.args.half else model.float()
428
458
  self.device = device
429
459
  self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
430
460
  self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
@@ -433,33 +463,33 @@ class Predictor(BasePredictor):
433
463
  self.model.pt = False
434
464
  self.model.triton = False
435
465
  self.model.stride = 32
436
- self.model.fp16 = False
466
+ self.model.fp16 = self.args.half
437
467
  self.done_warmup = True
468
+ self.torch_dtype = torch.float16 if self.model.fp16 else torch.float32
438
469
 
439
470
  def get_model(self):
440
- """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
471
+ """Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks."""
441
472
  from .build import build_sam # slow import
442
473
 
443
474
  return build_sam(self.args.model)
444
475
 
445
476
  def postprocess(self, preds, img, orig_imgs):
446
- """
447
- Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
477
+ """Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
448
478
 
449
479
  This method scales masks and boxes to the original image size and applies a threshold to the mask
450
480
  predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
451
481
 
452
482
  Args:
453
- preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
483
+ preds (tuple): The output from SAM model inference, containing:
454
484
  - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
455
485
  - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
456
486
  - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
457
487
  img (torch.Tensor): The processed input image tensor with shape (C, H, W).
458
- orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
488
+ orig_imgs (list[np.ndarray] | torch.Tensor): The original, unprocessed images.
459
489
 
460
490
  Returns:
461
- results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
462
- metadata for each processed image.
491
+ (list[Results]): List of Results objects containing detection masks, bounding boxes, and other metadata for
492
+ each processed image.
463
493
 
464
494
  Examples:
465
495
  >>> predictor = Predictor()
@@ -469,14 +499,14 @@ class Predictor(BasePredictor):
469
499
  # (N, 1, H, W), (N, 1)
470
500
  pred_masks, pred_scores = preds[:2]
471
501
  pred_bboxes = preds[2] if self.segment_all else None
472
- names = dict(enumerate(str(i) for i in range(len(pred_masks))))
502
+ names = dict(enumerate(str(i) for i in range(pred_masks.shape[0])))
473
503
 
474
504
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
475
505
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
476
506
 
477
507
  results = []
478
508
  for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
479
- if len(masks) == 0:
509
+ if masks.shape[0] == 0:
480
510
  masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
481
511
  else:
482
512
  masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
@@ -486,23 +516,24 @@ class Predictor(BasePredictor):
486
516
  else:
487
517
  pred_bboxes = batched_mask_to_box(masks)
488
518
  # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
489
- cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
490
- pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
519
+ cls = torch.arange(pred_masks.shape[0], dtype=torch.int32, device=pred_masks.device)
520
+ idx = pred_scores > self.args.conf
521
+ pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)[idx]
522
+ masks = masks[idx]
491
523
  results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
492
524
  # Reset segment-all mode.
493
525
  self.segment_all = False
494
526
  return results
495
527
 
496
528
  def setup_source(self, source):
497
- """
498
- Sets up the data source for inference.
529
+ """Set up the data source for inference.
499
530
 
500
- This method configures the data source from which images will be fetched for inference. It supports
501
- various input types such as image files, directories, video files, and other compatible data sources.
531
+ This method configures the data source from which images will be fetched for inference. It supports various
532
+ input types such as image files, directories, video files, and other compatible data sources.
502
533
 
503
534
  Args:
504
- source (str | Path | None): The path or identifier for the image data source. Can be a file path,
505
- directory path, URL, or other supported source types.
535
+ source (str | Path | None): The path or identifier for the image data source. Can be a file path, directory
536
+ path, URL, or other supported source types.
506
537
 
507
538
  Examples:
508
539
  >>> predictor = Predictor()
@@ -519,16 +550,15 @@ class Predictor(BasePredictor):
519
550
  super().setup_source(source)
520
551
 
521
552
  def set_image(self, image):
522
- """
523
- Preprocesses and sets a single image for inference.
553
+ """Preprocess and set a single image for inference.
524
554
 
525
555
  This method prepares the model for inference on a single image by setting up the model if not already
526
- initialized, configuring the data source, and preprocessing the image for feature extraction. It
527
- ensures that only one image is set at a time and extracts image features for subsequent use.
556
+ initialized, configuring the data source, and preprocessing the image for feature extraction. It ensures that
557
+ only one image is set at a time and extracts image features for subsequent use.
528
558
 
529
559
  Args:
530
- image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
531
- an image read by cv2.
560
+ image (str | np.ndarray): Path to the image file as a string, or a numpy array representing an image read by
561
+ cv2.
532
562
 
533
563
  Raises:
534
564
  AssertionError: If more than one image is attempted to be set.
@@ -543,7 +573,7 @@ class Predictor(BasePredictor):
543
573
  - The extracted features are stored in the `self.features` attribute for later use.
544
574
  """
545
575
  if self.model is None:
546
- self.setup_model(model=None)
576
+ self.setup_model()
547
577
  self.setup_source(image)
548
578
  assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
549
579
  for batch in self.dataset:
@@ -552,7 +582,7 @@ class Predictor(BasePredictor):
552
582
  break
553
583
 
554
584
  def get_im_features(self, im):
555
- """Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
585
+ """Extract image features using the SAM model's image encoder for subsequent mask prediction."""
556
586
  assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
557
587
  f"SAM models only support square image size, but got {self.imgsz}."
558
588
  )
@@ -560,22 +590,21 @@ class Predictor(BasePredictor):
560
590
  return self.model.image_encoder(im)
561
591
 
562
592
  def set_prompts(self, prompts):
563
- """Sets prompts for subsequent inference operations."""
593
+ """Set prompts for subsequent inference operations."""
564
594
  self.prompts = prompts
565
595
 
566
596
  def reset_image(self):
567
- """Resets the current image and its features, clearing them for subsequent inference."""
597
+ """Reset the current image and its features, clearing them for subsequent inference."""
568
598
  self.im = None
569
599
  self.features = None
570
600
 
571
601
  @staticmethod
572
602
  def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
573
- """
574
- Remove small disconnected regions and holes from segmentation masks.
603
+ """Remove small disconnected regions and holes from segmentation masks.
575
604
 
576
- This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
577
- It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
578
- Suppression (NMS) to eliminate any newly created duplicate boxes.
605
+ This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). It
606
+ removes small disconnected regions and holes from the input masks, and then performs Non-Maximum Suppression
607
+ (NMS) to eliminate any newly created duplicate boxes.
579
608
 
580
609
  Args:
581
610
  masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
@@ -586,7 +615,7 @@ class Predictor(BasePredictor):
586
615
 
587
616
  Returns:
588
617
  new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
589
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
618
+ keep (list[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
590
619
 
591
620
  Examples:
592
621
  >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
@@ -596,7 +625,7 @@ class Predictor(BasePredictor):
596
625
  """
597
626
  import torchvision # scope for faster 'import ultralytics'
598
627
 
599
- if len(masks) == 0:
628
+ if masks.shape[0] == 0:
600
629
  return masks
601
630
 
602
631
  # Filter small disconnected regions and holes
@@ -620,28 +649,74 @@ class Predictor(BasePredictor):
620
649
 
621
650
  return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
622
651
 
652
+ @smart_inference_mode()
653
+ def inference_features(
654
+ self,
655
+ features,
656
+ src_shape,
657
+ dst_shape=None,
658
+ bboxes=None,
659
+ points=None,
660
+ labels=None,
661
+ masks=None,
662
+ multimask_output=False,
663
+ ):
664
+ """Perform prompts preprocessing and inference on provided image features using the SAM model.
665
+
666
+ Args:
667
+ features (torch.Tensor | dict[str, Any]): Extracted image features from the SAM/SAM2 model image encoder.
668
+ src_shape (tuple[int, int]): The source shape (height, width) of the input image.
669
+ dst_shape (tuple[int, int] | None): The target shape (height, width) for the prompts. If None, defaults to
670
+ (imgsz, imgsz).
671
+ bboxes (np.ndarray | list[list[float]] | None): Bounding boxes in xyxy format with shape (N, 4).
672
+ points (np.ndarray | list[list[float]] | None): Points indicating object locations with shape (N, 2), in
673
+ pixels.
674
+ labels (np.ndarray | list[int] | None): Point prompt labels with shape (N, ).
675
+ masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
676
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
677
+
678
+ Returns:
679
+ pred_masks (torch.Tensor): The output masks in shape (C, H, W), where C is the number of generated masks.
680
+ pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 6), where N is the number of boxes.
681
+ Each box is in xyxy format with additional columns for score and class.
682
+
683
+ Notes:
684
+ - The input features is a torch.Tensor of shape (B, C, H, W) if performing on SAM, or a dict[str, Any] if performing on SAM2.
685
+ """
686
+ dst_shape = dst_shape or (self.args.imgsz, self.args.imgsz)
687
+ prompts = self._prepare_prompts(dst_shape, src_shape, bboxes, points, labels, masks)
688
+ pred_masks, pred_scores = self._inference_features(features, *prompts, multimask_output)
689
+ if pred_masks.shape[0] == 0:
690
+ pred_masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
691
+ else:
692
+ pred_masks = ops.scale_masks(pred_masks[None].float(), src_shape, padding=False)[0]
693
+ pred_masks = pred_masks > self.model.mask_threshold # to bool
694
+ pred_bboxes = batched_mask_to_box(pred_masks)
695
+ # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
696
+ cls = torch.arange(pred_masks.shape[0], dtype=torch.int32, device=pred_masks.device)
697
+ pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
698
+ return pred_masks, pred_bboxes
699
+
623
700
 
624
701
  class SAM2Predictor(Predictor):
625
- """
626
- SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
702
+ """SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
627
703
 
628
- This class extends the base Predictor class to implement SAM2-specific functionality for image
629
- segmentation tasks. It provides methods for model initialization, feature extraction, and
630
- prompt-based inference.
704
+ This class extends the base Predictor class to implement SAM2-specific functionality for image segmentation tasks.
705
+ It provides methods for model initialization, feature extraction, and prompt-based inference.
631
706
 
632
707
  Attributes:
633
- _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
708
+ _bb_feat_sizes (list[tuple]): Feature sizes for different backbone levels.
634
709
  model (torch.nn.Module): The loaded SAM2 model.
635
710
  device (torch.device): The device (CPU or GPU) on which the model is loaded.
636
- features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
711
+ features (dict): Cached image features for efficient inference.
637
712
  segment_all (bool): Flag to indicate if all segments should be predicted.
638
- prompts (dict): Dictionary to store various types of prompts for inference.
713
+ prompts (dict[str, Any]): Dictionary to store various types of prompts for inference.
639
714
 
640
715
  Methods:
641
- get_model: Retrieves and initializes the SAM2 model.
642
- prompt_inference: Performs image segmentation inference based on various prompts.
643
- set_image: Preprocesses and sets a single image for inference.
644
- get_im_features: Extracts and processes image features using SAM2's image encoder.
716
+ get_model: Retrieve and initialize the SAM2 model.
717
+ prompt_inference: Perform image segmentation inference based on various prompts.
718
+ set_image: Preprocess and set a single image for inference.
719
+ get_im_features: Extract and process image features using SAM2's image encoder.
645
720
 
646
721
  Examples:
647
722
  >>> predictor = SAM2Predictor(cfg)
@@ -658,100 +733,36 @@ class SAM2Predictor(Predictor):
658
733
  ]
659
734
 
660
735
  def get_model(self):
661
- """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
736
+ """Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
662
737
  from .build import build_sam # slow import
663
738
 
664
739
  return build_sam(self.args.model)
665
740
 
666
- def prompt_inference(
667
- self,
668
- im,
669
- bboxes=None,
670
- points=None,
671
- labels=None,
672
- masks=None,
673
- multimask_output=False,
674
- img_idx=-1,
675
- ):
676
- """
677
- Performs image segmentation inference based on various prompts using SAM2 architecture.
678
-
679
- This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
680
- based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
681
- multi-object prediction scenarios.
741
+ def _prepare_prompts(self, dst_shape, src_shape, bboxes=None, points=None, labels=None, masks=None):
742
+ """Prepare and transform the input prompts for processing based on the destination shape.
682
743
 
683
744
  Args:
684
- im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
685
- bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
686
- points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
687
- labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
688
- masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
689
- multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
690
- img_idx (int): Index of the image in the batch to process.
745
+ dst_shape (tuple[int, int]): The target shape (height, width) for the prompts.
746
+ src_shape (tuple[int, int]): The source shape (height, width) of the input image.
747
+ bboxes (np.ndarray | list | None): Bounding boxes in XYXY format with shape (N, 4).
748
+ points (np.ndarray | list | None): Points indicating object locations with shape (N, 2) or (N, num_points,
749
+ 2), in pixels.
750
+ labels (np.ndarray | list | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground,
751
+ 0 for background.
752
+ masks (list | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
691
753
 
692
754
  Returns:
693
- (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
694
- (np.ndarray): Quality scores for each mask, with length C.
695
-
696
- Examples:
697
- >>> predictor = SAM2Predictor(cfg)
698
- >>> image = torch.rand(1, 3, 640, 640)
699
- >>> bboxes = [[100, 100, 200, 200]]
700
- >>> result = predictor(image, bboxes=bboxes)[0]
701
- >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
702
-
703
- Notes:
704
- - The method supports batched inference for multiple objects when points or bboxes are provided.
705
- - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
706
- - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
707
- """
708
- features = self.get_im_features(im) if self.features is None else self.features
709
-
710
- points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
711
- points = (points, labels) if points is not None else None
712
-
713
- sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
714
- points=points,
715
- boxes=None,
716
- masks=masks,
717
- )
718
- # Predict masks
719
- batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
720
- high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
721
- pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
722
- image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
723
- image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
724
- sparse_prompt_embeddings=sparse_embeddings,
725
- dense_prompt_embeddings=dense_embeddings,
726
- multimask_output=multimask_output,
727
- repeat_image=batched_mode,
728
- high_res_features=high_res_features,
729
- )
730
- # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
731
- # `d` could be 1 or 3 depends on `multimask_output`.
732
- return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
733
-
734
- def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
735
- """
736
- Prepares and transforms the input prompts for processing based on the destination shape.
737
-
738
- Args:
739
- dst_shape (tuple): The target shape (height, width) for the prompts.
740
- bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
741
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
742
- labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
743
- masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
755
+ points (torch.Tensor | None): Transformed points.
756
+ labels (torch.Tensor | None): Transformed labels.
757
+ masks (torch.Tensor | None): Transformed masks.
744
758
 
745
759
  Raises:
746
760
  AssertionError: If the number of points don't match the number of labels, in case labels were passed.
747
-
748
- Returns:
749
- (tuple): A tuple containing transformed points, labels, and masks.
750
761
  """
751
- bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
762
+ bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, src_shape, bboxes, points, labels, masks)
752
763
  if bboxes is not None:
753
764
  bboxes = bboxes.view(-1, 2, 2)
754
- bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
765
+ bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(bboxes.shape[0], -1)
755
766
  # NOTE: merge "boxes" and "points" into a single "points" input
756
767
  # (where boxes are added at the beginning) to model.sam_prompt_encoder
757
768
  if points is not None:
@@ -762,11 +773,10 @@ class SAM2Predictor(Predictor):
762
773
  return points, labels, masks
763
774
 
764
775
  def set_image(self, image):
765
- """
766
- Preprocesses and sets a single image for inference using the SAM2 model.
776
+ """Preprocess and set a single image for inference using the SAM2 model.
767
777
 
768
- This method initializes the model if not already done, configures the data source to the specified image,
769
- and preprocesses the image for feature extraction. It supports setting only one image at a time.
778
+ This method initializes the model if not already done, configures the data source to the specified image, and
779
+ preprocesses the image for feature extraction. It supports setting only one image at a time.
770
780
 
771
781
  Args:
772
782
  image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
@@ -794,7 +804,7 @@ class SAM2Predictor(Predictor):
794
804
  break
795
805
 
796
806
  def get_im_features(self, im):
797
- """Extracts image features from the SAM image encoder for subsequent processing."""
807
+ """Extract image features from the SAM image encoder for subsequent processing."""
798
808
  assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
799
809
  f"SAM 2 models only support square image size, but got {self.imgsz}."
800
810
  )
@@ -806,50 +816,108 @@ class SAM2Predictor(Predictor):
806
816
  if self.model.directly_add_no_mem_embed:
807
817
  vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
808
818
  feats = [
809
- feat.permute(1, 2, 0).view(1, -1, *feat_size)
810
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
811
- ][::-1]
819
+ feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats, self._bb_feat_sizes)
820
+ ]
812
821
  return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
813
822
 
823
+ def _inference_features(
824
+ self,
825
+ features,
826
+ points=None,
827
+ labels=None,
828
+ masks=None,
829
+ multimask_output=False,
830
+ img_idx=-1,
831
+ ):
832
+ """Perform inference on image features using the SAM2 model.
833
+
834
+ Args:
835
+ features (torch.Tensor | dict[str, Any]): Extracted image features with shape (B, C, H, W) from the SAM2
836
+ model image encoder, it could also be a dictionary including:
837
+ - image_embed (torch.Tensor): Image embedding with shape (B, C, H, W).
838
+ - high_res_feats (list[torch.Tensor]): List of high-resolution feature maps from the backbone, each with shape (B, C, H, W).
839
+ points (np.ndarray | list[list[float]] | None): Object location points with shape (N, 2), in pixels.
840
+ labels (np.ndarray | list[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
841
+ masks (list[np.ndarray] | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
842
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
843
+ img_idx (int): Index of the image in the batch to process.
844
+
845
+ Returns:
846
+ pred_masks (torch.Tensor): Output masks with shape (C, H, W), where C is the number of generated masks.
847
+ pred_scores (torch.Tensor): Quality scores for each mask, with length C.
848
+ """
849
+ points = (points, labels) if points is not None else None
850
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
851
+ points=points,
852
+ boxes=None,
853
+ masks=masks,
854
+ )
855
+ # Predict masks
856
+ batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
857
+ high_res_features = None
858
+ if isinstance(features, dict):
859
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
860
+ features = features["image_embed"][[img_idx]]
861
+ pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
862
+ image_embeddings=features,
863
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
864
+ sparse_prompt_embeddings=sparse_embeddings,
865
+ dense_prompt_embeddings=dense_embeddings,
866
+ multimask_output=multimask_output,
867
+ repeat_image=batched_mode,
868
+ high_res_features=high_res_features,
869
+ )
870
+ # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
871
+ # `d` could be 1 or 3 depends on `multimask_output`.
872
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
873
+
814
874
 
815
875
  class SAM2VideoPredictor(SAM2Predictor):
816
- """
817
- SAM2VideoPredictor to handle user interactions with videos and manage inference states.
876
+ """SAM2VideoPredictor to handle user interactions with videos and manage inference states.
818
877
 
819
- This class extends the functionality of SAM2Predictor to support video processing and maintains
820
- the state of inference operations. It includes configurations for managing non-overlapping masks,
821
- clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
878
+ This class extends the functionality of SAM2Predictor to support video processing and maintains the state of
879
+ inference operations. It includes configurations for managing non-overlapping masks, clearing memory for
880
+ non-conditional inputs, and setting up callbacks for prediction events.
822
881
 
823
882
  Attributes:
824
883
  inference_state (dict): A dictionary to store the current state of inference operations.
825
884
  non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
826
885
  clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
827
- clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
886
+ clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object
887
+ scenarios.
828
888
  callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.
829
889
 
830
- Args:
831
- cfg (dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
832
- overrides (dict, Optional): Additional configuration overrides. Defaults to None.
833
- _callbacks (list, Optional): Custom callbacks to be added. Defaults to None.
890
+ Methods:
891
+ get_model: Retrieve and configure the model with binarization enabled.
892
+ inference: Perform image segmentation inference based on the given input cues.
893
+ postprocess: Post-process the predictions to apply non-overlapping constraints if required.
894
+ add_new_prompts: Add new points or masks to a specific frame for a given object ID.
895
+ propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.
896
+ init_state: Initialize an inference state for the predictor.
897
+ get_im_features: Extract image features using SAM2's image encoder for subsequent segmentation tasks.
898
+
899
+ Examples:
900
+ >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
901
+ >>> predictor.set_image("path/to/video_frame.jpg")
902
+ >>> bboxes = [[100, 100, 200, 200]]
903
+ >>> results = predictor(bboxes=bboxes)
834
904
 
835
- Note:
905
+ Notes:
836
906
  The `fill_hole_area` attribute is defined but not used in the current implementation.
837
907
  """
838
908
 
839
909
  # fill_hole_area = 8 # not used
840
910
 
841
911
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
842
- """
843
- Initialize the predictor with configuration and optional overrides.
912
+ """Initialize the predictor with configuration and optional overrides.
844
913
 
845
- This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
846
- specified overrides, and sets up the inference state along with certain flags
847
- that control the behavior of the predictor.
914
+ This constructor initializes the SAM2VideoPredictor with a given configuration, applies any specified overrides,
915
+ and sets up the inference state along with certain flags that control the behavior of the predictor.
848
916
 
849
917
  Args:
850
918
  cfg (dict): Configuration dictionary containing default settings.
851
- overrides (Dict | None): Dictionary of values to override default configuration.
852
- _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
919
+ overrides (dict | None): Dictionary of values to override default configuration.
920
+ _callbacks (dict | None): Dictionary of callback functions to customize behavior.
853
921
 
854
922
  Examples:
855
923
  >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
@@ -864,10 +932,9 @@ class SAM2VideoPredictor(SAM2Predictor):
864
932
  self.callbacks["on_predict_start"].append(self.init_state)
865
933
 
866
934
  def get_model(self):
867
- """
868
- Retrieves and configures the model with binarization enabled.
935
+ """Retrieve and configure the model with binarization enabled.
869
936
 
870
- Note:
937
+ Notes:
871
938
  This method overrides the base class implementation to set the binarize flag to True.
872
939
  """
873
940
  model = super().get_model()
@@ -875,21 +942,20 @@ class SAM2VideoPredictor(SAM2Predictor):
875
942
  return model
876
943
 
877
944
  def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
878
- """
879
- Perform image segmentation inference based on the given input cues, using the currently loaded image. This
880
- method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
881
- mask decoder for real-time and promptable segmentation tasks.
945
+ """Perform image segmentation inference based on the given input cues, using the currently loaded image. This
946
+ method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
947
+ encoder, and mask decoder for real-time and promptable segmentation tasks.
882
948
 
883
949
  Args:
884
950
  im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
885
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
886
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
887
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
951
+ bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
952
+ points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
953
+ labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
888
954
  masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
889
955
 
890
956
  Returns:
891
- (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
892
- (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
957
+ pred_masks (torch.Tensor): The output masks in shape CxHxW, where C is the number of generated masks.
958
+ pred_scores (torch.Tensor): An array of length C containing predicted quality scores for each mask.
893
959
  """
894
960
  # Override prompts if any stored in self.prompts
895
961
  bboxes = self.prompts.pop("bboxes", bboxes)
@@ -900,7 +966,9 @@ class SAM2VideoPredictor(SAM2Predictor):
900
966
  self.inference_state["im"] = im
901
967
  output_dict = self.inference_state["output_dict"]
902
968
  if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
903
- points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
969
+ points, labels, masks = self._prepare_prompts(
970
+ im.shape[2:], self.batch[1][0].shape[:2], bboxes, points, labels, masks
971
+ )
904
972
  if points is not None:
905
973
  for i in range(len(points)):
906
974
  self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
@@ -943,25 +1011,24 @@ class SAM2VideoPredictor(SAM2Predictor):
943
1011
  pred_masks = current_out["pred_masks"].flatten(0, 1)
944
1012
  pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
945
1013
 
946
- return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
1014
+ return pred_masks, torch.ones(pred_masks.shape[0], dtype=pred_masks.dtype, device=pred_masks.device)
947
1015
 
948
1016
  def postprocess(self, preds, img, orig_imgs):
949
- """
950
- Post-processes the predictions to apply non-overlapping constraints if required.
1017
+ """Post-process the predictions to apply non-overlapping constraints if required.
951
1018
 
952
- This method extends the post-processing functionality by applying non-overlapping constraints
953
- to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
954
- the masks do not overlap, which can be useful for certain applications.
1019
+ This method extends the post-processing functionality by applying non-overlapping constraints to the predicted
1020
+ masks if the `non_overlap_masks` flag is set to True. This ensures that the masks do not overlap, which can be
1021
+ useful for certain applications.
955
1022
 
956
1023
  Args:
957
- preds (Tuple[torch.Tensor]): The predictions from the model.
1024
+ preds (tuple[torch.Tensor, torch.Tensor]): The predicted masks and scores from the model.
958
1025
  img (torch.Tensor): The processed image tensor.
959
- orig_imgs (List[np.ndarray]): The original images before processing.
1026
+ orig_imgs (list[np.ndarray]): The original images before processing.
960
1027
 
961
1028
  Returns:
962
- results (list): The post-processed predictions.
1029
+ (list): The post-processed predictions.
963
1030
 
964
- Note:
1031
+ Notes:
965
1032
  If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
966
1033
  """
967
1034
  results = super().postprocess(preds, img, orig_imgs)
@@ -981,28 +1048,28 @@ class SAM2VideoPredictor(SAM2Predictor):
981
1048
  masks=None,
982
1049
  frame_idx=0,
983
1050
  ):
984
- """
985
- Adds new points or masks to a specific frame for a given object ID.
1051
+ """Add new points or masks to a specific frame for a given object ID.
986
1052
 
987
- This method updates the inference state with new prompts (points or masks) for a specified
988
- object and frame index. It ensures that the prompts are either points or masks, but not both,
989
- and updates the internal state accordingly. It also handles the generation of new segmentations
990
- based on the provided prompts and the existing state.
1053
+ This method updates the inference state with new prompts (points or masks) for a specified object and frame
1054
+ index. It ensures that the prompts are either points or masks, but not both, and updates the internal state
1055
+ accordingly. It also handles the generation of new segmentations based on the provided prompts and the existing
1056
+ state.
991
1057
 
992
1058
  Args:
993
1059
  obj_id (int): The ID of the object to which the prompts are associated.
994
- points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
995
- labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
996
- masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
997
- frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
1060
+ points (torch.Tensor, optional): The coordinates of the points of interest.
1061
+ labels (torch.Tensor, optional): The labels corresponding to the points.
1062
+ masks (torch.Tensor, optional): Binary masks for the object.
1063
+ frame_idx (int, optional): The index of the frame to which the prompts are applied.
998
1064
 
999
1065
  Returns:
1000
- (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
1066
+ pred_masks (torch.Tensor): The flattened predicted masks.
1067
+ pred_scores (torch.Tensor): A tensor of ones indicating the number of objects.
1001
1068
 
1002
1069
  Raises:
1003
1070
  AssertionError: If both `masks` and `points` are provided, or neither is provided.
1004
1071
 
1005
- Note:
1072
+ Notes:
1006
1073
  - Only one type of prompt (either points or masks) can be added per call.
1007
1074
  - If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
1008
1075
  - The method handles the consolidation of outputs and resizing of masks to the original video resolution.
@@ -1043,7 +1110,9 @@ class SAM2VideoPredictor(SAM2Predictor):
1043
1110
  )
1044
1111
 
1045
1112
  if prev_out is not None and prev_out.get("pred_masks") is not None:
1046
- prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
1113
+ prev_sam_mask_logits = prev_out["pred_masks"].to(
1114
+ device=self.device, non_blocking=self.device.type == "cuda"
1115
+ )
1047
1116
  # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
1048
1117
  prev_sam_mask_logits.clamp_(-32.0, 32.0)
1049
1118
  current_out = self._run_single_frame_inference(
@@ -1075,13 +1144,12 @@ class SAM2VideoPredictor(SAM2Predictor):
1075
1144
 
1076
1145
  @smart_inference_mode()
1077
1146
  def propagate_in_video_preflight(self):
1078
- """
1079
- Prepare inference_state and consolidate temporary outputs before tracking.
1147
+ """Prepare inference_state and consolidate temporary outputs before tracking.
1080
1148
 
1081
- This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
1082
- It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
1083
- Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
1084
- with the provided inputs.
1149
+ This method marks the start of tracking, disallowing the addition of new objects until the session is reset. It
1150
+ consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. Additionally,
1151
+ it clears non-conditioning memory around input frames and ensures that the state is consistent with the provided
1152
+ inputs.
1085
1153
  """
1086
1154
  # Tracking has started and we don't allow adding new objects until session is reset.
1087
1155
  self.inference_state["tracking_has_started"] = True
@@ -1146,12 +1214,11 @@ class SAM2VideoPredictor(SAM2Predictor):
1146
1214
 
1147
1215
  @staticmethod
1148
1216
  def init_state(predictor):
1149
- """
1150
- Initialize an inference state for the predictor.
1217
+ """Initialize an inference state for the predictor.
1151
1218
 
1152
- This function sets up the initial state required for performing inference on video data.
1153
- It includes initializing various dictionaries and ordered dictionaries that will store
1154
- inputs, outputs, and other metadata relevant to the tracking process.
1219
+ This function sets up the initial state required for performing inference on video data. It includes
1220
+ initializing various dictionaries and ordered dictionaries that will store inputs, outputs, and other metadata
1221
+ relevant to the tracking process.
1155
1222
 
1156
1223
  Args:
1157
1224
  predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
@@ -1193,22 +1260,22 @@ class SAM2VideoPredictor(SAM2Predictor):
1193
1260
  predictor.inference_state = inference_state
1194
1261
 
1195
1262
  def get_im_features(self, im, batch=1):
1196
- """
1197
- Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
1263
+ """Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
1198
1264
 
1199
1265
  Args:
1200
1266
  im (torch.Tensor): The input image tensor.
1201
- batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
1267
+ batch (int, optional): The batch size for expanding features if there are multiple prompts.
1202
1268
 
1203
1269
  Returns:
1204
1270
  vis_feats (torch.Tensor): The visual features extracted from the image.
1205
1271
  vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
1206
- feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
1272
+ feat_sizes (list[tuple]): A list containing the sizes of the extracted features.
1207
1273
 
1208
- Note:
1274
+ Notes:
1209
1275
  - If `batch` is greater than 1, the features are expanded to fit the batch size.
1210
1276
  - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
1211
1277
  """
1278
+ self.model.set_imgsz(self.imgsz)
1212
1279
  backbone_out = self.model.forward_image(im)
1213
1280
  if batch > 1: # expand features if there's more than one prompt
1214
1281
  for i, feat in enumerate(backbone_out["backbone_fpn"]):
@@ -1220,19 +1287,18 @@ class SAM2VideoPredictor(SAM2Predictor):
1220
1287
  return vis_feats, vis_pos_embed, feat_sizes
1221
1288
 
1222
1289
  def _obj_id_to_idx(self, obj_id):
1223
- """
1224
- Map client-side object id to model-side object index.
1290
+ """Map client-side object id to model-side object index.
1225
1291
 
1226
1292
  Args:
1227
1293
  obj_id (int): The unique identifier of the object provided by the client side.
1228
1294
 
1229
1295
  Returns:
1230
- obj_idx (int): The index of the object on the model side.
1296
+ (int): The index of the object on the model side.
1231
1297
 
1232
1298
  Raises:
1233
1299
  RuntimeError: If an attempt is made to add a new object after tracking has started.
1234
1300
 
1235
- Note:
1301
+ Notes:
1236
1302
  - The method updates or retrieves mappings between object IDs and indices stored in
1237
1303
  `inference_state`.
1238
1304
  - It ensures that new objects can only be added before tracking commences.
@@ -1283,27 +1349,26 @@ class SAM2VideoPredictor(SAM2Predictor):
1283
1349
  run_mem_encoder,
1284
1350
  prev_sam_mask_logits=None,
1285
1351
  ):
1286
- """
1287
- Run tracking on a single frame based on current inputs and previous memory.
1352
+ """Run tracking on a single frame based on current inputs and previous memory.
1288
1353
 
1289
1354
  Args:
1290
1355
  output_dict (dict): The dictionary containing the output states of the tracking process.
1291
1356
  frame_idx (int): The index of the current frame.
1292
1357
  batch_size (int): The batch size for processing the frame.
1293
1358
  is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
1294
- point_inputs (dict, Optional): Input points and their labels. Defaults to None.
1295
- mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
1359
+ point_inputs (dict | None): Input points and their labels.
1360
+ mask_inputs (torch.Tensor | None): Input binary masks.
1296
1361
  reverse (bool): Indicates if the tracking should be performed in reverse order.
1297
1362
  run_mem_encoder (bool): Indicates if the memory encoder should be executed.
1298
- prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
1363
+ prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object.
1299
1364
 
1300
1365
  Returns:
1301
- current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
1366
+ (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
1302
1367
 
1303
1368
  Raises:
1304
1369
  AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
1305
1370
 
1306
- Note:
1371
+ Notes:
1307
1372
  - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
1308
1373
  - The method retrieves image features using the `get_im_features` method.
1309
1374
  - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
@@ -1334,12 +1399,12 @@ class SAM2VideoPredictor(SAM2Predictor):
1334
1399
  maskmem_features = current_out["maskmem_features"]
1335
1400
  if maskmem_features is not None:
1336
1401
  current_out["maskmem_features"] = maskmem_features.to(
1337
- dtype=torch.float16, device=self.device, non_blocking=True
1402
+ dtype=torch.float16, device=self.device, non_blocking=self.device.type == "cuda"
1338
1403
  )
1339
1404
  # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
1340
1405
  # potentially fill holes in the predicted masks
1341
1406
  # if self.fill_hole_area > 0:
1342
- # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
1407
+ # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=self.device.type == "cuda")
1343
1408
  # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
1344
1409
 
1345
1410
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
@@ -1347,24 +1412,22 @@ class SAM2VideoPredictor(SAM2Predictor):
1347
1412
  return current_out
1348
1413
 
1349
1414
  def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
1350
- """
1351
- Caches and manages the positional encoding for mask memory across frames and objects.
1415
+ """Cache and manage the positional encoding for mask memory across frames and objects.
1352
1416
 
1353
- This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
1354
- mask memory, which is constant across frames and objects, thus reducing the amount of
1355
- redundant information stored during an inference session. It checks if the positional
1356
- encoding has already been cached; if not, it caches a slice of the provided encoding.
1357
- If the batch size is greater than one, it expands the cached positional encoding to match
1358
- the current batch size.
1417
+ This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for mask memory, which is
1418
+ constant across frames and objects, thus reducing the amount of redundant information stored during an inference
1419
+ session. It checks if the positional encoding has already been cached; if not, it caches a slice of the provided
1420
+ encoding. If the batch size is greater than one, it expands the cached positional encoding to match the current
1421
+ batch size.
1359
1422
 
1360
1423
  Args:
1361
- out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
1362
- Should be a list of tensors or None.
1424
+ out_maskmem_pos_enc (list[torch.Tensor] | None): The positional encoding for mask memory. Should be a list
1425
+ of tensors or None.
1363
1426
 
1364
1427
  Returns:
1365
- out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
1428
+ (list[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
1366
1429
 
1367
- Note:
1430
+ Notes:
1368
1431
  - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
1369
1432
  - Only a single object's slice is cached since the encoding is the same across objects.
1370
1433
  - The method checks if the positional encoding has already been cached in the session's constants.
@@ -1381,7 +1444,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1381
1444
  else:
1382
1445
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1383
1446
  # expand the cached maskmem_pos_enc to the actual batch size
1384
- batch_size = out_maskmem_pos_enc[0].size(0)
1447
+ batch_size = out_maskmem_pos_enc[0].shape[0]
1385
1448
  if batch_size > 1:
1386
1449
  out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
1387
1450
  return out_maskmem_pos_enc
@@ -1392,25 +1455,23 @@ class SAM2VideoPredictor(SAM2Predictor):
1392
1455
  is_cond=False,
1393
1456
  run_mem_encoder=False,
1394
1457
  ):
1395
- """
1396
- Consolidates per-object temporary outputs into a single output for all objects.
1458
+ """Consolidate per-object temporary outputs into a single output for all objects.
1397
1459
 
1398
1460
  This method combines the temporary outputs for each object on a given frame into a unified
1399
1461
  output. It fills in any missing objects either from the main output dictionary or leaves
1400
- placeholders if they do not exist in the main output. Optionally, it can re-run the memory
1401
- encoder after applying non-overlapping constraints to the object scores.
1462
+ placeholders if they do not exist in the main output. Optionally, it can re-run the memory encoder after
1463
+ applying non-overlapping constraints to the object scores.
1402
1464
 
1403
1465
  Args:
1404
1466
  frame_idx (int): The index of the frame for which to consolidate outputs.
1405
- is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
1406
- Defaults to False.
1407
- run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
1408
- consolidating the outputs. Defaults to False.
1467
+ is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.
1468
+ run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after consolidating the
1469
+ outputs.
1409
1470
 
1410
1471
  Returns:
1411
- consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
1472
+ (dict): A consolidated output dictionary containing the combined results for all objects.
1412
1473
 
1413
- Note:
1474
+ Notes:
1414
1475
  - The method initializes the consolidated output with placeholder values for missing objects.
1415
1476
  - It searches for outputs in both the temporary and main output dictionaries.
1416
1477
  - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
@@ -1429,13 +1490,13 @@ class SAM2VideoPredictor(SAM2Predictor):
1429
1490
  "pred_masks": torch.full(
1430
1491
  size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
1431
1492
  fill_value=-1024.0,
1432
- dtype=torch.float32,
1493
+ dtype=self.torch_dtype,
1433
1494
  device=self.device,
1434
1495
  ),
1435
1496
  "obj_ptr": torch.full(
1436
1497
  size=(batch_size, self.model.hidden_dim),
1437
1498
  fill_value=-1024.0,
1438
- dtype=torch.float32,
1499
+ dtype=self.torch_dtype,
1439
1500
  device=self.device,
1440
1501
  ),
1441
1502
  "object_score_logits": torch.full(
@@ -1443,7 +1504,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1443
1504
  # default to 10.0 for object_score_logits, i.e. assuming the object is
1444
1505
  # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
1445
1506
  fill_value=10.0,
1446
- dtype=torch.float32,
1507
+ dtype=self.torch_dtype,
1447
1508
  device=self.device,
1448
1509
  ),
1449
1510
  }
@@ -1494,8 +1555,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1494
1555
  return consolidated_out
1495
1556
 
1496
1557
  def _get_empty_mask_ptr(self, frame_idx):
1497
- """
1498
- Get a dummy object pointer based on an empty mask on the current frame.
1558
+ """Get a dummy object pointer based on an empty mask on the current frame.
1499
1559
 
1500
1560
  Args:
1501
1561
  frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
@@ -1515,7 +1575,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1515
1575
  feat_sizes=feat_sizes,
1516
1576
  point_inputs=None,
1517
1577
  # A dummy (empty) mask with a single object
1518
- mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
1578
+ mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=self.torch_dtype, device=self.device),
1519
1579
  output_dict={},
1520
1580
  num_frames=self.inference_state["num_frames"],
1521
1581
  track_in_reverse=False,
@@ -1525,8 +1585,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1525
1585
  return current_out["obj_ptr"]
1526
1586
 
1527
1587
  def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
1528
- """
1529
- Run the memory encoder on masks.
1588
+ """Run the memory encoder on masks.
1530
1589
 
1531
1590
  This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
1532
1591
  memory also needs to be computed again with the memory encoder.
@@ -1538,7 +1597,8 @@ class SAM2VideoPredictor(SAM2Predictor):
1538
1597
  is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
1539
1598
 
1540
1599
  Returns:
1541
- (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
1600
+ maskmem_features (torch.Tensor): The encoded mask features.
1601
+ maskmem_pos_enc (torch.Tensor): The positional encoding.
1542
1602
  """
1543
1603
  # Retrieve correct image features
1544
1604
  current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
@@ -1552,11 +1612,12 @@ class SAM2VideoPredictor(SAM2Predictor):
1552
1612
 
1553
1613
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1554
1614
  maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
1555
- return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
1615
+ return maskmem_features.to(
1616
+ dtype=torch.float16, device=self.device, non_blocking=self.device.type == "cuda"
1617
+ ), maskmem_pos_enc
1556
1618
 
1557
1619
  def _add_output_per_object(self, frame_idx, current_out, storage_key):
1558
- """
1559
- Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
1620
+ """Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
1560
1621
 
1561
1622
  The resulting slices share the same tensor storage.
1562
1623
 
@@ -1586,12 +1647,12 @@ class SAM2VideoPredictor(SAM2Predictor):
1586
1647
  obj_output_dict[storage_key][frame_idx] = obj_out
1587
1648
 
1588
1649
  def _clear_non_cond_mem_around_input(self, frame_idx):
1589
- """
1590
- Remove the non-conditioning memory around the input frame.
1650
+ """Remove the non-conditioning memory around the input frame.
1591
1651
 
1592
- When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
1593
- object appearance information and could confuse the model. This method clears those non-conditioning memories
1594
- surrounding the interacted frame to avoid giving the model both old and new information about the object.
1652
+ When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain
1653
+ outdated object appearance information and could confuse the model. This method clears those non-conditioning
1654
+ memories surrounding the interacted frame to avoid giving the model both old and new information about the
1655
+ object.
1595
1656
 
1596
1657
  Args:
1597
1658
  frame_idx (int): The index of the current frame where user interaction occurred.
@@ -1603,3 +1664,346 @@ class SAM2VideoPredictor(SAM2Predictor):
1603
1664
  self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
1604
1665
  for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
1605
1666
  obj_output_dict["non_cond_frame_outputs"].pop(t, None)
1667
+
1668
+
1669
+ class SAM2DynamicInteractivePredictor(SAM2Predictor):
1670
+ """SAM2DynamicInteractivePredictor extends SAM2Predictor to support dynamic interactions with video frames or a
1671
+ sequence of images.
1672
+
1673
+ Attributes:
1674
+ memory_bank (list): OrderedDict: Stores the states of each image with prompts.
1675
+ obj_idx_set (set): A set to keep track of the object indices that have been added.
1676
+ obj_id_to_idx (OrderedDict): Maps object IDs to their corresponding indices.
1677
+ obj_idx_to_id (OrderedDict): Maps object indices to their corresponding IDs.
1678
+
1679
+ Methods:
1680
+ get_model: Retrieves and configures the model with binarization enabled.
1681
+ inference: Performs inference on a single image with optional prompts and object IDs.
1682
+ postprocess: Post-processes the predictions to apply non-overlapping constraints if required.
1683
+ update_memory: Append the imgState to the memory_bank and update the memory for the model.
1684
+ track_step: Tracking step for the current image state to predict masks.
1685
+ get_maskmem_enc: Get memory and positional encoding from the memory bank.
1686
+
1687
+ Examples:
1688
+ >>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
1689
+ >>> predictor(source=support_img1, bboxes=bboxes1, obj_ids=labels1, update_memory=True)
1690
+ >>> results1 = predictor(source=query_img1)
1691
+ >>> predictor(source=support_img2, bboxes=bboxes2, obj_ids=labels2, update_memory=True)
1692
+ >>> results2 = predictor(source=query_img2)
1693
+ """
1694
+
1695
+ def __init__(
1696
+ self,
1697
+ cfg: Any = DEFAULT_CFG,
1698
+ overrides: dict[str, Any] | None = None,
1699
+ max_obj_num: int = 3,
1700
+ _callbacks: dict[str, Any] | None = None,
1701
+ ) -> None:
1702
+ """Initialize the predictor with configuration and optional overrides.
1703
+
1704
+ This constructor initializes the SAM2DynamicInteractivePredictor with a given configuration, applies any
1705
+ specified overrides
1706
+
1707
+ Args:
1708
+ cfg (dict[str, Any]): Configuration dictionary containing default settings.
1709
+ overrides (dict[str, Any] | None): Dictionary of values to override default configuration.
1710
+ max_obj_num (int): Maximum number of objects to track. Default is 3. this is set to keep fix feature size
1711
+ for the model.
1712
+ _callbacks (dict[str, Any] | None): Dictionary of callback functions to customize behavior.
1713
+
1714
+ Examples:
1715
+ >>> predictor = SAM2DynamicInteractivePredictor(cfg=DEFAULT_CFG)
1716
+ >>> predictor_example_with_imgsz = SAM2DynamicInteractivePredictor(overrides={"imgsz": 640})
1717
+ >>> predictor_example_with_callback = SAM2DynamicInteractivePredictor(
1718
+ ... _callbacks={"on_predict_start": custom_callback}
1719
+ ... )
1720
+ """
1721
+ super().__init__(cfg, overrides, _callbacks)
1722
+ self.non_overlap_masks = True
1723
+
1724
+ # Initialize the memory bank to store image states
1725
+ # NOTE: probably need to use dict for better query
1726
+ self.memory_bank = []
1727
+
1728
+ # Initialize the object index set and mappings
1729
+ self.obj_idx_set = set()
1730
+ self.obj_id_to_idx = OrderedDict()
1731
+ self.obj_idx_to_id = OrderedDict()
1732
+ self._max_obj_num = max_obj_num
1733
+ for i in range(self._max_obj_num):
1734
+ self.obj_id_to_idx[i + 1] = i
1735
+ self.obj_idx_to_id[i] = i + 1
1736
+
1737
+ @smart_inference_mode()
1738
+ def inference(
1739
+ self,
1740
+ im: torch.Tensor | np.ndarray,
1741
+ bboxes: list[list[float]] | None = None,
1742
+ masks: torch.Tensor | np.ndarray | None = None,
1743
+ points: list[list[float]] | None = None,
1744
+ labels: list[int] | None = None,
1745
+ obj_ids: list[int] | None = None,
1746
+ update_memory: bool = False,
1747
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1748
+ """Perform inference on a single image with optional bounding boxes, masks, points and object IDs. It has two
1749
+ modes: one is to run inference on a single image without updating the memory, and the other is to update
1750
+ the memory with the provided prompts and object IDs. When update_memory is True, it will update the
1751
+ memory with the provided prompts and obj_ids. When update_memory is False, it will only run inference on
1752
+ the provided image without updating the memory.
1753
+
1754
+ Args:
1755
+ im (torch.Tensor | np.ndarray): The input image tensor or numpy array.
1756
+ bboxes (list[list[float]] | None): Optional list of bounding boxes to update the memory.
1757
+ masks (list[torch.Tensor | np.ndarray] | None): Optional masks to update the memory.
1758
+ points (list[list[float]] | None): Optional list of points to update the memory, each point is [x, y].
1759
+ labels (list[int] | None): Optional list of object IDs corresponding to the points (>0 for positive, 0 for
1760
+ negative).
1761
+ obj_ids (list[int] | None): Optional list of object IDs corresponding to the prompts.
1762
+ update_memory (bool): Flag to indicate whether to update the memory with new objects.
1763
+
1764
+ Returns:
1765
+ res_masks (torch.Tensor): The output masks in shape (C, H, W)
1766
+ object_score_logits (torch.Tensor): Quality scores for each mask
1767
+ """
1768
+ self.get_im_features(im)
1769
+ points, labels, masks = self._prepare_prompts(
1770
+ dst_shape=self.imgsz,
1771
+ src_shape=self.batch[1][0].shape[:2],
1772
+ points=points,
1773
+ bboxes=bboxes,
1774
+ labels=labels,
1775
+ masks=masks,
1776
+ )
1777
+
1778
+ if update_memory:
1779
+ if isinstance(obj_ids, int):
1780
+ obj_ids = [obj_ids]
1781
+ assert obj_ids is not None, "obj_ids must be provided when update_memory is True"
1782
+ assert masks is not None or points is not None, (
1783
+ "bboxes, masks, or points must be provided when update_memory is True"
1784
+ )
1785
+ if points is None: # placeholder
1786
+ points = torch.zeros((len(obj_ids), 0, 2), dtype=self.torch_dtype, device=self.device)
1787
+ labels = torch.zeros((len(obj_ids), 0), dtype=torch.int32, device=self.device)
1788
+ if masks is not None:
1789
+ assert len(masks) == len(obj_ids), "masks and obj_ids must have the same length."
1790
+ assert len(points) == len(obj_ids), "points and obj_ids must have the same length."
1791
+ self.update_memory(obj_ids, points, labels, masks)
1792
+
1793
+ current_out = self.track_step()
1794
+ pred_masks, pred_scores = current_out["pred_masks"], current_out["object_score_logits"]
1795
+ # filter the masks and logits based on the object indices
1796
+ if len(self.obj_idx_set) == 0:
1797
+ raise RuntimeError("No objects have been added to the state. Please add objects before inference.")
1798
+ idx = list(self.obj_idx_set) # cls id
1799
+ pred_masks, pred_scores = pred_masks[idx], pred_scores[idx]
1800
+ # the original score are in [-32,32], and a object score larger than 0 means the object is present, we map it to [-1,1] range,
1801
+ # and use a activate function to make sure the object score logits are non-negative, so that we can use it as a mask
1802
+ pred_scores = torch.clamp_(pred_scores / 32, min=0)
1803
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
1804
+
1805
+ def get_im_features(self, img: torch.Tensor | np.ndarray) -> None:
1806
+ """Initialize the image state by processing the input image and extracting features.
1807
+
1808
+ Args:
1809
+ img (torch.Tensor | np.ndarray): The input image tensor or numpy array.
1810
+ """
1811
+ vis_feats, vis_pos_embed, feat_sizes = SAM2VideoPredictor.get_im_features(self, img, batch=self._max_obj_num)
1812
+ self.high_res_features = [
1813
+ feat.permute(1, 2, 0).view(*feat.shape[1:], *feat_size)
1814
+ for feat, feat_size in zip(vis_feats[:-1], feat_sizes[:-1])
1815
+ ]
1816
+
1817
+ self.vision_feats = vis_feats
1818
+ self.vision_pos_embeds = vis_pos_embed
1819
+ self.feat_sizes = feat_sizes
1820
+
1821
+ @smart_inference_mode()
1822
+ def update_memory(
1823
+ self,
1824
+ obj_ids: list[int] | None = None,
1825
+ points: torch.Tensor | None = None,
1826
+ labels: torch.Tensor | None = None,
1827
+ masks: torch.Tensor | None = None,
1828
+ ) -> None:
1829
+ """Append the imgState to the memory_bank and update the memory for the model.
1830
+
1831
+ Args:
1832
+ obj_ids (list[int]): List of object IDs corresponding to the prompts.
1833
+ points (torch.Tensor | None): Tensor of shape (B, N, 2) representing the input points for N objects.
1834
+ labels (torch.Tensor | None): Tensor of shape (B, N) representing the labels for the input points.
1835
+ masks (torch.Tensor | None): Optional tensor of shape (N, H, W) representing the input masks for N objects.
1836
+ """
1837
+ consolidated_out = {
1838
+ "maskmem_features": None,
1839
+ "maskmem_pos_enc": None,
1840
+ "pred_masks": torch.full(
1841
+ size=(self._max_obj_num, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
1842
+ fill_value=-1024.0,
1843
+ dtype=self.torch_dtype,
1844
+ device=self.device,
1845
+ ),
1846
+ "obj_ptr": torch.full(
1847
+ size=(self._max_obj_num, self.model.hidden_dim),
1848
+ fill_value=-1024.0,
1849
+ dtype=self.torch_dtype,
1850
+ device=self.device,
1851
+ ),
1852
+ "object_score_logits": torch.full(
1853
+ size=(self._max_obj_num, 1),
1854
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
1855
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
1856
+ fill_value=-32, # 10.0,
1857
+ dtype=self.torch_dtype,
1858
+ device=self.device,
1859
+ ),
1860
+ }
1861
+
1862
+ for i, obj_id in enumerate(obj_ids):
1863
+ assert obj_id < self._max_obj_num
1864
+ obj_idx = self._obj_id_to_idx(int(obj_id))
1865
+ self.obj_idx_set.add(obj_idx)
1866
+ point, label = points[[i]], labels[[i]]
1867
+ mask = masks[[i]][None] if masks is not None else None
1868
+ # Currently, only bbox prompt or mask prompt is supported, so we assert that bbox is not None.
1869
+ assert point is not None or mask is not None, "Either bbox, points or mask is required"
1870
+ out = self.track_step(obj_idx, point, label, mask)
1871
+ if out is not None:
1872
+ obj_mask = out["pred_masks"]
1873
+ assert obj_mask.shape[-2:] == consolidated_out["pred_masks"].shape[-2:], (
1874
+ f"Expected mask shape {consolidated_out['pred_masks'].shape[-2:]} but got {obj_mask.shape[-2:]} for object {obj_idx}."
1875
+ )
1876
+ consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = obj_mask
1877
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
1878
+
1879
+ if "object_score_logits" in out.keys():
1880
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out["object_score_logits"]
1881
+
1882
+ high_res_masks = F.interpolate(
1883
+ consolidated_out["pred_masks"].to(self.device, non_blocking=self.device.type == "cuda"),
1884
+ size=self.imgsz,
1885
+ mode="bilinear",
1886
+ align_corners=False,
1887
+ )
1888
+
1889
+ if self.model.non_overlap_masks_for_mem_enc:
1890
+ high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
1891
+ maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
1892
+ current_vision_feats=self.vision_feats,
1893
+ feat_sizes=self.feat_sizes,
1894
+ pred_masks_high_res=high_res_masks,
1895
+ object_score_logits=consolidated_out["object_score_logits"],
1896
+ is_mask_from_pts=True,
1897
+ )
1898
+ consolidated_out["maskmem_features"] = maskmem_features
1899
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
1900
+ self.memory_bank.append(consolidated_out)
1901
+
1902
+ def _prepare_memory_conditioned_features(self, obj_idx: int | None) -> torch.Tensor:
1903
+ """Prepare the memory-conditioned features for the current image state. If obj_idx is provided, it supposes to
1904
+ prepare features for a specific prompted object in the image. If obj_idx is None, it prepares features
1905
+ for all objects in the image. If there is no memory, it will directly add a no-memory embedding to the
1906
+ current vision features. If there is memory, it will use the memory features from previous frames to
1907
+ condition the current vision features using a transformer attention mechanism.
1908
+
1909
+ Args:
1910
+ obj_idx (int | None): The index of the object for which to prepare the features.
1911
+
1912
+ Returns:
1913
+ pix_feat_with_mem (torch.Tensor): The memory-conditioned pixel features.
1914
+ """
1915
+ if len(self.memory_bank) == 0 or isinstance(obj_idx, int):
1916
+ # for initial conditioning frames with, encode them without using any previous memory
1917
+ # directly add no-mem embedding (instead of using the transformer encoder)
1918
+ pix_feat_with_mem = self.vision_feats[-1] + self.model.no_mem_embed
1919
+ else:
1920
+ # for inference frames, use the memory features from previous frames
1921
+ memory, memory_pos_embed = self.get_maskmem_enc()
1922
+ pix_feat_with_mem = self.model.memory_attention(
1923
+ curr=self.vision_feats[-1:],
1924
+ curr_pos=self.vision_pos_embeds[-1:],
1925
+ memory=memory,
1926
+ memory_pos=memory_pos_embed,
1927
+ num_obj_ptr_tokens=0, # num_obj_ptr_tokens
1928
+ )
1929
+ # reshape the output (HW)BC => BCHW
1930
+ return pix_feat_with_mem.permute(1, 2, 0).view(
1931
+ self._max_obj_num,
1932
+ self.model.memory_attention.d_model,
1933
+ *self.feat_sizes[-1],
1934
+ )
1935
+
1936
+ def get_maskmem_enc(self) -> tuple[torch.Tensor, torch.Tensor]:
1937
+ """Get memory and positional encoding from memory, which is used to condition the current image features."""
1938
+ to_cat_memory, to_cat_memory_pos_embed = [], []
1939
+ for consolidated_out in self.memory_bank:
1940
+ to_cat_memory.append(consolidated_out["maskmem_features"].flatten(2).permute(2, 0, 1)) # (H*W, B, C)
1941
+ maskmem_enc = consolidated_out["maskmem_pos_enc"][-1].flatten(2).permute(2, 0, 1)
1942
+ maskmem_enc = maskmem_enc + self.model.maskmem_tpos_enc[self.model.num_maskmem - 1]
1943
+ to_cat_memory_pos_embed.append(maskmem_enc)
1944
+
1945
+ memory = torch.cat(to_cat_memory, dim=0)
1946
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
1947
+ return memory, memory_pos_embed
1948
+
1949
+ def _obj_id_to_idx(self, obj_id: int) -> int | None:
1950
+ """Map client-side object id to model-side object index.
1951
+
1952
+ Args:
1953
+ obj_id (int): The client-side object ID.
1954
+
1955
+ Returns:
1956
+ (int): The model-side object index, or None if not found.
1957
+ """
1958
+ return self.obj_id_to_idx.get(obj_id, None)
1959
+
1960
+ def track_step(
1961
+ self,
1962
+ obj_idx: int | None = None,
1963
+ point: torch.Tensor | None = None,
1964
+ label: torch.Tensor | None = None,
1965
+ mask: torch.Tensor | None = None,
1966
+ ) -> dict[str, Any]:
1967
+ """Tracking step for the current image state to predict masks.
1968
+
1969
+ This method processes the image features and runs the SAM heads to predict masks. If obj_idx is provided, it
1970
+ processes the features for a specific prompted object in the image. If obj_idx is None, it processes the
1971
+ features for all objects in the image. The method supports both mask-based output without SAM and full SAM
1972
+ processing with memory-conditioned features.
1973
+
1974
+ Args:
1975
+ obj_idx (int | None): The index of the object for which to predict masks. If None, it processes all objects.
1976
+ point (torch.Tensor | None): The coordinates of the points of interest with shape (N, 2).
1977
+ label (torch.Tensor | None): The labels corresponding to the points where 1 means positive clicks, 0 means
1978
+ negative clicks.
1979
+ mask (torch.Tensor | None): The mask input for the object with shape (H, W).
1980
+
1981
+ Returns:
1982
+ current_out (dict[str, Any]): A dictionary containing the current output with mask predictions and object
1983
+ pointers. Keys include 'point_inputs', 'mask_inputs', 'pred_masks', 'pred_masks_high_res',
1984
+ 'obj_ptr', 'object_score_logits'.
1985
+ """
1986
+ if mask is not None and self.model.use_mask_input_as_output_without_sam:
1987
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
1988
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
1989
+ pix_feat = self.vision_feats[-1].permute(1, 2, 0)
1990
+ pix_feat = pix_feat.view(-1, self.model.memory_attention.d_model, *self.feat_sizes[-1])
1991
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._use_mask_as_output(mask)
1992
+ else:
1993
+ # fused the visual feature with previous memory features in the memory bank
1994
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(obj_idx)
1995
+ # calculate the first feature if adding obj_idx exists(means adding prompts)
1996
+ pix_feat_with_mem = pix_feat_with_mem[:1] if obj_idx is not None else pix_feat_with_mem
1997
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = self.model._forward_sam_heads(
1998
+ backbone_features=pix_feat_with_mem,
1999
+ point_inputs={"point_coords": point, "point_labels": label} if obj_idx is not None else None,
2000
+ mask_inputs=mask,
2001
+ multimask_output=False,
2002
+ high_res_features=[feat[: pix_feat_with_mem.shape[0]] for feat in self.high_res_features],
2003
+ )
2004
+ return {
2005
+ "pred_masks": low_res_masks,
2006
+ "pred_masks_high_res": high_res_masks,
2007
+ "obj_ptr": obj_ptr,
2008
+ "object_score_logits": object_score_logits,
2009
+ }