birder 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/__init__.py +2 -0
- birder/common/fs_ops.py +81 -1
- birder/common/training_cli.py +12 -2
- birder/common/training_utils.py +73 -12
- birder/data/collators/detection.py +3 -1
- birder/datahub/_lib.py +15 -6
- birder/datahub/evaluation.py +591 -0
- birder/eval/__init__.py +0 -0
- birder/eval/__main__.py +74 -0
- birder/eval/_embeddings.py +50 -0
- birder/eval/adversarial.py +315 -0
- birder/eval/benchmarks/__init__.py +0 -0
- birder/eval/benchmarks/awa2.py +357 -0
- birder/eval/benchmarks/bioscan5m.py +198 -0
- birder/eval/benchmarks/fishnet.py +318 -0
- birder/eval/benchmarks/flowers102.py +210 -0
- birder/eval/benchmarks/fungiclef.py +261 -0
- birder/eval/benchmarks/nabirds.py +202 -0
- birder/eval/benchmarks/newt.py +262 -0
- birder/eval/benchmarks/plankton.py +255 -0
- birder/eval/benchmarks/plantdoc.py +259 -0
- birder/eval/benchmarks/plantnet.py +252 -0
- birder/eval/classification.py +235 -0
- birder/eval/methods/__init__.py +0 -0
- birder/eval/methods/ami.py +78 -0
- birder/eval/methods/knn.py +71 -0
- birder/eval/methods/linear.py +152 -0
- birder/eval/methods/mlp.py +178 -0
- birder/eval/methods/simpleshot.py +100 -0
- birder/eval/methods/svm.py +92 -0
- birder/inference/classification.py +23 -2
- birder/inference/detection.py +35 -15
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/cswin_transformer.py +2 -1
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/base.py +41 -18
- birder/net/detection/deformable_detr.py +74 -50
- birder/net/detection/detr.py +29 -26
- birder/net/detection/efficientdet.py +42 -25
- birder/net/detection/faster_rcnn.py +53 -21
- birder/net/detection/fcos.py +42 -23
- birder/net/detection/lw_detr.py +1204 -0
- birder/net/detection/plain_detr.py +60 -47
- birder/net/detection/retinanet.py +47 -35
- birder/net/detection/rt_detr_v1.py +49 -46
- birder/net/detection/rt_detr_v2.py +95 -102
- birder/net/detection/ssd.py +47 -31
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +33 -18
- birder/net/detection/yolo_v3.py +35 -33
- birder/net/detection/yolo_v4.py +35 -20
- birder/net/detection/yolo_v4_tiny.py +1 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hiera.py +44 -67
- birder/net/hieradet.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/mim/fcmae.py +2 -2
- birder/net/mim/mae_hiera.py +9 -16
- birder/net/mnasnet.py +2 -2
- birder/net/nextvit.py +4 -4
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +2 -2
- birder/net/rope_flexivit.py +2 -2
- birder/net/rope_vit.py +2 -2
- birder/net/simple_vit.py +1 -1
- birder/net/squeezenet.py +1 -1
- birder/net/ssl/capi.py +32 -25
- birder/net/ssl/dino_v2.py +12 -15
- birder/net/ssl/franca.py +26 -19
- birder/net/van.py +2 -2
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/net/xcit.py +1 -1
- birder/ops/msda.py +46 -16
- birder/scripts/benchmark.py +35 -8
- birder/scripts/predict.py +14 -1
- birder/scripts/predict_detection.py +7 -1
- birder/scripts/train.py +27 -11
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +18 -7
- birder/scripts/train_franca.py +10 -2
- birder/scripts/train_kd.py +28 -11
- birder/tools/adversarial.py +5 -0
- birder/tools/convert_model.py +101 -43
- birder/tools/quantize_model.py +33 -16
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/METADATA +17 -10
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/RECORD +99 -75
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/WHEEL +1 -1
- birder/scripts/evaluate.py +0 -176
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.4.dist-info}/top_level.txt +0 -0
|
@@ -301,14 +301,11 @@ class GlobalDecoderLayer(nn.Module):
|
|
|
301
301
|
|
|
302
302
|
|
|
303
303
|
class GlobalDecoder(nn.Module):
|
|
304
|
-
def __init__(
|
|
305
|
-
self, decoder_layer: nn.Module, num_layers: int, norm: nn.Module, return_intermediate: bool, d_model: int
|
|
306
|
-
) -> None:
|
|
304
|
+
def __init__(self, decoder_layer: nn.Module, num_layers: int, norm: nn.Module, d_model: int) -> None:
|
|
307
305
|
super().__init__()
|
|
308
306
|
self.layers = _get_clones(decoder_layer, num_layers)
|
|
309
307
|
self.num_layers = num_layers
|
|
310
308
|
self.norm = norm
|
|
311
|
-
self.return_intermediate = return_intermediate
|
|
312
309
|
self.d_model = d_model
|
|
313
310
|
|
|
314
311
|
self.bbox_embed: Optional[nn.ModuleList] = None
|
|
@@ -339,6 +336,7 @@ class GlobalDecoder(nn.Module):
|
|
|
339
336
|
reference_points: torch.Tensor,
|
|
340
337
|
spatial_shape: tuple[int, int],
|
|
341
338
|
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
|
339
|
+
return_intermediates: bool = True,
|
|
342
340
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
343
341
|
output = tgt
|
|
344
342
|
intermediate = []
|
|
@@ -364,14 +362,14 @@ class GlobalDecoder(nn.Module):
|
|
|
364
362
|
new_reference_points = new_reference_points.sigmoid()
|
|
365
363
|
reference_points = new_reference_points.detach()
|
|
366
364
|
|
|
367
|
-
if
|
|
365
|
+
if return_intermediates is True:
|
|
368
366
|
intermediate.append(output_for_pred)
|
|
369
367
|
intermediate_reference_points.append(new_reference_points)
|
|
370
368
|
|
|
371
|
-
if
|
|
369
|
+
if return_intermediates is True:
|
|
372
370
|
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
|
|
373
371
|
|
|
374
|
-
return output_for_pred
|
|
372
|
+
return output_for_pred, new_reference_points
|
|
375
373
|
|
|
376
374
|
for layer in self.layers:
|
|
377
375
|
reference_points_input = reference_points.detach().clamp(0, 1)
|
|
@@ -388,14 +386,14 @@ class GlobalDecoder(nn.Module):
|
|
|
388
386
|
|
|
389
387
|
output_for_pred = self.norm(output)
|
|
390
388
|
|
|
391
|
-
if
|
|
389
|
+
if return_intermediates is True:
|
|
392
390
|
intermediate.append(output_for_pred)
|
|
393
391
|
intermediate_reference_points.append(reference_points)
|
|
394
392
|
|
|
395
|
-
if
|
|
393
|
+
if return_intermediates is True:
|
|
396
394
|
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
|
|
397
395
|
|
|
398
|
-
return output_for_pred
|
|
396
|
+
return output_for_pred, reference_points
|
|
399
397
|
|
|
400
398
|
|
|
401
399
|
class TransformerEncoderLayer(nn.Module):
|
|
@@ -467,7 +465,6 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
467
465
|
hidden_dim = 256
|
|
468
466
|
num_heads = 8
|
|
469
467
|
dropout = 0.0
|
|
470
|
-
return_intermediate = True
|
|
471
468
|
dim_feedforward: int = self.config.get("dim_feedforward", 2048)
|
|
472
469
|
num_encoder_layers: int = self.config["num_encoder_layers"]
|
|
473
470
|
num_decoder_layers: int = self.config["num_decoder_layers"]
|
|
@@ -516,19 +513,18 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
516
513
|
decoder_layer,
|
|
517
514
|
num_decoder_layers,
|
|
518
515
|
decoder_norm,
|
|
519
|
-
return_intermediate=return_intermediate,
|
|
520
516
|
d_model=hidden_dim,
|
|
521
517
|
)
|
|
522
518
|
|
|
523
519
|
self.class_embed = nn.Linear(hidden_dim, self.num_classes)
|
|
524
520
|
self.bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
|
|
525
|
-
self.query_embed = nn.
|
|
521
|
+
self.query_embed = nn.Parameter(torch.empty(self.num_queries, hidden_dim * 2))
|
|
526
522
|
self.reference_point_head = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
|
|
527
523
|
self.input_proj = nn.Conv2d(
|
|
528
524
|
self.backbone.return_channels[-1], hidden_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
529
525
|
)
|
|
530
526
|
self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
|
|
531
|
-
self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
|
|
527
|
+
self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
|
|
532
528
|
|
|
533
529
|
if box_refine is True:
|
|
534
530
|
self.class_embed = _get_clones(self.class_embed, num_decoder_layers)
|
|
@@ -554,6 +550,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
554
550
|
if idx == 0:
|
|
555
551
|
nn.init.constant_(last_linear.bias[2:], -2.0) # Small initial wh
|
|
556
552
|
|
|
553
|
+
nn.init.normal_(self.query_embed)
|
|
557
554
|
ref_last_linear = [m for m in self.reference_point_head.modules() if isinstance(m, nn.Linear)][-1]
|
|
558
555
|
nn.init.zeros_(ref_last_linear.weight)
|
|
559
556
|
nn.init.zeros_(ref_last_linear.bias)
|
|
@@ -576,7 +573,8 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
576
573
|
for param in self.class_embed.parameters():
|
|
577
574
|
param.requires_grad_(True)
|
|
578
575
|
|
|
579
|
-
|
|
576
|
+
@staticmethod
|
|
577
|
+
def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
580
578
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
581
579
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
582
580
|
return (batch_idx, src_idx)
|
|
@@ -585,7 +583,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
585
583
|
self,
|
|
586
584
|
cls_logits: torch.Tensor,
|
|
587
585
|
targets: list[dict[str, torch.Tensor]],
|
|
588
|
-
indices: list[torch.Tensor],
|
|
586
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
589
587
|
num_boxes: int,
|
|
590
588
|
) -> torch.Tensor:
|
|
591
589
|
idx = self._get_src_permutation_idx(indices)
|
|
@@ -610,7 +608,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
610
608
|
self,
|
|
611
609
|
box_output: torch.Tensor,
|
|
612
610
|
targets: list[dict[str, torch.Tensor]],
|
|
613
|
-
indices: list[torch.Tensor],
|
|
611
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
614
612
|
num_boxes: int,
|
|
615
613
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
616
614
|
idx = self._get_src_permutation_idx(indices)
|
|
@@ -646,7 +644,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
646
644
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
647
645
|
torch.distributed.all_reduce(num_boxes)
|
|
648
646
|
|
|
649
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
647
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
650
648
|
|
|
651
649
|
loss_ce_list = []
|
|
652
650
|
loss_bbox_list = []
|
|
@@ -697,20 +695,17 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
697
695
|
return losses
|
|
698
696
|
|
|
699
697
|
def postprocess_detections(
|
|
700
|
-
self, class_logits: torch.Tensor, box_regression: torch.Tensor,
|
|
698
|
+
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
|
|
701
699
|
) -> list[dict[str, torch.Tensor]]:
|
|
702
700
|
prob = class_logits.sigmoid()
|
|
703
701
|
scores, labels = prob.max(-1)
|
|
704
702
|
labels = labels + 1 # Background offset
|
|
705
703
|
|
|
706
|
-
# TorchScript doesn't support creating tensor from tuples, convert everything to lists
|
|
707
|
-
target_sizes = torch.tensor([list(s) for s in image_shapes], device=class_logits.device)
|
|
708
|
-
|
|
709
704
|
# Convert to [x0, y0, x1, y1] format
|
|
710
705
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
711
706
|
|
|
712
707
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
713
|
-
img_h, img_w =
|
|
708
|
+
img_h, img_w = image_sizes.unbind(1)
|
|
714
709
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
715
710
|
boxes = boxes * scale_fct[:, None, :]
|
|
716
711
|
|
|
@@ -735,17 +730,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
735
730
|
|
|
736
731
|
return detections
|
|
737
732
|
|
|
738
|
-
|
|
739
|
-
def forward(
|
|
740
|
-
self,
|
|
741
|
-
x: torch.Tensor,
|
|
742
|
-
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
743
|
-
masks: Optional[torch.Tensor] = None,
|
|
744
|
-
image_sizes: Optional[list[list[int]]] = None,
|
|
745
|
-
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
746
|
-
self._input_check(targets)
|
|
747
|
-
images = self._to_img_list(x, image_sizes)
|
|
748
|
-
|
|
733
|
+
def forward_net(self, x: torch.Tensor, masks: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
749
734
|
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
750
735
|
src = features[self.backbone.return_stages[-1]]
|
|
751
736
|
src = self.input_proj(src)
|
|
@@ -772,7 +757,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
772
757
|
else:
|
|
773
758
|
num_queries_to_use = self.num_queries_one2one
|
|
774
759
|
|
|
775
|
-
query_embed = self.query_embed
|
|
760
|
+
query_embed = self.query_embed[:num_queries_to_use]
|
|
776
761
|
query_embed, query_pos = torch.split(query_embed, self.hidden_dim, dim=1)
|
|
777
762
|
query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
|
|
778
763
|
query_pos = query_pos.unsqueeze(0).expand(B, -1, -1)
|
|
@@ -787,25 +772,52 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
787
772
|
reference_points=reference_points,
|
|
788
773
|
spatial_shape=(H, W),
|
|
789
774
|
memory_key_padding_mask=mask_flatten,
|
|
775
|
+
return_intermediates=self.training is True,
|
|
790
776
|
)
|
|
791
777
|
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
778
|
+
if self.training is True:
|
|
779
|
+
outputs_classes = []
|
|
780
|
+
outputs_coords = []
|
|
781
|
+
for lvl, (class_embed, bbox_embed) in enumerate(zip(self.class_embed, self.bbox_embed)):
|
|
782
|
+
outputs_class = class_embed(hs[lvl])
|
|
783
|
+
outputs_classes.append(outputs_class)
|
|
784
|
+
|
|
785
|
+
if self.box_refine is True:
|
|
786
|
+
outputs_coord = inter_references[lvl]
|
|
787
|
+
else:
|
|
788
|
+
tmp = bbox_embed(hs[lvl])
|
|
789
|
+
tmp = tmp + inverse_sigmoid(reference_points)
|
|
790
|
+
outputs_coord = tmp.sigmoid()
|
|
791
|
+
|
|
792
|
+
outputs_coords.append(outputs_coord)
|
|
793
|
+
|
|
794
|
+
outputs_class = torch.stack(outputs_classes)
|
|
795
|
+
outputs_coord = torch.stack(outputs_coords)
|
|
796
|
+
else:
|
|
797
|
+
class_embed = self.class_embed[-1]
|
|
798
|
+
bbox_embed = self.bbox_embed[-1]
|
|
799
|
+
outputs_class = class_embed(hs)
|
|
797
800
|
|
|
798
801
|
if self.box_refine is True:
|
|
799
|
-
outputs_coord = inter_references
|
|
802
|
+
outputs_coord = inter_references
|
|
800
803
|
else:
|
|
801
|
-
tmp = bbox_embed(hs
|
|
804
|
+
tmp = bbox_embed(hs)
|
|
802
805
|
tmp = tmp + inverse_sigmoid(reference_points)
|
|
803
806
|
outputs_coord = tmp.sigmoid()
|
|
804
807
|
|
|
805
|
-
|
|
808
|
+
return (outputs_class, outputs_coord)
|
|
809
|
+
|
|
810
|
+
def forward(
|
|
811
|
+
self,
|
|
812
|
+
x: torch.Tensor,
|
|
813
|
+
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
814
|
+
masks: Optional[torch.Tensor] = None,
|
|
815
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
816
|
+
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
817
|
+
self._input_check(targets)
|
|
818
|
+
images = self._to_img_list(x, image_sizes)
|
|
806
819
|
|
|
807
|
-
outputs_class =
|
|
808
|
-
outputs_coord = torch.stack(outputs_coords)
|
|
820
|
+
outputs_class, outputs_coord = self.forward_net(x, masks)
|
|
809
821
|
|
|
810
822
|
losses = {}
|
|
811
823
|
detections: list[dict[str, torch.Tensor]] = []
|
|
@@ -815,7 +827,8 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
815
827
|
for idx, target in enumerate(targets):
|
|
816
828
|
boxes = target["boxes"]
|
|
817
829
|
boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
|
818
|
-
|
|
830
|
+
scale = images.image_sizes[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
|
|
831
|
+
boxes = boxes / scale
|
|
819
832
|
targets[idx]["boxes"] = boxes
|
|
820
833
|
targets[idx]["labels"] = target["labels"] - 1 # No background
|
|
821
834
|
|
|
@@ -835,7 +848,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
835
848
|
)
|
|
836
849
|
|
|
837
850
|
else:
|
|
838
|
-
detections = self.postprocess_detections(outputs_class
|
|
851
|
+
detections = self.postprocess_detections(outputs_class, outputs_coord, images.image_sizes)
|
|
839
852
|
|
|
840
853
|
return (detections, losses)
|
|
841
854
|
|
|
@@ -30,6 +30,7 @@ from birder.net.detection.base import BackboneWithSimpleFPN
|
|
|
30
30
|
from birder.net.detection.base import BoxCoder
|
|
31
31
|
from birder.net.detection.base import DetectionBaseNet
|
|
32
32
|
from birder.net.detection.base import Matcher
|
|
33
|
+
from birder.net.detection.base import clip_boxes_to_image
|
|
33
34
|
from birder.ops.soft_nms import SoftNMS
|
|
34
35
|
|
|
35
36
|
|
|
@@ -63,7 +64,7 @@ class RetinaNetClassificationHead(nn.Module):
|
|
|
63
64
|
if isinstance(layer, nn.Conv2d):
|
|
64
65
|
nn.init.normal_(layer.weight, std=0.01)
|
|
65
66
|
if layer.bias is not None:
|
|
66
|
-
nn.init.
|
|
67
|
+
nn.init.zeros_(layer.bias)
|
|
67
68
|
|
|
68
69
|
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
|
|
69
70
|
|
|
@@ -281,6 +282,11 @@ class RetinaNet(DetectionBaseNet):
|
|
|
281
282
|
if soft_nms is True:
|
|
282
283
|
self.soft_nms = SoftNMS()
|
|
283
284
|
|
|
285
|
+
self.score_thresh = score_thresh
|
|
286
|
+
self.nms_thresh = nms_thresh
|
|
287
|
+
self.detections_per_img = detections_per_img
|
|
288
|
+
self.topk_candidates = topk_candidates
|
|
289
|
+
|
|
284
290
|
if feature_pyramid_type == "fpn":
|
|
285
291
|
feature_pyramid: Callable[..., nn.Module] = BackboneWithFPN
|
|
286
292
|
num_anchor_sizes = len(self.backbone.return_stages) + 2
|
|
@@ -314,10 +320,8 @@ class RetinaNet(DetectionBaseNet):
|
|
|
314
320
|
self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=True)
|
|
315
321
|
self.box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
|
316
322
|
|
|
317
|
-
self.
|
|
318
|
-
|
|
319
|
-
self.detections_per_img = detections_per_img
|
|
320
|
-
self.topk_candidates = topk_candidates
|
|
323
|
+
if self.export_mode is False:
|
|
324
|
+
self.forward = torch.compiler.disable(recursive=False)(self.forward) # type: ignore[method-assign]
|
|
321
325
|
|
|
322
326
|
def reset_classifier(self, num_classes: int) -> None:
|
|
323
327
|
self.num_classes = num_classes
|
|
@@ -341,10 +345,7 @@ class RetinaNet(DetectionBaseNet):
|
|
|
341
345
|
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
342
346
|
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
343
347
|
def compute_loss(
|
|
344
|
-
self,
|
|
345
|
-
targets: list[dict[str, torch.Tensor]],
|
|
346
|
-
head_outputs: dict[str, torch.Tensor],
|
|
347
|
-
anchors: list[torch.Tensor],
|
|
348
|
+
self, targets: list[dict[str, torch.Tensor]], head_outputs: dict[str, torch.Tensor], anchors: list[torch.Tensor]
|
|
348
349
|
) -> dict[str, torch.Tensor]:
|
|
349
350
|
matched_idxs = []
|
|
350
351
|
for idx, (anchors_per_image, targets_per_image) in enumerate(zip(anchors, targets)):
|
|
@@ -362,22 +363,19 @@ class RetinaNet(DetectionBaseNet):
|
|
|
362
363
|
|
|
363
364
|
# pylint: disable=too-many-locals
|
|
364
365
|
def postprocess_detections(
|
|
365
|
-
self,
|
|
366
|
-
head_outputs: dict[str, list[torch.Tensor]],
|
|
367
|
-
anchors: list[list[torch.Tensor]],
|
|
368
|
-
image_shapes: list[tuple[int, int]],
|
|
366
|
+
self, head_outputs: dict[str, list[torch.Tensor]], anchors: list[list[torch.Tensor]], image_sizes: torch.Tensor
|
|
369
367
|
) -> list[dict[str, torch.Tensor]]:
|
|
370
368
|
class_logits = head_outputs["cls_logits"]
|
|
371
369
|
box_regression = head_outputs["bbox_regression"]
|
|
372
370
|
|
|
373
|
-
num_images =
|
|
371
|
+
num_images = image_sizes.size(0)
|
|
374
372
|
|
|
375
373
|
detections: list[dict[str, torch.Tensor]] = []
|
|
376
374
|
for index in range(num_images):
|
|
377
375
|
box_regression_per_image = [br[index] for br in box_regression]
|
|
378
376
|
logits_per_image = [cl[index] for cl in class_logits]
|
|
379
377
|
anchors_per_image = anchors[index]
|
|
380
|
-
image_shape =
|
|
378
|
+
image_shape = image_sizes[index]
|
|
381
379
|
|
|
382
380
|
image_boxes_list = []
|
|
383
381
|
image_scores_list = []
|
|
@@ -394,7 +392,7 @@ class RetinaNet(DetectionBaseNet):
|
|
|
394
392
|
topk_idxs = torch.where(keep_idxs)[0]
|
|
395
393
|
|
|
396
394
|
# Keep only topk scoring predictions
|
|
397
|
-
num_topk = min(self.topk_candidates,
|
|
395
|
+
num_topk = min(self.topk_candidates, topk_idxs.size(0))
|
|
398
396
|
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
399
397
|
topk_idxs = topk_idxs[idxs]
|
|
400
398
|
|
|
@@ -405,7 +403,7 @@ class RetinaNet(DetectionBaseNet):
|
|
|
405
403
|
boxes_per_level = self.box_coder.decode_single(
|
|
406
404
|
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
|
407
405
|
)
|
|
408
|
-
boxes_per_level =
|
|
406
|
+
boxes_per_level = clip_boxes_to_image(boxes_per_level, image_shape)
|
|
409
407
|
|
|
410
408
|
image_boxes_list.append(boxes_per_level)
|
|
411
409
|
image_scores_list.append(scores_per_level)
|
|
@@ -415,24 +413,40 @@ class RetinaNet(DetectionBaseNet):
|
|
|
415
413
|
image_scores = torch.concat(image_scores_list, dim=0)
|
|
416
414
|
image_labels = torch.concat(image_labels_list, dim=0)
|
|
417
415
|
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
416
|
+
if self.export_mode is False:
|
|
417
|
+
# Non-maximum suppression
|
|
418
|
+
if self.soft_nms is not None:
|
|
419
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
420
|
+
image_scores[keep] = soft_scores
|
|
421
|
+
else:
|
|
422
|
+
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
423
|
+
|
|
424
|
+
keep = keep[: self.detections_per_img]
|
|
425
|
+
|
|
426
|
+
detections.append(
|
|
427
|
+
{
|
|
428
|
+
"boxes": image_boxes[keep],
|
|
429
|
+
"scores": image_scores[keep],
|
|
430
|
+
"labels": image_labels[keep],
|
|
431
|
+
}
|
|
432
|
+
)
|
|
422
433
|
else:
|
|
423
|
-
|
|
434
|
+
detections.append(
|
|
435
|
+
{
|
|
436
|
+
"boxes": image_boxes,
|
|
437
|
+
"scores": image_scores,
|
|
438
|
+
"labels": image_labels,
|
|
439
|
+
}
|
|
440
|
+
)
|
|
424
441
|
|
|
425
|
-
|
|
442
|
+
return detections
|
|
426
443
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
"labels": image_labels[keep],
|
|
432
|
-
}
|
|
433
|
-
)
|
|
444
|
+
def forward_net(self, x: torch.Tensor) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
|
|
445
|
+
features: dict[str, torch.Tensor] = self.backbone_with_fpn(x)
|
|
446
|
+
feature_list = list(features.values())
|
|
447
|
+
head_outputs = self.head(feature_list)
|
|
434
448
|
|
|
435
|
-
return
|
|
449
|
+
return (feature_list, head_outputs)
|
|
436
450
|
|
|
437
451
|
# pylint: disable=invalid-name
|
|
438
452
|
def forward(
|
|
@@ -440,14 +454,12 @@ class RetinaNet(DetectionBaseNet):
|
|
|
440
454
|
x: torch.Tensor,
|
|
441
455
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
442
456
|
masks: Optional[torch.Tensor] = None,
|
|
443
|
-
image_sizes: Optional[list[
|
|
457
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
444
458
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
445
459
|
self._input_check(targets)
|
|
446
460
|
images = self._to_img_list(x, image_sizes)
|
|
447
461
|
|
|
448
|
-
|
|
449
|
-
feature_list = list(features.values())
|
|
450
|
-
head_outputs = self.head(feature_list)
|
|
462
|
+
feature_list, head_outputs = self.forward_net(x)
|
|
451
463
|
anchors = self.anchor_generator(images, feature_list)
|
|
452
464
|
|
|
453
465
|
losses: dict[str, torch.Tensor] = {}
|
|
@@ -47,9 +47,6 @@ def get_contrastive_denoising_training_group( # pylint: disable=too-many-locals
|
|
|
47
47
|
label_noise_ratio: float,
|
|
48
48
|
box_noise_scale: float,
|
|
49
49
|
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[dict[str, Any]]]:
|
|
50
|
-
if num_denoising_queries <= 0:
|
|
51
|
-
return (None, None, None, None)
|
|
52
|
-
|
|
53
50
|
num_ground_truths = [len(t["labels"]) for t in targets]
|
|
54
51
|
device = targets[0]["labels"].device
|
|
55
52
|
|
|
@@ -596,18 +593,18 @@ class RT_DETRDecoder(nn.Module):
|
|
|
596
593
|
|
|
597
594
|
# Gather reference points
|
|
598
595
|
reference_points_unact = enc_outputs_coord_unact.gather(
|
|
599
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
596
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
|
|
600
597
|
)
|
|
601
598
|
|
|
602
599
|
enc_topk_bboxes = reference_points_unact.sigmoid()
|
|
603
600
|
|
|
604
601
|
# Gather encoder logits for loss computation
|
|
605
602
|
enc_topk_logits = enc_outputs_class.gather(
|
|
606
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
603
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
|
|
607
604
|
)
|
|
608
605
|
|
|
609
606
|
# Extract region features
|
|
610
|
-
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).
|
|
607
|
+
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
|
|
611
608
|
target = target.detach()
|
|
612
609
|
|
|
613
610
|
return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
|
|
@@ -621,6 +618,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
621
618
|
denoising_bbox_unact: Optional[torch.Tensor] = None,
|
|
622
619
|
attn_mask: Optional[torch.Tensor] = None,
|
|
623
620
|
padding_mask: Optional[list[torch.Tensor]] = None,
|
|
621
|
+
return_intermediates: bool = True,
|
|
624
622
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
625
623
|
memory = []
|
|
626
624
|
mask_flatten = []
|
|
@@ -648,12 +646,12 @@ class RT_DETRDecoder(nn.Module):
|
|
|
648
646
|
level_start_index_tensor = torch.tensor(level_start_index, dtype=torch.long, device=memory.device)
|
|
649
647
|
|
|
650
648
|
# Decoder forward
|
|
651
|
-
|
|
652
|
-
|
|
649
|
+
bboxes_list: list[torch.Tensor] = []
|
|
650
|
+
logits_list: list[torch.Tensor] = []
|
|
653
651
|
reference_points = init_ref_points_unact.sigmoid()
|
|
654
652
|
for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
|
|
655
653
|
query_pos = self.query_pos_head(reference_points)
|
|
656
|
-
reference_points_input = reference_points.unsqueeze(2).
|
|
654
|
+
reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
|
|
657
655
|
target = decoder_layer(
|
|
658
656
|
target,
|
|
659
657
|
query_pos,
|
|
@@ -663,6 +661,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
663
661
|
level_start_index_tensor,
|
|
664
662
|
memory_padding_mask,
|
|
665
663
|
attn_mask,
|
|
664
|
+
src_shapes=spatial_shapes,
|
|
666
665
|
)
|
|
667
666
|
|
|
668
667
|
bbox_delta = bbox_head(target)
|
|
@@ -672,14 +671,19 @@ class RT_DETRDecoder(nn.Module):
|
|
|
672
671
|
# Classification
|
|
673
672
|
class_logits = class_head(target)
|
|
674
673
|
|
|
675
|
-
|
|
676
|
-
|
|
674
|
+
if return_intermediates is True:
|
|
675
|
+
bboxes_list.append(new_reference_points)
|
|
676
|
+
logits_list.append(class_logits)
|
|
677
677
|
|
|
678
678
|
# Update reference points for next layer
|
|
679
679
|
reference_points = new_reference_points.detach()
|
|
680
680
|
|
|
681
|
-
|
|
682
|
-
|
|
681
|
+
if return_intermediates is True:
|
|
682
|
+
out_bboxes = torch.stack(bboxes_list)
|
|
683
|
+
out_logits = torch.stack(logits_list)
|
|
684
|
+
else:
|
|
685
|
+
out_bboxes = new_reference_points
|
|
686
|
+
out_logits = class_logits
|
|
683
687
|
|
|
684
688
|
return (out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits)
|
|
685
689
|
|
|
@@ -743,7 +747,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
743
747
|
self.decoder = RT_DETRDecoder(
|
|
744
748
|
hidden_dim=hidden_dim,
|
|
745
749
|
num_classes=self.num_classes,
|
|
746
|
-
num_queries=num_queries,
|
|
750
|
+
num_queries=self.num_queries,
|
|
747
751
|
num_decoder_layers=num_decoder_layers,
|
|
748
752
|
num_levels=self.num_levels,
|
|
749
753
|
num_heads=num_heads,
|
|
@@ -810,7 +814,8 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
810
814
|
for param in self.denoising_class_embed.parameters():
|
|
811
815
|
param.requires_grad_(True)
|
|
812
816
|
|
|
813
|
-
|
|
817
|
+
@staticmethod
|
|
818
|
+
def _get_src_permutation_idx(indices: list[tuple[torch.Tensor, torch.Tensor]]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
814
819
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
815
820
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
816
821
|
return (batch_idx, src_idx)
|
|
@@ -820,7 +825,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
820
825
|
cls_logits: torch.Tensor,
|
|
821
826
|
box_output: torch.Tensor,
|
|
822
827
|
targets: list[dict[str, torch.Tensor]],
|
|
823
|
-
indices: list[torch.Tensor],
|
|
828
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
824
829
|
num_boxes: float,
|
|
825
830
|
) -> torch.Tensor:
|
|
826
831
|
idx = self._get_src_permutation_idx(indices)
|
|
@@ -859,7 +864,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
859
864
|
self,
|
|
860
865
|
box_output: torch.Tensor,
|
|
861
866
|
targets: list[dict[str, torch.Tensor]],
|
|
862
|
-
indices: list[torch.Tensor],
|
|
867
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
863
868
|
num_boxes: float,
|
|
864
869
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
865
870
|
idx = self._get_src_permutation_idx(indices)
|
|
@@ -927,8 +932,6 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
927
932
|
|
|
928
933
|
return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
|
|
929
934
|
|
|
930
|
-
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
931
|
-
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
932
935
|
def _compute_loss_from_outputs( # pylint: disable=too-many-locals
|
|
933
936
|
self,
|
|
934
937
|
targets: list[dict[str, torch.Tensor]],
|
|
@@ -946,7 +949,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
946
949
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
947
950
|
torch.distributed.all_reduce(num_boxes)
|
|
948
951
|
|
|
949
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
952
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
950
953
|
|
|
951
954
|
loss_ce_list = []
|
|
952
955
|
loss_bbox_list = []
|
|
@@ -1001,11 +1004,11 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1001
1004
|
images: Any,
|
|
1002
1005
|
masks: Optional[list[torch.Tensor]] = None,
|
|
1003
1006
|
) -> dict[str, torch.Tensor]:
|
|
1004
|
-
device = encoder_features[0].device
|
|
1005
1007
|
for idx, target in enumerate(targets):
|
|
1006
1008
|
boxes = target["boxes"]
|
|
1007
1009
|
boxes = box_ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
|
1008
|
-
|
|
1010
|
+
scale = images.image_sizes[idx].flip(0).repeat(2).float() # flip to [W, H], repeat to [W, H, W, H]
|
|
1011
|
+
boxes = boxes / scale
|
|
1009
1012
|
targets[idx]["boxes"] = boxes
|
|
1010
1013
|
targets[idx]["labels"] = target["labels"] - 1 # No background
|
|
1011
1014
|
|
|
@@ -1038,7 +1041,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1038
1041
|
return losses
|
|
1039
1042
|
|
|
1040
1043
|
def postprocess_detections(
|
|
1041
|
-
self, class_logits: torch.Tensor, box_regression: torch.Tensor,
|
|
1044
|
+
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_sizes: torch.Tensor
|
|
1042
1045
|
) -> list[dict[str, torch.Tensor]]:
|
|
1043
1046
|
prob = class_logits.sigmoid()
|
|
1044
1047
|
topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=self.decoder.num_queries, dim=1)
|
|
@@ -1047,14 +1050,12 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1047
1050
|
labels = topk_indexes % class_logits.shape[2]
|
|
1048
1051
|
labels += 1 # Background offset
|
|
1049
1052
|
|
|
1050
|
-
target_sizes = torch.tensor(image_shapes, device=class_logits.device)
|
|
1051
|
-
|
|
1052
1053
|
# Convert to [x0, y0, x1, y1] format
|
|
1053
1054
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
1054
|
-
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).
|
|
1055
|
+
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
|
|
1055
1056
|
|
|
1056
1057
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
1057
|
-
img_h, img_w =
|
|
1058
|
+
img_h, img_w = image_sizes.unbind(1)
|
|
1058
1059
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
1059
1060
|
boxes = boxes * scale_fct[:, None, :]
|
|
1060
1061
|
|
|
@@ -1090,32 +1091,34 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1090
1091
|
|
|
1091
1092
|
return (None, None, None, None)
|
|
1092
1093
|
|
|
1094
|
+
def forward_net(
|
|
1095
|
+
self, x: torch.Tensor, masks: Optional[torch.Tensor]
|
|
1096
|
+
) -> tuple[list[torch.Tensor], Optional[list[torch.Tensor]]]:
|
|
1097
|
+
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
1098
|
+
feature_list = list(features.values())
|
|
1099
|
+
|
|
1100
|
+
mask_list: Optional[list[torch.Tensor]] = None
|
|
1101
|
+
if masks is not None:
|
|
1102
|
+
mask_list = []
|
|
1103
|
+
for feat in feature_list:
|
|
1104
|
+
m = F.interpolate(masks[None].float(), size=feat.shape[-2:], mode="nearest").to(torch.bool)[0]
|
|
1105
|
+
mask_list.append(m)
|
|
1106
|
+
|
|
1107
|
+
encoder_features = self.encoder(feature_list, masks=mask_list)
|
|
1108
|
+
|
|
1109
|
+
return (encoder_features, mask_list)
|
|
1110
|
+
|
|
1093
1111
|
def forward(
|
|
1094
1112
|
self,
|
|
1095
1113
|
x: torch.Tensor,
|
|
1096
1114
|
targets: Optional[list[dict[str, torch.Tensor]]] = None,
|
|
1097
1115
|
masks: Optional[torch.Tensor] = None,
|
|
1098
|
-
image_sizes: Optional[list[
|
|
1116
|
+
image_sizes: Optional[list[tuple[int, int]]] = None,
|
|
1099
1117
|
) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
|
|
1100
1118
|
self._input_check(targets)
|
|
1101
1119
|
images = self._to_img_list(x, image_sizes)
|
|
1102
1120
|
|
|
1103
|
-
|
|
1104
|
-
features: dict[str, torch.Tensor] = self.backbone.detection_features(x)
|
|
1105
|
-
feature_list = list(features.values())
|
|
1106
|
-
|
|
1107
|
-
# Hybrid encoder
|
|
1108
|
-
mask_list: list[torch.Tensor] = []
|
|
1109
|
-
for feat in feature_list:
|
|
1110
|
-
if masks is not None:
|
|
1111
|
-
mask_size = feat.shape[-2:]
|
|
1112
|
-
m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
|
|
1113
|
-
else:
|
|
1114
|
-
B, _, H, W = feat.size()
|
|
1115
|
-
m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
1116
|
-
mask_list.append(m)
|
|
1117
|
-
|
|
1118
|
-
encoder_features = self.encoder(feature_list, masks=mask_list)
|
|
1121
|
+
encoder_features, mask_list = self.forward_net(x, masks)
|
|
1119
1122
|
|
|
1120
1123
|
# Prepare spatial shapes and level start index
|
|
1121
1124
|
spatial_shapes: list[list[int]] = []
|
|
@@ -1136,9 +1139,9 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1136
1139
|
else:
|
|
1137
1140
|
# Inference path - no CDN
|
|
1138
1141
|
out_bboxes, out_logits, _, _ = self.decoder(
|
|
1139
|
-
encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list
|
|
1142
|
+
encoder_features, spatial_shapes, level_start_index, padding_mask=mask_list, return_intermediates=False
|
|
1140
1143
|
)
|
|
1141
|
-
detections = self.postprocess_detections(out_logits
|
|
1144
|
+
detections = self.postprocess_detections(out_logits, out_bboxes, images.image_sizes)
|
|
1142
1145
|
|
|
1143
1146
|
return (detections, losses)
|
|
1144
1147
|
|