birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/smt.py CHANGED
@@ -36,7 +36,7 @@ class DWConv(nn.Module):
36
36
  self.dwconv = nn.Conv2d(dim, dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=dim)
37
37
 
38
38
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
39
- (B, _, C) = x.size()
39
+ B, _, C = x.size()
40
40
  x = x.transpose(1, 2).view(B, C, H, W)
41
41
  x = self.dwconv(x)
42
42
  x = x.flatten(2).transpose(1, 2)
@@ -94,7 +94,7 @@ class CAAttention(nn.Module):
94
94
  self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
95
95
 
96
96
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
97
- (B, N, C) = x.size()
97
+ B, N, C = x.size()
98
98
 
99
99
  v = self.v(x)
100
100
  s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1, 2)
@@ -140,11 +140,11 @@ class SAAttention(nn.Module):
140
140
  self.conv = nn.Conv2d(dim, dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=dim)
141
141
 
142
142
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
143
- (B, N, C) = x.size()
143
+ B, N, C = x.size()
144
144
 
145
145
  q = self.q(x).reshape(B, N, self.sa_num_heads, C // self.sa_num_heads).permute(0, 2, 1, 3)
146
146
  kv = self.kv(x).reshape(B, -1, 2, self.sa_num_heads, C // self.sa_num_heads).permute(2, 0, 3, 1, 4)
147
- (k, v) = kv.unbind(0)
147
+ k, v = kv.unbind(0)
148
148
  attn = (q @ k.transpose(-2, -1)) * self.scale
149
149
  attn = attn.softmax(dim=-1)
150
150
  attn = self.attn_drop(attn)
@@ -243,7 +243,7 @@ class OverlapPatchEmbed(nn.Module):
243
243
 
244
244
  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
245
245
  x = self.proj(x)
246
- (_, _, H, W) = x.size()
246
+ _, _, H, W = x.size()
247
247
  x = x.flatten(2).transpose(1, 2)
248
248
  x = self.norm(x)
249
249
 
@@ -267,7 +267,7 @@ class Stem(nn.Module):
267
267
 
268
268
  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
269
269
  x = self.conv(x)
270
- (_, _, H, W) = x.size()
270
+ _, _, H, W = x.size()
271
271
  x = x.flatten(2).transpose(1, 2)
272
272
  x = self.norm(x)
273
273
 
@@ -329,7 +329,7 @@ class SMTStage(nn.Module):
329
329
 
330
330
  def forward(self, x: torch.Tensor) -> torch.Tensor:
331
331
  B = x.size(0)
332
- (x, H, W) = self.downsample_block(x)
332
+ x, H, W = self.downsample_block(x)
333
333
  x = self.blocks(x, H, W)
334
334
  x = self.norm(x)
335
335
  x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
@@ -21,7 +21,7 @@ from birder.net.ssl.base import SSLBaseNet
21
21
 
22
22
  def off_diagonal(x: torch.Tensor) -> torch.Tensor:
23
23
  # Return a flattened view of the off-diagonal elements of a square matrix
24
- (n, _) = x.size()
24
+ n, _ = x.size()
25
25
  # assert n == m
26
26
  return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
27
27
 
birder/net/ssl/byol.py CHANGED
@@ -80,11 +80,11 @@ class BYOL(SSLBaseNet):
80
80
  def forward(self, x: torch.Tensor) -> torch.Tensor:
81
81
  projection = self.online_encoder(x)
82
82
  online_predictions = self.online_predictor(projection)
83
- (online_pred_one, online_pred_two) = online_predictions.chunk(2, dim=0)
83
+ online_pred_one, online_pred_two = online_predictions.chunk(2, dim=0)
84
84
 
85
85
  with torch.no_grad():
86
86
  target_projections = self.target_encoder(x)
87
- (target_proj_one, target_proj_two) = target_projections.chunk(2, dim=0)
87
+ target_proj_one, target_proj_two = target_projections.chunk(2, dim=0)
88
88
 
89
89
  loss_one = loss_fn(online_pred_one, target_proj_two.detach())
90
90
  loss_two = loss_fn(online_pred_two, target_proj_one.detach())
birder/net/ssl/capi.py CHANGED
@@ -263,11 +263,11 @@ class CrossAttention(nn.Module):
263
263
  self.proj = nn.Linear(decoder_dim, decoder_dim)
264
264
 
265
265
  def forward(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
266
- (B, N, C) = tgt.size()
266
+ B, N, C = tgt.size()
267
267
  n_kv = memory.size(1)
268
268
  q = self.q(tgt).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
269
269
  kv = self.kv(memory).reshape(B, n_kv, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
270
- (k, v) = kv.unbind(0)
270
+ k, v = kv.unbind(0)
271
271
 
272
272
  attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) # pylint: disable=not-callable
273
273
  x = attn.transpose(1, 2).reshape(B, N, C)
@@ -419,7 +419,7 @@ class CAPITeacher(SSLBaseNet):
419
419
  x = self.backbone.masked_encoding_omission(x, ids_keep)["tokens"]
420
420
 
421
421
  x = x[:, self.backbone.num_special_tokens :, :]
422
- (assignments, clustering_loss) = self.head(x.transpose(0, 1))
422
+ assignments, clustering_loss = self.head(x.transpose(0, 1))
423
423
 
424
424
  assignments = assignments.detach().transpose(0, 1)
425
425
  row_indices = torch.arange(B).unsqueeze(1).expand_as(ids_predict)
@@ -68,7 +68,7 @@ class Decoder2d(nn.Module):
68
68
  self.proj = nn.Linear(embed_dim, in_channels)
69
69
 
70
70
  def forward(self, x: torch.Tensor) -> torch.Tensor:
71
- (B, _, C) = x.size() # B, N, C
71
+ B, _, C = x.size() # B, N, C
72
72
 
73
73
  x = x.transpose(1, 2).reshape(B, C, self.H, self.W)
74
74
 
birder/net/ssl/dino_v2.py CHANGED
@@ -148,7 +148,17 @@ class DINOLoss(nn.Module):
148
148
 
149
149
  def forward(
150
150
  self, student_output_list: list[torch.Tensor], teacher_out_softmax_centered_list: list[torch.Tensor]
151
- ) -> float:
151
+ ) -> torch.Tensor:
152
+ s = torch.stack(student_output_list, 0)
153
+ t = torch.stack(teacher_out_softmax_centered_list, 0)
154
+ lsm = F.log_softmax(s / self.student_temp, dim=-1)
155
+ loss = -(torch.einsum("tbk,sbk->tsb", t, lsm).mean(-1).sum())
156
+
157
+ return loss
158
+
159
+ def forward_reference(
160
+ self, student_output_list: list[torch.Tensor], teacher_out_softmax_centered_list: list[torch.Tensor]
161
+ ) -> torch.Tensor:
152
162
  total_loss = 0.0
153
163
  for s in student_output_list:
154
164
  lsm = F.log_softmax(s / self.student_temp, dim=-1)
birder/net/ssl/franca.py CHANGED
@@ -69,7 +69,7 @@ class DINOHeadMRL(nn.Module):
69
69
  ) -> None:
70
70
  super().__init__()
71
71
  self.nesting_list = nesting_list
72
- self.matryoshka_projections = nn.ModuleList([nn.Linear(dim, dim, bias=True) for dim in self.nesting_list])
72
+ self.matryoshka_projections = nn.ModuleList([nn.Linear(dim, dim) for dim in self.nesting_list])
73
73
 
74
74
  self.mlps = nn.ModuleList(
75
75
  [
@@ -197,7 +197,31 @@ class DINOLossMRL(nn.Module):
197
197
  teacher_out_softmax_centered_list: list[torch.Tensor],
198
198
  n_crops: int | tuple[int, int],
199
199
  teacher_global: bool,
200
- ) -> float:
200
+ ) -> torch.Tensor:
201
+ total_loss = 0.0
202
+ if teacher_global is False:
203
+ for student_outputs, teacher_outputs in zip(student_output_list, teacher_out_softmax_centered_list):
204
+ s = torch.stack(student_outputs.chunk(n_crops[0]), 0) # type: ignore[index]
205
+ t = teacher_outputs.view(n_crops[1], -1, teacher_outputs.shape[-1]) # type: ignore[index]
206
+ lsm = F.log_softmax(s / self.student_temp, dim=-1)
207
+ total_loss -= torch.einsum("tbk,sbk->tsb", t, lsm).mean(-1).sum()
208
+
209
+ else:
210
+ for student_outputs, teacher_outputs in zip(student_output_list, teacher_out_softmax_centered_list):
211
+ teacher_outputs = teacher_outputs.view(n_crops, -1, teacher_outputs.shape[-1])
212
+ lsm = F.log_softmax(student_outputs / self.student_temp, dim=-1)
213
+ loss = torch.sum(teacher_outputs.flatten(0, 1) * lsm, dim=-1)
214
+ total_loss -= loss.mean()
215
+
216
+ return total_loss
217
+
218
+ def forward_reference(
219
+ self,
220
+ student_output_list: list[torch.Tensor],
221
+ teacher_out_softmax_centered_list: list[torch.Tensor],
222
+ n_crops: int | tuple[int, int],
223
+ teacher_global: bool,
224
+ ) -> torch.Tensor:
201
225
  total_loss = 0.0
202
226
  if teacher_global is False:
203
227
  for student_outputs, teacher_outputs in zip(student_output_list, teacher_out_softmax_centered_list):
birder/net/ssl/i_jepa.py CHANGED
@@ -69,11 +69,11 @@ class MultiBlockMasking:
69
69
  ) -> tuple[int, int]:
70
70
  _rand = torch.rand(1).item()
71
71
 
72
- (min_s, max_s) = scale
72
+ min_s, max_s = scale
73
73
  mask_scale = min_s + _rand * (max_s - min_s)
74
74
  max_keep = int(self.height * self.width * mask_scale)
75
75
 
76
- (min_ar, max_ar) = aspect_ratio_scale
76
+ min_ar, max_ar = aspect_ratio_scale
77
77
  aspect_ratio = min_ar + _rand * (max_ar - min_ar)
78
78
 
79
79
  # Compute block height and width (given scale and aspect-ratio)
@@ -154,7 +154,7 @@ class MultiBlockMasking:
154
154
  masks_p = []
155
155
  masks_c = []
156
156
  for _ in range(self.n_pred):
157
- (mask, mask_c) = self._sample_block_mask(p_size)
157
+ mask, mask_c = self._sample_block_mask(p_size)
158
158
  masks_p.append(mask)
159
159
  masks_c.append(mask_c)
160
160
  min_keep_pred = min(min_keep_pred, len(mask))
@@ -167,7 +167,7 @@ class MultiBlockMasking:
167
167
 
168
168
  masks_e = []
169
169
  for _ in range(self.n_enc):
170
- (mask, _) = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions)
170
+ mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions)
171
171
  masks_e.append(mask)
172
172
  min_keep_enc = min(min_keep_enc, len(mask))
173
173
 
birder/net/ssl/mmcr.py CHANGED
@@ -125,7 +125,7 @@ class MMCR(SSLBaseNet):
125
125
  self.momentum_encoder.load_state_dict(self.encoder.state_dict())
126
126
 
127
127
  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
128
- (C, H, W) = x.shape[-3:] # B, num_views, C, H, W
128
+ C, H, W = x.shape[-3:] # B, num_views, C, H, W
129
129
  x = x.reshape(-1, C, H, W)
130
130
  z = self.encoder(x)
131
131
 
birder/net/swiftformer.py CHANGED
@@ -111,7 +111,7 @@ class EfficientAdditiveAttention(nn.Module):
111
111
  self.final = nn.Linear(token_dim * num_heads, token_dim)
112
112
 
113
113
  def forward(self, x: torch.Tensor) -> torch.Tensor:
114
- (B, _, H, W) = x.size()
114
+ B, _, H, W = x.size()
115
115
  x = x.flatten(2).permute(0, 2, 1)
116
116
 
117
117
  query = F.normalize(self.to_query(x), dim=-1)
@@ -30,7 +30,7 @@ from birder.net.base import DetectorBackbone
30
30
 
31
31
 
32
32
  def patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
33
- (H, W, _) = x.shape[-3:]
33
+ H, W, _ = x.shape[-3:]
34
34
  x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
35
35
  x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
36
36
  x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
@@ -73,13 +73,13 @@ def shifted_window_attention(
73
73
  proj_bias: Optional[torch.Tensor] = None,
74
74
  logit_scale: Optional[torch.Tensor] = None,
75
75
  ) -> torch.Tensor:
76
- (B, H, W, C) = x.size()
76
+ B, H, W, C = x.size()
77
77
 
78
78
  # Pad feature maps to multiples of window size
79
79
  pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
80
80
  pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
81
81
  x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
82
- (_, pad_h, pad_w, _) = x.size()
82
+ _, pad_h, pad_w, _ = x.size()
83
83
 
84
84
  # If window size is larger than feature size, there is no need to shift window
85
85
  shift_size_w = shift_size[0]
@@ -309,7 +309,6 @@ class Swin_Transformer_v1(DetectorBackbone):
309
309
  kernel_size=(patch_size, patch_size),
310
310
  stride=(patch_size, patch_size),
311
311
  padding=(0, 0),
312
- bias=True,
313
312
  ),
314
313
  Permute([0, 2, 3, 1]),
315
314
  nn.LayerNorm(embed_dim, eps=1e-5),
@@ -434,7 +433,7 @@ class Swin_Transformer_v1(DetectorBackbone):
434
433
  num_attn_heads = rel_pos_bias.size(1)
435
434
 
436
435
  def _calc(src: int, dst: int) -> list[float]:
437
- (left, right) = 1.01, 1.5
436
+ left, right = 1.01, 1.5
438
437
  while right - left > 1e-6:
439
438
  q = (left + right) / 2.0
440
439
  gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
@@ -76,7 +76,9 @@ class ShiftedWindowAttention(nn.Module):
76
76
 
77
77
  # MLP to generate continuous relative position bias
78
78
  self.cpb_mlp = nn.Sequential(
79
- nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
79
+ nn.Linear(2, 512),
80
+ nn.ReLU(inplace=True),
81
+ nn.Linear(512, num_heads, bias=False),
80
82
  )
81
83
  if qkv_bias is True:
82
84
  length = self.qkv.bias.numel() // 3
@@ -224,12 +226,7 @@ class Swin_Transformer_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentio
224
226
 
225
227
  self.stem = nn.Sequential(
226
228
  nn.Conv2d(
227
- self.input_channels,
228
- embed_dim,
229
- kernel_size=(patch_size, patch_size),
230
- stride=patch_size,
231
- padding=(0, 0),
232
- bias=True,
229
+ self.input_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, padding=(0, 0)
233
230
  ),
234
231
  Permute([0, 2, 3, 1]),
235
232
  nn.LayerNorm(embed_dim, eps=1e-5),
birder/net/tiny_vit.py CHANGED
@@ -201,12 +201,12 @@ class Attention(nn.Module):
201
201
 
202
202
  def forward(self, x: torch.Tensor) -> torch.Tensor:
203
203
  attn_bias = self.attention_biases[:, self.attention_bias_idxs]
204
- (B, N, _) = x.shape
204
+ B, N, _ = x.shape
205
205
 
206
206
  # Normalization
207
207
  x = self.norm(x)
208
208
  qkv = self.qkv(x)
209
- (q, k, v) = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
209
+ q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
210
210
 
211
211
  q = q.permute(0, 2, 1, 3)
212
212
  k = k.permute(0, 2, 1, 3)
@@ -252,7 +252,7 @@ class TinyVitBlock(nn.Module):
252
252
  )
253
253
 
254
254
  def forward(self, x: torch.Tensor) -> torch.Tensor:
255
- (B, H, W, C) = x.shape
255
+ B, H, W, C = x.shape
256
256
  L = H * W
257
257
 
258
258
  shortcut = x
birder/net/transnext.py CHANGED
@@ -32,8 +32,8 @@ def get_relative_position_cpb(
32
32
  axis_kh = F.adaptive_avg_pool1d(axis_qh.unsqueeze(0), key_size[0]).squeeze(0) # pylint: disable=not-callable
33
33
  axis_qw = torch.arange(query_size[1], dtype=torch.float32, device=device)
34
34
  axis_kw = F.adaptive_avg_pool1d(axis_qw.unsqueeze(0), key_size[1]).squeeze(0) # pylint: disable=not-callable
35
- (axis_kh, axis_kw) = torch.meshgrid(axis_kh, axis_kw, indexing="ij")
36
- (axis_qh, axis_qw) = torch.meshgrid(axis_qh, axis_qw, indexing="ij")
35
+ axis_kh, axis_kw = torch.meshgrid(axis_kh, axis_kw, indexing="ij")
36
+ axis_qh, axis_qw = torch.meshgrid(axis_qh, axis_qw, indexing="ij")
37
37
 
38
38
  axis_kh = torch.reshape(axis_kh, [-1])
39
39
  axis_kw = torch.reshape(axis_kw, [-1])
@@ -44,7 +44,7 @@ def get_relative_position_cpb(
44
44
  relative_w = (axis_qw[:, None] - axis_kw[None, :]) / (pretrain_size[1] - 1) * 8
45
45
  relative_hw = torch.stack([relative_h, relative_w], dim=-1).view(-1, 2)
46
46
 
47
- (relative_coords_table, idx_map) = torch.unique(relative_hw, return_inverse=True, dim=0)
47
+ relative_coords_table, idx_map = torch.unique(relative_hw, return_inverse=True, dim=0)
48
48
 
49
49
  relative_coords_table = (
50
50
  torch.sign(relative_coords_table)
@@ -86,9 +86,9 @@ class ConvolutionalGLU(nn.Module):
86
86
  self.drop = nn.Dropout(drop)
87
87
 
88
88
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
89
- (x, v) = self.fc1(x).chunk(2, dim=-1)
89
+ x, v = self.fc1(x).chunk(2, dim=-1)
90
90
 
91
- (B, _, C) = x.size()
91
+ B, _, C = x.size()
92
92
  x = x.transpose(1, 2).view(B, C, H, W).contiguous()
93
93
  x = self.dwconv(x)
94
94
  x = x.flatten(2).transpose(1, 2)
@@ -143,9 +143,9 @@ class Attention(nn.Module):
143
143
  def forward(
144
144
  self, x: torch.Tensor, _h: int, _w: int, relative_pos_index: torch.Tensor, relative_coords_table: torch.Tensor
145
145
  ) -> torch.Tensor:
146
- (B, N, C) = x.size()
146
+ B, N, C = x.size()
147
147
  qkv = self.qkv(x).reshape(B, -1, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)
148
- (q, k, v) = qkv.chunk(3, dim=1)
148
+ q, k, v = qkv.chunk(3, dim=1)
149
149
 
150
150
  # Use MLP to generate continuous relative positional bias
151
151
  rel_bias = (
@@ -217,9 +217,9 @@ class AggregatedAttention(nn.Module):
217
217
  self.act = nn.GELU()
218
218
 
219
219
  # MLP to generate continuous relative position bias
220
- self.cpb_fc1 = nn.Linear(2, 512, bias=True)
220
+ self.cpb_fc1 = nn.Linear(2, 512)
221
221
  self.cpb_act = nn.ReLU(inplace=True)
222
- self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)
222
+ self.cpb_fc2 = nn.Linear(512, num_heads)
223
223
 
224
224
  # relative bias for local features
225
225
  self.relative_pos_bias_local = nn.Parameter(
@@ -227,7 +227,7 @@ class AggregatedAttention(nn.Module):
227
227
  )
228
228
 
229
229
  # Generate padding_mask and sequence length scale
230
- (local_seq_length, padding_mask) = get_seqlen_and_mask(input_resolution, self.window_size)
230
+ local_seq_length, padding_mask = get_seqlen_and_mask(input_resolution, self.window_size)
231
231
  self.seq_length_scale = nn.Buffer(torch.log(local_seq_length + self.pool_len), persistent=False)
232
232
  self.padding_mask = nn.Buffer(padding_mask, persistent=False)
233
233
 
@@ -240,7 +240,7 @@ class AggregatedAttention(nn.Module):
240
240
  def forward(
241
241
  self, x: torch.Tensor, H: int, W: int, relative_pos_index: torch.Tensor, relative_coords_table: torch.Tensor
242
242
  ) -> torch.Tensor:
243
- (B, N, C) = x.size()
243
+ B, N, C = x.size()
244
244
 
245
245
  # Generate queries, normalize them with L2, add query embedding,
246
246
  # and then magnify with sequence length scale and temperature.
@@ -252,7 +252,7 @@ class AggregatedAttention(nn.Module):
252
252
  * self.seq_length_scale
253
253
  )
254
254
 
255
- (attn_local, v_local) = self.swa_qk_rpb(
255
+ attn_local, v_local = self.swa_qk_rpb(
256
256
  self.kv(x),
257
257
  q_norm_scaled.contiguous(),
258
258
  self.relative_pos_bias_local,
@@ -272,7 +272,7 @@ class AggregatedAttention(nn.Module):
272
272
 
273
273
  # Generate pooled keys and values
274
274
  kv_pool = self.kv(x_).reshape(B, self.pool_len, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)
275
- (k_pool, v_pool) = kv_pool.chunk(2, dim=1)
275
+ k_pool, v_pool = kv_pool.chunk(2, dim=1)
276
276
 
277
277
  # Use MLP to generate continuous relative positional bias for pooled features.
278
278
  pool_bias = (
@@ -288,7 +288,7 @@ class AggregatedAttention(nn.Module):
288
288
  attn = self.attn_drop(attn)
289
289
 
290
290
  # Split the attention weights and separately aggregate the values of local & pooled features
291
- (attn_local, attn_pool) = torch.split(attn, [self.local_len, self.pool_len], dim=-1)
291
+ attn_local, attn_pool = torch.split(attn, [self.local_len, self.pool_len], dim=-1)
292
292
 
293
293
  x_local = self.swa_av(
294
294
  q_norm, attn_local, v_local.contiguous(), self.learnable_tokens, self.learnable_bias, self.window_size, H, W
@@ -367,7 +367,7 @@ class OverlapPatchEmbed(nn.Module):
367
367
 
368
368
  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
369
369
  x = self.proj(x)
370
- (_, _, H, W) = x.size()
370
+ _, _, H, W = x.size()
371
371
  x = x.flatten(2).transpose(1, 2)
372
372
  x = self.norm(x)
373
373
 
@@ -396,7 +396,7 @@ class TransNeXtStage(nn.Module):
396
396
 
397
397
  # Generate relative positional coordinate table and index for each stage
398
398
  # to compute continuous relative positional bias
399
- (relative_pos_index, relative_coords_table) = get_relative_position_cpb(
399
+ relative_pos_index, relative_coords_table = get_relative_position_cpb(
400
400
  query_size=input_resolution, key_size=(input_resolution[0] // sr_ratio, input_resolution[1] // sr_ratio)
401
401
  )
402
402
  self.relative_pos_index = nn.Buffer(relative_pos_index, persistent=False)
@@ -430,7 +430,7 @@ class TransNeXtStage(nn.Module):
430
430
 
431
431
  def forward(self, x: torch.Tensor) -> torch.Tensor:
432
432
  B = x.size(0)
433
- (x, H, W) = self.patch_embed(x)
433
+ x, H, W = self.patch_embed(x)
434
434
  for blk in self.blocks:
435
435
  x = blk(x, H, W, self.relative_pos_index, self.relative_coords_table)
436
436
 
@@ -553,7 +553,7 @@ class TransNeXt(DetectorBackbone):
553
553
  sr_ratio = self.sr_ratio[i]
554
554
  with torch.no_grad():
555
555
  device = next(m.parameters()).device
556
- (relative_pos_index, relative_coords_table) = get_relative_position_cpb(
556
+ relative_pos_index, relative_coords_table = get_relative_position_cpb(
557
557
  query_size=input_resolution,
558
558
  key_size=(input_resolution[0] // sr_ratio, input_resolution[1] // sr_ratio),
559
559
  device=device,
@@ -574,7 +574,7 @@ class TransNeXt(DetectorBackbone):
574
574
  blk.pool_len = pool_h * pool_w
575
575
  blk.pool = nn.AdaptiveAvgPool2d((pool_h, pool_w))
576
576
 
577
- (local_seq_length, padding_mask) = get_seqlen_and_mask(
577
+ local_seq_length, padding_mask = get_seqlen_and_mask(
578
578
  input_resolution, blk.window_size, device=device
579
579
  )
580
580
  blk.seq_length_scale = nn.Buffer(
birder/net/uniformer.py CHANGED
@@ -71,9 +71,9 @@ class Attention(nn.Module):
71
71
  self.proj_drop = nn.Dropout(proj_drop)
72
72
 
73
73
  def forward(self, x: torch.Tensor) -> torch.Tensor:
74
- (B, N, C) = x.shape
74
+ B, N, C = x.shape
75
75
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
76
- (q, k, v) = qkv.unbind(0)
76
+ q, k, v = qkv.unbind(0)
77
77
 
78
78
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
79
79
  q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
@@ -137,7 +137,7 @@ class AttentionBlock(nn.Module):
137
137
 
138
138
  def forward(self, x: torch.Tensor) -> torch.Tensor:
139
139
  x = x + self.pos_embed(x)
140
- (B, N, H, W) = x.shape
140
+ B, N, H, W = x.shape
141
141
  x = x.flatten(2).transpose(1, 2)
142
142
  x = x + self.drop_path(self.layer_scale_1(self.attn(self.norm1(x))))
143
143
  x = x + self.drop_path(self.layer_scale_2(self.mlp(self.norm2(x))))
@@ -155,7 +155,7 @@ class PatchEmbed(nn.Module):
155
155
 
156
156
  def forward(self, x: torch.Tensor) -> torch.Tensor:
157
157
  x = self.proj(x)
158
- (B, _, H, W) = x.size() # B, C, H, W
158
+ B, _, H, W = x.size() # B, C, H, W
159
159
  x = x.flatten(2).transpose(1, 2)
160
160
  x = self.norm(x)
161
161
  x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
birder/net/vgg.py CHANGED
@@ -40,16 +40,7 @@ class Vgg(DetectorBackbone):
40
40
  else:
41
41
  in_channels = filters[i]
42
42
 
43
- layers.append(
44
- nn.Conv2d(
45
- in_channels,
46
- filters[i],
47
- kernel_size=(3, 3),
48
- stride=(1, 1),
49
- padding=(1, 1),
50
- bias=True,
51
- )
52
- )
43
+ layers.append(nn.Conv2d(in_channels, filters[i], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
53
44
  layers.append(nn.ReLU(inplace=True))
54
45
 
55
46
  layers.append(nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)))