frontveg 0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. frontveg/__init__.py +11 -0
  2. frontveg/_tests/__init__.py +0 -0
  3. frontveg/_tests/test_widget.py +66 -0
  4. frontveg/_version.py +21 -0
  5. frontveg/_widget.py +132 -0
  6. frontveg/napari.yaml +14 -0
  7. frontveg/utils.py +95 -0
  8. frontveg-0.1.dev1.dist-info/METADATA +143 -0
  9. frontveg-0.1.dev1.dist-info/RECORD +44 -0
  10. frontveg-0.1.dev1.dist-info/WHEEL +5 -0
  11. frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
  12. frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
  13. frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
  14. sam2/__init__.py +11 -0
  15. sam2/automatic_mask_generator.py +454 -0
  16. sam2/build_sam.py +167 -0
  17. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  18. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  19. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  20. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  21. sam2/modeling/__init__.py +5 -0
  22. sam2/modeling/backbones/__init__.py +5 -0
  23. sam2/modeling/backbones/hieradet.py +317 -0
  24. sam2/modeling/backbones/image_encoder.py +134 -0
  25. sam2/modeling/backbones/utils.py +95 -0
  26. sam2/modeling/memory_attention.py +169 -0
  27. sam2/modeling/memory_encoder.py +181 -0
  28. sam2/modeling/position_encoding.py +221 -0
  29. sam2/modeling/sam/__init__.py +5 -0
  30. sam2/modeling/sam/mask_decoder.py +295 -0
  31. sam2/modeling/sam/prompt_encoder.py +182 -0
  32. sam2/modeling/sam/transformer.py +360 -0
  33. sam2/modeling/sam2_base.py +907 -0
  34. sam2/modeling/sam2_utils.py +323 -0
  35. sam2/sam2_hiera_b+.yaml +1 -0
  36. sam2/sam2_hiera_l.yaml +1 -0
  37. sam2/sam2_hiera_s.yaml +1 -0
  38. sam2/sam2_hiera_t.yaml +1 -0
  39. sam2/sam2_image_predictor.py +466 -0
  40. sam2/sam2_video_predictor.py +1172 -0
  41. sam2/utils/__init__.py +5 -0
  42. sam2/utils/amg.py +348 -0
  43. sam2/utils/misc.py +349 -0
  44. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,1172 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+
12
+ from tqdm import tqdm
13
+
14
+ from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
15
+ from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
16
+
17
+
18
+ class SAM2VideoPredictor(SAM2Base):
19
+ """The predictor class to handle user interactions and manage inference states."""
20
+
21
+ def __init__(
22
+ self,
23
+ fill_hole_area=0,
24
+ # whether to apply non-overlapping constraints on the output object masks
25
+ non_overlap_masks=False,
26
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
27
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
28
+ clear_non_cond_mem_around_input=False,
29
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
30
+ clear_non_cond_mem_for_multi_obj=False,
31
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
32
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
33
+ add_all_frames_to_correct_as_cond=False,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.fill_hole_area = fill_hole_area
38
+ self.non_overlap_masks = non_overlap_masks
39
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
40
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
41
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
42
+
43
+ @torch.inference_mode()
44
+ def init_state(
45
+ self,
46
+ video_path,
47
+ offload_video_to_cpu=False,
48
+ offload_state_to_cpu=False,
49
+ async_loading_frames=False,
50
+ ):
51
+ """Initialize an inference state."""
52
+ compute_device = self.device # device of the model
53
+ images, video_height, video_width = load_video_frames(
54
+ video_path=video_path,
55
+ image_size=self.image_size,
56
+ offload_video_to_cpu=offload_video_to_cpu,
57
+ async_loading_frames=async_loading_frames,
58
+ compute_device=compute_device,
59
+ )
60
+ inference_state = {}
61
+ inference_state["images"] = images
62
+ inference_state["num_frames"] = len(images)
63
+ # whether to offload the video frames to CPU memory
64
+ # turning on this option saves the GPU memory with only a very small overhead
65
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
66
+ # whether to offload the inference state to CPU memory
67
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
68
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
69
+ # and from 24 to 21 when tracking two objects)
70
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
71
+ # the original video height and width, used for resizing final output scores
72
+ inference_state["video_height"] = video_height
73
+ inference_state["video_width"] = video_width
74
+ inference_state["device"] = compute_device
75
+ if offload_state_to_cpu:
76
+ inference_state["storage_device"] = torch.device("cpu")
77
+ else:
78
+ inference_state["storage_device"] = compute_device
79
+ # inputs on each frame
80
+ inference_state["point_inputs_per_obj"] = {}
81
+ inference_state["mask_inputs_per_obj"] = {}
82
+ # visual features on a small number of recently visited frames for quick interactions
83
+ inference_state["cached_features"] = {}
84
+ # values that don't change across frames (so we only need to hold one copy of them)
85
+ inference_state["constants"] = {}
86
+ # mapping between client-side object id and model-side object index
87
+ inference_state["obj_id_to_idx"] = OrderedDict()
88
+ inference_state["obj_idx_to_id"] = OrderedDict()
89
+ inference_state["obj_ids"] = []
90
+ # A storage to hold the model's tracking results and states on each frame
91
+ inference_state["output_dict"] = {
92
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
93
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
94
+ }
95
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
96
+ inference_state["output_dict_per_obj"] = {}
97
+ # A temporary storage to hold new outputs when user interact with a frame
98
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
99
+ inference_state["temp_output_dict_per_obj"] = {}
100
+ # Frames that already holds consolidated outputs from click or mask inputs
101
+ # (we directly use their consolidated outputs during tracking)
102
+ inference_state["consolidated_frame_inds"] = {
103
+ "cond_frame_outputs": set(), # set containing frame indices
104
+ "non_cond_frame_outputs": set(), # set containing frame indices
105
+ }
106
+ # metadata for each tracking frame (e.g. which direction it's tracked)
107
+ inference_state["tracking_has_started"] = False
108
+ inference_state["frames_already_tracked"] = {}
109
+ # Warm up the visual backbone and cache the image feature on frame 0
110
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
111
+ return inference_state
112
+
113
+ @classmethod
114
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
115
+ """
116
+ Load a pretrained model from the Hugging Face hub.
117
+
118
+ Arguments:
119
+ model_id (str): The Hugging Face repository ID.
120
+ **kwargs: Additional arguments to pass to the model constructor.
121
+
122
+ Returns:
123
+ (SAM2VideoPredictor): The loaded model.
124
+ """
125
+ from sam2.build_sam import build_sam2_video_predictor_hf
126
+
127
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
128
+ return sam_model
129
+
130
+ def _obj_id_to_idx(self, inference_state, obj_id):
131
+ """Map client-side object id to model-side object index."""
132
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
133
+ if obj_idx is not None:
134
+ return obj_idx
135
+
136
+ # This is a new object id not sent to the server before. We only allow adding
137
+ # new objects *before* the tracking starts.
138
+ allow_new_object = not inference_state["tracking_has_started"]
139
+ if allow_new_object:
140
+ # get the next object slot
141
+ obj_idx = len(inference_state["obj_id_to_idx"])
142
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
143
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
144
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
145
+ # set up input and output structures for this object
146
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
147
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
148
+ inference_state["output_dict_per_obj"][obj_idx] = {
149
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
150
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
151
+ }
152
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
153
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
154
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
155
+ }
156
+ return obj_idx
157
+ else:
158
+ raise RuntimeError(
159
+ f"Cannot add new object id {obj_id} after tracking starts. "
160
+ f"All existing object ids: {inference_state['obj_ids']}. "
161
+ f"Please call 'reset_state' to restart from scratch."
162
+ )
163
+
164
+ def _obj_idx_to_id(self, inference_state, obj_idx):
165
+ """Map model-side object index to client-side object id."""
166
+ return inference_state["obj_idx_to_id"][obj_idx]
167
+
168
+ def _get_obj_num(self, inference_state):
169
+ """Get the total number of unique object ids received so far in this session."""
170
+ return len(inference_state["obj_idx_to_id"])
171
+
172
+ @torch.inference_mode()
173
+ def add_new_points_or_box(
174
+ self,
175
+ inference_state,
176
+ frame_idx,
177
+ obj_id,
178
+ points=None,
179
+ labels=None,
180
+ clear_old_points=True,
181
+ normalize_coords=True,
182
+ box=None,
183
+ ):
184
+ """Add new points to a frame."""
185
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
186
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
187
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
188
+
189
+ if (points is not None) != (labels is not None):
190
+ raise ValueError("points and labels must be provided together")
191
+ if points is None and box is None:
192
+ raise ValueError("at least one of points or box must be provided as input")
193
+
194
+ if points is None:
195
+ points = torch.zeros(0, 2, dtype=torch.float32)
196
+ elif not isinstance(points, torch.Tensor):
197
+ points = torch.tensor(points, dtype=torch.float32)
198
+ if labels is None:
199
+ labels = torch.zeros(0, dtype=torch.int32)
200
+ elif not isinstance(labels, torch.Tensor):
201
+ labels = torch.tensor(labels, dtype=torch.int32)
202
+ if points.dim() == 2:
203
+ points = points.unsqueeze(0) # add batch dimension
204
+ if labels.dim() == 1:
205
+ labels = labels.unsqueeze(0) # add batch dimension
206
+
207
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
208
+ # along with the user-provided points (consistent with how SAM 2 is trained).
209
+ if box is not None:
210
+ if not clear_old_points:
211
+ raise ValueError(
212
+ "cannot add box without clearing old points, since "
213
+ "box prompt must be provided before any point prompt "
214
+ "(please use clear_old_points=True instead)"
215
+ )
216
+ if inference_state["tracking_has_started"]:
217
+ warnings.warn(
218
+ "You are adding a box after tracking starts. SAM 2 may not always be "
219
+ "able to incorporate a box prompt for *refinement*. If you intend to "
220
+ "use box prompt as an *initial* input before tracking, please call "
221
+ "'reset_state' on the inference state to restart from scratch.",
222
+ category=UserWarning,
223
+ stacklevel=2,
224
+ )
225
+ if not isinstance(box, torch.Tensor):
226
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
227
+ box_coords = box.reshape(1, 2, 2)
228
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
229
+ box_labels = box_labels.reshape(1, 2)
230
+ points = torch.cat([box_coords, points], dim=1)
231
+ labels = torch.cat([box_labels, labels], dim=1)
232
+
233
+ if normalize_coords:
234
+ video_H = inference_state["video_height"]
235
+ video_W = inference_state["video_width"]
236
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
237
+ # scale the (normalized) coordinates by the model's internal image size
238
+ points = points * self.image_size
239
+ points = points.to(inference_state["device"])
240
+ labels = labels.to(inference_state["device"])
241
+
242
+ if not clear_old_points:
243
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
244
+ else:
245
+ point_inputs = None
246
+ point_inputs = concat_points(point_inputs, points, labels)
247
+
248
+ point_inputs_per_frame[frame_idx] = point_inputs
249
+ mask_inputs_per_frame.pop(frame_idx, None)
250
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
251
+ # frame, meaning that the inputs points are to generate segments on this frame without
252
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
253
+ # the input points will be used to correct the already tracked masks.
254
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
255
+ # whether to track in reverse time order
256
+ if is_init_cond_frame:
257
+ reverse = False
258
+ else:
259
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
260
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
261
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
262
+ # Add a frame to conditioning output if it's an initial conditioning frame or
263
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
264
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
265
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
266
+
267
+ # Get any previously predicted mask logits on this object and feed it along with
268
+ # the new clicks into the SAM mask decoder.
269
+ prev_sam_mask_logits = None
270
+ # lookup temporary output dict first, which contains the most recent output
271
+ # (if not found, then lookup conditioning and non-conditioning frame output)
272
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
273
+ if prev_out is None:
274
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
275
+ if prev_out is None:
276
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
277
+
278
+ if prev_out is not None and prev_out["pred_masks"] is not None:
279
+ device = inference_state["device"]
280
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
281
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
282
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
283
+ current_out, _ = self._run_single_frame_inference(
284
+ inference_state=inference_state,
285
+ output_dict=obj_output_dict, # run on the slice of a single object
286
+ frame_idx=frame_idx,
287
+ batch_size=1, # run on the slice of a single object
288
+ is_init_cond_frame=is_init_cond_frame,
289
+ point_inputs=point_inputs,
290
+ mask_inputs=None,
291
+ reverse=reverse,
292
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
293
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
294
+ # allows us to enforce non-overlapping constraints on all objects before encoding
295
+ # them into memory.
296
+ run_mem_encoder=False,
297
+ prev_sam_mask_logits=prev_sam_mask_logits,
298
+ )
299
+ # Add the output to the output dict (to be used as future memory)
300
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
301
+
302
+ # Resize the output mask to the original video resolution
303
+ obj_ids = inference_state["obj_ids"]
304
+ consolidated_out = self._consolidate_temp_output_across_obj(
305
+ inference_state,
306
+ frame_idx,
307
+ is_cond=is_cond,
308
+ run_mem_encoder=False,
309
+ consolidate_at_video_res=True,
310
+ )
311
+ _, video_res_masks = self._get_orig_video_res_output(
312
+ inference_state, consolidated_out["pred_masks_video_res"]
313
+ )
314
+ return frame_idx, obj_ids, video_res_masks
315
+
316
+ def add_new_points(self, *args, **kwargs):
317
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
318
+ return self.add_new_points_or_box(*args, **kwargs)
319
+
320
+ @torch.inference_mode()
321
+ def add_new_mask(
322
+ self,
323
+ inference_state,
324
+ frame_idx,
325
+ obj_id,
326
+ mask,
327
+ ):
328
+ """Add new mask to a frame."""
329
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
330
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
331
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
332
+
333
+ if not isinstance(mask, torch.Tensor):
334
+ mask = torch.tensor(mask, dtype=torch.bool)
335
+ assert mask.dim() == 2
336
+ mask_H, mask_W = mask.shape
337
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
338
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
339
+
340
+ # resize the mask if it doesn't match the model's image size
341
+ if mask_H != self.image_size or mask_W != self.image_size:
342
+ mask_inputs = torch.nn.functional.interpolate(
343
+ mask_inputs_orig,
344
+ size=(self.image_size, self.image_size),
345
+ align_corners=False,
346
+ mode="bilinear",
347
+ antialias=True, # use antialias for downsampling
348
+ )
349
+ mask_inputs = (mask_inputs >= 0.5).float()
350
+ else:
351
+ mask_inputs = mask_inputs_orig
352
+
353
+ mask_inputs_per_frame[frame_idx] = mask_inputs
354
+ point_inputs_per_frame.pop(frame_idx, None)
355
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
356
+ # frame, meaning that the inputs points are to generate segments on this frame without
357
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
358
+ # the input points will be used to correct the already tracked masks.
359
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
360
+ # whether to track in reverse time order
361
+ if is_init_cond_frame:
362
+ reverse = False
363
+ else:
364
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
365
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
366
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
367
+ # Add a frame to conditioning output if it's an initial conditioning frame or
368
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
369
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
370
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
371
+
372
+ current_out, _ = self._run_single_frame_inference(
373
+ inference_state=inference_state,
374
+ output_dict=obj_output_dict, # run on the slice of a single object
375
+ frame_idx=frame_idx,
376
+ batch_size=1, # run on the slice of a single object
377
+ is_init_cond_frame=is_init_cond_frame,
378
+ point_inputs=None,
379
+ mask_inputs=mask_inputs,
380
+ reverse=reverse,
381
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
382
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
383
+ # allows us to enforce non-overlapping constraints on all objects before encoding
384
+ # them into memory.
385
+ run_mem_encoder=False,
386
+ )
387
+ # Add the output to the output dict (to be used as future memory)
388
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
389
+
390
+ # Resize the output mask to the original video resolution
391
+ obj_ids = inference_state["obj_ids"]
392
+ consolidated_out = self._consolidate_temp_output_across_obj(
393
+ inference_state,
394
+ frame_idx,
395
+ is_cond=is_cond,
396
+ run_mem_encoder=False,
397
+ consolidate_at_video_res=True,
398
+ )
399
+ _, video_res_masks = self._get_orig_video_res_output(
400
+ inference_state, consolidated_out["pred_masks_video_res"]
401
+ )
402
+ return frame_idx, obj_ids, video_res_masks
403
+
404
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
405
+ """
406
+ Resize the object scores to the original video resolution (video_res_masks)
407
+ and apply non-overlapping constraints for final output.
408
+ """
409
+ device = inference_state["device"]
410
+ video_H = inference_state["video_height"]
411
+ video_W = inference_state["video_width"]
412
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
413
+ if any_res_masks.shape[-2:] == (video_H, video_W):
414
+ video_res_masks = any_res_masks
415
+ else:
416
+ video_res_masks = torch.nn.functional.interpolate(
417
+ any_res_masks,
418
+ size=(video_H, video_W),
419
+ mode="bilinear",
420
+ align_corners=False,
421
+ )
422
+ if self.non_overlap_masks:
423
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
424
+ return any_res_masks, video_res_masks
425
+
426
+ def _consolidate_temp_output_across_obj(
427
+ self,
428
+ inference_state,
429
+ frame_idx,
430
+ is_cond,
431
+ run_mem_encoder,
432
+ consolidate_at_video_res=False,
433
+ ):
434
+ """
435
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
436
+ a frame into a single output for all objects, including
437
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
438
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
439
+ (if they don't exist in `output_dict_per_obj` for this frame);
440
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
441
+ on the object scores.
442
+ """
443
+ batch_size = self._get_obj_num(inference_state)
444
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
445
+ # Optionally, we allow consolidating the temporary outputs at the original
446
+ # video resolution (to provide a better editing experience for mask prompts).
447
+ if consolidate_at_video_res:
448
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
449
+ consolidated_H = inference_state["video_height"]
450
+ consolidated_W = inference_state["video_width"]
451
+ consolidated_mask_key = "pred_masks_video_res"
452
+ else:
453
+ consolidated_H = consolidated_W = self.image_size // 4
454
+ consolidated_mask_key = "pred_masks"
455
+
456
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
457
+ # will be added when rerunning the memory encoder after applying non-overlapping
458
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
459
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
460
+ consolidated_out = {
461
+ "maskmem_features": None,
462
+ "maskmem_pos_enc": None,
463
+ consolidated_mask_key: torch.full(
464
+ size=(batch_size, 1, consolidated_H, consolidated_W),
465
+ fill_value=NO_OBJ_SCORE,
466
+ dtype=torch.float32,
467
+ device=inference_state["storage_device"],
468
+ ),
469
+ "obj_ptr": torch.full(
470
+ size=(batch_size, self.hidden_dim),
471
+ fill_value=NO_OBJ_SCORE,
472
+ dtype=torch.float32,
473
+ device=inference_state["device"],
474
+ ),
475
+ "object_score_logits": torch.full(
476
+ size=(batch_size, 1),
477
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
478
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
479
+ fill_value=10.0,
480
+ dtype=torch.float32,
481
+ device=inference_state["device"],
482
+ ),
483
+ }
484
+ empty_mask_ptr = None
485
+ for obj_idx in range(batch_size):
486
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
487
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
488
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
489
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
490
+ # we fall back and look up its previous output in "output_dict_per_obj".
491
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
492
+ # "output_dict_per_obj" to find a previous output for this object.
493
+ if out is None:
494
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
495
+ if out is None:
496
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
497
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
498
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
499
+ # placeholder above) and set its object pointer to be a dummy pointer.
500
+ if out is None:
501
+ # Fill in dummy object pointers for those objects without any inputs or
502
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
503
+ # i.e. when we need to build the memory for tracking).
504
+ if run_mem_encoder:
505
+ if empty_mask_ptr is None:
506
+ empty_mask_ptr = self._get_empty_mask_ptr(
507
+ inference_state, frame_idx
508
+ )
509
+ # fill object pointer with a dummy pointer (based on an empty mask)
510
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
511
+ continue
512
+ # Add the temporary object output mask to consolidated output mask
513
+ obj_mask = out["pred_masks"]
514
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
515
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
516
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
517
+ else:
518
+ # Resize first if temporary object mask has a different resolution
519
+ resized_obj_mask = torch.nn.functional.interpolate(
520
+ obj_mask,
521
+ size=consolidated_pred_masks.shape[-2:],
522
+ mode="bilinear",
523
+ align_corners=False,
524
+ )
525
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
526
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
527
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
528
+ "object_score_logits"
529
+ ]
530
+
531
+ # Optionally, apply non-overlapping constraints on the consolidated scores
532
+ # and rerun the memory encoder
533
+ if run_mem_encoder:
534
+ device = inference_state["device"]
535
+ high_res_masks = torch.nn.functional.interpolate(
536
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
537
+ size=(self.image_size, self.image_size),
538
+ mode="bilinear",
539
+ align_corners=False,
540
+ )
541
+ if self.non_overlap_masks_for_mem_enc:
542
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
543
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
544
+ inference_state=inference_state,
545
+ frame_idx=frame_idx,
546
+ batch_size=batch_size,
547
+ high_res_masks=high_res_masks,
548
+ object_score_logits=consolidated_out["object_score_logits"],
549
+ is_mask_from_pts=True, # these frames are what the user interacted with
550
+ )
551
+ consolidated_out["maskmem_features"] = maskmem_features
552
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
553
+
554
+ return consolidated_out
555
+
556
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
557
+ """Get a dummy object pointer based on an empty mask on the current frame."""
558
+ # A dummy (empty) mask with a single object
559
+ batch_size = 1
560
+ mask_inputs = torch.zeros(
561
+ (batch_size, 1, self.image_size, self.image_size),
562
+ dtype=torch.float32,
563
+ device=inference_state["device"],
564
+ )
565
+
566
+ # Retrieve correct image features
567
+ (
568
+ _,
569
+ _,
570
+ current_vision_feats,
571
+ current_vision_pos_embeds,
572
+ feat_sizes,
573
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
574
+
575
+ # Feed the empty mask and image feature above to get a dummy object pointer
576
+ current_out = self.track_step(
577
+ frame_idx=frame_idx,
578
+ is_init_cond_frame=True,
579
+ current_vision_feats=current_vision_feats,
580
+ current_vision_pos_embeds=current_vision_pos_embeds,
581
+ feat_sizes=feat_sizes,
582
+ point_inputs=None,
583
+ mask_inputs=mask_inputs,
584
+ output_dict={},
585
+ num_frames=inference_state["num_frames"],
586
+ track_in_reverse=False,
587
+ run_mem_encoder=False,
588
+ prev_sam_mask_logits=None,
589
+ )
590
+ return current_out["obj_ptr"]
591
+
592
+ @torch.inference_mode()
593
+ def propagate_in_video_preflight(self, inference_state):
594
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
595
+ # Tracking has started and we don't allow adding new objects until session is reset.
596
+ inference_state["tracking_has_started"] = True
597
+ batch_size = self._get_obj_num(inference_state)
598
+
599
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
600
+ # add them into "output_dict".
601
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
602
+ output_dict = inference_state["output_dict"]
603
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
604
+ # temporary outputs have been added (either in this call or any previous calls
605
+ # to `propagate_in_video_preflight`).
606
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
607
+ for is_cond in [False, True]:
608
+ # Separately consolidate conditioning and non-conditioning temp outputs
609
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
610
+ # Find all the frames that contain temporary outputs for any objects
611
+ # (these should be the frames that have just received clicks for mask inputs
612
+ # via `add_new_points_or_box` or `add_new_mask`)
613
+ temp_frame_inds = set()
614
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
615
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
616
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
617
+ # consolidate the temporary output across all objects on this frame
618
+ for frame_idx in temp_frame_inds:
619
+ consolidated_out = self._consolidate_temp_output_across_obj(
620
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
621
+ )
622
+ # merge them into "output_dict" and also create per-object slices
623
+ output_dict[storage_key][frame_idx] = consolidated_out
624
+ self._add_output_per_object(
625
+ inference_state, frame_idx, consolidated_out, storage_key
626
+ )
627
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
628
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
629
+ )
630
+ if clear_non_cond_mem:
631
+ # clear non-conditioning memory of the surrounding frames
632
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
633
+
634
+ # clear temporary outputs in `temp_output_dict_per_obj`
635
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
636
+ obj_temp_output_dict[storage_key].clear()
637
+
638
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
639
+ # output on the same frame in "non_cond_frame_outputs"
640
+ for frame_idx in output_dict["cond_frame_outputs"]:
641
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
642
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
643
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
644
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
645
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
646
+ assert frame_idx in output_dict["cond_frame_outputs"]
647
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
648
+
649
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
650
+ # with either points or mask inputs (which should be true under a correct workflow).
651
+ all_consolidated_frame_inds = (
652
+ consolidated_frame_inds["cond_frame_outputs"]
653
+ | consolidated_frame_inds["non_cond_frame_outputs"]
654
+ )
655
+ input_frames_inds = set()
656
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
657
+ input_frames_inds.update(point_inputs_per_frame.keys())
658
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
659
+ input_frames_inds.update(mask_inputs_per_frame.keys())
660
+ assert all_consolidated_frame_inds == input_frames_inds
661
+
662
+ @torch.inference_mode()
663
+ def propagate_in_video(
664
+ self,
665
+ inference_state,
666
+ start_frame_idx=None,
667
+ max_frame_num_to_track=None,
668
+ reverse=False,
669
+ ):
670
+ """Propagate the input points across frames to track in the entire video."""
671
+ self.propagate_in_video_preflight(inference_state)
672
+
673
+ output_dict = inference_state["output_dict"]
674
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
675
+ obj_ids = inference_state["obj_ids"]
676
+ num_frames = inference_state["num_frames"]
677
+ batch_size = self._get_obj_num(inference_state)
678
+ if len(output_dict["cond_frame_outputs"]) == 0:
679
+ raise RuntimeError("No points are provided; please add points first")
680
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
681
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
682
+ )
683
+
684
+ # set start index, end index, and processing order
685
+ if start_frame_idx is None:
686
+ # default: start from the earliest frame with input points
687
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
688
+ if max_frame_num_to_track is None:
689
+ # default: track all the frames in the video
690
+ max_frame_num_to_track = num_frames
691
+ if reverse:
692
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
693
+ if start_frame_idx > 0:
694
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
695
+ else:
696
+ processing_order = [] # skip reverse tracking if starting from frame 0
697
+ else:
698
+ end_frame_idx = min(
699
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
700
+ )
701
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
702
+
703
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
704
+ # We skip those frames already in consolidated outputs (these are frames
705
+ # that received input clicks or mask). Note that we cannot directly run
706
+ # batched forward on them via `_run_single_frame_inference` because the
707
+ # number of clicks on each object might be different.
708
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
709
+ storage_key = "cond_frame_outputs"
710
+ current_out = output_dict[storage_key][frame_idx]
711
+ pred_masks = current_out["pred_masks"]
712
+ if clear_non_cond_mem:
713
+ # clear non-conditioning memory of the surrounding frames
714
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
715
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
716
+ storage_key = "non_cond_frame_outputs"
717
+ current_out = output_dict[storage_key][frame_idx]
718
+ pred_masks = current_out["pred_masks"]
719
+ else:
720
+ storage_key = "non_cond_frame_outputs"
721
+ current_out, pred_masks = self._run_single_frame_inference(
722
+ inference_state=inference_state,
723
+ output_dict=output_dict,
724
+ frame_idx=frame_idx,
725
+ batch_size=batch_size,
726
+ is_init_cond_frame=False,
727
+ point_inputs=None,
728
+ mask_inputs=None,
729
+ reverse=reverse,
730
+ run_mem_encoder=True,
731
+ )
732
+ output_dict[storage_key][frame_idx] = current_out
733
+ # Create slices of per-object outputs for subsequent interaction with each
734
+ # individual object after tracking.
735
+ self._add_output_per_object(
736
+ inference_state, frame_idx, current_out, storage_key
737
+ )
738
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
739
+
740
+ # Resize the output mask to the original video resolution (we directly use
741
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
742
+ _, video_res_masks = self._get_orig_video_res_output(
743
+ inference_state, pred_masks
744
+ )
745
+ yield frame_idx, obj_ids, video_res_masks
746
+
747
+ def _add_output_per_object(
748
+ self, inference_state, frame_idx, current_out, storage_key
749
+ ):
750
+ """
751
+ Split a multi-object output into per-object output slices and add them into
752
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
753
+ """
754
+ maskmem_features = current_out["maskmem_features"]
755
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
756
+
757
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
758
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
759
+
760
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
761
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
762
+ obj_slice = slice(obj_idx, obj_idx + 1)
763
+ obj_out = {
764
+ "maskmem_features": None,
765
+ "maskmem_pos_enc": None,
766
+ "pred_masks": current_out["pred_masks"][obj_slice],
767
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
768
+ "object_score_logits": current_out["object_score_logits"][obj_slice],
769
+ }
770
+ if maskmem_features is not None:
771
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
772
+ if maskmem_pos_enc is not None:
773
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
774
+ obj_output_dict[storage_key][frame_idx] = obj_out
775
+
776
+ @torch.inference_mode()
777
+ def clear_all_prompts_in_frame(
778
+ self, inference_state, frame_idx, obj_id, need_output=True
779
+ ):
780
+ """Remove all input points or mask in a specific frame for a given object."""
781
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
782
+
783
+ # Clear the conditioning information on the given frame
784
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
785
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
786
+
787
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
788
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
789
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
790
+
791
+ # Check and see if there are still any inputs left on this frame
792
+ batch_size = self._get_obj_num(inference_state)
793
+ frame_has_input = False
794
+ for obj_idx2 in range(batch_size):
795
+ if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
796
+ frame_has_input = True
797
+ break
798
+ if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
799
+ frame_has_input = True
800
+ break
801
+
802
+ # If this frame has no remaining inputs for any objects, we further clear its
803
+ # conditioning frame status
804
+ if not frame_has_input:
805
+ output_dict = inference_state["output_dict"]
806
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
807
+ consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
808
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
809
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
810
+ out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
811
+ if out is not None:
812
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
813
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
814
+ output_dict["non_cond_frame_outputs"][frame_idx] = out
815
+ inference_state["frames_already_tracked"].pop(frame_idx, None)
816
+ # Similarly, do it for the sliced output on each object.
817
+ for obj_idx2 in range(batch_size):
818
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
819
+ obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
820
+ if obj_out is not None:
821
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
822
+
823
+ # If all the conditioning frames have been removed, we also clear the tracking outputs
824
+ if len(output_dict["cond_frame_outputs"]) == 0:
825
+ self._reset_tracking_results(inference_state)
826
+
827
+ if not need_output:
828
+ return
829
+ # Finally, output updated masks per object (after removing the inputs above)
830
+ obj_ids = inference_state["obj_ids"]
831
+ is_cond = any(
832
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
833
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
834
+ )
835
+ consolidated_out = self._consolidate_temp_output_across_obj(
836
+ inference_state,
837
+ frame_idx,
838
+ is_cond=is_cond,
839
+ run_mem_encoder=False,
840
+ consolidate_at_video_res=True,
841
+ )
842
+ _, video_res_masks = self._get_orig_video_res_output(
843
+ inference_state, consolidated_out["pred_masks_video_res"]
844
+ )
845
+ return frame_idx, obj_ids, video_res_masks
846
+
847
+ @torch.inference_mode()
848
+ def reset_state(self, inference_state):
849
+ """Remove all input points or mask in all frames throughout the video."""
850
+ self._reset_tracking_results(inference_state)
851
+ # Remove all object ids
852
+ inference_state["obj_id_to_idx"].clear()
853
+ inference_state["obj_idx_to_id"].clear()
854
+ inference_state["obj_ids"].clear()
855
+ inference_state["point_inputs_per_obj"].clear()
856
+ inference_state["mask_inputs_per_obj"].clear()
857
+ inference_state["output_dict_per_obj"].clear()
858
+ inference_state["temp_output_dict_per_obj"].clear()
859
+
860
+ def _reset_tracking_results(self, inference_state):
861
+ """Reset all tracking inputs and results across the videos."""
862
+ for v in inference_state["point_inputs_per_obj"].values():
863
+ v.clear()
864
+ for v in inference_state["mask_inputs_per_obj"].values():
865
+ v.clear()
866
+ for v in inference_state["output_dict_per_obj"].values():
867
+ v["cond_frame_outputs"].clear()
868
+ v["non_cond_frame_outputs"].clear()
869
+ for v in inference_state["temp_output_dict_per_obj"].values():
870
+ v["cond_frame_outputs"].clear()
871
+ v["non_cond_frame_outputs"].clear()
872
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
873
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
874
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
875
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
876
+ inference_state["tracking_has_started"] = False
877
+ inference_state["frames_already_tracked"].clear()
878
+
879
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
880
+ """Compute the image features on a given frame."""
881
+ # Look up in the cache first
882
+ image, backbone_out = inference_state["cached_features"].get(
883
+ frame_idx, (None, None)
884
+ )
885
+ if backbone_out is None:
886
+ # Cache miss -- we will run inference on a single image
887
+ device = inference_state["device"]
888
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
889
+ backbone_out = self.forward_image(image)
890
+ # Cache the most recent frame's feature (for repeated interactions with
891
+ # a frame; we can use an LRU cache for more frames in the future).
892
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
893
+
894
+ # expand the features to have the same dimension as the number of objects
895
+ expanded_image = image.expand(batch_size, -1, -1, -1)
896
+ expanded_backbone_out = {
897
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
898
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
899
+ }
900
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
901
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
902
+ batch_size, -1, -1, -1
903
+ )
904
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
905
+ pos = pos.expand(batch_size, -1, -1, -1)
906
+ expanded_backbone_out["vision_pos_enc"][i] = pos
907
+
908
+ features = self._prepare_backbone_features(expanded_backbone_out)
909
+ features = (expanded_image,) + features
910
+ return features
911
+
912
+ def _run_single_frame_inference(
913
+ self,
914
+ inference_state,
915
+ output_dict,
916
+ frame_idx,
917
+ batch_size,
918
+ is_init_cond_frame,
919
+ point_inputs,
920
+ mask_inputs,
921
+ reverse,
922
+ run_mem_encoder,
923
+ prev_sam_mask_logits=None,
924
+ ):
925
+ """Run tracking on a single frame based on current inputs and previous memory."""
926
+ # Retrieve correct image features
927
+ (
928
+ _,
929
+ _,
930
+ current_vision_feats,
931
+ current_vision_pos_embeds,
932
+ feat_sizes,
933
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
934
+
935
+ # point and mask should not appear as input simultaneously on the same frame
936
+ assert point_inputs is None or mask_inputs is None
937
+ current_out = self.track_step(
938
+ frame_idx=frame_idx,
939
+ is_init_cond_frame=is_init_cond_frame,
940
+ current_vision_feats=current_vision_feats,
941
+ current_vision_pos_embeds=current_vision_pos_embeds,
942
+ feat_sizes=feat_sizes,
943
+ point_inputs=point_inputs,
944
+ mask_inputs=mask_inputs,
945
+ output_dict=output_dict,
946
+ num_frames=inference_state["num_frames"],
947
+ track_in_reverse=reverse,
948
+ run_mem_encoder=run_mem_encoder,
949
+ prev_sam_mask_logits=prev_sam_mask_logits,
950
+ )
951
+
952
+ # optionally offload the output to CPU memory to save GPU space
953
+ storage_device = inference_state["storage_device"]
954
+ maskmem_features = current_out["maskmem_features"]
955
+ if maskmem_features is not None:
956
+ maskmem_features = maskmem_features.to(torch.bfloat16)
957
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
958
+ pred_masks_gpu = current_out["pred_masks"]
959
+ # potentially fill holes in the predicted masks
960
+ if self.fill_hole_area > 0:
961
+ pred_masks_gpu = fill_holes_in_mask_scores(
962
+ pred_masks_gpu, self.fill_hole_area
963
+ )
964
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
965
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
966
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
967
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
968
+ obj_ptr = current_out["obj_ptr"]
969
+ object_score_logits = current_out["object_score_logits"]
970
+ # make a compact version of this frame's output to reduce the state size
971
+ compact_current_out = {
972
+ "maskmem_features": maskmem_features,
973
+ "maskmem_pos_enc": maskmem_pos_enc,
974
+ "pred_masks": pred_masks,
975
+ "obj_ptr": obj_ptr,
976
+ "object_score_logits": object_score_logits,
977
+ }
978
+ return compact_current_out, pred_masks_gpu
979
+
980
+ def _run_memory_encoder(
981
+ self,
982
+ inference_state,
983
+ frame_idx,
984
+ batch_size,
985
+ high_res_masks,
986
+ object_score_logits,
987
+ is_mask_from_pts,
988
+ ):
989
+ """
990
+ Run the memory encoder on `high_res_masks`. This is usually after applying
991
+ non-overlapping constraints to object scores. Since their scores changed, their
992
+ memory also need to be computed again with the memory encoder.
993
+ """
994
+ # Retrieve correct image features
995
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
996
+ inference_state, frame_idx, batch_size
997
+ )
998
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
999
+ current_vision_feats=current_vision_feats,
1000
+ feat_sizes=feat_sizes,
1001
+ pred_masks_high_res=high_res_masks,
1002
+ object_score_logits=object_score_logits,
1003
+ is_mask_from_pts=is_mask_from_pts,
1004
+ )
1005
+
1006
+ # optionally offload the output to CPU memory to save GPU space
1007
+ storage_device = inference_state["storage_device"]
1008
+ maskmem_features = maskmem_features.to(torch.bfloat16)
1009
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
1010
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1011
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
1012
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
1013
+ )
1014
+ return maskmem_features, maskmem_pos_enc
1015
+
1016
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
1017
+ """
1018
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1019
+ a constant in the inference session to reduce session storage size.
1020
+ """
1021
+ model_constants = inference_state["constants"]
1022
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
1023
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
1024
+ if out_maskmem_pos_enc is not None:
1025
+ if "maskmem_pos_enc" not in model_constants:
1026
+ assert isinstance(out_maskmem_pos_enc, list)
1027
+ # only take the slice for one object, since it's same across objects
1028
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1029
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
1030
+ else:
1031
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1032
+ # expand the cached maskmem_pos_enc to the actual batch size
1033
+ batch_size = out_maskmem_pos_enc[0].size(0)
1034
+ expanded_maskmem_pos_enc = [
1035
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
1036
+ ]
1037
+ else:
1038
+ expanded_maskmem_pos_enc = None
1039
+ return expanded_maskmem_pos_enc
1040
+
1041
+ @torch.inference_mode()
1042
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
1043
+ """
1044
+ Remove an object id from the tracking state. If strict is True, we check whether
1045
+ the object id actually exists and raise an error if it doesn't exist.
1046
+ """
1047
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
1048
+ updated_frames = []
1049
+ # Check whether this object_id to remove actually exists and possibly raise an error.
1050
+ if old_obj_idx_to_rm is None:
1051
+ if not strict:
1052
+ return inference_state["obj_ids"], updated_frames
1053
+ raise RuntimeError(
1054
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
1055
+ f"All existing object ids: {inference_state['obj_ids']}."
1056
+ )
1057
+
1058
+ # If this is the only remaining object id, we simply reset the state.
1059
+ if len(inference_state["obj_id_to_idx"]) == 1:
1060
+ self.reset_state(inference_state)
1061
+ return inference_state["obj_ids"], updated_frames
1062
+
1063
+ # There are still remaining objects after removing this object id. In this case,
1064
+ # we need to delete the object storage from inference state tensors.
1065
+ # Step 0: clear the input on those frames where this object id has point or mask input
1066
+ # (note that this step is required as it might downgrade conditioning frames to
1067
+ # non-conditioning ones)
1068
+ obj_input_frames_inds = set()
1069
+ obj_input_frames_inds.update(
1070
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
1071
+ )
1072
+ obj_input_frames_inds.update(
1073
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
1074
+ )
1075
+ for frame_idx in obj_input_frames_inds:
1076
+ self.clear_all_prompts_in_frame(
1077
+ inference_state, frame_idx, obj_id, need_output=False
1078
+ )
1079
+
1080
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
1081
+ # since Step 0 still requires the old object id mappings in inference_state)
1082
+ old_obj_ids = inference_state["obj_ids"]
1083
+ old_obj_inds = list(range(len(old_obj_ids)))
1084
+ remain_old_obj_inds = old_obj_inds.copy()
1085
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
1086
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
1087
+ new_obj_inds = list(range(len(new_obj_ids)))
1088
+ # build new mappings
1089
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
1090
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
1091
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
1092
+ inference_state["obj_ids"] = new_obj_ids
1093
+
1094
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
1095
+ # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
1096
+ # it's already handled in Step 0)
1097
+ def _map_keys(container):
1098
+ new_kvs = []
1099
+ for k in old_obj_inds:
1100
+ v = container.pop(k)
1101
+ if k in old_idx_to_new_idx:
1102
+ new_kvs.append((old_idx_to_new_idx[k], v))
1103
+ container.update(new_kvs)
1104
+
1105
+ _map_keys(inference_state["point_inputs_per_obj"])
1106
+ _map_keys(inference_state["mask_inputs_per_obj"])
1107
+ _map_keys(inference_state["output_dict_per_obj"])
1108
+ _map_keys(inference_state["temp_output_dict_per_obj"])
1109
+
1110
+ # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
1111
+ def _slice_state(output_dict, storage_key):
1112
+ for frame_idx, out in output_dict[storage_key].items():
1113
+ out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
1114
+ out["maskmem_pos_enc"] = [
1115
+ x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
1116
+ ]
1117
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1118
+ out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
1119
+ out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
1120
+ out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
1121
+ out["object_score_logits"] = out["object_score_logits"][
1122
+ remain_old_obj_inds
1123
+ ]
1124
+ # also update the per-object slices
1125
+ self._add_output_per_object(
1126
+ inference_state, frame_idx, out, storage_key
1127
+ )
1128
+
1129
+ _slice_state(inference_state["output_dict"], "cond_frame_outputs")
1130
+ _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
1131
+
1132
+ # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
1133
+ # could show an updated mask for objects previously occluded by the object being removed
1134
+ if need_output:
1135
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
1136
+ for frame_idx in obj_input_frames_inds:
1137
+ is_cond = any(
1138
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
1139
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
1140
+ )
1141
+ consolidated_out = self._consolidate_temp_output_across_obj(
1142
+ inference_state,
1143
+ frame_idx,
1144
+ is_cond=is_cond,
1145
+ run_mem_encoder=False,
1146
+ consolidate_at_video_res=True,
1147
+ )
1148
+ _, video_res_masks = self._get_orig_video_res_output(
1149
+ inference_state, consolidated_out["pred_masks_video_res"]
1150
+ )
1151
+ updated_frames.append((frame_idx, video_res_masks))
1152
+
1153
+ return inference_state["obj_ids"], updated_frames
1154
+
1155
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
1156
+ """
1157
+ Remove the non-conditioning memory around the input frame. When users provide
1158
+ correction clicks, the surrounding frames' non-conditioning memories can still
1159
+ contain outdated object appearance information and could confuse the model.
1160
+
1161
+ This method clears those non-conditioning memories surrounding the interacted
1162
+ frame to avoid giving the model both old and new information about the object.
1163
+ """
1164
+ r = self.memory_temporal_stride_for_eval
1165
+ frame_idx_begin = frame_idx - r * self.num_maskmem
1166
+ frame_idx_end = frame_idx + r * self.num_maskmem
1167
+ output_dict = inference_state["output_dict"]
1168
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
1169
+ for t in range(frame_idx_begin, frame_idx_end + 1):
1170
+ non_cond_frame_outputs.pop(t, None)
1171
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
1172
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)