birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +13 -13
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +6 -6
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +4 -4
- birder/layers/attention_pool.py +2 -2
- birder/layers/layer_scale.py +1 -1
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +4 -10
- birder/net/_rope_vit_configs.py +435 -0
- birder/net/_vit_configs.py +466 -0
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +18 -17
- birder/net/cait.py +7 -7
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +27 -27
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +3 -11
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +6 -6
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +11 -11
- birder/net/deit.py +68 -29
- birder/net/deit3.py +69 -204
- birder/net/densenet.py +9 -8
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +31 -30
- birder/net/detection/detr.py +14 -11
- birder/net/detection/efficientdet.py +10 -29
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +5 -4
- birder/net/edgevit.py +13 -14
- birder/net/efficientformer_v1.py +3 -2
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +7 -7
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +3 -3
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +50 -58
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +13 -13
- birder/net/hgnet_v1.py +6 -6
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +5 -15
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +11 -23
- birder/net/metaformer.py +5 -5
- birder/net/mim/crossmae.py +6 -6
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +4 -6
- birder/net/mim/simmim.py +3 -4
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +7 -34
- birder/net/mobilevit_v2.py +6 -54
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +30 -30
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +9 -9
- birder/net/pvt_v2.py +10 -16
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +5 -35
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +4 -1
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +62 -151
- birder/net/rope_flexivit.py +46 -33
- birder/net/rope_vit.py +44 -758
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +69 -21
- birder/net/smt.py +8 -8
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +4 -4
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +13 -3
- birder/net/ssl/franca.py +28 -4
- birder/net/ssl/i_jepa.py +5 -5
- birder/net/ssl/ibot.py +1 -1
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +13 -3
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +5 -8
- birder/net/tiny_vit.py +6 -19
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/van.py +2 -2
- birder/net/vgg.py +1 -10
- birder/net/vit.py +72 -987
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +23 -48
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +16 -13
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +12 -3
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +15 -15
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- birder-0.3.3.dist-info/RECORD +0 -299
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/cait.py
CHANGED
|
@@ -47,7 +47,7 @@ class ClassAttention(nn.Module):
|
|
|
47
47
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
48
48
|
|
|
49
49
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
50
|
-
|
|
50
|
+
B, N, C = x.shape
|
|
51
51
|
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
52
52
|
k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
53
53
|
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
@@ -66,12 +66,12 @@ class ClassAttentionBlock(nn.Module):
|
|
|
66
66
|
self, dim: int, num_heads: int, mlp_ratio: float, qkv_bias: bool, proj_drop: float, drop_path: float, eta: float
|
|
67
67
|
) -> None:
|
|
68
68
|
super().__init__()
|
|
69
|
-
self.norm1 = nn.LayerNorm(dim)
|
|
69
|
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
70
70
|
|
|
71
71
|
self.attn = ClassAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=proj_drop)
|
|
72
72
|
|
|
73
73
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
74
|
-
self.norm2 = nn.LayerNorm(dim)
|
|
74
|
+
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
|
75
75
|
self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
|
|
76
76
|
|
|
77
77
|
self.gamma1 = nn.Parameter(eta * torch.ones(dim))
|
|
@@ -103,7 +103,7 @@ class TalkingHeadAttn(nn.Module):
|
|
|
103
103
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
104
104
|
|
|
105
105
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
106
|
-
|
|
106
|
+
B, N, C = x.shape
|
|
107
107
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
108
108
|
q = qkv[0] * self.scale
|
|
109
109
|
k = qkv[1]
|
|
@@ -135,7 +135,7 @@ class LayerScaleBlock(nn.Module):
|
|
|
135
135
|
init_values: float,
|
|
136
136
|
) -> None:
|
|
137
137
|
super().__init__()
|
|
138
|
-
self.norm1 = nn.LayerNorm(dim)
|
|
138
|
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
139
139
|
self.attn = TalkingHeadAttn(
|
|
140
140
|
dim,
|
|
141
141
|
num_heads=num_heads,
|
|
@@ -144,7 +144,7 @@ class LayerScaleBlock(nn.Module):
|
|
|
144
144
|
proj_drop=proj_drop,
|
|
145
145
|
)
|
|
146
146
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
147
|
-
self.norm2 = nn.LayerNorm(dim)
|
|
147
|
+
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
|
148
148
|
self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
|
|
149
149
|
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
|
|
150
150
|
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
|
|
@@ -221,7 +221,7 @@ class CaiT(BaseNet):
|
|
|
221
221
|
)
|
|
222
222
|
)
|
|
223
223
|
|
|
224
|
-
self.norm = nn.LayerNorm(embed_dim)
|
|
224
|
+
self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
|
|
225
225
|
|
|
226
226
|
self.embedding_size = embed_dim
|
|
227
227
|
self.classifier = self.create_classifier()
|
birder/net/cas_vit.py
CHANGED
|
@@ -122,7 +122,7 @@ class AdditiveTokenMixer(nn.Module):
|
|
|
122
122
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
123
123
|
|
|
124
124
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
125
|
-
|
|
125
|
+
q, k, v = self.qkv(x).chunk(3, dim=1)
|
|
126
126
|
q = self.op_q(q)
|
|
127
127
|
k = self.op_k(k)
|
|
128
128
|
|
birder/net/coat.py
CHANGED
|
@@ -21,7 +21,7 @@ from birder.net.base import DetectorBackbone
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def insert_cls(x: torch.Tensor, cls_token: torch.Tensor) -> torch.Tensor:
|
|
24
|
-
cls_tokens = cls_token.expand(x.
|
|
24
|
+
cls_tokens = cls_token.expand(x.size(0), -1, -1)
|
|
25
25
|
x = torch.concat((cls_tokens, x), dim=1)
|
|
26
26
|
|
|
27
27
|
return x
|
|
@@ -57,8 +57,8 @@ class ConvRelPosEnc(nn.Module):
|
|
|
57
57
|
self.channel_splits = [x * head_channels for x in head_splits]
|
|
58
58
|
|
|
59
59
|
def forward(self, q: torch.Tensor, v: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
B, num_heads, N, C = q.size()
|
|
61
|
+
H, W = size
|
|
62
62
|
torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
|
|
63
63
|
|
|
64
64
|
# Convolutional relative position encoding.
|
|
@@ -102,11 +102,11 @@ class FactorAttnConvRelPosEnc(nn.Module):
|
|
|
102
102
|
self.crpe = shared_crpe
|
|
103
103
|
|
|
104
104
|
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
105
|
-
|
|
105
|
+
B, N, C = x.size()
|
|
106
106
|
|
|
107
107
|
# Generate Q, K, V
|
|
108
108
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
109
|
-
|
|
109
|
+
q, k, v = qkv.unbind(0) # [B, h, N, Ch]
|
|
110
110
|
|
|
111
111
|
# Factorized attention
|
|
112
112
|
k_softmax = k.softmax(dim=2)
|
|
@@ -135,8 +135,8 @@ class ConvPosEnc(nn.Module):
|
|
|
135
135
|
)
|
|
136
136
|
|
|
137
137
|
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
B, N, C = x.size()
|
|
139
|
+
H, W = size
|
|
140
140
|
torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
|
|
141
141
|
|
|
142
142
|
# Extract CLS token and image tokens
|
|
@@ -170,7 +170,7 @@ class SerialBlock(nn.Module):
|
|
|
170
170
|
|
|
171
171
|
# Conv-attention
|
|
172
172
|
self.cpe = shared_cpe
|
|
173
|
-
self.norm1 = nn.LayerNorm(dim)
|
|
173
|
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
174
174
|
self.factor_attn_crpe = FactorAttnConvRelPosEnc(
|
|
175
175
|
dim,
|
|
176
176
|
num_heads=num_heads,
|
|
@@ -181,7 +181,7 @@ class SerialBlock(nn.Module):
|
|
|
181
181
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
182
182
|
|
|
183
183
|
# MLP
|
|
184
|
-
self.norm2 = nn.LayerNorm(dim)
|
|
184
|
+
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
|
185
185
|
self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
|
|
186
186
|
|
|
187
187
|
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
@@ -213,9 +213,9 @@ class ParallelBlock(nn.Module):
|
|
|
213
213
|
super().__init__()
|
|
214
214
|
|
|
215
215
|
# Conv-attention
|
|
216
|
-
self.norm12 = nn.LayerNorm(dims[1])
|
|
217
|
-
self.norm13 = nn.LayerNorm(dims[2])
|
|
218
|
-
self.norm14 = nn.LayerNorm(dims[3])
|
|
216
|
+
self.norm12 = nn.LayerNorm(dims[1], eps=1e-6)
|
|
217
|
+
self.norm13 = nn.LayerNorm(dims[2], eps=1e-6)
|
|
218
|
+
self.norm14 = nn.LayerNorm(dims[3], eps=1e-6)
|
|
219
219
|
self.factor_attn_crpe2 = FactorAttnConvRelPosEnc(
|
|
220
220
|
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=proj_drop, shared_crpe=shared_crpes[1]
|
|
221
221
|
)
|
|
@@ -228,9 +228,9 @@ class ParallelBlock(nn.Module):
|
|
|
228
228
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
229
229
|
|
|
230
230
|
# MLP
|
|
231
|
-
self.norm22 = nn.LayerNorm(dims[1])
|
|
232
|
-
self.norm23 = nn.LayerNorm(dims[2])
|
|
233
|
-
self.norm24 = nn.LayerNorm(dims[3])
|
|
231
|
+
self.norm22 = nn.LayerNorm(dims[1], eps=1e-6)
|
|
232
|
+
self.norm23 = nn.LayerNorm(dims[2], eps=1e-6)
|
|
233
|
+
self.norm24 = nn.LayerNorm(dims[3], eps=1e-6)
|
|
234
234
|
|
|
235
235
|
# In the parallel block, we assume dimensions are the same and share the linear transformation
|
|
236
236
|
assert dims[1] == dims[2] == dims[3]
|
|
@@ -244,8 +244,8 @@ class ParallelBlock(nn.Module):
|
|
|
244
244
|
return self.interpolate(x, scale_factor=1.0 / factor, size=size)
|
|
245
245
|
|
|
246
246
|
def interpolate(self, x: torch.Tensor, scale_factor: float, size: tuple[int, int]) -> torch.Tensor:
|
|
247
|
-
|
|
248
|
-
|
|
247
|
+
B, N, C = x.size()
|
|
248
|
+
H, W = size
|
|
249
249
|
torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
|
|
250
250
|
|
|
251
251
|
cls_token = x[:, :1, :]
|
|
@@ -268,7 +268,7 @@ class ParallelBlock(nn.Module):
|
|
|
268
268
|
def forward(
|
|
269
269
|
self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, sizes: list[tuple[int, int]]
|
|
270
270
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
271
|
-
|
|
271
|
+
_, s2, s3, s4 = sizes
|
|
272
272
|
cur2 = self.norm12(x2)
|
|
273
273
|
cur3 = self.norm13(x3)
|
|
274
274
|
cur4 = self.norm14(x4)
|
|
@@ -310,7 +310,7 @@ class PatchEmbed(nn.Module):
|
|
|
310
310
|
|
|
311
311
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
312
312
|
x = self.proj(x)
|
|
313
|
-
|
|
313
|
+
H, W = x.shape[2:4]
|
|
314
314
|
|
|
315
315
|
x = x.flatten(2).transpose(1, 2)
|
|
316
316
|
x = self.norm(x)
|
|
@@ -447,13 +447,13 @@ class CoaT(DetectorBackbone):
|
|
|
447
447
|
|
|
448
448
|
# Norms
|
|
449
449
|
if self.parallel_blocks is not None:
|
|
450
|
-
self.norm2 = nn.LayerNorm(embed_dims[1])
|
|
451
|
-
self.norm3 = nn.LayerNorm(embed_dims[2])
|
|
450
|
+
self.norm2 = nn.LayerNorm(embed_dims[1], eps=1e-6)
|
|
451
|
+
self.norm3 = nn.LayerNorm(embed_dims[2], eps=1e-6)
|
|
452
452
|
else:
|
|
453
453
|
self.norm2 = None
|
|
454
454
|
self.norm3 = None
|
|
455
455
|
|
|
456
|
-
self.norm4 = nn.LayerNorm(embed_dims[3])
|
|
456
|
+
self.norm4 = nn.LayerNorm(embed_dims[3], eps=1e-6)
|
|
457
457
|
|
|
458
458
|
# Head
|
|
459
459
|
if parallel_depth > 0:
|
|
@@ -500,7 +500,7 @@ class CoaT(DetectorBackbone):
|
|
|
500
500
|
B = x.shape[0]
|
|
501
501
|
|
|
502
502
|
# Serial blocks 1
|
|
503
|
-
|
|
503
|
+
x1, (h1, w1) = self.patch_embed1(x)
|
|
504
504
|
x1 = insert_cls(x1, self.cls_token1)
|
|
505
505
|
for blk in self.serial_blocks1:
|
|
506
506
|
x1 = blk(x1, size=(h1, w1))
|
|
@@ -508,7 +508,7 @@ class CoaT(DetectorBackbone):
|
|
|
508
508
|
x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
|
|
509
509
|
|
|
510
510
|
# Serial blocks 2
|
|
511
|
-
|
|
511
|
+
x2, (h2, w2) = self.patch_embed2(x1_no_cls)
|
|
512
512
|
x2 = insert_cls(x2, self.cls_token2)
|
|
513
513
|
for blk in self.serial_blocks2:
|
|
514
514
|
x2 = blk(x2, size=(h2, w2))
|
|
@@ -516,7 +516,7 @@ class CoaT(DetectorBackbone):
|
|
|
516
516
|
x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
|
|
517
517
|
|
|
518
518
|
# Serial blocks 3
|
|
519
|
-
|
|
519
|
+
x3, (h3, w3) = self.patch_embed3(x2_no_cls)
|
|
520
520
|
x3 = insert_cls(x3, self.cls_token3)
|
|
521
521
|
for blk in self.serial_blocks3:
|
|
522
522
|
x3 = blk(x3, size=(h3, w3))
|
|
@@ -524,7 +524,7 @@ class CoaT(DetectorBackbone):
|
|
|
524
524
|
x3_no_cls = remove_cls(x3).reshape(B, h3, w3, -1).permute(0, 3, 1, 2).contiguous()
|
|
525
525
|
|
|
526
526
|
# Serial blocks 4
|
|
527
|
-
|
|
527
|
+
x4, (h4, w4) = self.patch_embed4(x3_no_cls)
|
|
528
528
|
x4 = insert_cls(x4, self.cls_token4)
|
|
529
529
|
for blk in self.serial_blocks4:
|
|
530
530
|
x4 = blk(x4, size=(h4, w4))
|
|
@@ -537,7 +537,7 @@ class CoaT(DetectorBackbone):
|
|
|
537
537
|
x2 = self.cpe2(x2, (h2, w2))
|
|
538
538
|
x3 = self.cpe3(x3, (h3, w3))
|
|
539
539
|
x4 = self.cpe4(x4, (h4, w4))
|
|
540
|
-
|
|
540
|
+
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(h1, w1), (h2, w2), (h3, w3), (h4, w4)])
|
|
541
541
|
|
|
542
542
|
x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
|
|
543
543
|
x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
|
birder/net/conv2former.py
CHANGED
|
@@ -64,7 +64,7 @@ class SpatialAttention(nn.Module):
|
|
|
64
64
|
dim,
|
|
65
65
|
kernel_size=kernel_size,
|
|
66
66
|
stride=(1, 1),
|
|
67
|
-
padding=(kernel_size[0] // 2, kernel_size[1] // 2),
|
|
67
|
+
padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
|
|
68
68
|
groups=dim,
|
|
69
69
|
),
|
|
70
70
|
)
|
|
@@ -87,8 +87,8 @@ class Conv2FormerBlock(nn.Module):
|
|
|
87
87
|
self.mlp = MLP(dim, mlp_ratio)
|
|
88
88
|
|
|
89
89
|
layer_scale_init_value = 1e-6
|
|
90
|
-
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1))
|
|
91
|
-
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1))
|
|
90
|
+
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
|
|
91
|
+
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
|
|
92
92
|
|
|
93
93
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
94
94
|
x = x + self.drop_path(self.layer_scale_1 * self.attn(x))
|
birder/net/convmixer.py
CHANGED
birder/net/convnext_v1.py
CHANGED
|
@@ -37,15 +37,7 @@ class ConvNeXtBlock(nn.Module):
|
|
|
37
37
|
) -> None:
|
|
38
38
|
super().__init__()
|
|
39
39
|
self.block = nn.Sequential(
|
|
40
|
-
nn.Conv2d(
|
|
41
|
-
channels,
|
|
42
|
-
channels,
|
|
43
|
-
kernel_size=(7, 7),
|
|
44
|
-
stride=(1, 1),
|
|
45
|
-
padding=(3, 3),
|
|
46
|
-
groups=channels,
|
|
47
|
-
bias=True,
|
|
48
|
-
),
|
|
40
|
+
nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
|
|
49
41
|
Permute([0, 2, 3, 1]),
|
|
50
42
|
nn.LayerNorm(channels, eps=1e-6),
|
|
51
43
|
nn.Linear(channels, 4 * channels), # Same as 1x1 conv
|
|
@@ -53,7 +45,7 @@ class ConvNeXtBlock(nn.Module):
|
|
|
53
45
|
nn.Linear(4 * channels, channels), # Same as 1x1 conv
|
|
54
46
|
Permute([0, 3, 1, 2]),
|
|
55
47
|
)
|
|
56
|
-
self.layer_scale = nn.Parameter(torch.ones(channels, 1, 1) * layer_scale
|
|
48
|
+
self.layer_scale = nn.Parameter(torch.ones(channels, 1, 1) * layer_scale)
|
|
57
49
|
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
|
|
58
50
|
|
|
59
51
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -119,7 +111,7 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
119
111
|
layers.append(
|
|
120
112
|
nn.Sequential(
|
|
121
113
|
LayerNorm2d(i, eps=1e-6),
|
|
122
|
-
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)
|
|
114
|
+
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
|
|
123
115
|
)
|
|
124
116
|
)
|
|
125
117
|
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ConvNeXt v1 Isotropic, adapted from
|
|
3
|
+
https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext_isotropic.py
|
|
4
|
+
|
|
5
|
+
Paper "A ConvNet for the 2020s", https://arxiv.org/abs/2201.03545
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Reference license: MIT
|
|
9
|
+
|
|
10
|
+
from functools import partial
|
|
11
|
+
from typing import Any
|
|
12
|
+
from typing import Literal
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import nn
|
|
17
|
+
from torchvision.ops import Permute
|
|
18
|
+
from torchvision.ops import StochasticDepth
|
|
19
|
+
|
|
20
|
+
from birder.common.masking import mask_tensor
|
|
21
|
+
from birder.layers import LayerNorm2d
|
|
22
|
+
from birder.model_registry import registry
|
|
23
|
+
from birder.net.base import DetectorBackbone
|
|
24
|
+
from birder.net.base import MaskedTokenRetentionMixin
|
|
25
|
+
from birder.net.base import PreTrainEncoder
|
|
26
|
+
from birder.net.base import TokenRetentionResultType
|
|
27
|
+
from birder.net.base import normalize_out_indices
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ConvNeXtBlock(nn.Module):
|
|
31
|
+
def __init__(self, channels: int, stochastic_depth_prob: float) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.block = nn.Sequential(
|
|
34
|
+
nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
|
|
35
|
+
Permute([0, 2, 3, 1]),
|
|
36
|
+
nn.LayerNorm(channels, eps=1e-6),
|
|
37
|
+
nn.Linear(channels, 4 * channels),
|
|
38
|
+
nn.GELU(),
|
|
39
|
+
nn.Linear(4 * channels, channels),
|
|
40
|
+
Permute([0, 3, 1, 2]),
|
|
41
|
+
)
|
|
42
|
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
|
|
43
|
+
|
|
44
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
identity = x
|
|
46
|
+
x = self.block(x)
|
|
47
|
+
x = self.stochastic_depth(x)
|
|
48
|
+
x += identity
|
|
49
|
+
|
|
50
|
+
return x
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# pylint: disable=invalid-name
|
|
54
|
+
class ConvNeXt_v1_Isotropic(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
55
|
+
block_group_regex = r"body\.(\d+)"
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
input_channels: int,
|
|
60
|
+
num_classes: int,
|
|
61
|
+
*,
|
|
62
|
+
config: Optional[dict[str, Any]] = None,
|
|
63
|
+
size: Optional[tuple[int, int]] = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
super().__init__(input_channels, num_classes, config=config, size=size)
|
|
66
|
+
assert self.config is not None, "must set config"
|
|
67
|
+
|
|
68
|
+
patch_size = 16
|
|
69
|
+
dim: int = self.config["dim"]
|
|
70
|
+
num_layers: int = self.config["num_layers"]
|
|
71
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
72
|
+
drop_path_rate: float = self.config["drop_path_rate"]
|
|
73
|
+
|
|
74
|
+
torch._assert(self.size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
75
|
+
torch._assert(self.size[1] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
76
|
+
self.patch_size = patch_size
|
|
77
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
78
|
+
|
|
79
|
+
self.stem = nn.Conv2d(
|
|
80
|
+
self.input_channels,
|
|
81
|
+
dim,
|
|
82
|
+
kernel_size=(patch_size, patch_size),
|
|
83
|
+
stride=(patch_size, patch_size),
|
|
84
|
+
padding=(0, 0),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
layers = []
|
|
88
|
+
for idx in range(num_layers):
|
|
89
|
+
# Adjust stochastic depth probability based on the depth of the stage block
|
|
90
|
+
sd_prob = drop_path_rate * idx / (num_layers - 1.0)
|
|
91
|
+
layers.append(ConvNeXtBlock(dim, sd_prob))
|
|
92
|
+
|
|
93
|
+
self.body = nn.Sequential(*layers)
|
|
94
|
+
self.features = nn.Sequential(
|
|
95
|
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
96
|
+
LayerNorm2d(dim, eps=1e-6),
|
|
97
|
+
nn.Flatten(1),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
101
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
102
|
+
self.return_channels = [dim] * num_return_stages
|
|
103
|
+
self.embedding_size = dim
|
|
104
|
+
self.classifier = self.create_classifier()
|
|
105
|
+
|
|
106
|
+
self.max_stride = patch_size
|
|
107
|
+
self.stem_stride = patch_size
|
|
108
|
+
self.stem_width = dim
|
|
109
|
+
self.encoding_size = dim
|
|
110
|
+
self.decoder_block = partial(ConvNeXtBlock, stochastic_depth_prob=0)
|
|
111
|
+
|
|
112
|
+
# Weights initialization
|
|
113
|
+
for m in self.modules():
|
|
114
|
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
115
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
116
|
+
if m.bias is not None:
|
|
117
|
+
nn.init.zeros_(m.bias)
|
|
118
|
+
|
|
119
|
+
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
120
|
+
x = self.stem(x)
|
|
121
|
+
|
|
122
|
+
if self.out_indices is None:
|
|
123
|
+
x = self.body(x)
|
|
124
|
+
return {self.return_stages[0]: x}
|
|
125
|
+
|
|
126
|
+
stage_num = 0
|
|
127
|
+
out: dict[str, torch.Tensor] = {}
|
|
128
|
+
for idx, module in enumerate(self.body.children()):
|
|
129
|
+
x = module(x)
|
|
130
|
+
if idx in self.out_indices:
|
|
131
|
+
out[self.return_stages[stage_num]] = x
|
|
132
|
+
stage_num += 1
|
|
133
|
+
|
|
134
|
+
return out
|
|
135
|
+
|
|
136
|
+
def freeze_stages(self, up_to_stage: int) -> None:
|
|
137
|
+
for param in self.stem.parameters():
|
|
138
|
+
param.requires_grad_(False)
|
|
139
|
+
|
|
140
|
+
for idx, module in enumerate(self.body.children()):
|
|
141
|
+
if idx >= up_to_stage:
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
for param in module.parameters():
|
|
145
|
+
param.requires_grad_(False)
|
|
146
|
+
|
|
147
|
+
def masked_encoding_retention(
|
|
148
|
+
self,
|
|
149
|
+
x: torch.Tensor,
|
|
150
|
+
mask: torch.Tensor,
|
|
151
|
+
mask_token: Optional[torch.Tensor] = None,
|
|
152
|
+
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
153
|
+
) -> TokenRetentionResultType:
|
|
154
|
+
x = self.stem(x)
|
|
155
|
+
x = mask_tensor(x, mask, patch_factor=self.max_stride // self.stem_stride, mask_token=mask_token)
|
|
156
|
+
x = self.body(x)
|
|
157
|
+
|
|
158
|
+
result: TokenRetentionResultType = {}
|
|
159
|
+
if return_keys in ("all", "features"):
|
|
160
|
+
result["features"] = x
|
|
161
|
+
if return_keys in ("all", "embedding"):
|
|
162
|
+
result["embedding"] = self.features(x)
|
|
163
|
+
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
167
|
+
x = self.stem(x)
|
|
168
|
+
return self.body(x)
|
|
169
|
+
|
|
170
|
+
def embedding(self, x: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
x = self.forward_features(x)
|
|
172
|
+
return self.features(x)
|
|
173
|
+
|
|
174
|
+
def adjust_size(self, new_size: tuple[int, int]) -> None:
|
|
175
|
+
if new_size == self.size:
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
assert new_size[0] % self.patch_size == 0, "Input shape indivisible by patch size!"
|
|
179
|
+
assert new_size[1] % self.patch_size == 0, "Input shape indivisible by patch size!"
|
|
180
|
+
|
|
181
|
+
super().adjust_size(new_size)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
registry.register_model_config(
|
|
185
|
+
"convnext_v1_iso_small",
|
|
186
|
+
ConvNeXt_v1_Isotropic,
|
|
187
|
+
config={"dim": 384, "num_layers": 18, "drop_path_rate": 0.1},
|
|
188
|
+
)
|
|
189
|
+
registry.register_model_config(
|
|
190
|
+
"convnext_v1_iso_base",
|
|
191
|
+
ConvNeXt_v1_Isotropic,
|
|
192
|
+
config={"in_channels": 768, "num_layers": 18, "drop_path_rate": 0.2},
|
|
193
|
+
)
|
|
194
|
+
registry.register_model_config(
|
|
195
|
+
"convnext_v1_iso_large",
|
|
196
|
+
ConvNeXt_v1_Isotropic,
|
|
197
|
+
config={"in_channels": 1024, "num_layers": 36, "drop_path_rate": 0.5},
|
|
198
|
+
)
|
birder/net/convnext_v2.py
CHANGED
|
@@ -56,15 +56,7 @@ class ConvNeXtBlock(nn.Module):
|
|
|
56
56
|
) -> None:
|
|
57
57
|
super().__init__()
|
|
58
58
|
self.block = nn.Sequential(
|
|
59
|
-
nn.Conv2d(
|
|
60
|
-
channels,
|
|
61
|
-
channels,
|
|
62
|
-
kernel_size=(7, 7),
|
|
63
|
-
stride=(1, 1),
|
|
64
|
-
padding=(3, 3),
|
|
65
|
-
groups=channels,
|
|
66
|
-
bias=True,
|
|
67
|
-
),
|
|
59
|
+
nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
|
|
68
60
|
Permute([0, 2, 3, 1]),
|
|
69
61
|
nn.LayerNorm(channels, eps=1e-6),
|
|
70
62
|
nn.Linear(channels, 4 * channels), # Same as 1x1 conv
|
|
@@ -137,7 +129,7 @@ class ConvNeXt_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
137
129
|
layers.append(
|
|
138
130
|
nn.Sequential(
|
|
139
131
|
LayerNorm2d(i, eps=1e-6),
|
|
140
|
-
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)
|
|
132
|
+
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
|
|
141
133
|
)
|
|
142
134
|
)
|
|
143
135
|
|
birder/net/crossformer.py
CHANGED
|
@@ -120,9 +120,9 @@ class Attention(nn.Module):
|
|
|
120
120
|
self.relative_position_index = nn.Buffer(relative_position_index)
|
|
121
121
|
|
|
122
122
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
123
|
-
|
|
123
|
+
B, N, C = x.size()
|
|
124
124
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
125
|
-
|
|
125
|
+
q, k, v = qkv.unbind(0)
|
|
126
126
|
|
|
127
127
|
q = q * self.scale
|
|
128
128
|
attn = q @ k.transpose(-2, -1)
|
|
@@ -188,15 +188,15 @@ class CrossFormerBlock(nn.Module):
|
|
|
188
188
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
189
189
|
|
|
190
190
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
191
|
-
|
|
192
|
-
|
|
191
|
+
H, W = self.input_resolution
|
|
192
|
+
B, _, C = x.size()
|
|
193
193
|
|
|
194
194
|
shortcut = x
|
|
195
195
|
x = self.norm1(x)
|
|
196
196
|
x = x.view(B, H, W, C)
|
|
197
197
|
|
|
198
198
|
# Group embeddings
|
|
199
|
-
|
|
199
|
+
GH, GW = self.group_size # pylint: disable=invalid-name
|
|
200
200
|
if self.use_lda is False:
|
|
201
201
|
x = x.reshape(B, H // GH, GH, W // GW, GW, C).permute(0, 1, 3, 2, 4, 5)
|
|
202
202
|
else:
|
|
@@ -244,8 +244,8 @@ class PatchMerging(nn.Module):
|
|
|
244
244
|
)
|
|
245
245
|
|
|
246
246
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
247
|
-
|
|
248
|
-
|
|
247
|
+
H, W = self.input_resolution
|
|
248
|
+
B, _, C = x.shape
|
|
249
249
|
|
|
250
250
|
x = self.norm(x)
|
|
251
251
|
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
|
|
@@ -396,8 +396,8 @@ class CrossFormer(DetectorBackbone):
|
|
|
396
396
|
for name, module in self.body.named_children():
|
|
397
397
|
x = module(x)
|
|
398
398
|
if name in self.return_stages:
|
|
399
|
-
|
|
400
|
-
|
|
399
|
+
H, W = module.resolution
|
|
400
|
+
B, _, C = x.size()
|
|
401
401
|
out[name] = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
402
402
|
|
|
403
403
|
return out
|
birder/net/crossvit.py
CHANGED
|
@@ -74,7 +74,7 @@ class CrossAttention(nn.Module):
|
|
|
74
74
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
75
75
|
|
|
76
76
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
77
|
-
|
|
77
|
+
B, N, C = x.shape
|
|
78
78
|
# B1C -> B1H(C/H) -> BH1(C/H)
|
|
79
79
|
q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
80
80
|
# BNC -> BNH(C/H) -> BHN(C/H)
|
|
@@ -97,7 +97,7 @@ class CrossAttentionBlock(nn.Module):
|
|
|
97
97
|
self, dim: int, num_heads: int, qkv_bias: bool, proj_drop: float, attn_drop: float, drop_path: float
|
|
98
98
|
) -> None:
|
|
99
99
|
super().__init__()
|
|
100
|
-
self.norm1 = nn.LayerNorm(dim)
|
|
100
|
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
101
101
|
self.attn = CrossAttention(
|
|
102
102
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop
|
|
103
103
|
)
|
|
@@ -146,7 +146,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
146
146
|
for d in range(num_branches):
|
|
147
147
|
self.projs.append(
|
|
148
148
|
nn.Sequential(
|
|
149
|
-
nn.LayerNorm(dim[d]),
|
|
149
|
+
nn.LayerNorm(dim[d], eps=1e-6),
|
|
150
150
|
nn.GELU(),
|
|
151
151
|
nn.Linear(dim[d], dim[(d + 1) % num_branches]),
|
|
152
152
|
)
|
|
@@ -187,7 +187,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
187
187
|
for d in range(num_branches):
|
|
188
188
|
self.revert_projs.append(
|
|
189
189
|
nn.Sequential(
|
|
190
|
-
nn.LayerNorm(dim[(d + 1) % num_branches]),
|
|
190
|
+
nn.LayerNorm(dim[(d + 1) % num_branches], eps=1e-6),
|
|
191
191
|
nn.GELU(),
|
|
192
192
|
nn.Linear(dim[(d + 1) % num_branches], dim[d]),
|
|
193
193
|
)
|
|
@@ -290,7 +290,7 @@ class CrossViT(BaseNet):
|
|
|
290
290
|
dpr_ptr += curr_depth
|
|
291
291
|
self.blocks.append(block)
|
|
292
292
|
|
|
293
|
-
self.norm = nn.ModuleList([nn.LayerNorm(embed_dim[i]) for i in range(self.num_branches)])
|
|
293
|
+
self.norm = nn.ModuleList([nn.LayerNorm(embed_dim[i], eps=1e-6) for i in range(self.num_branches)])
|
|
294
294
|
self.embedding_size = sum(self.embed_dim)
|
|
295
295
|
self.classifier = nn.ModuleList()
|
|
296
296
|
for i in range(self.num_branches):
|
|
@@ -482,7 +482,7 @@ registry.register_weights(
|
|
|
482
482
|
"formats": {
|
|
483
483
|
"pt": {
|
|
484
484
|
"file_size": 32.7,
|
|
485
|
-
"sha256": "
|
|
485
|
+
"sha256": "08f674d8165dc97cc535f8188a5c5361751a8d0bb85061454986a21541a6fe8e",
|
|
486
486
|
}
|
|
487
487
|
},
|
|
488
488
|
"net": {"network": "crossvit_9d", "tag": "il-common"},
|
birder/net/cspnet.py
CHANGED
|
@@ -226,7 +226,7 @@ class CrossStage(nn.Module):
|
|
|
226
226
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
227
227
|
x = self.conv_down(x)
|
|
228
228
|
x = self.conv_exp(x)
|
|
229
|
-
|
|
229
|
+
xs, xb = x.split(self.expand_channels // 2, dim=1)
|
|
230
230
|
xb = self.blocks(xb)
|
|
231
231
|
xb = self.conv_transition_b(xb).contiguous()
|
|
232
232
|
out = self.conv_transition(torch.concat([xs, xb], dim=1))
|