birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. birder/__init__.py +2 -0
  2. birder/common/fs_ops.py +81 -1
  3. birder/common/training_cli.py +12 -2
  4. birder/common/training_utils.py +73 -12
  5. birder/data/collators/detection.py +3 -1
  6. birder/datahub/_lib.py +15 -6
  7. birder/datahub/evaluation.py +591 -0
  8. birder/eval/__init__.py +0 -0
  9. birder/eval/__main__.py +74 -0
  10. birder/eval/_embeddings.py +50 -0
  11. birder/eval/adversarial.py +315 -0
  12. birder/eval/benchmarks/__init__.py +0 -0
  13. birder/eval/benchmarks/awa2.py +357 -0
  14. birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder/eval/benchmarks/fishnet.py +318 -0
  16. birder/eval/benchmarks/flowers102.py +210 -0
  17. birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder/eval/benchmarks/nabirds.py +202 -0
  19. birder/eval/benchmarks/newt.py +262 -0
  20. birder/eval/benchmarks/plankton.py +255 -0
  21. birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder/eval/benchmarks/plantnet.py +252 -0
  23. birder/eval/classification.py +235 -0
  24. birder/eval/methods/__init__.py +0 -0
  25. birder/eval/methods/ami.py +78 -0
  26. birder/eval/methods/knn.py +71 -0
  27. birder/eval/methods/linear.py +152 -0
  28. birder/eval/methods/mlp.py +178 -0
  29. birder/eval/methods/simpleshot.py +100 -0
  30. birder/eval/methods/svm.py +92 -0
  31. birder/inference/classification.py +23 -2
  32. birder/inference/detection.py +35 -15
  33. birder/net/_vit_configs.py +5 -0
  34. birder/net/cait.py +3 -3
  35. birder/net/coat.py +3 -3
  36. birder/net/cswin_transformer.py +2 -1
  37. birder/net/deit.py +1 -1
  38. birder/net/deit3.py +1 -1
  39. birder/net/detection/__init__.py +2 -0
  40. birder/net/detection/base.py +41 -18
  41. birder/net/detection/deformable_detr.py +74 -50
  42. birder/net/detection/detr.py +29 -26
  43. birder/net/detection/efficientdet.py +42 -25
  44. birder/net/detection/faster_rcnn.py +53 -21
  45. birder/net/detection/fcos.py +42 -23
  46. birder/net/detection/lw_detr.py +1204 -0
  47. birder/net/detection/plain_detr.py +60 -47
  48. birder/net/detection/retinanet.py +47 -35
  49. birder/net/detection/rt_detr_v1.py +49 -46
  50. birder/net/detection/rt_detr_v2.py +95 -102
  51. birder/net/detection/ssd.py +47 -31
  52. birder/net/detection/ssdlite.py +2 -2
  53. birder/net/detection/yolo_v2.py +33 -18
  54. birder/net/detection/yolo_v3.py +35 -33
  55. birder/net/detection/yolo_v4.py +35 -20
  56. birder/net/detection/yolo_v4_tiny.py +1 -2
  57. birder/net/edgevit.py +3 -3
  58. birder/net/efficientvit_msft.py +1 -1
  59. birder/net/flexivit.py +1 -1
  60. birder/net/hiera.py +44 -67
  61. birder/net/hieradet.py +2 -2
  62. birder/net/maxvit.py +2 -2
  63. birder/net/mim/fcmae.py +2 -2
  64. birder/net/mim/mae_hiera.py +9 -16
  65. birder/net/mnasnet.py +2 -2
  66. birder/net/nextvit.py +4 -4
  67. birder/net/resnext.py +2 -2
  68. birder/net/rope_deit3.py +2 -2
  69. birder/net/rope_flexivit.py +2 -2
  70. birder/net/rope_vit.py +2 -2
  71. birder/net/simple_vit.py +1 -1
  72. birder/net/squeezenet.py +1 -1
  73. birder/net/ssl/capi.py +32 -25
  74. birder/net/ssl/dino_v2.py +12 -15
  75. birder/net/ssl/franca.py +26 -19
  76. birder/net/van.py +2 -2
  77. birder/net/vit.py +21 -3
  78. birder/net/vit_parallel.py +1 -1
  79. birder/net/vit_sam.py +62 -16
  80. birder/net/xcit.py +1 -1
  81. birder/ops/msda.py +46 -16
  82. birder/scripts/benchmark.py +35 -8
  83. birder/scripts/predict.py +14 -1
  84. birder/scripts/predict_detection.py +7 -1
  85. birder/scripts/train.py +27 -11
  86. birder/scripts/train_capi.py +13 -10
  87. birder/scripts/train_detection.py +18 -7
  88. birder/scripts/train_franca.py +10 -2
  89. birder/scripts/train_kd.py +28 -11
  90. birder/tools/adversarial.py +5 -0
  91. birder/tools/convert_model.py +101 -43
  92. birder/tools/quantize_model.py +33 -16
  93. birder/version.py +1 -1
  94. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
  95. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
  96. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
  97. birder/scripts/evaluate.py +0 -176
  98. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
  99. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
  100. {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ from birder.net.detection.base import AnchorGenerator
27
27
  from birder.net.detection.base import BoxCoder
28
28
  from birder.net.detection.base import DetectionBaseNet
29
29
  from birder.net.detection.base import Matcher
30
+ from birder.net.detection.base import clip_boxes_to_image
30
31
  from birder.ops.soft_nms import SoftNMS
31
32
 
32
33
 
@@ -588,6 +589,8 @@ class EfficientDet(DetectionBaseNet):
588
589
  for param in self.class_net.parameters():
589
590
  param.requires_grad_(True)
590
591
 
592
+ @torch.jit.unused # type: ignore[untyped-decorator]
593
+ @torch.compiler.disable() # type: ignore[untyped-decorator]
591
594
  def compute_loss(
592
595
  self,
593
596
  targets: list[dict[str, torch.Tensor]],
@@ -617,16 +620,16 @@ class EfficientDet(DetectionBaseNet):
617
620
  class_logits: list[torch.Tensor],
618
621
  box_regression: list[torch.Tensor],
619
622
  anchors: list[list[torch.Tensor]],
620
- image_shapes: list[tuple[int, int]],
623
+ image_sizes: torch.Tensor,
621
624
  ) -> list[dict[str, torch.Tensor]]:
622
- num_images = len(image_shapes)
625
+ num_images = image_sizes.size(0)
623
626
 
624
627
  detections: list[dict[str, torch.Tensor]] = []
625
628
  for index in range(num_images):
626
629
  box_regression_per_image = [br[index] for br in box_regression]
627
630
  logits_per_image = [cl[index] for cl in class_logits]
628
631
  anchors_per_image = anchors[index]
629
- image_shape = image_shapes[index]
632
+ image_shape = image_sizes[index]
630
633
 
631
634
  image_boxes_list = []
632
635
  image_scores_list = []
@@ -643,7 +646,7 @@ class EfficientDet(DetectionBaseNet):
643
646
  topk_idxs = torch.where(keep_idxs)[0]
644
647
 
645
648
  # Keep only topk scoring predictions
646
- num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
649
+ num_topk = min(self.topk_candidates, topk_idxs.size(0))
647
650
  scores_per_level, idxs = scores_per_level.topk(num_topk)
648
651
  topk_idxs = topk_idxs[idxs]
649
652
 
@@ -654,7 +657,7 @@ class EfficientDet(DetectionBaseNet):
654
657
  boxes_per_level = self.box_coder.decode_single(
655
658
  box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
656
659
  )
657
- boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
660
+ boxes_per_level = clip_boxes_to_image(boxes_per_level, image_shape)
658
661
 
659
662
  image_boxes_list.append(boxes_per_level)
660
663
  image_scores_list.append(scores_per_level)
@@ -664,24 +667,42 @@ class EfficientDet(DetectionBaseNet):
664
667
  image_scores = torch.concat(image_scores_list, dim=0)
665
668
  image_labels = torch.concat(image_labels_list, dim=0)
666
669
 
667
- # Non-maximum suppression
668
- if self.soft_nms is not None:
669
- soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
670
- image_scores[keep] = soft_scores
670
+ if self.export_mode is False:
671
+ # Non-maximum suppression
672
+ if self.soft_nms is not None:
673
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
674
+ image_scores[keep] = soft_scores
675
+ else:
676
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
677
+
678
+ keep = keep[: self.detections_per_img]
679
+
680
+ detections.append(
681
+ {
682
+ "boxes": image_boxes[keep],
683
+ "scores": image_scores[keep],
684
+ "labels": image_labels[keep],
685
+ }
686
+ )
671
687
  else:
672
- keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
688
+ detections.append(
689
+ {
690
+ "boxes": image_boxes,
691
+ "scores": image_scores,
692
+ "labels": image_labels,
693
+ }
694
+ )
673
695
 
674
- keep = keep[: self.detections_per_img]
696
+ return detections
675
697
 
676
- detections.append(
677
- {
678
- "boxes": image_boxes[keep],
679
- "scores": image_scores[keep],
680
- "labels": image_labels[keep],
681
- }
682
- )
698
+ def forward_net(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]:
699
+ features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
700
+ feature_list = list(features.values())
701
+ feature_list = self.bifpn(feature_list)
702
+ cls_logits = self.class_net(feature_list)
703
+ box_output = self.box_net(feature_list)
683
704
 
684
- return detections
705
+ return (cls_logits, box_output, feature_list)
685
706
 
686
707
  # pylint: disable=invalid-name
687
708
  def forward(
@@ -689,16 +710,12 @@ class EfficientDet(DetectionBaseNet):
689
710
  x: torch.Tensor,
690
711
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
691
712
  masks: Optional[torch.Tensor] = None,
692
- image_sizes: Optional[list[list[int]]] = None,
713
+ image_sizes: Optional[list[tuple[int, int]]] = None,
693
714
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
694
715
  self._input_check(targets)
695
716
  images = self._to_img_list(x, image_sizes)
696
717
 
697
- features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
698
- feature_list = list(features.values())
699
- feature_list = self.bifpn(feature_list)
700
- cls_logits = self.class_net(feature_list)
701
- box_output = self.box_net(feature_list)
718
+ cls_logits, box_output, feature_list = self.forward_net(x)
702
719
  anchors = self.anchor_generator(images, feature_list)
703
720
 
704
721
  losses: dict[str, torch.Tensor] = {}
@@ -27,6 +27,7 @@ from birder.net.detection.base import BoxCoder
27
27
  from birder.net.detection.base import DetectionBaseNet
28
28
  from birder.net.detection.base import ImageList
29
29
  from birder.net.detection.base import Matcher
30
+ from birder.net.detection.base import clip_boxes_to_image
30
31
 
31
32
 
32
33
  class BalancedPositiveNegativeSampler:
@@ -169,6 +170,24 @@ def concat_box_prediction_layers(
169
170
  return (box_cls, box_regression)
170
171
 
171
172
 
173
+ def _batched_nms_coordinate_trick(
174
+ boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
175
+ ) -> torch.Tensor:
176
+ """
177
+ Batched NMS using coordinate offset trick (same as torchvision)
178
+
179
+ Separates boxes from different classes by adding class-specific offsets,
180
+ then runs standard NMS on all boxes together.
181
+ """
182
+
183
+ max_coordinate = boxes.max()
184
+ offsets = idxs.to(boxes) * (max_coordinate + 1)
185
+ boxes_for_nms = boxes + offsets[:, None]
186
+ keep = box_ops.nms(boxes_for_nms, scores, iou_threshold)
187
+
188
+ return keep
189
+
190
+
172
191
  class RegionProposalNetwork(nn.Module):
173
192
  def __init__(
174
193
  self,
@@ -184,6 +203,8 @@ class RegionProposalNetwork(nn.Module):
184
203
  post_nms_top_n: dict[str, int],
185
204
  nms_thresh: float,
186
205
  score_thresh: float = 0.0,
206
+ # Other
207
+ export_mode: bool = False,
187
208
  ) -> None:
188
209
  super().__init__()
189
210
  self.anchor_generator = anchor_generator
@@ -206,6 +227,7 @@ class RegionProposalNetwork(nn.Module):
206
227
  self.nms_thresh = nms_thresh
207
228
  self.score_thresh = score_thresh
208
229
  self.min_size = 1e-3
230
+ self.export_mode = export_mode
209
231
 
210
232
  def pre_nms_top_n(self) -> int:
211
233
  if self.training is True:
@@ -275,7 +297,7 @@ class RegionProposalNetwork(nn.Module):
275
297
  self,
276
298
  proposals: torch.Tensor,
277
299
  objectness: torch.Tensor,
278
- image_shapes: list[tuple[int, int]],
300
+ image_sizes: torch.Tensor,
279
301
  num_anchors_per_level: list[int],
280
302
  ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
281
303
  num_images = proposals.shape[0]
@@ -305,8 +327,8 @@ class RegionProposalNetwork(nn.Module):
305
327
 
306
328
  final_boxes = []
307
329
  final_scores = []
308
- for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
309
- boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
330
+ for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_sizes):
331
+ boxes = clip_boxes_to_image(boxes, img_shape)
310
332
 
311
333
  # Remove small boxes
312
334
  keep = box_ops.remove_small_boxes(boxes, self.min_size)
@@ -317,13 +339,17 @@ class RegionProposalNetwork(nn.Module):
317
339
  keep = torch.where(scores >= self.score_thresh)[0]
318
340
  boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
319
341
 
320
- # Non-maximum suppression, independently done per level
321
- keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
342
+ if self.export_mode is False:
343
+ # Non-maximum suppression, independently done per level
344
+ keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
322
345
 
323
- # Keep only topk scoring predictions
324
- keep = keep[: self.post_nms_top_n()]
325
- boxes, scores = boxes[keep], scores[keep]
346
+ # Keep only topk scoring predictions
347
+ keep = keep[: self.post_nms_top_n()]
348
+ else:
349
+ keep = _batched_nms_coordinate_trick(boxes, scores, lvl, self.nms_thresh)
326
350
 
351
+ boxes = boxes[keep]
352
+ scores = scores[keep]
327
353
  final_boxes.append(boxes)
328
354
  final_scores.append(scores)
329
355
 
@@ -531,6 +557,7 @@ class RoIHeads(nn.Module):
531
557
  self.score_thresh = score_thresh
532
558
  self.nms_thresh = nms_thresh
533
559
  self.detections_per_img = detections_per_img
560
+ self.export_mode = export_mode
534
561
 
535
562
  if export_mode is False:
536
563
  self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
@@ -637,7 +664,7 @@ class RoIHeads(nn.Module):
637
664
  class_logits: torch.Tensor,
638
665
  box_regression: torch.Tensor,
639
666
  proposals: list[torch.Tensor],
640
- image_shapes: list[tuple[int, int]],
667
+ image_sizes: torch.Tensor,
641
668
  ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
642
669
  device = class_logits.device
643
670
  num_classes = class_logits.shape[-1]
@@ -653,8 +680,8 @@ class RoIHeads(nn.Module):
653
680
  all_boxes = []
654
681
  all_scores = []
655
682
  all_labels = []
656
- for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
657
- boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
683
+ for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_sizes):
684
+ boxes = clip_boxes_to_image(boxes, image_shape)
658
685
 
659
686
  # Create labels for each prediction
660
687
  labels = torch.arange(num_classes, device=device)
@@ -682,14 +709,15 @@ class RoIHeads(nn.Module):
682
709
  scores = scores[keep]
683
710
  labels = labels[keep]
684
711
 
685
- # Non-maximum suppression, independently done per class
686
- keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
712
+ if self.export_mode is False:
713
+ # Non-maximum suppression, independently done per class
714
+ keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
687
715
 
688
- # Keep only topk scoring predictions
689
- keep = keep[: self.detections_per_img]
690
- boxes = boxes[keep]
691
- scores = scores[keep]
692
- labels = labels[keep]
716
+ # Keep only topk scoring predictions
717
+ keep = keep[: self.detections_per_img]
718
+ boxes = boxes[keep]
719
+ scores = scores[keep]
720
+ labels = labels[keep]
693
721
 
694
722
  all_boxes.append(boxes)
695
723
  all_scores.append(scores)
@@ -701,7 +729,7 @@ class RoIHeads(nn.Module):
701
729
  self,
702
730
  features: dict[str, torch.Tensor],
703
731
  proposals: list[torch.Tensor],
704
- image_shapes: list[tuple[int, int]],
732
+ image_sizes: torch.Tensor,
705
733
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
706
734
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
707
735
  if targets is not None:
@@ -719,6 +747,8 @@ class RoIHeads(nn.Module):
719
747
  regression_targets = None
720
748
  _matched_idxs = None # noqa: F841
721
749
 
750
+ image_shapes: list[tuple[int, int]] = [(int(s[0]), int(s[1])) for s in image_sizes] # TorchScript
751
+
722
752
  box_features = self.box_roi_pool(features, proposals, image_shapes)
723
753
  box_features = self.box_head(box_features)
724
754
  class_logits, box_regression = self.box_predictor(box_features)
@@ -735,7 +765,7 @@ class RoIHeads(nn.Module):
735
765
  losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
736
766
 
737
767
  else:
738
- boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
768
+ boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_sizes)
739
769
  num_images = len(boxes)
740
770
  for i in range(num_images):
741
771
  result.append(
@@ -753,6 +783,7 @@ class RoIHeads(nn.Module):
753
783
  class Faster_RCNN(DetectionBaseNet):
754
784
  default_size = (640, 640)
755
785
  auto_register = True
786
+ exportable = False
756
787
 
757
788
  # pylint: disable=too-many-locals
758
789
  def __init__(
@@ -812,6 +843,7 @@ class Faster_RCNN(DetectionBaseNet):
812
843
  rpn_post_nms_top_n,
813
844
  rpn_nms_thresh,
814
845
  score_thresh=rpn_score_thresh,
846
+ export_mode=self.export_mode,
815
847
  )
816
848
  box_roi_pool = MultiScaleRoIAlign(
817
849
  featmap_names=self.backbone.return_stages,
@@ -862,7 +894,7 @@ class Faster_RCNN(DetectionBaseNet):
862
894
  x: torch.Tensor,
863
895
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
864
896
  masks: Optional[torch.Tensor] = None,
865
- image_sizes: Optional[list[list[int]]] = None,
897
+ image_sizes: Optional[list[tuple[int, int]]] = None,
866
898
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
867
899
  self._input_check(targets)
868
900
  images = self._to_img_list(x, image_sizes)
@@ -26,6 +26,7 @@ from birder.net.base import DetectorBackbone
26
26
  from birder.net.detection.base import AnchorGenerator
27
27
  from birder.net.detection.base import BackboneWithFPN
28
28
  from birder.net.detection.base import DetectionBaseNet
29
+ from birder.net.detection.base import clip_boxes_to_image
29
30
  from birder.ops.soft_nms import SoftNMS
30
31
 
31
32
 
@@ -326,6 +327,9 @@ class FCOS(DetectionBaseNet):
326
327
 
327
328
  self.box_coder = BoxLinearCoder(normalize_by_size=True)
328
329
 
330
+ if self.export_mode is False:
331
+ self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
332
+
329
333
  def reset_classifier(self, num_classes: int) -> None:
330
334
  self.num_classes = num_classes
331
335
 
@@ -344,6 +348,7 @@ class FCOS(DetectionBaseNet):
344
348
  for param in self.head.classification_head.parameters():
345
349
  param.requires_grad_(True)
346
350
 
351
+ @torch.compiler.disable() # type: ignore[untyped-decorator]
347
352
  def compute_loss(
348
353
  self,
349
354
  targets: list[dict[str, torch.Tensor]],
@@ -400,20 +405,20 @@ class FCOS(DetectionBaseNet):
400
405
  self,
401
406
  head_outputs: dict[str, list[torch.Tensor]],
402
407
  anchors: list[list[torch.Tensor]],
403
- image_shapes: list[tuple[int, int]],
408
+ image_sizes: torch.Tensor,
404
409
  ) -> list[dict[str, torch.Tensor]]:
405
410
  class_logits = head_outputs["cls_logits"]
406
411
  box_regression = head_outputs["bbox_regression"]
407
412
  box_ctrness = head_outputs["bbox_ctrness"]
408
413
 
409
- num_images = len(image_shapes)
414
+ num_images = image_sizes.size(0)
410
415
  detections: list[dict[str, torch.Tensor]] = []
411
416
  for index in range(num_images):
412
417
  box_regression_per_image = [br[index] for br in box_regression]
413
418
  logits_per_image = [cl[index] for cl in class_logits]
414
419
  box_ctrness_per_image = [bc[index] for bc in box_ctrness]
415
420
  anchors_per_image = anchors[index]
416
- image_shape = image_shapes[index]
421
+ image_shape = image_sizes[index]
417
422
 
418
423
  image_boxes_list = []
419
424
  image_scores_list = []
@@ -432,7 +437,7 @@ class FCOS(DetectionBaseNet):
432
437
  topk_idxs = torch.where(keep_idxs)[0]
433
438
 
434
439
  # Keep only topk scoring predictions
435
- num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
440
+ num_topk = min(self.topk_candidates, topk_idxs.size(0))
436
441
  scores_per_level, idxs = scores_per_level.topk(num_topk)
437
442
  topk_idxs = topk_idxs[idxs]
438
443
 
@@ -443,7 +448,7 @@ class FCOS(DetectionBaseNet):
443
448
  boxes_per_level = self.box_coder.decode(
444
449
  box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
445
450
  )
446
- boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
451
+ boxes_per_level = clip_boxes_to_image(boxes_per_level, image_shape)
447
452
 
448
453
  image_boxes_list.append(boxes_per_level)
449
454
  image_scores_list.append(scores_per_level)
@@ -453,38 +458,52 @@ class FCOS(DetectionBaseNet):
453
458
  image_scores = torch.concat(image_scores_list, dim=0)
454
459
  image_labels = torch.concat(image_labels_list, dim=0)
455
460
 
456
- # Non-maximum suppression
457
- if self.soft_nms is not None:
458
- soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
459
- image_scores[keep] = soft_scores
461
+ if self.export_mode is False:
462
+ # Non-maximum suppression
463
+ if self.soft_nms is not None:
464
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
465
+ image_scores[keep] = soft_scores
466
+ else:
467
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
468
+
469
+ keep = keep[: self.detections_per_img]
470
+
471
+ detections.append(
472
+ {
473
+ "boxes": image_boxes[keep],
474
+ "scores": image_scores[keep],
475
+ "labels": image_labels[keep],
476
+ }
477
+ )
460
478
  else:
461
- keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
479
+ detections.append(
480
+ {
481
+ "boxes": image_boxes,
482
+ "scores": image_scores,
483
+ "labels": image_labels,
484
+ }
485
+ )
462
486
 
463
- keep = keep[: self.detections_per_img]
487
+ return detections
464
488
 
465
- detections.append(
466
- {
467
- "boxes": image_boxes[keep],
468
- "scores": image_scores[keep],
469
- "labels": image_labels[keep],
470
- }
471
- )
489
+ def forward_net(self, x: torch.Tensor) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
490
+ features: dict[str, torch.Tensor] = self.backbone_with_fpn(x)
491
+ feature_list = list(features.values())
492
+ head_outputs = self.head(feature_list)
472
493
 
473
- return detections
494
+ return (feature_list, head_outputs)
474
495
 
475
496
  def forward(
476
497
  self,
477
498
  x: torch.Tensor,
478
499
  targets: Optional[list[dict[str, torch.Tensor]]] = None,
479
500
  masks: Optional[torch.Tensor] = None,
480
- image_sizes: Optional[list[list[int]]] = None,
501
+ image_sizes: Optional[list[tuple[int, int]]] = None,
481
502
  ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
482
503
  self._input_check(targets)
483
504
  images = self._to_img_list(x, image_sizes)
484
505
 
485
- features: dict[str, torch.Tensor] = self.backbone_with_fpn(x)
486
- feature_list = list(features.values())
487
- head_outputs = self.head(feature_list)
506
+ feature_list, head_outputs = self.forward_net(x)
488
507
  anchors = self.anchor_generator(images, feature_list)
489
508
 
490
509
  # recover level sizes