birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/repvgg.py CHANGED
@@ -56,7 +56,6 @@ class RepVggBlock(nn.Module):
56
56
  stride=(stride, stride),
57
57
  padding=(padding, padding),
58
58
  groups=groups,
59
- bias=True,
60
59
  )
61
60
  else:
62
61
  self.reparam_conv = None
@@ -113,7 +112,7 @@ class RepVggBlock(nn.Module):
113
112
  if self.reparameterized is True:
114
113
  return
115
114
 
116
- (kernel, bias) = self._get_kernel_bias()
115
+ kernel, bias = self._get_kernel_bias()
117
116
  self.reparam_conv = nn.Conv2d(
118
117
  in_channels=self.conv_kxk.conv.in_channels,
119
118
  out_channels=self.conv_kxk.conv.out_channels,
@@ -122,7 +121,6 @@ class RepVggBlock(nn.Module):
122
121
  padding=self.conv_kxk.conv.padding,
123
122
  dilation=self.conv_kxk.conv.dilation,
124
123
  groups=self.conv_kxk.conv.groups,
125
- bias=True,
126
124
  )
127
125
  self.reparam_conv.weight.data = kernel
128
126
  self.reparam_conv.bias.data = bias
@@ -151,10 +149,10 @@ class RepVggBlock(nn.Module):
151
149
  kernel_identity = 0
152
150
  bias_identity = 0
153
151
  if self.rbr_identity is not None:
154
- (kernel_identity, bias_identity) = self._fuse_bn_tensor(self.rbr_identity)
152
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_identity)
155
153
 
156
154
  # Get weights and bias of conv branches
157
- (kernel_conv, bias_conv) = self._fuse_bn_tensor(self.conv_kxk)
155
+ kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
158
156
 
159
157
  kernel_final = kernel_conv + kernel_1x1 + kernel_identity
160
158
  bias_final = bias_conv + bias_1x1 + bias_identity
birder/net/repvit.py CHANGED
@@ -60,7 +60,7 @@ class RepConvBN(nn.Sequential):
60
60
  if self.reparameterized is True:
61
61
  return
62
62
 
63
- (c, bn) = self._modules.values()
63
+ c, bn = self._modules.values()
64
64
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
65
65
  w = c.weight * w[:, None, None, None]
66
66
  b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
@@ -101,7 +101,7 @@ class RepNormLinear(nn.Sequential):
101
101
  if self.reparameterized is True:
102
102
  return
103
103
 
104
- (bn, li) = self._modules.values()
104
+ bn, li = self._modules.values()
105
105
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
106
106
  b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
107
107
  w = li.weight * w[None, :]
birder/net/resnest.py CHANGED
@@ -85,7 +85,7 @@ class SplitAttn(nn.Module):
85
85
  def forward(self, x: torch.Tensor) -> torch.Tensor:
86
86
  x = self.conv(x)
87
87
 
88
- (B, RC, H, W) = x.size() # pylint: disable=invalid-name
88
+ B, RC, H, W = x.size() # pylint: disable=invalid-name
89
89
  if self.radix > 1:
90
90
  x = x.reshape((B, self.radix, RC // self.radix, H, W))
91
91
  x_gap = x.sum(dim=1)
birder/net/rope_deit3.py CHANGED
@@ -34,6 +34,7 @@ from birder.net.base import MaskedTokenRetentionMixin
34
34
  from birder.net.base import PreTrainEncoder
35
35
  from birder.net.base import TokenOmissionResultType
36
36
  from birder.net.base import TokenRetentionResultType
37
+ from birder.net.base import normalize_out_indices
37
38
  from birder.net.rope_vit import Encoder
38
39
  from birder.net.rope_vit import MAEDecoderBlock
39
40
  from birder.net.rope_vit import RoPE
@@ -46,6 +47,7 @@ from birder.net.vit import adjust_position_embedding
46
47
  class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTokenRetentionMixin):
47
48
  block_group_regex = r"encoder\.block\.(\d+)"
48
49
 
50
+ # pylint: disable=too-many-locals
49
51
  def __init__(
50
52
  self,
51
53
  input_channels: int,
@@ -68,6 +70,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
68
70
  mlp_dim: int = self.config["mlp_dim"]
69
71
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", 1e-5)
70
72
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
73
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
71
74
  rope_rot_type: Literal["standard", "interleaved"] = self.config.get("rope_rot_type", "standard")
72
75
  rope_grid_indexing: Literal["ij", "xy"] = self.config.get("rope_grid_indexing", "ij")
73
76
  rope_grid_offset: int = self.config.get("rope_grid_offset", 0)
@@ -86,6 +89,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
86
89
  self.num_reg_tokens = num_reg_tokens
87
90
  self.num_special_tokens = 1 + self.num_reg_tokens
88
91
  self.pos_embed_special_tokens = pos_embed_special_tokens
92
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
89
93
  self.rope_rot_type = rope_rot_type
90
94
  self.rope_grid_indexing = rope_grid_indexing
91
95
  self.rope_grid_offset = rope_grid_offset
@@ -105,7 +109,6 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
105
109
  kernel_size=(patch_size, patch_size),
106
110
  stride=(patch_size, patch_size),
107
111
  padding=(0, 0),
108
- bias=True,
109
112
  )
110
113
  self.patch_embed = PatchEmbed()
111
114
 
@@ -153,8 +156,9 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
153
156
  )
154
157
  self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
155
158
 
156
- self.return_stages = ["neck"] # Actually meaningless, but for completeness
157
- self.return_channels = [hidden_dim]
159
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
160
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
161
+ self.return_channels = [hidden_dim] * num_return_stages
158
162
  self.embedding_size = hidden_dim
159
163
  self.classifier = self.create_classifier()
160
164
 
@@ -222,7 +226,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
222
226
  ).to(self.rope.pos_embed.device)
223
227
 
224
228
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
225
- (H, W) = x.shape[-2:]
229
+ H, W = x.shape[-2:]
226
230
  x = self.conv_proj(x)
227
231
  x = self.patch_embed(x)
228
232
 
@@ -238,15 +242,21 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
238
242
  x = x + self._get_pos_embed(H, W)
239
243
  x = torch.concat([batch_special_tokens, x], dim=1)
240
244
 
241
- x = self.encoder(x, self._get_rope_embed(H, W))
242
- x = self.norm(x)
245
+ rope = self._get_rope_embed(H, W)
246
+ if self.out_indices is None:
247
+ xs = [self.encoder(x, rope)]
248
+ else:
249
+ xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
243
250
 
244
- x = x[:, self.num_special_tokens :]
245
- x = x.permute(0, 2, 1)
246
- (B, C, _) = x.size()
247
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
251
+ out: dict[str, torch.Tensor] = {}
252
+ for stage_name, stage_x in zip(self.return_stages, xs):
253
+ stage_x = stage_x[:, self.num_special_tokens :]
254
+ stage_x = stage_x.permute(0, 2, 1)
255
+ B, C, _ = stage_x.size()
256
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
257
+ out[stage_name] = stage_x
248
258
 
249
- return {self.return_stages[0]: x}
259
+ return out
250
260
 
251
261
  def freeze_stages(self, up_to_stage: int) -> None:
252
262
  for param in self.conv_proj.parameters():
@@ -261,6 +271,10 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
261
271
  for param in module.parameters():
262
272
  param.requires_grad_(False)
263
273
 
274
+ def transform_to_backbone(self) -> None:
275
+ super().transform_to_backbone()
276
+ self.norm = nn.Identity()
277
+
264
278
  def set_causal_attention(self, is_causal: bool = True) -> None:
265
279
  self.encoder.set_causal_attention(is_causal)
266
280
 
@@ -271,7 +285,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
271
285
  return_all_features: bool = False,
272
286
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
273
287
  ) -> TokenOmissionResultType:
274
- (H, W) = x.shape[-2:]
288
+ H, W = x.shape[-2:]
275
289
 
276
290
  # Reshape and permute the input tensor
277
291
  x = self.conv_proj(x)
@@ -340,7 +354,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
340
354
  mask_token: Optional[torch.Tensor] = None,
341
355
  return_keys: Literal["all", "features", "embedding"] = "features",
342
356
  ) -> TokenRetentionResultType:
343
- (H, W) = x.shape[-2:]
357
+ H, W = x.shape[-2:]
344
358
 
345
359
  x = self.conv_proj(x)
346
360
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -370,7 +384,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
370
384
  if return_keys in ("all", "features"):
371
385
  features = x[:, self.num_special_tokens :]
372
386
  features = features.permute(0, 2, 1)
373
- (B, C, _) = features.size()
387
+ B, C, _ = features.size()
374
388
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
375
389
  result["features"] = features
376
390
 
@@ -380,7 +394,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
380
394
  return result
381
395
 
382
396
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
383
- (H, W) = x.shape[-2:]
397
+ H, W = x.shape[-2:]
384
398
 
385
399
  # Reshape and permute the input tensor
386
400
  x = self.conv_proj(x)
@@ -29,6 +29,7 @@ from birder.net.base import MaskedTokenRetentionMixin
29
29
  from birder.net.base import PreTrainEncoder
30
30
  from birder.net.base import TokenOmissionResultType
31
31
  from birder.net.base import TokenRetentionResultType
32
+ from birder.net.base import normalize_out_indices
32
33
  from birder.net.flexivit import flex_proj
33
34
  from birder.net.flexivit import get_patch_sizes
34
35
  from birder.net.flexivit import interpolate_proj
@@ -82,6 +83,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
82
83
  norm_layer_eps: float = self.config.get("norm_layer_eps", 1e-6)
83
84
  mlp_layer_type: str = self.config.get("mlp_layer_type", "FFN")
84
85
  act_layer_type: Optional[str] = self.config.get("act_layer_type", None) # Default according to mlp type
86
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
85
87
  rope_rot_type: Literal["standard", "interleaved"] = self.config.get("rope_rot_type", "standard")
86
88
  rope_grid_indexing: Literal["ij", "xy"] = self.config.get("rope_grid_indexing", "ij")
87
89
  rope_grid_offset: int = self.config.get("rope_grid_offset", 0)
@@ -125,6 +127,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
125
127
  self.norm_layer_eps = norm_layer_eps
126
128
  self.mlp_layer = mlp_layer
127
129
  self.act_layer = act_layer
130
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
128
131
  self.rope_rot_type = rope_rot_type
129
132
  self.rope_grid_indexing = rope_grid_indexing
130
133
  self.rope_grid_offset = rope_grid_offset
@@ -145,7 +148,6 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
145
148
  kernel_size=(patch_size, patch_size),
146
149
  stride=(patch_size, patch_size),
147
150
  padding=(0, 0),
148
- bias=True,
149
151
  )
150
152
  self.patch_embed = PatchEmbed()
151
153
 
@@ -218,8 +220,9 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
218
220
 
219
221
  self.attn_pool = MultiHeadAttentionPool(hidden_dim, attn_pool_num_heads, mlp_dim, qkv_bias=True)
220
222
 
221
- self.return_stages = ["neck"] # Actually meaningless, just for completeness
222
- self.return_channels = [hidden_dim]
223
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
224
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
225
+ self.return_channels = [hidden_dim] * num_return_stages
223
226
  self.embedding_size = hidden_dim
224
227
  self.classifier = self.create_classifier()
225
228
 
@@ -307,8 +310,12 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
307
310
  def set_causal_attention(self, is_causal: bool = True) -> None:
308
311
  self.encoder.set_causal_attention(is_causal)
309
312
 
313
+ def transform_to_backbone(self) -> None:
314
+ super().transform_to_backbone()
315
+ self.norm = nn.Identity()
316
+
310
317
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
311
- (H, W) = x.shape[-2:]
318
+ H, W = x.shape[-2:]
312
319
  x = self.conv_proj(x)
313
320
  x = self.patch_embed(x)
314
321
 
@@ -328,15 +335,21 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
328
335
  if self.pos_embed_special_tokens is True:
329
336
  x = x + self._get_pos_embed(H, W)
330
337
 
331
- x = self.encoder(x, self._get_rope_embed(H, W))
332
- x = self.norm(x)
338
+ rope = self._get_rope_embed(H, W)
339
+ if self.out_indices is None:
340
+ xs = [self.encoder(x, rope)]
341
+ else:
342
+ xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
333
343
 
334
- x = x[:, self.num_special_tokens :]
335
- x = x.permute(0, 2, 1)
336
- (B, C, _) = x.size()
337
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
344
+ out: dict[str, torch.Tensor] = {}
345
+ for stage_name, stage_x in zip(self.return_stages, xs):
346
+ stage_x = stage_x[:, self.num_special_tokens :]
347
+ stage_x = stage_x.permute(0, 2, 1)
348
+ B, C, _ = stage_x.size()
349
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
350
+ out[stage_name] = stage_x
338
351
 
339
- return {self.return_stages[0]: x}
352
+ return out
340
353
 
341
354
  def freeze_stages(self, up_to_stage: int) -> None:
342
355
  for param in self.conv_proj.parameters():
@@ -359,7 +372,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
359
372
  return_all_features: bool = False,
360
373
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
361
374
  ) -> TokenOmissionResultType:
362
- (H, W) = x.shape[-2:]
375
+ H, W = x.shape[-2:]
363
376
 
364
377
  # Reshape and permute the input tensor
365
378
  x = self.conv_proj(x)
@@ -439,7 +452,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
439
452
  mask_token: Optional[torch.Tensor] = None,
440
453
  return_keys: Literal["all", "features", "embedding"] = "features",
441
454
  ) -> TokenRetentionResultType:
442
- (H, W) = x.shape[-2:]
455
+ H, W = x.shape[-2:]
443
456
 
444
457
  x = self.conv_proj(x)
445
458
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -470,7 +483,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
470
483
  if return_keys in ("all", "features"):
471
484
  features = x[:, self.num_special_tokens :]
472
485
  features = features.permute(0, 2, 1)
473
- (B, C, _) = features.size()
486
+ B, C, _ = features.size()
474
487
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
475
488
  result["features"] = features
476
489
 
@@ -490,7 +503,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
490
503
  return result
491
504
 
492
505
  def forward_features(self, x: torch.Tensor, patch_size: Optional[int] = None) -> torch.Tensor:
493
- (H, W) = x.shape[-2:]
506
+ H, W = x.shape[-2:]
494
507
 
495
508
  # Reshape and permute the input tensor
496
509
  x = flex_proj(x, self.conv_proj.weight, self.conv_proj.bias, patch_size)
birder/net/rope_vit.py CHANGED
@@ -38,6 +38,7 @@ from birder.net.base import MaskedTokenRetentionMixin
38
38
  from birder.net.base import PreTrainEncoder
39
39
  from birder.net.base import TokenOmissionResultType
40
40
  from birder.net.base import TokenRetentionResultType
41
+ from birder.net.base import normalize_out_indices
41
42
  from birder.net.vit import PatchEmbed
42
43
  from birder.net.vit import adjust_position_embedding
43
44
 
@@ -76,7 +77,7 @@ def build_rotary_pos_embed(
76
77
 
77
78
  def rotate_half(x: torch.Tensor) -> torch.Tensor:
78
79
  # Taken from: https://github.com/facebookresearch/capi/blob/main/model.py
79
- (x1, x2) = x.chunk(2, dim=-1)
80
+ x1, x2 = x.chunk(2, dim=-1)
80
81
  return torch.concat((-x2, x1), dim=-1)
81
82
 
82
83
 
@@ -85,7 +86,7 @@ def rotate_half_interleaved(x: torch.Tensor) -> torch.Tensor:
85
86
 
86
87
 
87
88
  def apply_rotary_pos_embed(x: torch.Tensor, embed: torch.Tensor) -> torch.Tensor:
88
- (sin_emb, cos_emb) = embed.tensor_split(2, dim=-1)
89
+ sin_emb, cos_emb = embed.tensor_split(2, dim=-1)
89
90
  if cos_emb.ndim == 3:
90
91
  return x * cos_emb.unsqueeze(1).expand_as(x) + rotate_half(x) * sin_emb.unsqueeze(1).expand_as(x)
91
92
 
@@ -93,7 +94,7 @@ def apply_rotary_pos_embed(x: torch.Tensor, embed: torch.Tensor) -> torch.Tensor
93
94
 
94
95
 
95
96
  def apply_interleaved_rotary_pos_embed(x: torch.Tensor, embed: torch.Tensor) -> torch.Tensor:
96
- (sin_emb, cos_emb) = embed.tensor_split(2, dim=-1)
97
+ sin_emb, cos_emb = embed.tensor_split(2, dim=-1)
97
98
  if cos_emb.ndim == 3:
98
99
  return x * cos_emb.unsqueeze(1).expand_as(x) + rotate_half_interleaved(x) * sin_emb.unsqueeze(1).expand_as(x)
99
100
 
@@ -128,7 +129,7 @@ class RoPE(nn.Module):
128
129
  else:
129
130
  raise ValueError(f"Unknown rope_rot_type, got '{rope_rot_type}'")
130
131
 
131
- (sin_emb, cos_emb) = build_rotary_pos_embed(
132
+ sin_emb, cos_emb = build_rotary_pos_embed(
132
133
  dim,
133
134
  temperature,
134
135
  grid_size=grid_size,
@@ -185,9 +186,9 @@ class RoPEAttention(nn.Module):
185
186
  self.proj_drop = nn.Dropout(proj_drop)
186
187
 
187
188
  def forward(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
188
- (B, N, C) = x.size()
189
+ B, N, C = x.size()
189
190
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
190
- (q, k, v) = qkv.unbind(0)
191
+ q, k, v = qkv.unbind(0)
191
192
  q = self.q_norm(q)
192
193
  k = self.k_norm(k)
193
194
 
@@ -326,13 +327,17 @@ class Encoder(nn.Module):
326
327
  x = self.pre_block(x)
327
328
  return self.block(x, rope)
328
329
 
329
- def forward_features(self, x: torch.Tensor, rope: torch.Tensor) -> list[torch.Tensor]:
330
+ def forward_features(
331
+ self, x: torch.Tensor, rope: torch.Tensor, out_indices: Optional[list[int]] = None
332
+ ) -> list[torch.Tensor]:
330
333
  x = self.pre_block(x)
331
334
 
335
+ out_indices_set = set(out_indices) if out_indices is not None else None
332
336
  xs = []
333
- for blk in self.block:
337
+ for idx, blk in enumerate(self.block):
334
338
  x = blk(x, rope)
335
- xs.append(x)
339
+ if out_indices_set is None or idx in out_indices_set:
340
+ xs.append(x)
336
341
 
337
342
  return xs
338
343
 
@@ -438,6 +443,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
438
443
  norm_layer_eps: float = self.config.get("norm_layer_eps", 1e-6)
439
444
  mlp_layer_type: str = self.config.get("mlp_layer_type", "FFN")
440
445
  act_layer_type: Optional[str] = self.config.get("act_layer_type", None) # Default according to mlp type
446
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
441
447
  rope_rot_type: Literal["standard", "interleaved"] = self.config.get("rope_rot_type", "standard")
442
448
  rope_grid_indexing: Literal["ij", "xy"] = self.config.get("rope_grid_indexing", "ij")
443
449
  rope_grid_offset: int = self.config.get("rope_grid_offset", 0)
@@ -479,6 +485,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
479
485
  self.norm_layer_eps = norm_layer_eps
480
486
  self.mlp_layer = mlp_layer
481
487
  self.act_layer = act_layer
488
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
482
489
  self.rope_rot_type = rope_rot_type
483
490
  self.rope_grid_indexing = rope_grid_indexing
484
491
  self.rope_grid_offset = rope_grid_offset
@@ -571,8 +578,9 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
571
578
 
572
579
  self.attn_pool = MultiHeadAttentionPool(hidden_dim, attn_pool_num_heads, mlp_dim, qkv_bias=True)
573
580
 
574
- self.return_stages = ["neck"] # Actually meaningless, just for completeness
575
- self.return_channels = [hidden_dim]
581
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
582
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
583
+ self.return_channels = [hidden_dim] * num_return_stages
576
584
  self.embedding_size = hidden_dim
577
585
  self.classifier = self.create_classifier()
578
586
 
@@ -658,8 +666,12 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
658
666
  def set_causal_attention(self, is_causal: bool = True) -> None:
659
667
  self.encoder.set_causal_attention(is_causal)
660
668
 
669
+ def transform_to_backbone(self) -> None:
670
+ super().transform_to_backbone()
671
+ self.norm = nn.Identity()
672
+
661
673
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
662
- (H, W) = x.shape[-2:]
674
+ H, W = x.shape[-2:]
663
675
  x = self.conv_proj(x)
664
676
  x = self.patch_embed(x)
665
677
 
@@ -679,15 +691,21 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
679
691
  if self.pos_embed_special_tokens is True:
680
692
  x = x + self._get_pos_embed(H, W)
681
693
 
682
- x = self.encoder(x, self._get_rope_embed(H, W))
683
- x = self.norm(x)
694
+ rope = self._get_rope_embed(H, W)
695
+ if self.out_indices is None:
696
+ xs = [self.encoder(x, rope)]
697
+ else:
698
+ xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
684
699
 
685
- x = x[:, self.num_special_tokens :]
686
- x = x.permute(0, 2, 1)
687
- (B, C, _) = x.size()
688
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
700
+ out: dict[str, torch.Tensor] = {}
701
+ for stage_name, stage_x in zip(self.return_stages, xs):
702
+ stage_x = stage_x[:, self.num_special_tokens :]
703
+ stage_x = stage_x.permute(0, 2, 1)
704
+ B, C, _ = stage_x.size()
705
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
706
+ out[stage_name] = stage_x
689
707
 
690
- return {self.return_stages[0]: x}
708
+ return out
691
709
 
692
710
  def freeze_stages(self, up_to_stage: int) -> None:
693
711
  for param in self.conv_proj.parameters():
@@ -709,7 +727,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
709
727
  return_all_features: bool = False,
710
728
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
711
729
  ) -> TokenOmissionResultType:
712
- (H, W) = x.shape[-2:]
730
+ H, W = x.shape[-2:]
713
731
 
714
732
  # Reshape and permute the input tensor
715
733
  x = self.conv_proj(x)
@@ -789,7 +807,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
789
807
  mask_token: Optional[torch.Tensor] = None,
790
808
  return_keys: Literal["all", "features", "embedding"] = "features",
791
809
  ) -> TokenRetentionResultType:
792
- (H, W) = x.shape[-2:]
810
+ H, W = x.shape[-2:]
793
811
 
794
812
  x = self.conv_proj(x)
795
813
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -820,7 +838,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
820
838
  if return_keys in ("all", "features"):
821
839
  features = x[:, self.num_special_tokens :]
822
840
  features = features.permute(0, 2, 1)
823
- (B, C, _) = features.size()
841
+ B, C, _ = features.size()
824
842
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
825
843
  result["features"] = features
826
844
 
@@ -840,7 +858,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
840
858
  return result
841
859
 
842
860
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
843
- (H, W) = x.shape[-2:]
861
+ H, W = x.shape[-2:]
844
862
 
845
863
  # Reshape and permute the input tensor
846
864
  x = self.conv_proj(x)
birder/net/sequencer2d.py CHANGED
@@ -57,16 +57,16 @@ class LSTM2d(nn.Module):
57
57
  )
58
58
 
59
59
  def forward(self, x: torch.Tensor) -> torch.Tensor:
60
- (B, H, W, C) = x.shape
60
+ B, H, W, C = x.shape
61
61
 
62
62
  v = x.permute(0, 2, 1, 3)
63
63
  v = v.reshape(-1, H, C)
64
- (v, _) = self.rnn_v(v)
64
+ v, _ = self.rnn_v(v)
65
65
  v = v.reshape(B, W, H, -1)
66
66
  v = v.permute(0, 2, 1, 3)
67
67
 
68
68
  h = x.reshape(-1, W, C)
69
- (h, _) = self.rnn_h(h)
69
+ h, _ = self.rnn_h(h)
70
70
  h = h.reshape(B, H, W, -1)
71
71
 
72
72
  x = torch.concat([v, h], dim=-1)
@@ -187,7 +187,6 @@ class Sequencer2d(BaseNet):
187
187
  kernel_size=(patch_sizes[0], patch_sizes[0]),
188
188
  stride=(patch_sizes[0], patch_sizes[0]),
189
189
  padding=(0, 0),
190
- bias=True,
191
190
  ),
192
191
  Permute([0, 2, 3, 1]),
193
192
  )
@@ -22,7 +22,7 @@ from birder.net.base import DetectorBackbone
22
22
 
23
23
 
24
24
  def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
25
- (batch_size, num_channels, height, width) = x.size()
25
+ batch_size, num_channels, height, width = x.size()
26
26
  channels_per_group = num_channels // groups
27
27
 
28
28
  # Reshape
@@ -85,7 +85,7 @@ class ShuffleUnit(nn.Module):
85
85
 
86
86
  def forward(self, x: torch.Tensor) -> torch.Tensor:
87
87
  if self.dw_conv_stride == 1:
88
- (branch1, branch2) = x.chunk(2, dim=1)
88
+ branch1, branch2 = x.chunk(2, dim=1)
89
89
  x = torch.concat((branch1, self.branch2(branch2)), dim=1)
90
90
  else:
91
91
  x = torch.concat((self.branch1(x), self.branch2(x)), dim=1)
birder/net/simple_vit.py CHANGED
@@ -26,17 +26,19 @@ from birder.net._vit_configs import HUGE
26
26
  from birder.net._vit_configs import LARGE
27
27
  from birder.net._vit_configs import MEDIUM
28
28
  from birder.net._vit_configs import SMALL
29
+ from birder.net.base import DetectorBackbone
29
30
  from birder.net.base import MaskedTokenOmissionMixin
30
31
  from birder.net.base import PreTrainEncoder
31
32
  from birder.net.base import TokenOmissionResultType
33
+ from birder.net.base import normalize_out_indices
32
34
  from birder.net.base import pos_embedding_sin_cos_2d
33
35
  from birder.net.vit import Encoder
34
36
  from birder.net.vit import EncoderBlock
35
37
  from birder.net.vit import PatchEmbed
36
38
 
37
39
 
38
- # pylint: disable=invalid-name
39
- class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
40
+ # pylint: disable=invalid-name,too-many-instance-attributes
41
+ class Simple_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
40
42
  block_group_regex = r"encoder\.block\.(\d+)"
41
43
 
42
44
  def __init__(
@@ -56,6 +58,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
56
58
  num_heads: int = self.config["num_heads"]
57
59
  hidden_dim: int = self.config["hidden_dim"]
58
60
  mlp_dim: int = self.config["mlp_dim"]
61
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
59
62
  drop_path_rate: float = self.config["drop_path_rate"]
60
63
 
61
64
  torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
@@ -66,6 +69,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
66
69
  self.hidden_dim = hidden_dim
67
70
  self.mlp_dim = mlp_dim
68
71
  self.num_special_tokens = 0
72
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
69
73
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
70
74
 
71
75
  self.conv_proj = nn.Conv2d(
@@ -74,7 +78,6 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
74
78
  kernel_size=(patch_size, patch_size),
75
79
  stride=(patch_size, patch_size),
76
80
  padding=(0, 0),
77
- bias=True,
78
81
  )
79
82
  self.patch_embed = PatchEmbed()
80
83
 
@@ -94,6 +97,9 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
94
97
  nn.Flatten(1),
95
98
  )
96
99
 
100
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
101
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
102
+ self.return_channels = [hidden_dim] * num_return_stages
97
103
  self.embedding_size = hidden_dim
98
104
  self.classifier = self.create_classifier()
99
105
 
@@ -144,7 +150,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
144
150
  return_all_features: bool = False,
145
151
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
146
152
  ) -> TokenOmissionResultType:
147
- (H, W) = x.shape[-2:]
153
+ H, W = x.shape[-2:]
148
154
 
149
155
  # Reshape and permute the input tensor
150
156
  x = self.conv_proj(x)
@@ -179,7 +185,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
179
185
  return result
180
186
 
181
187
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
182
- (H, W) = x.shape[-2:]
188
+ H, W = x.shape[-2:]
183
189
  x = self.conv_proj(x)
184
190
  x = self.patch_embed(x)
185
191
  x = x + self._get_pos_embed(H, W)
@@ -193,6 +199,42 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
193
199
  x = x.permute(0, 2, 1)
194
200
  return self.features(x)
195
201
 
202
+ def transform_to_backbone(self) -> None:
203
+ super().transform_to_backbone()
204
+ self.norm = nn.Identity()
205
+
206
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
207
+ H, W = x.shape[-2:]
208
+ x = self.conv_proj(x)
209
+ x = self.patch_embed(x)
210
+ x = x + self._get_pos_embed(H, W)
211
+
212
+ if self.out_indices is None:
213
+ xs = [self.encoder(x)]
214
+ else:
215
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
216
+
217
+ out: dict[str, torch.Tensor] = {}
218
+ for stage_name, stage_x in zip(self.return_stages, xs):
219
+ stage_x = stage_x[:, self.num_special_tokens :]
220
+ stage_x = stage_x.permute(0, 2, 1)
221
+ B, C, _ = stage_x.size()
222
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
223
+ out[stage_name] = stage_x
224
+
225
+ return out
226
+
227
+ def freeze_stages(self, up_to_stage: int) -> None:
228
+ for param in self.conv_proj.parameters():
229
+ param.requires_grad_(False)
230
+
231
+ for idx, module in enumerate(self.encoder.children()):
232
+ if idx >= up_to_stage:
233
+ break
234
+
235
+ for param in module.parameters():
236
+ param.requires_grad_(False)
237
+
196
238
  def adjust_size(self, new_size: tuple[int, int]) -> None:
197
239
  if new_size == self.size:
198
240
  return