dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from torch.nn.init import trunc_normal_
13
13
  from ultralytics.nn.modules import MLP
14
14
  from ultralytics.utils import LOGGER
15
15
 
16
- from .blocks import SAM2TwoWayTransformer
16
+ from .blocks import SAM2TwoWayTransformer, TwoWayTransformer
17
17
  from .decoders import MaskDecoder, SAM2MaskDecoder
18
18
  from .encoders import ImageEncoderViT, PromptEncoder
19
19
  from .utils import get_1d_sine_pe, select_closest_cond_frames
@@ -23,11 +23,10 @@ NO_OBJ_SCORE = -1024.0
23
23
 
24
24
 
25
25
  class SAMModel(nn.Module):
26
- """
27
- Segment Anything Model (SAM) for object segmentation tasks.
26
+ """Segment Anything Model (SAM) for object segmentation tasks.
28
27
 
29
- This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
30
- and input prompts.
28
+ This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
29
+ prompts.
31
30
 
32
31
  Attributes:
33
32
  mask_threshold (float): Threshold value for mask prediction.
@@ -61,8 +60,7 @@ class SAMModel(nn.Module):
61
60
  pixel_mean: list[float] = (123.675, 116.28, 103.53),
62
61
  pixel_std: list[float] = (58.395, 57.12, 57.375),
63
62
  ) -> None:
64
- """
65
- Initialize the SAMModel class to predict object masks from an image and input prompts.
63
+ """Initialize the SAMModel class to predict object masks from an image and input prompts.
66
64
 
67
65
  Args:
68
66
  image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
@@ -71,13 +69,6 @@ class SAMModel(nn.Module):
71
69
  pixel_mean (list[float]): Mean values for normalizing pixels in the input image.
72
70
  pixel_std (list[float]): Standard deviation values for normalizing pixels in the input image.
73
71
 
74
- Examples:
75
- >>> image_encoder = ImageEncoderViT(...)
76
- >>> prompt_encoder = PromptEncoder(...)
77
- >>> mask_decoder = MaskDecoder(...)
78
- >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
79
- >>> # Further usage depends on SAMPredictor class
80
-
81
72
  Notes:
82
73
  All forward() operations moved to SAMPredictor.
83
74
  """
@@ -98,11 +89,10 @@ class SAMModel(nn.Module):
98
89
 
99
90
 
100
91
  class SAM2Model(torch.nn.Module):
101
- """
102
- SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
92
+ """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
103
93
 
104
- This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
105
- for temporal consistency and efficient tracking of objects across frames.
94
+ This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
95
+ consistency and efficient tracking of objects across frames.
106
96
 
107
97
  Attributes:
108
98
  mask_threshold (float): Threshold value for mask prediction.
@@ -136,24 +126,24 @@ class SAM2Model(torch.nn.Module):
136
126
  use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
137
127
  no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
138
128
  max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
139
- directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
140
- first frame.
141
- multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
142
- conditioning frames.
129
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
130
+ frame.
131
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
132
+ frames.
143
133
  multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
144
134
  multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
145
135
  multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
146
136
  use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
147
137
  iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
148
138
  memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
149
- non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
150
- memory encoder during evaluation.
139
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
140
+ encoder during evaluation.
151
141
  sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
152
142
  sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
153
- binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
154
- with clicks during evaluation.
155
- use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
156
- prompt encoder and mask decoder on frames with mask input.
143
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
144
+ clicks during evaluation.
145
+ use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
146
+ encoder and mask decoder on frames with mask input.
157
147
 
158
148
  Methods:
159
149
  forward_image: Process image batch through encoder to extract multi-level features.
@@ -208,8 +198,7 @@ class SAM2Model(torch.nn.Module):
208
198
  sam_mask_decoder_extra_args=None,
209
199
  compile_image_encoder: bool = False,
210
200
  ):
211
- """
212
- Initialize the SAM2Model for video object segmentation with memory-based tracking.
201
+ """Initialize the SAM2Model for video object segmentation with memory-based tracking.
213
202
 
214
203
  Args:
215
204
  image_encoder (nn.Module): Visual encoder for extracting image features.
@@ -220,35 +209,35 @@ class SAM2Model(torch.nn.Module):
220
209
  backbone_stride (int): Stride of the image backbone output.
221
210
  sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
222
211
  sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
223
- binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
224
- with clicks during evaluation.
212
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
213
+ clicks during evaluation.
225
214
  use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
226
215
  prompt encoder and mask decoder on frames with mask input.
227
216
  max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
228
- directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
229
- first frame.
217
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
218
+ frame.
230
219
  use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
231
- multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
232
- conditioning frames.
220
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
221
+ frames.
233
222
  multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
234
223
  multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
235
224
  multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
236
225
  use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
237
226
  iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
238
227
  memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
239
- non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
240
- memory encoder during evaluation.
228
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
229
+ encoder during evaluation.
241
230
  use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
242
231
  max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
243
232
  cross-attention.
244
- add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
245
- the encoder.
233
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
234
+ encoder.
246
235
  proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
247
236
  encoding in object pointers.
248
237
  use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
249
238
  in the object pointers.
250
- only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
251
- during evaluation.
239
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
240
+ evaluation.
252
241
  pred_obj_scores (bool): Whether to predict if there is an object in the frame.
253
242
  pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
254
243
  fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
@@ -257,15 +246,6 @@ class SAM2Model(torch.nn.Module):
257
246
  no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
258
247
  sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.
259
248
  compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
260
-
261
- Examples:
262
- >>> image_encoder = ImageEncoderViT(...)
263
- >>> memory_attention = SAM2TwoWayTransformer(...)
264
- >>> memory_encoder = nn.Sequential(...)
265
- >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
266
- >>> image_batch = torch.rand(1, 3, 512, 512)
267
- >>> features = model.forward_image(image_batch)
268
- >>> track_results = model.track_step(0, True, features, None, None, None, {})
269
249
  """
270
250
  super().__init__()
271
251
 
@@ -349,6 +329,7 @@ class SAM2Model(torch.nn.Module):
349
329
 
350
330
  self._build_sam_heads()
351
331
  self.max_cond_frames_in_attn = max_cond_frames_in_attn
332
+ self.add_all_frames_to_correct_as_cond = True
352
333
 
353
334
  # Model compilation
354
335
  if compile_image_encoder:
@@ -428,25 +409,23 @@ class SAM2Model(torch.nn.Module):
428
409
  high_res_features=None,
429
410
  multimask_output=False,
430
411
  ):
431
- """
432
- Forward pass through SAM prompt encoders and mask heads.
412
+ """Forward pass through SAM prompt encoders and mask heads.
433
413
 
434
414
  This method processes image features and optional point/mask inputs to generate object masks and scores.
435
415
 
436
416
  Args:
437
417
  backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
438
418
  point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
439
- 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
440
- pixel-unit coordinates in (x, y) format for P input points.
441
- 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
442
- 0 means negative clicks, and -1 means padding.
443
- mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
444
- same spatial size as the image.
445
- high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes
446
- (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
447
- for SAM decoder.
448
- multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
449
- output only 1 mask and its IoU estimate.
419
+ 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
420
+ (x, y) format for P input points.
421
+ 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
422
+ clicks, and -1 means padding.
423
+ mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
424
+ size as the image.
425
+ high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
426
+ C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
427
+ multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
428
+ mask and its IoU estimate.
450
429
 
451
430
  Returns:
452
431
  low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
@@ -495,7 +474,7 @@ class SAM2Model(torch.nn.Module):
495
474
  assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
496
475
  if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
497
476
  sam_mask_prompt = F.interpolate(
498
- mask_inputs.float(),
477
+ mask_inputs.to(backbone_features.dtype),
499
478
  size=self.sam_prompt_encoder.mask_input_size,
500
479
  align_corners=False,
501
480
  mode="bilinear",
@@ -593,7 +572,7 @@ class SAM2Model(torch.nn.Module):
593
572
  # produce an object pointer using the SAM decoder from the mask input
594
573
  _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
595
574
  backbone_features=backbone_features,
596
- mask_inputs=self.mask_downsample(mask_inputs_float),
575
+ mask_inputs=self.mask_downsample(mask_inputs_float.to(backbone_features.dtype)),
597
576
  high_res_features=high_res_features,
598
577
  )
599
578
  # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
@@ -628,8 +607,14 @@ class SAM2Model(torch.nn.Module):
628
607
  backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
629
608
  return backbone_out
630
609
 
631
- def _prepare_backbone_features(self, backbone_out):
610
+ def _prepare_backbone_features(self, backbone_out, batch=1):
632
611
  """Prepare and flatten visual features from the image backbone output for further processing."""
612
+ if batch > 1: # expand features if there's more than one prompt
613
+ backbone_out = {
614
+ **backbone_out,
615
+ "backbone_fpn": [feat.expand(batch, -1, -1, -1) for feat in backbone_out["backbone_fpn"]],
616
+ "vision_pos_enc": [pos.expand(batch, -1, -1, -1) for pos in backbone_out["vision_pos_enc"]],
617
+ }
633
618
  assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
634
619
  assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
635
620
 
@@ -640,7 +625,6 @@ class SAM2Model(torch.nn.Module):
640
625
  # flatten NxCxHxW to HWxNxC
641
626
  vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
642
627
  vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
643
-
644
628
  return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
645
629
 
646
630
  def _prepare_memory_conditioned_features(
@@ -803,7 +787,7 @@ class SAM2Model(torch.nn.Module):
803
787
  memory_pos=memory_pos_embed,
804
788
  num_obj_ptr_tokens=num_obj_ptr_tokens,
805
789
  )
806
- # reshape the output (HW)BC => BCHW
790
+ # Reshape output (HW)BC => BCHW
807
791
  pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
808
792
  return pix_feat_with_mem
809
793
 
@@ -840,7 +824,6 @@ class SAM2Model(torch.nn.Module):
840
824
  mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
841
825
  maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
842
826
  maskmem_features = maskmem_out["vision_features"]
843
- maskmem_pos_enc = maskmem_out["vision_pos_enc"]
844
827
  # add a no-object embedding to the spatial memory to indicate that the frame
845
828
  # is predicted to be occluded (i.e. no object is appearing in the frame)
846
829
  if self.no_obj_embed_spatial is not None:
@@ -849,7 +832,7 @@ class SAM2Model(torch.nn.Module):
849
832
  ..., None, None
850
833
  ].expand(*maskmem_features.shape)
851
834
 
852
- return maskmem_features, maskmem_pos_enc
835
+ return maskmem_features, maskmem_out["vision_pos_enc"]
853
836
 
854
837
  def _track_step(
855
838
  self,
@@ -881,7 +864,7 @@ class SAM2Model(torch.nn.Module):
881
864
  pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
882
865
  sam_outputs = self._use_mask_as_output(mask_inputs, pix_feat, high_res_features)
883
866
  else:
884
- # fused the visual feature with previous memory features in the memory bank
867
+ # Fuse visual features with previous memory features in the memory bank
885
868
  pix_feat = self._prepare_memory_conditioned_features(
886
869
  frame_idx=frame_idx,
887
870
  is_init_cond_frame=is_init_cond_frame,
@@ -1027,7 +1010,151 @@ class SAM2Model(torch.nn.Module):
1027
1010
 
1028
1011
  def set_imgsz(self, imgsz):
1029
1012
  """Set image size to make model compatible with different image sizes."""
1013
+ if hasattr(self.image_encoder, "set_imgsz"):
1014
+ self.image_encoder.set_imgsz(imgsz)
1030
1015
  self.image_size = imgsz[0]
1031
1016
  self.sam_prompt_encoder.input_image_size = imgsz
1032
- self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
1017
+ self.sam_prompt_encoder.image_embedding_size = [
1018
+ x // self.backbone_stride for x in imgsz
1019
+ ] # fixed ViT patch size of 16
1020
+ self.sam_prompt_encoder.mask_input_size = [
1021
+ x // self.backbone_stride * 4 for x in imgsz
1022
+ ] # fixed ViT patch size of 16
1033
1023
  self.sam_image_embedding_size = self.image_size // self.backbone_stride # update image embedding size
1024
+
1025
+
1026
+ class SAM3Model(SAM2Model):
1027
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
1028
+
1029
+ def __init__(
1030
+ self,
1031
+ image_encoder,
1032
+ memory_attention,
1033
+ memory_encoder,
1034
+ num_maskmem=7,
1035
+ image_size=1008,
1036
+ backbone_stride=14,
1037
+ sigmoid_scale_for_mem_enc=1,
1038
+ sigmoid_bias_for_mem_enc=0,
1039
+ binarize_mask_from_pts_for_mem_enc=False,
1040
+ use_mask_input_as_output_without_sam=False,
1041
+ max_cond_frames_in_attn=-1,
1042
+ directly_add_no_mem_embed=False,
1043
+ use_high_res_features_in_sam=False,
1044
+ multimask_output_in_sam=False,
1045
+ multimask_min_pt_num=1,
1046
+ multimask_max_pt_num=1,
1047
+ multimask_output_for_tracking=False,
1048
+ use_multimask_token_for_obj_ptr: bool = False,
1049
+ iou_prediction_use_sigmoid=False,
1050
+ memory_temporal_stride_for_eval=1,
1051
+ non_overlap_masks_for_mem_enc=False,
1052
+ use_obj_ptrs_in_encoder=False,
1053
+ max_obj_ptrs_in_encoder=16,
1054
+ add_tpos_enc_to_obj_ptrs=True,
1055
+ proj_tpos_enc_in_obj_ptrs=False,
1056
+ use_signed_tpos_enc_to_obj_ptrs=False,
1057
+ only_obj_ptrs_in_the_past_for_eval=False,
1058
+ pred_obj_scores: bool = False,
1059
+ pred_obj_scores_mlp: bool = False,
1060
+ fixed_no_obj_ptr: bool = False,
1061
+ soft_no_obj_ptr: bool = False,
1062
+ use_mlp_for_obj_ptr_proj: bool = False,
1063
+ no_obj_embed_spatial: bool = False,
1064
+ sam_mask_decoder_extra_args=None,
1065
+ compile_image_encoder: bool = False,
1066
+ ):
1067
+ """SAM3Model class for Segment Anything Model 3 with memory-based video object segmentation capabilities."""
1068
+ super().__init__(
1069
+ image_encoder,
1070
+ memory_attention,
1071
+ memory_encoder,
1072
+ num_maskmem,
1073
+ image_size,
1074
+ backbone_stride,
1075
+ sigmoid_scale_for_mem_enc,
1076
+ sigmoid_bias_for_mem_enc,
1077
+ binarize_mask_from_pts_for_mem_enc,
1078
+ use_mask_input_as_output_without_sam,
1079
+ max_cond_frames_in_attn,
1080
+ directly_add_no_mem_embed,
1081
+ use_high_res_features_in_sam,
1082
+ multimask_output_in_sam,
1083
+ multimask_min_pt_num,
1084
+ multimask_max_pt_num,
1085
+ multimask_output_for_tracking,
1086
+ use_multimask_token_for_obj_ptr,
1087
+ iou_prediction_use_sigmoid,
1088
+ memory_temporal_stride_for_eval,
1089
+ non_overlap_masks_for_mem_enc,
1090
+ use_obj_ptrs_in_encoder,
1091
+ max_obj_ptrs_in_encoder,
1092
+ add_tpos_enc_to_obj_ptrs,
1093
+ proj_tpos_enc_in_obj_ptrs,
1094
+ use_signed_tpos_enc_to_obj_ptrs,
1095
+ only_obj_ptrs_in_the_past_for_eval,
1096
+ pred_obj_scores,
1097
+ pred_obj_scores_mlp,
1098
+ fixed_no_obj_ptr,
1099
+ soft_no_obj_ptr,
1100
+ use_mlp_for_obj_ptr_proj,
1101
+ no_obj_embed_spatial,
1102
+ sam_mask_decoder_extra_args,
1103
+ compile_image_encoder,
1104
+ )
1105
+ self.sam_mask_decoder = SAM2MaskDecoder(
1106
+ num_multimask_outputs=3,
1107
+ transformer=TwoWayTransformer(
1108
+ depth=2,
1109
+ embedding_dim=self.sam_prompt_embed_dim,
1110
+ mlp_dim=2048,
1111
+ num_heads=8,
1112
+ ),
1113
+ transformer_dim=self.sam_prompt_embed_dim,
1114
+ iou_head_depth=3,
1115
+ iou_head_hidden_dim=256,
1116
+ use_high_res_features=self.use_high_res_features_in_sam,
1117
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
1118
+ pred_obj_scores=self.pred_obj_scores,
1119
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
1120
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
1121
+ **(self.sam_mask_decoder_extra_args or {}),
1122
+ )
1123
+
1124
+ def forward_image(self, img_batch: torch.Tensor):
1125
+ """Process image batch through encoder to extract multi-level features for SAM model."""
1126
+ backbone_out = self.image_encoder.forward_image_sam2(img_batch)
1127
+ if self.use_high_res_features_in_sam:
1128
+ # precompute projected level 0 and level 1 features in SAM decoder
1129
+ # to avoid running it again on every SAM click
1130
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
1131
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
1132
+ return backbone_out
1133
+
1134
+ def set_imgsz(self, imgsz: tuple[int, int]):
1135
+ """Set the image size for the model and mask downsampler."""
1136
+ super().set_imgsz(imgsz)
1137
+ self.memory_encoder.mask_downsampler.interpol_size = [size // 14 * 16 for size in imgsz]
1138
+
1139
+ @staticmethod
1140
+ def _suppress_shrinked_masks(pred_masks, new_pred_masks, shrink_threshold=0.3):
1141
+ """Suppress masks that shrink in area after applying pixelwise non-overlapping constraints."""
1142
+ area_before = (pred_masks > 0).sum(dim=(-1, -2))
1143
+ area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
1144
+ area_before = torch.clamp(area_before, min=1.0)
1145
+ area_ratio = area_after / area_before
1146
+ keep = area_ratio >= shrink_threshold
1147
+ keep_mask = keep[..., None, None].expand_as(pred_masks)
1148
+ pred_masks_after = torch.where(keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0))
1149
+ return pred_masks_after
1150
+
1151
+ def _suppress_object_pw_area_shrinkage(self, pred_masks):
1152
+ """This function suppresses masks that shrink in area after applying pixelwise non-overlapping constraints. Note
1153
+ that the final output can still be overlapping.
1154
+ """
1155
+ # Apply pixel-wise non-overlapping constraint based on mask scores
1156
+ pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks)
1157
+ # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
1158
+ # NOTE: The output of this function can be a no op if none of the masks shrink by a large factor.
1159
+ pred_masks = self._suppress_shrinked_masks(pred_masks, pixel_level_non_overlapping_masks)
1160
+ return pred_masks