birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/cait.py CHANGED
@@ -47,7 +47,7 @@ class ClassAttention(nn.Module):
47
47
  self.proj_drop = nn.Dropout(proj_drop)
48
48
 
49
49
  def forward(self, x: torch.Tensor) -> torch.Tensor:
50
- (B, N, C) = x.shape
50
+ B, N, C = x.shape
51
51
  q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
52
52
  k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
53
53
  v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
@@ -66,12 +66,12 @@ class ClassAttentionBlock(nn.Module):
66
66
  self, dim: int, num_heads: int, mlp_ratio: float, qkv_bias: bool, proj_drop: float, drop_path: float, eta: float
67
67
  ) -> None:
68
68
  super().__init__()
69
- self.norm1 = nn.LayerNorm(dim)
69
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
70
70
 
71
71
  self.attn = ClassAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=proj_drop)
72
72
 
73
73
  self.drop_path = StochasticDepth(drop_path, mode="row")
74
- self.norm2 = nn.LayerNorm(dim)
74
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
75
75
  self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
76
76
 
77
77
  self.gamma1 = nn.Parameter(eta * torch.ones(dim))
@@ -103,7 +103,7 @@ class TalkingHeadAttn(nn.Module):
103
103
  self.proj_drop = nn.Dropout(proj_drop)
104
104
 
105
105
  def forward(self, x: torch.Tensor) -> torch.Tensor:
106
- (B, N, C) = x.shape
106
+ B, N, C = x.shape
107
107
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
108
108
  q = qkv[0] * self.scale
109
109
  k = qkv[1]
@@ -135,7 +135,7 @@ class LayerScaleBlock(nn.Module):
135
135
  init_values: float,
136
136
  ) -> None:
137
137
  super().__init__()
138
- self.norm1 = nn.LayerNorm(dim)
138
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
139
139
  self.attn = TalkingHeadAttn(
140
140
  dim,
141
141
  num_heads=num_heads,
@@ -144,7 +144,7 @@ class LayerScaleBlock(nn.Module):
144
144
  proj_drop=proj_drop,
145
145
  )
146
146
  self.drop_path = StochasticDepth(drop_path, mode="row")
147
- self.norm2 = nn.LayerNorm(dim)
147
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
148
148
  self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
149
149
  self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
150
150
  self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
@@ -221,7 +221,7 @@ class CaiT(BaseNet):
221
221
  )
222
222
  )
223
223
 
224
- self.norm = nn.LayerNorm(embed_dim)
224
+ self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
225
225
 
226
226
  self.embedding_size = embed_dim
227
227
  self.classifier = self.create_classifier()
birder/net/cas_vit.py CHANGED
@@ -122,7 +122,7 @@ class AdditiveTokenMixer(nn.Module):
122
122
  self.proj_drop = nn.Dropout(proj_drop)
123
123
 
124
124
  def forward(self, x: torch.Tensor) -> torch.Tensor:
125
- (q, k, v) = self.qkv(x).chunk(3, dim=1)
125
+ q, k, v = self.qkv(x).chunk(3, dim=1)
126
126
  q = self.op_q(q)
127
127
  k = self.op_k(k)
128
128
 
birder/net/coat.py CHANGED
@@ -21,7 +21,7 @@ from birder.net.base import DetectorBackbone
21
21
 
22
22
 
23
23
  def insert_cls(x: torch.Tensor, cls_token: torch.Tensor) -> torch.Tensor:
24
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
24
+ cls_tokens = cls_token.expand(x.size(0), -1, -1)
25
25
  x = torch.concat((cls_tokens, x), dim=1)
26
26
 
27
27
  return x
@@ -57,8 +57,8 @@ class ConvRelPosEnc(nn.Module):
57
57
  self.channel_splits = [x * head_channels for x in head_splits]
58
58
 
59
59
  def forward(self, q: torch.Tensor, v: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
60
- (B, num_heads, N, C) = q.size()
61
- (H, W) = size
60
+ B, num_heads, N, C = q.size()
61
+ H, W = size
62
62
  torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
63
63
 
64
64
  # Convolutional relative position encoding.
@@ -102,11 +102,11 @@ class FactorAttnConvRelPosEnc(nn.Module):
102
102
  self.crpe = shared_crpe
103
103
 
104
104
  def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
105
- (B, N, C) = x.size()
105
+ B, N, C = x.size()
106
106
 
107
107
  # Generate Q, K, V
108
108
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
109
- (q, k, v) = qkv.unbind(0) # [B, h, N, Ch]
109
+ q, k, v = qkv.unbind(0) # [B, h, N, Ch]
110
110
 
111
111
  # Factorized attention
112
112
  k_softmax = k.softmax(dim=2)
@@ -135,8 +135,8 @@ class ConvPosEnc(nn.Module):
135
135
  )
136
136
 
137
137
  def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
138
- (B, N, C) = x.size()
139
- (H, W) = size
138
+ B, N, C = x.size()
139
+ H, W = size
140
140
  torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
141
141
 
142
142
  # Extract CLS token and image tokens
@@ -170,7 +170,7 @@ class SerialBlock(nn.Module):
170
170
 
171
171
  # Conv-attention
172
172
  self.cpe = shared_cpe
173
- self.norm1 = nn.LayerNorm(dim)
173
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
174
174
  self.factor_attn_crpe = FactorAttnConvRelPosEnc(
175
175
  dim,
176
176
  num_heads=num_heads,
@@ -181,7 +181,7 @@ class SerialBlock(nn.Module):
181
181
  self.drop_path = StochasticDepth(drop_path, mode="row")
182
182
 
183
183
  # MLP
184
- self.norm2 = nn.LayerNorm(dim)
184
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
185
185
  self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, dropout=proj_drop)
186
186
 
187
187
  def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
@@ -213,9 +213,9 @@ class ParallelBlock(nn.Module):
213
213
  super().__init__()
214
214
 
215
215
  # Conv-attention
216
- self.norm12 = nn.LayerNorm(dims[1])
217
- self.norm13 = nn.LayerNorm(dims[2])
218
- self.norm14 = nn.LayerNorm(dims[3])
216
+ self.norm12 = nn.LayerNorm(dims[1], eps=1e-6)
217
+ self.norm13 = nn.LayerNorm(dims[2], eps=1e-6)
218
+ self.norm14 = nn.LayerNorm(dims[3], eps=1e-6)
219
219
  self.factor_attn_crpe2 = FactorAttnConvRelPosEnc(
220
220
  dims[1], num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=proj_drop, shared_crpe=shared_crpes[1]
221
221
  )
@@ -228,9 +228,9 @@ class ParallelBlock(nn.Module):
228
228
  self.drop_path = StochasticDepth(drop_path, mode="row")
229
229
 
230
230
  # MLP
231
- self.norm22 = nn.LayerNorm(dims[1])
232
- self.norm23 = nn.LayerNorm(dims[2])
233
- self.norm24 = nn.LayerNorm(dims[3])
231
+ self.norm22 = nn.LayerNorm(dims[1], eps=1e-6)
232
+ self.norm23 = nn.LayerNorm(dims[2], eps=1e-6)
233
+ self.norm24 = nn.LayerNorm(dims[3], eps=1e-6)
234
234
 
235
235
  # In the parallel block, we assume dimensions are the same and share the linear transformation
236
236
  assert dims[1] == dims[2] == dims[3]
@@ -244,8 +244,8 @@ class ParallelBlock(nn.Module):
244
244
  return self.interpolate(x, scale_factor=1.0 / factor, size=size)
245
245
 
246
246
  def interpolate(self, x: torch.Tensor, scale_factor: float, size: tuple[int, int]) -> torch.Tensor:
247
- (B, N, C) = x.size()
248
- (H, W) = size
247
+ B, N, C = x.size()
248
+ H, W = size
249
249
  torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
250
250
 
251
251
  cls_token = x[:, :1, :]
@@ -268,7 +268,7 @@ class ParallelBlock(nn.Module):
268
268
  def forward(
269
269
  self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, sizes: list[tuple[int, int]]
270
270
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
271
- (_, s2, s3, s4) = sizes
271
+ _, s2, s3, s4 = sizes
272
272
  cur2 = self.norm12(x2)
273
273
  cur3 = self.norm13(x3)
274
274
  cur4 = self.norm14(x4)
@@ -310,7 +310,7 @@ class PatchEmbed(nn.Module):
310
310
 
311
311
  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
312
312
  x = self.proj(x)
313
- (H, W) = x.shape[2:4]
313
+ H, W = x.shape[2:4]
314
314
 
315
315
  x = x.flatten(2).transpose(1, 2)
316
316
  x = self.norm(x)
@@ -447,13 +447,13 @@ class CoaT(DetectorBackbone):
447
447
 
448
448
  # Norms
449
449
  if self.parallel_blocks is not None:
450
- self.norm2 = nn.LayerNorm(embed_dims[1])
451
- self.norm3 = nn.LayerNorm(embed_dims[2])
450
+ self.norm2 = nn.LayerNorm(embed_dims[1], eps=1e-6)
451
+ self.norm3 = nn.LayerNorm(embed_dims[2], eps=1e-6)
452
452
  else:
453
453
  self.norm2 = None
454
454
  self.norm3 = None
455
455
 
456
- self.norm4 = nn.LayerNorm(embed_dims[3])
456
+ self.norm4 = nn.LayerNorm(embed_dims[3], eps=1e-6)
457
457
 
458
458
  # Head
459
459
  if parallel_depth > 0:
@@ -500,7 +500,7 @@ class CoaT(DetectorBackbone):
500
500
  B = x.shape[0]
501
501
 
502
502
  # Serial blocks 1
503
- (x1, (h1, w1)) = self.patch_embed1(x)
503
+ x1, (h1, w1) = self.patch_embed1(x)
504
504
  x1 = insert_cls(x1, self.cls_token1)
505
505
  for blk in self.serial_blocks1:
506
506
  x1 = blk(x1, size=(h1, w1))
@@ -508,7 +508,7 @@ class CoaT(DetectorBackbone):
508
508
  x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
509
509
 
510
510
  # Serial blocks 2
511
- (x2, (h2, w2)) = self.patch_embed2(x1_no_cls)
511
+ x2, (h2, w2) = self.patch_embed2(x1_no_cls)
512
512
  x2 = insert_cls(x2, self.cls_token2)
513
513
  for blk in self.serial_blocks2:
514
514
  x2 = blk(x2, size=(h2, w2))
@@ -516,7 +516,7 @@ class CoaT(DetectorBackbone):
516
516
  x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
517
517
 
518
518
  # Serial blocks 3
519
- (x3, (h3, w3)) = self.patch_embed3(x2_no_cls)
519
+ x3, (h3, w3) = self.patch_embed3(x2_no_cls)
520
520
  x3 = insert_cls(x3, self.cls_token3)
521
521
  for blk in self.serial_blocks3:
522
522
  x3 = blk(x3, size=(h3, w3))
@@ -524,7 +524,7 @@ class CoaT(DetectorBackbone):
524
524
  x3_no_cls = remove_cls(x3).reshape(B, h3, w3, -1).permute(0, 3, 1, 2).contiguous()
525
525
 
526
526
  # Serial blocks 4
527
- (x4, (h4, w4)) = self.patch_embed4(x3_no_cls)
527
+ x4, (h4, w4) = self.patch_embed4(x3_no_cls)
528
528
  x4 = insert_cls(x4, self.cls_token4)
529
529
  for blk in self.serial_blocks4:
530
530
  x4 = blk(x4, size=(h4, w4))
@@ -537,7 +537,7 @@ class CoaT(DetectorBackbone):
537
537
  x2 = self.cpe2(x2, (h2, w2))
538
538
  x3 = self.cpe3(x3, (h3, w3))
539
539
  x4 = self.cpe4(x4, (h4, w4))
540
- (x1, x2, x3, x4) = blk(x1, x2, x3, x4, sizes=[(h1, w1), (h2, w2), (h3, w3), (h4, w4)])
540
+ x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(h1, w1), (h2, w2), (h3, w3), (h4, w4)])
541
541
 
542
542
  x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
543
543
  x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
birder/net/conv2former.py CHANGED
@@ -64,7 +64,7 @@ class SpatialAttention(nn.Module):
64
64
  dim,
65
65
  kernel_size=kernel_size,
66
66
  stride=(1, 1),
67
- padding=(kernel_size[0] // 2, kernel_size[1] // 2),
67
+ padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
68
68
  groups=dim,
69
69
  ),
70
70
  )
@@ -87,8 +87,8 @@ class Conv2FormerBlock(nn.Module):
87
87
  self.mlp = MLP(dim, mlp_ratio)
88
88
 
89
89
  layer_scale_init_value = 1e-6
90
- self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)), requires_grad=True)
91
- self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)), requires_grad=True)
90
+ self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
91
+ self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1, 1)))
92
92
 
93
93
  def forward(self, x: torch.Tensor) -> torch.Tensor:
94
94
  x = x + self.drop_path(self.layer_scale_1 * self.attn(x))
birder/net/convmixer.py CHANGED
@@ -58,7 +58,7 @@ class ConvMixer(BaseNet):
58
58
  inplace=None,
59
59
  )
60
60
 
61
- padding = (kernel_size[0] // 2, kernel_size[1] // 2)
61
+ padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
62
62
  self.body = nn.Sequential(
63
63
  *[
64
64
  nn.Sequential(
birder/net/convnext_v1.py CHANGED
@@ -37,15 +37,7 @@ class ConvNeXtBlock(nn.Module):
37
37
  ) -> None:
38
38
  super().__init__()
39
39
  self.block = nn.Sequential(
40
- nn.Conv2d(
41
- channels,
42
- channels,
43
- kernel_size=(7, 7),
44
- stride=(1, 1),
45
- padding=(3, 3),
46
- groups=channels,
47
- bias=True,
48
- ),
40
+ nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
49
41
  Permute([0, 2, 3, 1]),
50
42
  nn.LayerNorm(channels, eps=1e-6),
51
43
  nn.Linear(channels, 4 * channels), # Same as 1x1 conv
@@ -53,7 +45,7 @@ class ConvNeXtBlock(nn.Module):
53
45
  nn.Linear(4 * channels, channels), # Same as 1x1 conv
54
46
  Permute([0, 3, 1, 2]),
55
47
  )
56
- self.layer_scale = nn.Parameter(torch.ones(channels, 1, 1) * layer_scale, requires_grad=True)
48
+ self.layer_scale = nn.Parameter(torch.ones(channels, 1, 1) * layer_scale)
57
49
  self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
58
50
 
59
51
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -119,7 +111,7 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
119
111
  layers.append(
120
112
  nn.Sequential(
121
113
  LayerNorm2d(i, eps=1e-6),
122
- nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True),
114
+ nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
123
115
  )
124
116
  )
125
117
 
@@ -0,0 +1,198 @@
1
+ """
2
+ ConvNeXt v1 Isotropic, adapted from
3
+ https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext_isotropic.py
4
+
5
+ Paper "A ConvNet for the 2020s", https://arxiv.org/abs/2201.03545
6
+ """
7
+
8
+ # Reference license: MIT
9
+
10
+ from functools import partial
11
+ from typing import Any
12
+ from typing import Literal
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torchvision.ops import Permute
18
+ from torchvision.ops import StochasticDepth
19
+
20
+ from birder.common.masking import mask_tensor
21
+ from birder.layers import LayerNorm2d
22
+ from birder.model_registry import registry
23
+ from birder.net.base import DetectorBackbone
24
+ from birder.net.base import MaskedTokenRetentionMixin
25
+ from birder.net.base import PreTrainEncoder
26
+ from birder.net.base import TokenRetentionResultType
27
+ from birder.net.base import normalize_out_indices
28
+
29
+
30
+ class ConvNeXtBlock(nn.Module):
31
+ def __init__(self, channels: int, stochastic_depth_prob: float) -> None:
32
+ super().__init__()
33
+ self.block = nn.Sequential(
34
+ nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
35
+ Permute([0, 2, 3, 1]),
36
+ nn.LayerNorm(channels, eps=1e-6),
37
+ nn.Linear(channels, 4 * channels),
38
+ nn.GELU(),
39
+ nn.Linear(4 * channels, channels),
40
+ Permute([0, 3, 1, 2]),
41
+ )
42
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ identity = x
46
+ x = self.block(x)
47
+ x = self.stochastic_depth(x)
48
+ x += identity
49
+
50
+ return x
51
+
52
+
53
+ # pylint: disable=invalid-name
54
+ class ConvNeXt_v1_Isotropic(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
55
+ block_group_regex = r"body\.(\d+)"
56
+
57
+ def __init__(
58
+ self,
59
+ input_channels: int,
60
+ num_classes: int,
61
+ *,
62
+ config: Optional[dict[str, Any]] = None,
63
+ size: Optional[tuple[int, int]] = None,
64
+ ) -> None:
65
+ super().__init__(input_channels, num_classes, config=config, size=size)
66
+ assert self.config is not None, "must set config"
67
+
68
+ patch_size = 16
69
+ dim: int = self.config["dim"]
70
+ num_layers: int = self.config["num_layers"]
71
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
72
+ drop_path_rate: float = self.config["drop_path_rate"]
73
+
74
+ torch._assert(self.size[0] % patch_size == 0, "Input shape indivisible by patch size!")
75
+ torch._assert(self.size[1] % patch_size == 0, "Input shape indivisible by patch size!")
76
+ self.patch_size = patch_size
77
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
78
+
79
+ self.stem = nn.Conv2d(
80
+ self.input_channels,
81
+ dim,
82
+ kernel_size=(patch_size, patch_size),
83
+ stride=(patch_size, patch_size),
84
+ padding=(0, 0),
85
+ )
86
+
87
+ layers = []
88
+ for idx in range(num_layers):
89
+ # Adjust stochastic depth probability based on the depth of the stage block
90
+ sd_prob = drop_path_rate * idx / (num_layers - 1.0)
91
+ layers.append(ConvNeXtBlock(dim, sd_prob))
92
+
93
+ self.body = nn.Sequential(*layers)
94
+ self.features = nn.Sequential(
95
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
96
+ LayerNorm2d(dim, eps=1e-6),
97
+ nn.Flatten(1),
98
+ )
99
+
100
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
101
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
102
+ self.return_channels = [dim] * num_return_stages
103
+ self.embedding_size = dim
104
+ self.classifier = self.create_classifier()
105
+
106
+ self.max_stride = patch_size
107
+ self.stem_stride = patch_size
108
+ self.stem_width = dim
109
+ self.encoding_size = dim
110
+ self.decoder_block = partial(ConvNeXtBlock, stochastic_depth_prob=0)
111
+
112
+ # Weights initialization
113
+ for m in self.modules():
114
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
115
+ nn.init.trunc_normal_(m.weight, std=0.02)
116
+ if m.bias is not None:
117
+ nn.init.zeros_(m.bias)
118
+
119
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
120
+ x = self.stem(x)
121
+
122
+ if self.out_indices is None:
123
+ x = self.body(x)
124
+ return {self.return_stages[0]: x}
125
+
126
+ stage_num = 0
127
+ out: dict[str, torch.Tensor] = {}
128
+ for idx, module in enumerate(self.body.children()):
129
+ x = module(x)
130
+ if idx in self.out_indices:
131
+ out[self.return_stages[stage_num]] = x
132
+ stage_num += 1
133
+
134
+ return out
135
+
136
+ def freeze_stages(self, up_to_stage: int) -> None:
137
+ for param in self.stem.parameters():
138
+ param.requires_grad_(False)
139
+
140
+ for idx, module in enumerate(self.body.children()):
141
+ if idx >= up_to_stage:
142
+ break
143
+
144
+ for param in module.parameters():
145
+ param.requires_grad_(False)
146
+
147
+ def masked_encoding_retention(
148
+ self,
149
+ x: torch.Tensor,
150
+ mask: torch.Tensor,
151
+ mask_token: Optional[torch.Tensor] = None,
152
+ return_keys: Literal["all", "features", "embedding"] = "features",
153
+ ) -> TokenRetentionResultType:
154
+ x = self.stem(x)
155
+ x = mask_tensor(x, mask, patch_factor=self.max_stride // self.stem_stride, mask_token=mask_token)
156
+ x = self.body(x)
157
+
158
+ result: TokenRetentionResultType = {}
159
+ if return_keys in ("all", "features"):
160
+ result["features"] = x
161
+ if return_keys in ("all", "embedding"):
162
+ result["embedding"] = self.features(x)
163
+
164
+ return result
165
+
166
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
167
+ x = self.stem(x)
168
+ return self.body(x)
169
+
170
+ def embedding(self, x: torch.Tensor) -> torch.Tensor:
171
+ x = self.forward_features(x)
172
+ return self.features(x)
173
+
174
+ def adjust_size(self, new_size: tuple[int, int]) -> None:
175
+ if new_size == self.size:
176
+ return
177
+
178
+ assert new_size[0] % self.patch_size == 0, "Input shape indivisible by patch size!"
179
+ assert new_size[1] % self.patch_size == 0, "Input shape indivisible by patch size!"
180
+
181
+ super().adjust_size(new_size)
182
+
183
+
184
+ registry.register_model_config(
185
+ "convnext_v1_iso_small",
186
+ ConvNeXt_v1_Isotropic,
187
+ config={"dim": 384, "num_layers": 18, "drop_path_rate": 0.1},
188
+ )
189
+ registry.register_model_config(
190
+ "convnext_v1_iso_base",
191
+ ConvNeXt_v1_Isotropic,
192
+ config={"in_channels": 768, "num_layers": 18, "drop_path_rate": 0.2},
193
+ )
194
+ registry.register_model_config(
195
+ "convnext_v1_iso_large",
196
+ ConvNeXt_v1_Isotropic,
197
+ config={"in_channels": 1024, "num_layers": 36, "drop_path_rate": 0.5},
198
+ )
birder/net/convnext_v2.py CHANGED
@@ -56,15 +56,7 @@ class ConvNeXtBlock(nn.Module):
56
56
  ) -> None:
57
57
  super().__init__()
58
58
  self.block = nn.Sequential(
59
- nn.Conv2d(
60
- channels,
61
- channels,
62
- kernel_size=(7, 7),
63
- stride=(1, 1),
64
- padding=(3, 3),
65
- groups=channels,
66
- bias=True,
67
- ),
59
+ nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
68
60
  Permute([0, 2, 3, 1]),
69
61
  nn.LayerNorm(channels, eps=1e-6),
70
62
  nn.Linear(channels, 4 * channels), # Same as 1x1 conv
@@ -137,7 +129,7 @@ class ConvNeXt_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
137
129
  layers.append(
138
130
  nn.Sequential(
139
131
  LayerNorm2d(i, eps=1e-6),
140
- nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True),
132
+ nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
141
133
  )
142
134
  )
143
135
 
birder/net/crossformer.py CHANGED
@@ -120,9 +120,9 @@ class Attention(nn.Module):
120
120
  self.relative_position_index = nn.Buffer(relative_position_index)
121
121
 
122
122
  def forward(self, x: torch.Tensor) -> torch.Tensor:
123
- (B, N, C) = x.size()
123
+ B, N, C = x.size()
124
124
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
125
- (q, k, v) = qkv.unbind(0)
125
+ q, k, v = qkv.unbind(0)
126
126
 
127
127
  q = q * self.scale
128
128
  attn = q @ k.transpose(-2, -1)
@@ -188,15 +188,15 @@ class CrossFormerBlock(nn.Module):
188
188
  self.drop_path = StochasticDepth(drop_path, mode="row")
189
189
 
190
190
  def forward(self, x: torch.Tensor) -> torch.Tensor:
191
- (H, W) = self.input_resolution
192
- (B, _, C) = x.size()
191
+ H, W = self.input_resolution
192
+ B, _, C = x.size()
193
193
 
194
194
  shortcut = x
195
195
  x = self.norm1(x)
196
196
  x = x.view(B, H, W, C)
197
197
 
198
198
  # Group embeddings
199
- (GH, GW) = self.group_size # pylint: disable=invalid-name
199
+ GH, GW = self.group_size # pylint: disable=invalid-name
200
200
  if self.use_lda is False:
201
201
  x = x.reshape(B, H // GH, GH, W // GW, GW, C).permute(0, 1, 3, 2, 4, 5)
202
202
  else:
@@ -244,8 +244,8 @@ class PatchMerging(nn.Module):
244
244
  )
245
245
 
246
246
  def forward(self, x: torch.Tensor) -> torch.Tensor:
247
- (H, W) = self.input_resolution
248
- (B, _, C) = x.shape
247
+ H, W = self.input_resolution
248
+ B, _, C = x.shape
249
249
 
250
250
  x = self.norm(x)
251
251
  x = x.view(B, H, W, C).permute(0, 3, 1, 2)
@@ -396,8 +396,8 @@ class CrossFormer(DetectorBackbone):
396
396
  for name, module in self.body.named_children():
397
397
  x = module(x)
398
398
  if name in self.return_stages:
399
- (H, W) = module.resolution
400
- (B, _, C) = x.size()
399
+ H, W = module.resolution
400
+ B, _, C = x.size()
401
401
  out[name] = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
402
402
 
403
403
  return out
birder/net/crossvit.py CHANGED
@@ -74,7 +74,7 @@ class CrossAttention(nn.Module):
74
74
  self.proj_drop = nn.Dropout(proj_drop)
75
75
 
76
76
  def forward(self, x: torch.Tensor) -> torch.Tensor:
77
- (B, N, C) = x.shape
77
+ B, N, C = x.shape
78
78
  # B1C -> B1H(C/H) -> BH1(C/H)
79
79
  q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
80
80
  # BNC -> BNH(C/H) -> BHN(C/H)
@@ -97,7 +97,7 @@ class CrossAttentionBlock(nn.Module):
97
97
  self, dim: int, num_heads: int, qkv_bias: bool, proj_drop: float, attn_drop: float, drop_path: float
98
98
  ) -> None:
99
99
  super().__init__()
100
- self.norm1 = nn.LayerNorm(dim)
100
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
101
101
  self.attn = CrossAttention(
102
102
  dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop
103
103
  )
@@ -146,7 +146,7 @@ class MultiScaleBlock(nn.Module):
146
146
  for d in range(num_branches):
147
147
  self.projs.append(
148
148
  nn.Sequential(
149
- nn.LayerNorm(dim[d]),
149
+ nn.LayerNorm(dim[d], eps=1e-6),
150
150
  nn.GELU(),
151
151
  nn.Linear(dim[d], dim[(d + 1) % num_branches]),
152
152
  )
@@ -187,7 +187,7 @@ class MultiScaleBlock(nn.Module):
187
187
  for d in range(num_branches):
188
188
  self.revert_projs.append(
189
189
  nn.Sequential(
190
- nn.LayerNorm(dim[(d + 1) % num_branches]),
190
+ nn.LayerNorm(dim[(d + 1) % num_branches], eps=1e-6),
191
191
  nn.GELU(),
192
192
  nn.Linear(dim[(d + 1) % num_branches], dim[d]),
193
193
  )
@@ -290,7 +290,7 @@ class CrossViT(BaseNet):
290
290
  dpr_ptr += curr_depth
291
291
  self.blocks.append(block)
292
292
 
293
- self.norm = nn.ModuleList([nn.LayerNorm(embed_dim[i]) for i in range(self.num_branches)])
293
+ self.norm = nn.ModuleList([nn.LayerNorm(embed_dim[i], eps=1e-6) for i in range(self.num_branches)])
294
294
  self.embedding_size = sum(self.embed_dim)
295
295
  self.classifier = nn.ModuleList()
296
296
  for i in range(self.num_branches):
@@ -482,7 +482,7 @@ registry.register_weights(
482
482
  "formats": {
483
483
  "pt": {
484
484
  "file_size": 32.7,
485
- "sha256": "515265ed725adce09464bfd23ce612b1d1178bc22a57960db089d7148556149a",
485
+ "sha256": "08f674d8165dc97cc535f8188a5c5361751a8d0bb85061454986a21541a6fe8e",
486
486
  }
487
487
  },
488
488
  "net": {"network": "crossvit_9d", "tag": "il-common"},
birder/net/cspnet.py CHANGED
@@ -226,7 +226,7 @@ class CrossStage(nn.Module):
226
226
  def forward(self, x: torch.Tensor) -> torch.Tensor:
227
227
  x = self.conv_down(x)
228
228
  x = self.conv_exp(x)
229
- (xs, xb) = x.split(self.expand_channels // 2, dim=1)
229
+ xs, xb = x.split(self.expand_channels // 2, dim=1)
230
230
  xb = self.blocks(xb)
231
231
  xb = self.conv_transition_b(xb).contiguous()
232
232
  out = self.conv_transition(torch.concat([xs, xb], dim=1))