birder 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. birder/common/training_cli.py +6 -1
  2. birder/common/training_utils.py +69 -12
  3. birder/net/_vit_configs.py +5 -0
  4. birder/net/cait.py +3 -3
  5. birder/net/coat.py +3 -3
  6. birder/net/deit.py +1 -1
  7. birder/net/deit3.py +1 -1
  8. birder/net/detection/__init__.py +2 -0
  9. birder/net/detection/deformable_detr.py +12 -12
  10. birder/net/detection/detr.py +7 -7
  11. birder/net/detection/lw_detr.py +1181 -0
  12. birder/net/detection/plain_detr.py +7 -5
  13. birder/net/detection/retinanet.py +1 -1
  14. birder/net/detection/rt_detr_v1.py +10 -10
  15. birder/net/detection/rt_detr_v2.py +47 -64
  16. birder/net/detection/ssdlite.py +2 -2
  17. birder/net/edgevit.py +3 -3
  18. birder/net/efficientvit_msft.py +1 -1
  19. birder/net/flexivit.py +1 -1
  20. birder/net/hieradet.py +2 -2
  21. birder/net/mnasnet.py +2 -2
  22. birder/net/resnext.py +2 -2
  23. birder/net/rope_deit3.py +1 -1
  24. birder/net/rope_flexivit.py +1 -1
  25. birder/net/rope_vit.py +1 -1
  26. birder/net/simple_vit.py +1 -1
  27. birder/net/vit.py +21 -3
  28. birder/net/vit_parallel.py +1 -1
  29. birder/net/vit_sam.py +62 -16
  30. birder/scripts/train.py +12 -8
  31. birder/scripts/train_capi.py +13 -10
  32. birder/scripts/train_detection.py +2 -1
  33. birder/scripts/train_kd.py +12 -8
  34. birder/version.py +1 -1
  35. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/METADATA +3 -3
  36. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/RECORD +40 -39
  37. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/WHEEL +1 -1
  38. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/entry_points.txt +0 -0
  39. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/licenses/LICENSE +0 -0
  40. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/top_level.txt +0 -0
@@ -522,13 +522,13 @@ class Plain_DETR(DetectionBaseNet):
522
522
 
523
523
  self.class_embed = nn.Linear(hidden_dim, self.num_classes)
524
524
  self.bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
525
- self.query_embed = nn.Embedding(self.num_queries, hidden_dim * 2)
525
+ self.query_embed = nn.Parameter(torch.empty(self.num_queries, hidden_dim * 2))
526
526
  self.reference_point_head = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
527
527
  self.input_proj = nn.Conv2d(
528
528
  self.backbone.return_channels[-1], hidden_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)
529
529
  )
530
530
  self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
531
- self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
531
+ self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
532
532
 
533
533
  if box_refine is True:
534
534
  self.class_embed = _get_clones(self.class_embed, num_decoder_layers)
@@ -554,6 +554,7 @@ class Plain_DETR(DetectionBaseNet):
554
554
  if idx == 0:
555
555
  nn.init.constant_(last_linear.bias[2:], -2.0) # Small initial wh
556
556
 
557
+ nn.init.normal_(self.query_embed)
557
558
  ref_last_linear = [m for m in self.reference_point_head.modules() if isinstance(m, nn.Linear)][-1]
558
559
  nn.init.zeros_(ref_last_linear.weight)
559
560
  nn.init.zeros_(ref_last_linear.bias)
@@ -576,7 +577,8 @@ class Plain_DETR(DetectionBaseNet):
576
577
  for param in self.class_embed.parameters():
577
578
  param.requires_grad_(True)
578
579
 
579
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
580
+ @staticmethod
581
+ def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
580
582
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
581
583
  src_idx = torch.concat([src for (src, _) in indices])
582
584
  return (batch_idx, src_idx)
@@ -646,7 +648,7 @@ class Plain_DETR(DetectionBaseNet):
646
648
  if training_utils.is_dist_available_and_initialized() is True:
647
649
  torch.distributed.all_reduce(num_boxes)
648
650
 
649
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
651
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
650
652
 
651
653
  loss_ce_list = []
652
654
  loss_bbox_list = []
@@ -772,7 +774,7 @@ class Plain_DETR(DetectionBaseNet):
772
774
  else:
773
775
  num_queries_to_use = self.num_queries_one2one
774
776
 
775
- query_embed = self.query_embed.weight[:num_queries_to_use]
777
+ query_embed = self.query_embed[:num_queries_to_use]
776
778
  query_embed, query_pos = torch.split(query_embed, self.hidden_dim, dim=1)
777
779
  query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
778
780
  query_pos = query_pos.unsqueeze(0).expand(B, -1, -1)
@@ -63,7 +63,7 @@ class RetinaNetClassificationHead(nn.Module):
63
63
  if isinstance(layer, nn.Conv2d):
64
64
  nn.init.normal_(layer.weight, std=0.01)
65
65
  if layer.bias is not None:
66
- nn.init.constant_(layer.bias, 0)
66
+ nn.init.zeros_(layer.bias)
67
67
 
68
68
  self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
69
69
 
@@ -596,18 +596,18 @@ class RT_DETRDecoder(nn.Module):
596
596
 
597
597
  # Gather reference points
598
598
  reference_points_unact = enc_outputs_coord_unact.gather(
599
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1])
599
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
600
600
  )
601
601
 
602
602
  enc_topk_bboxes = reference_points_unact.sigmoid()
603
603
 
604
604
  # Gather encoder logits for loss computation
605
605
  enc_topk_logits = enc_outputs_class.gather(
606
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
606
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
607
607
  )
608
608
 
609
609
  # Extract region features
610
- target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
610
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
611
611
  target = target.detach()
612
612
 
613
613
  return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
@@ -653,7 +653,7 @@ class RT_DETRDecoder(nn.Module):
653
653
  reference_points = init_ref_points_unact.sigmoid()
654
654
  for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
655
655
  query_pos = self.query_pos_head(reference_points)
656
- reference_points_input = reference_points.unsqueeze(2).repeat(1, 1, len(spatial_shapes), 1)
656
+ reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
657
657
  target = decoder_layer(
658
658
  target,
659
659
  query_pos,
@@ -743,7 +743,7 @@ class RT_DETR_v1(DetectionBaseNet):
743
743
  self.decoder = RT_DETRDecoder(
744
744
  hidden_dim=hidden_dim,
745
745
  num_classes=self.num_classes,
746
- num_queries=num_queries,
746
+ num_queries=self.num_queries,
747
747
  num_decoder_layers=num_decoder_layers,
748
748
  num_levels=self.num_levels,
749
749
  num_heads=num_heads,
@@ -810,7 +810,8 @@ class RT_DETR_v1(DetectionBaseNet):
810
810
  for param in self.denoising_class_embed.parameters():
811
811
  param.requires_grad_(True)
812
812
 
813
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
813
+ @staticmethod
814
+ def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
814
815
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
815
816
  src_idx = torch.concat([src for (src, _) in indices])
816
817
  return (batch_idx, src_idx)
@@ -927,8 +928,6 @@ class RT_DETR_v1(DetectionBaseNet):
927
928
 
928
929
  return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
929
930
 
930
- @torch.jit.unused # type: ignore[untyped-decorator]
931
- @torch.compiler.disable() # type: ignore[untyped-decorator]
932
931
  def _compute_loss_from_outputs( # pylint: disable=too-many-locals
933
932
  self,
934
933
  targets: list[dict[str, torch.Tensor]],
@@ -946,7 +945,7 @@ class RT_DETR_v1(DetectionBaseNet):
946
945
  if training_utils.is_dist_available_and_initialized() is True:
947
946
  torch.distributed.all_reduce(num_boxes)
948
947
 
949
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
948
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
950
949
 
951
950
  loss_ce_list = []
952
951
  loss_bbox_list = []
@@ -1051,7 +1050,7 @@ class RT_DETR_v1(DetectionBaseNet):
1051
1050
 
1052
1051
  # Convert to [x0, y0, x1, y1] format
1053
1052
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
1054
- boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
1053
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
1055
1054
 
1056
1055
  # Convert from relative [0, 1] to absolute [0, height] coordinates
1057
1056
  img_h, img_w = target_sizes.unbind(1)
@@ -1113,6 +1112,7 @@ class RT_DETR_v1(DetectionBaseNet):
1113
1112
  else:
1114
1113
  B, _, H, W = feat.size()
1115
1114
  m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
1115
+
1116
1116
  mask_list.append(m)
1117
1117
 
1118
1118
  encoder_features = self.encoder(feature_list, masks=mask_list)
@@ -147,7 +147,7 @@ class MultiScaleDeformableAttention(nn.Module):
147
147
  param.requires_grad_(False)
148
148
 
149
149
  def reset_parameters(self) -> None:
150
- nn.init.constant_(self.sampling_offsets.weight, 0.0)
150
+ nn.init.zeros_(self.sampling_offsets.weight)
151
151
  thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
152
152
  grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
153
153
  grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)[0]
@@ -158,12 +158,12 @@ class MultiScaleDeformableAttention(nn.Module):
158
158
  with torch.no_grad():
159
159
  self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
160
160
 
161
- nn.init.constant_(self.attention_weights.weight, 0.0)
162
- nn.init.constant_(self.attention_weights.bias, 0.0)
161
+ nn.init.zeros_(self.attention_weights.weight)
162
+ nn.init.zeros_(self.attention_weights.bias)
163
163
  nn.init.xavier_uniform_(self.value_proj.weight)
164
- nn.init.constant_(self.value_proj.bias, 0.0)
164
+ nn.init.zeros_(self.value_proj.bias)
165
165
  nn.init.xavier_uniform_(self.output_proj.weight)
166
- nn.init.constant_(self.output_proj.bias, 0.0)
166
+ nn.init.zeros_(self.output_proj.bias)
167
167
 
168
168
  def forward(
169
169
  self,
@@ -174,7 +174,7 @@ class MultiScaleDeformableAttention(nn.Module):
174
174
  input_level_start_index: torch.Tensor,
175
175
  input_padding_mask: Optional[torch.Tensor] = None,
176
176
  ) -> torch.Tensor:
177
- N, num_queries, _ = query.size()
177
+ num_queries = query.size(1)
178
178
  N, sequence_length, _ = input_flatten.size()
179
179
  assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
180
180
 
@@ -366,10 +366,9 @@ class TransformerDecoderLayer(nn.Module):
366
366
  self_attn_mask: Optional[torch.Tensor] = None,
367
367
  ) -> torch.Tensor:
368
368
  # Self attention
369
- q = tgt + query_pos
370
- k = tgt + query_pos
369
+ q_k = tgt + query_pos
371
370
 
372
- tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)
371
+ tgt2 = self.self_attn(q_k, q_k, tgt, attn_mask=self_attn_mask)
373
372
  tgt = tgt + self.dropout(tgt2)
374
373
  tgt = self.norm1(tgt)
375
374
 
@@ -526,18 +525,18 @@ class RT_DETRDecoder(nn.Module):
526
525
 
527
526
  # Gather reference points
528
527
  reference_points_unact = enc_outputs_coord_unact.gather(
529
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1])
528
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_coord_unact.shape[-1])
530
529
  )
531
530
 
532
531
  enc_topk_bboxes = reference_points_unact.sigmoid()
533
532
 
534
533
  # Gather encoder logits for loss computation
535
534
  enc_topk_logits = enc_outputs_class.gather(
536
- dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
535
+ dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, enc_outputs_class.shape[-1])
537
536
  )
538
537
 
539
538
  # Extract region features
540
- target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
539
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).expand(-1, -1, output_memory.shape[-1]))
541
540
  target = target.detach()
542
541
 
543
542
  return (target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits)
@@ -583,7 +582,7 @@ class RT_DETRDecoder(nn.Module):
583
582
  reference_points = init_ref_points_unact.sigmoid()
584
583
  for decoder_layer, bbox_head, class_head in zip(self.layers, self.bbox_embed, self.class_embed):
585
584
  query_pos = self.query_pos_head(reference_points)
586
- reference_points_input = reference_points.unsqueeze(2).repeat(1, 1, len(spatial_shapes), 1)
585
+ reference_points_input = reference_points.unsqueeze(2).expand(-1, -1, len(spatial_shapes), -1)
587
586
  target = decoder_layer(
588
587
  target,
589
588
  query_pos,
@@ -675,7 +674,7 @@ class RT_DETR_v2(DetectionBaseNet):
675
674
  self.decoder = RT_DETRDecoder(
676
675
  hidden_dim=hidden_dim,
677
676
  num_classes=self.num_classes,
678
- num_queries=num_queries,
677
+ num_queries=self.num_queries,
679
678
  num_decoder_layers=num_decoder_layers,
680
679
  num_levels=self.num_levels,
681
680
  num_heads=num_heads,
@@ -744,20 +743,32 @@ class RT_DETR_v2(DetectionBaseNet):
744
743
  for param in self.denoising_class_embed.parameters():
745
744
  param.requires_grad_(True)
746
745
 
747
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
746
+ @staticmethod
747
+ def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
748
748
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
749
749
  src_idx = torch.concat([src for (src, _) in indices])
750
750
  return (batch_idx, src_idx)
751
751
 
752
- def _class_loss(
752
+ def _compute_layer_losses(
753
753
  self,
754
754
  cls_logits: torch.Tensor,
755
755
  box_output: torch.Tensor,
756
756
  targets: list[dict[str, torch.Tensor]],
757
757
  indices: list[torch.Tensor],
758
758
  num_boxes: float,
759
- ) -> torch.Tensor:
759
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
760
760
  idx = self._get_src_permutation_idx(indices)
761
+
762
+ src_boxes = box_output[idx]
763
+ target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
764
+
765
+ src_boxes_xyxy = box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy")
766
+ target_boxes_xyxy = box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy")
767
+
768
+ # IoU for varifocal loss (class loss)
769
+ ious = torch.diag(box_ops.box_iou(src_boxes_xyxy, target_boxes_xyxy)).detach()
770
+
771
+ # Classification loss
761
772
  target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
762
773
  target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
763
774
  target_classes[idx] = target_classes_o
@@ -771,15 +782,6 @@ class RT_DETR_v2(DetectionBaseNet):
771
782
  target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
772
783
  target_classes_onehot = target_classes_onehot[:, :, :-1]
773
784
 
774
- src_boxes = box_output[idx]
775
- target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
776
- ious = torch.diag(
777
- box_ops.box_iou(
778
- box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
779
- box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
780
- )
781
- ).detach()
782
-
783
785
  target_score_o = torch.zeros(cls_logits.shape[:2], dtype=cls_logits.dtype, device=cls_logits.device)
784
786
  target_score_o[idx] = ious.to(cls_logits.dtype)
785
787
  target_score = target_score_o.unsqueeze(-1) * target_classes_onehot
@@ -787,31 +789,13 @@ class RT_DETR_v2(DetectionBaseNet):
787
789
  loss = varifocal_loss(cls_logits, target_score, target_classes_onehot, alpha=0.75, gamma=2.0)
788
790
  loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.shape[1]
789
791
 
790
- return loss_ce
792
+ # Box L1 loss
793
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none").sum() / num_boxes
791
794
 
792
- def _box_loss(
793
- self,
794
- box_output: torch.Tensor,
795
- targets: list[dict[str, torch.Tensor]],
796
- indices: list[torch.Tensor],
797
- num_boxes: float,
798
- ) -> tuple[torch.Tensor, torch.Tensor]:
799
- idx = self._get_src_permutation_idx(indices)
800
- src_boxes = box_output[idx]
801
- target_boxes = torch.concat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
802
-
803
- loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
804
- loss_bbox = loss_bbox.sum() / num_boxes
795
+ # GIoU loss
796
+ loss_giou = (1 - torch.diag(box_ops.generalized_box_iou(src_boxes_xyxy, target_boxes_xyxy))).sum() / num_boxes
805
797
 
806
- loss_giou = 1 - torch.diag(
807
- box_ops.generalized_box_iou(
808
- box_ops.box_convert(src_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
809
- box_ops.box_convert(target_boxes, in_fmt="cxcywh", out_fmt="xyxy"),
810
- )
811
- )
812
- loss_giou = loss_giou.sum() / num_boxes
813
-
814
- return (loss_bbox, loss_giou)
798
+ return (loss_ce, loss_bbox, loss_giou)
815
799
 
816
800
  def _compute_denoising_loss(
817
801
  self,
@@ -846,11 +830,9 @@ class RT_DETR_v2(DetectionBaseNet):
846
830
  )
847
831
  )
848
832
 
849
- loss_ce = self._class_loss(
833
+ loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
850
834
  dn_out_logits[layer_idx], dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes
851
835
  )
852
- loss_bbox, loss_giou = self._box_loss(dn_out_bboxes[layer_idx], targets, indices, dn_num_boxes)
853
-
854
836
  loss_ce_list.append(loss_ce)
855
837
  loss_bbox_list.append(loss_bbox)
856
838
  loss_giou_list.append(loss_giou)
@@ -861,9 +843,7 @@ class RT_DETR_v2(DetectionBaseNet):
861
843
 
862
844
  return (loss_ce_dn, loss_bbox_dn, loss_giou_dn)
863
845
 
864
- @torch.jit.unused # type: ignore[untyped-decorator]
865
- @torch.compiler.disable() # type: ignore[untyped-decorator]
866
- def _compute_loss_from_outputs( # pylint: disable=too-many-locals
846
+ def _compute_loss_from_outputs(
867
847
  self,
868
848
  targets: list[dict[str, torch.Tensor]],
869
849
  out_bboxes: torch.Tensor,
@@ -880,7 +860,7 @@ class RT_DETR_v2(DetectionBaseNet):
880
860
  if training_utils.is_dist_available_and_initialized() is True:
881
861
  torch.distributed.all_reduce(num_boxes)
882
862
 
883
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
863
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
884
864
 
885
865
  loss_ce_list = []
886
866
  loss_bbox_list = []
@@ -889,19 +869,21 @@ class RT_DETR_v2(DetectionBaseNet):
889
869
  # Decoder losses (all layers)
890
870
  for layer_idx in range(out_logits.shape[0]):
891
871
  indices = self.matcher(out_logits[layer_idx], out_bboxes[layer_idx], targets)
892
- loss_ce = self._class_loss(out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes)
893
- loss_bbox, loss_giou = self._box_loss(out_bboxes[layer_idx], targets, indices, num_boxes)
872
+ loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
873
+ out_logits[layer_idx], out_bboxes[layer_idx], targets, indices, num_boxes
874
+ )
894
875
  loss_ce_list.append(loss_ce)
895
876
  loss_bbox_list.append(loss_bbox)
896
877
  loss_giou_list.append(loss_giou)
897
878
 
898
879
  # Encoder auxiliary loss
899
880
  enc_indices = self.matcher(enc_topk_logits, enc_topk_bboxes, targets)
900
- loss_ce_enc = self._class_loss(enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes)
901
- loss_bbox_enc, loss_giou_enc = self._box_loss(enc_topk_bboxes, targets, enc_indices, num_boxes)
902
- loss_ce_list.append(loss_ce_enc)
903
- loss_bbox_list.append(loss_bbox_enc)
904
- loss_giou_list.append(loss_giou_enc)
881
+ loss_ce, loss_bbox, loss_giou = self._compute_layer_losses(
882
+ enc_topk_logits, enc_topk_bboxes, targets, enc_indices, num_boxes
883
+ )
884
+ loss_ce_list.append(loss_ce)
885
+ loss_bbox_list.append(loss_bbox)
886
+ loss_giou_list.append(loss_giou)
905
887
 
906
888
  loss_ce = torch.stack(loss_ce_list).sum() # VFL weight is 1
907
889
  loss_bbox = torch.stack(loss_bbox_list).sum() * 5
@@ -985,7 +967,7 @@ class RT_DETR_v2(DetectionBaseNet):
985
967
 
986
968
  # Convert to [x0, y0, x1, y1] format
987
969
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
988
- boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
970
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).expand(-1, -1, 4))
989
971
 
990
972
  # Convert from relative [0, 1] to absolute [0, height] coordinates
991
973
  img_h, img_w = target_sizes.unbind(1)
@@ -1047,6 +1029,7 @@ class RT_DETR_v2(DetectionBaseNet):
1047
1029
  else:
1048
1030
  B, _, H, W = feat.size()
1049
1031
  m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
1032
+
1050
1033
  mask_list.append(m)
1051
1034
 
1052
1035
  encoder_features = self.encoder(feature_list, masks=mask_list)
@@ -50,7 +50,7 @@ class SSDLiteClassificationHead(SSDScoringHead):
50
50
  if isinstance(layer, nn.Conv2d):
51
51
  nn.init.xavier_uniform_(layer.weight)
52
52
  if layer.bias is not None:
53
- nn.init.constant_(layer.bias, 0.0)
53
+ nn.init.zeros_(layer.bias)
54
54
 
55
55
  super().__init__(cls_logits, num_classes)
56
56
 
@@ -79,7 +79,7 @@ class SSDLiteRegressionHead(SSDScoringHead):
79
79
  if isinstance(layer, nn.Conv2d):
80
80
  nn.init.xavier_uniform_(layer.weight)
81
81
  if layer.bias is not None:
82
- nn.init.constant_(layer.bias, 0.0)
82
+ nn.init.zeros_(layer.bias)
83
83
 
84
84
  super().__init__(bbox_reg, 4)
85
85
 
birder/net/edgevit.py CHANGED
@@ -332,11 +332,11 @@ class EdgeViT(DetectorBackbone):
332
332
  if isinstance(m, nn.Linear):
333
333
  nn.init.trunc_normal_(m.weight, std=0.02)
334
334
  if m.bias is not None:
335
- nn.init.constant_(m.bias, 0)
335
+ nn.init.zeros_(m.bias)
336
336
 
337
337
  elif isinstance(m, nn.LayerNorm):
338
- nn.init.constant_(m.bias, 0)
339
- nn.init.constant_(m.weight, 1.0)
338
+ nn.init.zeros_(m.bias)
339
+ nn.init.ones_(m.weight)
340
340
 
341
341
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
342
342
  out = {}
@@ -55,7 +55,7 @@ class Conv2dNorm(nn.Sequential):
55
55
  )
56
56
  self.add_module("bn", nn.BatchNorm2d(out_channels))
57
57
  nn.init.constant_(self.bn.weight, bn_weight_init)
58
- nn.init.constant_(self.bn.bias, 0)
58
+ nn.init.zeros_(self.bn.bias)
59
59
 
60
60
 
61
61
  class PatchMerging(nn.Module):
birder/net/flexivit.py CHANGED
@@ -314,7 +314,7 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
314
314
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
315
315
 
316
316
  out: dict[str, torch.Tensor] = {}
317
- for stage_name, stage_x in zip(self.return_stages, xs):
317
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
318
318
  stage_x = stage_x[:, self.num_special_tokens :]
319
319
  stage_x = stage_x.permute(0, 2, 1)
320
320
  B, C, _ = stage_x.size()
birder/net/hieradet.py CHANGED
@@ -613,11 +613,11 @@ registry.register_weights( # SAM v2: https://arxiv.org/abs/2408.00714
613
613
  "HieraDet small image encoder pre-trained by Meta AI using SAM v2. "
614
614
  "This model has not been fine-tuned for a specific classification task"
615
615
  ),
616
- "resolution": (224, 224),
616
+ "resolution": (1024, 1024),
617
617
  "formats": {
618
618
  "pt": {
619
619
  "file_size": 129.6,
620
- "sha256": "79b6ffdfd4ea9f3b1489ce5a229fe9756b215fc3b52640d01d64136560c1d341",
620
+ "sha256": "2ede3a78389ca74ed37d82dbc1c3410549f1fdafb5a7a94ac02968aa6d3dec80",
621
621
  }
622
622
  },
623
623
  "net": {"network": "hieradet_small", "tag": "sam2_1"},
birder/net/mnasnet.py CHANGED
@@ -230,8 +230,8 @@ class MNASNet(DetectorBackbone):
230
230
  nn.init.zeros_(m.bias)
231
231
 
232
232
  elif isinstance(m, nn.BatchNorm2d):
233
- nn.init.constant_(m.weight, 1)
234
- nn.init.constant_(m.bias, 0)
233
+ nn.init.ones_(m.weight)
234
+ nn.init.zeros_(m.bias)
235
235
 
236
236
  elif isinstance(m, nn.Linear):
237
237
  nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
birder/net/resnext.py CHANGED
@@ -205,8 +205,8 @@ class ResNeXt(DetectorBackbone):
205
205
  nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
206
206
 
207
207
  elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
208
- nn.init.constant_(m.weight, 1)
209
- nn.init.constant_(m.bias, 0)
208
+ nn.init.ones_(m.weight)
209
+ nn.init.zeros_(m.bias)
210
210
 
211
211
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
212
212
  x = self.stem(x)
birder/net/rope_deit3.py CHANGED
@@ -249,7 +249,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
249
249
  xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
250
250
 
251
251
  out: dict[str, torch.Tensor] = {}
252
- for stage_name, stage_x in zip(self.return_stages, xs):
252
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
253
253
  stage_x = stage_x[:, self.num_special_tokens :]
254
254
  stage_x = stage_x.permute(0, 2, 1)
255
255
  B, C, _ = stage_x.size()
@@ -342,7 +342,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
342
342
  xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
343
343
 
344
344
  out: dict[str, torch.Tensor] = {}
345
- for stage_name, stage_x in zip(self.return_stages, xs):
345
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
346
346
  stage_x = stage_x[:, self.num_special_tokens :]
347
347
  stage_x = stage_x.permute(0, 2, 1)
348
348
  B, C, _ = stage_x.size()
birder/net/rope_vit.py CHANGED
@@ -698,7 +698,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
698
698
  xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
699
699
 
700
700
  out: dict[str, torch.Tensor] = {}
701
- for stage_name, stage_x in zip(self.return_stages, xs):
701
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
702
702
  stage_x = stage_x[:, self.num_special_tokens :]
703
703
  stage_x = stage_x.permute(0, 2, 1)
704
704
  B, C, _ = stage_x.size()
birder/net/simple_vit.py CHANGED
@@ -215,7 +215,7 @@ class Simple_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
215
215
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
216
216
 
217
217
  out: dict[str, torch.Tensor] = {}
218
- for stage_name, stage_x in zip(self.return_stages, xs):
218
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
219
219
  stage_x = stage_x[:, self.num_special_tokens :]
220
220
  stage_x = stage_x.permute(0, 2, 1)
221
221
  B, C, _ = stage_x.size()
birder/net/vit.py CHANGED
@@ -572,7 +572,7 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
572
572
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
573
573
 
574
574
  out: dict[str, torch.Tensor] = {}
575
- for stage_name, stage_x in zip(self.return_stages, xs):
575
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
576
576
  stage_x = stage_x[:, self.num_special_tokens :]
577
577
  stage_x = stage_x.permute(0, 2, 1)
578
578
  B, C, _ = stage_x.size()
@@ -802,6 +802,24 @@ class ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTok
802
802
  # Register model configs (side effects)
803
803
  register_vit_configs(ViT)
804
804
 
805
+ registry.register_weights( # BioCLIP v1: https://arxiv.org/abs/2311.18803
806
+ "vit_b16_pn_bioclip-v1",
807
+ {
808
+ "url": "https://huggingface.co/birder-project/vit_b16_pn_bioclip-v1/resolve/main",
809
+ "description": (
810
+ "ViT b16 image encoder pre-trained by Imageomics using CLIP on the TreeOfLife-10M dataset. "
811
+ "This model has not been fine-tuned for a specific classification task"
812
+ ),
813
+ "resolution": (224, 224),
814
+ "formats": {
815
+ "pt": {
816
+ "file_size": 328.9,
817
+ "sha256": "9b2e5598f233657932eeb77e027cd4c4d683bf75515768fe6971cab6ec10bf15",
818
+ },
819
+ },
820
+ "net": {"network": "vit_b16_pn", "tag": "bioclip-v1"},
821
+ },
822
+ )
805
823
  registry.register_weights(
806
824
  "vit_l16_mim_200",
807
825
  {
@@ -849,8 +867,8 @@ registry.register_weights( # BioCLIP v2: https://arxiv.org/abs/2505.23883
849
867
  "resolution": (224, 224),
850
868
  "formats": {
851
869
  "pt": {
852
- "file_size": 1156.6,
853
- "sha256": "6cd7bd6993762590891fe2b41db1649cde5a0c4de5a7f341672f8856ed529d07",
870
+ "file_size": 1159.7,
871
+ "sha256": "301a325579dafdfa2ea13b0cbaf8129211ecd1429c29afa20d1c2eaaa91d8b0d",
854
872
  },
855
873
  },
856
874
  "net": {"network": "vit_l14_pn", "tag": "bioclip-v2"},
@@ -370,7 +370,7 @@ class ViT_Parallel(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
370
370
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
371
371
 
372
372
  out: dict[str, torch.Tensor] = {}
373
- for stage_name, stage_x in zip(self.return_stages, xs):
373
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
374
374
  stage_x = stage_x[:, self.num_special_tokens :]
375
375
  stage_x = stage_x.permute(0, 2, 1)
376
376
  B, C, _ = stage_x.size()