birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. birder/adversarial/base.py +1 -1
  2. birder/adversarial/simba.py +4 -4
  3. birder/common/cli.py +1 -1
  4. birder/common/fs_ops.py +13 -13
  5. birder/common/lib.py +2 -2
  6. birder/common/masking.py +3 -3
  7. birder/common/training_cli.py +24 -2
  8. birder/common/training_utils.py +28 -4
  9. birder/data/collators/detection.py +9 -1
  10. birder/data/transforms/detection.py +27 -8
  11. birder/data/transforms/mosaic.py +1 -1
  12. birder/datahub/classification.py +3 -3
  13. birder/inference/classification.py +3 -3
  14. birder/inference/data_parallel.py +1 -1
  15. birder/inference/detection.py +5 -5
  16. birder/inference/wbf.py +1 -1
  17. birder/introspection/attention_rollout.py +6 -6
  18. birder/introspection/feature_pca.py +4 -4
  19. birder/introspection/gradcam.py +1 -1
  20. birder/introspection/guided_backprop.py +2 -2
  21. birder/introspection/transformer_attribution.py +4 -4
  22. birder/layers/attention_pool.py +2 -2
  23. birder/layers/layer_scale.py +1 -1
  24. birder/model_registry/model_registry.py +2 -1
  25. birder/net/__init__.py +4 -10
  26. birder/net/_rope_vit_configs.py +435 -0
  27. birder/net/_vit_configs.py +466 -0
  28. birder/net/alexnet.py +5 -5
  29. birder/net/base.py +28 -3
  30. birder/net/biformer.py +18 -17
  31. birder/net/cait.py +7 -7
  32. birder/net/cas_vit.py +1 -1
  33. birder/net/coat.py +27 -27
  34. birder/net/conv2former.py +3 -3
  35. birder/net/convmixer.py +1 -1
  36. birder/net/convnext_v1.py +3 -11
  37. birder/net/convnext_v1_iso.py +198 -0
  38. birder/net/convnext_v2.py +2 -10
  39. birder/net/crossformer.py +9 -9
  40. birder/net/crossvit.py +6 -6
  41. birder/net/cspnet.py +1 -1
  42. birder/net/cswin_transformer.py +10 -10
  43. birder/net/davit.py +11 -11
  44. birder/net/deit.py +68 -29
  45. birder/net/deit3.py +69 -204
  46. birder/net/densenet.py +9 -8
  47. birder/net/detection/__init__.py +4 -0
  48. birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
  49. birder/net/detection/base.py +6 -5
  50. birder/net/detection/deformable_detr.py +31 -30
  51. birder/net/detection/detr.py +14 -11
  52. birder/net/detection/efficientdet.py +10 -29
  53. birder/net/detection/faster_rcnn.py +22 -22
  54. birder/net/detection/fcos.py +8 -8
  55. birder/net/detection/plain_detr.py +852 -0
  56. birder/net/detection/retinanet.py +4 -4
  57. birder/net/detection/rt_detr_v1.py +81 -25
  58. birder/net/detection/rt_detr_v2.py +1147 -0
  59. birder/net/detection/ssd.py +5 -5
  60. birder/net/detection/yolo_v2.py +12 -12
  61. birder/net/detection/yolo_v3.py +19 -19
  62. birder/net/detection/yolo_v4.py +16 -16
  63. birder/net/detection/yolo_v4_tiny.py +3 -3
  64. birder/net/dpn.py +1 -2
  65. birder/net/edgenext.py +5 -4
  66. birder/net/edgevit.py +13 -14
  67. birder/net/efficientformer_v1.py +3 -2
  68. birder/net/efficientformer_v2.py +18 -31
  69. birder/net/efficientnet_v2.py +3 -0
  70. birder/net/efficientvim.py +9 -9
  71. birder/net/efficientvit_mit.py +7 -7
  72. birder/net/efficientvit_msft.py +3 -3
  73. birder/net/fasternet.py +3 -3
  74. birder/net/fastvit.py +5 -12
  75. birder/net/flexivit.py +50 -58
  76. birder/net/focalnet.py +5 -9
  77. birder/net/gc_vit.py +11 -11
  78. birder/net/ghostnet_v1.py +1 -1
  79. birder/net/ghostnet_v2.py +1 -1
  80. birder/net/groupmixformer.py +13 -13
  81. birder/net/hgnet_v1.py +6 -6
  82. birder/net/hgnet_v2.py +4 -4
  83. birder/net/hiera.py +6 -6
  84. birder/net/hieradet.py +9 -9
  85. birder/net/hornet.py +3 -3
  86. birder/net/iformer.py +4 -4
  87. birder/net/inception_next.py +5 -15
  88. birder/net/inception_resnet_v1.py +3 -3
  89. birder/net/inception_resnet_v2.py +7 -4
  90. birder/net/inception_v3.py +3 -0
  91. birder/net/inception_v4.py +3 -0
  92. birder/net/levit.py +3 -3
  93. birder/net/lit_v1.py +13 -15
  94. birder/net/lit_v1_tiny.py +9 -9
  95. birder/net/lit_v2.py +14 -15
  96. birder/net/maxvit.py +11 -23
  97. birder/net/metaformer.py +5 -5
  98. birder/net/mim/crossmae.py +6 -6
  99. birder/net/mim/fcmae.py +3 -5
  100. birder/net/mim/mae_hiera.py +7 -7
  101. birder/net/mim/mae_vit.py +4 -6
  102. birder/net/mim/simmim.py +3 -4
  103. birder/net/mobilenet_v1.py +0 -9
  104. birder/net/mobilenet_v2.py +38 -44
  105. birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
  106. birder/net/mobilenet_v4_hybrid.py +4 -4
  107. birder/net/mobileone.py +5 -12
  108. birder/net/mobilevit_v1.py +7 -34
  109. birder/net/mobilevit_v2.py +6 -54
  110. birder/net/moganet.py +8 -5
  111. birder/net/mvit_v2.py +30 -30
  112. birder/net/nextvit.py +2 -2
  113. birder/net/nfnet.py +4 -0
  114. birder/net/pit.py +11 -26
  115. birder/net/pvt_v1.py +9 -9
  116. birder/net/pvt_v2.py +10 -16
  117. birder/net/regionvit.py +15 -15
  118. birder/net/regnet.py +1 -1
  119. birder/net/repghost.py +5 -35
  120. birder/net/repvgg.py +3 -5
  121. birder/net/repvit.py +2 -2
  122. birder/net/resmlp.py +2 -2
  123. birder/net/resnest.py +4 -1
  124. birder/net/resnet_v1.py +125 -1
  125. birder/net/resnet_v2.py +75 -1
  126. birder/net/resnext.py +35 -1
  127. birder/net/rope_deit3.py +62 -151
  128. birder/net/rope_flexivit.py +46 -33
  129. birder/net/rope_vit.py +44 -758
  130. birder/net/sequencer2d.py +3 -4
  131. birder/net/shufflenet_v1.py +1 -1
  132. birder/net/shufflenet_v2.py +1 -1
  133. birder/net/simple_vit.py +69 -21
  134. birder/net/smt.py +8 -8
  135. birder/net/squeezenet.py +5 -12
  136. birder/net/squeezenext.py +0 -24
  137. birder/net/ssl/barlow_twins.py +1 -1
  138. birder/net/ssl/byol.py +2 -2
  139. birder/net/ssl/capi.py +4 -4
  140. birder/net/ssl/data2vec.py +1 -1
  141. birder/net/ssl/data2vec2.py +1 -1
  142. birder/net/ssl/dino_v2.py +13 -3
  143. birder/net/ssl/franca.py +28 -4
  144. birder/net/ssl/i_jepa.py +5 -5
  145. birder/net/ssl/ibot.py +1 -1
  146. birder/net/ssl/mmcr.py +1 -1
  147. birder/net/swiftformer.py +13 -3
  148. birder/net/swin_transformer_v1.py +4 -5
  149. birder/net/swin_transformer_v2.py +5 -8
  150. birder/net/tiny_vit.py +6 -19
  151. birder/net/transnext.py +19 -19
  152. birder/net/uniformer.py +4 -4
  153. birder/net/van.py +2 -2
  154. birder/net/vgg.py +1 -10
  155. birder/net/vit.py +72 -987
  156. birder/net/vit_parallel.py +35 -20
  157. birder/net/vit_sam.py +23 -48
  158. birder/net/vovnet_v2.py +1 -1
  159. birder/net/xcit.py +16 -13
  160. birder/ops/msda.py +4 -4
  161. birder/ops/swattention.py +10 -10
  162. birder/results/classification.py +3 -3
  163. birder/results/gui.py +8 -8
  164. birder/scripts/benchmark.py +37 -12
  165. birder/scripts/evaluate.py +1 -1
  166. birder/scripts/predict.py +3 -3
  167. birder/scripts/predict_detection.py +2 -2
  168. birder/scripts/train.py +63 -15
  169. birder/scripts/train_barlow_twins.py +10 -7
  170. birder/scripts/train_byol.py +10 -7
  171. birder/scripts/train_capi.py +15 -10
  172. birder/scripts/train_data2vec.py +10 -7
  173. birder/scripts/train_data2vec2.py +10 -7
  174. birder/scripts/train_detection.py +29 -14
  175. birder/scripts/train_dino_v1.py +13 -9
  176. birder/scripts/train_dino_v2.py +27 -14
  177. birder/scripts/train_dino_v2_dist.py +28 -15
  178. birder/scripts/train_franca.py +16 -9
  179. birder/scripts/train_i_jepa.py +12 -9
  180. birder/scripts/train_ibot.py +15 -11
  181. birder/scripts/train_kd.py +64 -17
  182. birder/scripts/train_mim.py +11 -8
  183. birder/scripts/train_mmcr.py +11 -8
  184. birder/scripts/train_rotnet.py +11 -7
  185. birder/scripts/train_simclr.py +10 -7
  186. birder/scripts/train_vicreg.py +10 -7
  187. birder/tools/adversarial.py +4 -4
  188. birder/tools/auto_anchors.py +5 -5
  189. birder/tools/avg_model.py +1 -1
  190. birder/tools/convert_model.py +30 -22
  191. birder/tools/det_results.py +1 -1
  192. birder/tools/download_model.py +1 -1
  193. birder/tools/ensemble_model.py +1 -1
  194. birder/tools/introspection.py +12 -3
  195. birder/tools/labelme_to_coco.py +2 -2
  196. birder/tools/model_info.py +15 -15
  197. birder/tools/pack.py +8 -8
  198. birder/tools/quantize_model.py +53 -4
  199. birder/tools/results.py +2 -2
  200. birder/tools/show_det_iterator.py +19 -6
  201. birder/tools/show_iterator.py +2 -2
  202. birder/tools/similarity.py +5 -5
  203. birder/tools/stats.py +4 -6
  204. birder/tools/voc_to_coco.py +1 -1
  205. birder/version.py +1 -1
  206. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
  207. birder-0.4.1.dist-info/RECORD +300 -0
  208. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
  209. birder/net/mobilenet_v3_small.py +0 -43
  210. birder/net/se_resnet_v1.py +0 -105
  211. birder/net/se_resnet_v2.py +0 -59
  212. birder/net/se_resnext.py +0 -30
  213. birder-0.3.3.dist-info/RECORD +0 -299
  214. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
  215. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
  216. {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2010.04159
9
9
 
10
10
  Changes from original:
11
11
  * Removed two stage support
12
- * Zero cost matrix elements on overflow (HungarianMatcher)
12
+ * Penalize cost matrix elements on overflow (HungarianMatcher)
13
13
  """
14
14
 
15
15
  # Reference license: Apache-2.0 (both)
@@ -58,7 +58,7 @@ class HungarianMatcher(nn.Module):
58
58
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
59
59
  ) -> list[torch.Tensor]:
60
60
  with torch.no_grad():
61
- (B, num_queries) = class_logits.shape[:2]
61
+ B, num_queries = class_logits.shape[:2]
62
62
 
63
63
  # We flatten to compute the cost matrices in a batch
64
64
  out_prob = class_logits.flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
@@ -89,7 +89,10 @@ class HungarianMatcher(nn.Module):
89
89
  # Final cost matrix
90
90
  C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
91
91
  C = C.view(B, num_queries, -1).cpu()
92
- C[C.isnan() | C.isinf()] = 0.0
92
+ finite = torch.isfinite(C)
93
+ if not torch.all(finite):
94
+ penalty = C[finite].max().item() + 1.0 if finite.any().item() else 1.0
95
+ C.nan_to_num_(nan=penalty, posinf=penalty, neginf=penalty)
93
96
 
94
97
  sizes = [len(v["boxes"]) for v in targets]
95
98
  indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
@@ -108,8 +111,7 @@ def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
108
111
  class MultiScaleDeformableAttention(nn.Module):
109
112
  def __init__(self, d_model: int, n_levels: int, n_heads: int, n_points: int) -> None:
110
113
  super().__init__()
111
- if d_model % n_heads != 0:
112
- raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
114
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
113
115
 
114
116
  # Ensure dim_per_head is power of 2
115
117
  dim_per_head = d_model // n_heads
@@ -130,9 +132,9 @@ class MultiScaleDeformableAttention(nn.Module):
130
132
  self.value_proj = nn.Linear(d_model, d_model)
131
133
  self.output_proj = nn.Linear(d_model, d_model)
132
134
 
133
- self._reset_parameters()
135
+ self.reset_parameters()
134
136
 
135
- def _reset_parameters(self) -> None:
137
+ def reset_parameters(self) -> None:
136
138
  nn.init.constant_(self.sampling_offsets.weight, 0.0)
137
139
  thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
138
140
  grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
@@ -163,8 +165,8 @@ class MultiScaleDeformableAttention(nn.Module):
163
165
  input_level_start_index: torch.Tensor,
164
166
  input_padding_mask: Optional[torch.Tensor] = None,
165
167
  ) -> torch.Tensor:
166
- (N, num_queries, _) = query.size()
167
- (N, sequence_length, _) = input_flatten.size()
168
+ N, num_queries, _ = query.size()
169
+ N, sequence_length, _ = input_flatten.size()
168
170
  assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == sequence_length
169
171
 
170
172
  value = self.value_proj(input_flatten)
@@ -280,7 +282,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
280
282
  q = tgt + query_pos
281
283
  k = tgt + query_pos
282
284
 
283
- (tgt2, _) = self.self_attn(
285
+ tgt2, _ = self.self_attn(
284
286
  q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
285
287
  )
286
288
  tgt2 = tgt2.transpose(0, 1)
@@ -315,7 +317,7 @@ class DeformableTransformerEncoder(nn.Module):
315
317
  for lvl, spatial_shape in enumerate(spatial_shapes):
316
318
  H = spatial_shape[0]
317
319
  W = spatial_shape[1]
318
- (ref_y, ref_x) = torch.meshgrid(
320
+ ref_y, ref_x = torch.meshgrid(
319
321
  torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
320
322
  torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
321
323
  indexing="ij",
@@ -451,7 +453,7 @@ class DeformableTransformer(nn.Module):
451
453
 
452
454
  for m in self.modules():
453
455
  if isinstance(m, MultiScaleDeformableAttention):
454
- m._reset_parameters()
456
+ m.reset_parameters()
455
457
 
456
458
  nn.init.xavier_uniform_(self.reference_points.weight, gain=1.0)
457
459
  nn.init.zeros_(self.reference_points.bias)
@@ -459,7 +461,7 @@ class DeformableTransformer(nn.Module):
459
461
  nn.init.normal_(self.level_embed)
460
462
 
461
463
  def get_valid_ratio(self, mask: torch.Tensor) -> torch.Tensor:
462
- (_, H, W) = mask.size()
464
+ _, H, W = mask.size()
463
465
  valid_h = torch.sum(~mask[:, :, 0], 1)
464
466
  valid_w = torch.sum(~mask[:, 0, :], 1)
465
467
  valid_ratio_h = valid_h.float() / H
@@ -482,7 +484,7 @@ class DeformableTransformer(nn.Module):
482
484
  mask_list = []
483
485
  spatial_shape_list: list[list[int]] = [] # list[tuple[int, int]] not supported on TorchScript
484
486
  for lvl, (src, pos_embed, mask) in enumerate(zip(srcs, pos_embeds, masks)):
485
- (_, _, H, W) = src.size()
487
+ _, _, H, W = src.size()
486
488
  spatial_shape_list.append([H, W])
487
489
  src = src.flatten(2).transpose(1, 2)
488
490
  pos_embed = pos_embed.flatten(2).transpose(1, 2)
@@ -505,14 +507,14 @@ class DeformableTransformer(nn.Module):
505
507
  )
506
508
 
507
509
  # Prepare input for decoder
508
- (B, _, C) = memory.size()
510
+ B, _, C = memory.size()
509
511
  query_embed, tgt = torch.split(query_embed, C, dim=1)
510
512
  query_embed = query_embed.unsqueeze(0).expand(B, -1, -1)
511
513
  tgt = tgt.unsqueeze(0).expand(B, -1, -1)
512
514
  reference_points = self.reference_points(query_embed).sigmoid()
513
515
 
514
516
  # Decoder
515
- (hs, inter_references) = self.decoder(
517
+ hs, inter_references = self.decoder(
516
518
  tgt, reference_points, memory, spatial_shapes, level_start_index, query_embed, valid_ratios, mask_flatten
517
519
  )
518
520
 
@@ -629,7 +631,7 @@ class Deformable_DETR(DetectionBaseNet):
629
631
  prior_prob = 0.01
630
632
  bias_value = -math.log((1 - prior_prob) / prior_prob)
631
633
  for class_embed in self.class_embed:
632
- class_embed.bias.data = torch.ones(self.num_classes) * bias_value
634
+ nn.init.constant_(class_embed.bias, bias_value)
633
635
 
634
636
  def freeze(self, freeze_classifier: bool = True) -> None:
635
637
  for param in self.parameters():
@@ -653,20 +655,19 @@ class Deformable_DETR(DetectionBaseNet):
653
655
  ) -> torch.Tensor:
654
656
  idx = self._get_src_permutation_idx(indices)
655
657
  target_classes_o = torch.concat([t["labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
656
- target_classes = torch.full(cls_logits.shape[:2], self.num_classes, dtype=torch.int64, device=cls_logits.device)
657
- target_classes[idx] = target_classes_o
658
658
 
659
659
  target_classes_onehot = torch.zeros(
660
- [cls_logits.shape[0], cls_logits.shape[1], cls_logits.shape[2] + 1],
660
+ cls_logits.size(0),
661
+ cls_logits.size(1),
662
+ cls_logits.size(2) + 1,
661
663
  dtype=cls_logits.dtype,
662
- layout=cls_logits.layout,
663
664
  device=cls_logits.device,
664
665
  )
665
- target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
666
-
666
+ target_classes_onehot[idx[0], idx[1], target_classes_o] = 1
667
667
  target_classes_onehot = target_classes_onehot[:, :, :-1]
668
+
668
669
  loss = sigmoid_focal_loss(cls_logits, target_classes_onehot, alpha=0.25, gamma=2.0)
669
- loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.shape[1]
670
+ loss_ce = (loss.mean(1).sum() / num_boxes) * cls_logits.size(1)
670
671
 
671
672
  return loss_ce
672
673
 
@@ -716,7 +717,7 @@ class Deformable_DETR(DetectionBaseNet):
716
717
  for idx in range(cls_logits.size(0)):
717
718
  indices = self.matcher(cls_logits[idx], box_output[idx], targets)
718
719
  loss_ce_i = self._class_loss(cls_logits[idx], targets, indices, num_boxes)
719
- (loss_bbox_i, loss_giou_i) = self._box_loss(box_output[idx], targets, indices, num_boxes)
720
+ loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
720
721
  loss_ce_list.append(loss_ce_i)
721
722
  loss_bbox_list.append(loss_bbox_i)
722
723
  loss_giou_list.append(loss_giou_i)
@@ -736,7 +737,7 @@ class Deformable_DETR(DetectionBaseNet):
736
737
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
737
738
  ) -> list[dict[str, torch.Tensor]]:
738
739
  prob = class_logits.sigmoid()
739
- (topk_values, topk_indexes) = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
740
+ topk_values, topk_indexes = torch.topk(prob.view(class_logits.shape[0], -1), k=100, dim=1)
740
741
  scores = topk_values
741
742
  topk_boxes = topk_indexes // class_logits.shape[2]
742
743
  labels = topk_indexes % class_logits.shape[2]
@@ -749,7 +750,7 @@ class Deformable_DETR(DetectionBaseNet):
749
750
  boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
750
751
 
751
752
  # Convert from relative [0, 1] to absolute [0, height] coordinates
752
- (img_h, img_w) = target_sizes.unbind(1)
753
+ img_h, img_w = target_sizes.unbind(1)
753
754
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
754
755
  boxes = boxes * scale_fct[:, None, :]
755
756
 
@@ -757,7 +758,7 @@ class Deformable_DETR(DetectionBaseNet):
757
758
  for s, l, b in zip(scores, labels, boxes):
758
759
  # Non-maximum suppression
759
760
  if self.soft_nms is not None:
760
- (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
761
+ soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
761
762
  s[keep] = soft_scores
762
763
 
763
764
  b = b[keep]
@@ -794,14 +795,14 @@ class Deformable_DETR(DetectionBaseNet):
794
795
  mask_size = feature_list[idx].shape[-2:]
795
796
  m = F.interpolate(masks[None].float(), size=mask_size, mode="nearest").to(torch.bool)[0]
796
797
  else:
797
- (B, _, H, W) = feature_list[idx].size()
798
+ B, _, H, W = feature_list[idx].size()
798
799
  m = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
799
800
 
800
801
  feature_list[idx] = proj(feature_list[idx])
801
802
  mask_list.append(m)
802
803
  pos_list.append(self.pos_enc(feature_list[idx], m))
803
804
 
804
- (hs, init_reference, inter_references) = self.transformer(
805
+ hs, init_reference, inter_references = self.transformer(
805
806
  feature_list, pos_list, self.query_embed.weight, mask_list
806
807
  )
807
808
  outputs_classes = []
@@ -6,7 +6,7 @@ Paper "End-to-End Object Detection with Transformers", https://arxiv.org/abs/200
6
6
 
7
7
  Changes from original:
8
8
  * Move background index to first from last (to be inline with the rest of Birder detectors)
9
- * Zero cost matrix elements on overflow (HungarianMatcher)
9
+ * Penalize cost matrix elements on overflow (HungarianMatcher)
10
10
  """
11
11
 
12
12
  # Reference license: Apache-2.0
@@ -51,7 +51,7 @@ class HungarianMatcher(nn.Module):
51
51
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
52
52
  ) -> list[torch.Tensor]:
53
53
  with torch.no_grad():
54
- (B, num_queries) = class_logits.shape[:2]
54
+ B, num_queries = class_logits.shape[:2]
55
55
 
56
56
  # We flatten to compute the cost matrices in a batch
57
57
  out_prob = class_logits.flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
@@ -78,7 +78,10 @@ class HungarianMatcher(nn.Module):
78
78
  # Final cost matrix
79
79
  C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
80
80
  C = C.view(B, num_queries, -1).cpu()
81
- C[C.isnan() | C.isinf()] = 0.0
81
+ finite = torch.isfinite(C)
82
+ if not torch.all(finite):
83
+ penalty = C[finite].max().item() + 1.0 if finite.any().item() else 1.0
84
+ C.nan_to_num_(nan=penalty, posinf=penalty, neginf=penalty)
82
85
 
83
86
  sizes = [len(v["boxes"]) for v in targets]
84
87
  indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
@@ -108,7 +111,7 @@ class TransformerEncoderLayer(nn.Module):
108
111
  q = src + pos
109
112
  k = src + pos
110
113
 
111
- (src2, _) = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask, need_weights=False)
114
+ src2, _ = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask, need_weights=False)
112
115
  src = src + self.dropout1(src2)
113
116
  src = self.norm1(src)
114
117
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
@@ -148,10 +151,10 @@ class TransformerDecoderLayer(nn.Module):
148
151
  q = tgt + query_pos
149
152
  k = tgt + query_pos
150
153
 
151
- (tgt2, _) = self.self_attn(q, k, value=tgt, need_weights=False)
154
+ tgt2, _ = self.self_attn(q, k, value=tgt, need_weights=False)
152
155
  tgt = tgt + self.dropout1(tgt2)
153
156
  tgt = self.norm1(tgt)
154
- (tgt2, _) = self.multihead_attn(
157
+ tgt2, _ = self.multihead_attn(
155
158
  query=tgt + query_pos,
156
159
  key=memory + pos,
157
160
  value=memory,
@@ -267,7 +270,7 @@ class PositionEmbeddingSine(nn.Module):
267
270
 
268
271
  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
269
272
  if mask is None:
270
- (B, _, H, W) = x.size()
273
+ B, _, H, W = x.size()
271
274
  mask = torch.zeros(B, H, W, dtype=torch.bool, device=x.device)
272
275
 
273
276
  not_mask = ~mask
@@ -427,7 +430,7 @@ class DETR(DetectionBaseNet):
427
430
  for idx in range(cls_logits.size(0)):
428
431
  indices = self.matcher(cls_logits[idx], box_output[idx], targets)
429
432
  loss_ce_i = self._class_loss(cls_logits[idx], targets, indices)
430
- (loss_bbox_i, loss_giou_i) = self._box_loss(box_output[idx], targets, indices, num_boxes)
433
+ loss_bbox_i, loss_giou_i = self._box_loss(box_output[idx], targets, indices, num_boxes)
431
434
  loss_ce_list.append(loss_ce_i)
432
435
  loss_bbox_list.append(loss_bbox_i)
433
436
  loss_giou_list.append(loss_giou_i)
@@ -447,7 +450,7 @@ class DETR(DetectionBaseNet):
447
450
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, image_shapes: list[tuple[int, int]]
448
451
  ) -> list[dict[str, torch.Tensor]]:
449
452
  prob = F.softmax(class_logits, -1)
450
- (scores, labels) = prob[..., 1:].max(-1)
453
+ scores, labels = prob[..., 1:].max(-1)
451
454
  labels = labels + 1
452
455
 
453
456
  # TorchScript doesn't support creating tensor from tuples, convert everything to lists
@@ -457,7 +460,7 @@ class DETR(DetectionBaseNet):
457
460
  boxes = box_ops.box_convert(box_regression, in_fmt="cxcywh", out_fmt="xyxy")
458
461
 
459
462
  # Convert from relative [0, 1] to absolute [0, height] coordinates
460
- (img_h, img_w) = target_sizes.unbind(1)
463
+ img_h, img_w = target_sizes.unbind(1)
461
464
  scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
462
465
  boxes = boxes * scale_fct[:, None, :]
463
466
 
@@ -465,7 +468,7 @@ class DETR(DetectionBaseNet):
465
468
  for s, l, b in zip(scores, labels, boxes):
466
469
  # Non-maximum suppression
467
470
  if self.soft_nms is not None:
468
- (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
471
+ soft_scores, keep = self.soft_nms(b, s, l, score_threshold=0.001)
469
472
  s[keep] = soft_scores
470
473
 
471
474
  b = b[keep]
@@ -136,8 +136,8 @@ class ResampleFeatureMap(nn.Module):
136
136
  if self.conv is not None:
137
137
  x = self.conv(x)
138
138
 
139
- (in_h, in_w) = x.shape[-2:]
140
- (target_h, target_w) = target_size
139
+ in_h, in_w = x.shape[-2:]
140
+ target_h, target_w = target_size
141
141
  if in_h == target_h and in_w == target_w:
142
142
  return x
143
143
 
@@ -195,7 +195,7 @@ class FpnCombine(nn.Module):
195
195
  )
196
196
 
197
197
  if weight_method in {"attn", "fastattn"}:
198
- self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets)), requires_grad=True) # WSM
198
+ self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets))) # WSM
199
199
  else:
200
200
  self.edge_weights = None
201
201
 
@@ -358,13 +358,7 @@ class HeadNet(nn.Module):
358
358
  for _ in range(repeats):
359
359
  layers.append(
360
360
  nn.Conv2d(
361
- fpn_channels,
362
- fpn_channels,
363
- kernel_size=(3, 3),
364
- stride=(1, 1),
365
- padding=(1, 1),
366
- groups=fpn_channels,
367
- bias=True,
361
+ fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
368
362
  )
369
363
  )
370
364
  layers.append(
@@ -383,22 +377,9 @@ class HeadNet(nn.Module):
383
377
  self.conv_repeat = nn.Sequential(*layers)
384
378
  self.predict = nn.Sequential(
385
379
  nn.Conv2d(
386
- fpn_channels,
387
- fpn_channels,
388
- kernel_size=(3, 3),
389
- stride=(1, 1),
390
- padding=(1, 1),
391
- groups=fpn_channels,
392
- bias=True,
393
- ),
394
- nn.Conv2d(
395
- fpn_channels,
396
- num_outputs * num_anchors,
397
- kernel_size=(1, 1),
398
- stride=(1, 1),
399
- padding=(0, 0),
400
- bias=True,
380
+ fpn_channels, fpn_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=fpn_channels
401
381
  ),
382
+ nn.Conv2d(fpn_channels, num_outputs * num_anchors, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)),
402
383
  )
403
384
 
404
385
  def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
@@ -453,7 +434,7 @@ class ClassificationHead(HeadNet):
453
434
  cls_logits = self.predict(cls_logits)
454
435
 
455
436
  # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
456
- (N, _, H, W) = cls_logits.shape
437
+ N, _, H, W = cls_logits.shape
457
438
  cls_logits = cls_logits.view(N, -1, self.num_outputs, H, W)
458
439
  cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
459
440
  cls_logits = cls_logits.reshape(N, -1, self.num_outputs) # Size=(N, HWA, K)
@@ -504,7 +485,7 @@ class RegressionHead(HeadNet):
504
485
  bbox_regression = self.predict(bbox_regression)
505
486
 
506
487
  # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
507
- (N, _, H, W) = bbox_regression.shape
488
+ N, _, H, W = bbox_regression.shape
508
489
  bbox_regression = bbox_regression.view(N, -1, 4, H, W)
509
490
  bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
510
491
  bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
@@ -663,7 +644,7 @@ class EfficientDet(DetectionBaseNet):
663
644
 
664
645
  # Keep only topk scoring predictions
665
646
  num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
666
- (scores_per_level, idxs) = scores_per_level.topk(num_topk)
647
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
667
648
  topk_idxs = topk_idxs[idxs]
668
649
 
669
650
  anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
@@ -685,7 +666,7 @@ class EfficientDet(DetectionBaseNet):
685
666
 
686
667
  # Non-maximum suppression
687
668
  if self.soft_nms is not None:
688
- (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
669
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
689
670
  image_scores[keep] = soft_scores
690
671
  else:
691
672
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
@@ -150,7 +150,7 @@ def concat_box_prediction_layers(
150
150
  # all feature levels concatenated, so we keep the same representation
151
151
  # for the objectness and the box_regression
152
152
  for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
153
- (N, AxC, H, W) = box_cls_per_level.shape # pylint: disable=invalid-name
153
+ N, AxC, H, W = box_cls_per_level.shape # pylint: disable=invalid-name
154
154
  Ax4 = box_regression_per_level.shape[1] # pylint: disable=invalid-name
155
155
  A = Ax4 // 4
156
156
  C = AxC // A
@@ -240,7 +240,7 @@ class RegionProposalNetwork(nn.Module):
240
240
 
241
241
  # Get the targets corresponding GT for each proposal
242
242
  # NB: need to clamp the indices because we can have a single
243
- # GT in the image, and matched_idxs can be -2, which goes out of bounds
243
+ # GT in the image and matched_idxs can be -2, which goes out of bounds
244
244
  matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
245
245
 
246
246
  labels_per_image = matched_idxs >= 0
@@ -265,7 +265,7 @@ class RegionProposalNetwork(nn.Module):
265
265
  for ob in objectness.split(num_anchors_per_level, 1):
266
266
  num_anchors = ob.shape[1]
267
267
  pre_nms_top_n = min(self.pre_nms_top_n(), int(ob.size(1)))
268
- (_, top_n_idx) = ob.topk(pre_nms_top_n, dim=1)
268
+ _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
269
269
  r.append(top_n_idx + offset)
270
270
  offset += num_anchors
271
271
 
@@ -310,19 +310,19 @@ class RegionProposalNetwork(nn.Module):
310
310
 
311
311
  # Remove small boxes
312
312
  keep = box_ops.remove_small_boxes(boxes, self.min_size)
313
- (boxes, scores, lvl) = boxes[keep], scores[keep], lvl[keep]
313
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
314
314
 
315
315
  # Remove low scoring boxes
316
316
  # use >= for Backwards compatibility
317
317
  keep = torch.where(scores >= self.score_thresh)[0]
318
- (boxes, scores, lvl) = boxes[keep], scores[keep], lvl[keep]
318
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
319
319
 
320
320
  # Non-maximum suppression, independently done per level
321
321
  keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
322
322
 
323
323
  # Keep only topk scoring predictions
324
324
  keep = keep[: self.post_nms_top_n()]
325
- (boxes, scores) = boxes[keep], scores[keep]
325
+ boxes, scores = boxes[keep], scores[keep]
326
326
 
327
327
  final_boxes.append(boxes)
328
328
  final_scores.append(scores)
@@ -336,7 +336,7 @@ class RegionProposalNetwork(nn.Module):
336
336
  labels: list[torch.Tensor],
337
337
  regression_targets: list[torch.Tensor],
338
338
  ) -> tuple[torch.Tensor, torch.Tensor]:
339
- (sampled_pos_idxs, sampled_neg_idxs) = self.fg_bg_sampler(labels)
339
+ sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
340
340
  sampled_pos_idxs = torch.where(torch.concat(sampled_pos_idxs, dim=0))[0]
341
341
  sampled_neg_idxs = torch.where(torch.concat(sampled_neg_idxs, dim=0))[0]
342
342
 
@@ -364,29 +364,29 @@ class RegionProposalNetwork(nn.Module):
364
364
  ) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
365
365
  # RPN uses all feature maps that are available
366
366
  features_list = list(features.values())
367
- (objectness, pred_bbox_deltas) = self.head(features_list)
367
+ objectness, pred_bbox_deltas = self.head(features_list)
368
368
  anchors = self.anchor_generator(images, features_list)
369
369
 
370
370
  num_images = len(anchors)
371
371
  num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
372
372
  num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
373
- (objectness, pred_bbox_deltas) = concat_box_prediction_layers(objectness, pred_bbox_deltas)
373
+ objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
374
374
 
375
375
  # Apply pred_bbox_deltas to anchors to obtain the decoded proposals
376
376
  # note that we detach the deltas because Faster R-CNN do not backprop through
377
377
  # the proposals
378
378
  proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
379
379
  proposals = proposals.view(num_images, -1, 4)
380
- (boxes, _scores) = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
380
+ boxes, _scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
381
381
 
382
382
  losses: dict[str, torch.Tensor] = {}
383
383
  if self.training is True:
384
384
  if targets is None:
385
385
  raise ValueError("targets should not be None")
386
386
 
387
- (labels, matched_gt_boxes) = self.assign_targets_to_anchors(anchors, targets)
387
+ labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
388
388
  regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
389
- (loss_objectness, loss_rpn_box_reg) = self.compute_loss(
389
+ loss_objectness, loss_rpn_box_reg = self.compute_loss(
390
390
  objectness, pred_bbox_deltas, labels, regression_targets
391
391
  )
392
392
  losses = {
@@ -405,7 +405,7 @@ class FastRCNNConvFCHead(nn.Sequential):
405
405
  fc_layers: list[int],
406
406
  norm_layer: Optional[Callable[..., nn.Module]] = None,
407
407
  ):
408
- (in_channels, in_height, in_width) = input_size
408
+ in_channels, in_height, in_width = input_size
409
409
 
410
410
  blocks = []
411
411
  previous_channels = in_channels
@@ -481,7 +481,7 @@ def faster_rcnn_loss(
481
481
  # advanced indexing
482
482
  sampled_pos_idxs_subset = torch.where(labels > 0)[0]
483
483
  labels_pos = labels[sampled_pos_idxs_subset]
484
- (N, _num_classes) = class_logits.shape
484
+ N, _num_classes = class_logits.shape
485
485
  box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
486
486
 
487
487
  box_loss = F.smooth_l1_loss(
@@ -573,7 +573,7 @@ class RoIHeads(nn.Module):
573
573
  return (matched_idxs, labels)
574
574
 
575
575
  def subsample(self, labels: list[torch.Tensor]) -> list[torch.Tensor]:
576
- (sampled_pos_idxs, sampled_neg_idxs) = self.fg_bg_sampler(labels)
576
+ sampled_pos_idxs, sampled_neg_idxs = self.fg_bg_sampler(labels)
577
577
  sampled_idxs = []
578
578
  for pos_idxs_img, neg_idxs_img in zip(sampled_pos_idxs, sampled_neg_idxs):
579
579
  img_sampled_idxs = torch.where(pos_idxs_img | neg_idxs_img)[0]
@@ -610,7 +610,7 @@ class RoIHeads(nn.Module):
610
610
  proposals = self.add_gt_proposals(proposals, gt_boxes)
611
611
 
612
612
  # Get matching gt indices for each proposal
613
- (matched_idxs, labels) = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
613
+ matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
614
614
 
615
615
  # Sample a fixed proportion of positive-negative proposals
616
616
  sampled_idxs = self.subsample(labels)
@@ -713,7 +713,7 @@ class RoIHeads(nn.Module):
713
713
  raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
714
714
 
715
715
  if self.training is True:
716
- (proposals, _matched_idxs, labels, regression_targets) = self.select_training_samples(proposals, targets)
716
+ proposals, _matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
717
717
  else:
718
718
  labels = None
719
719
  regression_targets = None
@@ -721,7 +721,7 @@ class RoIHeads(nn.Module):
721
721
 
722
722
  box_features = self.box_roi_pool(features, proposals, image_shapes)
723
723
  box_features = self.box_head(box_features)
724
- (class_logits, box_regression) = self.box_predictor(box_features)
724
+ class_logits, box_regression = self.box_predictor(box_features)
725
725
 
726
726
  losses = {}
727
727
  result: list[dict[str, torch.Tensor]] = []
@@ -731,11 +731,11 @@ class RoIHeads(nn.Module):
731
731
  if regression_targets is None:
732
732
  raise ValueError("regression_targets cannot be None")
733
733
 
734
- (loss_classifier, loss_box_reg) = faster_rcnn_loss(class_logits, box_regression, labels, regression_targets)
734
+ loss_classifier, loss_box_reg = faster_rcnn_loss(class_logits, box_regression, labels, regression_targets)
735
735
  losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
736
736
 
737
737
  else:
738
- (boxes, scores, labels) = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
738
+ boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
739
739
  num_images = len(boxes)
740
740
  for i in range(num_images):
741
741
  result.append(
@@ -868,8 +868,8 @@ class Faster_RCNN(DetectionBaseNet):
868
868
  images = self._to_img_list(x, image_sizes)
869
869
 
870
870
  features = self.backbone_with_fpn(x)
871
- (proposals, proposal_losses) = self.rpn(images, features, targets)
872
- (detections, detector_losses) = self.roi_heads(features, proposals, images.image_sizes, targets)
871
+ proposals, proposal_losses = self.rpn(images, features, targets)
872
+ detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
873
873
 
874
874
  losses = {}
875
875
  losses.update(detector_losses)
@@ -125,7 +125,7 @@ class FCOSClassificationHead(nn.Module):
125
125
  cls_logits = self.cls_logits(cls_logits)
126
126
 
127
127
  # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
128
- (N, _, H, W) = cls_logits.size()
128
+ N, _, H, W = cls_logits.size()
129
129
  cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
130
130
  cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
131
131
  cls_logits = cls_logits.reshape(N, -1, self.num_classes) # (N, HWA, 4)
@@ -165,7 +165,7 @@ class FCOSRegressionHead(nn.Module):
165
165
  bbox_ctrness = self.bbox_ctrness(bbox_feature)
166
166
 
167
167
  # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
168
- (N, _, H, W) = bbox_regression.size()
168
+ N, _, H, W = bbox_regression.size()
169
169
  bbox_regression = bbox_regression.view(N, -1, 4, H, W)
170
170
  bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
171
171
  bbox_regression = bbox_regression.reshape(N, -1, 4) # (N, HWA, 4)
@@ -262,7 +262,7 @@ class FCOSHead(nn.Module):
262
262
 
263
263
  def forward(self, x: list[torch.Tensor]) -> dict[str, torch.Tensor]:
264
264
  cls_logits = self.classification_head(x)
265
- (bbox_regression, bbox_ctrness) = self.regression_head(x)
265
+ bbox_regression, bbox_ctrness = self.regression_head(x)
266
266
 
267
267
  return {
268
268
  "cls_logits": cls_logits,
@@ -370,8 +370,8 @@ class FCOS(DetectionBaseNet):
370
370
  ).values < self.center_sampling_radius * anchor_sizes[:, None]
371
371
 
372
372
  # Compute pairwise distance between N points and M boxes
373
- (x, y) = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
374
- (x0, y0, x1, y1) = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
373
+ x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
374
+ x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
375
375
  pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
376
376
 
377
377
  # Anchor point must be inside gt
@@ -388,7 +388,7 @@ class FCOS(DetectionBaseNet):
388
388
  # Match the GT box with minimum area, if there are multiple GT matches
389
389
  gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
390
390
  pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
391
- (min_values, matched_idx) = pairwise_match.max(dim=1) # R, per-anchor match
391
+ min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
392
392
  matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1
393
393
 
394
394
  matched_idxs.append(matched_idx)
@@ -433,7 +433,7 @@ class FCOS(DetectionBaseNet):
433
433
 
434
434
  # Keep only topk scoring predictions
435
435
  num_topk = min(self.topk_candidates, int(topk_idxs.size(0)))
436
- (scores_per_level, idxs) = scores_per_level.topk(num_topk)
436
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
437
437
  topk_idxs = topk_idxs[idxs]
438
438
 
439
439
  anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
@@ -455,7 +455,7 @@ class FCOS(DetectionBaseNet):
455
455
 
456
456
  # Non-maximum suppression
457
457
  if self.soft_nms is not None:
458
- (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
458
+ soft_scores, keep = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
459
459
  image_scores[keep] = soft_scores
460
460
  else:
461
461
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)