birder 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. birder/common/fs_ops.py +2 -2
  2. birder/introspection/attention_rollout.py +1 -1
  3. birder/introspection/transformer_attribution.py +1 -1
  4. birder/layers/layer_scale.py +1 -1
  5. birder/net/__init__.py +2 -10
  6. birder/net/_rope_vit_configs.py +430 -0
  7. birder/net/_vit_configs.py +479 -0
  8. birder/net/biformer.py +1 -0
  9. birder/net/cait.py +5 -5
  10. birder/net/coat.py +12 -12
  11. birder/net/conv2former.py +3 -3
  12. birder/net/convmixer.py +1 -1
  13. birder/net/convnext_v1.py +1 -1
  14. birder/net/crossvit.py +5 -5
  15. birder/net/davit.py +1 -1
  16. birder/net/deit.py +12 -26
  17. birder/net/deit3.py +42 -189
  18. birder/net/densenet.py +9 -8
  19. birder/net/detection/deformable_detr.py +5 -2
  20. birder/net/detection/detr.py +5 -2
  21. birder/net/detection/efficientdet.py +1 -1
  22. birder/net/dpn.py +1 -2
  23. birder/net/edgenext.py +2 -1
  24. birder/net/edgevit.py +3 -0
  25. birder/net/efficientformer_v1.py +2 -1
  26. birder/net/efficientformer_v2.py +18 -31
  27. birder/net/efficientnet_v2.py +3 -0
  28. birder/net/efficientvit_mit.py +5 -5
  29. birder/net/fasternet.py +2 -2
  30. birder/net/flexivit.py +22 -43
  31. birder/net/groupmixformer.py +1 -1
  32. birder/net/hgnet_v1.py +5 -5
  33. birder/net/inception_next.py +1 -1
  34. birder/net/inception_resnet_v1.py +3 -3
  35. birder/net/inception_resnet_v2.py +7 -4
  36. birder/net/inception_v3.py +3 -0
  37. birder/net/inception_v4.py +3 -0
  38. birder/net/maxvit.py +1 -1
  39. birder/net/metaformer.py +3 -3
  40. birder/net/mim/crossmae.py +1 -1
  41. birder/net/mim/mae_vit.py +1 -1
  42. birder/net/mim/simmim.py +1 -1
  43. birder/net/mobilenet_v1.py +0 -9
  44. birder/net/mobilenet_v2.py +38 -44
  45. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  46. birder/net/mobilevit_v1.py +5 -32
  47. birder/net/mobilevit_v2.py +1 -45
  48. birder/net/moganet.py +8 -5
  49. birder/net/mvit_v2.py +6 -6
  50. birder/net/nfnet.py +4 -0
  51. birder/net/pit.py +1 -1
  52. birder/net/pvt_v1.py +5 -5
  53. birder/net/pvt_v2.py +5 -5
  54. birder/net/repghost.py +1 -30
  55. birder/net/resmlp.py +2 -2
  56. birder/net/resnest.py +3 -0
  57. birder/net/resnet_v1.py +125 -1
  58. birder/net/resnet_v2.py +75 -1
  59. birder/net/resnext.py +35 -1
  60. birder/net/rope_deit3.py +33 -136
  61. birder/net/rope_flexivit.py +18 -18
  62. birder/net/rope_vit.py +3 -735
  63. birder/net/simple_vit.py +22 -16
  64. birder/net/smt.py +1 -1
  65. birder/net/squeezenet.py +5 -12
  66. birder/net/squeezenext.py +0 -24
  67. birder/net/ssl/capi.py +1 -1
  68. birder/net/ssl/data2vec.py +1 -1
  69. birder/net/ssl/dino_v2.py +2 -2
  70. birder/net/ssl/franca.py +2 -2
  71. birder/net/ssl/i_jepa.py +1 -1
  72. birder/net/ssl/ibot.py +1 -1
  73. birder/net/swiftformer.py +12 -2
  74. birder/net/swin_transformer_v2.py +1 -1
  75. birder/net/tiny_vit.py +3 -16
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +35 -963
  78. birder/net/vit_sam.py +13 -38
  79. birder/net/xcit.py +7 -6
  80. birder/tools/introspection.py +1 -1
  81. birder/tools/model_info.py +3 -1
  82. birder/version.py +1 -1
  83. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/METADATA +1 -1
  84. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/RECORD +88 -90
  85. birder/net/mobilenet_v3_small.py +0 -43
  86. birder/net/se_resnet_v1.py +0 -105
  87. birder/net/se_resnet_v2.py +0 -59
  88. birder/net/se_resnext.py +0 -30
  89. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/WHEEL +0 -0
  90. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/entry_points.txt +0 -0
  91. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/licenses/LICENSE +0 -0
  92. {birder-0.3.3.dist-info → birder-0.4.0.dist-info}/top_level.txt +0 -0
birder/net/simple_vit.py CHANGED
@@ -20,6 +20,12 @@ import torch
20
20
  from torch import nn
21
21
 
22
22
  from birder.model_registry import registry
23
+ from birder.net._vit_configs import BASE
24
+ from birder.net._vit_configs import GIANT
25
+ from birder.net._vit_configs import HUGE
26
+ from birder.net._vit_configs import LARGE
27
+ from birder.net._vit_configs import MEDIUM
28
+ from birder.net._vit_configs import SMALL
23
29
  from birder.net.base import MaskedTokenOmissionMixin
24
30
  from birder.net.base import PreTrainEncoder
25
31
  from birder.net.base import TokenOmissionResultType
@@ -45,12 +51,12 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
45
51
  assert self.config is not None, "must set config"
46
52
 
47
53
  image_size = self.size
48
- drop_path_rate = 0.0
49
54
  patch_size: int = self.config["patch_size"]
50
55
  num_layers: int = self.config["num_layers"]
51
56
  num_heads: int = self.config["num_heads"]
52
57
  hidden_dim: int = self.config["hidden_dim"]
53
58
  mlp_dim: int = self.config["mlp_dim"]
59
+ drop_path_rate: float = self.config["drop_path_rate"]
54
60
 
55
61
  torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
56
62
  torch._assert(image_size[1] % patch_size == 0, "Input shape indivisible by patch size!")
@@ -215,75 +221,75 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
215
221
  registry.register_model_config(
216
222
  "simple_vit_s32",
217
223
  Simple_ViT,
218
- config={"patch_size": 32, "num_layers": 12, "num_heads": 6, "hidden_dim": 384, "mlp_dim": 1536},
224
+ config={"patch_size": 32, **SMALL},
219
225
  )
220
226
  registry.register_model_config(
221
227
  "simple_vit_s16",
222
228
  Simple_ViT,
223
- config={"patch_size": 16, "num_layers": 12, "num_heads": 6, "hidden_dim": 384, "mlp_dim": 1536},
229
+ config={"patch_size": 16, **SMALL},
224
230
  )
225
231
  registry.register_model_config(
226
232
  "simple_vit_s14",
227
233
  Simple_ViT,
228
- config={"patch_size": 14, "num_layers": 12, "num_heads": 6, "hidden_dim": 384, "mlp_dim": 1536},
234
+ config={"patch_size": 14, **SMALL},
229
235
  )
230
236
  registry.register_model_config(
231
237
  "simple_vit_m32",
232
238
  Simple_ViT,
233
- config={"patch_size": 32, "num_layers": 12, "num_heads": 8, "hidden_dim": 512, "mlp_dim": 2048},
239
+ config={"patch_size": 32, **MEDIUM},
234
240
  )
235
241
  registry.register_model_config(
236
242
  "simple_vit_m16",
237
243
  Simple_ViT,
238
- config={"patch_size": 16, "num_layers": 12, "num_heads": 8, "hidden_dim": 512, "mlp_dim": 2048},
244
+ config={"patch_size": 16, **MEDIUM},
239
245
  )
240
246
  registry.register_model_config(
241
247
  "simple_vit_m14",
242
248
  Simple_ViT,
243
- config={"patch_size": 14, "num_layers": 12, "num_heads": 8, "hidden_dim": 512, "mlp_dim": 2048},
249
+ config={"patch_size": 14, **MEDIUM},
244
250
  )
245
251
  registry.register_model_config(
246
252
  "simple_vit_b32",
247
253
  Simple_ViT,
248
- config={"patch_size": 32, "num_layers": 12, "num_heads": 12, "hidden_dim": 768, "mlp_dim": 3072},
254
+ config={"patch_size": 32, **BASE}, # Override the BASE definition
249
255
  )
250
256
  registry.register_model_config(
251
257
  "simple_vit_b16",
252
258
  Simple_ViT,
253
- config={"patch_size": 16, "num_layers": 12, "num_heads": 12, "hidden_dim": 768, "mlp_dim": 3072},
259
+ config={"patch_size": 16, **BASE},
254
260
  )
255
261
  registry.register_model_config(
256
262
  "simple_vit_b14",
257
263
  Simple_ViT,
258
- config={"patch_size": 14, "num_layers": 12, "num_heads": 12, "hidden_dim": 768, "mlp_dim": 3072},
264
+ config={"patch_size": 14, **BASE},
259
265
  )
260
266
  registry.register_model_config(
261
267
  "simple_vit_l32",
262
268
  Simple_ViT,
263
- config={"patch_size": 32, "num_layers": 24, "num_heads": 16, "hidden_dim": 1024, "mlp_dim": 4096},
269
+ config={"patch_size": 32, **LARGE},
264
270
  )
265
271
  registry.register_model_config(
266
272
  "simple_vit_l16",
267
273
  Simple_ViT,
268
- config={"patch_size": 16, "num_layers": 24, "num_heads": 16, "hidden_dim": 1024, "mlp_dim": 4096},
274
+ config={"patch_size": 16, **LARGE},
269
275
  )
270
276
  registry.register_model_config(
271
277
  "simple_vit_l14",
272
278
  Simple_ViT,
273
- config={"patch_size": 14, "num_layers": 24, "num_heads": 16, "hidden_dim": 1024, "mlp_dim": 4096},
279
+ config={"patch_size": 14, **LARGE},
274
280
  )
275
281
  registry.register_model_config(
276
282
  "simple_vit_h16",
277
283
  Simple_ViT,
278
- config={"patch_size": 16, "num_layers": 32, "num_heads": 16, "hidden_dim": 1280, "mlp_dim": 5120},
284
+ config={"patch_size": 16, **HUGE},
279
285
  )
280
286
  registry.register_model_config(
281
287
  "simple_vit_h14",
282
288
  Simple_ViT,
283
- config={"patch_size": 14, "num_layers": 32, "num_heads": 16, "hidden_dim": 1280, "mlp_dim": 5120},
289
+ config={"patch_size": 14, **HUGE},
284
290
  )
285
291
  registry.register_model_config( # From "Scaling Vision Transformers"
286
292
  "simple_vit_g14",
287
293
  Simple_ViT,
288
- config={"patch_size": 14, "num_layers": 40, "num_heads": 16, "hidden_dim": 1408, "mlp_dim": 6144},
294
+ config={"patch_size": 14, **GIANT},
289
295
  )
birder/net/smt.py CHANGED
@@ -259,7 +259,7 @@ class Stem(nn.Module):
259
259
  embed_dim,
260
260
  kernel_size=kernel_size,
261
261
  stride=stride,
262
- padding=(kernel_size[0] // 2, kernel_size[1] // 2),
262
+ padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
263
263
  ),
264
264
  nn.Conv2d(embed_dim, embed_dim, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
265
265
  )
birder/net/squeezenet.py CHANGED
@@ -20,11 +20,11 @@ from birder.net.base import BaseNet
20
20
  class Fire(nn.Module):
21
21
  def __init__(self, in_planes: int, squeeze: int, expand: int) -> None:
22
22
  super().__init__()
23
- self.squeeze = nn.Conv2d(in_planes, squeeze, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
23
+ self.squeeze = nn.Conv2d(in_planes, squeeze, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
24
24
  self.squeeze_activation = nn.ReLU(inplace=True)
25
- self.left = nn.Conv2d(squeeze, expand, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
25
+ self.left = nn.Conv2d(squeeze, expand, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
26
26
  self.left_activation = nn.ReLU(inplace=True)
27
- self.right = nn.Conv2d(squeeze, expand, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
27
+ self.right = nn.Conv2d(squeeze, expand, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
28
28
  self.right_activation = nn.ReLU(inplace=True)
29
29
 
30
30
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -53,7 +53,7 @@ class SqueezeNet(BaseNet):
53
53
  assert self.config is None, "config not supported"
54
54
 
55
55
  self.stem = nn.Sequential(
56
- nn.Conv2d(self.input_channels, 64, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), bias=False),
56
+ nn.Conv2d(self.input_channels, 64, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
57
57
  nn.ReLU(inplace=True),
58
58
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=True),
59
59
  )
@@ -94,14 +94,7 @@ class SqueezeNet(BaseNet):
94
94
 
95
95
  return nn.Sequential(
96
96
  nn.Dropout(p=0.5, inplace=True),
97
- nn.Conv2d(
98
- embed_dim,
99
- self.num_classes,
100
- kernel_size=(1, 1),
101
- stride=(1, 1),
102
- padding=(0, 0),
103
- bias=False,
104
- ),
97
+ nn.Conv2d(embed_dim, self.num_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
105
98
  nn.ReLU(inplace=True),
106
99
  nn.AdaptiveAvgPool2d(output_size=(1, 1)),
107
100
  nn.Flatten(1),
birder/net/squeezenext.py CHANGED
@@ -27,7 +27,6 @@ class SqnxtUnit(nn.Module):
27
27
  kernel_size=(1, 1),
28
28
  stride=(stride, stride),
29
29
  padding=(0, 0),
30
- bias=False,
31
30
  )
32
31
 
33
32
  elif in_channels > out_channels:
@@ -38,7 +37,6 @@ class SqnxtUnit(nn.Module):
38
37
  kernel_size=(1, 1),
39
38
  stride=(stride, stride),
40
39
  padding=(0, 0),
41
- bias=False,
42
40
  )
43
41
 
44
42
  else:
@@ -52,7 +50,6 @@ class SqnxtUnit(nn.Module):
52
50
  kernel_size=(1, 1),
53
51
  stride=(stride, stride),
54
52
  padding=(0, 0),
55
- bias=False,
56
53
  ),
57
54
  Conv2dNormActivation(
58
55
  in_channels // reduction,
@@ -60,7 +57,6 @@ class SqnxtUnit(nn.Module):
60
57
  kernel_size=(1, 1),
61
58
  stride=(1, 1),
62
59
  padding=(0, 0),
63
- bias=False,
64
60
  ),
65
61
  Conv2dNormActivation(
66
62
  in_channels // (2 * reduction),
@@ -68,7 +64,6 @@ class SqnxtUnit(nn.Module):
68
64
  kernel_size=(1, 3),
69
65
  stride=(1, 1),
70
66
  padding=(0, 1),
71
- bias=False,
72
67
  ),
73
68
  Conv2dNormActivation(
74
69
  in_channels // reduction,
@@ -76,7 +71,6 @@ class SqnxtUnit(nn.Module):
76
71
  kernel_size=(3, 1),
77
72
  stride=(1, 1),
78
73
  padding=(1, 0),
79
- bias=False,
80
74
  ),
81
75
  Conv2dNormActivation(
82
76
  in_channels // reduction,
@@ -84,7 +78,6 @@ class SqnxtUnit(nn.Module):
84
78
  kernel_size=(1, 1),
85
79
  stride=(1, 1),
86
80
  padding=(0, 0),
87
- bias=False,
88
81
  ),
89
82
  )
90
83
  self.relu = nn.ReLU(inplace=True)
@@ -124,7 +117,6 @@ class SqueezeNext(DetectorBackbone):
124
117
  kernel_size=(7, 7),
125
118
  stride=(2, 2),
126
119
  padding=(1, 1),
127
- bias=False,
128
120
  ),
129
121
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=True),
130
122
  )
@@ -155,7 +147,6 @@ class SqueezeNext(DetectorBackbone):
155
147
  kernel_size=(1, 1),
156
148
  stride=(1, 1),
157
149
  padding=(0, 0),
158
- bias=False,
159
150
  ),
160
151
  nn.AdaptiveAvgPool2d(output_size=(1, 1)),
161
152
  nn.Flatten(1),
@@ -199,18 +190,3 @@ registry.register_model_config("squeezenext_0_5", SqueezeNext, config={"width_sc
199
190
  registry.register_model_config("squeezenext_1_0", SqueezeNext, config={"width_scale": 1.0})
200
191
  registry.register_model_config("squeezenext_1_5", SqueezeNext, config={"width_scale": 1.5})
201
192
  registry.register_model_config("squeezenext_2_0", SqueezeNext, config={"width_scale": 2.0})
202
-
203
- registry.register_weights(
204
- "squeezenext_1_0_il-common",
205
- {
206
- "description": "SqueezeNext v2 1.0x output channels model trained on the il-common dataset",
207
- "resolution": (259, 259),
208
- "formats": {
209
- "pt": {
210
- "file_size": 3.5,
211
- "sha256": "da01d1cd05c71b80b5e4e6ca66400f64fa3f6179d0e90834c4f6942c8095557a",
212
- }
213
- },
214
- "net": {"network": "squeezenext_1_0", "tag": "il-common"},
215
- },
216
- )
birder/net/ssl/capi.py CHANGED
@@ -306,7 +306,7 @@ class Decoder(nn.Module):
306
306
  dim=decoder_embed_dim,
307
307
  num_special_tokens=0,
308
308
  ).unsqueeze(0)
309
- self.decoder_pos_embed = nn.Parameter(pos_embedding, requires_grad=False)
309
+ self.decoder_pos_embed = nn.Buffer(pos_embedding)
310
310
 
311
311
  self.decoder_layers = nn.ModuleList()
312
312
  for _ in range(decoder_depth):
@@ -51,7 +51,7 @@ class Data2Vec(SSLBaseNet):
51
51
  self.ema_backbone = copy.deepcopy(self.backbone)
52
52
  self.head = nn.Linear(self.backbone.embedding_size, self.backbone.embedding_size)
53
53
 
54
- self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width), requires_grad=True)
54
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
55
55
 
56
56
  # Weights initialization
57
57
  self.ema_backbone.load_state_dict(self.backbone.state_dict())
birder/net/ssl/dino_v2.py CHANGED
@@ -460,7 +460,7 @@ class DINOv2Student(SSLBaseNet):
460
460
  bottleneck_dim=head_bottleneck_dim,
461
461
  )
462
462
 
463
- self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width), requires_grad=True)
463
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
464
464
 
465
465
  # pylint: disable=arguments-differ
466
466
  def forward( # type: ignore[override]
@@ -543,7 +543,7 @@ class DINOv2Teacher(SSLBaseNet):
543
543
  )
544
544
 
545
545
  # Unused, Makes for an easier EMA update
546
- self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width), requires_grad=True)
546
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
547
547
 
548
548
  # pylint: disable=arguments-differ
549
549
  def forward( # type: ignore[override]
birder/net/ssl/franca.py CHANGED
@@ -433,7 +433,7 @@ class FrancaStudent(SSLBaseNet):
433
433
  nesting_list=nesting_list,
434
434
  )
435
435
 
436
- self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width), requires_grad=True)
436
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
437
437
 
438
438
  # pylint: disable=arguments-differ
439
439
  def forward( # type: ignore[override]
@@ -523,7 +523,7 @@ class FrancaTeacher(SSLBaseNet):
523
523
  )
524
524
 
525
525
  # Unused, Makes for an easier EMA update
526
- self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width), requires_grad=True)
526
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
527
527
 
528
528
  # pylint: disable=arguments-differ
529
529
  def forward( # type: ignore[override]
birder/net/ssl/i_jepa.py CHANGED
@@ -200,7 +200,7 @@ class VisionTransformerPredictor(nn.Module):
200
200
  self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
201
201
 
202
202
  pos_embedding = pos_embedding_sin_cos_2d(h=size[0], w=size[1], dim=predictor_embed_dim, num_special_tokens=0)
203
- self.pos_embedding = nn.Parameter(pos_embedding, requires_grad=False)
203
+ self.pos_embedding = nn.Buffer(pos_embedding)
204
204
 
205
205
  self.encoder = Encoder(
206
206
  depth, num_heads, predictor_embed_dim, mlp_dim, dropout=0.0, attention_dropout=0.0, dpr=dpr
birder/net/ssl/ibot.py CHANGED
@@ -254,7 +254,7 @@ class iBOT(SSLBaseNet):
254
254
  shared_head=shared_head,
255
255
  )
256
256
 
257
- self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width), requires_grad=True)
257
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.backbone.stem_width))
258
258
 
259
259
  def forward( # type: ignore[override] # pylint: disable=arguments-differ
260
260
  self, x: torch.Tensor, masks: Optional[torch.Tensor], return_keys: Literal["all", "embedding"] = "all"
birder/net/swiftformer.py CHANGED
@@ -48,7 +48,12 @@ class ConvEncoder(nn.Module):
48
48
  ) -> None:
49
49
  super().__init__()
50
50
  self.dw_conv = nn.Conv2d(
51
- dim, dim, kernel_size, stride=(1, 1), padding=(kernel_size[0] // 2, kernel_size[1] // 2), groups=dim
51
+ dim,
52
+ dim,
53
+ kernel_size,
54
+ stride=(1, 1),
55
+ padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
56
+ groups=dim,
52
57
  )
53
58
  self.norm = nn.BatchNorm2d(dim)
54
59
  self.pw_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
@@ -125,7 +130,12 @@ class LocalRepresentation(nn.Module):
125
130
  def __init__(self, dim: int, kernel_size: tuple[int, int], drop_path: float, use_layer_scale: bool) -> None:
126
131
  super().__init__()
127
132
  self.dw_conv = nn.Conv2d(
128
- dim, dim, kernel_size, stride=(1, 1), padding=(kernel_size[0] // 2, kernel_size[1] // 2), groups=dim
133
+ dim,
134
+ dim,
135
+ kernel_size,
136
+ stride=(1, 1),
137
+ padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
138
+ groups=dim,
129
139
  )
130
140
  self.norm = nn.BatchNorm2d(dim)
131
141
  self.pw_conv1 = nn.Conv2d(dim, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
@@ -72,7 +72,7 @@ class ShiftedWindowAttention(nn.Module):
72
72
  self.define_relative_position_bias_table()
73
73
  self.define_relative_position_index()
74
74
 
75
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
75
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
76
76
 
77
77
  # MLP to generate continuous relative position bias
78
78
  self.cpb_mlp = nn.Sequential(
birder/net/tiny_vit.py CHANGED
@@ -77,10 +77,11 @@ class MBConv(nn.Module):
77
77
  kernel_size=(1, 1),
78
78
  stride=(1, 1),
79
79
  padding=(0, 0),
80
- activation_layer=nn.GELU,
80
+ activation_layer=None,
81
81
  inplace=None,
82
82
  )
83
83
  self.drop_path = StochasticDepth(drop_path, mode="row")
84
+ self.act = nn.GELU()
84
85
 
85
86
  def forward(self, x: torch.Tensor) -> torch.Tensor:
86
87
  shortcut = x
@@ -89,6 +90,7 @@ class MBConv(nn.Module):
89
90
  x = self.conv3(x)
90
91
  x = self.drop_path(x)
91
92
  x += shortcut
93
+ x = self.act(x)
92
94
 
93
95
  return x
94
96
 
@@ -508,18 +510,3 @@ registry.register_model_config(
508
510
  "drop_path_rate": 0.2,
509
511
  },
510
512
  )
511
-
512
- registry.register_weights(
513
- "tiny_vit_5m_il-common",
514
- {
515
- "description": "TinyViT 5M model trained on the il-common dataset",
516
- "resolution": (256, 256),
517
- "formats": {
518
- "pt": {
519
- "file_size": 20.0,
520
- "sha256": "57f84dc3144fc4e3ca39328d3a1446ca9e26ddb54e4c4d84301b7638bee2ec21",
521
- },
522
- },
523
- "net": {"network": "tiny_vit_5m", "tag": "il-common"},
524
- },
525
- )
birder/net/van.py CHANGED
@@ -116,8 +116,8 @@ class VANBlock(nn.Module):
116
116
  self.mlp = DWConvMLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)
117
117
 
118
118
  layer_scale_init_value = 1e-2
119
- self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)), requires_grad=True)
120
- self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)), requires_grad=True)
119
+ self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
120
+ self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
121
121
 
122
122
  def forward(self, x: torch.Tensor) -> torch.Tensor:
123
123
  x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))