birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -58,7 +58,7 @@ class HungarianMatcher(nn.Module):
|
|
|
58
58
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
59
59
|
) -> list[torch.Tensor]:
|
|
60
60
|
with torch.no_grad():
|
|
61
|
-
|
|
61
|
+
B, num_queries = class_logits.shape[:2]
|
|
62
62
|
|
|
63
63
|
# We flatten to compute the cost matrices in a batch
|
|
64
64
|
out_prob = class_logits.flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
|
@@ -111,8 +111,7 @@ def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
|
|
111
111
|
class MultiScaleDeformableAttention(nn.Module):
|
|
112
112
|
def __init__(self, d_model: int, n_levels: int, n_heads: int, n_points: int) -> None:
|
|
113
113
|
super().__init__()
|
|
114
|
-
|
|
115
|
-
raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
|
|
114
|
+
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
116
115
|
|
|
117
116
|
# Ensure dim_per_head is power of 2
|
|
118
117
|
dim_per_head = d_model // n_heads
|
|
@@ -133,9 +132,9 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
133
132
|
self.value_proj = nn.Linear(d_model, d_model)
|
|
134
133
|
self.output_proj = nn.Linear(d_model, d_model)
|
|
135
134
|
|
|
136
|
-
self.
|
|
135
|
+
self.reset_parameters()
|
|
137
136
|
|
|
138
|
-
def
|
|
137
|
+
def reset_parameters(self) -> None:
|
|
139
138
|
nn.init.constant_(self.sampling_offsets.weight, 0.0)
|
|
140
139
|
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
|
141
140
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
@@ -166,8 +165,8 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
166
165
|
input_level_start_index: torch.Tensor,
|
|
167
166
|
input_padding_mask: Optional[torch.Tensor] = None,
|
|
168
167
|
) -> torch.Tensor:
|
|
169
|
-
|
|
170
|
-
|
|
168
|
+
N, num_queries, _ = query.size()
|
|
169
|
+
N, sequence_length, _ = input_flatten.size()
|
|
171
170
|
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
|
|
172
171
|
|
|
173
172
|
value = self.value_proj(input_flatten)
|
|
@@ -283,7 +282,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
|
283
282
|
q = tgt + query_pos
|
|
284
283
|
k = tgt + query_pos
|
|
285
284
|
|
|
286
|
-
|
|
285
|
+
tgt2, _ = self.self_attn(
|
|
287
286
|
q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
|
|
288
287
|
)
|
|
289
288
|
tgt2 = tgt2.transpose(0, 1)
|
|
@@ -318,7 +317,7 @@ class DeformableTransformerEncoder(nn.Module):
|
|
|
318
317
|
for lvl, spatial_shape in enumerate(spatial_shapes):
|
|
319
318
|
H = spatial_shape[0]
|
|
320
319
|
W = spatial_shape[1]
|
|
321
|
-
|
|
320
|
+
ref_y, ref_x = torch.meshgrid(
|
|
322
321
|
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
|
|
323
322
|
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
|
|
324
323
|
indexing="ij",
|
|
@@ -454,7 +453,7 @@ class DeformableTransformer(nn.Module):
|
|
|
454
453
|
|
|
455
454
|
for m in self.modules():
|
|
456
455
|
if isinstance(m, MultiScaleDeformableAttention):
|
|
457
|
-
m.
|
|
456
|
+
m.reset_parameters()
|
|
458
457
|
|
|
459
458
|
nn.init.xavier_uniform_(self.reference_points.weight, gain=1.0)
|
|
460
459
|
nn.init.zeros_(self.reference_points.bias)
|
|
@@ -462,7 +461,7 @@ class DeformableTransformer(nn.Module):
|
|
|
462
461
|
nn.init.normal_(self.level_embed)
|
|
463
462
|
|
|
464
463
|
def get_valid_ratio(self, mask: torch.Tensor) -> torch.Tensor:
|
|
465
|
-
|
|
464
|
+
_, H, W = mask.size()
|
|
466
465
|
valid_h = torch.sum(~mask[:, :, 0], 1)
|
|
467
466
|
valid_w = torch.sum(~mask[:, 0, :], 1)
|
|
468
467
|
valid_ratio_h = valid_h.float() / H
|
|
@@ -485,7 +484,7 @@ class DeformableTransformer(nn.Module):
|
|
|
485
484
|
mask_list = []
|
|
486
485
|
spatial_shape_list: list[list[int]] = [] # list[tuple[int, int]] not supported on TorchScript
|
|
487
486
|
for lvl, (src, pos_embed, mask) in enumerate(zip(srcs, pos_embeds, masks)):
|
|
488
|
-
|
|
487
|
+
_, _, H, W = src.size()
|
|
489
488
|
spatial_shape_list.append([H, W])
|
|
490
489
|
src = src.flatten(2).transpose(1, 2)
|
|
491
490
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
@@ -508,14 +507,14 @@ class DeformableTransformer(nn.Module):
|
|
|
508
507
|
)
|
|
509
508
|
|
|
510
509
|
# Prepare input for decoder
|
|
511
|
-
|
|
510
|
+
B, _, C = memory.size()
|
|
512
511
|
query_embed, tgt = torch.split(query_embed, C, dim=1)
|
|
513
512
|
query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
|
|
514
513
|
tgt = tgt.unsqueeze(0).expand(B, -1, -1)
|
|
515
514
|
reference_points = self.reference_points(query_embed).sigmoid()
|
|
516
515
|
|
|
517
516
|
# Decoder
|
|
518
|
-
|
|
517
|
+
hs, inter_references = self.decoder(
|
|
519
518
|
tgt, reference_points, memory, spatial_shapes, level_start_index, query_embed, valid_ratios, mask_flatten
|
|
520
519
|
)
|
|
521
520
|
|
|
@@ -632,7 +631,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
632
631
|
prior_prob = 0.01
|
|
633
632
|
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
|
634
633
|
for class_embed in self.class_embed:
|
|
635
|
-
|
|
634
|
+
nn.init.constant_(class_embed.bias, bias_value)
|
|
636
635
|
|
|
637
636
|
def freeze(self, freeze_classifier: bool = True) -> None:
|
|
638
637
|
for param in self.parameters():
|
|
@@ -656,20 +655,19 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
656
655
|
) -> torch.Tensor:
|
|
657
656
|
idx = self._get_src_permutation_idx(indices)
|
|
658
657
|
target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
|
|
659
|
-
target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
|
|
660
|
-
target_classes[idx] = target_classes_o
|
|
661
658
|
|
|
662
659
|
target_classes_onehot = torch.zeros(
|
|
663
|
-
|
|
660
|
+
cls_logits.size(0),
|
|
661
|
+
cls_logits.size(1),
|
|
662
|
+
cls_logits.size(2) + 1,
|
|
664
663
|
dtype=cls_logits.dtype,
|
|
665
|
-
layout=cls_logits.layout,
|
|
666
664
|
device=cls_logits.device,
|
|
667
665
|
)
|
|
668
|
-
target_classes_onehot
|
|
669
|
-
|
|
666
|
+
target_classes_onehot[idx[0], idx[1], target_classes_o] = 1
|
|
670
667
|
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
|
668
|
+
|
|
671
669
|
loss = sigmoid_focal_loss(cls_logits, target_classes_onehot, alpha=0.25, gamma=2.0)
|
|
672
|
-
loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.
|
|
670
|
+
loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.size(1)
|
|
673
671
|
|
|
674
672
|
return loss_ce
|
|
675
673
|
|
|
@@ -719,7 +717,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
719
717
|
for idx in range(cls_logits.size(0)):
|
|
720
718
|
indices = self.matcher(cls_logits[idx], box_output[idx], targets)
|
|
721
719
|
loss_ce_i = self._class_loss(cls_logits[idx], targets, indices, num_boxes)
|
|
722
|
-
|
|
720
|
+
loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
|
|
723
721
|
loss_ce_list.append(loss_ce_i)
|
|
724
722
|
loss_bbox_list.append(loss_bbox_i)
|
|
725
723
|
loss_giou_list.append(loss_giou_i)
|
|
@@ -739,7 +737,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
739
737
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
|
|
740
738
|
) -> list[dict[str, torch.Tensor]]:
|
|
741
739
|
prob = class_logits.sigmoid()
|
|
742
|
-
|
|
740
|
+
topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
|
|
743
741
|
scores = topk_values
|
|
744
742
|
topk_boxes = topk_indexes // class_logits.shape[2]
|
|
745
743
|
labels = topk_indexes % class_logits.shape[2]
|
|
@@ -752,7 +750,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
752
750
|
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
|
753
751
|
|
|
754
752
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
755
|
-
|
|
753
|
+
img_h, img_w = target_sizes.unbind(1)
|
|
756
754
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
757
755
|
boxes = boxes * scale_fct[:, None, :]
|
|
758
756
|
|
|
@@ -760,7 +758,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
760
758
|
for s, l, b in zip(scores, labels, boxes):
|
|
761
759
|
# Non-maximum suppression
|
|
762
760
|
if self.soft_nms is not None:
|
|
763
|
-
|
|
761
|
+
soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
764
762
|
s[keep] = soft_scores
|
|
765
763
|
|
|
766
764
|
b = b[keep]
|
|
@@ -797,14 +795,14 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
797
795
|
mask_size = feature_list[idx].shape[-2:]
|
|
798
796
|
m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
|
|
799
797
|
else:
|
|
800
|
-
|
|
798
|
+
B, _, H, W = feature_list[idx].size()
|
|
801
799
|
m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
802
800
|
|
|
803
801
|
feature_list[idx] = proj(feature_list[idx])
|
|
804
802
|
mask_list.append(m)
|
|
805
803
|
pos_list.append(self.pos_enc(feature_list[idx], m))
|
|
806
804
|
|
|
807
|
-
|
|
805
|
+
hs, init_reference, inter_references = self.transformer(
|
|
808
806
|
feature_list, pos_list, self.query_embed.weight, mask_list
|
|
809
807
|
)
|
|
810
808
|
outputs_classes = []
|
birder/net/detection/detr.py
CHANGED
|
@@ -51,7 +51,7 @@ class HungarianMatcher(nn.Module):
|
|
|
51
51
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
52
52
|
) -> list[torch.Tensor]:
|
|
53
53
|
with torch.no_grad():
|
|
54
|
-
|
|
54
|
+
B, num_queries = class_logits.shape[:2]
|
|
55
55
|
|
|
56
56
|
# We flatten to compute the cost matrices in a batch
|
|
57
57
|
out_prob = class_logits.flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
|
@@ -111,7 +111,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
|
111
111
|
q = src + pos
|
|
112
112
|
k = src + pos
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
src2, _ = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask, need_weights=False)
|
|
115
115
|
src = src + self.dropout1(src2)
|
|
116
116
|
src = self.norm1(src)
|
|
117
117
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
@@ -151,10 +151,10 @@ class TransformerDecoderLayer(nn.Module):
|
|
|
151
151
|
q = tgt + query_pos
|
|
152
152
|
k = tgt + query_pos
|
|
153
153
|
|
|
154
|
-
|
|
154
|
+
tgt2, _ = self.self_attn(q, k, value=tgt, need_weights=False)
|
|
155
155
|
tgt = tgt + self.dropout1(tgt2)
|
|
156
156
|
tgt = self.norm1(tgt)
|
|
157
|
-
|
|
157
|
+
tgt2, _ = self.multihead_attn(
|
|
158
158
|
query=tgt + query_pos,
|
|
159
159
|
key=memory + pos,
|
|
160
160
|
value=memory,
|
|
@@ -270,7 +270,7 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
270
270
|
|
|
271
271
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
272
272
|
if mask is None:
|
|
273
|
-
|
|
273
|
+
B, _, H, W = x.size()
|
|
274
274
|
mask = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
275
275
|
|
|
276
276
|
not_mask = ~mask
|
|
@@ -430,7 +430,7 @@ class DETR(DetectionBaseNet):
|
|
|
430
430
|
for idx in range(cls_logits.size(0)):
|
|
431
431
|
indices = self.matcher(cls_logits[idx], box_output[idx], targets)
|
|
432
432
|
loss_ce_i = self._class_loss(cls_logits[idx], targets, indices)
|
|
433
|
-
|
|
433
|
+
loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
|
|
434
434
|
loss_ce_list.append(loss_ce_i)
|
|
435
435
|
loss_bbox_list.append(loss_bbox_i)
|
|
436
436
|
loss_giou_list.append(loss_giou_i)
|
|
@@ -450,7 +450,7 @@ class DETR(DetectionBaseNet):
|
|
|
450
450
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
|
|
451
451
|
) -> list[dict[str, torch.Tensor]]:
|
|
452
452
|
prob = F.softmax(class_logits, -1)
|
|
453
|
-
|
|
453
|
+
scores, labels = prob[..., 1:].max(-1)
|
|
454
454
|
labels = labels + 1
|
|
455
455
|
|
|
456
456
|
# TorchScript doesn't support creating tensor from tuples, convert everything to lists
|
|
@@ -460,7 +460,7 @@ class DETR(DetectionBaseNet):
|
|
|
460
460
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
461
461
|
|
|
462
462
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
463
|
-
|
|
463
|
+
img_h, img_w = target_sizes.unbind(1)
|
|
464
464
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
465
465
|
boxes = boxes * scale_fct[:, None, :]
|
|
466
466
|
|
|
@@ -468,7 +468,7 @@ class DETR(DetectionBaseNet):
|
|
|
468
468
|
for s, l, b in zip(scores, labels, boxes):
|
|
469
469
|
# Non-maximum suppression
|
|
470
470
|
if self.soft_nms is not None:
|
|
471
|
-
|
|
471
|
+
soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
472
472
|
s[keep] = soft_scores
|
|
473
473
|
|
|
474
474
|
b = b[keep]
|
|
@@ -136,8 +136,8 @@ class ResampleFeatureMap(nn.Module):
|
|
|
136
136
|
if self.conv is not None:
|
|
137
137
|
x = self.conv(x)
|
|
138
138
|
|
|
139
|
-
|
|
140
|
-
|
|
139
|
+
in_h, in_w = x.shape[-2:]
|
|
140
|
+
target_h, target_w = target_size
|
|
141
141
|
if in_h == target_h and in_w == target_w:
|
|
142
142
|
return x
|
|
143
143
|
|
|
@@ -358,13 +358,7 @@ class HeadNet(nn.Module):
|
|
|
358
358
|
for _ in range(repeats):
|
|
359
359
|
layers.append(
|
|
360
360
|
nn.Conv2d(
|
|
361
|
-
fpn_channels,
|
|
362
|
-
fpn_channels,
|
|
363
|
-
kernel_size=(3, 3),
|
|
364
|
-
stride=(1, 1),
|
|
365
|
-
padding=(1, 1),
|
|
366
|
-
groups=fpn_channels,
|
|
367
|
-
bias=True,
|
|
361
|
+
fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
|
|
368
362
|
)
|
|
369
363
|
)
|
|
370
364
|
layers.append(
|
|
@@ -383,22 +377,9 @@ class HeadNet(nn.Module):
|
|
|
383
377
|
self.conv_repeat = nn.Sequential(*layers)
|
|
384
378
|
self.predict = nn.Sequential(
|
|
385
379
|
nn.Conv2d(
|
|
386
|
-
fpn_channels,
|
|
387
|
-
fpn_channels,
|
|
388
|
-
kernel_size=(3, 3),
|
|
389
|
-
stride=(1, 1),
|
|
390
|
-
padding=(1, 1),
|
|
391
|
-
groups=fpn_channels,
|
|
392
|
-
bias=True,
|
|
393
|
-
),
|
|
394
|
-
nn.Conv2d(
|
|
395
|
-
fpn_channels,
|
|
396
|
-
num_outputs * num_anchors,
|
|
397
|
-
kernel_size=(1, 1),
|
|
398
|
-
stride=(1, 1),
|
|
399
|
-
padding=(0, 0),
|
|
400
|
-
bias=True,
|
|
380
|
+
fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
|
|
401
381
|
),
|
|
382
|
+
nn.Conv2d(fpn_channels, num_outputs * num_anchors, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
|
|
402
383
|
)
|
|
403
384
|
|
|
404
385
|
def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
@@ -453,7 +434,7 @@ class ClassificationHead(HeadNet):
|
|
|
453
434
|
cls_logits = self.predict(cls_logits)
|
|
454
435
|
|
|
455
436
|
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
|
|
456
|
-
|
|
437
|
+
N, _, H, W = cls_logits.shape
|
|
457
438
|
cls_logits = cls_logits.view(N, -1, self.num_outputs, H, W)
|
|
458
439
|
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
|
|
459
440
|
cls_logits = cls_logits.reshape(N, -1, self.num_outputs) # Size=(N, HWA, K)
|
|
@@ -504,7 +485,7 @@ class RegressionHead(HeadNet):
|
|
|
504
485
|
bbox_regression = self.predict(bbox_regression)
|
|
505
486
|
|
|
506
487
|
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
|
|
507
|
-
|
|
488
|
+
N, _, H, W = bbox_regression.shape
|
|
508
489
|
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
|
|
509
490
|
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
|
|
510
491
|
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
|
|
@@ -663,7 +644,7 @@ class EfficientDet(DetectionBaseNet):
|
|
|
663
644
|
|
|
664
645
|
# Keep only topk scoring predictions
|
|
665
646
|
num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
|
|
666
|
-
|
|
647
|
+
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
667
648
|
topk_idxs = topk_idxs[idxs]
|
|
668
649
|
|
|
669
650
|
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
|
|
@@ -685,7 +666,7 @@ class EfficientDet(DetectionBaseNet):
|
|
|
685
666
|
|
|
686
667
|
# Non-maximum suppression
|
|
687
668
|
if self.soft_nms is not None:
|
|
688
|
-
|
|
669
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
689
670
|
image_scores[keep] = soft_scores
|
|
690
671
|
else:
|
|
691
672
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
@@ -150,7 +150,7 @@ def concat_box_prediction_layers(
|
|
|
150
150
|
# all feature levels concatenated, so we keep the same representation
|
|
151
151
|
# for the objectness and the box_regression
|
|
152
152
|
for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
|
|
153
|
-
|
|
153
|
+
N, AxC, H, W = box_cls_per_level.shape # pylint: disable=invalid-name
|
|
154
154
|
Ax4 = box_regression_per_level.shape[1] # pylint: disable=invalid-name
|
|
155
155
|
A = Ax4 // 4
|
|
156
156
|
C = AxC // A
|
|
@@ -240,7 +240,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
240
240
|
|
|
241
241
|
# Get the targets corresponding GT for each proposal
|
|
242
242
|
# NB: need to clamp the indices because we can have a single
|
|
243
|
-
# GT in the image
|
|
243
|
+
# GT in the image and matched_idxs can be -2, which goes out of bounds
|
|
244
244
|
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
|
|
245
245
|
|
|
246
246
|
labels_per_image = matched_idxs >= 0
|
|
@@ -265,7 +265,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
265
265
|
for ob in objectness.split(num_anchors_per_level, 1):
|
|
266
266
|
num_anchors = ob.shape[1]
|
|
267
267
|
pre_nms_top_n = min(self.pre_nms_top_n(), int(ob.size(1)))
|
|
268
|
-
|
|
268
|
+
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
|
|
269
269
|
r.append(top_n_idx + offset)
|
|
270
270
|
offset += num_anchors
|
|
271
271
|
|
|
@@ -310,19 +310,19 @@ class RegionProposalNetwork(nn.Module):
|
|
|
310
310
|
|
|
311
311
|
# Remove small boxes
|
|
312
312
|
keep = box_ops.remove_small_boxes(boxes, self.min_size)
|
|
313
|
-
|
|
313
|
+
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
|
|
314
314
|
|
|
315
315
|
# Remove low scoring boxes
|
|
316
316
|
# use >= for Backwards compatibility
|
|
317
317
|
keep = torch.where(scores >= self.score_thresh)[0]
|
|
318
|
-
|
|
318
|
+
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
|
|
319
319
|
|
|
320
320
|
# Non-maximum suppression, independently done per level
|
|
321
321
|
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
|
|
322
322
|
|
|
323
323
|
# Keep only topk scoring predictions
|
|
324
324
|
keep = keep[: self.post_nms_top_n()]
|
|
325
|
-
|
|
325
|
+
boxes, scores = boxes[keep], scores[keep]
|
|
326
326
|
|
|
327
327
|
final_boxes.append(boxes)
|
|
328
328
|
final_scores.append(scores)
|
|
@@ -336,7 +336,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
336
336
|
labels: list[torch.Tensor],
|
|
337
337
|
regression_targets: list[torch.Tensor],
|
|
338
338
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
339
|
-
|
|
339
|
+
sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
|
|
340
340
|
sampled_pos_idxs = torch.where(torch.concat(sampled_pos_idxs, dim=0))[0]
|
|
341
341
|
sampled_neg_idxs = torch.where(torch.concat(sampled_neg_idxs, dim=0))[0]
|
|
342
342
|
|
|
@@ -364,29 +364,29 @@ class RegionProposalNetwork(nn.Module):
|
|
|
364
364
|
) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
|
|
365
365
|
# RPN uses all feature maps that are available
|
|
366
366
|
features_list = list(features.values())
|
|
367
|
-
|
|
367
|
+
objectness, pred_bbox_deltas = self.head(features_list)
|
|
368
368
|
anchors = self.anchor_generator(images, features_list)
|
|
369
369
|
|
|
370
370
|
num_images = len(anchors)
|
|
371
371
|
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
|
|
372
372
|
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
|
|
373
|
-
|
|
373
|
+
objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
|
|
374
374
|
|
|
375
375
|
# Apply pred_bbox_deltas to anchors to obtain the decoded proposals
|
|
376
376
|
# note that we detach the deltas because Faster R-CNN do not backprop through
|
|
377
377
|
# the proposals
|
|
378
378
|
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
|
|
379
379
|
proposals = proposals.view(num_images, -1, 4)
|
|
380
|
-
|
|
380
|
+
boxes, _scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
|
|
381
381
|
|
|
382
382
|
losses: dict[str, torch.Tensor] = {}
|
|
383
383
|
if self.training is True:
|
|
384
384
|
if targets is None:
|
|
385
385
|
raise ValueError("targets should not be None")
|
|
386
386
|
|
|
387
|
-
|
|
387
|
+
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
|
|
388
388
|
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
|
|
389
|
-
|
|
389
|
+
loss_objectness, loss_rpn_box_reg = self.compute_loss(
|
|
390
390
|
objectness, pred_bbox_deltas, labels, regression_targets
|
|
391
391
|
)
|
|
392
392
|
losses = {
|
|
@@ -405,7 +405,7 @@ class FastRCNNConvFCHead(nn.Sequential):
|
|
|
405
405
|
fc_layers: list[int],
|
|
406
406
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
407
407
|
):
|
|
408
|
-
|
|
408
|
+
in_channels, in_height, in_width = input_size
|
|
409
409
|
|
|
410
410
|
blocks = []
|
|
411
411
|
previous_channels = in_channels
|
|
@@ -481,7 +481,7 @@ def faster_rcnn_loss(
|
|
|
481
481
|
# advanced indexing
|
|
482
482
|
sampled_pos_idxs_subset = torch.where(labels > 0)[0]
|
|
483
483
|
labels_pos = labels[sampled_pos_idxs_subset]
|
|
484
|
-
|
|
484
|
+
N, _num_classes = class_logits.shape
|
|
485
485
|
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
|
|
486
486
|
|
|
487
487
|
box_loss = F.smooth_l1_loss(
|
|
@@ -573,7 +573,7 @@ class RoIHeads(nn.Module):
|
|
|
573
573
|
return (matched_idxs, labels)
|
|
574
574
|
|
|
575
575
|
def subsample(self, labels: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
576
|
-
|
|
576
|
+
sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
|
|
577
577
|
sampled_idxs = []
|
|
578
578
|
for pos_idxs_img, neg_idxs_img in zip(sampled_pos_idxs, sampled_neg_idxs):
|
|
579
579
|
img_sampled_idxs = torch.where(pos_idxs_img | neg_idxs_img)[0]
|
|
@@ -610,7 +610,7 @@ class RoIHeads(nn.Module):
|
|
|
610
610
|
proposals = self.add_gt_proposals(proposals, gt_boxes)
|
|
611
611
|
|
|
612
612
|
# Get matching gt indices for each proposal
|
|
613
|
-
|
|
613
|
+
matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
|
|
614
614
|
|
|
615
615
|
# Sample a fixed proportion of positive-negative proposals
|
|
616
616
|
sampled_idxs = self.subsample(labels)
|
|
@@ -713,7 +713,7 @@ class RoIHeads(nn.Module):
|
|
|
713
713
|
raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
|
|
714
714
|
|
|
715
715
|
if self.training is True:
|
|
716
|
-
|
|
716
|
+
proposals, _matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
|
|
717
717
|
else:
|
|
718
718
|
labels = None
|
|
719
719
|
regression_targets = None
|
|
@@ -721,7 +721,7 @@ class RoIHeads(nn.Module):
|
|
|
721
721
|
|
|
722
722
|
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
|
723
723
|
box_features = self.box_head(box_features)
|
|
724
|
-
|
|
724
|
+
class_logits, box_regression = self.box_predictor(box_features)
|
|
725
725
|
|
|
726
726
|
losses = {}
|
|
727
727
|
result: list[dict[str, torch.Tensor]] = []
|
|
@@ -731,11 +731,11 @@ class RoIHeads(nn.Module):
|
|
|
731
731
|
if regression_targets is None:
|
|
732
732
|
raise ValueError("regression_targets cannot be None")
|
|
733
733
|
|
|
734
|
-
|
|
734
|
+
loss_classifier, loss_box_reg = faster_rcnn_loss(class_logits, box_regression, labels, regression_targets)
|
|
735
735
|
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
|
736
736
|
|
|
737
737
|
else:
|
|
738
|
-
|
|
738
|
+
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
|
|
739
739
|
num_images = len(boxes)
|
|
740
740
|
for i in range(num_images):
|
|
741
741
|
result.append(
|
|
@@ -868,8 +868,8 @@ class Faster_RCNN(DetectionBaseNet):
|
|
|
868
868
|
images = self._to_img_list(x, image_sizes)
|
|
869
869
|
|
|
870
870
|
features = self.backbone_with_fpn(x)
|
|
871
|
-
|
|
872
|
-
|
|
871
|
+
proposals, proposal_losses = self.rpn(images, features, targets)
|
|
872
|
+
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
|
|
873
873
|
|
|
874
874
|
losses = {}
|
|
875
875
|
losses.update(detector_losses)
|
birder/net/detection/fcos.py
CHANGED
|
@@ -125,7 +125,7 @@ class FCOSClassificationHead(nn.Module):
|
|
|
125
125
|
cls_logits = self.cls_logits(cls_logits)
|
|
126
126
|
|
|
127
127
|
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
|
|
128
|
-
|
|
128
|
+
N, _, H, W = cls_logits.size()
|
|
129
129
|
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
|
|
130
130
|
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
|
|
131
131
|
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # (N, HWA, 4)
|
|
@@ -165,7 +165,7 @@ class FCOSRegressionHead(nn.Module):
|
|
|
165
165
|
bbox_ctrness = self.bbox_ctrness(bbox_feature)
|
|
166
166
|
|
|
167
167
|
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
|
|
168
|
-
|
|
168
|
+
N, _, H, W = bbox_regression.size()
|
|
169
169
|
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
|
|
170
170
|
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
|
|
171
171
|
bbox_regression = bbox_regression.reshape(N, -1, 4) # (N, HWA, 4)
|
|
@@ -262,7 +262,7 @@ class FCOSHead(nn.Module):
|
|
|
262
262
|
|
|
263
263
|
def forward(self, x: list[torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
264
264
|
cls_logits = self.classification_head(x)
|
|
265
|
-
|
|
265
|
+
bbox_regression, bbox_ctrness = self.regression_head(x)
|
|
266
266
|
|
|
267
267
|
return {
|
|
268
268
|
"cls_logits": cls_logits,
|
|
@@ -370,8 +370,8 @@ class FCOS(DetectionBaseNet):
|
|
|
370
370
|
).values < self.center_sampling_radius * anchor_sizes[:, None]
|
|
371
371
|
|
|
372
372
|
# Compute pairwise distance between N points and M boxes
|
|
373
|
-
|
|
374
|
-
|
|
373
|
+
x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
|
|
374
|
+
x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
|
|
375
375
|
pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
|
|
376
376
|
|
|
377
377
|
# Anchor point must be inside gt
|
|
@@ -388,7 +388,7 @@ class FCOS(DetectionBaseNet):
|
|
|
388
388
|
# Match the GT box with minimum area, if there are multiple GT matches
|
|
389
389
|
gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
|
|
390
390
|
pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
|
|
391
|
-
|
|
391
|
+
min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
|
|
392
392
|
matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1
|
|
393
393
|
|
|
394
394
|
matched_idxs.append(matched_idx)
|
|
@@ -433,7 +433,7 @@ class FCOS(DetectionBaseNet):
|
|
|
433
433
|
|
|
434
434
|
# Keep only topk scoring predictions
|
|
435
435
|
num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
|
|
436
|
-
|
|
436
|
+
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
437
437
|
topk_idxs = topk_idxs[idxs]
|
|
438
438
|
|
|
439
439
|
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
|
|
@@ -455,7 +455,7 @@ class FCOS(DetectionBaseNet):
|
|
|
455
455
|
|
|
456
456
|
# Non-maximum suppression
|
|
457
457
|
if self.soft_nms is not None:
|
|
458
|
-
|
|
458
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
459
459
|
image_scores[keep] = soft_scores
|
|
460
460
|
else:
|
|
461
461
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|