birder 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/fs_ops.py +2 -2
- birder/introspection/attention_rollout.py +1 -1
- birder/introspection/transformer_attribution.py +1 -1
- birder/layers/layer_scale.py +1 -1
- birder/net/__init__.py +2 -10
- birder/net/_rope_vit_configs.py +430 -0
- birder/net/_vit_configs.py +479 -0
- birder/net/biformer.py +1 -0
- birder/net/cait.py +5 -5
- birder/net/coat.py +12 -12
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +1 -1
- birder/net/crossvit.py +5 -5
- birder/net/davit.py +1 -1
- birder/net/deit.py +12 -26
- birder/net/deit3.py +42 -189
- birder/net/densenet.py +9 -8
- birder/net/detection/deformable_detr.py +5 -2
- birder/net/detection/detr.py +5 -2
- birder/net/detection/efficientdet.py +1 -1
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +2 -1
- birder/net/edgevit.py +3 -0
- birder/net/efficientformer_v1.py +2 -1
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvit_mit.py +5 -5
- birder/net/fasternet.py +2 -2
- birder/net/flexivit.py +22 -43
- birder/net/groupmixformer.py +1 -1
- birder/net/hgnet_v1.py +5 -5
- birder/net/hiera.py +3 -3
- birder/net/hieradet.py +116 -28
- birder/net/inception_next.py +1 -1
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/maxvit.py +1 -1
- birder/net/metaformer.py +3 -3
- birder/net/mim/crossmae.py +1 -1
- birder/net/mim/mae_vit.py +1 -1
- birder/net/mim/simmim.py +1 -1
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilevit_v1.py +5 -32
- birder/net/mobilevit_v2.py +1 -45
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +6 -6
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +1 -1
- birder/net/pvt_v1.py +5 -5
- birder/net/pvt_v2.py +5 -5
- birder/net/repghost.py +1 -30
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +3 -0
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +33 -136
- birder/net/rope_flexivit.py +18 -18
- birder/net/rope_vit.py +3 -735
- birder/net/simple_vit.py +22 -16
- birder/net/smt.py +1 -1
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/capi.py +1 -1
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/dino_v2.py +2 -2
- birder/net/ssl/franca.py +2 -2
- birder/net/ssl/i_jepa.py +1 -1
- birder/net/ssl/ibot.py +1 -1
- birder/net/swiftformer.py +12 -2
- birder/net/swin_transformer_v2.py +1 -1
- birder/net/tiny_vit.py +3 -16
- birder/net/van.py +2 -2
- birder/net/vit.py +35 -963
- birder/net/vit_sam.py +13 -38
- birder/net/xcit.py +7 -6
- birder/scripts/train.py +17 -15
- birder/scripts/train_kd.py +17 -16
- birder/tools/introspection.py +1 -1
- birder/tools/model_info.py +3 -1
- birder/tools/show_iterator.py +16 -2
- birder/version.py +1 -1
- {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/METADATA +1 -1
- {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/RECORD +93 -95
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/WHEEL +0 -0
- {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/entry_points.txt +0 -0
- {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.2.dist-info → birder-0.4.0.dist-info}/top_level.txt +0 -0
birder/common/fs_ops.py
CHANGED
|
@@ -158,7 +158,7 @@ def model_path(
|
|
|
158
158
|
file_name = f"{file_name}_quantized"
|
|
159
159
|
|
|
160
160
|
if states is True:
|
|
161
|
-
file_name = f"{file_name}_states"
|
|
161
|
+
file_name = f"{file_name}_states.pt"
|
|
162
162
|
elif lite is True:
|
|
163
163
|
file_name = f"{file_name}.ptl"
|
|
164
164
|
elif pt2 is True:
|
|
@@ -254,7 +254,7 @@ def clean_checkpoints(network_name: str, keep_last: int) -> None:
|
|
|
254
254
|
models_glob = str(model_path(network_name, epoch=epoch))
|
|
255
255
|
states_glob = str(model_path(network_name, epoch=epoch, states=True))
|
|
256
256
|
model_pattern = re.compile(r".*_([1-9][0-9]*)\.pt$")
|
|
257
|
-
states_pattern = re.compile(r".*_([1-9][0-9]*)_states$")
|
|
257
|
+
states_pattern = re.compile(r".*_([1-9][0-9]*)_states\.pt$")
|
|
258
258
|
|
|
259
259
|
model_paths = list(settings.BASE_DIR.glob(models_glob))
|
|
260
260
|
for p in sorted(model_paths, key=lambda p: p.stat().st_mtime)[:-keep_last]:
|
|
@@ -141,7 +141,7 @@ class AttentionRollout:
|
|
|
141
141
|
net: nn.Module,
|
|
142
142
|
device: torch.device,
|
|
143
143
|
transform: Callable[..., torch.Tensor],
|
|
144
|
-
attention_layer_name: str = "
|
|
144
|
+
attention_layer_name: str = "attn",
|
|
145
145
|
discard_ratio: float = 0.9,
|
|
146
146
|
head_fusion: Literal["mean", "max", "min"] = "max",
|
|
147
147
|
) -> None:
|
|
@@ -132,7 +132,7 @@ class TransformerAttribution:
|
|
|
132
132
|
net: nn.Module,
|
|
133
133
|
device: torch.device,
|
|
134
134
|
transform: Callable[..., torch.Tensor],
|
|
135
|
-
attention_layer_name: str = "
|
|
135
|
+
attention_layer_name: str = "attn",
|
|
136
136
|
) -> None:
|
|
137
137
|
self.net = net.eval()
|
|
138
138
|
self.device = device
|
birder/layers/layer_scale.py
CHANGED
|
@@ -23,7 +23,7 @@ class LayerScale2d(nn.Module):
|
|
|
23
23
|
def __init__(self, dim: int, init_value: float, inplace: bool = False) -> None:
|
|
24
24
|
super().__init__()
|
|
25
25
|
self.inplace = inplace
|
|
26
|
-
self.gamma = nn.Parameter(init_value * torch.ones(dim, 1, 1)
|
|
26
|
+
self.gamma = nn.Parameter(init_value * torch.ones(dim, 1, 1))
|
|
27
27
|
|
|
28
28
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
29
29
|
if self.inplace is True:
|
birder/net/__init__.py
CHANGED
|
@@ -55,8 +55,7 @@ from birder.net.metaformer import MetaFormer
|
|
|
55
55
|
from birder.net.mnasnet import MNASNet
|
|
56
56
|
from birder.net.mobilenet_v1 import MobileNet_v1
|
|
57
57
|
from birder.net.mobilenet_v2 import MobileNet_v2
|
|
58
|
-
from birder.net.
|
|
59
|
-
from birder.net.mobilenet_v3_small import MobileNet_v3_Small
|
|
58
|
+
from birder.net.mobilenet_v3 import MobileNet_v3
|
|
60
59
|
from birder.net.mobilenet_v4 import MobileNet_v4
|
|
61
60
|
from birder.net.mobilenet_v4_hybrid import MobileNet_v4_Hybrid
|
|
62
61
|
from birder.net.mobileone import MobileOne
|
|
@@ -84,9 +83,6 @@ from birder.net.resnext import ResNeXt
|
|
|
84
83
|
from birder.net.rope_deit3 import RoPE_DeiT3
|
|
85
84
|
from birder.net.rope_flexivit import RoPE_FlexiViT
|
|
86
85
|
from birder.net.rope_vit import RoPE_ViT
|
|
87
|
-
from birder.net.se_resnet_v1 import SE_ResNet_v1
|
|
88
|
-
from birder.net.se_resnet_v2 import SE_ResNet_v2
|
|
89
|
-
from birder.net.se_resnext import SE_ResNeXt
|
|
90
86
|
from birder.net.sequencer2d import Sequencer2d
|
|
91
87
|
from birder.net.shufflenet_v1 import ShuffleNet_v1
|
|
92
88
|
from birder.net.shufflenet_v2 import ShuffleNet_v2
|
|
@@ -171,8 +167,7 @@ __all__ = [
|
|
|
171
167
|
"MNASNet",
|
|
172
168
|
"MobileNet_v1",
|
|
173
169
|
"MobileNet_v2",
|
|
174
|
-
"
|
|
175
|
-
"MobileNet_v3_Small",
|
|
170
|
+
"MobileNet_v3",
|
|
176
171
|
"MobileNet_v4",
|
|
177
172
|
"MobileNet_v4_Hybrid",
|
|
178
173
|
"MobileOne",
|
|
@@ -200,9 +195,6 @@ __all__ = [
|
|
|
200
195
|
"RoPE_DeiT3",
|
|
201
196
|
"RoPE_FlexiViT",
|
|
202
197
|
"RoPE_ViT",
|
|
203
|
-
"SE_ResNet_v1",
|
|
204
|
-
"SE_ResNet_v2",
|
|
205
|
-
"SE_ResNeXt",
|
|
206
198
|
"Sequencer2d",
|
|
207
199
|
"ShuffleNet_v1",
|
|
208
200
|
"ShuffleNet_v2",
|
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RoPE ViT model configuration registrations
|
|
3
|
+
|
|
4
|
+
This file contains *only* model variant definitions and their registration
|
|
5
|
+
with the global model registry. The actual RoPE ViT implementation lives in rope_vit.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from birder.model_registry import registry
|
|
9
|
+
from birder.net._vit_configs import BASE
|
|
10
|
+
from birder.net._vit_configs import GIANT
|
|
11
|
+
from birder.net._vit_configs import GIGANTIC
|
|
12
|
+
from birder.net._vit_configs import HUGE
|
|
13
|
+
from birder.net._vit_configs import LARGE
|
|
14
|
+
from birder.net._vit_configs import MEDIUM
|
|
15
|
+
from birder.net._vit_configs import SMALL
|
|
16
|
+
from birder.net._vit_configs import SO150
|
|
17
|
+
from birder.net._vit_configs import SO400
|
|
18
|
+
from birder.net._vit_configs import TINY
|
|
19
|
+
from birder.net.base import BaseNet
|
|
20
|
+
|
|
21
|
+
# Vision Transformer Model Naming Convention
|
|
22
|
+
# ==========================================
|
|
23
|
+
#
|
|
24
|
+
# Model names follow a structured pattern to encode architectural choices:
|
|
25
|
+
# [rope_]vit_[reg{N}_][size][patch_size][_components][_pooling][_c{N}]
|
|
26
|
+
#
|
|
27
|
+
# Core Components:
|
|
28
|
+
# - rope_ : Rotary Position Embedding (RoPE) enabled
|
|
29
|
+
# - rope_i_ : Rotary Position Embedding (RoPE) enabled with interleaved rotation - implies different temp, indexing
|
|
30
|
+
# - vit_ : Vision Transformer base architecture
|
|
31
|
+
# - reg{N}_ : Register tokens (N = number of register tokens, e.g., reg4, reg8)
|
|
32
|
+
# - size : Model size (s=small, b=base, l=large, or specific like so150m)
|
|
33
|
+
# - patch_size : Patch size (e.g., 14, 16, 32 for 14x14, 16x16, 32x32 patches)
|
|
34
|
+
#
|
|
35
|
+
# Optional Components:
|
|
36
|
+
# Position Embeddings:
|
|
37
|
+
# - nps : No Position embedding on Special tokens
|
|
38
|
+
#
|
|
39
|
+
# Normalization:
|
|
40
|
+
# - rms : RMSNorm (instead of LayerNorm)
|
|
41
|
+
# - pn : Pre-Norm (layer norm before the encoder) - implies norm eps of 1e-5
|
|
42
|
+
# - npn : No Post Norm (disables post-normalization layer)
|
|
43
|
+
# - qkn : QK Norm
|
|
44
|
+
#
|
|
45
|
+
# Feed-Forward Network:
|
|
46
|
+
# - swiglu : SwiGLU FFN layer type (instead of standard FFN)
|
|
47
|
+
#
|
|
48
|
+
# Activation:
|
|
49
|
+
# - quick_gelu : QuickGELU activation type
|
|
50
|
+
# - ...
|
|
51
|
+
#
|
|
52
|
+
# Regularization:
|
|
53
|
+
# - ls : Layer Scaling applied
|
|
54
|
+
#
|
|
55
|
+
# Pooling/Reduction:
|
|
56
|
+
# - avg : Average pooling for sequence reduction
|
|
57
|
+
# - ap : Attention Pooling for sequence reduction
|
|
58
|
+
# - aps : Attention Pooling inc. Special tokens for sequence reduction
|
|
59
|
+
#
|
|
60
|
+
# Custom Variants:
|
|
61
|
+
# - c{N} : Custom variant (N = version number) for models with fine-grained or non-standard
|
|
62
|
+
# modifications not fully reflected in the name
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def register_rope_vit_configs(rope_vit: type[BaseNet]) -> None:
|
|
66
|
+
registry.register_model_config(
|
|
67
|
+
"rope_vit_t32",
|
|
68
|
+
rope_vit,
|
|
69
|
+
config={"patch_size": 32, **TINY},
|
|
70
|
+
)
|
|
71
|
+
registry.register_model_config(
|
|
72
|
+
"rope_vit_t16",
|
|
73
|
+
rope_vit,
|
|
74
|
+
config={"patch_size": 16, **TINY},
|
|
75
|
+
)
|
|
76
|
+
registry.register_model_config(
|
|
77
|
+
"rope_vit_t14",
|
|
78
|
+
rope_vit,
|
|
79
|
+
config={"patch_size": 14, **TINY},
|
|
80
|
+
)
|
|
81
|
+
registry.register_model_config(
|
|
82
|
+
"rope_vit_s32",
|
|
83
|
+
rope_vit,
|
|
84
|
+
config={"patch_size": 32, **SMALL},
|
|
85
|
+
)
|
|
86
|
+
registry.register_model_config(
|
|
87
|
+
"rope_vit_s16",
|
|
88
|
+
rope_vit,
|
|
89
|
+
config={"patch_size": 16, **SMALL},
|
|
90
|
+
)
|
|
91
|
+
registry.register_model_config(
|
|
92
|
+
"rope_i_vit_s16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
|
|
93
|
+
rope_vit,
|
|
94
|
+
config={
|
|
95
|
+
"patch_size": 16,
|
|
96
|
+
**SMALL,
|
|
97
|
+
"pre_norm": True,
|
|
98
|
+
"attn_pool_head": True,
|
|
99
|
+
"attn_pool_num_heads": 8,
|
|
100
|
+
"attn_pool_special_tokens": True,
|
|
101
|
+
"norm_layer_eps": 1e-5,
|
|
102
|
+
"rope_rot_type": "interleaved",
|
|
103
|
+
"rope_grid_indexing": "xy",
|
|
104
|
+
"rope_grid_offset": 1,
|
|
105
|
+
"rope_temperature": 10000.0,
|
|
106
|
+
},
|
|
107
|
+
)
|
|
108
|
+
registry.register_model_config(
|
|
109
|
+
"rope_vit_s14",
|
|
110
|
+
rope_vit,
|
|
111
|
+
config={"patch_size": 14, **SMALL},
|
|
112
|
+
)
|
|
113
|
+
registry.register_model_config(
|
|
114
|
+
"rope_vit_m32",
|
|
115
|
+
rope_vit,
|
|
116
|
+
config={"patch_size": 32, **MEDIUM},
|
|
117
|
+
)
|
|
118
|
+
registry.register_model_config(
|
|
119
|
+
"rope_vit_m16",
|
|
120
|
+
rope_vit,
|
|
121
|
+
config={"patch_size": 16, **MEDIUM},
|
|
122
|
+
)
|
|
123
|
+
registry.register_model_config(
|
|
124
|
+
"rope_vit_m14",
|
|
125
|
+
rope_vit,
|
|
126
|
+
config={"patch_size": 14, **MEDIUM},
|
|
127
|
+
)
|
|
128
|
+
registry.register_model_config(
|
|
129
|
+
"rope_vit_b32",
|
|
130
|
+
rope_vit,
|
|
131
|
+
config={"patch_size": 32, **BASE, "drop_path_rate": 0.0}, # Override the BASE definition
|
|
132
|
+
)
|
|
133
|
+
registry.register_model_config(
|
|
134
|
+
"rope_vit_b16",
|
|
135
|
+
rope_vit,
|
|
136
|
+
config={"patch_size": 16, **BASE},
|
|
137
|
+
)
|
|
138
|
+
registry.register_model_config(
|
|
139
|
+
"rope_vit_b16_qkn_ls",
|
|
140
|
+
rope_vit,
|
|
141
|
+
config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5, "qk_norm": True},
|
|
142
|
+
)
|
|
143
|
+
registry.register_model_config(
|
|
144
|
+
"rope_i_vit_b16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
|
|
145
|
+
rope_vit,
|
|
146
|
+
config={
|
|
147
|
+
"patch_size": 16,
|
|
148
|
+
**BASE,
|
|
149
|
+
"pre_norm": True,
|
|
150
|
+
"attn_pool_head": True,
|
|
151
|
+
"attn_pool_num_heads": 8,
|
|
152
|
+
"attn_pool_special_tokens": True,
|
|
153
|
+
"norm_layer_eps": 1e-5,
|
|
154
|
+
"rope_rot_type": "interleaved",
|
|
155
|
+
"rope_grid_indexing": "xy",
|
|
156
|
+
"rope_grid_offset": 1,
|
|
157
|
+
"rope_temperature": 10000.0,
|
|
158
|
+
},
|
|
159
|
+
)
|
|
160
|
+
registry.register_model_config(
|
|
161
|
+
"rope_vit_b14",
|
|
162
|
+
rope_vit,
|
|
163
|
+
config={"patch_size": 14, **BASE},
|
|
164
|
+
)
|
|
165
|
+
registry.register_model_config(
|
|
166
|
+
"rope_vit_so150m_p14_ap",
|
|
167
|
+
rope_vit,
|
|
168
|
+
config={"patch_size": 14, **SO150, "class_token": False, "attn_pool_head": True},
|
|
169
|
+
)
|
|
170
|
+
registry.register_model_config(
|
|
171
|
+
"rope_vit_l32",
|
|
172
|
+
rope_vit,
|
|
173
|
+
config={"patch_size": 32, **LARGE},
|
|
174
|
+
)
|
|
175
|
+
registry.register_model_config(
|
|
176
|
+
"rope_vit_l16",
|
|
177
|
+
rope_vit,
|
|
178
|
+
config={"patch_size": 16, **LARGE},
|
|
179
|
+
)
|
|
180
|
+
registry.register_model_config(
|
|
181
|
+
"rope_vit_l14",
|
|
182
|
+
rope_vit,
|
|
183
|
+
config={"patch_size": 14, **LARGE},
|
|
184
|
+
)
|
|
185
|
+
registry.register_model_config(
|
|
186
|
+
"rope_i_vit_l14_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
|
|
187
|
+
rope_vit,
|
|
188
|
+
config={
|
|
189
|
+
"patch_size": 14,
|
|
190
|
+
**LARGE,
|
|
191
|
+
"pre_norm": True,
|
|
192
|
+
"attn_pool_head": True,
|
|
193
|
+
"attn_pool_num_heads": 8,
|
|
194
|
+
"attn_pool_special_tokens": True,
|
|
195
|
+
"norm_layer_eps": 1e-5,
|
|
196
|
+
"rope_rot_type": "interleaved",
|
|
197
|
+
"rope_grid_indexing": "xy",
|
|
198
|
+
"rope_grid_offset": 1,
|
|
199
|
+
"rope_temperature": 10000.0,
|
|
200
|
+
},
|
|
201
|
+
)
|
|
202
|
+
registry.register_model_config(
|
|
203
|
+
"rope_vit_so400m_p14_ap",
|
|
204
|
+
rope_vit,
|
|
205
|
+
config={"patch_size": 14, **SO400, "class_token": False, "attn_pool_head": True},
|
|
206
|
+
)
|
|
207
|
+
registry.register_model_config(
|
|
208
|
+
"rope_vit_h16",
|
|
209
|
+
rope_vit,
|
|
210
|
+
config={"patch_size": 16, **HUGE},
|
|
211
|
+
)
|
|
212
|
+
registry.register_model_config(
|
|
213
|
+
"rope_vit_h14",
|
|
214
|
+
rope_vit,
|
|
215
|
+
config={"patch_size": 14, **HUGE},
|
|
216
|
+
)
|
|
217
|
+
registry.register_model_config( # From "Scaling Vision Transformers"
|
|
218
|
+
"rope_vit_g16",
|
|
219
|
+
rope_vit,
|
|
220
|
+
config={"patch_size": 16, **GIANT},
|
|
221
|
+
)
|
|
222
|
+
registry.register_model_config( # From "Scaling Vision Transformers"
|
|
223
|
+
"rope_vit_g14",
|
|
224
|
+
rope_vit,
|
|
225
|
+
config={"patch_size": 14, **GIANT},
|
|
226
|
+
)
|
|
227
|
+
registry.register_model_config( # From "Scaling Vision Transformers"
|
|
228
|
+
"rope_vit_gigantic14",
|
|
229
|
+
rope_vit,
|
|
230
|
+
config={"patch_size": 14, **GIGANTIC},
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# With registers
|
|
234
|
+
####################
|
|
235
|
+
|
|
236
|
+
registry.register_model_config(
|
|
237
|
+
"rope_vit_reg1_t32",
|
|
238
|
+
rope_vit,
|
|
239
|
+
config={"patch_size": 32, **TINY, "num_reg_tokens": 1},
|
|
240
|
+
)
|
|
241
|
+
registry.register_model_config(
|
|
242
|
+
"rope_vit_reg1_t16",
|
|
243
|
+
rope_vit,
|
|
244
|
+
config={"patch_size": 16, **TINY, "num_reg_tokens": 1},
|
|
245
|
+
)
|
|
246
|
+
registry.register_model_config(
|
|
247
|
+
"rope_vit_reg1_t14",
|
|
248
|
+
rope_vit,
|
|
249
|
+
config={"patch_size": 14, **TINY, "num_reg_tokens": 1},
|
|
250
|
+
)
|
|
251
|
+
registry.register_model_config(
|
|
252
|
+
"rope_vit_reg1_s32",
|
|
253
|
+
rope_vit,
|
|
254
|
+
config={"patch_size": 32, **SMALL, "num_reg_tokens": 1},
|
|
255
|
+
)
|
|
256
|
+
registry.register_model_config(
|
|
257
|
+
"rope_vit_reg1_s16",
|
|
258
|
+
rope_vit,
|
|
259
|
+
config={"patch_size": 16, **SMALL, "num_reg_tokens": 1},
|
|
260
|
+
)
|
|
261
|
+
registry.register_model_config(
|
|
262
|
+
"rope_i_vit_reg1_s16_pn_npn_avg_c1", # For PE Spatial - https://arxiv.org/abs/2504.13181
|
|
263
|
+
rope_vit,
|
|
264
|
+
config={
|
|
265
|
+
"patch_size": 16,
|
|
266
|
+
**SMALL,
|
|
267
|
+
"num_reg_tokens": 1,
|
|
268
|
+
"class_token": False,
|
|
269
|
+
"pre_norm": True,
|
|
270
|
+
"post_norm": False,
|
|
271
|
+
"norm_layer_eps": 1e-5,
|
|
272
|
+
"rope_rot_type": "interleaved",
|
|
273
|
+
"rope_grid_indexing": "xy",
|
|
274
|
+
"rope_grid_offset": 1,
|
|
275
|
+
"rope_temperature": 10000.0,
|
|
276
|
+
},
|
|
277
|
+
)
|
|
278
|
+
registry.register_model_config(
|
|
279
|
+
"rope_vit_reg1_s14",
|
|
280
|
+
rope_vit,
|
|
281
|
+
config={"patch_size": 14, **SMALL, "num_reg_tokens": 1},
|
|
282
|
+
)
|
|
283
|
+
registry.register_model_config(
|
|
284
|
+
"rope_vit_reg4_m32",
|
|
285
|
+
rope_vit,
|
|
286
|
+
config={"patch_size": 32, **MEDIUM, "num_reg_tokens": 4},
|
|
287
|
+
)
|
|
288
|
+
registry.register_model_config(
|
|
289
|
+
"rope_vit_reg4_m16",
|
|
290
|
+
rope_vit,
|
|
291
|
+
config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4},
|
|
292
|
+
)
|
|
293
|
+
registry.register_model_config(
|
|
294
|
+
"rope_vit_reg4_m16_rms_avg",
|
|
295
|
+
rope_vit,
|
|
296
|
+
config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4, "class_token": False, "norm_layer_type": "RMSNorm"},
|
|
297
|
+
)
|
|
298
|
+
registry.register_model_config(
|
|
299
|
+
"rope_vit_reg4_m14",
|
|
300
|
+
rope_vit,
|
|
301
|
+
config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4},
|
|
302
|
+
)
|
|
303
|
+
registry.register_model_config(
|
|
304
|
+
"rope_vit_reg4_m14_avg",
|
|
305
|
+
rope_vit,
|
|
306
|
+
config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4, "class_token": False},
|
|
307
|
+
)
|
|
308
|
+
registry.register_model_config(
|
|
309
|
+
"rope_vit_reg4_b32",
|
|
310
|
+
rope_vit,
|
|
311
|
+
config={"patch_size": 32, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.0}, # Override the BASE definition
|
|
312
|
+
)
|
|
313
|
+
registry.register_model_config(
|
|
314
|
+
"rope_vit_reg4_b16",
|
|
315
|
+
rope_vit,
|
|
316
|
+
config={"patch_size": 16, **BASE, "num_reg_tokens": 4},
|
|
317
|
+
)
|
|
318
|
+
registry.register_model_config(
|
|
319
|
+
"rope_vit_reg4_b14",
|
|
320
|
+
rope_vit,
|
|
321
|
+
config={"patch_size": 14, **BASE, "num_reg_tokens": 4},
|
|
322
|
+
)
|
|
323
|
+
registry.register_model_config(
|
|
324
|
+
"rope_vit_reg8_nps_b14_ap",
|
|
325
|
+
rope_vit,
|
|
326
|
+
config={
|
|
327
|
+
"pos_embed_special_tokens": False,
|
|
328
|
+
"patch_size": 14,
|
|
329
|
+
**BASE,
|
|
330
|
+
"num_reg_tokens": 8,
|
|
331
|
+
"class_token": False,
|
|
332
|
+
"attn_pool_head": True,
|
|
333
|
+
},
|
|
334
|
+
)
|
|
335
|
+
registry.register_model_config(
|
|
336
|
+
"rope_vit_reg4_so150m_p14_ap",
|
|
337
|
+
rope_vit,
|
|
338
|
+
config={"patch_size": 14, **SO150, "num_reg_tokens": 4, "class_token": False, "attn_pool_head": True},
|
|
339
|
+
)
|
|
340
|
+
registry.register_model_config(
|
|
341
|
+
"rope_vit_reg8_so150m_p14_ap",
|
|
342
|
+
rope_vit,
|
|
343
|
+
config={"patch_size": 14, **SO150, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
|
|
344
|
+
)
|
|
345
|
+
registry.register_model_config(
|
|
346
|
+
"rope_vit_reg8_so150m_p14_swiglu_rms_avg",
|
|
347
|
+
rope_vit,
|
|
348
|
+
config={
|
|
349
|
+
"patch_size": 14,
|
|
350
|
+
**SO150,
|
|
351
|
+
"num_reg_tokens": 8,
|
|
352
|
+
"class_token": False,
|
|
353
|
+
"norm_layer_type": "RMSNorm",
|
|
354
|
+
"mlp_layer_type": "SwiGLU_FFN",
|
|
355
|
+
},
|
|
356
|
+
)
|
|
357
|
+
registry.register_model_config(
|
|
358
|
+
"rope_vit_reg8_so150m_p14_swiglu_rms_ap",
|
|
359
|
+
rope_vit,
|
|
360
|
+
config={
|
|
361
|
+
"patch_size": 14,
|
|
362
|
+
**SO150,
|
|
363
|
+
"num_reg_tokens": 8,
|
|
364
|
+
"class_token": False,
|
|
365
|
+
"attn_pool_head": True,
|
|
366
|
+
"norm_layer_type": "RMSNorm",
|
|
367
|
+
"mlp_layer_type": "SwiGLU_FFN",
|
|
368
|
+
},
|
|
369
|
+
)
|
|
370
|
+
registry.register_model_config(
|
|
371
|
+
"rope_vit_reg8_so150m_p14_swiglu_rms_aps",
|
|
372
|
+
rope_vit,
|
|
373
|
+
config={
|
|
374
|
+
"patch_size": 14,
|
|
375
|
+
**SO150,
|
|
376
|
+
"num_reg_tokens": 8,
|
|
377
|
+
"class_token": False,
|
|
378
|
+
"attn_pool_head": True,
|
|
379
|
+
"attn_pool_special_tokens": True,
|
|
380
|
+
"norm_layer_type": "RMSNorm",
|
|
381
|
+
"mlp_layer_type": "SwiGLU_FFN",
|
|
382
|
+
},
|
|
383
|
+
)
|
|
384
|
+
registry.register_model_config(
|
|
385
|
+
"rope_vit_reg4_l32",
|
|
386
|
+
rope_vit,
|
|
387
|
+
config={"patch_size": 32, **LARGE, "num_reg_tokens": 4},
|
|
388
|
+
)
|
|
389
|
+
registry.register_model_config(
|
|
390
|
+
"rope_vit_reg4_l16",
|
|
391
|
+
rope_vit,
|
|
392
|
+
config={"patch_size": 16, **LARGE, "num_reg_tokens": 4},
|
|
393
|
+
)
|
|
394
|
+
registry.register_model_config(
|
|
395
|
+
"rope_vit_reg4_l14",
|
|
396
|
+
rope_vit,
|
|
397
|
+
config={"patch_size": 14, **LARGE, "num_reg_tokens": 4},
|
|
398
|
+
)
|
|
399
|
+
registry.register_model_config(
|
|
400
|
+
"rope_vit_reg8_l14_rms_ap",
|
|
401
|
+
rope_vit,
|
|
402
|
+
config={
|
|
403
|
+
"patch_size": 14,
|
|
404
|
+
**LARGE,
|
|
405
|
+
"num_reg_tokens": 8,
|
|
406
|
+
"class_token": False,
|
|
407
|
+
"attn_pool_head": True,
|
|
408
|
+
"norm_layer_type": "RMSNorm",
|
|
409
|
+
},
|
|
410
|
+
)
|
|
411
|
+
registry.register_model_config(
|
|
412
|
+
"rope_vit_reg8_so400m_p14_ap",
|
|
413
|
+
rope_vit,
|
|
414
|
+
config={"patch_size": 14, **SO400, "num_reg_tokens": 8, "class_token": False, "attn_pool_head": True},
|
|
415
|
+
)
|
|
416
|
+
registry.register_model_config(
|
|
417
|
+
"rope_vit_reg4_h16",
|
|
418
|
+
rope_vit,
|
|
419
|
+
config={"patch_size": 16, **HUGE, "num_reg_tokens": 4},
|
|
420
|
+
)
|
|
421
|
+
registry.register_model_config(
|
|
422
|
+
"rope_vit_reg4_h14",
|
|
423
|
+
rope_vit,
|
|
424
|
+
config={"patch_size": 14, **HUGE, "num_reg_tokens": 4},
|
|
425
|
+
)
|
|
426
|
+
registry.register_model_config( # From "Scaling Vision Transformers"
|
|
427
|
+
"rope_vit_reg4_g14",
|
|
428
|
+
rope_vit,
|
|
429
|
+
config={"patch_size": 14, **GIANT, "num_reg_tokens": 4},
|
|
430
|
+
)
|