birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
birder/net/repvgg.py
CHANGED
|
@@ -56,7 +56,6 @@ class RepVggBlock(nn.Module):
|
|
|
56
56
|
stride=(stride, stride),
|
|
57
57
|
padding=(padding, padding),
|
|
58
58
|
groups=groups,
|
|
59
|
-
bias=True,
|
|
60
59
|
)
|
|
61
60
|
else:
|
|
62
61
|
self.reparam_conv = None
|
|
@@ -113,7 +112,7 @@ class RepVggBlock(nn.Module):
|
|
|
113
112
|
if self.reparameterized is True:
|
|
114
113
|
return
|
|
115
114
|
|
|
116
|
-
|
|
115
|
+
kernel, bias = self._get_kernel_bias()
|
|
117
116
|
self.reparam_conv = nn.Conv2d(
|
|
118
117
|
in_channels=self.conv_kxk.conv.in_channels,
|
|
119
118
|
out_channels=self.conv_kxk.conv.out_channels,
|
|
@@ -122,7 +121,6 @@ class RepVggBlock(nn.Module):
|
|
|
122
121
|
padding=self.conv_kxk.conv.padding,
|
|
123
122
|
dilation=self.conv_kxk.conv.dilation,
|
|
124
123
|
groups=self.conv_kxk.conv.groups,
|
|
125
|
-
bias=True,
|
|
126
124
|
)
|
|
127
125
|
self.reparam_conv.weight.data = kernel
|
|
128
126
|
self.reparam_conv.bias.data = bias
|
|
@@ -151,10 +149,10 @@ class RepVggBlock(nn.Module):
|
|
|
151
149
|
kernel_identity = 0
|
|
152
150
|
bias_identity = 0
|
|
153
151
|
if self.rbr_identity is not None:
|
|
154
|
-
|
|
152
|
+
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_identity)
|
|
155
153
|
|
|
156
154
|
# Get weights and bias of conv branches
|
|
157
|
-
|
|
155
|
+
kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
|
|
158
156
|
|
|
159
157
|
kernel_final = kernel_conv + kernel_1x1 + kernel_identity
|
|
160
158
|
bias_final = bias_conv + bias_1x1 + bias_identity
|
birder/net/repvit.py
CHANGED
|
@@ -60,7 +60,7 @@ class RepConvBN(nn.Sequential):
|
|
|
60
60
|
if self.reparameterized is True:
|
|
61
61
|
return
|
|
62
62
|
|
|
63
|
-
|
|
63
|
+
c, bn = self._modules.values()
|
|
64
64
|
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
|
65
65
|
w = c.weight * w[:, None, None, None]
|
|
66
66
|
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
|
@@ -101,7 +101,7 @@ class RepNormLinear(nn.Sequential):
|
|
|
101
101
|
if self.reparameterized is True:
|
|
102
102
|
return
|
|
103
103
|
|
|
104
|
-
|
|
104
|
+
bn, li = self._modules.values()
|
|
105
105
|
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
|
106
106
|
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
|
|
107
107
|
w = li.weight * w[None, :]
|
birder/net/resnest.py
CHANGED
|
@@ -85,7 +85,7 @@ class SplitAttn(nn.Module):
|
|
|
85
85
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
86
86
|
x = self.conv(x)
|
|
87
87
|
|
|
88
|
-
|
|
88
|
+
B, RC, H, W = x.size() # pylint: disable=invalid-name
|
|
89
89
|
if self.radix > 1:
|
|
90
90
|
x = x.reshape((B, self.radix, RC // self.radix, H, W))
|
|
91
91
|
x_gap = x.sum(dim=1)
|
birder/net/rope_deit3.py
CHANGED
|
@@ -34,6 +34,7 @@ from birder.net.base import MaskedTokenRetentionMixin
|
|
|
34
34
|
from birder.net.base import PreTrainEncoder
|
|
35
35
|
from birder.net.base import TokenOmissionResultType
|
|
36
36
|
from birder.net.base import TokenRetentionResultType
|
|
37
|
+
from birder.net.base import normalize_out_indices
|
|
37
38
|
from birder.net.rope_vit import Encoder
|
|
38
39
|
from birder.net.rope_vit import MAEDecoderBlock
|
|
39
40
|
from birder.net.rope_vit import RoPE
|
|
@@ -46,6 +47,7 @@ from birder.net.vit import adjust_position_embedding
|
|
|
46
47
|
class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedTokenRetentionMixin):
|
|
47
48
|
block_group_regex = r"encoder\.block\.(\d+)"
|
|
48
49
|
|
|
50
|
+
# pylint: disable=too-many-locals
|
|
49
51
|
def __init__(
|
|
50
52
|
self,
|
|
51
53
|
input_channels: int,
|
|
@@ -68,6 +70,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
68
70
|
mlp_dim: int = self.config["mlp_dim"]
|
|
69
71
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", 1e-5)
|
|
70
72
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
73
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
71
74
|
rope_rot_type: Literal["standard", "interleaved"] = self.config.get("rope_rot_type", "standard")
|
|
72
75
|
rope_grid_indexing: Literal["ij", "xy"] = self.config.get("rope_grid_indexing", "ij")
|
|
73
76
|
rope_grid_offset: int = self.config.get("rope_grid_offset", 0)
|
|
@@ -86,6 +89,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
86
89
|
self.num_reg_tokens = num_reg_tokens
|
|
87
90
|
self.num_special_tokens = 1 + self.num_reg_tokens
|
|
88
91
|
self.pos_embed_special_tokens = pos_embed_special_tokens
|
|
92
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
89
93
|
self.rope_rot_type = rope_rot_type
|
|
90
94
|
self.rope_grid_indexing = rope_grid_indexing
|
|
91
95
|
self.rope_grid_offset = rope_grid_offset
|
|
@@ -105,7 +109,6 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
105
109
|
kernel_size=(patch_size, patch_size),
|
|
106
110
|
stride=(patch_size, patch_size),
|
|
107
111
|
padding=(0, 0),
|
|
108
|
-
bias=True,
|
|
109
112
|
)
|
|
110
113
|
self.patch_embed = PatchEmbed()
|
|
111
114
|
|
|
@@ -153,8 +156,9 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
153
156
|
)
|
|
154
157
|
self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
155
158
|
|
|
156
|
-
self.
|
|
157
|
-
self.
|
|
159
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
160
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
161
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
158
162
|
self.embedding_size = hidden_dim
|
|
159
163
|
self.classifier = self.create_classifier()
|
|
160
164
|
|
|
@@ -222,7 +226,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
222
226
|
).to(self.rope.pos_embed.device)
|
|
223
227
|
|
|
224
228
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
225
|
-
|
|
229
|
+
H, W = x.shape[-2:]
|
|
226
230
|
x = self.conv_proj(x)
|
|
227
231
|
x = self.patch_embed(x)
|
|
228
232
|
|
|
@@ -238,15 +242,21 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
238
242
|
x = x + self._get_pos_embed(H, W)
|
|
239
243
|
x = torch.concat([batch_special_tokens, x], dim=1)
|
|
240
244
|
|
|
241
|
-
|
|
242
|
-
|
|
245
|
+
rope = self._get_rope_embed(H, W)
|
|
246
|
+
if self.out_indices is None:
|
|
247
|
+
xs = [self.encoder(x, rope)]
|
|
248
|
+
else:
|
|
249
|
+
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
243
250
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
251
|
+
out: dict[str, torch.Tensor] = {}
|
|
252
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
253
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
254
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
255
|
+
B, C, _ = stage_x.size()
|
|
256
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
257
|
+
out[stage_name] = stage_x
|
|
248
258
|
|
|
249
|
-
return
|
|
259
|
+
return out
|
|
250
260
|
|
|
251
261
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
252
262
|
for param in self.conv_proj.parameters():
|
|
@@ -261,6 +271,10 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
261
271
|
for param in module.parameters():
|
|
262
272
|
param.requires_grad_(False)
|
|
263
273
|
|
|
274
|
+
def transform_to_backbone(self) -> None:
|
|
275
|
+
super().transform_to_backbone()
|
|
276
|
+
self.norm = nn.Identity()
|
|
277
|
+
|
|
264
278
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
265
279
|
self.encoder.set_causal_attention(is_causal)
|
|
266
280
|
|
|
@@ -271,7 +285,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
271
285
|
return_all_features: bool = False,
|
|
272
286
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
273
287
|
) -> TokenOmissionResultType:
|
|
274
|
-
|
|
288
|
+
H, W = x.shape[-2:]
|
|
275
289
|
|
|
276
290
|
# Reshape and permute the input tensor
|
|
277
291
|
x = self.conv_proj(x)
|
|
@@ -340,7 +354,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
340
354
|
mask_token: Optional[torch.Tensor] = None,
|
|
341
355
|
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
342
356
|
) -> TokenRetentionResultType:
|
|
343
|
-
|
|
357
|
+
H, W = x.shape[-2:]
|
|
344
358
|
|
|
345
359
|
x = self.conv_proj(x)
|
|
346
360
|
x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
|
|
@@ -370,7 +384,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
370
384
|
if return_keys in ("all", "features"):
|
|
371
385
|
features = x[:, self.num_special_tokens :]
|
|
372
386
|
features = features.permute(0, 2, 1)
|
|
373
|
-
|
|
387
|
+
B, C, _ = features.size()
|
|
374
388
|
features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
375
389
|
result["features"] = features
|
|
376
390
|
|
|
@@ -380,7 +394,7 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
|
|
|
380
394
|
return result
|
|
381
395
|
|
|
382
396
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
383
|
-
|
|
397
|
+
H, W = x.shape[-2:]
|
|
384
398
|
|
|
385
399
|
# Reshape and permute the input tensor
|
|
386
400
|
x = self.conv_proj(x)
|
birder/net/rope_flexivit.py
CHANGED
|
@@ -29,6 +29,7 @@ from birder.net.base import MaskedTokenRetentionMixin
|
|
|
29
29
|
from birder.net.base import PreTrainEncoder
|
|
30
30
|
from birder.net.base import TokenOmissionResultType
|
|
31
31
|
from birder.net.base import TokenRetentionResultType
|
|
32
|
+
from birder.net.base import normalize_out_indices
|
|
32
33
|
from birder.net.flexivit import flex_proj
|
|
33
34
|
from birder.net.flexivit import get_patch_sizes
|
|
34
35
|
from birder.net.flexivit import interpolate_proj
|
|
@@ -82,6 +83,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
82
83
|
norm_layer_eps: float = self.config.get("norm_layer_eps", 1e-6)
|
|
83
84
|
mlp_layer_type: str = self.config.get("mlp_layer_type", "FFN")
|
|
84
85
|
act_layer_type: Optional[str] = self.config.get("act_layer_type", None) # Default according to mlp type
|
|
86
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
85
87
|
rope_rot_type: Literal["standard", "interleaved"] = self.config.get("rope_rot_type", "standard")
|
|
86
88
|
rope_grid_indexing: Literal["ij", "xy"] = self.config.get("rope_grid_indexing", "ij")
|
|
87
89
|
rope_grid_offset: int = self.config.get("rope_grid_offset", 0)
|
|
@@ -125,6 +127,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
125
127
|
self.norm_layer_eps = norm_layer_eps
|
|
126
128
|
self.mlp_layer = mlp_layer
|
|
127
129
|
self.act_layer = act_layer
|
|
130
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
128
131
|
self.rope_rot_type = rope_rot_type
|
|
129
132
|
self.rope_grid_indexing = rope_grid_indexing
|
|
130
133
|
self.rope_grid_offset = rope_grid_offset
|
|
@@ -145,7 +148,6 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
145
148
|
kernel_size=(patch_size, patch_size),
|
|
146
149
|
stride=(patch_size, patch_size),
|
|
147
150
|
padding=(0, 0),
|
|
148
|
-
bias=True,
|
|
149
151
|
)
|
|
150
152
|
self.patch_embed = PatchEmbed()
|
|
151
153
|
|
|
@@ -218,8 +220,9 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
218
220
|
|
|
219
221
|
self.attn_pool = MultiHeadAttentionPool(hidden_dim, attn_pool_num_heads, mlp_dim, qkv_bias=True)
|
|
220
222
|
|
|
221
|
-
self.
|
|
222
|
-
self.
|
|
223
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
224
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
225
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
223
226
|
self.embedding_size = hidden_dim
|
|
224
227
|
self.classifier = self.create_classifier()
|
|
225
228
|
|
|
@@ -307,8 +310,12 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
307
310
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
308
311
|
self.encoder.set_causal_attention(is_causal)
|
|
309
312
|
|
|
313
|
+
def transform_to_backbone(self) -> None:
|
|
314
|
+
super().transform_to_backbone()
|
|
315
|
+
self.norm = nn.Identity()
|
|
316
|
+
|
|
310
317
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
311
|
-
|
|
318
|
+
H, W = x.shape[-2:]
|
|
312
319
|
x = self.conv_proj(x)
|
|
313
320
|
x = self.patch_embed(x)
|
|
314
321
|
|
|
@@ -328,15 +335,21 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
328
335
|
if self.pos_embed_special_tokens is True:
|
|
329
336
|
x = x + self._get_pos_embed(H, W)
|
|
330
337
|
|
|
331
|
-
|
|
332
|
-
|
|
338
|
+
rope = self._get_rope_embed(H, W)
|
|
339
|
+
if self.out_indices is None:
|
|
340
|
+
xs = [self.encoder(x, rope)]
|
|
341
|
+
else:
|
|
342
|
+
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
333
343
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
344
|
+
out: dict[str, torch.Tensor] = {}
|
|
345
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
346
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
347
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
348
|
+
B, C, _ = stage_x.size()
|
|
349
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
350
|
+
out[stage_name] = stage_x
|
|
338
351
|
|
|
339
|
-
return
|
|
352
|
+
return out
|
|
340
353
|
|
|
341
354
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
342
355
|
for param in self.conv_proj.parameters():
|
|
@@ -359,7 +372,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
359
372
|
return_all_features: bool = False,
|
|
360
373
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
361
374
|
) -> TokenOmissionResultType:
|
|
362
|
-
|
|
375
|
+
H, W = x.shape[-2:]
|
|
363
376
|
|
|
364
377
|
# Reshape and permute the input tensor
|
|
365
378
|
x = self.conv_proj(x)
|
|
@@ -439,7 +452,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
439
452
|
mask_token: Optional[torch.Tensor] = None,
|
|
440
453
|
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
441
454
|
) -> TokenRetentionResultType:
|
|
442
|
-
|
|
455
|
+
H, W = x.shape[-2:]
|
|
443
456
|
|
|
444
457
|
x = self.conv_proj(x)
|
|
445
458
|
x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
|
|
@@ -470,7 +483,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
470
483
|
if return_keys in ("all", "features"):
|
|
471
484
|
features = x[:, self.num_special_tokens :]
|
|
472
485
|
features = features.permute(0, 2, 1)
|
|
473
|
-
|
|
486
|
+
B, C, _ = features.size()
|
|
474
487
|
features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
475
488
|
result["features"] = features
|
|
476
489
|
|
|
@@ -490,7 +503,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
490
503
|
return result
|
|
491
504
|
|
|
492
505
|
def forward_features(self, x: torch.Tensor, patch_size: Optional[int] = None) -> torch.Tensor:
|
|
493
|
-
|
|
506
|
+
H, W = x.shape[-2:]
|
|
494
507
|
|
|
495
508
|
# Reshape and permute the input tensor
|
|
496
509
|
x = flex_proj(x, self.conv_proj.weight, self.conv_proj.bias, patch_size)
|
birder/net/rope_vit.py
CHANGED
|
@@ -38,6 +38,7 @@ from birder.net.base import MaskedTokenRetentionMixin
|
|
|
38
38
|
from birder.net.base import PreTrainEncoder
|
|
39
39
|
from birder.net.base import TokenOmissionResultType
|
|
40
40
|
from birder.net.base import TokenRetentionResultType
|
|
41
|
+
from birder.net.base import normalize_out_indices
|
|
41
42
|
from birder.net.vit import PatchEmbed
|
|
42
43
|
from birder.net.vit import adjust_position_embedding
|
|
43
44
|
|
|
@@ -76,7 +77,7 @@ def build_rotary_pos_embed(
|
|
|
76
77
|
|
|
77
78
|
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
78
79
|
# Taken from: https://github.com/facebookresearch/capi/blob/main/model.py
|
|
79
|
-
|
|
80
|
+
x1, x2 = x.chunk(2, dim=-1)
|
|
80
81
|
return torch.concat((-x2, x1), dim=-1)
|
|
81
82
|
|
|
82
83
|
|
|
@@ -85,7 +86,7 @@ def rotate_half_interleaved(x: torch.Tensor) -> torch.Tensor:
|
|
|
85
86
|
|
|
86
87
|
|
|
87
88
|
def apply_rotary_pos_embed(x: torch.Tensor, embed: torch.Tensor) -> torch.Tensor:
|
|
88
|
-
|
|
89
|
+
sin_emb, cos_emb = embed.tensor_split(2, dim=-1)
|
|
89
90
|
if cos_emb.ndim == 3:
|
|
90
91
|
return x * cos_emb.unsqueeze(1).expand_as(x) + rotate_half(x) * sin_emb.unsqueeze(1).expand_as(x)
|
|
91
92
|
|
|
@@ -93,7 +94,7 @@ def apply_rotary_pos_embed(x: torch.Tensor, embed: torch.Tensor) -> torch.Tensor
|
|
|
93
94
|
|
|
94
95
|
|
|
95
96
|
def apply_interleaved_rotary_pos_embed(x: torch.Tensor, embed: torch.Tensor) -> torch.Tensor:
|
|
96
|
-
|
|
97
|
+
sin_emb, cos_emb = embed.tensor_split(2, dim=-1)
|
|
97
98
|
if cos_emb.ndim == 3:
|
|
98
99
|
return x * cos_emb.unsqueeze(1).expand_as(x) + rotate_half_interleaved(x) * sin_emb.unsqueeze(1).expand_as(x)
|
|
99
100
|
|
|
@@ -128,7 +129,7 @@ class RoPE(nn.Module):
|
|
|
128
129
|
else:
|
|
129
130
|
raise ValueError(f"Unknown rope_rot_type, got '{rope_rot_type}'")
|
|
130
131
|
|
|
131
|
-
|
|
132
|
+
sin_emb, cos_emb = build_rotary_pos_embed(
|
|
132
133
|
dim,
|
|
133
134
|
temperature,
|
|
134
135
|
grid_size=grid_size,
|
|
@@ -185,9 +186,9 @@ class RoPEAttention(nn.Module):
|
|
|
185
186
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
186
187
|
|
|
187
188
|
def forward(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
|
|
188
|
-
|
|
189
|
+
B, N, C = x.size()
|
|
189
190
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
190
|
-
|
|
191
|
+
q, k, v = qkv.unbind(0)
|
|
191
192
|
q = self.q_norm(q)
|
|
192
193
|
k = self.k_norm(k)
|
|
193
194
|
|
|
@@ -326,13 +327,17 @@ class Encoder(nn.Module):
|
|
|
326
327
|
x = self.pre_block(x)
|
|
327
328
|
return self.block(x, rope)
|
|
328
329
|
|
|
329
|
-
def forward_features(
|
|
330
|
+
def forward_features(
|
|
331
|
+
self, x: torch.Tensor, rope: torch.Tensor, out_indices: Optional[list[int]] = None
|
|
332
|
+
) -> list[torch.Tensor]:
|
|
330
333
|
x = self.pre_block(x)
|
|
331
334
|
|
|
335
|
+
out_indices_set = set(out_indices) if out_indices is not None else None
|
|
332
336
|
xs = []
|
|
333
|
-
for blk in self.block:
|
|
337
|
+
for idx, blk in enumerate(self.block):
|
|
334
338
|
x = blk(x, rope)
|
|
335
|
-
|
|
339
|
+
if out_indices_set is None or idx in out_indices_set:
|
|
340
|
+
xs.append(x)
|
|
336
341
|
|
|
337
342
|
return xs
|
|
338
343
|
|
|
@@ -438,6 +443,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
438
443
|
norm_layer_eps: float = self.config.get("norm_layer_eps", 1e-6)
|
|
439
444
|
mlp_layer_type: str = self.config.get("mlp_layer_type", "FFN")
|
|
440
445
|
act_layer_type: Optional[str] = self.config.get("act_layer_type", None) # Default according to mlp type
|
|
446
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
441
447
|
rope_rot_type: Literal["standard", "interleaved"] = self.config.get("rope_rot_type", "standard")
|
|
442
448
|
rope_grid_indexing: Literal["ij", "xy"] = self.config.get("rope_grid_indexing", "ij")
|
|
443
449
|
rope_grid_offset: int = self.config.get("rope_grid_offset", 0)
|
|
@@ -479,6 +485,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
479
485
|
self.norm_layer_eps = norm_layer_eps
|
|
480
486
|
self.mlp_layer = mlp_layer
|
|
481
487
|
self.act_layer = act_layer
|
|
488
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
482
489
|
self.rope_rot_type = rope_rot_type
|
|
483
490
|
self.rope_grid_indexing = rope_grid_indexing
|
|
484
491
|
self.rope_grid_offset = rope_grid_offset
|
|
@@ -571,8 +578,9 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
571
578
|
|
|
572
579
|
self.attn_pool = MultiHeadAttentionPool(hidden_dim, attn_pool_num_heads, mlp_dim, qkv_bias=True)
|
|
573
580
|
|
|
574
|
-
self.
|
|
575
|
-
self.
|
|
581
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
582
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
583
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
576
584
|
self.embedding_size = hidden_dim
|
|
577
585
|
self.classifier = self.create_classifier()
|
|
578
586
|
|
|
@@ -658,8 +666,12 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
658
666
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
659
667
|
self.encoder.set_causal_attention(is_causal)
|
|
660
668
|
|
|
669
|
+
def transform_to_backbone(self) -> None:
|
|
670
|
+
super().transform_to_backbone()
|
|
671
|
+
self.norm = nn.Identity()
|
|
672
|
+
|
|
661
673
|
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
662
|
-
|
|
674
|
+
H, W = x.shape[-2:]
|
|
663
675
|
x = self.conv_proj(x)
|
|
664
676
|
x = self.patch_embed(x)
|
|
665
677
|
|
|
@@ -679,15 +691,21 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
679
691
|
if self.pos_embed_special_tokens is True:
|
|
680
692
|
x = x + self._get_pos_embed(H, W)
|
|
681
693
|
|
|
682
|
-
|
|
683
|
-
|
|
694
|
+
rope = self._get_rope_embed(H, W)
|
|
695
|
+
if self.out_indices is None:
|
|
696
|
+
xs = [self.encoder(x, rope)]
|
|
697
|
+
else:
|
|
698
|
+
xs = self.encoder.forward_features(x, rope, out_indices=self.out_indices)
|
|
684
699
|
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
700
|
+
out: dict[str, torch.Tensor] = {}
|
|
701
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
702
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
703
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
704
|
+
B, C, _ = stage_x.size()
|
|
705
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
706
|
+
out[stage_name] = stage_x
|
|
689
707
|
|
|
690
|
-
return
|
|
708
|
+
return out
|
|
691
709
|
|
|
692
710
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
693
711
|
for param in self.conv_proj.parameters():
|
|
@@ -709,7 +727,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
709
727
|
return_all_features: bool = False,
|
|
710
728
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
711
729
|
) -> TokenOmissionResultType:
|
|
712
|
-
|
|
730
|
+
H, W = x.shape[-2:]
|
|
713
731
|
|
|
714
732
|
# Reshape and permute the input tensor
|
|
715
733
|
x = self.conv_proj(x)
|
|
@@ -789,7 +807,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
789
807
|
mask_token: Optional[torch.Tensor] = None,
|
|
790
808
|
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
791
809
|
) -> TokenRetentionResultType:
|
|
792
|
-
|
|
810
|
+
H, W = x.shape[-2:]
|
|
793
811
|
|
|
794
812
|
x = self.conv_proj(x)
|
|
795
813
|
x = mask_tensor(x, mask, mask_token=mask_token, patch_factor=self.max_stride // self.stem_stride)
|
|
@@ -820,7 +838,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
820
838
|
if return_keys in ("all", "features"):
|
|
821
839
|
features = x[:, self.num_special_tokens :]
|
|
822
840
|
features = features.permute(0, 2, 1)
|
|
823
|
-
|
|
841
|
+
B, C, _ = features.size()
|
|
824
842
|
features = features.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
825
843
|
result["features"] = features
|
|
826
844
|
|
|
@@ -840,7 +858,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
840
858
|
return result
|
|
841
859
|
|
|
842
860
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
843
|
-
|
|
861
|
+
H, W = x.shape[-2:]
|
|
844
862
|
|
|
845
863
|
# Reshape and permute the input tensor
|
|
846
864
|
x = self.conv_proj(x)
|
birder/net/sequencer2d.py
CHANGED
|
@@ -57,16 +57,16 @@ class LSTM2d(nn.Module):
|
|
|
57
57
|
)
|
|
58
58
|
|
|
59
59
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
|
-
|
|
60
|
+
B, H, W, C = x.shape
|
|
61
61
|
|
|
62
62
|
v = x.permute(0, 2, 1, 3)
|
|
63
63
|
v = v.reshape(-1, H, C)
|
|
64
|
-
|
|
64
|
+
v, _ = self.rnn_v(v)
|
|
65
65
|
v = v.reshape(B, W, H, -1)
|
|
66
66
|
v = v.permute(0, 2, 1, 3)
|
|
67
67
|
|
|
68
68
|
h = x.reshape(-1, W, C)
|
|
69
|
-
|
|
69
|
+
h, _ = self.rnn_h(h)
|
|
70
70
|
h = h.reshape(B, H, W, -1)
|
|
71
71
|
|
|
72
72
|
x = torch.concat([v, h], dim=-1)
|
|
@@ -187,7 +187,6 @@ class Sequencer2d(BaseNet):
|
|
|
187
187
|
kernel_size=(patch_sizes[0], patch_sizes[0]),
|
|
188
188
|
stride=(patch_sizes[0], patch_sizes[0]),
|
|
189
189
|
padding=(0, 0),
|
|
190
|
-
bias=True,
|
|
191
190
|
),
|
|
192
191
|
Permute([0, 2, 3, 1]),
|
|
193
192
|
)
|
birder/net/shufflenet_v1.py
CHANGED
|
@@ -22,7 +22,7 @@ from birder.net.base import DetectorBackbone
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
|
|
25
|
-
|
|
25
|
+
batch_size, num_channels, height, width = x.size()
|
|
26
26
|
channels_per_group = num_channels // groups
|
|
27
27
|
|
|
28
28
|
# Reshape
|
birder/net/shufflenet_v2.py
CHANGED
|
@@ -85,7 +85,7 @@ class ShuffleUnit(nn.Module):
|
|
|
85
85
|
|
|
86
86
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
87
87
|
if self.dw_conv_stride == 1:
|
|
88
|
-
|
|
88
|
+
branch1, branch2 = x.chunk(2, dim=1)
|
|
89
89
|
x = torch.concat((branch1, self.branch2(branch2)), dim=1)
|
|
90
90
|
else:
|
|
91
91
|
x = torch.concat((self.branch1(x), self.branch2(x)), dim=1)
|
birder/net/simple_vit.py
CHANGED
|
@@ -26,17 +26,19 @@ from birder.net._vit_configs import HUGE
|
|
|
26
26
|
from birder.net._vit_configs import LARGE
|
|
27
27
|
from birder.net._vit_configs import MEDIUM
|
|
28
28
|
from birder.net._vit_configs import SMALL
|
|
29
|
+
from birder.net.base import DetectorBackbone
|
|
29
30
|
from birder.net.base import MaskedTokenOmissionMixin
|
|
30
31
|
from birder.net.base import PreTrainEncoder
|
|
31
32
|
from birder.net.base import TokenOmissionResultType
|
|
33
|
+
from birder.net.base import normalize_out_indices
|
|
32
34
|
from birder.net.base import pos_embedding_sin_cos_2d
|
|
33
35
|
from birder.net.vit import Encoder
|
|
34
36
|
from birder.net.vit import EncoderBlock
|
|
35
37
|
from birder.net.vit import PatchEmbed
|
|
36
38
|
|
|
37
39
|
|
|
38
|
-
# pylint: disable=invalid-name
|
|
39
|
-
class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
40
|
+
# pylint: disable=invalid-name,too-many-instance-attributes
|
|
41
|
+
class Simple_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
40
42
|
block_group_regex = r"encoder\.block\.(\d+)"
|
|
41
43
|
|
|
42
44
|
def __init__(
|
|
@@ -56,6 +58,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
56
58
|
num_heads: int = self.config["num_heads"]
|
|
57
59
|
hidden_dim: int = self.config["hidden_dim"]
|
|
58
60
|
mlp_dim: int = self.config["mlp_dim"]
|
|
61
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
59
62
|
drop_path_rate: float = self.config["drop_path_rate"]
|
|
60
63
|
|
|
61
64
|
torch._assert(image_size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
@@ -66,6 +69,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
66
69
|
self.hidden_dim = hidden_dim
|
|
67
70
|
self.mlp_dim = mlp_dim
|
|
68
71
|
self.num_special_tokens = 0
|
|
72
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
69
73
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # Stochastic depth decay rule
|
|
70
74
|
|
|
71
75
|
self.conv_proj = nn.Conv2d(
|
|
@@ -74,7 +78,6 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
74
78
|
kernel_size=(patch_size, patch_size),
|
|
75
79
|
stride=(patch_size, patch_size),
|
|
76
80
|
padding=(0, 0),
|
|
77
|
-
bias=True,
|
|
78
81
|
)
|
|
79
82
|
self.patch_embed = PatchEmbed()
|
|
80
83
|
|
|
@@ -94,6 +97,9 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
94
97
|
nn.Flatten(1),
|
|
95
98
|
)
|
|
96
99
|
|
|
100
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
101
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
102
|
+
self.return_channels = [hidden_dim] * num_return_stages
|
|
97
103
|
self.embedding_size = hidden_dim
|
|
98
104
|
self.classifier = self.create_classifier()
|
|
99
105
|
|
|
@@ -144,7 +150,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
144
150
|
return_all_features: bool = False,
|
|
145
151
|
return_keys: Literal["all", "tokens", "embedding"] = "tokens",
|
|
146
152
|
) -> TokenOmissionResultType:
|
|
147
|
-
|
|
153
|
+
H, W = x.shape[-2:]
|
|
148
154
|
|
|
149
155
|
# Reshape and permute the input tensor
|
|
150
156
|
x = self.conv_proj(x)
|
|
@@ -179,7 +185,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
179
185
|
return result
|
|
180
186
|
|
|
181
187
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
182
|
-
|
|
188
|
+
H, W = x.shape[-2:]
|
|
183
189
|
x = self.conv_proj(x)
|
|
184
190
|
x = self.patch_embed(x)
|
|
185
191
|
x = x + self._get_pos_embed(H, W)
|
|
@@ -193,6 +199,42 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
193
199
|
x = x.permute(0, 2, 1)
|
|
194
200
|
return self.features(x)
|
|
195
201
|
|
|
202
|
+
def transform_to_backbone(self) -> None:
|
|
203
|
+
super().transform_to_backbone()
|
|
204
|
+
self.norm = nn.Identity()
|
|
205
|
+
|
|
206
|
+
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
207
|
+
H, W = x.shape[-2:]
|
|
208
|
+
x = self.conv_proj(x)
|
|
209
|
+
x = self.patch_embed(x)
|
|
210
|
+
x = x + self._get_pos_embed(H, W)
|
|
211
|
+
|
|
212
|
+
if self.out_indices is None:
|
|
213
|
+
xs = [self.encoder(x)]
|
|
214
|
+
else:
|
|
215
|
+
xs = self.encoder.forward_features(x, out_indices=self.out_indices)
|
|
216
|
+
|
|
217
|
+
out: dict[str, torch.Tensor] = {}
|
|
218
|
+
for stage_name, stage_x in zip(self.return_stages, xs):
|
|
219
|
+
stage_x = stage_x[:, self.num_special_tokens :]
|
|
220
|
+
stage_x = stage_x.permute(0, 2, 1)
|
|
221
|
+
B, C, _ = stage_x.size()
|
|
222
|
+
stage_x = stage_x.reshape(B, C, H // self.patch_size, W // self.patch_size)
|
|
223
|
+
out[stage_name] = stage_x
|
|
224
|
+
|
|
225
|
+
return out
|
|
226
|
+
|
|
227
|
+
def freeze_stages(self, up_to_stage: int) -> None:
|
|
228
|
+
for param in self.conv_proj.parameters():
|
|
229
|
+
param.requires_grad_(False)
|
|
230
|
+
|
|
231
|
+
for idx, module in enumerate(self.encoder.children()):
|
|
232
|
+
if idx >= up_to_stage:
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
for param in module.parameters():
|
|
236
|
+
param.requires_grad_(False)
|
|
237
|
+
|
|
196
238
|
def adjust_size(self, new_size: tuple[int, int]) -> None:
|
|
197
239
|
if new_size == self.size:
|
|
198
240
|
return
|