birder 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/training_cli.py +6 -0
- birder/common/training_utils.py +215 -31
- birder/data/collators/detection.py +1 -0
- birder/data/dataloader/webdataset.py +12 -2
- birder/kernels/load_kernel.py +16 -11
- birder/kernels/soft_nms/soft_nms.cpp +17 -18
- birder/net/cait.py +4 -3
- birder/net/convnext_v1.py +5 -0
- birder/net/crossformer.py +33 -30
- birder/net/crossvit.py +4 -3
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/detection/deformable_detr.py +2 -5
- birder/net/detection/detr.py +2 -5
- birder/net/detection/efficientdet.py +2 -7
- birder/net/detection/fcos.py +2 -7
- birder/net/detection/retinanet.py +2 -7
- birder/net/detection/rt_detr_v1.py +1 -0
- birder/net/efficientformer_v1.py +15 -9
- birder/net/efficientformer_v2.py +39 -29
- birder/net/efficientvit_msft.py +9 -7
- birder/net/fastvit.py +1 -0
- birder/net/flexivit.py +5 -4
- birder/net/hiera.py +12 -9
- birder/net/hornet.py +9 -7
- birder/net/iformer.py +8 -6
- birder/net/levit.py +42 -30
- birder/net/lit_v1_tiny.py +15 -0
- birder/net/maxvit.py +67 -55
- birder/net/mobileone.py +1 -0
- birder/net/mvit_v2.py +13 -12
- birder/net/pit.py +4 -3
- birder/net/pvt_v1.py +4 -1
- birder/net/repghost.py +1 -0
- birder/net/repvgg.py +1 -0
- birder/net/repvit.py +1 -0
- birder/net/rope_deit3.py +5 -3
- birder/net/rope_flexivit.py +7 -4
- birder/net/rope_vit.py +10 -5
- birder/net/simple_vit.py +9 -6
- birder/net/swin_transformer_v1.py +71 -68
- birder/net/swin_transformer_v2.py +38 -31
- birder/net/tiny_vit.py +20 -10
- birder/net/transnext.py +38 -28
- birder/net/vit.py +5 -4
- birder/net/vit_parallel.py +5 -4
- birder/net/vit_sam.py +38 -37
- birder/net/vovnet_v1.py +15 -0
- birder/ops/msda.py +108 -43
- birder/ops/swattention.py +124 -61
- birder/results/detection.py +4 -0
- birder/scripts/benchmark.py +21 -12
- birder/scripts/predict.py +7 -0
- birder/scripts/train.py +39 -13
- birder/scripts/train_barlow_twins.py +35 -12
- birder/scripts/train_byol.py +35 -12
- birder/scripts/train_capi.py +41 -15
- birder/scripts/train_data2vec.py +37 -14
- birder/scripts/train_data2vec2.py +37 -14
- birder/scripts/train_detection.py +36 -11
- birder/scripts/train_dino_v1.py +51 -14
- birder/scripts/train_dino_v2.py +78 -19
- birder/scripts/train_dino_v2_dist.py +76 -17
- birder/scripts/train_franca.py +43 -19
- birder/scripts/train_i_jepa.py +37 -14
- birder/scripts/train_ibot.py +43 -20
- birder/scripts/train_kd.py +39 -13
- birder/scripts/train_mim.py +35 -12
- birder/scripts/train_mmcr.py +35 -12
- birder/scripts/train_rotnet.py +36 -13
- birder/scripts/train_simclr.py +35 -12
- birder/scripts/train_vicreg.py +35 -12
- birder/tools/convert_model.py +18 -15
- birder/tools/det_results.py +114 -2
- birder/tools/quantize_model.py +73 -67
- birder/version.py +1 -1
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/METADATA +2 -1
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/RECORD +82 -82
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.3.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
birder/common/training_cli.py
CHANGED
|
@@ -211,6 +211,12 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
|
|
|
211
211
|
group.add_argument(
|
|
212
212
|
"--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
|
|
213
213
|
)
|
|
214
|
+
group.add_argument(
|
|
215
|
+
"--steps-per-epoch",
|
|
216
|
+
type=int,
|
|
217
|
+
metavar="N",
|
|
218
|
+
help="virtual epoch length in steps, leave unset to use the full dataset",
|
|
219
|
+
)
|
|
214
220
|
group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
|
|
215
221
|
group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
|
|
216
222
|
group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
|
birder/common/training_utils.py
CHANGED
|
@@ -17,6 +17,7 @@ from typing import Any
|
|
|
17
17
|
from typing import Literal
|
|
18
18
|
from typing import Optional
|
|
19
19
|
from typing import Sized
|
|
20
|
+
from typing import overload
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
import torch
|
|
@@ -70,13 +71,7 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
70
71
|
"""
|
|
71
72
|
|
|
72
73
|
def __init__(
|
|
73
|
-
self,
|
|
74
|
-
dataset: Sized,
|
|
75
|
-
num_replicas: int,
|
|
76
|
-
rank: int,
|
|
77
|
-
shuffle: bool,
|
|
78
|
-
seed: int = 0,
|
|
79
|
-
repetitions: int = 3,
|
|
74
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
80
75
|
) -> None:
|
|
81
76
|
super().__init__()
|
|
82
77
|
self.dataset = dataset
|
|
@@ -85,12 +80,11 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
85
80
|
self.epoch = 0
|
|
86
81
|
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
87
82
|
self.total_size = self.num_samples * self.num_replicas
|
|
88
|
-
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
|
89
83
|
self.shuffle = shuffle
|
|
90
84
|
self.seed = seed
|
|
91
85
|
self.repetitions = repetitions
|
|
92
86
|
|
|
93
|
-
def __iter__(self) -> Iterator[
|
|
87
|
+
def __iter__(self) -> Iterator[int]:
|
|
94
88
|
if self.shuffle is True:
|
|
95
89
|
# Deterministically shuffle based on epoch
|
|
96
90
|
g = torch.Generator()
|
|
@@ -100,18 +94,148 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
100
94
|
indices = list(range(len(self.dataset)))
|
|
101
95
|
|
|
102
96
|
# Add extra samples to make it evenly divisible
|
|
103
|
-
indices = [ele for ele in indices for
|
|
104
|
-
indices
|
|
105
|
-
|
|
97
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
98
|
+
if len(indices) < self.total_size:
|
|
99
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
100
|
+
else:
|
|
101
|
+
indices = indices[: self.total_size]
|
|
106
102
|
|
|
107
|
-
#
|
|
103
|
+
# Shard by rank
|
|
108
104
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
109
105
|
assert len(indices) == self.num_samples
|
|
110
106
|
|
|
111
|
-
|
|
107
|
+
yield from indices
|
|
108
|
+
|
|
109
|
+
def __len__(self) -> int:
|
|
110
|
+
return self.num_samples
|
|
111
|
+
|
|
112
|
+
def set_epoch(self, epoch: int) -> None:
|
|
113
|
+
self.epoch = epoch
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class InfiniteSampler(torch.utils.data.Sampler):
|
|
117
|
+
"""
|
|
118
|
+
Infinite sampler that loops indefinitely over the dataset
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, dataset: Sized, shuffle: bool, seed: int = 0) -> None:
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.dataset = dataset
|
|
124
|
+
self.shuffle = shuffle
|
|
125
|
+
self.seed = seed
|
|
126
|
+
self.epoch = 0
|
|
127
|
+
|
|
128
|
+
def __iter__(self) -> Iterator[int]:
|
|
129
|
+
g = torch.Generator()
|
|
130
|
+
while True:
|
|
131
|
+
if self.shuffle is True:
|
|
132
|
+
g.manual_seed(self.seed + self.epoch)
|
|
133
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
134
|
+
else:
|
|
135
|
+
indices = list(range(len(self.dataset)))
|
|
136
|
+
|
|
137
|
+
yield from indices
|
|
138
|
+
|
|
139
|
+
logger.debug(f"InfiniteSampler finished epoch {self.epoch}")
|
|
140
|
+
self.epoch += 1
|
|
141
|
+
|
|
142
|
+
def __len__(self) -> int:
|
|
143
|
+
return len(self.dataset)
|
|
144
|
+
|
|
145
|
+
def set_epoch(self, epoch: int) -> None:
|
|
146
|
+
self.epoch = epoch
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class InfiniteDistributedSampler(torch.utils.data.Sampler):
|
|
150
|
+
"""
|
|
151
|
+
Infinite distributed sampler that keeps a continuous shuffled stream per rank
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0) -> None:
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.dataset = dataset
|
|
157
|
+
self.num_replicas = num_replicas
|
|
158
|
+
self.rank = rank
|
|
159
|
+
self.shuffle = shuffle
|
|
160
|
+
self.seed = seed
|
|
161
|
+
self.epoch = 0
|
|
162
|
+
self.num_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
|
163
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
164
|
+
|
|
165
|
+
def __iter__(self) -> Iterator[int]:
|
|
166
|
+
g = torch.Generator()
|
|
167
|
+
while True:
|
|
168
|
+
if self.shuffle is True:
|
|
169
|
+
g.manual_seed(self.seed + self.epoch)
|
|
170
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
171
|
+
else:
|
|
172
|
+
indices = list(range(len(self.dataset)))
|
|
173
|
+
|
|
174
|
+
if len(indices) < self.total_size:
|
|
175
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
176
|
+
else:
|
|
177
|
+
indices = indices[: self.total_size]
|
|
178
|
+
|
|
179
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
180
|
+
assert len(indices) == self.num_samples
|
|
181
|
+
|
|
182
|
+
yield from indices
|
|
183
|
+
|
|
184
|
+
logger.debug(f"InfiniteDistributedSampler finished epoch {self.epoch}")
|
|
185
|
+
self.epoch += 1
|
|
112
186
|
|
|
113
187
|
def __len__(self) -> int:
|
|
114
|
-
return self.
|
|
188
|
+
return self.num_samples
|
|
189
|
+
|
|
190
|
+
def set_epoch(self, epoch: int) -> None:
|
|
191
|
+
self.epoch = epoch
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class InfiniteRASampler(torch.utils.data.Sampler):
|
|
195
|
+
"""
|
|
196
|
+
Infinite version of the repeated augmentation sampler
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(
|
|
200
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
201
|
+
) -> None:
|
|
202
|
+
super().__init__()
|
|
203
|
+
self.dataset = dataset
|
|
204
|
+
self.num_replicas = num_replicas
|
|
205
|
+
self.rank = rank
|
|
206
|
+
self.epoch = 0
|
|
207
|
+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
208
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
209
|
+
self.shuffle = shuffle
|
|
210
|
+
self.seed = seed
|
|
211
|
+
self.repetitions = repetitions
|
|
212
|
+
|
|
213
|
+
def __iter__(self) -> Iterator[int]:
|
|
214
|
+
g = torch.Generator()
|
|
215
|
+
while True:
|
|
216
|
+
if self.shuffle is True:
|
|
217
|
+
g.manual_seed(self.seed + self.epoch)
|
|
218
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
219
|
+
else:
|
|
220
|
+
indices = list(range(len(self.dataset)))
|
|
221
|
+
|
|
222
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
223
|
+
if len(indices) < self.total_size:
|
|
224
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
225
|
+
else:
|
|
226
|
+
indices = indices[: self.total_size]
|
|
227
|
+
|
|
228
|
+
# Shard by rank
|
|
229
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
230
|
+
assert len(indices) == self.num_samples
|
|
231
|
+
|
|
232
|
+
yield from indices
|
|
233
|
+
|
|
234
|
+
logger.debug(f"InfiniteRASampler finished epoch {self.epoch}")
|
|
235
|
+
self.epoch += 1
|
|
236
|
+
|
|
237
|
+
def __len__(self) -> int:
|
|
238
|
+
return self.num_samples
|
|
115
239
|
|
|
116
240
|
def set_epoch(self, epoch: int) -> None:
|
|
117
241
|
self.epoch = epoch
|
|
@@ -636,27 +760,87 @@ def get_amp_scaler(amp: bool, amp_dtype_str: str) -> tuple[Optional[torch.amp.Gr
|
|
|
636
760
|
return (scaler, amp_dtype)
|
|
637
761
|
|
|
638
762
|
|
|
763
|
+
@overload
|
|
639
764
|
def get_samplers(
|
|
640
|
-
args: argparse.Namespace,
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
training_dataset,
|
|
646
|
-
num_replicas=args.world_size,
|
|
647
|
-
rank=args.rank,
|
|
648
|
-
shuffle=True,
|
|
649
|
-
repetitions=args.ra_reps,
|
|
650
|
-
)
|
|
765
|
+
args: argparse.Namespace,
|
|
766
|
+
training_dataset: torch.utils.data.Dataset,
|
|
767
|
+
validation_dataset: torch.utils.data.Dataset,
|
|
768
|
+
infinite: bool = False,
|
|
769
|
+
) -> tuple[torch.utils.data.Sampler, torch.utils.data.Sampler]: ...
|
|
651
770
|
|
|
652
|
-
else:
|
|
653
|
-
train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset, shuffle=True)
|
|
654
771
|
|
|
655
|
-
|
|
772
|
+
@overload
|
|
773
|
+
def get_samplers(
|
|
774
|
+
args: argparse.Namespace,
|
|
775
|
+
training_dataset: torch.utils.data.Dataset,
|
|
776
|
+
validation_dataset: None = None,
|
|
777
|
+
infinite: bool = False,
|
|
778
|
+
) -> tuple[torch.utils.data.Sampler, None]: ...
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def get_samplers(
|
|
782
|
+
args: argparse.Namespace,
|
|
783
|
+
training_dataset: torch.utils.data.Dataset,
|
|
784
|
+
validation_dataset: Optional[torch.utils.data.Dataset] = None,
|
|
785
|
+
infinite: bool = False,
|
|
786
|
+
) -> tuple[torch.utils.data.Sampler, Optional[torch.utils.data.Sampler]]:
|
|
787
|
+
if args.seed is None:
|
|
788
|
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
|
789
|
+
if is_dist_available_and_initialized() is True:
|
|
790
|
+
seed_tensor = torch.tensor(seed, dtype=torch.int64).cuda()
|
|
791
|
+
dist.broadcast(seed_tensor, src=0, async_op=False)
|
|
792
|
+
seed = int(seed_tensor.item())
|
|
793
|
+
else:
|
|
794
|
+
seed = args.seed
|
|
795
|
+
|
|
796
|
+
ra_sampler = getattr(args, "ra_sampler", False)
|
|
797
|
+
if args.distributed is True:
|
|
798
|
+
if infinite is True:
|
|
799
|
+
if ra_sampler is True:
|
|
800
|
+
train_sampler = InfiniteRASampler(
|
|
801
|
+
training_dataset,
|
|
802
|
+
num_replicas=args.world_size,
|
|
803
|
+
rank=args.rank,
|
|
804
|
+
shuffle=True,
|
|
805
|
+
seed=seed,
|
|
806
|
+
repetitions=args.ra_reps,
|
|
807
|
+
)
|
|
808
|
+
else:
|
|
809
|
+
train_sampler = InfiniteDistributedSampler(
|
|
810
|
+
training_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True, seed=seed
|
|
811
|
+
)
|
|
812
|
+
else:
|
|
813
|
+
if ra_sampler is True:
|
|
814
|
+
train_sampler = RASampler(
|
|
815
|
+
training_dataset,
|
|
816
|
+
num_replicas=args.world_size,
|
|
817
|
+
rank=args.rank,
|
|
818
|
+
shuffle=True,
|
|
819
|
+
seed=seed,
|
|
820
|
+
repetitions=args.ra_reps,
|
|
821
|
+
)
|
|
822
|
+
else:
|
|
823
|
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
824
|
+
training_dataset, shuffle=True, seed=seed
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
if validation_dataset is None:
|
|
828
|
+
validation_sampler = None
|
|
829
|
+
else:
|
|
830
|
+
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
|
|
656
831
|
|
|
657
832
|
else:
|
|
658
|
-
|
|
659
|
-
|
|
833
|
+
if infinite is True:
|
|
834
|
+
train_sampler = InfiniteSampler(training_dataset, shuffle=True, seed=seed)
|
|
835
|
+
else:
|
|
836
|
+
generator = torch.Generator()
|
|
837
|
+
generator.manual_seed(seed)
|
|
838
|
+
train_sampler = torch.utils.data.RandomSampler(training_dataset, generator=generator)
|
|
839
|
+
|
|
840
|
+
if validation_dataset is None:
|
|
841
|
+
validation_sampler = None
|
|
842
|
+
else:
|
|
843
|
+
validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
|
|
660
844
|
|
|
661
845
|
return (train_sampler, validation_sampler)
|
|
662
846
|
|
|
@@ -98,6 +98,7 @@ class BatchRandomResizeCollator(DetectionCollator):
|
|
|
98
98
|
if isinstance(boxes, tv_tensors.BoundingBoxes) is False:
|
|
99
99
|
if boxes.numel() == 0:
|
|
100
100
|
boxes = boxes.reshape(0, 4)
|
|
101
|
+
|
|
101
102
|
boxes = tv_tensors.BoundingBoxes(
|
|
102
103
|
boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=F.get_size(image)
|
|
103
104
|
)
|
|
@@ -22,9 +22,19 @@ def make_wds_loader(
|
|
|
22
22
|
shuffle: bool = False,
|
|
23
23
|
*,
|
|
24
24
|
exact: bool = False,
|
|
25
|
+
infinite: bool = False,
|
|
25
26
|
) -> DataLoader:
|
|
27
|
+
assert exact is False or infinite is False
|
|
28
|
+
|
|
29
|
+
if infinite is True:
|
|
30
|
+
dataset_iterable = dataset.repeat()
|
|
31
|
+
elif exact is False:
|
|
32
|
+
dataset_iterable = dataset.repeat()
|
|
33
|
+
else:
|
|
34
|
+
dataset_iterable = dataset
|
|
35
|
+
|
|
26
36
|
dataloader = wds.WebLoader(
|
|
27
|
-
|
|
37
|
+
dataset_iterable,
|
|
28
38
|
batch_size=batch_size,
|
|
29
39
|
num_workers=num_workers,
|
|
30
40
|
prefetch_factor=prefetch_factor,
|
|
@@ -43,7 +53,7 @@ def make_wds_loader(
|
|
|
43
53
|
epoch_size = math.ceil(len(dataset) / (batch_size * world_size))
|
|
44
54
|
|
|
45
55
|
dataloader = dataloader.with_length(epoch_size, silent=True)
|
|
46
|
-
if exact is False:
|
|
56
|
+
if exact is False and infinite is False:
|
|
47
57
|
dataloader = dataloader.with_epoch(epoch_size)
|
|
48
58
|
|
|
49
59
|
return dataloader
|
birder/kernels/load_kernel.py
CHANGED
|
@@ -14,11 +14,24 @@ logger = logging.getLogger(__name__)
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
_CACHED_KERNELS: dict[str, ModuleType] = {}
|
|
17
|
+
_CUSTOM_KERNELS_ENABLED = True
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def set_custom_kernels_enabled(enabled: bool) -> None:
|
|
21
|
+
global _CUSTOM_KERNELS_ENABLED # pylint: disable=global-statement
|
|
22
|
+
_CUSTOM_KERNELS_ENABLED = enabled
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_custom_kernels_enabled() -> bool:
|
|
26
|
+
if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
return _CUSTOM_KERNELS_ENABLED
|
|
17
30
|
|
|
18
31
|
|
|
19
32
|
def load_msda() -> Optional[ModuleType]:
|
|
20
33
|
name = "msda"
|
|
21
|
-
if torch.cuda.is_available() is False or
|
|
34
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
22
35
|
return None
|
|
23
36
|
|
|
24
37
|
if name in _CACHED_KERNELS:
|
|
@@ -60,7 +73,7 @@ def load_msda() -> Optional[ModuleType]:
|
|
|
60
73
|
|
|
61
74
|
def load_swattention() -> Optional[ModuleType]:
|
|
62
75
|
name = "swattention"
|
|
63
|
-
if torch.cuda.is_available() is False or
|
|
76
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
64
77
|
return None
|
|
65
78
|
|
|
66
79
|
if name in _CACHED_KERNELS:
|
|
@@ -103,7 +116,7 @@ def load_swattention() -> Optional[ModuleType]:
|
|
|
103
116
|
|
|
104
117
|
def load_soft_nms() -> Optional[ModuleType]:
|
|
105
118
|
name = "soft_nms"
|
|
106
|
-
if
|
|
119
|
+
if is_custom_kernels_enabled() is False:
|
|
107
120
|
return None
|
|
108
121
|
|
|
109
122
|
if name in _CACHED_KERNELS:
|
|
@@ -120,14 +133,6 @@ def load_soft_nms() -> Optional[ModuleType]:
|
|
|
120
133
|
soft_nms: Optional[ModuleType] = load(
|
|
121
134
|
"soft_nms",
|
|
122
135
|
src_files,
|
|
123
|
-
with_cuda=True,
|
|
124
|
-
extra_cflags=["-DWITH_CUDA=1"],
|
|
125
|
-
extra_cuda_cflags=[
|
|
126
|
-
"-DCUDA_HAS_FP16=1",
|
|
127
|
-
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
128
|
-
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
129
|
-
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
130
|
-
],
|
|
131
136
|
)
|
|
132
137
|
|
|
133
138
|
if soft_nms is not None:
|
|
@@ -61,24 +61,23 @@ void update_sorting_order(torch::Tensor& boxes, torch::Tensor& scores, torch::Te
|
|
|
61
61
|
std::tie(max_score, t_max_idx) = torch::max(scores.index({Slice(idx + 1, None)}), 0);
|
|
62
62
|
|
|
63
63
|
// max_idx is computed from sliced data, therefore need to convert it to "global" max idx
|
|
64
|
-
auto max_idx = t_max_idx
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
}
|
|
64
|
+
auto max_idx = t_max_idx + (idx + 1);
|
|
65
|
+
auto should_swap = scores.index({idx}) < max_score;
|
|
66
|
+
|
|
67
|
+
auto boxes_idx = boxes.index({idx}).clone();
|
|
68
|
+
auto boxes_max = boxes.index({max_idx}).clone();
|
|
69
|
+
boxes.index_put_({idx}, torch::where(should_swap, boxes_max, boxes_idx));
|
|
70
|
+
boxes.index_put_({max_idx}, torch::where(should_swap, boxes_idx, boxes_max));
|
|
71
|
+
|
|
72
|
+
auto scores_idx = scores.index({idx}).clone();
|
|
73
|
+
auto scores_max = scores.index({max_idx}).clone();
|
|
74
|
+
scores.index_put_({idx}, torch::where(should_swap, scores_max, scores_idx));
|
|
75
|
+
scores.index_put_({max_idx}, torch::where(should_swap, scores_idx, scores_max));
|
|
76
|
+
|
|
77
|
+
auto areas_idx = areas.index({idx}).clone();
|
|
78
|
+
auto areas_max = areas.index({max_idx}).clone();
|
|
79
|
+
areas.index_put_({idx}, torch::where(should_swap, areas_max, areas_idx));
|
|
80
|
+
areas.index_put_({max_idx}, torch::where(should_swap, areas_idx, areas_max));
|
|
82
81
|
}
|
|
83
82
|
|
|
84
83
|
std::tuple<torch::Tensor, torch::Tensor> soft_nms(
|
birder/net/cait.py
CHANGED
|
@@ -268,14 +268,15 @@ class CaiT(BaseNet):
|
|
|
268
268
|
super().adjust_size(new_size)
|
|
269
269
|
|
|
270
270
|
# Add back class tokens
|
|
271
|
-
|
|
272
|
-
adjust_position_embedding(
|
|
271
|
+
with torch.no_grad():
|
|
272
|
+
pos_embed = adjust_position_embedding(
|
|
273
273
|
self.pos_embed,
|
|
274
274
|
(old_size[0] // self.patch_size[0], old_size[1] // self.patch_size[1]),
|
|
275
275
|
(new_size[0] // self.patch_size[0], new_size[1] // self.patch_size[1]),
|
|
276
276
|
0,
|
|
277
277
|
)
|
|
278
|
-
|
|
278
|
+
|
|
279
|
+
self.pos_embed = nn.Parameter(pos_embed)
|
|
279
280
|
|
|
280
281
|
|
|
281
282
|
registry.register_model_config(
|
birder/net/convnext_v1.py
CHANGED
|
@@ -195,6 +195,11 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
195
195
|
return self.features(x)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
registry.register_model_config(
|
|
199
|
+
"convnext_v1_nano", # Not in the original v1, taken from v2
|
|
200
|
+
ConvNeXt_v1,
|
|
201
|
+
config={"in_channels": [80, 160, 320, 640], "num_layers": [2, 2, 8, 2], "drop_path_rate": 0.1},
|
|
202
|
+
)
|
|
198
203
|
registry.register_model_config(
|
|
199
204
|
"convnext_v1_tiny",
|
|
200
205
|
ConvNeXt_v1,
|
birder/net/crossformer.py
CHANGED
|
@@ -98,15 +98,17 @@ class Attention(nn.Module):
|
|
|
98
98
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
99
99
|
|
|
100
100
|
def define_bias_table(self) -> None:
|
|
101
|
-
|
|
102
|
-
|
|
101
|
+
device = next(self.pos.parameters()).device
|
|
102
|
+
position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0], device=device)
|
|
103
|
+
position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1], device=device)
|
|
103
104
|
biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w], indexing="ij")) # 2, 2Wh-1, 2W2-1
|
|
104
105
|
biases = biases.flatten(1).transpose(0, 1).float()
|
|
105
106
|
self.biases = nn.Buffer(biases)
|
|
106
107
|
|
|
107
108
|
def define_relative_position_index(self) -> None:
|
|
108
|
-
|
|
109
|
-
|
|
109
|
+
device = self.biases.device
|
|
110
|
+
coords_h = torch.arange(self.group_size[0], device=device)
|
|
111
|
+
coords_w = torch.arange(self.group_size[1], device=device)
|
|
110
112
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
|
111
113
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
112
114
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
@@ -430,32 +432,33 @@ class CrossFormer(DetectorBackbone):
|
|
|
430
432
|
|
|
431
433
|
new_patch_resolution = (new_size[0] // self.patch_sizes[0], new_size[1] // self.patch_sizes[0])
|
|
432
434
|
input_resolution = new_patch_resolution
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
m
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
m
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
m
|
|
448
|
-
|
|
449
|
-
m.
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
m.
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
m
|
|
457
|
-
|
|
458
|
-
|
|
435
|
+
with torch.no_grad():
|
|
436
|
+
for mod in self.body.modules():
|
|
437
|
+
if isinstance(mod, CrossFormerStage):
|
|
438
|
+
for m in mod.modules():
|
|
439
|
+
if isinstance(m, PatchMerging):
|
|
440
|
+
m.input_resolution = input_resolution
|
|
441
|
+
input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
|
|
442
|
+
elif isinstance(m, CrossFormerBlock):
|
|
443
|
+
m.input_resolution = input_resolution
|
|
444
|
+
|
|
445
|
+
mod.resolution = input_resolution
|
|
446
|
+
|
|
447
|
+
new_group_size = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
|
|
448
|
+
for m in self.body.modules():
|
|
449
|
+
if isinstance(m, CrossFormerBlock):
|
|
450
|
+
m.group_size = new_group_size
|
|
451
|
+
if m.input_resolution[0] <= m.group_size[0]:
|
|
452
|
+
m.use_lda = False
|
|
453
|
+
m.group_size = (m.input_resolution[0], m.group_size[1])
|
|
454
|
+
if m.input_resolution[1] <= m.group_size[1]:
|
|
455
|
+
m.use_lda = False
|
|
456
|
+
m.group_size = (m.group_size[0], m.input_resolution[1])
|
|
457
|
+
|
|
458
|
+
elif isinstance(m, Attention):
|
|
459
|
+
m.group_size = new_group_size
|
|
460
|
+
m.define_bias_table()
|
|
461
|
+
m.define_relative_position_index()
|
|
459
462
|
|
|
460
463
|
|
|
461
464
|
registry.register_model_config(
|
birder/net/crossvit.py
CHANGED
|
@@ -359,9 +359,10 @@ class CrossViT(BaseNet):
|
|
|
359
359
|
old_w = old_size[1] // self.patch_size[i]
|
|
360
360
|
h = new_size[0] // self.patch_size[i]
|
|
361
361
|
w = new_size[1] // self.patch_size[i]
|
|
362
|
-
|
|
363
|
-
adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
|
|
364
|
-
|
|
362
|
+
with torch.no_grad():
|
|
363
|
+
pos_embed = adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
|
|
364
|
+
|
|
365
|
+
self.pos_embed[i] = nn.Parameter(pos_embed)
|
|
365
366
|
|
|
366
367
|
|
|
367
368
|
registry.register_model_config(
|
birder/net/deit.py
CHANGED
|
@@ -187,14 +187,14 @@ class DeiT(BaseNet):
|
|
|
187
187
|
num_prefix_tokens = 2
|
|
188
188
|
|
|
189
189
|
# Add back class tokens
|
|
190
|
-
|
|
191
|
-
adjust_position_embedding(
|
|
190
|
+
with torch.no_grad():
|
|
191
|
+
pos_embedding = adjust_position_embedding(
|
|
192
192
|
self.pos_embedding,
|
|
193
193
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
194
194
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
195
195
|
num_prefix_tokens,
|
|
196
196
|
)
|
|
197
|
-
)
|
|
197
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
198
198
|
|
|
199
199
|
|
|
200
200
|
registry.register_model_config(
|
birder/net/deit3.py
CHANGED
|
@@ -355,14 +355,14 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
355
355
|
num_prefix_tokens = 0
|
|
356
356
|
|
|
357
357
|
# Add back class tokens
|
|
358
|
-
|
|
359
|
-
adjust_position_embedding(
|
|
358
|
+
with torch.no_grad():
|
|
359
|
+
pos_embedding = adjust_position_embedding(
|
|
360
360
|
self.pos_embedding,
|
|
361
361
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
362
362
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
363
363
|
num_prefix_tokens,
|
|
364
364
|
)
|
|
365
|
-
)
|
|
365
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
366
366
|
|
|
367
367
|
|
|
368
368
|
registry.register_model_config(
|
|
@@ -757,11 +757,8 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
757
757
|
for s, l, b in zip(scores, labels, boxes):
|
|
758
758
|
# Non-maximum suppression
|
|
759
759
|
if self.soft_nms is not None:
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
(soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
|
|
763
|
-
keep = keep.to(device)
|
|
764
|
-
s[keep] = soft_scores.to(device)
|
|
760
|
+
(soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
761
|
+
s[keep] = soft_scores
|
|
765
762
|
|
|
766
763
|
b = b[keep]
|
|
767
764
|
s = s[keep]
|
birder/net/detection/detr.py
CHANGED
|
@@ -465,11 +465,8 @@ class DETR(DetectionBaseNet):
|
|
|
465
465
|
for s, l, b in zip(scores, labels, boxes):
|
|
466
466
|
# Non-maximum suppression
|
|
467
467
|
if self.soft_nms is not None:
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
(soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
|
|
471
|
-
keep = keep.to(device)
|
|
472
|
-
s[keep] = soft_scores.to(device)
|
|
468
|
+
(soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
469
|
+
s[keep] = soft_scores
|
|
473
470
|
|
|
474
471
|
b = b[keep]
|
|
475
472
|
s = s[keep]
|
|
@@ -685,13 +685,8 @@ class EfficientDet(DetectionBaseNet):
|
|
|
685
685
|
|
|
686
686
|
# Non-maximum suppression
|
|
687
687
|
if self.soft_nms is not None:
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
(soft_scores, keep) = self.soft_nms(
|
|
691
|
-
image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
|
|
692
|
-
)
|
|
693
|
-
keep = keep.to(device)
|
|
694
|
-
image_scores[keep] = soft_scores.to(device)
|
|
688
|
+
(soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
689
|
+
image_scores[keep] = soft_scores
|
|
695
690
|
else:
|
|
696
691
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
697
692
|
|