birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/smt.py
CHANGED
|
@@ -36,7 +36,7 @@ class DWConv(nn.Module):
|
|
|
36
36
|
self.dwconv = nn.Conv2d(dim, dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=dim)
|
|
37
37
|
|
|
38
38
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
39
|
-
|
|
39
|
+
B, _, C = x.size()
|
|
40
40
|
x = x.transpose(1, 2).view(B, C, H, W)
|
|
41
41
|
x = self.dwconv(x)
|
|
42
42
|
x = x.flatten(2).transpose(1, 2)
|
|
@@ -94,7 +94,7 @@ class CAAttention(nn.Module):
|
|
|
94
94
|
self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
95
95
|
|
|
96
96
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
97
|
-
|
|
97
|
+
B, N, C = x.size()
|
|
98
98
|
|
|
99
99
|
v = self.v(x)
|
|
100
100
|
s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1, 2)
|
|
@@ -140,11 +140,11 @@ class SAAttention(nn.Module):
|
|
|
140
140
|
self.conv = nn.Conv2d(dim, dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=dim)
|
|
141
141
|
|
|
142
142
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
143
|
-
|
|
143
|
+
B, N, C = x.size()
|
|
144
144
|
|
|
145
145
|
q = self.q(x).reshape(B, N, self.sa_num_heads, C // self.sa_num_heads).permute(0, 2, 1, 3)
|
|
146
146
|
kv = self.kv(x).reshape(B, -1, 2, self.sa_num_heads, C // self.sa_num_heads).permute(2, 0, 3, 1, 4)
|
|
147
|
-
|
|
147
|
+
k, v = kv.unbind(0)
|
|
148
148
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
149
149
|
attn = attn.softmax(dim=-1)
|
|
150
150
|
attn = self.attn_drop(attn)
|
|
@@ -243,7 +243,7 @@ class OverlapPatchEmbed(nn.Module):
|
|
|
243
243
|
|
|
244
244
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
|
|
245
245
|
x = self.proj(x)
|
|
246
|
-
|
|
246
|
+
_, _, H, W = x.size()
|
|
247
247
|
x = x.flatten(2).transpose(1, 2)
|
|
248
248
|
x = self.norm(x)
|
|
249
249
|
|
|
@@ -267,7 +267,7 @@ class Stem(nn.Module):
|
|
|
267
267
|
|
|
268
268
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
|
|
269
269
|
x = self.conv(x)
|
|
270
|
-
|
|
270
|
+
_, _, H, W = x.size()
|
|
271
271
|
x = x.flatten(2).transpose(1, 2)
|
|
272
272
|
x = self.norm(x)
|
|
273
273
|
|
|
@@ -329,7 +329,7 @@ class SMTStage(nn.Module):
|
|
|
329
329
|
|
|
330
330
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
331
331
|
B = x.size(0)
|
|
332
|
-
|
|
332
|
+
x, H, W = self.downsample_block(x)
|
|
333
333
|
x = self.blocks(x, H, W)
|
|
334
334
|
x = self.norm(x)
|
|
335
335
|
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
birder/net/ssl/barlow_twins.py
CHANGED
|
@@ -21,7 +21,7 @@ from birder.net.ssl.base import SSLBaseNet
|
|
|
21
21
|
|
|
22
22
|
def off_diagonal(x: torch.Tensor) -> torch.Tensor:
|
|
23
23
|
# Return a flattened view of the off-diagonal elements of a square matrix
|
|
24
|
-
|
|
24
|
+
n, _ = x.size()
|
|
25
25
|
# assert n == m
|
|
26
26
|
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
|
27
27
|
|
birder/net/ssl/byol.py
CHANGED
|
@@ -80,11 +80,11 @@ class BYOL(SSLBaseNet):
|
|
|
80
80
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
81
81
|
projection = self.online_encoder(x)
|
|
82
82
|
online_predictions = self.online_predictor(projection)
|
|
83
|
-
|
|
83
|
+
online_pred_one, online_pred_two = online_predictions.chunk(2, dim=0)
|
|
84
84
|
|
|
85
85
|
with torch.no_grad():
|
|
86
86
|
target_projections = self.target_encoder(x)
|
|
87
|
-
|
|
87
|
+
target_proj_one, target_proj_two = target_projections.chunk(2, dim=0)
|
|
88
88
|
|
|
89
89
|
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
|
|
90
90
|
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
|
birder/net/ssl/capi.py
CHANGED
|
@@ -263,11 +263,11 @@ class CrossAttention(nn.Module):
|
|
|
263
263
|
self.proj = nn.Linear(decoder_dim, decoder_dim)
|
|
264
264
|
|
|
265
265
|
def forward(self, tgt: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
|
|
266
|
-
|
|
266
|
+
B, N, C = tgt.size()
|
|
267
267
|
n_kv = memory.size(1)
|
|
268
268
|
q = self.q(tgt).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
269
269
|
kv = self.kv(memory).reshape(B, n_kv, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
270
|
-
|
|
270
|
+
k, v = kv.unbind(0)
|
|
271
271
|
|
|
272
272
|
attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) # pylint: disable=not-callable
|
|
273
273
|
x = attn.transpose(1, 2).reshape(B, N, C)
|
|
@@ -419,7 +419,7 @@ class CAPITeacher(SSLBaseNet):
|
|
|
419
419
|
x = self.backbone.masked_encoding_omission(x, ids_keep)["tokens"]
|
|
420
420
|
|
|
421
421
|
x = x[:, self.backbone.num_special_tokens :, :]
|
|
422
|
-
|
|
422
|
+
assignments, clustering_loss = self.head(x.transpose(0, 1))
|
|
423
423
|
|
|
424
424
|
assignments = assignments.detach().transpose(0, 1)
|
|
425
425
|
row_indices = torch.arange(B).unsqueeze(1).expand_as(ids_predict)
|
birder/net/ssl/data2vec2.py
CHANGED
|
@@ -68,7 +68,7 @@ class Decoder2d(nn.Module):
|
|
|
68
68
|
self.proj = nn.Linear(embed_dim, in_channels)
|
|
69
69
|
|
|
70
70
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
-
|
|
71
|
+
B, _, C = x.size() # B, N, C
|
|
72
72
|
|
|
73
73
|
x = x.transpose(1, 2).reshape(B, C, self.H, self.W)
|
|
74
74
|
|
birder/net/ssl/dino_v2.py
CHANGED
|
@@ -148,7 +148,17 @@ class DINOLoss(nn.Module):
|
|
|
148
148
|
|
|
149
149
|
def forward(
|
|
150
150
|
self, student_output_list: list[torch.Tensor], teacher_out_softmax_centered_list: list[torch.Tensor]
|
|
151
|
-
) ->
|
|
151
|
+
) -> torch.Tensor:
|
|
152
|
+
s = torch.stack(student_output_list, 0)
|
|
153
|
+
t = torch.stack(teacher_out_softmax_centered_list, 0)
|
|
154
|
+
lsm = F.log_softmax(s / self.student_temp, dim=-1)
|
|
155
|
+
loss = -(torch.einsum("tbk,sbk->tsb", t, lsm).mean(-1).sum())
|
|
156
|
+
|
|
157
|
+
return loss
|
|
158
|
+
|
|
159
|
+
def forward_reference(
|
|
160
|
+
self, student_output_list: list[torch.Tensor], teacher_out_softmax_centered_list: list[torch.Tensor]
|
|
161
|
+
) -> torch.Tensor:
|
|
152
162
|
total_loss = 0.0
|
|
153
163
|
for s in student_output_list:
|
|
154
164
|
lsm = F.log_softmax(s / self.student_temp, dim=-1)
|
birder/net/ssl/franca.py
CHANGED
|
@@ -69,7 +69,7 @@ class DINOHeadMRL(nn.Module):
|
|
|
69
69
|
) -> None:
|
|
70
70
|
super().__init__()
|
|
71
71
|
self.nesting_list = nesting_list
|
|
72
|
-
self.matryoshka_projections = nn.ModuleList([nn.Linear(dim, dim
|
|
72
|
+
self.matryoshka_projections = nn.ModuleList([nn.Linear(dim, dim) for dim in self.nesting_list])
|
|
73
73
|
|
|
74
74
|
self.mlps = nn.ModuleList(
|
|
75
75
|
[
|
|
@@ -197,7 +197,31 @@ class DINOLossMRL(nn.Module):
|
|
|
197
197
|
teacher_out_softmax_centered_list: list[torch.Tensor],
|
|
198
198
|
n_crops: int | tuple[int, int],
|
|
199
199
|
teacher_global: bool,
|
|
200
|
-
) ->
|
|
200
|
+
) -> torch.Tensor:
|
|
201
|
+
total_loss = 0.0
|
|
202
|
+
if teacher_global is False:
|
|
203
|
+
for student_outputs, teacher_outputs in zip(student_output_list, teacher_out_softmax_centered_list):
|
|
204
|
+
s = torch.stack(student_outputs.chunk(n_crops[0]), 0) # type: ignore[index]
|
|
205
|
+
t = teacher_outputs.view(n_crops[1], -1, teacher_outputs.shape[-1]) # type: ignore[index]
|
|
206
|
+
lsm = F.log_softmax(s / self.student_temp, dim=-1)
|
|
207
|
+
total_loss -= torch.einsum("tbk,sbk->tsb", t, lsm).mean(-1).sum()
|
|
208
|
+
|
|
209
|
+
else:
|
|
210
|
+
for student_outputs, teacher_outputs in zip(student_output_list, teacher_out_softmax_centered_list):
|
|
211
|
+
teacher_outputs = teacher_outputs.view(n_crops, -1, teacher_outputs.shape[-1])
|
|
212
|
+
lsm = F.log_softmax(student_outputs / self.student_temp, dim=-1)
|
|
213
|
+
loss = torch.sum(teacher_outputs.flatten(0, 1) * lsm, dim=-1)
|
|
214
|
+
total_loss -= loss.mean()
|
|
215
|
+
|
|
216
|
+
return total_loss
|
|
217
|
+
|
|
218
|
+
def forward_reference(
|
|
219
|
+
self,
|
|
220
|
+
student_output_list: list[torch.Tensor],
|
|
221
|
+
teacher_out_softmax_centered_list: list[torch.Tensor],
|
|
222
|
+
n_crops: int | tuple[int, int],
|
|
223
|
+
teacher_global: bool,
|
|
224
|
+
) -> torch.Tensor:
|
|
201
225
|
total_loss = 0.0
|
|
202
226
|
if teacher_global is False:
|
|
203
227
|
for student_outputs, teacher_outputs in zip(student_output_list, teacher_out_softmax_centered_list):
|
birder/net/ssl/i_jepa.py
CHANGED
|
@@ -69,11 +69,11 @@ class MultiBlockMasking:
|
|
|
69
69
|
) -> tuple[int, int]:
|
|
70
70
|
_rand = torch.rand(1).item()
|
|
71
71
|
|
|
72
|
-
|
|
72
|
+
min_s, max_s = scale
|
|
73
73
|
mask_scale = min_s + _rand * (max_s - min_s)
|
|
74
74
|
max_keep = int(self.height * self.width * mask_scale)
|
|
75
75
|
|
|
76
|
-
|
|
76
|
+
min_ar, max_ar = aspect_ratio_scale
|
|
77
77
|
aspect_ratio = min_ar + _rand * (max_ar - min_ar)
|
|
78
78
|
|
|
79
79
|
# Compute block height and width (given scale and aspect-ratio)
|
|
@@ -154,7 +154,7 @@ class MultiBlockMasking:
|
|
|
154
154
|
masks_p = []
|
|
155
155
|
masks_c = []
|
|
156
156
|
for _ in range(self.n_pred):
|
|
157
|
-
|
|
157
|
+
mask, mask_c = self._sample_block_mask(p_size)
|
|
158
158
|
masks_p.append(mask)
|
|
159
159
|
masks_c.append(mask_c)
|
|
160
160
|
min_keep_pred = min(min_keep_pred, len(mask))
|
|
@@ -167,7 +167,7 @@ class MultiBlockMasking:
|
|
|
167
167
|
|
|
168
168
|
masks_e = []
|
|
169
169
|
for _ in range(self.n_enc):
|
|
170
|
-
|
|
170
|
+
mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions)
|
|
171
171
|
masks_e.append(mask)
|
|
172
172
|
min_keep_enc = min(min_keep_enc, len(mask))
|
|
173
173
|
|
birder/net/ssl/mmcr.py
CHANGED
|
@@ -125,7 +125,7 @@ class MMCR(SSLBaseNet):
|
|
|
125
125
|
self.momentum_encoder.load_state_dict(self.encoder.state_dict())
|
|
126
126
|
|
|
127
127
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
128
|
-
|
|
128
|
+
C, H, W = x.shape[-3:] # B, num_views, C, H, W
|
|
129
129
|
x = x.reshape(-1, C, H, W)
|
|
130
130
|
z = self.encoder(x)
|
|
131
131
|
|
birder/net/swiftformer.py
CHANGED
|
@@ -111,7 +111,7 @@ class EfficientAdditiveAttention(nn.Module):
|
|
|
111
111
|
self.final = nn.Linear(token_dim * num_heads, token_dim)
|
|
112
112
|
|
|
113
113
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
114
|
-
|
|
114
|
+
B, _, H, W = x.size()
|
|
115
115
|
x = x.flatten(2).permute(0, 2, 1)
|
|
116
116
|
|
|
117
117
|
query = F.normalize(self.to_query(x), dim=-1)
|
|
@@ -30,7 +30,7 @@ from birder.net.base import DetectorBackbone
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
|
|
33
|
-
|
|
33
|
+
H, W, _ = x.shape[-3:]
|
|
34
34
|
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
|
35
35
|
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
|
|
36
36
|
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
|
|
@@ -73,13 +73,13 @@ def shifted_window_attention(
|
|
|
73
73
|
proj_bias: Optional[torch.Tensor] = None,
|
|
74
74
|
logit_scale: Optional[torch.Tensor] = None,
|
|
75
75
|
) -> torch.Tensor:
|
|
76
|
-
|
|
76
|
+
B, H, W, C = x.size()
|
|
77
77
|
|
|
78
78
|
# Pad feature maps to multiples of window size
|
|
79
79
|
pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
|
|
80
80
|
pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
|
|
81
81
|
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
|
82
|
-
|
|
82
|
+
_, pad_h, pad_w, _ = x.size()
|
|
83
83
|
|
|
84
84
|
# If window size is larger than feature size, there is no need to shift window
|
|
85
85
|
shift_size_w = shift_size[0]
|
|
@@ -309,7 +309,6 @@ class Swin_Transformer_v1(DetectorBackbone):
|
|
|
309
309
|
kernel_size=(patch_size, patch_size),
|
|
310
310
|
stride=(patch_size, patch_size),
|
|
311
311
|
padding=(0, 0),
|
|
312
|
-
bias=True,
|
|
313
312
|
),
|
|
314
313
|
Permute([0, 2, 3, 1]),
|
|
315
314
|
nn.LayerNorm(embed_dim, eps=1e-5),
|
|
@@ -434,7 +433,7 @@ class Swin_Transformer_v1(DetectorBackbone):
|
|
|
434
433
|
num_attn_heads = rel_pos_bias.size(1)
|
|
435
434
|
|
|
436
435
|
def _calc(src: int, dst: int) -> list[float]:
|
|
437
|
-
|
|
436
|
+
left, right = 1.01, 1.5
|
|
438
437
|
while right - left > 1e-6:
|
|
439
438
|
q = (left + right) / 2.0
|
|
440
439
|
gp = (1.0 - q ** (src // 2)) / (1.0 - q) # Geometric progression
|
|
@@ -76,7 +76,9 @@ class ShiftedWindowAttention(nn.Module):
|
|
|
76
76
|
|
|
77
77
|
# MLP to generate continuous relative position bias
|
|
78
78
|
self.cpb_mlp = nn.Sequential(
|
|
79
|
-
nn.Linear(2, 512
|
|
79
|
+
nn.Linear(2, 512),
|
|
80
|
+
nn.ReLU(inplace=True),
|
|
81
|
+
nn.Linear(512, num_heads, bias=False),
|
|
80
82
|
)
|
|
81
83
|
if qkv_bias is True:
|
|
82
84
|
length = self.qkv.bias.numel() // 3
|
|
@@ -224,12 +226,7 @@ class Swin_Transformer_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentio
|
|
|
224
226
|
|
|
225
227
|
self.stem = nn.Sequential(
|
|
226
228
|
nn.Conv2d(
|
|
227
|
-
self.input_channels,
|
|
228
|
-
embed_dim,
|
|
229
|
-
kernel_size=(patch_size, patch_size),
|
|
230
|
-
stride=patch_size,
|
|
231
|
-
padding=(0, 0),
|
|
232
|
-
bias=True,
|
|
229
|
+
self.input_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, padding=(0, 0)
|
|
233
230
|
),
|
|
234
231
|
Permute([0, 2, 3, 1]),
|
|
235
232
|
nn.LayerNorm(embed_dim, eps=1e-5),
|
birder/net/tiny_vit.py
CHANGED
|
@@ -201,12 +201,12 @@ class Attention(nn.Module):
|
|
|
201
201
|
|
|
202
202
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
203
203
|
attn_bias = self.attention_biases[:, self.attention_bias_idxs]
|
|
204
|
-
|
|
204
|
+
B, N, _ = x.shape
|
|
205
205
|
|
|
206
206
|
# Normalization
|
|
207
207
|
x = self.norm(x)
|
|
208
208
|
qkv = self.qkv(x)
|
|
209
|
-
|
|
209
|
+
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
|
|
210
210
|
|
|
211
211
|
q = q.permute(0, 2, 1, 3)
|
|
212
212
|
k = k.permute(0, 2, 1, 3)
|
|
@@ -252,7 +252,7 @@ class TinyVitBlock(nn.Module):
|
|
|
252
252
|
)
|
|
253
253
|
|
|
254
254
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
255
|
-
|
|
255
|
+
B, H, W, C = x.shape
|
|
256
256
|
L = H * W
|
|
257
257
|
|
|
258
258
|
shortcut = x
|
birder/net/transnext.py
CHANGED
|
@@ -32,8 +32,8 @@ def get_relative_position_cpb(
|
|
|
32
32
|
axis_kh = F.adaptive_avg_pool1d(axis_qh.unsqueeze(0), key_size[0]).squeeze(0) # pylint: disable=not-callable
|
|
33
33
|
axis_qw = torch.arange(query_size[1], dtype=torch.float32, device=device)
|
|
34
34
|
axis_kw = F.adaptive_avg_pool1d(axis_qw.unsqueeze(0), key_size[1]).squeeze(0) # pylint: disable=not-callable
|
|
35
|
-
|
|
36
|
-
|
|
35
|
+
axis_kh, axis_kw = torch.meshgrid(axis_kh, axis_kw, indexing="ij")
|
|
36
|
+
axis_qh, axis_qw = torch.meshgrid(axis_qh, axis_qw, indexing="ij")
|
|
37
37
|
|
|
38
38
|
axis_kh = torch.reshape(axis_kh, [-1])
|
|
39
39
|
axis_kw = torch.reshape(axis_kw, [-1])
|
|
@@ -44,7 +44,7 @@ def get_relative_position_cpb(
|
|
|
44
44
|
relative_w = (axis_qw[:, None] - axis_kw[None, :]) / (pretrain_size[1] - 1) * 8
|
|
45
45
|
relative_hw = torch.stack([relative_h, relative_w], dim=-1).view(-1, 2)
|
|
46
46
|
|
|
47
|
-
|
|
47
|
+
relative_coords_table, idx_map = torch.unique(relative_hw, return_inverse=True, dim=0)
|
|
48
48
|
|
|
49
49
|
relative_coords_table = (
|
|
50
50
|
torch.sign(relative_coords_table)
|
|
@@ -86,9 +86,9 @@ class ConvolutionalGLU(nn.Module):
|
|
|
86
86
|
self.drop = nn.Dropout(drop)
|
|
87
87
|
|
|
88
88
|
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
89
|
-
|
|
89
|
+
x, v = self.fc1(x).chunk(2, dim=-1)
|
|
90
90
|
|
|
91
|
-
|
|
91
|
+
B, _, C = x.size()
|
|
92
92
|
x = x.transpose(1, 2).view(B, C, H, W).contiguous()
|
|
93
93
|
x = self.dwconv(x)
|
|
94
94
|
x = x.flatten(2).transpose(1, 2)
|
|
@@ -143,9 +143,9 @@ class Attention(nn.Module):
|
|
|
143
143
|
def forward(
|
|
144
144
|
self, x: torch.Tensor, _h: int, _w: int, relative_pos_index: torch.Tensor, relative_coords_table: torch.Tensor
|
|
145
145
|
) -> torch.Tensor:
|
|
146
|
-
|
|
146
|
+
B, N, C = x.size()
|
|
147
147
|
qkv = self.qkv(x).reshape(B, -1, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
148
|
-
|
|
148
|
+
q, k, v = qkv.chunk(3, dim=1)
|
|
149
149
|
|
|
150
150
|
# Use MLP to generate continuous relative positional bias
|
|
151
151
|
rel_bias = (
|
|
@@ -217,9 +217,9 @@ class AggregatedAttention(nn.Module):
|
|
|
217
217
|
self.act = nn.GELU()
|
|
218
218
|
|
|
219
219
|
# MLP to generate continuous relative position bias
|
|
220
|
-
self.cpb_fc1 = nn.Linear(2, 512
|
|
220
|
+
self.cpb_fc1 = nn.Linear(2, 512)
|
|
221
221
|
self.cpb_act = nn.ReLU(inplace=True)
|
|
222
|
-
self.cpb_fc2 = nn.Linear(512, num_heads
|
|
222
|
+
self.cpb_fc2 = nn.Linear(512, num_heads)
|
|
223
223
|
|
|
224
224
|
# relative bias for local features
|
|
225
225
|
self.relative_pos_bias_local = nn.Parameter(
|
|
@@ -227,7 +227,7 @@ class AggregatedAttention(nn.Module):
|
|
|
227
227
|
)
|
|
228
228
|
|
|
229
229
|
# Generate padding_mask and sequence length scale
|
|
230
|
-
|
|
230
|
+
local_seq_length, padding_mask = get_seqlen_and_mask(input_resolution, self.window_size)
|
|
231
231
|
self.seq_length_scale = nn.Buffer(torch.log(local_seq_length + self.pool_len), persistent=False)
|
|
232
232
|
self.padding_mask = nn.Buffer(padding_mask, persistent=False)
|
|
233
233
|
|
|
@@ -240,7 +240,7 @@ class AggregatedAttention(nn.Module):
|
|
|
240
240
|
def forward(
|
|
241
241
|
self, x: torch.Tensor, H: int, W: int, relative_pos_index: torch.Tensor, relative_coords_table: torch.Tensor
|
|
242
242
|
) -> torch.Tensor:
|
|
243
|
-
|
|
243
|
+
B, N, C = x.size()
|
|
244
244
|
|
|
245
245
|
# Generate queries, normalize them with L2, add query embedding,
|
|
246
246
|
# and then magnify with sequence length scale and temperature.
|
|
@@ -252,7 +252,7 @@ class AggregatedAttention(nn.Module):
|
|
|
252
252
|
* self.seq_length_scale
|
|
253
253
|
)
|
|
254
254
|
|
|
255
|
-
|
|
255
|
+
attn_local, v_local = self.swa_qk_rpb(
|
|
256
256
|
self.kv(x),
|
|
257
257
|
q_norm_scaled.contiguous(),
|
|
258
258
|
self.relative_pos_bias_local,
|
|
@@ -272,7 +272,7 @@ class AggregatedAttention(nn.Module):
|
|
|
272
272
|
|
|
273
273
|
# Generate pooled keys and values
|
|
274
274
|
kv_pool = self.kv(x_).reshape(B, self.pool_len, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
275
|
-
|
|
275
|
+
k_pool, v_pool = kv_pool.chunk(2, dim=1)
|
|
276
276
|
|
|
277
277
|
# Use MLP to generate continuous relative positional bias for pooled features.
|
|
278
278
|
pool_bias = (
|
|
@@ -288,7 +288,7 @@ class AggregatedAttention(nn.Module):
|
|
|
288
288
|
attn = self.attn_drop(attn)
|
|
289
289
|
|
|
290
290
|
# Split the attention weights and separately aggregate the values of local & pooled features
|
|
291
|
-
|
|
291
|
+
attn_local, attn_pool = torch.split(attn, [self.local_len, self.pool_len], dim=-1)
|
|
292
292
|
|
|
293
293
|
x_local = self.swa_av(
|
|
294
294
|
q_norm, attn_local, v_local.contiguous(), self.learnable_tokens, self.learnable_bias, self.window_size, H, W
|
|
@@ -367,7 +367,7 @@ class OverlapPatchEmbed(nn.Module):
|
|
|
367
367
|
|
|
368
368
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
|
|
369
369
|
x = self.proj(x)
|
|
370
|
-
|
|
370
|
+
_, _, H, W = x.size()
|
|
371
371
|
x = x.flatten(2).transpose(1, 2)
|
|
372
372
|
x = self.norm(x)
|
|
373
373
|
|
|
@@ -396,7 +396,7 @@ class TransNeXtStage(nn.Module):
|
|
|
396
396
|
|
|
397
397
|
# Generate relative positional coordinate table and index for each stage
|
|
398
398
|
# to compute continuous relative positional bias
|
|
399
|
-
|
|
399
|
+
relative_pos_index, relative_coords_table = get_relative_position_cpb(
|
|
400
400
|
query_size=input_resolution, key_size=(input_resolution[0] // sr_ratio, input_resolution[1] // sr_ratio)
|
|
401
401
|
)
|
|
402
402
|
self.relative_pos_index = nn.Buffer(relative_pos_index, persistent=False)
|
|
@@ -430,7 +430,7 @@ class TransNeXtStage(nn.Module):
|
|
|
430
430
|
|
|
431
431
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
432
432
|
B = x.size(0)
|
|
433
|
-
|
|
433
|
+
x, H, W = self.patch_embed(x)
|
|
434
434
|
for blk in self.blocks:
|
|
435
435
|
x = blk(x, H, W, self.relative_pos_index, self.relative_coords_table)
|
|
436
436
|
|
|
@@ -553,7 +553,7 @@ class TransNeXt(DetectorBackbone):
|
|
|
553
553
|
sr_ratio = self.sr_ratio[i]
|
|
554
554
|
with torch.no_grad():
|
|
555
555
|
device = next(m.parameters()).device
|
|
556
|
-
|
|
556
|
+
relative_pos_index, relative_coords_table = get_relative_position_cpb(
|
|
557
557
|
query_size=input_resolution,
|
|
558
558
|
key_size=(input_resolution[0] // sr_ratio, input_resolution[1] // sr_ratio),
|
|
559
559
|
device=device,
|
|
@@ -574,7 +574,7 @@ class TransNeXt(DetectorBackbone):
|
|
|
574
574
|
blk.pool_len = pool_h * pool_w
|
|
575
575
|
blk.pool = nn.AdaptiveAvgPool2d((pool_h, pool_w))
|
|
576
576
|
|
|
577
|
-
|
|
577
|
+
local_seq_length, padding_mask = get_seqlen_and_mask(
|
|
578
578
|
input_resolution, blk.window_size, device=device
|
|
579
579
|
)
|
|
580
580
|
blk.seq_length_scale = nn.Buffer(
|
birder/net/uniformer.py
CHANGED
|
@@ -71,9 +71,9 @@ class Attention(nn.Module):
|
|
|
71
71
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
72
72
|
|
|
73
73
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
74
|
-
|
|
74
|
+
B, N, C = x.shape
|
|
75
75
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
76
|
-
|
|
76
|
+
q, k, v = qkv.unbind(0)
|
|
77
77
|
|
|
78
78
|
x = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
79
79
|
q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
|
|
@@ -137,7 +137,7 @@ class AttentionBlock(nn.Module):
|
|
|
137
137
|
|
|
138
138
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
139
139
|
x = x + self.pos_embed(x)
|
|
140
|
-
|
|
140
|
+
B, N, H, W = x.shape
|
|
141
141
|
x = x.flatten(2).transpose(1, 2)
|
|
142
142
|
x = x + self.drop_path(self.layer_scale_1(self.attn(self.norm1(x))))
|
|
143
143
|
x = x + self.drop_path(self.layer_scale_2(self.mlp(self.norm2(x))))
|
|
@@ -155,7 +155,7 @@ class PatchEmbed(nn.Module):
|
|
|
155
155
|
|
|
156
156
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
157
157
|
x = self.proj(x)
|
|
158
|
-
|
|
158
|
+
B, _, H, W = x.size() # B, C, H, W
|
|
159
159
|
x = x.flatten(2).transpose(1, 2)
|
|
160
160
|
x = self.norm(x)
|
|
161
161
|
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
birder/net/vgg.py
CHANGED
|
@@ -40,16 +40,7 @@ class Vgg(DetectorBackbone):
|
|
|
40
40
|
else:
|
|
41
41
|
in_channels = filters[i]
|
|
42
42
|
|
|
43
|
-
layers.append(
|
|
44
|
-
nn.Conv2d(
|
|
45
|
-
in_channels,
|
|
46
|
-
filters[i],
|
|
47
|
-
kernel_size=(3, 3),
|
|
48
|
-
stride=(1, 1),
|
|
49
|
-
padding=(1, 1),
|
|
50
|
-
bias=True,
|
|
51
|
-
)
|
|
52
|
-
)
|
|
43
|
+
layers.append(nn.Conv2d(in_channels, filters[i], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
|
|
53
44
|
layers.append(nn.ReLU(inplace=True))
|
|
54
45
|
|
|
55
46
|
layers.append(nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)))
|