birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +24 -0
- birder/common/training_utils.py +338 -41
- birder/data/collators/detection.py +11 -3
- birder/data/dataloader/webdataset.py +12 -2
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/kernels/load_kernel.py +16 -11
- birder/kernels/soft_nms/soft_nms.cpp +17 -18
- birder/net/__init__.py +8 -0
- birder/net/cait.py +4 -3
- birder/net/convnext_v1.py +5 -0
- birder/net/crossformer.py +33 -30
- birder/net/crossvit.py +4 -3
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/detection/deformable_detr.py +2 -5
- birder/net/detection/detr.py +2 -5
- birder/net/detection/efficientdet.py +67 -93
- birder/net/detection/fcos.py +2 -7
- birder/net/detection/retinanet.py +2 -7
- birder/net/detection/rt_detr_v1.py +2 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/efficientformer_v1.py +15 -9
- birder/net/efficientformer_v2.py +39 -29
- birder/net/efficientvit_msft.py +9 -7
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +1 -0
- birder/net/flexivit.py +5 -4
- birder/net/gc_vit.py +671 -0
- birder/net/hiera.py +12 -9
- birder/net/hornet.py +9 -7
- birder/net/iformer.py +8 -6
- birder/net/levit.py +42 -30
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +357 -0
- birder/net/lit_v2.py +436 -0
- birder/net/maxvit.py +67 -55
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/mobileone.py +1 -0
- birder/net/mvit_v2.py +13 -12
- birder/net/pit.py +4 -3
- birder/net/pvt_v1.py +4 -1
- birder/net/repghost.py +1 -0
- birder/net/repvgg.py +1 -0
- birder/net/repvit.py +1 -0
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/rope_deit3.py +5 -3
- birder/net/rope_flexivit.py +7 -4
- birder/net/rope_vit.py +10 -5
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +11 -8
- birder/net/swin_transformer_v1.py +71 -68
- birder/net/swin_transformer_v2.py +38 -31
- birder/net/tiny_vit.py +20 -10
- birder/net/transnext.py +38 -28
- birder/net/vit.py +5 -19
- birder/net/vit_parallel.py +5 -4
- birder/net/vit_sam.py +38 -37
- birder/net/vovnet_v1.py +15 -0
- birder/net/vovnet_v2.py +31 -1
- birder/ops/msda.py +108 -43
- birder/ops/swattention.py +124 -61
- birder/results/detection.py +4 -0
- birder/scripts/benchmark.py +110 -32
- birder/scripts/predict.py +8 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +48 -46
- birder/scripts/train_barlow_twins.py +44 -45
- birder/scripts/train_byol.py +44 -45
- birder/scripts/train_capi.py +50 -49
- birder/scripts/train_data2vec.py +45 -47
- birder/scripts/train_data2vec2.py +45 -47
- birder/scripts/train_detection.py +83 -50
- birder/scripts/train_dino_v1.py +60 -47
- birder/scripts/train_dino_v2.py +86 -52
- birder/scripts/train_dino_v2_dist.py +84 -50
- birder/scripts/train_franca.py +51 -52
- birder/scripts/train_i_jepa.py +45 -47
- birder/scripts/train_ibot.py +51 -53
- birder/scripts/train_kd.py +194 -76
- birder/scripts/train_mim.py +44 -45
- birder/scripts/train_mmcr.py +44 -45
- birder/scripts/train_rotnet.py +45 -46
- birder/scripts/train_simclr.py +44 -45
- birder/scripts/train_vicreg.py +44 -45
- birder/tools/auto_anchors.py +20 -1
- birder/tools/convert_model.py +18 -15
- birder/tools/det_results.py +114 -2
- birder/tools/pack.py +172 -103
- birder/tools/quantize_model.py +73 -67
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
birder/net/__init__.py
CHANGED
|
@@ -31,6 +31,7 @@ from birder.net.fasternet import FasterNet
|
|
|
31
31
|
from birder.net.fastvit import FastViT
|
|
32
32
|
from birder.net.flexivit import FlexiViT
|
|
33
33
|
from birder.net.focalnet import FocalNet
|
|
34
|
+
from birder.net.gc_vit import GC_ViT
|
|
34
35
|
from birder.net.ghostnet_v1 import GhostNet_v1
|
|
35
36
|
from birder.net.ghostnet_v2 import GhostNet_v2
|
|
36
37
|
from birder.net.groupmixformer import GroupMixFormer
|
|
@@ -46,6 +47,9 @@ from birder.net.inception_resnet_v2 import Inception_ResNet_v2
|
|
|
46
47
|
from birder.net.inception_v3 import Inception_v3
|
|
47
48
|
from birder.net.inception_v4 import Inception_v4
|
|
48
49
|
from birder.net.levit import LeViT
|
|
50
|
+
from birder.net.lit_v1 import LIT_v1
|
|
51
|
+
from birder.net.lit_v1_tiny import LIT_v1_Tiny
|
|
52
|
+
from birder.net.lit_v2 import LIT_v2
|
|
49
53
|
from birder.net.maxvit import MaxViT
|
|
50
54
|
from birder.net.metaformer import MetaFormer
|
|
51
55
|
from birder.net.mnasnet import MNASNet
|
|
@@ -143,6 +147,7 @@ __all__ = [
|
|
|
143
147
|
"FastViT",
|
|
144
148
|
"FlexiViT",
|
|
145
149
|
"FocalNet",
|
|
150
|
+
"GC_ViT",
|
|
146
151
|
"GhostNet_v1",
|
|
147
152
|
"GhostNet_v2",
|
|
148
153
|
"GroupMixFormer",
|
|
@@ -158,6 +163,9 @@ __all__ = [
|
|
|
158
163
|
"Inception_v3",
|
|
159
164
|
"Inception_v4",
|
|
160
165
|
"LeViT",
|
|
166
|
+
"LIT_v1",
|
|
167
|
+
"LIT_v1_Tiny",
|
|
168
|
+
"LIT_v2",
|
|
161
169
|
"MaxViT",
|
|
162
170
|
"MetaFormer",
|
|
163
171
|
"MNASNet",
|
birder/net/cait.py
CHANGED
|
@@ -268,14 +268,15 @@ class CaiT(BaseNet):
|
|
|
268
268
|
super().adjust_size(new_size)
|
|
269
269
|
|
|
270
270
|
# Add back class tokens
|
|
271
|
-
|
|
272
|
-
adjust_position_embedding(
|
|
271
|
+
with torch.no_grad():
|
|
272
|
+
pos_embed = adjust_position_embedding(
|
|
273
273
|
self.pos_embed,
|
|
274
274
|
(old_size[0] // self.patch_size[0], old_size[1] // self.patch_size[1]),
|
|
275
275
|
(new_size[0] // self.patch_size[0], new_size[1] // self.patch_size[1]),
|
|
276
276
|
0,
|
|
277
277
|
)
|
|
278
|
-
|
|
278
|
+
|
|
279
|
+
self.pos_embed = nn.Parameter(pos_embed)
|
|
279
280
|
|
|
280
281
|
|
|
281
282
|
registry.register_model_config(
|
birder/net/convnext_v1.py
CHANGED
|
@@ -195,6 +195,11 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
195
195
|
return self.features(x)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
registry.register_model_config(
|
|
199
|
+
"convnext_v1_nano", # Not in the original v1, taken from v2
|
|
200
|
+
ConvNeXt_v1,
|
|
201
|
+
config={"in_channels": [80, 160, 320, 640], "num_layers": [2, 2, 8, 2], "drop_path_rate": 0.1},
|
|
202
|
+
)
|
|
198
203
|
registry.register_model_config(
|
|
199
204
|
"convnext_v1_tiny",
|
|
200
205
|
ConvNeXt_v1,
|
birder/net/crossformer.py
CHANGED
|
@@ -98,15 +98,17 @@ class Attention(nn.Module):
|
|
|
98
98
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
99
99
|
|
|
100
100
|
def define_bias_table(self) -> None:
|
|
101
|
-
|
|
102
|
-
|
|
101
|
+
device = next(self.pos.parameters()).device
|
|
102
|
+
position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0], device=device)
|
|
103
|
+
position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1], device=device)
|
|
103
104
|
biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w], indexing="ij")) # 2, 2Wh-1, 2W2-1
|
|
104
105
|
biases = biases.flatten(1).transpose(0, 1).float()
|
|
105
106
|
self.biases = nn.Buffer(biases)
|
|
106
107
|
|
|
107
108
|
def define_relative_position_index(self) -> None:
|
|
108
|
-
|
|
109
|
-
|
|
109
|
+
device = self.biases.device
|
|
110
|
+
coords_h = torch.arange(self.group_size[0], device=device)
|
|
111
|
+
coords_w = torch.arange(self.group_size[1], device=device)
|
|
110
112
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
|
111
113
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
112
114
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
@@ -430,32 +432,33 @@ class CrossFormer(DetectorBackbone):
|
|
|
430
432
|
|
|
431
433
|
new_patch_resolution = (new_size[0] // self.patch_sizes[0], new_size[1] // self.patch_sizes[0])
|
|
432
434
|
input_resolution = new_patch_resolution
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
m
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
m
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
m
|
|
448
|
-
|
|
449
|
-
m.
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
m.
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
m
|
|
457
|
-
|
|
458
|
-
|
|
435
|
+
with torch.no_grad():
|
|
436
|
+
for mod in self.body.modules():
|
|
437
|
+
if isinstance(mod, CrossFormerStage):
|
|
438
|
+
for m in mod.modules():
|
|
439
|
+
if isinstance(m, PatchMerging):
|
|
440
|
+
m.input_resolution = input_resolution
|
|
441
|
+
input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
|
|
442
|
+
elif isinstance(m, CrossFormerBlock):
|
|
443
|
+
m.input_resolution = input_resolution
|
|
444
|
+
|
|
445
|
+
mod.resolution = input_resolution
|
|
446
|
+
|
|
447
|
+
new_group_size = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
|
|
448
|
+
for m in self.body.modules():
|
|
449
|
+
if isinstance(m, CrossFormerBlock):
|
|
450
|
+
m.group_size = new_group_size
|
|
451
|
+
if m.input_resolution[0] <= m.group_size[0]:
|
|
452
|
+
m.use_lda = False
|
|
453
|
+
m.group_size = (m.input_resolution[0], m.group_size[1])
|
|
454
|
+
if m.input_resolution[1] <= m.group_size[1]:
|
|
455
|
+
m.use_lda = False
|
|
456
|
+
m.group_size = (m.group_size[0], m.input_resolution[1])
|
|
457
|
+
|
|
458
|
+
elif isinstance(m, Attention):
|
|
459
|
+
m.group_size = new_group_size
|
|
460
|
+
m.define_bias_table()
|
|
461
|
+
m.define_relative_position_index()
|
|
459
462
|
|
|
460
463
|
|
|
461
464
|
registry.register_model_config(
|
birder/net/crossvit.py
CHANGED
|
@@ -359,9 +359,10 @@ class CrossViT(BaseNet):
|
|
|
359
359
|
old_w = old_size[1] // self.patch_size[i]
|
|
360
360
|
h = new_size[0] // self.patch_size[i]
|
|
361
361
|
w = new_size[1] // self.patch_size[i]
|
|
362
|
-
|
|
363
|
-
adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
|
|
364
|
-
|
|
362
|
+
with torch.no_grad():
|
|
363
|
+
pos_embed = adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
|
|
364
|
+
|
|
365
|
+
self.pos_embed[i] = nn.Parameter(pos_embed)
|
|
365
366
|
|
|
366
367
|
|
|
367
368
|
registry.register_model_config(
|
birder/net/deit.py
CHANGED
|
@@ -187,14 +187,14 @@ class DeiT(BaseNet):
|
|
|
187
187
|
num_prefix_tokens = 2
|
|
188
188
|
|
|
189
189
|
# Add back class tokens
|
|
190
|
-
|
|
191
|
-
adjust_position_embedding(
|
|
190
|
+
with torch.no_grad():
|
|
191
|
+
pos_embedding = adjust_position_embedding(
|
|
192
192
|
self.pos_embedding,
|
|
193
193
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
194
194
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
195
195
|
num_prefix_tokens,
|
|
196
196
|
)
|
|
197
|
-
)
|
|
197
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
198
198
|
|
|
199
199
|
|
|
200
200
|
registry.register_model_config(
|
birder/net/deit3.py
CHANGED
|
@@ -355,14 +355,14 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
355
355
|
num_prefix_tokens = 0
|
|
356
356
|
|
|
357
357
|
# Add back class tokens
|
|
358
|
-
|
|
359
|
-
adjust_position_embedding(
|
|
358
|
+
with torch.no_grad():
|
|
359
|
+
pos_embedding = adjust_position_embedding(
|
|
360
360
|
self.pos_embedding,
|
|
361
361
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
362
362
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
363
363
|
num_prefix_tokens,
|
|
364
364
|
)
|
|
365
|
-
)
|
|
365
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
366
366
|
|
|
367
367
|
|
|
368
368
|
registry.register_model_config(
|
|
@@ -757,11 +757,8 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
757
757
|
for s, l, b in zip(scores, labels, boxes):
|
|
758
758
|
# Non-maximum suppression
|
|
759
759
|
if self.soft_nms is not None:
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
(soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
|
|
763
|
-
keep = keep.to(device)
|
|
764
|
-
s[keep] = soft_scores.to(device)
|
|
760
|
+
(soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
761
|
+
s[keep] = soft_scores
|
|
765
762
|
|
|
766
763
|
b = b[keep]
|
|
767
764
|
s = s[keep]
|
birder/net/detection/detr.py
CHANGED
|
@@ -465,11 +465,8 @@ class DETR(DetectionBaseNet):
|
|
|
465
465
|
for s, l, b in zip(scores, labels, boxes):
|
|
466
466
|
# Non-maximum suppression
|
|
467
467
|
if self.soft_nms is not None:
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
(soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
|
|
471
|
-
keep = keep.to(device)
|
|
472
|
-
s[keep] = soft_scores.to(device)
|
|
468
|
+
(soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
469
|
+
s[keep] = soft_scores
|
|
473
470
|
|
|
474
471
|
b = b[keep]
|
|
475
472
|
s = s[keep]
|
|
@@ -83,32 +83,25 @@ class Interpolate2d(nn.Module):
|
|
|
83
83
|
|
|
84
84
|
def __init__(
|
|
85
85
|
self,
|
|
86
|
-
size: Optional[int | tuple[int, int]] = None,
|
|
87
|
-
scale_factor: Optional[float | tuple[float, float]] = None,
|
|
88
86
|
mode: str = "nearest",
|
|
89
87
|
align_corners: Optional[bool] = False,
|
|
90
88
|
) -> None:
|
|
91
89
|
super().__init__()
|
|
92
|
-
self.size = size
|
|
93
|
-
self.scale_factor = scale_factor
|
|
94
90
|
self.mode = mode
|
|
95
91
|
self.align_corners = align_corners
|
|
96
92
|
if mode == "nearest":
|
|
97
93
|
self.align_corners = None
|
|
98
94
|
|
|
99
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
)
|
|
95
|
+
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
96
|
+
size_list = [size[0], size[1]]
|
|
97
|
+
return F.interpolate(x, size_list, None, self.mode, self.align_corners, recompute_scale_factor=False)
|
|
103
98
|
|
|
104
99
|
|
|
105
|
-
class ResampleFeatureMap(nn.
|
|
100
|
+
class ResampleFeatureMap(nn.Module):
|
|
106
101
|
def __init__(
|
|
107
102
|
self,
|
|
108
103
|
in_channels: int,
|
|
109
104
|
out_channels: int,
|
|
110
|
-
input_size: tuple[int, int],
|
|
111
|
-
output_size: tuple[int, int],
|
|
112
105
|
downsample: Literal["max", "bilinear"],
|
|
113
106
|
upsample: Literal["nearest", "bilinear"],
|
|
114
107
|
norm_layer: Optional[Callable[..., nn.Module]],
|
|
@@ -116,46 +109,63 @@ class ResampleFeatureMap(nn.Sequential):
|
|
|
116
109
|
super().__init__()
|
|
117
110
|
self.in_channels = in_channels
|
|
118
111
|
self.out_channels = out_channels
|
|
119
|
-
self.
|
|
120
|
-
self.output_size = output_size
|
|
112
|
+
self.downsample_mode = downsample
|
|
121
113
|
|
|
122
114
|
if in_channels != out_channels:
|
|
123
115
|
# padding = ((stride - 1) + (kernel_size - 1)) // 2
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
bias=False,
|
|
134
|
-
activation_layer=None,
|
|
135
|
-
),
|
|
116
|
+
self.conv = Conv2dNormActivation(
|
|
117
|
+
in_channels,
|
|
118
|
+
out_channels,
|
|
119
|
+
kernel_size=(1, 1),
|
|
120
|
+
stride=(1, 1),
|
|
121
|
+
padding=(0, 0),
|
|
122
|
+
norm_layer=norm_layer,
|
|
123
|
+
bias=False,
|
|
124
|
+
activation_layer=None,
|
|
136
125
|
)
|
|
126
|
+
else:
|
|
127
|
+
self.conv = None
|
|
137
128
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
129
|
+
self.downsample = None
|
|
130
|
+
if downsample != "max":
|
|
131
|
+
self.downsample = Interpolate2d(mode=downsample)
|
|
132
|
+
|
|
133
|
+
self.upsample = Interpolate2d(mode=upsample)
|
|
134
|
+
|
|
135
|
+
def forward(self, x: torch.Tensor, target_size: tuple[int, int]) -> torch.Tensor:
|
|
136
|
+
if self.conv is not None:
|
|
137
|
+
x = self.conv(x)
|
|
138
|
+
|
|
139
|
+
(in_h, in_w) = x.shape[-2:]
|
|
140
|
+
(target_h, target_w) = target_size
|
|
141
|
+
if in_h == target_h and in_w == target_w:
|
|
142
|
+
return x
|
|
143
|
+
|
|
144
|
+
downsample_needed = in_h > target_h or in_w > target_w
|
|
145
|
+
upsample_needed = in_h < target_h or in_w < target_w
|
|
146
|
+
|
|
147
|
+
if downsample_needed is True and upsample_needed is False:
|
|
148
|
+
if self.downsample_mode == "max":
|
|
149
|
+
stride_size_h = int((in_h - 1) // target_h + 1)
|
|
150
|
+
stride_size_w = int((in_w - 1) // target_w + 1)
|
|
142
151
|
kernel_size = (stride_size_h + 1, stride_size_w + 1)
|
|
143
152
|
stride = (stride_size_h, stride_size_w)
|
|
144
153
|
padding = (
|
|
145
154
|
((stride[0] - 1) + (kernel_size[0] - 1)) // 2,
|
|
146
155
|
((stride[1] - 1) + (kernel_size[1] - 1)) // 2,
|
|
147
156
|
)
|
|
157
|
+
return F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
148
158
|
|
|
149
|
-
|
|
159
|
+
if self.downsample is not None:
|
|
160
|
+
return self.downsample(x, size=target_size)
|
|
150
161
|
|
|
151
|
-
|
|
152
|
-
|
|
162
|
+
if upsample_needed is True and downsample_needed is False:
|
|
163
|
+
return self.upsample(x, size=target_size)
|
|
153
164
|
|
|
154
|
-
|
|
165
|
+
if self.downsample is not None and self.downsample_mode != "max":
|
|
166
|
+
return self.downsample(x, size=target_size)
|
|
155
167
|
|
|
156
|
-
|
|
157
|
-
if input_size[0] < output_size[0] or input_size[1] < output_size[1]:
|
|
158
|
-
self.add_module("upsample", Interpolate2d(size=output_size, mode=upsample))
|
|
168
|
+
return self.upsample(x, size=target_size)
|
|
159
169
|
|
|
160
170
|
|
|
161
171
|
class FpnCombine(nn.Module):
|
|
@@ -164,8 +174,6 @@ class FpnCombine(nn.Module):
|
|
|
164
174
|
in_channels: list[int],
|
|
165
175
|
fpn_channels: int,
|
|
166
176
|
inputs_offsets: list[int],
|
|
167
|
-
input_size: list[tuple[int, int]],
|
|
168
|
-
output_size: tuple[int, int],
|
|
169
177
|
downsample: Literal["max", "bilinear"],
|
|
170
178
|
upsample: Literal["nearest", "bilinear"],
|
|
171
179
|
norm_layer: Optional[Callable[..., nn.Module]],
|
|
@@ -173,14 +181,14 @@ class FpnCombine(nn.Module):
|
|
|
173
181
|
):
|
|
174
182
|
super().__init__()
|
|
175
183
|
self.weight_method = weight_method
|
|
184
|
+
self.inputs_offsets = inputs_offsets
|
|
185
|
+
self.target_offset = inputs_offsets[0]
|
|
176
186
|
|
|
177
187
|
self.resample = nn.ModuleDict()
|
|
178
188
|
for offset in inputs_offsets:
|
|
179
189
|
self.resample[str(offset)] = ResampleFeatureMap(
|
|
180
190
|
in_channels[offset],
|
|
181
191
|
fpn_channels,
|
|
182
|
-
input_size=input_size[offset],
|
|
183
|
-
output_size=output_size,
|
|
184
192
|
downsample=downsample,
|
|
185
193
|
upsample=upsample,
|
|
186
194
|
norm_layer=norm_layer,
|
|
@@ -193,10 +201,12 @@ class FpnCombine(nn.Module):
|
|
|
193
201
|
|
|
194
202
|
def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
195
203
|
dtype = x[0].dtype
|
|
204
|
+
target = x[self.target_offset]
|
|
205
|
+
target_size = (int(target.shape[-2]), int(target.shape[-1]))
|
|
196
206
|
nodes = []
|
|
197
207
|
for offset, resample in self.resample.items():
|
|
198
208
|
input_node = x[int(offset)]
|
|
199
|
-
input_node = resample(input_node)
|
|
209
|
+
input_node = resample(input_node, target_size=target_size)
|
|
200
210
|
nodes.append(input_node)
|
|
201
211
|
|
|
202
212
|
if self.weight_method == "attn":
|
|
@@ -231,8 +241,6 @@ class BiFpnLayer(nn.Module):
|
|
|
231
241
|
def __init__(
|
|
232
242
|
self,
|
|
233
243
|
in_channels: list[int],
|
|
234
|
-
input_size: list[tuple[int, int]],
|
|
235
|
-
feat_sizes: list[tuple[int, int]],
|
|
236
244
|
fpn_config: list[dict[str, Any]],
|
|
237
245
|
fpn_channels: int,
|
|
238
246
|
num_levels: int,
|
|
@@ -248,8 +256,6 @@ class BiFpnLayer(nn.Module):
|
|
|
248
256
|
in_channels,
|
|
249
257
|
fpn_channels,
|
|
250
258
|
inputs_offsets=fnode_cfg["inputs_offsets"],
|
|
251
|
-
input_size=input_size,
|
|
252
|
-
output_size=feat_sizes[fnode_cfg["feat_level"]],
|
|
253
259
|
downsample=downsample,
|
|
254
260
|
upsample=upsample,
|
|
255
261
|
norm_layer=norm_layer,
|
|
@@ -290,9 +296,6 @@ class BiFpnLayer(nn.Module):
|
|
|
290
296
|
class BiFpn(nn.Module):
|
|
291
297
|
def __init__(
|
|
292
298
|
self,
|
|
293
|
-
image_size: tuple[int, int],
|
|
294
|
-
min_level: int,
|
|
295
|
-
max_level: int,
|
|
296
299
|
num_levels: int,
|
|
297
300
|
backbone_channels: list[int],
|
|
298
301
|
fpn_channels: int,
|
|
@@ -300,45 +303,29 @@ class BiFpn(nn.Module):
|
|
|
300
303
|
bifpn_config: list[dict[str, Any]],
|
|
301
304
|
):
|
|
302
305
|
super().__init__()
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
input_size = input_size[-num_levels:]
|
|
311
|
-
prev_feat_size = feat_sizes[min_level]
|
|
312
|
-
self.resample = nn.ModuleDict()
|
|
313
|
-
for level in range(num_levels):
|
|
314
|
-
feat_size = feat_sizes[level + min_level]
|
|
315
|
-
if level < len(backbone_channels):
|
|
316
|
-
in_channels = backbone_channels[level]
|
|
317
|
-
input_size[level] = feat_size
|
|
318
|
-
else:
|
|
319
|
-
self.resample[str(level)] = ResampleFeatureMap(
|
|
306
|
+
self.resample = nn.ModuleList()
|
|
307
|
+
num_backbone_levels = len(backbone_channels)
|
|
308
|
+
extra_levels = max(0, num_levels - num_backbone_levels)
|
|
309
|
+
in_channels = backbone_channels[-1]
|
|
310
|
+
for _ in range(extra_levels):
|
|
311
|
+
self.resample.append(
|
|
312
|
+
ResampleFeatureMap(
|
|
320
313
|
in_channels=in_channels,
|
|
321
314
|
out_channels=fpn_channels,
|
|
322
|
-
input_size=prev_feat_size,
|
|
323
|
-
output_size=feat_size,
|
|
324
315
|
downsample="max",
|
|
325
316
|
upsample="nearest",
|
|
326
317
|
norm_layer=nn.BatchNorm2d,
|
|
327
318
|
)
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
prev_feat_size = feat_size
|
|
319
|
+
)
|
|
320
|
+
in_channels = fpn_channels
|
|
321
|
+
backbone_channels.append(in_channels)
|
|
332
322
|
|
|
333
323
|
self.cells = nn.ModuleList()
|
|
334
324
|
fpn_combine_channels = backbone_channels
|
|
335
325
|
for _ in range(fpn_cell_repeats):
|
|
336
326
|
fpn_combine_channels = fpn_combine_channels + [fpn_channels for _ in bifpn_config]
|
|
337
|
-
input_size = input_size + [feat_sizes[fc["feat_level"]] for fc in bifpn_config]
|
|
338
327
|
fpn_layer = BiFpnLayer(
|
|
339
328
|
in_channels=fpn_combine_channels,
|
|
340
|
-
input_size=input_size,
|
|
341
|
-
feat_sizes=feat_sizes,
|
|
342
329
|
fpn_config=bifpn_config,
|
|
343
330
|
fpn_channels=fpn_channels,
|
|
344
331
|
num_levels=num_levels,
|
|
@@ -348,11 +335,12 @@ class BiFpn(nn.Module):
|
|
|
348
335
|
)
|
|
349
336
|
self.cells.append(fpn_layer)
|
|
350
337
|
fpn_combine_channels = fpn_combine_channels[-num_levels::]
|
|
351
|
-
input_size = input_size[-num_levels::]
|
|
352
338
|
|
|
353
339
|
def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
354
|
-
for resample in self.resample
|
|
355
|
-
x
|
|
340
|
+
for resample in self.resample:
|
|
341
|
+
input_node = x[-1]
|
|
342
|
+
target_size = ((input_node.shape[-2] - 1) // 2 + 1, (input_node.shape[-1] - 1) // 2 + 1)
|
|
343
|
+
x.append(resample(input_node, target_size=target_size))
|
|
356
344
|
|
|
357
345
|
for cell in self.cells:
|
|
358
346
|
x = cell(x)
|
|
@@ -572,9 +560,6 @@ class EfficientDet(DetectionBaseNet):
|
|
|
572
560
|
self.backbone.return_stages = self.backbone.return_stages[-3:]
|
|
573
561
|
|
|
574
562
|
self.bifpn = BiFpn(
|
|
575
|
-
image_size=self.size,
|
|
576
|
-
min_level=min_level,
|
|
577
|
-
max_level=max_level,
|
|
578
563
|
num_levels=num_levels,
|
|
579
564
|
backbone_channels=self.backbone.return_channels,
|
|
580
565
|
fpn_channels=fpn_channels,
|
|
@@ -614,12 +599,6 @@ class EfficientDet(DetectionBaseNet):
|
|
|
614
599
|
num_anchors=self.anchor_generator.num_anchors_per_location()[0],
|
|
615
600
|
)
|
|
616
601
|
|
|
617
|
-
def adjust_size(self, new_size: tuple[int, int]) -> None:
|
|
618
|
-
if new_size == self.size:
|
|
619
|
-
return
|
|
620
|
-
|
|
621
|
-
raise RuntimeError("Model resizing not supported")
|
|
622
|
-
|
|
623
602
|
def freeze(self, freeze_classifier: bool = True) -> None:
|
|
624
603
|
for param in self.parameters():
|
|
625
604
|
param.requires_grad = False
|
|
@@ -706,13 +685,8 @@ class EfficientDet(DetectionBaseNet):
|
|
|
706
685
|
|
|
707
686
|
# Non-maximum suppression
|
|
708
687
|
if self.soft_nms is not None:
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
(soft_scores, keep) = self.soft_nms(
|
|
712
|
-
image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
|
|
713
|
-
)
|
|
714
|
-
keep = keep.to(device)
|
|
715
|
-
image_scores[keep] = soft_scores.to(device)
|
|
688
|
+
(soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
689
|
+
image_scores[keep] = soft_scores
|
|
716
690
|
else:
|
|
717
691
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
718
692
|
|
birder/net/detection/fcos.py
CHANGED
|
@@ -455,13 +455,8 @@ class FCOS(DetectionBaseNet):
|
|
|
455
455
|
|
|
456
456
|
# Non-maximum suppression
|
|
457
457
|
if self.soft_nms is not None:
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
(soft_scores, keep) = self.soft_nms(
|
|
461
|
-
image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
|
|
462
|
-
)
|
|
463
|
-
keep = keep.to(device)
|
|
464
|
-
image_scores[keep] = soft_scores.to(device)
|
|
458
|
+
(soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
459
|
+
image_scores[keep] = soft_scores
|
|
465
460
|
else:
|
|
466
461
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
467
462
|
|
|
@@ -417,13 +417,8 @@ class RetinaNet(DetectionBaseNet):
|
|
|
417
417
|
|
|
418
418
|
# Non-maximum suppression
|
|
419
419
|
if self.soft_nms is not None:
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
(soft_scores, keep) = self.soft_nms(
|
|
423
|
-
image_boxes.cpu(), image_scores.cpu(), image_labels.cpu(), score_threshold=0.001
|
|
424
|
-
)
|
|
425
|
-
keep = keep.to(device)
|
|
426
|
-
image_scores[keep] = soft_scores.to(device)
|
|
420
|
+
(soft_scores, keep) = self.soft_nms(image_boxes, image_scores, image_labels, score_threshold=0.001)
|
|
421
|
+
image_scores[keep] = soft_scores
|
|
427
422
|
else:
|
|
428
423
|
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
|
|
429
424
|
|
|
@@ -1070,6 +1070,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1070
1070
|
W = feat.shape[3]
|
|
1071
1071
|
spatial_shapes.append([H, W])
|
|
1072
1072
|
level_start_index.append(H * W + level_start_index[-1])
|
|
1073
|
+
|
|
1073
1074
|
level_start_index.pop()
|
|
1074
1075
|
|
|
1075
1076
|
detections: list[dict[str, torch.Tensor]] = []
|
|
@@ -1086,6 +1087,7 @@ class RT_DETR_v1(DetectionBaseNet):
|
|
|
1086
1087
|
|
|
1087
1088
|
return (detections, losses)
|
|
1088
1089
|
|
|
1090
|
+
@torch.no_grad() # type: ignore[untyped-decorator]
|
|
1089
1091
|
def reparameterize_model(self) -> None:
|
|
1090
1092
|
if self.reparameterized is True:
|
|
1091
1093
|
return
|