singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import Any, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PositionEmbeddingSine(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
This is a more standard version of the position embedding, very similar to the one
|
|
19
|
+
used by the Attention Is All You Need paper, generalized to work on images.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
num_pos_feats,
|
|
25
|
+
temperature: int = 10000,
|
|
26
|
+
normalize: bool = True,
|
|
27
|
+
scale: Optional[float] = None,
|
|
28
|
+
# Following settings only relevant
|
|
29
|
+
# for warmping up cache for compilation
|
|
30
|
+
warmup_cache: bool = True,
|
|
31
|
+
image_size: int = 1024,
|
|
32
|
+
strides: Tuple[int] = (4, 8, 16, 32),
|
|
33
|
+
):
|
|
34
|
+
super().__init__()
|
|
35
|
+
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
|
36
|
+
self.num_pos_feats = num_pos_feats // 2
|
|
37
|
+
self.temperature = temperature
|
|
38
|
+
self.normalize = normalize
|
|
39
|
+
if scale is not None and normalize is False:
|
|
40
|
+
raise ValueError("normalize should be True if scale is passed")
|
|
41
|
+
if scale is None:
|
|
42
|
+
scale = 2 * math.pi
|
|
43
|
+
self.scale = scale
|
|
44
|
+
|
|
45
|
+
self.cache = {}
|
|
46
|
+
if warmup_cache and torch.cuda.is_available():
|
|
47
|
+
# Warmup cache for cuda, to help with compilation
|
|
48
|
+
device = torch.device("cuda")
|
|
49
|
+
for stride in strides:
|
|
50
|
+
cache_key = (image_size // stride, image_size // stride)
|
|
51
|
+
self._pe(1, device, *cache_key)
|
|
52
|
+
|
|
53
|
+
def _encode_xy(self, x, y):
|
|
54
|
+
# The positions are expected to be normalized
|
|
55
|
+
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
|
56
|
+
x_embed = x * self.scale
|
|
57
|
+
y_embed = y * self.scale
|
|
58
|
+
|
|
59
|
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
60
|
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
61
|
+
|
|
62
|
+
pos_x = x_embed[:, None] / dim_t
|
|
63
|
+
pos_y = y_embed[:, None] / dim_t
|
|
64
|
+
pos_x = torch.stack(
|
|
65
|
+
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
|
|
66
|
+
).flatten(1)
|
|
67
|
+
pos_y = torch.stack(
|
|
68
|
+
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
|
|
69
|
+
).flatten(1)
|
|
70
|
+
return pos_x, pos_y
|
|
71
|
+
|
|
72
|
+
@torch.no_grad()
|
|
73
|
+
def encode_boxes(self, x, y, w, h):
|
|
74
|
+
pos_x, pos_y = self._encode_xy(x, y)
|
|
75
|
+
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
|
76
|
+
return pos
|
|
77
|
+
|
|
78
|
+
encode = encode_boxes # Backwards compatibility
|
|
79
|
+
|
|
80
|
+
@torch.no_grad()
|
|
81
|
+
def encode_points(self, x, y, labels):
|
|
82
|
+
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
|
83
|
+
assert bx == by and nx == ny and bx == bl and nx == nl
|
|
84
|
+
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
|
85
|
+
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
|
86
|
+
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
|
87
|
+
return pos
|
|
88
|
+
|
|
89
|
+
@torch.no_grad()
|
|
90
|
+
def _pe(self, B, device, *cache_key):
|
|
91
|
+
H, W = cache_key
|
|
92
|
+
if cache_key in self.cache:
|
|
93
|
+
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
|
94
|
+
|
|
95
|
+
y_embed = (
|
|
96
|
+
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
|
97
|
+
.view(1, -1, 1)
|
|
98
|
+
.repeat(B, 1, W)
|
|
99
|
+
)
|
|
100
|
+
x_embed = (
|
|
101
|
+
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
|
102
|
+
.view(1, 1, -1)
|
|
103
|
+
.repeat(B, H, 1)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if self.normalize:
|
|
107
|
+
eps = 1e-6
|
|
108
|
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
109
|
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
110
|
+
|
|
111
|
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
|
|
112
|
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
113
|
+
|
|
114
|
+
pos_x = x_embed[:, :, :, None] / dim_t
|
|
115
|
+
pos_y = y_embed[:, :, :, None] / dim_t
|
|
116
|
+
pos_x = torch.stack(
|
|
117
|
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
|
118
|
+
).flatten(3)
|
|
119
|
+
pos_y = torch.stack(
|
|
120
|
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
121
|
+
).flatten(3)
|
|
122
|
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
123
|
+
self.cache[cache_key] = pos[0]
|
|
124
|
+
return pos
|
|
125
|
+
|
|
126
|
+
@torch.no_grad()
|
|
127
|
+
def forward(self, x: torch.Tensor):
|
|
128
|
+
B = x.shape[0]
|
|
129
|
+
cache_key = (x.shape[-2], x.shape[-1])
|
|
130
|
+
return self._pe(B, x.device, *cache_key)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class PositionEmbeddingRandom(nn.Module):
|
|
134
|
+
"""
|
|
135
|
+
Positional encoding using random spatial frequencies.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
|
139
|
+
super().__init__()
|
|
140
|
+
if scale is None or scale <= 0.0:
|
|
141
|
+
scale = 1.0
|
|
142
|
+
self.register_buffer(
|
|
143
|
+
"positional_encoding_gaussian_matrix",
|
|
144
|
+
scale * torch.randn((2, num_pos_feats)),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
|
148
|
+
"""Positionally encode points that are normalized to [0,1]."""
|
|
149
|
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
|
150
|
+
coords = 2 * coords - 1
|
|
151
|
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
|
152
|
+
coords = 2 * np.pi * coords
|
|
153
|
+
# outputs d_1 x ... x d_n x C shape
|
|
154
|
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
|
155
|
+
|
|
156
|
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
|
157
|
+
"""Generate positional encoding for a grid of the specified size."""
|
|
158
|
+
h, w = size
|
|
159
|
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
|
160
|
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
|
161
|
+
y_embed = grid.cumsum(dim=0) - 0.5
|
|
162
|
+
x_embed = grid.cumsum(dim=1) - 0.5
|
|
163
|
+
y_embed = y_embed / h
|
|
164
|
+
x_embed = x_embed / w
|
|
165
|
+
|
|
166
|
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
|
167
|
+
return pe.permute(2, 0, 1) # C x H x W
|
|
168
|
+
|
|
169
|
+
def forward_with_coords(
|
|
170
|
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
|
171
|
+
) -> torch.Tensor:
|
|
172
|
+
"""Positionally encode points that are not normalized to [0,1]."""
|
|
173
|
+
coords = coords_input.clone()
|
|
174
|
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
|
175
|
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
|
176
|
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# Rotary Positional Encoding, adapted from:
|
|
180
|
+
# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
|
181
|
+
# 2. https://github.com/naver-ai/rope-vit
|
|
182
|
+
# 3. https://github.com/lucidrains/rotary-embedding-torch
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def init_t_xy(end_x: int, end_y: int):
|
|
186
|
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
|
187
|
+
t_x = (t % end_x).float()
|
|
188
|
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
|
189
|
+
return t_x, t_y
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
193
|
+
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
194
|
+
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
195
|
+
|
|
196
|
+
t_x, t_y = init_t_xy(end_x, end_y)
|
|
197
|
+
freqs_x = torch.outer(t_x, freqs_x)
|
|
198
|
+
freqs_y = torch.outer(t_y, freqs_y)
|
|
199
|
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
|
200
|
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
|
201
|
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
205
|
+
ndim = x.ndim
|
|
206
|
+
assert 0 <= 1 < ndim
|
|
207
|
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
|
208
|
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
|
209
|
+
return freqs_cis.view(*shape)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def apply_rotary_enc(
|
|
213
|
+
xq: torch.Tensor,
|
|
214
|
+
xk: torch.Tensor,
|
|
215
|
+
freqs_cis: torch.Tensor,
|
|
216
|
+
repeat_freqs_k: bool = False,
|
|
217
|
+
):
|
|
218
|
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
219
|
+
xk_ = (
|
|
220
|
+
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
221
|
+
if xk.shape[-2] != 0
|
|
222
|
+
else None
|
|
223
|
+
)
|
|
224
|
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
|
225
|
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
|
226
|
+
if xk_ is None:
|
|
227
|
+
# no keys to rotate, due to dropout
|
|
228
|
+
return xq_out.type_as(xq).to(xq.device), xk
|
|
229
|
+
# repeat freqs along seq_len dim to match k seq_len
|
|
230
|
+
if repeat_freqs_k:
|
|
231
|
+
r = xk_.shape[-2] // xq_.shape[-2]
|
|
232
|
+
if freqs_cis.is_cuda:
|
|
233
|
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
|
234
|
+
else:
|
|
235
|
+
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
|
236
|
+
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
|
237
|
+
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
|
238
|
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
239
|
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Tuple, Type
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from sam2.modeling.sam2_utils import LayerNorm2d, MLP
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MaskDecoder(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
*,
|
|
19
|
+
transformer_dim: int,
|
|
20
|
+
transformer: nn.Module,
|
|
21
|
+
num_multimask_outputs: int = 3,
|
|
22
|
+
activation: Type[nn.Module] = nn.GELU,
|
|
23
|
+
iou_head_depth: int = 3,
|
|
24
|
+
iou_head_hidden_dim: int = 256,
|
|
25
|
+
use_high_res_features: bool = False,
|
|
26
|
+
iou_prediction_use_sigmoid=False,
|
|
27
|
+
dynamic_multimask_via_stability=False,
|
|
28
|
+
dynamic_multimask_stability_delta=0.05,
|
|
29
|
+
dynamic_multimask_stability_thresh=0.98,
|
|
30
|
+
pred_obj_scores: bool = False,
|
|
31
|
+
pred_obj_scores_mlp: bool = False,
|
|
32
|
+
use_multimask_token_for_obj_ptr: bool = False,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Predicts masks given an image and prompt embeddings, using a
|
|
36
|
+
transformer architecture.
|
|
37
|
+
|
|
38
|
+
Arguments:
|
|
39
|
+
transformer_dim (int): the channel dimension of the transformer
|
|
40
|
+
transformer (nn.Module): the transformer used to predict masks
|
|
41
|
+
num_multimask_outputs (int): the number of masks to predict
|
|
42
|
+
when disambiguating masks
|
|
43
|
+
activation (nn.Module): the type of activation to use when
|
|
44
|
+
upscaling masks
|
|
45
|
+
iou_head_depth (int): the depth of the MLP used to predict
|
|
46
|
+
mask quality
|
|
47
|
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
|
48
|
+
used to predict mask quality
|
|
49
|
+
"""
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.transformer_dim = transformer_dim
|
|
52
|
+
self.transformer = transformer
|
|
53
|
+
|
|
54
|
+
self.num_multimask_outputs = num_multimask_outputs
|
|
55
|
+
|
|
56
|
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
|
57
|
+
self.num_mask_tokens = num_multimask_outputs + 1
|
|
58
|
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
|
59
|
+
|
|
60
|
+
self.pred_obj_scores = pred_obj_scores
|
|
61
|
+
if self.pred_obj_scores:
|
|
62
|
+
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
|
63
|
+
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
|
64
|
+
|
|
65
|
+
self.output_upscaling = nn.Sequential(
|
|
66
|
+
nn.ConvTranspose2d(
|
|
67
|
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
|
68
|
+
),
|
|
69
|
+
LayerNorm2d(transformer_dim // 4),
|
|
70
|
+
activation(),
|
|
71
|
+
nn.ConvTranspose2d(
|
|
72
|
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
|
73
|
+
),
|
|
74
|
+
activation(),
|
|
75
|
+
)
|
|
76
|
+
self.use_high_res_features = use_high_res_features
|
|
77
|
+
if use_high_res_features:
|
|
78
|
+
self.conv_s0 = nn.Conv2d(
|
|
79
|
+
transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
|
|
80
|
+
)
|
|
81
|
+
self.conv_s1 = nn.Conv2d(
|
|
82
|
+
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
|
86
|
+
[
|
|
87
|
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
|
88
|
+
for i in range(self.num_mask_tokens)
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.iou_prediction_head = MLP(
|
|
93
|
+
transformer_dim,
|
|
94
|
+
iou_head_hidden_dim,
|
|
95
|
+
self.num_mask_tokens,
|
|
96
|
+
iou_head_depth,
|
|
97
|
+
sigmoid_output=iou_prediction_use_sigmoid,
|
|
98
|
+
)
|
|
99
|
+
if self.pred_obj_scores:
|
|
100
|
+
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
|
101
|
+
if pred_obj_scores_mlp:
|
|
102
|
+
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
|
103
|
+
|
|
104
|
+
# When outputting a single mask, optionally we can dynamically fall back to the best
|
|
105
|
+
# multimask output token if the single mask output token gives low stability scores.
|
|
106
|
+
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
|
107
|
+
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
|
108
|
+
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
|
109
|
+
|
|
110
|
+
def forward(
|
|
111
|
+
self,
|
|
112
|
+
image_embeddings: torch.Tensor,
|
|
113
|
+
image_pe: torch.Tensor,
|
|
114
|
+
sparse_prompt_embeddings: torch.Tensor,
|
|
115
|
+
dense_prompt_embeddings: torch.Tensor,
|
|
116
|
+
multimask_output: bool,
|
|
117
|
+
repeat_image: bool,
|
|
118
|
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
|
119
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
120
|
+
"""
|
|
121
|
+
Predict masks given image and prompt embeddings.
|
|
122
|
+
|
|
123
|
+
Arguments:
|
|
124
|
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
|
125
|
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
|
126
|
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
|
127
|
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
|
128
|
+
multimask_output (bool): Whether to return multiple masks or a single
|
|
129
|
+
mask.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
torch.Tensor: batched predicted masks
|
|
133
|
+
torch.Tensor: batched predictions of mask quality
|
|
134
|
+
torch.Tensor: batched SAM token for mask output
|
|
135
|
+
"""
|
|
136
|
+
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
|
137
|
+
image_embeddings=image_embeddings,
|
|
138
|
+
image_pe=image_pe,
|
|
139
|
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
|
140
|
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
|
141
|
+
repeat_image=repeat_image,
|
|
142
|
+
high_res_features=high_res_features,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Select the correct mask or masks for output
|
|
146
|
+
if multimask_output:
|
|
147
|
+
masks = masks[:, 1:, :, :]
|
|
148
|
+
iou_pred = iou_pred[:, 1:]
|
|
149
|
+
elif self.dynamic_multimask_via_stability and not self.training:
|
|
150
|
+
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
|
151
|
+
else:
|
|
152
|
+
masks = masks[:, 0:1, :, :]
|
|
153
|
+
iou_pred = iou_pred[:, 0:1]
|
|
154
|
+
|
|
155
|
+
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
|
156
|
+
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
|
157
|
+
else:
|
|
158
|
+
# Take the mask output token. Here we *always* use the token for single mask output.
|
|
159
|
+
# At test time, even if we track after 1-click (and using multimask_output=True),
|
|
160
|
+
# we still take the single mask token here. The rationale is that we always track
|
|
161
|
+
# after multiple clicks during training, so the past tokens seen during training
|
|
162
|
+
# are always the single mask token (and we'll let it be the object-memory token).
|
|
163
|
+
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
|
164
|
+
|
|
165
|
+
# Prepare output
|
|
166
|
+
return masks, iou_pred, sam_tokens_out, object_score_logits
|
|
167
|
+
|
|
168
|
+
def predict_masks(
|
|
169
|
+
self,
|
|
170
|
+
image_embeddings: torch.Tensor,
|
|
171
|
+
image_pe: torch.Tensor,
|
|
172
|
+
sparse_prompt_embeddings: torch.Tensor,
|
|
173
|
+
dense_prompt_embeddings: torch.Tensor,
|
|
174
|
+
repeat_image: bool,
|
|
175
|
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
|
176
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
177
|
+
"""Predicts masks. See 'forward' for more details."""
|
|
178
|
+
# Concatenate output tokens
|
|
179
|
+
s = 0
|
|
180
|
+
if self.pred_obj_scores:
|
|
181
|
+
output_tokens = torch.cat(
|
|
182
|
+
[
|
|
183
|
+
self.obj_score_token.weight,
|
|
184
|
+
self.iou_token.weight,
|
|
185
|
+
self.mask_tokens.weight,
|
|
186
|
+
],
|
|
187
|
+
dim=0,
|
|
188
|
+
)
|
|
189
|
+
s = 1
|
|
190
|
+
else:
|
|
191
|
+
output_tokens = torch.cat(
|
|
192
|
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
|
193
|
+
)
|
|
194
|
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
|
195
|
+
sparse_prompt_embeddings.size(0), -1, -1
|
|
196
|
+
)
|
|
197
|
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
|
198
|
+
|
|
199
|
+
# Expand per-image data in batch direction to be per-mask
|
|
200
|
+
if repeat_image:
|
|
201
|
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
|
202
|
+
else:
|
|
203
|
+
assert image_embeddings.shape[0] == tokens.shape[0]
|
|
204
|
+
src = image_embeddings
|
|
205
|
+
src = src + dense_prompt_embeddings
|
|
206
|
+
assert (
|
|
207
|
+
image_pe.size(0) == 1
|
|
208
|
+
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
|
209
|
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
|
210
|
+
b, c, h, w = src.shape
|
|
211
|
+
|
|
212
|
+
# Run the transformer
|
|
213
|
+
hs, src = self.transformer(src, pos_src, tokens)
|
|
214
|
+
iou_token_out = hs[:, s, :]
|
|
215
|
+
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
|
216
|
+
|
|
217
|
+
# Upscale mask embeddings and predict masks using the mask tokens
|
|
218
|
+
src = src.transpose(1, 2).view(b, c, h, w)
|
|
219
|
+
if not self.use_high_res_features:
|
|
220
|
+
upscaled_embedding = self.output_upscaling(src)
|
|
221
|
+
else:
|
|
222
|
+
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
|
223
|
+
feat_s0, feat_s1 = high_res_features
|
|
224
|
+
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
|
225
|
+
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
|
226
|
+
|
|
227
|
+
hyper_in_list: List[torch.Tensor] = []
|
|
228
|
+
for i in range(self.num_mask_tokens):
|
|
229
|
+
hyper_in_list.append(
|
|
230
|
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
|
231
|
+
)
|
|
232
|
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
|
233
|
+
b, c, h, w = upscaled_embedding.shape
|
|
234
|
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
|
235
|
+
|
|
236
|
+
# Generate mask quality predictions
|
|
237
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
238
|
+
if self.pred_obj_scores:
|
|
239
|
+
assert s == 1
|
|
240
|
+
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
|
241
|
+
else:
|
|
242
|
+
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
|
243
|
+
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
|
244
|
+
|
|
245
|
+
return masks, iou_pred, mask_tokens_out, object_score_logits
|
|
246
|
+
|
|
247
|
+
def _get_stability_scores(self, mask_logits):
|
|
248
|
+
"""
|
|
249
|
+
Compute stability scores of the mask logits based on the IoU between upper and
|
|
250
|
+
lower thresholds.
|
|
251
|
+
"""
|
|
252
|
+
mask_logits = mask_logits.flatten(-2)
|
|
253
|
+
stability_delta = self.dynamic_multimask_stability_delta
|
|
254
|
+
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
|
255
|
+
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
|
256
|
+
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
|
257
|
+
return stability_scores
|
|
258
|
+
|
|
259
|
+
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
|
260
|
+
"""
|
|
261
|
+
When outputting a single mask, if the stability score from the current single-mask
|
|
262
|
+
output (based on output token 0) falls below a threshold, we instead select from
|
|
263
|
+
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
|
264
|
+
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
|
265
|
+
"""
|
|
266
|
+
# The best mask from multimask output tokens (1~3)
|
|
267
|
+
multimask_logits = all_mask_logits[:, 1:, :, :]
|
|
268
|
+
multimask_iou_scores = all_iou_scores[:, 1:]
|
|
269
|
+
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
|
270
|
+
batch_inds = torch.arange(
|
|
271
|
+
multimask_iou_scores.size(0), device=all_iou_scores.device
|
|
272
|
+
)
|
|
273
|
+
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
|
274
|
+
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
|
275
|
+
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
|
276
|
+
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
|
277
|
+
|
|
278
|
+
# The mask from singlemask output token 0 and its stability score
|
|
279
|
+
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
|
280
|
+
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
|
281
|
+
stability_scores = self._get_stability_scores(singlemask_logits)
|
|
282
|
+
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
|
283
|
+
|
|
284
|
+
# Dynamically fall back to best multimask output upon low stability scores.
|
|
285
|
+
mask_logits_out = torch.where(
|
|
286
|
+
is_stable[..., None, None].expand_as(singlemask_logits),
|
|
287
|
+
singlemask_logits,
|
|
288
|
+
best_multimask_logits,
|
|
289
|
+
)
|
|
290
|
+
iou_scores_out = torch.where(
|
|
291
|
+
is_stable.expand_as(singlemask_iou_scores),
|
|
292
|
+
singlemask_iou_scores,
|
|
293
|
+
best_multimask_iou_scores,
|
|
294
|
+
)
|
|
295
|
+
return mask_logits_out, iou_scores_out
|