birder 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/fs_ops.py +2 -2
- birder/introspection/attention_rollout.py +1 -1
- birder/introspection/transformer_attribution.py +1 -1
- birder/layers/layer_scale.py +1 -1
- birder/net/__init__.py +2 -10
- birder/net/_rope_vit_configs.py +430 -0
- birder/net/_vit_configs.py +479 -0
- birder/net/biformer.py +1 -0
- birder/net/cait.py +5 -5
- birder/net/coat.py +12 -12
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +1 -1
- birder/net/crossvit.py +5 -5
- birder/net/davit.py +1 -1
- birder/net/deit.py +12 -26
- birder/net/deit3.py +42 -189
- birder/net/densenet.py +9 -8
- birder/net/detection/deformable_detr.py +5 -2
- birder/net/detection/detr.py +5 -2
- birder/net/detection/efficientdet.py +1 -1
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +2 -1
- birder/net/edgevit.py +3 -0
- birder/net/efficientformer_v1.py +2 -1
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvit_mit.py +5 -5
- birder/net/fasternet.py +2 -2
- birder/net/flexivit.py +22 -43
- birder/net/groupmixformer.py +1 -1
- birder/net/hgnet_v1.py +5 -5
- birder/net/inception_next.py +1 -1
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/maxvit.py +1 -1
- birder/net/metaformer.py +3 -3
- birder/net/mim/crossmae.py +1 -1
- birder/net/mim/mae_vit.py +1 -1
- birder/net/mim/simmim.py +1 -1
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilevit_v1.py +5 -32
- birder/net/mobilevit_v2.py +1 -45
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +6 -6
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +1 -1
- birder/net/pvt_v1.py +5 -5
- birder/net/pvt_v2.py +5 -5
- birder/net/repghost.py +1 -30
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +3 -0
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +33 -136
- birder/net/rope_flexivit.py +18 -18
- birder/net/rope_vit.py +3 -735
- birder/net/simple_vit.py +22 -16
- birder/net/smt.py +1 -1
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/capi.py +1 -1
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/dino_v2.py +2 -2
- birder/net/ssl/franca.py +2 -2
- birder/net/ssl/i_jepa.py +1 -1
- birder/net/ssl/ibot.py +1 -1
- birder/net/swiftformer.py +12 -2
- birder/net/swin_transformer_v2.py +1 -1
- birder/net/tiny_vit.py +3 -16
- birder/net/van.py +2 -2
- birder/net/vit.py +35 -963
- birder/net/vit_sam.py +13 -38
- birder/net/xcit.py +7 -6
- birder/tools/introspection.py +1 -1
- birder/tools/model_info.py +3 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/METADATA +1 -1
- {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/RECORD +88 -90
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/WHEEL +0 -0
- {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/top_level.txt +0 -0
birder/net/simple_vit.py
CHANGED
|
@@ -20,6 +20,12 @@ import torch
|
|
|
20
20
|
from torch import nn
|
|
21
21
|
|
|
22
22
|
from birder.model_registry import registry
|
|
23
|
+
from birder.net._vit_configs import BASE
|
|
24
|
+
from birder.net._vit_configs import GIANT
|
|
25
|
+
from birder.net._vit_configs import HUGE
|
|
26
|
+
from birder.net._vit_configs import LARGE
|
|
27
|
+
from birder.net._vit_configs import MEDIUM
|
|
28
|
+
from birder.net._vit_configs import SMALL
|
|
23
29
|
from birder.net.base import MaskedTokenOmissionMixin
|
|
24
30
|
from birder.net.base import PreTrainEncoder
|
|
25
31
|
from birder.net.base import TokenOmissionResultType
|
|
@@ -45,12 +51,12 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
45
51
|
assert self.config is not None, "must set config"
|
|
46
52
|
|
|
47
53
|
image_size = self.size
|
|
48
|
-
drop_path_rate = 0.0
|
|
49
54
|
patch_size: int = self.config["patch_size"]
|
|
50
55
|
num_layers: int = self.config["num_layers"]
|
|
51
56
|
num_heads: int = self.config["num_heads"]
|
|
52
57
|
hidden_dim: int = self.config["hidden_dim"]
|
|
53
58
|
mlp_dim: int = self.config["mlp_dim"]
|
|
59
|
+
drop_path_rate: float = self.config["drop_path_rate"]
|
|
54
60
|
|
|
55
61
|
torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
56
62
|
torch._assert(image_size[1] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
@@ -215,75 +221,75 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
215
221
|
registry.register_model_config(
|
|
216
222
|
"simple_vit_s32",
|
|
217
223
|
Simple_ViT,
|
|
218
|
-
config={"patch_size": 32,
|
|
224
|
+
config={"patch_size": 32, **SMALL},
|
|
219
225
|
)
|
|
220
226
|
registry.register_model_config(
|
|
221
227
|
"simple_vit_s16",
|
|
222
228
|
Simple_ViT,
|
|
223
|
-
config={"patch_size": 16,
|
|
229
|
+
config={"patch_size": 16, **SMALL},
|
|
224
230
|
)
|
|
225
231
|
registry.register_model_config(
|
|
226
232
|
"simple_vit_s14",
|
|
227
233
|
Simple_ViT,
|
|
228
|
-
config={"patch_size": 14,
|
|
234
|
+
config={"patch_size": 14, **SMALL},
|
|
229
235
|
)
|
|
230
236
|
registry.register_model_config(
|
|
231
237
|
"simple_vit_m32",
|
|
232
238
|
Simple_ViT,
|
|
233
|
-
config={"patch_size": 32,
|
|
239
|
+
config={"patch_size": 32, **MEDIUM},
|
|
234
240
|
)
|
|
235
241
|
registry.register_model_config(
|
|
236
242
|
"simple_vit_m16",
|
|
237
243
|
Simple_ViT,
|
|
238
|
-
config={"patch_size": 16,
|
|
244
|
+
config={"patch_size": 16, **MEDIUM},
|
|
239
245
|
)
|
|
240
246
|
registry.register_model_config(
|
|
241
247
|
"simple_vit_m14",
|
|
242
248
|
Simple_ViT,
|
|
243
|
-
config={"patch_size": 14,
|
|
249
|
+
config={"patch_size": 14, **MEDIUM},
|
|
244
250
|
)
|
|
245
251
|
registry.register_model_config(
|
|
246
252
|
"simple_vit_b32",
|
|
247
253
|
Simple_ViT,
|
|
248
|
-
config={"patch_size": 32,
|
|
254
|
+
config={"patch_size": 32, **BASE}, # Override the BASE definition
|
|
249
255
|
)
|
|
250
256
|
registry.register_model_config(
|
|
251
257
|
"simple_vit_b16",
|
|
252
258
|
Simple_ViT,
|
|
253
|
-
config={"patch_size": 16,
|
|
259
|
+
config={"patch_size": 16, **BASE},
|
|
254
260
|
)
|
|
255
261
|
registry.register_model_config(
|
|
256
262
|
"simple_vit_b14",
|
|
257
263
|
Simple_ViT,
|
|
258
|
-
config={"patch_size": 14,
|
|
264
|
+
config={"patch_size": 14, **BASE},
|
|
259
265
|
)
|
|
260
266
|
registry.register_model_config(
|
|
261
267
|
"simple_vit_l32",
|
|
262
268
|
Simple_ViT,
|
|
263
|
-
config={"patch_size": 32,
|
|
269
|
+
config={"patch_size": 32, **LARGE},
|
|
264
270
|
)
|
|
265
271
|
registry.register_model_config(
|
|
266
272
|
"simple_vit_l16",
|
|
267
273
|
Simple_ViT,
|
|
268
|
-
config={"patch_size": 16,
|
|
274
|
+
config={"patch_size": 16, **LARGE},
|
|
269
275
|
)
|
|
270
276
|
registry.register_model_config(
|
|
271
277
|
"simple_vit_l14",
|
|
272
278
|
Simple_ViT,
|
|
273
|
-
config={"patch_size": 14,
|
|
279
|
+
config={"patch_size": 14, **LARGE},
|
|
274
280
|
)
|
|
275
281
|
registry.register_model_config(
|
|
276
282
|
"simple_vit_h16",
|
|
277
283
|
Simple_ViT,
|
|
278
|
-
config={"patch_size": 16,
|
|
284
|
+
config={"patch_size": 16, **HUGE},
|
|
279
285
|
)
|
|
280
286
|
registry.register_model_config(
|
|
281
287
|
"simple_vit_h14",
|
|
282
288
|
Simple_ViT,
|
|
283
|
-
config={"patch_size": 14,
|
|
289
|
+
config={"patch_size": 14, **HUGE},
|
|
284
290
|
)
|
|
285
291
|
registry.register_model_config( # From "Scaling Vision Transformers"
|
|
286
292
|
"simple_vit_g14",
|
|
287
293
|
Simple_ViT,
|
|
288
|
-
config={"patch_size": 14,
|
|
294
|
+
config={"patch_size": 14, **GIANT},
|
|
289
295
|
)
|
birder/net/smt.py
CHANGED
|
@@ -259,7 +259,7 @@ class Stem(nn.Module):
|
|
|
259
259
|
embed_dim,
|
|
260
260
|
kernel_size=kernel_size,
|
|
261
261
|
stride=stride,
|
|
262
|
-
padding=(kernel_size[0] // 2, kernel_size[1] // 2),
|
|
262
|
+
padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
|
|
263
263
|
),
|
|
264
264
|
nn.Conv2d(embed_dim, embed_dim, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
|
|
265
265
|
)
|
birder/net/squeezenet.py
CHANGED
|
@@ -20,11 +20,11 @@ from birder.net.base import BaseNet
|
|
|
20
20
|
class Fire(nn.Module):
|
|
21
21
|
def __init__(self, in_planes: int, squeeze: int, expand: int) -> None:
|
|
22
22
|
super().__init__()
|
|
23
|
-
self.squeeze = nn.Conv2d(in_planes, squeeze, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
23
|
+
self.squeeze = nn.Conv2d(in_planes, squeeze, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
24
24
|
self.squeeze_activation = nn.ReLU(inplace=True)
|
|
25
|
-
self.left = nn.Conv2d(squeeze, expand, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
25
|
+
self.left = nn.Conv2d(squeeze, expand, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
26
26
|
self.left_activation = nn.ReLU(inplace=True)
|
|
27
|
-
self.right = nn.Conv2d(squeeze, expand, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
|
27
|
+
self.right = nn.Conv2d(squeeze, expand, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
|
28
28
|
self.right_activation = nn.ReLU(inplace=True)
|
|
29
29
|
|
|
30
30
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -53,7 +53,7 @@ class SqueezeNet(BaseNet):
|
|
|
53
53
|
assert self.config is None, "config not supported"
|
|
54
54
|
|
|
55
55
|
self.stem = nn.Sequential(
|
|
56
|
-
nn.Conv2d(self.input_channels, 64, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)
|
|
56
|
+
nn.Conv2d(self.input_channels, 64, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
|
|
57
57
|
nn.ReLU(inplace=True),
|
|
58
58
|
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=True),
|
|
59
59
|
)
|
|
@@ -94,14 +94,7 @@ class SqueezeNet(BaseNet):
|
|
|
94
94
|
|
|
95
95
|
return nn.Sequential(
|
|
96
96
|
nn.Dropout(p=0.5, inplace=True),
|
|
97
|
-
nn.Conv2d(
|
|
98
|
-
embed_dim,
|
|
99
|
-
self.num_classes,
|
|
100
|
-
kernel_size=(1, 1),
|
|
101
|
-
stride=(1, 1),
|
|
102
|
-
padding=(0, 0),
|
|
103
|
-
bias=False,
|
|
104
|
-
),
|
|
97
|
+
nn.Conv2d(embed_dim, self.num_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
|
|
105
98
|
nn.ReLU(inplace=True),
|
|
106
99
|
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
107
100
|
nn.Flatten(1),
|
birder/net/squeezenext.py
CHANGED
|
@@ -27,7 +27,6 @@ class SqnxtUnit(nn.Module):
|
|
|
27
27
|
kernel_size=(1, 1),
|
|
28
28
|
stride=(stride, stride),
|
|
29
29
|
padding=(0, 0),
|
|
30
|
-
bias=False,
|
|
31
30
|
)
|
|
32
31
|
|
|
33
32
|
elif in_channels > out_channels:
|
|
@@ -38,7 +37,6 @@ class SqnxtUnit(nn.Module):
|
|
|
38
37
|
kernel_size=(1, 1),
|
|
39
38
|
stride=(stride, stride),
|
|
40
39
|
padding=(0, 0),
|
|
41
|
-
bias=False,
|
|
42
40
|
)
|
|
43
41
|
|
|
44
42
|
else:
|
|
@@ -52,7 +50,6 @@ class SqnxtUnit(nn.Module):
|
|
|
52
50
|
kernel_size=(1, 1),
|
|
53
51
|
stride=(stride, stride),
|
|
54
52
|
padding=(0, 0),
|
|
55
|
-
bias=False,
|
|
56
53
|
),
|
|
57
54
|
Conv2dNormActivation(
|
|
58
55
|
in_channels // reduction,
|
|
@@ -60,7 +57,6 @@ class SqnxtUnit(nn.Module):
|
|
|
60
57
|
kernel_size=(1, 1),
|
|
61
58
|
stride=(1, 1),
|
|
62
59
|
padding=(0, 0),
|
|
63
|
-
bias=False,
|
|
64
60
|
),
|
|
65
61
|
Conv2dNormActivation(
|
|
66
62
|
in_channels // (2 * reduction),
|
|
@@ -68,7 +64,6 @@ class SqnxtUnit(nn.Module):
|
|
|
68
64
|
kernel_size=(1, 3),
|
|
69
65
|
stride=(1, 1),
|
|
70
66
|
padding=(0, 1),
|
|
71
|
-
bias=False,
|
|
72
67
|
),
|
|
73
68
|
Conv2dNormActivation(
|
|
74
69
|
in_channels // reduction,
|
|
@@ -76,7 +71,6 @@ class SqnxtUnit(nn.Module):
|
|
|
76
71
|
kernel_size=(3, 1),
|
|
77
72
|
stride=(1, 1),
|
|
78
73
|
padding=(1, 0),
|
|
79
|
-
bias=False,
|
|
80
74
|
),
|
|
81
75
|
Conv2dNormActivation(
|
|
82
76
|
in_channels // reduction,
|
|
@@ -84,7 +78,6 @@ class SqnxtUnit(nn.Module):
|
|
|
84
78
|
kernel_size=(1, 1),
|
|
85
79
|
stride=(1, 1),
|
|
86
80
|
padding=(0, 0),
|
|
87
|
-
bias=False,
|
|
88
81
|
),
|
|
89
82
|
)
|
|
90
83
|
self.relu = nn.ReLU(inplace=True)
|
|
@@ -124,7 +117,6 @@ class SqueezeNext(DetectorBackbone):
|
|
|
124
117
|
kernel_size=(7, 7),
|
|
125
118
|
stride=(2, 2),
|
|
126
119
|
padding=(1, 1),
|
|
127
|
-
bias=False,
|
|
128
120
|
),
|
|
129
121
|
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=True),
|
|
130
122
|
)
|
|
@@ -155,7 +147,6 @@ class SqueezeNext(DetectorBackbone):
|
|
|
155
147
|
kernel_size=(1, 1),
|
|
156
148
|
stride=(1, 1),
|
|
157
149
|
padding=(0, 0),
|
|
158
|
-
bias=False,
|
|
159
150
|
),
|
|
160
151
|
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
161
152
|
nn.Flatten(1),
|
|
@@ -199,18 +190,3 @@ registry.register_model_config("squeezenext_0_5", SqueezeNext, config={"width_sc
|
|
|
199
190
|
registry.register_model_config("squeezenext_1_0", SqueezeNext, config={"width_scale": 1.0})
|
|
200
191
|
registry.register_model_config("squeezenext_1_5", SqueezeNext, config={"width_scale": 1.5})
|
|
201
192
|
registry.register_model_config("squeezenext_2_0", SqueezeNext, config={"width_scale": 2.0})
|
|
202
|
-
|
|
203
|
-
registry.register_weights(
|
|
204
|
-
"squeezenext_1_0_il-common",
|
|
205
|
-
{
|
|
206
|
-
"description": "SqueezeNext v2 1.0x output channels model trained on the il-common dataset",
|
|
207
|
-
"resolution": (259, 259),
|
|
208
|
-
"formats": {
|
|
209
|
-
"pt": {
|
|
210
|
-
"file_size": 3.5,
|
|
211
|
-
"sha256": "da01d1cd05c71b80b5e4e6ca66400f64fa3f6179d0e90834c4f6942c8095557a",
|
|
212
|
-
}
|
|
213
|
-
},
|
|
214
|
-
"net": {"network": "squeezenext_1_0", "tag": "il-common"},
|
|
215
|
-
},
|
|
216
|
-
)
|
birder/net/ssl/capi.py
CHANGED
|
@@ -306,7 +306,7 @@ class Decoder(nn.Module):
|
|
|
306
306
|
dim=decoder_embed_dim,
|
|
307
307
|
num_special_tokens=0,
|
|
308
308
|
).unsqueeze(0)
|
|
309
|
-
self.decoder_pos_embed = nn.
|
|
309
|
+
self.decoder_pos_embed = nn.Buffer(pos_embedding)
|
|
310
310
|
|
|
311
311
|
self.decoder_layers = nn.ModuleList()
|
|
312
312
|
for _ in range(decoder_depth):
|
birder/net/ssl/data2vec.py
CHANGED
|
@@ -51,7 +51,7 @@ class Data2Vec(SSLBaseNet):
|
|
|
51
51
|
self.ema_backbone = copy.deepcopy(self.backbone)
|
|
52
52
|
self.head = nn.Linear(self.backbone.embedding_size, self.backbone.embedding_size)
|
|
53
53
|
|
|
54
|
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width)
|
|
54
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
|
|
55
55
|
|
|
56
56
|
# Weights initialization
|
|
57
57
|
self.ema_backbone.load_state_dict(self.backbone.state_dict())
|
birder/net/ssl/dino_v2.py
CHANGED
|
@@ -460,7 +460,7 @@ class DINOv2Student(SSLBaseNet):
|
|
|
460
460
|
bottleneck_dim=head_bottleneck_dim,
|
|
461
461
|
)
|
|
462
462
|
|
|
463
|
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width)
|
|
463
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
|
|
464
464
|
|
|
465
465
|
# pylint: disable=arguments-differ
|
|
466
466
|
def forward( # type: ignore[override]
|
|
@@ -543,7 +543,7 @@ class DINOv2Teacher(SSLBaseNet):
|
|
|
543
543
|
)
|
|
544
544
|
|
|
545
545
|
# Unused, Makes for an easier EMA update
|
|
546
|
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width)
|
|
546
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
|
|
547
547
|
|
|
548
548
|
# pylint: disable=arguments-differ
|
|
549
549
|
def forward( # type: ignore[override]
|
birder/net/ssl/franca.py
CHANGED
|
@@ -433,7 +433,7 @@ class FrancaStudent(SSLBaseNet):
|
|
|
433
433
|
nesting_list=nesting_list,
|
|
434
434
|
)
|
|
435
435
|
|
|
436
|
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width)
|
|
436
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
|
|
437
437
|
|
|
438
438
|
# pylint: disable=arguments-differ
|
|
439
439
|
def forward( # type: ignore[override]
|
|
@@ -523,7 +523,7 @@ class FrancaTeacher(SSLBaseNet):
|
|
|
523
523
|
)
|
|
524
524
|
|
|
525
525
|
# Unused, Makes for an easier EMA update
|
|
526
|
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width)
|
|
526
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
|
|
527
527
|
|
|
528
528
|
# pylint: disable=arguments-differ
|
|
529
529
|
def forward( # type: ignore[override]
|
birder/net/ssl/i_jepa.py
CHANGED
|
@@ -200,7 +200,7 @@ class VisionTransformerPredictor(nn.Module):
|
|
|
200
200
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
|
|
201
201
|
|
|
202
202
|
pos_embedding = pos_embedding_sin_cos_2d(h=size[0], w=size[1], dim=predictor_embed_dim, num_special_tokens=0)
|
|
203
|
-
self.pos_embedding = nn.
|
|
203
|
+
self.pos_embedding = nn.Buffer(pos_embedding)
|
|
204
204
|
|
|
205
205
|
self.encoder = Encoder(
|
|
206
206
|
depth, num_heads, predictor_embed_dim, mlp_dim, dropout=0.0, attention_dropout=0.0, dpr=dpr
|
birder/net/ssl/ibot.py
CHANGED
|
@@ -254,7 +254,7 @@ class iBOT(SSLBaseNet):
|
|
|
254
254
|
shared_head=shared_head,
|
|
255
255
|
)
|
|
256
256
|
|
|
257
|
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width)
|
|
257
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
|
|
258
258
|
|
|
259
259
|
def forward( # type: ignore[override] # pylint: disable=arguments-differ
|
|
260
260
|
self, x: torch.Tensor, masks: Optional[torch.Tensor], return_keys: Literal["all", "embedding"] = "all"
|
birder/net/swiftformer.py
CHANGED
|
@@ -48,7 +48,12 @@ class ConvEncoder(nn.Module):
|
|
|
48
48
|
) -> None:
|
|
49
49
|
super().__init__()
|
|
50
50
|
self.dw_conv = nn.Conv2d(
|
|
51
|
-
dim,
|
|
51
|
+
dim,
|
|
52
|
+
dim,
|
|
53
|
+
kernel_size,
|
|
54
|
+
stride=(1, 1),
|
|
55
|
+
padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
|
|
56
|
+
groups=dim,
|
|
52
57
|
)
|
|
53
58
|
self.norm = nn.BatchNorm2d(dim)
|
|
54
59
|
self.pw_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
@@ -125,7 +130,12 @@ class LocalRepresentation(nn.Module):
|
|
|
125
130
|
def __init__(self, dim: int, kernel_size: tuple[int, int], drop_path: float, use_layer_scale: bool) -> None:
|
|
126
131
|
super().__init__()
|
|
127
132
|
self.dw_conv = nn.Conv2d(
|
|
128
|
-
dim,
|
|
133
|
+
dim,
|
|
134
|
+
dim,
|
|
135
|
+
kernel_size,
|
|
136
|
+
stride=(1, 1),
|
|
137
|
+
padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
|
|
138
|
+
groups=dim,
|
|
129
139
|
)
|
|
130
140
|
self.norm = nn.BatchNorm2d(dim)
|
|
131
141
|
self.pw_conv1 = nn.Conv2d(dim, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
@@ -72,7 +72,7 @@ class ShiftedWindowAttention(nn.Module):
|
|
|
72
72
|
self.define_relative_position_bias_table()
|
|
73
73
|
self.define_relative_position_index()
|
|
74
74
|
|
|
75
|
-
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))
|
|
75
|
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
|
76
76
|
|
|
77
77
|
# MLP to generate continuous relative position bias
|
|
78
78
|
self.cpb_mlp = nn.Sequential(
|
birder/net/tiny_vit.py
CHANGED
|
@@ -77,10 +77,11 @@ class MBConv(nn.Module):
|
|
|
77
77
|
kernel_size=(1, 1),
|
|
78
78
|
stride=(1, 1),
|
|
79
79
|
padding=(0, 0),
|
|
80
|
-
activation_layer=
|
|
80
|
+
activation_layer=None,
|
|
81
81
|
inplace=None,
|
|
82
82
|
)
|
|
83
83
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
84
|
+
self.act = nn.GELU()
|
|
84
85
|
|
|
85
86
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
86
87
|
shortcut = x
|
|
@@ -89,6 +90,7 @@ class MBConv(nn.Module):
|
|
|
89
90
|
x = self.conv3(x)
|
|
90
91
|
x = self.drop_path(x)
|
|
91
92
|
x += shortcut
|
|
93
|
+
x = self.act(x)
|
|
92
94
|
|
|
93
95
|
return x
|
|
94
96
|
|
|
@@ -508,18 +510,3 @@ registry.register_model_config(
|
|
|
508
510
|
"drop_path_rate": 0.2,
|
|
509
511
|
},
|
|
510
512
|
)
|
|
511
|
-
|
|
512
|
-
registry.register_weights(
|
|
513
|
-
"tiny_vit_5m_il-common",
|
|
514
|
-
{
|
|
515
|
-
"description": "TinyViT 5M model trained on the il-common dataset",
|
|
516
|
-
"resolution": (256, 256),
|
|
517
|
-
"formats": {
|
|
518
|
-
"pt": {
|
|
519
|
-
"file_size": 20.0,
|
|
520
|
-
"sha256": "57f84dc3144fc4e3ca39328d3a1446ca9e26ddb54e4c4d84301b7638bee2ec21",
|
|
521
|
-
},
|
|
522
|
-
},
|
|
523
|
-
"net": {"network": "tiny_vit_5m", "tag": "il-common"},
|
|
524
|
-
},
|
|
525
|
-
)
|
birder/net/van.py
CHANGED
|
@@ -116,8 +116,8 @@ class VANBlock(nn.Module):
|
|
|
116
116
|
self.mlp = DWConvMLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)
|
|
117
117
|
|
|
118
118
|
layer_scale_init_value = 1e-2
|
|
119
|
-
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1))
|
|
120
|
-
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1))
|
|
119
|
+
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
|
|
120
|
+
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
|
|
121
121
|
|
|
122
122
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
123
123
|
x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))
|