frontveg 0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. frontveg/__init__.py +11 -0
  2. frontveg/_tests/__init__.py +0 -0
  3. frontveg/_tests/test_widget.py +66 -0
  4. frontveg/_version.py +21 -0
  5. frontveg/_widget.py +132 -0
  6. frontveg/napari.yaml +14 -0
  7. frontveg/utils.py +95 -0
  8. frontveg-0.1.dev1.dist-info/METADATA +143 -0
  9. frontveg-0.1.dev1.dist-info/RECORD +44 -0
  10. frontveg-0.1.dev1.dist-info/WHEEL +5 -0
  11. frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
  12. frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
  13. frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
  14. sam2/__init__.py +11 -0
  15. sam2/automatic_mask_generator.py +454 -0
  16. sam2/build_sam.py +167 -0
  17. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  18. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  19. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  20. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  21. sam2/modeling/__init__.py +5 -0
  22. sam2/modeling/backbones/__init__.py +5 -0
  23. sam2/modeling/backbones/hieradet.py +317 -0
  24. sam2/modeling/backbones/image_encoder.py +134 -0
  25. sam2/modeling/backbones/utils.py +95 -0
  26. sam2/modeling/memory_attention.py +169 -0
  27. sam2/modeling/memory_encoder.py +181 -0
  28. sam2/modeling/position_encoding.py +221 -0
  29. sam2/modeling/sam/__init__.py +5 -0
  30. sam2/modeling/sam/mask_decoder.py +295 -0
  31. sam2/modeling/sam/prompt_encoder.py +182 -0
  32. sam2/modeling/sam/transformer.py +360 -0
  33. sam2/modeling/sam2_base.py +907 -0
  34. sam2/modeling/sam2_utils.py +323 -0
  35. sam2/sam2_hiera_b+.yaml +1 -0
  36. sam2/sam2_hiera_l.yaml +1 -0
  37. sam2/sam2_hiera_s.yaml +1 -0
  38. sam2/sam2_hiera_t.yaml +1 -0
  39. sam2/sam2_image_predictor.py +466 -0
  40. sam2/sam2_video_predictor.py +1172 -0
  41. sam2/utils/__init__.py +5 -0
  42. sam2/utils/amg.py +348 -0
  43. sam2/utils/misc.py +349 -0
  44. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,317 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from functools import partial
9
+ from typing import List, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from iopath.common.file_io import g_pathmgr
15
+
16
+ from sam2.modeling.backbones.utils import (
17
+ PatchEmbed,
18
+ window_partition,
19
+ window_unpartition,
20
+ )
21
+
22
+ from sam2.modeling.sam2_utils import DropPath, MLP
23
+
24
+
25
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26
+ if pool is None:
27
+ return x
28
+ # (B, H, W, C) -> (B, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+ x = pool(x)
31
+ # (B, C, H', W') -> (B, H', W', C)
32
+ x = x.permute(0, 2, 3, 1)
33
+ if norm:
34
+ x = norm(x)
35
+
36
+ return x
37
+
38
+
39
+ class MultiScaleAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ dim_out: int,
44
+ num_heads: int,
45
+ q_pool: nn.Module = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.dim = dim
50
+ self.dim_out = dim_out
51
+ self.num_heads = num_heads
52
+ self.q_pool = q_pool
53
+ self.qkv = nn.Linear(dim, dim_out * 3)
54
+ self.proj = nn.Linear(dim_out, dim_out)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ B, H, W, _ = x.shape
58
+ # qkv with shape (B, H * W, 3, nHead, C)
59
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60
+ # q, k, v with shape (B, H * W, nheads, C)
61
+ q, k, v = torch.unbind(qkv, 2)
62
+
63
+ # Q pooling (for downsample at stage changes)
64
+ if self.q_pool:
65
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66
+ H, W = q.shape[1:3] # downsampled shape
67
+ q = q.reshape(B, H * W, self.num_heads, -1)
68
+
69
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70
+ x = F.scaled_dot_product_attention(
71
+ q.transpose(1, 2),
72
+ k.transpose(1, 2),
73
+ v.transpose(1, 2),
74
+ )
75
+ # Transpose back
76
+ x = x.transpose(1, 2)
77
+ x = x.reshape(B, H, W, -1)
78
+
79
+ x = self.proj(x)
80
+
81
+ return x
82
+
83
+
84
+ class MultiScaleBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ drop_path: float = 0.0,
92
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
93
+ q_stride: Tuple[int, int] = None,
94
+ act_layer: nn.Module = nn.GELU,
95
+ window_size: int = 0,
96
+ ):
97
+ super().__init__()
98
+
99
+ if isinstance(norm_layer, str):
100
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101
+
102
+ self.dim = dim
103
+ self.dim_out = dim_out
104
+ self.norm1 = norm_layer(dim)
105
+
106
+ self.window_size = window_size
107
+
108
+ self.pool, self.q_stride = None, q_stride
109
+ if self.q_stride:
110
+ self.pool = nn.MaxPool2d(
111
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
112
+ )
113
+
114
+ self.attn = MultiScaleAttention(
115
+ dim,
116
+ dim_out,
117
+ num_heads=num_heads,
118
+ q_pool=self.pool,
119
+ )
120
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121
+
122
+ self.norm2 = norm_layer(dim_out)
123
+ self.mlp = MLP(
124
+ dim_out,
125
+ int(dim_out * mlp_ratio),
126
+ dim_out,
127
+ num_layers=2,
128
+ activation=act_layer,
129
+ )
130
+
131
+ if dim != dim_out:
132
+ self.proj = nn.Linear(dim, dim_out)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ shortcut = x # B, H, W, C
136
+ x = self.norm1(x)
137
+
138
+ # Skip connection
139
+ if self.dim != self.dim_out:
140
+ shortcut = do_pool(self.proj(x), self.pool)
141
+
142
+ # Window partition
143
+ window_size = self.window_size
144
+ if window_size > 0:
145
+ H, W = x.shape[1], x.shape[2]
146
+ x, pad_hw = window_partition(x, window_size)
147
+
148
+ # Window Attention + Q Pooling (if stage change)
149
+ x = self.attn(x)
150
+ if self.q_stride:
151
+ # Shapes have changed due to Q pooling
152
+ window_size = self.window_size // self.q_stride[0]
153
+ H, W = shortcut.shape[1:3]
154
+
155
+ pad_h = (window_size - H % window_size) % window_size
156
+ pad_w = (window_size - W % window_size) % window_size
157
+ pad_hw = (H + pad_h, W + pad_w)
158
+
159
+ # Reverse window partition
160
+ if self.window_size > 0:
161
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
162
+
163
+ x = shortcut + self.drop_path(x)
164
+ # MLP
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class Hiera(nn.Module):
170
+ """
171
+ Reference: https://arxiv.org/abs/2306.00989
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ embed_dim: int = 96, # initial embed dim
177
+ num_heads: int = 1, # initial number of heads
178
+ drop_path_rate: float = 0.0, # stochastic depth
179
+ q_pool: int = 3, # number of q_pool stages
180
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
183
+ head_mul: float = 2.0, # head_mul factor at stage shift
184
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185
+ # window size per stage, when not using global att.
186
+ window_spec: Tuple[int, ...] = (
187
+ 8,
188
+ 4,
189
+ 14,
190
+ 7,
191
+ ),
192
+ # global attn in these blocks
193
+ global_att_blocks: Tuple[int, ...] = (
194
+ 12,
195
+ 16,
196
+ 20,
197
+ ),
198
+ weights_path=None,
199
+ return_interm_layers=True, # return feats from every stage
200
+ ):
201
+ super().__init__()
202
+
203
+ assert len(stages) == len(window_spec)
204
+ self.window_spec = window_spec
205
+
206
+ depth = sum(stages)
207
+ self.q_stride = q_stride
208
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
210
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211
+ self.return_interm_layers = return_interm_layers
212
+
213
+ self.patch_embed = PatchEmbed(
214
+ embed_dim=embed_dim,
215
+ )
216
+ # Which blocks have global att?
217
+ self.global_att_blocks = global_att_blocks
218
+
219
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
+ self.pos_embed = nn.Parameter(
222
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
+ )
224
+ self.pos_embed_window = nn.Parameter(
225
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
+ )
227
+
228
+ dpr = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
+ ] # stochastic depth decay rule
231
+
232
+ cur_stage = 1
233
+ self.blocks = nn.ModuleList()
234
+
235
+ for i in range(depth):
236
+ dim_out = embed_dim
237
+ # lags by a block, so first block of
238
+ # next stage uses an initial window size
239
+ # of previous stage and final window size of current stage
240
+ window_size = self.window_spec[cur_stage - 1]
241
+
242
+ if self.global_att_blocks is not None:
243
+ window_size = 0 if i in self.global_att_blocks else window_size
244
+
245
+ if i - 1 in self.stage_ends:
246
+ dim_out = int(embed_dim * dim_mul)
247
+ num_heads = int(num_heads * head_mul)
248
+ cur_stage += 1
249
+
250
+ block = MultiScaleBlock(
251
+ dim=embed_dim,
252
+ dim_out=dim_out,
253
+ num_heads=num_heads,
254
+ drop_path=dpr[i],
255
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
256
+ window_size=window_size,
257
+ )
258
+
259
+ embed_dim = dim_out
260
+ self.blocks.append(block)
261
+
262
+ self.channel_list = (
263
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
+ if return_interm_layers
265
+ else [self.blocks[-1].dim_out]
266
+ )
267
+
268
+ if weights_path is not None:
269
+ with g_pathmgr.open(weights_path, "rb") as f:
270
+ chkpt = torch.load(f, map_location="cpu")
271
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272
+
273
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274
+ h, w = hw
275
+ window_embed = self.pos_embed_window
276
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277
+ pos_embed = pos_embed + window_embed.tile(
278
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279
+ )
280
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
281
+ return pos_embed
282
+
283
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284
+ x = self.patch_embed(x)
285
+ # x: (B, H, W, C)
286
+
287
+ # Add pos embed
288
+ x = x + self._get_pos_embed(x.shape[1:3])
289
+
290
+ outputs = []
291
+ for i, blk in enumerate(self.blocks):
292
+ x = blk(x)
293
+ if (i == self.stage_ends[-1]) or (
294
+ i in self.stage_ends and self.return_interm_layers
295
+ ):
296
+ feats = x.permute(0, 3, 1, 2)
297
+ outputs.append(feats)
298
+
299
+ return outputs
300
+
301
+ def get_layer_id(self, layer_name):
302
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303
+ num_layers = self.get_num_layers()
304
+
305
+ if layer_name.find("rel_pos") != -1:
306
+ return num_layers + 1
307
+ elif layer_name.find("pos_embed") != -1:
308
+ return 0
309
+ elif layer_name.find("patch_embed") != -1:
310
+ return 0
311
+ elif layer_name.find("blocks") != -1:
312
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313
+ else:
314
+ return num_layers + 1
315
+
316
+ def get_num_layers(self) -> int:
317
+ return len(self.blocks)
@@ -0,0 +1,134 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ImageEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ trunk: nn.Module,
18
+ neck: nn.Module,
19
+ scalp: int = 0,
20
+ ):
21
+ super().__init__()
22
+ self.trunk = trunk
23
+ self.neck = neck
24
+ self.scalp = scalp
25
+ assert (
26
+ self.trunk.channel_list == self.neck.backbone_channel_list
27
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28
+
29
+ def forward(self, sample: torch.Tensor):
30
+ # Forward through backbone
31
+ features, pos = self.neck(self.trunk(sample))
32
+ if self.scalp > 0:
33
+ # Discard the lowest resolution features
34
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
35
+
36
+ src = features[-1]
37
+ output = {
38
+ "vision_features": src,
39
+ "vision_pos_enc": pos,
40
+ "backbone_fpn": features,
41
+ }
42
+ return output
43
+
44
+
45
+ class FpnNeck(nn.Module):
46
+ """
47
+ A modified variant of Feature Pyramid Network (FPN) neck
48
+ (we remove output conv and also do bicubic interpolation similar to ViT
49
+ pos embed interpolation)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ position_encoding: nn.Module,
55
+ d_model: int,
56
+ backbone_channel_list: List[int],
57
+ kernel_size: int = 1,
58
+ stride: int = 1,
59
+ padding: int = 0,
60
+ fpn_interp_model: str = "bilinear",
61
+ fuse_type: str = "sum",
62
+ fpn_top_down_levels: Optional[List[int]] = None,
63
+ ):
64
+ """Initialize the neck
65
+ :param trunk: the backbone
66
+ :param position_encoding: the positional encoding to use
67
+ :param d_model: the dimension of the model
68
+ :param neck_norm: the normalization to use
69
+ """
70
+ super().__init__()
71
+ self.position_encoding = position_encoding
72
+ self.convs = nn.ModuleList()
73
+ self.backbone_channel_list = backbone_channel_list
74
+ self.d_model = d_model
75
+ for dim in backbone_channel_list:
76
+ current = nn.Sequential()
77
+ current.add_module(
78
+ "conv",
79
+ nn.Conv2d(
80
+ in_channels=dim,
81
+ out_channels=d_model,
82
+ kernel_size=kernel_size,
83
+ stride=stride,
84
+ padding=padding,
85
+ ),
86
+ )
87
+
88
+ self.convs.append(current)
89
+ self.fpn_interp_model = fpn_interp_model
90
+ assert fuse_type in ["sum", "avg"]
91
+ self.fuse_type = fuse_type
92
+
93
+ # levels to have top-down features in its outputs
94
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95
+ # have top-down propagation, while outputs of level 0 and level 1 have only
96
+ # lateral features from the same backbone level.
97
+ if fpn_top_down_levels is None:
98
+ # default is to have top-down features on all levels
99
+ fpn_top_down_levels = range(len(self.convs))
100
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
101
+
102
+ def forward(self, xs: List[torch.Tensor]):
103
+
104
+ out = [None] * len(self.convs)
105
+ pos = [None] * len(self.convs)
106
+ assert len(xs) == len(self.convs)
107
+ # fpn forward pass
108
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109
+ prev_features = None
110
+ # forward in top-down order (from low to high resolution)
111
+ n = len(self.convs) - 1
112
+ for i in range(n, -1, -1):
113
+ x = xs[i]
114
+ lateral_features = self.convs[n - i](x)
115
+ if i in self.fpn_top_down_levels and prev_features is not None:
116
+ top_down_features = F.interpolate(
117
+ prev_features.to(dtype=torch.float32),
118
+ scale_factor=2.0,
119
+ mode=self.fpn_interp_model,
120
+ align_corners=(
121
+ None if self.fpn_interp_model == "nearest" else False
122
+ ),
123
+ antialias=False,
124
+ )
125
+ prev_features = lateral_features + top_down_features
126
+ if self.fuse_type == "avg":
127
+ prev_features /= 2
128
+ else:
129
+ prev_features = lateral_features
130
+ x_out = prev_features
131
+ out[i] = x_out
132
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133
+
134
+ return out, pos
@@ -0,0 +1,95 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+ Returns:
23
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
24
+ (Hp, Wp): padded height and width before partition
25
+ """
26
+ B, H, W, C = x.shape
27
+
28
+ pad_h = (window_size - H % window_size) % window_size
29
+ pad_w = (window_size - W % window_size) % window_size
30
+ if pad_h > 0 or pad_w > 0:
31
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32
+ Hp, Wp = H + pad_h, W + pad_w
33
+
34
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35
+ windows = (
36
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ )
38
+ return windows, (Hp, Wp)
39
+
40
+
41
+ def window_unpartition(windows, window_size, pad_hw, hw):
42
+ """
43
+ Window unpartition into original sequences and removing padding.
44
+ Args:
45
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46
+ window_size (int): window size.
47
+ pad_hw (Tuple): padded height and width (Hp, Wp).
48
+ hw (Tuple): original height and width (H, W) before padding.
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(
56
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57
+ )
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59
+
60
+ if Hp > H or Wp > W:
61
+ x = x[:, :H, :W, :].contiguous()
62
+ return x
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ """
67
+ Image to Patch Embedding.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ kernel_size: Tuple[int, ...] = (7, 7),
73
+ stride: Tuple[int, ...] = (4, 4),
74
+ padding: Tuple[int, ...] = (3, 3),
75
+ in_chans: int = 3,
76
+ embed_dim: int = 768,
77
+ ):
78
+ """
79
+ Args:
80
+ kernel_size (Tuple): kernel size of the projection layer.
81
+ stride (Tuple): stride of the projection layer.
82
+ padding (Tuple): padding size of the projection layer.
83
+ in_chans (int): Number of input image channels.
84
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
85
+ """
86
+ super().__init__()
87
+ self.proj = nn.Conv2d(
88
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ x = self.proj(x)
93
+ # B C H W -> B H W C
94
+ x = x.permute(0, 2, 3, 1)
95
+ return x