birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/convnext_v2.py CHANGED
@@ -56,15 +56,7 @@ class ConvNeXtBlock(nn.Module):
56
56
  ) -> None:
57
57
  super().__init__()
58
58
  self.block = nn.Sequential(
59
- nn.Conv2d(
60
- channels,
61
- channels,
62
- kernel_size=(7, 7),
63
- stride=(1, 1),
64
- padding=(3, 3),
65
- groups=channels,
66
- bias=True,
67
- ),
59
+ nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
68
60
  Permute([0, 2, 3, 1]),
69
61
  nn.LayerNorm(channels, eps=1e-6),
70
62
  nn.Linear(channels, 4 * channels), # Same as 1x1 conv
@@ -137,7 +129,7 @@ class ConvNeXt_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
137
129
  layers.append(
138
130
  nn.Sequential(
139
131
  LayerNorm2d(i, eps=1e-6),
140
- nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), bias=True),
132
+ nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
141
133
  )
142
134
  )
143
135
 
birder/net/crossformer.py CHANGED
@@ -120,9 +120,9 @@ class Attention(nn.Module):
120
120
  self.relative_position_index = nn.Buffer(relative_position_index)
121
121
 
122
122
  def forward(self, x: torch.Tensor) -> torch.Tensor:
123
- (B, N, C) = x.size()
123
+ B, N, C = x.size()
124
124
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
125
- (q, k, v) = qkv.unbind(0)
125
+ q, k, v = qkv.unbind(0)
126
126
 
127
127
  q = q * self.scale
128
128
  attn = q @ k.transpose(-2, -1)
@@ -188,15 +188,15 @@ class CrossFormerBlock(nn.Module):
188
188
  self.drop_path = StochasticDepth(drop_path, mode="row")
189
189
 
190
190
  def forward(self, x: torch.Tensor) -> torch.Tensor:
191
- (H, W) = self.input_resolution
192
- (B, _, C) = x.size()
191
+ H, W = self.input_resolution
192
+ B, _, C = x.size()
193
193
 
194
194
  shortcut = x
195
195
  x = self.norm1(x)
196
196
  x = x.view(B, H, W, C)
197
197
 
198
198
  # Group embeddings
199
- (GH, GW) = self.group_size # pylint: disable=invalid-name
199
+ GH, GW = self.group_size # pylint: disable=invalid-name
200
200
  if self.use_lda is False:
201
201
  x = x.reshape(B, H // GH, GH, W // GW, GW, C).permute(0, 1, 3, 2, 4, 5)
202
202
  else:
@@ -244,8 +244,8 @@ class PatchMerging(nn.Module):
244
244
  )
245
245
 
246
246
  def forward(self, x: torch.Tensor) -> torch.Tensor:
247
- (H, W) = self.input_resolution
248
- (B, _, C) = x.shape
247
+ H, W = self.input_resolution
248
+ B, _, C = x.shape
249
249
 
250
250
  x = self.norm(x)
251
251
  x = x.view(B, H, W, C).permute(0, 3, 1, 2)
@@ -396,8 +396,8 @@ class CrossFormer(DetectorBackbone):
396
396
  for name, module in self.body.named_children():
397
397
  x = module(x)
398
398
  if name in self.return_stages:
399
- (H, W) = module.resolution
400
- (B, _, C) = x.size()
399
+ H, W = module.resolution
400
+ B, _, C = x.size()
401
401
  out[name] = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
402
402
 
403
403
  return out
birder/net/crossvit.py CHANGED
@@ -74,7 +74,7 @@ class CrossAttention(nn.Module):
74
74
  self.proj_drop = nn.Dropout(proj_drop)
75
75
 
76
76
  def forward(self, x: torch.Tensor) -> torch.Tensor:
77
- (B, N, C) = x.shape
77
+ B, N, C = x.shape
78
78
  # B1C -> B1H(C/H) -> BH1(C/H)
79
79
  q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
80
80
  # BNC -> BNH(C/H) -> BHN(C/H)
birder/net/cspnet.py CHANGED
@@ -226,7 +226,7 @@ class CrossStage(nn.Module):
226
226
  def forward(self, x: torch.Tensor) -> torch.Tensor:
227
227
  x = self.conv_down(x)
228
228
  x = self.conv_exp(x)
229
- (xs, xb) = x.split(self.expand_channels // 2, dim=1)
229
+ xs, xb = x.split(self.expand_channels // 2, dim=1)
230
230
  xb = self.blocks(xb)
231
231
  xb = self.conv_transition_b(xb).contiguous()
232
232
  out = self.conv_transition(torch.concat([xs, xb], dim=1))
@@ -29,7 +29,7 @@ from birder.net.vit import PatchEmbed
29
29
 
30
30
 
31
31
  def img2windows(img: torch.Tensor, h_sp: int, w_sp: int) -> torch.Tensor:
32
- (B, C, H, W) = img.size()
32
+ B, C, H, W = img.size()
33
33
  img_reshape = img.view(B, C, H // h_sp, h_sp, W // w_sp, w_sp)
34
34
  img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, h_sp * w_sp, C)
35
35
 
@@ -81,7 +81,7 @@ class LePEAttention(nn.Module):
81
81
  raise ValueError("unsupported idx")
82
82
 
83
83
  def im2cswin(self, x: torch.Tensor) -> torch.Tensor:
84
- (B, _, C) = x.size()
84
+ B, _, C = x.size()
85
85
  x = x.transpose(-2, -1).contiguous().view(B, C, self.resolution[0], self.resolution[1])
86
86
  x = img2windows(x, self.h_sp, self.w_sp)
87
87
  x = x.reshape(-1, self.h_sp * self.w_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
@@ -89,7 +89,7 @@ class LePEAttention(nn.Module):
89
89
  return x
90
90
 
91
91
  def get_lepe(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
92
- (B, _, C) = x.size()
92
+ B, _, C = x.size()
93
93
  H = self.resolution[0]
94
94
  W = self.resolution[1]
95
95
  x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
@@ -107,13 +107,13 @@ class LePEAttention(nn.Module):
107
107
  return (x, lepe)
108
108
 
109
109
  def forward(self, x: torch.Tensor) -> torch.Tensor:
110
- (q, k, v) = x.unbind(0)
110
+ q, k, v = x.unbind(0)
111
111
 
112
- (B, _, C) = q.shape
112
+ B, _, C = q.shape
113
113
 
114
114
  q = self.im2cswin(q)
115
115
  k = self.im2cswin(k)
116
- (v, lepe) = self.get_lepe(v)
116
+ v, lepe = self.get_lepe(v)
117
117
 
118
118
  q = q * self.scale
119
119
  attn = q @ k.transpose(-2, -1) # B head N C @ B head C N --> B head N N
@@ -136,12 +136,12 @@ class MergeBlock(nn.Module):
136
136
  self.resolution = resolution
137
137
 
138
138
  def forward(self, x: torch.Tensor) -> torch.Tensor:
139
- (B, _, C) = x.size()
139
+ B, _, C = x.size()
140
140
  H = self.resolution[0]
141
141
  W = self.resolution[1]
142
142
  x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
143
143
  x = self.conv(x)
144
- (B, C) = x.shape[:2]
144
+ B, C = x.shape[:2]
145
145
  x = x.view(B, C, -1).transpose(-2, -1).contiguous()
146
146
  x = self.norm(x)
147
147
 
@@ -206,7 +206,7 @@ class CSWinBlock(nn.Module):
206
206
  self.drop_path = StochasticDepth(drop_path, mode="row")
207
207
 
208
208
  def forward(self, x: torch.Tensor) -> torch.Tensor:
209
- (B, _, C) = x.shape
209
+ B, _, C = x.shape
210
210
 
211
211
  qkv = self.qkv(self.norm1(x)).reshape(B, -1, 3, C).permute(2, 0, 1, 3)
212
212
  if self.branch_num == 2:
@@ -350,7 +350,7 @@ class CSWin_Transformer(DetectorBackbone):
350
350
  for name, module in self.body.named_children():
351
351
  x = module(x)
352
352
  if name in self.return_stages:
353
- (B, L, C) = x.size()
353
+ B, L, C = x.size()
354
354
  H = int(math.sqrt(L))
355
355
  W = H
356
356
  out[name] = x.transpose(-2, -1).contiguous().view(B, C, H, W)
birder/net/davit.py CHANGED
@@ -31,7 +31,7 @@ from birder.net.base import TokenRetentionResultType
31
31
 
32
32
 
33
33
  def window_partition(x: torch.Tensor, window_size: tuple[int, int]) -> torch.Tensor:
34
- (B, H, W, C) = x.shape
34
+ B, H, W, C = x.shape
35
35
  x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
36
36
  windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
37
37
 
@@ -92,10 +92,10 @@ class Downsample(nn.Module):
92
92
  )
93
93
 
94
94
  def forward(self, x: torch.Tensor) -> torch.Tensor:
95
- (_, _, H, W) = x.shape
95
+ _, _, H, W = x.shape
96
96
  x = self.norm(x)
97
97
  if self.even_k is True:
98
- (k_h, k_w) = self.conv.kernel_size
98
+ k_h, k_w = self.conv.kernel_size
99
99
  pad_r = (k_w - W % k_w) % k_w
100
100
  pad_b = (k_h - H % k_h) % k_h
101
101
  x = F.pad(x, (0, pad_r, 0, pad_b))
@@ -115,10 +115,10 @@ class ChannelAttention(nn.Module):
115
115
  self.proj = nn.Linear(dim, dim)
116
116
 
117
117
  def forward(self, x: torch.Tensor) -> torch.Tensor:
118
- (B, N, C) = x.shape
118
+ B, N, C = x.shape
119
119
 
120
120
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121
- (q, k, v) = qkv.unbind(0)
121
+ q, k, v = qkv.unbind(0)
122
122
 
123
123
  k = k * self.scale
124
124
  attn = k.transpose(-1, -2) @ v
@@ -151,7 +151,7 @@ class ChannelBlock(nn.Module):
151
151
  self.drop_path = StochasticDepth(drop_path, mode="row")
152
152
 
153
153
  def forward(self, x: torch.Tensor) -> torch.Tensor:
154
- (B, C, H, W) = x.shape
154
+ B, C, H, W = x.shape
155
155
  x = self.cpe1(x).flatten(2).transpose(1, 2)
156
156
 
157
157
  cur = self.norm1(x)
@@ -177,10 +177,10 @@ class WindowAttention(nn.Module):
177
177
  self.proj = nn.Linear(dim, dim)
178
178
 
179
179
  def forward(self, x: torch.Tensor) -> torch.Tensor:
180
- (B, N, C) = x.shape
180
+ B, N, C = x.shape
181
181
 
182
182
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
183
- (q, k, v) = qkv.unbind(0)
183
+ q, k, v = qkv.unbind(0)
184
184
 
185
185
  x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # pylint: disable=not-callable
186
186
  x = x.transpose(1, 2).reshape(B, N, C)
@@ -215,7 +215,7 @@ class SpatialBlock(nn.Module):
215
215
 
216
216
  # pylint: disable=invalid-name
217
217
  def forward(self, x: torch.Tensor) -> torch.Tensor:
218
- (B, C, H, W) = x.shape
218
+ B, C, H, W = x.shape
219
219
 
220
220
  shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
221
221
 
@@ -226,7 +226,7 @@ class SpatialBlock(nn.Module):
226
226
  pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
227
227
  pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
228
228
  x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
229
- (_, Hp, Wp, _) = x.shape
229
+ _, Hp, Wp, _ = x.shape
230
230
 
231
231
  x_windows = window_partition(x, self.window_size)
232
232
  x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
birder/net/deit.py CHANGED
@@ -19,13 +19,15 @@ from birder.model_registry import registry
19
19
  from birder.net._vit_configs import BASE
20
20
  from birder.net._vit_configs import SMALL
21
21
  from birder.net._vit_configs import TINY
22
- from birder.net.base import BaseNet
22
+ from birder.net.base import DetectorBackbone
23
+ from birder.net.base import normalize_out_indices
23
24
  from birder.net.vit import Encoder
24
25
  from birder.net.vit import PatchEmbed
25
26
  from birder.net.vit import adjust_position_embedding
26
27
 
27
28
 
28
- class DeiT(BaseNet):
29
+ # pylint: disable=too-many-instance-attributes
30
+ class DeiT(DetectorBackbone):
29
31
  block_group_regex = r"encoder\.block\.(\d+)"
30
32
 
31
33
  def __init__(
@@ -47,6 +49,7 @@ class DeiT(BaseNet):
47
49
  num_heads: int = self.config["num_heads"]
48
50
  hidden_dim: int = self.config["hidden_dim"]
49
51
  mlp_dim: int = self.config["mlp_dim"]
52
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
50
53
  drop_path_rate: float = self.config["drop_path_rate"]
51
54
 
52
55
  torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
@@ -56,6 +59,7 @@ class DeiT(BaseNet):
56
59
  self.num_layers = num_layers
57
60
  self.hidden_dim = hidden_dim
58
61
  self.num_special_tokens = 2
62
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
59
63
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
60
64
 
61
65
  self.conv_proj = nn.Conv2d(
@@ -64,7 +68,6 @@ class DeiT(BaseNet):
64
68
  kernel_size=(patch_size, patch_size),
65
69
  stride=(patch_size, patch_size),
66
70
  padding=(0, 0),
67
- bias=True,
68
71
  )
69
72
  self.patch_embed = PatchEmbed()
70
73
 
@@ -92,6 +95,9 @@ class DeiT(BaseNet):
92
95
  )
93
96
  self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
94
97
 
98
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
99
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
100
+ self.return_channels = [hidden_dim] * num_return_stages
95
101
  self.embedding_size = hidden_dim
96
102
  self.dist_classifier = self.create_classifier()
97
103
  self.classifier = self.create_classifier()
@@ -136,6 +142,53 @@ class DeiT(BaseNet):
136
142
  def set_causal_attention(self, is_causal: bool = True) -> None:
137
143
  self.encoder.set_causal_attention(is_causal)
138
144
 
145
+ def transform_to_backbone(self) -> None:
146
+ super().transform_to_backbone()
147
+ self.norm = nn.Identity()
148
+ self.dist_classifier = nn.Identity()
149
+
150
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
151
+ H, W = x.shape[-2:]
152
+
153
+ # Reshape and permute the input tensor
154
+ x = self.conv_proj(x)
155
+ x = self.patch_embed(x)
156
+
157
+ # Expand the class token to the full batch
158
+ batch_class_token = self.class_token.expand(x.shape[0], -1, -1)
159
+ batch_dist_token = self.dist_token.expand(x.shape[0], -1, -1)
160
+
161
+ x = torch.concat([batch_class_token, batch_dist_token, x], dim=1)
162
+ x = x + self.pos_embedding
163
+
164
+ if self.out_indices is None:
165
+ xs = [self.encoder(x)]
166
+ else:
167
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
168
+
169
+ out: dict[str, torch.Tensor] = {}
170
+ for stage_name, stage_x in zip(self.return_stages, xs):
171
+ stage_x = stage_x[:, self.num_special_tokens :]
172
+ stage_x = stage_x.permute(0, 2, 1)
173
+ B, C, _ = stage_x.size()
174
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
175
+ out[stage_name] = stage_x
176
+
177
+ return out
178
+
179
+ def freeze_stages(self, up_to_stage: int) -> None:
180
+ for param in self.conv_proj.parameters():
181
+ param.requires_grad_(False)
182
+
183
+ self.pos_embedding.requires_grad_(False)
184
+
185
+ for idx, module in enumerate(self.encoder.children()):
186
+ if idx >= up_to_stage:
187
+ break
188
+
189
+ for param in module.parameters():
190
+ param.requires_grad_(False)
191
+
139
192
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
140
193
  # Reshape and permute the input tensor
141
194
  x = self.conv_proj(x)
birder/net/deit3.py CHANGED
@@ -27,6 +27,7 @@ from birder.net.base import MaskedTokenRetentionMixin
27
27
  from birder.net.base import PreTrainEncoder
28
28
  from birder.net.base import TokenOmissionResultType
29
29
  from birder.net.base import TokenRetentionResultType
30
+ from birder.net.base import normalize_out_indices
30
31
  from birder.net.vit import Encoder
31
32
  from birder.net.vit import EncoderBlock
32
33
  from birder.net.vit import PatchEmbed
@@ -59,6 +60,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
59
60
  mlp_dim: int = self.config["mlp_dim"]
60
61
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", 1e-5)
61
62
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
63
+ out_indices: Optional[list[int]] = self.config.get("out_indices", None)
62
64
  drop_path_rate: float = self.config["drop_path_rate"]
63
65
 
64
66
  torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
@@ -70,6 +72,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
70
72
  self.num_reg_tokens = num_reg_tokens
71
73
  self.num_special_tokens = 1 + self.num_reg_tokens
72
74
  self.pos_embed_special_tokens = pos_embed_special_tokens
75
+ self.out_indices = normalize_out_indices(out_indices, num_layers)
73
76
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
74
77
 
75
78
  self.conv_proj = nn.Conv2d(
@@ -78,7 +81,6 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
78
81
  kernel_size=(patch_size, patch_size),
79
82
  stride=(patch_size, patch_size),
80
83
  padding=(0, 0),
81
- bias=True,
82
84
  )
83
85
  self.patch_embed = PatchEmbed()
84
86
 
@@ -112,8 +114,9 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
112
114
  )
113
115
  self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
114
116
 
115
- self.return_stages = ["neck"] # Actually meaningless, just for completeness
116
- self.return_channels = [hidden_dim]
117
+ num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
118
+ self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
119
+ self.return_channels = [hidden_dim] * num_return_stages
117
120
  self.embedding_size = hidden_dim
118
121
  self.classifier = self.create_classifier()
119
122
 
@@ -159,7 +162,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
159
162
  )
160
163
 
161
164
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
162
- (H, W) = x.shape[-2:]
165
+ H, W = x.shape[-2:]
163
166
 
164
167
  x = self.conv_proj(x)
165
168
  x = self.patch_embed(x)
@@ -176,15 +179,20 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
176
179
  x = x + self._get_pos_embed(H, W)
177
180
  x = torch.concat([batch_special_tokens, x], dim=1)
178
181
 
179
- x = self.encoder(x)
180
- x = self.norm(x)
182
+ if self.out_indices is None:
183
+ xs = [self.encoder(x)]
184
+ else:
185
+ xs = self.encoder.forward_features(x, out_indices=self.out_indices)
181
186
 
182
- x = x[:, self.num_special_tokens :]
183
- x = x.permute(0, 2, 1)
184
- (B, C, _) = x.size()
185
- x = x.reshape(B, C, H // self.patch_size, W // self.patch_size)
187
+ out: dict[str, torch.Tensor] = {}
188
+ for stage_name, stage_x in zip(self.return_stages, xs):
189
+ stage_x = stage_x[:, self.num_special_tokens :]
190
+ stage_x = stage_x.permute(0, 2, 1)
191
+ B, C, _ = stage_x.size()
192
+ stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
193
+ out[stage_name] = stage_x
186
194
 
187
- return {self.return_stages[0]: x}
195
+ return out
188
196
 
189
197
  def freeze_stages(self, up_to_stage: int) -> None:
190
198
  for param in self.conv_proj.parameters():
@@ -199,6 +207,10 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
199
207
  for param in module.parameters():
200
208
  param.requires_grad_(False)
201
209
 
210
+ def transform_to_backbone(self) -> None:
211
+ super().transform_to_backbone()
212
+ self.norm = nn.Identity()
213
+
202
214
  def set_causal_attention(self, is_causal: bool = True) -> None:
203
215
  self.encoder.set_causal_attention(is_causal)
204
216
 
@@ -209,7 +221,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
209
221
  return_all_features: bool = False,
210
222
  return_keys: Literal["all", "tokens", "embedding"] = "tokens",
211
223
  ) -> TokenOmissionResultType:
212
- (H, W) = x.shape[-2:]
224
+ H, W = x.shape[-2:]
213
225
 
214
226
  # Reshape and permute the input tensor
215
227
  x = self.conv_proj(x)
@@ -272,7 +284,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
272
284
  mask_token: Optional[torch.Tensor] = None,
273
285
  return_keys: Literal["all", "features", "embedding"] = "features",
274
286
  ) -> TokenRetentionResultType:
275
- (H, W) = x.shape[-2:]
287
+ H, W = x.shape[-2:]
276
288
 
277
289
  x = self.conv_proj(x)
278
290
  x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
@@ -302,7 +314,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
302
314
  if return_keys in ("all", "features"):
303
315
  features = x[:, self.num_special_tokens :]
304
316
  features = features.permute(0, 2, 1)
305
- (B, C, _) = features.size()
317
+ B, C, _ = features.size()
306
318
  features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
307
319
  result["features"] = features
308
320
 
@@ -312,7 +324,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
312
324
  return result
313
325
 
314
326
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
315
- (H, W) = x.shape[-2:]
327
+ H, W = x.shape[-2:]
316
328
 
317
329
  # Reshape and permute the input tensor
318
330
  x = self.conv_proj(x)
@@ -3,8 +3,10 @@ from birder.net.detection.detr import DETR
3
3
  from birder.net.detection.efficientdet import EfficientDet
4
4
  from birder.net.detection.faster_rcnn import Faster_RCNN
5
5
  from birder.net.detection.fcos import FCOS
6
+ from birder.net.detection.plain_detr import Plain_DETR
6
7
  from birder.net.detection.retinanet import RetinaNet
7
8
  from birder.net.detection.rt_detr_v1 import RT_DETR_v1
9
+ from birder.net.detection.rt_detr_v2 import RT_DETR_v2
8
10
  from birder.net.detection.ssd import SSD
9
11
  from birder.net.detection.ssdlite import SSDLite
10
12
  from birder.net.detection.vitdet import ViTDet
@@ -19,8 +21,10 @@ __all__ = [
19
21
  "EfficientDet",
20
22
  "Faster_RCNN",
21
23
  "FCOS",
24
+ "Plain_DETR",
22
25
  "RetinaNet",
23
26
  "RT_DETR_v1",
27
+ "RT_DETR_v2",
24
28
  "SSD",
25
29
  "SSDLite",
26
30
  "ViTDet",
@@ -71,7 +71,7 @@ def scale_anchors(anchors: AnchorGroups, from_size: tuple[int, int], to_size: tu
71
71
 
72
72
 
73
73
  def scale_anchors(anchors: AnchorLike, from_size: tuple[int, int], to_size: tuple[int, int]) -> AnchorLike:
74
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
74
+ anchor_groups, single = _normalize_anchor_groups(anchors)
75
75
 
76
76
  if from_size == to_size:
77
77
  # Avoid aliasing default anchors in case they are mutated later
@@ -100,7 +100,7 @@ def pixels_to_grid(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
100
100
 
101
101
 
102
102
  def pixels_to_grid(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
103
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
103
+ anchor_groups, single = _normalize_anchor_groups(anchors)
104
104
  if len(anchor_groups) != len(strides):
105
105
  raise ValueError("strides must provide one value per anchor scale")
106
106
 
@@ -123,7 +123,7 @@ def grid_to_pixels(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
123
123
 
124
124
 
125
125
  def grid_to_pixels(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
126
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
126
+ anchor_groups, single = _normalize_anchor_groups(anchors)
127
127
  if len(anchor_groups) != len(strides):
128
128
  raise ValueError("strides must provide one value per anchor scale")
129
129
 
@@ -187,7 +187,7 @@ def resolve_anchor_group(
187
187
  preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
188
188
  ) -> AnchorGroup:
189
189
  anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
190
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
190
+ anchor_groups, single = _normalize_anchor_groups(anchors)
191
191
  if single is False:
192
192
  raise ValueError("Expected a single anchor group for this model")
193
193
 
@@ -198,7 +198,7 @@ def resolve_anchor_groups(
198
198
  preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
199
199
  ) -> AnchorGroups:
200
200
  anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
201
- (anchor_groups, single) = _normalize_anchor_groups(anchors)
201
+ anchor_groups, single = _normalize_anchor_groups(anchors)
202
202
  if single is True:
203
203
  raise ValueError("Expected multiple anchor groups for this model")
204
204
 
@@ -41,6 +41,7 @@ def get_detection_signature(input_shape: tuple[int, ...], num_outputs: int, dyna
41
41
 
42
42
  class DetectionBaseNet(nn.Module):
43
43
  default_size: tuple[int, int]
44
+ block_group_regex: Optional[str]
44
45
  auto_register = False
45
46
  scriptable = True
46
47
  task = str(Task.OBJECT_DETECTION)
@@ -308,7 +309,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
308
309
  names.append(f"stage{idx+1}")
309
310
 
310
311
  if self.extra_blocks is not None:
311
- (results, names) = self.extra_blocks(results, [x], names)
312
+ results, names = self.extra_blocks(results, [x], names)
312
313
 
313
314
  out = OrderedDict(list(zip(names, results)))
314
315
 
@@ -432,7 +433,7 @@ class BoxCoder:
432
433
  ctr_x = boxes[:, 0] + 0.5 * widths
433
434
  ctr_y = boxes[:, 1] + 0.5 * heights
434
435
 
435
- (wx, wy, ww, wh) = self.weights
436
+ wx, wy, ww, wh = self.weights
436
437
  dx = rel_codes[:, 0::4] / wx
437
438
  dy = rel_codes[:, 1::4] / wy
438
439
  dw = rel_codes[:, 2::4] / ww
@@ -510,8 +511,8 @@ class AnchorGenerator(nn.Module):
510
511
  )
511
512
 
512
513
  for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
513
- (grid_height, grid_width) = size
514
- (stride_height, stride_width) = stride
514
+ grid_height, grid_width = size
515
+ stride_height, stride_width = stride
515
516
  device = base_anchors.device
516
517
 
517
518
  # For output anchor, compute [x_center, y_center, x_center, y_center]
@@ -656,7 +657,7 @@ class Matcher(nn.Module):
656
657
  # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
657
658
  # Each element in the first tensor is a gt index,
658
659
  # and each element in second tensor is a prediction index
659
- # Note how gt items 1, 2, 3, and 5 each have two ties
660
+ # Note how gt items 1, 2, 3 and 5 each have two ties
660
661
 
661
662
  pred_idx_to_update = gt_pred_pairs_of_highest_quality[1]
662
663
  matches[pred_idx_to_update] = all_matches[pred_idx_to_update]