frontveg 0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. frontveg/__init__.py +11 -0
  2. frontveg/_tests/__init__.py +0 -0
  3. frontveg/_tests/test_widget.py +66 -0
  4. frontveg/_version.py +21 -0
  5. frontveg/_widget.py +132 -0
  6. frontveg/napari.yaml +14 -0
  7. frontveg/utils.py +95 -0
  8. frontveg-0.1.dev1.dist-info/METADATA +143 -0
  9. frontveg-0.1.dev1.dist-info/RECORD +44 -0
  10. frontveg-0.1.dev1.dist-info/WHEEL +5 -0
  11. frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
  12. frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
  13. frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
  14. sam2/__init__.py +11 -0
  15. sam2/automatic_mask_generator.py +454 -0
  16. sam2/build_sam.py +167 -0
  17. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  18. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  19. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  20. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  21. sam2/modeling/__init__.py +5 -0
  22. sam2/modeling/backbones/__init__.py +5 -0
  23. sam2/modeling/backbones/hieradet.py +317 -0
  24. sam2/modeling/backbones/image_encoder.py +134 -0
  25. sam2/modeling/backbones/utils.py +95 -0
  26. sam2/modeling/memory_attention.py +169 -0
  27. sam2/modeling/memory_encoder.py +181 -0
  28. sam2/modeling/position_encoding.py +221 -0
  29. sam2/modeling/sam/__init__.py +5 -0
  30. sam2/modeling/sam/mask_decoder.py +295 -0
  31. sam2/modeling/sam/prompt_encoder.py +182 -0
  32. sam2/modeling/sam/transformer.py +360 -0
  33. sam2/modeling/sam2_base.py +907 -0
  34. sam2/modeling/sam2_utils.py +323 -0
  35. sam2/sam2_hiera_b+.yaml +1 -0
  36. sam2/sam2_hiera_l.yaml +1 -0
  37. sam2/sam2_hiera_s.yaml +1 -0
  38. sam2/sam2_hiera_t.yaml +1 -0
  39. sam2/sam2_image_predictor.py +466 -0
  40. sam2/sam2_video_predictor.py +1172 -0
  41. sam2/utils/__init__.py +5 -0
  42. sam2/utils/amg.py +348 -0
  43. sam2/utils/misc.py +349 -0
  44. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,360 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import contextlib
8
+ import math
9
+ import warnings
10
+ from functools import partial
11
+ from typing import Tuple, Type
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn, Tensor
16
+
17
+ from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
18
+ from sam2.modeling.sam2_utils import MLP
19
+ from sam2.utils.misc import get_sdpa_settings
20
+
21
+ warnings.simplefilter(action="ignore", category=FutureWarning)
22
+ # Check whether Flash Attention is available (and use it by default)
23
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
24
+ # A fallback setting to allow all available kernels if Flash Attention fails
25
+ ALLOW_ALL_KERNELS = False
26
+
27
+
28
+ def sdp_kernel_context(dropout_p):
29
+ """
30
+ Get the context for the attention scaled dot-product kernel. We use Flash Attention
31
+ by default, but fall back to all available kernels if Flash Attention fails.
32
+ """
33
+ if ALLOW_ALL_KERNELS:
34
+ return contextlib.nullcontext()
35
+
36
+ return torch.backends.cuda.sdp_kernel(
37
+ enable_flash=USE_FLASH_ATTN,
38
+ # if Flash attention kernel is off, then math kernel needs to be enabled
39
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
40
+ enable_mem_efficient=OLD_GPU,
41
+ )
42
+
43
+
44
+ class TwoWayTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ depth: int,
48
+ embedding_dim: int,
49
+ num_heads: int,
50
+ mlp_dim: int,
51
+ activation: Type[nn.Module] = nn.ReLU,
52
+ attention_downsample_rate: int = 2,
53
+ ) -> None:
54
+ """
55
+ A transformer decoder that attends to an input image using
56
+ queries whose positional embedding is supplied.
57
+
58
+ Args:
59
+ depth (int): number of layers in the transformer
60
+ embedding_dim (int): the channel dimension for the input embeddings
61
+ num_heads (int): the number of heads for multihead attention. Must
62
+ divide embedding_dim
63
+ mlp_dim (int): the channel dimension internal to the MLP block
64
+ activation (nn.Module): the activation to use in the MLP block
65
+ """
66
+ super().__init__()
67
+ self.depth = depth
68
+ self.embedding_dim = embedding_dim
69
+ self.num_heads = num_heads
70
+ self.mlp_dim = mlp_dim
71
+ self.layers = nn.ModuleList()
72
+
73
+ for i in range(depth):
74
+ self.layers.append(
75
+ TwoWayAttentionBlock(
76
+ embedding_dim=embedding_dim,
77
+ num_heads=num_heads,
78
+ mlp_dim=mlp_dim,
79
+ activation=activation,
80
+ attention_downsample_rate=attention_downsample_rate,
81
+ skip_first_layer_pe=(i == 0),
82
+ )
83
+ )
84
+
85
+ self.final_attn_token_to_image = Attention(
86
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
87
+ )
88
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
89
+
90
+ def forward(
91
+ self,
92
+ image_embedding: Tensor,
93
+ image_pe: Tensor,
94
+ point_embedding: Tensor,
95
+ ) -> Tuple[Tensor, Tensor]:
96
+ """
97
+ Args:
98
+ image_embedding (torch.Tensor): image to attend to. Should be shape
99
+ B x embedding_dim x h x w for any h and w.
100
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
101
+ have the same shape as image_embedding.
102
+ point_embedding (torch.Tensor): the embedding to add to the query points.
103
+ Must have shape B x N_points x embedding_dim for any N_points.
104
+
105
+ Returns:
106
+ torch.Tensor: the processed point_embedding
107
+ torch.Tensor: the processed image_embedding
108
+ """
109
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
110
+ bs, c, h, w = image_embedding.shape
111
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
112
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
113
+
114
+ # Prepare queries
115
+ queries = point_embedding
116
+ keys = image_embedding
117
+
118
+ # Apply transformer blocks and final layernorm
119
+ for layer in self.layers:
120
+ queries, keys = layer(
121
+ queries=queries,
122
+ keys=keys,
123
+ query_pe=point_embedding,
124
+ key_pe=image_pe,
125
+ )
126
+
127
+ # Apply the final attention layer from the points to the image
128
+ q = queries + point_embedding
129
+ k = keys + image_pe
130
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
131
+ queries = queries + attn_out
132
+ queries = self.norm_final_attn(queries)
133
+
134
+ return queries, keys
135
+
136
+
137
+ class TwoWayAttentionBlock(nn.Module):
138
+ def __init__(
139
+ self,
140
+ embedding_dim: int,
141
+ num_heads: int,
142
+ mlp_dim: int = 2048,
143
+ activation: Type[nn.Module] = nn.ReLU,
144
+ attention_downsample_rate: int = 2,
145
+ skip_first_layer_pe: bool = False,
146
+ ) -> None:
147
+ """
148
+ A transformer block with four layers: (1) self-attention of sparse
149
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
150
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
151
+ inputs.
152
+
153
+ Arguments:
154
+ embedding_dim (int): the channel dimension of the embeddings
155
+ num_heads (int): the number of heads in the attention layers
156
+ mlp_dim (int): the hidden dimension of the mlp block
157
+ activation (nn.Module): the activation of the mlp block
158
+ skip_first_layer_pe (bool): skip the PE on the first layer
159
+ """
160
+ super().__init__()
161
+ self.self_attn = Attention(embedding_dim, num_heads)
162
+ self.norm1 = nn.LayerNorm(embedding_dim)
163
+
164
+ self.cross_attn_token_to_image = Attention(
165
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
166
+ )
167
+ self.norm2 = nn.LayerNorm(embedding_dim)
168
+
169
+ self.mlp = MLP(
170
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
171
+ )
172
+ self.norm3 = nn.LayerNorm(embedding_dim)
173
+
174
+ self.norm4 = nn.LayerNorm(embedding_dim)
175
+ self.cross_attn_image_to_token = Attention(
176
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
177
+ )
178
+
179
+ self.skip_first_layer_pe = skip_first_layer_pe
180
+
181
+ def forward(
182
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
183
+ ) -> Tuple[Tensor, Tensor]:
184
+ # Self attention block
185
+ if self.skip_first_layer_pe:
186
+ queries = self.self_attn(q=queries, k=queries, v=queries)
187
+ else:
188
+ q = queries + query_pe
189
+ attn_out = self.self_attn(q=q, k=q, v=queries)
190
+ queries = queries + attn_out
191
+ queries = self.norm1(queries)
192
+
193
+ # Cross attention block, tokens attending to image embedding
194
+ q = queries + query_pe
195
+ k = keys + key_pe
196
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
197
+ queries = queries + attn_out
198
+ queries = self.norm2(queries)
199
+
200
+ # MLP block
201
+ mlp_out = self.mlp(queries)
202
+ queries = queries + mlp_out
203
+ queries = self.norm3(queries)
204
+
205
+ # Cross attention block, image embedding attending to tokens
206
+ q = queries + query_pe
207
+ k = keys + key_pe
208
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
209
+ keys = keys + attn_out
210
+ keys = self.norm4(keys)
211
+
212
+ return queries, keys
213
+
214
+
215
+ class Attention(nn.Module):
216
+ """
217
+ An attention layer that allows for downscaling the size of the embedding
218
+ after projection to queries, keys, and values.
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ embedding_dim: int,
224
+ num_heads: int,
225
+ downsample_rate: int = 1,
226
+ dropout: float = 0.0,
227
+ kv_in_dim: int = None,
228
+ ) -> None:
229
+ super().__init__()
230
+ self.embedding_dim = embedding_dim
231
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
232
+ self.internal_dim = embedding_dim // downsample_rate
233
+ self.num_heads = num_heads
234
+ assert (
235
+ self.internal_dim % num_heads == 0
236
+ ), "num_heads must divide embedding_dim."
237
+
238
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
239
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
240
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
241
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
242
+
243
+ self.dropout_p = dropout
244
+
245
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
246
+ b, n, c = x.shape
247
+ x = x.reshape(b, n, num_heads, c // num_heads)
248
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
249
+
250
+ def _recombine_heads(self, x: Tensor) -> Tensor:
251
+ b, n_heads, n_tokens, c_per_head = x.shape
252
+ x = x.transpose(1, 2)
253
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
254
+
255
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
256
+ # Input projections
257
+ q = self.q_proj(q)
258
+ k = self.k_proj(k)
259
+ v = self.v_proj(v)
260
+
261
+ # Separate into heads
262
+ q = self._separate_heads(q, self.num_heads)
263
+ k = self._separate_heads(k, self.num_heads)
264
+ v = self._separate_heads(v, self.num_heads)
265
+
266
+ dropout_p = self.dropout_p if self.training else 0.0
267
+ # Attention
268
+ try:
269
+ with sdp_kernel_context(dropout_p):
270
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
271
+ except Exception as e:
272
+ # Fall back to all kernels if the Flash attention kernel fails
273
+ warnings.warn(
274
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
275
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
276
+ category=UserWarning,
277
+ stacklevel=2,
278
+ )
279
+ global ALLOW_ALL_KERNELS
280
+ ALLOW_ALL_KERNELS = True
281
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
282
+
283
+ out = self._recombine_heads(out)
284
+ out = self.out_proj(out)
285
+
286
+ return out
287
+
288
+
289
+ class RoPEAttention(Attention):
290
+ """Attention with rotary position encoding."""
291
+
292
+ def __init__(
293
+ self,
294
+ *args,
295
+ rope_theta=10000.0,
296
+ # whether to repeat q rope to match k length
297
+ # this is needed for cross-attention to memories
298
+ rope_k_repeat=False,
299
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
300
+ **kwargs,
301
+ ):
302
+ super().__init__(*args, **kwargs)
303
+
304
+ self.compute_cis = partial(
305
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
306
+ )
307
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
308
+ self.freqs_cis = freqs_cis
309
+ self.rope_k_repeat = rope_k_repeat
310
+
311
+ def forward(
312
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
313
+ ) -> Tensor:
314
+ # Input projections
315
+ q = self.q_proj(q)
316
+ k = self.k_proj(k)
317
+ v = self.v_proj(v)
318
+
319
+ # Separate into heads
320
+ q = self._separate_heads(q, self.num_heads)
321
+ k = self._separate_heads(k, self.num_heads)
322
+ v = self._separate_heads(v, self.num_heads)
323
+
324
+ # Apply rotary position encoding
325
+ w = h = math.sqrt(q.shape[-2])
326
+ self.freqs_cis = self.freqs_cis.to(q.device)
327
+ if self.freqs_cis.shape[0] != q.shape[-2]:
328
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
329
+ if q.shape[-2] != k.shape[-2]:
330
+ assert self.rope_k_repeat
331
+
332
+ num_k_rope = k.size(-2) - num_k_exclude_rope
333
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
334
+ q,
335
+ k[:, :, :num_k_rope],
336
+ freqs_cis=self.freqs_cis,
337
+ repeat_freqs_k=self.rope_k_repeat,
338
+ )
339
+
340
+ dropout_p = self.dropout_p if self.training else 0.0
341
+ # Attention
342
+ try:
343
+ with sdp_kernel_context(dropout_p):
344
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
345
+ except Exception as e:
346
+ # Fall back to all kernels if the Flash attention kernel fails
347
+ warnings.warn(
348
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
349
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
350
+ category=UserWarning,
351
+ stacklevel=2,
352
+ )
353
+ global ALLOW_ALL_KERNELS
354
+ ALLOW_ALL_KERNELS = True
355
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
356
+
357
+ out = self._recombine_heads(out)
358
+ out = self.out_proj(out)
359
+
360
+ return out