nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,515 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
import torch.distributed
|
|
12
|
+
from sam2.modeling.sam2_base import SAM2Base
|
|
13
|
+
from sam2.modeling.sam2_utils import (
|
|
14
|
+
get_1d_sine_pe,
|
|
15
|
+
get_next_point,
|
|
16
|
+
sample_box_points,
|
|
17
|
+
select_closest_cond_frames,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from sam2.utils.misc import concat_points
|
|
21
|
+
|
|
22
|
+
from training.utils.data_utils import BatchedVideoDatapoint
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SAM2Train(SAM2Base):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
image_encoder,
|
|
29
|
+
memory_attention=None,
|
|
30
|
+
memory_encoder=None,
|
|
31
|
+
prob_to_use_pt_input_for_train=0.0,
|
|
32
|
+
prob_to_use_pt_input_for_eval=0.0,
|
|
33
|
+
prob_to_use_box_input_for_train=0.0,
|
|
34
|
+
prob_to_use_box_input_for_eval=0.0,
|
|
35
|
+
# if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
|
|
36
|
+
num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
|
|
37
|
+
num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
|
|
38
|
+
rand_frames_to_correct_for_train=False,
|
|
39
|
+
rand_frames_to_correct_for_eval=False,
|
|
40
|
+
# how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
|
|
41
|
+
# - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
|
|
42
|
+
# - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
|
|
43
|
+
# note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
|
|
44
|
+
# these are initial conditioning frames because as we track the video, more conditioning frames might be added
|
|
45
|
+
# when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
|
|
46
|
+
num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
|
|
47
|
+
num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
|
|
48
|
+
rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
|
|
49
|
+
rand_init_cond_frames_for_eval=False,
|
|
50
|
+
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
|
51
|
+
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
|
52
|
+
add_all_frames_to_correct_as_cond=False,
|
|
53
|
+
# how many additional correction points to sample (on each frame selected to be corrected)
|
|
54
|
+
# note that the first frame receives an initial input click (in addition to any correction clicks)
|
|
55
|
+
num_correction_pt_per_frame=7,
|
|
56
|
+
# method for point sampling during evaluation
|
|
57
|
+
# "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
|
|
58
|
+
# default to "center" to be consistent with evaluation in the SAM paper
|
|
59
|
+
pt_sampling_for_eval="center",
|
|
60
|
+
# During training, we optionally allow sampling the correction points from GT regions
|
|
61
|
+
# instead of the prediction error regions with a small probability. This might allow the
|
|
62
|
+
# model to overfit less to the error regions in training datasets
|
|
63
|
+
prob_to_sample_from_gt_for_train=0.0,
|
|
64
|
+
use_act_ckpt_iterative_pt_sampling=False,
|
|
65
|
+
# whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
|
|
66
|
+
# of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
|
|
67
|
+
forward_backbone_per_frame_for_eval=False,
|
|
68
|
+
freeze_image_encoder=False,
|
|
69
|
+
**kwargs,
|
|
70
|
+
):
|
|
71
|
+
super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
|
|
72
|
+
self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
|
|
73
|
+
self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
|
|
74
|
+
|
|
75
|
+
# Point sampler and conditioning frames
|
|
76
|
+
self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
|
|
77
|
+
self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
|
|
78
|
+
self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
|
|
79
|
+
self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
|
|
80
|
+
if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
|
|
81
|
+
logging.info(f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}")
|
|
82
|
+
assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
|
|
83
|
+
assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
|
|
84
|
+
|
|
85
|
+
self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
|
|
86
|
+
self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
|
|
87
|
+
self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
|
|
88
|
+
self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
|
|
89
|
+
# Initial multi-conditioning frames
|
|
90
|
+
self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
|
|
91
|
+
self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
|
|
92
|
+
self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
|
|
93
|
+
self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
|
|
94
|
+
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
|
95
|
+
self.num_correction_pt_per_frame = num_correction_pt_per_frame
|
|
96
|
+
self.pt_sampling_for_eval = pt_sampling_for_eval
|
|
97
|
+
self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
|
|
98
|
+
# A random number generator with a fixed initial seed across GPUs
|
|
99
|
+
self.rng = np.random.default_rng(seed=42)
|
|
100
|
+
|
|
101
|
+
if freeze_image_encoder:
|
|
102
|
+
for p in self.image_encoder.parameters():
|
|
103
|
+
p.requires_grad = False
|
|
104
|
+
|
|
105
|
+
def forward(self, input: BatchedVideoDatapoint):
|
|
106
|
+
if self.training or not self.forward_backbone_per_frame_for_eval:
|
|
107
|
+
# precompute image features on all frames before tracking
|
|
108
|
+
backbone_out = self.forward_image(input.flat_img_batch)
|
|
109
|
+
else:
|
|
110
|
+
# defer image feature computation on a frame until it's being tracked
|
|
111
|
+
backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
|
|
112
|
+
backbone_out = self.prepare_prompt_inputs(backbone_out, input)
|
|
113
|
+
previous_stages_out = self.forward_tracking(backbone_out, input)
|
|
114
|
+
|
|
115
|
+
return previous_stages_out
|
|
116
|
+
|
|
117
|
+
def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
|
|
118
|
+
"""Compute the image backbone features on the fly for the given img_ids."""
|
|
119
|
+
# Only forward backbone on unique image ids to avoid repetitive computation
|
|
120
|
+
# (if `img_ids` has only one element, it's already unique so we skip this step).
|
|
121
|
+
if img_ids.numel() > 1:
|
|
122
|
+
unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
|
|
123
|
+
else:
|
|
124
|
+
unique_img_ids, inv_ids = img_ids, None
|
|
125
|
+
|
|
126
|
+
# Compute the image features on those unique image ids
|
|
127
|
+
image = img_batch[unique_img_ids]
|
|
128
|
+
backbone_out = self.forward_image(image)
|
|
129
|
+
(
|
|
130
|
+
_,
|
|
131
|
+
vision_feats,
|
|
132
|
+
vision_pos_embeds,
|
|
133
|
+
feat_sizes,
|
|
134
|
+
) = self._prepare_backbone_features(backbone_out)
|
|
135
|
+
# Inverse-map image features for `unique_img_ids` to the final image features
|
|
136
|
+
# for the original input `img_ids`.
|
|
137
|
+
if inv_ids is not None:
|
|
138
|
+
image = image[inv_ids]
|
|
139
|
+
vision_feats = [x[:, inv_ids] for x in vision_feats]
|
|
140
|
+
vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
|
|
141
|
+
|
|
142
|
+
return image, vision_feats, vision_pos_embeds, feat_sizes
|
|
143
|
+
|
|
144
|
+
def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
|
|
145
|
+
"""
|
|
146
|
+
Prepare input mask, point or box prompts. Optionally, we allow tracking from
|
|
147
|
+
a custom `start_frame_idx` to the end of the video (for evaluation purposes).
|
|
148
|
+
"""
|
|
149
|
+
# Load the ground-truth masks on all frames (so that we can later
|
|
150
|
+
# sample correction points from them)
|
|
151
|
+
# gt_masks_per_frame = {
|
|
152
|
+
# stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
|
|
153
|
+
# for stage_id, targets in enumerate(input.find_targets)
|
|
154
|
+
# }
|
|
155
|
+
gt_masks_per_frame = {
|
|
156
|
+
stage_id: masks.unsqueeze(1) for stage_id, masks in enumerate(input.masks) # [B, 1, H_im, W_im]
|
|
157
|
+
}
|
|
158
|
+
# gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
|
|
159
|
+
backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
|
|
160
|
+
num_frames = input.num_frames
|
|
161
|
+
backbone_out["num_frames"] = num_frames
|
|
162
|
+
|
|
163
|
+
# Randomly decide whether to use point inputs or mask inputs
|
|
164
|
+
if self.training:
|
|
165
|
+
prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
|
|
166
|
+
prob_to_use_box_input = self.prob_to_use_box_input_for_train
|
|
167
|
+
num_frames_to_correct = self.num_frames_to_correct_for_train
|
|
168
|
+
rand_frames_to_correct = self.rand_frames_to_correct_for_train
|
|
169
|
+
num_init_cond_frames = self.num_init_cond_frames_for_train
|
|
170
|
+
rand_init_cond_frames = self.rand_init_cond_frames_for_train
|
|
171
|
+
else:
|
|
172
|
+
prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
|
|
173
|
+
prob_to_use_box_input = self.prob_to_use_box_input_for_eval
|
|
174
|
+
num_frames_to_correct = self.num_frames_to_correct_for_eval
|
|
175
|
+
rand_frames_to_correct = self.rand_frames_to_correct_for_eval
|
|
176
|
+
num_init_cond_frames = self.num_init_cond_frames_for_eval
|
|
177
|
+
rand_init_cond_frames = self.rand_init_cond_frames_for_eval
|
|
178
|
+
if num_frames == 1:
|
|
179
|
+
# here we handle a special case for mixing video + SAM on image training,
|
|
180
|
+
# where we force using point input for the SAM task on static images
|
|
181
|
+
prob_to_use_pt_input = 1.0
|
|
182
|
+
num_frames_to_correct = 1
|
|
183
|
+
num_init_cond_frames = 1
|
|
184
|
+
assert num_init_cond_frames >= 1
|
|
185
|
+
# (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
|
|
186
|
+
use_pt_input = self.rng.random() < prob_to_use_pt_input
|
|
187
|
+
if rand_init_cond_frames and num_init_cond_frames > 1:
|
|
188
|
+
# randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
|
|
189
|
+
num_init_cond_frames = self.rng.integers(1, num_init_cond_frames, endpoint=True)
|
|
190
|
+
if use_pt_input and rand_frames_to_correct and num_frames_to_correct > num_init_cond_frames:
|
|
191
|
+
# randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
|
|
192
|
+
# correction clicks (only for the case of point input)
|
|
193
|
+
num_frames_to_correct = self.rng.integers(num_init_cond_frames, num_frames_to_correct, endpoint=True)
|
|
194
|
+
backbone_out["use_pt_input"] = use_pt_input
|
|
195
|
+
|
|
196
|
+
# Sample initial conditioning frames
|
|
197
|
+
if num_init_cond_frames == 1:
|
|
198
|
+
init_cond_frames = [start_frame_idx] # starting frame
|
|
199
|
+
else:
|
|
200
|
+
# starting frame + randomly selected remaining frames (without replacement)
|
|
201
|
+
init_cond_frames = [start_frame_idx] + self.rng.choice(
|
|
202
|
+
range(start_frame_idx + 1, num_frames),
|
|
203
|
+
num_init_cond_frames - 1,
|
|
204
|
+
replace=False,
|
|
205
|
+
).tolist()
|
|
206
|
+
backbone_out["init_cond_frames"] = init_cond_frames
|
|
207
|
+
backbone_out["frames_not_in_init_cond"] = [
|
|
208
|
+
t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
|
|
209
|
+
]
|
|
210
|
+
# Prepare mask or point inputs on initial conditioning frames
|
|
211
|
+
backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
|
|
212
|
+
backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
|
|
213
|
+
for t in init_cond_frames:
|
|
214
|
+
if not use_pt_input:
|
|
215
|
+
backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
|
|
216
|
+
else:
|
|
217
|
+
# During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
|
|
218
|
+
use_box_input = self.rng.random() < prob_to_use_box_input
|
|
219
|
+
if use_box_input:
|
|
220
|
+
points, labels = sample_box_points(
|
|
221
|
+
gt_masks_per_frame[t],
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
# (here we only sample **one initial point** on initial conditioning frames from the
|
|
225
|
+
# ground-truth mask; we may sample more correction points on the fly)
|
|
226
|
+
points, labels = get_next_point(
|
|
227
|
+
gt_masks=gt_masks_per_frame[t],
|
|
228
|
+
pred_masks=None,
|
|
229
|
+
method=("uniform" if self.training else self.pt_sampling_for_eval),
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
point_inputs = {"point_coords": points, "point_labels": labels}
|
|
233
|
+
backbone_out["point_inputs_per_frame"][t] = point_inputs
|
|
234
|
+
|
|
235
|
+
# Sample frames where we will add correction clicks on the fly
|
|
236
|
+
# based on the error between prediction and ground-truth masks
|
|
237
|
+
if not use_pt_input:
|
|
238
|
+
# no correction points will be sampled when using mask inputs
|
|
239
|
+
frames_to_add_correction_pt = []
|
|
240
|
+
elif num_frames_to_correct == num_init_cond_frames:
|
|
241
|
+
frames_to_add_correction_pt = init_cond_frames
|
|
242
|
+
else:
|
|
243
|
+
assert num_frames_to_correct > num_init_cond_frames
|
|
244
|
+
# initial cond frame + randomly selected remaining frames (without replacement)
|
|
245
|
+
extra_num = num_frames_to_correct - num_init_cond_frames
|
|
246
|
+
frames_to_add_correction_pt = (
|
|
247
|
+
init_cond_frames
|
|
248
|
+
+ self.rng.choice(backbone_out["frames_not_in_init_cond"], extra_num, replace=False).tolist()
|
|
249
|
+
)
|
|
250
|
+
backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
|
|
251
|
+
|
|
252
|
+
return backbone_out
|
|
253
|
+
|
|
254
|
+
def forward_tracking(self, backbone_out, input: BatchedVideoDatapoint, return_dict=False):
|
|
255
|
+
"""Forward video tracking on each frame (and sample correction clicks)."""
|
|
256
|
+
img_feats_already_computed = backbone_out["backbone_fpn"] is not None
|
|
257
|
+
if img_feats_already_computed:
|
|
258
|
+
# Prepare the backbone features
|
|
259
|
+
# - vision_feats and vision_pos_embeds are in (HW)BC format
|
|
260
|
+
(
|
|
261
|
+
_,
|
|
262
|
+
vision_feats,
|
|
263
|
+
vision_pos_embeds,
|
|
264
|
+
feat_sizes,
|
|
265
|
+
) = self._prepare_backbone_features(backbone_out)
|
|
266
|
+
|
|
267
|
+
# Starting the stage loop
|
|
268
|
+
num_frames = backbone_out["num_frames"]
|
|
269
|
+
init_cond_frames = backbone_out["init_cond_frames"]
|
|
270
|
+
frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
|
|
271
|
+
# first process all the initial conditioning frames to encode them as memory,
|
|
272
|
+
# and then conditioning on them to track the remaining frames
|
|
273
|
+
processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
|
|
274
|
+
output_dict = {
|
|
275
|
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
276
|
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
277
|
+
}
|
|
278
|
+
for stage_id in processing_order:
|
|
279
|
+
# Get the image features for the current frames
|
|
280
|
+
# img_ids = input.find_inputs[stage_id].img_ids
|
|
281
|
+
img_ids = input.flat_obj_to_img_idx[stage_id]
|
|
282
|
+
if img_feats_already_computed:
|
|
283
|
+
# Retrieve image features according to img_ids (if they are already computed).
|
|
284
|
+
current_vision_feats = [x[:, img_ids] for x in vision_feats]
|
|
285
|
+
current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
|
|
286
|
+
else:
|
|
287
|
+
# Otherwise, compute the image features on the fly for the given img_ids
|
|
288
|
+
# (this might be used for evaluation on long videos to avoid backbone OOM).
|
|
289
|
+
(
|
|
290
|
+
_,
|
|
291
|
+
current_vision_feats,
|
|
292
|
+
current_vision_pos_embeds,
|
|
293
|
+
feat_sizes,
|
|
294
|
+
) = self._prepare_backbone_features_per_frame(input.flat_img_batch, img_ids)
|
|
295
|
+
|
|
296
|
+
# Get output masks based on this frame's prompts and previous memory
|
|
297
|
+
current_out = self.track_step(
|
|
298
|
+
frame_idx=stage_id,
|
|
299
|
+
is_init_cond_frame=stage_id in init_cond_frames,
|
|
300
|
+
current_vision_feats=current_vision_feats,
|
|
301
|
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
|
302
|
+
feat_sizes=feat_sizes,
|
|
303
|
+
point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
|
|
304
|
+
mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
|
|
305
|
+
gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
|
|
306
|
+
frames_to_add_correction_pt=frames_to_add_correction_pt,
|
|
307
|
+
output_dict=output_dict,
|
|
308
|
+
num_frames=num_frames,
|
|
309
|
+
)
|
|
310
|
+
# Append the output, depending on whether it's a conditioning frame
|
|
311
|
+
add_output_as_cond_frame = stage_id in init_cond_frames or (
|
|
312
|
+
self.add_all_frames_to_correct_as_cond and stage_id in frames_to_add_correction_pt
|
|
313
|
+
)
|
|
314
|
+
if add_output_as_cond_frame:
|
|
315
|
+
output_dict["cond_frame_outputs"][stage_id] = current_out
|
|
316
|
+
else:
|
|
317
|
+
output_dict["non_cond_frame_outputs"][stage_id] = current_out
|
|
318
|
+
|
|
319
|
+
if return_dict:
|
|
320
|
+
return output_dict
|
|
321
|
+
# turn `output_dict` into a list for loss function
|
|
322
|
+
all_frame_outputs = {}
|
|
323
|
+
all_frame_outputs.update(output_dict["cond_frame_outputs"])
|
|
324
|
+
all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
|
|
325
|
+
all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
|
|
326
|
+
# Make DDP happy with activation checkpointing by removing unused keys
|
|
327
|
+
all_frame_outputs = [{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs]
|
|
328
|
+
|
|
329
|
+
return all_frame_outputs
|
|
330
|
+
|
|
331
|
+
def track_step(
|
|
332
|
+
self,
|
|
333
|
+
frame_idx,
|
|
334
|
+
is_init_cond_frame,
|
|
335
|
+
current_vision_feats,
|
|
336
|
+
current_vision_pos_embeds,
|
|
337
|
+
feat_sizes,
|
|
338
|
+
point_inputs,
|
|
339
|
+
mask_inputs,
|
|
340
|
+
output_dict,
|
|
341
|
+
num_frames,
|
|
342
|
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
|
343
|
+
run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
|
|
344
|
+
prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
|
|
345
|
+
frames_to_add_correction_pt=None,
|
|
346
|
+
gt_masks=None,
|
|
347
|
+
):
|
|
348
|
+
if frames_to_add_correction_pt is None:
|
|
349
|
+
frames_to_add_correction_pt = []
|
|
350
|
+
current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
|
|
351
|
+
frame_idx,
|
|
352
|
+
is_init_cond_frame,
|
|
353
|
+
current_vision_feats,
|
|
354
|
+
current_vision_pos_embeds,
|
|
355
|
+
feat_sizes,
|
|
356
|
+
point_inputs,
|
|
357
|
+
mask_inputs,
|
|
358
|
+
output_dict,
|
|
359
|
+
num_frames,
|
|
360
|
+
track_in_reverse,
|
|
361
|
+
prev_sam_mask_logits,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
(
|
|
365
|
+
low_res_multimasks,
|
|
366
|
+
high_res_multimasks,
|
|
367
|
+
ious,
|
|
368
|
+
low_res_masks,
|
|
369
|
+
high_res_masks,
|
|
370
|
+
obj_ptr,
|
|
371
|
+
object_score_logits,
|
|
372
|
+
) = sam_outputs
|
|
373
|
+
|
|
374
|
+
current_out["multistep_pred_masks"] = low_res_masks
|
|
375
|
+
current_out["multistep_pred_masks_high_res"] = high_res_masks
|
|
376
|
+
current_out["multistep_pred_multimasks"] = [low_res_multimasks]
|
|
377
|
+
current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
|
|
378
|
+
current_out["multistep_pred_ious"] = [ious]
|
|
379
|
+
current_out["multistep_point_inputs"] = [point_inputs]
|
|
380
|
+
current_out["multistep_object_score_logits"] = [object_score_logits]
|
|
381
|
+
|
|
382
|
+
# Optionally, sample correction points iteratively to correct the mask
|
|
383
|
+
if frame_idx in frames_to_add_correction_pt:
|
|
384
|
+
point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
|
|
385
|
+
is_init_cond_frame,
|
|
386
|
+
point_inputs,
|
|
387
|
+
gt_masks,
|
|
388
|
+
high_res_features,
|
|
389
|
+
pix_feat,
|
|
390
|
+
low_res_multimasks,
|
|
391
|
+
high_res_multimasks,
|
|
392
|
+
ious,
|
|
393
|
+
low_res_masks,
|
|
394
|
+
high_res_masks,
|
|
395
|
+
object_score_logits,
|
|
396
|
+
current_out,
|
|
397
|
+
)
|
|
398
|
+
(
|
|
399
|
+
_,
|
|
400
|
+
_,
|
|
401
|
+
_,
|
|
402
|
+
low_res_masks,
|
|
403
|
+
high_res_masks,
|
|
404
|
+
obj_ptr,
|
|
405
|
+
object_score_logits,
|
|
406
|
+
) = final_sam_outputs
|
|
407
|
+
|
|
408
|
+
# Use the final prediction (after all correction steps for output and eval)
|
|
409
|
+
current_out["pred_masks"] = low_res_masks
|
|
410
|
+
current_out["pred_masks_high_res"] = high_res_masks
|
|
411
|
+
current_out["obj_ptr"] = obj_ptr
|
|
412
|
+
|
|
413
|
+
# Finally run the memory encoder on the predicted mask to encode
|
|
414
|
+
# it into a new memory feature (that can be used in future frames)
|
|
415
|
+
self._encode_memory_in_output(
|
|
416
|
+
current_vision_feats,
|
|
417
|
+
feat_sizes,
|
|
418
|
+
point_inputs,
|
|
419
|
+
run_mem_encoder,
|
|
420
|
+
high_res_masks,
|
|
421
|
+
object_score_logits,
|
|
422
|
+
current_out,
|
|
423
|
+
)
|
|
424
|
+
return current_out
|
|
425
|
+
|
|
426
|
+
def _iter_correct_pt_sampling(
|
|
427
|
+
self,
|
|
428
|
+
is_init_cond_frame,
|
|
429
|
+
point_inputs,
|
|
430
|
+
gt_masks,
|
|
431
|
+
high_res_features,
|
|
432
|
+
pix_feat_with_mem,
|
|
433
|
+
low_res_multimasks,
|
|
434
|
+
high_res_multimasks,
|
|
435
|
+
ious,
|
|
436
|
+
low_res_masks,
|
|
437
|
+
high_res_masks,
|
|
438
|
+
object_score_logits,
|
|
439
|
+
current_out,
|
|
440
|
+
):
|
|
441
|
+
|
|
442
|
+
assert gt_masks is not None
|
|
443
|
+
all_pred_masks = [low_res_masks]
|
|
444
|
+
all_pred_high_res_masks = [high_res_masks]
|
|
445
|
+
all_pred_multimasks = [low_res_multimasks]
|
|
446
|
+
all_pred_high_res_multimasks = [high_res_multimasks]
|
|
447
|
+
all_pred_ious = [ious]
|
|
448
|
+
all_point_inputs = [point_inputs]
|
|
449
|
+
all_object_score_logits = [object_score_logits]
|
|
450
|
+
for _ in range(self.num_correction_pt_per_frame):
|
|
451
|
+
# sample a new point from the error between prediction and ground-truth
|
|
452
|
+
# (with a small probability, directly sample from GT masks instead of errors)
|
|
453
|
+
if self.training and self.prob_to_sample_from_gt_for_train > 0:
|
|
454
|
+
sample_from_gt = self.rng.random() < self.prob_to_sample_from_gt_for_train
|
|
455
|
+
else:
|
|
456
|
+
sample_from_gt = False
|
|
457
|
+
# if `pred_for_new_pt` is None, only GT masks will be used for point sampling
|
|
458
|
+
pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
|
|
459
|
+
new_points, new_labels = get_next_point(
|
|
460
|
+
gt_masks=gt_masks,
|
|
461
|
+
pred_masks=pred_for_new_pt,
|
|
462
|
+
method="uniform" if self.training else self.pt_sampling_for_eval,
|
|
463
|
+
)
|
|
464
|
+
point_inputs = concat_points(point_inputs, new_points, new_labels)
|
|
465
|
+
# Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
|
|
466
|
+
# For tracking, this means that when the user adds a correction click, we also feed
|
|
467
|
+
# the tracking output mask logits along with the click as input to the SAM decoder.
|
|
468
|
+
mask_inputs = low_res_masks
|
|
469
|
+
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
|
470
|
+
if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
|
|
471
|
+
sam_outputs = torch.utils.checkpoint.checkpoint(
|
|
472
|
+
self._forward_sam_heads,
|
|
473
|
+
backbone_features=pix_feat_with_mem,
|
|
474
|
+
point_inputs=point_inputs,
|
|
475
|
+
mask_inputs=mask_inputs,
|
|
476
|
+
high_res_features=high_res_features,
|
|
477
|
+
multimask_output=multimask_output,
|
|
478
|
+
use_reentrant=False,
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
sam_outputs = self._forward_sam_heads(
|
|
482
|
+
backbone_features=pix_feat_with_mem,
|
|
483
|
+
point_inputs=point_inputs,
|
|
484
|
+
mask_inputs=mask_inputs,
|
|
485
|
+
high_res_features=high_res_features,
|
|
486
|
+
multimask_output=multimask_output,
|
|
487
|
+
)
|
|
488
|
+
(
|
|
489
|
+
low_res_multimasks,
|
|
490
|
+
high_res_multimasks,
|
|
491
|
+
ious,
|
|
492
|
+
low_res_masks,
|
|
493
|
+
high_res_masks,
|
|
494
|
+
_,
|
|
495
|
+
object_score_logits,
|
|
496
|
+
) = sam_outputs
|
|
497
|
+
all_pred_masks.append(low_res_masks)
|
|
498
|
+
all_pred_high_res_masks.append(high_res_masks)
|
|
499
|
+
all_pred_multimasks.append(low_res_multimasks)
|
|
500
|
+
all_pred_high_res_multimasks.append(high_res_multimasks)
|
|
501
|
+
all_pred_ious.append(ious)
|
|
502
|
+
all_point_inputs.append(point_inputs)
|
|
503
|
+
all_object_score_logits.append(object_score_logits)
|
|
504
|
+
|
|
505
|
+
# Concatenate the masks along channel (to compute losses on all of them,
|
|
506
|
+
# using `MultiStepIteractiveMasks`)
|
|
507
|
+
current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
|
|
508
|
+
current_out["multistep_pred_masks_high_res"] = torch.cat(all_pred_high_res_masks, dim=1)
|
|
509
|
+
current_out["multistep_pred_multimasks"] = all_pred_multimasks
|
|
510
|
+
current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
|
|
511
|
+
current_out["multistep_pred_ious"] = all_pred_ious
|
|
512
|
+
current_out["multistep_point_inputs"] = all_point_inputs
|
|
513
|
+
current_out["multistep_object_score_logits"] = all_object_score_logits
|
|
514
|
+
|
|
515
|
+
return point_inputs, sam_outputs
|