birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/iformer.py
CHANGED
|
@@ -113,12 +113,12 @@ class LowMixer(nn.Module):
|
|
|
113
113
|
|
|
114
114
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
115
115
|
x = self.pool(x)
|
|
116
|
-
|
|
116
|
+
B, _, H, W = x.size()
|
|
117
117
|
x = x.permute(0, 2, 3, 1).view(B, -1, self.dim)
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
B, N, C = x.size()
|
|
120
120
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
121
|
-
|
|
121
|
+
q, k, v = qkv.unbind(0)
|
|
122
122
|
x = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
123
123
|
q, k, v, dropout_p=self.attn_drop if self.training else 0.0, scale=self.scale
|
|
124
124
|
)
|
|
@@ -301,7 +301,7 @@ class InceptionTransformerStage(nn.Module):
|
|
|
301
301
|
|
|
302
302
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
303
303
|
x = self.downsample(x)
|
|
304
|
-
|
|
304
|
+
H, W = x.shape[1:3]
|
|
305
305
|
|
|
306
306
|
x = x + self._get_pos_embed(H, W)
|
|
307
307
|
x = self.blocks(x)
|
birder/net/inception_next.py
CHANGED
|
@@ -33,7 +33,6 @@ class InceptionDWConv2d(nn.Module):
|
|
|
33
33
|
stride=(1, 1),
|
|
34
34
|
padding=square_kernel_size // 2,
|
|
35
35
|
groups=branch_channels,
|
|
36
|
-
bias=True,
|
|
37
36
|
)
|
|
38
37
|
self.dwconv_w = nn.Conv2d(
|
|
39
38
|
branch_channels,
|
|
@@ -42,7 +41,6 @@ class InceptionDWConv2d(nn.Module):
|
|
|
42
41
|
stride=(1, 1),
|
|
43
42
|
padding=(0, band_kernel_size // 2),
|
|
44
43
|
groups=branch_channels,
|
|
45
|
-
bias=True,
|
|
46
44
|
)
|
|
47
45
|
self.dwconv_h = nn.Conv2d(
|
|
48
46
|
branch_channels,
|
|
@@ -51,7 +49,6 @@ class InceptionDWConv2d(nn.Module):
|
|
|
51
49
|
stride=(1, 1),
|
|
52
50
|
padding=(band_kernel_size // 2, 0),
|
|
53
51
|
groups=branch_channels,
|
|
54
|
-
bias=True,
|
|
55
52
|
)
|
|
56
53
|
self.split_indexes = (
|
|
57
54
|
in_channels - (3 * branch_channels),
|
|
@@ -61,7 +58,7 @@ class InceptionDWConv2d(nn.Module):
|
|
|
61
58
|
)
|
|
62
59
|
|
|
63
60
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
64
|
-
|
|
61
|
+
x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
|
|
65
62
|
x_hw = self.dwconv_hw(x_hw)
|
|
66
63
|
x_w = self.dwconv_w(x_w)
|
|
67
64
|
x_h = self.dwconv_h(x_h)
|
|
@@ -78,11 +75,9 @@ class ConvMLP(nn.Module):
|
|
|
78
75
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
|
79
76
|
) -> None:
|
|
80
77
|
super().__init__()
|
|
81
|
-
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
78
|
+
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
82
79
|
self.act = act_layer()
|
|
83
|
-
self.fc2 = nn.Conv2d(
|
|
84
|
-
hidden_features, out_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
|
|
85
|
-
)
|
|
80
|
+
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
86
81
|
|
|
87
82
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
88
83
|
x = self.fc1(x)
|
|
@@ -139,12 +134,7 @@ class InceptionNeXtStage(nn.Module):
|
|
|
139
134
|
self.downsample = nn.Sequential(
|
|
140
135
|
nn.BatchNorm2d(in_channels),
|
|
141
136
|
nn.Conv2d(
|
|
142
|
-
in_channels,
|
|
143
|
-
out_channels,
|
|
144
|
-
kernel_size=(stride, stride),
|
|
145
|
-
stride=(stride, stride),
|
|
146
|
-
padding=(0, 0),
|
|
147
|
-
bias=True,
|
|
137
|
+
in_channels, out_channels, kernel_size=(stride, stride), stride=(stride, stride), padding=(0, 0)
|
|
148
138
|
),
|
|
149
139
|
)
|
|
150
140
|
|
birder/net/levit.py
CHANGED
|
@@ -45,7 +45,7 @@ class Subsample(nn.Module):
|
|
|
45
45
|
self.resolution = resolution
|
|
46
46
|
|
|
47
47
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
48
|
-
|
|
48
|
+
B, _, C = x.shape
|
|
49
49
|
x = x.view(B, self.resolution[0], self.resolution[1], C)
|
|
50
50
|
x = x[:, :: self.stride, :: self.stride]
|
|
51
51
|
return x.reshape(B, -1, C)
|
|
@@ -84,7 +84,7 @@ class Attention(nn.Module):
|
|
|
84
84
|
self.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
|
|
85
85
|
|
|
86
86
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
87
|
-
|
|
87
|
+
B, N, _ = x.shape
|
|
88
88
|
q, k, v = self.qkv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
|
|
89
89
|
q = q.permute(0, 2, 1, 3)
|
|
90
90
|
k = k.permute(0, 2, 3, 1)
|
|
@@ -144,7 +144,7 @@ class AttentionSubsample(nn.Module):
|
|
|
144
144
|
self.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
|
|
145
145
|
|
|
146
146
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
147
|
-
|
|
147
|
+
B, N, _ = x.shape
|
|
148
148
|
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
|
|
149
149
|
k = k.permute(0, 2, 3, 1) # BHCN
|
|
150
150
|
v = v.permute(0, 2, 1, 3) # BHNC
|
birder/net/lit_v1.py
CHANGED
|
@@ -43,7 +43,7 @@ def interpolate_rel_pos_bias_table(
|
|
|
43
43
|
if new_resolution == base_resolution:
|
|
44
44
|
return rel_pos_bias_table
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
base_h, base_w = base_resolution
|
|
47
47
|
num_heads = rel_pos_bias_table.size(1)
|
|
48
48
|
orig_dtype = rel_pos_bias_table.dtype
|
|
49
49
|
bias_table = rel_pos_bias_table.float()
|
|
@@ -104,7 +104,7 @@ class RelPosAttention(nn.Module):
|
|
|
104
104
|
relative_position_index = build_relative_position_index(input_resolution, device=bias_table.device)
|
|
105
105
|
self.relative_position_index = nn.Buffer(relative_position_index)
|
|
106
106
|
|
|
107
|
-
self.qkv = nn.Linear(dim, dim * 3
|
|
107
|
+
self.qkv = nn.Linear(dim, dim * 3)
|
|
108
108
|
self.proj = nn.Linear(dim, dim)
|
|
109
109
|
|
|
110
110
|
# Weight initialization
|
|
@@ -130,9 +130,9 @@ class RelPosAttention(nn.Module):
|
|
|
130
130
|
return relative_position_bias.unsqueeze(0)
|
|
131
131
|
|
|
132
132
|
def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
|
|
133
|
-
|
|
133
|
+
B, N, C = x.size()
|
|
134
134
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
135
|
-
|
|
135
|
+
q, k, v = qkv.unbind(0)
|
|
136
136
|
|
|
137
137
|
attn = (q * self.scale) @ k.transpose(-2, -1)
|
|
138
138
|
attn = attn + self._get_rel_pos_bias(resolution)
|
|
@@ -177,7 +177,6 @@ class DeformablePatchMerging(nn.Module):
|
|
|
177
177
|
kernel_size=(kernel_size, kernel_size),
|
|
178
178
|
stride=(kernel_size, kernel_size),
|
|
179
179
|
padding=(0, 0),
|
|
180
|
-
bias=True,
|
|
181
180
|
)
|
|
182
181
|
self.deform_conv = DeformConv2d(
|
|
183
182
|
in_dim,
|
|
@@ -195,8 +194,8 @@ class DeformablePatchMerging(nn.Module):
|
|
|
195
194
|
nn.init.zeros_(self.offset_conv.bias)
|
|
196
195
|
|
|
197
196
|
def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
|
|
198
|
-
|
|
199
|
-
|
|
197
|
+
H, W = resolution
|
|
198
|
+
B, _, C = x.size()
|
|
200
199
|
|
|
201
200
|
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
202
201
|
|
|
@@ -206,7 +205,7 @@ class DeformablePatchMerging(nn.Module):
|
|
|
206
205
|
x = self.norm(x)
|
|
207
206
|
x = self.act(x)
|
|
208
207
|
|
|
209
|
-
|
|
208
|
+
B, C, H, W = x.size()
|
|
210
209
|
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
|
211
210
|
|
|
212
211
|
return (x, H, W)
|
|
@@ -252,7 +251,7 @@ class LITStage(nn.Module):
|
|
|
252
251
|
block.set_dynamic_size(dynamic_size)
|
|
253
252
|
|
|
254
253
|
def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
|
|
255
|
-
|
|
254
|
+
x, H, W = self.downsample(x, input_resolution)
|
|
256
255
|
for block in self.blocks:
|
|
257
256
|
x = block(x, (H, W))
|
|
258
257
|
|
|
@@ -291,7 +290,6 @@ class LIT_v1(DetectorBackbone):
|
|
|
291
290
|
kernel_size=(patch_size, patch_size),
|
|
292
291
|
stride=(patch_size, patch_size),
|
|
293
292
|
padding=(0, 0),
|
|
294
|
-
bias=True,
|
|
295
293
|
),
|
|
296
294
|
Permute([0, 2, 3, 1]),
|
|
297
295
|
nn.LayerNorm(embed_dim),
|
|
@@ -361,12 +359,12 @@ class LIT_v1(DetectorBackbone):
|
|
|
361
359
|
|
|
362
360
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
363
361
|
x = self.stem(x)
|
|
364
|
-
|
|
362
|
+
B, H, W, C = x.size()
|
|
365
363
|
x = x.reshape(B, H * W, C)
|
|
366
364
|
|
|
367
365
|
out = {}
|
|
368
366
|
for name, stage in self.body.items():
|
|
369
|
-
|
|
367
|
+
x, H, W = stage(x, (H, W))
|
|
370
368
|
if name in self.return_stages:
|
|
371
369
|
features = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
372
370
|
out[name] = features
|
|
@@ -386,10 +384,10 @@ class LIT_v1(DetectorBackbone):
|
|
|
386
384
|
|
|
387
385
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
388
386
|
x = self.stem(x)
|
|
389
|
-
|
|
387
|
+
B, H, W, C = x.size()
|
|
390
388
|
x = x.reshape(B, H * W, C)
|
|
391
389
|
for stage in self.body.values():
|
|
392
|
-
|
|
390
|
+
x, H, W = stage(x, (H, W))
|
|
393
391
|
|
|
394
392
|
return x
|
|
395
393
|
|
|
@@ -410,7 +408,7 @@ class LIT_v1(DetectorBackbone):
|
|
|
410
408
|
|
|
411
409
|
new_patches_resolution = (new_size[0] // self.patch_size, new_size[1] // self.patch_size)
|
|
412
410
|
|
|
413
|
-
|
|
411
|
+
h, w = new_patches_resolution
|
|
414
412
|
for stage in self.body.values():
|
|
415
413
|
if not isinstance(stage.downsample, IdentityDownsample):
|
|
416
414
|
h = h // 2
|
birder/net/lit_v1_tiny.py
CHANGED
|
@@ -44,13 +44,13 @@ class Attention(nn.Module):
|
|
|
44
44
|
super().__init__()
|
|
45
45
|
self.num_heads = num_heads
|
|
46
46
|
self.scale = (dim // num_heads) ** -0.5
|
|
47
|
-
self.qkv = nn.Linear(dim, dim * 3
|
|
47
|
+
self.qkv = nn.Linear(dim, dim * 3)
|
|
48
48
|
self.proj = nn.Linear(dim, dim)
|
|
49
49
|
|
|
50
50
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
-
|
|
51
|
+
B, N, C = x.size()
|
|
52
52
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
53
|
-
|
|
53
|
+
q, k, v = qkv.unbind(0)
|
|
54
54
|
|
|
55
55
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
56
56
|
attn = F.softmax(attn, dim=-1)
|
|
@@ -139,7 +139,7 @@ class LITStage(nn.Module):
|
|
|
139
139
|
)
|
|
140
140
|
|
|
141
141
|
def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
|
|
142
|
-
|
|
142
|
+
x, H, W = self.downsample(x, input_resolution)
|
|
143
143
|
|
|
144
144
|
if self.cls_token is not None:
|
|
145
145
|
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
|
|
@@ -247,12 +247,12 @@ class LIT_v1_Tiny(DetectorBackbone):
|
|
|
247
247
|
|
|
248
248
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
249
249
|
x = self.stem(x)
|
|
250
|
-
|
|
250
|
+
B, H, W, C = x.size()
|
|
251
251
|
x = x.reshape(B, H * W, C)
|
|
252
252
|
|
|
253
253
|
out = {}
|
|
254
254
|
for name, stage in self.body.items():
|
|
255
|
-
|
|
255
|
+
x, H, W = stage(x, (H, W))
|
|
256
256
|
if name in self.return_stages:
|
|
257
257
|
if stage.cls_token is not None:
|
|
258
258
|
spatial_x = x[:, 1:]
|
|
@@ -276,10 +276,10 @@ class LIT_v1_Tiny(DetectorBackbone):
|
|
|
276
276
|
|
|
277
277
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
278
278
|
x = self.stem(x)
|
|
279
|
-
|
|
279
|
+
B, H, W, C = x.size()
|
|
280
280
|
x = x.reshape(B, H * W, C)
|
|
281
281
|
for stage in self.body.values():
|
|
282
|
-
|
|
282
|
+
x, H, W = stage(x, (H, W))
|
|
283
283
|
|
|
284
284
|
return x
|
|
285
285
|
|
|
@@ -301,7 +301,7 @@ class LIT_v1_Tiny(DetectorBackbone):
|
|
|
301
301
|
|
|
302
302
|
new_patches_resolution = (new_size[0] // self.patch_size, new_size[1] // self.patch_size)
|
|
303
303
|
|
|
304
|
-
|
|
304
|
+
h, w = new_patches_resolution
|
|
305
305
|
for stage in self.body.values():
|
|
306
306
|
if not isinstance(stage.downsample, IdentityDownsample):
|
|
307
307
|
h = h // 2
|
birder/net/lit_v2.py
CHANGED
|
@@ -39,7 +39,7 @@ class DepthwiseMLP(nn.Module):
|
|
|
39
39
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
40
40
|
x = self.fc1(x)
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
B, N, C = x.size()
|
|
43
43
|
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
44
44
|
x = self.dwconv(x)
|
|
45
45
|
x = x.permute(0, 2, 3, 1).reshape(B, N, C)
|
|
@@ -57,7 +57,7 @@ class DepthwiseMLPBlock(nn.Module):
|
|
|
57
57
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
58
58
|
|
|
59
59
|
def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
|
|
60
|
-
|
|
60
|
+
H, W = resolution
|
|
61
61
|
return x + self.drop_path(self.mlp(self.norm(x), H, W))
|
|
62
62
|
|
|
63
63
|
|
|
@@ -121,7 +121,7 @@ class HiLoAttention(nn.Module):
|
|
|
121
121
|
self.h_proj = nn.Identity()
|
|
122
122
|
|
|
123
123
|
def _lofi(self, x: torch.Tensor) -> torch.Tensor:
|
|
124
|
-
|
|
124
|
+
B, H, W, C = x.size()
|
|
125
125
|
|
|
126
126
|
q = self.l_q(x).reshape(B, H * W, self.l_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
127
127
|
|
|
@@ -133,7 +133,7 @@ class HiLoAttention(nn.Module):
|
|
|
133
133
|
else:
|
|
134
134
|
kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
135
135
|
|
|
136
|
-
|
|
136
|
+
k, v = kv.unbind(0)
|
|
137
137
|
|
|
138
138
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
139
139
|
attn = F.softmax(attn, dim=-1)
|
|
@@ -144,7 +144,7 @@ class HiLoAttention(nn.Module):
|
|
|
144
144
|
return x
|
|
145
145
|
|
|
146
146
|
def _hifi(self, x: torch.Tensor) -> torch.Tensor:
|
|
147
|
-
|
|
147
|
+
B, H, W, _ = x.size()
|
|
148
148
|
ws = self.window_size
|
|
149
149
|
|
|
150
150
|
# Pad if needed
|
|
@@ -153,7 +153,7 @@ class HiLoAttention(nn.Module):
|
|
|
153
153
|
if pad_h > 0 or pad_w > 0:
|
|
154
154
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
|
155
155
|
|
|
156
|
-
|
|
156
|
+
_, h_pad, w_pad, _ = x.size()
|
|
157
157
|
h_groups = h_pad // ws
|
|
158
158
|
w_groups = w_pad // ws
|
|
159
159
|
total_groups = h_groups * w_groups
|
|
@@ -161,7 +161,7 @@ class HiLoAttention(nn.Module):
|
|
|
161
161
|
x = x.reshape(B, h_groups, ws, w_groups, ws, -1).transpose(2, 3)
|
|
162
162
|
|
|
163
163
|
qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.head_dim).permute(3, 0, 1, 4, 2, 5)
|
|
164
|
-
|
|
164
|
+
q, k, v = qkv.unbind(0)
|
|
165
165
|
|
|
166
166
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
167
167
|
attn = F.softmax(attn, dim=-1)
|
|
@@ -177,7 +177,7 @@ class HiLoAttention(nn.Module):
|
|
|
177
177
|
return x
|
|
178
178
|
|
|
179
179
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
180
|
-
|
|
180
|
+
B, N, C = x.size()
|
|
181
181
|
x = x.reshape(B, H, W, C)
|
|
182
182
|
|
|
183
183
|
if self.h_heads == 0:
|
|
@@ -215,7 +215,7 @@ class HiLoBlock(nn.Module):
|
|
|
215
215
|
self.drop_path2 = StochasticDepth(drop_path, mode="row")
|
|
216
216
|
|
|
217
217
|
def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
|
|
218
|
-
|
|
218
|
+
H, W = resolution
|
|
219
219
|
x = x + self.drop_path1(self.attn(self.norm1(x), H, W))
|
|
220
220
|
x = x + self.drop_path2(self.mlp(self.norm2(x), H, W))
|
|
221
221
|
return x
|
|
@@ -252,7 +252,7 @@ class LITStage(nn.Module):
|
|
|
252
252
|
self.blocks = nn.ModuleList(blocks)
|
|
253
253
|
|
|
254
254
|
def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
|
|
255
|
-
|
|
255
|
+
x, H, W = self.downsample(x, input_resolution)
|
|
256
256
|
for block in self.blocks:
|
|
257
257
|
x = block(x, (H, W))
|
|
258
258
|
|
|
@@ -292,7 +292,6 @@ class LIT_v2(DetectorBackbone):
|
|
|
292
292
|
kernel_size=(patch_size, patch_size),
|
|
293
293
|
stride=(patch_size, patch_size),
|
|
294
294
|
padding=(0, 0),
|
|
295
|
-
bias=True,
|
|
296
295
|
),
|
|
297
296
|
Permute([0, 2, 3, 1]),
|
|
298
297
|
nn.LayerNorm(embed_dim),
|
|
@@ -361,12 +360,12 @@ class LIT_v2(DetectorBackbone):
|
|
|
361
360
|
|
|
362
361
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
363
362
|
x = self.stem(x)
|
|
364
|
-
|
|
363
|
+
B, H, W, C = x.size()
|
|
365
364
|
x = x.reshape(B, H * W, C)
|
|
366
365
|
|
|
367
366
|
out = {}
|
|
368
367
|
for name, stage in self.body.items():
|
|
369
|
-
|
|
368
|
+
x, H, W = stage(x, (H, W))
|
|
370
369
|
if name in self.return_stages:
|
|
371
370
|
features = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
372
371
|
out[name] = features
|
|
@@ -386,10 +385,10 @@ class LIT_v2(DetectorBackbone):
|
|
|
386
385
|
|
|
387
386
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
388
387
|
x = self.stem(x)
|
|
389
|
-
|
|
388
|
+
B, H, W, C = x.size()
|
|
390
389
|
x = x.reshape(B, H * W, C)
|
|
391
390
|
for stage in self.body.values():
|
|
392
|
-
|
|
391
|
+
x, H, W = stage(x, (H, W))
|
|
393
392
|
|
|
394
393
|
return x
|
|
395
394
|
|
birder/net/maxvit.py
CHANGED
|
@@ -83,7 +83,7 @@ class MBConv(nn.Module):
|
|
|
83
83
|
if stride[0] != 1 or stride[1] != 1 or in_channels != out_channels:
|
|
84
84
|
self.proj = nn.Sequential(
|
|
85
85
|
nn.AvgPool2d(kernel_size=(3, 3), stride=stride, padding=(1, 1)),
|
|
86
|
-
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
86
|
+
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
|
|
87
87
|
)
|
|
88
88
|
else:
|
|
89
89
|
self.proj = nn.Identity()
|
|
@@ -119,12 +119,7 @@ class MBConv(nn.Module):
|
|
|
119
119
|
),
|
|
120
120
|
SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU),
|
|
121
121
|
nn.Conv2d(
|
|
122
|
-
in_channels=mid_channels,
|
|
123
|
-
out_channels=out_channels,
|
|
124
|
-
kernel_size=(1, 1),
|
|
125
|
-
stride=(1, 1),
|
|
126
|
-
padding=(0, 0),
|
|
127
|
-
bias=True,
|
|
122
|
+
in_channels=mid_channels, out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
128
123
|
),
|
|
129
124
|
)
|
|
130
125
|
|
|
@@ -169,12 +164,12 @@ class RelativePositionalMultiHeadAttention(nn.Module):
|
|
|
169
164
|
|
|
170
165
|
# pylint: disable=invalid-name
|
|
171
166
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
172
|
-
|
|
167
|
+
B, G, P, D = x.size()
|
|
173
168
|
H = self.n_heads
|
|
174
169
|
DH = self.head_dim
|
|
175
170
|
|
|
176
171
|
qkv = self.to_qkv(x)
|
|
177
|
-
|
|
172
|
+
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
|
178
173
|
|
|
179
174
|
q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
|
|
180
175
|
k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
|
|
@@ -206,8 +201,8 @@ class SwapAxes(nn.Module):
|
|
|
206
201
|
|
|
207
202
|
class WindowPartition(nn.Module):
|
|
208
203
|
def forward(self, x: torch.Tensor, p: tuple[int, int]) -> torch.Tensor:
|
|
209
|
-
|
|
210
|
-
|
|
204
|
+
B, C, H, W = x.size()
|
|
205
|
+
PH, PW = p # pylint: disable=invalid-name
|
|
211
206
|
|
|
212
207
|
# Chunk up H and W dimensions
|
|
213
208
|
x = x.reshape(B, C, H // PH, PH, W // PW, PW)
|
|
@@ -222,8 +217,8 @@ class WindowPartition(nn.Module):
|
|
|
222
217
|
class WindowDepartition(nn.Module):
|
|
223
218
|
# pylint: disable=invalid-name
|
|
224
219
|
def forward(self, x: torch.Tensor, p: tuple[int, int], h_partitions: int, w_partitions: int) -> torch.Tensor:
|
|
225
|
-
|
|
226
|
-
|
|
220
|
+
B, _G, _PP, C = x.size()
|
|
221
|
+
PH, PW = p # pylint: disable=invalid-name
|
|
227
222
|
HP = h_partitions
|
|
228
223
|
WP = w_partitions
|
|
229
224
|
|
|
@@ -500,14 +495,7 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
500
495
|
activation_layer=nn.GELU,
|
|
501
496
|
inplace=None,
|
|
502
497
|
),
|
|
503
|
-
nn.Conv2d(
|
|
504
|
-
stem_channels,
|
|
505
|
-
stem_channels,
|
|
506
|
-
kernel_size=(3, 3),
|
|
507
|
-
stride=(1, 1),
|
|
508
|
-
padding=(1, 1),
|
|
509
|
-
bias=True,
|
|
510
|
-
),
|
|
498
|
+
nn.Conv2d(stem_channels, stem_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
|
511
499
|
)
|
|
512
500
|
|
|
513
501
|
# Account for stem stride
|
|
@@ -706,7 +694,7 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
706
694
|
src_size = (2 * old_attn_size[0] - 1, 2 * old_attn_size[1] - 1)
|
|
707
695
|
|
|
708
696
|
def _calc(src: int, dst: int) -> list[float]:
|
|
709
|
-
|
|
697
|
+
left, right = 1.01, 1.5
|
|
710
698
|
while right - left > 1e-6:
|
|
711
699
|
q = (left + right) / 2.0
|
|
712
700
|
gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
|
birder/net/metaformer.py
CHANGED
|
@@ -127,10 +127,10 @@ class Attention(nn.Module):
|
|
|
127
127
|
|
|
128
128
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
129
129
|
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
130
|
-
|
|
130
|
+
B, H, W, _ = x.shape
|
|
131
131
|
N = H * W
|
|
132
132
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
133
|
-
|
|
133
|
+
q, k, v = qkv.unbind(0)
|
|
134
134
|
|
|
135
135
|
x = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
136
136
|
q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
|
birder/net/mim/crossmae.py
CHANGED
|
@@ -46,11 +46,11 @@ class CrossAttention(nn.Module):
|
|
|
46
46
|
self.proj = nn.Linear(decoder_dim, decoder_dim)
|
|
47
47
|
|
|
48
48
|
def forward(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
|
|
49
|
-
|
|
49
|
+
B, N, C = tgt.size()
|
|
50
50
|
n_kv = memory.size(1)
|
|
51
51
|
q = self.q(tgt).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
52
52
|
kv = self.kv(memory).reshape(B, n_kv, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
53
|
-
|
|
53
|
+
k, v = kv.unbind(0)
|
|
54
54
|
|
|
55
55
|
attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) # pylint: disable=not-callable
|
|
56
56
|
x = attn.transpose(1, 2).reshape(B, N, C)
|
|
@@ -120,7 +120,7 @@ class CrossMAE(MIMBaseNet):
|
|
|
120
120
|
self.decoder_layers.append(CrossAttentionBlock(encoder_dim, decoder_embed_dim, num_heads=16, mlp_ratio=4.0))
|
|
121
121
|
|
|
122
122
|
self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6)
|
|
123
|
-
self.pred = nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels
|
|
123
|
+
self.pred = nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels)
|
|
124
124
|
|
|
125
125
|
# Weight initialization
|
|
126
126
|
for m in self.modules():
|
|
@@ -170,7 +170,7 @@ class CrossMAE(MIMBaseNet):
|
|
|
170
170
|
return imgs
|
|
171
171
|
|
|
172
172
|
def fill_pred(self, mask: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
|
|
173
|
-
|
|
173
|
+
N, L = mask.shape[0:2]
|
|
174
174
|
combined = torch.zeros(N, L, pred.shape[2], device=pred.device, dtype=pred.dtype)
|
|
175
175
|
combined[mask.bool()] = pred.view(-1, pred.shape[2])
|
|
176
176
|
|
|
@@ -213,7 +213,7 @@ class CrossMAE(MIMBaseNet):
|
|
|
213
213
|
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
214
214
|
h = self.size[0] // self.encoder.stem_stride
|
|
215
215
|
w = self.size[1] // self.encoder.stem_stride
|
|
216
|
-
|
|
216
|
+
mask, ids_keep, _ = uniform_mask(
|
|
217
217
|
x.size(0), h, w, self.mask_ratio, self.kept_mask_ratio, min_mask_size=self.min_mask_size, device=x.device
|
|
218
218
|
)
|
|
219
219
|
|
birder/net/mim/fcmae.py
CHANGED
|
@@ -48,7 +48,6 @@ class FCMAE(MIMBaseNet):
|
|
|
48
48
|
kernel_size=(1, 1),
|
|
49
49
|
stride=(1, 1),
|
|
50
50
|
padding=(0, 0),
|
|
51
|
-
bias=True,
|
|
52
51
|
)
|
|
53
52
|
|
|
54
53
|
self.mask_token = nn.Parameter(torch.zeros(1, self.decoder_embed_dim, 1, 1))
|
|
@@ -65,7 +64,6 @@ class FCMAE(MIMBaseNet):
|
|
|
65
64
|
kernel_size=(1, 1),
|
|
66
65
|
stride=(1, 1),
|
|
67
66
|
padding=(0, 0),
|
|
68
|
-
bias=True,
|
|
69
67
|
)
|
|
70
68
|
|
|
71
69
|
# Weights initialization
|
|
@@ -106,7 +104,7 @@ class FCMAE(MIMBaseNet):
|
|
|
106
104
|
"""
|
|
107
105
|
|
|
108
106
|
if x.ndim == 4:
|
|
109
|
-
|
|
107
|
+
n, c, _, _ = x.shape
|
|
110
108
|
x = x.reshape(n, c, -1)
|
|
111
109
|
x = torch.einsum("ncl->nlc", x)
|
|
112
110
|
|
|
@@ -125,7 +123,7 @@ class FCMAE(MIMBaseNet):
|
|
|
125
123
|
x = self.proj(x)
|
|
126
124
|
|
|
127
125
|
# Append mask token
|
|
128
|
-
|
|
126
|
+
B, _, H, W = x.shape
|
|
129
127
|
mask = mask.reshape(-1, H, W).unsqueeze(1).type_as(x)
|
|
130
128
|
mask_token = self.mask_token.repeat(B, 1, H, W)
|
|
131
129
|
x = x * (1.0 - mask) + (mask_token * mask)
|
|
@@ -141,7 +139,7 @@ class FCMAE(MIMBaseNet):
|
|
|
141
139
|
mask: 0 is keep, 1 is remove
|
|
142
140
|
"""
|
|
143
141
|
|
|
144
|
-
|
|
142
|
+
n, c, _, _ = pred.shape
|
|
145
143
|
pred = pred.reshape(n, c, -1)
|
|
146
144
|
pred = torch.einsum("ncl->nlc", pred)
|
|
147
145
|
|
birder/net/mim/mae_hiera.py
CHANGED
|
@@ -26,7 +26,7 @@ def apply_fusion_head(head: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|
|
26
26
|
if isinstance(head, nn.Identity):
|
|
27
27
|
return x
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
B, num_mask_units = x.shape[0:2]
|
|
30
30
|
|
|
31
31
|
# Apply head, e.g [B, #MUs, My, Mx, C] -> head([B * #MUs, C, My, Mx])
|
|
32
32
|
permute = [0] + [len(x.shape) - 2] + list(range(1, len(x.shape) - 2))
|
|
@@ -169,7 +169,7 @@ class MAE_Hiera(MIMBaseNet):
|
|
|
169
169
|
|
|
170
170
|
def forward_encoder(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
171
171
|
# Tokens selected for masking at mask unit level
|
|
172
|
-
|
|
172
|
+
mask, _, _ = uniform_mask(
|
|
173
173
|
x.size(0),
|
|
174
174
|
self.encoder.mask_spatial_shape[0],
|
|
175
175
|
self.encoder.mask_spatial_shape[1],
|
|
@@ -179,7 +179,7 @@ class MAE_Hiera(MIMBaseNet):
|
|
|
179
179
|
)
|
|
180
180
|
|
|
181
181
|
# Get multi-scale representations from encoder
|
|
182
|
-
|
|
182
|
+
intermediates, mask = self.encoder.masked_encoding(x, mask)
|
|
183
183
|
|
|
184
184
|
# Resolution unchanged after q_pool stages, so skip those features
|
|
185
185
|
intermediates = intermediates[: self.encoder.q_pool] + intermediates[-1:]
|
|
@@ -206,12 +206,12 @@ class MAE_Hiera(MIMBaseNet):
|
|
|
206
206
|
# Get back spatial order
|
|
207
207
|
x = undo_windowing(
|
|
208
208
|
x_dec,
|
|
209
|
-
self.tokens_spatial_shape_final, # type:ignore[arg-type]
|
|
209
|
+
self.tokens_spatial_shape_final, # type: ignore[arg-type]
|
|
210
210
|
self.mask_unit_spatial_shape_final,
|
|
211
211
|
)
|
|
212
212
|
mask = undo_windowing(
|
|
213
213
|
mask[..., 0:1],
|
|
214
|
-
self.tokens_spatial_shape_final, # type:ignore[arg-type]
|
|
214
|
+
self.tokens_spatial_shape_final, # type: ignore[arg-type]
|
|
215
215
|
self.mask_unit_spatial_shape_final,
|
|
216
216
|
)
|
|
217
217
|
|
|
@@ -240,8 +240,8 @@ class MAE_Hiera(MIMBaseNet):
|
|
|
240
240
|
return loss.mean()
|
|
241
241
|
|
|
242
242
|
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
243
|
-
|
|
244
|
-
|
|
243
|
+
latent, mask = self.forward_encoder(x)
|
|
244
|
+
pred, pred_mask = self.forward_decoder(latent, mask)
|
|
245
245
|
loss = self.forward_loss(x, pred, ~pred_mask)
|
|
246
246
|
|
|
247
247
|
return {"loss": loss, "pred": pred, "mask": mask}
|
birder/net/mim/mae_vit.py
CHANGED
|
@@ -52,7 +52,7 @@ class MAE_ViT(MIMBaseNet):
|
|
|
52
52
|
|
|
53
53
|
self.norm_pix_loss = norm_pix_loss
|
|
54
54
|
|
|
55
|
-
self.decoder_embed = nn.Linear(encoder_dim, decoder_embed_dim
|
|
55
|
+
self.decoder_embed = nn.Linear(encoder_dim, decoder_embed_dim)
|
|
56
56
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
|
57
57
|
|
|
58
58
|
if learnable_pos_embed is True:
|
|
@@ -74,9 +74,7 @@ class MAE_ViT(MIMBaseNet):
|
|
|
74
74
|
layers.append(self.encoder.decoder_block(decoder_embed_dim))
|
|
75
75
|
|
|
76
76
|
layers.append(nn.LayerNorm(decoder_embed_dim, eps=1e-6))
|
|
77
|
-
layers.append(
|
|
78
|
-
nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels, bias=True)
|
|
79
|
-
) # Decoder to patch
|
|
77
|
+
layers.append(nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels)) # Decoder to patch
|
|
80
78
|
self.decoder = nn.Sequential(*layers)
|
|
81
79
|
|
|
82
80
|
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
|
@@ -153,7 +151,7 @@ class MAE_ViT(MIMBaseNet):
|
|
|
153
151
|
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
154
152
|
h = self.size[0] // self.encoder.max_stride
|
|
155
153
|
w = self.size[1] // self.encoder.max_stride
|
|
156
|
-
|
|
154
|
+
mask, ids_keep, ids_restore = uniform_mask(
|
|
157
155
|
x.size(0), h, w, self.mask_ratio, min_mask_size=self.min_mask_size, device=x.device
|
|
158
156
|
)
|
|
159
157
|
|