birder 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +13 -13
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +6 -6
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +4 -4
- birder/layers/attention_pool.py +2 -2
- birder/layers/layer_scale.py +1 -1
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +4 -10
- birder/net/_rope_vit_configs.py +435 -0
- birder/net/_vit_configs.py +466 -0
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +18 -17
- birder/net/cait.py +7 -7
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +27 -27
- birder/net/conv2former.py +3 -3
- birder/net/convmixer.py +1 -1
- birder/net/convnext_v1.py +3 -11
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +6 -6
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +11 -11
- birder/net/deit.py +68 -29
- birder/net/deit3.py +69 -204
- birder/net/densenet.py +9 -8
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +31 -30
- birder/net/detection/detr.py +14 -11
- birder/net/detection/efficientdet.py +10 -29
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/dpn.py +1 -2
- birder/net/edgenext.py +5 -4
- birder/net/edgevit.py +13 -14
- birder/net/efficientformer_v1.py +3 -2
- birder/net/efficientformer_v2.py +18 -31
- birder/net/efficientnet_v2.py +3 -0
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +7 -7
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +3 -3
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +50 -58
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +13 -13
- birder/net/hgnet_v1.py +6 -6
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +5 -15
- birder/net/inception_resnet_v1.py +3 -3
- birder/net/inception_resnet_v2.py +7 -4
- birder/net/inception_v3.py +3 -0
- birder/net/inception_v4.py +3 -0
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +11 -23
- birder/net/metaformer.py +5 -5
- birder/net/mim/crossmae.py +6 -6
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +4 -6
- birder/net/mim/simmim.py +3 -4
- birder/net/mobilenet_v1.py +0 -9
- birder/net/mobilenet_v2.py +38 -44
- birder/net/{mobilenet_v3_large.py → mobilenet_v3.py} +37 -10
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +7 -34
- birder/net/mobilevit_v2.py +6 -54
- birder/net/moganet.py +8 -5
- birder/net/mvit_v2.py +30 -30
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +4 -0
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +9 -9
- birder/net/pvt_v2.py +10 -16
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +5 -35
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resmlp.py +2 -2
- birder/net/resnest.py +4 -1
- birder/net/resnet_v1.py +125 -1
- birder/net/resnet_v2.py +75 -1
- birder/net/resnext.py +35 -1
- birder/net/rope_deit3.py +62 -151
- birder/net/rope_flexivit.py +46 -33
- birder/net/rope_vit.py +44 -758
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +69 -21
- birder/net/smt.py +8 -8
- birder/net/squeezenet.py +5 -12
- birder/net/squeezenext.py +0 -24
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +4 -4
- birder/net/ssl/data2vec.py +1 -1
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +13 -3
- birder/net/ssl/franca.py +28 -4
- birder/net/ssl/i_jepa.py +5 -5
- birder/net/ssl/ibot.py +1 -1
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +13 -3
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +5 -8
- birder/net/tiny_vit.py +6 -19
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/van.py +2 -2
- birder/net/vgg.py +1 -10
- birder/net/vit.py +72 -987
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +23 -48
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +16 -13
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +12 -3
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +15 -15
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder/net/mobilenet_v3_small.py +0 -43
- birder/net/se_resnet_v1.py +0 -105
- birder/net/se_resnet_v2.py +0 -59
- birder/net/se_resnext.py +0 -30
- birder-0.3.3.dist-info/RECORD +0 -299
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.3.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/cswin_transformer.py
CHANGED
|
@@ -29,7 +29,7 @@ from birder.net.vit import PatchEmbed
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
def img2windows(img: torch.Tensor, h_sp: int, w_sp: int) -> torch.Tensor:
|
|
32
|
-
|
|
32
|
+
B, C, H, W = img.size()
|
|
33
33
|
img_reshape = img.view(B, C, H // h_sp, h_sp, W // w_sp, w_sp)
|
|
34
34
|
img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, h_sp * w_sp, C)
|
|
35
35
|
|
|
@@ -81,7 +81,7 @@ class LePEAttention(nn.Module):
|
|
|
81
81
|
raise ValueError("unsupported idx")
|
|
82
82
|
|
|
83
83
|
def im2cswin(self, x: torch.Tensor) -> torch.Tensor:
|
|
84
|
-
|
|
84
|
+
B, _, C = x.size()
|
|
85
85
|
x = x.transpose(-2, -1).contiguous().view(B, C, self.resolution[0], self.resolution[1])
|
|
86
86
|
x = img2windows(x, self.h_sp, self.w_sp)
|
|
87
87
|
x = x.reshape(-1, self.h_sp * self.w_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
|
|
@@ -89,7 +89,7 @@ class LePEAttention(nn.Module):
|
|
|
89
89
|
return x
|
|
90
90
|
|
|
91
91
|
def get_lepe(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
92
|
-
|
|
92
|
+
B, _, C = x.size()
|
|
93
93
|
H = self.resolution[0]
|
|
94
94
|
W = self.resolution[1]
|
|
95
95
|
x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
|
|
@@ -107,13 +107,13 @@ class LePEAttention(nn.Module):
|
|
|
107
107
|
return (x, lepe)
|
|
108
108
|
|
|
109
109
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
110
|
-
|
|
110
|
+
q, k, v = x.unbind(0)
|
|
111
111
|
|
|
112
|
-
|
|
112
|
+
B, _, C = q.shape
|
|
113
113
|
|
|
114
114
|
q = self.im2cswin(q)
|
|
115
115
|
k = self.im2cswin(k)
|
|
116
|
-
|
|
116
|
+
v, lepe = self.get_lepe(v)
|
|
117
117
|
|
|
118
118
|
q = q * self.scale
|
|
119
119
|
attn = q @ k.transpose(-2, -1) # B head N C @ B head C N --> B head N N
|
|
@@ -136,12 +136,12 @@ class MergeBlock(nn.Module):
|
|
|
136
136
|
self.resolution = resolution
|
|
137
137
|
|
|
138
138
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
139
|
-
|
|
139
|
+
B, _, C = x.size()
|
|
140
140
|
H = self.resolution[0]
|
|
141
141
|
W = self.resolution[1]
|
|
142
142
|
x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
|
|
143
143
|
x = self.conv(x)
|
|
144
|
-
|
|
144
|
+
B, C = x.shape[:2]
|
|
145
145
|
x = x.view(B, C, -1).transpose(-2, -1).contiguous()
|
|
146
146
|
x = self.norm(x)
|
|
147
147
|
|
|
@@ -206,7 +206,7 @@ class CSWinBlock(nn.Module):
|
|
|
206
206
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
207
207
|
|
|
208
208
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
209
|
-
|
|
209
|
+
B, _, C = x.shape
|
|
210
210
|
|
|
211
211
|
qkv = self.qkv(self.norm1(x)).reshape(B, -1, 3, C).permute(2, 0, 1, 3)
|
|
212
212
|
if self.branch_num == 2:
|
|
@@ -350,7 +350,7 @@ class CSWin_Transformer(DetectorBackbone):
|
|
|
350
350
|
for name, module in self.body.named_children():
|
|
351
351
|
x = module(x)
|
|
352
352
|
if name in self.return_stages:
|
|
353
|
-
|
|
353
|
+
B, L, C = x.size()
|
|
354
354
|
H = int(math.sqrt(L))
|
|
355
355
|
W = H
|
|
356
356
|
out[name] = x.transpose(-2, -1).contiguous().view(B, C, H, W)
|
birder/net/davit.py
CHANGED
|
@@ -31,7 +31,7 @@ from birder.net.base import TokenRetentionResultType
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def window_partition(x: torch.Tensor, window_size: tuple[int, int]) -> torch.Tensor:
|
|
34
|
-
|
|
34
|
+
B, H, W, C = x.shape
|
|
35
35
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
|
36
36
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
|
37
37
|
|
|
@@ -64,7 +64,7 @@ class ConvPosEnc(nn.Module):
|
|
|
64
64
|
dim,
|
|
65
65
|
kernel_size=kernel_size,
|
|
66
66
|
stride=(1, 1),
|
|
67
|
-
padding=(kernel_size[0] // 2, kernel_size[1] // 2),
|
|
67
|
+
padding=((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
|
|
68
68
|
groups=dim,
|
|
69
69
|
)
|
|
70
70
|
if act is True:
|
|
@@ -92,10 +92,10 @@ class Downsample(nn.Module):
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
95
|
-
|
|
95
|
+
_, _, H, W = x.shape
|
|
96
96
|
x = self.norm(x)
|
|
97
97
|
if self.even_k is True:
|
|
98
|
-
|
|
98
|
+
k_h, k_w = self.conv.kernel_size
|
|
99
99
|
pad_r = (k_w - W % k_w) % k_w
|
|
100
100
|
pad_b = (k_h - H % k_h) % k_h
|
|
101
101
|
x = F.pad(x, (0, pad_r, 0, pad_b))
|
|
@@ -115,10 +115,10 @@ class ChannelAttention(nn.Module):
|
|
|
115
115
|
self.proj = nn.Linear(dim, dim)
|
|
116
116
|
|
|
117
117
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
118
|
-
|
|
118
|
+
B, N, C = x.shape
|
|
119
119
|
|
|
120
120
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
121
|
-
|
|
121
|
+
q, k, v = qkv.unbind(0)
|
|
122
122
|
|
|
123
123
|
k = k * self.scale
|
|
124
124
|
attn = k.transpose(-1, -2) @ v
|
|
@@ -151,7 +151,7 @@ class ChannelBlock(nn.Module):
|
|
|
151
151
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
152
152
|
|
|
153
153
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
154
|
-
|
|
154
|
+
B, C, H, W = x.shape
|
|
155
155
|
x = self.cpe1(x).flatten(2).transpose(1, 2)
|
|
156
156
|
|
|
157
157
|
cur = self.norm1(x)
|
|
@@ -177,10 +177,10 @@ class WindowAttention(nn.Module):
|
|
|
177
177
|
self.proj = nn.Linear(dim, dim)
|
|
178
178
|
|
|
179
179
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
180
|
-
|
|
180
|
+
B, N, C = x.shape
|
|
181
181
|
|
|
182
182
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
183
|
-
|
|
183
|
+
q, k, v = qkv.unbind(0)
|
|
184
184
|
|
|
185
185
|
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # pylint: disable=not-callable
|
|
186
186
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
@@ -215,7 +215,7 @@ class SpatialBlock(nn.Module):
|
|
|
215
215
|
|
|
216
216
|
# pylint: disable=invalid-name
|
|
217
217
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
218
|
-
|
|
218
|
+
B, C, H, W = x.shape
|
|
219
219
|
|
|
220
220
|
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
|
|
221
221
|
|
|
@@ -226,7 +226,7 @@ class SpatialBlock(nn.Module):
|
|
|
226
226
|
pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
|
227
227
|
pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
|
228
228
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
|
229
|
-
|
|
229
|
+
_, Hp, Wp, _ = x.shape
|
|
230
230
|
|
|
231
231
|
x_windows = window_partition(x, self.window_size)
|
|
232
232
|
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
birder/net/deit.py
CHANGED
|
@@ -16,13 +16,18 @@ import torch
|
|
|
16
16
|
from torch import nn
|
|
17
17
|
|
|
18
18
|
from birder.model_registry import registry
|
|
19
|
-
from birder.net.
|
|
19
|
+
from birder.net._vit_configs import BASE
|
|
20
|
+
from birder.net._vit_configs import SMALL
|
|
21
|
+
from birder.net._vit_configs import TINY
|
|
22
|
+
from birder.net.base import DetectorBackbone
|
|
23
|
+
from birder.net.base import normalize_out_indices
|
|
20
24
|
from birder.net.vit import Encoder
|
|
21
25
|
from birder.net.vit import PatchEmbed
|
|
22
26
|
from birder.net.vit import adjust_position_embedding
|
|
23
27
|
|
|
24
28
|
|
|
25
|
-
|
|
29
|
+
# pylint: disable=too-many-instance-attributes
|
|
30
|
+
class DeiT(DetectorBackbone):
|
|
26
31
|
block_group_regex = r"encoder\.block\.(\d+)"
|
|
27
32
|
|
|
28
33
|
def __init__(
|
|
@@ -44,6 +49,7 @@ class DeiT(BaseNet):
|
|
|
44
49
|
num_heads: int = self.config["num_heads"]
|
|
45
50
|
hidden_dim: int = self.config["hidden_dim"]
|
|
46
51
|
mlp_dim: int = self.config["mlp_dim"]
|
|
52
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
47
53
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
48
54
|
|
|
49
55
|
torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
@@ -53,6 +59,7 @@ class DeiT(BaseNet):
|
|
|
53
59
|
self.num_layers = num_layers
|
|
54
60
|
self.hidden_dim = hidden_dim
|
|
55
61
|
self.num_special_tokens = 2
|
|
62
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
56
63
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
57
64
|
|
|
58
65
|
self.conv_proj = nn.Conv2d(
|
|
@@ -61,7 +68,6 @@ class DeiT(BaseNet):
|
|
|
61
68
|
kernel_size=(patch_size, patch_size),
|
|
62
69
|
stride=(patch_size, patch_size),
|
|
63
70
|
padding=(0, 0),
|
|
64
|
-
bias=True,
|
|
65
71
|
)
|
|
66
72
|
self.patch_embed = PatchEmbed()
|
|
67
73
|
|
|
@@ -89,11 +95,18 @@ class DeiT(BaseNet):
|
|
|
89
95
|
)
|
|
90
96
|
self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
91
97
|
|
|
98
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
99
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
100
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
92
101
|
self.embedding_size = hidden_dim
|
|
93
102
|
self.dist_classifier = self.create_classifier()
|
|
94
103
|
self.classifier = self.create_classifier()
|
|
95
104
|
self.distillation_output = False
|
|
96
105
|
|
|
106
|
+
self.max_stride = patch_size
|
|
107
|
+
self.stem_stride = patch_size
|
|
108
|
+
self.stem_width = hidden_dim
|
|
109
|
+
|
|
97
110
|
# Weight initialization
|
|
98
111
|
if isinstance(self.conv_proj, nn.Conv2d):
|
|
99
112
|
# Init the patchify stem
|
|
@@ -129,6 +142,53 @@ class DeiT(BaseNet):
|
|
|
129
142
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
130
143
|
self.encoder.set_causal_attention(is_causal)
|
|
131
144
|
|
|
145
|
+
def transform_to_backbone(self) -> None:
|
|
146
|
+
super().transform_to_backbone()
|
|
147
|
+
self.norm = nn.Identity()
|
|
148
|
+
self.dist_classifier = nn.Identity()
|
|
149
|
+
|
|
150
|
+
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
151
|
+
H, W = x.shape[-2:]
|
|
152
|
+
|
|
153
|
+
# Reshape and permute the input tensor
|
|
154
|
+
x = self.conv_proj(x)
|
|
155
|
+
x = self.patch_embed(x)
|
|
156
|
+
|
|
157
|
+
# Expand the class token to the full batch
|
|
158
|
+
batch_class_token = self.class_token.expand(x.shape[0], -1, -1)
|
|
159
|
+
batch_dist_token = self.dist_token.expand(x.shape[0], -1, -1)
|
|
160
|
+
|
|
161
|
+
x = torch.concat([batch_class_token, batch_dist_token, x], dim=1)
|
|
162
|
+
x = x + self.pos_embedding
|
|
163
|
+
|
|
164
|
+
if self.out_indices is None:
|
|
165
|
+
xs = [self.encoder(x)]
|
|
166
|
+
else:
|
|
167
|
+
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
168
|
+
|
|
169
|
+
out: dict[str, torch.Tensor] = {}
|
|
170
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
171
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
172
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
173
|
+
B, C, _ = stage_x.size()
|
|
174
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
175
|
+
out[stage_name] = stage_x
|
|
176
|
+
|
|
177
|
+
return out
|
|
178
|
+
|
|
179
|
+
def freeze_stages(self, up_to_stage: int) -> None:
|
|
180
|
+
for param in self.conv_proj.parameters():
|
|
181
|
+
param.requires_grad_(False)
|
|
182
|
+
|
|
183
|
+
self.pos_embedding.requires_grad_(False)
|
|
184
|
+
|
|
185
|
+
for idx, module in enumerate(self.encoder.children()):
|
|
186
|
+
if idx >= up_to_stage:
|
|
187
|
+
break
|
|
188
|
+
|
|
189
|
+
for param in module.parameters():
|
|
190
|
+
param.requires_grad_(False)
|
|
191
|
+
|
|
132
192
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
133
193
|
# Reshape and permute the input tensor
|
|
134
194
|
x = self.conv_proj(x)
|
|
@@ -200,38 +260,17 @@ class DeiT(BaseNet):
|
|
|
200
260
|
registry.register_model_config(
|
|
201
261
|
"deit_t16",
|
|
202
262
|
DeiT,
|
|
203
|
-
config={
|
|
204
|
-
"patch_size": 16,
|
|
205
|
-
"num_layers": 12,
|
|
206
|
-
"num_heads": 3,
|
|
207
|
-
"hidden_dim": 192,
|
|
208
|
-
"mlp_dim": 768,
|
|
209
|
-
"drop_path_rate": 0.0,
|
|
210
|
-
},
|
|
263
|
+
config={"patch_size": 16, **TINY},
|
|
211
264
|
)
|
|
212
265
|
registry.register_model_config(
|
|
213
266
|
"deit_s16",
|
|
214
267
|
DeiT,
|
|
215
|
-
config={
|
|
216
|
-
"patch_size": 16,
|
|
217
|
-
"num_layers": 12,
|
|
218
|
-
"num_heads": 6,
|
|
219
|
-
"hidden_dim": 384,
|
|
220
|
-
"mlp_dim": 1536,
|
|
221
|
-
"drop_path_rate": 0.1,
|
|
222
|
-
},
|
|
268
|
+
config={"patch_size": 16, **SMALL, "drop_path_rate": 0.1}, # Override the SMALL definition
|
|
223
269
|
)
|
|
224
270
|
registry.register_model_config(
|
|
225
271
|
"deit_b16",
|
|
226
272
|
DeiT,
|
|
227
|
-
config={
|
|
228
|
-
"patch_size": 16,
|
|
229
|
-
"num_layers": 12,
|
|
230
|
-
"num_heads": 12,
|
|
231
|
-
"hidden_dim": 768,
|
|
232
|
-
"mlp_dim": 3072,
|
|
233
|
-
"drop_path_rate": 0.1,
|
|
234
|
-
},
|
|
273
|
+
config={"patch_size": 16, **BASE},
|
|
235
274
|
)
|
|
236
275
|
|
|
237
276
|
registry.register_weights(
|
|
@@ -242,7 +281,7 @@ registry.register_weights(
|
|
|
242
281
|
"formats": {
|
|
243
282
|
"pt": {
|
|
244
283
|
"file_size": 21.7,
|
|
245
|
-
"sha256": "
|
|
284
|
+
"sha256": "68b33aba0c1be5e78d4a33e74a7c1ea72b6abb232d59f0048ff9b8342e43246e",
|
|
246
285
|
}
|
|
247
286
|
},
|
|
248
287
|
"net": {"network": "deit_t16", "tag": "il-common"},
|
|
@@ -258,7 +297,7 @@ registry.register_weights(
|
|
|
258
297
|
"formats": {
|
|
259
298
|
"pt": {
|
|
260
299
|
"file_size": 21.7,
|
|
261
|
-
"sha256": "
|
|
300
|
+
"sha256": "f693e89fc350341141c55152bec9f499df63738e8423071f3b8e71801c3e5415",
|
|
262
301
|
}
|
|
263
302
|
},
|
|
264
303
|
"net": {"network": "deit_t16", "tag": "dist-il-common"},
|