birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -301,14 +301,11 @@ class GlobalDecoderLayer(nn.Module):
301
301
 
302
302
 
303
303
  class GlobalDecoder(nn.Module):
304
- def __init__(
305
- self, decoder_layer: nn.Module, num_layers: int, norm: nn.Module, return_intermediate: bool, d_model: int
306
- ) -> None:
304
+ def __init__(self, decoder_layer: nn.Module, num_layers: int, norm: nn.Module, d_model: int) -> None:
307
305
  super().__init__()
308
306
  self.layers = _get_clones(decoder_layer, num_layers)
309
307
  self.num_layers = num_layers
310
308
  self.norm = norm
311
- self.return_intermediate = return_intermediate
312
309
  self.d_model = d_model
313
310
 
314
311
  self.bbox_embed: Optional[nn.ModuleList] = None
@@ -339,6 +336,7 @@ class GlobalDecoder(nn.Module):
339
336
  reference_points: torch.Tensor,
340
337
  spatial_shape: tuple[int, int],
341
338
  memory_key_padding_mask: Optional[torch.Tensor] = None,
339
+ return_intermediates: bool = True,
342
340
  ) -> tuple[torch.Tensor, torch.Tensor]:
343
341
  output = tgt
344
342
  intermediate = []
@@ -364,14 +362,14 @@ class GlobalDecoder(nn.Module):
364
362
  new_reference_points = new_reference_points.sigmoid()
365
363
  reference_points = new_reference_points.detach()
366
364
 
367
- if self.return_intermediate is True:
365
+ if return_intermediates is True:
368
366
  intermediate.append(output_for_pred)
369
367
  intermediate_reference_points.append(new_reference_points)
370
368
 
371
- if self.return_intermediate is True:
369
+ if return_intermediates is True:
372
370
  return torch.stack(intermediate), torch.stack(intermediate_reference_points)
373
371
 
374
- return output_for_pred.unsqueeze(0), new_reference_points.unsqueeze(0)
372
+ return output_for_pred, new_reference_points
375
373
 
376
374
  for layer in self.layers:
377
375
  reference_points_input = reference_points.detach().clamp(0, 1)
@@ -388,14 +386,14 @@ class GlobalDecoder(nn.Module):
388
386
 
389
387
  output_for_pred = self.norm(output)
390
388
 
391
- if self.return_intermediate is True:
389
+ if return_intermediates is True:
392
390
  intermediate.append(output_for_pred)
393
391
  intermediate_reference_points.append(reference_points)
394
392
 
395
- if self.return_intermediate is True:
393
+ if return_intermediates is True:
396
394
  return torch.stack(intermediate), torch.stack(intermediate_reference_points)
397
395
 
398
- return output_for_pred.unsqueeze(0), reference_points.unsqueeze(0)
396
+ return output_for_pred, reference_points
399
397
 
400
398
 
401
399
  class TransformerEncoderLayer(nn.Module):
@@ -467,7 +465,6 @@ class Plain_DETR(DetectionBaseNet):
467
465
  hidden_dim = 256
468
466
  num_heads = 8
469
467
  dropout = 0.0
470
- return_intermediate = True
471
468
  dim_feedforward: int = self.config.get("dim_feedforward", 2048)
472
469
  num_encoder_layers: int = self.config["num_encoder_layers"]
473
470
  num_decoder_layers: int = self.config["num_decoder_layers"]
@@ -516,19 +513,18 @@ class Plain_DETR(DetectionBaseNet):
516
513
  decoder_layer,
517
514
  num_decoder_layers,
518
515
  decoder_norm,
519
- return_intermediate=return_intermediate,
520
516
  d_model=hidden_dim,
521
517
  )
522
518
 
523
519
  self.class_embed = nn.Linear(hidden_dim, self.num_classes)
524
520
  self.bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
525
- self.query_embed = nn.Embedding(self.num_queries, hidden_dim * 2)
521
+ self.query_embed = nn.Parameter(torch.empty(self.num_queries, hidden_dim * 2))
526
522
  self.reference_point_head = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
527
523
  self.input_proj = nn.Conv2d(
528
524
  self.backbone.return_channels[-1], hidden_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
529
525
  )
530
526
  self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
531
- self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
527
+ self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
532
528
 
533
529
  if box_refine is True:
534
530
  self.class_embed = _get_clones(self.class_embed, num_decoder_layers)
@@ -554,6 +550,7 @@ class Plain_DETR(DetectionBaseNet):
554
550
  if idx == 0:
555
551
  nn.init.constant_(last_linear.bias[2:], -2.0) # Small initial wh
556
552
 
553
+ nn.init.normal_(self.query_embed)
557
554
  ref_last_linear = [m for m in self.reference_point_head.modules() if isinstance(m, nn.Linear)][-1]
558
555
  nn.init.zeros_(ref_last_linear.weight)
559
556
  nn.init.zeros_(ref_last_linear.bias)
@@ -576,7 +573,8 @@ class Plain_DETR(DetectionBaseNet):
576
573
  for param in self.class_embed.parameters():
577
574
  param.requires_grad_(True)
578
575
 
579
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
576
+ @staticmethod
577
+ def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
580
578
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
581
579
  src_idx = torch.concat([src for (src, _) in indices])
582
580
  return (batch_idx, src_idx)
@@ -585,7 +583,7 @@ class Plain_DETR(DetectionBaseNet):
585
583
  self,
586
584
  cls_logits: torch.Tensor,
587
585
  targets: list[dict[str, torch.Tensor]],
588
- indices: list[torch.Tensor],
586
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
589
587
  num_boxes: int,
590
588
  ) -> torch.Tensor:
591
589
  idx = self._get_src_permutation_idx(indices)
@@ -610,7 +608,7 @@ class Plain_DETR(DetectionBaseNet):
610
608
  self,
611
609
  box_output: torch.Tensor,
612
610
  targets: list[dict[str, torch.Tensor]],
613
- indices: list[torch.Tensor],
611
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
614
612
  num_boxes: int,
615
613
  ) -> tuple[torch.Tensor, torch.Tensor]:
616
614
  idx = self._get_src_permutation_idx(indices)
@@ -646,7 +644,7 @@ class Plain_DETR(DetectionBaseNet):
646
644
  if training_utils.is_dist_available_and_initialized() is True:
647
645
  torch.distributed.all_reduce(num_boxes)
648
646
 
649
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
647
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
650
648
 
651
649
  loss_ce_list = []
652
650
  loss_bbox_list = []
@@ -697,20 +695,17 @@ class Plain_DETR(DetectionBaseNet):
697
695
  return losses
698
696
 
699
697
  def postprocess_detections(
700
- self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
698
+ self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
701
699
  ) -> list[dict[str, torch.Tensor]]:
702
700
  prob = class_logits.sigmoid()
703
701
  scores, labels = prob.max(-1)
704
702
  labels = labels + 1 # Background offset
705
703
 
706
- # TorchScript doesn't support creating tensor from tuples, convert everything to lists
707
- target_sizes = torch.tensor([list(s) for s in image_shapes], device=class_logits.device)
708
-
709
704
  # Convert to [x0, y0, x1, y1] format
710
705
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
711
706
 
712
707
  # Convert from relative [0, 1] to absolute [0, height] coordinates
713
- img_h, img_w = target_sizes.unbind(1)
708
+ img_h, img_w = image_sizes.unbind(1)
714
709
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
715
710
  boxes = boxes * scale_fct[:, None, :]
716
711
 
@@ -735,17 +730,7 @@ class Plain_DETR(DetectionBaseNet):
735
730
 
736
731
  return detections
737
732
 
738
- # pylint: disable=too-many-locals
739
- def forward(
740
- self,
741
- x: torch.Tensor,
742
- targets: Optional[list[dict[str, torch.Tensor]]] = None,
743
- masks: Optional[torch.Tensor] = None,
744
- image_sizes: Optional[list[list[int]]] = None,
745
- ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
746
- self._input_check(targets)
747
- images = self._to_img_list(x, image_sizes)
748
-
733
+ def forward_net(self, x: torch.Tensor, masks: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
749
734
  features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
750
735
  src = features[self.backbone.return_stages[-1]]
751
736
  src = self.input_proj(src)
@@ -772,7 +757,7 @@ class Plain_DETR(DetectionBaseNet):
772
757
  else:
773
758
  num_queries_to_use = self.num_queries_one2one
774
759
 
775
- query_embed = self.query_embed.weight[:num_queries_to_use]
760
+ query_embed = self.query_embed[:num_queries_to_use]
776
761
  query_embed, query_pos = torch.split(query_embed, self.hidden_dim, dim=1)
777
762
  query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
778
763
  query_pos = query_pos.unsqueeze(0).expand(B, -1, -1)
@@ -787,25 +772,52 @@ class Plain_DETR(DetectionBaseNet):
787
772
  reference_points=reference_points,
788
773
  spatial_shape=(H, W),
789
774
  memory_key_padding_mask=mask_flatten,
775
+ return_intermediates=self.training is True,
790
776
  )
791
777
 
792
- outputs_classes = []
793
- outputs_coords = []
794
- for lvl, (class_embed, bbox_embed) in enumerate(zip(self.class_embed, self.bbox_embed)):
795
- outputs_class = class_embed(hs[lvl])
796
- outputs_classes.append(outputs_class)
778
+ if self.training is True:
779
+ outputs_classes = []
780
+ outputs_coords = []
781
+ for lvl, (class_embed, bbox_embed) in enumerate(zip(self.class_embed, self.bbox_embed)):
782
+ outputs_class = class_embed(hs[lvl])
783
+ outputs_classes.append(outputs_class)
784
+
785
+ if self.box_refine is True:
786
+ outputs_coord = inter_references[lvl]
787
+ else:
788
+ tmp = bbox_embed(hs[lvl])
789
+ tmp = tmp + inverse_sigmoid(reference_points)
790
+ outputs_coord = tmp.sigmoid()
791
+
792
+ outputs_coords.append(outputs_coord)
793
+
794
+ outputs_class = torch.stack(outputs_classes)
795
+ outputs_coord = torch.stack(outputs_coords)
796
+ else:
797
+ class_embed = self.class_embed[-1]
798
+ bbox_embed = self.bbox_embed[-1]
799
+ outputs_class = class_embed(hs)
797
800
 
798
801
  if self.box_refine is True:
799
- outputs_coord = inter_references[lvl]
802
+ outputs_coord = inter_references
800
803
  else:
801
- tmp = bbox_embed(hs[lvl])
804
+ tmp = bbox_embed(hs)
802
805
  tmp = tmp + inverse_sigmoid(reference_points)
803
806
  outputs_coord = tmp.sigmoid()
804
807
 
805
- outputs_coords.append(outputs_coord)
808
+ return (outputs_class, outputs_coord)
809
+
810
+ def forward(
811
+ self,
812
+ x: torch.Tensor,
813
+ targets: Optional[list[dict[str, torch.Tensor]]] = None,
814
+ masks: Optional[torch.Tensor] = None,
815
+ image_sizes: Optional[list[tuple[int, int]]] = None,
816
+ ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
817
+ self._input_check(targets)
818
+ images = self._to_img_list(x, image_sizes)
806
819
 
807
- outputs_class = torch.stack(outputs_classes)
808
- outputs_coord = torch.stack(outputs_coords)
820
+ outputs_class, outputs_coord = self.forward_net(x, masks)
809
821
 
810
822
  losses = {}
811
823
  detections: list[dict[str, torch.Tensor]] = []
@@ -815,7 +827,8 @@ class Plain_DETR(DetectionBaseNet):
815
827
  for idx, target in enumerate(targets):
816
828
  boxes = target["boxes"]
817
829
  boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
818
- boxes = boxes / torch.tensor(images.image_sizes[idx][::-1] * 2, dtype=torch.float32, device=x.device)
830
+ scale = images.image_sizes[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
831
+ boxes = boxes / scale
819
832
  targets[idx]["boxes"] = boxes
820
833
  targets[idx]["labels"] = target["labels"] - 1 # No background
821
834
 
@@ -835,7 +848,7 @@ class Plain_DETR(DetectionBaseNet):
835
848
  )
836
849
 
837
850
  else:
838
- detections = self.postprocess_detections(outputs_class[-1], outputs_coord[-1], images.image_sizes)
851
+ detections = self.postprocess_detections(outputs_class, outputs_coord, images.image_sizes)
839
852
 
840
853
  return (detections, losses)
841
854
 
@@ -30,6 +30,7 @@ from birder.net.detection.base import BackboneWithSimpleFPN
30
30
  from birder.net.detection.base import BoxCoder
31
31
  from birder.net.detection.base import DetectionBaseNet
32
32
  from birder.net.detection.base import Matcher
33
+ from birder.net.detection.base import clip_boxes_to_image
33
34
  from birder.ops.soft_nms import SoftNMS
34
35
 
35
36
 
@@ -63,7 +64,7 @@ class RetinaNetClassificationHead(nn.Module):
63
64
  if isinstance(layer, nn.Conv2d):
64
65
  nn.init.normal_(layer.weight, std=0.01)
65
66
  if layer.bias is not None:
66
- nn.init.constant_(layer.bias, 0)
67
+ nn.init.zeros_(layer.bias)
67
68
 
68
69
  self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
69
70
 
@@ -281,6 +282,11 @@ class RetinaNet(DetectionBaseNet):
281
282
  if soft_nms is True:
282
283
  self.soft_nms = SoftNMS()
283
284
 
285
+ self.score_thresh = score_thresh
286
+ self.nms_thresh = nms_thresh
287
+ self.detections_per_img = detections_per_img
288
+ self.topk_candidates = topk_candidates
289
+
284
290
  if feature_pyramid_type == "fpn":
285
291
  feature_pyramid: Callable[..., nn.Module] = BackboneWithFPN
286
292
  num_anchor_sizes = len(self.backbone.return_stages) + 2
@@ -314,10 +320,8 @@ class RetinaNet(DetectionBaseNet):
314
320
  self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=True)
315
321
  self.box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
316
322
 
317
- self.score_thresh = score_thresh
318
- self.nms_thresh = nms_thresh
319
- self.detections_per_img = detections_per_img
320
- self.topk_candidates = topk_candidates
323
+ if self.export_mode is False:
324
+ self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
321
325
 
322
326
  def reset_classifier(self, num_classes: int) -> None:
323
327
  self.num_classes = num_classes
@@ -341,10 +345,7 @@ class RetinaNet(DetectionBaseNet):
341
345
  @torch.jit.unused # type: ignore[untyped-decorator]
342
346
  @torch.compiler.disable() # type: ignore[untyped-decorator]
343
347
  def compute_loss(
344
- self,
345
- targets: list[dict[str, torch.Tensor]],
346
- head_outputs: dict[str, torch.Tensor],
347
- anchors: list[torch.Tensor],
348
+ self, targets: list[dict[str, torch.Tensor]], head_outputs: dict[str, torch.Tensor], anchors: list[torch.Tensor]
348
349
  ) -> dict[str, torch.Tensor]:
349
350
  matched_idxs = []
350
351
  for idx, (anchors_per_image, targets_per_image) in enumerate(zip(anchors, targets)):
@@ -362,22 +363,19 @@ class RetinaNet(DetectionBaseNet):
362
363
 
363
364
  # pylint: disable=too-many-locals
364
365
  def postprocess_detections(
365
- self,
366
- head_outputs: dict[str, list[torch.Tensor]],
367
- anchors: list[list[torch.Tensor]],
368
- image_shapes: list[tuple[int, int]],
366
+ self, head_outputs: dict[str, list[torch.Tensor]], anchors: list[list[torch.Tensor]], image_sizes: torch.Tensor
369
367
  ) -> list[dict[str, torch.Tensor]]:
370
368
  class_logits = head_outputs["cls_logits"]
371
369
  box_regression = head_outputs["bbox_regression"]
372
370
 
373
- num_images = len(image_shapes)
371
+ num_images = image_sizes.size(0)
374
372
 
375
373
  detections: list[dict[str, torch.Tensor]] = []
376
374
  for index in range(num_images):
377
375
  box_regression_per_image = [br[index] for br in box_regression]
378
376
  logits_per_image = [cl[index] for cl in class_logits]
379
377
  anchors_per_image = anchors[index]
380
- image_shape = image_shapes[index]
378
+ image_shape = image_sizes[index]
381
379
 
382
380
  image_boxes_list = []
383
381
  image_scores_list = []
@@ -394,7 +392,7 @@ class RetinaNet(DetectionBaseNet):
394
392
  topk_idxs = torch.where(keep_idxs)[0]
395
393
 
396
394
  # Keep only topk scoring predictions
397
- num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
395
+ num_topk = min(self.topk_candidates, topk_idxs.size(0))
398
396
  scores_per_level, idxs = scores_per_level.topk(num_topk)
399
397
  topk_idxs = topk_idxs[idxs]
400
398
 
@@ -405,7 +403,7 @@ class RetinaNet(DetectionBaseNet):
405
403
  boxes_per_level = self.box_coder.decode_single(
406
404
  box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
407
405
  )
408
- boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
406
+ boxes_per_level = clip_boxes_to_image(boxes_per_level, image_shape)
409
407
 
410
408
  image_boxes_list.append(boxes_per_level)
411
409
  image_scores_list.append(scores_per_level)
@@ -415,24 +413,40 @@ class RetinaNet(DetectionBaseNet):
415
413
  image_scores = torch.concat(image_scores_list, dim=0)
416
414
  image_labels = torch.concat(image_labels_list, dim=0)
417
415
 
418
- # Non-maximum suppression
419
- if self.soft_nms is not None:
420
- soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
421
- image_scores[keep] = soft_scores
416
+ if self.export_mode is False:
417
+ # Non-maximum suppression
418
+ if self.soft_nms is not None:
419
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
420
+ image_scores[keep] = soft_scores
421
+ else:
422
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
423
+
424
+ keep = keep[: self.detections_per_img]
425
+
426
+ detections.append(
427
+ {
428
+ "boxes": image_boxes[keep],
429
+ "scores": image_scores[keep],
430
+ "labels": image_labels[keep],
431
+ }
432
+ )
422
433
  else:
423
- keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
434
+ detections.append(
435
+ {
436
+ "boxes": image_boxes,
437
+ "scores": image_scores,
438
+ "labels": image_labels,
439
+ }
440
+ )
424
441
 
425
- keep = keep[: self.detections_per_img]
442
+ return detections
426
443
 
427
- detections.append(
428
- {
429
- "boxes": image_boxes[keep],
430
- "scores": image_scores[keep],
431
- "labels": image_labels[keep],
432
- }
433
- )
444
+ def forward_net(self, x: torch.Tensor) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
445
+ features: dict[str, torch.Tensor] = self.backbone_with_fpn(x)
446
+ feature_list = list(features.values())
447
+ head_outputs = self.head(feature_list)
434
448
 
435
- return detections
449
+ return (feature_list, head_outputs)
436
450
 
437
451
  # pylint: disable=invalid-name
438
452
  def forward(
@@ -440,14 +454,12 @@ class RetinaNet(DetectionBaseNet):
440
454
  x: torch.Tensor,
441
455
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
442
456
  masks: Optional[torch.Tensor] = None,
443
- image_sizes: Optional[list[list[int]]] = None,
457
+ image_sizes: Optional[list[tuple[int, int]]] = None,
444
458
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
445
459
  self._input_check(targets)
446
460
  images = self._to_img_list(x, image_sizes)
447
461
 
448
- features: dict[str, torch.Tensor] = self.backbone_with_fpn(x)
449
- feature_list = list(features.values())
450
- head_outputs = self.head(feature_list)
462
+ feature_list, head_outputs = self.forward_net(x)
451
463
  anchors = self.anchor_generator(images, feature_list)
452
464
 
453
465
  losses: dict[str, torch.Tensor] = {}
@@ -47,9 +47,6 @@ def get_contrastive_denoising_training_group( # pylint: disable=too-many-locals
47
47
  label_noise_ratio: float,
48
48
  box_noise_scale: float,
49
49
  ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[dict[str, Any]]]:
50
- if num_denoising_queries <= 0:
51
- return (None, None, None, None)
52
-
53
50
  num_ground_truths = [len(t["labels"]) for t in targets]
54
51
  device = targets[0]["labels"].device
55
52
 
@@ -596,18 +593,18 @@ class RT_DETRDecoder(nn.Module):
596
593
 
597
594
  # Gather reference points
598
595
  reference_points_unact = enc_outputs_coord_unact.gather(
599
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1])
596
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
600
597
  )
601
598
 
602
599
  enc_topk_bboxes = reference_points_unact.sigmoid()
603
600
 
604
601
  # Gather encoder logits for loss computation
605
602
  enc_topk_logits = enc_outputs_class.gather(
606
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
603
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
607
604
  )
608
605
 
609
606
  # Extract region features
610
- target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
607
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
611
608
  target = target.detach()
612
609
 
613
610
  return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
@@ -621,6 +618,7 @@ class RT_DETRDecoder(nn.Module):
621
618
  denoising_bbox_unact: Optional[torch.Tensor] = None,
622
619
  attn_mask: Optional[torch.Tensor] = None,
623
620
  padding_mask: Optional[list[torch.Tensor]] = None,
621
+ return_intermediates: bool = True,
624
622
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
625
623
  memory = []
626
624
  mask_flatten = []
@@ -648,12 +646,12 @@ class RT_DETRDecoder(nn.Module):
648
646
  level_start_index_tensor = torch.tensor(level_start_index, dtype=torch.long, device=memory.device)
649
647
 
650
648
  # Decoder forward
651
- out_bboxes = []
652
- out_logits = []
649
+ bboxes_list: list[torch.Tensor] = []
650
+ logits_list: list[torch.Tensor] = []
653
651
  reference_points = init_ref_points_unact.sigmoid()
654
652
  for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
655
653
  query_pos = self.query_pos_head(reference_points)
656
- reference_points_input = reference_points.unsqueeze(2).repeat(1, 1, len(spatial_shapes), 1)
654
+ reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
657
655
  target = decoder_layer(
658
656
  target,
659
657
  query_pos,
@@ -663,6 +661,7 @@ class RT_DETRDecoder(nn.Module):
663
661
  level_start_index_tensor,
664
662
  memory_padding_mask,
665
663
  attn_mask,
664
+ src_shapes=spatial_shapes,
666
665
  )
667
666
 
668
667
  bbox_delta = bbox_head(target)
@@ -672,14 +671,19 @@ class RT_DETRDecoder(nn.Module):
672
671
  # Classification
673
672
  class_logits = class_head(target)
674
673
 
675
- out_bboxes.append(new_reference_points)
676
- out_logits.append(class_logits)
674
+ if return_intermediates is True:
675
+ bboxes_list.append(new_reference_points)
676
+ logits_list.append(class_logits)
677
677
 
678
678
  # Update reference points for next layer
679
679
  reference_points = new_reference_points.detach()
680
680
 
681
- out_bboxes = torch.stack(out_bboxes)
682
- out_logits = torch.stack(out_logits)
681
+ if return_intermediates is True:
682
+ out_bboxes = torch.stack(bboxes_list)
683
+ out_logits = torch.stack(logits_list)
684
+ else:
685
+ out_bboxes = new_reference_points
686
+ out_logits = class_logits
683
687
 
684
688
  return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits)
685
689
 
@@ -743,7 +747,7 @@ class RT_DETR_v1(DetectionBaseNet):
743
747
  self.decoder = RT_DETRDecoder(
744
748
  hidden_dim=hidden_dim,
745
749
  num_classes=self.num_classes,
746
- num_queries=num_queries,
750
+ num_queries=self.num_queries,
747
751
  num_decoder_layers=num_decoder_layers,
748
752
  num_levels=self.num_levels,
749
753
  num_heads=num_heads,
@@ -810,7 +814,8 @@ class RT_DETR_v1(DetectionBaseNet):
810
814
  for param in self.denoising_class_embed.parameters():
811
815
  param.requires_grad_(True)
812
816
 
813
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
817
+ @staticmethod
818
+ def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
814
819
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
815
820
  src_idx = torch.concat([src for (src, _) in indices])
816
821
  return (batch_idx, src_idx)
@@ -820,7 +825,7 @@ class RT_DETR_v1(DetectionBaseNet):
820
825
  cls_logits: torch.Tensor,
821
826
  box_output: torch.Tensor,
822
827
  targets: list[dict[str, torch.Tensor]],
823
- indices: list[torch.Tensor],
828
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
824
829
  num_boxes: float,
825
830
  ) -> torch.Tensor:
826
831
  idx = self._get_src_permutation_idx(indices)
@@ -859,7 +864,7 @@ class RT_DETR_v1(DetectionBaseNet):
859
864
  self,
860
865
  box_output: torch.Tensor,
861
866
  targets: list[dict[str, torch.Tensor]],
862
- indices: list[torch.Tensor],
867
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
863
868
  num_boxes: float,
864
869
  ) -> tuple[torch.Tensor, torch.Tensor]:
865
870
  idx = self._get_src_permutation_idx(indices)
@@ -927,8 +932,6 @@ class RT_DETR_v1(DetectionBaseNet):
927
932
 
928
933
  return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
929
934
 
930
- @torch.jit.unused # type: ignore[untyped-decorator]
931
- @torch.compiler.disable() # type: ignore[untyped-decorator]
932
935
  def _compute_loss_from_outputs( # pylint: disable=too-many-locals
933
936
  self,
934
937
  targets: list[dict[str, torch.Tensor]],
@@ -946,7 +949,7 @@ class RT_DETR_v1(DetectionBaseNet):
946
949
  if training_utils.is_dist_available_and_initialized() is True:
947
950
  torch.distributed.all_reduce(num_boxes)
948
951
 
949
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
952
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
950
953
 
951
954
  loss_ce_list = []
952
955
  loss_bbox_list = []
@@ -1001,11 +1004,11 @@ class RT_DETR_v1(DetectionBaseNet):
1001
1004
  images: Any,
1002
1005
  masks: Optional[list[torch.Tensor]] = None,
1003
1006
  ) -> dict[str, torch.Tensor]:
1004
- device = encoder_features[0].device
1005
1007
  for idx, target in enumerate(targets):
1006
1008
  boxes = target["boxes"]
1007
1009
  boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
1008
- boxes = boxes / torch.tensor(images.image_sizes[idx][::-1] * 2, dtype=torch.float32, device=device)
1010
+ scale = images.image_sizes[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
1011
+ boxes = boxes / scale
1009
1012
  targets[idx]["boxes"] = boxes
1010
1013
  targets[idx]["labels"] = target["labels"] - 1 # No background
1011
1014
 
@@ -1038,7 +1041,7 @@ class RT_DETR_v1(DetectionBaseNet):
1038
1041
  return losses
1039
1042
 
1040
1043
  def postprocess_detections(
1041
- self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
1044
+ self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
1042
1045
  ) -> list[dict[str, torch.Tensor]]:
1043
1046
  prob = class_logits.sigmoid()
1044
1047
  topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1)
@@ -1047,14 +1050,12 @@ class RT_DETR_v1(DetectionBaseNet):
1047
1050
  labels = topk_indexes % class_logits.shape[2]
1048
1051
  labels += 1 # Background offset
1049
1052
 
1050
- target_sizes = torch.tensor(image_shapes, device=class_logits.device)
1051
-
1052
1053
  # Convert to [x0, y0, x1, y1] format
1053
1054
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
1054
- boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
1055
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
1055
1056
 
1056
1057
  # Convert from relative [0, 1] to absolute [0, height] coordinates
1057
- img_h, img_w = target_sizes.unbind(1)
1058
+ img_h, img_w = image_sizes.unbind(1)
1058
1059
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
1059
1060
  boxes = boxes * scale_fct[:, None, :]
1060
1061
 
@@ -1090,32 +1091,34 @@ class RT_DETR_v1(DetectionBaseNet):
1090
1091
 
1091
1092
  return (None, None, None, None)
1092
1093
 
1094
+ def forward_net(
1095
+ self, x: torch.Tensor, masks: Optional[torch.Tensor]
1096
+ ) -> tuple[list[torch.Tensor], Optional[list[torch.Tensor]]]:
1097
+ features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
1098
+ feature_list = list(features.values())
1099
+
1100
+ mask_list: Optional[list[torch.Tensor]] = None
1101
+ if masks is not None:
1102
+ mask_list = []
1103
+ for feat in feature_list:
1104
+ m = F.interpolate(masks[None].float(), size=feat.shape[-2:], mode="nearest").to(torch.bool)[0]
1105
+ mask_list.append(m)
1106
+
1107
+ encoder_features = self.encoder(feature_list, masks=mask_list)
1108
+
1109
+ return (encoder_features, mask_list)
1110
+
1093
1111
  def forward(
1094
1112
  self,
1095
1113
  x: torch.Tensor,
1096
1114
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
1097
1115
  masks: Optional[torch.Tensor] = None,
1098
- image_sizes: Optional[list[list[int]]] = None,
1116
+ image_sizes: Optional[list[tuple[int, int]]] = None,
1099
1117
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
1100
1118
  self._input_check(targets)
1101
1119
  images = self._to_img_list(x, image_sizes)
1102
1120
 
1103
- # Backbone features
1104
- features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
1105
- feature_list = list(features.values())
1106
-
1107
- # Hybrid encoder
1108
- mask_list: list[torch.Tensor] = []
1109
- for feat in feature_list:
1110
- if masks is not None:
1111
- mask_size = feat.shape[-2:]
1112
- m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
1113
- else:
1114
- B, _, H, W = feat.size()
1115
- m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
1116
- mask_list.append(m)
1117
-
1118
- encoder_features = self.encoder(feature_list, masks=mask_list)
1121
+ encoder_features, mask_list = self.forward_net(x, masks)
1119
1122
 
1120
1123
  # Prepare spatial shapes and level start index
1121
1124
  spatial_shapes: list[list[int]] = []
@@ -1136,9 +1139,9 @@ class RT_DETR_v1(DetectionBaseNet):
1136
1139
  else:
1137
1140
  # Inference path - no CDN
1138
1141
  out_bboxes, out_logits, _, _ = self.decoder(
1139
- encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list
1142
+ encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list, return_intermediates=False
1140
1143
  )
1141
- detections = self.postprocess_detections(out_logits[-1], out_bboxes[-1], images.image_sizes)
1144
+ detections = self.postprocess_detections(out_logits, out_bboxes, images.image_sizes)
1142
1145
 
1143
1146
  return (detections, losses)
1144
1147