birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -37,36 +37,44 @@ class InvertedResidual(nn.Module):
37
37
  num_expfilter = int(round(in_channels * expansion_factor))
38
38
 
39
39
  self.shortcut = shortcut
40
- self.block = nn.Sequential(
41
- Conv2dNormActivation(
42
- in_channels,
43
- num_expfilter,
44
- kernel_size=(1, 1),
45
- stride=(1, 1),
46
- padding=(0, 0),
47
- bias=False,
48
- activation_layer=activation_layer,
49
- ),
50
- Conv2dNormActivation(
51
- num_expfilter,
52
- num_expfilter,
53
- kernel_size=kernel_size,
54
- stride=stride,
55
- padding=padding,
56
- groups=num_expfilter,
57
- bias=False,
58
- activation_layer=activation_layer,
59
- ),
60
- Conv2dNormActivation(
61
- num_expfilter,
62
- out_channels,
63
- kernel_size=(1, 1),
64
- stride=(1, 1),
65
- padding=(0, 0),
66
- bias=False,
67
- activation_layer=None,
68
- ),
40
+ layers = []
41
+ if expansion_factor != 1.0:
42
+ layers.append(
43
+ Conv2dNormActivation(
44
+ in_channels,
45
+ num_expfilter,
46
+ kernel_size=(1, 1),
47
+ stride=(1, 1),
48
+ padding=(0, 0),
49
+ bias=False,
50
+ activation_layer=activation_layer,
51
+ )
52
+ )
53
+
54
+ layers.extend(
55
+ [
56
+ Conv2dNormActivation(
57
+ num_expfilter,
58
+ num_expfilter,
59
+ kernel_size=kernel_size,
60
+ stride=stride,
61
+ padding=padding,
62
+ groups=num_expfilter,
63
+ bias=False,
64
+ activation_layer=activation_layer,
65
+ ),
66
+ Conv2dNormActivation(
67
+ num_expfilter,
68
+ out_channels,
69
+ kernel_size=(1, 1),
70
+ stride=(1, 1),
71
+ padding=(0, 0),
72
+ bias=False,
73
+ activation_layer=None,
74
+ ),
75
+ ]
69
76
  )
77
+ self.block = nn.Sequential(*layers)
70
78
 
71
79
  def forward(self, x: torch.Tensor) -> torch.Tensor:
72
80
  if self.shortcut is True:
@@ -171,6 +179,7 @@ class MobileNet_v2(DetectorBackbone):
171
179
  ),
172
180
  nn.AdaptiveAvgPool2d(output_size=(1, 1)),
173
181
  nn.Flatten(1),
182
+ nn.Dropout(0.2),
174
183
  )
175
184
  self.return_channels = return_channels[1:5]
176
185
  self.embedding_size = last_channels
@@ -230,18 +239,3 @@ registry.register_model_config("mobilenet_v2_1_25", MobileNet_v2, config={"alpha
230
239
  registry.register_model_config("mobilenet_v2_1_5", MobileNet_v2, config={"alpha": 1.5})
231
240
  registry.register_model_config("mobilenet_v2_1_75", MobileNet_v2, config={"alpha": 1.75})
232
241
  registry.register_model_config("mobilenet_v2_2_0", MobileNet_v2, config={"alpha": 2.0})
233
-
234
- registry.register_weights(
235
- "mobilenet_v2_1_0_il-common",
236
- {
237
- "description": "MobileNet v2 (1.0 multiplier) model trained on the il-common dataset",
238
- "resolution": (256, 256),
239
- "formats": {
240
- "pt": {
241
- "file_size": 10.6,
242
- "sha256": "d6182293e98c102026f7cdc0d446aaf0e511232173c4b98c1a882c9f147be6e7",
243
- }
244
- },
245
- "net": {"network": "mobilenet_v2_1_0", "tag": "il-common"},
246
- },
247
- )
@@ -3,6 +3,9 @@ MobileNet v3, adapted from
3
3
  https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py
4
4
 
5
5
  Paper "Searching for MobileNetV3", https://arxiv.org/abs/1905.02244
6
+
7
+ Changes from original:
8
+ * Using nn.BatchNorm2d with eps 1e-5 instead of 1e-3
6
9
  """
7
10
 
8
11
  # Reference license: BSD 3-Clause
@@ -113,7 +116,7 @@ class InvertedResidual(nn.Module):
113
116
 
114
117
 
115
118
  # pylint: disable=invalid-name
116
- class MobileNet_v3_Large(DetectorBackbone):
119
+ class MobileNet_v3(DetectorBackbone):
117
120
  def __init__(
118
121
  self,
119
122
  input_channels: int,
@@ -121,12 +124,12 @@ class MobileNet_v3_Large(DetectorBackbone):
121
124
  *,
122
125
  config: Optional[dict[str, Any]] = None,
123
126
  size: Optional[tuple[int, int]] = None,
124
- large: bool = True,
125
127
  ) -> None:
126
128
  super().__init__(input_channels, num_classes, config=config, size=size)
127
129
  assert self.config is not None, "must set config"
128
130
 
129
131
  alpha: float = self.config["alpha"]
132
+ large: bool = self.config["large"]
130
133
 
131
134
  if large is True:
132
135
  last_channels = int(round(1280 * max(1.0, alpha)))
@@ -268,15 +271,39 @@ class MobileNet_v3_Large(DetectorBackbone):
268
271
  )
269
272
 
270
273
 
271
- registry.register_model_config("mobilenet_v3_large_0_25", MobileNet_v3_Large, config={"alpha": 0.25})
272
- registry.register_model_config("mobilenet_v3_large_0_5", MobileNet_v3_Large, config={"alpha": 0.5})
273
- registry.register_model_config("mobilenet_v3_large_0_75", MobileNet_v3_Large, config={"alpha": 0.75})
274
- registry.register_model_config("mobilenet_v3_large_1_0", MobileNet_v3_Large, config={"alpha": 1.0})
275
- registry.register_model_config("mobilenet_v3_large_1_25", MobileNet_v3_Large, config={"alpha": 1.25})
276
- registry.register_model_config("mobilenet_v3_large_1_5", MobileNet_v3_Large, config={"alpha": 1.5})
277
- registry.register_model_config("mobilenet_v3_large_1_75", MobileNet_v3_Large, config={"alpha": 1.75})
278
- registry.register_model_config("mobilenet_v3_large_2_0", MobileNet_v3_Large, config={"alpha": 2.0})
274
+ registry.register_model_config("mobilenet_v3_small_0_25", MobileNet_v3, config={"alpha": 0.25, "large": False})
275
+ registry.register_model_config("mobilenet_v3_small_0_5", MobileNet_v3, config={"alpha": 0.5, "large": False})
276
+ registry.register_model_config("mobilenet_v3_small_0_75", MobileNet_v3, config={"alpha": 0.75, "large": False})
277
+ registry.register_model_config("mobilenet_v3_small_1_0", MobileNet_v3, config={"alpha": 1.0, "large": False})
278
+ registry.register_model_config("mobilenet_v3_small_1_25", MobileNet_v3, config={"alpha": 1.25, "large": False})
279
+ registry.register_model_config("mobilenet_v3_small_1_5", MobileNet_v3, config={"alpha": 1.5, "large": False})
280
+ registry.register_model_config("mobilenet_v3_small_1_75", MobileNet_v3, config={"alpha": 1.75, "large": False})
281
+ registry.register_model_config("mobilenet_v3_small_2_0", MobileNet_v3, config={"alpha": 2.0, "large": False})
282
+
283
+ registry.register_model_config("mobilenet_v3_large_0_25", MobileNet_v3, config={"alpha": 0.25, "large": True})
284
+ registry.register_model_config("mobilenet_v3_large_0_5", MobileNet_v3, config={"alpha": 0.5, "large": True})
285
+ registry.register_model_config("mobilenet_v3_large_0_75", MobileNet_v3, config={"alpha": 0.75, "large": True})
286
+ registry.register_model_config("mobilenet_v3_large_1_0", MobileNet_v3, config={"alpha": 1.0, "large": True})
287
+ registry.register_model_config("mobilenet_v3_large_1_25", MobileNet_v3, config={"alpha": 1.25, "large": True})
288
+ registry.register_model_config("mobilenet_v3_large_1_5", MobileNet_v3, config={"alpha": 1.5, "large": True})
289
+ registry.register_model_config("mobilenet_v3_large_1_75", MobileNet_v3, config={"alpha": 1.75, "large": True})
290
+ registry.register_model_config("mobilenet_v3_large_2_0", MobileNet_v3, config={"alpha": 2.0, "large": True})
291
+
279
292
 
293
+ registry.register_weights(
294
+ "mobilenet_v3_small_1_0_il-common",
295
+ {
296
+ "description": "MobileNet v3 small (1.0 multiplier) model trained on the il-common dataset",
297
+ "resolution": (256, 256),
298
+ "formats": {
299
+ "pt": {
300
+ "file_size": 7.4,
301
+ "sha256": "ac53227f7513fd0c0b5204ee57403de2ab6c74c4e4d1061b9168596c6b5cea48",
302
+ }
303
+ },
304
+ "net": {"network": "mobilenet_v3_small_1_0", "tag": "il-common"},
305
+ },
306
+ )
280
307
  registry.register_weights(
281
308
  "mobilenet_v3_large_0_75_il-common",
282
309
  {
@@ -142,24 +142,24 @@ class MultiQueryAttention(nn.Module):
142
142
  self.output = nn.Sequential(*output_layers)
143
143
 
144
144
  def forward(self, x: torch.Tensor) -> torch.Tensor:
145
- (B, C, H, W) = x.size()
145
+ B, C, H, W = x.size()
146
146
  q = self.query(x)
147
147
  q = q.reshape(B, self.num_heads, self.key_dim, -1)
148
148
  q = q.transpose(-1, -2).contiguous()
149
149
 
150
150
  k = self.key(x)
151
- (B, C, _, _) = k.size()
151
+ B, C, _, _ = k.size()
152
152
  k = k.reshape(B, C, -1).transpose(1, 2)
153
153
  k = k.unsqueeze(1).contiguous()
154
154
 
155
155
  v = self.value(x)
156
- (B, C, _, _) = v.size()
156
+ B, C, _, _ = v.size()
157
157
  v = v.reshape(B, C, -1).transpose(1, 2)
158
158
  v = v.unsqueeze(1).contiguous()
159
159
 
160
160
  # Calculate attention score
161
161
  attn_score = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) # pylint: disable=not-callable
162
- (B, _, _, C) = attn_score.size()
162
+ B, _, _, C = attn_score.size()
163
163
  feat_dim = C * self.num_heads
164
164
  attn_score = attn_score.transpose(1, 2)
165
165
  attn_score = (
birder/net/mobileone.py CHANGED
@@ -61,13 +61,7 @@ class MobileOneBlock(nn.Module):
61
61
 
62
62
  if reparameterized is True:
63
63
  self.reparam_conv = nn.Conv2d(
64
- in_channels,
65
- out_channels,
66
- kernel_size=kernel_size,
67
- stride=stride,
68
- padding=padding,
69
- groups=groups,
70
- bias=True,
64
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups
71
65
  )
72
66
  else:
73
67
  self.reparam_conv = None
@@ -144,7 +138,7 @@ class MobileOneBlock(nn.Module):
144
138
  if self.reparameterized is True:
145
139
  return
146
140
 
147
- (kernel, bias) = self._get_kernel_bias()
141
+ kernel, bias = self._get_kernel_bias()
148
142
  self.reparam_conv = nn.Conv2d(
149
143
  in_channels=self.in_channels,
150
144
  out_channels=self.out_channels,
@@ -152,7 +146,6 @@ class MobileOneBlock(nn.Module):
152
146
  stride=self.stride,
153
147
  padding=self.padding,
154
148
  groups=self.groups,
155
- bias=True,
156
149
  )
157
150
  self.reparam_conv.weight.data = kernel
158
151
  self.reparam_conv.bias.data = bias
@@ -178,7 +171,7 @@ class MobileOneBlock(nn.Module):
178
171
  kernel_scale = 0
179
172
  bias_scale = 0
180
173
  if self.rbr_scale is not None:
181
- (kernel_scale, bias_scale) = self._fuse_bn_tensor(self.rbr_scale)
174
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
182
175
  pad = self.kernel_size // 2
183
176
  kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
184
177
 
@@ -186,13 +179,13 @@ class MobileOneBlock(nn.Module):
186
179
  kernel_identity = 0
187
180
  bias_identity = 0
188
181
  if self.rbr_skip is not None:
189
- (kernel_identity, bias_identity) = self._fuse_bn_tensor(self.rbr_skip)
182
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
190
183
 
191
184
  # Get weights and bias of conv branches
192
185
  kernel_conv = 0
193
186
  bias_conv = 0
194
187
  for ix in range(self.num_conv_branches):
195
- (_kernel, _bias) = self._fuse_bn_tensor(self.rbr_conv[ix])
188
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
196
189
  kernel_conv += _kernel
197
190
  bias_conv += _bias
198
191
 
@@ -1,11 +1,14 @@
1
1
  """
2
- MobileViT, adapted from
2
+ MobileViT v1, adapted from
3
3
  https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilevit.py
4
4
  and
5
5
  https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/mobile_vit.py
6
6
 
7
7
  Paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer",
8
8
  https://arxiv.org/abs/2110.02178
9
+
10
+ Changes from original:
11
+ * Removed classifier bias
9
12
  """
10
13
 
11
14
  # Reference license: Apache-2.0 and MIT
@@ -63,6 +66,7 @@ class MobileVitBlock(nn.Module):
63
66
  attention_dropout=attn_drop,
64
67
  drop_path=drop_path_rate,
65
68
  activation_layer=nn.SiLU,
69
+ norm_layer_eps=1e-5,
66
70
  )
67
71
  for _ in range(transformer_depth)
68
72
  ]
@@ -97,8 +101,8 @@ class MobileVitBlock(nn.Module):
97
101
  x = self.conv_1x1(x)
98
102
 
99
103
  # Unfold (feature map -> patches)
100
- (patch_h, patch_w) = self.patch_size
101
- (B, C, H, W) = x.shape
104
+ patch_h, patch_w = self.patch_size
105
+ B, C, H, W = x.shape
102
106
  new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
103
107
  num_patch_h = new_h // patch_h # n_h, n_w
104
108
  num_patch_w = new_w // patch_w
@@ -166,7 +170,6 @@ class MobileViT_v1(BaseNet):
166
170
  stride=(2, 2),
167
171
  padding=(1, 1),
168
172
  activation_layer=nn.SiLU,
169
- bias=True,
170
173
  )
171
174
 
172
175
  layers = []
@@ -231,7 +234,6 @@ class MobileViT_v1(BaseNet):
231
234
  stride=(1, 1),
232
235
  padding=(0, 0),
233
236
  activation_layer=nn.SiLU,
234
- bias=True,
235
237
  ),
236
238
  nn.AdaptiveAvgPool2d(output_size=(1, 1)),
237
239
  nn.Flatten(1),
@@ -290,32 +292,3 @@ registry.register_model_config(
290
292
  "expansion": 4,
291
293
  },
292
294
  )
293
-
294
- registry.register_weights(
295
- "mobilevit_v1_xxs_il-common",
296
- {
297
- "description": "MobileViT v1 XXS model trained on the il-common dataset",
298
- "resolution": (256, 256),
299
- "formats": {
300
- "pt": {
301
- "file_size": 4.2,
302
- "sha256": "2b565a768ca21fd72d5ef5090ff0f8b725f3e1165cd8e56749815041e5254d26",
303
- }
304
- },
305
- "net": {"network": "mobilevit_v1_xxs", "tag": "il-common"},
306
- },
307
- )
308
- registry.register_weights(
309
- "mobilevit_v1_xs_il-common",
310
- {
311
- "description": "MobileViT v1 XS model trained on the il-common dataset",
312
- "resolution": (256, 256),
313
- "formats": {
314
- "pt": {
315
- "file_size": 8.1,
316
- "sha256": "193bcede7f0b9f4574673e95c23c6ca3b8eeb30254a32a85e93342f1d67db31b",
317
- }
318
- },
319
- "net": {"network": "mobilevit_v1_xs", "tag": "il-common"},
320
- },
321
- )
@@ -63,7 +63,7 @@ class LinearSelfAttention(nn.Module):
63
63
  # Project x into query, key and value
64
64
  # Query --> [B, 1, P, N]
65
65
  # value, key --> [B, d, P, N]
66
- (query, key, value) = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
66
+ query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
67
67
 
68
68
  # apply softmax along N dimension
69
69
  context_scores = F.softmax(query, dim=-1)
@@ -98,14 +98,10 @@ class LinearTransformerBlock(nn.Module):
98
98
 
99
99
  self.norm2 = nn.GroupNorm(num_groups=1, num_channels=embed_dim)
100
100
  self.mlp = nn.Sequential(
101
- nn.Conv2d(
102
- embed_dim, int(embed_dim * mlp_ratio), kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
103
- ),
101
+ nn.Conv2d(embed_dim, int(embed_dim * mlp_ratio), kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
104
102
  nn.SiLU(),
105
103
  nn.Dropout(drop),
106
- nn.Conv2d(
107
- int(embed_dim * mlp_ratio), embed_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
108
- ),
104
+ nn.Conv2d(int(embed_dim * mlp_ratio), embed_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
109
105
  )
110
106
  self.drop_path2 = StochasticDepth(drop_path, mode="row")
111
107
 
@@ -159,20 +155,15 @@ class MobileVitBlock(nn.Module):
159
155
  self.norm = nn.GroupNorm(num_groups=1, num_channels=transformer_dim)
160
156
 
161
157
  self.conv_proj = Conv2dNormActivation(
162
- transformer_dim,
163
- channels,
164
- kernel_size=(1, 1),
165
- stride=(1, 1),
166
- padding=(0, 0),
167
- activation_layer=nn.SiLU,
158
+ transformer_dim, channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), activation_layer=None
168
159
  )
169
160
 
170
161
  self.patch_size = patch_size
171
162
  self.patch_area = self.patch_size[0] * self.patch_size[1]
172
163
 
173
164
  def forward(self, x: torch.Tensor) -> torch.Tensor:
174
- (B, C, H, W) = x.shape
175
- (patch_h, patch_w) = self.patch_size
165
+ B, C, H, W = x.shape
166
+ patch_h, patch_w = self.patch_size
176
167
  new_h = math.ceil(H / patch_h) * patch_h
177
168
  new_w = math.ceil(W / patch_w) * patch_w
178
169
  num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
@@ -236,7 +227,6 @@ class MobileViT_v2(DetectorBackbone):
236
227
  stride=(2, 2),
237
228
  padding=(1, 1),
238
229
  activation_layer=nn.SiLU,
239
- bias=True,
240
230
  )
241
231
 
242
232
  stages: OrderedDict[str, nn.Module] = OrderedDict()
@@ -340,15 +330,6 @@ class MobileViT_v2(DetectorBackbone):
340
330
  x = self.forward_features(x)
341
331
  return self.features(x)
342
332
 
343
- def create_classifier(self, embed_dim: Optional[int] = None) -> nn.Module:
344
- if self.num_classes == 0:
345
- return nn.Identity()
346
-
347
- if embed_dim is None:
348
- embed_dim = self.embedding_size
349
-
350
- return nn.Linear(embed_dim, self.num_classes, bias=False)
351
-
352
333
 
353
334
  registry.register_model_config("mobilevit_v2_0_25", MobileViT_v2, config={"width_factor": 0.25})
354
335
  registry.register_model_config("mobilevit_v2_0_5", MobileViT_v2, config={"width_factor": 0.5})
@@ -358,32 +339,3 @@ registry.register_model_config("mobilevit_v2_1_25", MobileViT_v2, config={"width
358
339
  registry.register_model_config("mobilevit_v2_1_5", MobileViT_v2, config={"width_factor": 1.5})
359
340
  registry.register_model_config("mobilevit_v2_1_75", MobileViT_v2, config={"width_factor": 1.75})
360
341
  registry.register_model_config("mobilevit_v2_2_0", MobileViT_v2, config={"width_factor": 2.0})
361
-
362
- registry.register_weights(
363
- "mobilevit_v2_1_0_il-common",
364
- {
365
- "description": "MobileViT v2 with width multiplier of 1.0 trained on the il-common dataset",
366
- "resolution": (256, 256),
367
- "formats": {
368
- "pt": {
369
- "file_size": 17.6,
370
- "sha256": "2b45b7f2ffe3dd129d9a7e9690d2dfd0f93ac60f24d118b920a51bcb950fd95e",
371
- }
372
- },
373
- "net": {"network": "mobilevit_v2_1_0", "tag": "il-common"},
374
- },
375
- )
376
- registry.register_weights(
377
- "mobilevit_v2_1_5_il-common",
378
- {
379
- "description": "MobileViT v2 with width multiplier of 1.5 trained on the il-common dataset",
380
- "resolution": (256, 256),
381
- "formats": {
382
- "pt": {
383
- "file_size": 38.8,
384
- "sha256": "acd28c3ee653b62c69ad765c1d99827cea5051deb6dbdd7b9c8d7612782c86a3",
385
- }
386
- },
387
- "net": {"network": "mobilevit_v2_1_5", "tag": "il-common"},
388
- },
389
- )
birder/net/moganet.py CHANGED
@@ -4,6 +4,9 @@ https://github.com/Westlake-AI/MogaNet/blob/main/models/moganet.py
4
4
 
5
5
  Paper "MogaNet: Multi-order Gated Aggregation Network",
6
6
  https://arxiv.org/abs/2211.03295
7
+
8
+ Changes from original:
9
+ * Removed biases before norms
7
10
  """
8
11
 
9
12
  # Reference license: Apache-2.0
@@ -30,7 +33,7 @@ from birder.net.base import TokenRetentionResultType
30
33
  class ElementScale(nn.Module):
31
34
  def __init__(self, embed_dims: int, init_value: float) -> None:
32
35
  super().__init__()
33
- self.scale = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
36
+ self.scale = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)))
34
37
 
35
38
  def forward(self, x: torch.Tensor) -> torch.Tensor:
36
39
  return x * self.scale
@@ -179,14 +182,14 @@ class MogaBlock(nn.Module):
179
182
  super().__init__()
180
183
 
181
184
  # Spatial attention
182
- self.norm1 = nn.BatchNorm2d(embed_dims, eps=1e-5)
185
+ self.norm1 = nn.BatchNorm2d(embed_dims)
183
186
  self.attn = MultiOrderGatedAggregation(
184
187
  embed_dims, attn_dw_dilation=attn_dw_dilation, attn_channel_split=attn_channel_split
185
188
  )
186
- self.layer_scale_1 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
189
+ self.layer_scale_1 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)))
187
190
 
188
191
  # Channel MLP
189
- self.norm2 = nn.BatchNorm2d(embed_dims, eps=1e-5)
192
+ self.norm2 = nn.BatchNorm2d(embed_dims)
190
193
  mlp_hidden_dim = int(embed_dims * ffn_ratio)
191
194
  self.mlp = ChannelAggregationFFN(
192
195
  embed_dims=embed_dims,
@@ -194,7 +197,7 @@ class MogaBlock(nn.Module):
194
197
  kernel_size=3,
195
198
  ffn_drop=drop_rate,
196
199
  )
197
- self.layer_scale_2 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
200
+ self.layer_scale_2 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)))
198
201
 
199
202
  self.drop_path = StochasticDepth(drop_path_rate, mode="row")
200
203