birder 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. birder/common/fs_ops.py +2 -2
  2. birder/introspection/attention_rollout.py +1 -1
  3. birder/introspection/transformer_attribution.py +1 -1
  4. birder/layers/layer_scale.py +1 -1
  5. birder/net/__init__.py +2 -10
  6. birder/net/_rope_vit_configs.py +430 -0
  7. birder/net/_vit_configs.py +479 -0
  8. birder/net/biformer.py +1 -0
  9. birder/net/cait.py +5 -5
  10. birder/net/coat.py +12 -12
  11. birder/net/conv2former.py +3 -3
  12. birder/net/convmixer.py +1 -1
  13. birder/net/convnext_v1.py +1 -1
  14. birder/net/crossvit.py +5 -5
  15. birder/net/davit.py +1 -1
  16. birder/net/deit.py +12 -26
  17. birder/net/deit3.py +42 -189
  18. birder/net/densenet.py +9 -8
  19. birder/net/detection/deformable_detr.py +5 -2
  20. birder/net/detection/detr.py +5 -2
  21. birder/net/detection/efficientdet.py +1 -1
  22. birder/net/dpn.py +1 -2
  23. birder/net/edgenext.py +2 -1
  24. birder/net/edgevit.py +3 -0
  25. birder/net/efficientformer_v1.py +2 -1
  26. birder/net/efficientformer_v2.py +18 -31
  27. birder/net/efficientnet_v2.py +3 -0
  28. birder/net/efficientvit_mit.py +5 -5
  29. birder/net/fasternet.py +2 -2
  30. birder/net/flexivit.py +22 -43
  31. birder/net/groupmixformer.py +1 -1
  32. birder/net/hgnet_v1.py +5 -5
  33. birder/net/hiera.py +3 -3
  34. birder/net/hieradet.py +116 -28
  35. birder/net/inception_next.py +1 -1
  36. birder/net/inception_resnet_v1.py +3 -3
  37. birder/net/inception_resnet_v2.py +7 -4
  38. birder/net/inception_v3.py +3 -0
  39. birder/net/inception_v4.py +3 -0
  40. birder/net/maxvit.py +1 -1
  41. birder/net/metaformer.py +3 -3
  42. birder/net/mim/crossmae.py +1 -1
  43. birder/net/mim/mae_vit.py +1 -1
  44. birder/net/mim/simmim.py +1 -1
  45. birder/net/mobilenet_v1.py +0 -9
  46. birder/net/mobilenet_v2.py +38 -44
  47. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  48. birder/net/mobilevit_v1.py +5 -32
  49. birder/net/mobilevit_v2.py +1 -45
  50. birder/net/moganet.py +8 -5
  51. birder/net/mvit_v2.py +6 -6
  52. birder/net/nfnet.py +4 -0
  53. birder/net/pit.py +1 -1
  54. birder/net/pvt_v1.py +5 -5
  55. birder/net/pvt_v2.py +5 -5
  56. birder/net/repghost.py +1 -30
  57. birder/net/resmlp.py +2 -2
  58. birder/net/resnest.py +3 -0
  59. birder/net/resnet_v1.py +125 -1
  60. birder/net/resnet_v2.py +75 -1
  61. birder/net/resnext.py +35 -1
  62. birder/net/rope_deit3.py +33 -136
  63. birder/net/rope_flexivit.py +18 -18
  64. birder/net/rope_vit.py +3 -735
  65. birder/net/simple_vit.py +22 -16
  66. birder/net/smt.py +1 -1
  67. birder/net/squeezenet.py +5 -12
  68. birder/net/squeezenext.py +0 -24
  69. birder/net/ssl/capi.py +1 -1
  70. birder/net/ssl/data2vec.py +1 -1
  71. birder/net/ssl/dino_v2.py +2 -2
  72. birder/net/ssl/franca.py +2 -2
  73. birder/net/ssl/i_jepa.py +1 -1
  74. birder/net/ssl/ibot.py +1 -1
  75. birder/net/swiftformer.py +12 -2
  76. birder/net/swin_transformer_v2.py +1 -1
  77. birder/net/tiny_vit.py +3 -16
  78. birder/net/van.py +2 -2
  79. birder/net/vit.py +35 -963
  80. birder/net/vit_sam.py +13 -38
  81. birder/net/xcit.py +7 -6
  82. birder/scripts/train.py +17 -15
  83. birder/scripts/train_kd.py +17 -16
  84. birder/tools/introspection.py +1 -1
  85. birder/tools/model_info.py +3 -1
  86. birder/tools/show_iterator.py +16 -2
  87. birder/version.py +1 -1
  88. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/METADATA +1 -1
  89. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/RECORD +93 -95
  90. birder/net/mobilenet_v3_small.py +0 -43
  91. birder/net/se_resnet_v1.py +0 -105
  92. birder/net/se_resnet_v2.py +0 -59
  93. birder/net/se_resnext.py +0 -30
  94. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/WHEEL +0 -0
  95. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/entry_points.txt +0 -0
  96. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/licenses/LICENSE +0 -0
  97. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,479 @@
1
+ """
2
+ ViT model configuration registrations
3
+
4
+ This file contains *only* model variant definitions and their registration
5
+ with the global model registry. The actual ViT implementation lives in vit.py.
6
+
7
+ Naming:
8
+ - All model names must follow the ViT / RoPE ViT naming convention documented in rope_vit_configs.py.
9
+ """
10
+
11
+ from birder.model_registry import registry
12
+ from birder.net.base import BaseNet
13
+
14
+ TINY = {"num_layers": 12, "num_heads": 3, "hidden_dim": 192, "mlp_dim": 768, "drop_path_rate": 0.0}
15
+ SMALL = {"num_layers": 12, "num_heads": 6, "hidden_dim": 384, "mlp_dim": 1536, "drop_path_rate": 0.0}
16
+ MEDIUM = {"num_layers": 12, "num_heads": 8, "hidden_dim": 512, "mlp_dim": 2048, "drop_path_rate": 0.0}
17
+ BASE = {"num_layers": 12, "num_heads": 12, "hidden_dim": 768, "mlp_dim": 3072, "drop_path_rate": 0.1}
18
+ LARGE = {"num_layers": 24, "num_heads": 16, "hidden_dim": 1024, "mlp_dim": 4096, "drop_path_rate": 0.1}
19
+ HUGE = {"num_layers": 32, "num_heads": 16, "hidden_dim": 1280, "mlp_dim": 5120, "drop_path_rate": 0.1}
20
+
21
+ # From "Getting vit in Shape: Scaling Laws for Compute-Optimal Model Design"
22
+ # Shape-optimized vision transformer (SoViT)
23
+ SO150 = {
24
+ "num_layers": 18,
25
+ "num_heads": 16,
26
+ "hidden_dim": 896, # Changed from 880 for RoPE divisibility
27
+ "mlp_dim": 2320,
28
+ "drop_path_rate": 0.1,
29
+ }
30
+ SO400 = {
31
+ "num_layers": 27,
32
+ "num_heads": 16,
33
+ "hidden_dim": 1152,
34
+ "mlp_dim": 4304,
35
+ "drop_path_rate": 0.1,
36
+ }
37
+
38
+ # From "Scaling Vision Transformers"
39
+ GIANT = {"num_layers": 40, "num_heads": 16, "hidden_dim": 1408, "mlp_dim": 6144, "drop_path_rate": 0.1}
40
+ GIGANTIC = {"num_layers": 48, "num_heads": 16, "hidden_dim": 1664, "mlp_dim": 8192, "drop_path_rate": 0.1}
41
+
42
+
43
+ def register_vit_configs(vit: type[BaseNet]) -> None:
44
+ registry.register_model_config(
45
+ "vit_t32",
46
+ vit,
47
+ config={"patch_size": 32, **TINY},
48
+ )
49
+ registry.register_model_config(
50
+ "vit_t16",
51
+ vit,
52
+ config={"patch_size": 16, **TINY},
53
+ )
54
+ registry.register_model_config(
55
+ "vit_t14",
56
+ vit,
57
+ config={"patch_size": 14, **TINY},
58
+ )
59
+ registry.register_model_config(
60
+ "vit_s32",
61
+ vit,
62
+ config={"patch_size": 32, **SMALL},
63
+ )
64
+ registry.register_model_config(
65
+ "vit_s16",
66
+ vit,
67
+ config={"patch_size": 16, **SMALL},
68
+ )
69
+ registry.register_model_config(
70
+ "vit_s16_ls",
71
+ vit,
72
+ config={"patch_size": 16, **SMALL, "layer_scale_init_value": 1e-5},
73
+ )
74
+ registry.register_model_config(
75
+ "vit_s16_pn",
76
+ vit,
77
+ config={"patch_size": 16, **SMALL, "pre_norm": True, "norm_layer_eps": 1e-5},
78
+ )
79
+ registry.register_model_config(
80
+ "vit_s14",
81
+ vit,
82
+ config={"patch_size": 14, **SMALL},
83
+ )
84
+ registry.register_model_config(
85
+ "vit_m32",
86
+ vit,
87
+ config={"patch_size": 32, **MEDIUM},
88
+ )
89
+ registry.register_model_config(
90
+ "vit_m16",
91
+ vit,
92
+ config={"patch_size": 16, **MEDIUM},
93
+ )
94
+ registry.register_model_config(
95
+ "vit_m14",
96
+ vit,
97
+ config={"patch_size": 14, **MEDIUM},
98
+ )
99
+ registry.register_model_config(
100
+ "vit_b32",
101
+ vit,
102
+ config={"patch_size": 32, **BASE, "drop_path_rate": 0.0}, # Override the BASE definition
103
+ )
104
+ registry.register_model_config(
105
+ "vit_b16",
106
+ vit,
107
+ config={"patch_size": 16, **BASE},
108
+ )
109
+ registry.register_model_config(
110
+ "vit_b16_ls",
111
+ vit,
112
+ config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5},
113
+ )
114
+ registry.register_model_config(
115
+ "vit_b16_qkn_ls",
116
+ vit,
117
+ config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5, "qk_norm": True},
118
+ )
119
+ registry.register_model_config(
120
+ "vit_b16_pn_quick_gelu",
121
+ vit,
122
+ config={"patch_size": 16, **BASE, "pre_norm": True, "norm_layer_eps": 1e-5, "act_layer_type": "quick_gelu"},
123
+ )
124
+ registry.register_model_config(
125
+ "vit_b14",
126
+ vit,
127
+ config={"patch_size": 14, **BASE},
128
+ )
129
+ registry.register_model_config(
130
+ "vit_so150m_p14_avg",
131
+ vit,
132
+ config={"patch_size": 14, **SO150, "class_token": False},
133
+ )
134
+ registry.register_model_config(
135
+ "vit_so150m_p14_ap",
136
+ vit,
137
+ config={"patch_size": 14, **SO150, "class_token": False, "attn_pool_head": True},
138
+ )
139
+ registry.register_model_config(
140
+ "vit_l32",
141
+ vit,
142
+ config={"patch_size": 32, **LARGE},
143
+ )
144
+ registry.register_model_config(
145
+ "vit_l16",
146
+ vit,
147
+ config={"patch_size": 16, **LARGE},
148
+ )
149
+ registry.register_model_config(
150
+ "vit_l14",
151
+ vit,
152
+ config={"patch_size": 14, **LARGE},
153
+ )
154
+ registry.register_model_config(
155
+ "vit_l14_pn",
156
+ vit,
157
+ config={"patch_size": 14, **LARGE, "pre_norm": True, "norm_layer_eps": 1e-5},
158
+ )
159
+ registry.register_model_config(
160
+ "vit_l14_pn_quick_gelu",
161
+ vit,
162
+ config={"patch_size": 14, **LARGE, "pre_norm": True, "norm_layer_eps": 1e-5, "act_layer_type": "quick_gelu"},
163
+ )
164
+ registry.register_model_config(
165
+ "vit_so400m_p14_ap",
166
+ vit,
167
+ config={"patch_size": 14, **SO400, "class_token": False, "attn_pool_head": True},
168
+ )
169
+ registry.register_model_config(
170
+ "vit_h16",
171
+ vit,
172
+ config={"patch_size": 16, **HUGE},
173
+ )
174
+ registry.register_model_config(
175
+ "vit_h14",
176
+ vit,
177
+ config={"patch_size": 14, **HUGE},
178
+ )
179
+ registry.register_model_config( # From "Scaling Vision Transformers"
180
+ "vit_g16",
181
+ vit,
182
+ config={"patch_size": 16, **GIANT},
183
+ )
184
+ registry.register_model_config( # From "Scaling Vision Transformers"
185
+ "vit_g14",
186
+ vit,
187
+ config={"patch_size": 14, **GIANT},
188
+ )
189
+ registry.register_model_config( # From "Scaling Vision Transformers"
190
+ "vit_gigantic14",
191
+ vit,
192
+ config={"patch_size": 14, **GIGANTIC},
193
+ )
194
+ registry.register_model_config( # From "PaLI: A Jointly-Scaled Multilingual Language-Image Model"
195
+ "vit_e14",
196
+ vit,
197
+ config={
198
+ "patch_size": 14,
199
+ "num_layers": 56,
200
+ "num_heads": 16,
201
+ "hidden_dim": 1792,
202
+ "mlp_dim": 15360,
203
+ "drop_path_rate": 0.1,
204
+ },
205
+ )
206
+ registry.register_model_config( # From "Scaling Language-Free Visual Representation Learning"
207
+ "vit_1b_p16", # AKA vit_giant2 from DINOv2
208
+ vit,
209
+ config={
210
+ "patch_size": 16,
211
+ "num_layers": 40,
212
+ "num_heads": 24,
213
+ "hidden_dim": 1536,
214
+ "mlp_dim": 6144,
215
+ "drop_path_rate": 0.1,
216
+ },
217
+ )
218
+ registry.register_model_config( # From "Scaling Vision Transformers to 22 Billion Parameters"
219
+ "vit_22b_p16_qkn",
220
+ vit,
221
+ config={
222
+ "patch_size": 16,
223
+ "num_layers": 48,
224
+ "num_heads": 48,
225
+ "hidden_dim": 6144,
226
+ "mlp_dim": 24576,
227
+ "qk_norm": True,
228
+ "drop_path_rate": 0.1,
229
+ },
230
+ )
231
+
232
+ # With registers
233
+ ####################
234
+
235
+ registry.register_model_config(
236
+ "vit_reg1_t32",
237
+ vit,
238
+ config={"patch_size": 32, **TINY, "num_reg_tokens": 1},
239
+ )
240
+ registry.register_model_config(
241
+ "vit_reg1_t16",
242
+ vit,
243
+ config={"patch_size": 16, **TINY, "num_reg_tokens": 1},
244
+ )
245
+ registry.register_model_config(
246
+ "vit_reg1_t14",
247
+ vit,
248
+ config={"patch_size": 14, **TINY, "num_reg_tokens": 1},
249
+ )
250
+ registry.register_model_config(
251
+ "vit_reg1_s32",
252
+ vit,
253
+ config={"patch_size": 32, **SMALL, "num_reg_tokens": 1},
254
+ )
255
+ registry.register_model_config(
256
+ "vit_reg1_s16",
257
+ vit,
258
+ config={"patch_size": 16, **SMALL, "num_reg_tokens": 1},
259
+ )
260
+ registry.register_model_config(
261
+ "vit_reg1_s16_ls",
262
+ vit,
263
+ config={"patch_size": 16, **SMALL, "layer_scale_init_value": 1e-5, "num_reg_tokens": 1},
264
+ )
265
+ registry.register_model_config(
266
+ "vit_reg1_s16_rms_ls",
267
+ vit,
268
+ config={
269
+ "patch_size": 16,
270
+ **SMALL,
271
+ "layer_scale_init_value": 1e-5,
272
+ "num_reg_tokens": 1,
273
+ "norm_layer_type": "RMSNorm",
274
+ },
275
+ )
276
+ registry.register_model_config(
277
+ "vit_reg1_s14",
278
+ vit,
279
+ config={"patch_size": 14, **SMALL, "num_reg_tokens": 1},
280
+ )
281
+ registry.register_model_config(
282
+ "vit_reg4_m32",
283
+ vit,
284
+ config={"patch_size": 32, **MEDIUM, "num_reg_tokens": 4},
285
+ )
286
+ registry.register_model_config(
287
+ "vit_reg4_m16",
288
+ vit,
289
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4},
290
+ )
291
+ registry.register_model_config(
292
+ "vit_reg4_m16_rms_avg",
293
+ vit,
294
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4, "class_token": False, "norm_layer_type": "RMSNorm"},
295
+ )
296
+ registry.register_model_config(
297
+ "vit_reg4_m14",
298
+ vit,
299
+ config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4},
300
+ )
301
+ registry.register_model_config(
302
+ "vit_reg4_b32",
303
+ vit,
304
+ config={"patch_size": 32, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.0}, # Override the BASE definition
305
+ )
306
+ registry.register_model_config(
307
+ "vit_reg4_b16",
308
+ vit,
309
+ config={"patch_size": 16, **BASE, "num_reg_tokens": 4},
310
+ )
311
+ registry.register_model_config(
312
+ "vit_reg4_b16_avg",
313
+ vit,
314
+ config={"patch_size": 16, **BASE, "num_reg_tokens": 4, "class_token": False},
315
+ )
316
+ registry.register_model_config(
317
+ "vit_reg4_b14",
318
+ vit,
319
+ config={"patch_size": 14, **BASE, "num_reg_tokens": 4},
320
+ )
321
+ registry.register_model_config(
322
+ "vit_reg8_b14_ap",
323
+ vit,
324
+ config={"patch_size": 14, **BASE, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
325
+ )
326
+ registry.register_model_config(
327
+ "vit_reg4_so150m_p16_avg",
328
+ vit,
329
+ config={"patch_size": 16, **SO150, "num_reg_tokens": 4, "class_token": False},
330
+ )
331
+ registry.register_model_config(
332
+ "vit_reg8_so150m_p16_swiglu_ap",
333
+ vit,
334
+ config={
335
+ "patch_size": 16,
336
+ **SO150,
337
+ "num_reg_tokens": 8,
338
+ "class_token": False,
339
+ "attn_pool_head": True,
340
+ "mlp_layer_type": "SwiGLU_FFN",
341
+ },
342
+ )
343
+ registry.register_model_config(
344
+ "vit_reg4_so150m_p14_avg",
345
+ vit,
346
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 4, "class_token": False},
347
+ )
348
+ registry.register_model_config(
349
+ "vit_reg4_so150m_p14_ls",
350
+ vit,
351
+ config={"patch_size": 14, **SO150, "layer_scale_init_value": 1e-5, "num_reg_tokens": 4},
352
+ )
353
+ registry.register_model_config(
354
+ "vit_reg4_so150m_p14_ap",
355
+ vit,
356
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 4, "class_token": False, "attn_pool_head": True},
357
+ )
358
+ registry.register_model_config(
359
+ "vit_reg4_so150m_p14_aps",
360
+ vit,
361
+ config={
362
+ "patch_size": 14,
363
+ **SO150,
364
+ "num_reg_tokens": 4,
365
+ "class_token": False,
366
+ "attn_pool_head": True,
367
+ "attn_pool_special_tokens": True,
368
+ },
369
+ )
370
+ registry.register_model_config(
371
+ "vit_reg8_so150m_p14_avg",
372
+ vit,
373
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False},
374
+ )
375
+ registry.register_model_config(
376
+ "vit_reg8_so150m_p14_swiglu",
377
+ vit,
378
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "mlp_layer_type": "SwiGLU_FFN"},
379
+ )
380
+ registry.register_model_config(
381
+ "vit_reg8_so150m_p14_swiglu_avg",
382
+ vit,
383
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False, "mlp_layer_type": "SwiGLU_FFN"},
384
+ )
385
+ registry.register_model_config(
386
+ "vit_reg8_so150m_p14_ap",
387
+ vit,
388
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
389
+ )
390
+ registry.register_model_config(
391
+ "vit_reg4_l32",
392
+ vit,
393
+ config={"patch_size": 32, **LARGE, "num_reg_tokens": 4},
394
+ )
395
+ registry.register_model_config(
396
+ "vit_reg4_l16",
397
+ vit,
398
+ config={"patch_size": 16, **LARGE, "num_reg_tokens": 4},
399
+ )
400
+ registry.register_model_config(
401
+ "vit_reg8_l16_avg",
402
+ vit,
403
+ config={"patch_size": 16, **LARGE, "num_reg_tokens": 8, "class_token": False},
404
+ )
405
+ registry.register_model_config(
406
+ "vit_reg8_l16_aps",
407
+ vit,
408
+ config={
409
+ "patch_size": 16,
410
+ **LARGE,
411
+ "num_reg_tokens": 8,
412
+ "class_token": False,
413
+ "attn_pool_head": True,
414
+ "attn_pool_special_tokens": True,
415
+ },
416
+ )
417
+ registry.register_model_config(
418
+ "vit_reg4_l14",
419
+ vit,
420
+ config={"patch_size": 14, **LARGE, "num_reg_tokens": 4},
421
+ )
422
+ registry.register_model_config( # DeiT III style
423
+ "vit_reg4_l14_nps_ls",
424
+ vit,
425
+ config={
426
+ "pos_embed_special_tokens": False,
427
+ "patch_size": 14,
428
+ **LARGE,
429
+ "layer_scale_init_value": 1e-5,
430
+ "num_reg_tokens": 4,
431
+ },
432
+ )
433
+ registry.register_model_config(
434
+ "vit_reg8_l14_ap",
435
+ vit,
436
+ config={"patch_size": 14, **LARGE, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
437
+ )
438
+ registry.register_model_config(
439
+ "vit_reg8_l14_rms_ap",
440
+ vit,
441
+ config={
442
+ "patch_size": 14,
443
+ **LARGE,
444
+ "num_reg_tokens": 8,
445
+ "class_token": False,
446
+ "attn_pool_head": True,
447
+ "norm_layer_type": "RMSNorm",
448
+ },
449
+ )
450
+ registry.register_model_config(
451
+ "vit_reg8_so400m_p14_ap",
452
+ vit,
453
+ config={"patch_size": 14, **SO400, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
454
+ )
455
+ registry.register_model_config(
456
+ "vit_reg4_h16",
457
+ vit,
458
+ config={"patch_size": 16, **HUGE, "num_reg_tokens": 4},
459
+ )
460
+ registry.register_model_config(
461
+ "vit_reg4_h14",
462
+ vit,
463
+ config={"patch_size": 14, **HUGE, "num_reg_tokens": 4},
464
+ )
465
+ registry.register_model_config( # From "Scaling Vision Transformers"
466
+ "vit_reg4_g16",
467
+ vit,
468
+ config={"patch_size": 16, **GIANT, "num_reg_tokens": 4},
469
+ )
470
+ registry.register_model_config( # From "Scaling Vision Transformers"
471
+ "vit_reg4_g14",
472
+ vit,
473
+ config={"patch_size": 14, **GIANT, "num_reg_tokens": 4},
474
+ )
475
+ registry.register_model_config( # From "Scaling Vision Transformers"
476
+ "vit_reg4_gigantic14",
477
+ vit,
478
+ config={"patch_size": 14, **GIGANTIC, "num_reg_tokens": 4},
479
+ )
birder/net/biformer.py CHANGED
@@ -8,6 +8,7 @@ Changes from original:
8
8
  * All attention types are in (B, C, H, W)
9
9
  * Using the newer Bi-Level Routing Attention implementation
10
10
  * Dynamic n_win size (image size // 32)
11
+ * Stem bias term removed
11
12
  """
12
13
 
13
14
  # Reference license: Apache-2.0
birder/net/cait.py CHANGED
@@ -66,12 +66,12 @@ class ClassAttentionBlock(nn.Module):
66
66
  self, dim: int, num_heads: int, mlp_ratio: float, qkv_bias: bool, proj_drop: float, drop_path: float, eta: float
67
67
  ) -> None:
68
68
  super().__init__()
69
- self.norm1 = nn.LayerNorm(dim)
69
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
70
70
 
71
71
  self.attn = ClassAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=proj_drop)
72
72
 
73
73
  self.drop_path = StochasticDepth(drop_path, mode="row")
74
- self.norm2 = nn.LayerNorm(dim)
74
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
75
75
  self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
76
76
 
77
77
  self.gamma1 = nn.Parameter(eta * torch.ones(dim))
@@ -135,7 +135,7 @@ class LayerScaleBlock(nn.Module):
135
135
  init_values: float,
136
136
  ) -> None:
137
137
  super().__init__()
138
- self.norm1 = nn.LayerNorm(dim)
138
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
139
139
  self.attn = TalkingHeadAttn(
140
140
  dim,
141
141
  num_heads=num_heads,
@@ -144,7 +144,7 @@ class LayerScaleBlock(nn.Module):
144
144
  proj_drop=proj_drop,
145
145
  )
146
146
  self.drop_path = StochasticDepth(drop_path, mode="row")
147
- self.norm2 = nn.LayerNorm(dim)
147
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
148
148
  self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
149
149
  self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
150
150
  self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
@@ -221,7 +221,7 @@ class CaiT(BaseNet):
221
221
  )
222
222
  )
223
223
 
224
- self.norm = nn.LayerNorm(embed_dim)
224
+ self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
225
225
 
226
226
  self.embedding_size = embed_dim
227
227
  self.classifier = self.create_classifier()
birder/net/coat.py CHANGED
@@ -21,7 +21,7 @@ from birder.net.base import DetectorBackbone
21
21
 
22
22
 
23
23
  def insert_cls(x: torch.Tensor, cls_token: torch.Tensor) -> torch.Tensor:
24
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
24
+ cls_tokens = cls_token.expand(x.size(0), -1, -1)
25
25
  x = torch.concat((cls_tokens, x), dim=1)
26
26
 
27
27
  return x
@@ -170,7 +170,7 @@ class SerialBlock(nn.Module):
170
170
 
171
171
  # Conv-attention
172
172
  self.cpe = shared_cpe
173
- self.norm1 = nn.LayerNorm(dim)
173
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
174
174
  self.factor_attn_crpe = FactorAttnConvRelPosEnc(
175
175
  dim,
176
176
  num_heads=num_heads,
@@ -181,7 +181,7 @@ class SerialBlock(nn.Module):
181
181
  self.drop_path = StochasticDepth(drop_path, mode="row")
182
182
 
183
183
  # MLP
184
- self.norm2 = nn.LayerNorm(dim)
184
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
185
185
  self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
186
186
 
187
187
  def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
@@ -213,9 +213,9 @@ class ParallelBlock(nn.Module):
213
213
  super().__init__()
214
214
 
215
215
  # Conv-attention
216
- self.norm12 = nn.LayerNorm(dims[1])
217
- self.norm13 = nn.LayerNorm(dims[2])
218
- self.norm14 = nn.LayerNorm(dims[3])
216
+ self.norm12 = nn.LayerNorm(dims[1], eps=1e-6)
217
+ self.norm13 = nn.LayerNorm(dims[2], eps=1e-6)
218
+ self.norm14 = nn.LayerNorm(dims[3], eps=1e-6)
219
219
  self.factor_attn_crpe2 = FactorAttnConvRelPosEnc(
220
220
  dims[1], num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=proj_drop, shared_crpe=shared_crpes[1]
221
221
  )
@@ -228,9 +228,9 @@ class ParallelBlock(nn.Module):
228
228
  self.drop_path = StochasticDepth(drop_path, mode="row")
229
229
 
230
230
  # MLP
231
- self.norm22 = nn.LayerNorm(dims[1])
232
- self.norm23 = nn.LayerNorm(dims[2])
233
- self.norm24 = nn.LayerNorm(dims[3])
231
+ self.norm22 = nn.LayerNorm(dims[1], eps=1e-6)
232
+ self.norm23 = nn.LayerNorm(dims[2], eps=1e-6)
233
+ self.norm24 = nn.LayerNorm(dims[3], eps=1e-6)
234
234
 
235
235
  # In the parallel block, we assume dimensions are the same and share the linear transformation
236
236
  assert dims[1] == dims[2] == dims[3]
@@ -447,13 +447,13 @@ class CoaT(DetectorBackbone):
447
447
 
448
448
  # Norms
449
449
  if self.parallel_blocks is not None:
450
- self.norm2 = nn.LayerNorm(embed_dims[1])
451
- self.norm3 = nn.LayerNorm(embed_dims[2])
450
+ self.norm2 = nn.LayerNorm(embed_dims[1], eps=1e-6)
451
+ self.norm3 = nn.LayerNorm(embed_dims[2], eps=1e-6)
452
452
  else:
453
453
  self.norm2 = None
454
454
  self.norm3 = None
455
455
 
456
- self.norm4 = nn.LayerNorm(embed_dims[3])
456
+ self.norm4 = nn.LayerNorm(embed_dims[3], eps=1e-6)
457
457
 
458
458
  # Head
459
459
  if parallel_depth > 0:
birder/net/conv2former.py CHANGED
@@ -64,7 +64,7 @@ class SpatialAttention(nn.Module):
64
64
  dim,
65
65
  kernel_size=kernel_size,
66
66
  stride=(1, 1),
67
- padding=(kernel_size[0] // 2, kernel_size[1] // 2),
67
+ padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
68
68
  groups=dim,
69
69
  ),
70
70
  )
@@ -87,8 +87,8 @@ class Conv2FormerBlock(nn.Module):
87
87
  self.mlp = MLP(dim, mlp_ratio)
88
88
 
89
89
  layer_scale_init_value = 1e-6
90
- self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)), requires_grad=True)
91
- self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)), requires_grad=True)
90
+ self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
91
+ self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
92
92
 
93
93
  def forward(self, x: torch.Tensor) -> torch.Tensor:
94
94
  x = x + self.drop_path(self.layer_scale_1 * self.attn(x))
birder/net/convmixer.py CHANGED
@@ -58,7 +58,7 @@ class ConvMixer(BaseNet):
58
58
  inplace=None,
59
59
  )
60
60
 
61
- padding = (kernel_size[0] // 2, kernel_size[1] // 2)
61
+ padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
62
62
  self.body = nn.Sequential(
63
63
  *[
64
64
  nn.Sequential(
birder/net/convnext_v1.py CHANGED
@@ -53,7 +53,7 @@ class ConvNeXtBlock(nn.Module):
53
53
  nn.Linear(4 * channels, channels), # Same as 1x1 conv
54
54
  Permute([0, 3, 1, 2]),
55
55
  )
56
- self.layer_scale = nn.Parameter(torch.ones(channels, 1, 1) * layer_scale, requires_grad=True)
56
+ self.layer_scale = nn.Parameter(torch.ones(channels, 1, 1) * layer_scale)
57
57
  self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
58
58
 
59
59
  def forward(self, x: torch.Tensor) -> torch.Tensor:
birder/net/crossvit.py CHANGED
@@ -97,7 +97,7 @@ class CrossAttentionBlock(nn.Module):
97
97
  self, dim: int, num_heads: int, qkv_bias: bool, proj_drop: float, attn_drop: float, drop_path: float
98
98
  ) -> None:
99
99
  super().__init__()
100
- self.norm1 = nn.LayerNorm(dim)
100
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
101
101
  self.attn = CrossAttention(
102
102
  dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop
103
103
  )
@@ -146,7 +146,7 @@ class MultiScaleBlock(nn.Module):
146
146
  for d in range(num_branches):
147
147
  self.projs.append(
148
148
  nn.Sequential(
149
- nn.LayerNorm(dim[d]),
149
+ nn.LayerNorm(dim[d], eps=1e-6),
150
150
  nn.GELU(),
151
151
  nn.Linear(dim[d], dim[(d + 1) % num_branches]),
152
152
  )
@@ -187,7 +187,7 @@ class MultiScaleBlock(nn.Module):
187
187
  for d in range(num_branches):
188
188
  self.revert_projs.append(
189
189
  nn.Sequential(
190
- nn.LayerNorm(dim[(d + 1) % num_branches]),
190
+ nn.LayerNorm(dim[(d + 1) % num_branches], eps=1e-6),
191
191
  nn.GELU(),
192
192
  nn.Linear(dim[(d + 1) % num_branches], dim[d]),
193
193
  )
@@ -290,7 +290,7 @@ class CrossViT(BaseNet):
290
290
  dpr_ptr += curr_depth
291
291
  self.blocks.append(block)
292
292
 
293
- self.norm = nn.ModuleList([nn.LayerNorm(embed_dim[i]) for i in range(self.num_branches)])
293
+ self.norm = nn.ModuleList([nn.LayerNorm(embed_dim[i], eps=1e-6) for i in range(self.num_branches)])
294
294
  self.embedding_size = sum(self.embed_dim)
295
295
  self.classifier = nn.ModuleList()
296
296
  for i in range(self.num_branches):
@@ -482,7 +482,7 @@ registry.register_weights(
482
482
  "formats": {
483
483
  "pt": {
484
484
  "file_size": 32.7,
485
- "sha256": "515265ed725adce09464bfd23ce612b1d1178bc22a57960db089d7148556149a",
485
+ "sha256": "08f674d8165dc97cc535f8188a5c5361751a8d0bb85061454986a21541a6fe8e",
486
486
  }
487
487
  },
488
488
  "net": {"network": "crossvit_9d", "tag": "il-common"},
birder/net/davit.py CHANGED
@@ -64,7 +64,7 @@ class ConvPosEnc(nn.Module):
64
64
  dim,
65
65
  kernel_size=kernel_size,
66
66
  stride=(1, 1),
67
- padding=(kernel_size[0] // 2, kernel_size[1] // 2),
67
+ padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
68
68
  groups=dim,
69
69
  )
70
70
  if act is True: