singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,1388 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+ from collections import OrderedDict
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import cv2
14
+
15
+ from tqdm import tqdm
16
+
17
+ from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
18
+ from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
19
+
20
+
21
+ class SAM2VideoPredictor(SAM2Base):
22
+ """The predictor class to handle user interactions and manage inference states."""
23
+
24
+ def __init__(
25
+ self,
26
+ fill_hole_area=0,
27
+ # whether to apply non-overlapping constraints on the output object masks
28
+ non_overlap_masks=False,
29
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
30
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
31
+ clear_non_cond_mem_around_input=False,
32
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
33
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
34
+ add_all_frames_to_correct_as_cond=False,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.fill_hole_area = fill_hole_area
39
+ self.non_overlap_masks = non_overlap_masks
40
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
41
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
42
+
43
+ @torch.inference_mode()
44
+ def init_state(
45
+ self,
46
+ video_path,
47
+ offload_video_to_cpu=False,
48
+ offload_state_to_cpu=False,
49
+ async_loading_frames=False,
50
+ ):
51
+ """Initialize an inference state."""
52
+ compute_device = self.device # device of the model
53
+ images, video_height, video_width = load_video_frames(
54
+ video_path=video_path,
55
+ image_size=self.image_size,
56
+ offload_video_to_cpu=offload_video_to_cpu,
57
+ async_loading_frames=async_loading_frames,
58
+ compute_device=compute_device,
59
+ )
60
+ inference_state = {}
61
+ inference_state["images"] = images
62
+ inference_state["num_frames"] = len(images)
63
+ # whether to offload the video frames to CPU memory
64
+ # turning on this option saves the GPU memory with only a very small overhead
65
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
66
+ # whether to offload the inference state to CPU memory
67
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
68
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
69
+ # and from 24 to 21 when tracking two objects)
70
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
71
+ # the original video height and width, used for resizing final output scores
72
+ inference_state["video_height"] = video_height
73
+ inference_state["video_width"] = video_width
74
+ inference_state["device"] = compute_device
75
+ if offload_state_to_cpu:
76
+ inference_state["storage_device"] = torch.device("cpu")
77
+ else:
78
+ inference_state["storage_device"] = compute_device
79
+ # inputs on each frame
80
+ inference_state["point_inputs_per_obj"] = {}
81
+ inference_state["mask_inputs_per_obj"] = {}
82
+ # visual features on a small number of recently visited frames for quick interactions
83
+ inference_state["cached_features"] = {}
84
+ # values that don't change across frames (so we only need to hold one copy of them)
85
+ inference_state["constants"] = {}
86
+ # mapping between client-side object id and model-side object index
87
+ inference_state["obj_id_to_idx"] = OrderedDict()
88
+ inference_state["obj_idx_to_id"] = OrderedDict()
89
+ inference_state["obj_ids"] = []
90
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
91
+ inference_state["output_dict_per_obj"] = {}
92
+ # A temporary storage to hold new outputs when user interact with a frame
93
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
94
+ inference_state["temp_output_dict_per_obj"] = {}
95
+ # Frames that already holds consolidated outputs from click or mask inputs
96
+ # (we directly use their consolidated outputs during tracking)
97
+ # metadata for each tracking frame (e.g. which direction it's tracked)
98
+ inference_state["frames_tracked_per_obj"] = {}
99
+ # Warm up the visual backbone and cache the image feature on frame 0
100
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
101
+ return inference_state
102
+
103
+ @classmethod
104
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
105
+ """
106
+ Load a pretrained model from the Hugging Face hub.
107
+
108
+ Arguments:
109
+ model_id (str): The Hugging Face repository ID.
110
+ **kwargs: Additional arguments to pass to the model constructor.
111
+
112
+ Returns:
113
+ (SAM2VideoPredictor): The loaded model.
114
+ """
115
+ from sam2.build_sam import build_sam2_video_predictor_hf
116
+
117
+ sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
118
+ return sam_model
119
+
120
+ def _obj_id_to_idx(self, inference_state, obj_id):
121
+ """Map client-side object id to model-side object index."""
122
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
123
+ if obj_idx is not None:
124
+ return obj_idx
125
+
126
+ # We always allow adding new objects (including after tracking starts).
127
+ allow_new_object = True
128
+ if allow_new_object:
129
+ # get the next object slot
130
+ obj_idx = len(inference_state["obj_id_to_idx"])
131
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
132
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
133
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
134
+ # set up input and output structures for this object
135
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
136
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
137
+ inference_state["output_dict_per_obj"][obj_idx] = {
138
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
139
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
140
+ }
141
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
142
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
143
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
144
+ }
145
+ inference_state["frames_tracked_per_obj"][obj_idx] = {}
146
+ return obj_idx
147
+ else:
148
+ raise RuntimeError(
149
+ f"Cannot add new object id {obj_id} after tracking starts. "
150
+ f"All existing object ids: {inference_state['obj_ids']}. "
151
+ f"Please call 'reset_state' to restart from scratch."
152
+ )
153
+
154
+ def _obj_idx_to_id(self, inference_state, obj_idx):
155
+ """Map model-side object index to client-side object id."""
156
+ return inference_state["obj_idx_to_id"][obj_idx]
157
+
158
+ def _get_obj_num(self, inference_state):
159
+ """Get the total number of unique object ids received so far in this session."""
160
+ return len(inference_state["obj_idx_to_id"])
161
+
162
+ @torch.inference_mode()
163
+ def add_new_points_or_box(
164
+ self,
165
+ inference_state,
166
+ frame_idx,
167
+ obj_id,
168
+ points=None,
169
+ labels=None,
170
+ clear_old_points=True,
171
+ normalize_coords=True,
172
+ box=None,
173
+ ):
174
+ """Add new points to a frame."""
175
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
176
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
177
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
178
+
179
+ if (points is not None) != (labels is not None):
180
+ raise ValueError("points and labels must be provided together")
181
+ if points is None and box is None:
182
+ raise ValueError("at least one of points or box must be provided as input")
183
+
184
+ if points is None:
185
+ points = torch.zeros(0, 2, dtype=torch.float32)
186
+ elif not isinstance(points, torch.Tensor):
187
+ points = torch.tensor(points, dtype=torch.float32)
188
+ if labels is None:
189
+ labels = torch.zeros(0, dtype=torch.int32)
190
+ elif not isinstance(labels, torch.Tensor):
191
+ labels = torch.tensor(labels, dtype=torch.int32)
192
+ if points.dim() == 2:
193
+ points = points.unsqueeze(0) # add batch dimension
194
+ if labels.dim() == 1:
195
+ labels = labels.unsqueeze(0) # add batch dimension
196
+
197
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
198
+ # along with the user-provided points (consistent with how SAM 2 is trained).
199
+ if box is not None:
200
+ if not clear_old_points:
201
+ raise ValueError(
202
+ "cannot add box without clearing old points, since "
203
+ "box prompt must be provided before any point prompt "
204
+ "(please use clear_old_points=True instead)"
205
+ )
206
+ if not isinstance(box, torch.Tensor):
207
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
208
+ box_coords = box.reshape(1, 2, 2)
209
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
210
+ box_labels = box_labels.reshape(1, 2)
211
+ points = torch.cat([box_coords, points], dim=1)
212
+ labels = torch.cat([box_labels, labels], dim=1)
213
+
214
+ if normalize_coords:
215
+ video_H = inference_state["video_height"]
216
+ video_W = inference_state["video_width"]
217
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
218
+ # scale the (normalized) coordinates by the model's internal image size
219
+ points = points * self.image_size
220
+ points = points.to(inference_state["device"])
221
+ labels = labels.to(inference_state["device"])
222
+
223
+ if not clear_old_points:
224
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
225
+ else:
226
+ point_inputs = None
227
+ point_inputs = concat_points(point_inputs, points, labels)
228
+
229
+ point_inputs_per_frame[frame_idx] = point_inputs
230
+ mask_inputs_per_frame.pop(frame_idx, None)
231
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
232
+ # frame, meaning that the inputs points are to generate segments on this frame without
233
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
234
+ # the input points will be used to correct the already tracked masks.
235
+ obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
236
+ is_init_cond_frame = frame_idx not in obj_frames_tracked
237
+ # whether to track in reverse time order
238
+ if is_init_cond_frame:
239
+ reverse = False
240
+ else:
241
+ reverse = obj_frames_tracked[frame_idx]["reverse"]
242
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
243
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
244
+ # Add a frame to conditioning output if it's an initial conditioning frame or
245
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
246
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
247
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
248
+
249
+ # Get any previously predicted mask logits on this object and feed it along with
250
+ # the new clicks into the SAM mask decoder.
251
+ prev_sam_mask_logits = None
252
+ # lookup temporary output dict first, which contains the most recent output
253
+ # (if not found, then lookup conditioning and non-conditioning frame output)
254
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
255
+ if prev_out is None:
256
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
257
+ if prev_out is None:
258
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
259
+
260
+ if prev_out is not None and prev_out["pred_masks"] is not None:
261
+ device = inference_state["device"]
262
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
263
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
264
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
265
+ current_out, _ = self._run_single_frame_inference(
266
+ inference_state=inference_state,
267
+ output_dict=obj_output_dict, # run on the slice of a single object
268
+ frame_idx=frame_idx,
269
+ batch_size=1, # run on the slice of a single object
270
+ is_init_cond_frame=is_init_cond_frame,
271
+ point_inputs=point_inputs,
272
+ mask_inputs=None,
273
+ reverse=reverse,
274
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
275
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
276
+ # allows us to enforce non-overlapping constraints on all objects before encoding
277
+ # them into memory.
278
+ run_mem_encoder=False,
279
+ prev_sam_mask_logits=prev_sam_mask_logits,
280
+ )
281
+ # Add the output to the output dict (to be used as future memory)
282
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
283
+
284
+ # Resize the output mask to the original video resolution
285
+ obj_ids = inference_state["obj_ids"]
286
+ consolidated_out = self._consolidate_temp_output_across_obj(
287
+ inference_state,
288
+ frame_idx,
289
+ is_cond=is_cond,
290
+ consolidate_at_video_res=True,
291
+ )
292
+ _, video_res_masks = self._get_orig_video_res_output(
293
+ inference_state, consolidated_out["pred_masks_video_res"]
294
+ )
295
+ return frame_idx, obj_ids, video_res_masks
296
+
297
+ def add_new_points(self, *args, **kwargs):
298
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
299
+ return self.add_new_points_or_box(*args, **kwargs)
300
+
301
+ @torch.inference_mode()
302
+ def add_new_mask(
303
+ self,
304
+ inference_state,
305
+ frame_idx,
306
+ obj_id,
307
+ mask,
308
+ ):
309
+ """Add new mask to a frame."""
310
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
311
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
312
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
313
+
314
+ if not isinstance(mask, torch.Tensor):
315
+ mask = torch.tensor(mask, dtype=torch.bool)
316
+ assert mask.dim() == 2
317
+ mask_H, mask_W = mask.shape
318
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
319
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
320
+
321
+ # resize the mask if it doesn't match the model's image size
322
+ if mask_H != self.image_size or mask_W != self.image_size:
323
+ mask_inputs = torch.nn.functional.interpolate(
324
+ mask_inputs_orig,
325
+ size=(self.image_size, self.image_size),
326
+ align_corners=False,
327
+ mode="bilinear",
328
+ antialias=True, # use antialias for downsampling
329
+ )
330
+ mask_inputs = (mask_inputs >= 0.5).float()
331
+ else:
332
+ mask_inputs = mask_inputs_orig
333
+
334
+ mask_inputs_per_frame[frame_idx] = mask_inputs
335
+ point_inputs_per_frame.pop(frame_idx, None)
336
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
337
+ # frame, meaning that the inputs points are to generate segments on this frame without
338
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
339
+ # the input points will be used to correct the already tracked masks.
340
+ obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
341
+ is_init_cond_frame = frame_idx not in obj_frames_tracked
342
+ # whether to track in reverse time order
343
+ if is_init_cond_frame:
344
+ reverse = False
345
+ else:
346
+ reverse = obj_frames_tracked[frame_idx]["reverse"]
347
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
348
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
349
+ # Add a frame to conditioning output if it's an initial conditioning frame or
350
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
351
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
352
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
353
+
354
+ current_out, _ = self._run_single_frame_inference(
355
+ inference_state=inference_state,
356
+ output_dict=obj_output_dict, # run on the slice of a single object
357
+ frame_idx=frame_idx,
358
+ batch_size=1, # run on the slice of a single object
359
+ is_init_cond_frame=is_init_cond_frame,
360
+ point_inputs=None,
361
+ mask_inputs=mask_inputs,
362
+ reverse=reverse,
363
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
364
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
365
+ # allows us to enforce non-overlapping constraints on all objects before encoding
366
+ # them into memory.
367
+ run_mem_encoder=False,
368
+ )
369
+ # Add the output to the output dict (to be used as future memory)
370
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
371
+
372
+ # Resize the output mask to the original video resolution
373
+ obj_ids = inference_state["obj_ids"]
374
+ consolidated_out = self._consolidate_temp_output_across_obj(
375
+ inference_state,
376
+ frame_idx,
377
+ is_cond=is_cond,
378
+ consolidate_at_video_res=True,
379
+ )
380
+ _, video_res_masks = self._get_orig_video_res_output(
381
+ inference_state, consolidated_out["pred_masks_video_res"]
382
+ )
383
+ return frame_idx, obj_ids, video_res_masks
384
+
385
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
386
+ """
387
+ Resize the object scores to the original video resolution (video_res_masks)
388
+ and apply non-overlapping constraints for final output.
389
+ """
390
+ device = inference_state["device"]
391
+ video_H = inference_state["video_height"]
392
+ video_W = inference_state["video_width"]
393
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
394
+ if any_res_masks.shape[-2:] == (video_H, video_W):
395
+ video_res_masks = any_res_masks
396
+ else:
397
+ video_res_masks = torch.nn.functional.interpolate(
398
+ any_res_masks,
399
+ size=(video_H, video_W),
400
+ mode="bilinear",
401
+ align_corners=False,
402
+ )
403
+ if self.non_overlap_masks:
404
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
405
+ return any_res_masks, video_res_masks
406
+
407
+ def _consolidate_temp_output_across_obj(
408
+ self,
409
+ inference_state,
410
+ frame_idx,
411
+ is_cond,
412
+ consolidate_at_video_res=False,
413
+ ):
414
+ """
415
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
416
+ a frame into a single output for all objects, including
417
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
418
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
419
+ (if they don't exist in `output_dict_per_obj` for this frame);
420
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
421
+ on the object scores.
422
+ """
423
+ batch_size = self._get_obj_num(inference_state)
424
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
425
+ # Optionally, we allow consolidating the temporary outputs at the original
426
+ # video resolution (to provide a better editing experience for mask prompts).
427
+ if consolidate_at_video_res:
428
+ consolidated_H = inference_state["video_height"]
429
+ consolidated_W = inference_state["video_width"]
430
+ consolidated_mask_key = "pred_masks_video_res"
431
+ else:
432
+ consolidated_H = consolidated_W = self.image_size // 4
433
+ consolidated_mask_key = "pred_masks"
434
+
435
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
436
+ # will be added when rerunning the memory encoder after applying non-overlapping
437
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
438
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
439
+ consolidated_out = {
440
+ consolidated_mask_key: torch.full(
441
+ size=(batch_size, 1, consolidated_H, consolidated_W),
442
+ fill_value=NO_OBJ_SCORE,
443
+ dtype=torch.float32,
444
+ device=inference_state["storage_device"],
445
+ ),
446
+ }
447
+ for obj_idx in range(batch_size):
448
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
449
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
450
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
451
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
452
+ # we fall back and look up its previous output in "output_dict_per_obj".
453
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
454
+ # "output_dict_per_obj" to find a previous output for this object.
455
+ if out is None:
456
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
457
+ if out is None:
458
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
459
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
460
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
461
+ # placeholder above) and set its object pointer to be a dummy pointer.
462
+ if out is None:
463
+ continue
464
+ # Add the temporary object output mask to consolidated output mask
465
+ obj_mask = out["pred_masks"]
466
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
467
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
468
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
469
+ else:
470
+ # Resize first if temporary object mask has a different resolution
471
+ resized_obj_mask = torch.nn.functional.interpolate(
472
+ obj_mask,
473
+ size=consolidated_pred_masks.shape[-2:],
474
+ mode="bilinear",
475
+ align_corners=False,
476
+ )
477
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
478
+
479
+ return consolidated_out
480
+
481
+ @torch.inference_mode()
482
+ def propagate_in_video_preflight(self, inference_state):
483
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
484
+ # Check and make sure that every object has received input points or masks.
485
+ batch_size = self._get_obj_num(inference_state)
486
+ if batch_size == 0:
487
+ raise RuntimeError(
488
+ "No input points or masks are provided for any object; please add inputs first."
489
+ )
490
+
491
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
492
+ # add them into "output_dict".
493
+ for obj_idx in range(batch_size):
494
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
495
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
496
+ for is_cond in [False, True]:
497
+ # Separately consolidate conditioning and non-conditioning temp outputs
498
+ storage_key = (
499
+ "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
500
+ )
501
+ # Find all the frames that contain temporary outputs for any objects
502
+ # (these should be the frames that have just received clicks for mask inputs
503
+ # via `add_new_points_or_box` or `add_new_mask`)
504
+ for frame_idx, out in obj_temp_output_dict[storage_key].items():
505
+ # Run memory encoder on the temporary outputs (if the memory feature is missing)
506
+ if out["maskmem_features"] is None:
507
+ high_res_masks = torch.nn.functional.interpolate(
508
+ out["pred_masks"].to(inference_state["device"]),
509
+ size=(self.image_size, self.image_size),
510
+ mode="bilinear",
511
+ align_corners=False,
512
+ )
513
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
514
+ inference_state=inference_state,
515
+ frame_idx=frame_idx,
516
+ batch_size=1, # run on the slice of a single object
517
+ high_res_masks=high_res_masks,
518
+ object_score_logits=out["object_score_logits"],
519
+ # these frames are what the user interacted with
520
+ is_mask_from_pts=True,
521
+ )
522
+ out["maskmem_features"] = maskmem_features
523
+ out["maskmem_pos_enc"] = maskmem_pos_enc
524
+
525
+ obj_output_dict[storage_key][frame_idx] = out
526
+ if self.clear_non_cond_mem_around_input:
527
+ # clear non-conditioning memory of the surrounding frames
528
+ self._clear_obj_non_cond_mem_around_input(
529
+ inference_state, frame_idx, obj_idx
530
+ )
531
+
532
+ # clear temporary outputs in `temp_output_dict_per_obj`
533
+ obj_temp_output_dict[storage_key].clear()
534
+
535
+ # check and make sure that every object has received input points or masks
536
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
537
+ # Relaxed check: Allow propagation if we have EITHER conditioning outputs OR tracking history (non-cond)
538
+ # This supports streaming/sliding window where initial clicks might be trimmed from memory
539
+ if len(obj_output_dict["cond_frame_outputs"]) == 0 and len(obj_output_dict["non_cond_frame_outputs"]) == 0:
540
+ obj_id = self._obj_idx_to_id(inference_state, obj_idx)
541
+ raise RuntimeError(
542
+ f"No input points or masks are provided for object id {obj_id}; please add inputs first."
543
+ )
544
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
545
+ # output on the same frame in "non_cond_frame_outputs"
546
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
547
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
548
+
549
+ def _get_optical_flow_prediction(self, prev_img, curr_img, prev_mask):
550
+ """
551
+ Predict the next centroid using Optical Flow (Lucas-Kanade).
552
+ prev_img, curr_img: torch tensors [3, H, W], normalized
553
+ prev_mask: torch tensor [1, H, W] (logits)
554
+ """
555
+ try:
556
+ # Convert to numpy grayscale (uint8)
557
+ # Images are normalized, so we need to handle that.
558
+ # We can just normalize min-max to 0-255.
559
+
560
+ prev_np = prev_img.mean(0).cpu().numpy()
561
+ curr_np = curr_img.mean(0).cpu().numpy()
562
+
563
+ prev_gray = cv2.normalize(prev_np, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
564
+ curr_gray = cv2.normalize(curr_np, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
565
+
566
+ # Get mask as uint8
567
+ mask_np = (prev_mask.squeeze().cpu().numpy() > 0.0).astype(np.uint8)
568
+
569
+ if mask_np.sum() == 0:
570
+ return None
571
+
572
+ # Find features to track in the previous mask
573
+ p0 = cv2.goodFeaturesToTrack(prev_gray, mask=mask_np, maxCorners=50, qualityLevel=0.1, minDistance=5, blockSize=7)
574
+
575
+ if p0 is None:
576
+ return None
577
+
578
+ # Calculate Optical Flow
579
+ p1, st, err = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, p0, None, winSize=(15, 15), maxLevel=2, criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))
580
+
581
+ if p1 is not None:
582
+ good_new = p1[st==1]
583
+ good_old = p0[st==1]
584
+
585
+ if len(good_new) > 0:
586
+ # Calculate average motion vector
587
+ motion = good_new - good_old
588
+ avg_motion = np.mean(motion, axis=0)
589
+
590
+ # Calculate centroid of old mask
591
+ M = cv2.moments(mask_np)
592
+ if M["m00"] != 0:
593
+ cX = M["m10"] / M["m00"]
594
+ cY = M["m01"] / M["m00"]
595
+
596
+ # Apply motion to centroid
597
+ new_cX = cX + avg_motion[0]
598
+ new_cY = cY + avg_motion[1]
599
+
600
+ # Ensure within bounds
601
+ H, W = prev_img.shape[1], prev_img.shape[2]
602
+ new_cX = max(0, min(new_cX, W - 1))
603
+ new_cY = max(0, min(new_cY, H - 1))
604
+
605
+ return torch.tensor([[new_cX, new_cY]], dtype=torch.float32)
606
+ except Exception:
607
+ pass
608
+ return None
609
+
610
+ @torch.inference_mode()
611
+ def propagate_in_video(
612
+ self,
613
+ inference_state,
614
+ start_frame_idx=None,
615
+ max_frame_num_to_track=None,
616
+ reverse=False,
617
+ use_motion_heuristics=False,
618
+ ):
619
+ """Propagate the input points across frames to track in the entire video."""
620
+ self.propagate_in_video_preflight(inference_state)
621
+
622
+ obj_ids = inference_state["obj_ids"]
623
+ num_frames = inference_state["num_frames"]
624
+ batch_size = self._get_obj_num(inference_state)
625
+
626
+ # set start index, end index, and processing order
627
+ if start_frame_idx is None:
628
+ # default: start from the earliest frame with input points
629
+ start_frame_idx = min(
630
+ t
631
+ for obj_output_dict in inference_state["output_dict_per_obj"].values()
632
+ for t in obj_output_dict["cond_frame_outputs"]
633
+ )
634
+ if max_frame_num_to_track is None:
635
+ # default: track all the frames in the video
636
+ max_frame_num_to_track = num_frames
637
+ if reverse:
638
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
639
+ if start_frame_idx > 0:
640
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
641
+ else:
642
+ processing_order = [] # skip reverse tracking if starting from frame 0
643
+ else:
644
+ end_frame_idx = min(
645
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
646
+ )
647
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
648
+
649
+ # Initialize centroid history for velocity calculation
650
+ # obj_idx -> {"last_centroid": (x, y), "velocity": (vx, vy)}
651
+ obj_motion_state = {}
652
+
653
+ for i, frame_idx in enumerate(tqdm(processing_order, desc="propagate in video")):
654
+ # Determine previous frame index
655
+ prev_frame_idx = None
656
+ if i > 0:
657
+ prev_frame_idx = processing_order[i-1]
658
+
659
+ pred_masks_per_obj = [None] * batch_size
660
+ for obj_idx in range(batch_size):
661
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
662
+
663
+ # Initialize motion state for this object if not present
664
+ if obj_idx not in obj_motion_state:
665
+ obj_motion_state[obj_idx] = {"last_centroid": None, "velocity": (0.0, 0.0)}
666
+
667
+ # We skip those frames already in consolidated outputs (these are frames
668
+ # that received input clicks or mask). Note that we cannot directly run
669
+ # batched forward on them via `_run_single_frame_inference` because the
670
+ # number of clicks on each object might be different.
671
+ if frame_idx in obj_output_dict["cond_frame_outputs"]:
672
+ storage_key = "cond_frame_outputs"
673
+ current_out = obj_output_dict[storage_key][frame_idx]
674
+ device = inference_state["device"]
675
+ pred_masks = current_out["pred_masks"].to(device, non_blocking=True)
676
+ if self.clear_non_cond_mem_around_input:
677
+ # clear non-conditioning memory of the surrounding frames
678
+ self._clear_obj_non_cond_mem_around_input(
679
+ inference_state, frame_idx, obj_idx
680
+ )
681
+
682
+ # Update motion state from ground truth/user interaction
683
+ if use_motion_heuristics:
684
+ # Calculate centroid of the mask
685
+ mask_np = (pred_masks.squeeze().cpu().numpy() > 0.0).astype(np.uint8)
686
+ M = cv2.moments(mask_np)
687
+ if M["m00"] != 0:
688
+ cX = M["m10"] / M["m00"]
689
+ cY = M["m01"] / M["m00"]
690
+
691
+ if obj_motion_state[obj_idx]["last_centroid"] is not None:
692
+ last_cx, last_cy = obj_motion_state[obj_idx]["last_centroid"]
693
+ # Simple exponential moving average for velocity
694
+ vx = cX - last_cx
695
+ vy = cY - last_cy
696
+ old_vx, old_vy = obj_motion_state[obj_idx]["velocity"]
697
+ # Smooth velocity (0.7 current, 0.3 history)
698
+ obj_motion_state[obj_idx]["velocity"] = (0.7 * vx + 0.3 * old_vx, 0.7 * vy + 0.3 * old_vy)
699
+
700
+ obj_motion_state[obj_idx]["last_centroid"] = (cX, cY)
701
+
702
+ else:
703
+ # Try to predict point from previous frame using Optical Flow or Velocity
704
+ point_inputs = None
705
+ if use_motion_heuristics and prev_frame_idx is not None:
706
+ # Get previous mask
707
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(prev_frame_idx)
708
+ if prev_out is None:
709
+ prev_out = obj_output_dict["cond_frame_outputs"].get(prev_frame_idx)
710
+
711
+ # CONFIDENCE CHECK: Only propagate if the previous frame had a high confidence object
712
+ # object_score_logits > 0 means the model thinks the object is present
713
+ is_confident = False
714
+ if prev_out is not None and "object_score_logits" in prev_out:
715
+ # Check if any logit is positive (object present)
716
+ if (prev_out["object_score_logits"] > 0.0).any():
717
+ is_confident = True
718
+
719
+ if prev_out is not None and is_confident:
720
+ # Get images
721
+ prev_img = inference_state["images"][prev_frame_idx]
722
+ curr_img = inference_state["images"][frame_idx]
723
+ prev_mask = prev_out["pred_masks"]
724
+
725
+ # 1. Try Optical Flow
726
+ pred_point = self._get_optical_flow_prediction(prev_img, curr_img, prev_mask)
727
+
728
+ # 2. Fallback to Velocity if Optical Flow failed
729
+ if pred_point is None and obj_motion_state[obj_idx]["last_centroid"] is not None:
730
+ last_cx, last_cy = obj_motion_state[obj_idx]["last_centroid"]
731
+ vx, vy = obj_motion_state[obj_idx]["velocity"]
732
+ # Predict next position
733
+ new_cx = last_cx + vx
734
+ new_cy = last_cy + vy
735
+
736
+ # Ensure bounds
737
+ H, W = prev_img.shape[1], prev_img.shape[2]
738
+ new_cx = max(0, min(new_cx, W - 1))
739
+ new_cy = max(0, min(new_cy, H - 1))
740
+
741
+ pred_point = torch.tensor([[new_cx, new_cy]], dtype=torch.float32)
742
+
743
+ if pred_point is not None:
744
+ # Create point_inputs dictionary
745
+ point_inputs = {
746
+ "point_coords": pred_point.unsqueeze(0).to(inference_state["device"]), # [1, 1, 2]
747
+ "point_labels": torch.tensor([[1]], dtype=torch.int32).to(inference_state["device"]) # [1, 1]
748
+ }
749
+
750
+ storage_key = "non_cond_frame_outputs"
751
+ current_out, pred_masks = self._run_single_frame_inference(
752
+ inference_state=inference_state,
753
+ output_dict=obj_output_dict,
754
+ frame_idx=frame_idx,
755
+ batch_size=1, # run on the slice of a single object
756
+ is_init_cond_frame=False,
757
+ point_inputs=point_inputs,
758
+ mask_inputs=None,
759
+ reverse=reverse,
760
+ run_mem_encoder=True,
761
+ )
762
+ obj_output_dict[storage_key][frame_idx] = current_out
763
+
764
+ # Update motion state from prediction
765
+ if use_motion_heuristics:
766
+ mask_np = (pred_masks.squeeze().cpu().numpy() > 0.0).astype(np.uint8)
767
+ M = cv2.moments(mask_np)
768
+ if M["m00"] != 0:
769
+ cX = M["m10"] / M["m00"]
770
+ cY = M["m01"] / M["m00"]
771
+
772
+ if obj_motion_state[obj_idx]["last_centroid"] is not None:
773
+ last_cx, last_cy = obj_motion_state[obj_idx]["last_centroid"]
774
+ vx = cX - last_cx
775
+ vy = cY - last_cy
776
+ old_vx, old_vy = obj_motion_state[obj_idx]["velocity"]
777
+ obj_motion_state[obj_idx]["velocity"] = (0.7 * vx + 0.3 * old_vx, 0.7 * vy + 0.3 * old_vy)
778
+
779
+ obj_motion_state[obj_idx]["last_centroid"] = (cX, cY)
780
+
781
+ inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
782
+ "reverse": reverse
783
+ }
784
+ pred_masks_per_obj[obj_idx] = pred_masks
785
+
786
+ # Resize the output mask to the original video resolution (we directly use
787
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
788
+ if len(pred_masks_per_obj) > 1:
789
+ all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
790
+ else:
791
+ all_pred_masks = pred_masks_per_obj[0]
792
+ _, video_res_masks = self._get_orig_video_res_output(
793
+ inference_state, all_pred_masks
794
+ )
795
+ yield frame_idx, obj_ids, video_res_masks
796
+
797
+ @torch.inference_mode()
798
+ def clear_all_prompts_in_frame(
799
+ self, inference_state, frame_idx, obj_id, need_output=True
800
+ ):
801
+ """Remove all input points or mask in a specific frame for a given object."""
802
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
803
+
804
+ # Clear the conditioning information on the given frame
805
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
806
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
807
+
808
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
809
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
810
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
811
+
812
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
813
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
814
+ out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
815
+ if out is not None:
816
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
817
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
818
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
819
+ inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
820
+
821
+ if not need_output:
822
+ return
823
+ # Finally, output updated masks per object (after removing the inputs above)
824
+ obj_ids = inference_state["obj_ids"]
825
+ is_cond = any(
826
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
827
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
828
+ )
829
+ consolidated_out = self._consolidate_temp_output_across_obj(
830
+ inference_state,
831
+ frame_idx,
832
+ is_cond=is_cond,
833
+ consolidate_at_video_res=True,
834
+ )
835
+ _, video_res_masks = self._get_orig_video_res_output(
836
+ inference_state, consolidated_out["pred_masks_video_res"]
837
+ )
838
+ return frame_idx, obj_ids, video_res_masks
839
+
840
+ @torch.inference_mode()
841
+ def reset_state(self, inference_state):
842
+ """Remove all input points or mask in all frames throughout the video."""
843
+ self._reset_tracking_results(inference_state)
844
+ # Remove all object ids
845
+ inference_state["obj_id_to_idx"].clear()
846
+ inference_state["obj_idx_to_id"].clear()
847
+ inference_state["obj_ids"].clear()
848
+ inference_state["point_inputs_per_obj"].clear()
849
+ inference_state["mask_inputs_per_obj"].clear()
850
+ inference_state["output_dict_per_obj"].clear()
851
+ inference_state["temp_output_dict_per_obj"].clear()
852
+ inference_state["frames_tracked_per_obj"].clear()
853
+
854
+ def _reset_tracking_results(self, inference_state):
855
+ """Reset all tracking inputs and results across the videos."""
856
+ for v in inference_state["point_inputs_per_obj"].values():
857
+ v.clear()
858
+ for v in inference_state["mask_inputs_per_obj"].values():
859
+ v.clear()
860
+ for v in inference_state["output_dict_per_obj"].values():
861
+ v["cond_frame_outputs"].clear()
862
+ v["non_cond_frame_outputs"].clear()
863
+ for v in inference_state["temp_output_dict_per_obj"].values():
864
+ v["cond_frame_outputs"].clear()
865
+ v["non_cond_frame_outputs"].clear()
866
+ for v in inference_state["frames_tracked_per_obj"].values():
867
+ v.clear()
868
+
869
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
870
+ """Compute the image features on a given frame."""
871
+ # Look up in the cache first
872
+ image, backbone_out = inference_state["cached_features"].get(
873
+ frame_idx, (None, None)
874
+ )
875
+ if backbone_out is None:
876
+ # Cache miss -- we will run inference on a single image
877
+ device = inference_state["device"]
878
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
879
+ backbone_out = self.forward_image(image)
880
+ # Cache the most recent frame's feature (for repeated interactions with
881
+ # a frame; we can use an LRU cache for more frames in the future).
882
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
883
+
884
+ # expand the features to have the same dimension as the number of objects
885
+ expanded_image = image.expand(batch_size, -1, -1, -1)
886
+ expanded_backbone_out = {
887
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
888
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
889
+ }
890
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
891
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
892
+ batch_size, -1, -1, -1
893
+ )
894
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
895
+ pos = pos.expand(batch_size, -1, -1, -1)
896
+ expanded_backbone_out["vision_pos_enc"][i] = pos
897
+
898
+ features = self._prepare_backbone_features(expanded_backbone_out)
899
+ features = (expanded_image,) + features
900
+ return features
901
+
902
+ def _run_single_frame_inference(
903
+ self,
904
+ inference_state,
905
+ output_dict,
906
+ frame_idx,
907
+ batch_size,
908
+ is_init_cond_frame,
909
+ point_inputs,
910
+ mask_inputs,
911
+ reverse,
912
+ run_mem_encoder,
913
+ prev_sam_mask_logits=None,
914
+ ):
915
+ """Run tracking on a single frame based on current inputs and previous memory."""
916
+ # Retrieve correct image features
917
+ (
918
+ _,
919
+ _,
920
+ current_vision_feats,
921
+ current_vision_pos_embeds,
922
+ feat_sizes,
923
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
924
+
925
+ # point and mask should not appear as input simultaneously on the same frame
926
+ assert point_inputs is None or mask_inputs is None
927
+ current_out = self.track_step(
928
+ frame_idx=frame_idx,
929
+ is_init_cond_frame=is_init_cond_frame,
930
+ current_vision_feats=current_vision_feats,
931
+ current_vision_pos_embeds=current_vision_pos_embeds,
932
+ feat_sizes=feat_sizes,
933
+ point_inputs=point_inputs,
934
+ mask_inputs=mask_inputs,
935
+ output_dict=output_dict,
936
+ num_frames=inference_state["num_frames"],
937
+ track_in_reverse=reverse,
938
+ run_mem_encoder=run_mem_encoder,
939
+ prev_sam_mask_logits=prev_sam_mask_logits,
940
+ )
941
+
942
+ # optionally offload the output to CPU memory to save GPU space
943
+ storage_device = inference_state["storage_device"]
944
+ maskmem_features = current_out["maskmem_features"]
945
+ if maskmem_features is not None:
946
+ maskmem_features = maskmem_features.to(torch.bfloat16)
947
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
948
+ pred_masks_gpu = current_out["pred_masks"]
949
+ # potentially fill holes in the predicted masks
950
+ if self.fill_hole_area > 0:
951
+ pred_masks_gpu = fill_holes_in_mask_scores(
952
+ pred_masks_gpu, self.fill_hole_area
953
+ )
954
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
955
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
956
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
957
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
958
+ obj_ptr = current_out["obj_ptr"]
959
+ object_score_logits = current_out["object_score_logits"]
960
+ # make a compact version of this frame's output to reduce the state size
961
+ compact_current_out = {
962
+ "maskmem_features": maskmem_features,
963
+ "maskmem_pos_enc": maskmem_pos_enc,
964
+ "pred_masks": pred_masks,
965
+ "obj_ptr": obj_ptr,
966
+ "object_score_logits": object_score_logits,
967
+ }
968
+ return compact_current_out, pred_masks_gpu
969
+
970
+ def _run_memory_encoder(
971
+ self,
972
+ inference_state,
973
+ frame_idx,
974
+ batch_size,
975
+ high_res_masks,
976
+ object_score_logits,
977
+ is_mask_from_pts,
978
+ ):
979
+ """
980
+ Run the memory encoder on `high_res_masks`. This is usually after applying
981
+ non-overlapping constraints to object scores. Since their scores changed, their
982
+ memory also need to be computed again with the memory encoder.
983
+ """
984
+ # Retrieve correct image features
985
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
986
+ inference_state, frame_idx, batch_size
987
+ )
988
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
989
+ current_vision_feats=current_vision_feats,
990
+ feat_sizes=feat_sizes,
991
+ pred_masks_high_res=high_res_masks,
992
+ object_score_logits=object_score_logits,
993
+ is_mask_from_pts=is_mask_from_pts,
994
+ )
995
+
996
+ # optionally offload the output to CPU memory to save GPU space
997
+ storage_device = inference_state["storage_device"]
998
+ maskmem_features = maskmem_features.to(torch.bfloat16)
999
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
1000
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1001
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
1002
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
1003
+ )
1004
+ return maskmem_features, maskmem_pos_enc
1005
+
1006
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
1007
+ """
1008
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1009
+ a constant in the inference session to reduce session storage size.
1010
+ """
1011
+ model_constants = inference_state["constants"]
1012
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
1013
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
1014
+ if out_maskmem_pos_enc is not None:
1015
+ if "maskmem_pos_enc" not in model_constants:
1016
+ assert isinstance(out_maskmem_pos_enc, list)
1017
+ # only take the slice for one object, since it's same across objects
1018
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1019
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
1020
+ else:
1021
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1022
+ # expand the cached maskmem_pos_enc to the actual batch size
1023
+ batch_size = out_maskmem_pos_enc[0].size(0)
1024
+ expanded_maskmem_pos_enc = [
1025
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
1026
+ ]
1027
+ else:
1028
+ expanded_maskmem_pos_enc = None
1029
+ return expanded_maskmem_pos_enc
1030
+
1031
+ @torch.inference_mode()
1032
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
1033
+ """
1034
+ Remove an object id from the tracking state. If strict is True, we check whether
1035
+ the object id actually exists and raise an error if it doesn't exist.
1036
+ """
1037
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
1038
+ updated_frames = []
1039
+ # Check whether this object_id to remove actually exists and possibly raise an error.
1040
+ if old_obj_idx_to_rm is None:
1041
+ if not strict:
1042
+ return inference_state["obj_ids"], updated_frames
1043
+ raise RuntimeError(
1044
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
1045
+ f"All existing object ids: {inference_state['obj_ids']}."
1046
+ )
1047
+
1048
+ # If this is the only remaining object id, we simply reset the state.
1049
+ if len(inference_state["obj_id_to_idx"]) == 1:
1050
+ self.reset_state(inference_state)
1051
+ return inference_state["obj_ids"], updated_frames
1052
+
1053
+ # There are still remaining objects after removing this object id. In this case,
1054
+ # we need to delete the object storage from inference state tensors.
1055
+ # Step 0: clear the input on those frames where this object id has point or mask input
1056
+ # (note that this step is required as it might downgrade conditioning frames to
1057
+ # non-conditioning ones)
1058
+ obj_input_frames_inds = set()
1059
+ obj_input_frames_inds.update(
1060
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
1061
+ )
1062
+ obj_input_frames_inds.update(
1063
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
1064
+ )
1065
+ for frame_idx in obj_input_frames_inds:
1066
+ self.clear_all_prompts_in_frame(
1067
+ inference_state, frame_idx, obj_id, need_output=False
1068
+ )
1069
+
1070
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
1071
+ # since Step 0 still requires the old object id mappings in inference_state)
1072
+ old_obj_ids = inference_state["obj_ids"]
1073
+ old_obj_inds = list(range(len(old_obj_ids)))
1074
+ remain_old_obj_inds = old_obj_inds.copy()
1075
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
1076
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
1077
+ new_obj_inds = list(range(len(new_obj_ids)))
1078
+ # build new mappings
1079
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
1080
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
1081
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
1082
+ inference_state["obj_ids"] = new_obj_ids
1083
+
1084
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
1085
+ def _map_keys(container):
1086
+ new_kvs = []
1087
+ for k in old_obj_inds:
1088
+ v = container.pop(k)
1089
+ if k in old_idx_to_new_idx:
1090
+ new_kvs.append((old_idx_to_new_idx[k], v))
1091
+ container.update(new_kvs)
1092
+
1093
+ _map_keys(inference_state["point_inputs_per_obj"])
1094
+ _map_keys(inference_state["mask_inputs_per_obj"])
1095
+ _map_keys(inference_state["output_dict_per_obj"])
1096
+ _map_keys(inference_state["temp_output_dict_per_obj"])
1097
+ _map_keys(inference_state["frames_tracked_per_obj"])
1098
+
1099
+ # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
1100
+ # could show an updated mask for objects previously occluded by the object being removed
1101
+ if need_output:
1102
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
1103
+ for frame_idx in obj_input_frames_inds:
1104
+ is_cond = any(
1105
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
1106
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
1107
+ )
1108
+ consolidated_out = self._consolidate_temp_output_across_obj(
1109
+ inference_state,
1110
+ frame_idx,
1111
+ is_cond=is_cond,
1112
+ consolidate_at_video_res=True,
1113
+ )
1114
+ _, video_res_masks = self._get_orig_video_res_output(
1115
+ inference_state, consolidated_out["pred_masks_video_res"]
1116
+ )
1117
+ updated_frames.append((frame_idx, video_res_masks))
1118
+
1119
+ return inference_state["obj_ids"], updated_frames
1120
+
1121
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
1122
+ """
1123
+ Remove the non-conditioning memory around the input frame. When users provide
1124
+ correction clicks, the surrounding frames' non-conditioning memories can still
1125
+ contain outdated object appearance information and could confuse the model.
1126
+
1127
+ This method clears those non-conditioning memories surrounding the interacted
1128
+ frame to avoid giving the model both old and new information about the object.
1129
+ """
1130
+ r = self.memory_temporal_stride_for_eval
1131
+ frame_idx_begin = frame_idx - r * self.num_maskmem
1132
+ frame_idx_end = frame_idx + r * self.num_maskmem
1133
+ batch_size = self._get_obj_num(inference_state)
1134
+ for obj_idx in range(batch_size):
1135
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
1136
+ non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
1137
+ for t in range(frame_idx_begin, frame_idx_end + 1):
1138
+ non_cond_frame_outputs.pop(t, None)
1139
+
1140
+
1141
+ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
1142
+ """Optimized for the VOS setting"""
1143
+
1144
+ def __init__(self, *args, **kwargs):
1145
+ super().__init__(*args, **kwargs)
1146
+ self._compile_all_components()
1147
+
1148
+ def _compile_all_components(self):
1149
+ print("Compiling all components for VOS setting. First time may be very slow.")
1150
+ self.memory_encoder.forward = torch.compile(
1151
+ self.memory_encoder.forward,
1152
+ mode="max-autotune",
1153
+ fullgraph=True,
1154
+ dynamic=False,
1155
+ )
1156
+
1157
+ self.memory_attention.forward = torch.compile(
1158
+ self.memory_attention.forward,
1159
+ mode="max-autotune",
1160
+ fullgraph=True,
1161
+ dynamic=True, # Num. of memories varies
1162
+ )
1163
+
1164
+ self.sam_prompt_encoder.forward = torch.compile(
1165
+ self.sam_prompt_encoder.forward,
1166
+ mode="max-autotune",
1167
+ fullgraph=True,
1168
+ dynamic=False, # Accuracy regression on True
1169
+ )
1170
+
1171
+ self.sam_mask_decoder.forward = torch.compile(
1172
+ self.sam_mask_decoder.forward,
1173
+ mode="max-autotune",
1174
+ fullgraph=True,
1175
+ dynamic=False, # Accuracy regression on True
1176
+ )
1177
+
1178
+ def forward_image(self, img_batch: torch.Tensor):
1179
+ """
1180
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
1181
+ cloning the backbone features and pos encoding to enable compilation.
1182
+ """
1183
+ backbone_out = self.image_encoder(img_batch)
1184
+ if self.use_high_res_features_in_sam:
1185
+ # precompute projected level 0 and level 1 features in SAM decoder
1186
+ # to avoid running it again on every SAM click
1187
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
1188
+ backbone_out["backbone_fpn"][0]
1189
+ )
1190
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
1191
+ backbone_out["backbone_fpn"][1]
1192
+ )
1193
+ # Clone to help torch.compile
1194
+ for i in range(len(backbone_out["backbone_fpn"])):
1195
+ backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
1196
+ backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
1197
+ i
1198
+ ].clone()
1199
+ return backbone_out
1200
+
1201
+ def _forward_sam_heads(
1202
+ self,
1203
+ backbone_features,
1204
+ point_inputs=None,
1205
+ mask_inputs=None,
1206
+ high_res_features=None,
1207
+ multimask_output=False,
1208
+ ):
1209
+ """
1210
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
1211
+ cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
1212
+ """
1213
+ B = backbone_features.size(0)
1214
+ device = backbone_features.device
1215
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
1216
+ assert backbone_features.size(2) == self.sam_image_embedding_size
1217
+ assert backbone_features.size(3) == self.sam_image_embedding_size
1218
+
1219
+ # a) Handle point prompts
1220
+ if point_inputs is not None:
1221
+ sam_point_coords = point_inputs["point_coords"]
1222
+ sam_point_labels = point_inputs["point_labels"]
1223
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
1224
+ else:
1225
+ # If no points are provide, pad with an empty point (with label -1)
1226
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
1227
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
1228
+
1229
+ # b) Handle mask prompts
1230
+ if mask_inputs is not None:
1231
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
1232
+ # and feed it as a dense mask prompt into the SAM mask encoder
1233
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
1234
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
1235
+ sam_mask_prompt = F.interpolate(
1236
+ mask_inputs.float(),
1237
+ size=self.sam_prompt_encoder.mask_input_size,
1238
+ align_corners=False,
1239
+ mode="bilinear",
1240
+ antialias=True, # use antialias for downsampling
1241
+ )
1242
+ else:
1243
+ sam_mask_prompt = mask_inputs
1244
+ else:
1245
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
1246
+ # a learned `no_mask_embed` to indicate no mask input in this case).
1247
+ sam_mask_prompt = None
1248
+
1249
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
1250
+ points=(sam_point_coords, sam_point_labels),
1251
+ boxes=None,
1252
+ masks=sam_mask_prompt,
1253
+ )
1254
+ # Clone image_pe and the outputs of sam_prompt_encoder
1255
+ # to enable compilation
1256
+ sparse_embeddings = sparse_embeddings.clone()
1257
+ dense_embeddings = dense_embeddings.clone()
1258
+ image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
1259
+ (
1260
+ low_res_multimasks,
1261
+ ious,
1262
+ sam_output_tokens,
1263
+ object_score_logits,
1264
+ ) = self.sam_mask_decoder(
1265
+ image_embeddings=backbone_features,
1266
+ image_pe=image_pe,
1267
+ sparse_prompt_embeddings=sparse_embeddings,
1268
+ dense_prompt_embeddings=dense_embeddings,
1269
+ multimask_output=multimask_output,
1270
+ repeat_image=False, # the image is already batched
1271
+ high_res_features=high_res_features,
1272
+ )
1273
+ # Clone the output of sam_mask_decoder
1274
+ # to enable compilation
1275
+ low_res_multimasks = low_res_multimasks.clone()
1276
+ ious = ious.clone()
1277
+ sam_output_tokens = sam_output_tokens.clone()
1278
+ object_score_logits = object_score_logits.clone()
1279
+
1280
+ if self.pred_obj_scores:
1281
+ is_obj_appearing = object_score_logits > 0
1282
+
1283
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
1284
+ # consistent with the actual mask prediction
1285
+ low_res_multimasks = torch.where(
1286
+ is_obj_appearing[:, None, None],
1287
+ low_res_multimasks,
1288
+ NO_OBJ_SCORE,
1289
+ )
1290
+
1291
+ # convert masks from possibly bfloat16 (or float16) to float32
1292
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
1293
+ low_res_multimasks = low_res_multimasks.float()
1294
+ high_res_multimasks = F.interpolate(
1295
+ low_res_multimasks,
1296
+ size=(self.image_size, self.image_size),
1297
+ mode="bilinear",
1298
+ align_corners=False,
1299
+ )
1300
+
1301
+ sam_output_token = sam_output_tokens[:, 0]
1302
+ if multimask_output:
1303
+ # take the best mask prediction (with the highest IoU estimation)
1304
+ best_iou_inds = torch.argmax(ious, dim=-1)
1305
+ batch_inds = torch.arange(B, device=device)
1306
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
1307
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
1308
+ if sam_output_tokens.size(1) > 1:
1309
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
1310
+ else:
1311
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
1312
+
1313
+ # Extract object pointer from the SAM output token (with occlusion handling)
1314
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
1315
+ if self.pred_obj_scores:
1316
+ # Allow *soft* no obj ptr, unlike for masks
1317
+ if self.soft_no_obj_ptr:
1318
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
1319
+ else:
1320
+ lambda_is_obj_appearing = is_obj_appearing.float()
1321
+
1322
+ if self.fixed_no_obj_ptr:
1323
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
1324
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
1325
+
1326
+ return (
1327
+ low_res_multimasks,
1328
+ high_res_multimasks,
1329
+ ious,
1330
+ low_res_masks,
1331
+ high_res_masks,
1332
+ obj_ptr,
1333
+ object_score_logits,
1334
+ )
1335
+
1336
+ def _encode_new_memory(
1337
+ self,
1338
+ current_vision_feats,
1339
+ feat_sizes,
1340
+ pred_masks_high_res,
1341
+ object_score_logits,
1342
+ is_mask_from_pts,
1343
+ ):
1344
+ """
1345
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
1346
+ cloning the memories and their pos enc to enable compilation.
1347
+ """
1348
+ B = current_vision_feats[-1].size(1) # batch size on this frame
1349
+ C = self.hidden_dim
1350
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
1351
+ # top-level feature, (HW)BC => BCHW
1352
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
1353
+ if self.non_overlap_masks_for_mem_enc and not self.training:
1354
+ # optionally, apply non-overlapping constraints to the masks (it's applied
1355
+ # in the batch dimension and should only be used during eval, where all
1356
+ # the objects come from the same video under batch size 1).
1357
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
1358
+ pred_masks_high_res
1359
+ )
1360
+ # scale the raw mask logits with a temperature before applying sigmoid
1361
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
1362
+ if binarize and not self.training:
1363
+ mask_for_mem = (pred_masks_high_res > 0).float()
1364
+ else:
1365
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
1366
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
1367
+ # apply scale and bias terms to the sigmoid probabilities
1368
+ if self.sigmoid_scale_for_mem_enc != 1.0:
1369
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
1370
+ if self.sigmoid_bias_for_mem_enc != 0.0:
1371
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
1372
+ maskmem_out = self.memory_encoder(
1373
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
1374
+ )
1375
+ # Clone the feats and pos_enc to enable compilation
1376
+ maskmem_features = maskmem_out["vision_features"].clone()
1377
+ maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
1378
+ # add a no-object embedding to the spatial memory to indicate that the frame
1379
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
1380
+ if self.no_obj_embed_spatial is not None:
1381
+ is_obj_appearing = (object_score_logits > 0).float()
1382
+ maskmem_features += (
1383
+ 1 - is_obj_appearing[..., None, None]
1384
+ ) * self.no_obj_embed_spatial[..., None, None].expand(
1385
+ *maskmem_features.shape
1386
+ )
1387
+
1388
+ return maskmem_features, maskmem_pos_enc