birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/mim/simmim.py CHANGED
@@ -80,7 +80,6 @@ class SimMIM(MIMBaseNet):
80
80
  kernel_size=(1, 1),
81
81
  stride=(1, 1),
82
82
  padding=(0, 0),
83
- bias=True,
84
83
  )
85
84
 
86
85
  self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.encoder.stem_width))
@@ -112,7 +111,7 @@ class SimMIM(MIMBaseNet):
112
111
  """
113
112
 
114
113
  if x.ndim == 4:
115
- (n, c, _, _) = x.shape
114
+ n, c, _, _ = x.shape
116
115
  x = x.reshape(n, c, -1)
117
116
  x = torch.einsum("ncl->nlc", x)
118
117
 
@@ -135,7 +134,7 @@ class SimMIM(MIMBaseNet):
135
134
  mask: 0 is keep, 1 is remove
136
135
  """
137
136
 
138
- (N, C, _, _) = pred.shape
137
+ N, C, _, _ = pred.shape
139
138
  pred = pred.reshape(N, C, -1)
140
139
  pred = torch.einsum("ncl->nlc", pred)
141
140
 
@@ -142,24 +142,24 @@ class MultiQueryAttention(nn.Module):
142
142
  self.output = nn.Sequential(*output_layers)
143
143
 
144
144
  def forward(self, x: torch.Tensor) -> torch.Tensor:
145
- (B, C, H, W) = x.size()
145
+ B, C, H, W = x.size()
146
146
  q = self.query(x)
147
147
  q = q.reshape(B, self.num_heads, self.key_dim, -1)
148
148
  q = q.transpose(-1, -2).contiguous()
149
149
 
150
150
  k = self.key(x)
151
- (B, C, _, _) = k.size()
151
+ B, C, _, _ = k.size()
152
152
  k = k.reshape(B, C, -1).transpose(1, 2)
153
153
  k = k.unsqueeze(1).contiguous()
154
154
 
155
155
  v = self.value(x)
156
- (B, C, _, _) = v.size()
156
+ B, C, _, _ = v.size()
157
157
  v = v.reshape(B, C, -1).transpose(1, 2)
158
158
  v = v.unsqueeze(1).contiguous()
159
159
 
160
160
  # Calculate attention score
161
161
  attn_score = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) # pylint: disable=not-callable
162
- (B, _, _, C) = attn_score.size()
162
+ B, _, _, C = attn_score.size()
163
163
  feat_dim = C * self.num_heads
164
164
  attn_score = attn_score.transpose(1, 2)
165
165
  attn_score = (
birder/net/mobileone.py CHANGED
@@ -61,13 +61,7 @@ class MobileOneBlock(nn.Module):
61
61
 
62
62
  if reparameterized is True:
63
63
  self.reparam_conv = nn.Conv2d(
64
- in_channels,
65
- out_channels,
66
- kernel_size=kernel_size,
67
- stride=stride,
68
- padding=padding,
69
- groups=groups,
70
- bias=True,
64
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups
71
65
  )
72
66
  else:
73
67
  self.reparam_conv = None
@@ -144,7 +138,7 @@ class MobileOneBlock(nn.Module):
144
138
  if self.reparameterized is True:
145
139
  return
146
140
 
147
- (kernel, bias) = self._get_kernel_bias()
141
+ kernel, bias = self._get_kernel_bias()
148
142
  self.reparam_conv = nn.Conv2d(
149
143
  in_channels=self.in_channels,
150
144
  out_channels=self.out_channels,
@@ -152,7 +146,6 @@ class MobileOneBlock(nn.Module):
152
146
  stride=self.stride,
153
147
  padding=self.padding,
154
148
  groups=self.groups,
155
- bias=True,
156
149
  )
157
150
  self.reparam_conv.weight.data = kernel
158
151
  self.reparam_conv.bias.data = bias
@@ -178,7 +171,7 @@ class MobileOneBlock(nn.Module):
178
171
  kernel_scale = 0
179
172
  bias_scale = 0
180
173
  if self.rbr_scale is not None:
181
- (kernel_scale, bias_scale) = self._fuse_bn_tensor(self.rbr_scale)
174
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
182
175
  pad = self.kernel_size // 2
183
176
  kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
184
177
 
@@ -186,13 +179,13 @@ class MobileOneBlock(nn.Module):
186
179
  kernel_identity = 0
187
180
  bias_identity = 0
188
181
  if self.rbr_skip is not None:
189
- (kernel_identity, bias_identity) = self._fuse_bn_tensor(self.rbr_skip)
182
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
190
183
 
191
184
  # Get weights and bias of conv branches
192
185
  kernel_conv = 0
193
186
  bias_conv = 0
194
187
  for ix in range(self.num_conv_branches):
195
- (_kernel, _bias) = self._fuse_bn_tensor(self.rbr_conv[ix])
188
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
196
189
  kernel_conv += _kernel
197
190
  bias_conv += _bias
198
191
 
@@ -101,8 +101,8 @@ class MobileVitBlock(nn.Module):
101
101
  x = self.conv_1x1(x)
102
102
 
103
103
  # Unfold (feature map -> patches)
104
- (patch_h, patch_w) = self.patch_size
105
- (B, C, H, W) = x.shape
104
+ patch_h, patch_w = self.patch_size
105
+ B, C, H, W = x.shape
106
106
  new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
107
107
  num_patch_h = new_h // patch_h # n_h, n_w
108
108
  num_patch_w = new_w // patch_w
@@ -63,7 +63,7 @@ class LinearSelfAttention(nn.Module):
63
63
  # Project x into query, key and value
64
64
  # Query --> [B, 1, P, N]
65
65
  # value, key --> [B, d, P, N]
66
- (query, key, value) = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
66
+ query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
67
67
 
68
68
  # apply softmax along N dimension
69
69
  context_scores = F.softmax(query, dim=-1)
@@ -98,14 +98,10 @@ class LinearTransformerBlock(nn.Module):
98
98
 
99
99
  self.norm2 = nn.GroupNorm(num_groups=1, num_channels=embed_dim)
100
100
  self.mlp = nn.Sequential(
101
- nn.Conv2d(
102
- embed_dim, int(embed_dim * mlp_ratio), kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
103
- ),
101
+ nn.Conv2d(embed_dim, int(embed_dim * mlp_ratio), kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
104
102
  nn.SiLU(),
105
103
  nn.Dropout(drop),
106
- nn.Conv2d(
107
- int(embed_dim * mlp_ratio), embed_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True
108
- ),
104
+ nn.Conv2d(int(embed_dim * mlp_ratio), embed_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
109
105
  )
110
106
  self.drop_path2 = StochasticDepth(drop_path, mode="row")
111
107
 
@@ -166,8 +162,8 @@ class MobileVitBlock(nn.Module):
166
162
  self.patch_area = self.patch_size[0] * self.patch_size[1]
167
163
 
168
164
  def forward(self, x: torch.Tensor) -> torch.Tensor:
169
- (B, C, H, W) = x.shape
170
- (patch_h, patch_w) = self.patch_size
165
+ B, C, H, W = x.shape
166
+ patch_h, patch_w = self.patch_size
171
167
  new_h = math.ceil(H / patch_h) * patch_h
172
168
  new_w = math.ceil(W / patch_w) * patch_w
173
169
  num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
birder/net/mvit_v2.py CHANGED
@@ -36,7 +36,7 @@ from birder.net.base import TokenRetentionResultType
36
36
  def pre_pool(
37
37
  x: torch.Tensor, hw_shape: tuple[int, int], has_cls_token: bool
38
38
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
39
- (H, W) = hw_shape
39
+ H, W = hw_shape
40
40
  if has_cls_token is True:
41
41
  cls_tok = x[:, :, :1, :]
42
42
  x = x[:, :, 1:, :]
@@ -68,8 +68,8 @@ def cal_rel_pos_spatial(
68
68
  rel_pos_w: torch.Tensor,
69
69
  ) -> torch.Tensor:
70
70
  sp_idx = 1 if has_cls_token is True else 0
71
- (q_h, q_w) = q_shape
72
- (k_h, k_w) = k_shape
71
+ q_h, q_w = q_shape
72
+ k_h, k_w = k_shape
73
73
 
74
74
  # Scale up rel pos if shapes for q and k are different.
75
75
  q_h_ratio = max(k_h / q_h, 1.0)
@@ -90,7 +90,7 @@ def cal_rel_pos_spatial(
90
90
  rel_h = rel_pos_h[dist_h.long()]
91
91
  rel_w = rel_pos_w[dist_w.long()]
92
92
 
93
- (B, n_head, _, dim) = q.shape
93
+ B, n_head, _, dim = q.shape
94
94
 
95
95
  r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
96
96
  rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
@@ -108,7 +108,7 @@ class SequentialWithShape(nn.Sequential):
108
108
  self, x: torch.Tensor, hw_shape: tuple[int, int]
109
109
  ) -> tuple[torch.Tensor, tuple[int, int]]:
110
110
  for module in self:
111
- (x, hw_shape) = module(x, hw_shape)
111
+ x, hw_shape = module(x, hw_shape)
112
112
 
113
113
  return (x, hw_shape)
114
114
 
@@ -129,7 +129,7 @@ class PatchEmbed(nn.Module):
129
129
 
130
130
  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
131
131
  x = self.proj(x)
132
- (H, W) = x.shape[2:4]
132
+ H, W = x.shape[2:4]
133
133
 
134
134
  x = x.flatten(2).transpose(1, 2)
135
135
 
@@ -227,31 +227,31 @@ class MultiScaleAttention(nn.Module):
227
227
  nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
228
228
 
229
229
  def forward(self, x: torch.Tensor, hw_shape: tuple[int, int]) -> tuple[torch.Tensor, tuple[int, int]]:
230
- (B, N, _) = x.size()
230
+ B, N, _ = x.size()
231
231
 
232
232
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
233
- (q, k, v) = qkv.unbind(dim=0)
233
+ q, k, v = qkv.unbind(dim=0)
234
234
 
235
235
  if self.pool_q is not None:
236
- (q, q_tok) = pre_pool(q, hw_shape, self.has_cls_token)
236
+ q, q_tok = pre_pool(q, hw_shape, self.has_cls_token)
237
237
  q = self.pool_q(q)
238
- (q, q_shape) = post_pool(q, self.num_heads, q_tok)
238
+ q, q_shape = post_pool(q, self.num_heads, q_tok)
239
239
  q = self.norm_q(q)
240
240
  else:
241
241
  q_shape = hw_shape
242
242
 
243
243
  if self.pool_k is not None:
244
- (k, k_tok) = pre_pool(k, hw_shape, self.has_cls_token)
244
+ k, k_tok = pre_pool(k, hw_shape, self.has_cls_token)
245
245
  k = self.pool_k(k)
246
- (k, k_shape) = post_pool(k, self.num_heads, k_tok)
246
+ k, k_shape = post_pool(k, self.num_heads, k_tok)
247
247
  k = self.norm_k(k)
248
248
  else:
249
249
  k_shape = hw_shape
250
250
 
251
251
  if self.pool_v is not None:
252
- (v, v_tok) = pre_pool(v, hw_shape, self.has_cls_token)
252
+ v, v_tok = pre_pool(v, hw_shape, self.has_cls_token)
253
253
  v = self.pool_v(v)
254
- (v, _) = post_pool(v, self.num_heads, v_tok)
254
+ v, _ = post_pool(v, self.num_heads, v_tok)
255
255
  v = self.norm_v(v)
256
256
 
257
257
  attn = (q * self.scale) @ k.transpose(-2, -1)
@@ -337,8 +337,8 @@ class MultiScaleBlock(nn.Module):
337
337
  else:
338
338
  cls_tok = None
339
339
 
340
- (B, _, C) = x.size()
341
- (H, W) = hw_shape
340
+ B, _, C = x.size()
341
+ H, W = hw_shape
342
342
  x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
343
343
  x = self.pool_skip(x)
344
344
  x = x.reshape(B, C, -1).transpose(1, 2)
@@ -349,7 +349,7 @@ class MultiScaleBlock(nn.Module):
349
349
 
350
350
  def forward(self, x: torch.Tensor, hw_shape: tuple[int, int]) -> tuple[torch.Tensor, tuple[int, int]]:
351
351
  x_norm = self.norm1(x)
352
- (x_block, hw_shape_new) = self.attn(x_norm, hw_shape)
352
+ x_block, hw_shape_new = self.attn(x_norm, hw_shape)
353
353
 
354
354
  if self.proj_attn is not None:
355
355
  x = self.proj_attn(x_norm)
@@ -421,7 +421,7 @@ class MultiScaleVitStage(nn.Module):
421
421
 
422
422
  def forward(self, x: torch.Tensor, hw_shape: tuple[int, int]) -> tuple[torch.Tensor, tuple[int, int]]:
423
423
  for blk in self.blocks:
424
- (x, hw_shape) = blk(x, hw_shape)
424
+ x, hw_shape = blk(x, hw_shape)
425
425
 
426
426
  return (x, hw_shape)
427
427
 
@@ -523,14 +523,14 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
523
523
  nn.init.zeros_(m.bias)
524
524
 
525
525
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
526
- (x, hw_shape) = self.patch_embed(x)
526
+ x, hw_shape = self.patch_embed(x)
527
527
  if self.cls_token is not None:
528
528
  cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
529
529
  x = torch.concat((cls_tokens, x), dim=1)
530
530
 
531
531
  out = {}
532
532
  for name, module in self.body.named_children():
533
- (x, hw_shape) = module(x, hw_shape)
533
+ x, hw_shape = module(x, hw_shape)
534
534
  if name in self.return_stages:
535
535
  x_inter = x
536
536
  if self.cls_token is not None:
@@ -561,7 +561,7 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
561
561
  ) -> TokenRetentionResultType:
562
562
  B = x.size(0)
563
563
 
564
- (x, hw_shape) = self.patch_embed(x)
564
+ x, hw_shape = self.patch_embed(x)
565
565
  x = mask_tensor(
566
566
  x.permute(0, 2, 1).reshape(B, -1, hw_shape[0], hw_shape[1]),
567
567
  mask,
@@ -574,7 +574,7 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
574
574
  cls_tokens = self.cls_token.expand(B, -1, -1)
575
575
  x = torch.concat((cls_tokens, x), dim=1)
576
576
 
577
- (x, _) = self.body(x, hw_shape)
577
+ x, _ = self.body(x, hw_shape)
578
578
  x = self.norm(x)
579
579
 
580
580
  result: TokenRetentionResultType = {}
@@ -596,12 +596,12 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
596
596
  return result
597
597
 
598
598
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
599
- (x, hw_shape) = self.patch_embed(x)
599
+ x, hw_shape = self.patch_embed(x)
600
600
  if self.cls_token is not None:
601
601
  cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
602
602
  x = torch.concat((cls_tokens, x), dim=1)
603
603
 
604
- (x, _) = self.body(x, hw_shape)
604
+ x, _ = self.body(x, hw_shape)
605
605
  x = self.norm(x)
606
606
 
607
607
  return x
birder/net/nextvit.py CHANGED
@@ -165,7 +165,7 @@ class E_MHSA(nn.Module):
165
165
  self.norm = nn.Identity()
166
166
 
167
167
  def forward(self, x: torch.Tensor) -> torch.Tensor:
168
- (B, N, C) = x.size()
168
+ B, N, C = x.size()
169
169
  q = self.q(x)
170
170
  q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
171
171
 
@@ -226,7 +226,7 @@ class NTB(nn.Module):
226
226
 
227
227
  def forward(self, x: torch.Tensor) -> torch.Tensor:
228
228
  x = self.patch_embed(x)
229
- (B, C, H, W) = x.size()
229
+ B, C, H, W = x.size()
230
230
  out = self.norm1(x)
231
231
 
232
232
  out = out.reshape(B, C, H * W).permute(0, 2, 1)
birder/net/pit.py CHANGED
@@ -29,12 +29,12 @@ class SequentialTuple(nn.Sequential):
29
29
  self, x: tuple[torch.Tensor, torch.Tensor]
30
30
  ) -> tuple[torch.Tensor, torch.Tensor]:
31
31
  for module in self:
32
- x = module(x)
32
+ x = module(*x)
33
33
 
34
34
  return x
35
35
 
36
36
 
37
- class Transformer(nn.Module):
37
+ class PiTStage(nn.Module):
38
38
  def __init__(
39
39
  self,
40
40
  base_dim: int,
@@ -59,13 +59,12 @@ class Transformer(nn.Module):
59
59
  dpr=drop_path_prob,
60
60
  )
61
61
 
62
- def forward(self, xt: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
63
- (x, cls_tokens) = xt
62
+ def forward(self, x: torch.Tensor, cls_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
64
63
  token_length = cls_tokens.shape[1]
65
64
  if self.pool is not None:
66
- (x, cls_tokens) = self.pool(x, cls_tokens)
65
+ x, cls_tokens = self.pool(x, cls_tokens)
67
66
 
68
- (B, C, H, W) = x.size()
67
+ B, C, H, W = x.size()
69
68
  x = x.flatten(2).transpose(1, 2)
70
69
  x = torch.concat((cls_tokens, x), dim=1)
71
70
  x = self.encoder(x)
@@ -142,7 +141,7 @@ class PiT(DetectorBackbone):
142
141
  if i > 0:
143
142
  pool = Pooling(prev_dim, embed_dim)
144
143
 
145
- stages[f"stage{i+1}"] = Transformer(
144
+ stages[f"stage{i+1}"] = PiTStage(
146
145
  base_dims[i],
147
146
  depth,
148
147
  heads=heads[i],
@@ -158,7 +157,7 @@ class PiT(DetectorBackbone):
158
157
  self.body = SequentialTuple(stages)
159
158
  self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
160
159
 
161
- self.return_stages = self.return_stages[: len(depths)]
160
+ self.return_stages = [f"stage{idx + 1}" for idx in range(len(depths))]
162
161
  self.return_channels = return_channels
163
162
  self.embedding_size = embed_dim
164
163
  self.dist_classifier = self.create_classifier()
@@ -197,7 +196,7 @@ class PiT(DetectorBackbone):
197
196
 
198
197
  out = {}
199
198
  for name, module in self.body.named_children():
200
- (x, cls_tokens) = module((x, cls_tokens))
199
+ x, cls_tokens = module(x, cls_tokens)
201
200
  if name in self.return_stages:
202
201
  out[name] = x
203
202
 
@@ -218,12 +217,13 @@ class PiT(DetectorBackbone):
218
217
  x = self.stem(x)
219
218
  x = x + self.pos_embed
220
219
  cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
221
- (x, cls_tokens) = self.body((x, cls_tokens))
220
+ for stage in self.body.children():
221
+ x, cls_tokens = stage(x, cls_tokens)
222
222
 
223
223
  return (x, cls_tokens)
224
224
 
225
225
  def embedding(self, x: torch.Tensor) -> torch.Tensor:
226
- (_, cls_tokens) = self.forward_features(x)
226
+ _, cls_tokens = self.forward_features(x)
227
227
  cls_tokens = self.norm(cls_tokens)
228
228
 
229
229
  return cls_tokens
@@ -312,18 +312,3 @@ registry.register_model_config(
312
312
  "drop_path_rate": 0.1,
313
313
  },
314
314
  )
315
-
316
- registry.register_weights(
317
- "pit_t_il-common",
318
- {
319
- "description": "PiT tiny model trained on the il-common dataset",
320
- "resolution": (256, 256),
321
- "formats": {
322
- "pt": {
323
- "file_size": 18.4,
324
- "sha256": "5f6bd74b09c1ee541ee2ddae4844ce501b4b3218201ea6381fce0b8fc30257f2",
325
- }
326
- },
327
- "net": {"network": "pit_t", "tag": "il-common"},
328
- },
329
- )
birder/net/pvt_v1.py CHANGED
@@ -56,7 +56,7 @@ class Attention(nn.Module):
56
56
  self.norm = None
57
57
 
58
58
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
59
- (B, N, C) = x.shape
59
+ B, N, C = x.shape
60
60
  q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
61
61
 
62
62
  if self.sr is not None:
@@ -65,7 +65,7 @@ class Attention(nn.Module):
65
65
  x = self.norm(x)
66
66
 
67
67
  kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
68
- (k, v) = kv.unbind(0)
68
+ k, v = kv.unbind(0)
69
69
 
70
70
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
71
71
  q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
@@ -177,7 +177,7 @@ class PyramidVisionTransformerStage(nn.Module):
177
177
 
178
178
  def forward(self, x: torch.Tensor) -> torch.Tensor:
179
179
  x = self.downsample(x) # B, C, H, W -> B, H, W, C
180
- (B, H, W, C) = x.size()
180
+ B, H, W, C = x.size()
181
181
  x = x.reshape(B, -1, C)
182
182
  x = x + self.pos_embed
183
183
  if self.cls_token is not None:
@@ -264,7 +264,7 @@ class PVT_v1(DetectorBackbone):
264
264
 
265
265
  out = {}
266
266
  for name, module in self.body.named_children():
267
- (B, _, H, W) = x.size()
267
+ B, _, H, W = x.size()
268
268
  x = module(x)
269
269
  if name in self.return_stages:
270
270
  if name == "stage4":
birder/net/pvt_v2.py CHANGED
@@ -29,13 +29,7 @@ class MLP(nn.Module):
29
29
  self.fc1 = nn.Linear(in_features, hidden_features)
30
30
  self.relu = nn.ReLU() if extra_relu else nn.Identity()
31
31
  self.dwconv = nn.Conv2d(
32
- hidden_features,
33
- hidden_features,
34
- kernel_size=(3, 3),
35
- stride=(1, 1),
36
- padding=(1, 1),
37
- groups=hidden_features,
38
- bias=True,
32
+ hidden_features, hidden_features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=hidden_features
39
33
  )
40
34
  self.act = nn.GELU()
41
35
  self.fc2 = nn.Linear(hidden_features, in_features)
@@ -44,7 +38,7 @@ class MLP(nn.Module):
44
38
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
45
39
  x = self.fc1(x)
46
40
  x = self.relu(x)
47
- (B, _, C) = x.shape
41
+ B, _, C = x.shape
48
42
  x = x.transpose(1, 2).view(B, C, H, W)
49
43
  x = self.dwconv(x)
50
44
  x = x.flatten(2).transpose(1, 2)
@@ -98,7 +92,7 @@ class Attention(nn.Module):
98
92
  assert (self.pool is None and self.act is None) or (self.pool is not None and self.act is not None)
99
93
 
100
94
  def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
101
- (B, N, C) = x.shape
95
+ B, N, C = x.shape
102
96
  q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
103
97
 
104
98
  if self.pool is not None and self.act is not None:
@@ -114,7 +108,7 @@ class Attention(nn.Module):
114
108
  x = self.norm(x)
115
109
 
116
110
  kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
117
- (k, v) = kv.unbind(0)
111
+ k, v = kv.unbind(0)
118
112
 
119
113
  x = F.scaled_dot_product_attention( # pylint: disable=not-callable
120
114
  q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
@@ -238,7 +232,7 @@ class PyramidVisionTransformerStage(nn.Module):
238
232
 
239
233
  def forward(self, x: torch.Tensor) -> torch.Tensor:
240
234
  x = self.downsample(x) # B, C, H, W -> B, H, W, C
241
- (B, H, W, C) = x.shape
235
+ B, H, W, C = x.shape
242
236
  x = x.reshape(B, -1, C)
243
237
  for blk in self.blocks:
244
238
  x = blk(x, H, W)
birder/net/regionvit.py CHANGED
@@ -30,8 +30,8 @@ def convert_to_flatten_layout(
30
30
  cls_tokens: torch.Tensor, patch_tokens: torch.Tensor, ws: int
31
31
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], int, int, int, int, int, int]:
32
32
  # Padding if added will be at the bottom right
33
- (B, C, H, W) = patch_tokens.size()
34
- (_, _, h_ks, w_ks) = cls_tokens.size()
33
+ B, C, H, W = patch_tokens.size()
34
+ _, _, h_ks, w_ks = cls_tokens.size()
35
35
  need_mask = False
36
36
  p_l = 0
37
37
  p_r = 0
@@ -43,13 +43,13 @@ def convert_to_flatten_layout(
43
43
  patch_tokens = F.pad(patch_tokens, (p_l, p_r, p_t, p_b))
44
44
  need_mask = True
45
45
 
46
- (B, C, H, W) = patch_tokens.size()
46
+ B, C, H, W = patch_tokens.size()
47
47
  kernel_size = (H // h_ks, W // w_ks)
48
48
  tmp = F.unfold(patch_tokens, kernel_size=kernel_size, dilation=(1, 1), padding=(0, 0), stride=kernel_size)
49
49
  patch_tokens = tmp.transpose(1, 2).reshape(-1, C, kernel_size[0] * kernel_size[1]).transpose(-2, -1)
50
50
 
51
51
  if need_mask is True:
52
- (bh_sk_s, ksks, C) = patch_tokens.size()
52
+ bh_sk_s, ksks, C = patch_tokens.size()
53
53
  h_s = H // ws
54
54
  w_s = W // ws
55
55
  mask = torch.ones(bh_sk_s // B, 1 + ksks, 1 + ksks, device=patch_tokens.device, dtype=torch.float)
@@ -116,7 +116,7 @@ class SequentialWithTwo(nn.Sequential):
116
116
  self, cls_tokens: torch.Tensor, patch_tokens: torch.Tensor
117
117
  ) -> tuple[torch.Tensor, torch.Tensor]:
118
118
  for module in self:
119
- (cls_tokens, patch_tokens) = module(cls_tokens, patch_tokens)
119
+ cls_tokens, patch_tokens = module(cls_tokens, patch_tokens)
120
120
 
121
121
  return (cls_tokens, patch_tokens)
122
122
 
@@ -178,9 +178,9 @@ class AttentionWithRelPos(nn.Module):
178
178
  nn.init.trunc_normal_(self.rel_pos, std=0.02)
179
179
 
180
180
  def forward(self, x: torch.Tensor, patch_attn: bool = False, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
181
- (B, N, C) = x.size()
181
+ B, N, C = x.size()
182
182
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
183
- (q, k, v) = qkv.unbind(0)
183
+ q, k, v = qkv.unbind(0)
184
184
 
185
185
  attn = (q @ k.transpose(-2, -1)) * self.scale
186
186
 
@@ -242,7 +242,7 @@ class PatchEmbed(nn.Module):
242
242
  raise ValueError("Unknown patch_conv_type")
243
243
 
244
244
  def forward(self, x: torch.Tensor, extra_padding: bool = False) -> torch.Tensor:
245
- (_, _, H, W) = x.size()
245
+ _, _, H, W = x.size()
246
246
  if extra_padding and (H % self.patch_size[0] != 0 or W % self.patch_size[1] != 0):
247
247
  p_l = (self.patch_size[1] - W % self.patch_size[1]) // 2
248
248
  p_r = (self.patch_size[1] - W % self.patch_size[1]) - p_l
@@ -384,12 +384,12 @@ class ConvAttStage(nn.Module):
384
384
  self.ws = window_size
385
385
 
386
386
  def forward(self, cls_tokens: torch.Tensor, patch_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
387
- (cls_tokens, patch_tokens) = self.proj(cls_tokens, patch_tokens)
388
- (out, mask, p_r, p_b, B, C, H, W) = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws[0])
387
+ cls_tokens, patch_tokens = self.proj(cls_tokens, patch_tokens)
388
+ out, mask, p_r, p_b, B, C, H, W = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws[0])
389
389
  for blk in self.blocks:
390
390
  out = blk(out, mask, B)
391
391
 
392
- (cls_tokens, patch_tokens) = convert_to_spatial_layout(out, B, C, H, W, self.ws, mask, p_r, p_b)
392
+ cls_tokens, patch_tokens = convert_to_spatial_layout(out, B, C, H, W, self.ws, mask, p_r, p_b)
393
393
 
394
394
  return (cls_tokens, patch_tokens)
395
395
 
@@ -480,7 +480,7 @@ class RegionViT(DetectorBackbone):
480
480
 
481
481
  out = {}
482
482
  for name, module in self.body.named_children():
483
- (cls_tokens, x) = module(cls_tokens, x)
483
+ cls_tokens, x = module(cls_tokens, x)
484
484
  if name in self.return_stages:
485
485
  out[name] = x
486
486
 
@@ -503,14 +503,14 @@ class RegionViT(DetectorBackbone):
503
503
  o_x = x
504
504
  x = self.patch_embed(x)
505
505
  cls_tokens = self.cls_token(o_x, extra_padding=True)
506
- (cls_tokens, x) = self.body(cls_tokens, x)
506
+ cls_tokens, x = self.body(cls_tokens, x)
507
507
 
508
508
  return (cls_tokens, x)
509
509
 
510
510
  def embedding(self, x: torch.Tensor) -> torch.Tensor:
511
- (cls_tokens, _) = self.forward_features(x)
511
+ cls_tokens, _ = self.forward_features(x)
512
512
 
513
- (N, C, _, _) = cls_tokens.size()
513
+ N, C, _, _ = cls_tokens.size()
514
514
  cls_tokens = cls_tokens.reshape(N, C, -1).transpose(1, 2)
515
515
  cls_tokens = self.norm(cls_tokens)
516
516
  out = torch.mean(cls_tokens, dim=1)
birder/net/regnet.py CHANGED
@@ -100,7 +100,7 @@ class BlockParams:
100
100
  group_widths = [group_width] * num_stages
101
101
 
102
102
  # Adjust the compatibility of stage widths and group widths
103
- (stage_widths, group_widths) = cls._adjust_widths_groups_compatibility(
103
+ stage_widths, group_widths = cls._adjust_widths_groups_compatibility(
104
104
  stage_widths, bottleneck_multipliers, group_widths
105
105
  )
106
106
 
birder/net/repghost.py CHANGED
@@ -79,7 +79,7 @@ class RepGhostModule(nn.Module):
79
79
  if self.reparameterized is True:
80
80
  return
81
81
 
82
- (kernel, bias) = self._get_kernel_bias()
82
+ kernel, bias = self._get_kernel_bias()
83
83
  self.cheap_operation = nn.Conv2d(
84
84
  in_channels=self.cheap_operation[0].in_channels,
85
85
  out_channels=self.cheap_operation[0].out_channels,
@@ -87,7 +87,6 @@ class RepGhostModule(nn.Module):
87
87
  padding=self.cheap_operation[0].padding,
88
88
  dilation=self.cheap_operation[0].dilation,
89
89
  groups=self.cheap_operation[0].groups,
90
- bias=True,
91
90
  )
92
91
 
93
92
  self.cheap_operation.weight.data = kernel
@@ -98,9 +97,9 @@ class RepGhostModule(nn.Module):
98
97
  self.reparameterized = True
99
98
 
100
99
  def _get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
101
- (kernel, bias) = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1])
100
+ kernel, bias = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1])
102
101
  if self.fusion_bn is not None:
103
- (kernel1x1, bias_bn) = self._fuse_bn_tensor(nn.Identity(), self.fusion_bn, kernel.shape[0])
102
+ kernel1x1, bias_bn = self._fuse_bn_tensor(nn.Identity(), self.fusion_bn, kernel.shape[0])
104
103
  kernel += F.pad(kernel1x1, [1, 1, 1, 1])
105
104
  bias += bias_bn
106
105
 
@@ -299,7 +298,7 @@ class RepGhost(DetectorBackbone):
299
298
  out_channels = 1280
300
299
  self.features = nn.Sequential(
301
300
  nn.AdaptiveAvgPool2d(output_size=(1, 1)),
302
- nn.Conv2d(prev_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True),
301
+ nn.Conv2d(prev_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
303
302
  nn.ReLU(inplace=True),
304
303
  nn.Flatten(1),
305
304
  nn.Dropout(p=0.2),