birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +13 -13
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +6 -6
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +4 -4
- birder/layers/attention_pool.py +2 -2
- birder/layers/layer_scale.py +1 -1
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +4 -10
- birder/net/_rope_vit_configs.py +435 -0
- birder/net/_vit_configs.py +466 -0
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +18 -17
- birder/net/cait.py +7 -7
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +27 -27
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +3 -11
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +6 -6
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +11 -11
- birder/net/deit.py +68 -29
- birder/net/deit3.py +69 -204
- birder/net/densenet.py +9 -8
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +31 -30
- birder/net/detection/detr.py +14 -11
- birder/net/detection/efficientdet.py +10 -29
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +5 -4
- birder/net/edgevit.py +13 -14
- birder/net/efficientformer_v1.py +3 -2
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +7 -7
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +3 -3
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +50 -58
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +13 -13
- birder/net/hgnet_v1.py +6 -6
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +5 -15
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +11 -23
- birder/net/metaformer.py +5 -5
- birder/net/mim/crossmae.py +6 -6
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +4 -6
- birder/net/mim/simmim.py +3 -4
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +7 -34
- birder/net/mobilevit_v2.py +6 -54
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +30 -30
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +9 -9
- birder/net/pvt_v2.py +10 -16
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +5 -35
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +4 -1
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +62 -151
- birder/net/rope_flexivit.py +46 -33
- birder/net/rope_vit.py +44 -758
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +69 -21
- birder/net/smt.py +8 -8
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +4 -4
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +13 -3
- birder/net/ssl/franca.py +28 -4
- birder/net/ssl/i_jepa.py +5 -5
- birder/net/ssl/ibot.py +1 -1
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +13 -3
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +5 -8
- birder/net/tiny_vit.py +6 -19
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/van.py +2 -2
- birder/net/vgg.py +1 -10
- birder/net/vit.py +72 -987
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +23 -48
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +16 -13
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +12 -3
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +15 -15
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- birder-0.3.3.dist-info/RECORD +0 -299
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2010.04159
|
|
|
9
9
|
|
|
10
10
|
Changes from original:
|
|
11
11
|
* Removed two stage support
|
|
12
|
-
*
|
|
12
|
+
* Penalize cost matrix elements on overflow (HungarianMatcher)
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
15
|
# Reference license: Apache-2.0 (both)
|
|
@@ -58,7 +58,7 @@ class HungarianMatcher(nn.Module):
|
|
|
58
58
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
59
59
|
) -> list[torch.Tensor]:
|
|
60
60
|
with torch.no_grad():
|
|
61
|
-
|
|
61
|
+
B, num_queries = class_logits.shape[:2]
|
|
62
62
|
|
|
63
63
|
# We flatten to compute the cost matrices in a batch
|
|
64
64
|
out_prob = class_logits.flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
|
@@ -89,7 +89,10 @@ class HungarianMatcher(nn.Module):
|
|
|
89
89
|
# Final cost matrix
|
|
90
90
|
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
|
|
91
91
|
C = C.view(B, num_queries, -1).cpu()
|
|
92
|
-
|
|
92
|
+
finite = torch.isfinite(C)
|
|
93
|
+
if not torch.all(finite):
|
|
94
|
+
penalty = C[finite].max().item() + 1.0 if finite.any().item() else 1.0
|
|
95
|
+
C.nan_to_num_(nan=penalty, posinf=penalty, neginf=penalty)
|
|
93
96
|
|
|
94
97
|
sizes = [len(v["boxes"]) for v in targets]
|
|
95
98
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
|
@@ -108,8 +111,7 @@ def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
|
|
108
111
|
class MultiScaleDeformableAttention(nn.Module):
|
|
109
112
|
def __init__(self, d_model: int, n_levels: int, n_heads: int, n_points: int) -> None:
|
|
110
113
|
super().__init__()
|
|
111
|
-
|
|
112
|
-
raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
|
|
114
|
+
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
113
115
|
|
|
114
116
|
# Ensure dim_per_head is power of 2
|
|
115
117
|
dim_per_head = d_model // n_heads
|
|
@@ -130,9 +132,9 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
130
132
|
self.value_proj = nn.Linear(d_model, d_model)
|
|
131
133
|
self.output_proj = nn.Linear(d_model, d_model)
|
|
132
134
|
|
|
133
|
-
self.
|
|
135
|
+
self.reset_parameters()
|
|
134
136
|
|
|
135
|
-
def
|
|
137
|
+
def reset_parameters(self) -> None:
|
|
136
138
|
nn.init.constant_(self.sampling_offsets.weight, 0.0)
|
|
137
139
|
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
|
138
140
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
@@ -163,8 +165,8 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
163
165
|
input_level_start_index: torch.Tensor,
|
|
164
166
|
input_padding_mask: Optional[torch.Tensor] = None,
|
|
165
167
|
) -> torch.Tensor:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
+
N, num_queries, _ = query.size()
|
|
169
|
+
N, sequence_length, _ = input_flatten.size()
|
|
168
170
|
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
|
|
169
171
|
|
|
170
172
|
value = self.value_proj(input_flatten)
|
|
@@ -280,7 +282,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|
|
280
282
|
q = tgt + query_pos
|
|
281
283
|
k = tgt + query_pos
|
|
282
284
|
|
|
283
|
-
|
|
285
|
+
tgt2, _ = self.self_attn(
|
|
284
286
|
q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
|
|
285
287
|
)
|
|
286
288
|
tgt2 = tgt2.transpose(0, 1)
|
|
@@ -315,7 +317,7 @@ class DeformableTransformerEncoder(nn.Module):
|
|
|
315
317
|
for lvl, spatial_shape in enumerate(spatial_shapes):
|
|
316
318
|
H = spatial_shape[0]
|
|
317
319
|
W = spatial_shape[1]
|
|
318
|
-
|
|
320
|
+
ref_y, ref_x = torch.meshgrid(
|
|
319
321
|
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
|
|
320
322
|
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
|
|
321
323
|
indexing="ij",
|
|
@@ -451,7 +453,7 @@ class DeformableTransformer(nn.Module):
|
|
|
451
453
|
|
|
452
454
|
for m in self.modules():
|
|
453
455
|
if isinstance(m, MultiScaleDeformableAttention):
|
|
454
|
-
m.
|
|
456
|
+
m.reset_parameters()
|
|
455
457
|
|
|
456
458
|
nn.init.xavier_uniform_(self.reference_points.weight, gain=1.0)
|
|
457
459
|
nn.init.zeros_(self.reference_points.bias)
|
|
@@ -459,7 +461,7 @@ class DeformableTransformer(nn.Module):
|
|
|
459
461
|
nn.init.normal_(self.level_embed)
|
|
460
462
|
|
|
461
463
|
def get_valid_ratio(self, mask: torch.Tensor) -> torch.Tensor:
|
|
462
|
-
|
|
464
|
+
_, H, W = mask.size()
|
|
463
465
|
valid_h = torch.sum(~mask[:, :, 0], 1)
|
|
464
466
|
valid_w = torch.sum(~mask[:, 0, :], 1)
|
|
465
467
|
valid_ratio_h = valid_h.float() / H
|
|
@@ -482,7 +484,7 @@ class DeformableTransformer(nn.Module):
|
|
|
482
484
|
mask_list = []
|
|
483
485
|
spatial_shape_list: list[list[int]] = [] # list[tuple[int, int]] not supported on TorchScript
|
|
484
486
|
for lvl, (src, pos_embed, mask) in enumerate(zip(srcs, pos_embeds, masks)):
|
|
485
|
-
|
|
487
|
+
_, _, H, W = src.size()
|
|
486
488
|
spatial_shape_list.append([H, W])
|
|
487
489
|
src = src.flatten(2).transpose(1, 2)
|
|
488
490
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
@@ -505,14 +507,14 @@ class DeformableTransformer(nn.Module):
|
|
|
505
507
|
)
|
|
506
508
|
|
|
507
509
|
# Prepare input for decoder
|
|
508
|
-
|
|
510
|
+
B, _, C = memory.size()
|
|
509
511
|
query_embed, tgt = torch.split(query_embed, C, dim=1)
|
|
510
512
|
query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
|
|
511
513
|
tgt = tgt.unsqueeze(0).expand(B, -1, -1)
|
|
512
514
|
reference_points = self.reference_points(query_embed).sigmoid()
|
|
513
515
|
|
|
514
516
|
# Decoder
|
|
515
|
-
|
|
517
|
+
hs, inter_references = self.decoder(
|
|
516
518
|
tgt, reference_points, memory, spatial_shapes, level_start_index, query_embed, valid_ratios, mask_flatten
|
|
517
519
|
)
|
|
518
520
|
|
|
@@ -629,7 +631,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
629
631
|
prior_prob = 0.01
|
|
630
632
|
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
|
631
633
|
for class_embed in self.class_embed:
|
|
632
|
-
|
|
634
|
+
nn.init.constant_(class_embed.bias, bias_value)
|
|
633
635
|
|
|
634
636
|
def freeze(self, freeze_classifier: bool = True) -> None:
|
|
635
637
|
for param in self.parameters():
|
|
@@ -653,20 +655,19 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
653
655
|
) -> torch.Tensor:
|
|
654
656
|
idx = self._get_src_permutation_idx(indices)
|
|
655
657
|
target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
|
|
656
|
-
target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
|
|
657
|
-
target_classes[idx] = target_classes_o
|
|
658
658
|
|
|
659
659
|
target_classes_onehot = torch.zeros(
|
|
660
|
-
|
|
660
|
+
cls_logits.size(0),
|
|
661
|
+
cls_logits.size(1),
|
|
662
|
+
cls_logits.size(2) + 1,
|
|
661
663
|
dtype=cls_logits.dtype,
|
|
662
|
-
layout=cls_logits.layout,
|
|
663
664
|
device=cls_logits.device,
|
|
664
665
|
)
|
|
665
|
-
target_classes_onehot
|
|
666
|
-
|
|
666
|
+
target_classes_onehot[idx[0], idx[1], target_classes_o] = 1
|
|
667
667
|
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
|
668
|
+
|
|
668
669
|
loss = sigmoid_focal_loss(cls_logits, target_classes_onehot, alpha=0.25, gamma=2.0)
|
|
669
|
-
loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.
|
|
670
|
+
loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.size(1)
|
|
670
671
|
|
|
671
672
|
return loss_ce
|
|
672
673
|
|
|
@@ -716,7 +717,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
716
717
|
for idx in range(cls_logits.size(0)):
|
|
717
718
|
indices = self.matcher(cls_logits[idx], box_output[idx], targets)
|
|
718
719
|
loss_ce_i = self._class_loss(cls_logits[idx], targets, indices, num_boxes)
|
|
719
|
-
|
|
720
|
+
loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
|
|
720
721
|
loss_ce_list.append(loss_ce_i)
|
|
721
722
|
loss_bbox_list.append(loss_bbox_i)
|
|
722
723
|
loss_giou_list.append(loss_giou_i)
|
|
@@ -736,7 +737,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
736
737
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
|
|
737
738
|
) -> list[dict[str, torch.Tensor]]:
|
|
738
739
|
prob = class_logits.sigmoid()
|
|
739
|
-
|
|
740
|
+
topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
|
|
740
741
|
scores = topk_values
|
|
741
742
|
topk_boxes = topk_indexes // class_logits.shape[2]
|
|
742
743
|
labels = topk_indexes % class_logits.shape[2]
|
|
@@ -749,7 +750,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
749
750
|
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
|
750
751
|
|
|
751
752
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
752
|
-
|
|
753
|
+
img_h, img_w = target_sizes.unbind(1)
|
|
753
754
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
754
755
|
boxes = boxes * scale_fct[:, None, :]
|
|
755
756
|
|
|
@@ -757,7 +758,7 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
757
758
|
for s, l, b in zip(scores, labels, boxes):
|
|
758
759
|
# Non-maximum suppression
|
|
759
760
|
if self.soft_nms is not None:
|
|
760
|
-
|
|
761
|
+
soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
761
762
|
s[keep] = soft_scores
|
|
762
763
|
|
|
763
764
|
b = b[keep]
|
|
@@ -794,14 +795,14 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
794
795
|
mask_size = feature_list[idx].shape[-2:]
|
|
795
796
|
m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
|
|
796
797
|
else:
|
|
797
|
-
|
|
798
|
+
B, _, H, W = feature_list[idx].size()
|
|
798
799
|
m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
799
800
|
|
|
800
801
|
feature_list[idx] = proj(feature_list[idx])
|
|
801
802
|
mask_list.append(m)
|
|
802
803
|
pos_list.append(self.pos_enc(feature_list[idx], m))
|
|
803
804
|
|
|
804
|
-
|
|
805
|
+
hs, init_reference, inter_references = self.transformer(
|
|
805
806
|
feature_list, pos_list, self.query_embed.weight, mask_list
|
|
806
807
|
)
|
|
807
808
|
outputs_classes = []
|
birder/net/detection/detr.py
CHANGED
|
@@ -6,7 +6,7 @@ Paper "End-to-End Object Detection with Transformers", https://arxiv.org/abs/200
|
|
|
6
6
|
|
|
7
7
|
Changes from original:
|
|
8
8
|
* Move background index to first from last (to be inline with the rest of Birder detectors)
|
|
9
|
-
*
|
|
9
|
+
* Penalize cost matrix elements on overflow (HungarianMatcher)
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
12
|
# Reference license: Apache-2.0
|
|
@@ -51,7 +51,7 @@ class HungarianMatcher(nn.Module):
|
|
|
51
51
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
|
|
52
52
|
) -> list[torch.Tensor]:
|
|
53
53
|
with torch.no_grad():
|
|
54
|
-
|
|
54
|
+
B, num_queries = class_logits.shape[:2]
|
|
55
55
|
|
|
56
56
|
# We flatten to compute the cost matrices in a batch
|
|
57
57
|
out_prob = class_logits.flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
|
@@ -78,7 +78,10 @@ class HungarianMatcher(nn.Module):
|
|
|
78
78
|
# Final cost matrix
|
|
79
79
|
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
|
|
80
80
|
C = C.view(B, num_queries, -1).cpu()
|
|
81
|
-
|
|
81
|
+
finite = torch.isfinite(C)
|
|
82
|
+
if not torch.all(finite):
|
|
83
|
+
penalty = C[finite].max().item() + 1.0 if finite.any().item() else 1.0
|
|
84
|
+
C.nan_to_num_(nan=penalty, posinf=penalty, neginf=penalty)
|
|
82
85
|
|
|
83
86
|
sizes = [len(v["boxes"]) for v in targets]
|
|
84
87
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
|
@@ -108,7 +111,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
|
108
111
|
q = src + pos
|
|
109
112
|
k = src + pos
|
|
110
113
|
|
|
111
|
-
|
|
114
|
+
src2, _ = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask, need_weights=False)
|
|
112
115
|
src = src + self.dropout1(src2)
|
|
113
116
|
src = self.norm1(src)
|
|
114
117
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
@@ -148,10 +151,10 @@ class TransformerDecoderLayer(nn.Module):
|
|
|
148
151
|
q = tgt + query_pos
|
|
149
152
|
k = tgt + query_pos
|
|
150
153
|
|
|
151
|
-
|
|
154
|
+
tgt2, _ = self.self_attn(q, k, value=tgt, need_weights=False)
|
|
152
155
|
tgt = tgt + self.dropout1(tgt2)
|
|
153
156
|
tgt = self.norm1(tgt)
|
|
154
|
-
|
|
157
|
+
tgt2, _ = self.multihead_attn(
|
|
155
158
|
query=tgt + query_pos,
|
|
156
159
|
key=memory + pos,
|
|
157
160
|
value=memory,
|
|
@@ -267,7 +270,7 @@ class PositionEmbeddingSine(nn.Module):
|
|
|
267
270
|
|
|
268
271
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
269
272
|
if mask is None:
|
|
270
|
-
|
|
273
|
+
B, _, H, W = x.size()
|
|
271
274
|
mask = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
272
275
|
|
|
273
276
|
not_mask = ~mask
|
|
@@ -427,7 +430,7 @@ class DETR(DetectionBaseNet):
|
|
|
427
430
|
for idx in range(cls_logits.size(0)):
|
|
428
431
|
indices = self.matcher(cls_logits[idx], box_output[idx], targets)
|
|
429
432
|
loss_ce_i = self._class_loss(cls_logits[idx], targets, indices)
|
|
430
|
-
|
|
433
|
+
loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
|
|
431
434
|
loss_ce_list.append(loss_ce_i)
|
|
432
435
|
loss_bbox_list.append(loss_bbox_i)
|
|
433
436
|
loss_giou_list.append(loss_giou_i)
|
|
@@ -447,7 +450,7 @@ class DETR(DetectionBaseNet):
|
|
|
447
450
|
self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
|
|
448
451
|
) -> list[dict[str, torch.Tensor]]:
|
|
449
452
|
prob = F.softmax(class_logits, -1)
|
|
450
|
-
|
|
453
|
+
scores, labels = prob[..., 1:].max(-1)
|
|
451
454
|
labels = labels + 1
|
|
452
455
|
|
|
453
456
|
# TorchScript doesn't support creating tensor from tuples, convert everything to lists
|
|
@@ -457,7 +460,7 @@ class DETR(DetectionBaseNet):
|
|
|
457
460
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
458
461
|
|
|
459
462
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
460
|
-
|
|
463
|
+
img_h, img_w = target_sizes.unbind(1)
|
|
461
464
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
462
465
|
boxes = boxes * scale_fct[:, None, :]
|
|
463
466
|
|
|
@@ -465,7 +468,7 @@ class DETR(DetectionBaseNet):
|
|
|
465
468
|
for s, l, b in zip(scores, labels, boxes):
|
|
466
469
|
# Non-maximum suppression
|
|
467
470
|
if self.soft_nms is not None:
|
|
468
|
-
|
|
471
|
+
soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
469
472
|
s[keep] = soft_scores
|
|
470
473
|
|
|
471
474
|
b = b[keep]
|
|
@@ -136,8 +136,8 @@ class ResampleFeatureMap(nn.Module):
|
|
|
136
136
|
if self.conv is not None:
|
|
137
137
|
x = self.conv(x)
|
|
138
138
|
|
|
139
|
-
|
|
140
|
-
|
|
139
|
+
in_h, in_w = x.shape[-2:]
|
|
140
|
+
target_h, target_w = target_size
|
|
141
141
|
if in_h == target_h and in_w == target_w:
|
|
142
142
|
return x
|
|
143
143
|
|
|
@@ -195,7 +195,7 @@ class FpnCombine(nn.Module):
|
|
|
195
195
|
)
|
|
196
196
|
|
|
197
197
|
if weight_method in {"attn", "fastattn"}:
|
|
198
|
-
self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets))
|
|
198
|
+
self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets))) # WSM
|
|
199
199
|
else:
|
|
200
200
|
self.edge_weights = None
|
|
201
201
|
|
|
@@ -358,13 +358,7 @@ class HeadNet(nn.Module):
|
|
|
358
358
|
for _ in range(repeats):
|
|
359
359
|
layers.append(
|
|
360
360
|
nn.Conv2d(
|
|
361
|
-
fpn_channels,
|
|
362
|
-
fpn_channels,
|
|
363
|
-
kernel_size=(3, 3),
|
|
364
|
-
stride=(1, 1),
|
|
365
|
-
padding=(1, 1),
|
|
366
|
-
groups=fpn_channels,
|
|
367
|
-
bias=True,
|
|
361
|
+
fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
|
|
368
362
|
)
|
|
369
363
|
)
|
|
370
364
|
layers.append(
|
|
@@ -383,22 +377,9 @@ class HeadNet(nn.Module):
|
|
|
383
377
|
self.conv_repeat = nn.Sequential(*layers)
|
|
384
378
|
self.predict = nn.Sequential(
|
|
385
379
|
nn.Conv2d(
|
|
386
|
-
fpn_channels,
|
|
387
|
-
fpn_channels,
|
|
388
|
-
kernel_size=(3, 3),
|
|
389
|
-
stride=(1, 1),
|
|
390
|
-
padding=(1, 1),
|
|
391
|
-
groups=fpn_channels,
|
|
392
|
-
bias=True,
|
|
393
|
-
),
|
|
394
|
-
nn.Conv2d(
|
|
395
|
-
fpn_channels,
|
|
396
|
-
num_outputs * num_anchors,
|
|
397
|
-
kernel_size=(1, 1),
|
|
398
|
-
stride=(1, 1),
|
|
399
|
-
padding=(0, 0),
|
|
400
|
-
bias=True,
|
|
380
|
+
fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
|
|
401
381
|
),
|
|
382
|
+
nn.Conv2d(fpn_channels, num_outputs * num_anchors, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
|
|
402
383
|
)
|
|
403
384
|
|
|
404
385
|
def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
@@ -453,7 +434,7 @@ class ClassificationHead(HeadNet):
|
|
|
453
434
|
cls_logits = self.predict(cls_logits)
|
|
454
435
|
|
|
455
436
|
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
|
|
456
|
-
|
|
437
|
+
N, _, H, W = cls_logits.shape
|
|
457
438
|
cls_logits = cls_logits.view(N, -1, self.num_outputs, H, W)
|
|
458
439
|
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
|
|
459
440
|
cls_logits = cls_logits.reshape(N, -1, self.num_outputs) # Size=(N, HWA, K)
|
|
@@ -504,7 +485,7 @@ class RegressionHead(HeadNet):
|
|
|
504
485
|
bbox_regression = self.predict(bbox_regression)
|
|
505
486
|
|
|
506
487
|
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
|
|
507
|
-
|
|
488
|
+
N, _, H, W = bbox_regression.shape
|
|
508
489
|
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
|
|
509
490
|
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
|
|
510
491
|
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
|
|
@@ -663,7 +644,7 @@ class EfficientDet(DetectionBaseNet):
|
|
|
663
644
|
|
|
664
645
|
# Keep only topk scoring predictions
|
|
665
646
|
num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
|
|
666
|
-
|
|
647
|
+
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
667
648
|
topk_idxs = topk_idxs[idxs]
|
|
668
649
|
|
|
669
650
|
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
|
|
@@ -685,7 +666,7 @@ class EfficientDet(DetectionBaseNet):
|
|
|
685
666
|
|
|
686
667
|
# Non-maximum suppression
|
|
687
668
|
if self.soft_nms is not None:
|
|
688
|
-
|
|
669
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
689
670
|
image_scores[keep] = soft_scores
|
|
690
671
|
else:
|
|
691
672
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
@@ -150,7 +150,7 @@ def concat_box_prediction_layers(
|
|
|
150
150
|
# all feature levels concatenated, so we keep the same representation
|
|
151
151
|
# for the objectness and the box_regression
|
|
152
152
|
for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
|
|
153
|
-
|
|
153
|
+
N, AxC, H, W = box_cls_per_level.shape # pylint: disable=invalid-name
|
|
154
154
|
Ax4 = box_regression_per_level.shape[1] # pylint: disable=invalid-name
|
|
155
155
|
A = Ax4 // 4
|
|
156
156
|
C = AxC // A
|
|
@@ -240,7 +240,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
240
240
|
|
|
241
241
|
# Get the targets corresponding GT for each proposal
|
|
242
242
|
# NB: need to clamp the indices because we can have a single
|
|
243
|
-
# GT in the image
|
|
243
|
+
# GT in the image and matched_idxs can be -2, which goes out of bounds
|
|
244
244
|
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
|
|
245
245
|
|
|
246
246
|
labels_per_image = matched_idxs >= 0
|
|
@@ -265,7 +265,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
265
265
|
for ob in objectness.split(num_anchors_per_level, 1):
|
|
266
266
|
num_anchors = ob.shape[1]
|
|
267
267
|
pre_nms_top_n = min(self.pre_nms_top_n(), int(ob.size(1)))
|
|
268
|
-
|
|
268
|
+
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
|
|
269
269
|
r.append(top_n_idx + offset)
|
|
270
270
|
offset += num_anchors
|
|
271
271
|
|
|
@@ -310,19 +310,19 @@ class RegionProposalNetwork(nn.Module):
|
|
|
310
310
|
|
|
311
311
|
# Remove small boxes
|
|
312
312
|
keep = box_ops.remove_small_boxes(boxes, self.min_size)
|
|
313
|
-
|
|
313
|
+
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
|
|
314
314
|
|
|
315
315
|
# Remove low scoring boxes
|
|
316
316
|
# use >= for Backwards compatibility
|
|
317
317
|
keep = torch.where(scores >= self.score_thresh)[0]
|
|
318
|
-
|
|
318
|
+
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
|
|
319
319
|
|
|
320
320
|
# Non-maximum suppression, independently done per level
|
|
321
321
|
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
|
|
322
322
|
|
|
323
323
|
# Keep only topk scoring predictions
|
|
324
324
|
keep = keep[: self.post_nms_top_n()]
|
|
325
|
-
|
|
325
|
+
boxes, scores = boxes[keep], scores[keep]
|
|
326
326
|
|
|
327
327
|
final_boxes.append(boxes)
|
|
328
328
|
final_scores.append(scores)
|
|
@@ -336,7 +336,7 @@ class RegionProposalNetwork(nn.Module):
|
|
|
336
336
|
labels: list[torch.Tensor],
|
|
337
337
|
regression_targets: list[torch.Tensor],
|
|
338
338
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
339
|
-
|
|
339
|
+
sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
|
|
340
340
|
sampled_pos_idxs = torch.where(torch.concat(sampled_pos_idxs, dim=0))[0]
|
|
341
341
|
sampled_neg_idxs = torch.where(torch.concat(sampled_neg_idxs, dim=0))[0]
|
|
342
342
|
|
|
@@ -364,29 +364,29 @@ class RegionProposalNetwork(nn.Module):
|
|
|
364
364
|
) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
|
|
365
365
|
# RPN uses all feature maps that are available
|
|
366
366
|
features_list = list(features.values())
|
|
367
|
-
|
|
367
|
+
objectness, pred_bbox_deltas = self.head(features_list)
|
|
368
368
|
anchors = self.anchor_generator(images, features_list)
|
|
369
369
|
|
|
370
370
|
num_images = len(anchors)
|
|
371
371
|
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
|
|
372
372
|
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
|
|
373
|
-
|
|
373
|
+
objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
|
|
374
374
|
|
|
375
375
|
# Apply pred_bbox_deltas to anchors to obtain the decoded proposals
|
|
376
376
|
# note that we detach the deltas because Faster R-CNN do not backprop through
|
|
377
377
|
# the proposals
|
|
378
378
|
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
|
|
379
379
|
proposals = proposals.view(num_images, -1, 4)
|
|
380
|
-
|
|
380
|
+
boxes, _scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
|
|
381
381
|
|
|
382
382
|
losses: dict[str, torch.Tensor] = {}
|
|
383
383
|
if self.training is True:
|
|
384
384
|
if targets is None:
|
|
385
385
|
raise ValueError("targets should not be None")
|
|
386
386
|
|
|
387
|
-
|
|
387
|
+
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
|
|
388
388
|
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
|
|
389
|
-
|
|
389
|
+
loss_objectness, loss_rpn_box_reg = self.compute_loss(
|
|
390
390
|
objectness, pred_bbox_deltas, labels, regression_targets
|
|
391
391
|
)
|
|
392
392
|
losses = {
|
|
@@ -405,7 +405,7 @@ class FastRCNNConvFCHead(nn.Sequential):
|
|
|
405
405
|
fc_layers: list[int],
|
|
406
406
|
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
407
407
|
):
|
|
408
|
-
|
|
408
|
+
in_channels, in_height, in_width = input_size
|
|
409
409
|
|
|
410
410
|
blocks = []
|
|
411
411
|
previous_channels = in_channels
|
|
@@ -481,7 +481,7 @@ def faster_rcnn_loss(
|
|
|
481
481
|
# advanced indexing
|
|
482
482
|
sampled_pos_idxs_subset = torch.where(labels > 0)[0]
|
|
483
483
|
labels_pos = labels[sampled_pos_idxs_subset]
|
|
484
|
-
|
|
484
|
+
N, _num_classes = class_logits.shape
|
|
485
485
|
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
|
|
486
486
|
|
|
487
487
|
box_loss = F.smooth_l1_loss(
|
|
@@ -573,7 +573,7 @@ class RoIHeads(nn.Module):
|
|
|
573
573
|
return (matched_idxs, labels)
|
|
574
574
|
|
|
575
575
|
def subsample(self, labels: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
576
|
-
|
|
576
|
+
sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
|
|
577
577
|
sampled_idxs = []
|
|
578
578
|
for pos_idxs_img, neg_idxs_img in zip(sampled_pos_idxs, sampled_neg_idxs):
|
|
579
579
|
img_sampled_idxs = torch.where(pos_idxs_img | neg_idxs_img)[0]
|
|
@@ -610,7 +610,7 @@ class RoIHeads(nn.Module):
|
|
|
610
610
|
proposals = self.add_gt_proposals(proposals, gt_boxes)
|
|
611
611
|
|
|
612
612
|
# Get matching gt indices for each proposal
|
|
613
|
-
|
|
613
|
+
matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
|
|
614
614
|
|
|
615
615
|
# Sample a fixed proportion of positive-negative proposals
|
|
616
616
|
sampled_idxs = self.subsample(labels)
|
|
@@ -713,7 +713,7 @@ class RoIHeads(nn.Module):
|
|
|
713
713
|
raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
|
|
714
714
|
|
|
715
715
|
if self.training is True:
|
|
716
|
-
|
|
716
|
+
proposals, _matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
|
|
717
717
|
else:
|
|
718
718
|
labels = None
|
|
719
719
|
regression_targets = None
|
|
@@ -721,7 +721,7 @@ class RoIHeads(nn.Module):
|
|
|
721
721
|
|
|
722
722
|
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
|
723
723
|
box_features = self.box_head(box_features)
|
|
724
|
-
|
|
724
|
+
class_logits, box_regression = self.box_predictor(box_features)
|
|
725
725
|
|
|
726
726
|
losses = {}
|
|
727
727
|
result: list[dict[str, torch.Tensor]] = []
|
|
@@ -731,11 +731,11 @@ class RoIHeads(nn.Module):
|
|
|
731
731
|
if regression_targets is None:
|
|
732
732
|
raise ValueError("regression_targets cannot be None")
|
|
733
733
|
|
|
734
|
-
|
|
734
|
+
loss_classifier, loss_box_reg = faster_rcnn_loss(class_logits, box_regression, labels, regression_targets)
|
|
735
735
|
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
|
736
736
|
|
|
737
737
|
else:
|
|
738
|
-
|
|
738
|
+
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
|
|
739
739
|
num_images = len(boxes)
|
|
740
740
|
for i in range(num_images):
|
|
741
741
|
result.append(
|
|
@@ -868,8 +868,8 @@ class Faster_RCNN(DetectionBaseNet):
|
|
|
868
868
|
images = self._to_img_list(x, image_sizes)
|
|
869
869
|
|
|
870
870
|
features = self.backbone_with_fpn(x)
|
|
871
|
-
|
|
872
|
-
|
|
871
|
+
proposals, proposal_losses = self.rpn(images, features, targets)
|
|
872
|
+
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
|
|
873
873
|
|
|
874
874
|
losses = {}
|
|
875
875
|
losses.update(detector_losses)
|
birder/net/detection/fcos.py
CHANGED
|
@@ -125,7 +125,7 @@ class FCOSClassificationHead(nn.Module):
|
|
|
125
125
|
cls_logits = self.cls_logits(cls_logits)
|
|
126
126
|
|
|
127
127
|
# Permute classification output from (N, A * K, H, W) to (N, HWA, K).
|
|
128
|
-
|
|
128
|
+
N, _, H, W = cls_logits.size()
|
|
129
129
|
cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
|
|
130
130
|
cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
|
|
131
131
|
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # (N, HWA, 4)
|
|
@@ -165,7 +165,7 @@ class FCOSRegressionHead(nn.Module):
|
|
|
165
165
|
bbox_ctrness = self.bbox_ctrness(bbox_feature)
|
|
166
166
|
|
|
167
167
|
# Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
|
|
168
|
-
|
|
168
|
+
N, _, H, W = bbox_regression.size()
|
|
169
169
|
bbox_regression = bbox_regression.view(N, -1, 4, H, W)
|
|
170
170
|
bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
|
|
171
171
|
bbox_regression = bbox_regression.reshape(N, -1, 4) # (N, HWA, 4)
|
|
@@ -262,7 +262,7 @@ class FCOSHead(nn.Module):
|
|
|
262
262
|
|
|
263
263
|
def forward(self, x: list[torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
264
264
|
cls_logits = self.classification_head(x)
|
|
265
|
-
|
|
265
|
+
bbox_regression, bbox_ctrness = self.regression_head(x)
|
|
266
266
|
|
|
267
267
|
return {
|
|
268
268
|
"cls_logits": cls_logits,
|
|
@@ -370,8 +370,8 @@ class FCOS(DetectionBaseNet):
|
|
|
370
370
|
).values < self.center_sampling_radius * anchor_sizes[:, None]
|
|
371
371
|
|
|
372
372
|
# Compute pairwise distance between N points and M boxes
|
|
373
|
-
|
|
374
|
-
|
|
373
|
+
x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
|
|
374
|
+
x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
|
|
375
375
|
pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
|
|
376
376
|
|
|
377
377
|
# Anchor point must be inside gt
|
|
@@ -388,7 +388,7 @@ class FCOS(DetectionBaseNet):
|
|
|
388
388
|
# Match the GT box with minimum area, if there are multiple GT matches
|
|
389
389
|
gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
|
|
390
390
|
pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
|
|
391
|
-
|
|
391
|
+
min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
|
|
392
392
|
matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1
|
|
393
393
|
|
|
394
394
|
matched_idxs.append(matched_idx)
|
|
@@ -433,7 +433,7 @@ class FCOS(DetectionBaseNet):
|
|
|
433
433
|
|
|
434
434
|
# Keep only topk scoring predictions
|
|
435
435
|
num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
|
|
436
|
-
|
|
436
|
+
scores_per_level, idxs = scores_per_level.topk(num_topk)
|
|
437
437
|
topk_idxs = topk_idxs[idxs]
|
|
438
438
|
|
|
439
439
|
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
|
|
@@ -455,7 +455,7 @@ class FCOS(DetectionBaseNet):
|
|
|
455
455
|
|
|
456
456
|
# Non-maximum suppression
|
|
457
457
|
if self.soft_nms is not None:
|
|
458
|
-
|
|
458
|
+
soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
459
459
|
image_scores[keep] = soft_scores
|
|
460
460
|
else:
|
|
461
461
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|