birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/deit3.py CHANGED
@@ -15,12 +15,19 @@ from torch import nn
15
15
 
16
16
  from birder.common.masking import mask_tensor
17
17
  from birder.model_registry import registry
18
+ from birder.net._vit_configs import BASE
19
+ from birder.net._vit_configs import HUGE
20
+ from birder.net._vit_configs import LARGE
21
+ from birder.net._vit_configs import MEDIUM
22
+ from birder.net._vit_configs import SMALL
23
+ from birder.net._vit_configs import TINY
18
24
  from birder.net.base import DetectorBackbone
19
25
  from birder.net.base import MaskedTokenOmissionMixin
20
26
  from birder.net.base import MaskedTokenRetentionMixin
21
27
  from birder.net.base import PreTrainEncoder
22
28
  from birder.net.base import TokenOmissionResultType
23
29
  from birder.net.base import TokenRetentionResultType
30
+ from birder.net.base import normalize_out_indices
24
31
  from birder.net.vit import Encoder
25
32
  from birder.net.vit import EncoderBlock
26
33
  from birder.net.vit import PatchEmbed
@@ -53,6 +60,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
53
60
  mlp_dim: int = self.config["mlp_dim"]
54
61
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", 1e-5)
55
62
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
63
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
56
64
  drop_path_rate: float = self.config["drop_path_rate"]
57
65
 
58
66
  torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
@@ -64,6 +72,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
64
72
  self.num_reg_tokens = num_reg_tokens
65
73
  self.num_special_tokens = 1 + self.num_reg_tokens
66
74
  self.pos_embed_special_tokens = pos_embed_special_tokens
75
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
67
76
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
68
77
 
69
78
  self.conv_proj = nn.Conv2d(
@@ -72,7 +81,6 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
72
81
  kernel_size=(patch_size, patch_size),
73
82
  stride=(patch_size, patch_size),
74
83
  padding=(0, 0),
75
- bias=True,
76
84
  )
77
85
  self.patch_embed = PatchEmbed()
78
86
 
@@ -106,8 +114,9 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
106
114
  )
107
115
  self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
108
116
 
109
- self.return_stages = ["neck"] # Actually meaningless, just for completeness
110
- self.return_channels = [hidden_dim]
117
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
118
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
119
+ self.return_channels = [hidden_dim] * num_return_stages
111
120
  self.embedding_size = hidden_dim
112
121
  self.classifier = self.create_classifier()
113
122
 
@@ -153,7 +162,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
153
162
  )
154
163
 
155
164
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
156
- (H, W) = x.shape[-2:]
165
+ H, W = x.shape[-2:]
157
166
 
158
167
  x = self.conv_proj(x)
159
168
  x = self.patch_embed(x)
@@ -170,15 +179,20 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
170
179
  x = x + self._get_pos_embed(H, W)
171
180
  x = torch.concat([batch_special_tokens, x], dim=1)
172
181
 
173
- x = self.encoder(x)
174
- x = self.norm(x)
182
+ if self.out_indices is None:
183
+ xs = [self.encoder(x)]
184
+ else:
185
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
175
186
 
176
- x = x[:, self.num_special_tokens :]
177
- x = x.permute(0, 2, 1)
178
- (B, C, _) = x.size()
179
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
187
+ out: dict[str, torch.Tensor] = {}
188
+ for stage_name, stage_x in zip(self.return_stages, xs):
189
+ stage_x = stage_x[:, self.num_special_tokens :]
190
+ stage_x = stage_x.permute(0, 2, 1)
191
+ B, C, _ = stage_x.size()
192
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
193
+ out[stage_name] = stage_x
180
194
 
181
- return {self.return_stages[0]: x}
195
+ return out
182
196
 
183
197
  def freeze_stages(self, up_to_stage: int) -> None:
184
198
  for param in self.conv_proj.parameters():
@@ -193,6 +207,10 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
193
207
  for param in module.parameters():
194
208
  param.requires_grad_(False)
195
209
 
210
+ def transform_to_backbone(self) -> None:
211
+ super().transform_to_backbone()
212
+ self.norm = nn.Identity()
213
+
196
214
  def set_causal_attention(self, is_causal: bool = True) -> None:
197
215
  self.encoder.set_causal_attention(is_causal)
198
216
 
@@ -203,7 +221,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
203
221
  return_all_features: bool = False,
204
222
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
205
223
  ) -> TokenOmissionResultType:
206
- (H, W) = x.shape[-2:]
224
+ H, W = x.shape[-2:]
207
225
 
208
226
  # Reshape and permute the input tensor
209
227
  x = self.conv_proj(x)
@@ -266,7 +284,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
266
284
  mask_token: Optional[torch.Tensor] = None,
267
285
  return_keys: Literal["all", "features", "embedding"] = "features",
268
286
  ) -> TokenRetentionResultType:
269
- (H, W) = x.shape[-2:]
287
+ H, W = x.shape[-2:]
270
288
 
271
289
  x = self.conv_proj(x)
272
290
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -296,7 +314,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
296
314
  if return_keys in ("all", "features"):
297
315
  features = x[:, self.num_special_tokens :]
298
316
  features = features.permute(0, 2, 1)
299
- (B, C, _) = features.size()
317
+ B, C, _ = features.size()
300
318
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
301
319
  result["features"] = features
302
320
 
@@ -306,7 +324,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
306
324
  return result
307
325
 
308
326
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
309
- (H, W) = x.shape[-2:]
327
+ H, W = x.shape[-2:]
310
328
 
311
329
  # Reshape and permute the input tensor
312
330
  x = self.conv_proj(x)
@@ -368,279 +386,126 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
368
386
  registry.register_model_config(
369
387
  "deit3_t16",
370
388
  DeiT3,
371
- config={
372
- "patch_size": 16,
373
- "num_layers": 12,
374
- "num_heads": 3,
375
- "hidden_dim": 192,
376
- "mlp_dim": 768,
377
- "drop_path_rate": 0.0,
378
- },
389
+ config={"patch_size": 16, **TINY},
390
+ )
391
+ registry.register_model_config(
392
+ "deit3_t14",
393
+ DeiT3,
394
+ config={"patch_size": 14, **TINY},
379
395
  )
380
396
  registry.register_model_config(
381
397
  "deit3_s16",
382
398
  DeiT3,
383
- config={
384
- "patch_size": 16,
385
- "num_layers": 12,
386
- "num_heads": 6,
387
- "hidden_dim": 384,
388
- "mlp_dim": 1536,
389
- "drop_path_rate": 0.05,
390
- },
399
+ config={"patch_size": 16, **SMALL, "drop_path_rate": 0.05},
391
400
  )
392
401
  registry.register_model_config(
393
402
  "deit3_s14",
394
403
  DeiT3,
395
- config={
396
- "patch_size": 14,
397
- "num_layers": 12,
398
- "num_heads": 6,
399
- "hidden_dim": 384,
400
- "mlp_dim": 1536,
401
- "drop_path_rate": 0.05,
402
- },
404
+ config={"patch_size": 14, **SMALL, "drop_path_rate": 0.05},
403
405
  )
404
406
  registry.register_model_config(
405
407
  "deit3_m16",
406
408
  DeiT3,
407
- config={
408
- "patch_size": 16,
409
- "num_layers": 12,
410
- "num_heads": 8,
411
- "hidden_dim": 512,
412
- "mlp_dim": 2048,
413
- "drop_path_rate": 0.1,
414
- },
409
+ config={"patch_size": 16, **MEDIUM, "drop_path_rate": 0.1},
415
410
  )
416
411
  registry.register_model_config(
417
412
  "deit3_m14",
418
413
  DeiT3,
419
- config={
420
- "patch_size": 14,
421
- "num_layers": 12,
422
- "num_heads": 8,
423
- "hidden_dim": 512,
424
- "mlp_dim": 2048,
425
- "drop_path_rate": 0.1,
426
- },
414
+ config={"patch_size": 14, **MEDIUM, "drop_path_rate": 0.1},
427
415
  )
428
416
  registry.register_model_config(
429
417
  "deit3_b16",
430
418
  DeiT3,
431
- config={
432
- "patch_size": 16,
433
- "num_layers": 12,
434
- "num_heads": 12,
435
- "hidden_dim": 768,
436
- "mlp_dim": 3072,
437
- "drop_path_rate": 0.2,
438
- },
419
+ config={"patch_size": 16, **BASE, "drop_path_rate": 0.2},
439
420
  )
440
421
  registry.register_model_config(
441
422
  "deit3_b14",
442
423
  DeiT3,
443
- config={
444
- "patch_size": 14,
445
- "num_layers": 12,
446
- "num_heads": 12,
447
- "hidden_dim": 768,
448
- "mlp_dim": 3072,
449
- "drop_path_rate": 0.2,
450
- },
424
+ config={"patch_size": 14, **BASE, "drop_path_rate": 0.2},
451
425
  )
452
426
  registry.register_model_config(
453
427
  "deit3_l16",
454
428
  DeiT3,
455
- config={
456
- "patch_size": 16,
457
- "num_layers": 24,
458
- "num_heads": 16,
459
- "hidden_dim": 1024,
460
- "mlp_dim": 4096,
461
- "drop_path_rate": 0.45,
462
- },
429
+ config={"patch_size": 16, **LARGE, "drop_path_rate": 0.45},
463
430
  )
464
431
  registry.register_model_config(
465
432
  "deit3_l14",
466
433
  DeiT3,
467
- config={
468
- "patch_size": 14,
469
- "num_layers": 24,
470
- "num_heads": 16,
471
- "hidden_dim": 1024,
472
- "mlp_dim": 4096,
473
- "drop_path_rate": 0.45,
474
- },
434
+ config={"patch_size": 14, **LARGE, "drop_path_rate": 0.45},
475
435
  )
476
436
  registry.register_model_config(
477
437
  "deit3_h16",
478
438
  DeiT3,
479
- config={
480
- "patch_size": 16,
481
- "num_layers": 32,
482
- "num_heads": 16,
483
- "hidden_dim": 1280,
484
- "mlp_dim": 5120,
485
- "drop_path_rate": 0.55,
486
- },
439
+ config={"patch_size": 16, **HUGE, "drop_path_rate": 0.55},
487
440
  )
488
441
  registry.register_model_config(
489
442
  "deit3_h14",
490
443
  DeiT3,
491
- config={
492
- "patch_size": 14,
493
- "num_layers": 32,
494
- "num_heads": 16,
495
- "hidden_dim": 1280,
496
- "mlp_dim": 5120,
497
- "drop_path_rate": 0.55,
498
- },
444
+ config={"patch_size": 14, **HUGE, "drop_path_rate": 0.55},
499
445
  )
500
446
 
501
447
  # With registers
448
+ ####################
449
+
502
450
  registry.register_model_config(
503
451
  "deit3_reg4_t16",
504
452
  DeiT3,
505
- config={
506
- "patch_size": 16,
507
- "num_layers": 12,
508
- "num_heads": 3,
509
- "hidden_dim": 192,
510
- "mlp_dim": 768,
511
- "num_reg_tokens": 4,
512
- "drop_path_rate": 0.0,
513
- },
453
+ config={"patch_size": 16, **TINY, "num_reg_tokens": 4},
454
+ )
455
+ registry.register_model_config(
456
+ "deit3_reg4_t14",
457
+ DeiT3,
458
+ config={"patch_size": 14, **TINY, "num_reg_tokens": 4},
514
459
  )
515
460
  registry.register_model_config(
516
461
  "deit3_reg4_s16",
517
462
  DeiT3,
518
- config={
519
- "patch_size": 16,
520
- "num_layers": 12,
521
- "num_heads": 6,
522
- "hidden_dim": 384,
523
- "mlp_dim": 1536,
524
- "num_reg_tokens": 4,
525
- "drop_path_rate": 0.05,
526
- },
463
+ config={"patch_size": 16, **SMALL, "num_reg_tokens": 4, "drop_path_rate": 0.05},
527
464
  )
528
465
  registry.register_model_config(
529
466
  "deit3_reg4_s14",
530
467
  DeiT3,
531
- config={
532
- "patch_size": 14,
533
- "num_layers": 12,
534
- "num_heads": 6,
535
- "hidden_dim": 384,
536
- "mlp_dim": 1536,
537
- "num_reg_tokens": 4,
538
- "drop_path_rate": 0.05,
539
- },
468
+ config={"patch_size": 14, **SMALL, "num_reg_tokens": 4, "drop_path_rate": 0.05},
540
469
  )
541
470
  registry.register_model_config(
542
471
  "deit3_reg4_m16",
543
472
  DeiT3,
544
- config={
545
- "patch_size": 16,
546
- "num_layers": 12,
547
- "num_heads": 8,
548
- "hidden_dim": 512,
549
- "mlp_dim": 2048,
550
- "num_reg_tokens": 4,
551
- "drop_path_rate": 0.1,
552
- },
473
+ config={"patch_size": 16, **MEDIUM, "num_reg_tokens": 4, "drop_path_rate": 0.1},
553
474
  )
554
475
  registry.register_model_config(
555
476
  "deit3_reg4_m14",
556
477
  DeiT3,
557
- config={
558
- "patch_size": 14,
559
- "num_layers": 12,
560
- "num_heads": 8,
561
- "hidden_dim": 512,
562
- "mlp_dim": 2048,
563
- "num_reg_tokens": 4,
564
- "drop_path_rate": 0.1,
565
- },
478
+ config={"patch_size": 14, **MEDIUM, "num_reg_tokens": 4, "drop_path_rate": 0.1},
566
479
  )
567
480
  registry.register_model_config(
568
481
  "deit3_reg4_b16",
569
482
  DeiT3,
570
- config={
571
- "patch_size": 16,
572
- "num_layers": 12,
573
- "num_heads": 12,
574
- "hidden_dim": 768,
575
- "mlp_dim": 3072,
576
- "num_reg_tokens": 4,
577
- "drop_path_rate": 0.2,
578
- },
483
+ config={"patch_size": 16, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.2},
579
484
  )
580
485
  registry.register_model_config(
581
486
  "deit3_reg4_b14",
582
487
  DeiT3,
583
- config={
584
- "patch_size": 14,
585
- "num_layers": 12,
586
- "num_heads": 12,
587
- "hidden_dim": 768,
588
- "mlp_dim": 3072,
589
- "num_reg_tokens": 4,
590
- "drop_path_rate": 0.2,
591
- },
488
+ config={"patch_size": 14, **BASE, "num_reg_tokens": 4, "drop_path_rate": 0.2},
592
489
  )
593
490
  registry.register_model_config(
594
491
  "deit3_reg4_l16",
595
492
  DeiT3,
596
- config={
597
- "patch_size": 16,
598
- "num_layers": 24,
599
- "num_heads": 16,
600
- "hidden_dim": 1024,
601
- "mlp_dim": 4096,
602
- "num_reg_tokens": 4,
603
- "drop_path_rate": 0.45,
604
- },
493
+ config={"patch_size": 16, **LARGE, "num_reg_tokens": 4, "drop_path_rate": 0.45},
605
494
  )
606
495
  registry.register_model_config(
607
496
  "deit3_reg4_l14",
608
497
  DeiT3,
609
- config={
610
- "patch_size": 14,
611
- "num_layers": 24,
612
- "num_heads": 16,
613
- "hidden_dim": 1024,
614
- "mlp_dim": 4096,
615
- "num_reg_tokens": 4,
616
- "drop_path_rate": 0.45,
617
- },
498
+ config={"patch_size": 14, **LARGE, "num_reg_tokens": 4, "drop_path_rate": 0.45},
618
499
  )
619
500
  registry.register_model_config(
620
501
  "deit3_reg4_h16",
621
502
  DeiT3,
622
- config={
623
- "patch_size": 16,
624
- "num_layers": 32,
625
- "num_heads": 16,
626
- "hidden_dim": 1280,
627
- "mlp_dim": 5120,
628
- "num_reg_tokens": 4,
629
- "drop_path_rate": 0.55,
630
- },
503
+ config={"patch_size": 16, **HUGE, "num_reg_tokens": 4, "drop_path_rate": 0.55},
631
504
  )
632
505
  registry.register_model_config(
633
506
  "deit3_reg4_h14",
634
507
  DeiT3,
635
- config={
636
- "patch_size": 14,
637
- "num_layers": 32,
638
- "num_heads": 16,
639
- "hidden_dim": 1280,
640
- "mlp_dim": 5120,
641
- "num_reg_tokens": 4,
642
- "drop_path_rate": 0.55,
643
- },
508
+ config={"patch_size": 14, **HUGE, "num_reg_tokens": 4, "drop_path_rate": 0.55},
644
509
  )
645
510
 
646
511
  registry.register_weights(
@@ -651,7 +516,7 @@ registry.register_weights(
651
516
  "formats": {
652
517
  "pt": {
653
518
  "file_size": 21.5,
654
- "sha256": "6cd9749a9522f8ff61088e38702553fb1c4d2547b417c499652e3bfa6a81e77a",
519
+ "sha256": "a04141c7f6c459ae075a48ccdee5b82d191bbaa82337673140c06ef82f0a8dc5",
655
520
  }
656
521
  },
657
522
  "net": {"network": "deit3_t16", "tag": "il-common"},
@@ -665,7 +530,7 @@ registry.register_weights(
665
530
  "formats": {
666
531
  "pt": {
667
532
  "file_size": 21.5,
668
- "sha256": "6806a5ae7d45f1c84b25e9869a9cbc7de94368fe9573dc3777acf2da8c83dc4e",
533
+ "sha256": "d26320462da64df6d62b307f7fb35d09c86a5f073002dfb24a51f014074e65c3",
669
534
  }
670
535
  },
671
536
  "net": {"network": "deit3_reg4_t16", "tag": "il-common"},
birder/net/densenet.py CHANGED
@@ -104,19 +104,20 @@ class DenseNet(DetectorBackbone):
104
104
  num_features = num_init_features
105
105
  stages: OrderedDict[str, nn.Module] = OrderedDict()
106
106
  return_channels: list[int] = []
107
- layers = []
108
107
  for i, num_layers in enumerate(layer_list):
108
+ stage_layers = []
109
+ if i != 0:
110
+ stage_layers.append(TransitionBlock(num_features, num_features // 2))
111
+ num_features = num_features // 2
109
112
 
110
- layers.append(DenseBlock(num_features, num_layers=num_layers, growth_rate=growth_rate))
113
+ stage_layers.append(DenseBlock(num_features, num_layers=num_layers, growth_rate=growth_rate))
111
114
  num_features = num_features + (num_layers * growth_rate)
115
+ if i == len(layer_list) - 1:
116
+ stage_layers.append(nn.BatchNorm2d(num_features))
117
+ stage_layers.append(nn.ReLU(inplace=True))
112
118
 
113
- stages[f"stage{i+1}"] = nn.Sequential(*layers)
119
+ stages[f"stage{i+1}"] = nn.Sequential(*stage_layers)
114
120
  return_channels.append(num_features)
115
- layers = []
116
-
117
- if i != len(layer_list) - 1:
118
- layers.append(TransitionBlock(num_features, num_features // 2))
119
- num_features = num_features // 2
120
121
 
121
122
  self.body = nn.Sequential(stages)
122
123
  self.features = nn.Sequential(
@@ -3,8 +3,10 @@ from birder.net.detection.detr import DETR
3
3
  from birder.net.detection.efficientdet import EfficientDet
4
4
  from birder.net.detection.faster_rcnn import Faster_RCNN
5
5
  from birder.net.detection.fcos import FCOS
6
+ from birder.net.detection.plain_detr import Plain_DETR
6
7
  from birder.net.detection.retinanet import RetinaNet
7
8
  from birder.net.detection.rt_detr_v1 import RT_DETR_v1
9
+ from birder.net.detection.rt_detr_v2 import RT_DETR_v2
8
10
  from birder.net.detection.ssd import SSD
9
11
  from birder.net.detection.ssdlite import SSDLite
10
12
  from birder.net.detection.vitdet import ViTDet
@@ -19,8 +21,10 @@ __all__ = [
19
21
  "EfficientDet",
20
22
  "Faster_RCNN",
21
23
  "FCOS",
24
+ "Plain_DETR",
22
25
  "RetinaNet",
23
26
  "RT_DETR_v1",
27
+ "RT_DETR_v2",
24
28
  "SSD",
25
29
  "SSDLite",
26
30
  "ViTDet",
@@ -71,7 +71,7 @@ def scale_anchors(anchors: AnchorGroups, from_size: tuple[int, int], to_size: tu
71
71
 
72
72
 
73
73
  def scale_anchors(anchors: AnchorLike, from_size: tuple[int, int], to_size: tuple[int, int]) -> AnchorLike:
74
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
74
+ anchor_groups, single = _normalize_anchor_groups(anchors)
75
75
 
76
76
  if from_size == to_size:
77
77
  # Avoid aliasing default anchors in case they are mutated later
@@ -100,7 +100,7 @@ def pixels_to_grid(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
100
100
 
101
101
 
102
102
  def pixels_to_grid(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
103
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
103
+ anchor_groups, single = _normalize_anchor_groups(anchors)
104
104
  if len(anchor_groups) != len(strides):
105
105
  raise ValueError("strides must provide one value per anchor scale")
106
106
 
@@ -123,7 +123,7 @@ def grid_to_pixels(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
123
123
 
124
124
 
125
125
  def grid_to_pixels(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
126
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
126
+ anchor_groups, single = _normalize_anchor_groups(anchors)
127
127
  if len(anchor_groups) != len(strides):
128
128
  raise ValueError("strides must provide one value per anchor scale")
129
129
 
@@ -187,7 +187,7 @@ def resolve_anchor_group(
187
187
  preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
188
188
  ) -> AnchorGroup:
189
189
  anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
190
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
190
+ anchor_groups, single = _normalize_anchor_groups(anchors)
191
191
  if single is False:
192
192
  raise ValueError("Expected a single anchor group for this model")
193
193
 
@@ -198,7 +198,7 @@ def resolve_anchor_groups(
198
198
  preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
199
199
  ) -> AnchorGroups:
200
200
  anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
201
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
201
+ anchor_groups, single = _normalize_anchor_groups(anchors)
202
202
  if single is True:
203
203
  raise ValueError("Expected multiple anchor groups for this model")
204
204
 
@@ -41,6 +41,7 @@ def get_detection_signature(input_shape: tuple[int, ...], num_outputs: int, dyna
41
41
 
42
42
  class DetectionBaseNet(nn.Module):
43
43
  default_size: tuple[int, int]
44
+ block_group_regex: Optional[str]
44
45
  auto_register = False
45
46
  scriptable = True
46
47
  task = str(Task.OBJECT_DETECTION)
@@ -308,7 +309,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
308
309
  names.append(f"stage{idx+1}")
309
310
 
310
311
  if self.extra_blocks is not None:
311
- (results, names) = self.extra_blocks(results, [x], names)
312
+ results, names = self.extra_blocks(results, [x], names)
312
313
 
313
314
  out = OrderedDict(list(zip(names, results)))
314
315
 
@@ -432,7 +433,7 @@ class BoxCoder:
432
433
  ctr_x = boxes[:, 0] + 0.5 * widths
433
434
  ctr_y = boxes[:, 1] + 0.5 * heights
434
435
 
435
- (wx, wy, ww, wh) = self.weights
436
+ wx, wy, ww, wh = self.weights
436
437
  dx = rel_codes[:, 0::4] / wx
437
438
  dy = rel_codes[:, 1::4] / wy
438
439
  dw = rel_codes[:, 2::4] / ww
@@ -510,8 +511,8 @@ class AnchorGenerator(nn.Module):
510
511
  )
511
512
 
512
513
  for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
513
- (grid_height, grid_width) = size
514
- (stride_height, stride_width) = stride
514
+ grid_height, grid_width = size
515
+ stride_height, stride_width = stride
515
516
  device = base_anchors.device
516
517
 
517
518
  # For output anchor, compute [x_center, y_center, x_center, y_center]
@@ -656,7 +657,7 @@ class Matcher(nn.Module):
656
657
  # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
657
658
  # Each element in the first tensor is a gt index,
658
659
  # and each element in second tensor is a prediction index
659
- # Note how gt items 1, 2, 3, and 5 each have two ties
660
+ # Note how gt items 1, 2, 3 and 5 each have two ties
660
661
 
661
662
  pred_idx_to_update = gt_pred_pairs_of_highest_quality[1]
662
663
  matches[pred_idx_to_update] = all_matches[pred_idx_to_update]