brainscore-vision 2.1__py3-none-any.whl → 2.2.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (143) hide show
  1. brainscore_vision/benchmarks/coggan2024_behavior/__init__.py +2 -1
  2. brainscore_vision/benchmarks/coggan2024_behavior/test.py +2 -2
  3. brainscore_vision/benchmarks/coggan2024_fMRI/__init__.py +4 -4
  4. brainscore_vision/benchmarks/coggan2024_fMRI/test.py +2 -2
  5. brainscore_vision/benchmarks/imagenet/imagenet2012.csv +50000 -50000
  6. brainscore_vision/benchmarks/imagenet_c/benchmark.py +1 -1
  7. brainscore_vision/benchmarks/lonnqvist2024/__init__.py +8 -0
  8. brainscore_vision/benchmarks/lonnqvist2024/benchmark.py +125 -0
  9. brainscore_vision/benchmarks/lonnqvist2024/test.py +61 -0
  10. brainscore_vision/benchmarks/malania2007/benchmark.py +3 -0
  11. brainscore_vision/benchmarks/maniquet2024/benchmark.py +1 -1
  12. brainscore_vision/data/lonnqvist2024/__init__.py +47 -0
  13. brainscore_vision/data/lonnqvist2024/data_packaging/lonnqvist_data_assembly.py +53 -0
  14. brainscore_vision/data/lonnqvist2024/data_packaging/lonnqvist_stimulus_set.py +61 -0
  15. brainscore_vision/data/lonnqvist2024/test.py +127 -0
  16. brainscore_vision/model_helpers/brain_transformation/__init__.py +33 -0
  17. brainscore_vision/models/alexnet/region_layer_map/alexnet.json +1 -0
  18. brainscore_vision/models/alexnet_7be5be79/setup.py +4 -4
  19. brainscore_vision/models/alexnet_random/__init__.py +7 -0
  20. brainscore_vision/models/alexnet_random/model.py +46 -0
  21. brainscore_vision/models/alexnet_random/setup.py +26 -0
  22. brainscore_vision/models/alexnet_random/test.py +1 -0
  23. brainscore_vision/models/cvt_cvt_13_224_in1k_4/__init__.py +9 -0
  24. brainscore_vision/models/cvt_cvt_13_224_in1k_4/model.py +142 -0
  25. brainscore_vision/models/cvt_cvt_13_224_in1k_4/region_layer_map/cvt_cvt-13-224-in1k_4.json +6 -0
  26. brainscore_vision/models/cvt_cvt_13_224_in1k_4/region_layer_map/cvt_cvt-13-224-in1k_4_LucyV4.json +6 -0
  27. brainscore_vision/models/cvt_cvt_13_224_in1k_4/requirements.txt +4 -0
  28. brainscore_vision/models/cvt_cvt_13_224_in1k_4/test.py +8 -0
  29. brainscore_vision/models/cvt_cvt_13_384_in1k_4/__init__.py +9 -0
  30. brainscore_vision/models/cvt_cvt_13_384_in1k_4/model.py +142 -0
  31. brainscore_vision/models/cvt_cvt_13_384_in1k_4/region_layer_map/cvt_cvt-13-384-in1k_4_LucyV4.json +6 -0
  32. brainscore_vision/models/cvt_cvt_13_384_in1k_4/requirements.txt +4 -0
  33. brainscore_vision/models/cvt_cvt_13_384_in1k_4/test.py +8 -0
  34. brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/__init__.py +9 -0
  35. brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/model.py +142 -0
  36. brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/region_layer_map/cvt_cvt-13-384-in22k_finetuned-in1k_4_LucyV4.json +6 -0
  37. brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/requirements.txt +4 -0
  38. brainscore_vision/models/cvt_cvt_13_384_in22k_finetuned_in1k_4/test.py +8 -0
  39. brainscore_vision/models/cvt_cvt_21_224_in1k_4/__init__.py +9 -0
  40. brainscore_vision/models/cvt_cvt_21_224_in1k_4/model.py +142 -0
  41. brainscore_vision/models/cvt_cvt_21_224_in1k_4/region_layer_map/cvt_cvt-21-224-in1k_4_LucyV4.json +6 -0
  42. brainscore_vision/models/cvt_cvt_21_224_in1k_4/requirements.txt +4 -0
  43. brainscore_vision/models/cvt_cvt_21_224_in1k_4/test.py +8 -0
  44. brainscore_vision/models/cvt_cvt_21_384_in1k_4/__init__.py +9 -0
  45. brainscore_vision/models/cvt_cvt_21_384_in1k_4/model.py +142 -0
  46. brainscore_vision/models/cvt_cvt_21_384_in1k_4/region_layer_map/cvt_cvt-21-384-in1k_4_LucyV4.json +6 -0
  47. brainscore_vision/models/cvt_cvt_21_384_in1k_4/requirements.txt +4 -0
  48. brainscore_vision/models/cvt_cvt_21_384_in1k_4/test.py +8 -0
  49. brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/__init__.py +9 -0
  50. brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/model.py +142 -0
  51. brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/region_layer_map/cvt_cvt-21-384-in22k_finetuned-in1k_4_LucyV4.json +6 -0
  52. brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/requirements.txt +4 -0
  53. brainscore_vision/models/cvt_cvt_21_384_in22k_finetuned_in1k_4/test.py +8 -0
  54. brainscore_vision/models/fixres_resnext101_32x48d_wsl/__init__.py +7 -0
  55. brainscore_vision/models/fixres_resnext101_32x48d_wsl/model.py +57 -0
  56. brainscore_vision/models/fixres_resnext101_32x48d_wsl/requirements.txt +5 -0
  57. brainscore_vision/models/fixres_resnext101_32x48d_wsl/test.py +7 -0
  58. brainscore_vision/models/inception_v4_pytorch/__init__.py +7 -0
  59. brainscore_vision/models/inception_v4_pytorch/model.py +64 -0
  60. brainscore_vision/models/inception_v4_pytorch/requirements.txt +3 -0
  61. brainscore_vision/models/inception_v4_pytorch/test.py +8 -0
  62. brainscore_vision/models/mvimgnet_ms_05/__init__.py +9 -0
  63. brainscore_vision/models/mvimgnet_ms_05/model.py +64 -0
  64. brainscore_vision/models/mvimgnet_ms_05/setup.py +25 -0
  65. brainscore_vision/models/mvimgnet_ms_05/test.py +1 -0
  66. brainscore_vision/models/mvimgnet_rf/__init__.py +9 -0
  67. brainscore_vision/models/mvimgnet_rf/model.py +64 -0
  68. brainscore_vision/models/mvimgnet_rf/setup.py +25 -0
  69. brainscore_vision/models/mvimgnet_rf/test.py +1 -0
  70. brainscore_vision/models/mvimgnet_ss_00/__init__.py +9 -0
  71. brainscore_vision/models/mvimgnet_ss_00/model.py +64 -0
  72. brainscore_vision/models/mvimgnet_ss_00/setup.py +25 -0
  73. brainscore_vision/models/mvimgnet_ss_00/test.py +1 -0
  74. brainscore_vision/models/mvimgnet_ss_02/__init__.py +9 -0
  75. brainscore_vision/models/mvimgnet_ss_02/model.py +64 -0
  76. brainscore_vision/models/mvimgnet_ss_02/setup.py +25 -0
  77. brainscore_vision/models/mvimgnet_ss_02/test.py +1 -0
  78. brainscore_vision/models/mvimgnet_ss_03/__init__.py +9 -0
  79. brainscore_vision/models/mvimgnet_ss_03/model.py +64 -0
  80. brainscore_vision/models/mvimgnet_ss_03/setup.py +25 -0
  81. brainscore_vision/models/mvimgnet_ss_03/test.py +1 -0
  82. brainscore_vision/models/mvimgnet_ss_04/__init__.py +9 -0
  83. brainscore_vision/models/mvimgnet_ss_04/model.py +64 -0
  84. brainscore_vision/models/mvimgnet_ss_04/setup.py +25 -0
  85. brainscore_vision/models/mvimgnet_ss_04/test.py +1 -0
  86. brainscore_vision/models/mvimgnet_ss_05/__init__.py +9 -0
  87. brainscore_vision/models/mvimgnet_ss_05/model.py +64 -0
  88. brainscore_vision/models/mvimgnet_ss_05/setup.py +25 -0
  89. brainscore_vision/models/mvimgnet_ss_05/test.py +1 -0
  90. brainscore_vision/models/resnet50_tutorial/region_layer_map/resnet50_tutorial.json +1 -0
  91. brainscore_vision/models/sam_test_resnet/__init__.py +5 -0
  92. brainscore_vision/models/sam_test_resnet/model.py +26 -0
  93. brainscore_vision/models/sam_test_resnet/requirements.txt +2 -0
  94. brainscore_vision/models/sam_test_resnet/test.py +8 -0
  95. brainscore_vision/models/sam_test_resnet_4/__init__.py +5 -0
  96. brainscore_vision/models/sam_test_resnet_4/model.py +26 -0
  97. brainscore_vision/models/sam_test_resnet_4/requirements.txt +2 -0
  98. brainscore_vision/models/sam_test_resnet_4/test.py +8 -0
  99. brainscore_vision/models/scaling_models/__init__.py +265 -0
  100. brainscore_vision/models/scaling_models/model.py +148 -0
  101. brainscore_vision/models/scaling_models/model_configs.json +869 -0
  102. brainscore_vision/models/scaling_models/region_layer_map/convnext_base_imagenet_full_seed-0.json +6 -0
  103. brainscore_vision/models/scaling_models/region_layer_map/convnext_large_imagenet_full_seed-0.json +6 -0
  104. brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_100_seed-0.json +6 -0
  105. brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_10_seed-0.json +6 -0
  106. brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_1_seed-0.json +6 -0
  107. brainscore_vision/models/scaling_models/region_layer_map/convnext_small_imagenet_full_seed-0.json +6 -0
  108. brainscore_vision/models/scaling_models/region_layer_map/deit_base_imagenet_full_seed-0.json +6 -0
  109. brainscore_vision/models/scaling_models/region_layer_map/deit_large_imagenet_full_seed-0.json +6 -0
  110. brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_100_seed-0.json +6 -0
  111. brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_10_seed-0.json +6 -0
  112. brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_1_seed-0.json +6 -0
  113. brainscore_vision/models/scaling_models/region_layer_map/deit_small_imagenet_full_seed-0.json +6 -0
  114. brainscore_vision/models/scaling_models/region_layer_map/efficientnet_b0_imagenet_full.json +6 -0
  115. brainscore_vision/models/scaling_models/region_layer_map/efficientnet_b1_imagenet_full.json +6 -0
  116. brainscore_vision/models/scaling_models/region_layer_map/efficientnet_b2_imagenet_full.json +6 -0
  117. brainscore_vision/models/scaling_models/region_layer_map/resnet101_ecoset_full.json +6 -0
  118. brainscore_vision/models/scaling_models/region_layer_map/resnet101_imagenet_full.json +6 -0
  119. brainscore_vision/models/scaling_models/region_layer_map/resnet152_ecoset_full.json +6 -0
  120. brainscore_vision/models/scaling_models/region_layer_map/resnet18_ecoset_full.json +6 -0
  121. brainscore_vision/models/scaling_models/region_layer_map/resnet18_imagenet_full.json +6 -0
  122. brainscore_vision/models/scaling_models/region_layer_map/resnet34_ecoset_full.json +6 -0
  123. brainscore_vision/models/scaling_models/region_layer_map/resnet34_imagenet_full.json +6 -0
  124. brainscore_vision/models/scaling_models/region_layer_map/resnet50_ecoset_full.json +6 -0
  125. brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_100_seed-0.json +6 -0
  126. brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_10_seed-0.json +6 -0
  127. brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_1_seed-0.json +6 -0
  128. brainscore_vision/models/scaling_models/region_layer_map/resnet50_imagenet_full.json +6 -0
  129. brainscore_vision/models/scaling_models/requirements.txt +4 -0
  130. brainscore_vision/models/scaling_models/test.py +0 -0
  131. brainscore_vision/models/vitb14_dinov2_imagenet1k/__init__.py +5 -0
  132. brainscore_vision/models/vitb14_dinov2_imagenet1k/model.py +852 -0
  133. brainscore_vision/models/vitb14_dinov2_imagenet1k/setup.py +25 -0
  134. brainscore_vision/models/vitb14_dinov2_imagenet1k/test.py +0 -0
  135. brainscore_vision/models/voneresnet_50_non_stochastic/region_layer_map/voneresnet-50-non_stochastic.json +1 -0
  136. brainscore_vision/submission/actions_helpers.py +2 -2
  137. brainscore_vision/submission/endpoints.py +3 -4
  138. {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/METADATA +2 -2
  139. {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/RECORD +143 -18
  140. {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/WHEEL +1 -1
  141. tests/test_model_helpers/temporal/activations/test_inferencer.py +2 -2
  142. {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/LICENSE +0 -0
  143. {brainscore_vision-2.1.dist-info → brainscore_vision-2.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,852 @@
1
+ import functools
2
+ import ssl
3
+ import math
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.init import trunc_normal_
9
+ from typing import Sequence, Tuple, Union, Callable, Optional, List, Any, Dict
10
+ from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
11
+ from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
12
+ from brainscore_vision.model_helpers.check_submission import check_models
13
+
14
+
15
+ ssl._create_default_https_context = ssl._create_unverified_context
16
+
17
+
18
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
19
+ if drop_prob == 0.0 or not training:
20
+ return x
21
+ keep_prob = 1 - drop_prob
22
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
23
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
24
+ if keep_prob > 0.0:
25
+ random_tensor.div_(keep_prob)
26
+ output = x * random_tensor
27
+ return output
28
+
29
+
30
+ class DropPath(nn.Module):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
32
+
33
+ def __init__(self, drop_prob=None):
34
+ super(DropPath, self).__init__()
35
+ self.drop_prob = drop_prob
36
+
37
+ def forward(self, x):
38
+ return drop_path(x, self.drop_prob, self.training)
39
+
40
+
41
+ class LayerScale(nn.Module):
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ init_values: Union[float, torch.torch.Tensor] = 1e-5,
46
+ inplace: bool = False,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.inplace = inplace
50
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
51
+
52
+ def forward(self, x: torch.torch.Tensor) -> torch.torch.Tensor:
53
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
54
+
55
+
56
+ class Mlp(nn.Module):
57
+ def __init__(
58
+ self,
59
+ in_features: int,
60
+ hidden_features: Optional[int] = None,
61
+ out_features: Optional[int] = None,
62
+ act_layer: Callable[..., nn.Module] = nn.GELU,
63
+ drop: float = 0.0,
64
+ bias: bool = True,
65
+ ) -> None:
66
+ super().__init__()
67
+ out_features = out_features or in_features
68
+ hidden_features = hidden_features or in_features
69
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
70
+ self.act = act_layer()
71
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
72
+ self.drop = nn.Dropout(drop)
73
+
74
+ def forward(self, x: torch.torch.Tensor) -> torch.torch.Tensor:
75
+ x = self.fc1(x)
76
+ x = self.act(x)
77
+ x = self.drop(x)
78
+ x = self.fc2(x)
79
+ x = self.drop(x)
80
+ return x
81
+
82
+
83
+ def make_2tuple(x):
84
+ if isinstance(x, tuple):
85
+ assert len(x) == 2
86
+ return x
87
+
88
+ assert isinstance(x, int)
89
+ return (x, x)
90
+
91
+
92
+ class PatchEmbed(nn.Module):
93
+ """
94
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
95
+
96
+ Args:
97
+ img_size: Image size.
98
+ patch_size: Patch token size.
99
+ in_chans: Number of input image channels.
100
+ embed_dim: Number of linear projection output channels.
101
+ norm_layer: Normalization layer.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ img_size: Union[int, Tuple[int, int]] = 224,
107
+ patch_size: Union[int, Tuple[int, int]] = 16,
108
+ in_chans: int = 3,
109
+ embed_dim: int = 768,
110
+ norm_layer: Optional[Callable] = None,
111
+ flatten_embedding: bool = True,
112
+ ) -> None:
113
+ super().__init__()
114
+
115
+ image_HW = make_2tuple(img_size)
116
+ patch_HW = make_2tuple(patch_size)
117
+ patch_grid_size = (
118
+ image_HW[0] // patch_HW[0],
119
+ image_HW[1] // patch_HW[1],
120
+ )
121
+
122
+ self.img_size = image_HW
123
+ self.patch_size = patch_HW
124
+ self.patches_resolution = patch_grid_size
125
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
126
+
127
+ self.in_chans = in_chans
128
+ self.embed_dim = embed_dim
129
+
130
+ self.flatten_embedding = flatten_embedding
131
+
132
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
133
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
134
+
135
+ def forward(self, x: torch.torch.Tensor) -> torch.torch.Tensor:
136
+ _, _, H, W = x.shape
137
+ patch_H, patch_W = self.patch_size
138
+
139
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
140
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
141
+
142
+ x = self.proj(x) # B C H W
143
+ H, W = x.size(2), x.size(3)
144
+ x = x.flatten(2).transpose(1, 2) # B HW C
145
+ x = self.norm(x)
146
+ if not self.flatten_embedding:
147
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
148
+ return x
149
+
150
+ def flops(self) -> float:
151
+ Ho, Wo = self.patches_resolution
152
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
153
+ if self.norm is not None:
154
+ flops += Ho * Wo * self.embed_dim
155
+ return flops
156
+
157
+
158
+ class SwiGLUFFN(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_features: int,
162
+ hidden_features: Optional[int] = None,
163
+ out_features: Optional[int] = None,
164
+ act_layer: Callable[..., nn.Module] = None,
165
+ drop: float = 0.0,
166
+ bias: bool = True,
167
+ ) -> None:
168
+ super().__init__()
169
+ out_features = out_features or in_features
170
+ hidden_features = hidden_features or in_features
171
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
172
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
173
+
174
+ def forward(self, x: torch.torch.Tensor) -> torch.torch.Tensor:
175
+ x12 = self.w12(x)
176
+ x1, x2 = x12.chunk(2, dim=-1)
177
+ hidden = F.silu(x1) * x2
178
+ return self.w3(hidden)
179
+
180
+
181
+ class SwiGLUFFNFused(SwiGLUFFN):
182
+ def __init__(
183
+ self,
184
+ in_features: int,
185
+ hidden_features: Optional[int] = None,
186
+ out_features: Optional[int] = None,
187
+ act_layer: Callable[..., nn.Module] = None,
188
+ drop: float = 0.0,
189
+ bias: bool = True,
190
+ ) -> None:
191
+ out_features = out_features or in_features
192
+ hidden_features = hidden_features or in_features
193
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
194
+ super().__init__(
195
+ in_features=in_features,
196
+ hidden_features=hidden_features,
197
+ out_features=out_features,
198
+ bias=bias,
199
+ )
200
+
201
+
202
+ class Attention(nn.Module):
203
+ def __init__(
204
+ self,
205
+ dim: int,
206
+ num_heads: int = 8,
207
+ qkv_bias: bool = False,
208
+ proj_bias: bool = True,
209
+ attn_drop: float = 0.0,
210
+ proj_drop: float = 0.0,
211
+ ) -> None:
212
+ super().__init__()
213
+ self.num_heads = num_heads
214
+ head_dim = dim // num_heads
215
+ self.scale = head_dim**-0.5
216
+
217
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
218
+ self.attn_drop = nn.Dropout(attn_drop)
219
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
220
+ self.proj_drop = nn.Dropout(proj_drop)
221
+
222
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
223
+ B, N, C = x.shape
224
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
225
+
226
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
227
+ attn = q @ k.transpose(-2, -1)
228
+
229
+ attn = attn.softmax(dim=-1)
230
+ attn = self.attn_drop(attn)
231
+
232
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
233
+ x = self.proj(x)
234
+ x = self.proj_drop(x)
235
+ return x
236
+
237
+
238
+ class MemEffAttention(Attention):
239
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
240
+ return super().forward(x)
241
+
242
+
243
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
244
+ if not depth_first and include_root:
245
+ fn(module=module, name=name)
246
+ for child_name, child_module in module.named_children():
247
+ child_name = ".".join((name, child_name)) if name else child_name
248
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
249
+ if depth_first and include_root:
250
+ fn(module=module, name=name)
251
+ return module
252
+
253
+
254
+ class BlockChunk(nn.ModuleList):
255
+ def forward(self, x):
256
+ for b in self:
257
+ x = b(x)
258
+ return x
259
+
260
+
261
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
262
+ try:
263
+ if XFORMERS_ENABLED:
264
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
265
+
266
+ XFORMERS_AVAILABLE = True
267
+ else:
268
+ raise ImportError
269
+ except ImportError:
270
+ XFORMERS_AVAILABLE = False
271
+
272
+
273
+ class Block(nn.Module):
274
+ def __init__(
275
+ self,
276
+ dim: int,
277
+ num_heads: int,
278
+ mlp_ratio: float = 4.0,
279
+ qkv_bias: bool = False,
280
+ proj_bias: bool = True,
281
+ ffn_bias: bool = True,
282
+ drop: float = 0.0,
283
+ attn_drop: float = 0.0,
284
+ init_values=None,
285
+ drop_path: float = 0.0,
286
+ act_layer: Callable[..., nn.Module] = nn.GELU,
287
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
288
+ attn_class: Callable[..., nn.Module] = Attention,
289
+ ffn_layer: Callable[..., nn.Module] = Mlp,
290
+ ) -> None:
291
+ super().__init__()
292
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
293
+ self.norm1 = norm_layer(dim)
294
+ self.attn = attn_class(
295
+ dim,
296
+ num_heads=num_heads,
297
+ qkv_bias=qkv_bias,
298
+ proj_bias=proj_bias,
299
+ attn_drop=attn_drop,
300
+ proj_drop=drop,
301
+ )
302
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
303
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
304
+
305
+ self.norm2 = norm_layer(dim)
306
+ mlp_hidden_dim = int(dim * mlp_ratio)
307
+ self.mlp = ffn_layer(
308
+ in_features=dim,
309
+ hidden_features=mlp_hidden_dim,
310
+ act_layer=act_layer,
311
+ drop=drop,
312
+ bias=ffn_bias,
313
+ )
314
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
315
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
316
+
317
+ self.sample_drop_ratio = drop_path
318
+
319
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
320
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
321
+ return self.ls1(self.attn(self.norm1(x)))
322
+
323
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
324
+ return self.ls2(self.mlp(self.norm2(x)))
325
+
326
+ if self.training and self.sample_drop_ratio > 0.1:
327
+ # the overhead is compensated only for a drop path rate larger than 0.1
328
+ x = drop_add_residual_stochastic_depth(
329
+ x,
330
+ residual_func=attn_residual_func,
331
+ sample_drop_ratio=self.sample_drop_ratio,
332
+ )
333
+ x = drop_add_residual_stochastic_depth(
334
+ x,
335
+ residual_func=ffn_residual_func,
336
+ sample_drop_ratio=self.sample_drop_ratio,
337
+ )
338
+ elif self.training and self.sample_drop_ratio > 0.0:
339
+ x = x + self.drop_path1(attn_residual_func(x))
340
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
341
+ else:
342
+ x = x + attn_residual_func(x)
343
+ x = x + ffn_residual_func(x)
344
+ return x
345
+
346
+
347
+ def drop_add_residual_stochastic_depth(
348
+ x: torch.Tensor,
349
+ residual_func: Callable[[torch.Tensor], torch.Tensor],
350
+ sample_drop_ratio: float = 0.0,
351
+ ) -> torch.Tensor:
352
+ # 1) extract subset using permutation
353
+ b, n, d = x.shape
354
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
355
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
356
+ x_subset = x[brange]
357
+
358
+ # 2) apply residual_func to get residual
359
+ residual = residual_func(x_subset)
360
+
361
+ x_flat = x.flatten(1)
362
+ residual = residual.flatten(1)
363
+
364
+ residual_scale_factor = b / sample_subset_size
365
+
366
+ # 3) add the residual
367
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
368
+ return x_plus_residual.view_as(x)
369
+
370
+
371
+ def get_branges_scales(x, sample_drop_ratio=0.0):
372
+ b, n, d = x.shape
373
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
374
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
375
+ residual_scale_factor = b / sample_subset_size
376
+ return brange, residual_scale_factor
377
+
378
+
379
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
380
+ if scaling_vector is None:
381
+ x_flat = x.flatten(1)
382
+ residual = residual.flatten(1)
383
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
384
+ else:
385
+ x_plus_residual = scaled_index_add(
386
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
387
+ )
388
+ return x_plus_residual
389
+
390
+
391
+ attn_bias_cache: Dict[Tuple, Any] = {}
392
+
393
+
394
+ def get_attn_bias_and_cat(x_list, branges=None):
395
+ """
396
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
397
+ """
398
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
399
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
400
+ if all_shapes not in attn_bias_cache.keys():
401
+ seqlens = []
402
+ for b, x in zip(batch_sizes, x_list):
403
+ for _ in range(b):
404
+ seqlens.append(x.shape[1])
405
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
406
+ attn_bias._batch_sizes = batch_sizes
407
+ attn_bias_cache[all_shapes] = attn_bias
408
+
409
+ if branges is not None:
410
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
411
+ else:
412
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
413
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
414
+
415
+ return attn_bias_cache[all_shapes], cat_tensors
416
+
417
+
418
+ def drop_add_residual_stochastic_depth_list(
419
+ x_list: List[torch.Tensor],
420
+ residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
421
+ sample_drop_ratio: float = 0.0,
422
+ scaling_vector=None,
423
+ ) -> torch.Tensor:
424
+ # 1) generate random set of indices for dropping samples in the batch
425
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
426
+ branges = [s[0] for s in branges_scales]
427
+ residual_scale_factors = [s[1] for s in branges_scales]
428
+
429
+ # 2) get attention bias and index+concat the tensors
430
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
431
+
432
+ # 3) apply residual_func to get residual, and split the result
433
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
434
+
435
+ outputs = []
436
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
437
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
438
+ return outputs
439
+
440
+
441
+ class NestedTensorBlock(Block):
442
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
443
+ """
444
+ x_list contains a list of tensors to nest together and run
445
+ """
446
+ assert isinstance(self.attn, MemEffAttention)
447
+
448
+ if self.training and self.sample_drop_ratio > 0.0:
449
+
450
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
451
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
452
+
453
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
454
+ return self.mlp(self.norm2(x))
455
+
456
+ x_list = drop_add_residual_stochastic_depth_list(
457
+ x_list,
458
+ residual_func=attn_residual_func,
459
+ sample_drop_ratio=self.sample_drop_ratio,
460
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
461
+ )
462
+ x_list = drop_add_residual_stochastic_depth_list(
463
+ x_list,
464
+ residual_func=ffn_residual_func,
465
+ sample_drop_ratio=self.sample_drop_ratio,
466
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
467
+ )
468
+ return x_list
469
+ else:
470
+
471
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
472
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
473
+
474
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
475
+ return self.ls2(self.mlp(self.norm2(x)))
476
+
477
+ attn_bias, x = get_attn_bias_and_cat(x_list)
478
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
479
+ x = x + ffn_residual_func(x)
480
+ return attn_bias.split(x)
481
+
482
+ def forward(self, x_or_x_list):
483
+ if isinstance(x_or_x_list, torch.Tensor):
484
+ return super().forward(x_or_x_list)
485
+ elif isinstance(x_or_x_list, list):
486
+ if not XFORMERS_AVAILABLE:
487
+ raise AssertionError("xFormers is required for using nested tensors")
488
+ return self.forward_nested(x_or_x_list)
489
+ else:
490
+ raise AssertionError
491
+
492
+
493
+ class DinoVisionTransformer(nn.Module):
494
+ def __init__(
495
+ self,
496
+ img_size=224,
497
+ patch_size=16,
498
+ in_chans=3,
499
+ embed_dim=768,
500
+ depth=12,
501
+ num_heads=12,
502
+ mlp_ratio=4.0,
503
+ qkv_bias=True,
504
+ ffn_bias=True,
505
+ proj_bias=True,
506
+ drop_path_rate=0.0,
507
+ drop_path_uniform=False,
508
+ init_values=None, # for layerscale: None or 0 => no layerscale
509
+ embed_layer=PatchEmbed,
510
+ act_layer=nn.GELU,
511
+ block_fn=NestedTensorBlock,
512
+ ffn_layer="mlp",
513
+ block_chunks=1,
514
+ num_register_tokens=0,
515
+ interpolate_antialias=False,
516
+ interpolate_offset=0.1,
517
+ ):
518
+ """
519
+ Args:
520
+ img_size (int, tuple): input image size
521
+ patch_size (int, tuple): patch size
522
+ in_chans (int): number of input channels
523
+ embed_dim (int): embedding dimension
524
+ depth (int): depth of transformer
525
+ num_heads (int): number of attention heads
526
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
527
+ qkv_bias (bool): enable bias for qkv if True
528
+ proj_bias (bool): enable bias for proj in attn if True
529
+ ffn_bias (bool): enable bias for ffn if True
530
+ drop_path_rate (float): stochastic depth rate
531
+ drop_path_uniform (bool): apply uniform drop rate across blocks
532
+ weight_init (str): weight init scheme
533
+ init_values (float): layer-scale init values
534
+ embed_layer (nn.Module): patch embedding layer
535
+ act_layer (nn.Module): MLP activation layer
536
+ block_fn (nn.Module): transformer block class
537
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
538
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
539
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
540
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
541
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
542
+ """
543
+ super().__init__()
544
+ norm_layer = functools.partial(nn.LayerNorm, eps=1e-6)
545
+
546
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
547
+ self.num_tokens = 1
548
+ self.n_blocks = depth
549
+ self.num_heads = num_heads
550
+ self.patch_size = patch_size
551
+ self.num_register_tokens = num_register_tokens
552
+ self.interpolate_antialias = interpolate_antialias
553
+ self.interpolate_offset = interpolate_offset
554
+
555
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
556
+ num_patches = self.patch_embed.num_patches
557
+
558
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
559
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
560
+ assert num_register_tokens >= 0
561
+ self.register_tokens = (
562
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
563
+ )
564
+
565
+ if drop_path_uniform is True:
566
+ dpr = [drop_path_rate] * depth
567
+ else:
568
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
569
+
570
+ if ffn_layer == "mlp":
571
+ ffn_layer = Mlp
572
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
573
+ ffn_layer = SwiGLUFFNFused
574
+ elif ffn_layer == "identity":
575
+
576
+ def f(*args, **kwargs):
577
+ return nn.Identity()
578
+
579
+ ffn_layer = f
580
+ else:
581
+ raise NotImplementedError
582
+
583
+ blocks_list = [
584
+ block_fn(
585
+ dim=embed_dim,
586
+ num_heads=num_heads,
587
+ mlp_ratio=mlp_ratio,
588
+ qkv_bias=qkv_bias,
589
+ proj_bias=proj_bias,
590
+ ffn_bias=ffn_bias,
591
+ drop_path=dpr[i],
592
+ norm_layer=norm_layer,
593
+ act_layer=act_layer,
594
+ ffn_layer=ffn_layer,
595
+ init_values=init_values,
596
+ )
597
+ for i in range(depth)
598
+ ]
599
+ if block_chunks > 0:
600
+ self.chunked_blocks = True
601
+ chunked_blocks = []
602
+ chunksize = depth // block_chunks
603
+ for i in range(0, depth, chunksize):
604
+ # this is to keep the block index consistent if we chunk the block list
605
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
606
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
607
+ else:
608
+ self.chunked_blocks = False
609
+ self.blocks = nn.ModuleList(blocks_list)
610
+
611
+ self.norm = norm_layer(embed_dim)
612
+ self.head = nn.Identity()
613
+
614
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
615
+
616
+ self.init_weights()
617
+
618
+ def init_weights(self):
619
+ trunc_normal_(self.pos_embed, std=0.02)
620
+ nn.init.normal_(self.cls_token, std=1e-6)
621
+ if self.register_tokens is not None:
622
+ nn.init.normal_(self.register_tokens, std=1e-6)
623
+ named_apply(init_weights_vit_timm, self)
624
+
625
+ def interpolate_pos_encoding(self, x, w, h):
626
+ previous_dtype = x.dtype
627
+ npatch = x.shape[1] - 1
628
+ N = self.pos_embed.shape[1] - 1
629
+ if npatch == N and w == h:
630
+ return self.pos_embed
631
+ pos_embed = self.pos_embed.float()
632
+ class_pos_embed = pos_embed[:, 0]
633
+ patch_pos_embed = pos_embed[:, 1:]
634
+ dim = x.shape[-1]
635
+ w0 = w // self.patch_size
636
+ h0 = h // self.patch_size
637
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
638
+ assert N == M * M
639
+ kwargs = {}
640
+ if self.interpolate_offset:
641
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
642
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
643
+ sx = float(w0 + self.interpolate_offset) / M
644
+ sy = float(h0 + self.interpolate_offset) / M
645
+ kwargs["scale_factor"] = (sx, sy)
646
+ else:
647
+ # Simply specify an output size instead of a scale factor
648
+ kwargs["size"] = (w0, h0)
649
+ patch_pos_embed = nn.functional.interpolate(
650
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
651
+ mode="bicubic",
652
+ antialias=self.interpolate_antialias,
653
+ **kwargs,
654
+ )
655
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
656
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
657
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
658
+
659
+ def prepare_tokens_with_masks(self, x, masks=None):
660
+ B, nc, w, h = x.shape
661
+ x = self.patch_embed(x)
662
+ if masks is not None:
663
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
664
+
665
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
666
+ x = x + self.interpolate_pos_encoding(x, w, h)
667
+
668
+ if self.register_tokens is not None:
669
+ x = torch.cat(
670
+ (
671
+ x[:, :1],
672
+ self.register_tokens.expand(x.shape[0], -1, -1),
673
+ x[:, 1:],
674
+ ),
675
+ dim=1,
676
+ )
677
+
678
+ return x
679
+
680
+ def forward_features_list(self, x_list, masks_list):
681
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
682
+ for blk in self.blocks:
683
+ x = blk(x)
684
+
685
+ all_x = x
686
+ output = []
687
+ for x, masks in zip(all_x, masks_list):
688
+ x_norm = self.norm(x)
689
+ output.append(
690
+ {
691
+ "x_norm_clstoken": x_norm[:, 0],
692
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
693
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
694
+ "x_prenorm": x,
695
+ "masks": masks,
696
+ }
697
+ )
698
+ return output
699
+
700
+ def forward_features(self, x, masks=None):
701
+ if isinstance(x, list):
702
+ return self.forward_features_list(x, masks)
703
+
704
+ x = self.prepare_tokens_with_masks(x, masks)
705
+
706
+ for blk in self.blocks:
707
+ x = blk(x)
708
+
709
+ x_norm = self.norm(x)
710
+ return {
711
+ "x_norm_clstoken": x_norm[:, 0],
712
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
713
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
714
+ "x_prenorm": x,
715
+ "masks": masks,
716
+ }
717
+
718
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
719
+ x = self.prepare_tokens_with_masks(x)
720
+ # If n is an int, take the n last blocks. If it's a list, take them
721
+ output, total_block_len = [], len(self.blocks)
722
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
723
+ for i, blk in enumerate(self.blocks):
724
+ x = blk(x)
725
+ if i in blocks_to_take:
726
+ output.append(x)
727
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
728
+ return output
729
+
730
+ def _get_intermediate_layers_chunked(self, x, n=1):
731
+ x = self.prepare_tokens_with_masks(x)
732
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
733
+ # If n is an int, take the n last blocks. If it's a list, take them
734
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
735
+ for block_chunk in self.blocks:
736
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
737
+ x = blk(x)
738
+ if i in blocks_to_take:
739
+ output.append(x)
740
+ i += 1
741
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
742
+ return output
743
+
744
+ def get_intermediate_layers(
745
+ self,
746
+ x: torch.torch.Tensor,
747
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
748
+ reshape: bool = False,
749
+ return_class_token: bool = False,
750
+ norm=True,
751
+ ) -> Tuple[Union[torch.torch.Tensor, Tuple[torch.torch.Tensor]]]:
752
+ if self.chunked_blocks:
753
+ outputs = self._get_intermediate_layers_chunked(x, n)
754
+ else:
755
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
756
+ if norm:
757
+ outputs = [self.norm(out) for out in outputs]
758
+ class_tokens = [out[:, 0] for out in outputs]
759
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
760
+ if reshape:
761
+ B, _, w, h = x.shape
762
+ outputs = [
763
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
764
+ for out in outputs
765
+ ]
766
+ if return_class_token:
767
+ return tuple(zip(outputs, class_tokens))
768
+ return tuple(outputs)
769
+
770
+ def forward(self, *args, is_training=False, **kwargs):
771
+ ret = self.forward_features(*args, **kwargs)
772
+ if is_training:
773
+ return ret
774
+ else:
775
+ return self.head(ret["x_norm_clstoken"])
776
+
777
+
778
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
779
+ """ViT weight initialization, original timm impl (for reproducibility)"""
780
+ if isinstance(module, nn.Linear):
781
+ trunc_normal_(module.weight, std=0.02)
782
+ if module.bias is not None:
783
+ nn.init.zeros_(module.bias)
784
+
785
+
786
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
787
+ model = DinoVisionTransformer(
788
+ patch_size=patch_size,
789
+ embed_dim=768,
790
+ depth=12,
791
+ num_heads=12,
792
+ mlp_ratio=4,
793
+ block_fn=functools.partial(NestedTensorBlock, attn_class=MemEffAttention),
794
+ num_register_tokens=num_register_tokens,
795
+ **kwargs,
796
+ )
797
+ return model
798
+
799
+
800
+ def load_weights(model):
801
+ weight_path = 'https://storage.googleapis.com/neurop/teacher_checkpoint.pth'
802
+ state_dict = torch.hub.load_state_dict_from_url(weight_path, map_location='cpu')["teacher"]
803
+ # remove `module.` prefix
804
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
805
+ # remove `backbone.` prefix induced by multicrop wrapper
806
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
807
+ msg = model.load_state_dict(state_dict, strict=False)
808
+ print("Pretrained weights found at {} and loaded with msg: {}".format(weight_path, msg))
809
+ return model
810
+
811
+
812
+ def get_model(name):
813
+ assert name == 'vitb14_dinov2_imagenet1k'
814
+ vit_kwargs = dict(
815
+ img_size=224,
816
+ patch_size=14,
817
+ init_values=1.0e-05,
818
+ ffn_layer="mlp",
819
+ block_chunks=0,
820
+ qkv_bias=True,
821
+ proj_bias=True,
822
+ ffn_bias=True,
823
+ num_register_tokens=0,
824
+ interpolate_offset=False,
825
+ interpolate_antialias=0.1,
826
+ )
827
+ model = vit_base(**vit_kwargs)
828
+ model = load_weights(model)
829
+ preprocessing = functools.partial(load_preprocess_images, image_size=224)
830
+ wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)
831
+ wrapper.image_size = 224
832
+ return wrapper
833
+
834
+
835
+ def get_layers(name):
836
+ assert name == 'vitb14_dinov2_imagenet1k'
837
+ return [f'blocks.{i}' for i in range(12)]
838
+
839
+
840
+ def get_bibtex(model_identifier):
841
+ return '''@inproceedings{
842
+ dosovitskiy2021an,
843
+ title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
844
+ author={Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
845
+ booktitle={International Conference on Learning Representations},
846
+ year={2021},
847
+ url={https://openreview.net/forum?id=YicbFdNTTy}
848
+ }'''
849
+
850
+
851
+ if __name__ == '__main__':
852
+ check_models.check_base_models(__name__)