frontveg 0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. frontveg/__init__.py +11 -0
  2. frontveg/_tests/__init__.py +0 -0
  3. frontveg/_tests/test_widget.py +66 -0
  4. frontveg/_version.py +21 -0
  5. frontveg/_widget.py +132 -0
  6. frontveg/napari.yaml +14 -0
  7. frontveg/utils.py +95 -0
  8. frontveg-0.1.dev1.dist-info/METADATA +143 -0
  9. frontveg-0.1.dev1.dist-info/RECORD +44 -0
  10. frontveg-0.1.dev1.dist-info/WHEEL +5 -0
  11. frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
  12. frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
  13. frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
  14. sam2/__init__.py +11 -0
  15. sam2/automatic_mask_generator.py +454 -0
  16. sam2/build_sam.py +167 -0
  17. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  18. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  19. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  20. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  21. sam2/modeling/__init__.py +5 -0
  22. sam2/modeling/backbones/__init__.py +5 -0
  23. sam2/modeling/backbones/hieradet.py +317 -0
  24. sam2/modeling/backbones/image_encoder.py +134 -0
  25. sam2/modeling/backbones/utils.py +95 -0
  26. sam2/modeling/memory_attention.py +169 -0
  27. sam2/modeling/memory_encoder.py +181 -0
  28. sam2/modeling/position_encoding.py +221 -0
  29. sam2/modeling/sam/__init__.py +5 -0
  30. sam2/modeling/sam/mask_decoder.py +295 -0
  31. sam2/modeling/sam/prompt_encoder.py +182 -0
  32. sam2/modeling/sam/transformer.py +360 -0
  33. sam2/modeling/sam2_base.py +907 -0
  34. sam2/modeling/sam2_utils.py +323 -0
  35. sam2/sam2_hiera_b+.yaml +1 -0
  36. sam2/sam2_hiera_l.yaml +1 -0
  37. sam2/sam2_hiera_s.yaml +1 -0
  38. sam2/sam2_hiera_t.yaml +1 -0
  39. sam2/sam2_image_predictor.py +466 -0
  40. sam2/sam2_video_predictor.py +1172 -0
  41. sam2/utils/__init__.py +5 -0
  42. sam2/utils/amg.py +348 -0
  43. sam2/utils/misc.py +349 -0
  44. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,907 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed
9
+ import torch.nn.functional as F
10
+
11
+ from torch.nn.init import trunc_normal_
12
+
13
+ from sam2.modeling.sam.mask_decoder import MaskDecoder
14
+ from sam2.modeling.sam.prompt_encoder import PromptEncoder
15
+ from sam2.modeling.sam.transformer import TwoWayTransformer
16
+ from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
17
+
18
+ # a large negative value as a placeholder score for missing objects
19
+ NO_OBJ_SCORE = -1024.0
20
+
21
+
22
+ class SAM2Base(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_encoder,
26
+ memory_attention,
27
+ memory_encoder,
28
+ num_maskmem=7, # default 1 input frame + 6 previous frames
29
+ image_size=512,
30
+ backbone_stride=16, # stride of the image backbone output
31
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
32
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
33
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
34
+ binarize_mask_from_pts_for_mem_enc=False,
35
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
36
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
37
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
38
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
39
+ max_cond_frames_in_attn=-1,
40
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
41
+ # (instead of using the transformer encoder)
42
+ directly_add_no_mem_embed=False,
43
+ # whether to use high-resolution feature maps in the SAM mask decoder
44
+ use_high_res_features_in_sam=False,
45
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
46
+ multimask_output_in_sam=False,
47
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
48
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
49
+ multimask_min_pt_num=1,
50
+ multimask_max_pt_num=1,
51
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
52
+ multimask_output_for_tracking=False,
53
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
54
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
55
+ use_multimask_token_for_obj_ptr: bool = False,
56
+ # whether to use sigmoid to restrict ious prediction to [0-1]
57
+ iou_prediction_use_sigmoid=False,
58
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
59
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
60
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
61
+ memory_temporal_stride_for_eval=1,
62
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
63
+ non_overlap_masks_for_mem_enc=False,
64
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
65
+ use_obj_ptrs_in_encoder=False,
66
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
67
+ max_obj_ptrs_in_encoder=16,
68
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
69
+ add_tpos_enc_to_obj_ptrs=True,
70
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
71
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
72
+ proj_tpos_enc_in_obj_ptrs=False,
73
+ # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
74
+ # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
75
+ use_signed_tpos_enc_to_obj_ptrs=False,
76
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
77
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
78
+ only_obj_ptrs_in_the_past_for_eval=False,
79
+ # Whether to predict if there is an object in the frame
80
+ pred_obj_scores: bool = False,
81
+ # Whether to use an MLP to predict object scores
82
+ pred_obj_scores_mlp: bool = False,
83
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
84
+ # Whether to have a fixed no obj pointer when there is no object present
85
+ # or to use it as an additive embedding with obj_ptr produced by decoder
86
+ fixed_no_obj_ptr: bool = False,
87
+ # Soft no object, i.e. mix in no_obj_ptr softly,
88
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
89
+ soft_no_obj_ptr: bool = False,
90
+ use_mlp_for_obj_ptr_proj: bool = False,
91
+ # add no obj embedding to spatial frames
92
+ no_obj_embed_spatial: bool = False,
93
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
94
+ sam_mask_decoder_extra_args=None,
95
+ compile_image_encoder: bool = False,
96
+ ):
97
+ super().__init__()
98
+
99
+ # Part 1: the image backbone
100
+ self.image_encoder = image_encoder
101
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
102
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
103
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
104
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
105
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
106
+ if use_obj_ptrs_in_encoder:
107
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
108
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
109
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
110
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
111
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
112
+ if proj_tpos_enc_in_obj_ptrs:
113
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
114
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
115
+ self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
116
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
117
+
118
+ # Part 2: memory attention to condition current frame's visual features
119
+ # with memories (and obj ptrs) from past frames
120
+ self.memory_attention = memory_attention
121
+ self.hidden_dim = image_encoder.neck.d_model
122
+
123
+ # Part 3: memory encoder for the previous frame's outputs
124
+ self.memory_encoder = memory_encoder
125
+ self.mem_dim = self.hidden_dim
126
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(
127
+ self.memory_encoder.out_proj, "weight"
128
+ ):
129
+ # if there is compression of memories along channel dim
130
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
131
+ self.num_maskmem = num_maskmem # Number of memories accessible
132
+ # Temporal encoding of the memories
133
+ self.maskmem_tpos_enc = torch.nn.Parameter(
134
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
135
+ )
136
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
137
+ # a single token to indicate no memory embedding from previous frames
138
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
139
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
140
+ trunc_normal_(self.no_mem_embed, std=0.02)
141
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
142
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
143
+ # Apply sigmoid to the output raw mask logits (to turn them from
144
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
145
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
146
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
147
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
148
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
149
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
150
+ # On frames with mask input, whether to directly output the input mask without
151
+ # using a SAM prompt encoder + mask decoder
152
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
153
+ self.multimask_output_in_sam = multimask_output_in_sam
154
+ self.multimask_min_pt_num = multimask_min_pt_num
155
+ self.multimask_max_pt_num = multimask_max_pt_num
156
+ self.multimask_output_for_tracking = multimask_output_for_tracking
157
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
158
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
159
+
160
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
161
+ # and SAM-style mask decoder for the final mask output
162
+ self.image_size = image_size
163
+ self.backbone_stride = backbone_stride
164
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
165
+ self.pred_obj_scores = pred_obj_scores
166
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
167
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
168
+ self.soft_no_obj_ptr = soft_no_obj_ptr
169
+ if self.fixed_no_obj_ptr:
170
+ assert self.pred_obj_scores
171
+ assert self.use_obj_ptrs_in_encoder
172
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
173
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
174
+ trunc_normal_(self.no_obj_ptr, std=0.02)
175
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
176
+ self.no_obj_embed_spatial = None
177
+ if no_obj_embed_spatial:
178
+ self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
179
+ trunc_normal_(self.no_obj_embed_spatial, std=0.02)
180
+
181
+ self._build_sam_heads()
182
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
183
+
184
+ # Model compilation
185
+ if compile_image_encoder:
186
+ # Compile the forward function (not the full module) to allow loading checkpoints.
187
+ print(
188
+ "Image encoder compilation is enabled. First forward pass will be slow."
189
+ )
190
+ self.image_encoder.forward = torch.compile(
191
+ self.image_encoder.forward,
192
+ mode="max-autotune",
193
+ fullgraph=True,
194
+ dynamic=False,
195
+ )
196
+
197
+ @property
198
+ def device(self):
199
+ return next(self.parameters()).device
200
+
201
+ def forward(self, *args, **kwargs):
202
+ raise NotImplementedError(
203
+ "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
204
+ "See notebooks/video_predictor_example.ipynb for an inference example."
205
+ )
206
+
207
+ def _build_sam_heads(self):
208
+ """Build SAM-style prompt encoder and mask decoder."""
209
+ self.sam_prompt_embed_dim = self.hidden_dim
210
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
211
+
212
+ # build PromptEncoder and MaskDecoder from SAM
213
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
214
+ self.sam_prompt_encoder = PromptEncoder(
215
+ embed_dim=self.sam_prompt_embed_dim,
216
+ image_embedding_size=(
217
+ self.sam_image_embedding_size,
218
+ self.sam_image_embedding_size,
219
+ ),
220
+ input_image_size=(self.image_size, self.image_size),
221
+ mask_in_chans=16,
222
+ )
223
+ self.sam_mask_decoder = MaskDecoder(
224
+ num_multimask_outputs=3,
225
+ transformer=TwoWayTransformer(
226
+ depth=2,
227
+ embedding_dim=self.sam_prompt_embed_dim,
228
+ mlp_dim=2048,
229
+ num_heads=8,
230
+ ),
231
+ transformer_dim=self.sam_prompt_embed_dim,
232
+ iou_head_depth=3,
233
+ iou_head_hidden_dim=256,
234
+ use_high_res_features=self.use_high_res_features_in_sam,
235
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
236
+ pred_obj_scores=self.pred_obj_scores,
237
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
238
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
239
+ **(self.sam_mask_decoder_extra_args or {}),
240
+ )
241
+ if self.use_obj_ptrs_in_encoder:
242
+ # a linear projection on SAM output tokens to turn them into object pointers
243
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
244
+ if self.use_mlp_for_obj_ptr_proj:
245
+ self.obj_ptr_proj = MLP(
246
+ self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
247
+ )
248
+ else:
249
+ self.obj_ptr_proj = torch.nn.Identity()
250
+ if self.proj_tpos_enc_in_obj_ptrs:
251
+ # a linear projection on temporal positional encoding in object pointers to
252
+ # avoid potential interference with spatial positional encoding
253
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
254
+ else:
255
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
256
+
257
+ def _forward_sam_heads(
258
+ self,
259
+ backbone_features,
260
+ point_inputs=None,
261
+ mask_inputs=None,
262
+ high_res_features=None,
263
+ multimask_output=False,
264
+ ):
265
+ """
266
+ Forward SAM prompt encoders and mask heads.
267
+
268
+ Inputs:
269
+ - backbone_features: image features of [B, C, H, W] shape
270
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
271
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
272
+ absolute pixel-unit coordinate in (x, y) format of the P input points
273
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
274
+ positive clicks, 0 means negative clicks, and -1 means padding
275
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
276
+ same spatial size as the image.
277
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
278
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
279
+ which will be used as high-resolution feature maps for SAM decoder.
280
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
281
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
282
+ its corresponding IoU estimate.
283
+
284
+ Outputs:
285
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
286
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
287
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
288
+ the resolution (1/4 stride) of the input backbone_features.
289
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
290
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
291
+ upsampled from the low-resolution masks, with shape size as the image
292
+ (stride is 1 pixel).
293
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
294
+ if `multimask_output=False`), the estimated IoU of each output mask.
295
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
296
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
297
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
298
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
299
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
300
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
301
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
302
+ based on the output token from the SAM mask decoder.
303
+ """
304
+ B = backbone_features.size(0)
305
+ device = backbone_features.device
306
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
307
+ assert backbone_features.size(2) == self.sam_image_embedding_size
308
+ assert backbone_features.size(3) == self.sam_image_embedding_size
309
+
310
+ # a) Handle point prompts
311
+ if point_inputs is not None:
312
+ sam_point_coords = point_inputs["point_coords"]
313
+ sam_point_labels = point_inputs["point_labels"]
314
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
315
+ else:
316
+ # If no points are provide, pad with an empty point (with label -1)
317
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
318
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
319
+
320
+ # b) Handle mask prompts
321
+ if mask_inputs is not None:
322
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
323
+ # and feed it as a dense mask prompt into the SAM mask encoder
324
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
325
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
326
+ sam_mask_prompt = F.interpolate(
327
+ mask_inputs.float(),
328
+ size=self.sam_prompt_encoder.mask_input_size,
329
+ align_corners=False,
330
+ mode="bilinear",
331
+ antialias=True, # use antialias for downsampling
332
+ )
333
+ else:
334
+ sam_mask_prompt = mask_inputs
335
+ else:
336
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
337
+ # a learned `no_mask_embed` to indicate no mask input in this case).
338
+ sam_mask_prompt = None
339
+
340
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
341
+ points=(sam_point_coords, sam_point_labels),
342
+ boxes=None,
343
+ masks=sam_mask_prompt,
344
+ )
345
+ (
346
+ low_res_multimasks,
347
+ ious,
348
+ sam_output_tokens,
349
+ object_score_logits,
350
+ ) = self.sam_mask_decoder(
351
+ image_embeddings=backbone_features,
352
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
353
+ sparse_prompt_embeddings=sparse_embeddings,
354
+ dense_prompt_embeddings=dense_embeddings,
355
+ multimask_output=multimask_output,
356
+ repeat_image=False, # the image is already batched
357
+ high_res_features=high_res_features,
358
+ )
359
+ if self.pred_obj_scores:
360
+ is_obj_appearing = object_score_logits > 0
361
+
362
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
363
+ # consistent with the actual mask prediction
364
+ low_res_multimasks = torch.where(
365
+ is_obj_appearing[:, None, None],
366
+ low_res_multimasks,
367
+ NO_OBJ_SCORE,
368
+ )
369
+
370
+ # convert masks from possibly bfloat16 (or float16) to float32
371
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
372
+ low_res_multimasks = low_res_multimasks.float()
373
+ high_res_multimasks = F.interpolate(
374
+ low_res_multimasks,
375
+ size=(self.image_size, self.image_size),
376
+ mode="bilinear",
377
+ align_corners=False,
378
+ )
379
+
380
+ sam_output_token = sam_output_tokens[:, 0]
381
+ if multimask_output:
382
+ # take the best mask prediction (with the highest IoU estimation)
383
+ best_iou_inds = torch.argmax(ious, dim=-1)
384
+ batch_inds = torch.arange(B, device=device)
385
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
386
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
387
+ if sam_output_tokens.size(1) > 1:
388
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
389
+ else:
390
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
391
+
392
+ # Extract object pointer from the SAM output token (with occlusion handling)
393
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
394
+ if self.pred_obj_scores:
395
+ # Allow *soft* no obj ptr, unlike for masks
396
+ if self.soft_no_obj_ptr:
397
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
398
+ else:
399
+ lambda_is_obj_appearing = is_obj_appearing.float()
400
+
401
+ if self.fixed_no_obj_ptr:
402
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
403
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
404
+
405
+ return (
406
+ low_res_multimasks,
407
+ high_res_multimasks,
408
+ ious,
409
+ low_res_masks,
410
+ high_res_masks,
411
+ obj_ptr,
412
+ object_score_logits,
413
+ )
414
+
415
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
416
+ """
417
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
418
+ (same input and output shapes as in _forward_sam_heads above).
419
+ """
420
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
421
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
422
+ mask_inputs_float = mask_inputs.float()
423
+ high_res_masks = mask_inputs_float * out_scale + out_bias
424
+ low_res_masks = F.interpolate(
425
+ high_res_masks,
426
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
427
+ align_corners=False,
428
+ mode="bilinear",
429
+ antialias=True, # use antialias for downsampling
430
+ )
431
+ # a dummy IoU prediction of all 1's under mask input
432
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
433
+ if not self.use_obj_ptrs_in_encoder:
434
+ # all zeros as a dummy object pointer (of shape [B, C])
435
+ obj_ptr = torch.zeros(
436
+ mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
437
+ )
438
+ else:
439
+ # produce an object pointer using the SAM decoder from the mask input
440
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
441
+ backbone_features=backbone_features,
442
+ mask_inputs=self.mask_downsample(mask_inputs_float),
443
+ high_res_features=high_res_features,
444
+ )
445
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
446
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
447
+ # on the object_scores from the SAM decoder.
448
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
449
+ is_obj_appearing = is_obj_appearing[..., None]
450
+ lambda_is_obj_appearing = is_obj_appearing.float()
451
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
452
+ if self.pred_obj_scores:
453
+ if self.fixed_no_obj_ptr:
454
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
455
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
456
+
457
+ return (
458
+ low_res_masks,
459
+ high_res_masks,
460
+ ious,
461
+ low_res_masks,
462
+ high_res_masks,
463
+ obj_ptr,
464
+ object_score_logits,
465
+ )
466
+
467
+ def forward_image(self, img_batch: torch.Tensor):
468
+ """Get the image feature on the input batch."""
469
+ backbone_out = self.image_encoder(img_batch)
470
+ if self.use_high_res_features_in_sam:
471
+ # precompute projected level 0 and level 1 features in SAM decoder
472
+ # to avoid running it again on every SAM click
473
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
474
+ backbone_out["backbone_fpn"][0]
475
+ )
476
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
477
+ backbone_out["backbone_fpn"][1]
478
+ )
479
+ return backbone_out
480
+
481
+ def _prepare_backbone_features(self, backbone_out):
482
+ """Prepare and flatten visual features."""
483
+ backbone_out = backbone_out.copy()
484
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
485
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
486
+
487
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
488
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
489
+
490
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
491
+ # flatten NxCxHxW to HWxNxC
492
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
493
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
494
+
495
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
496
+
497
+ def _prepare_memory_conditioned_features(
498
+ self,
499
+ frame_idx,
500
+ is_init_cond_frame,
501
+ current_vision_feats,
502
+ current_vision_pos_embeds,
503
+ feat_sizes,
504
+ output_dict,
505
+ num_frames,
506
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
507
+ ):
508
+ """Fuse the current frame's visual feature map with previous memory."""
509
+ B = current_vision_feats[-1].size(1) # batch size on this frame
510
+ C = self.hidden_dim
511
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
512
+ device = current_vision_feats[-1].device
513
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
514
+ # In this case, we skip the fusion with any memory.
515
+ if self.num_maskmem == 0: # Disable memory and skip fusion
516
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
517
+ return pix_feat
518
+
519
+ num_obj_ptr_tokens = 0
520
+ tpos_sign_mul = -1 if track_in_reverse else 1
521
+ # Step 1: condition the visual features of the current frame on previous memories
522
+ if not is_init_cond_frame:
523
+ # Retrieve the memories encoded with the maskmem backbone
524
+ to_cat_memory, to_cat_memory_pos_embed = [], []
525
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
526
+ # when getting temporal positional embedding below)
527
+ assert len(output_dict["cond_frame_outputs"]) > 0
528
+ # Select a maximum number of temporally closest cond frames for cross attention
529
+ cond_outputs = output_dict["cond_frame_outputs"]
530
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
531
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
532
+ )
533
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
534
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
535
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
536
+ # We also allow taking the memory frame non-consecutively (with stride>1), in which case
537
+ # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
538
+ stride = 1 if self.training else self.memory_temporal_stride_for_eval
539
+ for t_pos in range(1, self.num_maskmem):
540
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
541
+ if t_rel == 1:
542
+ # for t_rel == 1, we take the last frame (regardless of r)
543
+ if not track_in_reverse:
544
+ # the frame immediately before this frame (i.e. frame_idx - 1)
545
+ prev_frame_idx = frame_idx - t_rel
546
+ else:
547
+ # the frame immediately after this frame (i.e. frame_idx + 1)
548
+ prev_frame_idx = frame_idx + t_rel
549
+ else:
550
+ # for t_rel >= 2, we take the memory frame from every r-th frames
551
+ if not track_in_reverse:
552
+ # first find the nearest frame among every r-th frames before this frame
553
+ # for r=1, this would be (frame_idx - 2)
554
+ prev_frame_idx = ((frame_idx - 2) // stride) * stride
555
+ # then seek further among every r-th frames
556
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
557
+ else:
558
+ # first find the nearest frame among every r-th frames after this frame
559
+ # for r=1, this would be (frame_idx + 2)
560
+ prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
561
+ # then seek further among every r-th frames
562
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
563
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
564
+ if out is None:
565
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
566
+ # frames, we still attend to it as if it's a non-conditioning frame.
567
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
568
+ t_pos_and_prevs.append((t_pos, out))
569
+
570
+ for t_pos, prev in t_pos_and_prevs:
571
+ if prev is None:
572
+ continue # skip padding frames
573
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
574
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
575
+ feats = prev["maskmem_features"].to(device, non_blocking=True)
576
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
577
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
578
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
579
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
580
+ # Temporal positional encoding
581
+ maskmem_enc = (
582
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
583
+ )
584
+ to_cat_memory_pos_embed.append(maskmem_enc)
585
+
586
+ # Construct the list of past object pointers
587
+ if self.use_obj_ptrs_in_encoder:
588
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
589
+ # First add those object pointers from selected conditioning frames
590
+ # (optionally, only include object pointers in the past during evaluation)
591
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
592
+ ptr_cond_outputs = {
593
+ t: out
594
+ for t, out in selected_cond_outputs.items()
595
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
596
+ }
597
+ else:
598
+ ptr_cond_outputs = selected_cond_outputs
599
+ pos_and_ptrs = [
600
+ # Temporal pos encoding contains how far away each pointer is from current frame
601
+ (
602
+ (
603
+ (frame_idx - t) * tpos_sign_mul
604
+ if self.use_signed_tpos_enc_to_obj_ptrs
605
+ else abs(frame_idx - t)
606
+ ),
607
+ out["obj_ptr"],
608
+ )
609
+ for t, out in ptr_cond_outputs.items()
610
+ ]
611
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
612
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
613
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
614
+ if t < 0 or (num_frames is not None and t >= num_frames):
615
+ break
616
+ out = output_dict["non_cond_frame_outputs"].get(
617
+ t, unselected_cond_outputs.get(t, None)
618
+ )
619
+ if out is not None:
620
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
621
+ # If we have at least one object pointer, add them to the across attention
622
+ if len(pos_and_ptrs) > 0:
623
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
624
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
625
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
626
+ # a temporal positional embedding based on how far each object pointer is from
627
+ # the current frame (sine embedding normalized by the max pointer num).
628
+ if self.add_tpos_enc_to_obj_ptrs:
629
+ t_diff_max = max_obj_ptrs_in_encoder - 1
630
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
631
+ obj_pos = torch.tensor(pos_list, device=device)
632
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
633
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
634
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
635
+ else:
636
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
637
+ if self.mem_dim < C:
638
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
639
+ obj_ptrs = obj_ptrs.reshape(
640
+ -1, B, C // self.mem_dim, self.mem_dim
641
+ )
642
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
643
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
644
+ to_cat_memory.append(obj_ptrs)
645
+ to_cat_memory_pos_embed.append(obj_pos)
646
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
647
+ else:
648
+ num_obj_ptr_tokens = 0
649
+ else:
650
+ # for initial conditioning frames, encode them without using any previous memory
651
+ if self.directly_add_no_mem_embed:
652
+ # directly add no-mem embedding (instead of using the transformer encoder)
653
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
654
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
655
+ return pix_feat_with_mem
656
+
657
+ # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
658
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
659
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
660
+
661
+ # Step 2: Concatenate the memories and forward through the transformer encoder
662
+ memory = torch.cat(to_cat_memory, dim=0)
663
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
664
+
665
+ pix_feat_with_mem = self.memory_attention(
666
+ curr=current_vision_feats,
667
+ curr_pos=current_vision_pos_embeds,
668
+ memory=memory,
669
+ memory_pos=memory_pos_embed,
670
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
671
+ )
672
+ # reshape the output (HW)BC => BCHW
673
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
674
+ return pix_feat_with_mem
675
+
676
+ def _encode_new_memory(
677
+ self,
678
+ current_vision_feats,
679
+ feat_sizes,
680
+ pred_masks_high_res,
681
+ object_score_logits,
682
+ is_mask_from_pts,
683
+ ):
684
+ """Encode the current image and its prediction into a memory feature."""
685
+ B = current_vision_feats[-1].size(1) # batch size on this frame
686
+ C = self.hidden_dim
687
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
688
+ # top-level feature, (HW)BC => BCHW
689
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
690
+ if self.non_overlap_masks_for_mem_enc and not self.training:
691
+ # optionally, apply non-overlapping constraints to the masks (it's applied
692
+ # in the batch dimension and should only be used during eval, where all
693
+ # the objects come from the same video under batch size 1).
694
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
695
+ pred_masks_high_res
696
+ )
697
+ # scale the raw mask logits with a temperature before applying sigmoid
698
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
699
+ if binarize and not self.training:
700
+ mask_for_mem = (pred_masks_high_res > 0).float()
701
+ else:
702
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
703
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
704
+ # apply scale and bias terms to the sigmoid probabilities
705
+ if self.sigmoid_scale_for_mem_enc != 1.0:
706
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
707
+ if self.sigmoid_bias_for_mem_enc != 0.0:
708
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
709
+ maskmem_out = self.memory_encoder(
710
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
711
+ )
712
+ maskmem_features = maskmem_out["vision_features"]
713
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
714
+ # add a no-object embedding to the spatial memory to indicate that the frame
715
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
716
+ if self.no_obj_embed_spatial is not None:
717
+ is_obj_appearing = (object_score_logits > 0).float()
718
+ maskmem_features += (
719
+ 1 - is_obj_appearing[..., None, None]
720
+ ) * self.no_obj_embed_spatial[..., None, None].expand(
721
+ *maskmem_features.shape
722
+ )
723
+
724
+ return maskmem_features, maskmem_pos_enc
725
+
726
+ def _track_step(
727
+ self,
728
+ frame_idx,
729
+ is_init_cond_frame,
730
+ current_vision_feats,
731
+ current_vision_pos_embeds,
732
+ feat_sizes,
733
+ point_inputs,
734
+ mask_inputs,
735
+ output_dict,
736
+ num_frames,
737
+ track_in_reverse,
738
+ prev_sam_mask_logits,
739
+ ):
740
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
741
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
742
+ if len(current_vision_feats) > 1:
743
+ high_res_features = [
744
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
745
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
746
+ ]
747
+ else:
748
+ high_res_features = None
749
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
750
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
751
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
752
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
753
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
754
+ sam_outputs = self._use_mask_as_output(
755
+ pix_feat, high_res_features, mask_inputs
756
+ )
757
+ else:
758
+ # fused the visual feature with previous memory features in the memory bank
759
+ pix_feat = self._prepare_memory_conditioned_features(
760
+ frame_idx=frame_idx,
761
+ is_init_cond_frame=is_init_cond_frame,
762
+ current_vision_feats=current_vision_feats[-1:],
763
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
764
+ feat_sizes=feat_sizes[-1:],
765
+ output_dict=output_dict,
766
+ num_frames=num_frames,
767
+ track_in_reverse=track_in_reverse,
768
+ )
769
+ # apply SAM-style segmentation head
770
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
771
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
772
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
773
+ if prev_sam_mask_logits is not None:
774
+ assert point_inputs is not None and mask_inputs is None
775
+ mask_inputs = prev_sam_mask_logits
776
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
777
+ sam_outputs = self._forward_sam_heads(
778
+ backbone_features=pix_feat,
779
+ point_inputs=point_inputs,
780
+ mask_inputs=mask_inputs,
781
+ high_res_features=high_res_features,
782
+ multimask_output=multimask_output,
783
+ )
784
+
785
+ return current_out, sam_outputs, high_res_features, pix_feat
786
+
787
+ def _encode_memory_in_output(
788
+ self,
789
+ current_vision_feats,
790
+ feat_sizes,
791
+ point_inputs,
792
+ run_mem_encoder,
793
+ high_res_masks,
794
+ object_score_logits,
795
+ current_out,
796
+ ):
797
+ if run_mem_encoder and self.num_maskmem > 0:
798
+ high_res_masks_for_mem_enc = high_res_masks
799
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
800
+ current_vision_feats=current_vision_feats,
801
+ feat_sizes=feat_sizes,
802
+ pred_masks_high_res=high_res_masks_for_mem_enc,
803
+ object_score_logits=object_score_logits,
804
+ is_mask_from_pts=(point_inputs is not None),
805
+ )
806
+ current_out["maskmem_features"] = maskmem_features
807
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
808
+ else:
809
+ current_out["maskmem_features"] = None
810
+ current_out["maskmem_pos_enc"] = None
811
+
812
+ def track_step(
813
+ self,
814
+ frame_idx,
815
+ is_init_cond_frame,
816
+ current_vision_feats,
817
+ current_vision_pos_embeds,
818
+ feat_sizes,
819
+ point_inputs,
820
+ mask_inputs,
821
+ output_dict,
822
+ num_frames,
823
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
824
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
825
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
826
+ # in demo we might call `track_step` multiple times for each user click,
827
+ # and only encode the memory when the user finalizes their clicks. And in ablation
828
+ # settings like SAM training on static images, we don't need the memory encoder.
829
+ run_mem_encoder=True,
830
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
831
+ prev_sam_mask_logits=None,
832
+ ):
833
+ current_out, sam_outputs, _, _ = self._track_step(
834
+ frame_idx,
835
+ is_init_cond_frame,
836
+ current_vision_feats,
837
+ current_vision_pos_embeds,
838
+ feat_sizes,
839
+ point_inputs,
840
+ mask_inputs,
841
+ output_dict,
842
+ num_frames,
843
+ track_in_reverse,
844
+ prev_sam_mask_logits,
845
+ )
846
+
847
+ (
848
+ _,
849
+ _,
850
+ _,
851
+ low_res_masks,
852
+ high_res_masks,
853
+ obj_ptr,
854
+ object_score_logits,
855
+ ) = sam_outputs
856
+
857
+ current_out["pred_masks"] = low_res_masks
858
+ current_out["pred_masks_high_res"] = high_res_masks
859
+ current_out["obj_ptr"] = obj_ptr
860
+ if not self.training:
861
+ # Only add this in inference (to avoid unused param in activation checkpointing;
862
+ # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
863
+ current_out["object_score_logits"] = object_score_logits
864
+
865
+ # Finally run the memory encoder on the predicted mask to encode
866
+ # it into a new memory feature (that can be used in future frames)
867
+ self._encode_memory_in_output(
868
+ current_vision_feats,
869
+ feat_sizes,
870
+ point_inputs,
871
+ run_mem_encoder,
872
+ high_res_masks,
873
+ object_score_logits,
874
+ current_out,
875
+ )
876
+
877
+ return current_out
878
+
879
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
880
+ """Whether to use multimask output in the SAM head."""
881
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
882
+ multimask_output = (
883
+ self.multimask_output_in_sam
884
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
885
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
886
+ )
887
+ return multimask_output
888
+
889
+ def _apply_non_overlapping_constraints(self, pred_masks):
890
+ """
891
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
892
+ keep only the highest scoring object at each spatial location in pred_masks.
893
+ """
894
+ batch_size = pred_masks.size(0)
895
+ if batch_size == 1:
896
+ return pred_masks
897
+
898
+ device = pred_masks.device
899
+ # "max_obj_inds": object index of the object with the highest score at each location
900
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
901
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
902
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
903
+ keep = max_obj_inds == batch_obj_inds
904
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
905
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
906
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
907
+ return pred_masks