birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from birder.net.vit import PatchEmbed
29
29
 
30
30
 
31
31
  def img2windows(img: torch.Tensor, h_sp: int, w_sp: int) -> torch.Tensor:
32
- (B, C, H, W) = img.size()
32
+ B, C, H, W = img.size()
33
33
  img_reshape = img.view(B, C, H // h_sp, h_sp, W // w_sp, w_sp)
34
34
  img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, h_sp * w_sp, C)
35
35
 
@@ -81,7 +81,7 @@ class LePEAttention(nn.Module):
81
81
  raise ValueError("unsupported idx")
82
82
 
83
83
  def im2cswin(self, x: torch.Tensor) -> torch.Tensor:
84
- (B, _, C) = x.size()
84
+ B, _, C = x.size()
85
85
  x = x.transpose(-2, -1).contiguous().view(B, C, self.resolution[0], self.resolution[1])
86
86
  x = img2windows(x, self.h_sp, self.w_sp)
87
87
  x = x.reshape(-1, self.h_sp * self.w_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
@@ -89,7 +89,7 @@ class LePEAttention(nn.Module):
89
89
  return x
90
90
 
91
91
  def get_lepe(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
92
- (B, _, C) = x.size()
92
+ B, _, C = x.size()
93
93
  H = self.resolution[0]
94
94
  W = self.resolution[1]
95
95
  x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
@@ -107,13 +107,13 @@ class LePEAttention(nn.Module):
107
107
  return (x, lepe)
108
108
 
109
109
  def forward(self, x: torch.Tensor) -> torch.Tensor:
110
- (q, k, v) = x.unbind(0)
110
+ q, k, v = x.unbind(0)
111
111
 
112
- (B, _, C) = q.shape
112
+ B, _, C = q.shape
113
113
 
114
114
  q = self.im2cswin(q)
115
115
  k = self.im2cswin(k)
116
- (v, lepe) = self.get_lepe(v)
116
+ v, lepe = self.get_lepe(v)
117
117
 
118
118
  q = q * self.scale
119
119
  attn = q @ k.transpose(-2, -1) # B head N C @ B head C N --> B head N N
@@ -136,12 +136,12 @@ class MergeBlock(nn.Module):
136
136
  self.resolution = resolution
137
137
 
138
138
  def forward(self, x: torch.Tensor) -> torch.Tensor:
139
- (B, _, C) = x.size()
139
+ B, _, C = x.size()
140
140
  H = self.resolution[0]
141
141
  W = self.resolution[1]
142
142
  x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
143
143
  x = self.conv(x)
144
- (B, C) = x.shape[:2]
144
+ B, C = x.shape[:2]
145
145
  x = x.view(B, C, -1).transpose(-2, -1).contiguous()
146
146
  x = self.norm(x)
147
147
 
@@ -206,7 +206,7 @@ class CSWinBlock(nn.Module):
206
206
  self.drop_path = StochasticDepth(drop_path, mode="row")
207
207
 
208
208
  def forward(self, x: torch.Tensor) -> torch.Tensor:
209
- (B, _, C) = x.shape
209
+ B, _, C = x.shape
210
210
 
211
211
  qkv = self.qkv(self.norm1(x)).reshape(B, -1, 3, C).permute(2, 0, 1, 3)
212
212
  if self.branch_num == 2:
@@ -350,7 +350,7 @@ class CSWin_Transformer(DetectorBackbone):
350
350
  for name, module in self.body.named_children():
351
351
  x = module(x)
352
352
  if name in self.return_stages:
353
- (B, L, C) = x.size()
353
+ B, L, C = x.size()
354
354
  H = int(math.sqrt(L))
355
355
  W = H
356
356
  out[name] = x.transpose(-2, -1).contiguous().view(B, C, H, W)
birder/net/davit.py CHANGED
@@ -31,7 +31,7 @@ from birder.net.base import TokenRetentionResultType
31
31
 
32
32
 
33
33
  def window_partition(x: torch.Tensor, window_size: tuple[int, int]) -> torch.Tensor:
34
- (B, H, W, C) = x.shape
34
+ B, H, W, C = x.shape
35
35
  x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
36
36
  windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
37
37
 
@@ -64,7 +64,7 @@ class ConvPosEnc(nn.Module):
64
64
  dim,
65
65
  kernel_size=kernel_size,
66
66
  stride=(1, 1),
67
- padding=(kernel_size[0] // 2, kernel_size[1] // 2),
67
+ padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
68
68
  groups=dim,
69
69
  )
70
70
  if act is True:
@@ -92,10 +92,10 @@ class Downsample(nn.Module):
92
92
  )
93
93
 
94
94
  def forward(self, x: torch.Tensor) -> torch.Tensor:
95
- (_, _, H, W) = x.shape
95
+ _, _, H, W = x.shape
96
96
  x = self.norm(x)
97
97
  if self.even_k is True:
98
- (k_h, k_w) = self.conv.kernel_size
98
+ k_h, k_w = self.conv.kernel_size
99
99
  pad_r = (k_w - W % k_w) % k_w
100
100
  pad_b = (k_h - H % k_h) % k_h
101
101
  x = F.pad(x, (0, pad_r, 0, pad_b))
@@ -115,10 +115,10 @@ class ChannelAttention(nn.Module):
115
115
  self.proj = nn.Linear(dim, dim)
116
116
 
117
117
  def forward(self, x: torch.Tensor) -> torch.Tensor:
118
- (B, N, C) = x.shape
118
+ B, N, C = x.shape
119
119
 
120
120
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121
- (q, k, v) = qkv.unbind(0)
121
+ q, k, v = qkv.unbind(0)
122
122
 
123
123
  k = k * self.scale
124
124
  attn = k.transpose(-1, -2) @ v
@@ -151,7 +151,7 @@ class ChannelBlock(nn.Module):
151
151
  self.drop_path = StochasticDepth(drop_path, mode="row")
152
152
 
153
153
  def forward(self, x: torch.Tensor) -> torch.Tensor:
154
- (B, C, H, W) = x.shape
154
+ B, C, H, W = x.shape
155
155
  x = self.cpe1(x).flatten(2).transpose(1, 2)
156
156
 
157
157
  cur = self.norm1(x)
@@ -177,10 +177,10 @@ class WindowAttention(nn.Module):
177
177
  self.proj = nn.Linear(dim, dim)
178
178
 
179
179
  def forward(self, x: torch.Tensor) -> torch.Tensor:
180
- (B, N, C) = x.shape
180
+ B, N, C = x.shape
181
181
 
182
182
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
183
- (q, k, v) = qkv.unbind(0)
183
+ q, k, v = qkv.unbind(0)
184
184
 
185
185
  x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # pylint: disable=not-callable
186
186
  x = x.transpose(1, 2).reshape(B, N, C)
@@ -215,7 +215,7 @@ class SpatialBlock(nn.Module):
215
215
 
216
216
  # pylint: disable=invalid-name
217
217
  def forward(self, x: torch.Tensor) -> torch.Tensor:
218
- (B, C, H, W) = x.shape
218
+ B, C, H, W = x.shape
219
219
 
220
220
  shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
221
221
 
@@ -226,7 +226,7 @@ class SpatialBlock(nn.Module):
226
226
  pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
227
227
  pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
228
228
  x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
229
- (_, Hp, Wp, _) = x.shape
229
+ _, Hp, Wp, _ = x.shape
230
230
 
231
231
  x_windows = window_partition(x, self.window_size)
232
232
  x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
birder/net/deit.py CHANGED
@@ -16,13 +16,18 @@ import torch
16
16
  from torch import nn
17
17
 
18
18
  from birder.model_registry import registry
19
- from birder.net.base import BaseNet
19
+ from birder.net._vit_configs import BASE
20
+ from birder.net._vit_configs import SMALL
21
+ from birder.net._vit_configs import TINY
22
+ from birder.net.base import DetectorBackbone
23
+ from birder.net.base import normalize_out_indices
20
24
  from birder.net.vit import Encoder
21
25
  from birder.net.vit import PatchEmbed
22
26
  from birder.net.vit import adjust_position_embedding
23
27
 
24
28
 
25
- class DeiT(BaseNet):
29
+ # pylint: disable=too-many-instance-attributes
30
+ class DeiT(DetectorBackbone):
26
31
  block_group_regex = r"encoder\.block\.(\d+)"
27
32
 
28
33
  def __init__(
@@ -44,6 +49,7 @@ class DeiT(BaseNet):
44
49
  num_heads: int = self.config["num_heads"]
45
50
  hidden_dim: int = self.config["hidden_dim"]
46
51
  mlp_dim: int = self.config["mlp_dim"]
52
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
47
53
  drop_path_rate: float = self.config["drop_path_rate"]
48
54
 
49
55
  torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
@@ -53,6 +59,7 @@ class DeiT(BaseNet):
53
59
  self.num_layers = num_layers
54
60
  self.hidden_dim = hidden_dim
55
61
  self.num_special_tokens = 2
62
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
56
63
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
57
64
 
58
65
  self.conv_proj = nn.Conv2d(
@@ -61,7 +68,6 @@ class DeiT(BaseNet):
61
68
  kernel_size=(patch_size, patch_size),
62
69
  stride=(patch_size, patch_size),
63
70
  padding=(0, 0),
64
- bias=True,
65
71
  )
66
72
  self.patch_embed = PatchEmbed()
67
73
 
@@ -89,11 +95,18 @@ class DeiT(BaseNet):
89
95
  )
90
96
  self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
91
97
 
98
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
99
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
100
+ self.return_channels = [hidden_dim] * num_return_stages
92
101
  self.embedding_size = hidden_dim
93
102
  self.dist_classifier = self.create_classifier()
94
103
  self.classifier = self.create_classifier()
95
104
  self.distillation_output = False
96
105
 
106
+ self.max_stride = patch_size
107
+ self.stem_stride = patch_size
108
+ self.stem_width = hidden_dim
109
+
97
110
  # Weight initialization
98
111
  if isinstance(self.conv_proj, nn.Conv2d):
99
112
  # Init the patchify stem
@@ -129,6 +142,53 @@ class DeiT(BaseNet):
129
142
  def set_causal_attention(self, is_causal: bool = True) -> None:
130
143
  self.encoder.set_causal_attention(is_causal)
131
144
 
145
+ def transform_to_backbone(self) -> None:
146
+ super().transform_to_backbone()
147
+ self.norm = nn.Identity()
148
+ self.dist_classifier = nn.Identity()
149
+
150
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
151
+ H, W = x.shape[-2:]
152
+
153
+ # Reshape and permute the input tensor
154
+ x = self.conv_proj(x)
155
+ x = self.patch_embed(x)
156
+
157
+ # Expand the class token to the full batch
158
+ batch_class_token = self.class_token.expand(x.shape[0], -1, -1)
159
+ batch_dist_token = self.dist_token.expand(x.shape[0], -1, -1)
160
+
161
+ x = torch.concat([batch_class_token, batch_dist_token, x], dim=1)
162
+ x = x + self.pos_embedding
163
+
164
+ if self.out_indices is None:
165
+ xs = [self.encoder(x)]
166
+ else:
167
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
168
+
169
+ out: dict[str, torch.Tensor] = {}
170
+ for stage_name, stage_x in zip(self.return_stages, xs):
171
+ stage_x = stage_x[:, self.num_special_tokens :]
172
+ stage_x = stage_x.permute(0, 2, 1)
173
+ B, C, _ = stage_x.size()
174
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
175
+ out[stage_name] = stage_x
176
+
177
+ return out
178
+
179
+ def freeze_stages(self, up_to_stage: int) -> None:
180
+ for param in self.conv_proj.parameters():
181
+ param.requires_grad_(False)
182
+
183
+ self.pos_embedding.requires_grad_(False)
184
+
185
+ for idx, module in enumerate(self.encoder.children()):
186
+ if idx >= up_to_stage:
187
+ break
188
+
189
+ for param in module.parameters():
190
+ param.requires_grad_(False)
191
+
132
192
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
133
193
  # Reshape and permute the input tensor
134
194
  x = self.conv_proj(x)
@@ -200,38 +260,17 @@ class DeiT(BaseNet):
200
260
  registry.register_model_config(
201
261
  "deit_t16",
202
262
  DeiT,
203
- config={
204
- "patch_size": 16,
205
- "num_layers": 12,
206
- "num_heads": 3,
207
- "hidden_dim": 192,
208
- "mlp_dim": 768,
209
- "drop_path_rate": 0.0,
210
- },
263
+ config={"patch_size": 16, **TINY},
211
264
  )
212
265
  registry.register_model_config(
213
266
  "deit_s16",
214
267
  DeiT,
215
- config={
216
- "patch_size": 16,
217
- "num_layers": 12,
218
- "num_heads": 6,
219
- "hidden_dim": 384,
220
- "mlp_dim": 1536,
221
- "drop_path_rate": 0.1,
222
- },
268
+ config={"patch_size": 16, **SMALL, "drop_path_rate": 0.1}, # Override the SMALL definition
223
269
  )
224
270
  registry.register_model_config(
225
271
  "deit_b16",
226
272
  DeiT,
227
- config={
228
- "patch_size": 16,
229
- "num_layers": 12,
230
- "num_heads": 12,
231
- "hidden_dim": 768,
232
- "mlp_dim": 3072,
233
- "drop_path_rate": 0.1,
234
- },
273
+ config={"patch_size": 16, **BASE},
235
274
  )
236
275
 
237
276
  registry.register_weights(
@@ -242,7 +281,7 @@ registry.register_weights(
242
281
  "formats": {
243
282
  "pt": {
244
283
  "file_size": 21.7,
245
- "sha256": "ac124122dec9f1bceff383a6a555ca375ca1b613caf486dac3f29d87afac03b3",
284
+ "sha256": "68b33aba0c1be5e78d4a33e74a7c1ea72b6abb232d59f0048ff9b8342e43246e",
246
285
  }
247
286
  },
248
287
  "net": {"network": "deit_t16", "tag": "il-common"},
@@ -258,7 +297,7 @@ registry.register_weights(
258
297
  "formats": {
259
298
  "pt": {
260
299
  "file_size": 21.7,
261
- "sha256": "fafd0c3c65f9c35318f449f60485f640917736ee7b44056be55c2226909ffdb8",
300
+ "sha256": "f693e89fc350341141c55152bec9f499df63738e8423071f3b8e71801c3e5415",
262
301
  }
263
302
  },
264
303
  "net": {"network": "deit_t16", "tag": "dist-il-common"},