ultralytics 8.2.71__py3-none-any.whl → 8.2.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- tests/test_cli.py +3 -0
- ultralytics/__init__.py +2 -3
- ultralytics/models/__init__.py +1 -2
- ultralytics/models/sam/__init__.py +2 -2
- ultralytics/models/sam/amg.py +27 -21
- ultralytics/models/sam/build.py +200 -9
- ultralytics/models/sam/model.py +86 -34
- ultralytics/models/sam/modules/blocks.py +1131 -0
- ultralytics/models/sam/modules/decoders.py +390 -23
- ultralytics/models/sam/modules/encoders.py +508 -323
- ultralytics/models/{sam2 → sam}/modules/memory_attention.py +73 -6
- ultralytics/models/sam/modules/sam.py +887 -16
- ultralytics/models/sam/modules/tiny_encoder.py +376 -126
- ultralytics/models/sam/modules/transformer.py +155 -54
- ultralytics/models/{sam2 → sam}/modules/utils.py +105 -3
- ultralytics/models/sam/predict.py +382 -92
- ultralytics/nn/modules/transformer.py +2 -2
- ultralytics/utils/downloads.py +2 -2
- ultralytics/utils/ops.py +2 -2
- ultralytics/utils/plotting.py +3 -3
- {ultralytics-8.2.71.dist-info → ultralytics-8.2.73.dist-info}/METADATA +44 -44
- {ultralytics-8.2.71.dist-info → ultralytics-8.2.73.dist-info}/RECORD +26 -34
- ultralytics/models/sam2/__init__.py +0 -6
- ultralytics/models/sam2/build.py +0 -156
- ultralytics/models/sam2/model.py +0 -97
- ultralytics/models/sam2/modules/__init__.py +0 -1
- ultralytics/models/sam2/modules/decoders.py +0 -305
- ultralytics/models/sam2/modules/encoders.py +0 -332
- ultralytics/models/sam2/modules/sam2.py +0 -804
- ultralytics/models/sam2/modules/sam2_blocks.py +0 -715
- ultralytics/models/sam2/predict.py +0 -182
- {ultralytics-8.2.71.dist-info → ultralytics-8.2.73.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.71.dist-info → ultralytics-8.2.73.dist-info}/WHEEL +0 -0
- {ultralytics-8.2.71.dist-info → ultralytics-8.2.73.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.71.dist-info → ultralytics-8.2.73.dist-info}/top_level.txt +0 -0
|
@@ -1,804 +0,0 @@
|
|
|
1
|
-
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import torch.distributed
|
|
5
|
-
import torch.nn.functional as F
|
|
6
|
-
from torch.nn.init import trunc_normal_
|
|
7
|
-
|
|
8
|
-
from ultralytics.models.sam.modules.encoders import PromptEncoder
|
|
9
|
-
from ultralytics.nn.modules import MLP
|
|
10
|
-
|
|
11
|
-
from .decoders import MaskDecoder
|
|
12
|
-
from .sam2_blocks import TwoWayTransformer
|
|
13
|
-
from .utils import get_1d_sine_pe, select_closest_cond_frames
|
|
14
|
-
|
|
15
|
-
# a large negative value as a placeholder score for missing objects
|
|
16
|
-
NO_OBJ_SCORE = -1024.0
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class SAM2Model(torch.nn.Module):
|
|
20
|
-
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities."""
|
|
21
|
-
|
|
22
|
-
mask_threshold: float = 0.0
|
|
23
|
-
|
|
24
|
-
def __init__(
|
|
25
|
-
self,
|
|
26
|
-
image_encoder,
|
|
27
|
-
memory_attention,
|
|
28
|
-
memory_encoder,
|
|
29
|
-
num_maskmem=7, # default 1 input frame + 6 previous frames
|
|
30
|
-
image_size=512,
|
|
31
|
-
backbone_stride=16, # stride of the image backbone output
|
|
32
|
-
sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
|
|
33
|
-
sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
|
|
34
|
-
# During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
|
|
35
|
-
binarize_mask_from_pts_for_mem_enc=False,
|
|
36
|
-
use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
|
|
37
|
-
# The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
|
|
38
|
-
# we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
|
|
39
|
-
# a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
|
|
40
|
-
max_cond_frames_in_attn=-1,
|
|
41
|
-
# on the first frame, whether to directly add the no-memory embedding to the image feature
|
|
42
|
-
# (instead of using the transformer encoder)
|
|
43
|
-
directly_add_no_mem_embed=False,
|
|
44
|
-
# whether to use high-resolution feature maps in the SAM mask decoder
|
|
45
|
-
use_high_res_features_in_sam=False,
|
|
46
|
-
# whether to output multiple (3) masks for the first click on initial conditioning frames
|
|
47
|
-
multimask_output_in_sam=False,
|
|
48
|
-
# the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
|
|
49
|
-
# default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
|
|
50
|
-
multimask_min_pt_num=1,
|
|
51
|
-
multimask_max_pt_num=1,
|
|
52
|
-
# whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
|
|
53
|
-
multimask_output_for_tracking=False,
|
|
54
|
-
# Whether to use multimask tokens for obj ptr; Only relevant when both
|
|
55
|
-
# use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
|
|
56
|
-
use_multimask_token_for_obj_ptr: bool = False,
|
|
57
|
-
# whether to use sigmoid to restrict ious prediction to [0-1]
|
|
58
|
-
iou_prediction_use_sigmoid=False,
|
|
59
|
-
# The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
|
|
60
|
-
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
|
|
61
|
-
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
|
|
62
|
-
memory_temporal_stride_for_eval=1,
|
|
63
|
-
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
|
64
|
-
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
|
65
|
-
add_all_frames_to_correct_as_cond=False,
|
|
66
|
-
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
|
|
67
|
-
non_overlap_masks_for_mem_enc=False,
|
|
68
|
-
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
|
69
|
-
use_obj_ptrs_in_encoder=False,
|
|
70
|
-
# the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
|
|
71
|
-
max_obj_ptrs_in_encoder=16,
|
|
72
|
-
# whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
|
|
73
|
-
add_tpos_enc_to_obj_ptrs=True,
|
|
74
|
-
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
|
|
75
|
-
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
|
|
76
|
-
proj_tpos_enc_in_obj_ptrs=False,
|
|
77
|
-
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
|
|
78
|
-
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
|
|
79
|
-
only_obj_ptrs_in_the_past_for_eval=False,
|
|
80
|
-
# Whether to predict if there is an object in the frame
|
|
81
|
-
pred_obj_scores: bool = False,
|
|
82
|
-
# Whether to use an MLP to predict object scores
|
|
83
|
-
pred_obj_scores_mlp: bool = False,
|
|
84
|
-
# Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
|
|
85
|
-
# Whether to have a fixed no obj pointer when there is no object present
|
|
86
|
-
# or to use it as an additive embedding with obj_ptr produced by decoder
|
|
87
|
-
fixed_no_obj_ptr: bool = False,
|
|
88
|
-
# Soft no object, i.e. mix in no_obj_ptr softly,
|
|
89
|
-
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
|
|
90
|
-
soft_no_obj_ptr: bool = False,
|
|
91
|
-
use_mlp_for_obj_ptr_proj: bool = False,
|
|
92
|
-
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
|
|
93
|
-
sam_mask_decoder_extra_args=None,
|
|
94
|
-
compile_image_encoder: bool = False,
|
|
95
|
-
):
|
|
96
|
-
"""Initializes SAM2Model model with image encoder, memory attention, and memory encoder components."""
|
|
97
|
-
super().__init__()
|
|
98
|
-
|
|
99
|
-
# Part 1: the image backbone
|
|
100
|
-
self.image_encoder = image_encoder
|
|
101
|
-
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
|
|
102
|
-
self.use_high_res_features_in_sam = use_high_res_features_in_sam
|
|
103
|
-
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
|
|
104
|
-
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
|
|
105
|
-
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
|
|
106
|
-
if use_obj_ptrs_in_encoder:
|
|
107
|
-
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
|
|
108
|
-
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
|
|
109
|
-
# so that it can be fed into the SAM mask decoder to generate a pointer.
|
|
110
|
-
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
|
|
111
|
-
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
|
|
112
|
-
if proj_tpos_enc_in_obj_ptrs:
|
|
113
|
-
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
|
|
114
|
-
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
|
|
115
|
-
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
|
|
116
|
-
|
|
117
|
-
# Part 2: memory attention to condition current frame's visual features
|
|
118
|
-
# with memories (and obj ptrs) from past frames
|
|
119
|
-
self.memory_attention = memory_attention
|
|
120
|
-
self.hidden_dim = memory_attention.d_model
|
|
121
|
-
|
|
122
|
-
# Part 3: memory encoder for the previous frame's outputs
|
|
123
|
-
self.memory_encoder = memory_encoder
|
|
124
|
-
self.mem_dim = self.hidden_dim
|
|
125
|
-
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
|
|
126
|
-
# if there is compression of memories along channel dim
|
|
127
|
-
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
|
128
|
-
self.num_maskmem = num_maskmem # Number of memories accessible
|
|
129
|
-
# Temporal encoding of the memories
|
|
130
|
-
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
|
|
131
|
-
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
|
132
|
-
# a single token to indicate no memory embedding from previous frames
|
|
133
|
-
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
|
134
|
-
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
|
135
|
-
trunc_normal_(self.no_mem_embed, std=0.02)
|
|
136
|
-
trunc_normal_(self.no_mem_pos_enc, std=0.02)
|
|
137
|
-
self.directly_add_no_mem_embed = directly_add_no_mem_embed
|
|
138
|
-
# Apply sigmoid to the output raw mask logits (to turn them from
|
|
139
|
-
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
|
|
140
|
-
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
|
|
141
|
-
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
|
|
142
|
-
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
|
|
143
|
-
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
|
|
144
|
-
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
|
|
145
|
-
# On frames with mask input, whether to directly output the input mask without
|
|
146
|
-
# using a SAM prompt encoder + mask decoder
|
|
147
|
-
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
|
|
148
|
-
self.multimask_output_in_sam = multimask_output_in_sam
|
|
149
|
-
self.multimask_min_pt_num = multimask_min_pt_num
|
|
150
|
-
self.multimask_max_pt_num = multimask_max_pt_num
|
|
151
|
-
self.multimask_output_for_tracking = multimask_output_for_tracking
|
|
152
|
-
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
|
153
|
-
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
|
|
154
|
-
|
|
155
|
-
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
|
|
156
|
-
# and SAM-style mask decoder for the final mask output
|
|
157
|
-
self.image_size = image_size
|
|
158
|
-
self.backbone_stride = backbone_stride
|
|
159
|
-
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
|
|
160
|
-
self.pred_obj_scores = pred_obj_scores
|
|
161
|
-
self.pred_obj_scores_mlp = pred_obj_scores_mlp
|
|
162
|
-
self.fixed_no_obj_ptr = fixed_no_obj_ptr
|
|
163
|
-
self.soft_no_obj_ptr = soft_no_obj_ptr
|
|
164
|
-
if self.fixed_no_obj_ptr:
|
|
165
|
-
assert self.pred_obj_scores
|
|
166
|
-
assert self.use_obj_ptrs_in_encoder
|
|
167
|
-
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
|
|
168
|
-
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
|
|
169
|
-
trunc_normal_(self.no_obj_ptr, std=0.02)
|
|
170
|
-
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
|
|
171
|
-
|
|
172
|
-
self._build_sam_heads()
|
|
173
|
-
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
|
174
|
-
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
|
175
|
-
|
|
176
|
-
# Model compilation
|
|
177
|
-
if compile_image_encoder:
|
|
178
|
-
# Compile the forward function (not the full module) to allow loading checkpoints.
|
|
179
|
-
print("Image encoder compilation is enabled. First forward pass will be slow.")
|
|
180
|
-
self.image_encoder.forward = torch.compile(
|
|
181
|
-
self.image_encoder.forward,
|
|
182
|
-
mode="max-autotune",
|
|
183
|
-
fullgraph=True,
|
|
184
|
-
dynamic=False,
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
@property
|
|
188
|
-
def device(self):
|
|
189
|
-
"""Returns the device on which the model's parameters are stored."""
|
|
190
|
-
return next(self.parameters()).device
|
|
191
|
-
|
|
192
|
-
def forward(self, *args, **kwargs):
|
|
193
|
-
"""Processes input frames and prompts to generate object masks and scores in video sequences."""
|
|
194
|
-
raise NotImplementedError(
|
|
195
|
-
"Please use the corresponding methods in SAM2VideoPredictor for inference."
|
|
196
|
-
"See notebooks/video_predictor_example.ipynb for an example."
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
def _build_sam_heads(self):
|
|
200
|
-
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
|
|
201
|
-
self.sam_prompt_embed_dim = self.hidden_dim
|
|
202
|
-
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
|
203
|
-
|
|
204
|
-
# build PromptEncoder and MaskDecoder from SAM
|
|
205
|
-
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
|
206
|
-
self.sam_prompt_encoder = PromptEncoder(
|
|
207
|
-
embed_dim=self.sam_prompt_embed_dim,
|
|
208
|
-
image_embedding_size=(
|
|
209
|
-
self.sam_image_embedding_size,
|
|
210
|
-
self.sam_image_embedding_size,
|
|
211
|
-
),
|
|
212
|
-
input_image_size=(self.image_size, self.image_size),
|
|
213
|
-
mask_in_chans=16,
|
|
214
|
-
)
|
|
215
|
-
self.sam_mask_decoder = MaskDecoder(
|
|
216
|
-
num_multimask_outputs=3,
|
|
217
|
-
transformer=TwoWayTransformer(
|
|
218
|
-
depth=2,
|
|
219
|
-
embedding_dim=self.sam_prompt_embed_dim,
|
|
220
|
-
mlp_dim=2048,
|
|
221
|
-
num_heads=8,
|
|
222
|
-
),
|
|
223
|
-
transformer_dim=self.sam_prompt_embed_dim,
|
|
224
|
-
iou_head_depth=3,
|
|
225
|
-
iou_head_hidden_dim=256,
|
|
226
|
-
use_high_res_features=self.use_high_res_features_in_sam,
|
|
227
|
-
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
|
228
|
-
pred_obj_scores=self.pred_obj_scores,
|
|
229
|
-
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
|
230
|
-
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
|
231
|
-
**(self.sam_mask_decoder_extra_args or {}),
|
|
232
|
-
)
|
|
233
|
-
if self.use_obj_ptrs_in_encoder:
|
|
234
|
-
# a linear projection on SAM output tokens to turn them into object pointers
|
|
235
|
-
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
|
236
|
-
if self.use_mlp_for_obj_ptr_proj:
|
|
237
|
-
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
|
|
238
|
-
else:
|
|
239
|
-
self.obj_ptr_proj = torch.nn.Identity()
|
|
240
|
-
if self.proj_tpos_enc_in_obj_ptrs:
|
|
241
|
-
# a linear projection on temporal positional encoding in object pointers to
|
|
242
|
-
# avoid potential interference with spatial positional encoding
|
|
243
|
-
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
|
|
244
|
-
else:
|
|
245
|
-
self.obj_ptr_tpos_proj = torch.nn.Identity()
|
|
246
|
-
|
|
247
|
-
def _forward_sam_heads(
|
|
248
|
-
self,
|
|
249
|
-
backbone_features,
|
|
250
|
-
point_inputs=None,
|
|
251
|
-
mask_inputs=None,
|
|
252
|
-
high_res_features=None,
|
|
253
|
-
multimask_output=False,
|
|
254
|
-
):
|
|
255
|
-
"""
|
|
256
|
-
Forward SAM prompt encoders and mask heads.
|
|
257
|
-
|
|
258
|
-
Args:
|
|
259
|
-
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
|
|
260
|
-
point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
|
|
261
|
-
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
|
|
262
|
-
pixel-unit coordinates in (x, y) format for P input points.
|
|
263
|
-
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
|
|
264
|
-
0 means negative clicks, and -1 means padding.
|
|
265
|
-
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
|
|
266
|
-
same spatial size as the image.
|
|
267
|
-
high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
|
|
268
|
-
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
|
|
269
|
-
for SAM decoder.
|
|
270
|
-
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
|
|
271
|
-
output only 1 mask and its IoU estimate.
|
|
272
|
-
|
|
273
|
-
Returns:
|
|
274
|
-
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
|
|
275
|
-
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
|
276
|
-
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
|
277
|
-
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
|
|
278
|
-
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
|
|
279
|
-
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
|
|
280
|
-
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
|
281
|
-
object_score_logits: Tensor of shape (B,) with object score logits.
|
|
282
|
-
|
|
283
|
-
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
|
|
284
|
-
|
|
285
|
-
Examples:
|
|
286
|
-
>>> backbone_features = torch.rand(1, 256, 32, 32)
|
|
287
|
-
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
|
|
288
|
-
>>> mask_inputs = torch.rand(1, 1, 512, 512)
|
|
289
|
-
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
|
|
290
|
-
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
|
|
291
|
-
"""
|
|
292
|
-
B = backbone_features.size(0)
|
|
293
|
-
device = backbone_features.device
|
|
294
|
-
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
|
295
|
-
assert backbone_features.size(2) == self.sam_image_embedding_size
|
|
296
|
-
assert backbone_features.size(3) == self.sam_image_embedding_size
|
|
297
|
-
|
|
298
|
-
# a) Handle point prompts
|
|
299
|
-
if point_inputs is not None:
|
|
300
|
-
sam_point_coords = point_inputs["point_coords"]
|
|
301
|
-
sam_point_labels = point_inputs["point_labels"]
|
|
302
|
-
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
|
303
|
-
else:
|
|
304
|
-
# If no points are provide, pad with an empty point (with label -1)
|
|
305
|
-
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
|
306
|
-
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
|
307
|
-
|
|
308
|
-
# b) Handle mask prompts
|
|
309
|
-
if mask_inputs is not None:
|
|
310
|
-
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
|
311
|
-
# and feed it as a dense mask prompt into the SAM mask encoder
|
|
312
|
-
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
|
313
|
-
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
|
314
|
-
sam_mask_prompt = F.interpolate(
|
|
315
|
-
mask_inputs.float(),
|
|
316
|
-
size=self.sam_prompt_encoder.mask_input_size,
|
|
317
|
-
align_corners=False,
|
|
318
|
-
mode="bilinear",
|
|
319
|
-
antialias=True, # use antialias for downsampling
|
|
320
|
-
)
|
|
321
|
-
else:
|
|
322
|
-
sam_mask_prompt = mask_inputs
|
|
323
|
-
else:
|
|
324
|
-
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
|
325
|
-
# a learned `no_mask_embed` to indicate no mask input in this case).
|
|
326
|
-
sam_mask_prompt = None
|
|
327
|
-
|
|
328
|
-
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
|
329
|
-
points=(sam_point_coords, sam_point_labels),
|
|
330
|
-
boxes=None,
|
|
331
|
-
masks=sam_mask_prompt,
|
|
332
|
-
)
|
|
333
|
-
(
|
|
334
|
-
low_res_multimasks,
|
|
335
|
-
ious,
|
|
336
|
-
sam_output_tokens,
|
|
337
|
-
object_score_logits,
|
|
338
|
-
) = self.sam_mask_decoder(
|
|
339
|
-
image_embeddings=backbone_features,
|
|
340
|
-
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
|
341
|
-
sparse_prompt_embeddings=sparse_embeddings,
|
|
342
|
-
dense_prompt_embeddings=dense_embeddings,
|
|
343
|
-
multimask_output=multimask_output,
|
|
344
|
-
repeat_image=False, # the image is already batched
|
|
345
|
-
high_res_features=high_res_features,
|
|
346
|
-
)
|
|
347
|
-
if self.pred_obj_scores:
|
|
348
|
-
is_obj_appearing = object_score_logits > 0
|
|
349
|
-
|
|
350
|
-
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
|
351
|
-
# consistent with the actual mask prediction
|
|
352
|
-
low_res_multimasks = torch.where(
|
|
353
|
-
is_obj_appearing[:, None, None],
|
|
354
|
-
low_res_multimasks,
|
|
355
|
-
NO_OBJ_SCORE,
|
|
356
|
-
)
|
|
357
|
-
|
|
358
|
-
# convert masks from possibly bfloat16 (or float16) to float32
|
|
359
|
-
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
|
360
|
-
low_res_multimasks = low_res_multimasks.float()
|
|
361
|
-
high_res_multimasks = F.interpolate(
|
|
362
|
-
low_res_multimasks,
|
|
363
|
-
size=(self.image_size, self.image_size),
|
|
364
|
-
mode="bilinear",
|
|
365
|
-
align_corners=False,
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
sam_output_token = sam_output_tokens[:, 0]
|
|
369
|
-
if multimask_output:
|
|
370
|
-
# take the best mask prediction (with the highest IoU estimation)
|
|
371
|
-
best_iou_inds = torch.argmax(ious, dim=-1)
|
|
372
|
-
batch_inds = torch.arange(B, device=device)
|
|
373
|
-
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
|
374
|
-
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
|
375
|
-
if sam_output_tokens.size(1) > 1:
|
|
376
|
-
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
|
377
|
-
else:
|
|
378
|
-
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
|
379
|
-
|
|
380
|
-
# Extract object pointer from the SAM output token (with occlusion handling)
|
|
381
|
-
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
|
382
|
-
if self.pred_obj_scores:
|
|
383
|
-
# Allow *soft* no obj ptr, unlike for masks
|
|
384
|
-
if self.soft_no_obj_ptr:
|
|
385
|
-
# Only hard possible with gt
|
|
386
|
-
assert not self.teacher_force_obj_scores_for_mem
|
|
387
|
-
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
|
388
|
-
else:
|
|
389
|
-
lambda_is_obj_appearing = is_obj_appearing.float()
|
|
390
|
-
|
|
391
|
-
if self.fixed_no_obj_ptr:
|
|
392
|
-
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
|
393
|
-
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
|
394
|
-
|
|
395
|
-
return (
|
|
396
|
-
low_res_multimasks,
|
|
397
|
-
high_res_multimasks,
|
|
398
|
-
ious,
|
|
399
|
-
low_res_masks,
|
|
400
|
-
high_res_masks,
|
|
401
|
-
obj_ptr,
|
|
402
|
-
object_score_logits,
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
|
406
|
-
"""Processes mask inputs to generate output mask logits and object pointers without using SAM."""
|
|
407
|
-
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
|
408
|
-
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
|
409
|
-
mask_inputs_float = mask_inputs.float()
|
|
410
|
-
high_res_masks = mask_inputs_float * out_scale + out_bias
|
|
411
|
-
low_res_masks = F.interpolate(
|
|
412
|
-
high_res_masks,
|
|
413
|
-
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
|
|
414
|
-
align_corners=False,
|
|
415
|
-
mode="bilinear",
|
|
416
|
-
antialias=True, # use antialias for downsampling
|
|
417
|
-
)
|
|
418
|
-
# a dummy IoU prediction of all 1's under mask input
|
|
419
|
-
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
|
420
|
-
if not self.use_obj_ptrs_in_encoder:
|
|
421
|
-
# all zeros as a dummy object pointer (of shape [B, C])
|
|
422
|
-
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
|
|
423
|
-
else:
|
|
424
|
-
# produce an object pointer using the SAM decoder from the mask input
|
|
425
|
-
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
|
426
|
-
backbone_features=backbone_features,
|
|
427
|
-
mask_inputs=self.mask_downsample(mask_inputs_float),
|
|
428
|
-
high_res_features=high_res_features,
|
|
429
|
-
)
|
|
430
|
-
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
|
431
|
-
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
|
|
432
|
-
# on the object_scores from the SAM decoder.
|
|
433
|
-
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
|
|
434
|
-
is_obj_appearing = is_obj_appearing[..., None]
|
|
435
|
-
lambda_is_obj_appearing = is_obj_appearing.float()
|
|
436
|
-
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
|
|
437
|
-
if self.pred_obj_scores:
|
|
438
|
-
if self.fixed_no_obj_ptr:
|
|
439
|
-
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
|
440
|
-
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
|
441
|
-
|
|
442
|
-
return (
|
|
443
|
-
low_res_masks,
|
|
444
|
-
high_res_masks,
|
|
445
|
-
ious,
|
|
446
|
-
low_res_masks,
|
|
447
|
-
high_res_masks,
|
|
448
|
-
obj_ptr,
|
|
449
|
-
object_score_logits,
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
def forward_image(self, img_batch: torch.Tensor):
|
|
453
|
-
"""Process image batch through encoder to extract multi-level features for SAM model."""
|
|
454
|
-
backbone_out = self.image_encoder(img_batch)
|
|
455
|
-
if self.use_high_res_features_in_sam:
|
|
456
|
-
# precompute projected level 0 and level 1 features in SAM decoder
|
|
457
|
-
# to avoid running it again on every SAM click
|
|
458
|
-
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
|
459
|
-
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
|
460
|
-
return backbone_out
|
|
461
|
-
|
|
462
|
-
def _prepare_backbone_features(self, backbone_out):
|
|
463
|
-
"""Prepare and flatten visual features from the image backbone output."""
|
|
464
|
-
backbone_out = backbone_out.copy()
|
|
465
|
-
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
|
466
|
-
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
|
467
|
-
|
|
468
|
-
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
|
|
469
|
-
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
|
|
470
|
-
|
|
471
|
-
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
|
472
|
-
# flatten NxCxHxW to HWxNxC
|
|
473
|
-
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
|
474
|
-
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
|
475
|
-
|
|
476
|
-
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
|
477
|
-
|
|
478
|
-
def _prepare_memory_conditioned_features(
|
|
479
|
-
self,
|
|
480
|
-
frame_idx,
|
|
481
|
-
is_init_cond_frame,
|
|
482
|
-
current_vision_feats,
|
|
483
|
-
current_vision_pos_embeds,
|
|
484
|
-
feat_sizes,
|
|
485
|
-
output_dict,
|
|
486
|
-
num_frames,
|
|
487
|
-
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
|
488
|
-
):
|
|
489
|
-
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
|
|
490
|
-
B = current_vision_feats[-1].size(1) # batch size on this frame
|
|
491
|
-
C = self.hidden_dim
|
|
492
|
-
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
|
493
|
-
device = current_vision_feats[-1].device
|
|
494
|
-
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
|
|
495
|
-
# In this case, we skip the fusion with any memory.
|
|
496
|
-
if self.num_maskmem == 0: # Disable memory and skip fusion
|
|
497
|
-
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
|
498
|
-
return pix_feat
|
|
499
|
-
|
|
500
|
-
num_obj_ptr_tokens = 0
|
|
501
|
-
# Step 1: condition the visual features of the current frame on previous memories
|
|
502
|
-
if not is_init_cond_frame:
|
|
503
|
-
# Retrieve the memories encoded with the maskmem backbone
|
|
504
|
-
to_cat_memory, to_cat_memory_pos_embed = [], []
|
|
505
|
-
# Add conditioning frames's output first (all cond frames have t_pos=0 for
|
|
506
|
-
# when getting temporal positional embedding below)
|
|
507
|
-
assert len(output_dict["cond_frame_outputs"]) > 0
|
|
508
|
-
# Select a maximum number of temporally closest cond frames for cross attention
|
|
509
|
-
cond_outputs = output_dict["cond_frame_outputs"]
|
|
510
|
-
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
|
|
511
|
-
frame_idx, cond_outputs, self.max_cond_frames_in_attn
|
|
512
|
-
)
|
|
513
|
-
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
|
|
514
|
-
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
|
|
515
|
-
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
|
|
516
|
-
# We also allow taking the memory frame non-consecutively (with r>1), in which case
|
|
517
|
-
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
|
|
518
|
-
r = self.memory_temporal_stride_for_eval
|
|
519
|
-
for t_pos in range(1, self.num_maskmem):
|
|
520
|
-
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
|
521
|
-
if t_rel == 1:
|
|
522
|
-
# for t_rel == 1, we take the last frame (regardless of r)
|
|
523
|
-
if not track_in_reverse:
|
|
524
|
-
# the frame immediately before this frame (i.e. frame_idx - 1)
|
|
525
|
-
prev_frame_idx = frame_idx - t_rel
|
|
526
|
-
else:
|
|
527
|
-
# the frame immediately after this frame (i.e. frame_idx + 1)
|
|
528
|
-
prev_frame_idx = frame_idx + t_rel
|
|
529
|
-
else:
|
|
530
|
-
# for t_rel >= 2, we take the memory frame from every r-th frames
|
|
531
|
-
if not track_in_reverse:
|
|
532
|
-
# first find the nearest frame among every r-th frames before this frame
|
|
533
|
-
# for r=1, this would be (frame_idx - 2)
|
|
534
|
-
prev_frame_idx = ((frame_idx - 2) // r) * r
|
|
535
|
-
# then seek further among every r-th frames
|
|
536
|
-
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
|
|
537
|
-
else:
|
|
538
|
-
# first find the nearest frame among every r-th frames after this frame
|
|
539
|
-
# for r=1, this would be (frame_idx + 2)
|
|
540
|
-
prev_frame_idx = -(-(frame_idx + 2) // r) * r
|
|
541
|
-
# then seek further among every r-th frames
|
|
542
|
-
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
|
|
543
|
-
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
|
544
|
-
if out is None:
|
|
545
|
-
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
|
546
|
-
# frames, we still attend to it as if it's a non-conditioning frame.
|
|
547
|
-
out = unselected_cond_outputs.get(prev_frame_idx, None)
|
|
548
|
-
t_pos_and_prevs.append((t_pos, out))
|
|
549
|
-
|
|
550
|
-
for t_pos, prev in t_pos_and_prevs:
|
|
551
|
-
if prev is None:
|
|
552
|
-
continue # skip padding frames
|
|
553
|
-
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
|
554
|
-
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
|
555
|
-
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
|
556
|
-
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
|
557
|
-
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
|
558
|
-
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
|
559
|
-
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
|
560
|
-
# Temporal positional encoding
|
|
561
|
-
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
|
562
|
-
to_cat_memory_pos_embed.append(maskmem_enc)
|
|
563
|
-
|
|
564
|
-
# Construct the list of past object pointers
|
|
565
|
-
if self.use_obj_ptrs_in_encoder:
|
|
566
|
-
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
|
|
567
|
-
# First add those object pointers from selected conditioning frames
|
|
568
|
-
# (optionally, only include object pointers in the past during evaluation)
|
|
569
|
-
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
|
570
|
-
ptr_cond_outputs = {
|
|
571
|
-
t: out
|
|
572
|
-
for t, out in selected_cond_outputs.items()
|
|
573
|
-
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
|
574
|
-
}
|
|
575
|
-
else:
|
|
576
|
-
ptr_cond_outputs = selected_cond_outputs
|
|
577
|
-
pos_and_ptrs = [
|
|
578
|
-
# Temporal pos encoding contains how far away each pointer is from current frame
|
|
579
|
-
(abs(frame_idx - t), out["obj_ptr"])
|
|
580
|
-
for t, out in ptr_cond_outputs.items()
|
|
581
|
-
]
|
|
582
|
-
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
|
|
583
|
-
for t_diff in range(1, max_obj_ptrs_in_encoder):
|
|
584
|
-
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
|
585
|
-
if t < 0 or (num_frames is not None and t >= num_frames):
|
|
586
|
-
break
|
|
587
|
-
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
|
|
588
|
-
if out is not None:
|
|
589
|
-
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
|
590
|
-
# If we have at least one object pointer, add them to the across attention
|
|
591
|
-
if len(pos_and_ptrs) > 0:
|
|
592
|
-
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
|
593
|
-
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
|
|
594
|
-
obj_ptrs = torch.stack(ptrs_list, dim=0)
|
|
595
|
-
# a temporal positional embedding based on how far each object pointer is from
|
|
596
|
-
# the current frame (sine embedding normalized by the max pointer num).
|
|
597
|
-
if self.add_tpos_enc_to_obj_ptrs:
|
|
598
|
-
t_diff_max = max_obj_ptrs_in_encoder - 1
|
|
599
|
-
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
|
600
|
-
obj_pos = torch.tensor(pos_list, device=device)
|
|
601
|
-
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
|
602
|
-
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
|
603
|
-
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
|
604
|
-
else:
|
|
605
|
-
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
|
606
|
-
if self.mem_dim < C:
|
|
607
|
-
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
|
608
|
-
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
|
|
609
|
-
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
|
610
|
-
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
|
611
|
-
to_cat_memory.append(obj_ptrs)
|
|
612
|
-
to_cat_memory_pos_embed.append(obj_pos)
|
|
613
|
-
num_obj_ptr_tokens = obj_ptrs.shape[0]
|
|
614
|
-
else:
|
|
615
|
-
num_obj_ptr_tokens = 0
|
|
616
|
-
else:
|
|
617
|
-
# for initial conditioning frames, encode them without using any previous memory
|
|
618
|
-
if self.directly_add_no_mem_embed:
|
|
619
|
-
# directly add no-mem embedding (instead of using the transformer encoder)
|
|
620
|
-
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
|
|
621
|
-
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
|
622
|
-
return pix_feat_with_mem
|
|
623
|
-
|
|
624
|
-
# Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
|
|
625
|
-
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
|
626
|
-
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
|
627
|
-
|
|
628
|
-
# Step 2: Concatenate the memories and forward through the transformer encoder
|
|
629
|
-
memory = torch.cat(to_cat_memory, dim=0)
|
|
630
|
-
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
|
631
|
-
|
|
632
|
-
pix_feat_with_mem = self.memory_attention(
|
|
633
|
-
curr=current_vision_feats,
|
|
634
|
-
curr_pos=current_vision_pos_embeds,
|
|
635
|
-
memory=memory,
|
|
636
|
-
memory_pos=memory_pos_embed,
|
|
637
|
-
num_obj_ptr_tokens=num_obj_ptr_tokens,
|
|
638
|
-
)
|
|
639
|
-
# reshape the output (HW)BC => BCHW
|
|
640
|
-
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
|
641
|
-
return pix_feat_with_mem
|
|
642
|
-
|
|
643
|
-
def _encode_new_memory(
|
|
644
|
-
self,
|
|
645
|
-
current_vision_feats,
|
|
646
|
-
feat_sizes,
|
|
647
|
-
pred_masks_high_res,
|
|
648
|
-
is_mask_from_pts,
|
|
649
|
-
):
|
|
650
|
-
"""Encodes the current frame's features and predicted masks into a new memory representation."""
|
|
651
|
-
B = current_vision_feats[-1].size(1) # batch size on this frame
|
|
652
|
-
C = self.hidden_dim
|
|
653
|
-
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
|
654
|
-
# top-level feature, (HW)BC => BCHW
|
|
655
|
-
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
|
656
|
-
if self.non_overlap_masks_for_mem_enc and not self.training:
|
|
657
|
-
# optionally, apply non-overlapping constraints to the masks (it's applied
|
|
658
|
-
# in the batch dimension and should only be used during eval, where all
|
|
659
|
-
# the objects come from the same video under batch size 1).
|
|
660
|
-
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
|
|
661
|
-
# scale the raw mask logits with a temperature before applying sigmoid
|
|
662
|
-
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
|
663
|
-
if binarize and not self.training:
|
|
664
|
-
mask_for_mem = (pred_masks_high_res > 0).float()
|
|
665
|
-
else:
|
|
666
|
-
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
|
667
|
-
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
|
668
|
-
# apply scale and bias terms to the sigmoid probabilities
|
|
669
|
-
if self.sigmoid_scale_for_mem_enc != 1.0:
|
|
670
|
-
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
|
671
|
-
if self.sigmoid_bias_for_mem_enc != 0.0:
|
|
672
|
-
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
|
673
|
-
maskmem_out = self.memory_encoder(
|
|
674
|
-
pix_feat,
|
|
675
|
-
mask_for_mem,
|
|
676
|
-
skip_mask_sigmoid=True, # sigmoid already applied
|
|
677
|
-
)
|
|
678
|
-
maskmem_features = maskmem_out["vision_features"]
|
|
679
|
-
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
|
680
|
-
|
|
681
|
-
return maskmem_features, maskmem_pos_enc
|
|
682
|
-
|
|
683
|
-
def track_step(
|
|
684
|
-
self,
|
|
685
|
-
frame_idx,
|
|
686
|
-
is_init_cond_frame,
|
|
687
|
-
current_vision_feats,
|
|
688
|
-
current_vision_pos_embeds,
|
|
689
|
-
feat_sizes,
|
|
690
|
-
point_inputs,
|
|
691
|
-
mask_inputs,
|
|
692
|
-
output_dict,
|
|
693
|
-
num_frames,
|
|
694
|
-
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
|
695
|
-
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
|
|
696
|
-
# to skip the memory encoder with `run_mem_encoder=False`. For example,
|
|
697
|
-
# in demo we might call `track_step` multiple times for each user click,
|
|
698
|
-
# and only encode the memory when the user finalizes their clicks. And in ablation
|
|
699
|
-
# settings like SAM training on static images, we don't need the memory encoder.
|
|
700
|
-
run_mem_encoder=True,
|
|
701
|
-
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
|
702
|
-
prev_sam_mask_logits=None,
|
|
703
|
-
):
|
|
704
|
-
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
|
|
705
|
-
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
|
706
|
-
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
|
707
|
-
if len(current_vision_feats) > 1:
|
|
708
|
-
high_res_features = [
|
|
709
|
-
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
|
710
|
-
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
|
711
|
-
]
|
|
712
|
-
else:
|
|
713
|
-
high_res_features = None
|
|
714
|
-
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
|
715
|
-
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
|
716
|
-
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
|
717
|
-
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
|
718
|
-
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
|
719
|
-
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
|
|
720
|
-
else:
|
|
721
|
-
# fused the visual feature with previous memory features in the memory bank
|
|
722
|
-
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
|
723
|
-
frame_idx=frame_idx,
|
|
724
|
-
is_init_cond_frame=is_init_cond_frame,
|
|
725
|
-
current_vision_feats=current_vision_feats[-1:],
|
|
726
|
-
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
|
727
|
-
feat_sizes=feat_sizes[-1:],
|
|
728
|
-
output_dict=output_dict,
|
|
729
|
-
num_frames=num_frames,
|
|
730
|
-
track_in_reverse=track_in_reverse,
|
|
731
|
-
)
|
|
732
|
-
# apply SAM-style segmentation head
|
|
733
|
-
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
|
734
|
-
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
|
735
|
-
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
|
736
|
-
if prev_sam_mask_logits is not None:
|
|
737
|
-
assert point_inputs is not None and mask_inputs is None
|
|
738
|
-
mask_inputs = prev_sam_mask_logits
|
|
739
|
-
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
|
740
|
-
sam_outputs = self._forward_sam_heads(
|
|
741
|
-
backbone_features=pix_feat_with_mem,
|
|
742
|
-
point_inputs=point_inputs,
|
|
743
|
-
mask_inputs=mask_inputs,
|
|
744
|
-
high_res_features=high_res_features,
|
|
745
|
-
multimask_output=multimask_output,
|
|
746
|
-
)
|
|
747
|
-
(
|
|
748
|
-
_,
|
|
749
|
-
_,
|
|
750
|
-
_,
|
|
751
|
-
low_res_masks,
|
|
752
|
-
high_res_masks,
|
|
753
|
-
obj_ptr,
|
|
754
|
-
_,
|
|
755
|
-
) = sam_outputs
|
|
756
|
-
|
|
757
|
-
current_out["pred_masks"] = low_res_masks
|
|
758
|
-
current_out["pred_masks_high_res"] = high_res_masks
|
|
759
|
-
current_out["obj_ptr"] = obj_ptr
|
|
760
|
-
|
|
761
|
-
# Finally run the memory encoder on the predicted mask to encode
|
|
762
|
-
# it into a new memory feature (that can be used in future frames)
|
|
763
|
-
if run_mem_encoder and self.num_maskmem > 0:
|
|
764
|
-
high_res_masks_for_mem_enc = high_res_masks
|
|
765
|
-
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
|
766
|
-
current_vision_feats=current_vision_feats,
|
|
767
|
-
feat_sizes=feat_sizes,
|
|
768
|
-
pred_masks_high_res=high_res_masks_for_mem_enc,
|
|
769
|
-
is_mask_from_pts=(point_inputs is not None),
|
|
770
|
-
)
|
|
771
|
-
current_out["maskmem_features"] = maskmem_features
|
|
772
|
-
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
773
|
-
else:
|
|
774
|
-
current_out["maskmem_features"] = None
|
|
775
|
-
current_out["maskmem_pos_enc"] = None
|
|
776
|
-
|
|
777
|
-
return current_out
|
|
778
|
-
|
|
779
|
-
def _use_multimask(self, is_init_cond_frame, point_inputs):
|
|
780
|
-
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
|
|
781
|
-
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
|
782
|
-
multimask_output = (
|
|
783
|
-
self.multimask_output_in_sam
|
|
784
|
-
and (is_init_cond_frame or self.multimask_output_for_tracking)
|
|
785
|
-
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
|
786
|
-
)
|
|
787
|
-
return multimask_output
|
|
788
|
-
|
|
789
|
-
def _apply_non_overlapping_constraints(self, pred_masks):
|
|
790
|
-
"""Applies non-overlapping constraints to object masks, keeping highest scoring object at each location."""
|
|
791
|
-
batch_size = pred_masks.size(0)
|
|
792
|
-
if batch_size == 1:
|
|
793
|
-
return pred_masks
|
|
794
|
-
|
|
795
|
-
device = pred_masks.device
|
|
796
|
-
# "max_obj_inds": object index of the object with the highest score at each location
|
|
797
|
-
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
|
|
798
|
-
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
|
|
799
|
-
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
|
|
800
|
-
keep = max_obj_inds == batch_obj_inds
|
|
801
|
-
# suppress overlapping regions' scores below -10.0 so that the foreground regions
|
|
802
|
-
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
|
803
|
-
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
|
804
|
-
return pred_masks
|