birder 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. birder/common/training_cli.py +6 -0
  2. birder/common/training_utils.py +215 -31
  3. birder/data/collators/detection.py +1 -0
  4. birder/data/dataloader/webdataset.py +12 -2
  5. birder/kernels/load_kernel.py +16 -11
  6. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  7. birder/net/cait.py +4 -3
  8. birder/net/convnext_v1.py +5 -0
  9. birder/net/crossformer.py +33 -30
  10. birder/net/crossvit.py +4 -3
  11. birder/net/deit.py +3 -3
  12. birder/net/deit3.py +3 -3
  13. birder/net/detection/deformable_detr.py +2 -5
  14. birder/net/detection/detr.py +2 -5
  15. birder/net/detection/efficientdet.py +2 -7
  16. birder/net/detection/fcos.py +2 -7
  17. birder/net/detection/retinanet.py +2 -7
  18. birder/net/detection/rt_detr_v1.py +1 -0
  19. birder/net/efficientformer_v1.py +15 -9
  20. birder/net/efficientformer_v2.py +39 -29
  21. birder/net/efficientvit_msft.py +9 -7
  22. birder/net/fastvit.py +1 -0
  23. birder/net/flexivit.py +5 -4
  24. birder/net/hiera.py +12 -9
  25. birder/net/hornet.py +9 -7
  26. birder/net/iformer.py +8 -6
  27. birder/net/levit.py +42 -30
  28. birder/net/lit_v1_tiny.py +15 -0
  29. birder/net/maxvit.py +67 -55
  30. birder/net/mobileone.py +1 -0
  31. birder/net/mvit_v2.py +13 -12
  32. birder/net/pit.py +4 -3
  33. birder/net/pvt_v1.py +4 -1
  34. birder/net/repghost.py +1 -0
  35. birder/net/repvgg.py +1 -0
  36. birder/net/repvit.py +1 -0
  37. birder/net/rope_deit3.py +5 -3
  38. birder/net/rope_flexivit.py +7 -4
  39. birder/net/rope_vit.py +10 -5
  40. birder/net/simple_vit.py +9 -6
  41. birder/net/swin_transformer_v1.py +71 -68
  42. birder/net/swin_transformer_v2.py +38 -31
  43. birder/net/tiny_vit.py +20 -10
  44. birder/net/transnext.py +38 -28
  45. birder/net/vit.py +5 -4
  46. birder/net/vit_parallel.py +5 -4
  47. birder/net/vit_sam.py +38 -37
  48. birder/net/vovnet_v1.py +15 -0
  49. birder/ops/msda.py +108 -43
  50. birder/ops/swattention.py +124 -61
  51. birder/results/detection.py +4 -0
  52. birder/scripts/benchmark.py +21 -12
  53. birder/scripts/predict.py +7 -0
  54. birder/scripts/train.py +39 -13
  55. birder/scripts/train_barlow_twins.py +35 -12
  56. birder/scripts/train_byol.py +35 -12
  57. birder/scripts/train_capi.py +41 -15
  58. birder/scripts/train_data2vec.py +37 -14
  59. birder/scripts/train_data2vec2.py +37 -14
  60. birder/scripts/train_detection.py +36 -11
  61. birder/scripts/train_dino_v1.py +51 -14
  62. birder/scripts/train_dino_v2.py +78 -19
  63. birder/scripts/train_dino_v2_dist.py +76 -17
  64. birder/scripts/train_franca.py +43 -19
  65. birder/scripts/train_i_jepa.py +37 -14
  66. birder/scripts/train_ibot.py +43 -20
  67. birder/scripts/train_kd.py +39 -13
  68. birder/scripts/train_mim.py +35 -12
  69. birder/scripts/train_mmcr.py +35 -12
  70. birder/scripts/train_rotnet.py +36 -13
  71. birder/scripts/train_simclr.py +35 -12
  72. birder/scripts/train_vicreg.py +35 -12
  73. birder/tools/convert_model.py +18 -15
  74. birder/tools/det_results.py +114 -2
  75. birder/tools/quantize_model.py +73 -67
  76. birder/version.py +1 -1
  77. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/METADATA +2 -1
  78. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/RECORD +82 -82
  79. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  80. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  81. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  82. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
@@ -455,13 +455,8 @@ class FCOS(DetectionBaseNet):
455
455
 
456
456
  # Non-maximum suppression
457
457
  if self.soft_nms is not None:
458
- # Actually much faster on CPU
459
- device = image_boxes.device
460
- (soft_scores, keep) = self.soft_nms(
461
- image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
462
- )
463
- keep = keep.to(device)
464
- image_scores[keep] = soft_scores.to(device)
458
+ (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
459
+ image_scores[keep] = soft_scores
465
460
  else:
466
461
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
467
462
 
@@ -417,13 +417,8 @@ class RetinaNet(DetectionBaseNet):
417
417
 
418
418
  # Non-maximum suppression
419
419
  if self.soft_nms is not None:
420
- # Actually much faster on CPU
421
- device = image_boxes.device
422
- (soft_scores, keep) = self.soft_nms(
423
- image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
424
- )
425
- keep = keep.to(device)
426
- image_scores[keep] = soft_scores.to(device)
420
+ (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
421
+ image_scores[keep] = soft_scores
427
422
  else:
428
423
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
429
424
 
@@ -1087,6 +1087,7 @@ class RT_DETR_v1(DetectionBaseNet):
1087
1087
 
1088
1088
  return (detections, losses)
1089
1089
 
1090
+ @torch.no_grad() # type: ignore[untyped-decorator]
1090
1091
  def reparameterize_model(self) -> None:
1091
1092
  if self.reparameterized is True:
1092
1093
  return
@@ -357,16 +357,22 @@ class EfficientFormer_v1(BaseNet):
357
357
  resolution = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
358
358
  for m in self.body.modules():
359
359
  if isinstance(m, Attention):
360
- m.attention_biases = nn.Parameter(
361
- interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
362
- )
360
+ with torch.no_grad():
361
+ m.attention_biases = nn.Parameter(
362
+ interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
363
+ )
363
364
 
364
- pos = torch.stack(
365
- torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]), indexing="ij")
366
- ).flatten(1)
367
- rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
368
- rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
369
- m.attention_bias_idxs = nn.Buffer(rel_pos)
365
+ device = m.attention_biases.device
366
+ pos = torch.stack(
367
+ torch.meshgrid(
368
+ torch.arange(resolution[0], device=device),
369
+ torch.arange(resolution[1], device=device),
370
+ indexing="ij",
371
+ )
372
+ ).flatten(1)
373
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
374
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
375
+ m.attention_bias_idxs = nn.Buffer(rel_pos)
370
376
 
371
377
 
372
378
  registry.register_model_config(
@@ -554,26 +554,30 @@ class EfficientFormer_v2(DetectorBackbone):
554
554
  attn.N = attn.resolution[0] * attn.resolution[1]
555
555
  attn.N2 = attn.resolution2[0] * attn.resolution2[1]
556
556
 
557
- # Interpolate attention_biases
558
- attn.attention_biases = nn.Parameter(
559
- interpolate_attention_bias(attn.attention_biases, old_base, new_base)
560
- )
561
-
562
- k_pos = torch.stack(
563
- torch.meshgrid(
564
- torch.arange(attn.resolution[0]), torch.arange(attn.resolution[1]), indexing="ij"
557
+ with torch.no_grad():
558
+ # Interpolate attention_biases
559
+ attn.attention_biases = nn.Parameter(
560
+ interpolate_attention_bias(attn.attention_biases, old_base, new_base)
565
561
  )
566
- ).flatten(1)
567
- q_pos = torch.stack(
568
- torch.meshgrid(
569
- torch.arange(0, attn.resolution[0], step=2),
570
- torch.arange(0, attn.resolution[1], step=2),
571
- indexing="ij",
572
- )
573
- ).flatten(1)
574
- rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
575
- rel_pos = (rel_pos[0] * attn.resolution[1]) + rel_pos[1]
576
- attn.attention_bias_idxs = nn.Buffer(torch.LongTensor(rel_pos), persistent=False)
562
+
563
+ device = attn.attention_biases.device
564
+ k_pos = torch.stack(
565
+ torch.meshgrid(
566
+ torch.arange(attn.resolution[0], device=device),
567
+ torch.arange(attn.resolution[1], device=device),
568
+ indexing="ij",
569
+ )
570
+ ).flatten(1)
571
+ q_pos = torch.stack(
572
+ torch.meshgrid(
573
+ torch.arange(0, attn.resolution[0], step=2, device=device),
574
+ torch.arange(0, attn.resolution[1], step=2, device=device),
575
+ indexing="ij",
576
+ )
577
+ ).flatten(1)
578
+ rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
579
+ rel_pos = (rel_pos[0] * attn.resolution[1]) + rel_pos[1]
580
+ attn.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
577
581
 
578
582
  old_base = (old_base[0] // 2, old_base[1] // 2)
579
583
  new_base = (new_base[0] // 2, new_base[1] // 2)
@@ -590,16 +594,22 @@ class EfficientFormer_v2(DetectorBackbone):
590
594
  m.token_mixer.resolution = c_new_base
591
595
  m.token_mixer.N = m.token_mixer.resolution[0] * m.token_mixer.resolution[1]
592
596
 
593
- m.token_mixer.attention_biases = nn.Parameter(
594
- interpolate_attention_bias(m.token_mixer.attention_biases, c_old_base, c_new_base)
595
- )
596
-
597
- pos = torch.stack(
598
- torch.meshgrid(torch.arange(c_new_base[0]), torch.arange(c_new_base[1]), indexing="ij")
599
- ).flatten(1)
600
- rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
601
- rel_pos = (rel_pos[0] * c_new_base[1]) + rel_pos[1]
602
- m.token_mixer.attention_bias_idxs = nn.Buffer(torch.LongTensor(rel_pos), persistent=False)
597
+ with torch.no_grad():
598
+ m.token_mixer.attention_biases = nn.Parameter(
599
+ interpolate_attention_bias(m.token_mixer.attention_biases, c_old_base, c_new_base)
600
+ )
601
+
602
+ device = m.token_mixer.attention_biases.device
603
+ pos = torch.stack(
604
+ torch.meshgrid(
605
+ torch.arange(c_new_base[0], device=device),
606
+ torch.arange(c_new_base[1], device=device),
607
+ indexing="ij",
608
+ )
609
+ ).flatten(1)
610
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
611
+ rel_pos = (rel_pos[0] * c_new_base[1]) + rel_pos[1]
612
+ m.token_mixer.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
603
613
 
604
614
 
605
615
  registry.register_model_config(
@@ -497,14 +497,16 @@ class EfficientViT_MSFT(DetectorBackbone):
497
497
 
498
498
  idxs.append(attention_offsets[offset])
499
499
 
500
- m.mixer.m.attn.attention_biases = nn.Parameter(
501
- interpolate_attention_bias(
502
- m.mixer.m.attn.attention_biases, old_window_resolution, window_resolution
500
+ with torch.no_grad():
501
+ m.mixer.m.attn.attention_biases = nn.Parameter(
502
+ interpolate_attention_bias(
503
+ m.mixer.m.attn.attention_biases, old_window_resolution, window_resolution
504
+ )
505
+ )
506
+ device = m.mixer.m.attn.attention_biases.device
507
+ m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
508
+ torch.tensor(idxs, device=device, dtype=torch.long).view(N, N), persistent=False
503
509
  )
504
- )
505
- m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
506
- torch.LongTensor(idxs).view(N, N), persistent=False
507
- )
508
510
 
509
511
 
510
512
  registry.register_model_config(
birder/net/fastvit.py CHANGED
@@ -879,6 +879,7 @@ class FastViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
879
879
  x = self.forward_features(x)
880
880
  return self.features(x)
881
881
 
882
+ @torch.no_grad() # type: ignore[untyped-decorator]
882
883
  def reparameterize_model(self) -> None:
883
884
  for module in self.modules():
884
885
  if hasattr(module, "reparameterize") is True:
birder/net/flexivit.py CHANGED
@@ -519,15 +519,16 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
519
519
  else:
520
520
  num_prefix_tokens = 0
521
521
 
522
- self.pos_embedding = nn.Parameter(
523
- # On rounding error see: https://github.com/facebookresearch/dino/issues/8
524
- adjust_position_embedding(
522
+ with torch.no_grad():
523
+ pos_embedding = adjust_position_embedding(
524
+ # On rounding error see: https://github.com/facebookresearch/dino/issues/8
525
525
  self.pos_embedding,
526
526
  (old_size[0] // self.patch_size, old_size[1] // self.patch_size),
527
527
  (new_size[0] // self.patch_size, new_size[1] // self.patch_size),
528
528
  num_prefix_tokens,
529
529
  )
530
- )
530
+
531
+ self.pos_embedding = nn.Parameter(pos_embedding)
531
532
 
532
533
  def adjust_patch_size(self, patch_size: int) -> None:
533
534
  if self.patch_size == patch_size:
birder/net/hiera.py CHANGED
@@ -612,23 +612,26 @@ class Hiera(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
612
612
 
613
613
  if self.pos_embed_win is not None:
614
614
  global_pos_size = (new_size[0] // 2**4, new_size[1] // 2**4)
615
- pos_embed = F.interpolate(
616
- self.pos_embed,
617
- size=global_pos_size,
618
- mode="bicubic",
619
- antialias=True,
620
- )
615
+ with torch.no_grad():
616
+ pos_embed = F.interpolate(
617
+ self.pos_embed,
618
+ size=global_pos_size,
619
+ mode="bicubic",
620
+ antialias=True,
621
+ )
622
+
621
623
  self.pos_embed = nn.Parameter(pos_embed)
622
624
 
623
625
  else:
624
- self.pos_embed = nn.Parameter(
625
- adjust_position_embedding(
626
+ with torch.no_grad():
627
+ pos_embed = adjust_position_embedding(
626
628
  self.pos_embed,
627
629
  (old_size[0] // self.patch_stride[0], old_size[1] // self.patch_stride[1]),
628
630
  (new_size[0] // self.patch_stride[0], new_size[1] // self.patch_stride[1]),
629
631
  0,
630
632
  )
631
- )
633
+
634
+ self.pos_embed = nn.Parameter(pos_embed)
632
635
 
633
636
  # Re-init vars
634
637
  self.tokens_spatial_shape = [i // s for i, s in zip(new_size, self.patch_stride)]
birder/net/hornet.py CHANGED
@@ -332,13 +332,15 @@ class HorNet(DetectorBackbone):
332
332
  for m in module.modules():
333
333
  if isinstance(m, HorBlock):
334
334
  if isinstance(m.gn_conv.dwconv, GlobalLocalFilter):
335
- weight = m.gn_conv.dwconv.complex_weight
336
- weight = F.interpolate(
337
- weight.permute(3, 0, 1, 2),
338
- size=(gn_conv_h[i], gn_conv_w[i]),
339
- mode="bilinear",
340
- align_corners=True,
341
- ).permute(1, 2, 3, 0)
335
+ with torch.no_grad():
336
+ weight = m.gn_conv.dwconv.complex_weight
337
+ weight = F.interpolate(
338
+ weight.permute(3, 0, 1, 2),
339
+ size=(gn_conv_h[i], gn_conv_w[i]),
340
+ mode="bilinear",
341
+ align_corners=True,
342
+ ).permute(1, 2, 3, 0)
343
+
342
344
  m.gn_conv.dwconv.complex_weight = nn.Parameter(weight)
343
345
 
344
346
 
birder/net/iformer.py CHANGED
@@ -477,12 +477,14 @@ class iFormer(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
477
477
  resolution = (new_size[0] // 4, new_size[1] // 4)
478
478
  for stage in self.body.modules():
479
479
  if isinstance(stage, InceptionTransformerStage):
480
- orig_dtype = stage.pos_embed.dtype
481
- pos_embedding = stage.pos_embed.float()
482
- pos_embedding = F.interpolate(
483
- pos_embedding.permute(0, 3, 1, 2), size=resolution, mode="bilinear"
484
- ).permute(0, 2, 3, 1)
485
- pos_embedding = pos_embedding.to(orig_dtype)
480
+ with torch.no_grad():
481
+ orig_dtype = stage.pos_embed.dtype
482
+ pos_embedding = stage.pos_embed.float()
483
+ pos_embedding = F.interpolate(
484
+ pos_embedding.permute(0, 3, 1, 2), size=resolution, mode="bilinear"
485
+ ).permute(0, 2, 3, 1)
486
+ pos_embedding = pos_embedding.to(orig_dtype)
487
+
486
488
  stage.pos_embed = nn.Parameter(pos_embedding)
487
489
  stage.resolution = resolution
488
490
  resolution = (resolution[0] // 2, resolution[1] // 2)
birder/net/levit.py CHANGED
@@ -454,42 +454,54 @@ class LeViT(BaseNet):
454
454
  # Update Subsample resolution
455
455
  m.q[0].resolution = resolution
456
456
 
457
- # Interpolate attention biases
458
- m.attention_biases = nn.Parameter(
459
- interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
460
- )
461
-
462
- # Rebuild attention bias indices
463
- k_pos = torch.stack(
464
- torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]), indexing="ij")
465
- ).flatten(1)
466
- q_pos = torch.stack(
467
- torch.meshgrid(
468
- torch.arange(0, resolution[0], step=m.stride),
469
- torch.arange(0, resolution[1], step=m.stride),
470
- indexing="ij",
457
+ with torch.no_grad():
458
+ # Interpolate attention biases
459
+ m.attention_biases = nn.Parameter(
460
+ interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
471
461
  )
472
- ).flatten(1)
473
- rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
474
- rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
475
- m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
462
+
463
+ # Rebuild attention bias indices
464
+ device = m.attention_biases.device
465
+ k_pos = torch.stack(
466
+ torch.meshgrid(
467
+ torch.arange(resolution[0], device=device),
468
+ torch.arange(resolution[1], device=device),
469
+ indexing="ij",
470
+ )
471
+ ).flatten(1)
472
+ q_pos = torch.stack(
473
+ torch.meshgrid(
474
+ torch.arange(0, resolution[0], step=m.stride, device=device),
475
+ torch.arange(0, resolution[1], step=m.stride, device=device),
476
+ indexing="ij",
477
+ )
478
+ ).flatten(1)
479
+ rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
480
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
481
+ m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
476
482
 
477
483
  old_resolution = ((old_resolution[0] - 1) // 2 + 1, (old_resolution[1] - 1) // 2 + 1)
478
484
  resolution = ((resolution[0] - 1) // 2 + 1, (resolution[1] - 1) // 2 + 1)
479
485
 
480
486
  elif isinstance(m, Attention):
481
- # Interpolate attention biases
482
- m.attention_biases = nn.Parameter(
483
- interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
484
- )
485
-
486
- # Rebuild attention bias indices
487
- pos = torch.stack(
488
- torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]), indexing="ij")
489
- ).flatten(1)
490
- rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
491
- rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
492
- m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
487
+ with torch.no_grad():
488
+ # Interpolate attention biases
489
+ m.attention_biases = nn.Parameter(
490
+ interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
491
+ )
492
+
493
+ # Rebuild attention bias indices
494
+ device = m.attention_biases.device
495
+ pos = torch.stack(
496
+ torch.meshgrid(
497
+ torch.arange(resolution[0], device=device),
498
+ torch.arange(resolution[1], device=device),
499
+ indexing="ij",
500
+ )
501
+ ).flatten(1)
502
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
503
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
504
+ m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
493
505
 
494
506
 
495
507
  registry.register_model_config(
birder/net/lit_v1_tiny.py CHANGED
@@ -340,3 +340,18 @@ registry.register_model_config(
340
340
  "drop_path_rate": 0.1,
341
341
  },
342
342
  )
343
+
344
+ registry.register_weights(
345
+ "lit_v1_t_il-common",
346
+ {
347
+ "description": "LIT v1 Tiny model trained on the il-common dataset",
348
+ "resolution": (256, 256),
349
+ "formats": {
350
+ "pt": {
351
+ "file_size": 75.2,
352
+ "sha256": "93813b2716eb9f33e06dc15ab2ba335c6d219354d2983bbc4f834f8f4e688e5c",
353
+ }
354
+ },
355
+ "net": {"network": "lit_v1_t", "tag": "il-common"},
356
+ },
357
+ )
birder/net/maxvit.py CHANGED
@@ -52,8 +52,10 @@ def _make_block_input_shapes(input_size: tuple[int, int], n_blocks: int) -> list
52
52
  return shapes
53
53
 
54
54
 
55
- def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
56
- coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)], indexing="ij"))
55
+ def _get_relative_position_index(height: int, width: int, device: torch.device | None = None) -> torch.Tensor:
56
+ coords = torch.stack(
57
+ torch.meshgrid([torch.arange(height, device=device), torch.arange(width, device=device)], indexing="ij")
58
+ )
57
59
  coords_flat = torch.flatten(coords, 1)
58
60
  relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
59
61
  relative_coords = relative_coords.permute(1, 2, 0).contiguous()
@@ -152,7 +154,9 @@ class RelativePositionalMultiHeadAttention(nn.Module):
152
154
  self.relative_position_bias_table = nn.Parameter(
153
155
  torch.empty(((2 * self.size[0] - 1) * (2 * self.size[1] - 1), self.n_heads), dtype=torch.float32),
154
156
  )
155
- self.relative_position_index = nn.Buffer(_get_relative_position_index(self.size[0], self.size[1]))
157
+ self.relative_position_index = nn.Buffer(
158
+ _get_relative_position_index(self.size[0], self.size[1], device=self.relative_position_bias_table.device)
159
+ )
156
160
 
157
161
  # Initialize with truncated normal the bias
158
162
  nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
@@ -682,60 +686,68 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
682
686
  old_attn_size = attn.size
683
687
  attn.size = self.partition_size
684
688
  attn.max_seq_len = self.partition_size[0] * self.partition_size[1]
685
- attn.relative_position_index = nn.Buffer(
686
- _get_relative_position_index(attn.size[0], attn.size[1])
687
- )
688
-
689
- # Interpolate relative_position_bias_table, adapted from
690
- # https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed_rel.py
691
- dst_size = (2 * attn.size[0] - 1, 2 * attn.size[1] - 1)
692
- rel_pos_bias = attn.relative_position_bias_table
693
- rel_pos_bias = rel_pos_bias.detach()
694
-
695
- num_attn_heads = rel_pos_bias.size(1)
696
- src_size = (2 * old_attn_size[0] - 1, 2 * old_attn_size[1] - 1)
697
-
698
- def _calc(src: int, dst: int) -> list[float]:
699
- (left, right) = 1.01, 1.5
700
- while right - left > 1e-6:
701
- q = (left + right) / 2.0
702
- gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
703
- if gp > dst // 2:
704
- right = q
705
-
706
- else:
707
- left = q
708
-
709
- dis = []
710
- cur = 1.0
711
- for i in range(src // 2):
712
- dis.append(cur)
713
- cur += q ** (i + 1)
714
-
715
- r_ids = [-_ for _ in reversed(dis)]
716
- return r_ids + [0] + dis
717
-
718
- y = _calc(src_size[0], dst_size[0])
719
- x = _calc(src_size[1], dst_size[1])
720
-
721
- ty = dst_size[0] // 2.0
722
- tx = dst_size[1] // 2.0
723
- dy = torch.arange(-ty, ty + 0.1, 1.0)
724
- dx = torch.arange(-tx, tx + 0.1, 1.0)
725
- dxy = torch.meshgrid(dx, dy, indexing="ij")
726
-
727
- all_rel_pos_bias = []
728
- for i in range(num_attn_heads):
729
- z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
730
- rgi = interpolate.RegularGridInterpolator(
731
- (x, y), z.numpy().T, method="cubic", bounds_error=False, fill_value=None
689
+ with torch.no_grad():
690
+ attn.relative_position_index = nn.Buffer(
691
+ _get_relative_position_index(
692
+ attn.size[0],
693
+ attn.size[1],
694
+ device=attn.relative_position_bias_table.device,
695
+ )
732
696
  )
733
- r = torch.Tensor(rgi(dxy)).T.contiguous().to(rel_pos_bias.device)
734
-
735
- r = r.view(-1, 1)
736
- all_rel_pos_bias.append(r)
737
697
 
738
- rel_pos_bias = torch.concat(all_rel_pos_bias, dim=-1)
698
+ # Interpolate relative_position_bias_table, adapted from
699
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed_rel.py
700
+ dst_size = (2 * attn.size[0] - 1, 2 * attn.size[1] - 1)
701
+ rel_pos_bias = attn.relative_position_bias_table.detach()
702
+ rel_pos_device = rel_pos_bias.device
703
+ rel_pos_bias = rel_pos_bias.float().cpu()
704
+
705
+ num_attn_heads = rel_pos_bias.size(1)
706
+ src_size = (2 * old_attn_size[0] - 1, 2 * old_attn_size[1] - 1)
707
+
708
+ def _calc(src: int, dst: int) -> list[float]:
709
+ (left, right) = 1.01, 1.5
710
+ while right - left > 1e-6:
711
+ q = (left + right) / 2.0
712
+ gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
713
+ if gp > dst // 2:
714
+ right = q
715
+
716
+ else:
717
+ left = q
718
+
719
+ dis = []
720
+ cur = 1.0
721
+ for i in range(src // 2):
722
+ dis.append(cur)
723
+ cur += q ** (i + 1)
724
+
725
+ r_ids = [-_ for _ in reversed(dis)]
726
+ return r_ids + [0] + dis
727
+
728
+ y = _calc(src_size[0], dst_size[0])
729
+ x = _calc(src_size[1], dst_size[1])
730
+
731
+ ty = dst_size[0] // 2.0
732
+ tx = dst_size[1] // 2.0
733
+ dy = torch.arange(-ty, ty + 0.1, 1.0)
734
+ dx = torch.arange(-tx, tx + 0.1, 1.0)
735
+ dxy = torch.meshgrid(dx, dy, indexing="ij")
736
+
737
+ all_rel_pos_bias = []
738
+ for i in range(num_attn_heads):
739
+ z = rel_pos_bias[:, i].view(src_size[0], src_size[1])
740
+ rgi = interpolate.RegularGridInterpolator(
741
+ (x, y), z.numpy().T, method="cubic", bounds_error=False, fill_value=None
742
+ )
743
+ r = torch.tensor(
744
+ rgi(dxy), device=rel_pos_device, dtype=rel_pos_bias.dtype
745
+ ).T.contiguous()
746
+
747
+ r = r.view(-1, 1)
748
+ all_rel_pos_bias.append(r)
749
+
750
+ rel_pos_bias = torch.concat(all_rel_pos_bias, dim=-1)
739
751
  attn.relative_position_bias_table = nn.Parameter(rel_pos_bias)
740
752
 
741
753
  new_grid_size = m.grid_size
birder/net/mobileone.py CHANGED
@@ -380,6 +380,7 @@ class MobileOne(DetectorBackbone):
380
380
  x = self.forward_features(x)
381
381
  return self.features(x)
382
382
 
383
+ @torch.no_grad() # type: ignore[untyped-decorator]
383
384
  def reparameterize_model(self) -> None:
384
385
  for module in self.modules():
385
386
  if hasattr(module, "reparameterize") is True:
birder/net/mvit_v2.py CHANGED
@@ -638,18 +638,19 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
638
638
  rel_sp_dim_h = 2 * max(q_size_h, kv_size_h) - 1
639
639
  rel_sp_dim_w = 2 * max(q_size_w, kv_size_w) - 1
640
640
 
641
- rel_pos_h = m.attn.rel_pos_h
642
- rel_pos_h_resized = F.interpolate(
643
- rel_pos_h.reshape(1, rel_pos_h.shape[0], -1).permute(0, 2, 1),
644
- size=rel_sp_dim_h,
645
- mode="linear",
646
- )
647
- rel_pos_w = m.attn.rel_pos_w
648
- rel_pos_w_resized = F.interpolate(
649
- rel_pos_w.reshape(1, rel_pos_w.shape[0], -1).permute(0, 2, 1),
650
- size=rel_sp_dim_w,
651
- mode="linear",
652
- )
641
+ with torch.no_grad():
642
+ rel_pos_h = m.attn.rel_pos_h
643
+ rel_pos_h_resized = F.interpolate(
644
+ rel_pos_h.reshape(1, rel_pos_h.shape[0], -1).permute(0, 2, 1),
645
+ size=rel_sp_dim_h,
646
+ mode="linear",
647
+ )
648
+ rel_pos_w = m.attn.rel_pos_w
649
+ rel_pos_w_resized = F.interpolate(
650
+ rel_pos_w.reshape(1, rel_pos_w.shape[0], -1).permute(0, 2, 1),
651
+ size=rel_sp_dim_w,
652
+ mode="linear",
653
+ )
653
654
 
654
655
  m.attn.rel_pos_h = nn.Parameter(rel_pos_h_resized.reshape(-1, rel_sp_dim_h).permute(1, 0))
655
656
  m.attn.rel_pos_w = nn.Parameter(rel_pos_w_resized.reshape(-1, rel_sp_dim_w).permute(1, 0))
birder/net/pit.py CHANGED
@@ -258,9 +258,10 @@ class PiT(DetectorBackbone):
258
258
  height = (new_size[0] - self.patch_size[0]) // self.patch_stride[0] + 1
259
259
  width = (new_size[1] - self.patch_size[1]) // self.patch_stride[1] + 1
260
260
 
261
- self.pos_embed = nn.Parameter(
262
- F.interpolate(self.pos_embed, (height, width), mode="bicubic"), requires_grad=True
263
- )
261
+ with torch.no_grad():
262
+ pos_embed = F.interpolate(self.pos_embed, (height, width), mode="bicubic")
263
+
264
+ self.pos_embed = nn.Parameter(pos_embed)
264
265
 
265
266
 
266
267
  registry.register_model_config(
birder/net/pvt_v1.py CHANGED
@@ -308,7 +308,10 @@ class PVT_v1(DetectorBackbone):
308
308
  s = (new_size[0] // 4, new_size[1] // 4)
309
309
  for m in self.body.modules():
310
310
  if isinstance(m, PyramidVisionTransformerStage):
311
- m.pos_embed = nn.Parameter(adjust_position_embedding(m.pos_embed, old_s, s, 0))
311
+ with torch.no_grad():
312
+ pos_embed = adjust_position_embedding(m.pos_embed, old_s, s, 0)
313
+
314
+ m.pos_embed = nn.Parameter(pos_embed)
312
315
  old_s = (old_s[0] // 2, old_s[1] // 2)
313
316
  s = (s[0] // 2, s[1] // 2)
314
317
 
birder/net/repghost.py CHANGED
@@ -338,6 +338,7 @@ class RepGhost(DetectorBackbone):
338
338
  x = self.forward_features(x)
339
339
  return self.features(x)
340
340
 
341
+ @torch.no_grad() # type: ignore[untyped-decorator]
341
342
  def reparameterize_model(self) -> None:
342
343
  for module in self.modules():
343
344
  if hasattr(module, "reparameterize") is True: