frontveg 0.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontveg/__init__.py +11 -0
- frontveg/_tests/__init__.py +0 -0
- frontveg/_tests/test_widget.py +66 -0
- frontveg/_version.py +21 -0
- frontveg/_widget.py +132 -0
- frontveg/napari.yaml +14 -0
- frontveg/utils.py +95 -0
- frontveg-0.1.dev1.dist-info/METADATA +143 -0
- frontveg-0.1.dev1.dist-info/RECORD +44 -0
- frontveg-0.1.dev1.dist-info/WHEEL +5 -0
- frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
- frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
- frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/build_sam.py +167 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +95 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +221 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +182 -0
- sam2/modeling/sam/transformer.py +360 -0
- sam2/modeling/sam2_base.py +907 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +1 -0
- sam2/sam2_hiera_l.yaml +1 -0
- sam2/sam2_hiera_s.yaml +1 -0
- sam2/sam2_hiera_t.yaml +1 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
@@ -0,0 +1,360 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
|
4
|
+
# This source code is licensed under the license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import contextlib
|
8
|
+
import math
|
9
|
+
import warnings
|
10
|
+
from functools import partial
|
11
|
+
from typing import Tuple, Type
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import torch.nn.functional as F
|
15
|
+
from torch import nn, Tensor
|
16
|
+
|
17
|
+
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
18
|
+
from sam2.modeling.sam2_utils import MLP
|
19
|
+
from sam2.utils.misc import get_sdpa_settings
|
20
|
+
|
21
|
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
22
|
+
# Check whether Flash Attention is available (and use it by default)
|
23
|
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
24
|
+
# A fallback setting to allow all available kernels if Flash Attention fails
|
25
|
+
ALLOW_ALL_KERNELS = False
|
26
|
+
|
27
|
+
|
28
|
+
def sdp_kernel_context(dropout_p):
|
29
|
+
"""
|
30
|
+
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
31
|
+
by default, but fall back to all available kernels if Flash Attention fails.
|
32
|
+
"""
|
33
|
+
if ALLOW_ALL_KERNELS:
|
34
|
+
return contextlib.nullcontext()
|
35
|
+
|
36
|
+
return torch.backends.cuda.sdp_kernel(
|
37
|
+
enable_flash=USE_FLASH_ATTN,
|
38
|
+
# if Flash attention kernel is off, then math kernel needs to be enabled
|
39
|
+
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
40
|
+
enable_mem_efficient=OLD_GPU,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class TwoWayTransformer(nn.Module):
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
depth: int,
|
48
|
+
embedding_dim: int,
|
49
|
+
num_heads: int,
|
50
|
+
mlp_dim: int,
|
51
|
+
activation: Type[nn.Module] = nn.ReLU,
|
52
|
+
attention_downsample_rate: int = 2,
|
53
|
+
) -> None:
|
54
|
+
"""
|
55
|
+
A transformer decoder that attends to an input image using
|
56
|
+
queries whose positional embedding is supplied.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
depth (int): number of layers in the transformer
|
60
|
+
embedding_dim (int): the channel dimension for the input embeddings
|
61
|
+
num_heads (int): the number of heads for multihead attention. Must
|
62
|
+
divide embedding_dim
|
63
|
+
mlp_dim (int): the channel dimension internal to the MLP block
|
64
|
+
activation (nn.Module): the activation to use in the MLP block
|
65
|
+
"""
|
66
|
+
super().__init__()
|
67
|
+
self.depth = depth
|
68
|
+
self.embedding_dim = embedding_dim
|
69
|
+
self.num_heads = num_heads
|
70
|
+
self.mlp_dim = mlp_dim
|
71
|
+
self.layers = nn.ModuleList()
|
72
|
+
|
73
|
+
for i in range(depth):
|
74
|
+
self.layers.append(
|
75
|
+
TwoWayAttentionBlock(
|
76
|
+
embedding_dim=embedding_dim,
|
77
|
+
num_heads=num_heads,
|
78
|
+
mlp_dim=mlp_dim,
|
79
|
+
activation=activation,
|
80
|
+
attention_downsample_rate=attention_downsample_rate,
|
81
|
+
skip_first_layer_pe=(i == 0),
|
82
|
+
)
|
83
|
+
)
|
84
|
+
|
85
|
+
self.final_attn_token_to_image = Attention(
|
86
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
87
|
+
)
|
88
|
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
89
|
+
|
90
|
+
def forward(
|
91
|
+
self,
|
92
|
+
image_embedding: Tensor,
|
93
|
+
image_pe: Tensor,
|
94
|
+
point_embedding: Tensor,
|
95
|
+
) -> Tuple[Tensor, Tensor]:
|
96
|
+
"""
|
97
|
+
Args:
|
98
|
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
99
|
+
B x embedding_dim x h x w for any h and w.
|
100
|
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
101
|
+
have the same shape as image_embedding.
|
102
|
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
103
|
+
Must have shape B x N_points x embedding_dim for any N_points.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
torch.Tensor: the processed point_embedding
|
107
|
+
torch.Tensor: the processed image_embedding
|
108
|
+
"""
|
109
|
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
110
|
+
bs, c, h, w = image_embedding.shape
|
111
|
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
112
|
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
113
|
+
|
114
|
+
# Prepare queries
|
115
|
+
queries = point_embedding
|
116
|
+
keys = image_embedding
|
117
|
+
|
118
|
+
# Apply transformer blocks and final layernorm
|
119
|
+
for layer in self.layers:
|
120
|
+
queries, keys = layer(
|
121
|
+
queries=queries,
|
122
|
+
keys=keys,
|
123
|
+
query_pe=point_embedding,
|
124
|
+
key_pe=image_pe,
|
125
|
+
)
|
126
|
+
|
127
|
+
# Apply the final attention layer from the points to the image
|
128
|
+
q = queries + point_embedding
|
129
|
+
k = keys + image_pe
|
130
|
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
131
|
+
queries = queries + attn_out
|
132
|
+
queries = self.norm_final_attn(queries)
|
133
|
+
|
134
|
+
return queries, keys
|
135
|
+
|
136
|
+
|
137
|
+
class TwoWayAttentionBlock(nn.Module):
|
138
|
+
def __init__(
|
139
|
+
self,
|
140
|
+
embedding_dim: int,
|
141
|
+
num_heads: int,
|
142
|
+
mlp_dim: int = 2048,
|
143
|
+
activation: Type[nn.Module] = nn.ReLU,
|
144
|
+
attention_downsample_rate: int = 2,
|
145
|
+
skip_first_layer_pe: bool = False,
|
146
|
+
) -> None:
|
147
|
+
"""
|
148
|
+
A transformer block with four layers: (1) self-attention of sparse
|
149
|
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
150
|
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
151
|
+
inputs.
|
152
|
+
|
153
|
+
Arguments:
|
154
|
+
embedding_dim (int): the channel dimension of the embeddings
|
155
|
+
num_heads (int): the number of heads in the attention layers
|
156
|
+
mlp_dim (int): the hidden dimension of the mlp block
|
157
|
+
activation (nn.Module): the activation of the mlp block
|
158
|
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
159
|
+
"""
|
160
|
+
super().__init__()
|
161
|
+
self.self_attn = Attention(embedding_dim, num_heads)
|
162
|
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
163
|
+
|
164
|
+
self.cross_attn_token_to_image = Attention(
|
165
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
166
|
+
)
|
167
|
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
168
|
+
|
169
|
+
self.mlp = MLP(
|
170
|
+
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
171
|
+
)
|
172
|
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
173
|
+
|
174
|
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
175
|
+
self.cross_attn_image_to_token = Attention(
|
176
|
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
177
|
+
)
|
178
|
+
|
179
|
+
self.skip_first_layer_pe = skip_first_layer_pe
|
180
|
+
|
181
|
+
def forward(
|
182
|
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
183
|
+
) -> Tuple[Tensor, Tensor]:
|
184
|
+
# Self attention block
|
185
|
+
if self.skip_first_layer_pe:
|
186
|
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
187
|
+
else:
|
188
|
+
q = queries + query_pe
|
189
|
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
190
|
+
queries = queries + attn_out
|
191
|
+
queries = self.norm1(queries)
|
192
|
+
|
193
|
+
# Cross attention block, tokens attending to image embedding
|
194
|
+
q = queries + query_pe
|
195
|
+
k = keys + key_pe
|
196
|
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
197
|
+
queries = queries + attn_out
|
198
|
+
queries = self.norm2(queries)
|
199
|
+
|
200
|
+
# MLP block
|
201
|
+
mlp_out = self.mlp(queries)
|
202
|
+
queries = queries + mlp_out
|
203
|
+
queries = self.norm3(queries)
|
204
|
+
|
205
|
+
# Cross attention block, image embedding attending to tokens
|
206
|
+
q = queries + query_pe
|
207
|
+
k = keys + key_pe
|
208
|
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
209
|
+
keys = keys + attn_out
|
210
|
+
keys = self.norm4(keys)
|
211
|
+
|
212
|
+
return queries, keys
|
213
|
+
|
214
|
+
|
215
|
+
class Attention(nn.Module):
|
216
|
+
"""
|
217
|
+
An attention layer that allows for downscaling the size of the embedding
|
218
|
+
after projection to queries, keys, and values.
|
219
|
+
"""
|
220
|
+
|
221
|
+
def __init__(
|
222
|
+
self,
|
223
|
+
embedding_dim: int,
|
224
|
+
num_heads: int,
|
225
|
+
downsample_rate: int = 1,
|
226
|
+
dropout: float = 0.0,
|
227
|
+
kv_in_dim: int = None,
|
228
|
+
) -> None:
|
229
|
+
super().__init__()
|
230
|
+
self.embedding_dim = embedding_dim
|
231
|
+
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
232
|
+
self.internal_dim = embedding_dim // downsample_rate
|
233
|
+
self.num_heads = num_heads
|
234
|
+
assert (
|
235
|
+
self.internal_dim % num_heads == 0
|
236
|
+
), "num_heads must divide embedding_dim."
|
237
|
+
|
238
|
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
239
|
+
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
240
|
+
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
241
|
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
242
|
+
|
243
|
+
self.dropout_p = dropout
|
244
|
+
|
245
|
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
246
|
+
b, n, c = x.shape
|
247
|
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
248
|
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
249
|
+
|
250
|
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
251
|
+
b, n_heads, n_tokens, c_per_head = x.shape
|
252
|
+
x = x.transpose(1, 2)
|
253
|
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
254
|
+
|
255
|
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
256
|
+
# Input projections
|
257
|
+
q = self.q_proj(q)
|
258
|
+
k = self.k_proj(k)
|
259
|
+
v = self.v_proj(v)
|
260
|
+
|
261
|
+
# Separate into heads
|
262
|
+
q = self._separate_heads(q, self.num_heads)
|
263
|
+
k = self._separate_heads(k, self.num_heads)
|
264
|
+
v = self._separate_heads(v, self.num_heads)
|
265
|
+
|
266
|
+
dropout_p = self.dropout_p if self.training else 0.0
|
267
|
+
# Attention
|
268
|
+
try:
|
269
|
+
with sdp_kernel_context(dropout_p):
|
270
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
271
|
+
except Exception as e:
|
272
|
+
# Fall back to all kernels if the Flash attention kernel fails
|
273
|
+
warnings.warn(
|
274
|
+
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
275
|
+
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
276
|
+
category=UserWarning,
|
277
|
+
stacklevel=2,
|
278
|
+
)
|
279
|
+
global ALLOW_ALL_KERNELS
|
280
|
+
ALLOW_ALL_KERNELS = True
|
281
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
282
|
+
|
283
|
+
out = self._recombine_heads(out)
|
284
|
+
out = self.out_proj(out)
|
285
|
+
|
286
|
+
return out
|
287
|
+
|
288
|
+
|
289
|
+
class RoPEAttention(Attention):
|
290
|
+
"""Attention with rotary position encoding."""
|
291
|
+
|
292
|
+
def __init__(
|
293
|
+
self,
|
294
|
+
*args,
|
295
|
+
rope_theta=10000.0,
|
296
|
+
# whether to repeat q rope to match k length
|
297
|
+
# this is needed for cross-attention to memories
|
298
|
+
rope_k_repeat=False,
|
299
|
+
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
300
|
+
**kwargs,
|
301
|
+
):
|
302
|
+
super().__init__(*args, **kwargs)
|
303
|
+
|
304
|
+
self.compute_cis = partial(
|
305
|
+
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
306
|
+
)
|
307
|
+
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
308
|
+
self.freqs_cis = freqs_cis
|
309
|
+
self.rope_k_repeat = rope_k_repeat
|
310
|
+
|
311
|
+
def forward(
|
312
|
+
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
|
313
|
+
) -> Tensor:
|
314
|
+
# Input projections
|
315
|
+
q = self.q_proj(q)
|
316
|
+
k = self.k_proj(k)
|
317
|
+
v = self.v_proj(v)
|
318
|
+
|
319
|
+
# Separate into heads
|
320
|
+
q = self._separate_heads(q, self.num_heads)
|
321
|
+
k = self._separate_heads(k, self.num_heads)
|
322
|
+
v = self._separate_heads(v, self.num_heads)
|
323
|
+
|
324
|
+
# Apply rotary position encoding
|
325
|
+
w = h = math.sqrt(q.shape[-2])
|
326
|
+
self.freqs_cis = self.freqs_cis.to(q.device)
|
327
|
+
if self.freqs_cis.shape[0] != q.shape[-2]:
|
328
|
+
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
329
|
+
if q.shape[-2] != k.shape[-2]:
|
330
|
+
assert self.rope_k_repeat
|
331
|
+
|
332
|
+
num_k_rope = k.size(-2) - num_k_exclude_rope
|
333
|
+
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
334
|
+
q,
|
335
|
+
k[:, :, :num_k_rope],
|
336
|
+
freqs_cis=self.freqs_cis,
|
337
|
+
repeat_freqs_k=self.rope_k_repeat,
|
338
|
+
)
|
339
|
+
|
340
|
+
dropout_p = self.dropout_p if self.training else 0.0
|
341
|
+
# Attention
|
342
|
+
try:
|
343
|
+
with sdp_kernel_context(dropout_p):
|
344
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
345
|
+
except Exception as e:
|
346
|
+
# Fall back to all kernels if the Flash attention kernel fails
|
347
|
+
warnings.warn(
|
348
|
+
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
349
|
+
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
350
|
+
category=UserWarning,
|
351
|
+
stacklevel=2,
|
352
|
+
)
|
353
|
+
global ALLOW_ALL_KERNELS
|
354
|
+
ALLOW_ALL_KERNELS = True
|
355
|
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
356
|
+
|
357
|
+
out = self._recombine_heads(out)
|
358
|
+
out = self.out_proj(out)
|
359
|
+
|
360
|
+
return out
|