birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/iformer.py CHANGED
@@ -113,12 +113,12 @@ class LowMixer(nn.Module):
113
113
 
114
114
  def forward(self, x: torch.Tensor) -> torch.Tensor:
115
115
  x = self.pool(x)
116
- (B, _, H, W) = x.size()
116
+ B, _, H, W = x.size()
117
117
  x = x.permute(0, 2, 3, 1).view(B, -1, self.dim)
118
118
 
119
- (B, N, C) = x.size()
119
+ B, N, C = x.size()
120
120
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121
- (q, k, v) = qkv.unbind(0)
121
+ q, k, v = qkv.unbind(0)
122
122
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
123
123
  q, k, v, dropout_p=self.attn_drop if self.training else 0.0, scale=self.scale
124
124
  )
@@ -301,7 +301,7 @@ class InceptionTransformerStage(nn.Module):
301
301
 
302
302
  def forward(self, x: torch.Tensor) -> torch.Tensor:
303
303
  x = self.downsample(x)
304
- (H, W) = x.shape[1:3]
304
+ H, W = x.shape[1:3]
305
305
 
306
306
  x = x + self._get_pos_embed(H, W)
307
307
  x = self.blocks(x)
@@ -33,7 +33,6 @@ class InceptionDWConv2d(nn.Module):
33
33
  stride=(1, 1),
34
34
  padding=square_kernel_size // 2,
35
35
  groups=branch_channels,
36
- bias=True,
37
36
  )
38
37
  self.dwconv_w = nn.Conv2d(
39
38
  branch_channels,
@@ -42,7 +41,6 @@ class InceptionDWConv2d(nn.Module):
42
41
  stride=(1, 1),
43
42
  padding=(0, band_kernel_size // 2),
44
43
  groups=branch_channels,
45
- bias=True,
46
44
  )
47
45
  self.dwconv_h = nn.Conv2d(
48
46
  branch_channels,
@@ -51,7 +49,6 @@ class InceptionDWConv2d(nn.Module):
51
49
  stride=(1, 1),
52
50
  padding=(band_kernel_size // 2, 0),
53
51
  groups=branch_channels,
54
- bias=True,
55
52
  )
56
53
  self.split_indexes = (
57
54
  in_channels - (3 * branch_channels),
@@ -61,7 +58,7 @@ class InceptionDWConv2d(nn.Module):
61
58
  )
62
59
 
63
60
  def forward(self, x: torch.Tensor) -> torch.Tensor:
64
- (x_id, x_hw, x_w, x_h) = torch.split(x, self.split_indexes, dim=1)
61
+ x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1)
65
62
  x_hw = self.dwconv_hw(x_hw)
66
63
  x_w = self.dwconv_w(x_w)
67
64
  x_h = self.dwconv_h(x_h)
@@ -78,11 +75,9 @@ class ConvMLP(nn.Module):
78
75
  act_layer: Callable[..., nn.Module] = nn.GELU,
79
76
  ) -> None:
80
77
  super().__init__()
81
- self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
78
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
82
79
  self.act = act_layer()
83
- self.fc2 = nn.Conv2d(
84
- hidden_features, out_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
85
- )
80
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
86
81
 
87
82
  def forward(self, x: torch.Tensor) -> torch.Tensor:
88
83
  x = self.fc1(x)
@@ -139,12 +134,7 @@ class InceptionNeXtStage(nn.Module):
139
134
  self.downsample = nn.Sequential(
140
135
  nn.BatchNorm2d(in_channels),
141
136
  nn.Conv2d(
142
- in_channels,
143
- out_channels,
144
- kernel_size=(stride, stride),
145
- stride=(stride, stride),
146
- padding=(0, 0),
147
- bias=True,
137
+ in_channels, out_channels, kernel_size=(stride, stride), stride=(stride, stride), padding=(0, 0)
148
138
  ),
149
139
  )
150
140
 
birder/net/levit.py CHANGED
@@ -45,7 +45,7 @@ class Subsample(nn.Module):
45
45
  self.resolution = resolution
46
46
 
47
47
  def forward(self, x: torch.Tensor) -> torch.Tensor:
48
- (B, _, C) = x.shape
48
+ B, _, C = x.shape
49
49
  x = x.view(B, self.resolution[0], self.resolution[1], C)
50
50
  x = x[:, :: self.stride, :: self.stride]
51
51
  return x.reshape(B, -1, C)
@@ -84,7 +84,7 @@ class Attention(nn.Module):
84
84
  self.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
85
85
 
86
86
  def forward(self, x: torch.Tensor) -> torch.Tensor:
87
- (B, N, _) = x.shape
87
+ B, N, _ = x.shape
88
88
  q, k, v = self.qkv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
89
89
  q = q.permute(0, 2, 1, 3)
90
90
  k = k.permute(0, 2, 3, 1)
@@ -144,7 +144,7 @@ class AttentionSubsample(nn.Module):
144
144
  self.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
145
145
 
146
146
  def forward(self, x: torch.Tensor) -> torch.Tensor:
147
- (B, N, _) = x.shape
147
+ B, N, _ = x.shape
148
148
  k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
149
149
  k = k.permute(0, 2, 3, 1) # BHCN
150
150
  v = v.permute(0, 2, 1, 3) # BHNC
birder/net/lit_v1.py CHANGED
@@ -43,7 +43,7 @@ def interpolate_rel_pos_bias_table(
43
43
  if new_resolution == base_resolution:
44
44
  return rel_pos_bias_table
45
45
 
46
- (base_h, base_w) = base_resolution
46
+ base_h, base_w = base_resolution
47
47
  num_heads = rel_pos_bias_table.size(1)
48
48
  orig_dtype = rel_pos_bias_table.dtype
49
49
  bias_table = rel_pos_bias_table.float()
@@ -104,7 +104,7 @@ class RelPosAttention(nn.Module):
104
104
  relative_position_index = build_relative_position_index(input_resolution, device=bias_table.device)
105
105
  self.relative_position_index = nn.Buffer(relative_position_index)
106
106
 
107
- self.qkv = nn.Linear(dim, dim * 3, bias=True)
107
+ self.qkv = nn.Linear(dim, dim * 3)
108
108
  self.proj = nn.Linear(dim, dim)
109
109
 
110
110
  # Weight initialization
@@ -130,9 +130,9 @@ class RelPosAttention(nn.Module):
130
130
  return relative_position_bias.unsqueeze(0)
131
131
 
132
132
  def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
133
- (B, N, C) = x.size()
133
+ B, N, C = x.size()
134
134
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
135
- (q, k, v) = qkv.unbind(0)
135
+ q, k, v = qkv.unbind(0)
136
136
 
137
137
  attn = (q * self.scale) @ k.transpose(-2, -1)
138
138
  attn = attn + self._get_rel_pos_bias(resolution)
@@ -177,7 +177,6 @@ class DeformablePatchMerging(nn.Module):
177
177
  kernel_size=(kernel_size, kernel_size),
178
178
  stride=(kernel_size, kernel_size),
179
179
  padding=(0, 0),
180
- bias=True,
181
180
  )
182
181
  self.deform_conv = DeformConv2d(
183
182
  in_dim,
@@ -195,8 +194,8 @@ class DeformablePatchMerging(nn.Module):
195
194
  nn.init.zeros_(self.offset_conv.bias)
196
195
 
197
196
  def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
198
- (H, W) = resolution
199
- (B, _, C) = x.size()
197
+ H, W = resolution
198
+ B, _, C = x.size()
200
199
 
201
200
  x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
202
201
 
@@ -206,7 +205,7 @@ class DeformablePatchMerging(nn.Module):
206
205
  x = self.norm(x)
207
206
  x = self.act(x)
208
207
 
209
- (B, C, H, W) = x.size()
208
+ B, C, H, W = x.size()
210
209
  x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
211
210
 
212
211
  return (x, H, W)
@@ -252,7 +251,7 @@ class LITStage(nn.Module):
252
251
  block.set_dynamic_size(dynamic_size)
253
252
 
254
253
  def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
255
- (x, H, W) = self.downsample(x, input_resolution)
254
+ x, H, W = self.downsample(x, input_resolution)
256
255
  for block in self.blocks:
257
256
  x = block(x, (H, W))
258
257
 
@@ -291,7 +290,6 @@ class LIT_v1(DetectorBackbone):
291
290
  kernel_size=(patch_size, patch_size),
292
291
  stride=(patch_size, patch_size),
293
292
  padding=(0, 0),
294
- bias=True,
295
293
  ),
296
294
  Permute([0, 2, 3, 1]),
297
295
  nn.LayerNorm(embed_dim),
@@ -361,12 +359,12 @@ class LIT_v1(DetectorBackbone):
361
359
 
362
360
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
363
361
  x = self.stem(x)
364
- (B, H, W, C) = x.size()
362
+ B, H, W, C = x.size()
365
363
  x = x.reshape(B, H * W, C)
366
364
 
367
365
  out = {}
368
366
  for name, stage in self.body.items():
369
- (x, H, W) = stage(x, (H, W))
367
+ x, H, W = stage(x, (H, W))
370
368
  if name in self.return_stages:
371
369
  features = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
372
370
  out[name] = features
@@ -386,10 +384,10 @@ class LIT_v1(DetectorBackbone):
386
384
 
387
385
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
388
386
  x = self.stem(x)
389
- (B, H, W, C) = x.size()
387
+ B, H, W, C = x.size()
390
388
  x = x.reshape(B, H * W, C)
391
389
  for stage in self.body.values():
392
- (x, H, W) = stage(x, (H, W))
390
+ x, H, W = stage(x, (H, W))
393
391
 
394
392
  return x
395
393
 
@@ -410,7 +408,7 @@ class LIT_v1(DetectorBackbone):
410
408
 
411
409
  new_patches_resolution = (new_size[0] // self.patch_size, new_size[1] // self.patch_size)
412
410
 
413
- (h, w) = new_patches_resolution
411
+ h, w = new_patches_resolution
414
412
  for stage in self.body.values():
415
413
  if not isinstance(stage.downsample, IdentityDownsample):
416
414
  h = h // 2
birder/net/lit_v1_tiny.py CHANGED
@@ -44,13 +44,13 @@ class Attention(nn.Module):
44
44
  super().__init__()
45
45
  self.num_heads = num_heads
46
46
  self.scale = (dim // num_heads) ** -0.5
47
- self.qkv = nn.Linear(dim, dim * 3, bias=True)
47
+ self.qkv = nn.Linear(dim, dim * 3)
48
48
  self.proj = nn.Linear(dim, dim)
49
49
 
50
50
  def forward(self, x: torch.Tensor) -> torch.Tensor:
51
- (B, N, C) = x.size()
51
+ B, N, C = x.size()
52
52
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
53
- (q, k, v) = qkv.unbind(0)
53
+ q, k, v = qkv.unbind(0)
54
54
 
55
55
  attn = (q @ k.transpose(-2, -1)) * self.scale
56
56
  attn = F.softmax(attn, dim=-1)
@@ -139,7 +139,7 @@ class LITStage(nn.Module):
139
139
  )
140
140
 
141
141
  def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
142
- (x, H, W) = self.downsample(x, input_resolution)
142
+ x, H, W = self.downsample(x, input_resolution)
143
143
 
144
144
  if self.cls_token is not None:
145
145
  cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
@@ -247,12 +247,12 @@ class LIT_v1_Tiny(DetectorBackbone):
247
247
 
248
248
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
249
249
  x = self.stem(x)
250
- (B, H, W, C) = x.size()
250
+ B, H, W, C = x.size()
251
251
  x = x.reshape(B, H * W, C)
252
252
 
253
253
  out = {}
254
254
  for name, stage in self.body.items():
255
- (x, H, W) = stage(x, (H, W))
255
+ x, H, W = stage(x, (H, W))
256
256
  if name in self.return_stages:
257
257
  if stage.cls_token is not None:
258
258
  spatial_x = x[:, 1:]
@@ -276,10 +276,10 @@ class LIT_v1_Tiny(DetectorBackbone):
276
276
 
277
277
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
278
278
  x = self.stem(x)
279
- (B, H, W, C) = x.size()
279
+ B, H, W, C = x.size()
280
280
  x = x.reshape(B, H * W, C)
281
281
  for stage in self.body.values():
282
- (x, H, W) = stage(x, (H, W))
282
+ x, H, W = stage(x, (H, W))
283
283
 
284
284
  return x
285
285
 
@@ -301,7 +301,7 @@ class LIT_v1_Tiny(DetectorBackbone):
301
301
 
302
302
  new_patches_resolution = (new_size[0] // self.patch_size, new_size[1] // self.patch_size)
303
303
 
304
- (h, w) = new_patches_resolution
304
+ h, w = new_patches_resolution
305
305
  for stage in self.body.values():
306
306
  if not isinstance(stage.downsample, IdentityDownsample):
307
307
  h = h // 2
birder/net/lit_v2.py CHANGED
@@ -39,7 +39,7 @@ class DepthwiseMLP(nn.Module):
39
39
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
40
40
  x = self.fc1(x)
41
41
 
42
- (B, N, C) = x.size()
42
+ B, N, C = x.size()
43
43
  x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
44
44
  x = self.dwconv(x)
45
45
  x = x.permute(0, 2, 3, 1).reshape(B, N, C)
@@ -57,7 +57,7 @@ class DepthwiseMLPBlock(nn.Module):
57
57
  self.drop_path = StochasticDepth(drop_path, mode="row")
58
58
 
59
59
  def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
60
- (H, W) = resolution
60
+ H, W = resolution
61
61
  return x + self.drop_path(self.mlp(self.norm(x), H, W))
62
62
 
63
63
 
@@ -121,7 +121,7 @@ class HiLoAttention(nn.Module):
121
121
  self.h_proj = nn.Identity()
122
122
 
123
123
  def _lofi(self, x: torch.Tensor) -> torch.Tensor:
124
- (B, H, W, C) = x.size()
124
+ B, H, W, C = x.size()
125
125
 
126
126
  q = self.l_q(x).reshape(B, H * W, self.l_heads, self.head_dim).permute(0, 2, 1, 3)
127
127
 
@@ -133,7 +133,7 @@ class HiLoAttention(nn.Module):
133
133
  else:
134
134
  kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.head_dim).permute(2, 0, 3, 1, 4)
135
135
 
136
- (k, v) = kv.unbind(0)
136
+ k, v = kv.unbind(0)
137
137
 
138
138
  attn = (q @ k.transpose(-2, -1)) * self.scale
139
139
  attn = F.softmax(attn, dim=-1)
@@ -144,7 +144,7 @@ class HiLoAttention(nn.Module):
144
144
  return x
145
145
 
146
146
  def _hifi(self, x: torch.Tensor) -> torch.Tensor:
147
- (B, H, W, _) = x.size()
147
+ B, H, W, _ = x.size()
148
148
  ws = self.window_size
149
149
 
150
150
  # Pad if needed
@@ -153,7 +153,7 @@ class HiLoAttention(nn.Module):
153
153
  if pad_h > 0 or pad_w > 0:
154
154
  x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
155
155
 
156
- (_, h_pad, w_pad, _) = x.size()
156
+ _, h_pad, w_pad, _ = x.size()
157
157
  h_groups = h_pad // ws
158
158
  w_groups = w_pad // ws
159
159
  total_groups = h_groups * w_groups
@@ -161,7 +161,7 @@ class HiLoAttention(nn.Module):
161
161
  x = x.reshape(B, h_groups, ws, w_groups, ws, -1).transpose(2, 3)
162
162
 
163
163
  qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.head_dim).permute(3, 0, 1, 4, 2, 5)
164
- (q, k, v) = qkv.unbind(0)
164
+ q, k, v = qkv.unbind(0)
165
165
 
166
166
  attn = (q @ k.transpose(-2, -1)) * self.scale
167
167
  attn = F.softmax(attn, dim=-1)
@@ -177,7 +177,7 @@ class HiLoAttention(nn.Module):
177
177
  return x
178
178
 
179
179
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
180
- (B, N, C) = x.size()
180
+ B, N, C = x.size()
181
181
  x = x.reshape(B, H, W, C)
182
182
 
183
183
  if self.h_heads == 0:
@@ -215,7 +215,7 @@ class HiLoBlock(nn.Module):
215
215
  self.drop_path2 = StochasticDepth(drop_path, mode="row")
216
216
 
217
217
  def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
218
- (H, W) = resolution
218
+ H, W = resolution
219
219
  x = x + self.drop_path1(self.attn(self.norm1(x), H, W))
220
220
  x = x + self.drop_path2(self.mlp(self.norm2(x), H, W))
221
221
  return x
@@ -252,7 +252,7 @@ class LITStage(nn.Module):
252
252
  self.blocks = nn.ModuleList(blocks)
253
253
 
254
254
  def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
255
- (x, H, W) = self.downsample(x, input_resolution)
255
+ x, H, W = self.downsample(x, input_resolution)
256
256
  for block in self.blocks:
257
257
  x = block(x, (H, W))
258
258
 
@@ -292,7 +292,6 @@ class LIT_v2(DetectorBackbone):
292
292
  kernel_size=(patch_size, patch_size),
293
293
  stride=(patch_size, patch_size),
294
294
  padding=(0, 0),
295
- bias=True,
296
295
  ),
297
296
  Permute([0, 2, 3, 1]),
298
297
  nn.LayerNorm(embed_dim),
@@ -361,12 +360,12 @@ class LIT_v2(DetectorBackbone):
361
360
 
362
361
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
363
362
  x = self.stem(x)
364
- (B, H, W, C) = x.size()
363
+ B, H, W, C = x.size()
365
364
  x = x.reshape(B, H * W, C)
366
365
 
367
366
  out = {}
368
367
  for name, stage in self.body.items():
369
- (x, H, W) = stage(x, (H, W))
368
+ x, H, W = stage(x, (H, W))
370
369
  if name in self.return_stages:
371
370
  features = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
372
371
  out[name] = features
@@ -386,10 +385,10 @@ class LIT_v2(DetectorBackbone):
386
385
 
387
386
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
388
387
  x = self.stem(x)
389
- (B, H, W, C) = x.size()
388
+ B, H, W, C = x.size()
390
389
  x = x.reshape(B, H * W, C)
391
390
  for stage in self.body.values():
392
- (x, H, W) = stage(x, (H, W))
391
+ x, H, W = stage(x, (H, W))
393
392
 
394
393
  return x
395
394
 
birder/net/maxvit.py CHANGED
@@ -83,7 +83,7 @@ class MBConv(nn.Module):
83
83
  if stride[0] != 1 or stride[1] != 1 or in_channels != out_channels:
84
84
  self.proj = nn.Sequential(
85
85
  nn.AvgPool2d(kernel_size=(3, 3), stride=stride, padding=(1, 1)),
86
- nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
86
+ nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
87
87
  )
88
88
  else:
89
89
  self.proj = nn.Identity()
@@ -119,12 +119,7 @@ class MBConv(nn.Module):
119
119
  ),
120
120
  SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU),
121
121
  nn.Conv2d(
122
- in_channels=mid_channels,
123
- out_channels=out_channels,
124
- kernel_size=(1, 1),
125
- stride=(1, 1),
126
- padding=(0, 0),
127
- bias=True,
122
+ in_channels=mid_channels, out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
128
123
  ),
129
124
  )
130
125
 
@@ -169,12 +164,12 @@ class RelativePositionalMultiHeadAttention(nn.Module):
169
164
 
170
165
  # pylint: disable=invalid-name
171
166
  def forward(self, x: torch.Tensor) -> torch.Tensor:
172
- (B, G, P, D) = x.size()
167
+ B, G, P, D = x.size()
173
168
  H = self.n_heads
174
169
  DH = self.head_dim
175
170
 
176
171
  qkv = self.to_qkv(x)
177
- (q, k, v) = torch.chunk(qkv, 3, dim=-1)
172
+ q, k, v = torch.chunk(qkv, 3, dim=-1)
178
173
 
179
174
  q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
180
175
  k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
@@ -206,8 +201,8 @@ class SwapAxes(nn.Module):
206
201
 
207
202
  class WindowPartition(nn.Module):
208
203
  def forward(self, x: torch.Tensor, p: tuple[int, int]) -> torch.Tensor:
209
- (B, C, H, W) = x.size()
210
- (PH, PW) = p # pylint: disable=invalid-name
204
+ B, C, H, W = x.size()
205
+ PH, PW = p # pylint: disable=invalid-name
211
206
 
212
207
  # Chunk up H and W dimensions
213
208
  x = x.reshape(B, C, H // PH, PH, W // PW, PW)
@@ -222,8 +217,8 @@ class WindowPartition(nn.Module):
222
217
  class WindowDepartition(nn.Module):
223
218
  # pylint: disable=invalid-name
224
219
  def forward(self, x: torch.Tensor, p: tuple[int, int], h_partitions: int, w_partitions: int) -> torch.Tensor:
225
- (B, _G, _PP, C) = x.size()
226
- (PH, PW) = p # pylint: disable=invalid-name
220
+ B, _G, _PP, C = x.size()
221
+ PH, PW = p # pylint: disable=invalid-name
227
222
  HP = h_partitions
228
223
  WP = w_partitions
229
224
 
@@ -500,14 +495,7 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
500
495
  activation_layer=nn.GELU,
501
496
  inplace=None,
502
497
  ),
503
- nn.Conv2d(
504
- stem_channels,
505
- stem_channels,
506
- kernel_size=(3, 3),
507
- stride=(1, 1),
508
- padding=(1, 1),
509
- bias=True,
510
- ),
498
+ nn.Conv2d(stem_channels, stem_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
511
499
  )
512
500
 
513
501
  # Account for stem stride
@@ -706,7 +694,7 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
706
694
  src_size = (2 * old_attn_size[0] - 1, 2 * old_attn_size[1] - 1)
707
695
 
708
696
  def _calc(src: int, dst: int) -> list[float]:
709
- (left, right) = 1.01, 1.5
697
+ left, right = 1.01, 1.5
710
698
  while right - left > 1e-6:
711
699
  q = (left + right) / 2.0
712
700
  gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
birder/net/metaformer.py CHANGED
@@ -127,10 +127,10 @@ class Attention(nn.Module):
127
127
 
128
128
  def forward(self, x: torch.Tensor) -> torch.Tensor:
129
129
  x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
130
- (B, H, W, _) = x.shape
130
+ B, H, W, _ = x.shape
131
131
  N = H * W
132
132
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
133
- (q, k, v) = qkv.unbind(0)
133
+ q, k, v = qkv.unbind(0)
134
134
 
135
135
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
136
136
  q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
@@ -46,11 +46,11 @@ class CrossAttention(nn.Module):
46
46
  self.proj = nn.Linear(decoder_dim, decoder_dim)
47
47
 
48
48
  def forward(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
49
- (B, N, C) = tgt.size()
49
+ B, N, C = tgt.size()
50
50
  n_kv = memory.size(1)
51
51
  q = self.q(tgt).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
52
52
  kv = self.kv(memory).reshape(B, n_kv, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
53
- (k, v) = kv.unbind(0)
53
+ k, v = kv.unbind(0)
54
54
 
55
55
  attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) # pylint: disable=not-callable
56
56
  x = attn.transpose(1, 2).reshape(B, N, C)
@@ -120,7 +120,7 @@ class CrossMAE(MIMBaseNet):
120
120
  self.decoder_layers.append(CrossAttentionBlock(encoder_dim, decoder_embed_dim, num_heads=16, mlp_ratio=4.0))
121
121
 
122
122
  self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6)
123
- self.pred = nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels, bias=True)
123
+ self.pred = nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels)
124
124
 
125
125
  # Weight initialization
126
126
  for m in self.modules():
@@ -170,7 +170,7 @@ class CrossMAE(MIMBaseNet):
170
170
  return imgs
171
171
 
172
172
  def fill_pred(self, mask: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
173
- (N, L) = mask.shape[0:2]
173
+ N, L = mask.shape[0:2]
174
174
  combined = torch.zeros(N, L, pred.shape[2], device=pred.device, dtype=pred.dtype)
175
175
  combined[mask.bool()] = pred.view(-1, pred.shape[2])
176
176
 
@@ -213,7 +213,7 @@ class CrossMAE(MIMBaseNet):
213
213
  def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
214
214
  h = self.size[0] // self.encoder.stem_stride
215
215
  w = self.size[1] // self.encoder.stem_stride
216
- (mask, ids_keep, _) = uniform_mask(
216
+ mask, ids_keep, _ = uniform_mask(
217
217
  x.size(0), h, w, self.mask_ratio, self.kept_mask_ratio, min_mask_size=self.min_mask_size, device=x.device
218
218
  )
219
219
 
birder/net/mim/fcmae.py CHANGED
@@ -48,7 +48,6 @@ class FCMAE(MIMBaseNet):
48
48
  kernel_size=(1, 1),
49
49
  stride=(1, 1),
50
50
  padding=(0, 0),
51
- bias=True,
52
51
  )
53
52
 
54
53
  self.mask_token = nn.Parameter(torch.zeros(1, self.decoder_embed_dim, 1, 1))
@@ -65,7 +64,6 @@ class FCMAE(MIMBaseNet):
65
64
  kernel_size=(1, 1),
66
65
  stride=(1, 1),
67
66
  padding=(0, 0),
68
- bias=True,
69
67
  )
70
68
 
71
69
  # Weights initialization
@@ -106,7 +104,7 @@ class FCMAE(MIMBaseNet):
106
104
  """
107
105
 
108
106
  if x.ndim == 4:
109
- (n, c, _, _) = x.shape
107
+ n, c, _, _ = x.shape
110
108
  x = x.reshape(n, c, -1)
111
109
  x = torch.einsum("ncl->nlc", x)
112
110
 
@@ -125,7 +123,7 @@ class FCMAE(MIMBaseNet):
125
123
  x = self.proj(x)
126
124
 
127
125
  # Append mask token
128
- (B, _, H, W) = x.shape
126
+ B, _, H, W = x.shape
129
127
  mask = mask.reshape(-1, H, W).unsqueeze(1).type_as(x)
130
128
  mask_token = self.mask_token.repeat(B, 1, H, W)
131
129
  x = x * (1.0 - mask) + (mask_token * mask)
@@ -141,7 +139,7 @@ class FCMAE(MIMBaseNet):
141
139
  mask: 0 is keep, 1 is remove
142
140
  """
143
141
 
144
- (n, c, _, _) = pred.shape
142
+ n, c, _, _ = pred.shape
145
143
  pred = pred.reshape(n, c, -1)
146
144
  pred = torch.einsum("ncl->nlc", pred)
147
145
 
@@ -26,7 +26,7 @@ def apply_fusion_head(head: nn.Module, x: torch.Tensor) -> torch.Tensor:
26
26
  if isinstance(head, nn.Identity):
27
27
  return x
28
28
 
29
- (B, num_mask_units) = x.shape[0:2]
29
+ B, num_mask_units = x.shape[0:2]
30
30
 
31
31
  # Apply head, e.g [B, #MUs, My, Mx, C] -> head([B * #MUs, C, My, Mx])
32
32
  permute = [0] + [len(x.shape) - 2] + list(range(1, len(x.shape) - 2))
@@ -169,7 +169,7 @@ class MAE_Hiera(MIMBaseNet):
169
169
 
170
170
  def forward_encoder(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
171
171
  # Tokens selected for masking at mask unit level
172
- (mask, _, _) = uniform_mask(
172
+ mask, _, _ = uniform_mask(
173
173
  x.size(0),
174
174
  self.encoder.mask_spatial_shape[0],
175
175
  self.encoder.mask_spatial_shape[1],
@@ -179,7 +179,7 @@ class MAE_Hiera(MIMBaseNet):
179
179
  )
180
180
 
181
181
  # Get multi-scale representations from encoder
182
- (intermediates, mask) = self.encoder.masked_encoding(x, mask)
182
+ intermediates, mask = self.encoder.masked_encoding(x, mask)
183
183
 
184
184
  # Resolution unchanged after q_pool stages, so skip those features
185
185
  intermediates = intermediates[: self.encoder.q_pool] + intermediates[-1:]
@@ -206,12 +206,12 @@ class MAE_Hiera(MIMBaseNet):
206
206
  # Get back spatial order
207
207
  x = undo_windowing(
208
208
  x_dec,
209
- self.tokens_spatial_shape_final, # type:ignore[arg-type]
209
+ self.tokens_spatial_shape_final, # type: ignore[arg-type]
210
210
  self.mask_unit_spatial_shape_final,
211
211
  )
212
212
  mask = undo_windowing(
213
213
  mask[..., 0:1],
214
- self.tokens_spatial_shape_final, # type:ignore[arg-type]
214
+ self.tokens_spatial_shape_final, # type: ignore[arg-type]
215
215
  self.mask_unit_spatial_shape_final,
216
216
  )
217
217
 
@@ -240,8 +240,8 @@ class MAE_Hiera(MIMBaseNet):
240
240
  return loss.mean()
241
241
 
242
242
  def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
243
- (latent, mask) = self.forward_encoder(x)
244
- (pred, pred_mask) = self.forward_decoder(latent, mask)
243
+ latent, mask = self.forward_encoder(x)
244
+ pred, pred_mask = self.forward_decoder(latent, mask)
245
245
  loss = self.forward_loss(x, pred, ~pred_mask)
246
246
 
247
247
  return {"loss": loss, "pred": pred, "mask": mask}
birder/net/mim/mae_vit.py CHANGED
@@ -52,7 +52,7 @@ class MAE_ViT(MIMBaseNet):
52
52
 
53
53
  self.norm_pix_loss = norm_pix_loss
54
54
 
55
- self.decoder_embed = nn.Linear(encoder_dim, decoder_embed_dim, bias=True)
55
+ self.decoder_embed = nn.Linear(encoder_dim, decoder_embed_dim)
56
56
  self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
57
57
 
58
58
  if learnable_pos_embed is True:
@@ -74,9 +74,7 @@ class MAE_ViT(MIMBaseNet):
74
74
  layers.append(self.encoder.decoder_block(decoder_embed_dim))
75
75
 
76
76
  layers.append(nn.LayerNorm(decoder_embed_dim, eps=1e-6))
77
- layers.append(
78
- nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels, bias=True)
79
- ) # Decoder to patch
77
+ layers.append(nn.Linear(decoder_embed_dim, self.patch_size**2 * self.input_channels)) # Decoder to patch
80
78
  self.decoder = nn.Sequential(*layers)
81
79
 
82
80
  def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
@@ -153,7 +151,7 @@ class MAE_ViT(MIMBaseNet):
153
151
  def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
154
152
  h = self.size[0] // self.encoder.max_stride
155
153
  w = self.size[1] // self.encoder.max_stride
156
- (mask, ids_keep, ids_restore) = uniform_mask(
154
+ mask, ids_keep, ids_restore = uniform_mask(
157
155
  x.size(0), h, w, self.mask_ratio, min_mask_size=self.min_mask_size, device=x.device
158
156
  )
159
157