birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,466 @@
1
+ """
2
+ ViT model configuration registrations
3
+
4
+ This file contains *only* model variant definitions and their registration
5
+ with the global model registry. The actual ViT implementation lives in vit.py.
6
+
7
+ Naming:
8
+ - All model names must follow the ViT / RoPE ViT naming convention documented in rope_vit_configs.py.
9
+ """
10
+
11
+ from birder.model_registry import registry
12
+ from birder.net.base import BaseNet
13
+
14
+ TINY = {"num_layers": 12, "num_heads": 3, "hidden_dim": 192, "mlp_dim": 768, "drop_path_rate": 0.0}
15
+ SMALL = {"num_layers": 12, "num_heads": 6, "hidden_dim": 384, "mlp_dim": 1536, "drop_path_rate": 0.0}
16
+ MEDIUM = {"num_layers": 12, "num_heads": 8, "hidden_dim": 512, "mlp_dim": 2048, "drop_path_rate": 0.0}
17
+ BASE = {"num_layers": 12, "num_heads": 12, "hidden_dim": 768, "mlp_dim": 3072, "drop_path_rate": 0.1}
18
+ LARGE = {"num_layers": 24, "num_heads": 16, "hidden_dim": 1024, "mlp_dim": 4096, "drop_path_rate": 0.1}
19
+ HUGE = {"num_layers": 32, "num_heads": 16, "hidden_dim": 1280, "mlp_dim": 5120, "drop_path_rate": 0.1}
20
+
21
+ # From "Getting vit in Shape: Scaling Laws for Compute-Optimal Model Design"
22
+ # Shape-optimized vision transformer (SoViT)
23
+ SO150 = {
24
+ "num_layers": 18,
25
+ "num_heads": 16,
26
+ "hidden_dim": 896, # Changed from 880 for RoPE divisibility
27
+ "mlp_dim": 2320,
28
+ "drop_path_rate": 0.1,
29
+ }
30
+ SO400 = {
31
+ "num_layers": 27,
32
+ "num_heads": 16,
33
+ "hidden_dim": 1152,
34
+ "mlp_dim": 4304,
35
+ "drop_path_rate": 0.1,
36
+ }
37
+
38
+ # From "Scaling Vision Transformers"
39
+ GIANT = {"num_layers": 40, "num_heads": 16, "hidden_dim": 1408, "mlp_dim": 6144, "drop_path_rate": 0.1}
40
+ GIGANTIC = {"num_layers": 48, "num_heads": 16, "hidden_dim": 1664, "mlp_dim": 8192, "drop_path_rate": 0.1}
41
+
42
+
43
+ def register_vit_configs(vit: type[BaseNet]) -> None:
44
+ registry.register_model_config(
45
+ "vit_t32",
46
+ vit,
47
+ config={"patch_size": 32, **TINY},
48
+ )
49
+ registry.register_model_config(
50
+ "vit_t16",
51
+ vit,
52
+ config={"patch_size": 16, **TINY},
53
+ )
54
+ registry.register_model_config(
55
+ "vit_t14",
56
+ vit,
57
+ config={"patch_size": 14, **TINY},
58
+ )
59
+ registry.register_model_config(
60
+ "vit_s32",
61
+ vit,
62
+ config={"patch_size": 32, **SMALL},
63
+ )
64
+ registry.register_model_config(
65
+ "vit_s16",
66
+ vit,
67
+ config={"patch_size": 16, **SMALL},
68
+ )
69
+ registry.register_model_config(
70
+ "vit_s16_ls",
71
+ vit,
72
+ config={"patch_size": 16, **SMALL, "layer_scale_init_value": 1e-5},
73
+ )
74
+ registry.register_model_config(
75
+ "vit_s16_pn",
76
+ vit,
77
+ config={"patch_size": 16, **SMALL, "pre_norm": True, "norm_layer_eps": 1e-5},
78
+ )
79
+ registry.register_model_config(
80
+ "vit_s14",
81
+ vit,
82
+ config={"patch_size": 14, **SMALL},
83
+ )
84
+ registry.register_model_config(
85
+ "vit_m32",
86
+ vit,
87
+ config={"patch_size": 32, **MEDIUM},
88
+ )
89
+ registry.register_model_config(
90
+ "vit_m16",
91
+ vit,
92
+ config={"patch_size": 16, **MEDIUM},
93
+ )
94
+ registry.register_model_config(
95
+ "vit_m14",
96
+ vit,
97
+ config={"patch_size": 14, **MEDIUM},
98
+ )
99
+ registry.register_model_config(
100
+ "vit_b32",
101
+ vit,
102
+ config={"patch_size": 32, **BASE, "drop_path_rate": 0.0}, # Override the BASE definition
103
+ )
104
+ registry.register_model_config(
105
+ "vit_b16",
106
+ vit,
107
+ config={"patch_size": 16, **BASE},
108
+ )
109
+ registry.register_model_config(
110
+ "vit_b16_ls",
111
+ vit,
112
+ config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5},
113
+ )
114
+ registry.register_model_config(
115
+ "vit_b16_qkn_ls",
116
+ vit,
117
+ config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5, "qk_norm": True},
118
+ )
119
+ registry.register_model_config(
120
+ "vit_b16_pn_quick_gelu",
121
+ vit,
122
+ config={"patch_size": 16, **BASE, "pre_norm": True, "norm_layer_eps": 1e-5, "act_layer_type": "quick_gelu"},
123
+ )
124
+ registry.register_model_config(
125
+ "vit_b14",
126
+ vit,
127
+ config={"patch_size": 14, **BASE},
128
+ )
129
+ registry.register_model_config(
130
+ "vit_so150m_p14_avg",
131
+ vit,
132
+ config={"patch_size": 14, **SO150, "class_token": False},
133
+ )
134
+ registry.register_model_config(
135
+ "vit_so150m_p14_ap",
136
+ vit,
137
+ config={"patch_size": 14, **SO150, "class_token": False, "attn_pool_head": True},
138
+ )
139
+ registry.register_model_config(
140
+ "vit_l32",
141
+ vit,
142
+ config={"patch_size": 32, **LARGE},
143
+ )
144
+ registry.register_model_config(
145
+ "vit_l16",
146
+ vit,
147
+ config={"patch_size": 16, **LARGE},
148
+ )
149
+ registry.register_model_config(
150
+ "vit_l14",
151
+ vit,
152
+ config={"patch_size": 14, **LARGE},
153
+ )
154
+ registry.register_model_config(
155
+ "vit_l14_pn",
156
+ vit,
157
+ config={"patch_size": 14, **LARGE, "pre_norm": True, "norm_layer_eps": 1e-5},
158
+ )
159
+ registry.register_model_config(
160
+ "vit_l14_pn_quick_gelu",
161
+ vit,
162
+ config={"patch_size": 14, **LARGE, "pre_norm": True, "norm_layer_eps": 1e-5, "act_layer_type": "quick_gelu"},
163
+ )
164
+ registry.register_model_config(
165
+ "vit_so400m_p14_ap",
166
+ vit,
167
+ config={"patch_size": 14, **SO400, "class_token": False, "attn_pool_head": True},
168
+ )
169
+ registry.register_model_config(
170
+ "vit_h16",
171
+ vit,
172
+ config={"patch_size": 16, **HUGE},
173
+ )
174
+ registry.register_model_config(
175
+ "vit_h14",
176
+ vit,
177
+ config={"patch_size": 14, **HUGE},
178
+ )
179
+ registry.register_model_config( # From "Scaling Vision Transformers"
180
+ "vit_g16",
181
+ vit,
182
+ config={"patch_size": 16, **GIANT},
183
+ )
184
+ registry.register_model_config( # From "Scaling Vision Transformers"
185
+ "vit_g14",
186
+ vit,
187
+ config={"patch_size": 14, **GIANT},
188
+ )
189
+ registry.register_model_config( # From "Scaling Vision Transformers"
190
+ "vit_gigantic14",
191
+ vit,
192
+ config={"patch_size": 14, **GIGANTIC},
193
+ )
194
+ registry.register_model_config( # From "PaLI: A Jointly-Scaled Multilingual Language-Image Model"
195
+ "vit_e14",
196
+ vit,
197
+ config={
198
+ "patch_size": 14,
199
+ "num_layers": 56,
200
+ "num_heads": 16,
201
+ "hidden_dim": 1792,
202
+ "mlp_dim": 15360,
203
+ "drop_path_rate": 0.1,
204
+ },
205
+ )
206
+ registry.register_model_config( # From "Scaling Language-Free Visual Representation Learning"
207
+ "vit_1b_p16", # AKA vit_giant2 from DINOv2
208
+ vit,
209
+ config={
210
+ "patch_size": 16,
211
+ "num_layers": 40,
212
+ "num_heads": 24,
213
+ "hidden_dim": 1536,
214
+ "mlp_dim": 6144,
215
+ "drop_path_rate": 0.1,
216
+ },
217
+ )
218
+
219
+ # With registers
220
+ ####################
221
+
222
+ registry.register_model_config(
223
+ "vit_reg1_t32",
224
+ vit,
225
+ config={"patch_size": 32, **TINY, "num_reg_tokens": 1},
226
+ )
227
+ registry.register_model_config(
228
+ "vit_reg1_t16",
229
+ vit,
230
+ config={"patch_size": 16, **TINY, "num_reg_tokens": 1},
231
+ )
232
+ registry.register_model_config(
233
+ "vit_reg1_t14",
234
+ vit,
235
+ config={"patch_size": 14, **TINY, "num_reg_tokens": 1},
236
+ )
237
+ registry.register_model_config(
238
+ "vit_reg1_s32",
239
+ vit,
240
+ config={"patch_size": 32, **SMALL, "num_reg_tokens": 1},
241
+ )
242
+ registry.register_model_config(
243
+ "vit_reg1_s16",
244
+ vit,
245
+ config={"patch_size": 16, **SMALL, "num_reg_tokens": 1},
246
+ )
247
+ registry.register_model_config(
248
+ "vit_reg1_s16_ls",
249
+ vit,
250
+ config={"patch_size": 16, **SMALL, "layer_scale_init_value": 1e-5, "num_reg_tokens": 1},
251
+ )
252
+ registry.register_model_config(
253
+ "vit_reg1_s16_rms_ls",
254
+ vit,
255
+ config={
256
+ "patch_size": 16,
257
+ **SMALL,
258
+ "layer_scale_init_value": 1e-5,
259
+ "num_reg_tokens": 1,
260
+ "norm_layer_type": "RMSNorm",
261
+ },
262
+ )
263
+ registry.register_model_config(
264
+ "vit_reg1_s14",
265
+ vit,
266
+ config={"patch_size": 14, **SMALL, "num_reg_tokens": 1},
267
+ )
268
+ registry.register_model_config(
269
+ "vit_reg4_m32",
270
+ vit,
271
+ config={"patch_size": 32, **MEDIUM, "num_reg_tokens": 4},
272
+ )
273
+ registry.register_model_config(
274
+ "vit_reg4_m16",
275
+ vit,
276
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4},
277
+ )
278
+ registry.register_model_config(
279
+ "vit_reg4_m16_rms_avg",
280
+ vit,
281
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4, "class_token": False, "norm_layer_type": "RMSNorm"},
282
+ )
283
+ registry.register_model_config(
284
+ "vit_reg4_m14",
285
+ vit,
286
+ config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4},
287
+ )
288
+ registry.register_model_config(
289
+ "vit_reg4_b32",
290
+ vit,
291
+ config={"patch_size": 32, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.0}, # Override the BASE definition
292
+ )
293
+ registry.register_model_config(
294
+ "vit_reg4_b16",
295
+ vit,
296
+ config={"patch_size": 16, **BASE, "num_reg_tokens": 4},
297
+ )
298
+ registry.register_model_config(
299
+ "vit_reg4_b16_avg",
300
+ vit,
301
+ config={"patch_size": 16, **BASE, "num_reg_tokens": 4, "class_token": False},
302
+ )
303
+ registry.register_model_config(
304
+ "vit_reg4_b14",
305
+ vit,
306
+ config={"patch_size": 14, **BASE, "num_reg_tokens": 4},
307
+ )
308
+ registry.register_model_config(
309
+ "vit_reg8_b14_ap",
310
+ vit,
311
+ config={"patch_size": 14, **BASE, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
312
+ )
313
+ registry.register_model_config(
314
+ "vit_reg4_so150m_p16_avg",
315
+ vit,
316
+ config={"patch_size": 16, **SO150, "num_reg_tokens": 4, "class_token": False},
317
+ )
318
+ registry.register_model_config(
319
+ "vit_reg8_so150m_p16_swiglu_ap",
320
+ vit,
321
+ config={
322
+ "patch_size": 16,
323
+ **SO150,
324
+ "num_reg_tokens": 8,
325
+ "class_token": False,
326
+ "attn_pool_head": True,
327
+ "mlp_layer_type": "SwiGLU_FFN",
328
+ },
329
+ )
330
+ registry.register_model_config(
331
+ "vit_reg4_so150m_p14_avg",
332
+ vit,
333
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 4, "class_token": False},
334
+ )
335
+ registry.register_model_config(
336
+ "vit_reg4_so150m_p14_ls",
337
+ vit,
338
+ config={"patch_size": 14, **SO150, "layer_scale_init_value": 1e-5, "num_reg_tokens": 4},
339
+ )
340
+ registry.register_model_config(
341
+ "vit_reg4_so150m_p14_ap",
342
+ vit,
343
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 4, "class_token": False, "attn_pool_head": True},
344
+ )
345
+ registry.register_model_config(
346
+ "vit_reg4_so150m_p14_aps",
347
+ vit,
348
+ config={
349
+ "patch_size": 14,
350
+ **SO150,
351
+ "num_reg_tokens": 4,
352
+ "class_token": False,
353
+ "attn_pool_head": True,
354
+ "attn_pool_special_tokens": True,
355
+ },
356
+ )
357
+ registry.register_model_config(
358
+ "vit_reg8_so150m_p14_avg",
359
+ vit,
360
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False},
361
+ )
362
+ registry.register_model_config(
363
+ "vit_reg8_so150m_p14_swiglu",
364
+ vit,
365
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "mlp_layer_type": "SwiGLU_FFN"},
366
+ )
367
+ registry.register_model_config(
368
+ "vit_reg8_so150m_p14_swiglu_avg",
369
+ vit,
370
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False, "mlp_layer_type": "SwiGLU_FFN"},
371
+ )
372
+ registry.register_model_config(
373
+ "vit_reg8_so150m_p14_ap",
374
+ vit,
375
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
376
+ )
377
+ registry.register_model_config(
378
+ "vit_reg4_l32",
379
+ vit,
380
+ config={"patch_size": 32, **LARGE, "num_reg_tokens": 4},
381
+ )
382
+ registry.register_model_config(
383
+ "vit_reg4_l16",
384
+ vit,
385
+ config={"patch_size": 16, **LARGE, "num_reg_tokens": 4},
386
+ )
387
+ registry.register_model_config(
388
+ "vit_reg8_l16_avg",
389
+ vit,
390
+ config={"patch_size": 16, **LARGE, "num_reg_tokens": 8, "class_token": False},
391
+ )
392
+ registry.register_model_config(
393
+ "vit_reg8_l16_aps",
394
+ vit,
395
+ config={
396
+ "patch_size": 16,
397
+ **LARGE,
398
+ "num_reg_tokens": 8,
399
+ "class_token": False,
400
+ "attn_pool_head": True,
401
+ "attn_pool_special_tokens": True,
402
+ },
403
+ )
404
+ registry.register_model_config(
405
+ "vit_reg4_l14",
406
+ vit,
407
+ config={"patch_size": 14, **LARGE, "num_reg_tokens": 4},
408
+ )
409
+ registry.register_model_config( # DeiT III style
410
+ "vit_reg4_l14_nps_ls",
411
+ vit,
412
+ config={
413
+ "pos_embed_special_tokens": False,
414
+ "patch_size": 14,
415
+ **LARGE,
416
+ "layer_scale_init_value": 1e-5,
417
+ "num_reg_tokens": 4,
418
+ },
419
+ )
420
+ registry.register_model_config(
421
+ "vit_reg8_l14_ap",
422
+ vit,
423
+ config={"patch_size": 14, **LARGE, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
424
+ )
425
+ registry.register_model_config(
426
+ "vit_reg8_l14_rms_ap",
427
+ vit,
428
+ config={
429
+ "patch_size": 14,
430
+ **LARGE,
431
+ "num_reg_tokens": 8,
432
+ "class_token": False,
433
+ "attn_pool_head": True,
434
+ "norm_layer_type": "RMSNorm",
435
+ },
436
+ )
437
+ registry.register_model_config(
438
+ "vit_reg8_so400m_p14_ap",
439
+ vit,
440
+ config={"patch_size": 14, **SO400, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
441
+ )
442
+ registry.register_model_config(
443
+ "vit_reg4_h16",
444
+ vit,
445
+ config={"patch_size": 16, **HUGE, "num_reg_tokens": 4},
446
+ )
447
+ registry.register_model_config(
448
+ "vit_reg4_h14",
449
+ vit,
450
+ config={"patch_size": 14, **HUGE, "num_reg_tokens": 4},
451
+ )
452
+ registry.register_model_config( # From "Scaling Vision Transformers"
453
+ "vit_reg4_g16",
454
+ vit,
455
+ config={"patch_size": 16, **GIANT, "num_reg_tokens": 4},
456
+ )
457
+ registry.register_model_config( # From "Scaling Vision Transformers"
458
+ "vit_reg4_g14",
459
+ vit,
460
+ config={"patch_size": 14, **GIANT, "num_reg_tokens": 4},
461
+ )
462
+ registry.register_model_config( # From "Scaling Vision Transformers"
463
+ "vit_reg4_gigantic14",
464
+ vit,
465
+ config={"patch_size": 14, **GIGANTIC, "num_reg_tokens": 4},
466
+ )
birder/net/alexnet.py CHANGED
@@ -27,17 +27,17 @@ class AlexNet(BaseNet):
27
27
  assert self.config is None, "config not supported"
28
28
 
29
29
  self.body = nn.Sequential(
30
- nn.Conv2d(self.input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2), bias=True),
30
+ nn.Conv2d(self.input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)),
31
31
  nn.ReLU(inplace=True),
32
32
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
33
- nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True),
33
+ nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
34
34
  nn.ReLU(inplace=True),
35
35
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
36
- nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
36
+ nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
37
37
  nn.ReLU(inplace=True),
38
- nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
38
+ nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
39
39
  nn.ReLU(inplace=True),
40
- nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
40
+ nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
41
41
  nn.ReLU(inplace=True),
42
42
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
43
43
  nn.AdaptiveAvgPool2d(output_size=(6, 6)),
birder/net/base.py CHANGED
@@ -5,6 +5,7 @@ from typing import Literal
5
5
  from typing import NotRequired
6
6
  from typing import Optional
7
7
  from typing import TypedDict
8
+ from typing import overload
8
9
 
9
10
  import torch
10
11
  import torch.nn.functional as F
@@ -54,6 +55,30 @@ def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> i
54
55
  return new_v
55
56
 
56
57
 
58
+ @overload
59
+ def normalize_out_indices(out_indices: None, num_layers: int) -> None: ...
60
+
61
+
62
+ @overload
63
+ def normalize_out_indices(out_indices: list[int], num_layers: int) -> list[int]: ...
64
+
65
+
66
+ def normalize_out_indices(out_indices: Optional[list[int]], num_layers: int) -> Optional[list[int]]:
67
+ if out_indices is None:
68
+ return None
69
+
70
+ normalized_indices = []
71
+ for idx in out_indices:
72
+ if idx < 0:
73
+ idx = num_layers + idx
74
+ if idx < 0 or idx >= num_layers:
75
+ raise ValueError(f"out_indices contains invalid index for num_layers={num_layers}")
76
+
77
+ normalized_indices.append(idx)
78
+
79
+ return normalized_indices
80
+
81
+
57
82
  # class MiscNet(nn.Module):
58
83
  # """
59
84
  # Base class for general-purpose neural networks with automatic model registration
@@ -137,8 +162,8 @@ class BaseNet(nn.Module):
137
162
 
138
163
  self.dynamic_size = False
139
164
 
140
- self.classifier: nn.Module
141
165
  self.embedding_size: int
166
+ self.classifier: nn.Module
142
167
 
143
168
  def create_classifier(self, embed_dim: Optional[int] = None) -> nn.Module:
144
169
  if self.num_classes == 0:
@@ -274,7 +299,7 @@ def pos_embedding_sin_cos_2d(
274
299
  ) -> torch.Tensor:
275
300
  # assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sin-cos emb"
276
301
 
277
- (y, x) = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
302
+ y, x = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
278
303
  omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1)
279
304
  omega = 1.0 / (temperature**omega)
280
305
 
@@ -294,7 +319,7 @@ def interpolate_attention_bias(
294
319
  new_resolution: tuple[int, int],
295
320
  mode: Literal["bilinear", "bicubic"] = "bicubic",
296
321
  ) -> torch.Tensor:
297
- (H, _) = attention_bias.size()
322
+ H, _ = attention_bias.size()
298
323
 
299
324
  # Interpolate
300
325
  orig_dtype = attention_bias.dtype
birder/net/biformer.py CHANGED
@@ -8,6 +8,7 @@ Changes from original:
8
8
  * All attention types are in (B, C, H, W)
9
9
  * Using the newer Bi-Level Routing Attention implementation
10
10
  * Dynamic n_win size (image size // 32)
11
+ * Stem bias term removed
11
12
  """
12
13
 
13
14
  # Reference license: Apache-2.0
@@ -29,7 +30,7 @@ from birder.net.base import DetectorBackbone
29
30
 
30
31
 
31
32
  def _grid2seq(x: torch.Tensor, region_size: tuple[int, int], num_heads: int) -> tuple[torch.Tensor, int, int]:
32
- (B, C, H, W) = x.size()
33
+ B, C, H, W = x.size()
33
34
  region_h = H // region_size[0]
34
35
  region_w = W // region_size[1]
35
36
  x = x.view(B, num_heads, C // num_heads, region_h, region_size[0], region_w, region_size[1])
@@ -39,7 +40,7 @@ def _grid2seq(x: torch.Tensor, region_size: tuple[int, int], num_heads: int) ->
39
40
 
40
41
 
41
42
  def _seq2grid(x: torch.Tensor, region_h: int, region_w: int, region_size: tuple[int, int]) -> torch.Tensor:
42
- (bs, n_head, _, _, head_dim) = x.size()
43
+ bs, n_head, _, _, head_dim = x.size()
43
44
  x = x.view(bs, n_head, region_h, region_w, region_size[0], region_size[1], head_dim)
44
45
  x = torch.einsum("bmhwpqd->bmdhpwq", x).reshape(
45
46
  bs, n_head * head_dim, region_h * region_size[0], region_w * region_size[1]
@@ -59,7 +60,7 @@ def regional_routing_attention_torch(
59
60
  auto_pad: bool,
60
61
  ) -> tuple[torch.Tensor, torch.Tensor]:
61
62
  kv_region_size = region_size
62
- (bs, n_head, q_nregion, topk) = region_graph.size()
63
+ bs, n_head, q_nregion, topk = region_graph.size()
63
64
 
64
65
  # Pad to deal with any input size
65
66
  q_pad_b = 0
@@ -67,13 +68,13 @@ def regional_routing_attention_torch(
67
68
  kv_pad_b = 0
68
69
  kv_pad_r = 0
69
70
  if auto_pad is True:
70
- (_, _, h_q, w_q) = query.size()
71
+ _, _, h_q, w_q = query.size()
71
72
  q_pad_b = (region_size[0] - h_q % region_size[0]) % region_size[0]
72
73
  q_pad_r = (region_size[1] - w_q % region_size[1]) % region_size[1]
73
74
  if q_pad_b > 0 or q_pad_r > 0:
74
75
  query = F.pad(query, (0, q_pad_r, 0, q_pad_b))
75
76
 
76
- (_, _, h_k, w_k) = key.size()
77
+ _, _, h_k, w_k = key.size()
77
78
  kv_pad_b = (kv_region_size[0] - h_k % kv_region_size[0]) % kv_region_size[0]
78
79
  kv_pad_r = (kv_region_size[1] - w_k % kv_region_size[1]) % kv_region_size[1]
79
80
  if kv_pad_r > 0 or kv_pad_b > 0:
@@ -86,12 +87,12 @@ def regional_routing_attention_torch(
86
87
  w_k = None
87
88
 
88
89
  # To sequence format
89
- (query, q_region_h, q_region_w) = _grid2seq(query, region_size=region_size, num_heads=n_head)
90
- (key, _, _) = _grid2seq(key, region_size=kv_region_size, num_heads=n_head)
91
- (value, _, _) = _grid2seq(value, region_size=kv_region_size, num_heads=n_head)
90
+ query, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=n_head)
91
+ key, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=n_head)
92
+ value, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=n_head)
92
93
 
93
94
  # Gather key and values
94
- (bs, n_head, kv_nregion, kv_region_size, head_dim) = key.size()
95
+ bs, n_head, kv_nregion, kv_region_size, head_dim = key.size()
95
96
  broadcasted_region_graph = region_graph.view(bs, n_head, q_nregion, topk, 1, 1).expand(
96
97
  -1, -1, -1, -1, kv_region_size, head_dim
97
98
  )
@@ -145,12 +146,12 @@ class BiLevelRoutingAttention(nn.Module):
145
146
  self.output_linear = nn.Conv2d(dim, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
146
147
 
147
148
  def forward(self, x: torch.Tensor) -> torch.Tensor:
148
- (_, _, H, W) = x.size()
149
+ _, _, H, W = x.size()
149
150
  region_size = (H // self.n_win_h, W // self.n_win_w)
150
151
 
151
152
  # Linear projection
152
153
  qkv = self.qkv_linear(x)
153
- (q, k, v) = qkv.chunk(3, dim=1)
154
+ q, k, v = qkv.chunk(3, dim=1)
154
155
 
155
156
  # Region-to-region routing
156
157
  q_r = F.avg_pool2d( # pylint: disable=not-callable
@@ -162,11 +163,11 @@ class BiLevelRoutingAttention(nn.Module):
162
163
  q_r = q_r.permute(0, 2, 3, 1).flatten(1, 2) # (n, (hw), c)
163
164
  k_r = k_r.flatten(2, 3) # (n, c, (hw))
164
165
  a_r = q_r @ k_r
165
- (_, idx_r) = torch.topk(a_r, k=self.topk, dim=-1)
166
+ _, idx_r = torch.topk(a_r, k=self.topk, dim=-1)
166
167
  idx_r = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
167
168
 
168
169
  # Token to token attention
169
- (output, _) = regional_routing_attention_torch(
170
+ output, _ = regional_routing_attention_torch(
170
171
  q, k, v, scale=self.scale, region_graph=idx_r, region_size=region_size, auto_pad=True
171
172
  )
172
173
 
@@ -189,12 +190,12 @@ class Attention(nn.Module):
189
190
  self.proj_drop = nn.Dropout(proj_drop)
190
191
 
191
192
  def forward(self, x: torch.Tensor) -> torch.Tensor:
192
- (B, C, H, W) = x.size()
193
+ B, C, H, W = x.size()
193
194
  x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
194
195
 
195
196
  N = H * W
196
197
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
197
- (q, k, v) = qkv.unbind(0)
198
+ q, k, v = qkv.unbind(0)
198
199
 
199
200
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
200
201
  q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
@@ -236,8 +237,8 @@ class AttentionLePE(nn.Module):
236
237
  )
237
238
 
238
239
  def forward(self, x: torch.Tensor) -> torch.Tensor:
239
- (B, C, H, W) = x.size()
240
- (q, k, v) = self.qkv(x).chunk(3, dim=1)
240
+ B, C, H, W = x.size()
241
+ q, k, v = self.qkv(x).chunk(3, dim=1)
241
242
 
242
243
  attn = q.view(B, self.num_heads, self.head_dim, H * W).transpose(-1, -2) @ k.view(
243
244
  B, self.num_heads, self.head_dim, H * W