nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,515 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.distributed
12
+ from sam2.modeling.sam2_base import SAM2Base
13
+ from sam2.modeling.sam2_utils import (
14
+ get_1d_sine_pe,
15
+ get_next_point,
16
+ sample_box_points,
17
+ select_closest_cond_frames,
18
+ )
19
+
20
+ from sam2.utils.misc import concat_points
21
+
22
+ from training.utils.data_utils import BatchedVideoDatapoint
23
+
24
+
25
+ class SAM2Train(SAM2Base):
26
+ def __init__(
27
+ self,
28
+ image_encoder,
29
+ memory_attention=None,
30
+ memory_encoder=None,
31
+ prob_to_use_pt_input_for_train=0.0,
32
+ prob_to_use_pt_input_for_eval=0.0,
33
+ prob_to_use_box_input_for_train=0.0,
34
+ prob_to_use_box_input_for_eval=0.0,
35
+ # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
36
+ num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
37
+ num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
38
+ rand_frames_to_correct_for_train=False,
39
+ rand_frames_to_correct_for_eval=False,
40
+ # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
41
+ # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
42
+ # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
43
+ # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
44
+ # these are initial conditioning frames because as we track the video, more conditioning frames might be added
45
+ # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
46
+ num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
47
+ num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
48
+ rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
49
+ rand_init_cond_frames_for_eval=False,
50
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
51
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
52
+ add_all_frames_to_correct_as_cond=False,
53
+ # how many additional correction points to sample (on each frame selected to be corrected)
54
+ # note that the first frame receives an initial input click (in addition to any correction clicks)
55
+ num_correction_pt_per_frame=7,
56
+ # method for point sampling during evaluation
57
+ # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
58
+ # default to "center" to be consistent with evaluation in the SAM paper
59
+ pt_sampling_for_eval="center",
60
+ # During training, we optionally allow sampling the correction points from GT regions
61
+ # instead of the prediction error regions with a small probability. This might allow the
62
+ # model to overfit less to the error regions in training datasets
63
+ prob_to_sample_from_gt_for_train=0.0,
64
+ use_act_ckpt_iterative_pt_sampling=False,
65
+ # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
66
+ # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
67
+ forward_backbone_per_frame_for_eval=False,
68
+ freeze_image_encoder=False,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
72
+ self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
73
+ self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
74
+
75
+ # Point sampler and conditioning frames
76
+ self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
77
+ self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
78
+ self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
79
+ self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
80
+ if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
81
+ logging.info(f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}")
82
+ assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
83
+ assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
84
+
85
+ self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
86
+ self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
87
+ self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
88
+ self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
89
+ # Initial multi-conditioning frames
90
+ self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
91
+ self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
92
+ self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
93
+ self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
94
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
95
+ self.num_correction_pt_per_frame = num_correction_pt_per_frame
96
+ self.pt_sampling_for_eval = pt_sampling_for_eval
97
+ self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
98
+ # A random number generator with a fixed initial seed across GPUs
99
+ self.rng = np.random.default_rng(seed=42)
100
+
101
+ if freeze_image_encoder:
102
+ for p in self.image_encoder.parameters():
103
+ p.requires_grad = False
104
+
105
+ def forward(self, input: BatchedVideoDatapoint):
106
+ if self.training or not self.forward_backbone_per_frame_for_eval:
107
+ # precompute image features on all frames before tracking
108
+ backbone_out = self.forward_image(input.flat_img_batch)
109
+ else:
110
+ # defer image feature computation on a frame until it's being tracked
111
+ backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
112
+ backbone_out = self.prepare_prompt_inputs(backbone_out, input)
113
+ previous_stages_out = self.forward_tracking(backbone_out, input)
114
+
115
+ return previous_stages_out
116
+
117
+ def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
118
+ """Compute the image backbone features on the fly for the given img_ids."""
119
+ # Only forward backbone on unique image ids to avoid repetitive computation
120
+ # (if `img_ids` has only one element, it's already unique so we skip this step).
121
+ if img_ids.numel() > 1:
122
+ unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
123
+ else:
124
+ unique_img_ids, inv_ids = img_ids, None
125
+
126
+ # Compute the image features on those unique image ids
127
+ image = img_batch[unique_img_ids]
128
+ backbone_out = self.forward_image(image)
129
+ (
130
+ _,
131
+ vision_feats,
132
+ vision_pos_embeds,
133
+ feat_sizes,
134
+ ) = self._prepare_backbone_features(backbone_out)
135
+ # Inverse-map image features for `unique_img_ids` to the final image features
136
+ # for the original input `img_ids`.
137
+ if inv_ids is not None:
138
+ image = image[inv_ids]
139
+ vision_feats = [x[:, inv_ids] for x in vision_feats]
140
+ vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
141
+
142
+ return image, vision_feats, vision_pos_embeds, feat_sizes
143
+
144
+ def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
145
+ """
146
+ Prepare input mask, point or box prompts. Optionally, we allow tracking from
147
+ a custom `start_frame_idx` to the end of the video (for evaluation purposes).
148
+ """
149
+ # Load the ground-truth masks on all frames (so that we can later
150
+ # sample correction points from them)
151
+ # gt_masks_per_frame = {
152
+ # stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
153
+ # for stage_id, targets in enumerate(input.find_targets)
154
+ # }
155
+ gt_masks_per_frame = {
156
+ stage_id: masks.unsqueeze(1) for stage_id, masks in enumerate(input.masks) # [B, 1, H_im, W_im]
157
+ }
158
+ # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
159
+ backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
160
+ num_frames = input.num_frames
161
+ backbone_out["num_frames"] = num_frames
162
+
163
+ # Randomly decide whether to use point inputs or mask inputs
164
+ if self.training:
165
+ prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
166
+ prob_to_use_box_input = self.prob_to_use_box_input_for_train
167
+ num_frames_to_correct = self.num_frames_to_correct_for_train
168
+ rand_frames_to_correct = self.rand_frames_to_correct_for_train
169
+ num_init_cond_frames = self.num_init_cond_frames_for_train
170
+ rand_init_cond_frames = self.rand_init_cond_frames_for_train
171
+ else:
172
+ prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
173
+ prob_to_use_box_input = self.prob_to_use_box_input_for_eval
174
+ num_frames_to_correct = self.num_frames_to_correct_for_eval
175
+ rand_frames_to_correct = self.rand_frames_to_correct_for_eval
176
+ num_init_cond_frames = self.num_init_cond_frames_for_eval
177
+ rand_init_cond_frames = self.rand_init_cond_frames_for_eval
178
+ if num_frames == 1:
179
+ # here we handle a special case for mixing video + SAM on image training,
180
+ # where we force using point input for the SAM task on static images
181
+ prob_to_use_pt_input = 1.0
182
+ num_frames_to_correct = 1
183
+ num_init_cond_frames = 1
184
+ assert num_init_cond_frames >= 1
185
+ # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
186
+ use_pt_input = self.rng.random() < prob_to_use_pt_input
187
+ if rand_init_cond_frames and num_init_cond_frames > 1:
188
+ # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
189
+ num_init_cond_frames = self.rng.integers(1, num_init_cond_frames, endpoint=True)
190
+ if use_pt_input and rand_frames_to_correct and num_frames_to_correct > num_init_cond_frames:
191
+ # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
192
+ # correction clicks (only for the case of point input)
193
+ num_frames_to_correct = self.rng.integers(num_init_cond_frames, num_frames_to_correct, endpoint=True)
194
+ backbone_out["use_pt_input"] = use_pt_input
195
+
196
+ # Sample initial conditioning frames
197
+ if num_init_cond_frames == 1:
198
+ init_cond_frames = [start_frame_idx] # starting frame
199
+ else:
200
+ # starting frame + randomly selected remaining frames (without replacement)
201
+ init_cond_frames = [start_frame_idx] + self.rng.choice(
202
+ range(start_frame_idx + 1, num_frames),
203
+ num_init_cond_frames - 1,
204
+ replace=False,
205
+ ).tolist()
206
+ backbone_out["init_cond_frames"] = init_cond_frames
207
+ backbone_out["frames_not_in_init_cond"] = [
208
+ t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
209
+ ]
210
+ # Prepare mask or point inputs on initial conditioning frames
211
+ backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
212
+ backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
213
+ for t in init_cond_frames:
214
+ if not use_pt_input:
215
+ backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
216
+ else:
217
+ # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
218
+ use_box_input = self.rng.random() < prob_to_use_box_input
219
+ if use_box_input:
220
+ points, labels = sample_box_points(
221
+ gt_masks_per_frame[t],
222
+ )
223
+ else:
224
+ # (here we only sample **one initial point** on initial conditioning frames from the
225
+ # ground-truth mask; we may sample more correction points on the fly)
226
+ points, labels = get_next_point(
227
+ gt_masks=gt_masks_per_frame[t],
228
+ pred_masks=None,
229
+ method=("uniform" if self.training else self.pt_sampling_for_eval),
230
+ )
231
+
232
+ point_inputs = {"point_coords": points, "point_labels": labels}
233
+ backbone_out["point_inputs_per_frame"][t] = point_inputs
234
+
235
+ # Sample frames where we will add correction clicks on the fly
236
+ # based on the error between prediction and ground-truth masks
237
+ if not use_pt_input:
238
+ # no correction points will be sampled when using mask inputs
239
+ frames_to_add_correction_pt = []
240
+ elif num_frames_to_correct == num_init_cond_frames:
241
+ frames_to_add_correction_pt = init_cond_frames
242
+ else:
243
+ assert num_frames_to_correct > num_init_cond_frames
244
+ # initial cond frame + randomly selected remaining frames (without replacement)
245
+ extra_num = num_frames_to_correct - num_init_cond_frames
246
+ frames_to_add_correction_pt = (
247
+ init_cond_frames
248
+ + self.rng.choice(backbone_out["frames_not_in_init_cond"], extra_num, replace=False).tolist()
249
+ )
250
+ backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
251
+
252
+ return backbone_out
253
+
254
+ def forward_tracking(self, backbone_out, input: BatchedVideoDatapoint, return_dict=False):
255
+ """Forward video tracking on each frame (and sample correction clicks)."""
256
+ img_feats_already_computed = backbone_out["backbone_fpn"] is not None
257
+ if img_feats_already_computed:
258
+ # Prepare the backbone features
259
+ # - vision_feats and vision_pos_embeds are in (HW)BC format
260
+ (
261
+ _,
262
+ vision_feats,
263
+ vision_pos_embeds,
264
+ feat_sizes,
265
+ ) = self._prepare_backbone_features(backbone_out)
266
+
267
+ # Starting the stage loop
268
+ num_frames = backbone_out["num_frames"]
269
+ init_cond_frames = backbone_out["init_cond_frames"]
270
+ frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
271
+ # first process all the initial conditioning frames to encode them as memory,
272
+ # and then conditioning on them to track the remaining frames
273
+ processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
274
+ output_dict = {
275
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
276
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
277
+ }
278
+ for stage_id in processing_order:
279
+ # Get the image features for the current frames
280
+ # img_ids = input.find_inputs[stage_id].img_ids
281
+ img_ids = input.flat_obj_to_img_idx[stage_id]
282
+ if img_feats_already_computed:
283
+ # Retrieve image features according to img_ids (if they are already computed).
284
+ current_vision_feats = [x[:, img_ids] for x in vision_feats]
285
+ current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
286
+ else:
287
+ # Otherwise, compute the image features on the fly for the given img_ids
288
+ # (this might be used for evaluation on long videos to avoid backbone OOM).
289
+ (
290
+ _,
291
+ current_vision_feats,
292
+ current_vision_pos_embeds,
293
+ feat_sizes,
294
+ ) = self._prepare_backbone_features_per_frame(input.flat_img_batch, img_ids)
295
+
296
+ # Get output masks based on this frame's prompts and previous memory
297
+ current_out = self.track_step(
298
+ frame_idx=stage_id,
299
+ is_init_cond_frame=stage_id in init_cond_frames,
300
+ current_vision_feats=current_vision_feats,
301
+ current_vision_pos_embeds=current_vision_pos_embeds,
302
+ feat_sizes=feat_sizes,
303
+ point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
304
+ mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
305
+ gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
306
+ frames_to_add_correction_pt=frames_to_add_correction_pt,
307
+ output_dict=output_dict,
308
+ num_frames=num_frames,
309
+ )
310
+ # Append the output, depending on whether it's a conditioning frame
311
+ add_output_as_cond_frame = stage_id in init_cond_frames or (
312
+ self.add_all_frames_to_correct_as_cond and stage_id in frames_to_add_correction_pt
313
+ )
314
+ if add_output_as_cond_frame:
315
+ output_dict["cond_frame_outputs"][stage_id] = current_out
316
+ else:
317
+ output_dict["non_cond_frame_outputs"][stage_id] = current_out
318
+
319
+ if return_dict:
320
+ return output_dict
321
+ # turn `output_dict` into a list for loss function
322
+ all_frame_outputs = {}
323
+ all_frame_outputs.update(output_dict["cond_frame_outputs"])
324
+ all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
325
+ all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
326
+ # Make DDP happy with activation checkpointing by removing unused keys
327
+ all_frame_outputs = [{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs]
328
+
329
+ return all_frame_outputs
330
+
331
+ def track_step(
332
+ self,
333
+ frame_idx,
334
+ is_init_cond_frame,
335
+ current_vision_feats,
336
+ current_vision_pos_embeds,
337
+ feat_sizes,
338
+ point_inputs,
339
+ mask_inputs,
340
+ output_dict,
341
+ num_frames,
342
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
343
+ run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
344
+ prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
345
+ frames_to_add_correction_pt=None,
346
+ gt_masks=None,
347
+ ):
348
+ if frames_to_add_correction_pt is None:
349
+ frames_to_add_correction_pt = []
350
+ current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
351
+ frame_idx,
352
+ is_init_cond_frame,
353
+ current_vision_feats,
354
+ current_vision_pos_embeds,
355
+ feat_sizes,
356
+ point_inputs,
357
+ mask_inputs,
358
+ output_dict,
359
+ num_frames,
360
+ track_in_reverse,
361
+ prev_sam_mask_logits,
362
+ )
363
+
364
+ (
365
+ low_res_multimasks,
366
+ high_res_multimasks,
367
+ ious,
368
+ low_res_masks,
369
+ high_res_masks,
370
+ obj_ptr,
371
+ object_score_logits,
372
+ ) = sam_outputs
373
+
374
+ current_out["multistep_pred_masks"] = low_res_masks
375
+ current_out["multistep_pred_masks_high_res"] = high_res_masks
376
+ current_out["multistep_pred_multimasks"] = [low_res_multimasks]
377
+ current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
378
+ current_out["multistep_pred_ious"] = [ious]
379
+ current_out["multistep_point_inputs"] = [point_inputs]
380
+ current_out["multistep_object_score_logits"] = [object_score_logits]
381
+
382
+ # Optionally, sample correction points iteratively to correct the mask
383
+ if frame_idx in frames_to_add_correction_pt:
384
+ point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
385
+ is_init_cond_frame,
386
+ point_inputs,
387
+ gt_masks,
388
+ high_res_features,
389
+ pix_feat,
390
+ low_res_multimasks,
391
+ high_res_multimasks,
392
+ ious,
393
+ low_res_masks,
394
+ high_res_masks,
395
+ object_score_logits,
396
+ current_out,
397
+ )
398
+ (
399
+ _,
400
+ _,
401
+ _,
402
+ low_res_masks,
403
+ high_res_masks,
404
+ obj_ptr,
405
+ object_score_logits,
406
+ ) = final_sam_outputs
407
+
408
+ # Use the final prediction (after all correction steps for output and eval)
409
+ current_out["pred_masks"] = low_res_masks
410
+ current_out["pred_masks_high_res"] = high_res_masks
411
+ current_out["obj_ptr"] = obj_ptr
412
+
413
+ # Finally run the memory encoder on the predicted mask to encode
414
+ # it into a new memory feature (that can be used in future frames)
415
+ self._encode_memory_in_output(
416
+ current_vision_feats,
417
+ feat_sizes,
418
+ point_inputs,
419
+ run_mem_encoder,
420
+ high_res_masks,
421
+ object_score_logits,
422
+ current_out,
423
+ )
424
+ return current_out
425
+
426
+ def _iter_correct_pt_sampling(
427
+ self,
428
+ is_init_cond_frame,
429
+ point_inputs,
430
+ gt_masks,
431
+ high_res_features,
432
+ pix_feat_with_mem,
433
+ low_res_multimasks,
434
+ high_res_multimasks,
435
+ ious,
436
+ low_res_masks,
437
+ high_res_masks,
438
+ object_score_logits,
439
+ current_out,
440
+ ):
441
+
442
+ assert gt_masks is not None
443
+ all_pred_masks = [low_res_masks]
444
+ all_pred_high_res_masks = [high_res_masks]
445
+ all_pred_multimasks = [low_res_multimasks]
446
+ all_pred_high_res_multimasks = [high_res_multimasks]
447
+ all_pred_ious = [ious]
448
+ all_point_inputs = [point_inputs]
449
+ all_object_score_logits = [object_score_logits]
450
+ for _ in range(self.num_correction_pt_per_frame):
451
+ # sample a new point from the error between prediction and ground-truth
452
+ # (with a small probability, directly sample from GT masks instead of errors)
453
+ if self.training and self.prob_to_sample_from_gt_for_train > 0:
454
+ sample_from_gt = self.rng.random() < self.prob_to_sample_from_gt_for_train
455
+ else:
456
+ sample_from_gt = False
457
+ # if `pred_for_new_pt` is None, only GT masks will be used for point sampling
458
+ pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
459
+ new_points, new_labels = get_next_point(
460
+ gt_masks=gt_masks,
461
+ pred_masks=pred_for_new_pt,
462
+ method="uniform" if self.training else self.pt_sampling_for_eval,
463
+ )
464
+ point_inputs = concat_points(point_inputs, new_points, new_labels)
465
+ # Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
466
+ # For tracking, this means that when the user adds a correction click, we also feed
467
+ # the tracking output mask logits along with the click as input to the SAM decoder.
468
+ mask_inputs = low_res_masks
469
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
470
+ if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
471
+ sam_outputs = torch.utils.checkpoint.checkpoint(
472
+ self._forward_sam_heads,
473
+ backbone_features=pix_feat_with_mem,
474
+ point_inputs=point_inputs,
475
+ mask_inputs=mask_inputs,
476
+ high_res_features=high_res_features,
477
+ multimask_output=multimask_output,
478
+ use_reentrant=False,
479
+ )
480
+ else:
481
+ sam_outputs = self._forward_sam_heads(
482
+ backbone_features=pix_feat_with_mem,
483
+ point_inputs=point_inputs,
484
+ mask_inputs=mask_inputs,
485
+ high_res_features=high_res_features,
486
+ multimask_output=multimask_output,
487
+ )
488
+ (
489
+ low_res_multimasks,
490
+ high_res_multimasks,
491
+ ious,
492
+ low_res_masks,
493
+ high_res_masks,
494
+ _,
495
+ object_score_logits,
496
+ ) = sam_outputs
497
+ all_pred_masks.append(low_res_masks)
498
+ all_pred_high_res_masks.append(high_res_masks)
499
+ all_pred_multimasks.append(low_res_multimasks)
500
+ all_pred_high_res_multimasks.append(high_res_multimasks)
501
+ all_pred_ious.append(ious)
502
+ all_point_inputs.append(point_inputs)
503
+ all_object_score_logits.append(object_score_logits)
504
+
505
+ # Concatenate the masks along channel (to compute losses on all of them,
506
+ # using `MultiStepIteractiveMasks`)
507
+ current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
508
+ current_out["multistep_pred_masks_high_res"] = torch.cat(all_pred_high_res_masks, dim=1)
509
+ current_out["multistep_pred_multimasks"] = all_pred_multimasks
510
+ current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
511
+ current_out["multistep_pred_ious"] = all_pred_ious
512
+ current_out["multistep_point_inputs"] = all_point_inputs
513
+ current_out["multistep_object_score_logits"] = all_object_score_logits
514
+
515
+ return point_inputs, sam_outputs