birder 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. birder/common/training_cli.py +6 -0
  2. birder/common/training_utils.py +215 -31
  3. birder/data/collators/detection.py +1 -0
  4. birder/data/dataloader/webdataset.py +12 -2
  5. birder/kernels/load_kernel.py +16 -11
  6. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  7. birder/net/cait.py +4 -3
  8. birder/net/convnext_v1.py +5 -0
  9. birder/net/crossformer.py +33 -30
  10. birder/net/crossvit.py +4 -3
  11. birder/net/deit.py +3 -3
  12. birder/net/deit3.py +3 -3
  13. birder/net/detection/deformable_detr.py +2 -5
  14. birder/net/detection/detr.py +2 -5
  15. birder/net/detection/efficientdet.py +2 -7
  16. birder/net/detection/fcos.py +2 -7
  17. birder/net/detection/retinanet.py +2 -7
  18. birder/net/detection/rt_detr_v1.py +1 -0
  19. birder/net/efficientformer_v1.py +15 -9
  20. birder/net/efficientformer_v2.py +39 -29
  21. birder/net/efficientvit_msft.py +9 -7
  22. birder/net/fastvit.py +1 -0
  23. birder/net/flexivit.py +5 -4
  24. birder/net/hiera.py +12 -9
  25. birder/net/hornet.py +9 -7
  26. birder/net/iformer.py +8 -6
  27. birder/net/levit.py +42 -30
  28. birder/net/lit_v1_tiny.py +15 -0
  29. birder/net/maxvit.py +67 -55
  30. birder/net/mobileone.py +1 -0
  31. birder/net/mvit_v2.py +13 -12
  32. birder/net/pit.py +4 -3
  33. birder/net/pvt_v1.py +4 -1
  34. birder/net/repghost.py +1 -0
  35. birder/net/repvgg.py +1 -0
  36. birder/net/repvit.py +1 -0
  37. birder/net/rope_deit3.py +5 -3
  38. birder/net/rope_flexivit.py +7 -4
  39. birder/net/rope_vit.py +10 -5
  40. birder/net/simple_vit.py +9 -6
  41. birder/net/swin_transformer_v1.py +71 -68
  42. birder/net/swin_transformer_v2.py +38 -31
  43. birder/net/tiny_vit.py +20 -10
  44. birder/net/transnext.py +38 -28
  45. birder/net/vit.py +5 -4
  46. birder/net/vit_parallel.py +5 -4
  47. birder/net/vit_sam.py +38 -37
  48. birder/net/vovnet_v1.py +15 -0
  49. birder/ops/msda.py +108 -43
  50. birder/ops/swattention.py +124 -61
  51. birder/results/detection.py +4 -0
  52. birder/scripts/benchmark.py +21 -12
  53. birder/scripts/predict.py +7 -0
  54. birder/scripts/train.py +39 -13
  55. birder/scripts/train_barlow_twins.py +35 -12
  56. birder/scripts/train_byol.py +35 -12
  57. birder/scripts/train_capi.py +41 -15
  58. birder/scripts/train_data2vec.py +37 -14
  59. birder/scripts/train_data2vec2.py +37 -14
  60. birder/scripts/train_detection.py +36 -11
  61. birder/scripts/train_dino_v1.py +51 -14
  62. birder/scripts/train_dino_v2.py +78 -19
  63. birder/scripts/train_dino_v2_dist.py +76 -17
  64. birder/scripts/train_franca.py +43 -19
  65. birder/scripts/train_i_jepa.py +37 -14
  66. birder/scripts/train_ibot.py +43 -20
  67. birder/scripts/train_kd.py +39 -13
  68. birder/scripts/train_mim.py +35 -12
  69. birder/scripts/train_mmcr.py +35 -12
  70. birder/scripts/train_rotnet.py +36 -13
  71. birder/scripts/train_simclr.py +35 -12
  72. birder/scripts/train_vicreg.py +35 -12
  73. birder/tools/convert_model.py +18 -15
  74. birder/tools/det_results.py +114 -2
  75. birder/tools/quantize_model.py +73 -67
  76. birder/version.py +1 -1
  77. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/METADATA +2 -1
  78. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/RECORD +82 -82
  79. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  80. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  81. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  82. {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
@@ -211,6 +211,12 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
211
211
  group.add_argument(
212
212
  "--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
213
213
  )
214
+ group.add_argument(
215
+ "--steps-per-epoch",
216
+ type=int,
217
+ metavar="N",
218
+ help="virtual epoch length in steps, leave unset to use the full dataset",
219
+ )
214
220
  group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
215
221
  group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
216
222
  group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
@@ -17,6 +17,7 @@ from typing import Any
17
17
  from typing import Literal
18
18
  from typing import Optional
19
19
  from typing import Sized
20
+ from typing import overload
20
21
 
21
22
  import numpy as np
22
23
  import torch
@@ -70,13 +71,7 @@ class RASampler(torch.utils.data.Sampler):
70
71
  """
71
72
 
72
73
  def __init__(
73
- self,
74
- dataset: Sized,
75
- num_replicas: int,
76
- rank: int,
77
- shuffle: bool,
78
- seed: int = 0,
79
- repetitions: int = 3,
74
+ self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
80
75
  ) -> None:
81
76
  super().__init__()
82
77
  self.dataset = dataset
@@ -85,12 +80,11 @@ class RASampler(torch.utils.data.Sampler):
85
80
  self.epoch = 0
86
81
  self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
87
82
  self.total_size = self.num_samples * self.num_replicas
88
- self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
89
83
  self.shuffle = shuffle
90
84
  self.seed = seed
91
85
  self.repetitions = repetitions
92
86
 
93
- def __iter__(self) -> Iterator[list[int]]:
87
+ def __iter__(self) -> Iterator[int]:
94
88
  if self.shuffle is True:
95
89
  # Deterministically shuffle based on epoch
96
90
  g = torch.Generator()
@@ -100,18 +94,148 @@ class RASampler(torch.utils.data.Sampler):
100
94
  indices = list(range(len(self.dataset)))
101
95
 
102
96
  # Add extra samples to make it evenly divisible
103
- indices = [ele for ele in indices for i in range(self.repetitions)]
104
- indices += indices[: (self.total_size - len(indices))]
105
- assert len(indices) == self.total_size
97
+ indices = [ele for ele in indices for _ in range(self.repetitions)]
98
+ if len(indices) < self.total_size:
99
+ indices += indices[: (self.total_size - len(indices))]
100
+ else:
101
+ indices = indices[: self.total_size]
106
102
 
107
- # Subsample
103
+ # Shard by rank
108
104
  indices = indices[self.rank : self.total_size : self.num_replicas]
109
105
  assert len(indices) == self.num_samples
110
106
 
111
- return iter(indices[: self.num_selected_samples])
107
+ yield from indices
108
+
109
+ def __len__(self) -> int:
110
+ return self.num_samples
111
+
112
+ def set_epoch(self, epoch: int) -> None:
113
+ self.epoch = epoch
114
+
115
+
116
+ class InfiniteSampler(torch.utils.data.Sampler):
117
+ """
118
+ Infinite sampler that loops indefinitely over the dataset
119
+ """
120
+
121
+ def __init__(self, dataset: Sized, shuffle: bool, seed: int = 0) -> None:
122
+ super().__init__()
123
+ self.dataset = dataset
124
+ self.shuffle = shuffle
125
+ self.seed = seed
126
+ self.epoch = 0
127
+
128
+ def __iter__(self) -> Iterator[int]:
129
+ g = torch.Generator()
130
+ while True:
131
+ if self.shuffle is True:
132
+ g.manual_seed(self.seed + self.epoch)
133
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
134
+ else:
135
+ indices = list(range(len(self.dataset)))
136
+
137
+ yield from indices
138
+
139
+ logger.debug(f"InfiniteSampler finished epoch {self.epoch}")
140
+ self.epoch += 1
141
+
142
+ def __len__(self) -> int:
143
+ return len(self.dataset)
144
+
145
+ def set_epoch(self, epoch: int) -> None:
146
+ self.epoch = epoch
147
+
148
+
149
+ class InfiniteDistributedSampler(torch.utils.data.Sampler):
150
+ """
151
+ Infinite distributed sampler that keeps a continuous shuffled stream per rank
152
+ """
153
+
154
+ def __init__(self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0) -> None:
155
+ super().__init__()
156
+ self.dataset = dataset
157
+ self.num_replicas = num_replicas
158
+ self.rank = rank
159
+ self.shuffle = shuffle
160
+ self.seed = seed
161
+ self.epoch = 0
162
+ self.num_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
163
+ self.total_size = self.num_samples * self.num_replicas
164
+
165
+ def __iter__(self) -> Iterator[int]:
166
+ g = torch.Generator()
167
+ while True:
168
+ if self.shuffle is True:
169
+ g.manual_seed(self.seed + self.epoch)
170
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
171
+ else:
172
+ indices = list(range(len(self.dataset)))
173
+
174
+ if len(indices) < self.total_size:
175
+ indices += indices[: (self.total_size - len(indices))]
176
+ else:
177
+ indices = indices[: self.total_size]
178
+
179
+ indices = indices[self.rank : self.total_size : self.num_replicas]
180
+ assert len(indices) == self.num_samples
181
+
182
+ yield from indices
183
+
184
+ logger.debug(f"InfiniteDistributedSampler finished epoch {self.epoch}")
185
+ self.epoch += 1
112
186
 
113
187
  def __len__(self) -> int:
114
- return self.num_selected_samples
188
+ return self.num_samples
189
+
190
+ def set_epoch(self, epoch: int) -> None:
191
+ self.epoch = epoch
192
+
193
+
194
+ class InfiniteRASampler(torch.utils.data.Sampler):
195
+ """
196
+ Infinite version of the repeated augmentation sampler
197
+ """
198
+
199
+ def __init__(
200
+ self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
201
+ ) -> None:
202
+ super().__init__()
203
+ self.dataset = dataset
204
+ self.num_replicas = num_replicas
205
+ self.rank = rank
206
+ self.epoch = 0
207
+ self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
208
+ self.total_size = self.num_samples * self.num_replicas
209
+ self.shuffle = shuffle
210
+ self.seed = seed
211
+ self.repetitions = repetitions
212
+
213
+ def __iter__(self) -> Iterator[int]:
214
+ g = torch.Generator()
215
+ while True:
216
+ if self.shuffle is True:
217
+ g.manual_seed(self.seed + self.epoch)
218
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
219
+ else:
220
+ indices = list(range(len(self.dataset)))
221
+
222
+ indices = [ele for ele in indices for _ in range(self.repetitions)]
223
+ if len(indices) < self.total_size:
224
+ indices += indices[: (self.total_size - len(indices))]
225
+ else:
226
+ indices = indices[: self.total_size]
227
+
228
+ # Shard by rank
229
+ indices = indices[self.rank : self.total_size : self.num_replicas]
230
+ assert len(indices) == self.num_samples
231
+
232
+ yield from indices
233
+
234
+ logger.debug(f"InfiniteRASampler finished epoch {self.epoch}")
235
+ self.epoch += 1
236
+
237
+ def __len__(self) -> int:
238
+ return self.num_samples
115
239
 
116
240
  def set_epoch(self, epoch: int) -> None:
117
241
  self.epoch = epoch
@@ -636,27 +760,87 @@ def get_amp_scaler(amp: bool, amp_dtype_str: str) -> tuple[Optional[torch.amp.Gr
636
760
  return (scaler, amp_dtype)
637
761
 
638
762
 
763
+ @overload
639
764
  def get_samplers(
640
- args: argparse.Namespace, training_dataset: torch.utils.data.Dataset, validation_dataset: torch.utils.data.Dataset
641
- ) -> torch.utils.data.Sampler:
642
- if args.distributed is True:
643
- if args.ra_sampler is True:
644
- train_sampler = RASampler(
645
- training_dataset,
646
- num_replicas=args.world_size,
647
- rank=args.rank,
648
- shuffle=True,
649
- repetitions=args.ra_reps,
650
- )
765
+ args: argparse.Namespace,
766
+ training_dataset: torch.utils.data.Dataset,
767
+ validation_dataset: torch.utils.data.Dataset,
768
+ infinite: bool = False,
769
+ ) -> tuple[torch.utils.data.Sampler, torch.utils.data.Sampler]: ...
651
770
 
652
- else:
653
- train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset, shuffle=True)
654
771
 
655
- validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
772
+ @overload
773
+ def get_samplers(
774
+ args: argparse.Namespace,
775
+ training_dataset: torch.utils.data.Dataset,
776
+ validation_dataset: None = None,
777
+ infinite: bool = False,
778
+ ) -> tuple[torch.utils.data.Sampler, None]: ...
779
+
780
+
781
+ def get_samplers(
782
+ args: argparse.Namespace,
783
+ training_dataset: torch.utils.data.Dataset,
784
+ validation_dataset: Optional[torch.utils.data.Dataset] = None,
785
+ infinite: bool = False,
786
+ ) -> tuple[torch.utils.data.Sampler, Optional[torch.utils.data.Sampler]]:
787
+ if args.seed is None:
788
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
789
+ if is_dist_available_and_initialized() is True:
790
+ seed_tensor = torch.tensor(seed, dtype=torch.int64).cuda()
791
+ dist.broadcast(seed_tensor, src=0, async_op=False)
792
+ seed = int(seed_tensor.item())
793
+ else:
794
+ seed = args.seed
795
+
796
+ ra_sampler = getattr(args, "ra_sampler", False)
797
+ if args.distributed is True:
798
+ if infinite is True:
799
+ if ra_sampler is True:
800
+ train_sampler = InfiniteRASampler(
801
+ training_dataset,
802
+ num_replicas=args.world_size,
803
+ rank=args.rank,
804
+ shuffle=True,
805
+ seed=seed,
806
+ repetitions=args.ra_reps,
807
+ )
808
+ else:
809
+ train_sampler = InfiniteDistributedSampler(
810
+ training_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True, seed=seed
811
+ )
812
+ else:
813
+ if ra_sampler is True:
814
+ train_sampler = RASampler(
815
+ training_dataset,
816
+ num_replicas=args.world_size,
817
+ rank=args.rank,
818
+ shuffle=True,
819
+ seed=seed,
820
+ repetitions=args.ra_reps,
821
+ )
822
+ else:
823
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
824
+ training_dataset, shuffle=True, seed=seed
825
+ )
826
+
827
+ if validation_dataset is None:
828
+ validation_sampler = None
829
+ else:
830
+ validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
656
831
 
657
832
  else:
658
- train_sampler = torch.utils.data.RandomSampler(training_dataset)
659
- validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
833
+ if infinite is True:
834
+ train_sampler = InfiniteSampler(training_dataset, shuffle=True, seed=seed)
835
+ else:
836
+ generator = torch.Generator()
837
+ generator.manual_seed(seed)
838
+ train_sampler = torch.utils.data.RandomSampler(training_dataset, generator=generator)
839
+
840
+ if validation_dataset is None:
841
+ validation_sampler = None
842
+ else:
843
+ validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
660
844
 
661
845
  return (train_sampler, validation_sampler)
662
846
 
@@ -98,6 +98,7 @@ class BatchRandomResizeCollator(DetectionCollator):
98
98
  if isinstance(boxes, tv_tensors.BoundingBoxes) is False:
99
99
  if boxes.numel() == 0:
100
100
  boxes = boxes.reshape(0, 4)
101
+
101
102
  boxes = tv_tensors.BoundingBoxes(
102
103
  boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=F.get_size(image)
103
104
  )
@@ -22,9 +22,19 @@ def make_wds_loader(
22
22
  shuffle: bool = False,
23
23
  *,
24
24
  exact: bool = False,
25
+ infinite: bool = False,
25
26
  ) -> DataLoader:
27
+ assert exact is False or infinite is False
28
+
29
+ if infinite is True:
30
+ dataset_iterable = dataset.repeat()
31
+ elif exact is False:
32
+ dataset_iterable = dataset.repeat()
33
+ else:
34
+ dataset_iterable = dataset
35
+
26
36
  dataloader = wds.WebLoader(
27
- dataset.repeat() if exact is False else dataset,
37
+ dataset_iterable,
28
38
  batch_size=batch_size,
29
39
  num_workers=num_workers,
30
40
  prefetch_factor=prefetch_factor,
@@ -43,7 +53,7 @@ def make_wds_loader(
43
53
  epoch_size = math.ceil(len(dataset) / (batch_size * world_size))
44
54
 
45
55
  dataloader = dataloader.with_length(epoch_size, silent=True)
46
- if exact is False:
56
+ if exact is False and infinite is False:
47
57
  dataloader = dataloader.with_epoch(epoch_size)
48
58
 
49
59
  return dataloader
@@ -14,11 +14,24 @@ logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  _CACHED_KERNELS: dict[str, ModuleType] = {}
17
+ _CUSTOM_KERNELS_ENABLED = True
18
+
19
+
20
+ def set_custom_kernels_enabled(enabled: bool) -> None:
21
+ global _CUSTOM_KERNELS_ENABLED # pylint: disable=global-statement
22
+ _CUSTOM_KERNELS_ENABLED = enabled
23
+
24
+
25
+ def is_custom_kernels_enabled() -> bool:
26
+ if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
27
+ return False
28
+
29
+ return _CUSTOM_KERNELS_ENABLED
17
30
 
18
31
 
19
32
  def load_msda() -> Optional[ModuleType]:
20
33
  name = "msda"
21
- if torch.cuda.is_available() is False or os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
34
+ if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
22
35
  return None
23
36
 
24
37
  if name in _CACHED_KERNELS:
@@ -60,7 +73,7 @@ def load_msda() -> Optional[ModuleType]:
60
73
 
61
74
  def load_swattention() -> Optional[ModuleType]:
62
75
  name = "swattention"
63
- if torch.cuda.is_available() is False or os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
76
+ if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
64
77
  return None
65
78
 
66
79
  if name in _CACHED_KERNELS:
@@ -103,7 +116,7 @@ def load_swattention() -> Optional[ModuleType]:
103
116
 
104
117
  def load_soft_nms() -> Optional[ModuleType]:
105
118
  name = "soft_nms"
106
- if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
119
+ if is_custom_kernels_enabled() is False:
107
120
  return None
108
121
 
109
122
  if name in _CACHED_KERNELS:
@@ -120,14 +133,6 @@ def load_soft_nms() -> Optional[ModuleType]:
120
133
  soft_nms: Optional[ModuleType] = load(
121
134
  "soft_nms",
122
135
  src_files,
123
- with_cuda=True,
124
- extra_cflags=["-DWITH_CUDA=1"],
125
- extra_cuda_cflags=[
126
- "-DCUDA_HAS_FP16=1",
127
- "-D__CUDA_NO_HALF_OPERATORS__",
128
- "-D__CUDA_NO_HALF_CONVERSIONS__",
129
- "-D__CUDA_NO_HALF2_OPERATORS__",
130
- ],
131
136
  )
132
137
 
133
138
  if soft_nms is not None:
@@ -61,24 +61,23 @@ void update_sorting_order(torch::Tensor& boxes, torch::Tensor& scores, torch::Te
61
61
  std::tie(max_score, t_max_idx) = torch::max(scores.index({Slice(idx + 1, None)}), 0);
62
62
 
63
63
  // max_idx is computed from sliced data, therefore need to convert it to "global" max idx
64
- auto max_idx = t_max_idx.item<int>() + idx + 1;
65
-
66
- if (scores.index({idx}).item<float>() < max_score.item<float>()) {
67
- auto boxes_idx = boxes.index({idx}).clone();
68
- auto boxes_max = boxes.index({max_idx}).clone();
69
- boxes.index({idx}) = boxes_max;
70
- boxes.index({max_idx}) = boxes_idx;
71
-
72
- auto scores_idx = scores.index({idx}).clone();
73
- auto scores_max = scores.index({max_idx}).clone();
74
- scores.index({idx}) = scores_max;
75
- scores.index({max_idx}) = scores_idx;
76
-
77
- auto areas_idx = areas.index({idx}).clone();
78
- auto areas_max = areas.index({max_idx}).clone();
79
- areas.index({idx}) = areas_max;
80
- areas.index({max_idx}) = areas_idx;
81
- }
64
+ auto max_idx = t_max_idx + (idx + 1);
65
+ auto should_swap = scores.index({idx}) < max_score;
66
+
67
+ auto boxes_idx = boxes.index({idx}).clone();
68
+ auto boxes_max = boxes.index({max_idx}).clone();
69
+ boxes.index_put_({idx}, torch::where(should_swap, boxes_max, boxes_idx));
70
+ boxes.index_put_({max_idx}, torch::where(should_swap, boxes_idx, boxes_max));
71
+
72
+ auto scores_idx = scores.index({idx}).clone();
73
+ auto scores_max = scores.index({max_idx}).clone();
74
+ scores.index_put_({idx}, torch::where(should_swap, scores_max, scores_idx));
75
+ scores.index_put_({max_idx}, torch::where(should_swap, scores_idx, scores_max));
76
+
77
+ auto areas_idx = areas.index({idx}).clone();
78
+ auto areas_max = areas.index({max_idx}).clone();
79
+ areas.index_put_({idx}, torch::where(should_swap, areas_max, areas_idx));
80
+ areas.index_put_({max_idx}, torch::where(should_swap, areas_idx, areas_max));
82
81
  }
83
82
 
84
83
  std::tuple<torch::Tensor, torch::Tensor> soft_nms(
birder/net/cait.py CHANGED
@@ -268,14 +268,15 @@ class CaiT(BaseNet):
268
268
  super().adjust_size(new_size)
269
269
 
270
270
  # Add back class tokens
271
- self.pos_embed = nn.Parameter(
272
- adjust_position_embedding(
271
+ with torch.no_grad():
272
+ pos_embed = adjust_position_embedding(
273
273
  self.pos_embed,
274
274
  (old_size[0] // self.patch_size[0], old_size[1] // self.patch_size[1]),
275
275
  (new_size[0] // self.patch_size[0], new_size[1] // self.patch_size[1]),
276
276
  0,
277
277
  )
278
- )
278
+
279
+ self.pos_embed = nn.Parameter(pos_embed)
279
280
 
280
281
 
281
282
  registry.register_model_config(
birder/net/convnext_v1.py CHANGED
@@ -195,6 +195,11 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
195
195
  return self.features(x)
196
196
 
197
197
 
198
+ registry.register_model_config(
199
+ "convnext_v1_nano", # Not in the original v1, taken from v2
200
+ ConvNeXt_v1,
201
+ config={"in_channels": [80, 160, 320, 640], "num_layers": [2, 2, 8, 2], "drop_path_rate": 0.1},
202
+ )
198
203
  registry.register_model_config(
199
204
  "convnext_v1_tiny",
200
205
  ConvNeXt_v1,
birder/net/crossformer.py CHANGED
@@ -98,15 +98,17 @@ class Attention(nn.Module):
98
98
  self.proj_drop = nn.Dropout(proj_drop)
99
99
 
100
100
  def define_bias_table(self) -> None:
101
- position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
102
- position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
101
+ device = next(self.pos.parameters()).device
102
+ position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0], device=device)
103
+ position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1], device=device)
103
104
  biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w], indexing="ij")) # 2, 2Wh-1, 2W2-1
104
105
  biases = biases.flatten(1).transpose(0, 1).float()
105
106
  self.biases = nn.Buffer(biases)
106
107
 
107
108
  def define_relative_position_index(self) -> None:
108
- coords_h = torch.arange(self.group_size[0])
109
- coords_w = torch.arange(self.group_size[1])
109
+ device = self.biases.device
110
+ coords_h = torch.arange(self.group_size[0], device=device)
111
+ coords_w = torch.arange(self.group_size[1], device=device)
110
112
  coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
111
113
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
112
114
  relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
@@ -430,32 +432,33 @@ class CrossFormer(DetectorBackbone):
430
432
 
431
433
  new_patch_resolution = (new_size[0] // self.patch_sizes[0], new_size[1] // self.patch_sizes[0])
432
434
  input_resolution = new_patch_resolution
433
- for mod in self.body.modules():
434
- if isinstance(mod, CrossFormerStage):
435
- for m in mod.modules():
436
- if isinstance(m, PatchMerging):
437
- m.input_resolution = input_resolution
438
- input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
439
- elif isinstance(m, CrossFormerBlock):
440
- m.input_resolution = input_resolution
441
-
442
- mod.resolution = input_resolution
443
-
444
- new_group_size = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
445
- for m in self.body.modules():
446
- if isinstance(m, CrossFormerBlock):
447
- m.group_size = new_group_size
448
- if m.input_resolution[0] <= m.group_size[0]:
449
- m.use_lda = False
450
- m.group_size = (m.input_resolution[0], m.group_size[1])
451
- if m.input_resolution[1] <= m.group_size[1]:
452
- m.use_lda = False
453
- m.group_size = (m.group_size[0], m.input_resolution[1])
454
-
455
- elif isinstance(m, Attention):
456
- m.group_size = new_group_size
457
- m.define_bias_table()
458
- m.define_relative_position_index()
435
+ with torch.no_grad():
436
+ for mod in self.body.modules():
437
+ if isinstance(mod, CrossFormerStage):
438
+ for m in mod.modules():
439
+ if isinstance(m, PatchMerging):
440
+ m.input_resolution = input_resolution
441
+ input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
442
+ elif isinstance(m, CrossFormerBlock):
443
+ m.input_resolution = input_resolution
444
+
445
+ mod.resolution = input_resolution
446
+
447
+ new_group_size = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
448
+ for m in self.body.modules():
449
+ if isinstance(m, CrossFormerBlock):
450
+ m.group_size = new_group_size
451
+ if m.input_resolution[0] <= m.group_size[0]:
452
+ m.use_lda = False
453
+ m.group_size = (m.input_resolution[0], m.group_size[1])
454
+ if m.input_resolution[1] <= m.group_size[1]:
455
+ m.use_lda = False
456
+ m.group_size = (m.group_size[0], m.input_resolution[1])
457
+
458
+ elif isinstance(m, Attention):
459
+ m.group_size = new_group_size
460
+ m.define_bias_table()
461
+ m.define_relative_position_index()
459
462
 
460
463
 
461
464
  registry.register_model_config(
birder/net/crossvit.py CHANGED
@@ -359,9 +359,10 @@ class CrossViT(BaseNet):
359
359
  old_w = old_size[1] // self.patch_size[i]
360
360
  h = new_size[0] // self.patch_size[i]
361
361
  w = new_size[1] // self.patch_size[i]
362
- self.pos_embed[i] = nn.Parameter(
363
- adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
364
- )
362
+ with torch.no_grad():
363
+ pos_embed = adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
364
+
365
+ self.pos_embed[i] = nn.Parameter(pos_embed)
365
366
 
366
367
 
367
368
  registry.register_model_config(
birder/net/deit.py CHANGED
@@ -187,14 +187,14 @@ class DeiT(BaseNet):
187
187
  num_prefix_tokens = 2
188
188
 
189
189
  # Add back class tokens
190
- self.pos_embedding = nn.Parameter(
191
- adjust_position_embedding(
190
+ with torch.no_grad():
191
+ pos_embedding = adjust_position_embedding(
192
192
  self.pos_embedding,
193
193
  (old_size[0] // self.patch_size, old_size[1] // self.patch_size),
194
194
  (new_size[0] // self.patch_size, new_size[1] // self.patch_size),
195
195
  num_prefix_tokens,
196
196
  )
197
- )
197
+ self.pos_embedding = nn.Parameter(pos_embedding)
198
198
 
199
199
 
200
200
  registry.register_model_config(
birder/net/deit3.py CHANGED
@@ -355,14 +355,14 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
355
355
  num_prefix_tokens = 0
356
356
 
357
357
  # Add back class tokens
358
- self.pos_embedding = nn.Parameter(
359
- adjust_position_embedding(
358
+ with torch.no_grad():
359
+ pos_embedding = adjust_position_embedding(
360
360
  self.pos_embedding,
361
361
  (old_size[0] // self.patch_size, old_size[1] // self.patch_size),
362
362
  (new_size[0] // self.patch_size, new_size[1] // self.patch_size),
363
363
  num_prefix_tokens,
364
364
  )
365
- )
365
+ self.pos_embedding = nn.Parameter(pos_embedding)
366
366
 
367
367
 
368
368
  registry.register_model_config(
@@ -757,11 +757,8 @@ class Deformable_DETR(DetectionBaseNet):
757
757
  for s, l, b in zip(scores, labels, boxes):
758
758
  # Non-maximum suppression
759
759
  if self.soft_nms is not None:
760
- # Actually much faster on CPU
761
- device = b.device
762
- (soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
763
- keep = keep.to(device)
764
- s[keep] = soft_scores.to(device)
760
+ (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
761
+ s[keep] = soft_scores
765
762
 
766
763
  b = b[keep]
767
764
  s = s[keep]
@@ -465,11 +465,8 @@ class DETR(DetectionBaseNet):
465
465
  for s, l, b in zip(scores, labels, boxes):
466
466
  # Non-maximum suppression
467
467
  if self.soft_nms is not None:
468
- # Actually much faster on CPU
469
- device = b.device
470
- (soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
471
- keep = keep.to(device)
472
- s[keep] = soft_scores.to(device)
468
+ (soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
469
+ s[keep] = soft_scores
473
470
 
474
471
  b = b[keep]
475
472
  s = s[keep]
@@ -685,13 +685,8 @@ class EfficientDet(DetectionBaseNet):
685
685
 
686
686
  # Non-maximum suppression
687
687
  if self.soft_nms is not None:
688
- # Actually much faster on CPU
689
- device = image_boxes.device
690
- (soft_scores, keep) = self.soft_nms(
691
- image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
692
- )
693
- keep = keep.to(device)
694
- image_scores[keep] = soft_scores.to(device)
688
+ (soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
689
+ image_scores[keep] = soft_scores
695
690
  else:
696
691
  keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
697
692