birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,7 @@ from birder.net.detection.base import AnchorGenerator
|
|
|
27
27
|
from birder.net.detection.base import BoxCoder
|
|
28
28
|
from birder.net.detection.base import DetectionBaseNet
|
|
29
29
|
from birder.net.detection.base import Matcher
|
|
30
|
+
from birder.net.detection.base import clip_boxes_to_image
|
|
30
31
|
from birder.ops.soft_nms import SoftNMS
|
|
31
32
|
|
|
32
33
|
|
|
@@ -588,6 +589,8 @@ class EfficientDet(DetectionBaseNet):
|
|
|
588
589
|
for param in self.class_net.parameters():
|
|
589
590
|
param.requires_grad_(True)
|
|
590
591
|
|
|
592
|
+
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
593
|
+
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
591
594
|
def compute_loss(
|
|
592
595
|
self,
|
|
593
596
|
targets: list[dict[str, torch.Tensor]],
|
|
@@ -617,16 +620,16 @@ class EfficientDet(DetectionBaseNet):
|
|
|
617
620
|
class_logits: list[torch.Tensor],
|
|
618
621
|
box_regression: list[torch.Tensor],
|
|
619
622
|
anchors: list[list[torch.Tensor]],
|
|
620
|
-
|
|
623
|
+
image_sizes: torch.Tensor,
|
|
621
624
|
) -> list[dict[str, torch.Tensor]]:
|
|
622
|
-
num_images =
|
|
625
|
+
num_images = image_sizes.size(0)
|
|
623
626
|
|
|
624
627
|
detections: list[dict[str, torch.Tensor]] = []
|
|
625
628
|
for index in range(num_images):
|
|
626
629
|
box_regression_per_image = [br[index] for br in box_regression]
|
|
627
630
|
logits_per_image = [cl[index] for cl in class_logits]
|
|
628
631
|
anchors_per_image = anchors[index]
|
|
629
|
-
image_shape =
|
|
632
|
+
image_shape = image_sizes[index]
|
|
630
633
|
|
|
631
634
|
image_boxes_list = []
|
|
632
635
|
image_scores_list = []
|
|
@@ -643,7 +646,7 @@ class EfficientDet(DetectionBaseNet):
|
|
|
643
646
|
topk_idxs = torch.where(keep_idxs)[0]
|
|
644
647
|
|
|
645
648
|
# Keep only topk scoring predictions
|
|
646
|
-
num_topk = min(self.topk_candidates,
|
|
649
|
+
num_topk = min(self.topk_candidates, topk_idxs.size(0))
|
|
647
650
|
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
648
651
|
topk_idxs = topk_idxs[idxs]
|
|
649
652
|
|
|
@@ -654,7 +657,7 @@ class EfficientDet(DetectionBaseNet):
|
|
|
654
657
|
boxes_per_level = self.box_coder.decode_single(
|
|
655
658
|
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
|
656
659
|
)
|
|
657
|
-
boxes_per_level =
|
|
660
|
+
boxes_per_level = clip_boxes_to_image(boxes_per_level, image_shape)
|
|
658
661
|
|
|
659
662
|
image_boxes_list.append(boxes_per_level)
|
|
660
663
|
image_scores_list.append(scores_per_level)
|
|
@@ -664,24 +667,42 @@ class EfficientDet(DetectionBaseNet):
|
|
|
664
667
|
image_scores = torch.concat(image_scores_list, dim=0)
|
|
665
668
|
image_labels = torch.concat(image_labels_list, dim=0)
|
|
666
669
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
670
|
+
if self.export_mode is False:
|
|
671
|
+
# Non-maximum suppression
|
|
672
|
+
if self.soft_nms is not None:
|
|
673
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
674
|
+
image_scores[keep] = soft_scores
|
|
675
|
+
else:
|
|
676
|
+
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
677
|
+
|
|
678
|
+
keep = keep[: self.detections_per_img]
|
|
679
|
+
|
|
680
|
+
detections.append(
|
|
681
|
+
{
|
|
682
|
+
"boxes": image_boxes[keep],
|
|
683
|
+
"scores": image_scores[keep],
|
|
684
|
+
"labels": image_labels[keep],
|
|
685
|
+
}
|
|
686
|
+
)
|
|
671
687
|
else:
|
|
672
|
-
|
|
688
|
+
detections.append(
|
|
689
|
+
{
|
|
690
|
+
"boxes": image_boxes,
|
|
691
|
+
"scores": image_scores,
|
|
692
|
+
"labels": image_labels,
|
|
693
|
+
}
|
|
694
|
+
)
|
|
673
695
|
|
|
674
|
-
|
|
696
|
+
return detections
|
|
675
697
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
)
|
|
698
|
+
def forward_net(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]:
|
|
699
|
+
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
700
|
+
feature_list = list(features.values())
|
|
701
|
+
feature_list = self.bifpn(feature_list)
|
|
702
|
+
cls_logits = self.class_net(feature_list)
|
|
703
|
+
box_output = self.box_net(feature_list)
|
|
683
704
|
|
|
684
|
-
return
|
|
705
|
+
return (cls_logits, box_output, feature_list)
|
|
685
706
|
|
|
686
707
|
# pylint: disable=invalid-name
|
|
687
708
|
def forward(
|
|
@@ -689,16 +710,12 @@ class EfficientDet(DetectionBaseNet):
|
|
|
689
710
|
x: torch.Tensor,
|
|
690
711
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
691
712
|
masks: Optional[torch.Tensor] = None,
|
|
692
|
-
image_sizes: Optional[list[
|
|
713
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
693
714
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
694
715
|
self._input_check(targets)
|
|
695
716
|
images = self._to_img_list(x, image_sizes)
|
|
696
717
|
|
|
697
|
-
|
|
698
|
-
feature_list = list(features.values())
|
|
699
|
-
feature_list = self.bifpn(feature_list)
|
|
700
|
-
cls_logits = self.class_net(feature_list)
|
|
701
|
-
box_output = self.box_net(feature_list)
|
|
718
|
+
cls_logits, box_output, feature_list = self.forward_net(x)
|
|
702
719
|
anchors = self.anchor_generator(images, feature_list)
|
|
703
720
|
|
|
704
721
|
losses: dict[str, torch.Tensor] = {}
|
|
@@ -27,6 +27,7 @@ from birder.net.detection.base import BoxCoder
|
|
|
27
27
|
from birder.net.detection.base import DetectionBaseNet
|
|
28
28
|
from birder.net.detection.base import ImageList
|
|
29
29
|
from birder.net.detection.base import Matcher
|
|
30
|
+
from birder.net.detection.base import clip_boxes_to_image
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
class BalancedPositiveNegativeSampler:
|
|
@@ -169,6 +170,24 @@ def concat_box_prediction_layers(
|
|
|
169
170
|
return (box_cls, box_regression)
|
|
170
171
|
|
|
171
172
|
|
|
173
|
+
def _batched_nms_coordinate_trick(
|
|
174
|
+
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
|
|
175
|
+
) -> torch.Tensor:
|
|
176
|
+
"""
|
|
177
|
+
Batched NMS using coordinate offset trick (same as torchvision)
|
|
178
|
+
|
|
179
|
+
Separates boxes from different classes by adding class-specific offsets,
|
|
180
|
+
then runs standard NMS on all boxes together.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
max_coordinate = boxes.max()
|
|
184
|
+
offsets = idxs.to(boxes) * (max_coordinate + 1)
|
|
185
|
+
boxes_for_nms = boxes + offsets[:, None]
|
|
186
|
+
keep = box_ops.nms(boxes_for_nms, scores, iou_threshold)
|
|
187
|
+
|
|
188
|
+
return keep
|
|
189
|
+
|
|
190
|
+
|
|
172
191
|
class RegionProposalNetwork(nn.Module):
|
|
173
192
|
def __init__(
|
|
174
193
|
self,
|
|
@@ -184,6 +203,8 @@ class RegionProposalNetwork(nn.Module):
|
|
|
184
203
|
post_nms_top_n: dict[str, int],
|
|
185
204
|
nms_thresh: float,
|
|
186
205
|
score_thresh: float = 0.0,
|
|
206
|
+
# Other
|
|
207
|
+
export_mode: bool = False,
|
|
187
208
|
) -> None:
|
|
188
209
|
super().__init__()
|
|
189
210
|
self.anchor_generator = anchor_generator
|
|
@@ -206,6 +227,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
206
227
|
self.nms_thresh = nms_thresh
|
|
207
228
|
self.score_thresh = score_thresh
|
|
208
229
|
self.min_size = 1e-3
|
|
230
|
+
self.export_mode = export_mode
|
|
209
231
|
|
|
210
232
|
def pre_nms_top_n(self) -> int:
|
|
211
233
|
if self.training is True:
|
|
@@ -275,7 +297,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
275
297
|
self,
|
|
276
298
|
proposals: torch.Tensor,
|
|
277
299
|
objectness: torch.Tensor,
|
|
278
|
-
|
|
300
|
+
image_sizes: torch.Tensor,
|
|
279
301
|
num_anchors_per_level: list[int],
|
|
280
302
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
|
281
303
|
num_images = proposals.shape[0]
|
|
@@ -305,8 +327,8 @@ class RegionProposalNetwork(nn.Module):
|
|
|
305
327
|
|
|
306
328
|
final_boxes = []
|
|
307
329
|
final_scores = []
|
|
308
|
-
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels,
|
|
309
|
-
boxes =
|
|
330
|
+
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_sizes):
|
|
331
|
+
boxes = clip_boxes_to_image(boxes, img_shape)
|
|
310
332
|
|
|
311
333
|
# Remove small boxes
|
|
312
334
|
keep = box_ops.remove_small_boxes(boxes, self.min_size)
|
|
@@ -317,13 +339,17 @@ class RegionProposalNetwork(nn.Module):
|
|
|
317
339
|
keep = torch.where(scores >= self.score_thresh)[0]
|
|
318
340
|
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
|
|
319
341
|
|
|
320
|
-
|
|
321
|
-
|
|
342
|
+
if self.export_mode is False:
|
|
343
|
+
# Non-maximum suppression, independently done per level
|
|
344
|
+
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
|
|
322
345
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
346
|
+
# Keep only topk scoring predictions
|
|
347
|
+
keep = keep[: self.post_nms_top_n()]
|
|
348
|
+
else:
|
|
349
|
+
keep = _batched_nms_coordinate_trick(boxes, scores, lvl, self.nms_thresh)
|
|
326
350
|
|
|
351
|
+
boxes = boxes[keep]
|
|
352
|
+
scores = scores[keep]
|
|
327
353
|
final_boxes.append(boxes)
|
|
328
354
|
final_scores.append(scores)
|
|
329
355
|
|
|
@@ -531,6 +557,7 @@ class RoIHeads(nn.Module):
|
|
|
531
557
|
self.score_thresh = score_thresh
|
|
532
558
|
self.nms_thresh = nms_thresh
|
|
533
559
|
self.detections_per_img = detections_per_img
|
|
560
|
+
self.export_mode = export_mode
|
|
534
561
|
|
|
535
562
|
if export_mode is False:
|
|
536
563
|
self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
|
|
@@ -637,7 +664,7 @@ class RoIHeads(nn.Module):
|
|
|
637
664
|
class_logits: torch.Tensor,
|
|
638
665
|
box_regression: torch.Tensor,
|
|
639
666
|
proposals: list[torch.Tensor],
|
|
640
|
-
|
|
667
|
+
image_sizes: torch.Tensor,
|
|
641
668
|
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
|
|
642
669
|
device = class_logits.device
|
|
643
670
|
num_classes = class_logits.shape[-1]
|
|
@@ -653,8 +680,8 @@ class RoIHeads(nn.Module):
|
|
|
653
680
|
all_boxes = []
|
|
654
681
|
all_scores = []
|
|
655
682
|
all_labels = []
|
|
656
|
-
for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list,
|
|
657
|
-
boxes =
|
|
683
|
+
for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_sizes):
|
|
684
|
+
boxes = clip_boxes_to_image(boxes, image_shape)
|
|
658
685
|
|
|
659
686
|
# Create labels for each prediction
|
|
660
687
|
labels = torch.arange(num_classes, device=device)
|
|
@@ -682,14 +709,15 @@ class RoIHeads(nn.Module):
|
|
|
682
709
|
scores = scores[keep]
|
|
683
710
|
labels = labels[keep]
|
|
684
711
|
|
|
685
|
-
|
|
686
|
-
|
|
712
|
+
if self.export_mode is False:
|
|
713
|
+
# Non-maximum suppression, independently done per class
|
|
714
|
+
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
|
|
687
715
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
716
|
+
# Keep only topk scoring predictions
|
|
717
|
+
keep = keep[: self.detections_per_img]
|
|
718
|
+
boxes = boxes[keep]
|
|
719
|
+
scores = scores[keep]
|
|
720
|
+
labels = labels[keep]
|
|
693
721
|
|
|
694
722
|
all_boxes.append(boxes)
|
|
695
723
|
all_scores.append(scores)
|
|
@@ -701,7 +729,7 @@ class RoIHeads(nn.Module):
|
|
|
701
729
|
self,
|
|
702
730
|
features: dict[str, torch.Tensor],
|
|
703
731
|
proposals: list[torch.Tensor],
|
|
704
|
-
|
|
732
|
+
image_sizes: torch.Tensor,
|
|
705
733
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
706
734
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
707
735
|
if targets is not None:
|
|
@@ -719,6 +747,8 @@ class RoIHeads(nn.Module):
|
|
|
719
747
|
regression_targets = None
|
|
720
748
|
_matched_idxs = None # noqa: F841
|
|
721
749
|
|
|
750
|
+
image_shapes: list[tuple[int, int]] = [(int(s[0]), int(s[1])) for s in image_sizes] # TorchScript
|
|
751
|
+
|
|
722
752
|
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
|
723
753
|
box_features = self.box_head(box_features)
|
|
724
754
|
class_logits, box_regression = self.box_predictor(box_features)
|
|
@@ -735,7 +765,7 @@ class RoIHeads(nn.Module):
|
|
|
735
765
|
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
|
736
766
|
|
|
737
767
|
else:
|
|
738
|
-
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals,
|
|
768
|
+
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_sizes)
|
|
739
769
|
num_images = len(boxes)
|
|
740
770
|
for i in range(num_images):
|
|
741
771
|
result.append(
|
|
@@ -753,6 +783,7 @@ class RoIHeads(nn.Module):
|
|
|
753
783
|
class Faster_RCNN(DetectionBaseNet):
|
|
754
784
|
default_size = (640, 640)
|
|
755
785
|
auto_register = True
|
|
786
|
+
exportable = False
|
|
756
787
|
|
|
757
788
|
# pylint: disable=too-many-locals
|
|
758
789
|
def __init__(
|
|
@@ -812,6 +843,7 @@ class Faster_RCNN(DetectionBaseNet):
|
|
|
812
843
|
rpn_post_nms_top_n,
|
|
813
844
|
rpn_nms_thresh,
|
|
814
845
|
score_thresh=rpn_score_thresh,
|
|
846
|
+
export_mode=self.export_mode,
|
|
815
847
|
)
|
|
816
848
|
box_roi_pool = MultiScaleRoIAlign(
|
|
817
849
|
featmap_names=self.backbone.return_stages,
|
|
@@ -862,7 +894,7 @@ class Faster_RCNN(DetectionBaseNet):
|
|
|
862
894
|
x: torch.Tensor,
|
|
863
895
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
864
896
|
masks: Optional[torch.Tensor] = None,
|
|
865
|
-
image_sizes: Optional[list[
|
|
897
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
866
898
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
867
899
|
self._input_check(targets)
|
|
868
900
|
images = self._to_img_list(x, image_sizes)
|
birder/net/detection/fcos.py
CHANGED
|
@@ -26,6 +26,7 @@ from birder.net.base import DetectorBackbone
|
|
|
26
26
|
from birder.net.detection.base import AnchorGenerator
|
|
27
27
|
from birder.net.detection.base import BackboneWithFPN
|
|
28
28
|
from birder.net.detection.base import DetectionBaseNet
|
|
29
|
+
from birder.net.detection.base import clip_boxes_to_image
|
|
29
30
|
from birder.ops.soft_nms import SoftNMS
|
|
30
31
|
|
|
31
32
|
|
|
@@ -326,6 +327,9 @@ class FCOS(DetectionBaseNet):
|
|
|
326
327
|
|
|
327
328
|
self.box_coder = BoxLinearCoder(normalize_by_size=True)
|
|
328
329
|
|
|
330
|
+
if self.export_mode is False:
|
|
331
|
+
self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
|
|
332
|
+
|
|
329
333
|
def reset_classifier(self, num_classes: int) -> None:
|
|
330
334
|
self.num_classes = num_classes
|
|
331
335
|
|
|
@@ -344,6 +348,7 @@ class FCOS(DetectionBaseNet):
|
|
|
344
348
|
for param in self.head.classification_head.parameters():
|
|
345
349
|
param.requires_grad_(True)
|
|
346
350
|
|
|
351
|
+
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
347
352
|
def compute_loss(
|
|
348
353
|
self,
|
|
349
354
|
targets: list[dict[str, torch.Tensor]],
|
|
@@ -400,20 +405,20 @@ class FCOS(DetectionBaseNet):
|
|
|
400
405
|
self,
|
|
401
406
|
head_outputs: dict[str, list[torch.Tensor]],
|
|
402
407
|
anchors: list[list[torch.Tensor]],
|
|
403
|
-
|
|
408
|
+
image_sizes: torch.Tensor,
|
|
404
409
|
) -> list[dict[str, torch.Tensor]]:
|
|
405
410
|
class_logits = head_outputs["cls_logits"]
|
|
406
411
|
box_regression = head_outputs["bbox_regression"]
|
|
407
412
|
box_ctrness = head_outputs["bbox_ctrness"]
|
|
408
413
|
|
|
409
|
-
num_images =
|
|
414
|
+
num_images = image_sizes.size(0)
|
|
410
415
|
detections: list[dict[str, torch.Tensor]] = []
|
|
411
416
|
for index in range(num_images):
|
|
412
417
|
box_regression_per_image = [br[index] for br in box_regression]
|
|
413
418
|
logits_per_image = [cl[index] for cl in class_logits]
|
|
414
419
|
box_ctrness_per_image = [bc[index] for bc in box_ctrness]
|
|
415
420
|
anchors_per_image = anchors[index]
|
|
416
|
-
image_shape =
|
|
421
|
+
image_shape = image_sizes[index]
|
|
417
422
|
|
|
418
423
|
image_boxes_list = []
|
|
419
424
|
image_scores_list = []
|
|
@@ -432,7 +437,7 @@ class FCOS(DetectionBaseNet):
|
|
|
432
437
|
topk_idxs = torch.where(keep_idxs)[0]
|
|
433
438
|
|
|
434
439
|
# Keep only topk scoring predictions
|
|
435
|
-
num_topk = min(self.topk_candidates,
|
|
440
|
+
num_topk = min(self.topk_candidates, topk_idxs.size(0))
|
|
436
441
|
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
437
442
|
topk_idxs = topk_idxs[idxs]
|
|
438
443
|
|
|
@@ -443,7 +448,7 @@ class FCOS(DetectionBaseNet):
|
|
|
443
448
|
boxes_per_level = self.box_coder.decode(
|
|
444
449
|
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
|
445
450
|
)
|
|
446
|
-
boxes_per_level =
|
|
451
|
+
boxes_per_level = clip_boxes_to_image(boxes_per_level, image_shape)
|
|
447
452
|
|
|
448
453
|
image_boxes_list.append(boxes_per_level)
|
|
449
454
|
image_scores_list.append(scores_per_level)
|
|
@@ -453,38 +458,52 @@ class FCOS(DetectionBaseNet):
|
|
|
453
458
|
image_scores = torch.concat(image_scores_list, dim=0)
|
|
454
459
|
image_labels = torch.concat(image_labels_list, dim=0)
|
|
455
460
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
461
|
+
if self.export_mode is False:
|
|
462
|
+
# Non-maximum suppression
|
|
463
|
+
if self.soft_nms is not None:
|
|
464
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
465
|
+
image_scores[keep] = soft_scores
|
|
466
|
+
else:
|
|
467
|
+
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
468
|
+
|
|
469
|
+
keep = keep[: self.detections_per_img]
|
|
470
|
+
|
|
471
|
+
detections.append(
|
|
472
|
+
{
|
|
473
|
+
"boxes": image_boxes[keep],
|
|
474
|
+
"scores": image_scores[keep],
|
|
475
|
+
"labels": image_labels[keep],
|
|
476
|
+
}
|
|
477
|
+
)
|
|
460
478
|
else:
|
|
461
|
-
|
|
479
|
+
detections.append(
|
|
480
|
+
{
|
|
481
|
+
"boxes": image_boxes,
|
|
482
|
+
"scores": image_scores,
|
|
483
|
+
"labels": image_labels,
|
|
484
|
+
}
|
|
485
|
+
)
|
|
462
486
|
|
|
463
|
-
|
|
487
|
+
return detections
|
|
464
488
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
"labels": image_labels[keep],
|
|
470
|
-
}
|
|
471
|
-
)
|
|
489
|
+
def forward_net(self, x: torch.Tensor) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
|
|
490
|
+
features: dict[str, torch.Tensor] = self.backbone_with_fpn(x)
|
|
491
|
+
feature_list = list(features.values())
|
|
492
|
+
head_outputs = self.head(feature_list)
|
|
472
493
|
|
|
473
|
-
return
|
|
494
|
+
return (feature_list, head_outputs)
|
|
474
495
|
|
|
475
496
|
def forward(
|
|
476
497
|
self,
|
|
477
498
|
x: torch.Tensor,
|
|
478
499
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
479
500
|
masks: Optional[torch.Tensor] = None,
|
|
480
|
-
image_sizes: Optional[list[
|
|
501
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
481
502
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
482
503
|
self._input_check(targets)
|
|
483
504
|
images = self._to_img_list(x, image_sizes)
|
|
484
505
|
|
|
485
|
-
|
|
486
|
-
feature_list = list(features.values())
|
|
487
|
-
head_outputs = self.head(feature_list)
|
|
506
|
+
feature_list, head_outputs = self.forward_net(x)
|
|
488
507
|
anchors = self.anchor_generator(images, feature_list)
|
|
489
508
|
|
|
490
509
|
# recover level sizes
|