birder 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. birder/common/fs_ops.py +2 -2
  2. birder/introspection/attention_rollout.py +1 -1
  3. birder/introspection/transformer_attribution.py +1 -1
  4. birder/layers/layer_scale.py +1 -1
  5. birder/net/__init__.py +2 -10
  6. birder/net/_rope_vit_configs.py +430 -0
  7. birder/net/_vit_configs.py +479 -0
  8. birder/net/biformer.py +1 -0
  9. birder/net/cait.py +5 -5
  10. birder/net/coat.py +12 -12
  11. birder/net/conv2former.py +3 -3
  12. birder/net/convmixer.py +1 -1
  13. birder/net/convnext_v1.py +1 -1
  14. birder/net/crossvit.py +5 -5
  15. birder/net/davit.py +1 -1
  16. birder/net/deit.py +12 -26
  17. birder/net/deit3.py +42 -189
  18. birder/net/densenet.py +9 -8
  19. birder/net/detection/deformable_detr.py +5 -2
  20. birder/net/detection/detr.py +5 -2
  21. birder/net/detection/efficientdet.py +1 -1
  22. birder/net/dpn.py +1 -2
  23. birder/net/edgenext.py +2 -1
  24. birder/net/edgevit.py +3 -0
  25. birder/net/efficientformer_v1.py +2 -1
  26. birder/net/efficientformer_v2.py +18 -31
  27. birder/net/efficientnet_v2.py +3 -0
  28. birder/net/efficientvit_mit.py +5 -5
  29. birder/net/fasternet.py +2 -2
  30. birder/net/flexivit.py +22 -43
  31. birder/net/groupmixformer.py +1 -1
  32. birder/net/hgnet_v1.py +5 -5
  33. birder/net/hiera.py +3 -3
  34. birder/net/hieradet.py +116 -28
  35. birder/net/inception_next.py +1 -1
  36. birder/net/inception_resnet_v1.py +3 -3
  37. birder/net/inception_resnet_v2.py +7 -4
  38. birder/net/inception_v3.py +3 -0
  39. birder/net/inception_v4.py +3 -0
  40. birder/net/maxvit.py +1 -1
  41. birder/net/metaformer.py +3 -3
  42. birder/net/mim/crossmae.py +1 -1
  43. birder/net/mim/mae_vit.py +1 -1
  44. birder/net/mim/simmim.py +1 -1
  45. birder/net/mobilenet_v1.py +0 -9
  46. birder/net/mobilenet_v2.py +38 -44
  47. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  48. birder/net/mobilevit_v1.py +5 -32
  49. birder/net/mobilevit_v2.py +1 -45
  50. birder/net/moganet.py +8 -5
  51. birder/net/mvit_v2.py +6 -6
  52. birder/net/nfnet.py +4 -0
  53. birder/net/pit.py +1 -1
  54. birder/net/pvt_v1.py +5 -5
  55. birder/net/pvt_v2.py +5 -5
  56. birder/net/repghost.py +1 -30
  57. birder/net/resmlp.py +2 -2
  58. birder/net/resnest.py +3 -0
  59. birder/net/resnet_v1.py +125 -1
  60. birder/net/resnet_v2.py +75 -1
  61. birder/net/resnext.py +35 -1
  62. birder/net/rope_deit3.py +33 -136
  63. birder/net/rope_flexivit.py +18 -18
  64. birder/net/rope_vit.py +3 -735
  65. birder/net/simple_vit.py +22 -16
  66. birder/net/smt.py +1 -1
  67. birder/net/squeezenet.py +5 -12
  68. birder/net/squeezenext.py +0 -24
  69. birder/net/ssl/capi.py +1 -1
  70. birder/net/ssl/data2vec.py +1 -1
  71. birder/net/ssl/dino_v2.py +2 -2
  72. birder/net/ssl/franca.py +2 -2
  73. birder/net/ssl/i_jepa.py +1 -1
  74. birder/net/ssl/ibot.py +1 -1
  75. birder/net/swiftformer.py +12 -2
  76. birder/net/swin_transformer_v2.py +1 -1
  77. birder/net/tiny_vit.py +3 -16
  78. birder/net/van.py +2 -2
  79. birder/net/vit.py +35 -963
  80. birder/net/vit_sam.py +13 -38
  81. birder/net/xcit.py +7 -6
  82. birder/scripts/train.py +17 -15
  83. birder/scripts/train_kd.py +17 -16
  84. birder/tools/introspection.py +1 -1
  85. birder/tools/model_info.py +3 -1
  86. birder/tools/show_iterator.py +16 -2
  87. birder/version.py +1 -1
  88. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/METADATA +1 -1
  89. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/RECORD +93 -95
  90. birder/net/mobilenet_v3_small.py +0 -43
  91. birder/net/se_resnet_v1.py +0 -105
  92. birder/net/se_resnet_v2.py +0 -59
  93. birder/net/se_resnext.py +0 -30
  94. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/WHEEL +0 -0
  95. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/entry_points.txt +0 -0
  96. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/licenses/LICENSE +0 -0
  97. {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/top_level.txt +0 -0
birder/common/fs_ops.py CHANGED
@@ -158,7 +158,7 @@ def model_path(
158
158
  file_name = f"{file_name}_quantized"
159
159
 
160
160
  if states is True:
161
- file_name = f"{file_name}_states"
161
+ file_name = f"{file_name}_states.pt"
162
162
  elif lite is True:
163
163
  file_name = f"{file_name}.ptl"
164
164
  elif pt2 is True:
@@ -254,7 +254,7 @@ def clean_checkpoints(network_name: str, keep_last: int) -> None:
254
254
  models_glob = str(model_path(network_name, epoch=epoch))
255
255
  states_glob = str(model_path(network_name, epoch=epoch, states=True))
256
256
  model_pattern = re.compile(r".*_([1-9][0-9]*)\.pt$")
257
- states_pattern = re.compile(r".*_([1-9][0-9]*)_states$")
257
+ states_pattern = re.compile(r".*_([1-9][0-9]*)_states\.pt$")
258
258
 
259
259
  model_paths = list(settings.BASE_DIR.glob(models_glob))
260
260
  for p in sorted(model_paths, key=lambda p: p.stat().st_mtime)[:-keep_last]:
@@ -141,7 +141,7 @@ class AttentionRollout:
141
141
  net: nn.Module,
142
142
  device: torch.device,
143
143
  transform: Callable[..., torch.Tensor],
144
- attention_layer_name: str = "self_attention",
144
+ attention_layer_name: str = "attn",
145
145
  discard_ratio: float = 0.9,
146
146
  head_fusion: Literal["mean", "max", "min"] = "max",
147
147
  ) -> None:
@@ -132,7 +132,7 @@ class TransformerAttribution:
132
132
  net: nn.Module,
133
133
  device: torch.device,
134
134
  transform: Callable[..., torch.Tensor],
135
- attention_layer_name: str = "self_attention",
135
+ attention_layer_name: str = "attn",
136
136
  ) -> None:
137
137
  self.net = net.eval()
138
138
  self.device = device
@@ -23,7 +23,7 @@ class LayerScale2d(nn.Module):
23
23
  def __init__(self, dim: int, init_value: float, inplace: bool = False) -> None:
24
24
  super().__init__()
25
25
  self.inplace = inplace
26
- self.gamma = nn.Parameter(init_value * torch.ones(dim, 1, 1), requires_grad=True)
26
+ self.gamma = nn.Parameter(init_value * torch.ones(dim, 1, 1))
27
27
 
28
28
  def forward(self, x: torch.Tensor) -> torch.Tensor:
29
29
  if self.inplace is True:
birder/net/__init__.py CHANGED
@@ -55,8 +55,7 @@ from birder.net.metaformer import MetaFormer
55
55
  from birder.net.mnasnet import MNASNet
56
56
  from birder.net.mobilenet_v1 import MobileNet_v1
57
57
  from birder.net.mobilenet_v2 import MobileNet_v2
58
- from birder.net.mobilenet_v3_large import MobileNet_v3_Large
59
- from birder.net.mobilenet_v3_small import MobileNet_v3_Small
58
+ from birder.net.mobilenet_v3 import MobileNet_v3
60
59
  from birder.net.mobilenet_v4 import MobileNet_v4
61
60
  from birder.net.mobilenet_v4_hybrid import MobileNet_v4_Hybrid
62
61
  from birder.net.mobileone import MobileOne
@@ -84,9 +83,6 @@ from birder.net.resnext import ResNeXt
84
83
  from birder.net.rope_deit3 import RoPE_DeiT3
85
84
  from birder.net.rope_flexivit import RoPE_FlexiViT
86
85
  from birder.net.rope_vit import RoPE_ViT
87
- from birder.net.se_resnet_v1 import SE_ResNet_v1
88
- from birder.net.se_resnet_v2 import SE_ResNet_v2
89
- from birder.net.se_resnext import SE_ResNeXt
90
86
  from birder.net.sequencer2d import Sequencer2d
91
87
  from birder.net.shufflenet_v1 import ShuffleNet_v1
92
88
  from birder.net.shufflenet_v2 import ShuffleNet_v2
@@ -171,8 +167,7 @@ __all__ = [
171
167
  "MNASNet",
172
168
  "MobileNet_v1",
173
169
  "MobileNet_v2",
174
- "MobileNet_v3_Large",
175
- "MobileNet_v3_Small",
170
+ "MobileNet_v3",
176
171
  "MobileNet_v4",
177
172
  "MobileNet_v4_Hybrid",
178
173
  "MobileOne",
@@ -200,9 +195,6 @@ __all__ = [
200
195
  "RoPE_DeiT3",
201
196
  "RoPE_FlexiViT",
202
197
  "RoPE_ViT",
203
- "SE_ResNet_v1",
204
- "SE_ResNet_v2",
205
- "SE_ResNeXt",
206
198
  "Sequencer2d",
207
199
  "ShuffleNet_v1",
208
200
  "ShuffleNet_v2",
@@ -0,0 +1,430 @@
1
+ """
2
+ RoPE ViT model configuration registrations
3
+
4
+ This file contains *only* model variant definitions and their registration
5
+ with the global model registry. The actual RoPE ViT implementation lives in rope_vit.py.
6
+ """
7
+
8
+ from birder.model_registry import registry
9
+ from birder.net._vit_configs import BASE
10
+ from birder.net._vit_configs import GIANT
11
+ from birder.net._vit_configs import GIGANTIC
12
+ from birder.net._vit_configs import HUGE
13
+ from birder.net._vit_configs import LARGE
14
+ from birder.net._vit_configs import MEDIUM
15
+ from birder.net._vit_configs import SMALL
16
+ from birder.net._vit_configs import SO150
17
+ from birder.net._vit_configs import SO400
18
+ from birder.net._vit_configs import TINY
19
+ from birder.net.base import BaseNet
20
+
21
+ # Vision Transformer Model Naming Convention
22
+ # ==========================================
23
+ #
24
+ # Model names follow a structured pattern to encode architectural choices:
25
+ # [rope_]vit_[reg{N}_][size][patch_size][_components][_pooling][_c{N}]
26
+ #
27
+ # Core Components:
28
+ # - rope_ : Rotary Position Embedding (RoPE) enabled
29
+ # - rope_i_ : Rotary Position Embedding (RoPE) enabled with interleaved rotation - implies different temp, indexing
30
+ # - vit_ : Vision Transformer base architecture
31
+ # - reg{N}_ : Register tokens (N = number of register tokens, e.g., reg4, reg8)
32
+ # - size : Model size (s=small, b=base, l=large, or specific like so150m)
33
+ # - patch_size : Patch size (e.g., 14, 16, 32 for 14x14, 16x16, 32x32 patches)
34
+ #
35
+ # Optional Components:
36
+ # Position Embeddings:
37
+ # - nps : No Position embedding on Special tokens
38
+ #
39
+ # Normalization:
40
+ # - rms : RMSNorm (instead of LayerNorm)
41
+ # - pn : Pre-Norm (layer norm before the encoder) - implies norm eps of 1e-5
42
+ # - npn : No Post Norm (disables post-normalization layer)
43
+ # - qkn : QK Norm
44
+ #
45
+ # Feed-Forward Network:
46
+ # - swiglu : SwiGLU FFN layer type (instead of standard FFN)
47
+ #
48
+ # Activation:
49
+ # - quick_gelu : QuickGELU activation type
50
+ # - ...
51
+ #
52
+ # Regularization:
53
+ # - ls : Layer Scaling applied
54
+ #
55
+ # Pooling/Reduction:
56
+ # - avg : Average pooling for sequence reduction
57
+ # - ap : Attention Pooling for sequence reduction
58
+ # - aps : Attention Pooling inc. Special tokens for sequence reduction
59
+ #
60
+ # Custom Variants:
61
+ # - c{N} : Custom variant (N = version number) for models with fine-grained or non-standard
62
+ # modifications not fully reflected in the name
63
+
64
+
65
+ def register_rope_vit_configs(rope_vit: type[BaseNet]) -> None:
66
+ registry.register_model_config(
67
+ "rope_vit_t32",
68
+ rope_vit,
69
+ config={"patch_size": 32, **TINY},
70
+ )
71
+ registry.register_model_config(
72
+ "rope_vit_t16",
73
+ rope_vit,
74
+ config={"patch_size": 16, **TINY},
75
+ )
76
+ registry.register_model_config(
77
+ "rope_vit_t14",
78
+ rope_vit,
79
+ config={"patch_size": 14, **TINY},
80
+ )
81
+ registry.register_model_config(
82
+ "rope_vit_s32",
83
+ rope_vit,
84
+ config={"patch_size": 32, **SMALL},
85
+ )
86
+ registry.register_model_config(
87
+ "rope_vit_s16",
88
+ rope_vit,
89
+ config={"patch_size": 16, **SMALL},
90
+ )
91
+ registry.register_model_config(
92
+ "rope_i_vit_s16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
93
+ rope_vit,
94
+ config={
95
+ "patch_size": 16,
96
+ **SMALL,
97
+ "pre_norm": True,
98
+ "attn_pool_head": True,
99
+ "attn_pool_num_heads": 8,
100
+ "attn_pool_special_tokens": True,
101
+ "norm_layer_eps": 1e-5,
102
+ "rope_rot_type": "interleaved",
103
+ "rope_grid_indexing": "xy",
104
+ "rope_grid_offset": 1,
105
+ "rope_temperature": 10000.0,
106
+ },
107
+ )
108
+ registry.register_model_config(
109
+ "rope_vit_s14",
110
+ rope_vit,
111
+ config={"patch_size": 14, **SMALL},
112
+ )
113
+ registry.register_model_config(
114
+ "rope_vit_m32",
115
+ rope_vit,
116
+ config={"patch_size": 32, **MEDIUM},
117
+ )
118
+ registry.register_model_config(
119
+ "rope_vit_m16",
120
+ rope_vit,
121
+ config={"patch_size": 16, **MEDIUM},
122
+ )
123
+ registry.register_model_config(
124
+ "rope_vit_m14",
125
+ rope_vit,
126
+ config={"patch_size": 14, **MEDIUM},
127
+ )
128
+ registry.register_model_config(
129
+ "rope_vit_b32",
130
+ rope_vit,
131
+ config={"patch_size": 32, **BASE, "drop_path_rate": 0.0}, # Override the BASE definition
132
+ )
133
+ registry.register_model_config(
134
+ "rope_vit_b16",
135
+ rope_vit,
136
+ config={"patch_size": 16, **BASE},
137
+ )
138
+ registry.register_model_config(
139
+ "rope_vit_b16_qkn_ls",
140
+ rope_vit,
141
+ config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5, "qk_norm": True},
142
+ )
143
+ registry.register_model_config(
144
+ "rope_i_vit_b16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
145
+ rope_vit,
146
+ config={
147
+ "patch_size": 16,
148
+ **BASE,
149
+ "pre_norm": True,
150
+ "attn_pool_head": True,
151
+ "attn_pool_num_heads": 8,
152
+ "attn_pool_special_tokens": True,
153
+ "norm_layer_eps": 1e-5,
154
+ "rope_rot_type": "interleaved",
155
+ "rope_grid_indexing": "xy",
156
+ "rope_grid_offset": 1,
157
+ "rope_temperature": 10000.0,
158
+ },
159
+ )
160
+ registry.register_model_config(
161
+ "rope_vit_b14",
162
+ rope_vit,
163
+ config={"patch_size": 14, **BASE},
164
+ )
165
+ registry.register_model_config(
166
+ "rope_vit_so150m_p14_ap",
167
+ rope_vit,
168
+ config={"patch_size": 14, **SO150, "class_token": False, "attn_pool_head": True},
169
+ )
170
+ registry.register_model_config(
171
+ "rope_vit_l32",
172
+ rope_vit,
173
+ config={"patch_size": 32, **LARGE},
174
+ )
175
+ registry.register_model_config(
176
+ "rope_vit_l16",
177
+ rope_vit,
178
+ config={"patch_size": 16, **LARGE},
179
+ )
180
+ registry.register_model_config(
181
+ "rope_vit_l14",
182
+ rope_vit,
183
+ config={"patch_size": 14, **LARGE},
184
+ )
185
+ registry.register_model_config(
186
+ "rope_i_vit_l14_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
187
+ rope_vit,
188
+ config={
189
+ "patch_size": 14,
190
+ **LARGE,
191
+ "pre_norm": True,
192
+ "attn_pool_head": True,
193
+ "attn_pool_num_heads": 8,
194
+ "attn_pool_special_tokens": True,
195
+ "norm_layer_eps": 1e-5,
196
+ "rope_rot_type": "interleaved",
197
+ "rope_grid_indexing": "xy",
198
+ "rope_grid_offset": 1,
199
+ "rope_temperature": 10000.0,
200
+ },
201
+ )
202
+ registry.register_model_config(
203
+ "rope_vit_so400m_p14_ap",
204
+ rope_vit,
205
+ config={"patch_size": 14, **SO400, "class_token": False, "attn_pool_head": True},
206
+ )
207
+ registry.register_model_config(
208
+ "rope_vit_h16",
209
+ rope_vit,
210
+ config={"patch_size": 16, **HUGE},
211
+ )
212
+ registry.register_model_config(
213
+ "rope_vit_h14",
214
+ rope_vit,
215
+ config={"patch_size": 14, **HUGE},
216
+ )
217
+ registry.register_model_config( # From "Scaling Vision Transformers"
218
+ "rope_vit_g16",
219
+ rope_vit,
220
+ config={"patch_size": 16, **GIANT},
221
+ )
222
+ registry.register_model_config( # From "Scaling Vision Transformers"
223
+ "rope_vit_g14",
224
+ rope_vit,
225
+ config={"patch_size": 14, **GIANT},
226
+ )
227
+ registry.register_model_config( # From "Scaling Vision Transformers"
228
+ "rope_vit_gigantic14",
229
+ rope_vit,
230
+ config={"patch_size": 14, **GIGANTIC},
231
+ )
232
+
233
+ # With registers
234
+ ####################
235
+
236
+ registry.register_model_config(
237
+ "rope_vit_reg1_t32",
238
+ rope_vit,
239
+ config={"patch_size": 32, **TINY, "num_reg_tokens": 1},
240
+ )
241
+ registry.register_model_config(
242
+ "rope_vit_reg1_t16",
243
+ rope_vit,
244
+ config={"patch_size": 16, **TINY, "num_reg_tokens": 1},
245
+ )
246
+ registry.register_model_config(
247
+ "rope_vit_reg1_t14",
248
+ rope_vit,
249
+ config={"patch_size": 14, **TINY, "num_reg_tokens": 1},
250
+ )
251
+ registry.register_model_config(
252
+ "rope_vit_reg1_s32",
253
+ rope_vit,
254
+ config={"patch_size": 32, **SMALL, "num_reg_tokens": 1},
255
+ )
256
+ registry.register_model_config(
257
+ "rope_vit_reg1_s16",
258
+ rope_vit,
259
+ config={"patch_size": 16, **SMALL, "num_reg_tokens": 1},
260
+ )
261
+ registry.register_model_config(
262
+ "rope_i_vit_reg1_s16_pn_npn_avg_c1", # For PE Spatial - https://arxiv.org/abs/2504.13181
263
+ rope_vit,
264
+ config={
265
+ "patch_size": 16,
266
+ **SMALL,
267
+ "num_reg_tokens": 1,
268
+ "class_token": False,
269
+ "pre_norm": True,
270
+ "post_norm": False,
271
+ "norm_layer_eps": 1e-5,
272
+ "rope_rot_type": "interleaved",
273
+ "rope_grid_indexing": "xy",
274
+ "rope_grid_offset": 1,
275
+ "rope_temperature": 10000.0,
276
+ },
277
+ )
278
+ registry.register_model_config(
279
+ "rope_vit_reg1_s14",
280
+ rope_vit,
281
+ config={"patch_size": 14, **SMALL, "num_reg_tokens": 1},
282
+ )
283
+ registry.register_model_config(
284
+ "rope_vit_reg4_m32",
285
+ rope_vit,
286
+ config={"patch_size": 32, **MEDIUM, "num_reg_tokens": 4},
287
+ )
288
+ registry.register_model_config(
289
+ "rope_vit_reg4_m16",
290
+ rope_vit,
291
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4},
292
+ )
293
+ registry.register_model_config(
294
+ "rope_vit_reg4_m16_rms_avg",
295
+ rope_vit,
296
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4, "class_token": False, "norm_layer_type": "RMSNorm"},
297
+ )
298
+ registry.register_model_config(
299
+ "rope_vit_reg4_m14",
300
+ rope_vit,
301
+ config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4},
302
+ )
303
+ registry.register_model_config(
304
+ "rope_vit_reg4_m14_avg",
305
+ rope_vit,
306
+ config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4, "class_token": False},
307
+ )
308
+ registry.register_model_config(
309
+ "rope_vit_reg4_b32",
310
+ rope_vit,
311
+ config={"patch_size": 32, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.0}, # Override the BASE definition
312
+ )
313
+ registry.register_model_config(
314
+ "rope_vit_reg4_b16",
315
+ rope_vit,
316
+ config={"patch_size": 16, **BASE, "num_reg_tokens": 4},
317
+ )
318
+ registry.register_model_config(
319
+ "rope_vit_reg4_b14",
320
+ rope_vit,
321
+ config={"patch_size": 14, **BASE, "num_reg_tokens": 4},
322
+ )
323
+ registry.register_model_config(
324
+ "rope_vit_reg8_nps_b14_ap",
325
+ rope_vit,
326
+ config={
327
+ "pos_embed_special_tokens": False,
328
+ "patch_size": 14,
329
+ **BASE,
330
+ "num_reg_tokens": 8,
331
+ "class_token": False,
332
+ "attn_pool_head": True,
333
+ },
334
+ )
335
+ registry.register_model_config(
336
+ "rope_vit_reg4_so150m_p14_ap",
337
+ rope_vit,
338
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 4, "class_token": False, "attn_pool_head": True},
339
+ )
340
+ registry.register_model_config(
341
+ "rope_vit_reg8_so150m_p14_ap",
342
+ rope_vit,
343
+ config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
344
+ )
345
+ registry.register_model_config(
346
+ "rope_vit_reg8_so150m_p14_swiglu_rms_avg",
347
+ rope_vit,
348
+ config={
349
+ "patch_size": 14,
350
+ **SO150,
351
+ "num_reg_tokens": 8,
352
+ "class_token": False,
353
+ "norm_layer_type": "RMSNorm",
354
+ "mlp_layer_type": "SwiGLU_FFN",
355
+ },
356
+ )
357
+ registry.register_model_config(
358
+ "rope_vit_reg8_so150m_p14_swiglu_rms_ap",
359
+ rope_vit,
360
+ config={
361
+ "patch_size": 14,
362
+ **SO150,
363
+ "num_reg_tokens": 8,
364
+ "class_token": False,
365
+ "attn_pool_head": True,
366
+ "norm_layer_type": "RMSNorm",
367
+ "mlp_layer_type": "SwiGLU_FFN",
368
+ },
369
+ )
370
+ registry.register_model_config(
371
+ "rope_vit_reg8_so150m_p14_swiglu_rms_aps",
372
+ rope_vit,
373
+ config={
374
+ "patch_size": 14,
375
+ **SO150,
376
+ "num_reg_tokens": 8,
377
+ "class_token": False,
378
+ "attn_pool_head": True,
379
+ "attn_pool_special_tokens": True,
380
+ "norm_layer_type": "RMSNorm",
381
+ "mlp_layer_type": "SwiGLU_FFN",
382
+ },
383
+ )
384
+ registry.register_model_config(
385
+ "rope_vit_reg4_l32",
386
+ rope_vit,
387
+ config={"patch_size": 32, **LARGE, "num_reg_tokens": 4},
388
+ )
389
+ registry.register_model_config(
390
+ "rope_vit_reg4_l16",
391
+ rope_vit,
392
+ config={"patch_size": 16, **LARGE, "num_reg_tokens": 4},
393
+ )
394
+ registry.register_model_config(
395
+ "rope_vit_reg4_l14",
396
+ rope_vit,
397
+ config={"patch_size": 14, **LARGE, "num_reg_tokens": 4},
398
+ )
399
+ registry.register_model_config(
400
+ "rope_vit_reg8_l14_rms_ap",
401
+ rope_vit,
402
+ config={
403
+ "patch_size": 14,
404
+ **LARGE,
405
+ "num_reg_tokens": 8,
406
+ "class_token": False,
407
+ "attn_pool_head": True,
408
+ "norm_layer_type": "RMSNorm",
409
+ },
410
+ )
411
+ registry.register_model_config(
412
+ "rope_vit_reg8_so400m_p14_ap",
413
+ rope_vit,
414
+ config={"patch_size": 14, **SO400, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
415
+ )
416
+ registry.register_model_config(
417
+ "rope_vit_reg4_h16",
418
+ rope_vit,
419
+ config={"patch_size": 16, **HUGE, "num_reg_tokens": 4},
420
+ )
421
+ registry.register_model_config(
422
+ "rope_vit_reg4_h14",
423
+ rope_vit,
424
+ config={"patch_size": 14, **HUGE, "num_reg_tokens": 4},
425
+ )
426
+ registry.register_model_config( # From "Scaling Vision Transformers"
427
+ "rope_vit_reg4_g14",
428
+ rope_vit,
429
+ config={"patch_size": 14, **GIANT, "num_reg_tokens": 4},
430
+ )