birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -66,7 +66,7 @@ def compute_attribution_rollout(
66
66
 
67
67
  mask = mask / (mask.max() + 1e-8)
68
68
 
69
- (grid_h, grid_w) = patch_grid_shape
69
+ grid_h, grid_w = patch_grid_shape
70
70
  mask = mask.reshape(grid_h, grid_w)
71
71
 
72
72
  return mask
@@ -140,7 +140,7 @@ class TransformerAttribution:
140
140
  self.gatherer = AttributionGatherer(net, attention_layer_name)
141
141
 
142
142
  def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
143
- (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
143
+ input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
144
144
  input_tensor.requires_grad_(True)
145
145
 
146
146
  self.net.zero_grad()
@@ -156,7 +156,7 @@ class TransformerAttribution:
156
156
 
157
157
  attribution_data = self.gatherer.get_captured_data()
158
158
 
159
- (_, _, H, W) = input_tensor.shape
159
+ _, _, H, W = input_tensor.shape
160
160
  patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
161
161
 
162
162
  attribution_map = compute_attribution_rollout(
@@ -39,13 +39,13 @@ class MultiHeadAttentionPool(nn.Module):
39
39
  nn.init.trunc_normal_(self.latent, std=dim**-0.5)
40
40
 
41
41
  def forward(self, x: torch.Tensor) -> torch.Tensor:
42
- (B, N, C) = x.size()
42
+ B, N, C = x.size()
43
43
 
44
44
  q_latent = self.latent.expand(B, self.latent_len, -1)
45
45
  q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
46
46
 
47
47
  kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
48
- (k, v) = kv.unbind(0)
48
+ k, v = kv.unbind(0)
49
49
 
50
50
  x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # pylint: disable=not-callable
51
51
  x = x.transpose(1, 2).reshape(B, self.latent_len, C)
@@ -6,6 +6,7 @@ from typing import Any
6
6
  from typing import Literal
7
7
  from typing import Optional
8
8
 
9
+ from birder.conf.settings import DEFAULT_NUM_CHANNELS
9
10
  from birder.model_registry import manifest
10
11
 
11
12
  if TYPE_CHECKING is True:
@@ -229,8 +230,8 @@ class ModelRegistry:
229
230
  def net_factory(
230
231
  self,
231
232
  name: str,
232
- input_channels: int,
233
233
  num_classes: int,
234
+ input_channels: int = DEFAULT_NUM_CHANNELS,
234
235
  *,
235
236
  config: Optional[dict[str, Any]] = None,
236
237
  size: Optional[tuple[int, int]] = None,
birder/net/__init__.py CHANGED
@@ -6,6 +6,7 @@ from birder.net.coat import CoaT
6
6
  from birder.net.conv2former import Conv2Former
7
7
  from birder.net.convmixer import ConvMixer
8
8
  from birder.net.convnext_v1 import ConvNeXt_v1
9
+ from birder.net.convnext_v1_iso import ConvNeXt_v1_Isotropic
9
10
  from birder.net.convnext_v2 import ConvNeXt_v2
10
11
  from birder.net.crossformer import CrossFormer
11
12
  from birder.net.crossvit import CrossViT
@@ -118,6 +119,7 @@ __all__ = [
118
119
  "Conv2Former",
119
120
  "ConvMixer",
120
121
  "ConvNeXt_v1",
122
+ "ConvNeXt_v1_Isotropic",
121
123
  "ConvNeXt_v2",
122
124
  "CrossFormer",
123
125
  "CrossViT",
@@ -88,6 +88,11 @@ def register_rope_vit_configs(rope_vit: type[BaseNet]) -> None:
88
88
  rope_vit,
89
89
  config={"patch_size": 16, **SMALL},
90
90
  )
91
+ registry.register_model_config(
92
+ "rope_vit_s16_avg",
93
+ rope_vit,
94
+ config={"patch_size": 16, **SMALL, "class_token": False},
95
+ )
91
96
  registry.register_model_config(
92
97
  "rope_i_vit_s16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
93
98
  rope_vit,
@@ -215,19 +215,6 @@ def register_vit_configs(vit: type[BaseNet]) -> None:
215
215
  "drop_path_rate": 0.1,
216
216
  },
217
217
  )
218
- registry.register_model_config( # From "Scaling Vision Transformers to 22 Billion Parameters"
219
- "vit_22b_p16_qkn",
220
- vit,
221
- config={
222
- "patch_size": 16,
223
- "num_layers": 48,
224
- "num_heads": 48,
225
- "hidden_dim": 6144,
226
- "mlp_dim": 24576,
227
- "qk_norm": True,
228
- "drop_path_rate": 0.1,
229
- },
230
- )
231
218
 
232
219
  # With registers
233
220
  ####################
birder/net/alexnet.py CHANGED
@@ -27,17 +27,17 @@ class AlexNet(BaseNet):
27
27
  assert self.config is None, "config not supported"
28
28
 
29
29
  self.body = nn.Sequential(
30
- nn.Conv2d(self.input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2), bias=True),
30
+ nn.Conv2d(self.input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)),
31
31
  nn.ReLU(inplace=True),
32
32
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
33
- nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True),
33
+ nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
34
34
  nn.ReLU(inplace=True),
35
35
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
36
- nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
36
+ nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
37
37
  nn.ReLU(inplace=True),
38
- nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
38
+ nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
39
39
  nn.ReLU(inplace=True),
40
- nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True),
40
+ nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
41
41
  nn.ReLU(inplace=True),
42
42
  nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
43
43
  nn.AdaptiveAvgPool2d(output_size=(6, 6)),
birder/net/base.py CHANGED
@@ -5,6 +5,7 @@ from typing import Literal
5
5
  from typing import NotRequired
6
6
  from typing import Optional
7
7
  from typing import TypedDict
8
+ from typing import overload
8
9
 
9
10
  import torch
10
11
  import torch.nn.functional as F
@@ -54,6 +55,30 @@ def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> i
54
55
  return new_v
55
56
 
56
57
 
58
+ @overload
59
+ def normalize_out_indices(out_indices: None, num_layers: int) -> None: ...
60
+
61
+
62
+ @overload
63
+ def normalize_out_indices(out_indices: list[int], num_layers: int) -> list[int]: ...
64
+
65
+
66
+ def normalize_out_indices(out_indices: Optional[list[int]], num_layers: int) -> Optional[list[int]]:
67
+ if out_indices is None:
68
+ return None
69
+
70
+ normalized_indices = []
71
+ for idx in out_indices:
72
+ if idx < 0:
73
+ idx = num_layers + idx
74
+ if idx < 0 or idx >= num_layers:
75
+ raise ValueError(f"out_indices contains invalid index for num_layers={num_layers}")
76
+
77
+ normalized_indices.append(idx)
78
+
79
+ return normalized_indices
80
+
81
+
57
82
  # class MiscNet(nn.Module):
58
83
  # """
59
84
  # Base class for general-purpose neural networks with automatic model registration
@@ -137,8 +162,8 @@ class BaseNet(nn.Module):
137
162
 
138
163
  self.dynamic_size = False
139
164
 
140
- self.classifier: nn.Module
141
165
  self.embedding_size: int
166
+ self.classifier: nn.Module
142
167
 
143
168
  def create_classifier(self, embed_dim: Optional[int] = None) -> nn.Module:
144
169
  if self.num_classes == 0:
@@ -274,7 +299,7 @@ def pos_embedding_sin_cos_2d(
274
299
  ) -> torch.Tensor:
275
300
  # assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sin-cos emb"
276
301
 
277
- (y, x) = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
302
+ y, x = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
278
303
  omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1)
279
304
  omega = 1.0 / (temperature**omega)
280
305
 
@@ -294,7 +319,7 @@ def interpolate_attention_bias(
294
319
  new_resolution: tuple[int, int],
295
320
  mode: Literal["bilinear", "bicubic"] = "bicubic",
296
321
  ) -> torch.Tensor:
297
- (H, _) = attention_bias.size()
322
+ H, _ = attention_bias.size()
298
323
 
299
324
  # Interpolate
300
325
  orig_dtype = attention_bias.dtype
birder/net/biformer.py CHANGED
@@ -30,7 +30,7 @@ from birder.net.base import DetectorBackbone
30
30
 
31
31
 
32
32
  def _grid2seq(x: torch.Tensor, region_size: tuple[int, int], num_heads: int) -> tuple[torch.Tensor, int, int]:
33
- (B, C, H, W) = x.size()
33
+ B, C, H, W = x.size()
34
34
  region_h = H // region_size[0]
35
35
  region_w = W // region_size[1]
36
36
  x = x.view(B, num_heads, C // num_heads, region_h, region_size[0], region_w, region_size[1])
@@ -40,7 +40,7 @@ def _grid2seq(x: torch.Tensor, region_size: tuple[int, int], num_heads: int) ->
40
40
 
41
41
 
42
42
  def _seq2grid(x: torch.Tensor, region_h: int, region_w: int, region_size: tuple[int, int]) -> torch.Tensor:
43
- (bs, n_head, _, _, head_dim) = x.size()
43
+ bs, n_head, _, _, head_dim = x.size()
44
44
  x = x.view(bs, n_head, region_h, region_w, region_size[0], region_size[1], head_dim)
45
45
  x = torch.einsum("bmhwpqd->bmdhpwq", x).reshape(
46
46
  bs, n_head * head_dim, region_h * region_size[0], region_w * region_size[1]
@@ -60,7 +60,7 @@ def regional_routing_attention_torch(
60
60
  auto_pad: bool,
61
61
  ) -> tuple[torch.Tensor, torch.Tensor]:
62
62
  kv_region_size = region_size
63
- (bs, n_head, q_nregion, topk) = region_graph.size()
63
+ bs, n_head, q_nregion, topk = region_graph.size()
64
64
 
65
65
  # Pad to deal with any input size
66
66
  q_pad_b = 0
@@ -68,13 +68,13 @@ def regional_routing_attention_torch(
68
68
  kv_pad_b = 0
69
69
  kv_pad_r = 0
70
70
  if auto_pad is True:
71
- (_, _, h_q, w_q) = query.size()
71
+ _, _, h_q, w_q = query.size()
72
72
  q_pad_b = (region_size[0] - h_q % region_size[0]) % region_size[0]
73
73
  q_pad_r = (region_size[1] - w_q % region_size[1]) % region_size[1]
74
74
  if q_pad_b > 0 or q_pad_r > 0:
75
75
  query = F.pad(query, (0, q_pad_r, 0, q_pad_b))
76
76
 
77
- (_, _, h_k, w_k) = key.size()
77
+ _, _, h_k, w_k = key.size()
78
78
  kv_pad_b = (kv_region_size[0] - h_k % kv_region_size[0]) % kv_region_size[0]
79
79
  kv_pad_r = (kv_region_size[1] - w_k % kv_region_size[1]) % kv_region_size[1]
80
80
  if kv_pad_r > 0 or kv_pad_b > 0:
@@ -87,12 +87,12 @@ def regional_routing_attention_torch(
87
87
  w_k = None
88
88
 
89
89
  # To sequence format
90
- (query, q_region_h, q_region_w) = _grid2seq(query, region_size=region_size, num_heads=n_head)
91
- (key, _, _) = _grid2seq(key, region_size=kv_region_size, num_heads=n_head)
92
- (value, _, _) = _grid2seq(value, region_size=kv_region_size, num_heads=n_head)
90
+ query, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=n_head)
91
+ key, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=n_head)
92
+ value, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=n_head)
93
93
 
94
94
  # Gather key and values
95
- (bs, n_head, kv_nregion, kv_region_size, head_dim) = key.size()
95
+ bs, n_head, kv_nregion, kv_region_size, head_dim = key.size()
96
96
  broadcasted_region_graph = region_graph.view(bs, n_head, q_nregion, topk, 1, 1).expand(
97
97
  -1, -1, -1, -1, kv_region_size, head_dim
98
98
  )
@@ -146,12 +146,12 @@ class BiLevelRoutingAttention(nn.Module):
146
146
  self.output_linear = nn.Conv2d(dim, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
147
147
 
148
148
  def forward(self, x: torch.Tensor) -> torch.Tensor:
149
- (_, _, H, W) = x.size()
149
+ _, _, H, W = x.size()
150
150
  region_size = (H // self.n_win_h, W // self.n_win_w)
151
151
 
152
152
  # Linear projection
153
153
  qkv = self.qkv_linear(x)
154
- (q, k, v) = qkv.chunk(3, dim=1)
154
+ q, k, v = qkv.chunk(3, dim=1)
155
155
 
156
156
  # Region-to-region routing
157
157
  q_r = F.avg_pool2d( # pylint: disable=not-callable
@@ -163,11 +163,11 @@ class BiLevelRoutingAttention(nn.Module):
163
163
  q_r = q_r.permute(0, 2, 3, 1).flatten(1, 2) # (n, (hw), c)
164
164
  k_r = k_r.flatten(2, 3) # (n, c, (hw))
165
165
  a_r = q_r @ k_r
166
- (_, idx_r) = torch.topk(a_r, k=self.topk, dim=-1)
166
+ _, idx_r = torch.topk(a_r, k=self.topk, dim=-1)
167
167
  idx_r = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
168
168
 
169
169
  # Token to token attention
170
- (output, _) = regional_routing_attention_torch(
170
+ output, _ = regional_routing_attention_torch(
171
171
  q, k, v, scale=self.scale, region_graph=idx_r, region_size=region_size, auto_pad=True
172
172
  )
173
173
 
@@ -190,12 +190,12 @@ class Attention(nn.Module):
190
190
  self.proj_drop = nn.Dropout(proj_drop)
191
191
 
192
192
  def forward(self, x: torch.Tensor) -> torch.Tensor:
193
- (B, C, H, W) = x.size()
193
+ B, C, H, W = x.size()
194
194
  x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
195
195
 
196
196
  N = H * W
197
197
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
198
- (q, k, v) = qkv.unbind(0)
198
+ q, k, v = qkv.unbind(0)
199
199
 
200
200
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
201
201
  q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
@@ -237,8 +237,8 @@ class AttentionLePE(nn.Module):
237
237
  )
238
238
 
239
239
  def forward(self, x: torch.Tensor) -> torch.Tensor:
240
- (B, C, H, W) = x.size()
241
- (q, k, v) = self.qkv(x).chunk(3, dim=1)
240
+ B, C, H, W = x.size()
241
+ q, k, v = self.qkv(x).chunk(3, dim=1)
242
242
 
243
243
  attn = q.view(B, self.num_heads, self.head_dim, H * W).transpose(-1, -2) @ k.view(
244
244
  B, self.num_heads, self.head_dim, H * W
birder/net/cait.py CHANGED
@@ -47,7 +47,7 @@ class ClassAttention(nn.Module):
47
47
  self.proj_drop = nn.Dropout(proj_drop)
48
48
 
49
49
  def forward(self, x: torch.Tensor) -> torch.Tensor:
50
- (B, N, C) = x.shape
50
+ B, N, C = x.shape
51
51
  q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
52
52
  k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
53
53
  v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
@@ -103,7 +103,7 @@ class TalkingHeadAttn(nn.Module):
103
103
  self.proj_drop = nn.Dropout(proj_drop)
104
104
 
105
105
  def forward(self, x: torch.Tensor) -> torch.Tensor:
106
- (B, N, C) = x.shape
106
+ B, N, C = x.shape
107
107
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
108
108
  q = qkv[0] * self.scale
109
109
  k = qkv[1]
birder/net/cas_vit.py CHANGED
@@ -122,7 +122,7 @@ class AdditiveTokenMixer(nn.Module):
122
122
  self.proj_drop = nn.Dropout(proj_drop)
123
123
 
124
124
  def forward(self, x: torch.Tensor) -> torch.Tensor:
125
- (q, k, v) = self.qkv(x).chunk(3, dim=1)
125
+ q, k, v = self.qkv(x).chunk(3, dim=1)
126
126
  q = self.op_q(q)
127
127
  k = self.op_k(k)
128
128
 
birder/net/coat.py CHANGED
@@ -57,8 +57,8 @@ class ConvRelPosEnc(nn.Module):
57
57
  self.channel_splits = [x * head_channels for x in head_splits]
58
58
 
59
59
  def forward(self, q: torch.Tensor, v: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
60
- (B, num_heads, N, C) = q.size()
61
- (H, W) = size
60
+ B, num_heads, N, C = q.size()
61
+ H, W = size
62
62
  torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
63
63
 
64
64
  # Convolutional relative position encoding.
@@ -102,11 +102,11 @@ class FactorAttnConvRelPosEnc(nn.Module):
102
102
  self.crpe = shared_crpe
103
103
 
104
104
  def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
105
- (B, N, C) = x.size()
105
+ B, N, C = x.size()
106
106
 
107
107
  # Generate Q, K, V
108
108
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
109
- (q, k, v) = qkv.unbind(0) # [B, h, N, Ch]
109
+ q, k, v = qkv.unbind(0) # [B, h, N, Ch]
110
110
 
111
111
  # Factorized attention
112
112
  k_softmax = k.softmax(dim=2)
@@ -135,8 +135,8 @@ class ConvPosEnc(nn.Module):
135
135
  )
136
136
 
137
137
  def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
138
- (B, N, C) = x.size()
139
- (H, W) = size
138
+ B, N, C = x.size()
139
+ H, W = size
140
140
  torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
141
141
 
142
142
  # Extract CLS token and image tokens
@@ -244,8 +244,8 @@ class ParallelBlock(nn.Module):
244
244
  return self.interpolate(x, scale_factor=1.0 / factor, size=size)
245
245
 
246
246
  def interpolate(self, x: torch.Tensor, scale_factor: float, size: tuple[int, int]) -> torch.Tensor:
247
- (B, N, C) = x.size()
248
- (H, W) = size
247
+ B, N, C = x.size()
248
+ H, W = size
249
249
  torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
250
250
 
251
251
  cls_token = x[:, :1, :]
@@ -268,7 +268,7 @@ class ParallelBlock(nn.Module):
268
268
  def forward(
269
269
  self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, sizes: list[tuple[int, int]]
270
270
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
271
- (_, s2, s3, s4) = sizes
271
+ _, s2, s3, s4 = sizes
272
272
  cur2 = self.norm12(x2)
273
273
  cur3 = self.norm13(x3)
274
274
  cur4 = self.norm14(x4)
@@ -310,7 +310,7 @@ class PatchEmbed(nn.Module):
310
310
 
311
311
  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
312
312
  x = self.proj(x)
313
- (H, W) = x.shape[2:4]
313
+ H, W = x.shape[2:4]
314
314
 
315
315
  x = x.flatten(2).transpose(1, 2)
316
316
  x = self.norm(x)
@@ -500,7 +500,7 @@ class CoaT(DetectorBackbone):
500
500
  B = x.shape[0]
501
501
 
502
502
  # Serial blocks 1
503
- (x1, (h1, w1)) = self.patch_embed1(x)
503
+ x1, (h1, w1) = self.patch_embed1(x)
504
504
  x1 = insert_cls(x1, self.cls_token1)
505
505
  for blk in self.serial_blocks1:
506
506
  x1 = blk(x1, size=(h1, w1))
@@ -508,7 +508,7 @@ class CoaT(DetectorBackbone):
508
508
  x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
509
509
 
510
510
  # Serial blocks 2
511
- (x2, (h2, w2)) = self.patch_embed2(x1_no_cls)
511
+ x2, (h2, w2) = self.patch_embed2(x1_no_cls)
512
512
  x2 = insert_cls(x2, self.cls_token2)
513
513
  for blk in self.serial_blocks2:
514
514
  x2 = blk(x2, size=(h2, w2))
@@ -516,7 +516,7 @@ class CoaT(DetectorBackbone):
516
516
  x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
517
517
 
518
518
  # Serial blocks 3
519
- (x3, (h3, w3)) = self.patch_embed3(x2_no_cls)
519
+ x3, (h3, w3) = self.patch_embed3(x2_no_cls)
520
520
  x3 = insert_cls(x3, self.cls_token3)
521
521
  for blk in self.serial_blocks3:
522
522
  x3 = blk(x3, size=(h3, w3))
@@ -524,7 +524,7 @@ class CoaT(DetectorBackbone):
524
524
  x3_no_cls = remove_cls(x3).reshape(B, h3, w3, -1).permute(0, 3, 1, 2).contiguous()
525
525
 
526
526
  # Serial blocks 4
527
- (x4, (h4, w4)) = self.patch_embed4(x3_no_cls)
527
+ x4, (h4, w4) = self.patch_embed4(x3_no_cls)
528
528
  x4 = insert_cls(x4, self.cls_token4)
529
529
  for blk in self.serial_blocks4:
530
530
  x4 = blk(x4, size=(h4, w4))
@@ -537,7 +537,7 @@ class CoaT(DetectorBackbone):
537
537
  x2 = self.cpe2(x2, (h2, w2))
538
538
  x3 = self.cpe3(x3, (h3, w3))
539
539
  x4 = self.cpe4(x4, (h4, w4))
540
- (x1, x2, x3, x4) = blk(x1, x2, x3, x4, sizes=[(h1, w1), (h2, w2), (h3, w3), (h4, w4)])
540
+ x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(h1, w1), (h2, w2), (h3, w3), (h4, w4)])
541
541
 
542
542
  x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
543
543
  x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
birder/net/convnext_v1.py CHANGED
@@ -37,15 +37,7 @@ class ConvNeXtBlock(nn.Module):
37
37
  ) -> None:
38
38
  super().__init__()
39
39
  self.block = nn.Sequential(
40
- nn.Conv2d(
41
- channels,
42
- channels,
43
- kernel_size=(7, 7),
44
- stride=(1, 1),
45
- padding=(3, 3),
46
- groups=channels,
47
- bias=True,
48
- ),
40
+ nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
49
41
  Permute([0, 2, 3, 1]),
50
42
  nn.LayerNorm(channels, eps=1e-6),
51
43
  nn.Linear(channels, 4 * channels), # Same as 1x1 conv
@@ -119,7 +111,7 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
119
111
  layers.append(
120
112
  nn.Sequential(
121
113
  LayerNorm2d(i, eps=1e-6),
122
- nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True),
114
+ nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
123
115
  )
124
116
  )
125
117
 
@@ -0,0 +1,198 @@
1
+ """
2
+ ConvNeXt v1 Isotropic, adapted from
3
+ https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext_isotropic.py
4
+
5
+ Paper "A ConvNet for the 2020s", https://arxiv.org/abs/2201.03545
6
+ """
7
+
8
+ # Reference license: MIT
9
+
10
+ from functools import partial
11
+ from typing import Any
12
+ from typing import Literal
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torchvision.ops import Permute
18
+ from torchvision.ops import StochasticDepth
19
+
20
+ from birder.common.masking import mask_tensor
21
+ from birder.layers import LayerNorm2d
22
+ from birder.model_registry import registry
23
+ from birder.net.base import DetectorBackbone
24
+ from birder.net.base import MaskedTokenRetentionMixin
25
+ from birder.net.base import PreTrainEncoder
26
+ from birder.net.base import TokenRetentionResultType
27
+ from birder.net.base import normalize_out_indices
28
+
29
+
30
+ class ConvNeXtBlock(nn.Module):
31
+ def __init__(self, channels: int, stochastic_depth_prob: float) -> None:
32
+ super().__init__()
33
+ self.block = nn.Sequential(
34
+ nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
35
+ Permute([0, 2, 3, 1]),
36
+ nn.LayerNorm(channels, eps=1e-6),
37
+ nn.Linear(channels, 4 * channels),
38
+ nn.GELU(),
39
+ nn.Linear(4 * channels, channels),
40
+ Permute([0, 3, 1, 2]),
41
+ )
42
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ identity = x
46
+ x = self.block(x)
47
+ x = self.stochastic_depth(x)
48
+ x += identity
49
+
50
+ return x
51
+
52
+
53
+ # pylint: disable=invalid-name
54
+ class ConvNeXt_v1_Isotropic(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
55
+ block_group_regex = r"body\.(\d+)"
56
+
57
+ def __init__(
58
+ self,
59
+ input_channels: int,
60
+ num_classes: int,
61
+ *,
62
+ config: Optional[dict[str, Any]] = None,
63
+ size: Optional[tuple[int, int]] = None,
64
+ ) -> None:
65
+ super().__init__(input_channels, num_classes, config=config, size=size)
66
+ assert self.config is not None, "must set config"
67
+
68
+ patch_size = 16
69
+ dim: int = self.config["dim"]
70
+ num_layers: int = self.config["num_layers"]
71
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
72
+ drop_path_rate: float = self.config["drop_path_rate"]
73
+
74
+ torch._assert(self.size[0] % patch_size == 0, "Input shape indivisible by patch size!")
75
+ torch._assert(self.size[1] % patch_size == 0, "Input shape indivisible by patch size!")
76
+ self.patch_size = patch_size
77
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
78
+
79
+ self.stem = nn.Conv2d(
80
+ self.input_channels,
81
+ dim,
82
+ kernel_size=(patch_size, patch_size),
83
+ stride=(patch_size, patch_size),
84
+ padding=(0, 0),
85
+ )
86
+
87
+ layers = []
88
+ for idx in range(num_layers):
89
+ # Adjust stochastic depth probability based on the depth of the stage block
90
+ sd_prob = drop_path_rate * idx / (num_layers - 1.0)
91
+ layers.append(ConvNeXtBlock(dim, sd_prob))
92
+
93
+ self.body = nn.Sequential(*layers)
94
+ self.features = nn.Sequential(
95
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
96
+ LayerNorm2d(dim, eps=1e-6),
97
+ nn.Flatten(1),
98
+ )
99
+
100
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
101
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
102
+ self.return_channels = [dim] * num_return_stages
103
+ self.embedding_size = dim
104
+ self.classifier = self.create_classifier()
105
+
106
+ self.max_stride = patch_size
107
+ self.stem_stride = patch_size
108
+ self.stem_width = dim
109
+ self.encoding_size = dim
110
+ self.decoder_block = partial(ConvNeXtBlock, stochastic_depth_prob=0)
111
+
112
+ # Weights initialization
113
+ for m in self.modules():
114
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
115
+ nn.init.trunc_normal_(m.weight, std=0.02)
116
+ if m.bias is not None:
117
+ nn.init.zeros_(m.bias)
118
+
119
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
120
+ x = self.stem(x)
121
+
122
+ if self.out_indices is None:
123
+ x = self.body(x)
124
+ return {self.return_stages[0]: x}
125
+
126
+ stage_num = 0
127
+ out: dict[str, torch.Tensor] = {}
128
+ for idx, module in enumerate(self.body.children()):
129
+ x = module(x)
130
+ if idx in self.out_indices:
131
+ out[self.return_stages[stage_num]] = x
132
+ stage_num += 1
133
+
134
+ return out
135
+
136
+ def freeze_stages(self, up_to_stage: int) -> None:
137
+ for param in self.stem.parameters():
138
+ param.requires_grad_(False)
139
+
140
+ for idx, module in enumerate(self.body.children()):
141
+ if idx >= up_to_stage:
142
+ break
143
+
144
+ for param in module.parameters():
145
+ param.requires_grad_(False)
146
+
147
+ def masked_encoding_retention(
148
+ self,
149
+ x: torch.Tensor,
150
+ mask: torch.Tensor,
151
+ mask_token: Optional[torch.Tensor] = None,
152
+ return_keys: Literal["all", "features", "embedding"] = "features",
153
+ ) -> TokenRetentionResultType:
154
+ x = self.stem(x)
155
+ x = mask_tensor(x, mask, patch_factor=self.max_stride // self.stem_stride, mask_token=mask_token)
156
+ x = self.body(x)
157
+
158
+ result: TokenRetentionResultType = {}
159
+ if return_keys in ("all", "features"):
160
+ result["features"] = x
161
+ if return_keys in ("all", "embedding"):
162
+ result["embedding"] = self.features(x)
163
+
164
+ return result
165
+
166
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
167
+ x = self.stem(x)
168
+ return self.body(x)
169
+
170
+ def embedding(self, x: torch.Tensor) -> torch.Tensor:
171
+ x = self.forward_features(x)
172
+ return self.features(x)
173
+
174
+ def adjust_size(self, new_size: tuple[int, int]) -> None:
175
+ if new_size == self.size:
176
+ return
177
+
178
+ assert new_size[0] % self.patch_size == 0, "Input shape indivisible by patch size!"
179
+ assert new_size[1] % self.patch_size == 0, "Input shape indivisible by patch size!"
180
+
181
+ super().adjust_size(new_size)
182
+
183
+
184
+ registry.register_model_config(
185
+ "convnext_v1_iso_small",
186
+ ConvNeXt_v1_Isotropic,
187
+ config={"dim": 384, "num_layers": 18, "drop_path_rate": 0.1},
188
+ )
189
+ registry.register_model_config(
190
+ "convnext_v1_iso_base",
191
+ ConvNeXt_v1_Isotropic,
192
+ config={"in_channels": 768, "num_layers": 18, "drop_path_rate": 0.2},
193
+ )
194
+ registry.register_model_config(
195
+ "convnext_v1_iso_large",
196
+ ConvNeXt_v1_Isotropic,
197
+ config={"in_channels": 1024, "num_layers": 36, "drop_path_rate": 0.5},
198
+ )