birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -120,7 +120,7 @@ class RetinaNetClassificationHead(nn.Module):
120
120
  cls_logits = self.cls_logits(cls_logits)
121
121
 
122
122
  # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
123
- (N, _, H, W) = cls_logits.shape
123
+ N, _, H, W = cls_logits.shape
124
124
  cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
125
125
  cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
126
126
  cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, K)
@@ -202,7 +202,7 @@ class RetinaNetRegressionHead(nn.Module):
202
202
  bbox_regression = self.bbox_reg(bbox_regression)
203
203
 
204
204
  # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
205
- (N, _, H, W) = bbox_regression.size()
205
+ N, _, H, W = bbox_regression.size()
206
206
  bbox_regression = bbox_regression.view(N, -1, 4, H, W)
207
207
  bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
208
208
  bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
@@ -395,7 +395,7 @@ class RetinaNet(DetectionBaseNet):
395
395
 
396
396
  # Keep only topk scoring predictions
397
397
  num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
398
- (scores_per_level, idxs) = scores_per_level.topk(num_topk)
398
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
399
399
  topk_idxs = topk_idxs[idxs]
400
400
 
401
401
  anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
@@ -417,7 +417,7 @@ class RetinaNet(DetectionBaseNet):
417
417
 
418
418
  # Non-maximum suppression
419
419
  if self.soft_nms is not None:
420
- (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
420
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
421
421
  image_scores[keep] = soft_scores
422
422
  else:
423
423
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
@@ -1,5 +1,5 @@
1
1
  """
2
- RT-DETR (Real-Time DEtection TRansformer), adapted from
2
+ RT-DETR v1 (Real-Time DEtection TRansformer), adapted from
3
3
  https://github.com/lyuwenyu/RT-DETR/tree/main/rtdetr_pytorch
4
4
 
5
5
  Paper "DETRs Beat YOLOs on Real-time Object Detection", https://arxiv.org/abs/2304.08069
@@ -114,7 +114,7 @@ def get_contrastive_denoising_training_group( # pylint: disable=too-many-locals
114
114
  # Embed class labels
115
115
  input_query_class = class_embed(input_query_class)
116
116
 
117
- # Create attention mask
117
+ # Create attention mask (True = masked)
118
118
  target_size = total_denoising_queries + num_queries
119
119
  attn_mask = torch.zeros([target_size, target_size], dtype=torch.bool, device=device)
120
120
  attn_mask[total_denoising_queries:, :total_denoising_queries] = True
@@ -212,10 +212,69 @@ class CSPRepLayer(nn.Module):
212
212
  return self.conv3(x1 + x2)
213
213
 
214
214
 
215
+ class MultiheadAttention(nn.Module):
216
+ def __init__(self, d_model: int, num_heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
217
+ super().__init__()
218
+ assert d_model % num_heads == 0, "d_model should be divisible by num_heads"
219
+
220
+ self.num_heads = num_heads
221
+ self.head_dim = d_model // num_heads
222
+ self.scale = self.head_dim**-0.5
223
+
224
+ self.q_proj = nn.Linear(d_model, d_model)
225
+ self.k_proj = nn.Linear(d_model, d_model)
226
+ self.v_proj = nn.Linear(d_model, d_model)
227
+ self.attn_drop = nn.Dropout(attn_drop)
228
+ self.proj = nn.Linear(d_model, d_model)
229
+ self.proj_drop = nn.Dropout(proj_drop)
230
+
231
+ self.reset_parameters()
232
+
233
+ def reset_parameters(self) -> None:
234
+ nn.init.xavier_uniform_(self.q_proj.weight)
235
+ nn.init.xavier_uniform_(self.k_proj.weight)
236
+ nn.init.xavier_uniform_(self.v_proj.weight)
237
+ nn.init.xavier_uniform_(self.proj.weight)
238
+ if self.q_proj.bias is not None:
239
+ nn.init.zeros_(self.q_proj.bias)
240
+ nn.init.zeros_(self.k_proj.bias)
241
+ nn.init.zeros_(self.v_proj.bias)
242
+ nn.init.zeros_(self.proj.bias)
243
+
244
+ def forward(
245
+ self,
246
+ query: torch.Tensor,
247
+ key: torch.Tensor,
248
+ value: torch.Tensor,
249
+ key_padding_mask: Optional[torch.Tensor] = None,
250
+ ) -> torch.Tensor:
251
+ B, l_q, C = query.shape
252
+ q = self.q_proj(query).reshape(B, l_q, self.num_heads, self.head_dim).transpose(1, 2)
253
+ k = self.k_proj(key).reshape(B, key.size(1), self.num_heads, self.head_dim).transpose(1, 2)
254
+ v = self.v_proj(value).reshape(B, value.size(1), self.num_heads, self.head_dim).transpose(1, 2)
255
+
256
+ if key_padding_mask is not None:
257
+ # key_padding_mask is expected to be boolean (True = masked)
258
+ # SDPA expects True = attend, so we invert
259
+ attn_mask = ~key_padding_mask[:, None, None, :]
260
+ else:
261
+ attn_mask = None
262
+
263
+ attn = F.scaled_dot_product_attention( # pylint: disable=not-callable
264
+ q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
265
+ )
266
+
267
+ attn = attn.transpose(1, 2).reshape(B, l_q, C)
268
+ x = self.proj(attn)
269
+ x = self.proj_drop(x)
270
+
271
+ return x
272
+
273
+
215
274
  class TransformerEncoderLayer(nn.Module):
216
275
  def __init__(self, d_model: int, num_heads: int, dim_feedforward: int, dropout: float) -> None:
217
276
  super().__init__()
218
- self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
277
+ self.self_attn = MultiheadAttention(d_model, num_heads, attn_drop=dropout)
219
278
 
220
279
  self.linear1 = nn.Linear(d_model, dim_feedforward)
221
280
  self.dropout = nn.Dropout(dropout)
@@ -231,10 +290,8 @@ class TransformerEncoderLayer(nn.Module):
231
290
  def forward(
232
291
  self, src: torch.Tensor, pos: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None
233
292
  ) -> torch.Tensor:
234
- q = src + pos
235
- k = src + pos
236
-
237
- (src2, _) = self.self_attn(q, k, value=src, key_padding_mask=key_padding_mask, need_weights=False)
293
+ qk = src + pos
294
+ src2 = self.self_attn(qk, qk, value=src, key_padding_mask=key_padding_mask)
238
295
  src = src + self.dropout1(src2)
239
296
  src = self.norm1(src)
240
297
 
@@ -268,7 +325,7 @@ class AIFI(nn.Module):
268
325
  self._pos_cache.clear()
269
326
 
270
327
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
271
- (B, C, H, W) = x.size()
328
+ B, C, H, W = x.size()
272
329
  x = x.flatten(2).permute(0, 2, 1)
273
330
 
274
331
  use_cache = self.use_cache is True and torch.jit.is_tracing() is False and torch.jit.is_scripting() is False
@@ -522,7 +579,7 @@ class RT_DETRDecoder(nn.Module):
522
579
  spatial_shapes: list[list[int]],
523
580
  memory_padding_mask: Optional[torch.Tensor] = None,
524
581
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
525
- (anchors, valid_mask) = self._generate_anchors(spatial_shapes, device=memory.device, dtype=memory.dtype)
582
+ anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device, dtype=memory.dtype)
526
583
  if memory_padding_mask is not None:
527
584
  valid_mask = valid_mask & ~memory_padding_mask.unsqueeze(-1)
528
585
 
@@ -535,7 +592,7 @@ class RT_DETRDecoder(nn.Module):
535
592
  enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
536
593
 
537
594
  # Select top-k queries based on classification confidence
538
- (_, topk_ind) = torch.topk(enc_outputs_class.max(dim=-1).values, self.num_queries, dim=1)
595
+ _, topk_ind = torch.topk(enc_outputs_class.max(dim=-1).values, self.num_queries, dim=1)
539
596
 
540
597
  # Gather reference points
541
598
  reference_points_unact = enc_outputs_coord_unact.gather(
@@ -577,7 +634,7 @@ class RT_DETRDecoder(nn.Module):
577
634
  memory_padding_mask = torch.concat(mask_flatten, dim=1) if mask_flatten else None
578
635
 
579
636
  # Get decoder input (query selection)
580
- (target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits) = self._get_decoder_input(
637
+ target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
581
638
  memory, spatial_shapes, memory_padding_mask
582
639
  )
583
640
 
@@ -858,7 +915,7 @@ class RT_DETR_v1(DetectionBaseNet):
858
915
  loss_ce = self._class_loss(
859
916
  dn_out_logits[layer_idx], dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes
860
917
  )
861
- (loss_bbox, loss_giou) = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
918
+ loss_bbox, loss_giou = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
862
919
 
863
920
  loss_ce_list.append(loss_ce)
864
921
  loss_bbox_list.append(loss_bbox)
@@ -899,7 +956,7 @@ class RT_DETR_v1(DetectionBaseNet):
899
956
  for layer_idx in range(out_logits.shape[0]):
900
957
  indices = self.matcher(out_logits[layer_idx], out_bboxes[layer_idx], targets)
901
958
  loss_ce = self._class_loss(out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes)
902
- (loss_bbox, loss_giou) = self._box_loss(out_bboxes[layer_idx], targets, indices, num_boxes)
959
+ loss_bbox, loss_giou = self._box_loss(out_bboxes[layer_idx], targets, indices, num_boxes)
903
960
  loss_ce_list.append(loss_ce)
904
961
  loss_bbox_list.append(loss_bbox)
905
962
  loss_giou_list.append(loss_giou)
@@ -907,7 +964,7 @@ class RT_DETR_v1(DetectionBaseNet):
907
964
  # Encoder auxiliary loss
908
965
  enc_indices = self.matcher(enc_topk_logits, enc_topk_bboxes, targets)
909
966
  loss_ce_enc = self._class_loss(enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes)
910
- (loss_bbox_enc, loss_giou_enc) = self._box_loss(enc_topk_bboxes, targets, enc_indices, num_boxes)
967
+ loss_bbox_enc, loss_giou_enc = self._box_loss(enc_topk_bboxes, targets, enc_indices, num_boxes)
911
968
  loss_ce_list.append(loss_ce_enc)
912
969
  loss_bbox_list.append(loss_bbox_enc)
913
970
  loss_giou_list.append(loss_giou_enc)
@@ -918,7 +975,7 @@ class RT_DETR_v1(DetectionBaseNet):
918
975
 
919
976
  # Add denoising loss if available
920
977
  if dn_out_bboxes is not None and dn_out_logits is not None and dn_meta is not None:
921
- (loss_ce_dn, loss_bbox_dn, loss_giou_dn) = self._compute_denoising_loss(
978
+ loss_ce_dn, loss_bbox_dn, loss_giou_dn = self._compute_denoising_loss(
922
979
  dn_out_bboxes, dn_out_logits, targets, dn_meta, num_boxes
923
980
  )
924
981
  loss_ce = loss_ce + loss_ce_dn
@@ -952,9 +1009,9 @@ class RT_DETR_v1(DetectionBaseNet):
952
1009
  targets[idx]["boxes"] = boxes
953
1010
  targets[idx]["labels"] = target["labels"] - 1 # No background
954
1011
 
955
- (denoising_class, denoising_bbox_unact, attn_mask, dn_meta) = self._prepare_cdn_queries(targets)
1012
+ denoising_class, denoising_bbox_unact, attn_mask, dn_meta = self._prepare_cdn_queries(targets)
956
1013
 
957
- (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits) = self.decoder(
1014
+ out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits = self.decoder(
958
1015
  encoder_features,
959
1016
  spatial_shapes,
960
1017
  level_start_index,
@@ -965,7 +1022,7 @@ class RT_DETR_v1(DetectionBaseNet):
965
1022
  )
966
1023
 
967
1024
  if dn_meta is not None:
968
- (dn_num_split, _num_queries) = dn_meta["dn_num_split"]
1025
+ dn_num_split, _num_queries = dn_meta["dn_num_split"]
969
1026
  dn_out_bboxes = out_bboxes[:, :, :dn_num_split]
970
1027
  dn_out_logits = out_logits[:, :, :dn_num_split]
971
1028
  out_bboxes = out_bboxes[:, :, dn_num_split:]
@@ -984,9 +1041,7 @@ class RT_DETR_v1(DetectionBaseNet):
984
1041
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
985
1042
  ) -> list[dict[str, torch.Tensor]]:
986
1043
  prob = class_logits.sigmoid()
987
- (topk_values, topk_indexes) = torch.topk(
988
- prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1
989
- )
1044
+ topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1)
990
1045
  scores = topk_values
991
1046
  topk_boxes = topk_indexes // class_logits.shape[2]
992
1047
  labels = topk_indexes % class_logits.shape[2]
@@ -999,7 +1054,7 @@ class RT_DETR_v1(DetectionBaseNet):
999
1054
  boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
1000
1055
 
1001
1056
  # Convert from relative [0, 1] to absolute [0, height] coordinates
1002
- (img_h, img_w) = target_sizes.unbind(1)
1057
+ img_h, img_w = target_sizes.unbind(1)
1003
1058
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
1004
1059
  boxes = boxes * scale_fct[:, None, :]
1005
1060
 
@@ -1056,7 +1111,7 @@ class RT_DETR_v1(DetectionBaseNet):
1056
1111
  mask_size = feat.shape[-2:]
1057
1112
  m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
1058
1113
  else:
1059
- (B, _, H, W) = feat.size()
1114
+ B, _, H, W = feat.size()
1060
1115
  m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
1061
1116
  mask_list.append(m)
1062
1117
 
@@ -1080,7 +1135,7 @@ class RT_DETR_v1(DetectionBaseNet):
1080
1135
  losses = self.compute_loss(encoder_features, spatial_shapes, level_start_index, targets, images, mask_list)
1081
1136
  else:
1082
1137
  # Inference path - no CDN
1083
- (out_bboxes, out_logits, _, _) = self.decoder(
1138
+ out_bboxes, out_logits, _, _ = self.decoder(
1084
1139
  encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list
1085
1140
  )
1086
1141
  detections = self.postprocess_detections(out_logits[-1], out_bboxes[-1], images.image_sizes)
@@ -1100,6 +1155,7 @@ class RT_DETR_v1(DetectionBaseNet):
1100
1155
 
1101
1156
 
1102
1157
  registry.register_model_config(
1103
- "rt_detr_v1_s", RT_DETR_v1, config={"num_decoder_layers": 3, "expansion": 0.5, "depth_multiplier": 0.33}
1158
+ "rt_detr_v1_t", RT_DETR_v1, config={"num_decoder_layers": 3, "expansion": 0.5, "depth_multiplier": 0.33}
1104
1159
  )
1160
+ registry.register_model_config("rt_detr_v1_s", RT_DETR_v1, config={"num_decoder_layers": 3, "expansion": 0.5})
1105
1161
  registry.register_model_config("rt_detr_v1", RT_DETR_v1, config={"num_decoder_layers": 6})