frontveg 0.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontveg/__init__.py +11 -0
- frontveg/_tests/__init__.py +0 -0
- frontveg/_tests/test_widget.py +66 -0
- frontveg/_version.py +21 -0
- frontveg/_widget.py +132 -0
- frontveg/napari.yaml +14 -0
- frontveg/utils.py +95 -0
- frontveg-0.1.dev1.dist-info/METADATA +143 -0
- frontveg-0.1.dev1.dist-info/RECORD +44 -0
- frontveg-0.1.dev1.dist-info/WHEEL +5 -0
- frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
- frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
- frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/build_sam.py +167 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +95 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +221 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +182 -0
- sam2/modeling/sam/transformer.py +360 -0
- sam2/modeling/sam2_base.py +907 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +1 -0
- sam2/sam2_hiera_l.yaml +1 -0
- sam2/sam2_hiera_s.yaml +1 -0
- sam2/sam2_hiera_t.yaml +1 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
@@ -0,0 +1,317 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
|
4
|
+
# This source code is licensed under the license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from functools import partial
|
9
|
+
from typing import List, Tuple, Union
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import torch.nn as nn
|
13
|
+
import torch.nn.functional as F
|
14
|
+
from iopath.common.file_io import g_pathmgr
|
15
|
+
|
16
|
+
from sam2.modeling.backbones.utils import (
|
17
|
+
PatchEmbed,
|
18
|
+
window_partition,
|
19
|
+
window_unpartition,
|
20
|
+
)
|
21
|
+
|
22
|
+
from sam2.modeling.sam2_utils import DropPath, MLP
|
23
|
+
|
24
|
+
|
25
|
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
26
|
+
if pool is None:
|
27
|
+
return x
|
28
|
+
# (B, H, W, C) -> (B, C, H, W)
|
29
|
+
x = x.permute(0, 3, 1, 2)
|
30
|
+
x = pool(x)
|
31
|
+
# (B, C, H', W') -> (B, H', W', C)
|
32
|
+
x = x.permute(0, 2, 3, 1)
|
33
|
+
if norm:
|
34
|
+
x = norm(x)
|
35
|
+
|
36
|
+
return x
|
37
|
+
|
38
|
+
|
39
|
+
class MultiScaleAttention(nn.Module):
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
dim: int,
|
43
|
+
dim_out: int,
|
44
|
+
num_heads: int,
|
45
|
+
q_pool: nn.Module = None,
|
46
|
+
):
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
self.dim = dim
|
50
|
+
self.dim_out = dim_out
|
51
|
+
self.num_heads = num_heads
|
52
|
+
self.q_pool = q_pool
|
53
|
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
54
|
+
self.proj = nn.Linear(dim_out, dim_out)
|
55
|
+
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
57
|
+
B, H, W, _ = x.shape
|
58
|
+
# qkv with shape (B, H * W, 3, nHead, C)
|
59
|
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
60
|
+
# q, k, v with shape (B, H * W, nheads, C)
|
61
|
+
q, k, v = torch.unbind(qkv, 2)
|
62
|
+
|
63
|
+
# Q pooling (for downsample at stage changes)
|
64
|
+
if self.q_pool:
|
65
|
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
66
|
+
H, W = q.shape[1:3] # downsampled shape
|
67
|
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
68
|
+
|
69
|
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
70
|
+
x = F.scaled_dot_product_attention(
|
71
|
+
q.transpose(1, 2),
|
72
|
+
k.transpose(1, 2),
|
73
|
+
v.transpose(1, 2),
|
74
|
+
)
|
75
|
+
# Transpose back
|
76
|
+
x = x.transpose(1, 2)
|
77
|
+
x = x.reshape(B, H, W, -1)
|
78
|
+
|
79
|
+
x = self.proj(x)
|
80
|
+
|
81
|
+
return x
|
82
|
+
|
83
|
+
|
84
|
+
class MultiScaleBlock(nn.Module):
|
85
|
+
def __init__(
|
86
|
+
self,
|
87
|
+
dim: int,
|
88
|
+
dim_out: int,
|
89
|
+
num_heads: int,
|
90
|
+
mlp_ratio: float = 4.0,
|
91
|
+
drop_path: float = 0.0,
|
92
|
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
93
|
+
q_stride: Tuple[int, int] = None,
|
94
|
+
act_layer: nn.Module = nn.GELU,
|
95
|
+
window_size: int = 0,
|
96
|
+
):
|
97
|
+
super().__init__()
|
98
|
+
|
99
|
+
if isinstance(norm_layer, str):
|
100
|
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
101
|
+
|
102
|
+
self.dim = dim
|
103
|
+
self.dim_out = dim_out
|
104
|
+
self.norm1 = norm_layer(dim)
|
105
|
+
|
106
|
+
self.window_size = window_size
|
107
|
+
|
108
|
+
self.pool, self.q_stride = None, q_stride
|
109
|
+
if self.q_stride:
|
110
|
+
self.pool = nn.MaxPool2d(
|
111
|
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
112
|
+
)
|
113
|
+
|
114
|
+
self.attn = MultiScaleAttention(
|
115
|
+
dim,
|
116
|
+
dim_out,
|
117
|
+
num_heads=num_heads,
|
118
|
+
q_pool=self.pool,
|
119
|
+
)
|
120
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
121
|
+
|
122
|
+
self.norm2 = norm_layer(dim_out)
|
123
|
+
self.mlp = MLP(
|
124
|
+
dim_out,
|
125
|
+
int(dim_out * mlp_ratio),
|
126
|
+
dim_out,
|
127
|
+
num_layers=2,
|
128
|
+
activation=act_layer,
|
129
|
+
)
|
130
|
+
|
131
|
+
if dim != dim_out:
|
132
|
+
self.proj = nn.Linear(dim, dim_out)
|
133
|
+
|
134
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
135
|
+
shortcut = x # B, H, W, C
|
136
|
+
x = self.norm1(x)
|
137
|
+
|
138
|
+
# Skip connection
|
139
|
+
if self.dim != self.dim_out:
|
140
|
+
shortcut = do_pool(self.proj(x), self.pool)
|
141
|
+
|
142
|
+
# Window partition
|
143
|
+
window_size = self.window_size
|
144
|
+
if window_size > 0:
|
145
|
+
H, W = x.shape[1], x.shape[2]
|
146
|
+
x, pad_hw = window_partition(x, window_size)
|
147
|
+
|
148
|
+
# Window Attention + Q Pooling (if stage change)
|
149
|
+
x = self.attn(x)
|
150
|
+
if self.q_stride:
|
151
|
+
# Shapes have changed due to Q pooling
|
152
|
+
window_size = self.window_size // self.q_stride[0]
|
153
|
+
H, W = shortcut.shape[1:3]
|
154
|
+
|
155
|
+
pad_h = (window_size - H % window_size) % window_size
|
156
|
+
pad_w = (window_size - W % window_size) % window_size
|
157
|
+
pad_hw = (H + pad_h, W + pad_w)
|
158
|
+
|
159
|
+
# Reverse window partition
|
160
|
+
if self.window_size > 0:
|
161
|
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
162
|
+
|
163
|
+
x = shortcut + self.drop_path(x)
|
164
|
+
# MLP
|
165
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
166
|
+
return x
|
167
|
+
|
168
|
+
|
169
|
+
class Hiera(nn.Module):
|
170
|
+
"""
|
171
|
+
Reference: https://arxiv.org/abs/2306.00989
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
embed_dim: int = 96, # initial embed dim
|
177
|
+
num_heads: int = 1, # initial number of heads
|
178
|
+
drop_path_rate: float = 0.0, # stochastic depth
|
179
|
+
q_pool: int = 3, # number of q_pool stages
|
180
|
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
181
|
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
182
|
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
183
|
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
184
|
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
185
|
+
# window size per stage, when not using global att.
|
186
|
+
window_spec: Tuple[int, ...] = (
|
187
|
+
8,
|
188
|
+
4,
|
189
|
+
14,
|
190
|
+
7,
|
191
|
+
),
|
192
|
+
# global attn in these blocks
|
193
|
+
global_att_blocks: Tuple[int, ...] = (
|
194
|
+
12,
|
195
|
+
16,
|
196
|
+
20,
|
197
|
+
),
|
198
|
+
weights_path=None,
|
199
|
+
return_interm_layers=True, # return feats from every stage
|
200
|
+
):
|
201
|
+
super().__init__()
|
202
|
+
|
203
|
+
assert len(stages) == len(window_spec)
|
204
|
+
self.window_spec = window_spec
|
205
|
+
|
206
|
+
depth = sum(stages)
|
207
|
+
self.q_stride = q_stride
|
208
|
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
209
|
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
210
|
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
211
|
+
self.return_interm_layers = return_interm_layers
|
212
|
+
|
213
|
+
self.patch_embed = PatchEmbed(
|
214
|
+
embed_dim=embed_dim,
|
215
|
+
)
|
216
|
+
# Which blocks have global att?
|
217
|
+
self.global_att_blocks = global_att_blocks
|
218
|
+
|
219
|
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
220
|
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
221
|
+
self.pos_embed = nn.Parameter(
|
222
|
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
223
|
+
)
|
224
|
+
self.pos_embed_window = nn.Parameter(
|
225
|
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
226
|
+
)
|
227
|
+
|
228
|
+
dpr = [
|
229
|
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
230
|
+
] # stochastic depth decay rule
|
231
|
+
|
232
|
+
cur_stage = 1
|
233
|
+
self.blocks = nn.ModuleList()
|
234
|
+
|
235
|
+
for i in range(depth):
|
236
|
+
dim_out = embed_dim
|
237
|
+
# lags by a block, so first block of
|
238
|
+
# next stage uses an initial window size
|
239
|
+
# of previous stage and final window size of current stage
|
240
|
+
window_size = self.window_spec[cur_stage - 1]
|
241
|
+
|
242
|
+
if self.global_att_blocks is not None:
|
243
|
+
window_size = 0 if i in self.global_att_blocks else window_size
|
244
|
+
|
245
|
+
if i - 1 in self.stage_ends:
|
246
|
+
dim_out = int(embed_dim * dim_mul)
|
247
|
+
num_heads = int(num_heads * head_mul)
|
248
|
+
cur_stage += 1
|
249
|
+
|
250
|
+
block = MultiScaleBlock(
|
251
|
+
dim=embed_dim,
|
252
|
+
dim_out=dim_out,
|
253
|
+
num_heads=num_heads,
|
254
|
+
drop_path=dpr[i],
|
255
|
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
256
|
+
window_size=window_size,
|
257
|
+
)
|
258
|
+
|
259
|
+
embed_dim = dim_out
|
260
|
+
self.blocks.append(block)
|
261
|
+
|
262
|
+
self.channel_list = (
|
263
|
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
264
|
+
if return_interm_layers
|
265
|
+
else [self.blocks[-1].dim_out]
|
266
|
+
)
|
267
|
+
|
268
|
+
if weights_path is not None:
|
269
|
+
with g_pathmgr.open(weights_path, "rb") as f:
|
270
|
+
chkpt = torch.load(f, map_location="cpu")
|
271
|
+
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
272
|
+
|
273
|
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
274
|
+
h, w = hw
|
275
|
+
window_embed = self.pos_embed_window
|
276
|
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
277
|
+
pos_embed = pos_embed + window_embed.tile(
|
278
|
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
279
|
+
)
|
280
|
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
281
|
+
return pos_embed
|
282
|
+
|
283
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
284
|
+
x = self.patch_embed(x)
|
285
|
+
# x: (B, H, W, C)
|
286
|
+
|
287
|
+
# Add pos embed
|
288
|
+
x = x + self._get_pos_embed(x.shape[1:3])
|
289
|
+
|
290
|
+
outputs = []
|
291
|
+
for i, blk in enumerate(self.blocks):
|
292
|
+
x = blk(x)
|
293
|
+
if (i == self.stage_ends[-1]) or (
|
294
|
+
i in self.stage_ends and self.return_interm_layers
|
295
|
+
):
|
296
|
+
feats = x.permute(0, 3, 1, 2)
|
297
|
+
outputs.append(feats)
|
298
|
+
|
299
|
+
return outputs
|
300
|
+
|
301
|
+
def get_layer_id(self, layer_name):
|
302
|
+
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
303
|
+
num_layers = self.get_num_layers()
|
304
|
+
|
305
|
+
if layer_name.find("rel_pos") != -1:
|
306
|
+
return num_layers + 1
|
307
|
+
elif layer_name.find("pos_embed") != -1:
|
308
|
+
return 0
|
309
|
+
elif layer_name.find("patch_embed") != -1:
|
310
|
+
return 0
|
311
|
+
elif layer_name.find("blocks") != -1:
|
312
|
+
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
313
|
+
else:
|
314
|
+
return num_layers + 1
|
315
|
+
|
316
|
+
def get_num_layers(self) -> int:
|
317
|
+
return len(self.blocks)
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
|
4
|
+
# This source code is licensed under the license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
from typing import List, Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
import torch.nn as nn
|
11
|
+
import torch.nn.functional as F
|
12
|
+
|
13
|
+
|
14
|
+
class ImageEncoder(nn.Module):
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
trunk: nn.Module,
|
18
|
+
neck: nn.Module,
|
19
|
+
scalp: int = 0,
|
20
|
+
):
|
21
|
+
super().__init__()
|
22
|
+
self.trunk = trunk
|
23
|
+
self.neck = neck
|
24
|
+
self.scalp = scalp
|
25
|
+
assert (
|
26
|
+
self.trunk.channel_list == self.neck.backbone_channel_list
|
27
|
+
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
|
28
|
+
|
29
|
+
def forward(self, sample: torch.Tensor):
|
30
|
+
# Forward through backbone
|
31
|
+
features, pos = self.neck(self.trunk(sample))
|
32
|
+
if self.scalp > 0:
|
33
|
+
# Discard the lowest resolution features
|
34
|
+
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
35
|
+
|
36
|
+
src = features[-1]
|
37
|
+
output = {
|
38
|
+
"vision_features": src,
|
39
|
+
"vision_pos_enc": pos,
|
40
|
+
"backbone_fpn": features,
|
41
|
+
}
|
42
|
+
return output
|
43
|
+
|
44
|
+
|
45
|
+
class FpnNeck(nn.Module):
|
46
|
+
"""
|
47
|
+
A modified variant of Feature Pyramid Network (FPN) neck
|
48
|
+
(we remove output conv and also do bicubic interpolation similar to ViT
|
49
|
+
pos embed interpolation)
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
position_encoding: nn.Module,
|
55
|
+
d_model: int,
|
56
|
+
backbone_channel_list: List[int],
|
57
|
+
kernel_size: int = 1,
|
58
|
+
stride: int = 1,
|
59
|
+
padding: int = 0,
|
60
|
+
fpn_interp_model: str = "bilinear",
|
61
|
+
fuse_type: str = "sum",
|
62
|
+
fpn_top_down_levels: Optional[List[int]] = None,
|
63
|
+
):
|
64
|
+
"""Initialize the neck
|
65
|
+
:param trunk: the backbone
|
66
|
+
:param position_encoding: the positional encoding to use
|
67
|
+
:param d_model: the dimension of the model
|
68
|
+
:param neck_norm: the normalization to use
|
69
|
+
"""
|
70
|
+
super().__init__()
|
71
|
+
self.position_encoding = position_encoding
|
72
|
+
self.convs = nn.ModuleList()
|
73
|
+
self.backbone_channel_list = backbone_channel_list
|
74
|
+
self.d_model = d_model
|
75
|
+
for dim in backbone_channel_list:
|
76
|
+
current = nn.Sequential()
|
77
|
+
current.add_module(
|
78
|
+
"conv",
|
79
|
+
nn.Conv2d(
|
80
|
+
in_channels=dim,
|
81
|
+
out_channels=d_model,
|
82
|
+
kernel_size=kernel_size,
|
83
|
+
stride=stride,
|
84
|
+
padding=padding,
|
85
|
+
),
|
86
|
+
)
|
87
|
+
|
88
|
+
self.convs.append(current)
|
89
|
+
self.fpn_interp_model = fpn_interp_model
|
90
|
+
assert fuse_type in ["sum", "avg"]
|
91
|
+
self.fuse_type = fuse_type
|
92
|
+
|
93
|
+
# levels to have top-down features in its outputs
|
94
|
+
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
95
|
+
# have top-down propagation, while outputs of level 0 and level 1 have only
|
96
|
+
# lateral features from the same backbone level.
|
97
|
+
if fpn_top_down_levels is None:
|
98
|
+
# default is to have top-down features on all levels
|
99
|
+
fpn_top_down_levels = range(len(self.convs))
|
100
|
+
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
101
|
+
|
102
|
+
def forward(self, xs: List[torch.Tensor]):
|
103
|
+
|
104
|
+
out = [None] * len(self.convs)
|
105
|
+
pos = [None] * len(self.convs)
|
106
|
+
assert len(xs) == len(self.convs)
|
107
|
+
# fpn forward pass
|
108
|
+
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
109
|
+
prev_features = None
|
110
|
+
# forward in top-down order (from low to high resolution)
|
111
|
+
n = len(self.convs) - 1
|
112
|
+
for i in range(n, -1, -1):
|
113
|
+
x = xs[i]
|
114
|
+
lateral_features = self.convs[n - i](x)
|
115
|
+
if i in self.fpn_top_down_levels and prev_features is not None:
|
116
|
+
top_down_features = F.interpolate(
|
117
|
+
prev_features.to(dtype=torch.float32),
|
118
|
+
scale_factor=2.0,
|
119
|
+
mode=self.fpn_interp_model,
|
120
|
+
align_corners=(
|
121
|
+
None if self.fpn_interp_model == "nearest" else False
|
122
|
+
),
|
123
|
+
antialias=False,
|
124
|
+
)
|
125
|
+
prev_features = lateral_features + top_down_features
|
126
|
+
if self.fuse_type == "avg":
|
127
|
+
prev_features /= 2
|
128
|
+
else:
|
129
|
+
prev_features = lateral_features
|
130
|
+
x_out = prev_features
|
131
|
+
out[i] = x_out
|
132
|
+
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
133
|
+
|
134
|
+
return out, pos
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
|
4
|
+
# This source code is licensed under the license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
"""Some utilities for backbones, in particular for windowing"""
|
8
|
+
|
9
|
+
from typing import Tuple
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import torch.nn as nn
|
13
|
+
import torch.nn.functional as F
|
14
|
+
|
15
|
+
|
16
|
+
def window_partition(x, window_size):
|
17
|
+
"""
|
18
|
+
Partition into non-overlapping windows with padding if needed.
|
19
|
+
Args:
|
20
|
+
x (tensor): input tokens with [B, H, W, C].
|
21
|
+
window_size (int): window size.
|
22
|
+
Returns:
|
23
|
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
24
|
+
(Hp, Wp): padded height and width before partition
|
25
|
+
"""
|
26
|
+
B, H, W, C = x.shape
|
27
|
+
|
28
|
+
pad_h = (window_size - H % window_size) % window_size
|
29
|
+
pad_w = (window_size - W % window_size) % window_size
|
30
|
+
if pad_h > 0 or pad_w > 0:
|
31
|
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
32
|
+
Hp, Wp = H + pad_h, W + pad_w
|
33
|
+
|
34
|
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
35
|
+
windows = (
|
36
|
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
37
|
+
)
|
38
|
+
return windows, (Hp, Wp)
|
39
|
+
|
40
|
+
|
41
|
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
42
|
+
"""
|
43
|
+
Window unpartition into original sequences and removing padding.
|
44
|
+
Args:
|
45
|
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
46
|
+
window_size (int): window size.
|
47
|
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
48
|
+
hw (Tuple): original height and width (H, W) before padding.
|
49
|
+
Returns:
|
50
|
+
x: unpartitioned sequences with [B, H, W, C].
|
51
|
+
"""
|
52
|
+
Hp, Wp = pad_hw
|
53
|
+
H, W = hw
|
54
|
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
55
|
+
x = windows.view(
|
56
|
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
57
|
+
)
|
58
|
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
59
|
+
|
60
|
+
if Hp > H or Wp > W:
|
61
|
+
x = x[:, :H, :W, :].contiguous()
|
62
|
+
return x
|
63
|
+
|
64
|
+
|
65
|
+
class PatchEmbed(nn.Module):
|
66
|
+
"""
|
67
|
+
Image to Patch Embedding.
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
kernel_size: Tuple[int, ...] = (7, 7),
|
73
|
+
stride: Tuple[int, ...] = (4, 4),
|
74
|
+
padding: Tuple[int, ...] = (3, 3),
|
75
|
+
in_chans: int = 3,
|
76
|
+
embed_dim: int = 768,
|
77
|
+
):
|
78
|
+
"""
|
79
|
+
Args:
|
80
|
+
kernel_size (Tuple): kernel size of the projection layer.
|
81
|
+
stride (Tuple): stride of the projection layer.
|
82
|
+
padding (Tuple): padding size of the projection layer.
|
83
|
+
in_chans (int): Number of input image channels.
|
84
|
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
85
|
+
"""
|
86
|
+
super().__init__()
|
87
|
+
self.proj = nn.Conv2d(
|
88
|
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
89
|
+
)
|
90
|
+
|
91
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
92
|
+
x = self.proj(x)
|
93
|
+
# B C H W -> B H W C
|
94
|
+
x = x.permute(0, 2, 3, 1)
|
95
|
+
return x
|