birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/vit.py CHANGED
@@ -40,6 +40,7 @@ from birder.net.base import MaskedTokenRetentionMixin
40
40
  from birder.net.base import PreTrainEncoder
41
41
  from birder.net.base import TokenOmissionResultType
42
42
  from birder.net.base import TokenRetentionResultType
43
+ from birder.net.base import normalize_out_indices
43
44
 
44
45
 
45
46
  def adjust_position_embedding(
@@ -73,12 +74,10 @@ def adjust_position_embedding(
73
74
  class PatchEmbed(nn.Module):
74
75
  def forward(self, x: torch.Tensor) -> torch.Tensor:
75
76
  """
76
- The entire forward is equivalent to x.flatten(2).transpose(1, 2)
77
+ This is equivalent (in output) to: x.flatten(2).transpose(1, 2)
77
78
  """
78
79
 
79
- (n, hidden_dim, h, w) = x.size()
80
-
81
- # (n, hidden_dim, h, w) -> (n, hidden_dim, (h * w))
80
+ n, hidden_dim, h, w = x.size()
82
81
  x = x.reshape(n, hidden_dim, h * w)
83
82
 
84
83
  # (n, hidden_dim, (h * w)) -> (n, (h * w), hidden_dim)
@@ -155,9 +154,9 @@ class Attention(nn.Module):
155
154
  - attn_weights: If need_weights is True attention weights, otherwise, None.
156
155
  """
157
156
 
158
- (B, N, C) = x.size()
157
+ B, N, C = x.size()
159
158
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
160
- (q, k, v) = qkv.unbind(0)
159
+ q, k, v = qkv.unbind(0)
161
160
  q = self.q_norm(q)
162
161
  k = self.k_norm(k)
163
162
 
@@ -245,7 +244,7 @@ class EncoderBlock(nn.Module):
245
244
 
246
245
  def forward(self, x: torch.Tensor) -> torch.Tensor:
247
246
  # torch._assert(x.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {x.size()}")
248
- (attn_out, _) = self.attn(
247
+ attn_out, _ = self.attn(
249
248
  self.norm1(x),
250
249
  need_weights=self.need_attn,
251
250
  average_attn_weights=False,
@@ -317,13 +316,15 @@ class Encoder(nn.Module):
317
316
  x = self.pre_block(x)
318
317
  return self.block(x)
319
318
 
320
- def forward_features(self, x: torch.Tensor) -> list[torch.Tensor]:
319
+ def forward_features(self, x: torch.Tensor, out_indices: Optional[list[int]] = None) -> list[torch.Tensor]:
321
320
  x = self.pre_block(x)
322
321
 
322
+ out_indices_set = set(out_indices) if out_indices is not None else None
323
323
  xs = []
324
- for blk in self.block:
324
+ for idx, blk in enumerate(self.block):
325
325
  x = blk(x)
326
- xs.append(x)
326
+ if out_indices_set is None or idx in out_indices_set:
327
+ xs.append(x)
327
328
 
328
329
  return xs
329
330
 
@@ -340,7 +341,7 @@ class Encoder(nn.Module):
340
341
  class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTokenRetentionMixin):
341
342
  block_group_regex = r"encoder\.block\.(\d+)"
342
343
 
343
- # pylint: disable=too-many-locals,too-many-branches
344
+ # pylint: disable=too-many-locals,too-many-branches,too-many-statements
344
345
  def __init__(
345
346
  self,
346
347
  input_channels: int,
@@ -375,6 +376,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
375
376
  norm_layer_eps: float = self.config.get("norm_layer_eps", 1e-6)
376
377
  mlp_layer_type: str = self.config.get("mlp_layer_type", "FFN")
377
378
  act_layer_type: Optional[str] = self.config.get("act_layer_type", None) # Default according to mlp type
379
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
378
380
  drop_path_rate: float = self.config["drop_path_rate"]
379
381
 
380
382
  if norm_layer_type == "LayerNorm":
@@ -405,6 +407,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
405
407
  self.hidden_dim = hidden_dim
406
408
  self.num_reg_tokens = num_reg_tokens
407
409
  self.attn_pool_special_tokens = attn_pool_special_tokens
410
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
408
411
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
409
412
 
410
413
  self.conv_proj = nn.Conv2d(
@@ -472,8 +475,9 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
472
475
 
473
476
  self.attn_pool = MultiHeadAttentionPool(hidden_dim, attn_pool_num_heads, mlp_dim, qkv_bias=True)
474
477
 
475
- self.return_stages = ["neck"] # Actually meaningless, just for completeness
476
- self.return_channels = [hidden_dim]
478
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
479
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
480
+ self.return_channels = [hidden_dim] * num_return_stages
477
481
  self.embedding_size = hidden_dim
478
482
  self.classifier = self.create_classifier()
479
483
 
@@ -537,8 +541,12 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
537
541
  def set_causal_attention(self, is_causal: bool = True) -> None:
538
542
  self.encoder.set_causal_attention(is_causal)
539
543
 
544
+ def transform_to_backbone(self) -> None:
545
+ super().transform_to_backbone()
546
+ self.norm = nn.Identity()
547
+
540
548
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
541
- (H, W) = x.shape[-2:]
549
+ H, W = x.shape[-2:]
542
550
  x = self.conv_proj(x)
543
551
  x = self.patch_embed(x)
544
552
 
@@ -558,15 +566,20 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
558
566
  if self.pos_embed_special_tokens is True:
559
567
  x = x + self._get_pos_embed(H, W)
560
568
 
561
- x = self.encoder(x)
562
- x = self.norm(x)
569
+ if self.out_indices is None:
570
+ xs = [self.encoder(x)]
571
+ else:
572
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
563
573
 
564
- x = x[:, self.num_special_tokens :]
565
- x = x.permute(0, 2, 1)
566
- (B, C, _) = x.size()
567
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
574
+ out: dict[str, torch.Tensor] = {}
575
+ for stage_name, stage_x in zip(self.return_stages, xs):
576
+ stage_x = stage_x[:, self.num_special_tokens :]
577
+ stage_x = stage_x.permute(0, 2, 1)
578
+ B, C, _ = stage_x.size()
579
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
580
+ out[stage_name] = stage_x
568
581
 
569
- return {self.return_stages[0]: x}
582
+ return out
570
583
 
571
584
  def freeze_stages(self, up_to_stage: int) -> None:
572
585
  for param in self.conv_proj.parameters():
@@ -589,7 +602,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
589
602
  return_all_features: bool = False,
590
603
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
591
604
  ) -> TokenOmissionResultType:
592
- (H, W) = x.shape[-2:]
605
+ H, W = x.shape[-2:]
593
606
 
594
607
  # Reshape and permute the input tensor
595
608
  x = self.conv_proj(x)
@@ -663,7 +676,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
663
676
  mask_token: Optional[torch.Tensor] = None,
664
677
  return_keys: Literal["all", "features", "embedding"] = "features",
665
678
  ) -> TokenRetentionResultType:
666
- (H, W) = x.shape[-2:]
679
+ H, W = x.shape[-2:]
667
680
 
668
681
  x = self.conv_proj(x)
669
682
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -694,7 +707,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
694
707
  if return_keys in ("all", "features"):
695
708
  features = x[:, self.num_special_tokens :]
696
709
  features = features.permute(0, 2, 1)
697
- (B, C, _) = features.size()
710
+ B, C, _ = features.size()
698
711
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
699
712
  result["features"] = features
700
713
 
@@ -714,7 +727,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
714
727
  return result
715
728
 
716
729
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
717
- (H, W) = x.shape[-2:]
730
+ H, W = x.shape[-2:]
718
731
 
719
732
  # Reshape and permute the input tensor
720
733
  x = self.conv_proj(x)
@@ -31,6 +31,7 @@ from birder.net.base import MaskedTokenRetentionMixin
31
31
  from birder.net.base import PreTrainEncoder
32
32
  from birder.net.base import TokenOmissionResultType
33
33
  from birder.net.base import TokenRetentionResultType
34
+ from birder.net.base import normalize_out_indices
34
35
  from birder.net.vit import PatchEmbed
35
36
  from birder.net.vit import adjust_position_embedding
36
37
 
@@ -51,9 +52,9 @@ class Attention(nn.Module):
51
52
  self.proj_drop = nn.Dropout(proj_drop)
52
53
 
53
54
  def forward(self, x: torch.Tensor) -> torch.Tensor:
54
- (B, N, C) = x.size()
55
+ B, N, C = x.size()
55
56
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
56
- (q, k, v) = qkv.unbind(0)
57
+ q, k, v = qkv.unbind(0)
57
58
 
58
59
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
59
60
  q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, is_causal=self.is_causal, scale=self.scale
@@ -172,11 +173,13 @@ class Encoder(nn.Module):
172
173
 
173
174
  return x
174
175
 
175
- def forward_features(self, x: torch.Tensor) -> list[torch.Tensor]:
176
+ def forward_features(self, x: torch.Tensor, out_indices: Optional[list[int]] = None) -> list[torch.Tensor]:
176
177
  xs = []
177
- for blk in self.block:
178
+ out_indices_set = set(out_indices) if out_indices is not None else None
179
+ for idx, blk in enumerate(self.block):
178
180
  x = blk(x)
179
- xs.append(x)
181
+ if out_indices_set is None or idx in out_indices_set:
182
+ xs.append(x)
180
183
 
181
184
  return xs
182
185
 
@@ -213,6 +216,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
213
216
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
214
217
  class_token: bool = self.config.get("class_token", True)
215
218
  norm_layer_type: str = self.config.get("norm_layer_type", "LayerNorm")
219
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
216
220
  drop_path_rate: float = self.config["drop_path_rate"]
217
221
 
218
222
  if norm_layer_type == "LayerNorm":
@@ -230,6 +234,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
230
234
  self.hidden_dim = hidden_dim
231
235
  self.layer_scale_init_value = layer_scale_init_value
232
236
  self.num_reg_tokens = num_reg_tokens
237
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
233
238
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
234
239
 
235
240
  self.conv_proj = nn.Conv2d(
@@ -238,7 +243,6 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
238
243
  kernel_size=(patch_size, patch_size),
239
244
  stride=(patch_size, patch_size),
240
245
  padding=(0, 0),
241
- bias=True,
242
246
  )
243
247
  self.patch_embed = PatchEmbed()
244
248
 
@@ -278,8 +282,9 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
278
282
  )
279
283
  self.norm = norm_layer(hidden_dim, eps=1e-6)
280
284
 
281
- self.return_stages = ["neck"] # Actually meaningless, but for completeness
282
- self.return_channels = [hidden_dim]
285
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
286
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
287
+ self.return_channels = [hidden_dim] * num_return_stages
283
288
  self.embedding_size = hidden_dim
284
289
  self.classifier = self.create_classifier()
285
290
 
@@ -338,8 +343,12 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
338
343
  def set_causal_attention(self, is_causal: bool = True) -> None:
339
344
  self.encoder.set_causal_attention(is_causal)
340
345
 
346
+ def transform_to_backbone(self) -> None:
347
+ super().transform_to_backbone()
348
+ self.norm = nn.Identity()
349
+
341
350
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
342
- (H, W) = x.shape[-2:]
351
+ H, W = x.shape[-2:]
343
352
  x = self.conv_proj(x)
344
353
  x = self.patch_embed(x)
345
354
 
@@ -354,15 +363,21 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
354
363
  x = torch.concat([batch_reg_tokens, x], dim=1)
355
364
 
356
365
  x = x + self._get_pos_embed(H, W)
357
- x = self.encoder(x)
358
- x = self.norm(x)
359
366
 
360
- x = x[:, self.num_special_tokens :]
361
- x = x.permute(0, 2, 1)
362
- (B, C, _) = x.size()
363
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
367
+ if self.out_indices is None:
368
+ xs = [self.encoder(x)]
369
+ else:
370
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
371
+
372
+ out: dict[str, torch.Tensor] = {}
373
+ for stage_name, stage_x in zip(self.return_stages, xs):
374
+ stage_x = stage_x[:, self.num_special_tokens :]
375
+ stage_x = stage_x.permute(0, 2, 1)
376
+ B, C, _ = stage_x.size()
377
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
378
+ out[stage_name] = stage_x
364
379
 
365
- return {self.return_stages[0]: x}
380
+ return out
366
381
 
367
382
  def freeze_stages(self, up_to_stage: int) -> None:
368
383
  for param in self.conv_proj.parameters():
@@ -384,7 +399,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
384
399
  return_all_features: bool = False,
385
400
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
386
401
  ) -> TokenOmissionResultType:
387
- (H, W) = x.shape[-2:]
402
+ H, W = x.shape[-2:]
388
403
 
389
404
  # Reshape and permute the input tensor
390
405
  x = self.conv_proj(x)
@@ -441,7 +456,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
441
456
  mask_token: Optional[torch.Tensor] = None,
442
457
  return_keys: Literal["all", "features", "embedding"] = "features",
443
458
  ) -> TokenRetentionResultType:
444
- (H, W) = x.shape[-2:]
459
+ H, W = x.shape[-2:]
445
460
 
446
461
  x = self.conv_proj(x)
447
462
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -467,7 +482,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
467
482
  if return_keys in ("all", "features"):
468
483
  features = x[:, self.num_special_tokens :]
469
484
  features = features.permute(0, 2, 1)
470
- (B, C, _) = features.size()
485
+ B, C, _ = features.size()
471
486
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
472
487
  result["features"] = features
473
488
 
@@ -481,7 +496,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
481
496
  return result
482
497
 
483
498
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
484
- (H, W) = x.shape[-2:]
499
+ H, W = x.shape[-2:]
485
500
 
486
501
  # Reshape and permute the input tensor
487
502
  x = self.conv_proj(x)
birder/net/vit_sam.py CHANGED
@@ -35,7 +35,7 @@ from birder.net.vit import EncoderBlock as MAEDecoderBlock
35
35
 
36
36
  # pylint: disable=invalid-name
37
37
  def window_partition(x: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
38
- (B, H, W, C) = x.shape
38
+ B, H, W, C = x.shape
39
39
 
40
40
  pad_h = (window_size - H % window_size) % window_size
41
41
  pad_w = (window_size - W % window_size) % window_size
@@ -55,8 +55,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> tuple[torch.Tensor, t
55
55
  def window_unpartition(
56
56
  windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]
57
57
  ) -> torch.Tensor:
58
- (Hp, Wp) = pad_hw
59
- (H, W) = hw
58
+ Hp, Wp = pad_hw
59
+ H, W = hw
60
60
  B = windows.shape[0] // (Hp * Wp // window_size // window_size)
61
61
  x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
62
62
  x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
@@ -91,12 +91,12 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
91
91
  def get_decomposed_rel_pos_bias(
92
92
  q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: tuple[int, int], k_size: tuple[int, int]
93
93
  ) -> torch.Tensor:
94
- (q_h, q_w) = q_size
95
- (k_h, k_w) = k_size
94
+ q_h, q_w = q_size
95
+ k_h, k_w = k_size
96
96
  Rh = get_rel_pos(q_h, k_h, rel_pos_h)
97
97
  Rw = get_rel_pos(q_w, k_w, rel_pos_w)
98
98
 
99
- (B, _, dim) = q.shape
99
+ B, _, dim = q.shape
100
100
  r_q = q.reshape(B, q_h, q_w, dim)
101
101
  rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
102
102
  rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
@@ -139,9 +139,9 @@ class Attention(nn.Module):
139
139
  self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
140
140
 
141
141
  def forward(self, x: torch.Tensor) -> torch.Tensor:
142
- (B, H, W, _) = x.shape
142
+ B, H, W, _ = x.shape
143
143
  qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
144
- (q, k, v) = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
144
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
145
145
 
146
146
  if self.use_rel_pos is True:
147
147
  attn_bias = get_decomposed_rel_pos_bias(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
@@ -216,13 +216,13 @@ class EncoderBlock(nn.Module):
216
216
  self.layer_scale_2 = nn.Identity()
217
217
 
218
218
  def forward(self, x: torch.Tensor) -> torch.Tensor:
219
- (_, H, W, _) = x.shape
219
+ _, H, W, _ = x.shape
220
220
  shortcut = x
221
221
 
222
222
  x = self.norm1(x)
223
223
  pad_hw = (0, 0)
224
224
  if self.window_size > 0:
225
- (x, pad_hw) = window_partition(x, self.window_size)
225
+ x, pad_hw = window_partition(x, self.window_size)
226
226
 
227
227
  x = self.attn(x)
228
228
  if self.window_size > 0:
birder/net/vovnet_v2.py CHANGED
@@ -27,7 +27,7 @@ class EffectiveSE(nn.Module):
27
27
 
28
28
  def __init__(self, channels: int) -> None:
29
29
  super().__init__()
30
- self.fc = nn.Conv2d(channels, channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
30
+ self.fc = nn.Conv2d(channels, channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
31
31
 
32
32
  def forward(self, x: torch.Tensor) -> torch.Tensor:
33
33
  x_se = x.mean(dim=(2, 3), keepdim=True)
birder/net/xcit.py CHANGED
@@ -30,6 +30,7 @@ from birder.net.base import DetectorBackbone
30
30
  from birder.net.base import MaskedTokenRetentionMixin
31
31
  from birder.net.base import PreTrainEncoder
32
32
  from birder.net.base import TokenRetentionResultType
33
+ from birder.net.base import normalize_out_indices
33
34
  from birder.net.cait import ClassAttention
34
35
 
35
36
 
@@ -212,7 +213,7 @@ class LPI(nn.Module):
212
213
  )
213
214
 
214
215
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
215
- (B, N, C) = x.shape
216
+ B, N, C = x.shape
216
217
  x = x.permute(0, 2, 1).reshape(B, C, H, W)
217
218
  x = self.conv_bn_act(x)
218
219
  x = self.conv(x)
@@ -236,10 +237,10 @@ class XCA(nn.Module):
236
237
  self.proj_drop = nn.Dropout(proj_drop)
237
238
 
238
239
  def forward(self, x: torch.Tensor) -> torch.Tensor:
239
- (B, N, C) = x.shape
240
+ B, N, C = x.shape
240
241
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
241
242
  qkv = qkv.permute(2, 0, 3, 1, 4)
242
- (q, k, v) = qkv.unbind(0)
243
+ q, k, v = qkv.unbind(0)
243
244
 
244
245
  q = F.normalize(q, dim=-1) * self.temperature
245
246
  k = F.normalize(k, dim=-1)
@@ -311,6 +312,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
311
312
  else:
312
313
  raise ValueError(f"depth={depth} is not supported")
313
314
 
315
+ out_indices = normalize_out_indices(out_indices, depth)
314
316
  self.patch_embed = ConvPatchEmbed(patch_size, self.input_channels, dim=embed_dim)
315
317
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
316
318
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # Stochastic depth decay rule
@@ -381,7 +383,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
381
383
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
382
384
  B = x.size(0)
383
385
 
384
- (x, H, W) = self.patch_embed(x)
386
+ x, H, W = self.patch_embed(x)
385
387
 
386
388
  pos_encoding = self.pos_embed(B, H, W).reshape(B, -1, x.size(1)).permute(0, 2, 1)
387
389
  x = x + pos_encoding
@@ -414,7 +416,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
414
416
  ) -> TokenRetentionResultType:
415
417
  B = x.size(0)
416
418
 
417
- (x, H, W) = self.patch_embed(x)
419
+ x, H, W = self.patch_embed(x)
418
420
  x = mask_tensor(
419
421
  x.permute(0, 2, 1).reshape(B, -1, H, W),
420
422
  mask,
@@ -435,7 +437,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
435
437
  if return_keys in ("all", "features"):
436
438
  features = x[:, 1:]
437
439
  features = features.permute(0, 2, 1)
438
- (B, C, _) = features.size()
440
+ B, C, _ = features.size()
439
441
  features = features.reshape(B, C, H, W)
440
442
  result["features"] = features
441
443
 
@@ -447,7 +449,7 @@ class XCiT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
447
449
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
448
450
  B = x.size(0)
449
451
 
450
- (x, H, W) = self.patch_embed(x)
452
+ x, H, W = self.patch_embed(x)
451
453
 
452
454
  pos_encoding = self.pos_embed(B, H, W).reshape(B, -1, x.size(1)).permute(0, 2, 1)
453
455
  x = x + pos_encoding
birder/ops/msda.py CHANGED
@@ -91,8 +91,8 @@ def _ms_deform_attn_setup_context( # type: ignore[no-untyped-def] # pylint: dis
91
91
 
92
92
 
93
93
  def _ms_deform_attn_backward(ctx, grad_output): # type: ignore[no-untyped-def]
94
- (value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) = ctx.saved_tensors
95
- (grad_value, grad_sampling_loc, grad_attn_weight) = ms_deform_attn_backward_op(
94
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
95
+ grad_value, grad_sampling_loc, grad_attn_weight = ms_deform_attn_backward_op(
96
96
  value,
97
97
  value_spatial_shapes,
98
98
  value_level_start_index,
@@ -160,8 +160,8 @@ def multi_scale_deformable_attention(
160
160
  attention_weights: torch.Tensor,
161
161
  im2col_step: int, # pylint: disable=unused-argument
162
162
  ) -> torch.Tensor:
163
- (batch_size, _, num_heads, hidden_dim) = value.size()
164
- (_, num_queries, num_heads, num_levels, num_points, _) = sampling_locations.size()
163
+ batch_size, _, num_heads, hidden_dim = value.size()
164
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.size()
165
165
  areas: list[int] = value_spatial_shapes.prod(dim=1).tolist()
166
166
  value_list = value.split(areas, dim=1)
167
167
  sampling_grids = 2 * sampling_locations - 1
birder/ops/swattention.py CHANGED
@@ -38,7 +38,7 @@ def _swattention_qk_rpb_fake( # pylint: disable=unused-argument
38
38
  def _swattention_qk_rpb_setup_context( # type: ignore[no-untyped-def] # pylint: disable=unused-argument
39
39
  ctx, inputs, output
40
40
  ) -> None:
41
- (query, key, _rpb, height, width, kernel_size) = inputs
41
+ query, key, _rpb, height, width, kernel_size = inputs
42
42
  ctx.save_for_backward(query, key)
43
43
  ctx.height = height
44
44
  ctx.width = width
@@ -46,8 +46,8 @@ def _swattention_qk_rpb_setup_context( # type: ignore[no-untyped-def] # pylint:
46
46
 
47
47
 
48
48
  def _swattention_qk_rpb_backward(ctx, grad_output): # type: ignore[no-untyped-def]
49
- (query, key) = ctx.saved_tensors
50
- (d_query, d_key, d_rpb) = swattention_qk_rpb_backward_op(
49
+ query, key = ctx.saved_tensors
50
+ d_query, d_key, d_rpb = swattention_qk_rpb_backward_op(
51
51
  grad_output.contiguous(), query, key, ctx.height, ctx.width, ctx.kernel_size
52
52
  )
53
53
  return (d_query, d_key, d_rpb, None, None, None)
@@ -107,8 +107,8 @@ def _swattention_av_setup_context( # type: ignore[no-untyped-def] # pylint: dis
107
107
 
108
108
 
109
109
  def _swattention_av_backward(ctx, grad_output): # type: ignore[no-untyped-def]
110
- (attn_weight, value) = ctx.saved_tensors
111
- (d_attn_weight, d_value) = swattention_av_backward_op(
110
+ attn_weight, value = ctx.saved_tensors
111
+ d_attn_weight, d_value = swattention_av_backward_op(
112
112
  grad_output.contiguous(), attn_weight, value, ctx.height, ctx.width, ctx.kernel_size
113
113
  )
114
114
  return (d_attn_weight, d_value, None, None, None)
@@ -184,10 +184,10 @@ class SWAttention_QK_RPB(nn.Module):
184
184
  )
185
185
 
186
186
  # Custom kernel
187
- (B, N, _) = kv.size()
187
+ B, N, _ = kv.size()
188
188
 
189
189
  # Generate unfolded keys and values and l2-normalize them
190
- (k_local, v_local) = kv.reshape(B, N, 2 * num_heads, head_dim).permute(0, 2, 1, 3).chunk(2, dim=1)
190
+ k_local, v_local = kv.reshape(B, N, 2 * num_heads, head_dim).permute(0, 2, 1, 3).chunk(2, dim=1)
191
191
 
192
192
  # Compute local similarity
193
193
  attn_local = swattention_qk_rpb_op(
@@ -254,14 +254,14 @@ def swattention_qk_rpb(
254
254
  H: int,
255
255
  W: int,
256
256
  ) -> tuple[torch.Tensor, torch.Tensor]:
257
- (B, N, _) = kv.size()
257
+ B, N, _ = kv.size()
258
258
 
259
259
  # Generate unfolded keys and values and l2-normalize them
260
- (k_local, v_local) = kv.chunk(2, dim=-1)
260
+ k_local, v_local = kv.chunk(2, dim=-1)
261
261
  k_local = F.normalize(k_local.reshape(B, N, num_heads, head_dim), dim=-1).reshape(B, N, -1)
262
262
  kv_local = torch.concat([k_local, v_local], dim=-1).permute(0, 2, 1).reshape(B, -1, H, W)
263
263
 
264
- (k_local, v_local) = (
264
+ k_local, v_local = (
265
265
  F.unfold(kv_local, kernel_size=window_size, padding=window_size // 2, stride=1)
266
266
  .reshape(B, 2 * num_heads, head_dim, local_len, N)
267
267
  .permute(0, 1, 4, 2, 3)
@@ -30,7 +30,7 @@ def top_k_accuracy_score(y_true: npt.NDArray[Any], y_pred: npt.NDArray[np.float6
30
30
  if len(y_true.shape) == 2:
31
31
  y_true = np.argmax(y_true, axis=1)
32
32
 
33
- (num_samples, _num_labels) = y_pred.shape
33
+ num_samples, _num_labels = y_pred.shape
34
34
  indices: list[int] = []
35
35
  arg_sorted = np.argpartition(y_pred, -top_k, axis=1)[:, -top_k:]
36
36
  for i in range(num_samples):
@@ -693,7 +693,7 @@ class SparseResults(Results):
693
693
  For sparse files, this value is ignored.
694
694
  """
695
695
 
696
- (label_names, detected_sparse_k) = detect_file_format(path)
696
+ label_names, detected_sparse_k = detect_file_format(path)
697
697
 
698
698
  if detected_sparse_k is not None:
699
699
  schema_overrides = {
@@ -817,7 +817,7 @@ def load_results(path: str, lazy: bool = True) -> Results | SparseResults:
817
817
  <class 'birder.results.classification.SparseResults'>
818
818
  """
819
819
 
820
- (_, sparse_k) = detect_file_format(path)
820
+ _, sparse_k = detect_file_format(path)
821
821
 
822
822
  # Load using appropriate class
823
823
  if sparse_k is not None:
birder/results/gui.py CHANGED
@@ -212,7 +212,7 @@ class ConfusionMatrix:
212
212
  )
213
213
 
214
214
  offset = 0.5
215
- (height, width) = cnf_matrix.shape
215
+ height, width = cnf_matrix.shape
216
216
  ax.hlines(
217
217
  y=np.arange(height + 1) - offset,
218
218
  xmin=-offset,
@@ -261,7 +261,7 @@ class ROC:
261
261
  roc_auc = {}
262
262
  for i in results.unique_labels:
263
263
  binary_labels = results.labels == i
264
- (fpr[i], tpr[i], _) = roc_curve(binary_labels, results.output[:, i])
264
+ fpr[i], tpr[i], _ = roc_curve(binary_labels, results.output[:, i])
265
265
  if np.sum(binary_labels) == 0:
266
266
  tpr[i] = np.zeros_like(fpr[i])
267
267
 
@@ -324,7 +324,7 @@ class PrecisionRecall:
324
324
  labels = label_binarize(results.labels, classes=range(len(results.label_names)))
325
325
 
326
326
  # A "micro-average" quantifying score on all classes jointly
327
- (precision, recall, _) = precision_recall_curve(labels.ravel(), results.output.ravel())
327
+ precision, recall, _ = precision_recall_curve(labels.ravel(), results.output.ravel())
328
328
  average_precision = average_precision_score(labels.ravel(), results.output.ravel(), average="micro")
329
329
 
330
330
  line = ax.step(recall, precision, linestyle=":", where="post")
@@ -334,7 +334,7 @@ class PrecisionRecall:
334
334
  # Per selected class
335
335
  for cls in pr_classes:
336
336
  i = results.label_names.index(cls)
337
- (precision, recall, _) = precision_recall_curve(labels[:, i], results.output[:, i])
337
+ precision, recall, _ = precision_recall_curve(labels[:, i], results.output[:, i])
338
338
  average_precision = average_precision_score(labels[:, i], results.output[:, i])
339
339
  line = ax.plot(recall, precision, lw=2)
340
340
  legend_lines.append(line[0])
@@ -372,8 +372,8 @@ class ProbabilityHistogram:
372
372
  cls_a_df = results_df.filter(pl.col("label_name") == cls_a)
373
373
  cls_b_df = results_df.filter(pl.col("label_name") == cls_b)
374
374
 
375
- (cls_a_prob_a_counts, cls_a_prob_a_bins) = hist(cls_a_df[str(self.results.label_names.index(cls_a))])
376
- (cls_a_prob_b_counts, cls_a_prob_b_bins) = hist(cls_b_df[str(self.results.label_names.index(cls_a))])
375
+ cls_a_prob_a_counts, cls_a_prob_a_bins = hist(cls_a_df[str(self.results.label_names.index(cls_a))])
376
+ cls_a_prob_b_counts, cls_a_prob_b_bins = hist(cls_b_df[str(self.results.label_names.index(cls_a))])
377
377
  plt.subplot(2, 1, 1)
378
378
  plt.stairs(
379
379
  cls_a_prob_a_counts,
@@ -391,8 +391,8 @@ class ProbabilityHistogram:
391
391
  )
392
392
  plt.legend(loc="upper center")
393
393
 
394
- (cls_b_prob_a_counts, cls_b_prob_a_bins) = hist(cls_a_df[str(self.results.label_names.index(cls_b))])
395
- (cls_b_prob_b_counts, cls_b_prob_b_bins) = hist(cls_b_df[str(self.results.label_names.index(cls_b))])
394
+ cls_b_prob_a_counts, cls_b_prob_a_bins = hist(cls_a_df[str(self.results.label_names.index(cls_b))])
395
+ cls_b_prob_b_counts, cls_b_prob_b_bins = hist(cls_b_df[str(self.results.label_names.index(cls_b))])
396
396
  plt.subplot(2, 1, 2)
397
397
  plt.stairs(
398
398
  cls_b_prob_b_counts,