birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/convnext_v2.py
CHANGED
|
@@ -56,15 +56,7 @@ class ConvNeXtBlock(nn.Module):
|
|
|
56
56
|
) -> None:
|
|
57
57
|
super().__init__()
|
|
58
58
|
self.block = nn.Sequential(
|
|
59
|
-
nn.Conv2d(
|
|
60
|
-
channels,
|
|
61
|
-
channels,
|
|
62
|
-
kernel_size=(7, 7),
|
|
63
|
-
stride=(1, 1),
|
|
64
|
-
padding=(3, 3),
|
|
65
|
-
groups=channels,
|
|
66
|
-
bias=True,
|
|
67
|
-
),
|
|
59
|
+
nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
|
|
68
60
|
Permute([0, 2, 3, 1]),
|
|
69
61
|
nn.LayerNorm(channels, eps=1e-6),
|
|
70
62
|
nn.Linear(channels, 4 * channels), # Same as 1x1 conv
|
|
@@ -137,7 +129,7 @@ class ConvNeXt_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
137
129
|
layers.append(
|
|
138
130
|
nn.Sequential(
|
|
139
131
|
LayerNorm2d(i, eps=1e-6),
|
|
140
|
-
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)
|
|
132
|
+
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
|
|
141
133
|
)
|
|
142
134
|
)
|
|
143
135
|
|
birder/net/crossformer.py
CHANGED
|
@@ -120,9 +120,9 @@ class Attention(nn.Module):
|
|
|
120
120
|
self.relative_position_index = nn.Buffer(relative_position_index)
|
|
121
121
|
|
|
122
122
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
123
|
-
|
|
123
|
+
B, N, C = x.size()
|
|
124
124
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
125
|
-
|
|
125
|
+
q, k, v = qkv.unbind(0)
|
|
126
126
|
|
|
127
127
|
q = q * self.scale
|
|
128
128
|
attn = q @ k.transpose(-2, -1)
|
|
@@ -188,15 +188,15 @@ class CrossFormerBlock(nn.Module):
|
|
|
188
188
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
189
189
|
|
|
190
190
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
191
|
-
|
|
192
|
-
|
|
191
|
+
H, W = self.input_resolution
|
|
192
|
+
B, _, C = x.size()
|
|
193
193
|
|
|
194
194
|
shortcut = x
|
|
195
195
|
x = self.norm1(x)
|
|
196
196
|
x = x.view(B, H, W, C)
|
|
197
197
|
|
|
198
198
|
# Group embeddings
|
|
199
|
-
|
|
199
|
+
GH, GW = self.group_size # pylint: disable=invalid-name
|
|
200
200
|
if self.use_lda is False:
|
|
201
201
|
x = x.reshape(B, H // GH, GH, W // GW, GW, C).permute(0, 1, 3, 2, 4, 5)
|
|
202
202
|
else:
|
|
@@ -244,8 +244,8 @@ class PatchMerging(nn.Module):
|
|
|
244
244
|
)
|
|
245
245
|
|
|
246
246
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
247
|
-
|
|
248
|
-
|
|
247
|
+
H, W = self.input_resolution
|
|
248
|
+
B, _, C = x.shape
|
|
249
249
|
|
|
250
250
|
x = self.norm(x)
|
|
251
251
|
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
|
|
@@ -396,8 +396,8 @@ class CrossFormer(DetectorBackbone):
|
|
|
396
396
|
for name, module in self.body.named_children():
|
|
397
397
|
x = module(x)
|
|
398
398
|
if name in self.return_stages:
|
|
399
|
-
|
|
400
|
-
|
|
399
|
+
H, W = module.resolution
|
|
400
|
+
B, _, C = x.size()
|
|
401
401
|
out[name] = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
402
402
|
|
|
403
403
|
return out
|
birder/net/crossvit.py
CHANGED
|
@@ -74,7 +74,7 @@ class CrossAttention(nn.Module):
|
|
|
74
74
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
75
75
|
|
|
76
76
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
77
|
-
|
|
77
|
+
B, N, C = x.shape
|
|
78
78
|
# B1C -> B1H(C/H) -> BH1(C/H)
|
|
79
79
|
q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
80
80
|
# BNC -> BNH(C/H) -> BHN(C/H)
|
birder/net/cspnet.py
CHANGED
|
@@ -226,7 +226,7 @@ class CrossStage(nn.Module):
|
|
|
226
226
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
227
227
|
x = self.conv_down(x)
|
|
228
228
|
x = self.conv_exp(x)
|
|
229
|
-
|
|
229
|
+
xs, xb = x.split(self.expand_channels // 2, dim=1)
|
|
230
230
|
xb = self.blocks(xb)
|
|
231
231
|
xb = self.conv_transition_b(xb).contiguous()
|
|
232
232
|
out = self.conv_transition(torch.concat([xs, xb], dim=1))
|
birder/net/cswin_transformer.py
CHANGED
|
@@ -29,7 +29,7 @@ from birder.net.vit import PatchEmbed
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def img2windows(img: torch.Tensor, h_sp: int, w_sp: int) -> torch.Tensor:
|
|
32
|
-
|
|
32
|
+
B, C, H, W = img.size()
|
|
33
33
|
img_reshape = img.view(B, C, H // h_sp, h_sp, W // w_sp, w_sp)
|
|
34
34
|
img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, h_sp * w_sp, C)
|
|
35
35
|
|
|
@@ -81,7 +81,7 @@ class LePEAttention(nn.Module):
|
|
|
81
81
|
raise ValueError("unsupported idx")
|
|
82
82
|
|
|
83
83
|
def im2cswin(self, x: torch.Tensor) -> torch.Tensor:
|
|
84
|
-
|
|
84
|
+
B, _, C = x.size()
|
|
85
85
|
x = x.transpose(-2, -1).contiguous().view(B, C, self.resolution[0], self.resolution[1])
|
|
86
86
|
x = img2windows(x, self.h_sp, self.w_sp)
|
|
87
87
|
x = x.reshape(-1, self.h_sp * self.w_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
|
|
@@ -89,7 +89,7 @@ class LePEAttention(nn.Module):
|
|
|
89
89
|
return x
|
|
90
90
|
|
|
91
91
|
def get_lepe(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
92
|
-
|
|
92
|
+
B, _, C = x.size()
|
|
93
93
|
H = self.resolution[0]
|
|
94
94
|
W = self.resolution[1]
|
|
95
95
|
x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
|
|
@@ -107,13 +107,13 @@ class LePEAttention(nn.Module):
|
|
|
107
107
|
return (x, lepe)
|
|
108
108
|
|
|
109
109
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
110
|
-
|
|
110
|
+
q, k, v = x.unbind(0)
|
|
111
111
|
|
|
112
|
-
|
|
112
|
+
B, _, C = q.shape
|
|
113
113
|
|
|
114
114
|
q = self.im2cswin(q)
|
|
115
115
|
k = self.im2cswin(k)
|
|
116
|
-
|
|
116
|
+
v, lepe = self.get_lepe(v)
|
|
117
117
|
|
|
118
118
|
q = q * self.scale
|
|
119
119
|
attn = q @ k.transpose(-2, -1) # B head N C @ B head C N --> B head N N
|
|
@@ -136,12 +136,12 @@ class MergeBlock(nn.Module):
|
|
|
136
136
|
self.resolution = resolution
|
|
137
137
|
|
|
138
138
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
139
|
-
|
|
139
|
+
B, _, C = x.size()
|
|
140
140
|
H = self.resolution[0]
|
|
141
141
|
W = self.resolution[1]
|
|
142
142
|
x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
|
|
143
143
|
x = self.conv(x)
|
|
144
|
-
|
|
144
|
+
B, C = x.shape[:2]
|
|
145
145
|
x = x.view(B, C, -1).transpose(-2, -1).contiguous()
|
|
146
146
|
x = self.norm(x)
|
|
147
147
|
|
|
@@ -206,7 +206,7 @@ class CSWinBlock(nn.Module):
|
|
|
206
206
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
207
207
|
|
|
208
208
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
209
|
-
|
|
209
|
+
B, _, C = x.shape
|
|
210
210
|
|
|
211
211
|
qkv = self.qkv(self.norm1(x)).reshape(B, -1, 3, C).permute(2, 0, 1, 3)
|
|
212
212
|
if self.branch_num == 2:
|
|
@@ -350,7 +350,7 @@ class CSWin_Transformer(DetectorBackbone):
|
|
|
350
350
|
for name, module in self.body.named_children():
|
|
351
351
|
x = module(x)
|
|
352
352
|
if name in self.return_stages:
|
|
353
|
-
|
|
353
|
+
B, L, C = x.size()
|
|
354
354
|
H = int(math.sqrt(L))
|
|
355
355
|
W = H
|
|
356
356
|
out[name] = x.transpose(-2, -1).contiguous().view(B, C, H, W)
|
birder/net/davit.py
CHANGED
|
@@ -31,7 +31,7 @@ from birder.net.base import TokenRetentionResultType
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def window_partition(x: torch.Tensor, window_size: tuple[int, int]) -> torch.Tensor:
|
|
34
|
-
|
|
34
|
+
B, H, W, C = x.shape
|
|
35
35
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
|
36
36
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
|
37
37
|
|
|
@@ -92,10 +92,10 @@ class Downsample(nn.Module):
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
95
|
-
|
|
95
|
+
_, _, H, W = x.shape
|
|
96
96
|
x = self.norm(x)
|
|
97
97
|
if self.even_k is True:
|
|
98
|
-
|
|
98
|
+
k_h, k_w = self.conv.kernel_size
|
|
99
99
|
pad_r = (k_w - W % k_w) % k_w
|
|
100
100
|
pad_b = (k_h - H % k_h) % k_h
|
|
101
101
|
x = F.pad(x, (0, pad_r, 0, pad_b))
|
|
@@ -115,10 +115,10 @@ class ChannelAttention(nn.Module):
|
|
|
115
115
|
self.proj = nn.Linear(dim, dim)
|
|
116
116
|
|
|
117
117
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
118
|
-
|
|
118
|
+
B, N, C = x.shape
|
|
119
119
|
|
|
120
120
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
121
|
-
|
|
121
|
+
q, k, v = qkv.unbind(0)
|
|
122
122
|
|
|
123
123
|
k = k * self.scale
|
|
124
124
|
attn = k.transpose(-1, -2) @ v
|
|
@@ -151,7 +151,7 @@ class ChannelBlock(nn.Module):
|
|
|
151
151
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
152
152
|
|
|
153
153
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
154
|
-
|
|
154
|
+
B, C, H, W = x.shape
|
|
155
155
|
x = self.cpe1(x).flatten(2).transpose(1, 2)
|
|
156
156
|
|
|
157
157
|
cur = self.norm1(x)
|
|
@@ -177,10 +177,10 @@ class WindowAttention(nn.Module):
|
|
|
177
177
|
self.proj = nn.Linear(dim, dim)
|
|
178
178
|
|
|
179
179
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
180
|
-
|
|
180
|
+
B, N, C = x.shape
|
|
181
181
|
|
|
182
182
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
183
|
-
|
|
183
|
+
q, k, v = qkv.unbind(0)
|
|
184
184
|
|
|
185
185
|
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # pylint: disable=not-callable
|
|
186
186
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
@@ -215,7 +215,7 @@ class SpatialBlock(nn.Module):
|
|
|
215
215
|
|
|
216
216
|
# pylint: disable=invalid-name
|
|
217
217
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
218
|
-
|
|
218
|
+
B, C, H, W = x.shape
|
|
219
219
|
|
|
220
220
|
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
|
|
221
221
|
|
|
@@ -226,7 +226,7 @@ class SpatialBlock(nn.Module):
|
|
|
226
226
|
pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
|
227
227
|
pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
|
228
228
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
|
229
|
-
|
|
229
|
+
_, Hp, Wp, _ = x.shape
|
|
230
230
|
|
|
231
231
|
x_windows = window_partition(x, self.window_size)
|
|
232
232
|
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
birder/net/deit.py
CHANGED
|
@@ -19,13 +19,15 @@ from birder.model_registry import registry
|
|
|
19
19
|
from birder.net._vit_configs import BASE
|
|
20
20
|
from birder.net._vit_configs import SMALL
|
|
21
21
|
from birder.net._vit_configs import TINY
|
|
22
|
-
from birder.net.base import
|
|
22
|
+
from birder.net.base import DetectorBackbone
|
|
23
|
+
from birder.net.base import normalize_out_indices
|
|
23
24
|
from birder.net.vit import Encoder
|
|
24
25
|
from birder.net.vit import PatchEmbed
|
|
25
26
|
from birder.net.vit import adjust_position_embedding
|
|
26
27
|
|
|
27
28
|
|
|
28
|
-
|
|
29
|
+
# pylint: disable=too-many-instance-attributes
|
|
30
|
+
class DeiT(DetectorBackbone):
|
|
29
31
|
block_group_regex = r"encoder\.block\.(\d+)"
|
|
30
32
|
|
|
31
33
|
def __init__(
|
|
@@ -47,6 +49,7 @@ class DeiT(BaseNet):
|
|
|
47
49
|
num_heads: int = self.config["num_heads"]
|
|
48
50
|
hidden_dim: int = self.config["hidden_dim"]
|
|
49
51
|
mlp_dim: int = self.config["mlp_dim"]
|
|
52
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
50
53
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
51
54
|
|
|
52
55
|
torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
@@ -56,6 +59,7 @@ class DeiT(BaseNet):
|
|
|
56
59
|
self.num_layers = num_layers
|
|
57
60
|
self.hidden_dim = hidden_dim
|
|
58
61
|
self.num_special_tokens = 2
|
|
62
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
59
63
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
60
64
|
|
|
61
65
|
self.conv_proj = nn.Conv2d(
|
|
@@ -64,7 +68,6 @@ class DeiT(BaseNet):
|
|
|
64
68
|
kernel_size=(patch_size, patch_size),
|
|
65
69
|
stride=(patch_size, patch_size),
|
|
66
70
|
padding=(0, 0),
|
|
67
|
-
bias=True,
|
|
68
71
|
)
|
|
69
72
|
self.patch_embed = PatchEmbed()
|
|
70
73
|
|
|
@@ -92,6 +95,9 @@ class DeiT(BaseNet):
|
|
|
92
95
|
)
|
|
93
96
|
self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
94
97
|
|
|
98
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
99
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
100
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
95
101
|
self.embedding_size = hidden_dim
|
|
96
102
|
self.dist_classifier = self.create_classifier()
|
|
97
103
|
self.classifier = self.create_classifier()
|
|
@@ -136,6 +142,53 @@ class DeiT(BaseNet):
|
|
|
136
142
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
137
143
|
self.encoder.set_causal_attention(is_causal)
|
|
138
144
|
|
|
145
|
+
def transform_to_backbone(self) -> None:
|
|
146
|
+
super().transform_to_backbone()
|
|
147
|
+
self.norm = nn.Identity()
|
|
148
|
+
self.dist_classifier = nn.Identity()
|
|
149
|
+
|
|
150
|
+
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
151
|
+
H, W = x.shape[-2:]
|
|
152
|
+
|
|
153
|
+
# Reshape and permute the input tensor
|
|
154
|
+
x = self.conv_proj(x)
|
|
155
|
+
x = self.patch_embed(x)
|
|
156
|
+
|
|
157
|
+
# Expand the class token to the full batch
|
|
158
|
+
batch_class_token = self.class_token.expand(x.shape[0], -1, -1)
|
|
159
|
+
batch_dist_token = self.dist_token.expand(x.shape[0], -1, -1)
|
|
160
|
+
|
|
161
|
+
x = torch.concat([batch_class_token, batch_dist_token, x], dim=1)
|
|
162
|
+
x = x + self.pos_embedding
|
|
163
|
+
|
|
164
|
+
if self.out_indices is None:
|
|
165
|
+
xs = [self.encoder(x)]
|
|
166
|
+
else:
|
|
167
|
+
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
168
|
+
|
|
169
|
+
out: dict[str, torch.Tensor] = {}
|
|
170
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
171
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
172
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
173
|
+
B, C, _ = stage_x.size()
|
|
174
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
175
|
+
out[stage_name] = stage_x
|
|
176
|
+
|
|
177
|
+
return out
|
|
178
|
+
|
|
179
|
+
def freeze_stages(self, up_to_stage: int) -> None:
|
|
180
|
+
for param in self.conv_proj.parameters():
|
|
181
|
+
param.requires_grad_(False)
|
|
182
|
+
|
|
183
|
+
self.pos_embedding.requires_grad_(False)
|
|
184
|
+
|
|
185
|
+
for idx, module in enumerate(self.encoder.children()):
|
|
186
|
+
if idx >= up_to_stage:
|
|
187
|
+
break
|
|
188
|
+
|
|
189
|
+
for param in module.parameters():
|
|
190
|
+
param.requires_grad_(False)
|
|
191
|
+
|
|
139
192
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
140
193
|
# Reshape and permute the input tensor
|
|
141
194
|
x = self.conv_proj(x)
|
birder/net/deit3.py
CHANGED
|
@@ -27,6 +27,7 @@ from birder.net.base import MaskedTokenRetentionMixin
|
|
|
27
27
|
from birder.net.base import PreTrainEncoder
|
|
28
28
|
from birder.net.base import TokenOmissionResultType
|
|
29
29
|
from birder.net.base import TokenRetentionResultType
|
|
30
|
+
from birder.net.base import normalize_out_indices
|
|
30
31
|
from birder.net.vit import Encoder
|
|
31
32
|
from birder.net.vit import EncoderBlock
|
|
32
33
|
from birder.net.vit import PatchEmbed
|
|
@@ -59,6 +60,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
59
60
|
mlp_dim: int = self.config["mlp_dim"]
|
|
60
61
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", 1e-5)
|
|
61
62
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
63
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
62
64
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
63
65
|
|
|
64
66
|
torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
@@ -70,6 +72,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
70
72
|
self.num_reg_tokens = num_reg_tokens
|
|
71
73
|
self.num_special_tokens = 1 + self.num_reg_tokens
|
|
72
74
|
self.pos_embed_special_tokens = pos_embed_special_tokens
|
|
75
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
73
76
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
74
77
|
|
|
75
78
|
self.conv_proj = nn.Conv2d(
|
|
@@ -78,7 +81,6 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
78
81
|
kernel_size=(patch_size, patch_size),
|
|
79
82
|
stride=(patch_size, patch_size),
|
|
80
83
|
padding=(0, 0),
|
|
81
|
-
bias=True,
|
|
82
84
|
)
|
|
83
85
|
self.patch_embed = PatchEmbed()
|
|
84
86
|
|
|
@@ -112,8 +114,9 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
112
114
|
)
|
|
113
115
|
self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
114
116
|
|
|
115
|
-
self.
|
|
116
|
-
self.
|
|
117
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
118
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
119
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
117
120
|
self.embedding_size = hidden_dim
|
|
118
121
|
self.classifier = self.create_classifier()
|
|
119
122
|
|
|
@@ -159,7 +162,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
159
162
|
)
|
|
160
163
|
|
|
161
164
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
162
|
-
|
|
165
|
+
H, W = x.shape[-2:]
|
|
163
166
|
|
|
164
167
|
x = self.conv_proj(x)
|
|
165
168
|
x = self.patch_embed(x)
|
|
@@ -176,15 +179,20 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
176
179
|
x = x + self._get_pos_embed(H, W)
|
|
177
180
|
x = torch.concat([batch_special_tokens, x], dim=1)
|
|
178
181
|
|
|
179
|
-
|
|
180
|
-
|
|
182
|
+
if self.out_indices is None:
|
|
183
|
+
xs = [self.encoder(x)]
|
|
184
|
+
else:
|
|
185
|
+
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
181
186
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
187
|
+
out: dict[str, torch.Tensor] = {}
|
|
188
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
189
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
190
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
191
|
+
B, C, _ = stage_x.size()
|
|
192
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
193
|
+
out[stage_name] = stage_x
|
|
186
194
|
|
|
187
|
-
return
|
|
195
|
+
return out
|
|
188
196
|
|
|
189
197
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
190
198
|
for param in self.conv_proj.parameters():
|
|
@@ -199,6 +207,10 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
199
207
|
for param in module.parameters():
|
|
200
208
|
param.requires_grad_(False)
|
|
201
209
|
|
|
210
|
+
def transform_to_backbone(self) -> None:
|
|
211
|
+
super().transform_to_backbone()
|
|
212
|
+
self.norm = nn.Identity()
|
|
213
|
+
|
|
202
214
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
203
215
|
self.encoder.set_causal_attention(is_causal)
|
|
204
216
|
|
|
@@ -209,7 +221,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
209
221
|
return_all_features: bool = False,
|
|
210
222
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
211
223
|
) -> TokenOmissionResultType:
|
|
212
|
-
|
|
224
|
+
H, W = x.shape[-2:]
|
|
213
225
|
|
|
214
226
|
# Reshape and permute the input tensor
|
|
215
227
|
x = self.conv_proj(x)
|
|
@@ -272,7 +284,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
272
284
|
mask_token: Optional[torch.Tensor] = None,
|
|
273
285
|
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
274
286
|
) -> TokenRetentionResultType:
|
|
275
|
-
|
|
287
|
+
H, W = x.shape[-2:]
|
|
276
288
|
|
|
277
289
|
x = self.conv_proj(x)
|
|
278
290
|
x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
|
|
@@ -302,7 +314,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
302
314
|
if return_keys in ("all", "features"):
|
|
303
315
|
features = x[:, self.num_special_tokens :]
|
|
304
316
|
features = features.permute(0, 2, 1)
|
|
305
|
-
|
|
317
|
+
B, C, _ = features.size()
|
|
306
318
|
features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
307
319
|
result["features"] = features
|
|
308
320
|
|
|
@@ -312,7 +324,7 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
312
324
|
return result
|
|
313
325
|
|
|
314
326
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
315
|
-
|
|
327
|
+
H, W = x.shape[-2:]
|
|
316
328
|
|
|
317
329
|
# Reshape and permute the input tensor
|
|
318
330
|
x = self.conv_proj(x)
|
birder/net/detection/__init__.py
CHANGED
|
@@ -3,8 +3,10 @@ from birder.net.detection.detr import DETR
|
|
|
3
3
|
from birder.net.detection.efficientdet import EfficientDet
|
|
4
4
|
from birder.net.detection.faster_rcnn import Faster_RCNN
|
|
5
5
|
from birder.net.detection.fcos import FCOS
|
|
6
|
+
from birder.net.detection.plain_detr import Plain_DETR
|
|
6
7
|
from birder.net.detection.retinanet import RetinaNet
|
|
7
8
|
from birder.net.detection.rt_detr_v1 import RT_DETR_v1
|
|
9
|
+
from birder.net.detection.rt_detr_v2 import RT_DETR_v2
|
|
8
10
|
from birder.net.detection.ssd import SSD
|
|
9
11
|
from birder.net.detection.ssdlite import SSDLite
|
|
10
12
|
from birder.net.detection.vitdet import ViTDet
|
|
@@ -19,8 +21,10 @@ __all__ = [
|
|
|
19
21
|
"EfficientDet",
|
|
20
22
|
"Faster_RCNN",
|
|
21
23
|
"FCOS",
|
|
24
|
+
"Plain_DETR",
|
|
22
25
|
"RetinaNet",
|
|
23
26
|
"RT_DETR_v1",
|
|
27
|
+
"RT_DETR_v2",
|
|
24
28
|
"SSD",
|
|
25
29
|
"SSDLite",
|
|
26
30
|
"ViTDet",
|
|
@@ -71,7 +71,7 @@ def scale_anchors(anchors: AnchorGroups, from_size: tuple[int, int], to_size: tu
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
def scale_anchors(anchors: AnchorLike, from_size: tuple[int, int], to_size: tuple[int, int]) -> AnchorLike:
|
|
74
|
-
|
|
74
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
75
75
|
|
|
76
76
|
if from_size == to_size:
|
|
77
77
|
# Avoid aliasing default anchors in case they are mutated later
|
|
@@ -100,7 +100,7 @@ def pixels_to_grid(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
|
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
def pixels_to_grid(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
|
|
103
|
-
|
|
103
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
104
104
|
if len(anchor_groups) != len(strides):
|
|
105
105
|
raise ValueError("strides must provide one value per anchor scale")
|
|
106
106
|
|
|
@@ -123,7 +123,7 @@ def grid_to_pixels(anchors: AnchorGroups, strides: Sequence[int]) -> AnchorGroup
|
|
|
123
123
|
|
|
124
124
|
|
|
125
125
|
def grid_to_pixels(anchors: AnchorLike, strides: Sequence[int]) -> AnchorLike:
|
|
126
|
-
|
|
126
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
127
127
|
if len(anchor_groups) != len(strides):
|
|
128
128
|
raise ValueError("strides must provide one value per anchor scale")
|
|
129
129
|
|
|
@@ -187,7 +187,7 @@ def resolve_anchor_group(
|
|
|
187
187
|
preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
|
|
188
188
|
) -> AnchorGroup:
|
|
189
189
|
anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
|
|
190
|
-
|
|
190
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
191
191
|
if single is False:
|
|
192
192
|
raise ValueError("Expected a single anchor group for this model")
|
|
193
193
|
|
|
@@ -198,7 +198,7 @@ def resolve_anchor_groups(
|
|
|
198
198
|
preset: str, *, anchor_format: str, model_size: tuple[int, int], model_strides: Sequence[int]
|
|
199
199
|
) -> AnchorGroups:
|
|
200
200
|
anchors = _resolve_anchors(preset, anchor_format=anchor_format, model_size=model_size, model_strides=model_strides)
|
|
201
|
-
|
|
201
|
+
anchor_groups, single = _normalize_anchor_groups(anchors)
|
|
202
202
|
if single is True:
|
|
203
203
|
raise ValueError("Expected multiple anchor groups for this model")
|
|
204
204
|
|
birder/net/detection/base.py
CHANGED
|
@@ -41,6 +41,7 @@ def get_detection_signature(input_shape: tuple[int, ...], num_outputs: int, dyna
|
|
|
41
41
|
|
|
42
42
|
class DetectionBaseNet(nn.Module):
|
|
43
43
|
default_size: tuple[int, int]
|
|
44
|
+
block_group_regex: Optional[str]
|
|
44
45
|
auto_register = False
|
|
45
46
|
scriptable = True
|
|
46
47
|
task = str(Task.OBJECT_DETECTION)
|
|
@@ -308,7 +309,7 @@ class SimpleFeaturePyramidNetwork(nn.Module):
|
|
|
308
309
|
names.append(f"stage{idx+1}")
|
|
309
310
|
|
|
310
311
|
if self.extra_blocks is not None:
|
|
311
|
-
|
|
312
|
+
results, names = self.extra_blocks(results, [x], names)
|
|
312
313
|
|
|
313
314
|
out = OrderedDict(list(zip(names, results)))
|
|
314
315
|
|
|
@@ -432,7 +433,7 @@ class BoxCoder:
|
|
|
432
433
|
ctr_x = boxes[:, 0] + 0.5 * widths
|
|
433
434
|
ctr_y = boxes[:, 1] + 0.5 * heights
|
|
434
435
|
|
|
435
|
-
|
|
436
|
+
wx, wy, ww, wh = self.weights
|
|
436
437
|
dx = rel_codes[:, 0::4] / wx
|
|
437
438
|
dy = rel_codes[:, 1::4] / wy
|
|
438
439
|
dw = rel_codes[:, 2::4] / ww
|
|
@@ -510,8 +511,8 @@ class AnchorGenerator(nn.Module):
|
|
|
510
511
|
)
|
|
511
512
|
|
|
512
513
|
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
|
|
513
|
-
|
|
514
|
-
|
|
514
|
+
grid_height, grid_width = size
|
|
515
|
+
stride_height, stride_width = stride
|
|
515
516
|
device = base_anchors.device
|
|
516
517
|
|
|
517
518
|
# For output anchor, compute [x_center, y_center, x_center, y_center]
|
|
@@ -656,7 +657,7 @@ class Matcher(nn.Module):
|
|
|
656
657
|
# tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
|
|
657
658
|
# Each element in the first tensor is a gt index,
|
|
658
659
|
# and each element in second tensor is a prediction index
|
|
659
|
-
# Note how gt items 1, 2, 3
|
|
660
|
+
# Note how gt items 1, 2, 3 and 5 each have two ties
|
|
660
661
|
|
|
661
662
|
pred_idx_to_update = gt_pred_pairs_of_highest_quality[1]
|
|
662
663
|
matches[pred_idx_to_update] = all_matches[pred_idx_to_update]
|