birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +11 -11
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +5 -5
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +3 -3
  22. birder/layers/attention_pool.py +2 -2
  23. birder/model_registry/model_registry.py +2 -1
  24. birder/net/__init__.py +2 -0
  25. birder/net/_rope_vit_configs.py +5 -0
  26. birder/net/_vit_configs.py +0 -13
  27. birder/net/alexnet.py +5 -5
  28. birder/net/base.py +28 -3
  29. birder/net/biformer.py +17 -17
  30. birder/net/cait.py +2 -2
  31. birder/net/cas_vit.py +1 -1
  32. birder/net/coat.py +15 -15
  33. birder/net/convnext_v1.py +2 -10
  34. birder/net/convnext_v1_iso.py +198 -0
  35. birder/net/convnext_v2.py +2 -10
  36. birder/net/crossformer.py +9 -9
  37. birder/net/crossvit.py +1 -1
  38. birder/net/cspnet.py +1 -1
  39. birder/net/cswin_transformer.py +10 -10
  40. birder/net/davit.py +10 -10
  41. birder/net/deit.py +56 -3
  42. birder/net/deit3.py +27 -15
  43. birder/net/detection/__init__.py +4 -0
  44. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  45. birder/net/detection/base.py +6 -5
  46. birder/net/detection/deformable_detr.py +26 -28
  47. birder/net/detection/detr.py +9 -9
  48. birder/net/detection/efficientdet.py +9 -28
  49. birder/net/detection/faster_rcnn.py +22 -22
  50. birder/net/detection/fcos.py +8 -8
  51. birder/net/detection/plain_detr.py +852 -0
  52. birder/net/detection/retinanet.py +4 -4
  53. birder/net/detection/rt_detr_v1.py +81 -25
  54. birder/net/detection/rt_detr_v2.py +1147 -0
  55. birder/net/detection/ssd.py +5 -5
  56. birder/net/detection/yolo_v2.py +12 -12
  57. birder/net/detection/yolo_v3.py +19 -19
  58. birder/net/detection/yolo_v4.py +16 -16
  59. birder/net/detection/yolo_v4_tiny.py +3 -3
  60. birder/net/edgenext.py +3 -3
  61. birder/net/edgevit.py +10 -14
  62. birder/net/efficientformer_v1.py +1 -1
  63. birder/net/efficientvim.py +9 -9
  64. birder/net/efficientvit_mit.py +2 -2
  65. birder/net/efficientvit_msft.py +3 -3
  66. birder/net/fasternet.py +1 -1
  67. birder/net/fastvit.py +5 -12
  68. birder/net/flexivit.py +28 -15
  69. birder/net/focalnet.py +5 -9
  70. birder/net/gc_vit.py +11 -11
  71. birder/net/ghostnet_v1.py +1 -1
  72. birder/net/ghostnet_v2.py +1 -1
  73. birder/net/groupmixformer.py +12 -12
  74. birder/net/hgnet_v1.py +1 -1
  75. birder/net/hgnet_v2.py +4 -4
  76. birder/net/hiera.py +6 -6
  77. birder/net/hieradet.py +9 -9
  78. birder/net/hornet.py +3 -3
  79. birder/net/iformer.py +4 -4
  80. birder/net/inception_next.py +4 -14
  81. birder/net/levit.py +3 -3
  82. birder/net/lit_v1.py +13 -15
  83. birder/net/lit_v1_tiny.py +9 -9
  84. birder/net/lit_v2.py +14 -15
  85. birder/net/maxvit.py +10 -22
  86. birder/net/metaformer.py +2 -2
  87. birder/net/mim/crossmae.py +5 -5
  88. birder/net/mim/fcmae.py +3 -5
  89. birder/net/mim/mae_hiera.py +7 -7
  90. birder/net/mim/mae_vit.py +3 -5
  91. birder/net/mim/simmim.py +2 -3
  92. birder/net/mobilenet_v4_hybrid.py +4 -4
  93. birder/net/mobileone.py +5 -12
  94. birder/net/mobilevit_v1.py +2 -2
  95. birder/net/mobilevit_v2.py +5 -9
  96. birder/net/mvit_v2.py +24 -24
  97. birder/net/nextvit.py +2 -2
  98. birder/net/pit.py +11 -26
  99. birder/net/pvt_v1.py +4 -4
  100. birder/net/pvt_v2.py +5 -11
  101. birder/net/regionvit.py +15 -15
  102. birder/net/regnet.py +1 -1
  103. birder/net/repghost.py +4 -5
  104. birder/net/repvgg.py +3 -5
  105. birder/net/repvit.py +2 -2
  106. birder/net/resnest.py +1 -1
  107. birder/net/rope_deit3.py +29 -15
  108. birder/net/rope_flexivit.py +28 -15
  109. birder/net/rope_vit.py +41 -23
  110. birder/net/sequencer2d.py +3 -4
  111. birder/net/shufflenet_v1.py +1 -1
  112. birder/net/shufflenet_v2.py +1 -1
  113. birder/net/simple_vit.py +47 -5
  114. birder/net/smt.py +7 -7
  115. birder/net/ssl/barlow_twins.py +1 -1
  116. birder/net/ssl/byol.py +2 -2
  117. birder/net/ssl/capi.py +3 -3
  118. birder/net/ssl/data2vec2.py +1 -1
  119. birder/net/ssl/dino_v2.py +11 -1
  120. birder/net/ssl/franca.py +26 -2
  121. birder/net/ssl/i_jepa.py +4 -4
  122. birder/net/ssl/mmcr.py +1 -1
  123. birder/net/swiftformer.py +1 -1
  124. birder/net/swin_transformer_v1.py +4 -5
  125. birder/net/swin_transformer_v2.py +4 -7
  126. birder/net/tiny_vit.py +3 -3
  127. birder/net/transnext.py +19 -19
  128. birder/net/uniformer.py +4 -4
  129. birder/net/vgg.py +1 -10
  130. birder/net/vit.py +38 -25
  131. birder/net/vit_parallel.py +35 -20
  132. birder/net/vit_sam.py +10 -10
  133. birder/net/vovnet_v2.py +1 -1
  134. birder/net/xcit.py +9 -7
  135. birder/ops/msda.py +4 -4
  136. birder/ops/swattention.py +10 -10
  137. birder/results/classification.py +3 -3
  138. birder/results/gui.py +8 -8
  139. birder/scripts/benchmark.py +37 -12
  140. birder/scripts/evaluate.py +1 -1
  141. birder/scripts/predict.py +3 -3
  142. birder/scripts/predict_detection.py +2 -2
  143. birder/scripts/train.py +63 -15
  144. birder/scripts/train_barlow_twins.py +10 -7
  145. birder/scripts/train_byol.py +10 -7
  146. birder/scripts/train_capi.py +15 -10
  147. birder/scripts/train_data2vec.py +10 -7
  148. birder/scripts/train_data2vec2.py +10 -7
  149. birder/scripts/train_detection.py +29 -14
  150. birder/scripts/train_dino_v1.py +13 -9
  151. birder/scripts/train_dino_v2.py +27 -14
  152. birder/scripts/train_dino_v2_dist.py +28 -15
  153. birder/scripts/train_franca.py +16 -9
  154. birder/scripts/train_i_jepa.py +12 -9
  155. birder/scripts/train_ibot.py +15 -11
  156. birder/scripts/train_kd.py +64 -17
  157. birder/scripts/train_mim.py +11 -8
  158. birder/scripts/train_mmcr.py +11 -8
  159. birder/scripts/train_rotnet.py +11 -7
  160. birder/scripts/train_simclr.py +10 -7
  161. birder/scripts/train_vicreg.py +10 -7
  162. birder/tools/adversarial.py +4 -4
  163. birder/tools/auto_anchors.py +5 -5
  164. birder/tools/avg_model.py +1 -1
  165. birder/tools/convert_model.py +30 -22
  166. birder/tools/det_results.py +1 -1
  167. birder/tools/download_model.py +1 -1
  168. birder/tools/ensemble_model.py +1 -1
  169. birder/tools/introspection.py +11 -2
  170. birder/tools/labelme_to_coco.py +2 -2
  171. birder/tools/model_info.py +12 -14
  172. birder/tools/pack.py +8 -8
  173. birder/tools/quantize_model.py +53 -4
  174. birder/tools/results.py +2 -2
  175. birder/tools/show_det_iterator.py +19 -6
  176. birder/tools/show_iterator.py +2 -2
  177. birder/tools/similarity.py +5 -5
  178. birder/tools/stats.py +4 -6
  179. birder/tools/voc_to_coco.py +1 -1
  180. birder/version.py +1 -1
  181. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  182. birder-0.4.1.dist-info/RECORD +300 -0
  183. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  184. birder-0.4.0.dist-info/RECORD +0 -297
  185. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  186. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  187. {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -58,7 +58,7 @@ class HungarianMatcher(nn.Module):
58
58
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
59
59
  ) -> list[torch.Tensor]:
60
60
  with torch.no_grad():
61
- (B, num_queries) = class_logits.shape[:2]
61
+ B, num_queries = class_logits.shape[:2]
62
62
 
63
63
  # We flatten to compute the cost matrices in a batch
64
64
  out_prob = class_logits.flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
@@ -111,8 +111,7 @@ def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
111
111
  class MultiScaleDeformableAttention(nn.Module):
112
112
  def __init__(self, d_model: int, n_levels: int, n_heads: int, n_points: int) -> None:
113
113
  super().__init__()
114
- if d_model % n_heads != 0:
115
- raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
114
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
116
115
 
117
116
  # Ensure dim_per_head is power of 2
118
117
  dim_per_head = d_model // n_heads
@@ -133,9 +132,9 @@ class MultiScaleDeformableAttention(nn.Module):
133
132
  self.value_proj = nn.Linear(d_model, d_model)
134
133
  self.output_proj = nn.Linear(d_model, d_model)
135
134
 
136
- self._reset_parameters()
135
+ self.reset_parameters()
137
136
 
138
- def _reset_parameters(self) -> None:
137
+ def reset_parameters(self) -> None:
139
138
  nn.init.constant_(self.sampling_offsets.weight, 0.0)
140
139
  thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
141
140
  grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
@@ -166,8 +165,8 @@ class MultiScaleDeformableAttention(nn.Module):
166
165
  input_level_start_index: torch.Tensor,
167
166
  input_padding_mask: Optional[torch.Tensor] = None,
168
167
  ) -> torch.Tensor:
169
- (N, num_queries, _) = query.size()
170
- (N, sequence_length, _) = input_flatten.size()
168
+ N, num_queries, _ = query.size()
169
+ N, sequence_length, _ = input_flatten.size()
171
170
  assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
172
171
 
173
172
  value = self.value_proj(input_flatten)
@@ -283,7 +282,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
283
282
  q = tgt + query_pos
284
283
  k = tgt + query_pos
285
284
 
286
- (tgt2, _) = self.self_attn(
285
+ tgt2, _ = self.self_attn(
287
286
  q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
288
287
  )
289
288
  tgt2 = tgt2.transpose(0, 1)
@@ -318,7 +317,7 @@ class DeformableTransformerEncoder(nn.Module):
318
317
  for lvl, spatial_shape in enumerate(spatial_shapes):
319
318
  H = spatial_shape[0]
320
319
  W = spatial_shape[1]
321
- (ref_y, ref_x) = torch.meshgrid(
320
+ ref_y, ref_x = torch.meshgrid(
322
321
  torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
323
322
  torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
324
323
  indexing="ij",
@@ -454,7 +453,7 @@ class DeformableTransformer(nn.Module):
454
453
 
455
454
  for m in self.modules():
456
455
  if isinstance(m, MultiScaleDeformableAttention):
457
- m._reset_parameters()
456
+ m.reset_parameters()
458
457
 
459
458
  nn.init.xavier_uniform_(self.reference_points.weight, gain=1.0)
460
459
  nn.init.zeros_(self.reference_points.bias)
@@ -462,7 +461,7 @@ class DeformableTransformer(nn.Module):
462
461
  nn.init.normal_(self.level_embed)
463
462
 
464
463
  def get_valid_ratio(self, mask: torch.Tensor) -> torch.Tensor:
465
- (_, H, W) = mask.size()
464
+ _, H, W = mask.size()
466
465
  valid_h = torch.sum(~mask[:, :, 0], 1)
467
466
  valid_w = torch.sum(~mask[:, 0, :], 1)
468
467
  valid_ratio_h = valid_h.float() / H
@@ -485,7 +484,7 @@ class DeformableTransformer(nn.Module):
485
484
  mask_list = []
486
485
  spatial_shape_list: list[list[int]] = [] # list[tuple[int, int]] not supported on TorchScript
487
486
  for lvl, (src, pos_embed, mask) in enumerate(zip(srcs, pos_embeds, masks)):
488
- (_, _, H, W) = src.size()
487
+ _, _, H, W = src.size()
489
488
  spatial_shape_list.append([H, W])
490
489
  src = src.flatten(2).transpose(1, 2)
491
490
  pos_embed = pos_embed.flatten(2).transpose(1, 2)
@@ -508,14 +507,14 @@ class DeformableTransformer(nn.Module):
508
507
  )
509
508
 
510
509
  # Prepare input for decoder
511
- (B, _, C) = memory.size()
510
+ B, _, C = memory.size()
512
511
  query_embed, tgt = torch.split(query_embed, C, dim=1)
513
512
  query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
514
513
  tgt = tgt.unsqueeze(0).expand(B, -1, -1)
515
514
  reference_points = self.reference_points(query_embed).sigmoid()
516
515
 
517
516
  # Decoder
518
- (hs, inter_references) = self.decoder(
517
+ hs, inter_references = self.decoder(
519
518
  tgt, reference_points, memory, spatial_shapes, level_start_index, query_embed, valid_ratios, mask_flatten
520
519
  )
521
520
 
@@ -632,7 +631,7 @@ class Deformable_DETR(DetectionBaseNet):
632
631
  prior_prob = 0.01
633
632
  bias_value = -math.log((1 - prior_prob) / prior_prob)
634
633
  for class_embed in self.class_embed:
635
- class_embed.bias.data = torch.ones(self.num_classes) * bias_value
634
+ nn.init.constant_(class_embed.bias, bias_value)
636
635
 
637
636
  def freeze(self, freeze_classifier: bool = True) -> None:
638
637
  for param in self.parameters():
@@ -656,20 +655,19 @@ class Deformable_DETR(DetectionBaseNet):
656
655
  ) -> torch.Tensor:
657
656
  idx = self._get_src_permutation_idx(indices)
658
657
  target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
659
- target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
660
- target_classes[idx] = target_classes_o
661
658
 
662
659
  target_classes_onehot = torch.zeros(
663
- [cls_logits.shape[0], cls_logits.shape[1], cls_logits.shape[2] + 1],
660
+ cls_logits.size(0),
661
+ cls_logits.size(1),
662
+ cls_logits.size(2) + 1,
664
663
  dtype=cls_logits.dtype,
665
- layout=cls_logits.layout,
666
664
  device=cls_logits.device,
667
665
  )
668
- target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
669
-
666
+ target_classes_onehot[idx[0], idx[1], target_classes_o] = 1
670
667
  target_classes_onehot = target_classes_onehot[:, :, :-1]
668
+
671
669
  loss = sigmoid_focal_loss(cls_logits, target_classes_onehot, alpha=0.25, gamma=2.0)
672
- loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.shape[1]
670
+ loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.size(1)
673
671
 
674
672
  return loss_ce
675
673
 
@@ -719,7 +717,7 @@ class Deformable_DETR(DetectionBaseNet):
719
717
  for idx in range(cls_logits.size(0)):
720
718
  indices = self.matcher(cls_logits[idx], box_output[idx], targets)
721
719
  loss_ce_i = self._class_loss(cls_logits[idx], targets, indices, num_boxes)
722
- (loss_bbox_i, loss_giou_i) = self._box_loss(box_output[idx], targets, indices, num_boxes)
720
+ loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
723
721
  loss_ce_list.append(loss_ce_i)
724
722
  loss_bbox_list.append(loss_bbox_i)
725
723
  loss_giou_list.append(loss_giou_i)
@@ -739,7 +737,7 @@ class Deformable_DETR(DetectionBaseNet):
739
737
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
740
738
  ) -> list[dict[str, torch.Tensor]]:
741
739
  prob = class_logits.sigmoid()
742
- (topk_values, topk_indexes) = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
740
+ topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
743
741
  scores = topk_values
744
742
  topk_boxes = topk_indexes // class_logits.shape[2]
745
743
  labels = topk_indexes % class_logits.shape[2]
@@ -752,7 +750,7 @@ class Deformable_DETR(DetectionBaseNet):
752
750
  boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
753
751
 
754
752
  # Convert from relative [0, 1] to absolute [0, height] coordinates
755
- (img_h, img_w) = target_sizes.unbind(1)
753
+ img_h, img_w = target_sizes.unbind(1)
756
754
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
757
755
  boxes = boxes * scale_fct[:, None, :]
758
756
 
@@ -760,7 +758,7 @@ class Deformable_DETR(DetectionBaseNet):
760
758
  for s, l, b in zip(scores, labels, boxes):
761
759
  # Non-maximum suppression
762
760
  if self.soft_nms is not None:
763
- (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
761
+ soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
764
762
  s[keep] = soft_scores
765
763
 
766
764
  b = b[keep]
@@ -797,14 +795,14 @@ class Deformable_DETR(DetectionBaseNet):
797
795
  mask_size = feature_list[idx].shape[-2:]
798
796
  m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
799
797
  else:
800
- (B, _, H, W) = feature_list[idx].size()
798
+ B, _, H, W = feature_list[idx].size()
801
799
  m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
802
800
 
803
801
  feature_list[idx] = proj(feature_list[idx])
804
802
  mask_list.append(m)
805
803
  pos_list.append(self.pos_enc(feature_list[idx], m))
806
804
 
807
- (hs, init_reference, inter_references) = self.transformer(
805
+ hs, init_reference, inter_references = self.transformer(
808
806
  feature_list, pos_list, self.query_embed.weight, mask_list
809
807
  )
810
808
  outputs_classes = []
@@ -51,7 +51,7 @@ class HungarianMatcher(nn.Module):
51
51
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
52
52
  ) -> list[torch.Tensor]:
53
53
  with torch.no_grad():
54
- (B, num_queries) = class_logits.shape[:2]
54
+ B, num_queries = class_logits.shape[:2]
55
55
 
56
56
  # We flatten to compute the cost matrices in a batch
57
57
  out_prob = class_logits.flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
@@ -111,7 +111,7 @@ class TransformerEncoderLayer(nn.Module):
111
111
  q = src + pos
112
112
  k = src + pos
113
113
 
114
- (src2, _) = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask, need_weights=False)
114
+ src2, _ = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask, need_weights=False)
115
115
  src = src + self.dropout1(src2)
116
116
  src = self.norm1(src)
117
117
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
@@ -151,10 +151,10 @@ class TransformerDecoderLayer(nn.Module):
151
151
  q = tgt + query_pos
152
152
  k = tgt + query_pos
153
153
 
154
- (tgt2, _) = self.self_attn(q, k, value=tgt, need_weights=False)
154
+ tgt2, _ = self.self_attn(q, k, value=tgt, need_weights=False)
155
155
  tgt = tgt + self.dropout1(tgt2)
156
156
  tgt = self.norm1(tgt)
157
- (tgt2, _) = self.multihead_attn(
157
+ tgt2, _ = self.multihead_attn(
158
158
  query=tgt + query_pos,
159
159
  key=memory + pos,
160
160
  value=memory,
@@ -270,7 +270,7 @@ class PositionEmbeddingSine(nn.Module):
270
270
 
271
271
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
272
272
  if mask is None:
273
- (B, _, H, W) = x.size()
273
+ B, _, H, W = x.size()
274
274
  mask = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
275
275
 
276
276
  not_mask = ~mask
@@ -430,7 +430,7 @@ class DETR(DetectionBaseNet):
430
430
  for idx in range(cls_logits.size(0)):
431
431
  indices = self.matcher(cls_logits[idx], box_output[idx], targets)
432
432
  loss_ce_i = self._class_loss(cls_logits[idx], targets, indices)
433
- (loss_bbox_i, loss_giou_i) = self._box_loss(box_output[idx], targets, indices, num_boxes)
433
+ loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
434
434
  loss_ce_list.append(loss_ce_i)
435
435
  loss_bbox_list.append(loss_bbox_i)
436
436
  loss_giou_list.append(loss_giou_i)
@@ -450,7 +450,7 @@ class DETR(DetectionBaseNet):
450
450
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
451
451
  ) -> list[dict[str, torch.Tensor]]:
452
452
  prob = F.softmax(class_logits, -1)
453
- (scores, labels) = prob[..., 1:].max(-1)
453
+ scores, labels = prob[..., 1:].max(-1)
454
454
  labels = labels + 1
455
455
 
456
456
  # TorchScript doesn't support creating tensor from tuples, convert everything to lists
@@ -460,7 +460,7 @@ class DETR(DetectionBaseNet):
460
460
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
461
461
 
462
462
  # Convert from relative [0, 1] to absolute [0, height] coordinates
463
- (img_h, img_w) = target_sizes.unbind(1)
463
+ img_h, img_w = target_sizes.unbind(1)
464
464
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
465
465
  boxes = boxes * scale_fct[:, None, :]
466
466
 
@@ -468,7 +468,7 @@ class DETR(DetectionBaseNet):
468
468
  for s, l, b in zip(scores, labels, boxes):
469
469
  # Non-maximum suppression
470
470
  if self.soft_nms is not None:
471
- (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
471
+ soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
472
472
  s[keep] = soft_scores
473
473
 
474
474
  b = b[keep]
@@ -136,8 +136,8 @@ class ResampleFeatureMap(nn.Module):
136
136
  if self.conv is not None:
137
137
  x = self.conv(x)
138
138
 
139
- (in_h, in_w) = x.shape[-2:]
140
- (target_h, target_w) = target_size
139
+ in_h, in_w = x.shape[-2:]
140
+ target_h, target_w = target_size
141
141
  if in_h == target_h and in_w == target_w:
142
142
  return x
143
143
 
@@ -358,13 +358,7 @@ class HeadNet(nn.Module):
358
358
  for _ in range(repeats):
359
359
  layers.append(
360
360
  nn.Conv2d(
361
- fpn_channels,
362
- fpn_channels,
363
- kernel_size=(3, 3),
364
- stride=(1, 1),
365
- padding=(1, 1),
366
- groups=fpn_channels,
367
- bias=True,
361
+ fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
368
362
  )
369
363
  )
370
364
  layers.append(
@@ -383,22 +377,9 @@ class HeadNet(nn.Module):
383
377
  self.conv_repeat = nn.Sequential(*layers)
384
378
  self.predict = nn.Sequential(
385
379
  nn.Conv2d(
386
- fpn_channels,
387
- fpn_channels,
388
- kernel_size=(3, 3),
389
- stride=(1, 1),
390
- padding=(1, 1),
391
- groups=fpn_channels,
392
- bias=True,
393
- ),
394
- nn.Conv2d(
395
- fpn_channels,
396
- num_outputs * num_anchors,
397
- kernel_size=(1, 1),
398
- stride=(1, 1),
399
- padding=(0, 0),
400
- bias=True,
380
+ fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
401
381
  ),
382
+ nn.Conv2d(fpn_channels, num_outputs * num_anchors, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
402
383
  )
403
384
 
404
385
  def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
@@ -453,7 +434,7 @@ class ClassificationHead(HeadNet):
453
434
  cls_logits = self.predict(cls_logits)
454
435
 
455
436
  # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
456
- (N, _, H, W) = cls_logits.shape
437
+ N, _, H, W = cls_logits.shape
457
438
  cls_logits = cls_logits.view(N, -1, self.num_outputs, H, W)
458
439
  cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
459
440
  cls_logits = cls_logits.reshape(N, -1, self.num_outputs) # Size=(N, HWA, K)
@@ -504,7 +485,7 @@ class RegressionHead(HeadNet):
504
485
  bbox_regression = self.predict(bbox_regression)
505
486
 
506
487
  # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
507
- (N, _, H, W) = bbox_regression.shape
488
+ N, _, H, W = bbox_regression.shape
508
489
  bbox_regression = bbox_regression.view(N, -1, 4, H, W)
509
490
  bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
510
491
  bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
@@ -663,7 +644,7 @@ class EfficientDet(DetectionBaseNet):
663
644
 
664
645
  # Keep only topk scoring predictions
665
646
  num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
666
- (scores_per_level, idxs) = scores_per_level.topk(num_topk)
647
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
667
648
  topk_idxs = topk_idxs[idxs]
668
649
 
669
650
  anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
@@ -685,7 +666,7 @@ class EfficientDet(DetectionBaseNet):
685
666
 
686
667
  # Non-maximum suppression
687
668
  if self.soft_nms is not None:
688
- (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
669
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
689
670
  image_scores[keep] = soft_scores
690
671
  else:
691
672
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
@@ -150,7 +150,7 @@ def concat_box_prediction_layers(
150
150
  # all feature levels concatenated, so we keep the same representation
151
151
  # for the objectness and the box_regression
152
152
  for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
153
- (N, AxC, H, W) = box_cls_per_level.shape # pylint: disable=invalid-name
153
+ N, AxC, H, W = box_cls_per_level.shape # pylint: disable=invalid-name
154
154
  Ax4 = box_regression_per_level.shape[1] # pylint: disable=invalid-name
155
155
  A = Ax4 // 4
156
156
  C = AxC // A
@@ -240,7 +240,7 @@ class RegionProposalNetwork(nn.Module):
240
240
 
241
241
  # Get the targets corresponding GT for each proposal
242
242
  # NB: need to clamp the indices because we can have a single
243
- # GT in the image, and matched_idxs can be -2, which goes out of bounds
243
+ # GT in the image and matched_idxs can be -2, which goes out of bounds
244
244
  matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
245
245
 
246
246
  labels_per_image = matched_idxs >= 0
@@ -265,7 +265,7 @@ class RegionProposalNetwork(nn.Module):
265
265
  for ob in objectness.split(num_anchors_per_level, 1):
266
266
  num_anchors = ob.shape[1]
267
267
  pre_nms_top_n = min(self.pre_nms_top_n(), int(ob.size(1)))
268
- (_, top_n_idx) = ob.topk(pre_nms_top_n, dim=1)
268
+ _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
269
269
  r.append(top_n_idx + offset)
270
270
  offset += num_anchors
271
271
 
@@ -310,19 +310,19 @@ class RegionProposalNetwork(nn.Module):
310
310
 
311
311
  # Remove small boxes
312
312
  keep = box_ops.remove_small_boxes(boxes, self.min_size)
313
- (boxes, scores, lvl) = boxes[keep], scores[keep], lvl[keep]
313
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
314
314
 
315
315
  # Remove low scoring boxes
316
316
  # use >= for Backwards compatibility
317
317
  keep = torch.where(scores >= self.score_thresh)[0]
318
- (boxes, scores, lvl) = boxes[keep], scores[keep], lvl[keep]
318
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
319
319
 
320
320
  # Non-maximum suppression, independently done per level
321
321
  keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
322
322
 
323
323
  # Keep only topk scoring predictions
324
324
  keep = keep[: self.post_nms_top_n()]
325
- (boxes, scores) = boxes[keep], scores[keep]
325
+ boxes, scores = boxes[keep], scores[keep]
326
326
 
327
327
  final_boxes.append(boxes)
328
328
  final_scores.append(scores)
@@ -336,7 +336,7 @@ class RegionProposalNetwork(nn.Module):
336
336
  labels: list[torch.Tensor],
337
337
  regression_targets: list[torch.Tensor],
338
338
  ) -> tuple[torch.Tensor, torch.Tensor]:
339
- (sampled_pos_idxs, sampled_neg_idxs) = self.fg_bg_sampler(labels)
339
+ sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
340
340
  sampled_pos_idxs = torch.where(torch.concat(sampled_pos_idxs, dim=0))[0]
341
341
  sampled_neg_idxs = torch.where(torch.concat(sampled_neg_idxs, dim=0))[0]
342
342
 
@@ -364,29 +364,29 @@ class RegionProposalNetwork(nn.Module):
364
364
  ) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
365
365
  # RPN uses all feature maps that are available
366
366
  features_list = list(features.values())
367
- (objectness, pred_bbox_deltas) = self.head(features_list)
367
+ objectness, pred_bbox_deltas = self.head(features_list)
368
368
  anchors = self.anchor_generator(images, features_list)
369
369
 
370
370
  num_images = len(anchors)
371
371
  num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
372
372
  num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
373
- (objectness, pred_bbox_deltas) = concat_box_prediction_layers(objectness, pred_bbox_deltas)
373
+ objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
374
374
 
375
375
  # Apply pred_bbox_deltas to anchors to obtain the decoded proposals
376
376
  # note that we detach the deltas because Faster R-CNN do not backprop through
377
377
  # the proposals
378
378
  proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
379
379
  proposals = proposals.view(num_images, -1, 4)
380
- (boxes, _scores) = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
380
+ boxes, _scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
381
381
 
382
382
  losses: dict[str, torch.Tensor] = {}
383
383
  if self.training is True:
384
384
  if targets is None:
385
385
  raise ValueError("targets should not be None")
386
386
 
387
- (labels, matched_gt_boxes) = self.assign_targets_to_anchors(anchors, targets)
387
+ labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
388
388
  regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
389
- (loss_objectness, loss_rpn_box_reg) = self.compute_loss(
389
+ loss_objectness, loss_rpn_box_reg = self.compute_loss(
390
390
  objectness, pred_bbox_deltas, labels, regression_targets
391
391
  )
392
392
  losses = {
@@ -405,7 +405,7 @@ class FastRCNNConvFCHead(nn.Sequential):
405
405
  fc_layers: list[int],
406
406
  norm_layer: Optional[Callable[..., nn.Module]] = None,
407
407
  ):
408
- (in_channels, in_height, in_width) = input_size
408
+ in_channels, in_height, in_width = input_size
409
409
 
410
410
  blocks = []
411
411
  previous_channels = in_channels
@@ -481,7 +481,7 @@ def faster_rcnn_loss(
481
481
  # advanced indexing
482
482
  sampled_pos_idxs_subset = torch.where(labels > 0)[0]
483
483
  labels_pos = labels[sampled_pos_idxs_subset]
484
- (N, _num_classes) = class_logits.shape
484
+ N, _num_classes = class_logits.shape
485
485
  box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
486
486
 
487
487
  box_loss = F.smooth_l1_loss(
@@ -573,7 +573,7 @@ class RoIHeads(nn.Module):
573
573
  return (matched_idxs, labels)
574
574
 
575
575
  def subsample(self, labels: list[torch.Tensor]) -> list[torch.Tensor]:
576
- (sampled_pos_idxs, sampled_neg_idxs) = self.fg_bg_sampler(labels)
576
+ sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
577
577
  sampled_idxs = []
578
578
  for pos_idxs_img, neg_idxs_img in zip(sampled_pos_idxs, sampled_neg_idxs):
579
579
  img_sampled_idxs = torch.where(pos_idxs_img | neg_idxs_img)[0]
@@ -610,7 +610,7 @@ class RoIHeads(nn.Module):
610
610
  proposals = self.add_gt_proposals(proposals, gt_boxes)
611
611
 
612
612
  # Get matching gt indices for each proposal
613
- (matched_idxs, labels) = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
613
+ matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
614
614
 
615
615
  # Sample a fixed proportion of positive-negative proposals
616
616
  sampled_idxs = self.subsample(labels)
@@ -713,7 +713,7 @@ class RoIHeads(nn.Module):
713
713
  raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
714
714
 
715
715
  if self.training is True:
716
- (proposals, _matched_idxs, labels, regression_targets) = self.select_training_samples(proposals, targets)
716
+ proposals, _matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
717
717
  else:
718
718
  labels = None
719
719
  regression_targets = None
@@ -721,7 +721,7 @@ class RoIHeads(nn.Module):
721
721
 
722
722
  box_features = self.box_roi_pool(features, proposals, image_shapes)
723
723
  box_features = self.box_head(box_features)
724
- (class_logits, box_regression) = self.box_predictor(box_features)
724
+ class_logits, box_regression = self.box_predictor(box_features)
725
725
 
726
726
  losses = {}
727
727
  result: list[dict[str, torch.Tensor]] = []
@@ -731,11 +731,11 @@ class RoIHeads(nn.Module):
731
731
  if regression_targets is None:
732
732
  raise ValueError("regression_targets cannot be None")
733
733
 
734
- (loss_classifier, loss_box_reg) = faster_rcnn_loss(class_logits, box_regression, labels, regression_targets)
734
+ loss_classifier, loss_box_reg = faster_rcnn_loss(class_logits, box_regression, labels, regression_targets)
735
735
  losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
736
736
 
737
737
  else:
738
- (boxes, scores, labels) = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
738
+ boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
739
739
  num_images = len(boxes)
740
740
  for i in range(num_images):
741
741
  result.append(
@@ -868,8 +868,8 @@ class Faster_RCNN(DetectionBaseNet):
868
868
  images = self._to_img_list(x, image_sizes)
869
869
 
870
870
  features = self.backbone_with_fpn(x)
871
- (proposals, proposal_losses) = self.rpn(images, features, targets)
872
- (detections, detector_losses) = self.roi_heads(features, proposals, images.image_sizes, targets)
871
+ proposals, proposal_losses = self.rpn(images, features, targets)
872
+ detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
873
873
 
874
874
  losses = {}
875
875
  losses.update(detector_losses)
@@ -125,7 +125,7 @@ class FCOSClassificationHead(nn.Module):
125
125
  cls_logits = self.cls_logits(cls_logits)
126
126
 
127
127
  # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
128
- (N, _, H, W) = cls_logits.size()
128
+ N, _, H, W = cls_logits.size()
129
129
  cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
130
130
  cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
131
131
  cls_logits = cls_logits.reshape(N, -1, self.num_classes) # (N, HWA, 4)
@@ -165,7 +165,7 @@ class FCOSRegressionHead(nn.Module):
165
165
  bbox_ctrness = self.bbox_ctrness(bbox_feature)
166
166
 
167
167
  # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
168
- (N, _, H, W) = bbox_regression.size()
168
+ N, _, H, W = bbox_regression.size()
169
169
  bbox_regression = bbox_regression.view(N, -1, 4, H, W)
170
170
  bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
171
171
  bbox_regression = bbox_regression.reshape(N, -1, 4) # (N, HWA, 4)
@@ -262,7 +262,7 @@ class FCOSHead(nn.Module):
262
262
 
263
263
  def forward(self, x: list[torch.Tensor]) -> dict[str, torch.Tensor]:
264
264
  cls_logits = self.classification_head(x)
265
- (bbox_regression, bbox_ctrness) = self.regression_head(x)
265
+ bbox_regression, bbox_ctrness = self.regression_head(x)
266
266
 
267
267
  return {
268
268
  "cls_logits": cls_logits,
@@ -370,8 +370,8 @@ class FCOS(DetectionBaseNet):
370
370
  ).values < self.center_sampling_radius * anchor_sizes[:, None]
371
371
 
372
372
  # Compute pairwise distance between N points and M boxes
373
- (x, y) = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
374
- (x0, y0, x1, y1) = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
373
+ x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
374
+ x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
375
375
  pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
376
376
 
377
377
  # Anchor point must be inside gt
@@ -388,7 +388,7 @@ class FCOS(DetectionBaseNet):
388
388
  # Match the GT box with minimum area, if there are multiple GT matches
389
389
  gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
390
390
  pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
391
- (min_values, matched_idx) = pairwise_match.max(dim=1) # R, per-anchor match
391
+ min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
392
392
  matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1
393
393
 
394
394
  matched_idxs.append(matched_idx)
@@ -433,7 +433,7 @@ class FCOS(DetectionBaseNet):
433
433
 
434
434
  # Keep only topk scoring predictions
435
435
  num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
436
- (scores_per_level, idxs) = scores_per_level.topk(num_topk)
436
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
437
437
  topk_idxs = topk_idxs[idxs]
438
438
 
439
439
  anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
@@ -455,7 +455,7 @@ class FCOS(DetectionBaseNet):
455
455
 
456
456
  # Non-maximum suppression
457
457
  if self.soft_nms is not None:
458
- (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
458
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
459
459
  image_scores[keep] = soft_scores
460
460
  else:
461
461
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)