birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/vit.py CHANGED
@@ -10,8 +10,6 @@ and
10
10
  Paper "Vision Transformers Need Registers", https://arxiv.org/abs/2309.16588
11
11
  and
12
12
  Paper "Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design", https://arxiv.org/abs/2305.13035
13
- and
14
- Paper "Scaling Vision Transformers", https://arxiv.org/abs/2106.04560
15
13
  """
16
14
 
17
15
  # Reference license: BSD 3-Clause and Apache-2.0
@@ -35,12 +33,14 @@ from birder.layers import MultiHeadAttentionPool
35
33
  from birder.layers import SwiGLU_FFN
36
34
  from birder.layers.activations import get_activation_module
37
35
  from birder.model_registry import registry
36
+ from birder.net._vit_configs import register_vit_configs
38
37
  from birder.net.base import DetectorBackbone
39
38
  from birder.net.base import MaskedTokenOmissionMixin
40
39
  from birder.net.base import MaskedTokenRetentionMixin
41
40
  from birder.net.base import PreTrainEncoder
42
41
  from birder.net.base import TokenOmissionResultType
43
42
  from birder.net.base import TokenRetentionResultType
43
+ from birder.net.base import normalize_out_indices
44
44
 
45
45
 
46
46
  def adjust_position_embedding(
@@ -74,12 +74,10 @@ def adjust_position_embedding(
74
74
  class PatchEmbed(nn.Module):
75
75
  def forward(self, x: torch.Tensor) -> torch.Tensor:
76
76
  """
77
- The entire forward is equivalent to x.flatten(2).transpose(1, 2)
77
+ This is equivalent (in output) to: x.flatten(2).transpose(1, 2)
78
78
  """
79
79
 
80
- (n, hidden_dim, h, w) = x.size()
81
-
82
- # (n, hidden_dim, h, w) -> (n, hidden_dim, (h * w))
80
+ n, hidden_dim, h, w = x.size()
83
81
  x = x.reshape(n, hidden_dim, h * w)
84
82
 
85
83
  # (n, hidden_dim, (h * w)) -> (n, (h * w), hidden_dim)
@@ -122,14 +120,10 @@ class Attention(nn.Module):
122
120
  self.proj = nn.Linear(dim, dim)
123
121
  self.proj_drop = nn.Dropout(proj_drop)
124
122
 
125
- # Make the same interface as nn.MultiheadAttention forward for TorchScript compatibility
126
123
  def forward(
127
124
  self,
128
125
  x: torch.Tensor,
129
- key: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
130
- value: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
131
126
  need_weights: bool = False,
132
- attn_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
133
127
  average_attn_weights: bool = False,
134
128
  is_causal: bool = False,
135
129
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -137,30 +131,16 @@ class Attention(nn.Module):
137
131
  Apply multi-head self-attention to the input sequence
138
132
 
139
133
  This module implements scaled dot-product attention over x and returns the
140
- projected output. The method signature intentionally matches
141
- torch.nn.MultiheadAttention.forward for TorchScript compatibility.
142
-
143
- Compatibility notes
144
- -------------------
145
- The following parameters are accepted for API compatibility but are ignored by this implementation:
146
- - key: ignored (keys are computed from x)
147
- - value: ignored (values are computed from x)
148
- - attn_mask: ignored (no external attention mask is applied)
134
+ projected output.
149
135
 
150
136
  Parameters
151
137
  ----------
152
138
  x
153
139
  Input tensor of shape (B, N, C) where B is batch size, N is sequence length,
154
140
  and C is embedding dimension.
155
- key
156
- Unused. Present for nn.MultiheadAttention-compatible signature.
157
- value
158
- Unused. Present for nn.MultiheadAttention-compatible signature.
159
141
  need_weights
160
142
  If True, also return attention weights computed explicitly. If False, uses
161
143
  torch.nn.functional.scaled_dot_product_attention and returns None for attention weights.
162
- attn_mask
163
- Unused. Present for nn.MultiheadAttention-compatible signature.
164
144
  average_attn_weights
165
145
  If True and need_weights is True, average attention weights across heads
166
146
  to shape (B, N, N). If False, return per-head weights of shape (B, num_heads, N, N).
@@ -174,9 +154,9 @@ class Attention(nn.Module):
174
154
  - attn_weights: If need_weights is True attention weights, otherwise, None.
175
155
  """
176
156
 
177
- (B, N, C) = x.size()
157
+ B, N, C = x.size()
178
158
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
179
- (q, k, v) = qkv.unbind(0)
159
+ q, k, v = qkv.unbind(0)
180
160
  q = self.q_norm(q)
181
161
  k = self.k_norm(k)
182
162
 
@@ -231,41 +211,32 @@ class EncoderBlock(nn.Module):
231
211
  super().__init__()
232
212
  self.need_attn = False
233
213
  self.is_causal = False
234
- self.use_custom_attn = qk_norm is True
235
214
 
236
215
  if mlp_dim is None:
237
216
  mlp_dim = hidden_dim * 4
238
217
 
239
218
  # Attention block
240
- self.ln1 = norm_layer(hidden_dim, eps=norm_layer_eps)
241
-
242
- if self.use_custom_attn is False:
243
- # Prefer PyTorch's built-in MultiheadAttention for the "standard" case
244
- self.self_attention = nn.MultiheadAttention(
245
- hidden_dim, num_heads, dropout=attention_dropout, bias=qkv_bias, batch_first=True
246
- )
247
- else:
248
- self.self_attention = Attention(
249
- hidden_dim,
250
- num_heads=num_heads,
251
- attn_drop=attention_dropout,
252
- proj_drop=0.0,
253
- qkv_bias=qkv_bias,
254
- qk_norm=qk_norm,
255
- norm_layer=norm_layer,
256
- norm_layer_eps=norm_layer_eps,
257
- )
219
+ self.norm1 = norm_layer(hidden_dim, eps=norm_layer_eps)
220
+ self.attn = Attention(
221
+ hidden_dim,
222
+ num_heads=num_heads,
223
+ attn_drop=attention_dropout,
224
+ proj_drop=0.0,
225
+ qkv_bias=qkv_bias,
226
+ qk_norm=qk_norm,
227
+ norm_layer=norm_layer,
228
+ norm_layer_eps=norm_layer_eps,
229
+ )
258
230
 
259
- self.drop_path1 = StochasticDepth(drop_path, mode="row")
231
+ self.drop_path = StochasticDepth(drop_path, mode="row")
260
232
  if layer_scale_init_value is not None:
261
233
  self.layer_scale_1 = LayerScale(hidden_dim, layer_scale_init_value)
262
234
  else:
263
235
  self.layer_scale_1 = nn.Identity()
264
236
 
265
237
  # MLP block
266
- self.ln2 = norm_layer(hidden_dim, eps=norm_layer_eps)
238
+ self.norm2 = norm_layer(hidden_dim, eps=norm_layer_eps)
267
239
  self.mlp = mlp_layer(hidden_dim, mlp_dim, act_layer=activation_layer, dropout=dropout)
268
- self.drop_path2 = StochasticDepth(drop_path, mode="row")
269
240
  if layer_scale_init_value is not None:
270
241
  self.layer_scale_2 = LayerScale(hidden_dim, layer_scale_init_value)
271
242
  else:
@@ -273,34 +244,14 @@ class EncoderBlock(nn.Module):
273
244
 
274
245
  def forward(self, x: torch.Tensor) -> torch.Tensor:
275
246
  # torch._assert(x.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {x.size()}")
276
- branch1 = self.ln1(x)
277
- if self.is_causal is True:
278
- seq_len = x.size(1)
279
- attn_mask = torch.triu(
280
- torch.full((seq_len, seq_len), float("-inf"), dtype=x.dtype, device=x.device),
281
- diagonal=1,
282
- )
283
- else:
284
- attn_mask = None
285
-
286
- (branch1, _) = self.self_attention(
287
- branch1,
288
- branch1,
289
- branch1,
247
+ attn_out, _ = self.attn(
248
+ self.norm1(x),
290
249
  need_weights=self.need_attn,
291
- attn_mask=attn_mask, # Ignored on the custom attention
292
250
  average_attn_weights=False,
293
251
  is_causal=self.is_causal,
294
252
  )
295
-
296
- branch1 = self.layer_scale_1(branch1)
297
- branch1 = self.drop_path1(branch1) + x
298
-
299
- branch2 = self.ln2(branch1)
300
- branch2 = self.mlp(branch2)
301
- branch2 = self.layer_scale_2(branch2)
302
-
303
- x = self.drop_path2(branch2) + branch1
253
+ x = x + self.drop_path(self.layer_scale_1(attn_out))
254
+ x = x + self.drop_path(self.layer_scale_2(self.mlp(self.norm2(x))))
304
255
 
305
256
  return x
306
257
 
@@ -365,13 +316,15 @@ class Encoder(nn.Module):
365
316
  x = self.pre_block(x)
366
317
  return self.block(x)
367
318
 
368
- def forward_features(self, x: torch.Tensor) -> list[torch.Tensor]:
319
+ def forward_features(self, x: torch.Tensor, out_indices: Optional[list[int]] = None) -> list[torch.Tensor]:
369
320
  x = self.pre_block(x)
370
321
 
322
+ out_indices_set = set(out_indices) if out_indices is not None else None
371
323
  xs = []
372
- for blk in self.block:
324
+ for idx, blk in enumerate(self.block):
373
325
  x = blk(x)
374
- xs.append(x)
326
+ if out_indices_set is None or idx in out_indices_set:
327
+ xs.append(x)
375
328
 
376
329
  return xs
377
330
 
@@ -388,7 +341,7 @@ class Encoder(nn.Module):
388
341
  class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTokenRetentionMixin):
389
342
  block_group_regex = r"encoder\.block\.(\d+)"
390
343
 
391
- # pylint: disable=too-many-locals,too-many-branches
344
+ # pylint: disable=too-many-locals,too-many-branches,too-many-statements
392
345
  def __init__(
393
346
  self,
394
347
  input_channels: int,
@@ -423,6 +376,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
423
376
  norm_layer_eps: float = self.config.get("norm_layer_eps", 1e-6)
424
377
  mlp_layer_type: str = self.config.get("mlp_layer_type", "FFN")
425
378
  act_layer_type: Optional[str] = self.config.get("act_layer_type", None) # Default according to mlp type
379
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
426
380
  drop_path_rate: float = self.config["drop_path_rate"]
427
381
 
428
382
  if norm_layer_type == "LayerNorm":
@@ -453,6 +407,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
453
407
  self.hidden_dim = hidden_dim
454
408
  self.num_reg_tokens = num_reg_tokens
455
409
  self.attn_pool_special_tokens = attn_pool_special_tokens
410
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
456
411
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
457
412
 
458
413
  self.conv_proj = nn.Conv2d(
@@ -520,8 +475,9 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
520
475
 
521
476
  self.attn_pool = MultiHeadAttentionPool(hidden_dim, attn_pool_num_heads, mlp_dim, qkv_bias=True)
522
477
 
523
- self.return_stages = ["neck"] # Actually meaningless, just for completeness
524
- self.return_channels = [hidden_dim]
478
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
479
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
480
+ self.return_channels = [hidden_dim] * num_return_stages
525
481
  self.embedding_size = hidden_dim
526
482
  self.classifier = self.create_classifier()
527
483
 
@@ -585,8 +541,12 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
585
541
  def set_causal_attention(self, is_causal: bool = True) -> None:
586
542
  self.encoder.set_causal_attention(is_causal)
587
543
 
544
+ def transform_to_backbone(self) -> None:
545
+ super().transform_to_backbone()
546
+ self.norm = nn.Identity()
547
+
588
548
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
589
- (H, W) = x.shape[-2:]
549
+ H, W = x.shape[-2:]
590
550
  x = self.conv_proj(x)
591
551
  x = self.patch_embed(x)
592
552
 
@@ -606,15 +566,20 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
606
566
  if self.pos_embed_special_tokens is True:
607
567
  x = x + self._get_pos_embed(H, W)
608
568
 
609
- x = self.encoder(x)
610
- x = self.norm(x)
569
+ if self.out_indices is None:
570
+ xs = [self.encoder(x)]
571
+ else:
572
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
611
573
 
612
- x = x[:, self.num_special_tokens :]
613
- x = x.permute(0, 2, 1)
614
- (B, C, _) = x.size()
615
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
574
+ out: dict[str, torch.Tensor] = {}
575
+ for stage_name, stage_x in zip(self.return_stages, xs):
576
+ stage_x = stage_x[:, self.num_special_tokens :]
577
+ stage_x = stage_x.permute(0, 2, 1)
578
+ B, C, _ = stage_x.size()
579
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
580
+ out[stage_name] = stage_x
616
581
 
617
- return {self.return_stages[0]: x}
582
+ return out
618
583
 
619
584
  def freeze_stages(self, up_to_stage: int) -> None:
620
585
  for param in self.conv_proj.parameters():
@@ -637,7 +602,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
637
602
  return_all_features: bool = False,
638
603
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
639
604
  ) -> TokenOmissionResultType:
640
- (H, W) = x.shape[-2:]
605
+ H, W = x.shape[-2:]
641
606
 
642
607
  # Reshape and permute the input tensor
643
608
  x = self.conv_proj(x)
@@ -711,7 +676,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
711
676
  mask_token: Optional[torch.Tensor] = None,
712
677
  return_keys: Literal["all", "features", "embedding"] = "features",
713
678
  ) -> TokenRetentionResultType:
714
- (H, W) = x.shape[-2:]
679
+ H, W = x.shape[-2:]
715
680
 
716
681
  x = self.conv_proj(x)
717
682
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -742,7 +707,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
742
707
  if return_keys in ("all", "features"):
743
708
  features = x[:, self.num_special_tokens :]
744
709
  features = features.permute(0, 2, 1)
745
- (B, C, _) = features.size()
710
+ B, C, _ = features.size()
746
711
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
747
712
  result["features"] = features
748
713
 
@@ -762,7 +727,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
762
727
  return result
763
728
 
764
729
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
765
- (H, W) = x.shape[-2:]
730
+ H, W = x.shape[-2:]
766
731
 
767
732
  # Reshape and permute the input tensor
768
733
  x = self.conv_proj(x)
@@ -834,888 +799,8 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
834
799
  self.pos_embedding = nn.Parameter(pos_embedding)
835
800
 
836
801
 
837
- # For the model naming convention see rope_vit.py
838
-
839
- registry.register_model_config(
840
- "vit_t32",
841
- ViT,
842
- config={
843
- "patch_size": 32,
844
- "num_layers": 12,
845
- "num_heads": 3,
846
- "hidden_dim": 192,
847
- "mlp_dim": 768,
848
- "drop_path_rate": 0.0,
849
- },
850
- )
851
- registry.register_model_config(
852
- "vit_t16",
853
- ViT,
854
- config={
855
- "patch_size": 16,
856
- "num_layers": 12,
857
- "num_heads": 3,
858
- "hidden_dim": 192,
859
- "mlp_dim": 768,
860
- "drop_path_rate": 0.0,
861
- },
862
- )
863
- registry.register_model_config(
864
- "vit_s32",
865
- ViT,
866
- config={
867
- "patch_size": 32,
868
- "num_layers": 12,
869
- "num_heads": 6,
870
- "hidden_dim": 384,
871
- "mlp_dim": 1536,
872
- "drop_path_rate": 0.0,
873
- },
874
- )
875
- registry.register_model_config(
876
- "vit_s16",
877
- ViT,
878
- config={
879
- "patch_size": 16,
880
- "num_layers": 12,
881
- "num_heads": 6,
882
- "hidden_dim": 384,
883
- "mlp_dim": 1536,
884
- "drop_path_rate": 0.0,
885
- },
886
- )
887
- registry.register_model_config(
888
- "vit_s16_ls",
889
- ViT,
890
- config={
891
- "patch_size": 16,
892
- "num_layers": 12,
893
- "num_heads": 6,
894
- "hidden_dim": 384,
895
- "mlp_dim": 1536,
896
- "layer_scale_init_value": 1e-5,
897
- "drop_path_rate": 0.0,
898
- },
899
- )
900
- registry.register_model_config(
901
- "vit_s16_pn",
902
- ViT,
903
- config={
904
- "patch_size": 16,
905
- "num_layers": 12,
906
- "num_heads": 6,
907
- "hidden_dim": 384,
908
- "mlp_dim": 1536,
909
- "pre_norm": True,
910
- "norm_layer_eps": 1e-5,
911
- "drop_path_rate": 0.0,
912
- },
913
- )
914
- registry.register_model_config(
915
- "vit_s14",
916
- ViT,
917
- config={
918
- "patch_size": 14,
919
- "num_layers": 12,
920
- "num_heads": 6,
921
- "hidden_dim": 384,
922
- "mlp_dim": 1536,
923
- "drop_path_rate": 0.0,
924
- },
925
- )
926
- registry.register_model_config(
927
- "vit_m32",
928
- ViT,
929
- config={
930
- "patch_size": 32,
931
- "num_layers": 12,
932
- "num_heads": 8,
933
- "hidden_dim": 512,
934
- "mlp_dim": 2048,
935
- "drop_path_rate": 0.0,
936
- },
937
- )
938
- registry.register_model_config(
939
- "vit_m16",
940
- ViT,
941
- config={
942
- "patch_size": 16,
943
- "num_layers": 12,
944
- "num_heads": 8,
945
- "hidden_dim": 512,
946
- "mlp_dim": 2048,
947
- "drop_path_rate": 0.0,
948
- },
949
- )
950
- registry.register_model_config(
951
- "vit_m14",
952
- ViT,
953
- config={
954
- "patch_size": 14,
955
- "num_layers": 12,
956
- "num_heads": 8,
957
- "hidden_dim": 512,
958
- "mlp_dim": 2048,
959
- "drop_path_rate": 0.0,
960
- },
961
- )
962
- registry.register_model_config(
963
- "vit_b32",
964
- ViT,
965
- config={
966
- "patch_size": 32,
967
- "num_layers": 12,
968
- "num_heads": 12,
969
- "hidden_dim": 768,
970
- "mlp_dim": 3072,
971
- "drop_path_rate": 0.0,
972
- },
973
- )
974
- registry.register_model_config(
975
- "vit_b16",
976
- ViT,
977
- config={
978
- "patch_size": 16,
979
- "num_layers": 12,
980
- "num_heads": 12,
981
- "hidden_dim": 768,
982
- "mlp_dim": 3072,
983
- "drop_path_rate": 0.1,
984
- },
985
- )
986
- registry.register_model_config(
987
- "vit_b16_ls",
988
- ViT,
989
- config={
990
- "patch_size": 16,
991
- "num_layers": 12,
992
- "num_heads": 12,
993
- "hidden_dim": 768,
994
- "mlp_dim": 3072,
995
- "layer_scale_init_value": 1e-5,
996
- "drop_path_rate": 0.1,
997
- },
998
- )
999
- registry.register_model_config(
1000
- "vit_b16_qkn_ls",
1001
- ViT,
1002
- config={
1003
- "patch_size": 16,
1004
- "num_layers": 12,
1005
- "num_heads": 12,
1006
- "hidden_dim": 768,
1007
- "mlp_dim": 3072,
1008
- "layer_scale_init_value": 1e-5,
1009
- "qk_norm": True,
1010
- "drop_path_rate": 0.1,
1011
- },
1012
- )
1013
- registry.register_model_config(
1014
- "vit_b16_pn_quick_gelu",
1015
- ViT,
1016
- config={
1017
- "patch_size": 16,
1018
- "num_layers": 12,
1019
- "num_heads": 12,
1020
- "hidden_dim": 768,
1021
- "mlp_dim": 3072,
1022
- "pre_norm": True,
1023
- "norm_layer_eps": 1e-5,
1024
- "act_layer_type": "quick_gelu",
1025
- "drop_path_rate": 0.1,
1026
- },
1027
- )
1028
- registry.register_model_config(
1029
- "vit_b14",
1030
- ViT,
1031
- config={
1032
- "patch_size": 14,
1033
- "num_layers": 12,
1034
- "num_heads": 12,
1035
- "hidden_dim": 768,
1036
- "mlp_dim": 3072,
1037
- "drop_path_rate": 0.1,
1038
- },
1039
- )
1040
- registry.register_model_config(
1041
- "vit_l32",
1042
- ViT,
1043
- config={
1044
- "patch_size": 32,
1045
- "num_layers": 24,
1046
- "num_heads": 16,
1047
- "hidden_dim": 1024,
1048
- "mlp_dim": 4096,
1049
- "drop_path_rate": 0.1,
1050
- },
1051
- )
1052
- registry.register_model_config(
1053
- "vit_l16",
1054
- ViT,
1055
- config={
1056
- "patch_size": 16,
1057
- "num_layers": 24,
1058
- "num_heads": 16,
1059
- "hidden_dim": 1024,
1060
- "mlp_dim": 4096,
1061
- "drop_path_rate": 0.1,
1062
- },
1063
- )
1064
- registry.register_model_config(
1065
- "vit_l14",
1066
- ViT,
1067
- config={
1068
- "patch_size": 14,
1069
- "num_layers": 24,
1070
- "num_heads": 16,
1071
- "hidden_dim": 1024,
1072
- "mlp_dim": 4096,
1073
- "drop_path_rate": 0.1,
1074
- },
1075
- )
1076
- registry.register_model_config(
1077
- "vit_l14_pn",
1078
- ViT,
1079
- config={
1080
- "patch_size": 14,
1081
- "num_layers": 24,
1082
- "num_heads": 16,
1083
- "hidden_dim": 1024,
1084
- "mlp_dim": 4096,
1085
- "pre_norm": True,
1086
- "norm_layer_eps": 1e-5,
1087
- "drop_path_rate": 0.1,
1088
- },
1089
- )
1090
- registry.register_model_config(
1091
- "vit_l14_pn_quick_gelu",
1092
- ViT,
1093
- config={
1094
- "patch_size": 14,
1095
- "num_layers": 24,
1096
- "num_heads": 16,
1097
- "hidden_dim": 1024,
1098
- "mlp_dim": 4096,
1099
- "pre_norm": True,
1100
- "norm_layer_eps": 1e-5,
1101
- "act_layer_type": "quick_gelu",
1102
- "drop_path_rate": 0.1,
1103
- },
1104
- )
1105
- registry.register_model_config(
1106
- "vit_h16",
1107
- ViT,
1108
- config={
1109
- "patch_size": 16,
1110
- "num_layers": 32,
1111
- "num_heads": 16,
1112
- "hidden_dim": 1280,
1113
- "mlp_dim": 5120,
1114
- "drop_path_rate": 0.1,
1115
- },
1116
- )
1117
- registry.register_model_config(
1118
- "vit_h14",
1119
- ViT,
1120
- config={
1121
- "patch_size": 14,
1122
- "num_layers": 32,
1123
- "num_heads": 16,
1124
- "hidden_dim": 1280,
1125
- "mlp_dim": 5120,
1126
- "drop_path_rate": 0.1,
1127
- },
1128
- )
1129
- registry.register_model_config( # From "Scaling Vision Transformers"
1130
- "vit_g14",
1131
- ViT,
1132
- config={
1133
- "patch_size": 14,
1134
- "num_layers": 40,
1135
- "num_heads": 16,
1136
- "hidden_dim": 1408,
1137
- "mlp_dim": 6144,
1138
- "drop_path_rate": 0.1,
1139
- },
1140
- )
1141
- registry.register_model_config( # From "Scaling Vision Transformers"
1142
- "vit_gigantic14",
1143
- ViT,
1144
- config={
1145
- "patch_size": 14,
1146
- "num_layers": 48,
1147
- "num_heads": 16,
1148
- "hidden_dim": 1664,
1149
- "mlp_dim": 8192,
1150
- "drop_path_rate": 0.1,
1151
- },
1152
- )
1153
-
1154
- # With registers
1155
- registry.register_model_config(
1156
- "vit_reg1_t16",
1157
- ViT,
1158
- config={
1159
- "patch_size": 16,
1160
- "num_layers": 12,
1161
- "num_heads": 3,
1162
- "hidden_dim": 192,
1163
- "mlp_dim": 768,
1164
- "num_reg_tokens": 1,
1165
- "drop_path_rate": 0.0,
1166
- },
1167
- )
1168
- registry.register_model_config(
1169
- "vit_reg1_s32",
1170
- ViT,
1171
- config={
1172
- "patch_size": 32,
1173
- "num_layers": 12,
1174
- "num_heads": 6,
1175
- "hidden_dim": 384,
1176
- "mlp_dim": 1536,
1177
- "num_reg_tokens": 1,
1178
- "drop_path_rate": 0.0,
1179
- },
1180
- )
1181
- registry.register_model_config(
1182
- "vit_reg1_s16",
1183
- ViT,
1184
- config={
1185
- "patch_size": 16,
1186
- "num_layers": 12,
1187
- "num_heads": 6,
1188
- "hidden_dim": 384,
1189
- "mlp_dim": 1536,
1190
- "num_reg_tokens": 1,
1191
- "drop_path_rate": 0.0,
1192
- },
1193
- )
1194
- registry.register_model_config(
1195
- "vit_reg1_s16_ls",
1196
- ViT,
1197
- config={
1198
- "patch_size": 16,
1199
- "num_layers": 12,
1200
- "num_heads": 6,
1201
- "hidden_dim": 384,
1202
- "mlp_dim": 1536,
1203
- "layer_scale_init_value": 1e-5,
1204
- "num_reg_tokens": 1,
1205
- "drop_path_rate": 0.0,
1206
- },
1207
- )
1208
- registry.register_model_config(
1209
- "vit_reg1_s16_rms_ls",
1210
- ViT,
1211
- config={
1212
- "patch_size": 16,
1213
- "num_layers": 12,
1214
- "num_heads": 6,
1215
- "hidden_dim": 384,
1216
- "mlp_dim": 1536,
1217
- "layer_scale_init_value": 1e-5,
1218
- "num_reg_tokens": 1,
1219
- "norm_layer_type": "RMSNorm",
1220
- "drop_path_rate": 0.0,
1221
- },
1222
- )
1223
- registry.register_model_config(
1224
- "vit_reg1_s14",
1225
- ViT,
1226
- config={
1227
- "patch_size": 14,
1228
- "num_layers": 12,
1229
- "num_heads": 6,
1230
- "hidden_dim": 384,
1231
- "mlp_dim": 1536,
1232
- "num_reg_tokens": 1,
1233
- "drop_path_rate": 0.0,
1234
- },
1235
- )
1236
- registry.register_model_config(
1237
- "vit_reg4_m32",
1238
- ViT,
1239
- config={
1240
- "patch_size": 32,
1241
- "num_layers": 12,
1242
- "num_heads": 8,
1243
- "hidden_dim": 512,
1244
- "mlp_dim": 2048,
1245
- "num_reg_tokens": 4,
1246
- "drop_path_rate": 0.0,
1247
- },
1248
- )
1249
- registry.register_model_config(
1250
- "vit_reg4_m16",
1251
- ViT,
1252
- config={
1253
- "patch_size": 16,
1254
- "num_layers": 12,
1255
- "num_heads": 8,
1256
- "hidden_dim": 512,
1257
- "mlp_dim": 2048,
1258
- "num_reg_tokens": 4,
1259
- "drop_path_rate": 0.0,
1260
- },
1261
- )
1262
- registry.register_model_config(
1263
- "vit_reg4_m16_rms_avg",
1264
- ViT,
1265
- config={
1266
- "patch_size": 16,
1267
- "num_layers": 12,
1268
- "num_heads": 8,
1269
- "hidden_dim": 512,
1270
- "mlp_dim": 2048,
1271
- "num_reg_tokens": 4,
1272
- "class_token": False,
1273
- "norm_layer_type": "RMSNorm",
1274
- "drop_path_rate": 0.0,
1275
- },
1276
- )
1277
- registry.register_model_config(
1278
- "vit_reg4_m14",
1279
- ViT,
1280
- config={
1281
- "patch_size": 14,
1282
- "num_layers": 12,
1283
- "num_heads": 8,
1284
- "hidden_dim": 512,
1285
- "mlp_dim": 2048,
1286
- "num_reg_tokens": 4,
1287
- "drop_path_rate": 0.0,
1288
- },
1289
- )
1290
- registry.register_model_config(
1291
- "vit_reg4_b32",
1292
- ViT,
1293
- config={
1294
- "patch_size": 32,
1295
- "num_layers": 12,
1296
- "num_heads": 12,
1297
- "hidden_dim": 768,
1298
- "mlp_dim": 3072,
1299
- "num_reg_tokens": 4,
1300
- "drop_path_rate": 0.0,
1301
- },
1302
- )
1303
- registry.register_model_config(
1304
- "vit_reg4_b16",
1305
- ViT,
1306
- config={
1307
- "patch_size": 16,
1308
- "num_layers": 12,
1309
- "num_heads": 12,
1310
- "hidden_dim": 768,
1311
- "mlp_dim": 3072,
1312
- "num_reg_tokens": 4,
1313
- "drop_path_rate": 0.1,
1314
- },
1315
- )
1316
- registry.register_model_config(
1317
- "vit_reg4_b16_avg",
1318
- ViT,
1319
- config={
1320
- "patch_size": 16,
1321
- "num_layers": 12,
1322
- "num_heads": 12,
1323
- "hidden_dim": 768,
1324
- "mlp_dim": 3072,
1325
- "num_reg_tokens": 4,
1326
- "class_token": False,
1327
- "drop_path_rate": 0.1,
1328
- },
1329
- )
1330
- registry.register_model_config(
1331
- "vit_reg4_b14",
1332
- ViT,
1333
- config={
1334
- "patch_size": 14,
1335
- "num_layers": 12,
1336
- "num_heads": 12,
1337
- "hidden_dim": 768,
1338
- "mlp_dim": 3072,
1339
- "num_reg_tokens": 4,
1340
- "drop_path_rate": 0.1,
1341
- },
1342
- )
1343
- registry.register_model_config(
1344
- "vit_reg8_b14_ap",
1345
- ViT,
1346
- config={
1347
- "patch_size": 14,
1348
- "num_layers": 12,
1349
- "num_heads": 12,
1350
- "hidden_dim": 768,
1351
- "mlp_dim": 3072,
1352
- "num_reg_tokens": 8,
1353
- "class_token": False,
1354
- "attn_pool_head": True,
1355
- "drop_path_rate": 0.1,
1356
- },
1357
- )
1358
- registry.register_model_config(
1359
- "vit_reg4_l32",
1360
- ViT,
1361
- config={
1362
- "patch_size": 32,
1363
- "num_layers": 24,
1364
- "num_heads": 16,
1365
- "hidden_dim": 1024,
1366
- "mlp_dim": 4096,
1367
- "num_reg_tokens": 4,
1368
- "drop_path_rate": 0.1,
1369
- },
1370
- )
1371
- registry.register_model_config(
1372
- "vit_reg4_l16",
1373
- ViT,
1374
- config={
1375
- "patch_size": 16,
1376
- "num_layers": 24,
1377
- "num_heads": 16,
1378
- "hidden_dim": 1024,
1379
- "mlp_dim": 4096,
1380
- "num_reg_tokens": 4,
1381
- "drop_path_rate": 0.1,
1382
- },
1383
- )
1384
- registry.register_model_config(
1385
- "vit_reg8_l16_avg",
1386
- ViT,
1387
- config={
1388
- "patch_size": 16,
1389
- "num_layers": 24,
1390
- "num_heads": 16,
1391
- "hidden_dim": 1024,
1392
- "mlp_dim": 4096,
1393
- "num_reg_tokens": 8,
1394
- "class_token": False,
1395
- "drop_path_rate": 0.1,
1396
- },
1397
- )
1398
- registry.register_model_config(
1399
- "vit_reg8_l16_aps",
1400
- ViT,
1401
- config={
1402
- "patch_size": 16,
1403
- "num_layers": 24,
1404
- "num_heads": 16,
1405
- "hidden_dim": 1024,
1406
- "mlp_dim": 4096,
1407
- "num_reg_tokens": 8,
1408
- "class_token": False,
1409
- "attn_pool_head": True,
1410
- "attn_pool_special_tokens": True,
1411
- "drop_path_rate": 0.1,
1412
- },
1413
- )
1414
- registry.register_model_config(
1415
- "vit_reg4_l14",
1416
- ViT,
1417
- config={
1418
- "patch_size": 14,
1419
- "num_layers": 24,
1420
- "num_heads": 16,
1421
- "hidden_dim": 1024,
1422
- "mlp_dim": 4096,
1423
- "num_reg_tokens": 4,
1424
- "drop_path_rate": 0.1,
1425
- },
1426
- )
1427
- registry.register_model_config( # DeiT III style
1428
- "vit_reg4_l14_nps_ls",
1429
- ViT,
1430
- config={
1431
- "pos_embed_special_tokens": False,
1432
- "patch_size": 14,
1433
- "num_layers": 24,
1434
- "num_heads": 16,
1435
- "hidden_dim": 1024,
1436
- "mlp_dim": 4096,
1437
- "layer_scale_init_value": 1e-5,
1438
- "num_reg_tokens": 4,
1439
- "drop_path_rate": 0.1,
1440
- },
1441
- )
1442
- registry.register_model_config(
1443
- "vit_reg8_l14_ap",
1444
- ViT,
1445
- config={
1446
- "patch_size": 14,
1447
- "num_layers": 24,
1448
- "num_heads": 16,
1449
- "hidden_dim": 1024,
1450
- "mlp_dim": 4096,
1451
- "num_reg_tokens": 8,
1452
- "class_token": False,
1453
- "attn_pool_head": True,
1454
- "drop_path_rate": 0.1,
1455
- },
1456
- )
1457
- registry.register_model_config(
1458
- "vit_reg8_l14_rms_ap",
1459
- ViT,
1460
- config={
1461
- "patch_size": 14,
1462
- "num_layers": 24,
1463
- "num_heads": 16,
1464
- "hidden_dim": 1024,
1465
- "mlp_dim": 4096,
1466
- "num_reg_tokens": 8,
1467
- "class_token": False,
1468
- "attn_pool_head": True,
1469
- "norm_layer_type": "RMSNorm",
1470
- "drop_path_rate": 0.1,
1471
- },
1472
- )
1473
- registry.register_model_config(
1474
- "vit_reg4_h16",
1475
- ViT,
1476
- config={
1477
- "patch_size": 16,
1478
- "num_layers": 32,
1479
- "num_heads": 16,
1480
- "hidden_dim": 1280,
1481
- "mlp_dim": 5120,
1482
- "num_reg_tokens": 4,
1483
- "drop_path_rate": 0.1,
1484
- },
1485
- )
1486
- registry.register_model_config(
1487
- "vit_reg4_h14",
1488
- ViT,
1489
- config={
1490
- "patch_size": 14,
1491
- "num_layers": 32,
1492
- "num_heads": 16,
1493
- "hidden_dim": 1280,
1494
- "mlp_dim": 5120,
1495
- "num_reg_tokens": 4,
1496
- "drop_path_rate": 0.1,
1497
- },
1498
- )
1499
- registry.register_model_config( # From "Scaling Vision Transformers"
1500
- "vit_reg4_g14",
1501
- ViT,
1502
- config={
1503
- "patch_size": 14,
1504
- "num_layers": 40,
1505
- "num_heads": 16,
1506
- "hidden_dim": 1408,
1507
- "mlp_dim": 6144,
1508
- "num_reg_tokens": 4,
1509
- "drop_path_rate": 0.1,
1510
- },
1511
- )
1512
-
1513
- # Shape-optimized vision transformer (SoViT)
1514
- registry.register_model_config(
1515
- "vit_so150m_p14_ap",
1516
- ViT,
1517
- config={
1518
- "patch_size": 14,
1519
- "num_layers": 18,
1520
- "num_heads": 16,
1521
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1522
- "mlp_dim": 2320,
1523
- "class_token": False,
1524
- "attn_pool_head": True,
1525
- "drop_path_rate": 0.1,
1526
- },
1527
- )
1528
- registry.register_model_config(
1529
- "vit_so400m_p14_ap",
1530
- ViT,
1531
- config={
1532
- "patch_size": 14,
1533
- "num_layers": 27,
1534
- "num_heads": 16,
1535
- "hidden_dim": 1152,
1536
- "mlp_dim": 4304,
1537
- "class_token": False,
1538
- "attn_pool_head": True,
1539
- "drop_path_rate": 0.1,
1540
- },
1541
- )
1542
- registry.register_model_config(
1543
- "vit_reg4_so150m_p16_avg",
1544
- ViT,
1545
- config={
1546
- "patch_size": 16,
1547
- "num_layers": 18,
1548
- "num_heads": 16,
1549
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1550
- "mlp_dim": 2320,
1551
- "num_reg_tokens": 4,
1552
- "class_token": False,
1553
- "drop_path_rate": 0.1,
1554
- },
1555
- )
1556
- registry.register_model_config(
1557
- "vit_reg8_so150m_p16_swiglu_ap",
1558
- ViT,
1559
- config={
1560
- "patch_size": 16,
1561
- "num_layers": 18,
1562
- "num_heads": 16,
1563
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1564
- "mlp_dim": 2320,
1565
- "num_reg_tokens": 8,
1566
- "class_token": False,
1567
- "attn_pool_head": True,
1568
- "mlp_layer_type": "SwiGLU_FFN",
1569
- "drop_path_rate": 0.1,
1570
- },
1571
- )
1572
- registry.register_model_config(
1573
- "vit_reg4_so150m_p14_avg",
1574
- ViT,
1575
- config={
1576
- "patch_size": 14,
1577
- "num_layers": 18,
1578
- "num_heads": 16,
1579
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1580
- "mlp_dim": 2320,
1581
- "num_reg_tokens": 4,
1582
- "class_token": False,
1583
- "drop_path_rate": 0.1,
1584
- },
1585
- )
1586
- registry.register_model_config(
1587
- "vit_reg4_so150m_p14_ls",
1588
- ViT,
1589
- config={
1590
- "patch_size": 14,
1591
- "num_layers": 18,
1592
- "num_heads": 16,
1593
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1594
- "mlp_dim": 2320,
1595
- "layer_scale_init_value": 1e-5,
1596
- "num_reg_tokens": 4,
1597
- "drop_path_rate": 0.1,
1598
- },
1599
- )
1600
- registry.register_model_config(
1601
- "vit_reg4_so150m_p14_ap",
1602
- ViT,
1603
- config={
1604
- "patch_size": 14,
1605
- "num_layers": 18,
1606
- "num_heads": 16,
1607
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1608
- "mlp_dim": 2320,
1609
- "num_reg_tokens": 4,
1610
- "class_token": False,
1611
- "attn_pool_head": True,
1612
- "drop_path_rate": 0.1,
1613
- },
1614
- )
1615
- registry.register_model_config(
1616
- "vit_reg4_so150m_p14_aps",
1617
- ViT,
1618
- config={
1619
- "patch_size": 14,
1620
- "num_layers": 18,
1621
- "num_heads": 16,
1622
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1623
- "mlp_dim": 2320,
1624
- "num_reg_tokens": 4,
1625
- "class_token": False,
1626
- "attn_pool_head": True,
1627
- "attn_pool_special_tokens": True,
1628
- "drop_path_rate": 0.1,
1629
- },
1630
- )
1631
- registry.register_model_config(
1632
- "vit_reg8_so150m_p14_avg",
1633
- ViT,
1634
- config={
1635
- "patch_size": 14,
1636
- "num_layers": 18,
1637
- "num_heads": 16,
1638
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1639
- "mlp_dim": 2320,
1640
- "num_reg_tokens": 8,
1641
- "class_token": False,
1642
- "drop_path_rate": 0.1,
1643
- },
1644
- )
1645
- registry.register_model_config(
1646
- "vit_reg8_so150m_p14_swiglu",
1647
- ViT,
1648
- config={
1649
- "patch_size": 14,
1650
- "num_layers": 18,
1651
- "num_heads": 16,
1652
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1653
- "mlp_dim": 2320,
1654
- "num_reg_tokens": 8,
1655
- "mlp_layer_type": "SwiGLU_FFN",
1656
- "drop_path_rate": 0.1,
1657
- },
1658
- )
1659
- registry.register_model_config(
1660
- "vit_reg8_so150m_p14_swiglu_avg",
1661
- ViT,
1662
- config={
1663
- "patch_size": 14,
1664
- "num_layers": 18,
1665
- "num_heads": 16,
1666
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1667
- "mlp_dim": 2320,
1668
- "num_reg_tokens": 8,
1669
- "class_token": False,
1670
- "mlp_layer_type": "SwiGLU_FFN",
1671
- "drop_path_rate": 0.1,
1672
- },
1673
- )
1674
- registry.register_model_config(
1675
- "vit_reg8_so150m_p14_ap",
1676
- ViT,
1677
- config={
1678
- "patch_size": 14,
1679
- "num_layers": 18,
1680
- "num_heads": 16,
1681
- "hidden_dim": 896, # Changed from 880 for RoPE divisibility
1682
- "mlp_dim": 2320,
1683
- "num_reg_tokens": 8,
1684
- "class_token": False,
1685
- "attn_pool_head": True,
1686
- "drop_path_rate": 0.1,
1687
- },
1688
- )
1689
- registry.register_model_config(
1690
- "vit_reg4_so400m_p14_ap",
1691
- ViT,
1692
- config={
1693
- "patch_size": 14,
1694
- "num_layers": 27,
1695
- "num_heads": 16,
1696
- "hidden_dim": 1152,
1697
- "mlp_dim": 4304,
1698
- "num_reg_tokens": 4,
1699
- "class_token": False,
1700
- "attn_pool_head": True,
1701
- "drop_path_rate": 0.1,
1702
- },
1703
- )
1704
- registry.register_model_config(
1705
- "vit_reg8_so400m_p14_ap",
1706
- ViT,
1707
- config={
1708
- "patch_size": 14,
1709
- "num_layers": 27,
1710
- "num_heads": 16,
1711
- "hidden_dim": 1152,
1712
- "mlp_dim": 4304,
1713
- "num_reg_tokens": 8,
1714
- "class_token": False,
1715
- "attn_pool_head": True,
1716
- "drop_path_rate": 0.1,
1717
- },
1718
- )
802
+ # Register model configs (side effects)
803
+ register_vit_configs(ViT)
1719
804
 
1720
805
  registry.register_weights(
1721
806
  "vit_l16_mim_200",
@@ -1729,7 +814,7 @@ registry.register_weights(
1729
814
  "formats": {
1730
815
  "pt": {
1731
816
  "file_size": 1157.1,
1732
- "sha256": "003b15a79cd528339de1b19304bbd04fd5885df36b80e19202cd6ef6f8ffbed1",
817
+ "sha256": "7fc5b342347d8349aaf5f069a47efd441b646f8542821ed2e30b47a7da72917a",
1733
818
  },
1734
819
  },
1735
820
  "net": {"network": "vit_l16", "tag": "mim"},
@@ -1747,7 +832,7 @@ registry.register_weights(
1747
832
  "formats": {
1748
833
  "pt": {
1749
834
  "file_size": 1157.1,
1750
- "sha256": "c6083c6532996addaf4efe29276aa55f9a3c77984f862f720c6131f86b847994",
835
+ "sha256": "9b5c4e2538ea40edd60d8831d3807b543290dc2db44d537e60e44a341b47e54e",
1751
836
  },
1752
837
  },
1753
838
  "net": {"network": "vit_l16", "tag": "mim"},
@@ -1765,7 +850,7 @@ registry.register_weights( # BioCLIP v2: https://arxiv.org/abs/2505.23883
1765
850
  "formats": {
1766
851
  "pt": {
1767
852
  "file_size": 1156.6,
1768
- "sha256": "cfb998d762cd2ba883964026ddfc8f2f84cf1e6ad6f7264ab33da52f57d25fab",
853
+ "sha256": "6cd7bd6993762590891fe2b41db1649cde5a0c4de5a7f341672f8856ed529d07",
1769
854
  },
1770
855
  },
1771
856
  "net": {"network": "vit_l14_pn", "tag": "bioclip-v2"},
@@ -1783,7 +868,7 @@ registry.register_weights( # OpenAI CLIP: https://arxiv.org/abs/2103.00020
1783
868
  "formats": {
1784
869
  "pt": {
1785
870
  "file_size": 1159.7,
1786
- "sha256": "e4c6ff7467608c412d35f9a4e2df18f3b8f05fc9eca3803c8fcc01558921378d",
871
+ "sha256": "2c7462390956d8942de0df21d9d1a43cf53fdbe3a3570a1add64d859313a0bee",
1787
872
  },
1788
873
  },
1789
874
  "net": {"network": "vit_l14_pn_quick_gelu", "tag": "openai-clip"},
@@ -1801,7 +886,7 @@ registry.register_weights( # SigLIP 2: https://arxiv.org/abs/2502.14786
1801
886
  "formats": {
1802
887
  "pt": {
1803
888
  "file_size": 1631.6,
1804
- "sha256": "1f9f659a7b1bdf8a6a2977140be9bb3f876f7f756bf6e13d54bf00f3b6db0b0f",
889
+ "sha256": "f8ac3bdf028d17a2ee2673f58b51cffa5c696edef44c92092299d970607c7be6",
1805
890
  },
1806
891
  },
1807
892
  "net": {"network": "vit_so400m_p14_ap", "tag": "siglip-v2-webli"},
@@ -1821,7 +906,7 @@ registry.register_weights(
1821
906
  "formats": {
1822
907
  "pt": {
1823
908
  "file_size": 146.2,
1824
- "sha256": "bc4c9e600e93322440fb68c1001216d49c54c7587cdf61544f363f9537152f4a",
909
+ "sha256": "0f5cd4e0acb44d1e429bbed342c60bf22087ecd1d7112363c3ceb909dcd9d547",
1825
910
  },
1826
911
  },
1827
912
  "net": {"network": "vit_reg4_m16_rms_avg", "tag": "i-jepa"},
@@ -1839,7 +924,7 @@ registry.register_weights(
1839
924
  "formats": {
1840
925
  "pt": {
1841
926
  "file_size": 166.8,
1842
- "sha256": "9ff659be9826bbbafbcfa85d79d0fa9d5ac383fd2442ffa36db6c4f7ab09b86a",
927
+ "sha256": "e9b83e90c284877c859e92a05a35ff25884a06d3fd006d90ee576d58f71d3251",
1843
928
  },
1844
929
  },
1845
930
  "net": {"network": "vit_reg4_m16_rms_avg", "tag": "i-jepa-inat21-256px"},
@@ -1857,7 +942,7 @@ registry.register_weights(
1857
942
  "formats": {
1858
943
  "pt": {
1859
944
  "file_size": 167.4,
1860
- "sha256": "1cfa7ebea3db95363bf9e35fc24be94e419debe5db58746fe3320fbcb8bda6dd",
945
+ "sha256": "7fde7375f5f9165114561f6288cdf086ba8b6635251304de08bd01883bb7a2da",
1861
946
  },
1862
947
  },
1863
948
  "net": {"network": "vit_reg4_m16_rms_avg", "tag": "i-jepa-inat21"},
@@ -1874,7 +959,7 @@ registry.register_weights(
1874
959
  "formats": {
1875
960
  "pt": {
1876
961
  "file_size": 184.2,
1877
- "sha256": "d6d9fc47ecbad04a83b178bcd2eeecbd77569cc2a17fbdf52e02feda54523c3f",
962
+ "sha256": "da47dc6bd4f41c347235beba92657b66148180141d0bd629169e84449b629fbb",
1878
963
  },
1879
964
  },
1880
965
  "net": {"network": "vit_reg4_m16_rms_avg", "tag": "i-jepa-imagenet21k"},
@@ -1892,7 +977,7 @@ registry.register_weights(
1892
977
  "formats": {
1893
978
  "pt": {
1894
979
  "file_size": 327.4,
1895
- "sha256": "6b044cd7834293e344309f809070db3fe9ede489478e7549ad96255f9d76b329",
980
+ "sha256": "c7ec433c01e1dc0d6100cafc29fa88155a0d65f4b42afa9cc252b77485a566a7",
1896
981
  },
1897
982
  },
1898
983
  "net": {"network": "vit_reg4_b16", "tag": "mim"},
@@ -1910,7 +995,7 @@ registry.register_weights(
1910
995
  "formats": {
1911
996
  "pt": {
1912
997
  "file_size": 327.4,
1913
- "sha256": "e0df2e79f8ed0612d12c736cc6317be1b9b354e468715a5077366f7676fdd2ce",
998
+ "sha256": "b0e5e2b24ea7a8d2be246df43c9d8092354f6ee81e88c6cdd7c52d8e38ed44a4",
1914
999
  },
1915
1000
  },
1916
1001
  "net": {"network": "vit_reg4_b16", "tag": "mim"},
@@ -1928,7 +1013,7 @@ registry.register_weights(
1928
1013
  "formats": {
1929
1014
  "pt": {
1930
1015
  "file_size": 328.7,
1931
- "sha256": "3d1564be46b23081c76aa87c7e90324214b6ced899d4b38d59d1a4154b13f01c",
1016
+ "sha256": "3a15b95285cd4435b601ef058839f422cdce8f68cca50de9353e1ac2bcb65f9a",
1932
1017
  },
1933
1018
  },
1934
1019
  "net": {"network": "vit_reg4_b16", "tag": "mim-intermediate-il-common"},
@@ -1946,7 +1031,7 @@ registry.register_weights(
1946
1031
  "formats": {
1947
1032
  "pt": {
1948
1033
  "file_size": 330.7,
1949
- "sha256": "e011f931a5a4d96ef21283d70911a55ea649eadfefa9c163a48b996797f0d9da",
1034
+ "sha256": "78dbf578ebe7d5761705231e16fef280b14905a94f18879167c96df3e59d13a5",
1950
1035
  },
1951
1036
  },
1952
1037
  "net": {"network": "vit_reg4_b16", "tag": "mim-intermediate-arabian-peninsula"},
@@ -1964,7 +1049,7 @@ registry.register_weights( # DINO v2: https://arxiv.org/abs/2304.07193
1964
1049
  "formats": {
1965
1050
  "pt": {
1966
1051
  "file_size": 1161.2,
1967
- "sha256": "56d39cbaed8b7da72175b7b3a0c9419e71aabc1e9516567703a39ba05244a44f",
1052
+ "sha256": "441721029ca0ef85582bc8822ec91d780ee442eb3d06b04fb5e4662c9317b52d",
1968
1053
  },
1969
1054
  },
1970
1055
  "net": {"network": "vit_reg4_l14_nps_ls", "tag": "dino-v2-lvd142m"},