ultralytics-opencv-headless 8.3.246__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1578 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +313 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +1006 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +501 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1563 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.246.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.246.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.246.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.246.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.246.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.246.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1160 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch.nn.init import trunc_normal_
12
+
13
+ from ultralytics.nn.modules import MLP
14
+ from ultralytics.utils import LOGGER
15
+
16
+ from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
17
+ from .decoders import MaskDecoder, SAM2MaskDecoder
18
+ from .encoders import ImageEncoderViT, PromptEncoder
19
+ from .utils import get_1d_sine_pe, select_closest_cond_frames
20
+
21
+ # a large negative value as a placeholder score for missing objects
22
+ NO_OBJ_SCORE = -1024.0
23
+
24
+
25
+ class SAMModel(nn.Module):
26
+ """Segment Anything Model (SAM) for object segmentation tasks.
27
+
28
+ This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
29
+ prompts.
30
+
31
+ Attributes:
32
+ mask_threshold (float): Threshold value for mask prediction.
33
+ image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
34
+ prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
35
+ mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
36
+ pixel_mean (torch.Tensor): Mean values for normalizing pixels in the input image.
37
+ pixel_std (torch.Tensor): Standard deviation values for normalizing pixels in the input image.
38
+
39
+ Methods:
40
+ set_imgsz: Set image size to make model compatible with different image sizes.
41
+
42
+ Examples:
43
+ >>> image_encoder = ImageEncoderViT(...)
44
+ >>> prompt_encoder = PromptEncoder(...)
45
+ >>> mask_decoder = MaskDecoder(...)
46
+ >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
47
+ >>> # Further usage depends on SAMPredictor class
48
+
49
+ Notes:
50
+ All forward() operations are implemented in the SAMPredictor class.
51
+ """
52
+
53
+ mask_threshold: float = 0.0
54
+
55
+ def __init__(
56
+ self,
57
+ image_encoder: ImageEncoderViT,
58
+ prompt_encoder: PromptEncoder,
59
+ mask_decoder: MaskDecoder,
60
+ pixel_mean: list[float] = (123.675, 116.28, 103.53),
61
+ pixel_std: list[float] = (58.395, 57.12, 57.375),
62
+ ) -> None:
63
+ """Initialize the SAMModel class to predict object masks from an image and input prompts.
64
+
65
+ Args:
66
+ image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
67
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
68
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
69
+ pixel_mean (list[float]): Mean values for normalizing pixels in the input image.
70
+ pixel_std (list[float]): Standard deviation values for normalizing pixels in the input image.
71
+
72
+ Notes:
73
+ All forward() operations moved to SAMPredictor.
74
+ """
75
+ super().__init__()
76
+ self.image_encoder = image_encoder
77
+ self.prompt_encoder = prompt_encoder
78
+ self.mask_decoder = mask_decoder
79
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
80
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
81
+
82
+ def set_imgsz(self, imgsz):
83
+ """Set image size to make model compatible with different image sizes."""
84
+ if hasattr(self.image_encoder, "set_imgsz"):
85
+ self.image_encoder.set_imgsz(imgsz)
86
+ self.prompt_encoder.input_image_size = imgsz
87
+ self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model
88
+ self.image_encoder.img_size = imgsz[0]
89
+
90
+
91
+ class SAM2Model(torch.nn.Module):
92
+ """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
93
+
94
+ This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
95
+ consistency and efficient tracking of objects across frames.
96
+
97
+ Attributes:
98
+ mask_threshold (float): Threshold value for mask prediction.
99
+ image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
100
+ memory_attention (nn.Module): Module for attending to memory features.
101
+ memory_encoder (nn.Module): Encoder for generating memory representations.
102
+ num_maskmem (int): Number of accessible memory frames.
103
+ image_size (int): Size of input images.
104
+ backbone_stride (int): Stride of the backbone network output.
105
+ sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
106
+ sam_image_embedding_size (int): Size of SAM image embeddings.
107
+ sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
108
+ sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
109
+ obj_ptr_proj (nn.Module): Projection layer for object pointers.
110
+ obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
111
+ hidden_dim (int): Hidden dimension of the model.
112
+ mem_dim (int): Memory dimension for encoding features.
113
+ use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
114
+ use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
115
+ max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention.
116
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers.
117
+ proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
118
+ encoding in object pointers.
119
+ use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in temporal positional encoding.
120
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
121
+ evaluation.
122
+ pred_obj_scores (bool): Whether to predict if there is an object in the frame.
123
+ pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
124
+ fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
125
+ soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
126
+ use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
127
+ no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
128
+ max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
129
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
130
+ frame.
131
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
132
+ frames.
133
+ multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
134
+ multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
135
+ multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
136
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
137
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
138
+ memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
139
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
140
+ encoder during evaluation.
141
+ sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
142
+ sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
143
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
144
+ clicks during evaluation.
145
+ use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
146
+ encoder and mask decoder on frames with mask input.
147
+
148
+ Methods:
149
+ forward_image: Process image batch through encoder to extract multi-level features.
150
+ track_step: Perform a single tracking step, updating object masks and memory features.
151
+ set_binarize: Set binarize for VideoPredictor.
152
+ set_imgsz: Set image size to make model compatible with different image sizes.
153
+
154
+ Examples:
155
+ >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
156
+ >>> image_batch = torch.rand(1, 3, 512, 512)
157
+ >>> features = model.forward_image(image_batch)
158
+ >>> track_results = model.track_step(0, True, features, None, None, None, {})
159
+ """
160
+
161
+ mask_threshold: float = 0.0
162
+
163
+ def __init__(
164
+ self,
165
+ image_encoder,
166
+ memory_attention,
167
+ memory_encoder,
168
+ num_maskmem=7,
169
+ image_size=512,
170
+ backbone_stride=16,
171
+ sigmoid_scale_for_mem_enc=1.0,
172
+ sigmoid_bias_for_mem_enc=0.0,
173
+ binarize_mask_from_pts_for_mem_enc=False,
174
+ use_mask_input_as_output_without_sam=False,
175
+ max_cond_frames_in_attn=-1,
176
+ directly_add_no_mem_embed=False,
177
+ use_high_res_features_in_sam=False,
178
+ multimask_output_in_sam=False,
179
+ multimask_min_pt_num=1,
180
+ multimask_max_pt_num=1,
181
+ multimask_output_for_tracking=False,
182
+ use_multimask_token_for_obj_ptr: bool = False,
183
+ iou_prediction_use_sigmoid=False,
184
+ memory_temporal_stride_for_eval=1,
185
+ non_overlap_masks_for_mem_enc=False,
186
+ use_obj_ptrs_in_encoder=False,
187
+ max_obj_ptrs_in_encoder=16,
188
+ add_tpos_enc_to_obj_ptrs=True,
189
+ proj_tpos_enc_in_obj_ptrs=False,
190
+ use_signed_tpos_enc_to_obj_ptrs=False,
191
+ only_obj_ptrs_in_the_past_for_eval=False,
192
+ pred_obj_scores: bool = False,
193
+ pred_obj_scores_mlp: bool = False,
194
+ fixed_no_obj_ptr: bool = False,
195
+ soft_no_obj_ptr: bool = False,
196
+ use_mlp_for_obj_ptr_proj: bool = False,
197
+ no_obj_embed_spatial: bool = False,
198
+ sam_mask_decoder_extra_args=None,
199
+ compile_image_encoder: bool = False,
200
+ ):
201
+ """Initialize the SAM2Model for video object segmentation with memory-based tracking.
202
+
203
+ Args:
204
+ image_encoder (nn.Module): Visual encoder for extracting image features.
205
+ memory_attention (nn.Module): Module for attending to memory features.
206
+ memory_encoder (nn.Module): Encoder for generating memory representations.
207
+ num_maskmem (int): Number of accessible memory frames.
208
+ image_size (int): Size of input images.
209
+ backbone_stride (int): Stride of the image backbone output.
210
+ sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
211
+ sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
212
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
213
+ clicks during evaluation.
214
+ use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
215
+ prompt encoder and mask decoder on frames with mask input.
216
+ max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
217
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
218
+ frame.
219
+ use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
220
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
221
+ frames.
222
+ multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
223
+ multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
224
+ multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
225
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
226
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
227
+ memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
228
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
229
+ encoder during evaluation.
230
+ use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
231
+ max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
232
+ cross-attention.
233
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
234
+ encoder.
235
+ proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
236
+ encoding in object pointers.
237
+ use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
238
+ in the object pointers.
239
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
240
+ evaluation.
241
+ pred_obj_scores (bool): Whether to predict if there is an object in the frame.
242
+ pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
243
+ fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
244
+ soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
245
+ use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
246
+ no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
247
+ sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.
248
+ compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
249
+ """
250
+ super().__init__()
251
+
252
+ # Part 1: the image backbone
253
+ self.image_encoder = image_encoder
254
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
255
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
256
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
257
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
258
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
259
+ if use_obj_ptrs_in_encoder:
260
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
261
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
262
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
263
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
264
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
265
+ if proj_tpos_enc_in_obj_ptrs:
266
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
267
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
268
+ self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
269
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
270
+
271
+ # Part 2: memory attention to condition current frame's visual features
272
+ # with memories (and obj ptrs) from past frames
273
+ self.memory_attention = memory_attention
274
+ self.hidden_dim = memory_attention.d_model
275
+
276
+ # Part 3: memory encoder for the previous frame's outputs
277
+ self.memory_encoder = memory_encoder
278
+ self.mem_dim = self.hidden_dim
279
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
280
+ # if there is compression of memories along channel dim
281
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
282
+ self.num_maskmem = num_maskmem # Number of memories accessible
283
+ # Temporal encoding of the memories
284
+ self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
285
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
286
+ # a single token to indicate no memory embedding from previous frames
287
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
288
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
289
+ trunc_normal_(self.no_mem_embed, std=0.02)
290
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
291
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
292
+ # Apply sigmoid to the output raw mask logits (to turn them from
293
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
294
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
295
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
296
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
297
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
298
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
299
+ # On frames with mask input, whether to directly output the input mask without
300
+ # using a SAM prompt encoder + mask decoder
301
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
302
+ self.multimask_output_in_sam = multimask_output_in_sam
303
+ self.multimask_min_pt_num = multimask_min_pt_num
304
+ self.multimask_max_pt_num = multimask_max_pt_num
305
+ self.multimask_output_for_tracking = multimask_output_for_tracking
306
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
307
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
308
+
309
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
310
+ # and SAM-style mask decoder for the final mask output
311
+ self.image_size = image_size
312
+ self.backbone_stride = backbone_stride
313
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
314
+ self.pred_obj_scores = pred_obj_scores
315
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
316
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
317
+ self.soft_no_obj_ptr = soft_no_obj_ptr
318
+ if self.fixed_no_obj_ptr:
319
+ assert self.pred_obj_scores
320
+ assert self.use_obj_ptrs_in_encoder
321
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
322
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
323
+ trunc_normal_(self.no_obj_ptr, std=0.02)
324
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
325
+ self.no_obj_embed_spatial = None
326
+ if no_obj_embed_spatial:
327
+ self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
328
+ trunc_normal_(self.no_obj_embed_spatial, std=0.02)
329
+
330
+ self._build_sam_heads()
331
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
332
+ self.add_all_frames_to_correct_as_cond = True
333
+
334
+ # Model compilation
335
+ if compile_image_encoder:
336
+ # Compile the forward function (not the full module) to allow loading checkpoints.
337
+ LOGGER.info("Image encoder compilation is enabled. First forward pass will be slow.")
338
+ self.image_encoder.forward = torch.compile(
339
+ self.image_encoder.forward,
340
+ mode="max-autotune",
341
+ fullgraph=True,
342
+ dynamic=False,
343
+ )
344
+
345
+ @property
346
+ def device(self):
347
+ """Return the device on which the model's parameters are stored."""
348
+ return next(self.parameters()).device
349
+
350
+ def forward(self, *args, **kwargs):
351
+ """Process image and prompt inputs to generate object masks and scores in video sequences."""
352
+ raise NotImplementedError(
353
+ "Please use the corresponding methods in SAM2VideoPredictor for inference."
354
+ "See notebooks/video_predictor_example.ipynb for an example."
355
+ )
356
+
357
+ def _build_sam_heads(self):
358
+ """Build SAM-style prompt encoder and mask decoder for image segmentation tasks."""
359
+ self.sam_prompt_embed_dim = self.hidden_dim
360
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
361
+
362
+ # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
363
+ self.sam_prompt_encoder = PromptEncoder(
364
+ embed_dim=self.sam_prompt_embed_dim,
365
+ image_embedding_size=(
366
+ self.sam_image_embedding_size,
367
+ self.sam_image_embedding_size,
368
+ ),
369
+ input_image_size=(self.image_size, self.image_size),
370
+ mask_in_chans=16,
371
+ )
372
+ self.sam_mask_decoder = SAM2MaskDecoder(
373
+ num_multimask_outputs=3,
374
+ transformer=SAM2TwoWayTransformer(
375
+ depth=2,
376
+ embedding_dim=self.sam_prompt_embed_dim,
377
+ mlp_dim=2048,
378
+ num_heads=8,
379
+ ),
380
+ transformer_dim=self.sam_prompt_embed_dim,
381
+ iou_head_depth=3,
382
+ iou_head_hidden_dim=256,
383
+ use_high_res_features=self.use_high_res_features_in_sam,
384
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
385
+ pred_obj_scores=self.pred_obj_scores,
386
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
387
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
388
+ **(self.sam_mask_decoder_extra_args or {}),
389
+ )
390
+ if self.use_obj_ptrs_in_encoder:
391
+ # a linear projection on SAM output tokens to turn them into object pointers
392
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
393
+ if self.use_mlp_for_obj_ptr_proj:
394
+ self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
395
+ else:
396
+ self.obj_ptr_proj = torch.nn.Identity()
397
+ if self.proj_tpos_enc_in_obj_ptrs:
398
+ # a linear projection on temporal positional encoding in object pointers to
399
+ # avoid potential interference with spatial positional encoding
400
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
401
+ else:
402
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
403
+
404
+ def _forward_sam_heads(
405
+ self,
406
+ backbone_features,
407
+ point_inputs=None,
408
+ mask_inputs=None,
409
+ high_res_features=None,
410
+ multimask_output=False,
411
+ ):
412
+ """Forward pass through SAM prompt encoders and mask heads.
413
+
414
+ This method processes image features and optional point/mask inputs to generate object masks and scores.
415
+
416
+ Args:
417
+ backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
418
+ point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
419
+ 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
420
+ (x, y) format for P input points.
421
+ 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
422
+ clicks, and -1 means padding.
423
+ mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
424
+ size as the image.
425
+ high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
426
+ C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
427
+ multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
428
+ mask and its IoU estimate.
429
+
430
+ Returns:
431
+ low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
432
+ high_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
433
+ ious (torch.Tensor): Tensor of shape (B, M) with estimated IoU for each output mask.
434
+ low_res_masks (torch.Tensor): Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
435
+ high_res_masks (torch.Tensor): Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
436
+ obj_ptr (torch.Tensor): Tensor of shape (B, C) with object pointer vector for the output mask.
437
+ object_score_logits (torch.Tensor): Tensor of shape (B) with object score logits.
438
+
439
+ Examples:
440
+ >>> backbone_features = torch.rand(1, 256, 32, 32)
441
+ >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
442
+ >>> mask_inputs = torch.rand(1, 1, 512, 512)
443
+ >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
444
+ >>> (
445
+ ... low_res_multimasks,
446
+ ... high_res_multimasks,
447
+ ... ious,
448
+ ... low_res_masks,
449
+ ... high_res_masks,
450
+ ... obj_ptr,
451
+ ... object_score_logits,
452
+ ... ) = results
453
+ """
454
+ B = backbone_features.shape[0]
455
+ device = backbone_features.device
456
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
457
+ assert backbone_features.size(2) == self.sam_image_embedding_size
458
+ assert backbone_features.size(3) == self.sam_image_embedding_size
459
+
460
+ # a) Handle point prompts
461
+ if point_inputs is not None:
462
+ sam_point_coords = point_inputs["point_coords"]
463
+ sam_point_labels = point_inputs["point_labels"]
464
+ assert sam_point_coords.shape[0] == B and sam_point_labels.shape[0] == B
465
+ else:
466
+ # If no points are provide, pad with an empty point (with label -1)
467
+ sam_point_coords = torch.zeros(B, 1, 2, device=device, dtype=backbone_features.dtype)
468
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
469
+
470
+ # b) Handle mask prompts
471
+ if mask_inputs is not None:
472
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
473
+ # and feed it as a dense mask prompt into the SAM mask encoder
474
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
475
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
476
+ sam_mask_prompt = F.interpolate(
477
+ mask_inputs.to(backbone_features.dtype),
478
+ size=self.sam_prompt_encoder.mask_input_size,
479
+ align_corners=False,
480
+ mode="bilinear",
481
+ antialias=True, # use antialias for downsampling
482
+ )
483
+ else:
484
+ sam_mask_prompt = mask_inputs
485
+ else:
486
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
487
+ # a learned `no_mask_embed` to indicate no mask input in this case).
488
+ sam_mask_prompt = None
489
+
490
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
491
+ points=(sam_point_coords, sam_point_labels),
492
+ boxes=None,
493
+ masks=sam_mask_prompt,
494
+ )
495
+ low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
496
+ image_embeddings=backbone_features,
497
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
498
+ sparse_prompt_embeddings=sparse_embeddings,
499
+ dense_prompt_embeddings=dense_embeddings,
500
+ multimask_output=multimask_output,
501
+ repeat_image=False, # the image is already batched
502
+ high_res_features=high_res_features,
503
+ )
504
+ if self.pred_obj_scores:
505
+ is_obj_appearing = object_score_logits > 0
506
+
507
+ # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
508
+ low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
509
+
510
+ # convert masks from possibly bfloat16 (or float16) to float32
511
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
512
+ high_res_multimasks = F.interpolate(
513
+ low_res_multimasks,
514
+ size=(self.image_size, self.image_size),
515
+ mode="bilinear",
516
+ align_corners=False,
517
+ )
518
+
519
+ sam_output_token = sam_output_tokens[:, 0]
520
+ if multimask_output:
521
+ # take the best mask prediction (with the highest IoU estimation)
522
+ best_iou_inds = torch.argmax(ious, dim=-1)
523
+ batch_inds = torch.arange(B, device=device)
524
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
525
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
526
+ if sam_output_tokens.size(1) > 1:
527
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
528
+ else:
529
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
530
+
531
+ # Extract object pointer from the SAM output token (with occlusion handling)
532
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
533
+ if self.pred_obj_scores:
534
+ # Allow *soft* no obj ptr, unlike for masks
535
+ if self.soft_no_obj_ptr:
536
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
537
+ else:
538
+ lambda_is_obj_appearing = is_obj_appearing.to(obj_ptr.dtype)
539
+
540
+ if self.fixed_no_obj_ptr:
541
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
542
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
543
+ return (
544
+ low_res_multimasks,
545
+ high_res_multimasks,
546
+ ious,
547
+ low_res_masks,
548
+ high_res_masks,
549
+ obj_ptr,
550
+ object_score_logits,
551
+ )
552
+
553
+ def _use_mask_as_output(self, mask_inputs, backbone_features=None, high_res_features=None):
554
+ """Process mask inputs directly as output, bypassing SAM encoder/decoder."""
555
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
556
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
557
+ mask_inputs_float = mask_inputs.float()
558
+ high_res_masks = mask_inputs_float * out_scale + out_bias
559
+ low_res_masks = F.interpolate(
560
+ high_res_masks,
561
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
562
+ align_corners=False,
563
+ mode="bilinear",
564
+ antialias=True, # use antialias for downsampling
565
+ )
566
+ # a dummy IoU prediction of all 1's under mask input
567
+ ious = mask_inputs.new_ones(mask_inputs.shape[0], 1).float()
568
+ if not self.use_obj_ptrs_in_encoder or backbone_features is None or high_res_features is None:
569
+ # all zeros as a dummy object pointer (of shape [B, C])
570
+ obj_ptr = torch.zeros(mask_inputs.shape[0], self.hidden_dim, device=mask_inputs.device)
571
+ else:
572
+ # produce an object pointer using the SAM decoder from the mask input
573
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
574
+ backbone_features=backbone_features,
575
+ mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
576
+ high_res_features=high_res_features,
577
+ )
578
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
579
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
580
+ # on the object_scores from the SAM decoder.
581
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
582
+ is_obj_appearing = is_obj_appearing[..., None]
583
+ lambda_is_obj_appearing = is_obj_appearing.float()
584
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
585
+ if self.pred_obj_scores:
586
+ if self.fixed_no_obj_ptr:
587
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
588
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
589
+
590
+ return (
591
+ low_res_masks,
592
+ high_res_masks,
593
+ ious,
594
+ low_res_masks,
595
+ high_res_masks,
596
+ obj_ptr,
597
+ object_score_logits,
598
+ )
599
+
600
+ def forward_image(self, img_batch: torch.Tensor):
601
+ """Process image batch through encoder to extract multi-level features for SAM model."""
602
+ backbone_out = self.image_encoder(img_batch)
603
+ if self.use_high_res_features_in_sam:
604
+ # precompute projected level 0 and level 1 features in SAM decoder
605
+ # to avoid running it again on every SAM click
606
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
607
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
608
+ return backbone_out
609
+
610
+ def _prepare_backbone_features(self, backbone_out, batch=1):
611
+ """Prepare and flatten visual features from the image backbone output for further processing."""
612
+ if batch > 1: # expand features if there's more than one prompt
613
+ backbone_out = {
614
+ **backbone_out,
615
+ "backbone_fpn": [feat.expand(batch, -1, -1, -1) for feat in backbone_out["backbone_fpn"]],
616
+ "vision_pos_enc": [pos.expand(batch, -1, -1, -1) for pos in backbone_out["vision_pos_enc"]],
617
+ }
618
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
619
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
620
+
621
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
622
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
623
+
624
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
625
+ # flatten NxCxHxW to HWxNxC
626
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
627
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
628
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
629
+
630
+ def _prepare_memory_conditioned_features(
631
+ self,
632
+ frame_idx,
633
+ is_init_cond_frame,
634
+ current_vision_feats,
635
+ current_vision_pos_embeds,
636
+ feat_sizes,
637
+ output_dict,
638
+ num_frames,
639
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
640
+ ):
641
+ """Prepare memory-conditioned features by fusing current frame's visual features with previous memories."""
642
+ B = current_vision_feats[-1].size(1) # batch size on this frame
643
+ C = self.hidden_dim
644
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
645
+ device = current_vision_feats[-1].device
646
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
647
+ # In this case, we skip the fusion with any memory.
648
+ if self.num_maskmem == 0: # Disable memory and skip fusion
649
+ return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
650
+ num_obj_ptr_tokens = 0
651
+ tpos_sign_mul = -1 if track_in_reverse else 1
652
+ # Step 1: condition the visual features of the current frame on previous memories
653
+ if not is_init_cond_frame:
654
+ # Retrieve the memories encoded with the maskmem backbone
655
+ to_cat_memory, to_cat_memory_pos_embed = [], []
656
+ # Add conditioning frame's output first (all cond frames have t_pos=0 for
657
+ # when getting temporal positional embedding below)
658
+ assert len(output_dict["cond_frame_outputs"]) > 0
659
+ # Select a maximum number of temporally closest cond frames for cross attention
660
+ cond_outputs = output_dict["cond_frame_outputs"]
661
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
662
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
663
+ )
664
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
665
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
666
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
667
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
668
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
669
+ r = 1 if self.training else self.memory_temporal_stride_for_eval
670
+ for t_pos in range(1, self.num_maskmem):
671
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
672
+ if t_rel == 1:
673
+ # for t_rel == 1, we take the last frame (regardless of r)
674
+ prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
675
+ elif not track_in_reverse:
676
+ # first find the nearest frame among every r-th frames before this frame
677
+ # for r=1, this would be (frame_idx - 2)
678
+ prev_frame_idx = ((frame_idx - 2) // r) * r
679
+ # then seek further among every r-th frames
680
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
681
+ else:
682
+ # first find the nearest frame among every r-th frames after this frame
683
+ # for r=1, this would be (frame_idx + 2)
684
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
685
+ # then seek further among every r-th frames
686
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
687
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
688
+ if out is None:
689
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
690
+ # frames, we still attend to it as if it's a non-conditioning frame.
691
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
692
+ t_pos_and_prevs.append((t_pos, out))
693
+
694
+ for t_pos, prev in t_pos_and_prevs:
695
+ if prev is None:
696
+ continue # skip padding frames
697
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
698
+ # so we load it back to inference device (it's a no-op if it's already on device).
699
+ feats = prev["maskmem_features"].to(device=device, non_blocking=device.type == "cuda")
700
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
701
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
702
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
703
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
704
+ # Temporal positional encoding
705
+ maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
706
+ to_cat_memory_pos_embed.append(maskmem_enc)
707
+
708
+ # Construct the list of past object pointers
709
+ if self.use_obj_ptrs_in_encoder:
710
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
711
+ # First add those object pointers from selected conditioning frames
712
+ # (optionally, only include object pointers in the past during evaluation)
713
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
714
+ ptr_cond_outputs = {
715
+ t: out
716
+ for t, out in selected_cond_outputs.items()
717
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
718
+ }
719
+ else:
720
+ ptr_cond_outputs = selected_cond_outputs
721
+ pos_and_ptrs = [
722
+ # Temporal pos encoding contains how far away each pointer is from current frame
723
+ (
724
+ (
725
+ (frame_idx - t) * tpos_sign_mul
726
+ if self.use_signed_tpos_enc_to_obj_ptrs
727
+ else abs(frame_idx - t)
728
+ ),
729
+ out["obj_ptr"],
730
+ )
731
+ for t, out in ptr_cond_outputs.items()
732
+ ]
733
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
734
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
735
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
736
+ if t < 0 or (num_frames is not None and t >= num_frames):
737
+ break
738
+ out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
739
+ if out is not None:
740
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
741
+ # If we have at least one object pointer, add them to the across attention
742
+ if pos_and_ptrs:
743
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
744
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
745
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
746
+ # a temporal positional embedding based on how far each object pointer is from
747
+ # the current frame (sine embedding normalized by the max pointer num).
748
+ if self.add_tpos_enc_to_obj_ptrs:
749
+ t_diff_max = max_obj_ptrs_in_encoder - 1
750
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
751
+ obj_pos = torch.tensor(pos_list, device=device, dtype=current_vision_feats[-1].dtype)
752
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
753
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
754
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
755
+ else:
756
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
757
+ if self.mem_dim < C:
758
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
759
+ obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
760
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
761
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
762
+ to_cat_memory.append(obj_ptrs)
763
+ to_cat_memory_pos_embed.append(obj_pos)
764
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
765
+ else:
766
+ num_obj_ptr_tokens = 0
767
+ else:
768
+ # for initial conditioning frames, encode them without using any previous memory
769
+ if self.directly_add_no_mem_embed:
770
+ # directly add no-mem embedding (instead of using the transformer encoder)
771
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
772
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
773
+ return pix_feat_with_mem
774
+
775
+ # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
776
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
777
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
778
+
779
+ # Step 2: Concatenate the memories and forward through the transformer encoder
780
+ memory = torch.cat(to_cat_memory, dim=0)
781
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
782
+
783
+ pix_feat_with_mem = self.memory_attention(
784
+ curr=current_vision_feats,
785
+ curr_pos=current_vision_pos_embeds,
786
+ memory=memory,
787
+ memory_pos=memory_pos_embed,
788
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
789
+ )
790
+ # Reshape output (HW)BC => BCHW
791
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
792
+ return pix_feat_with_mem
793
+
794
+ def _encode_new_memory(
795
+ self,
796
+ current_vision_feats,
797
+ feat_sizes,
798
+ pred_masks_high_res,
799
+ object_score_logits,
800
+ is_mask_from_pts,
801
+ ):
802
+ """Encode frame features and masks into a new memory representation for video segmentation."""
803
+ B = current_vision_feats[-1].size(1) # batch size on this frame
804
+ C = self.hidden_dim
805
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
806
+ # top-level feature, (HW)BC => BCHW
807
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
808
+ if self.non_overlap_masks_for_mem_enc and not self.training:
809
+ # optionally, apply non-overlapping constraints to the masks (it's applied
810
+ # in the batch dimension and should only be used during eval, where all
811
+ # the objects come from the same video under batch size 1).
812
+ pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
813
+ # scale the raw mask logits with a temperature before applying sigmoid
814
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
815
+ if binarize and not self.training:
816
+ mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype)
817
+ else:
818
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
819
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
820
+ # apply scale and bias terms to the sigmoid probabilities
821
+ if self.sigmoid_scale_for_mem_enc != 1.0:
822
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
823
+ if self.sigmoid_bias_for_mem_enc != 0.0:
824
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
825
+ maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
826
+ maskmem_features = maskmem_out["vision_features"]
827
+ # add a no-object embedding to the spatial memory to indicate that the frame
828
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
829
+ if self.no_obj_embed_spatial is not None:
830
+ is_obj_appearing = (object_score_logits > 0).float()
831
+ maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[
832
+ ..., None, None
833
+ ].expand(*maskmem_features.shape)
834
+
835
+ return maskmem_features, maskmem_out["vision_pos_enc"]
836
+
837
+ def _track_step(
838
+ self,
839
+ frame_idx,
840
+ is_init_cond_frame,
841
+ current_vision_feats,
842
+ current_vision_pos_embeds,
843
+ feat_sizes,
844
+ point_inputs,
845
+ mask_inputs,
846
+ output_dict,
847
+ num_frames,
848
+ track_in_reverse,
849
+ prev_sam_mask_logits,
850
+ ):
851
+ """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
852
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
853
+ if len(current_vision_feats) > 1:
854
+ high_res_features = [
855
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
856
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
857
+ ]
858
+ else:
859
+ high_res_features = None
860
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
861
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
862
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
863
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
864
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
865
+ sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
866
+ else:
867
+ # Fuse visual features with previous memory features in the memory bank
868
+ pix_feat = self._prepare_memory_conditioned_features(
869
+ frame_idx=frame_idx,
870
+ is_init_cond_frame=is_init_cond_frame,
871
+ current_vision_feats=current_vision_feats[-1:],
872
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
873
+ feat_sizes=feat_sizes[-1:],
874
+ output_dict=output_dict,
875
+ num_frames=num_frames,
876
+ track_in_reverse=track_in_reverse,
877
+ )
878
+ # apply SAM-style segmentation head
879
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
880
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
881
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
882
+ if prev_sam_mask_logits is not None:
883
+ assert point_inputs is not None and mask_inputs is None
884
+ mask_inputs = prev_sam_mask_logits
885
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
886
+ sam_outputs = self._forward_sam_heads(
887
+ backbone_features=pix_feat,
888
+ point_inputs=point_inputs,
889
+ mask_inputs=mask_inputs,
890
+ high_res_features=high_res_features,
891
+ multimask_output=multimask_output,
892
+ )
893
+ return sam_outputs, high_res_features, pix_feat
894
+
895
+ def _encode_memory_in_output(
896
+ self,
897
+ current_vision_feats,
898
+ feat_sizes,
899
+ point_inputs,
900
+ run_mem_encoder,
901
+ high_res_masks,
902
+ object_score_logits,
903
+ current_out,
904
+ ):
905
+ """Run memory encoder on predicted mask to encode it into a new memory feature for future frames."""
906
+ if run_mem_encoder and self.num_maskmem > 0:
907
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
908
+ current_vision_feats=current_vision_feats,
909
+ feat_sizes=feat_sizes,
910
+ pred_masks_high_res=high_res_masks,
911
+ object_score_logits=object_score_logits,
912
+ is_mask_from_pts=(point_inputs is not None),
913
+ )
914
+ current_out["maskmem_features"] = maskmem_features
915
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
916
+ else:
917
+ current_out["maskmem_features"] = None
918
+ current_out["maskmem_pos_enc"] = None
919
+
920
+ def track_step(
921
+ self,
922
+ frame_idx,
923
+ is_init_cond_frame,
924
+ current_vision_feats,
925
+ current_vision_pos_embeds,
926
+ feat_sizes,
927
+ point_inputs,
928
+ mask_inputs,
929
+ output_dict,
930
+ num_frames,
931
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
932
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
933
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
934
+ # in demo we might call `track_step` multiple times for each user click,
935
+ # and only encode the memory when the user finalizes their clicks. And in ablation
936
+ # settings like SAM training on static images, we don't need the memory encoder.
937
+ run_mem_encoder=True,
938
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
939
+ prev_sam_mask_logits=None,
940
+ ):
941
+ """Perform a single tracking step, updating object masks and memory features based on current frame inputs."""
942
+ sam_outputs, _, _ = self._track_step(
943
+ frame_idx,
944
+ is_init_cond_frame,
945
+ current_vision_feats,
946
+ current_vision_pos_embeds,
947
+ feat_sizes,
948
+ point_inputs,
949
+ mask_inputs,
950
+ output_dict,
951
+ num_frames,
952
+ track_in_reverse,
953
+ prev_sam_mask_logits,
954
+ )
955
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
956
+
957
+ current_out = {
958
+ "pred_masks": low_res_masks,
959
+ "pred_masks_high_res": high_res_masks,
960
+ "obj_ptr": obj_ptr,
961
+ }
962
+ if not self.training:
963
+ # Only add this in inference (to avoid unused param in activation checkpointing;
964
+ # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
965
+ current_out["object_score_logits"] = object_score_logits
966
+
967
+ # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
968
+ self._encode_memory_in_output(
969
+ current_vision_feats,
970
+ feat_sizes,
971
+ point_inputs,
972
+ run_mem_encoder,
973
+ high_res_masks,
974
+ object_score_logits,
975
+ current_out,
976
+ )
977
+
978
+ return current_out
979
+
980
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
981
+ """Determine whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
982
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
983
+ return (
984
+ self.multimask_output_in_sam
985
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
986
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
987
+ )
988
+
989
+ @staticmethod
990
+ def _apply_non_overlapping_constraints(pred_masks):
991
+ """Apply non-overlapping constraints to masks, keeping the highest scoring object per location."""
992
+ batch_size = pred_masks.shape[0]
993
+ if batch_size == 1:
994
+ return pred_masks
995
+
996
+ device = pred_masks.device
997
+ # "max_obj_inds": object index of the object with the highest score at each location
998
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
999
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
1000
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
1001
+ keep = max_obj_inds == batch_obj_inds
1002
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
1003
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
1004
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
1005
+ return pred_masks
1006
+
1007
+ def set_binarize(self, binarize=False):
1008
+ """Set binarize for VideoPredictor."""
1009
+ self.binarize_mask_from_pts_for_mem_enc = binarize
1010
+
1011
+ def set_imgsz(self, imgsz):
1012
+ """Set image size to make model compatible with different image sizes."""
1013
+ if hasattr(self.image_encoder, "set_imgsz"):
1014
+ self.image_encoder.set_imgsz(imgsz)
1015
+ self.image_size = imgsz[0]
1016
+ self.sam_prompt_encoder.input_image_size = imgsz
1017
+ self.sam_prompt_encoder.image_embedding_size = [
1018
+ x // self.backbone_stride for x in imgsz
1019
+ ] # fixed ViT patch size of 16
1020
+ self.sam_prompt_encoder.mask_input_size = [
1021
+ x // self.backbone_stride * 4 for x in imgsz
1022
+ ] # fixed ViT patch size of 16
1023
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
1024
+
1025
+
1026
+ class SAM3Model(SAM2Model):
1027
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
1028
+
1029
+ def __init__(
1030
+ self,
1031
+ image_encoder,
1032
+ memory_attention,
1033
+ memory_encoder,
1034
+ num_maskmem=7,
1035
+ image_size=1008,
1036
+ backbone_stride=14,
1037
+ sigmoid_scale_for_mem_enc=1,
1038
+ sigmoid_bias_for_mem_enc=0,
1039
+ binarize_mask_from_pts_for_mem_enc=False,
1040
+ use_mask_input_as_output_without_sam=False,
1041
+ max_cond_frames_in_attn=-1,
1042
+ directly_add_no_mem_embed=False,
1043
+ use_high_res_features_in_sam=False,
1044
+ multimask_output_in_sam=False,
1045
+ multimask_min_pt_num=1,
1046
+ multimask_max_pt_num=1,
1047
+ multimask_output_for_tracking=False,
1048
+ use_multimask_token_for_obj_ptr: bool = False,
1049
+ iou_prediction_use_sigmoid=False,
1050
+ memory_temporal_stride_for_eval=1,
1051
+ non_overlap_masks_for_mem_enc=False,
1052
+ use_obj_ptrs_in_encoder=False,
1053
+ max_obj_ptrs_in_encoder=16,
1054
+ add_tpos_enc_to_obj_ptrs=True,
1055
+ proj_tpos_enc_in_obj_ptrs=False,
1056
+ use_signed_tpos_enc_to_obj_ptrs=False,
1057
+ only_obj_ptrs_in_the_past_for_eval=False,
1058
+ pred_obj_scores: bool = False,
1059
+ pred_obj_scores_mlp: bool = False,
1060
+ fixed_no_obj_ptr: bool = False,
1061
+ soft_no_obj_ptr: bool = False,
1062
+ use_mlp_for_obj_ptr_proj: bool = False,
1063
+ no_obj_embed_spatial: bool = False,
1064
+ sam_mask_decoder_extra_args=None,
1065
+ compile_image_encoder: bool = False,
1066
+ ):
1067
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
1068
+ super().__init__(
1069
+ image_encoder,
1070
+ memory_attention,
1071
+ memory_encoder,
1072
+ num_maskmem,
1073
+ image_size,
1074
+ backbone_stride,
1075
+ sigmoid_scale_for_mem_enc,
1076
+ sigmoid_bias_for_mem_enc,
1077
+ binarize_mask_from_pts_for_mem_enc,
1078
+ use_mask_input_as_output_without_sam,
1079
+ max_cond_frames_in_attn,
1080
+ directly_add_no_mem_embed,
1081
+ use_high_res_features_in_sam,
1082
+ multimask_output_in_sam,
1083
+ multimask_min_pt_num,
1084
+ multimask_max_pt_num,
1085
+ multimask_output_for_tracking,
1086
+ use_multimask_token_for_obj_ptr,
1087
+ iou_prediction_use_sigmoid,
1088
+ memory_temporal_stride_for_eval,
1089
+ non_overlap_masks_for_mem_enc,
1090
+ use_obj_ptrs_in_encoder,
1091
+ max_obj_ptrs_in_encoder,
1092
+ add_tpos_enc_to_obj_ptrs,
1093
+ proj_tpos_enc_in_obj_ptrs,
1094
+ use_signed_tpos_enc_to_obj_ptrs,
1095
+ only_obj_ptrs_in_the_past_for_eval,
1096
+ pred_obj_scores,
1097
+ pred_obj_scores_mlp,
1098
+ fixed_no_obj_ptr,
1099
+ soft_no_obj_ptr,
1100
+ use_mlp_for_obj_ptr_proj,
1101
+ no_obj_embed_spatial,
1102
+ sam_mask_decoder_extra_args,
1103
+ compile_image_encoder,
1104
+ )
1105
+ self.sam_mask_decoder = SAM2MaskDecoder(
1106
+ num_multimask_outputs=3,
1107
+ transformer=TwoWayTransformer(
1108
+ depth=2,
1109
+ embedding_dim=self.sam_prompt_embed_dim,
1110
+ mlp_dim=2048,
1111
+ num_heads=8,
1112
+ ),
1113
+ transformer_dim=self.sam_prompt_embed_dim,
1114
+ iou_head_depth=3,
1115
+ iou_head_hidden_dim=256,
1116
+ use_high_res_features=self.use_high_res_features_in_sam,
1117
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
1118
+ pred_obj_scores=self.pred_obj_scores,
1119
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
1120
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
1121
+ **(self.sam_mask_decoder_extra_args or {}),
1122
+ )
1123
+
1124
+ def forward_image(self, img_batch: torch.Tensor):
1125
+ """Process image batch through encoder to extract multi-level features for SAM model."""
1126
+ backbone_out = self.image_encoder.forward_image_sam2(img_batch)
1127
+ if self.use_high_res_features_in_sam:
1128
+ # precompute projected level 0 and level 1 features in SAM decoder
1129
+ # to avoid running it again on every SAM click
1130
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
1131
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
1132
+ return backbone_out
1133
+
1134
+ def set_imgsz(self, imgsz: tuple[int, int]):
1135
+ """Set the image size for the model and mask downsampler."""
1136
+ super().set_imgsz(imgsz)
1137
+ self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
1138
+
1139
+ @staticmethod
1140
+ def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
1141
+ """Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
1142
+ area_before = (pred_masks > 0).sum(dim=(-1, -2))
1143
+ area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
1144
+ area_before = torch.clamp(area_before, min=1.0)
1145
+ area_ratio = area_after / area_before
1146
+ keep = area_ratio >= shrink_threshold
1147
+ keep_mask = keep[..., None, None].expand_as(pred_masks)
1148
+ pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
1149
+ return pred_masks_after
1150
+
1151
+ def _suppress_object_pw_area_shrinkage(self, pred_masks):
1152
+ """This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
1153
+ that the final output can still be overlapping.
1154
+ """
1155
+ # Apply pixel-wise non-overlapping constraint based on mask scores
1156
+ pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
1157
+ # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
1158
+ # NOTE: The output of this function can be a no op if none of the masks shrink by a large factor.
1159
+ pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
1160
+ return pred_masks