birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -147,7 +147,7 @@ class MultiScaleDeformableAttention(nn.Module):
147
147
  param.requires_grad_(False)
148
148
 
149
149
  def reset_parameters(self) -> None:
150
- nn.init.constant_(self.sampling_offsets.weight, 0.0)
150
+ nn.init.zeros_(self.sampling_offsets.weight)
151
151
  thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
152
152
  grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
153
153
  grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)[0]
@@ -158,25 +158,27 @@ class MultiScaleDeformableAttention(nn.Module):
158
158
  with torch.no_grad():
159
159
  self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
160
160
 
161
- nn.init.constant_(self.attention_weights.weight, 0.0)
162
- nn.init.constant_(self.attention_weights.bias, 0.0)
161
+ nn.init.zeros_(self.attention_weights.weight)
162
+ nn.init.zeros_(self.attention_weights.bias)
163
163
  nn.init.xavier_uniform_(self.value_proj.weight)
164
- nn.init.constant_(self.value_proj.bias, 0.0)
164
+ nn.init.zeros_(self.value_proj.bias)
165
165
  nn.init.xavier_uniform_(self.output_proj.weight)
166
- nn.init.constant_(self.output_proj.bias, 0.0)
166
+ nn.init.zeros_(self.output_proj.bias)
167
167
 
168
+ # pylint: disable=too-many-locals
168
169
  def forward(
169
170
  self,
170
171
  query: torch.Tensor,
171
172
  reference_points: torch.Tensor,
172
173
  input_flatten: torch.Tensor,
173
174
  input_spatial_shapes: torch.Tensor,
175
+ src_shapes: list[list[int]],
174
176
  input_level_start_index: torch.Tensor,
175
177
  input_padding_mask: Optional[torch.Tensor] = None,
176
178
  ) -> torch.Tensor:
177
- N, num_queries, _ = query.size()
179
+ num_queries = query.size(1)
178
180
  N, sequence_length, _ = input_flatten.size()
179
- assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
181
+ # assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
180
182
 
181
183
  value = self.value_proj(input_flatten)
182
184
  if input_padding_mask is not None:
@@ -231,7 +233,7 @@ class MultiScaleDeformableAttention(nn.Module):
231
233
 
232
234
  if self.method == "discrete":
233
235
  output = self._forward_fallback(
234
- value, input_spatial_shapes, sampling_locations, attention_weights, method="discrete"
236
+ value, input_spatial_shapes, src_shapes, sampling_locations, attention_weights, method="discrete"
235
237
  )
236
238
  else:
237
239
  if self.uniform_points is True:
@@ -245,10 +247,11 @@ class MultiScaleDeformableAttention(nn.Module):
245
247
  sampling_locations,
246
248
  attention_weights,
247
249
  self.im2col_step,
250
+ src_shapes,
248
251
  )
249
252
  else:
250
253
  output = self._forward_fallback(
251
- value, input_spatial_shapes, sampling_locations, attention_weights, method="default"
254
+ value, input_spatial_shapes, src_shapes, sampling_locations, attention_weights, method="default"
252
255
  )
253
256
 
254
257
  output = self.output_proj(output)
@@ -258,6 +261,7 @@ class MultiScaleDeformableAttention(nn.Module):
258
261
  self,
259
262
  value: torch.Tensor,
260
263
  spatial_shapes: torch.Tensor,
264
+ src_shapes: list[list[int]],
261
265
  sampling_locations: torch.Tensor,
262
266
  attention_weights: torch.Tensor,
263
267
  method: str = "default",
@@ -272,8 +276,7 @@ class MultiScaleDeformableAttention(nn.Module):
272
276
  sampling_locations_list = sampling_grids.split(self.num_points, dim=-2)
273
277
 
274
278
  sampling_value_list = []
275
- spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
276
- for level, (H, W) in enumerate(spatial_shapes_list):
279
+ for level, (H, W) in enumerate(src_shapes):
277
280
  value_l = value_list[level].reshape(B * n_heads, head_dim, H, W)
278
281
  sampling_grid_l = sampling_locations_list[level]
279
282
 
@@ -361,21 +364,21 @@ class TransformerDecoderLayer(nn.Module):
361
364
  reference_points: torch.Tensor,
362
365
  src: torch.Tensor,
363
366
  src_spatial_shapes: torch.Tensor,
367
+ src_shapes: list[list[int]],
364
368
  level_start_index: torch.Tensor,
365
369
  src_padding_mask: Optional[torch.Tensor],
366
370
  self_attn_mask: Optional[torch.Tensor] = None,
367
371
  ) -> torch.Tensor:
368
372
  # Self attention
369
- q = tgt + query_pos
370
- k = tgt + query_pos
373
+ q_k = tgt + query_pos
371
374
 
372
- tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)
375
+ tgt2 = self.self_attn(q_k, q_k, tgt, attn_mask=self_attn_mask)
373
376
  tgt = tgt + self.dropout(tgt2)
374
377
  tgt = self.norm1(tgt)
375
378
 
376
379
  # Cross attention
377
380
  tgt2 = self.cross_attn(
378
- tgt + query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask
381
+ tgt + query_pos, reference_points, src, src_spatial_shapes, src_shapes, level_start_index, src_padding_mask
379
382
  )
380
383
  tgt = tgt + self.dropout(tgt2)
381
384
  tgt = self.norm2(tgt)
@@ -526,18 +529,18 @@ class RT_DETRDecoder(nn.Module):
526
529
 
527
530
  # Gather reference points
528
531
  reference_points_unact = enc_outputs_coord_unact.gather(
529
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1])
532
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
530
533
  )
531
534
 
532
535
  enc_topk_bboxes = reference_points_unact.sigmoid()
533
536
 
534
537
  # Gather encoder logits for loss computation
535
538
  enc_topk_logits = enc_outputs_class.gather(
536
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
539
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
537
540
  )
538
541
 
539
542
  # Extract region features
540
- target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
543
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
541
544
  target = target.detach()
542
545
 
543
546
  return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
@@ -551,6 +554,7 @@ class RT_DETRDecoder(nn.Module):
551
554
  denoising_bbox_unact: Optional[torch.Tensor] = None,
552
555
  attn_mask: Optional[torch.Tensor] = None,
553
556
  padding_mask: Optional[list[torch.Tensor]] = None,
557
+ return_intermediates: bool = True,
554
558
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
555
559
  memory = []
556
560
  mask_flatten = []
@@ -578,18 +582,19 @@ class RT_DETRDecoder(nn.Module):
578
582
  level_start_index_tensor = torch.tensor(level_start_index, dtype=torch.long, device=memory.device)
579
583
 
580
584
  # Decoder forward
581
- out_bboxes = []
582
- out_logits = []
585
+ bboxes_list: list[torch.Tensor] = []
586
+ logits_list: list[torch.Tensor] = []
583
587
  reference_points = init_ref_points_unact.sigmoid()
584
588
  for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
585
589
  query_pos = self.query_pos_head(reference_points)
586
- reference_points_input = reference_points.unsqueeze(2).repeat(1, 1, len(spatial_shapes), 1)
590
+ reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
587
591
  target = decoder_layer(
588
592
  target,
589
593
  query_pos,
590
594
  reference_points_input,
591
595
  memory,
592
596
  spatial_shapes_tensor,
597
+ spatial_shapes,
593
598
  level_start_index_tensor,
594
599
  memory_padding_mask,
595
600
  attn_mask,
@@ -602,14 +607,19 @@ class RT_DETRDecoder(nn.Module):
602
607
  # Classification
603
608
  class_logits = class_head(target)
604
609
 
605
- out_bboxes.append(new_reference_points)
606
- out_logits.append(class_logits)
610
+ if return_intermediates is True:
611
+ bboxes_list.append(new_reference_points)
612
+ logits_list.append(class_logits)
607
613
 
608
614
  # Update reference points for next layer
609
615
  reference_points = new_reference_points.detach()
610
616
 
611
- out_bboxes = torch.stack(out_bboxes)
612
- out_logits = torch.stack(out_logits)
617
+ if return_intermediates is True:
618
+ out_bboxes = torch.stack(bboxes_list)
619
+ out_logits = torch.stack(logits_list)
620
+ else:
621
+ out_bboxes = new_reference_points
622
+ out_logits = class_logits
613
623
 
614
624
  return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits)
615
625
 
@@ -675,7 +685,7 @@ class RT_DETR_v2(DetectionBaseNet):
675
685
  self.decoder = RT_DETRDecoder(
676
686
  hidden_dim=hidden_dim,
677
687
  num_classes=self.num_classes,
678
- num_queries=num_queries,
688
+ num_queries=self.num_queries,
679
689
  num_decoder_layers=num_decoder_layers,
680
690
  num_levels=self.num_levels,
681
691
  num_heads=num_heads,
@@ -744,20 +754,32 @@ class RT_DETR_v2(DetectionBaseNet):
744
754
  for param in self.denoising_class_embed.parameters():
745
755
  param.requires_grad_(True)
746
756
 
747
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
757
+ @staticmethod
758
+ def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
748
759
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
749
760
  src_idx = torch.concat([src for (src, _) in indices])
750
761
  return (batch_idx, src_idx)
751
762
 
752
- def _class_loss(
763
+ def _compute_layer_losses(
753
764
  self,
754
765
  cls_logits: torch.Tensor,
755
766
  box_output: torch.Tensor,
756
767
  targets: list[dict[str, torch.Tensor]],
757
- indices: list[torch.Tensor],
768
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
758
769
  num_boxes: float,
759
- ) -> torch.Tensor:
770
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
760
771
  idx = self._get_src_permutation_idx(indices)
772
+
773
+ src_boxes = box_output[idx]
774
+ target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
775
+
776
+ src_boxes_xyxy = box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy")
777
+ target_boxes_xyxy = box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy")
778
+
779
+ # IoU for varifocal loss (class loss)
780
+ ious = torch.diag(box_ops.box_iou(src_boxes_xyxy, target_boxes_xyxy)).detach()
781
+
782
+ # Classification loss
761
783
  target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
762
784
  target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
763
785
  target_classes[idx] = target_classes_o
@@ -771,15 +793,6 @@ class RT_DETR_v2(DetectionBaseNet):
771
793
  target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
772
794
  target_classes_onehot = target_classes_onehot[:, :, :-1]
773
795
 
774
- src_boxes = box_output[idx]
775
- target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
776
- ious = torch.diag(
777
- box_ops.box_iou(
778
- box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
779
- box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
780
- )
781
- ).detach()
782
-
783
796
  target_score_o = torch.zeros(cls_logits.shape[:2], dtype=cls_logits.dtype, device=cls_logits.device)
784
797
  target_score_o[idx] = ious.to(cls_logits.dtype)
785
798
  target_score = target_score_o.unsqueeze(-1) * target_classes_onehot
@@ -787,31 +800,13 @@ class RT_DETR_v2(DetectionBaseNet):
787
800
  loss = varifocal_loss(cls_logits, target_score, target_classes_onehot, alpha=0.75, gamma=2.0)
788
801
  loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.shape[1]
789
802
 
790
- return loss_ce
791
-
792
- def _box_loss(
793
- self,
794
- box_output: torch.Tensor,
795
- targets: list[dict[str, torch.Tensor]],
796
- indices: list[torch.Tensor],
797
- num_boxes: float,
798
- ) -> tuple[torch.Tensor, torch.Tensor]:
799
- idx = self._get_src_permutation_idx(indices)
800
- src_boxes = box_output[idx]
801
- target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
802
-
803
- loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
804
- loss_bbox = loss_bbox.sum() / num_boxes
803
+ # Box L1 loss
804
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none").sum() / num_boxes
805
805
 
806
- loss_giou = 1 - torch.diag(
807
- box_ops.generalized_box_iou(
808
- box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
809
- box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
810
- )
811
- )
812
- loss_giou = loss_giou.sum() / num_boxes
806
+ # GIoU loss
807
+ loss_giou = (1 - torch.diag(box_ops.generalized_box_iou(src_boxes_xyxy, target_boxes_xyxy))).sum() / num_boxes
813
808
 
814
- return (loss_bbox, loss_giou)
809
+ return (loss_ce, loss_bbox, loss_giou)
815
810
 
816
811
  def _compute_denoising_loss(
817
812
  self,
@@ -846,11 +841,9 @@ class RT_DETR_v2(DetectionBaseNet):
846
841
  )
847
842
  )
848
843
 
849
- loss_ce = self._class_loss(
844
+ loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
850
845
  dn_out_logits[layer_idx], dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes
851
846
  )
852
- loss_bbox, loss_giou = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
853
-
854
847
  loss_ce_list.append(loss_ce)
855
848
  loss_bbox_list.append(loss_bbox)
856
849
  loss_giou_list.append(loss_giou)
@@ -861,9 +854,7 @@ class RT_DETR_v2(DetectionBaseNet):
861
854
 
862
855
  return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
863
856
 
864
- @torch.jit.unused # type: ignore[untyped-decorator]
865
- @torch.compiler.disable() # type: ignore[untyped-decorator]
866
- def _compute_loss_from_outputs( # pylint: disable=too-many-locals
857
+ def _compute_loss_from_outputs(
867
858
  self,
868
859
  targets: list[dict[str, torch.Tensor]],
869
860
  out_bboxes: torch.Tensor,
@@ -880,7 +871,7 @@ class RT_DETR_v2(DetectionBaseNet):
880
871
  if training_utils.is_dist_available_and_initialized() is True:
881
872
  torch.distributed.all_reduce(num_boxes)
882
873
 
883
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
874
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
884
875
 
885
876
  loss_ce_list = []
886
877
  loss_bbox_list = []
@@ -889,19 +880,21 @@ class RT_DETR_v2(DetectionBaseNet):
889
880
  # Decoder losses (all layers)
890
881
  for layer_idx in range(out_logits.shape[0]):
891
882
  indices = self.matcher(out_logits[layer_idx], out_bboxes[layer_idx], targets)
892
- loss_ce = self._class_loss(out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes)
893
- loss_bbox, loss_giou = self._box_loss(out_bboxes[layer_idx], targets, indices, num_boxes)
883
+ loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
884
+ out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes
885
+ )
894
886
  loss_ce_list.append(loss_ce)
895
887
  loss_bbox_list.append(loss_bbox)
896
888
  loss_giou_list.append(loss_giou)
897
889
 
898
890
  # Encoder auxiliary loss
899
891
  enc_indices = self.matcher(enc_topk_logits, enc_topk_bboxes, targets)
900
- loss_ce_enc = self._class_loss(enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes)
901
- loss_bbox_enc, loss_giou_enc = self._box_loss(enc_topk_bboxes, targets, enc_indices, num_boxes)
902
- loss_ce_list.append(loss_ce_enc)
903
- loss_bbox_list.append(loss_bbox_enc)
904
- loss_giou_list.append(loss_giou_enc)
892
+ loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
893
+ enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes
894
+ )
895
+ loss_ce_list.append(loss_ce)
896
+ loss_bbox_list.append(loss_bbox)
897
+ loss_giou_list.append(loss_giou)
905
898
 
906
899
  loss_ce = torch.stack(loss_ce_list).sum() # VFL weight is 1
907
900
  loss_bbox = torch.stack(loss_bbox_list).sum() * 5
@@ -935,11 +928,11 @@ class RT_DETR_v2(DetectionBaseNet):
935
928
  images: Any,
936
929
  masks: Optional[list[torch.Tensor]] = None,
937
930
  ) -> dict[str, torch.Tensor]:
938
- device = encoder_features[0].device
939
931
  for idx, target in enumerate(targets):
940
932
  boxes = target["boxes"]
941
933
  boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
942
- boxes = boxes / torch.tensor(images.image_sizes[idx][::-1] * 2, dtype=torch.float32, device=device)
934
+ scale = images.image_sizes[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
935
+ boxes = boxes / scale
943
936
  targets[idx]["boxes"] = boxes
944
937
  targets[idx]["labels"] = target["labels"] - 1 # No background
945
938
 
@@ -972,7 +965,7 @@ class RT_DETR_v2(DetectionBaseNet):
972
965
  return losses
973
966
 
974
967
  def postprocess_detections(
975
- self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
968
+ self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
976
969
  ) -> list[dict[str, torch.Tensor]]:
977
970
  prob = class_logits.sigmoid()
978
971
  topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1)
@@ -981,14 +974,12 @@ class RT_DETR_v2(DetectionBaseNet):
981
974
  labels = topk_indexes % class_logits.shape[2]
982
975
  labels += 1 # Background offset
983
976
 
984
- target_sizes = torch.tensor(image_shapes, device=class_logits.device)
985
-
986
977
  # Convert to [x0, y0, x1, y1] format
987
978
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
988
- boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
979
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
989
980
 
990
981
  # Convert from relative [0, 1] to absolute [0, height] coordinates
991
- img_h, img_w = target_sizes.unbind(1)
982
+ img_h, img_w = image_sizes.unbind(1)
992
983
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
993
984
  boxes = boxes * scale_fct[:, None, :]
994
985
 
@@ -1024,32 +1015,34 @@ class RT_DETR_v2(DetectionBaseNet):
1024
1015
 
1025
1016
  return (None, None, None, None)
1026
1017
 
1018
+ def forward_net(
1019
+ self, x: torch.Tensor, masks: Optional[torch.Tensor]
1020
+ ) -> tuple[list[torch.Tensor], Optional[list[torch.Tensor]]]:
1021
+ features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
1022
+ feature_list = list(features.values())
1023
+
1024
+ mask_list: Optional[list[torch.Tensor]] = None
1025
+ if masks is not None:
1026
+ mask_list = []
1027
+ for feat in feature_list:
1028
+ m = F.interpolate(masks[None].float(), size=feat.shape[-2:], mode="nearest").to(torch.bool)[0]
1029
+ mask_list.append(m)
1030
+
1031
+ encoder_features = self.encoder(feature_list, masks=mask_list)
1032
+
1033
+ return (encoder_features, mask_list)
1034
+
1027
1035
  def forward(
1028
1036
  self,
1029
1037
  x: torch.Tensor,
1030
1038
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
1031
1039
  masks: Optional[torch.Tensor] = None,
1032
- image_sizes: Optional[list[list[int]]] = None,
1040
+ image_sizes: Optional[list[tuple[int, int]]] = None,
1033
1041
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
1034
1042
  self._input_check(targets)
1035
1043
  images = self._to_img_list(x, image_sizes)
1036
1044
 
1037
- # Backbone features
1038
- features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
1039
- feature_list = list(features.values())
1040
-
1041
- # Hybrid encoder
1042
- mask_list: list[torch.Tensor] = []
1043
- for feat in feature_list:
1044
- if masks is not None:
1045
- mask_size = feat.shape[-2:]
1046
- m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
1047
- else:
1048
- B, _, H, W = feat.size()
1049
- m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
1050
- mask_list.append(m)
1051
-
1052
- encoder_features = self.encoder(feature_list, masks=mask_list)
1045
+ encoder_features, mask_list = self.forward_net(x, masks)
1053
1046
 
1054
1047
  # Prepare spatial shapes and level start index
1055
1048
  spatial_shapes: list[list[int]] = []
@@ -1070,9 +1063,9 @@ class RT_DETR_v2(DetectionBaseNet):
1070
1063
  else:
1071
1064
  # Inference path - no CDN
1072
1065
  out_bboxes, out_logits, _, _ = self.decoder(
1073
- encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list
1066
+ encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list, return_intermediates=False
1074
1067
  )
1075
- detections = self.postprocess_detections(out_logits[-1], out_bboxes[-1], images.image_sizes)
1068
+ detections = self.postprocess_detections(out_logits, out_bboxes, images.image_sizes)
1076
1069
 
1077
1070
  return (detections, losses)
1078
1071
 
@@ -30,6 +30,7 @@ from birder.net.detection.base import BoxCoder
30
30
  from birder.net.detection.base import DetectionBaseNet
31
31
  from birder.net.detection.base import ImageList
32
32
  from birder.net.detection.base import Matcher
33
+ from birder.net.detection.base import clip_boxes_to_image
33
34
 
34
35
 
35
36
  class SSDMatcher(Matcher):
@@ -303,6 +304,12 @@ class SSD(DetectionBaseNet):
303
304
  topk_candidates = 400
304
305
  positive_fraction = 0.25
305
306
 
307
+ self.score_thresh = score_thresh
308
+ self.nms_thresh = nms_thresh
309
+ self.detections_per_img = detections_per_img
310
+ self.topk_candidates = topk_candidates
311
+ self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
312
+
306
313
  self.backbone.return_channels = self.backbone.return_channels[-2:]
307
314
  self.backbone.return_stages = self.backbone.return_stages[-2:]
308
315
  self.extra_blocks = nn.ModuleList(
@@ -325,11 +332,8 @@ class SSD(DetectionBaseNet):
325
332
  self.head = SSDHead(self.backbone.return_channels + [512, 256, 256, 256], num_anchors, self.num_classes)
326
333
  self.proposal_matcher = SSDMatcher(iou_thresh)
327
334
 
328
- self.score_thresh = score_thresh
329
- self.nms_thresh = nms_thresh
330
- self.detections_per_img = detections_per_img
331
- self.topk_candidates = topk_candidates
332
- self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
335
+ if self.export_mode is False:
336
+ self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
333
337
 
334
338
  def reset_classifier(self, num_classes: int) -> None:
335
339
  self.num_classes = num_classes + 1
@@ -348,6 +352,8 @@ class SSD(DetectionBaseNet):
348
352
  param.requires_grad_(True)
349
353
 
350
354
  # pylint: disable=too-many-locals
355
+ @torch.jit.unused # type: ignore[untyped-decorator]
356
+ @torch.compiler.disable() # type: ignore[untyped-decorator]
351
357
  def compute_loss(
352
358
  self,
353
359
  targets: list[dict[str, torch.Tensor]],
@@ -423,7 +429,7 @@ class SSD(DetectionBaseNet):
423
429
  self,
424
430
  head_outputs: dict[str, torch.Tensor],
425
431
  image_anchors: list[torch.Tensor],
426
- image_shapes: list[tuple[int, int]],
432
+ image_sizes: torch.Tensor,
427
433
  ) -> list[dict[str, torch.Tensor]]:
428
434
  bbox_regression = head_outputs["bbox_regression"]
429
435
  pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
@@ -431,11 +437,10 @@ class SSD(DetectionBaseNet):
431
437
  num_classes = pred_scores.size(-1)
432
438
  device = pred_scores.device
433
439
  detections: list[dict[str, torch.Tensor]] = []
434
- for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_shapes):
440
+ for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_sizes):
435
441
  boxes = self.box_coder.decode_single(boxes, anchors)
436
- boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
442
+ boxes = clip_boxes_to_image(boxes, image_shape)
437
443
 
438
- list_empty = True
439
444
  image_boxes_list = []
440
445
  image_scores_list = []
441
446
  image_labels_list = []
@@ -447,51 +452,62 @@ class SSD(DetectionBaseNet):
447
452
  box = boxes[keep_idxs]
448
453
 
449
454
  # Keep only topk scoring predictions
450
- num_topk = min(self.topk_candidates, int(score.size(0)))
455
+ num_topk = min(self.topk_candidates, score.size(0))
451
456
  score, idxs = score.topk(num_topk)
452
457
  box = box[idxs]
453
- if len(box) == 0 and list_empty is False:
454
- continue
455
458
 
456
459
  image_boxes_list.append(box)
457
460
  image_scores_list.append(score)
458
461
  image_labels_list.append(torch.full_like(score, fill_value=label, dtype=torch.int64, device=device))
459
- list_empty = False
460
462
 
461
463
  image_boxes = torch.concat(image_boxes_list, dim=0)
462
464
  image_scores = torch.concat(image_scores_list, dim=0)
463
465
  image_labels = torch.concat(image_labels_list, dim=0)
464
466
 
465
- # Non-maximum suppression
466
- keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
467
- keep = keep[: self.detections_per_img]
468
-
469
- detections.append(
470
- {
471
- "boxes": image_boxes[keep],
472
- "scores": image_scores[keep],
473
- "labels": image_labels[keep],
474
- }
475
- )
467
+ if self.export_mode is False:
468
+ # Non-maximum suppression
469
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
470
+ keep = keep[: self.detections_per_img]
471
+
472
+ detections.append(
473
+ {
474
+ "boxes": image_boxes[keep],
475
+ "scores": image_scores[keep],
476
+ "labels": image_labels[keep],
477
+ }
478
+ )
479
+ else:
480
+ detections.append(
481
+ {
482
+ "boxes": image_boxes,
483
+ "scores": image_scores,
484
+ "labels": image_labels,
485
+ }
486
+ )
476
487
 
477
488
  return detections
478
489
 
490
+ def forward_net(self, x: torch.Tensor) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
491
+ features = self.backbone.detection_features(x)
492
+ feature_list = list(features.values())
493
+ for extra_block in self.extra_blocks:
494
+ feature_list.append(extra_block(feature_list[-1]))
495
+
496
+ head_outputs = self.head(feature_list)
497
+
498
+ return (feature_list, head_outputs)
499
+
479
500
  def forward(
480
501
  self,
481
502
  x: torch.Tensor,
482
503
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
483
504
  masks: Optional[torch.Tensor] = None,
484
- image_sizes: Optional[list[list[int]]] = None,
505
+ image_sizes: Optional[list[tuple[int, int]]] = None,
485
506
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
486
507
  self._input_check(targets)
487
508
  images = self._to_img_list(x, image_sizes)
488
509
 
489
- features = self.backbone.detection_features(x)
490
- feature_list = list(features.values())
491
- for extra_block in self.extra_blocks:
492
- feature_list.append(extra_block(feature_list[-1]))
493
-
494
- head_outputs = self.head(feature_list)
510
+ feature_list, head_outputs = self.forward_net(x)
495
511
  anchors = self.anchor_generator(images, feature_list)
496
512
 
497
513
  losses = {}
@@ -50,7 +50,7 @@ class SSDLiteClassificationHead(SSDScoringHead):
50
50
  if isinstance(layer, nn.Conv2d):
51
51
  nn.init.xavier_uniform_(layer.weight)
52
52
  if layer.bias is not None:
53
- nn.init.constant_(layer.bias, 0.0)
53
+ nn.init.zeros_(layer.bias)
54
54
 
55
55
  super().__init__(cls_logits, num_classes)
56
56
 
@@ -79,7 +79,7 @@ class SSDLiteRegressionHead(SSDScoringHead):
79
79
  if isinstance(layer, nn.Conv2d):
80
80
  nn.init.xavier_uniform_(layer.weight)
81
81
  if layer.bias is not None:
82
- nn.init.constant_(layer.bias, 0.0)
82
+ nn.init.zeros_(layer.bias)
83
83
 
84
84
  super().__init__(bbox_reg, 4)
85
85