birder 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/training_cli.py +6 -1
- birder/common/training_utils.py +69 -12
- birder/net/_vit_configs.py +5 -0
- birder/net/cait.py +3 -3
- birder/net/coat.py +3 -3
- birder/net/deit.py +1 -1
- birder/net/deit3.py +1 -1
- birder/net/detection/__init__.py +2 -0
- birder/net/detection/deformable_detr.py +12 -12
- birder/net/detection/detr.py +7 -7
- birder/net/detection/lw_detr.py +1181 -0
- birder/net/detection/plain_detr.py +7 -5
- birder/net/detection/retinanet.py +1 -1
- birder/net/detection/rt_detr_v1.py +10 -10
- birder/net/detection/rt_detr_v2.py +47 -64
- birder/net/detection/ssdlite.py +2 -2
- birder/net/edgevit.py +3 -3
- birder/net/efficientvit_msft.py +1 -1
- birder/net/flexivit.py +1 -1
- birder/net/hieradet.py +2 -2
- birder/net/mnasnet.py +2 -2
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +1 -1
- birder/net/rope_flexivit.py +1 -1
- birder/net/rope_vit.py +1 -1
- birder/net/simple_vit.py +1 -1
- birder/net/vit.py +21 -3
- birder/net/vit_parallel.py +1 -1
- birder/net/vit_sam.py +62 -16
- birder/scripts/train.py +12 -8
- birder/scripts/train_capi.py +13 -10
- birder/scripts/train_detection.py +2 -1
- birder/scripts/train_kd.py +12 -8
- birder/version.py +1 -1
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/METADATA +3 -3
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/RECORD +40 -39
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/WHEEL +1 -1
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/entry_points.txt +0 -0
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -522,13 +522,13 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
522
522
|
|
|
523
523
|
self.class_embed = nn.Linear(hidden_dim, self.num_classes)
|
|
524
524
|
self.bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
|
|
525
|
-
self.query_embed = nn.
|
|
525
|
+
self.query_embed = nn.Parameter(torch.empty(self.num_queries, hidden_dim * 2))
|
|
526
526
|
self.reference_point_head = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
|
|
527
527
|
self.input_proj = nn.Conv2d(
|
|
528
528
|
self.backbone.return_channels[-1], hidden_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
|
|
529
529
|
)
|
|
530
530
|
self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
|
|
531
|
-
self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
|
|
531
|
+
self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
|
|
532
532
|
|
|
533
533
|
if box_refine is True:
|
|
534
534
|
self.class_embed = _get_clones(self.class_embed, num_decoder_layers)
|
|
@@ -554,6 +554,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
554
554
|
if idx == 0:
|
|
555
555
|
nn.init.constant_(last_linear.bias[2:], -2.0) # Small initial wh
|
|
556
556
|
|
|
557
|
+
nn.init.normal_(self.query_embed)
|
|
557
558
|
ref_last_linear = [m for m in self.reference_point_head.modules() if isinstance(m, nn.Linear)][-1]
|
|
558
559
|
nn.init.zeros_(ref_last_linear.weight)
|
|
559
560
|
nn.init.zeros_(ref_last_linear.bias)
|
|
@@ -576,7 +577,8 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
576
577
|
for param in self.class_embed.parameters():
|
|
577
578
|
param.requires_grad_(True)
|
|
578
579
|
|
|
579
|
-
|
|
580
|
+
@staticmethod
|
|
581
|
+
def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
580
582
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
581
583
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
582
584
|
return (batch_idx, src_idx)
|
|
@@ -646,7 +648,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
646
648
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
647
649
|
torch.distributed.all_reduce(num_boxes)
|
|
648
650
|
|
|
649
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
651
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
650
652
|
|
|
651
653
|
loss_ce_list = []
|
|
652
654
|
loss_bbox_list = []
|
|
@@ -772,7 +774,7 @@ class Plain_DETR(DetectionBaseNet):
|
|
|
772
774
|
else:
|
|
773
775
|
num_queries_to_use = self.num_queries_one2one
|
|
774
776
|
|
|
775
|
-
query_embed = self.query_embed
|
|
777
|
+
query_embed = self.query_embed[:num_queries_to_use]
|
|
776
778
|
query_embed, query_pos = torch.split(query_embed, self.hidden_dim, dim=1)
|
|
777
779
|
query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
|
|
778
780
|
query_pos = query_pos.unsqueeze(0).expand(B, -1, -1)
|
|
@@ -63,7 +63,7 @@ class RetinaNetClassificationHead(nn.Module):
|
|
|
63
63
|
if isinstance(layer, nn.Conv2d):
|
|
64
64
|
nn.init.normal_(layer.weight, std=0.01)
|
|
65
65
|
if layer.bias is not None:
|
|
66
|
-
nn.init.
|
|
66
|
+
nn.init.zeros_(layer.bias)
|
|
67
67
|
|
|
68
68
|
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
|
|
69
69
|
|
|
@@ -596,18 +596,18 @@ class RT_DETRDecoder(nn.Module):
|
|
|
596
596
|
|
|
597
597
|
# Gather reference points
|
|
598
598
|
reference_points_unact = enc_outputs_coord_unact.gather(
|
|
599
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
599
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
|
|
600
600
|
)
|
|
601
601
|
|
|
602
602
|
enc_topk_bboxes = reference_points_unact.sigmoid()
|
|
603
603
|
|
|
604
604
|
# Gather encoder logits for loss computation
|
|
605
605
|
enc_topk_logits = enc_outputs_class.gather(
|
|
606
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
606
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
|
|
607
607
|
)
|
|
608
608
|
|
|
609
609
|
# Extract region features
|
|
610
|
-
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).
|
|
610
|
+
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
|
|
611
611
|
target = target.detach()
|
|
612
612
|
|
|
613
613
|
return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
|
|
@@ -653,7 +653,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
653
653
|
reference_points = init_ref_points_unact.sigmoid()
|
|
654
654
|
for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
|
|
655
655
|
query_pos = self.query_pos_head(reference_points)
|
|
656
|
-
reference_points_input = reference_points.unsqueeze(2).
|
|
656
|
+
reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
|
|
657
657
|
target = decoder_layer(
|
|
658
658
|
target,
|
|
659
659
|
query_pos,
|
|
@@ -743,7 +743,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
743
743
|
self.decoder = RT_DETRDecoder(
|
|
744
744
|
hidden_dim=hidden_dim,
|
|
745
745
|
num_classes=self.num_classes,
|
|
746
|
-
num_queries=num_queries,
|
|
746
|
+
num_queries=self.num_queries,
|
|
747
747
|
num_decoder_layers=num_decoder_layers,
|
|
748
748
|
num_levels=self.num_levels,
|
|
749
749
|
num_heads=num_heads,
|
|
@@ -810,7 +810,8 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
810
810
|
for param in self.denoising_class_embed.parameters():
|
|
811
811
|
param.requires_grad_(True)
|
|
812
812
|
|
|
813
|
-
|
|
813
|
+
@staticmethod
|
|
814
|
+
def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
814
815
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
815
816
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
816
817
|
return (batch_idx, src_idx)
|
|
@@ -927,8 +928,6 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
927
928
|
|
|
928
929
|
return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
|
|
929
930
|
|
|
930
|
-
@torch.jit.unused # type: ignore[untyped-decorator]
|
|
931
|
-
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
932
931
|
def _compute_loss_from_outputs( # pylint: disable=too-many-locals
|
|
933
932
|
self,
|
|
934
933
|
targets: list[dict[str, torch.Tensor]],
|
|
@@ -946,7 +945,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
946
945
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
947
946
|
torch.distributed.all_reduce(num_boxes)
|
|
948
947
|
|
|
949
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
948
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
950
949
|
|
|
951
950
|
loss_ce_list = []
|
|
952
951
|
loss_bbox_list = []
|
|
@@ -1051,7 +1050,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1051
1050
|
|
|
1052
1051
|
# Convert to [x0, y0, x1, y1] format
|
|
1053
1052
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
1054
|
-
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).
|
|
1053
|
+
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
|
|
1055
1054
|
|
|
1056
1055
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
1057
1056
|
img_h, img_w = target_sizes.unbind(1)
|
|
@@ -1113,6 +1112,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1113
1112
|
else:
|
|
1114
1113
|
B, _, H, W = feat.size()
|
|
1115
1114
|
m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
1115
|
+
|
|
1116
1116
|
mask_list.append(m)
|
|
1117
1117
|
|
|
1118
1118
|
encoder_features = self.encoder(feature_list, masks=mask_list)
|
|
@@ -147,7 +147,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
147
147
|
param.requires_grad_(False)
|
|
148
148
|
|
|
149
149
|
def reset_parameters(self) -> None:
|
|
150
|
-
nn.init.
|
|
150
|
+
nn.init.zeros_(self.sampling_offsets.weight)
|
|
151
151
|
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
|
152
152
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
153
153
|
grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)[0]
|
|
@@ -158,12 +158,12 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
158
158
|
with torch.no_grad():
|
|
159
159
|
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
|
160
160
|
|
|
161
|
-
nn.init.
|
|
162
|
-
nn.init.
|
|
161
|
+
nn.init.zeros_(self.attention_weights.weight)
|
|
162
|
+
nn.init.zeros_(self.attention_weights.bias)
|
|
163
163
|
nn.init.xavier_uniform_(self.value_proj.weight)
|
|
164
|
-
nn.init.
|
|
164
|
+
nn.init.zeros_(self.value_proj.bias)
|
|
165
165
|
nn.init.xavier_uniform_(self.output_proj.weight)
|
|
166
|
-
nn.init.
|
|
166
|
+
nn.init.zeros_(self.output_proj.bias)
|
|
167
167
|
|
|
168
168
|
def forward(
|
|
169
169
|
self,
|
|
@@ -174,7 +174,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|
|
174
174
|
input_level_start_index: torch.Tensor,
|
|
175
175
|
input_padding_mask: Optional[torch.Tensor] = None,
|
|
176
176
|
) -> torch.Tensor:
|
|
177
|
-
|
|
177
|
+
num_queries = query.size(1)
|
|
178
178
|
N, sequence_length, _ = input_flatten.size()
|
|
179
179
|
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
|
|
180
180
|
|
|
@@ -366,10 +366,9 @@ class TransformerDecoderLayer(nn.Module):
|
|
|
366
366
|
self_attn_mask: Optional[torch.Tensor] = None,
|
|
367
367
|
) -> torch.Tensor:
|
|
368
368
|
# Self attention
|
|
369
|
-
|
|
370
|
-
k = tgt + query_pos
|
|
369
|
+
q_k = tgt + query_pos
|
|
371
370
|
|
|
372
|
-
tgt2 = self.self_attn(
|
|
371
|
+
tgt2 = self.self_attn(q_k, q_k, tgt, attn_mask=self_attn_mask)
|
|
373
372
|
tgt = tgt + self.dropout(tgt2)
|
|
374
373
|
tgt = self.norm1(tgt)
|
|
375
374
|
|
|
@@ -526,18 +525,18 @@ class RT_DETRDecoder(nn.Module):
|
|
|
526
525
|
|
|
527
526
|
# Gather reference points
|
|
528
527
|
reference_points_unact = enc_outputs_coord_unact.gather(
|
|
529
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
528
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
|
|
530
529
|
)
|
|
531
530
|
|
|
532
531
|
enc_topk_bboxes = reference_points_unact.sigmoid()
|
|
533
532
|
|
|
534
533
|
# Gather encoder logits for loss computation
|
|
535
534
|
enc_topk_logits = enc_outputs_class.gather(
|
|
536
|
-
dim=1, index=topk_ind.unsqueeze(-1).
|
|
535
|
+
dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
|
|
537
536
|
)
|
|
538
537
|
|
|
539
538
|
# Extract region features
|
|
540
|
-
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).
|
|
539
|
+
target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
|
|
541
540
|
target = target.detach()
|
|
542
541
|
|
|
543
542
|
return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
|
|
@@ -583,7 +582,7 @@ class RT_DETRDecoder(nn.Module):
|
|
|
583
582
|
reference_points = init_ref_points_unact.sigmoid()
|
|
584
583
|
for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
|
|
585
584
|
query_pos = self.query_pos_head(reference_points)
|
|
586
|
-
reference_points_input = reference_points.unsqueeze(2).
|
|
585
|
+
reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
|
|
587
586
|
target = decoder_layer(
|
|
588
587
|
target,
|
|
589
588
|
query_pos,
|
|
@@ -675,7 +674,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
675
674
|
self.decoder = RT_DETRDecoder(
|
|
676
675
|
hidden_dim=hidden_dim,
|
|
677
676
|
num_classes=self.num_classes,
|
|
678
|
-
num_queries=num_queries,
|
|
677
|
+
num_queries=self.num_queries,
|
|
679
678
|
num_decoder_layers=num_decoder_layers,
|
|
680
679
|
num_levels=self.num_levels,
|
|
681
680
|
num_heads=num_heads,
|
|
@@ -744,20 +743,32 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
744
743
|
for param in self.denoising_class_embed.parameters():
|
|
745
744
|
param.requires_grad_(True)
|
|
746
745
|
|
|
747
|
-
|
|
746
|
+
@staticmethod
|
|
747
|
+
def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
748
748
|
batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
749
749
|
src_idx = torch.concat([src for (src, _) in indices])
|
|
750
750
|
return (batch_idx, src_idx)
|
|
751
751
|
|
|
752
|
-
def
|
|
752
|
+
def _compute_layer_losses(
|
|
753
753
|
self,
|
|
754
754
|
cls_logits: torch.Tensor,
|
|
755
755
|
box_output: torch.Tensor,
|
|
756
756
|
targets: list[dict[str, torch.Tensor]],
|
|
757
757
|
indices: list[torch.Tensor],
|
|
758
758
|
num_boxes: float,
|
|
759
|
-
) -> torch.Tensor:
|
|
759
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
760
760
|
idx = self._get_src_permutation_idx(indices)
|
|
761
|
+
|
|
762
|
+
src_boxes = box_output[idx]
|
|
763
|
+
target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
764
|
+
|
|
765
|
+
src_boxes_xyxy = box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy")
|
|
766
|
+
target_boxes_xyxy = box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy")
|
|
767
|
+
|
|
768
|
+
# IoU for varifocal loss (class loss)
|
|
769
|
+
ious = torch.diag(box_ops.box_iou(src_boxes_xyxy, target_boxes_xyxy)).detach()
|
|
770
|
+
|
|
771
|
+
# Classification loss
|
|
761
772
|
target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
|
|
762
773
|
target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
|
|
763
774
|
target_classes[idx] = target_classes_o
|
|
@@ -771,15 +782,6 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
771
782
|
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
|
772
783
|
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
|
773
784
|
|
|
774
|
-
src_boxes = box_output[idx]
|
|
775
|
-
target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
776
|
-
ious = torch.diag(
|
|
777
|
-
box_ops.box_iou(
|
|
778
|
-
box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
779
|
-
box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
780
|
-
)
|
|
781
|
-
).detach()
|
|
782
|
-
|
|
783
785
|
target_score_o = torch.zeros(cls_logits.shape[:2], dtype=cls_logits.dtype, device=cls_logits.device)
|
|
784
786
|
target_score_o[idx] = ious.to(cls_logits.dtype)
|
|
785
787
|
target_score = target_score_o.unsqueeze(-1) * target_classes_onehot
|
|
@@ -787,31 +789,13 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
787
789
|
loss = varifocal_loss(cls_logits, target_score, target_classes_onehot, alpha=0.75, gamma=2.0)
|
|
788
790
|
loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.shape[1]
|
|
789
791
|
|
|
790
|
-
|
|
792
|
+
# Box L1 loss
|
|
793
|
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none").sum() / num_boxes
|
|
791
794
|
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
box_output: torch.Tensor,
|
|
795
|
-
targets: list[dict[str, torch.Tensor]],
|
|
796
|
-
indices: list[torch.Tensor],
|
|
797
|
-
num_boxes: float,
|
|
798
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
799
|
-
idx = self._get_src_permutation_idx(indices)
|
|
800
|
-
src_boxes = box_output[idx]
|
|
801
|
-
target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
802
|
-
|
|
803
|
-
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
|
804
|
-
loss_bbox = loss_bbox.sum() / num_boxes
|
|
795
|
+
# GIoU loss
|
|
796
|
+
loss_giou = (1 - torch.diag(box_ops.generalized_box_iou(src_boxes_xyxy, target_boxes_xyxy))).sum() / num_boxes
|
|
805
797
|
|
|
806
|
-
|
|
807
|
-
box_ops.generalized_box_iou(
|
|
808
|
-
box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
809
|
-
box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
|
|
810
|
-
)
|
|
811
|
-
)
|
|
812
|
-
loss_giou = loss_giou.sum() / num_boxes
|
|
813
|
-
|
|
814
|
-
return (loss_bbox, loss_giou)
|
|
798
|
+
return (loss_ce, loss_bbox, loss_giou)
|
|
815
799
|
|
|
816
800
|
def _compute_denoising_loss(
|
|
817
801
|
self,
|
|
@@ -846,11 +830,9 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
846
830
|
)
|
|
847
831
|
)
|
|
848
832
|
|
|
849
|
-
loss_ce = self.
|
|
833
|
+
loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
|
|
850
834
|
dn_out_logits[layer_idx], dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes
|
|
851
835
|
)
|
|
852
|
-
loss_bbox, loss_giou = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
|
|
853
|
-
|
|
854
836
|
loss_ce_list.append(loss_ce)
|
|
855
837
|
loss_bbox_list.append(loss_bbox)
|
|
856
838
|
loss_giou_list.append(loss_giou)
|
|
@@ -861,9 +843,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
861
843
|
|
|
862
844
|
return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
|
|
863
845
|
|
|
864
|
-
|
|
865
|
-
@torch.compiler.disable() # type: ignore[untyped-decorator]
|
|
866
|
-
def _compute_loss_from_outputs( # pylint: disable=too-many-locals
|
|
846
|
+
def _compute_loss_from_outputs(
|
|
867
847
|
self,
|
|
868
848
|
targets: list[dict[str, torch.Tensor]],
|
|
869
849
|
out_bboxes: torch.Tensor,
|
|
@@ -880,7 +860,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
880
860
|
if training_utils.is_dist_available_and_initialized() is True:
|
|
881
861
|
torch.distributed.all_reduce(num_boxes)
|
|
882
862
|
|
|
883
|
-
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
863
|
+
num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
|
|
884
864
|
|
|
885
865
|
loss_ce_list = []
|
|
886
866
|
loss_bbox_list = []
|
|
@@ -889,19 +869,21 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
889
869
|
# Decoder losses (all layers)
|
|
890
870
|
for layer_idx in range(out_logits.shape[0]):
|
|
891
871
|
indices = self.matcher(out_logits[layer_idx], out_bboxes[layer_idx], targets)
|
|
892
|
-
loss_ce = self.
|
|
893
|
-
|
|
872
|
+
loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
|
|
873
|
+
out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes
|
|
874
|
+
)
|
|
894
875
|
loss_ce_list.append(loss_ce)
|
|
895
876
|
loss_bbox_list.append(loss_bbox)
|
|
896
877
|
loss_giou_list.append(loss_giou)
|
|
897
878
|
|
|
898
879
|
# Encoder auxiliary loss
|
|
899
880
|
enc_indices = self.matcher(enc_topk_logits, enc_topk_bboxes, targets)
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
881
|
+
loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
|
|
882
|
+
enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes
|
|
883
|
+
)
|
|
884
|
+
loss_ce_list.append(loss_ce)
|
|
885
|
+
loss_bbox_list.append(loss_bbox)
|
|
886
|
+
loss_giou_list.append(loss_giou)
|
|
905
887
|
|
|
906
888
|
loss_ce = torch.stack(loss_ce_list).sum() # VFL weight is 1
|
|
907
889
|
loss_bbox = torch.stack(loss_bbox_list).sum() * 5
|
|
@@ -985,7 +967,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
985
967
|
|
|
986
968
|
# Convert to [x0, y0, x1, y1] format
|
|
987
969
|
boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
|
|
988
|
-
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).
|
|
970
|
+
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
|
|
989
971
|
|
|
990
972
|
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
|
991
973
|
img_h, img_w = target_sizes.unbind(1)
|
|
@@ -1047,6 +1029,7 @@ class RT_DETR_v2(DetectionBaseNet):
|
|
|
1047
1029
|
else:
|
|
1048
1030
|
B, _, H, W = feat.size()
|
|
1049
1031
|
m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
|
|
1032
|
+
|
|
1050
1033
|
mask_list.append(m)
|
|
1051
1034
|
|
|
1052
1035
|
encoder_features = self.encoder(feature_list, masks=mask_list)
|
birder/net/detection/ssdlite.py
CHANGED
|
@@ -50,7 +50,7 @@ class SSDLiteClassificationHead(SSDScoringHead):
|
|
|
50
50
|
if isinstance(layer, nn.Conv2d):
|
|
51
51
|
nn.init.xavier_uniform_(layer.weight)
|
|
52
52
|
if layer.bias is not None:
|
|
53
|
-
nn.init.
|
|
53
|
+
nn.init.zeros_(layer.bias)
|
|
54
54
|
|
|
55
55
|
super().__init__(cls_logits, num_classes)
|
|
56
56
|
|
|
@@ -79,7 +79,7 @@ class SSDLiteRegressionHead(SSDScoringHead):
|
|
|
79
79
|
if isinstance(layer, nn.Conv2d):
|
|
80
80
|
nn.init.xavier_uniform_(layer.weight)
|
|
81
81
|
if layer.bias is not None:
|
|
82
|
-
nn.init.
|
|
82
|
+
nn.init.zeros_(layer.bias)
|
|
83
83
|
|
|
84
84
|
super().__init__(bbox_reg, 4)
|
|
85
85
|
|
birder/net/edgevit.py
CHANGED
|
@@ -332,11 +332,11 @@ class EdgeViT(DetectorBackbone):
|
|
|
332
332
|
if isinstance(m, nn.Linear):
|
|
333
333
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
334
334
|
if m.bias is not None:
|
|
335
|
-
nn.init.
|
|
335
|
+
nn.init.zeros_(m.bias)
|
|
336
336
|
|
|
337
337
|
elif isinstance(m, nn.LayerNorm):
|
|
338
|
-
nn.init.
|
|
339
|
-
nn.init.
|
|
338
|
+
nn.init.zeros_(m.bias)
|
|
339
|
+
nn.init.ones_(m.weight)
|
|
340
340
|
|
|
341
341
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
342
342
|
out = {}
|
birder/net/efficientvit_msft.py
CHANGED
birder/net/flexivit.py
CHANGED
|
@@ -314,7 +314,7 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
314
314
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
315
315
|
|
|
316
316
|
out: dict[str, torch.Tensor] = {}
|
|
317
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
317
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
318
318
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
319
319
|
stage_x = stage_x.permute(0, 2, 1)
|
|
320
320
|
B, C, _ = stage_x.size()
|
birder/net/hieradet.py
CHANGED
|
@@ -613,11 +613,11 @@ registry.register_weights( # SAM v2: https://arxiv.org/abs/2408.00714
|
|
|
613
613
|
"HieraDet small image encoder pre-trained by Meta AI using SAM v2. "
|
|
614
614
|
"This model has not been fine-tuned for a specific classification task"
|
|
615
615
|
),
|
|
616
|
-
"resolution": (
|
|
616
|
+
"resolution": (1024, 1024),
|
|
617
617
|
"formats": {
|
|
618
618
|
"pt": {
|
|
619
619
|
"file_size": 129.6,
|
|
620
|
-
"sha256": "
|
|
620
|
+
"sha256": "2ede3a78389ca74ed37d82dbc1c3410549f1fdafb5a7a94ac02968aa6d3dec80",
|
|
621
621
|
}
|
|
622
622
|
},
|
|
623
623
|
"net": {"network": "hieradet_small", "tag": "sam2_1"},
|
birder/net/mnasnet.py
CHANGED
|
@@ -230,8 +230,8 @@ class MNASNet(DetectorBackbone):
|
|
|
230
230
|
nn.init.zeros_(m.bias)
|
|
231
231
|
|
|
232
232
|
elif isinstance(m, nn.BatchNorm2d):
|
|
233
|
-
nn.init.
|
|
234
|
-
nn.init.
|
|
233
|
+
nn.init.ones_(m.weight)
|
|
234
|
+
nn.init.zeros_(m.bias)
|
|
235
235
|
|
|
236
236
|
elif isinstance(m, nn.Linear):
|
|
237
237
|
nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
|
birder/net/resnext.py
CHANGED
|
@@ -205,8 +205,8 @@ class ResNeXt(DetectorBackbone):
|
|
|
205
205
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|
206
206
|
|
|
207
207
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
208
|
-
nn.init.
|
|
209
|
-
nn.init.
|
|
208
|
+
nn.init.ones_(m.weight)
|
|
209
|
+
nn.init.zeros_(m.bias)
|
|
210
210
|
|
|
211
211
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
212
212
|
x = self.stem(x)
|
birder/net/rope_deit3.py
CHANGED
|
@@ -249,7 +249,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
249
249
|
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
250
250
|
|
|
251
251
|
out: dict[str, torch.Tensor] = {}
|
|
252
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
252
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
253
253
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
254
254
|
stage_x = stage_x.permute(0, 2, 1)
|
|
255
255
|
B, C, _ = stage_x.size()
|
birder/net/rope_flexivit.py
CHANGED
|
@@ -342,7 +342,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
342
342
|
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
343
343
|
|
|
344
344
|
out: dict[str, torch.Tensor] = {}
|
|
345
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
345
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
346
346
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
347
347
|
stage_x = stage_x.permute(0, 2, 1)
|
|
348
348
|
B, C, _ = stage_x.size()
|
birder/net/rope_vit.py
CHANGED
|
@@ -698,7 +698,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
698
698
|
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
699
699
|
|
|
700
700
|
out: dict[str, torch.Tensor] = {}
|
|
701
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
701
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
702
702
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
703
703
|
stage_x = stage_x.permute(0, 2, 1)
|
|
704
704
|
B, C, _ = stage_x.size()
|
birder/net/simple_vit.py
CHANGED
|
@@ -215,7 +215,7 @@ class Simple_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
215
215
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
216
216
|
|
|
217
217
|
out: dict[str, torch.Tensor] = {}
|
|
218
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
218
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
219
219
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
220
220
|
stage_x = stage_x.permute(0, 2, 1)
|
|
221
221
|
B, C, _ = stage_x.size()
|
birder/net/vit.py
CHANGED
|
@@ -572,7 +572,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
572
572
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
573
573
|
|
|
574
574
|
out: dict[str, torch.Tensor] = {}
|
|
575
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
575
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
576
576
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
577
577
|
stage_x = stage_x.permute(0, 2, 1)
|
|
578
578
|
B, C, _ = stage_x.size()
|
|
@@ -802,6 +802,24 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
|
|
|
802
802
|
# Register model configs (side effects)
|
|
803
803
|
register_vit_configs(ViT)
|
|
804
804
|
|
|
805
|
+
registry.register_weights( # BioCLIP v1: https://arxiv.org/abs/2311.18803
|
|
806
|
+
"vit_b16_pn_bioclip-v1",
|
|
807
|
+
{
|
|
808
|
+
"url": "https://huggingface.co/birder-project/vit_b16_pn_bioclip-v1/resolve/main",
|
|
809
|
+
"description": (
|
|
810
|
+
"ViT b16 image encoder pre-trained by Imageomics using CLIP on the TreeOfLife-10M dataset. "
|
|
811
|
+
"This model has not been fine-tuned for a specific classification task"
|
|
812
|
+
),
|
|
813
|
+
"resolution": (224, 224),
|
|
814
|
+
"formats": {
|
|
815
|
+
"pt": {
|
|
816
|
+
"file_size": 328.9,
|
|
817
|
+
"sha256": "9b2e5598f233657932eeb77e027cd4c4d683bf75515768fe6971cab6ec10bf15",
|
|
818
|
+
},
|
|
819
|
+
},
|
|
820
|
+
"net": {"network": "vit_b16_pn", "tag": "bioclip-v1"},
|
|
821
|
+
},
|
|
822
|
+
)
|
|
805
823
|
registry.register_weights(
|
|
806
824
|
"vit_l16_mim_200",
|
|
807
825
|
{
|
|
@@ -849,8 +867,8 @@ registry.register_weights( # BioCLIP v2: https://arxiv.org/abs/2505.23883
|
|
|
849
867
|
"resolution": (224, 224),
|
|
850
868
|
"formats": {
|
|
851
869
|
"pt": {
|
|
852
|
-
"file_size":
|
|
853
|
-
"sha256": "
|
|
870
|
+
"file_size": 1159.7,
|
|
871
|
+
"sha256": "301a325579dafdfa2ea13b0cbaf8129211ecd1429c29afa20d1c2eaaa91d8b0d",
|
|
854
872
|
},
|
|
855
873
|
},
|
|
856
874
|
"net": {"network": "vit_l14_pn", "tag": "bioclip-v2"},
|
birder/net/vit_parallel.py
CHANGED
|
@@ -370,7 +370,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
370
370
|
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
371
371
|
|
|
372
372
|
out: dict[str, torch.Tensor] = {}
|
|
373
|
-
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
373
|
+
for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
|
|
374
374
|
stage_x = stage_x[:, self.num_special_tokens :]
|
|
375
375
|
stage_x = stage_x.permute(0, 2, 1)
|
|
376
376
|
B, C, _ = stage_x.size()
|