birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +24 -0
  3. birder/common/training_utils.py +338 -41
  4. birder/data/collators/detection.py +11 -3
  5. birder/data/dataloader/webdataset.py +12 -2
  6. birder/data/datasets/coco.py +8 -10
  7. birder/data/transforms/detection.py +30 -13
  8. birder/inference/detection.py +108 -4
  9. birder/inference/wbf.py +226 -0
  10. birder/kernels/load_kernel.py +16 -11
  11. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  12. birder/net/__init__.py +8 -0
  13. birder/net/cait.py +4 -3
  14. birder/net/convnext_v1.py +5 -0
  15. birder/net/crossformer.py +33 -30
  16. birder/net/crossvit.py +4 -3
  17. birder/net/deit.py +3 -3
  18. birder/net/deit3.py +3 -3
  19. birder/net/detection/deformable_detr.py +2 -5
  20. birder/net/detection/detr.py +2 -5
  21. birder/net/detection/efficientdet.py +67 -93
  22. birder/net/detection/fcos.py +2 -7
  23. birder/net/detection/retinanet.py +2 -7
  24. birder/net/detection/rt_detr_v1.py +2 -0
  25. birder/net/detection/yolo_anchors.py +205 -0
  26. birder/net/detection/yolo_v2.py +25 -24
  27. birder/net/detection/yolo_v3.py +39 -40
  28. birder/net/detection/yolo_v4.py +28 -26
  29. birder/net/detection/yolo_v4_tiny.py +24 -20
  30. birder/net/efficientformer_v1.py +15 -9
  31. birder/net/efficientformer_v2.py +39 -29
  32. birder/net/efficientvit_msft.py +9 -7
  33. birder/net/fasternet.py +1 -1
  34. birder/net/fastvit.py +1 -0
  35. birder/net/flexivit.py +5 -4
  36. birder/net/gc_vit.py +671 -0
  37. birder/net/hiera.py +12 -9
  38. birder/net/hornet.py +9 -7
  39. birder/net/iformer.py +8 -6
  40. birder/net/levit.py +42 -30
  41. birder/net/lit_v1.py +472 -0
  42. birder/net/lit_v1_tiny.py +357 -0
  43. birder/net/lit_v2.py +436 -0
  44. birder/net/maxvit.py +67 -55
  45. birder/net/mobilenet_v4_hybrid.py +1 -1
  46. birder/net/mobileone.py +1 -0
  47. birder/net/mvit_v2.py +13 -12
  48. birder/net/pit.py +4 -3
  49. birder/net/pvt_v1.py +4 -1
  50. birder/net/repghost.py +1 -0
  51. birder/net/repvgg.py +1 -0
  52. birder/net/repvit.py +1 -0
  53. birder/net/resnet_v1.py +1 -1
  54. birder/net/resnext.py +67 -25
  55. birder/net/rope_deit3.py +5 -3
  56. birder/net/rope_flexivit.py +7 -4
  57. birder/net/rope_vit.py +10 -5
  58. birder/net/se_resnet_v1.py +46 -0
  59. birder/net/se_resnext.py +3 -0
  60. birder/net/simple_vit.py +11 -8
  61. birder/net/swin_transformer_v1.py +71 -68
  62. birder/net/swin_transformer_v2.py +38 -31
  63. birder/net/tiny_vit.py +20 -10
  64. birder/net/transnext.py +38 -28
  65. birder/net/vit.py +5 -19
  66. birder/net/vit_parallel.py +5 -4
  67. birder/net/vit_sam.py +38 -37
  68. birder/net/vovnet_v1.py +15 -0
  69. birder/net/vovnet_v2.py +31 -1
  70. birder/ops/msda.py +108 -43
  71. birder/ops/swattention.py +124 -61
  72. birder/results/detection.py +4 -0
  73. birder/scripts/benchmark.py +110 -32
  74. birder/scripts/predict.py +8 -0
  75. birder/scripts/predict_detection.py +18 -11
  76. birder/scripts/train.py +48 -46
  77. birder/scripts/train_barlow_twins.py +44 -45
  78. birder/scripts/train_byol.py +44 -45
  79. birder/scripts/train_capi.py +50 -49
  80. birder/scripts/train_data2vec.py +45 -47
  81. birder/scripts/train_data2vec2.py +45 -47
  82. birder/scripts/train_detection.py +83 -50
  83. birder/scripts/train_dino_v1.py +60 -47
  84. birder/scripts/train_dino_v2.py +86 -52
  85. birder/scripts/train_dino_v2_dist.py +84 -50
  86. birder/scripts/train_franca.py +51 -52
  87. birder/scripts/train_i_jepa.py +45 -47
  88. birder/scripts/train_ibot.py +51 -53
  89. birder/scripts/train_kd.py +194 -76
  90. birder/scripts/train_mim.py +44 -45
  91. birder/scripts/train_mmcr.py +44 -45
  92. birder/scripts/train_rotnet.py +45 -46
  93. birder/scripts/train_simclr.py +44 -45
  94. birder/scripts/train_vicreg.py +44 -45
  95. birder/tools/auto_anchors.py +20 -1
  96. birder/tools/convert_model.py +18 -15
  97. birder/tools/det_results.py +114 -2
  98. birder/tools/pack.py +172 -103
  99. birder/tools/quantize_model.py +73 -67
  100. birder/tools/show_det_iterator.py +10 -1
  101. birder/version.py +1 -1
  102. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
  103. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
  104. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  105. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  106. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  107. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
@@ -15,18 +15,13 @@ import torch
15
15
  from torch import nn
16
16
  from torchvision.ops import Conv2dNormActivation
17
17
 
18
+ from birder.model_registry import registry
18
19
  from birder.net.base import DetectorBackbone
20
+ from birder.net.detection.yolo_anchors import resolve_anchor_groups
19
21
  from birder.net.detection.yolo_v3 import YOLOAnchorGenerator
20
22
  from birder.net.detection.yolo_v3 import YOLOHead
21
- from birder.net.detection.yolo_v3 import scale_anchors
22
23
  from birder.net.detection.yolo_v4 import YOLO_v4
23
24
 
24
- # Default anchors from YOLO v4 Tiny (COCO)
25
- DEFAULT_ANCHORS = [
26
- [(10.0, 14.0), (23.0, 27.0), (37.0, 58.0)], # Medium
27
- [(81.0, 82.0), (135.0, 169.0), (344.0, 319.0)], # Large
28
- ]
29
-
30
25
  # Scale factors per detection scale to eliminate grid sensitivity
31
26
  DEFAULT_SCALE_XY = [1.05, 1.05] # [medium, large]
32
27
 
@@ -92,7 +87,6 @@ class YOLOTinyNeck(nn.Module):
92
87
  # pylint: disable=invalid-name
93
88
  class YOLO_v4_Tiny(YOLO_v4):
94
89
  default_size = (416, 416)
95
- auto_register = True
96
90
 
97
91
  def __init__(
98
92
  self,
@@ -104,22 +98,26 @@ class YOLO_v4_Tiny(YOLO_v4):
104
98
  export_mode: bool = False,
105
99
  ) -> None:
106
100
  super().__init__(num_classes, backbone, config=config, size=size, export_mode=export_mode)
107
- assert self.config is None, "config not supported"
101
+ assert self.config is not None, "must set config"
108
102
 
109
103
  # self.num_classes = self.num_classes - 1 (Subtracted at parent)
110
104
 
111
105
  score_thresh = 0.05
112
106
  nms_thresh = 0.45
113
107
  detections_per_img = 300
114
- self.ignore_thresh = 0.7
115
-
116
- # Loss coefficients
117
- self.noobj_coeff = 0.25
118
- self.coord_coeff = 3.0
119
- self.obj_coeff = 1.0
120
- self.cls_coeff = 1.0
121
-
122
- self.anchors = scale_anchors(DEFAULT_ANCHORS, self.default_size, self.size)
108
+ ignore_thresh = 0.7
109
+ noobj_coeff = 0.25
110
+ coord_coeff = 3.0
111
+ obj_coeff = 1.0
112
+ cls_coeff = 1.0
113
+ label_smoothing = 0.1
114
+ anchor_spec = self.config["anchors"]
115
+
116
+ self.ignore_thresh = ignore_thresh
117
+ self.noobj_coeff = noobj_coeff
118
+ self.coord_coeff = coord_coeff
119
+ self.obj_coeff = obj_coeff
120
+ self.cls_coeff = cls_coeff
123
121
  self.scale_xy = DEFAULT_SCALE_XY
124
122
  self.score_thresh = score_thresh
125
123
  self.nms_thresh = nms_thresh
@@ -128,12 +126,18 @@ class YOLO_v4_Tiny(YOLO_v4):
128
126
  self.backbone.return_channels = self.backbone.return_channels[-2:]
129
127
  self.backbone.return_stages = self.backbone.return_stages[-2:]
130
128
 
131
- self.label_smoothing = 0.1
129
+ self.label_smoothing = label_smoothing
132
130
  self.smooth_positive = 1.0 - self.label_smoothing
133
131
  self.smooth_negative = self.label_smoothing / self.num_classes
134
132
 
135
133
  self.neck = YOLOTinyNeck(self.backbone.return_channels)
136
134
 
137
- self.anchor_generator = YOLOAnchorGenerator(self.anchors)
135
+ anchors = resolve_anchor_groups(
136
+ anchor_spec, anchor_format="pixels", model_size=self.size, model_strides=(16, 32)
137
+ )
138
+ self.anchor_generator = YOLOAnchorGenerator(anchors)
138
139
  num_anchors = self.anchor_generator.num_anchors_per_location()
139
140
  self.head = YOLOHead(self.neck.out_channels, num_anchors, self.num_classes)
141
+
142
+
143
+ registry.register_model_config("yolo_v4_tiny", YOLO_v4_Tiny, config={"anchors": "yolo_v4_tiny"})
@@ -357,16 +357,22 @@ class EfficientFormer_v1(BaseNet):
357
357
  resolution = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
358
358
  for m in self.body.modules():
359
359
  if isinstance(m, Attention):
360
- m.attention_biases = nn.Parameter(
361
- interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
362
- )
360
+ with torch.no_grad():
361
+ m.attention_biases = nn.Parameter(
362
+ interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
363
+ )
363
364
 
364
- pos = torch.stack(
365
- torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]), indexing="ij")
366
- ).flatten(1)
367
- rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
368
- rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
369
- m.attention_bias_idxs = nn.Buffer(rel_pos)
365
+ device = m.attention_biases.device
366
+ pos = torch.stack(
367
+ torch.meshgrid(
368
+ torch.arange(resolution[0], device=device),
369
+ torch.arange(resolution[1], device=device),
370
+ indexing="ij",
371
+ )
372
+ ).flatten(1)
373
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
374
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
375
+ m.attention_bias_idxs = nn.Buffer(rel_pos)
370
376
 
371
377
 
372
378
  registry.register_model_config(
@@ -554,26 +554,30 @@ class EfficientFormer_v2(DetectorBackbone):
554
554
  attn.N = attn.resolution[0] * attn.resolution[1]
555
555
  attn.N2 = attn.resolution2[0] * attn.resolution2[1]
556
556
 
557
- # Interpolate attention_biases
558
- attn.attention_biases = nn.Parameter(
559
- interpolate_attention_bias(attn.attention_biases, old_base, new_base)
560
- )
561
-
562
- k_pos = torch.stack(
563
- torch.meshgrid(
564
- torch.arange(attn.resolution[0]), torch.arange(attn.resolution[1]), indexing="ij"
557
+ with torch.no_grad():
558
+ # Interpolate attention_biases
559
+ attn.attention_biases = nn.Parameter(
560
+ interpolate_attention_bias(attn.attention_biases, old_base, new_base)
565
561
  )
566
- ).flatten(1)
567
- q_pos = torch.stack(
568
- torch.meshgrid(
569
- torch.arange(0, attn.resolution[0], step=2),
570
- torch.arange(0, attn.resolution[1], step=2),
571
- indexing="ij",
572
- )
573
- ).flatten(1)
574
- rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
575
- rel_pos = (rel_pos[0] * attn.resolution[1]) + rel_pos[1]
576
- attn.attention_bias_idxs = nn.Buffer(torch.LongTensor(rel_pos), persistent=False)
562
+
563
+ device = attn.attention_biases.device
564
+ k_pos = torch.stack(
565
+ torch.meshgrid(
566
+ torch.arange(attn.resolution[0], device=device),
567
+ torch.arange(attn.resolution[1], device=device),
568
+ indexing="ij",
569
+ )
570
+ ).flatten(1)
571
+ q_pos = torch.stack(
572
+ torch.meshgrid(
573
+ torch.arange(0, attn.resolution[0], step=2, device=device),
574
+ torch.arange(0, attn.resolution[1], step=2, device=device),
575
+ indexing="ij",
576
+ )
577
+ ).flatten(1)
578
+ rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
579
+ rel_pos = (rel_pos[0] * attn.resolution[1]) + rel_pos[1]
580
+ attn.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
577
581
 
578
582
  old_base = (old_base[0] // 2, old_base[1] // 2)
579
583
  new_base = (new_base[0] // 2, new_base[1] // 2)
@@ -590,16 +594,22 @@ class EfficientFormer_v2(DetectorBackbone):
590
594
  m.token_mixer.resolution = c_new_base
591
595
  m.token_mixer.N = m.token_mixer.resolution[0] * m.token_mixer.resolution[1]
592
596
 
593
- m.token_mixer.attention_biases = nn.Parameter(
594
- interpolate_attention_bias(m.token_mixer.attention_biases, c_old_base, c_new_base)
595
- )
596
-
597
- pos = torch.stack(
598
- torch.meshgrid(torch.arange(c_new_base[0]), torch.arange(c_new_base[1]), indexing="ij")
599
- ).flatten(1)
600
- rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
601
- rel_pos = (rel_pos[0] * c_new_base[1]) + rel_pos[1]
602
- m.token_mixer.attention_bias_idxs = nn.Buffer(torch.LongTensor(rel_pos), persistent=False)
597
+ with torch.no_grad():
598
+ m.token_mixer.attention_biases = nn.Parameter(
599
+ interpolate_attention_bias(m.token_mixer.attention_biases, c_old_base, c_new_base)
600
+ )
601
+
602
+ device = m.token_mixer.attention_biases.device
603
+ pos = torch.stack(
604
+ torch.meshgrid(
605
+ torch.arange(c_new_base[0], device=device),
606
+ torch.arange(c_new_base[1], device=device),
607
+ indexing="ij",
608
+ )
609
+ ).flatten(1)
610
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
611
+ rel_pos = (rel_pos[0] * c_new_base[1]) + rel_pos[1]
612
+ m.token_mixer.attention_bias_idxs = nn.Buffer(rel_pos.to(torch.long), persistent=False)
603
613
 
604
614
 
605
615
  registry.register_model_config(
@@ -497,14 +497,16 @@ class EfficientViT_MSFT(DetectorBackbone):
497
497
 
498
498
  idxs.append(attention_offsets[offset])
499
499
 
500
- m.mixer.m.attn.attention_biases = nn.Parameter(
501
- interpolate_attention_bias(
502
- m.mixer.m.attn.attention_biases, old_window_resolution, window_resolution
500
+ with torch.no_grad():
501
+ m.mixer.m.attn.attention_biases = nn.Parameter(
502
+ interpolate_attention_bias(
503
+ m.mixer.m.attn.attention_biases, old_window_resolution, window_resolution
504
+ )
505
+ )
506
+ device = m.mixer.m.attn.attention_biases.device
507
+ m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
508
+ torch.tensor(idxs, device=device, dtype=torch.long).view(N, N), persistent=False
503
509
  )
504
- )
505
- m.mixer.m.attn.attention_bias_idxs = nn.Buffer(
506
- torch.LongTensor(idxs).view(N, N), persistent=False
507
- )
508
510
 
509
511
 
510
512
  registry.register_model_config(
birder/net/fasternet.py CHANGED
@@ -6,7 +6,7 @@ Paper "Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks",
6
6
  https://arxiv.org/abs/2303.03667
7
7
 
8
8
  Changes from original:
9
- * No extra norm's for detection
9
+ * No extra norms for detection
10
10
  """
11
11
 
12
12
  # Reference license: MIT
birder/net/fastvit.py CHANGED
@@ -879,6 +879,7 @@ class FastViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
879
879
  x = self.forward_features(x)
880
880
  return self.features(x)
881
881
 
882
+ @torch.no_grad() # type: ignore[untyped-decorator]
882
883
  def reparameterize_model(self) -> None:
883
884
  for module in self.modules():
884
885
  if hasattr(module, "reparameterize") is True:
birder/net/flexivit.py CHANGED
@@ -519,15 +519,16 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
519
519
  else:
520
520
  num_prefix_tokens = 0
521
521
 
522
- self.pos_embedding = nn.Parameter(
523
- # On rounding error see: https://github.com/facebookresearch/dino/issues/8
524
- adjust_position_embedding(
522
+ with torch.no_grad():
523
+ pos_embedding = adjust_position_embedding(
524
+ # On rounding error see: https://github.com/facebookresearch/dino/issues/8
525
525
  self.pos_embedding,
526
526
  (old_size[0] // self.patch_size, old_size[1] // self.patch_size),
527
527
  (new_size[0] // self.patch_size, new_size[1] // self.patch_size),
528
528
  num_prefix_tokens,
529
529
  )
530
- )
530
+
531
+ self.pos_embedding = nn.Parameter(pos_embedding)
531
532
 
532
533
  def adjust_patch_size(self, patch_size: int) -> None:
533
534
  if self.patch_size == patch_size: