birder 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. birder/common/training_cli.py +6 -1
  2. birder/common/training_utils.py +69 -12
  3. birder/net/_vit_configs.py +5 -0
  4. birder/net/cait.py +3 -3
  5. birder/net/coat.py +3 -3
  6. birder/net/deit.py +1 -1
  7. birder/net/deit3.py +1 -1
  8. birder/net/detection/__init__.py +2 -0
  9. birder/net/detection/deformable_detr.py +12 -12
  10. birder/net/detection/detr.py +7 -7
  11. birder/net/detection/lw_detr.py +1181 -0
  12. birder/net/detection/plain_detr.py +7 -5
  13. birder/net/detection/retinanet.py +1 -1
  14. birder/net/detection/rt_detr_v1.py +10 -10
  15. birder/net/detection/rt_detr_v2.py +47 -64
  16. birder/net/detection/ssdlite.py +2 -2
  17. birder/net/edgevit.py +3 -3
  18. birder/net/efficientvit_msft.py +1 -1
  19. birder/net/flexivit.py +1 -1
  20. birder/net/hieradet.py +2 -2
  21. birder/net/mnasnet.py +2 -2
  22. birder/net/resnext.py +2 -2
  23. birder/net/rope_deit3.py +1 -1
  24. birder/net/rope_flexivit.py +1 -1
  25. birder/net/rope_vit.py +1 -1
  26. birder/net/simple_vit.py +1 -1
  27. birder/net/vit.py +21 -3
  28. birder/net/vit_parallel.py +1 -1
  29. birder/net/vit_sam.py +62 -16
  30. birder/scripts/train.py +12 -8
  31. birder/scripts/train_capi.py +13 -10
  32. birder/scripts/train_detection.py +2 -1
  33. birder/scripts/train_kd.py +12 -8
  34. birder/version.py +1 -1
  35. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/METADATA +3 -3
  36. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/RECORD +40 -39
  37. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/WHEEL +1 -1
  38. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/entry_points.txt +0 -0
  39. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/licenses/LICENSE +0 -0
  40. {birder-0.4.1.dist-info → birder-0.4.2.dist-info}/top_level.txt +0 -0
@@ -56,7 +56,9 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
56
56
  )
57
57
 
58
58
 
59
- def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False) -> None:
59
+ def add_lr_wd_args(
60
+ parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False, backbone_layer_decay: bool = False
61
+ ) -> None:
60
62
  group = parser.add_argument_group("Learning rate and regularization parameters")
61
63
  group.add_argument("--lr", type=float, default=0.1, metavar="LR", help="base learning rate")
62
64
  group.add_argument("--bias-lr", type=float, metavar="LR", help="learning rate of biases")
@@ -92,6 +94,9 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
92
94
  help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
93
95
  )
94
96
  group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
97
+ if backbone_layer_decay is True:
98
+ group.add_argument("--backbone-layer-decay", type=float, help="backbone layer-wise learning rate decay (LLRD)")
99
+
95
100
  group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
96
101
  group.add_argument(
97
102
  "--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
@@ -343,7 +343,7 @@ def count_layers(model: torch.nn.Module) -> int:
343
343
  return num_layers
344
344
 
345
345
 
346
- # pylint: disable=protected-access,too-many-locals,too-many-branches
346
+ # pylint: disable=protected-access,too-many-locals,too-many-branches,too-many-statements
347
347
  def optimizer_parameter_groups(
348
348
  model: torch.nn.Module,
349
349
  weight_decay: float,
@@ -352,6 +352,7 @@ def optimizer_parameter_groups(
352
352
  custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
353
353
  custom_layer_weight_decay: Optional[dict[str, float]] = None,
354
354
  layer_decay: Optional[float] = None,
355
+ backbone_layer_decay: Optional[float] = None,
355
356
  layer_decay_min_scale: Optional[float] = None,
356
357
  layer_decay_no_opt_scale: Optional[float] = None,
357
358
  bias_lr: Optional[float] = None,
@@ -388,6 +389,8 @@ def optimizer_parameter_groups(
388
389
  Applied to parameters whose names contain the specified keys.
389
390
  layer_decay
390
391
  Layer-wise learning rate decay factor.
392
+ backbone_layer_decay
393
+ Layer-wise learning rate decay factor for backbone parameters only.
391
394
  layer_decay_min_scale
392
395
  Minimum learning rate scale factor when using layer decay. Prevents layers from having too small learning rates.
393
396
  layer_decay_no_opt_scale
@@ -434,6 +437,27 @@ def optimizer_parameter_groups(
434
437
  if layer_decay is not None:
435
438
  logger.warning("Assigning lr scaling (layer decay) without a block group map")
436
439
 
440
+ backbone_group_map: dict[str, int] = {}
441
+ backbone_num_layers = 0
442
+ if backbone_layer_decay is not None:
443
+ backbone_module = getattr(model, "backbone", None)
444
+ if backbone_module is None:
445
+ logger.warning("Backbone layer decay requested but model has no backbone")
446
+ backbone_layer_decay = None
447
+ else:
448
+ backbone_block_group_regex = getattr(backbone_module, "block_group_regex", None)
449
+ if backbone_block_group_regex is not None:
450
+ names = [n for n, _ in backbone_module.named_parameters()]
451
+ groups = group_by_regex(names, backbone_block_group_regex)
452
+ backbone_group_map = {
453
+ f"backbone.{item}": index for index, sublist in enumerate(groups) for item in sublist
454
+ }
455
+ backbone_num_layers = len(groups)
456
+ else:
457
+ backbone_group_map = {}
458
+ backbone_num_layers = count_layers(backbone_module)
459
+ logger.warning("Assigning lr scaling (backbone layer decay) without a block group map")
460
+
437
461
  # Build layer scale
438
462
  if layer_decay_min_scale is None:
439
463
  layer_decay_min_scale = 0.0
@@ -444,14 +468,28 @@ def optimizer_parameter_groups(
444
468
  layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
445
469
  logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
446
470
 
471
+ backbone_layer_scales = []
472
+ if backbone_layer_decay is not None:
473
+ backbone_layer_max = backbone_num_layers - 1
474
+ backbone_layer_scales = [
475
+ max(layer_decay_min_scale, backbone_layer_decay ** (backbone_layer_max - i))
476
+ for i in range(backbone_num_layers)
477
+ ]
478
+ logger.info(
479
+ "Backbone layer scaling ranges from "
480
+ f"{min(backbone_layer_scales)} to {max(backbone_layer_scales)} across {backbone_num_layers} layers"
481
+ )
482
+
447
483
  # Set weight decay and layer decay
448
484
  idx = 0
485
+ backbone_idx = 0
449
486
  params = []
450
487
  module_stack_with_prefix = [(model, "")]
451
488
  visited_modules = []
452
489
  while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
453
490
  skip_module = False
454
491
  module, prefix = module_stack_with_prefix.pop()
492
+ is_backbone_module = prefix == "backbone" or prefix.startswith("backbone.")
455
493
  if id(module) in visited_modules:
456
494
  skip_module = True
457
495
 
@@ -460,23 +498,35 @@ def optimizer_parameter_groups(
460
498
  for name, p in module.named_parameters(recurse=False):
461
499
  target_name = f"{prefix}.{name}" if prefix != "" else name
462
500
  idx = group_map.get(target_name, idx)
501
+ is_backbone_param = target_name.startswith("backbone.")
502
+ if backbone_layer_decay is not None and is_backbone_param is True:
503
+ backbone_idx = backbone_group_map.get(target_name, backbone_idx)
463
504
  if skip_module is True:
464
505
  break
465
506
 
466
507
  parameters_found = True
467
508
  if p.requires_grad is False:
468
509
  continue
469
- if layer_decay is not None and layer_decay_no_opt_scale is not None:
470
- if layer_scales[idx] < layer_decay_no_opt_scale:
471
- p.requires_grad_(False)
510
+ if layer_decay_no_opt_scale is not None:
511
+ if backbone_layer_decay is not None and is_backbone_param is True:
512
+ if backbone_layer_scales and backbone_layer_scales[backbone_idx] < layer_decay_no_opt_scale:
513
+ p.requires_grad_(False)
514
+ elif layer_decay is not None:
515
+ if layer_scales[idx] < layer_decay_no_opt_scale:
516
+ p.requires_grad_(False)
472
517
 
473
518
  is_custom_key = False
474
519
  if custom_keys_weight_decay is not None:
475
520
  for key, custom_wd in custom_keys_weight_decay:
476
521
  target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
477
522
  if key == target_name_for_custom_key:
478
- # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
479
- lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
523
+ # Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
524
+ if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
525
+ lr_scale = layer_scales[idx]
526
+ elif backbone_layer_decay is not None and is_backbone_param is True:
527
+ lr_scale = backbone_layer_scales[backbone_idx]
528
+ else:
529
+ lr_scale = 1.0
480
530
  if custom_layer_lr_scale is not None:
481
531
  for layer_name_key, custom_scale in custom_layer_lr_scale.items():
482
532
  if layer_name_key in target_name:
@@ -500,8 +550,8 @@ def optimizer_parameter_groups(
500
550
  # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
501
551
  if bias_lr is not None and target_name.endswith(".bias") is True:
502
552
  d["lr"] = bias_lr
503
- elif backbone_lr is not None and target_name.startswith("backbone.") is True:
504
- d["lr"] = backbone_lr
553
+ elif backbone_lr is not None and is_backbone_param is True:
554
+ d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
505
555
  elif lr_scale != 1.0:
506
556
  d["lr"] = base_lr * lr_scale
507
557
 
@@ -522,8 +572,13 @@ def optimizer_parameter_groups(
522
572
  wd = custom_wd_value
523
573
  break
524
574
 
525
- # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
526
- lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
575
+ # Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
576
+ if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
577
+ lr_scale = layer_scales[idx]
578
+ elif backbone_layer_decay is not None and is_backbone_param is True:
579
+ lr_scale = backbone_layer_scales[backbone_idx]
580
+ else:
581
+ lr_scale = 1.0
527
582
  if custom_layer_lr_scale is not None:
528
583
  for layer_name_key, custom_scale in custom_layer_lr_scale.items():
529
584
  if layer_name_key in target_name:
@@ -539,8 +594,8 @@ def optimizer_parameter_groups(
539
594
  # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
540
595
  if bias_lr is not None and target_name.endswith(".bias") is True:
541
596
  d["lr"] = bias_lr
542
- elif backbone_lr is not None and target_name.startswith("backbone.") is True:
543
- d["lr"] = backbone_lr
597
+ elif backbone_lr is not None and is_backbone_param is True:
598
+ d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
544
599
  elif lr_scale != 1.0:
545
600
  d["lr"] = base_lr * lr_scale
546
601
 
@@ -548,6 +603,8 @@ def optimizer_parameter_groups(
548
603
 
549
604
  if parameters_found is True:
550
605
  idx += 1
606
+ if is_backbone_module is True:
607
+ backbone_idx += 1
551
608
 
552
609
  for child_name, child_module in reversed(list(module.named_children())):
553
610
  child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
@@ -111,6 +111,11 @@ def register_vit_configs(vit: type[BaseNet]) -> None:
111
111
  vit,
112
112
  config={"patch_size": 16, **BASE, "layer_scale_init_value": 1e-5},
113
113
  )
114
+ registry.register_model_config(
115
+ "vit_b16_pn",
116
+ vit,
117
+ config={"patch_size": 16, **BASE, "pre_norm": True, "norm_layer_eps": 1e-5},
118
+ )
114
119
  registry.register_model_config(
115
120
  "vit_b16_qkn_ls",
116
121
  vit,
birder/net/cait.py CHANGED
@@ -231,11 +231,11 @@ class CaiT(BaseNet):
231
231
  if isinstance(m, nn.Linear):
232
232
  nn.init.trunc_normal_(m.weight, std=0.02)
233
233
  if m.bias is not None:
234
- nn.init.constant_(m.bias, 0)
234
+ nn.init.zeros_(m.bias)
235
235
 
236
236
  elif isinstance(m, nn.LayerNorm):
237
- nn.init.constant_(m.bias, 0)
238
- nn.init.constant_(m.weight, 1.0)
237
+ nn.init.zeros_(m.bias)
238
+ nn.init.ones_(m.weight)
239
239
 
240
240
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
241
241
  nn.init.trunc_normal_(self.cls_token, std=0.02)
birder/net/coat.py CHANGED
@@ -474,11 +474,11 @@ class CoaT(DetectorBackbone):
474
474
  if isinstance(m, nn.Linear):
475
475
  nn.init.trunc_normal_(m.weight, std=0.02)
476
476
  if m.bias is not None:
477
- nn.init.constant_(m.bias, 0)
477
+ nn.init.zeros_(m.bias)
478
478
 
479
479
  elif isinstance(m, nn.LayerNorm):
480
- nn.init.constant_(m.bias, 0)
481
- nn.init.constant_(m.weight, 1.0)
480
+ nn.init.zeros_(m.bias)
481
+ nn.init.ones_(m.weight)
482
482
 
483
483
  nn.init.trunc_normal_(self.cls_token1, std=0.02)
484
484
  nn.init.trunc_normal_(self.cls_token2, std=0.02)
birder/net/deit.py CHANGED
@@ -167,7 +167,7 @@ class DeiT(DetectorBackbone):
167
167
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
168
168
 
169
169
  out: dict[str, torch.Tensor] = {}
170
- for stage_name, stage_x in zip(self.return_stages, xs):
170
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
171
171
  stage_x = stage_x[:, self.num_special_tokens :]
172
172
  stage_x = stage_x.permute(0, 2, 1)
173
173
  B, C, _ = stage_x.size()
birder/net/deit3.py CHANGED
@@ -185,7 +185,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
185
185
  xs = self.encoder.forward_features(x, out_indices=self.out_indices)
186
186
 
187
187
  out: dict[str, torch.Tensor] = {}
188
- for stage_name, stage_x in zip(self.return_stages, xs):
188
+ for stage_name, stage_x in zip(self.return_stages, xs, strict=True):
189
189
  stage_x = stage_x[:, self.num_special_tokens :]
190
190
  stage_x = stage_x.permute(0, 2, 1)
191
191
  B, C, _ = stage_x.size()
@@ -3,6 +3,7 @@ from birder.net.detection.detr import DETR
3
3
  from birder.net.detection.efficientdet import EfficientDet
4
4
  from birder.net.detection.faster_rcnn import Faster_RCNN
5
5
  from birder.net.detection.fcos import FCOS
6
+ from birder.net.detection.lw_detr import LW_DETR
6
7
  from birder.net.detection.plain_detr import Plain_DETR
7
8
  from birder.net.detection.retinanet import RetinaNet
8
9
  from birder.net.detection.rt_detr_v1 import RT_DETR_v1
@@ -21,6 +22,7 @@ __all__ = [
21
22
  "EfficientDet",
22
23
  "Faster_RCNN",
23
24
  "FCOS",
25
+ "LW_DETR",
24
26
  "Plain_DETR",
25
27
  "RetinaNet",
26
28
  "RT_DETR_v1",
@@ -56,7 +56,7 @@ class HungarianMatcher(nn.Module):
56
56
  @torch.jit.unused # type: ignore[untyped-decorator]
57
57
  def forward(
58
58
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
59
- ) -> list[torch.Tensor]:
59
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
60
60
  with torch.no_grad():
61
61
  B, num_queries = class_logits.shape[:2]
62
62
 
@@ -135,7 +135,7 @@ class MultiScaleDeformableAttention(nn.Module):
135
135
  self.reset_parameters()
136
136
 
137
137
  def reset_parameters(self) -> None:
138
- nn.init.constant_(self.sampling_offsets.weight, 0.0)
138
+ nn.init.zeros_(self.sampling_offsets.weight)
139
139
  thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
140
140
  grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
141
141
  grid_init = (
@@ -149,12 +149,12 @@ class MultiScaleDeformableAttention(nn.Module):
149
149
  with torch.no_grad():
150
150
  self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
151
151
 
152
- nn.init.constant_(self.attention_weights.weight, 0.0)
153
- nn.init.constant_(self.attention_weights.bias, 0.0)
152
+ nn.init.zeros_(self.attention_weights.weight)
153
+ nn.init.zeros_(self.attention_weights.bias)
154
154
  nn.init.xavier_uniform_(self.value_proj.weight)
155
- nn.init.constant_(self.value_proj.bias, 0.0)
155
+ nn.init.zeros_(self.value_proj.bias)
156
156
  nn.init.xavier_uniform_(self.output_proj.weight)
157
- nn.init.constant_(self.output_proj.bias, 0.0)
157
+ nn.init.zeros_(self.output_proj.bias)
158
158
 
159
159
  def forward(
160
160
  self,
@@ -279,11 +279,10 @@ class DeformableTransformerDecoderLayer(nn.Module):
279
279
  self_attn_mask: Optional[torch.Tensor] = None,
280
280
  ) -> torch.Tensor:
281
281
  # Self attention
282
- q = tgt + query_pos
283
- k = tgt + query_pos
282
+ q_k = tgt + query_pos
284
283
 
285
284
  tgt2, _ = self.self_attn(
286
- q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
285
+ q_k.transpose(0, 1), q_k.transpose(0, 1), tgt.transpose(0, 1), need_weights=False, attn_mask=self_attn_mask
287
286
  )
288
287
  tgt2 = tgt2.transpose(0, 1)
289
288
  tgt = tgt + self.dropout(tgt2)
@@ -587,7 +586,7 @@ class Deformable_DETR(DetectionBaseNet):
587
586
 
588
587
  self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
589
588
  self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
590
- self.matcher = HungarianMatcher(cost_class=2, cost_bbox=5, cost_giou=2)
589
+ self.matcher = HungarianMatcher(cost_class=2.0, cost_bbox=5.0, cost_giou=2.0)
591
590
 
592
591
  class_embed = nn.Linear(hidden_dim, self.num_classes)
593
592
  bbox_embed = MLP(hidden_dim, [hidden_dim, hidden_dim, 4], activation_layer=nn.ReLU)
@@ -641,7 +640,8 @@ class Deformable_DETR(DetectionBaseNet):
641
640
  for param in self.class_embed.parameters():
642
641
  param.requires_grad_(True)
643
642
 
644
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
643
+ @staticmethod
644
+ def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
645
645
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
646
646
  src_idx = torch.concat([src for (src, _) in indices])
647
647
  return (batch_idx, src_idx)
@@ -709,7 +709,7 @@ class Deformable_DETR(DetectionBaseNet):
709
709
  if training_utils.is_dist_available_and_initialized() is True:
710
710
  torch.distributed.all_reduce(num_boxes)
711
711
 
712
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
712
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
713
713
 
714
714
  loss_ce_list = []
715
715
  loss_bbox_list = []
@@ -49,7 +49,7 @@ class HungarianMatcher(nn.Module):
49
49
  @torch.jit.unused # type: ignore[untyped-decorator]
50
50
  def forward(
51
51
  self, class_logits: torch.Tensor, box_regression: torch.Tensor, targets: list[dict[str, torch.Tensor]]
52
- ) -> list[torch.Tensor]:
52
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
53
53
  with torch.no_grad():
54
54
  B, num_queries = class_logits.shape[:2]
55
55
 
@@ -148,10 +148,9 @@ class TransformerDecoderLayer(nn.Module):
148
148
  query_pos: torch.Tensor,
149
149
  memory_key_padding_mask: Optional[torch.Tensor] = None,
150
150
  ) -> torch.Tensor:
151
- q = tgt + query_pos
152
- k = tgt + query_pos
151
+ q_k = tgt + query_pos
153
152
 
154
- tgt2, _ = self.self_attn(q, k, value=tgt, need_weights=False)
153
+ tgt2, _ = self.self_attn(q_k, q_k, value=tgt, need_weights=False)
155
154
  tgt = tgt + self.dropout1(tgt2)
156
155
  tgt = self.norm1(tgt)
157
156
  tgt2, _ = self.multihead_attn(
@@ -341,7 +340,7 @@ class DETR(DetectionBaseNet):
341
340
  )
342
341
  self.pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
343
342
 
344
- self.matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
343
+ self.matcher = HungarianMatcher(cost_class=1.0, cost_bbox=5.0, cost_giou=2.0)
345
344
  empty_weight = torch.ones(self.num_classes)
346
345
  empty_weight[0] = 0.1
347
346
  self.empty_weight = nn.Buffer(empty_weight)
@@ -365,7 +364,8 @@ class DETR(DetectionBaseNet):
365
364
  for param in self.class_embed.parameters():
366
365
  param.requires_grad_(True)
367
366
 
368
- def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
367
+ @staticmethod
368
+ def _get_src_permutation_idx(indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
369
369
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
370
370
  src_idx = torch.concat([src for (src, _) in indices])
371
371
  return (batch_idx, src_idx)
@@ -422,7 +422,7 @@ class DETR(DetectionBaseNet):
422
422
  if training_utils.is_dist_available_and_initialized() is True:
423
423
  torch.distributed.all_reduce(num_boxes)
424
424
 
425
- num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1).item()
425
+ num_boxes = torch.clamp(num_boxes / training_utils.get_world_size(), min=1)
426
426
 
427
427
  loss_ce_list = []
428
428
  loss_bbox_list = []