birder 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/base.py +1 -1
- birder/adversarial/simba.py +4 -4
- birder/common/cli.py +1 -1
- birder/common/fs_ops.py +11 -11
- birder/common/lib.py +2 -2
- birder/common/masking.py +3 -3
- birder/common/training_cli.py +24 -2
- birder/common/training_utils.py +28 -4
- birder/data/collators/detection.py +9 -1
- birder/data/transforms/detection.py +27 -8
- birder/data/transforms/mosaic.py +1 -1
- birder/datahub/classification.py +3 -3
- birder/inference/classification.py +3 -3
- birder/inference/data_parallel.py +1 -1
- birder/inference/detection.py +5 -5
- birder/inference/wbf.py +1 -1
- birder/introspection/attention_rollout.py +5 -5
- birder/introspection/feature_pca.py +4 -4
- birder/introspection/gradcam.py +1 -1
- birder/introspection/guided_backprop.py +2 -2
- birder/introspection/transformer_attribution.py +3 -3
- birder/layers/attention_pool.py +2 -2
- birder/model_registry/model_registry.py +2 -1
- birder/net/__init__.py +2 -0
- birder/net/_rope_vit_configs.py +5 -0
- birder/net/_vit_configs.py +0 -13
- birder/net/alexnet.py +5 -5
- birder/net/base.py +28 -3
- birder/net/biformer.py +17 -17
- birder/net/cait.py +2 -2
- birder/net/cas_vit.py +1 -1
- birder/net/coat.py +15 -15
- birder/net/convnext_v1.py +2 -10
- birder/net/convnext_v1_iso.py +198 -0
- birder/net/convnext_v2.py +2 -10
- birder/net/crossformer.py +9 -9
- birder/net/crossvit.py +1 -1
- birder/net/cspnet.py +1 -1
- birder/net/cswin_transformer.py +10 -10
- birder/net/davit.py +10 -10
- birder/net/deit.py +56 -3
- birder/net/deit3.py +27 -15
- birder/net/detection/__init__.py +4 -0
- birder/net/detection/{yolo_anchors.py → _yolo_anchors.py} +5 -5
- birder/net/detection/base.py +6 -5
- birder/net/detection/deformable_detr.py +26 -28
- birder/net/detection/detr.py +9 -9
- birder/net/detection/efficientdet.py +9 -28
- birder/net/detection/faster_rcnn.py +22 -22
- birder/net/detection/fcos.py +8 -8
- birder/net/detection/plain_detr.py +852 -0
- birder/net/detection/retinanet.py +4 -4
- birder/net/detection/rt_detr_v1.py +81 -25
- birder/net/detection/rt_detr_v2.py +1147 -0
- birder/net/detection/ssd.py +5 -5
- birder/net/detection/yolo_v2.py +12 -12
- birder/net/detection/yolo_v3.py +19 -19
- birder/net/detection/yolo_v4.py +16 -16
- birder/net/detection/yolo_v4_tiny.py +3 -3
- birder/net/edgenext.py +3 -3
- birder/net/edgevit.py +10 -14
- birder/net/efficientformer_v1.py +1 -1
- birder/net/efficientvim.py +9 -9
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +3 -3
- birder/net/fasternet.py +1 -1
- birder/net/fastvit.py +5 -12
- birder/net/flexivit.py +28 -15
- birder/net/focalnet.py +5 -9
- birder/net/gc_vit.py +11 -11
- birder/net/ghostnet_v1.py +1 -1
- birder/net/ghostnet_v2.py +1 -1
- birder/net/groupmixformer.py +12 -12
- birder/net/hgnet_v1.py +1 -1
- birder/net/hgnet_v2.py +4 -4
- birder/net/hiera.py +6 -6
- birder/net/hieradet.py +9 -9
- birder/net/hornet.py +3 -3
- birder/net/iformer.py +4 -4
- birder/net/inception_next.py +4 -14
- birder/net/levit.py +3 -3
- birder/net/lit_v1.py +13 -15
- birder/net/lit_v1_tiny.py +9 -9
- birder/net/lit_v2.py +14 -15
- birder/net/maxvit.py +10 -22
- birder/net/metaformer.py +2 -2
- birder/net/mim/crossmae.py +5 -5
- birder/net/mim/fcmae.py +3 -5
- birder/net/mim/mae_hiera.py +7 -7
- birder/net/mim/mae_vit.py +3 -5
- birder/net/mim/simmim.py +2 -3
- birder/net/mobilenet_v4_hybrid.py +4 -4
- birder/net/mobileone.py +5 -12
- birder/net/mobilevit_v1.py +2 -2
- birder/net/mobilevit_v2.py +5 -9
- birder/net/mvit_v2.py +24 -24
- birder/net/nextvit.py +2 -2
- birder/net/pit.py +11 -26
- birder/net/pvt_v1.py +4 -4
- birder/net/pvt_v2.py +5 -11
- birder/net/regionvit.py +15 -15
- birder/net/regnet.py +1 -1
- birder/net/repghost.py +4 -5
- birder/net/repvgg.py +3 -5
- birder/net/repvit.py +2 -2
- birder/net/resnest.py +1 -1
- birder/net/rope_deit3.py +29 -15
- birder/net/rope_flexivit.py +28 -15
- birder/net/rope_vit.py +41 -23
- birder/net/sequencer2d.py +3 -4
- birder/net/shufflenet_v1.py +1 -1
- birder/net/shufflenet_v2.py +1 -1
- birder/net/simple_vit.py +47 -5
- birder/net/smt.py +7 -7
- birder/net/ssl/barlow_twins.py +1 -1
- birder/net/ssl/byol.py +2 -2
- birder/net/ssl/capi.py +3 -3
- birder/net/ssl/data2vec2.py +1 -1
- birder/net/ssl/dino_v2.py +11 -1
- birder/net/ssl/franca.py +26 -2
- birder/net/ssl/i_jepa.py +4 -4
- birder/net/ssl/mmcr.py +1 -1
- birder/net/swiftformer.py +1 -1
- birder/net/swin_transformer_v1.py +4 -5
- birder/net/swin_transformer_v2.py +4 -7
- birder/net/tiny_vit.py +3 -3
- birder/net/transnext.py +19 -19
- birder/net/uniformer.py +4 -4
- birder/net/vgg.py +1 -10
- birder/net/vit.py +38 -25
- birder/net/vit_parallel.py +35 -20
- birder/net/vit_sam.py +10 -10
- birder/net/vovnet_v2.py +1 -1
- birder/net/xcit.py +9 -7
- birder/ops/msda.py +4 -4
- birder/ops/swattention.py +10 -10
- birder/results/classification.py +3 -3
- birder/results/gui.py +8 -8
- birder/scripts/benchmark.py +37 -12
- birder/scripts/evaluate.py +1 -1
- birder/scripts/predict.py +3 -3
- birder/scripts/predict_detection.py +2 -2
- birder/scripts/train.py +63 -15
- birder/scripts/train_barlow_twins.py +10 -7
- birder/scripts/train_byol.py +10 -7
- birder/scripts/train_capi.py +15 -10
- birder/scripts/train_data2vec.py +10 -7
- birder/scripts/train_data2vec2.py +10 -7
- birder/scripts/train_detection.py +29 -14
- birder/scripts/train_dino_v1.py +13 -9
- birder/scripts/train_dino_v2.py +27 -14
- birder/scripts/train_dino_v2_dist.py +28 -15
- birder/scripts/train_franca.py +16 -9
- birder/scripts/train_i_jepa.py +12 -9
- birder/scripts/train_ibot.py +15 -11
- birder/scripts/train_kd.py +64 -17
- birder/scripts/train_mim.py +11 -8
- birder/scripts/train_mmcr.py +11 -8
- birder/scripts/train_rotnet.py +11 -7
- birder/scripts/train_simclr.py +10 -7
- birder/scripts/train_vicreg.py +10 -7
- birder/tools/adversarial.py +4 -4
- birder/tools/auto_anchors.py +5 -5
- birder/tools/avg_model.py +1 -1
- birder/tools/convert_model.py +30 -22
- birder/tools/det_results.py +1 -1
- birder/tools/download_model.py +1 -1
- birder/tools/ensemble_model.py +1 -1
- birder/tools/introspection.py +11 -2
- birder/tools/labelme_to_coco.py +2 -2
- birder/tools/model_info.py +12 -14
- birder/tools/pack.py +8 -8
- birder/tools/quantize_model.py +53 -4
- birder/tools/results.py +2 -2
- birder/tools/show_det_iterator.py +19 -6
- birder/tools/show_iterator.py +2 -2
- birder/tools/similarity.py +5 -5
- birder/tools/stats.py +4 -6
- birder/tools/voc_to_coco.py +1 -1
- birder/version.py +1 -1
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/METADATA +3 -3
- birder-0.4.1.dist-info/RECORD +300 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/WHEEL +1 -1
- birder-0.4.0.dist-info/RECORD +0 -297
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/entry_points.txt +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {birder-0.4.0.dist-info → birder-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -66,7 +66,7 @@ def compute_attribution_rollout(
|
|
|
66
66
|
|
|
67
67
|
mask = mask / (mask.max() + 1e-8)
|
|
68
68
|
|
|
69
|
-
|
|
69
|
+
grid_h, grid_w = patch_grid_shape
|
|
70
70
|
mask = mask.reshape(grid_h, grid_w)
|
|
71
71
|
|
|
72
72
|
return mask
|
|
@@ -140,7 +140,7 @@ class TransformerAttribution:
|
|
|
140
140
|
self.gatherer = AttributionGatherer(net, attention_layer_name)
|
|
141
141
|
|
|
142
142
|
def __call__(self, image: str | Path | Image.Image, target_class: Optional[int] = None) -> InterpretabilityResult:
|
|
143
|
-
|
|
143
|
+
input_tensor, rgb_img = preprocess_image(image, self.transform, self.device)
|
|
144
144
|
input_tensor.requires_grad_(True)
|
|
145
145
|
|
|
146
146
|
self.net.zero_grad()
|
|
@@ -156,7 +156,7 @@ class TransformerAttribution:
|
|
|
156
156
|
|
|
157
157
|
attribution_data = self.gatherer.get_captured_data()
|
|
158
158
|
|
|
159
|
-
|
|
159
|
+
_, _, H, W = input_tensor.shape
|
|
160
160
|
patch_grid_shape = (H // self.net.stem_stride, W // self.net.stem_stride)
|
|
161
161
|
|
|
162
162
|
attribution_map = compute_attribution_rollout(
|
birder/layers/attention_pool.py
CHANGED
|
@@ -39,13 +39,13 @@ class MultiHeadAttentionPool(nn.Module):
|
|
|
39
39
|
nn.init.trunc_normal_(self.latent, std=dim**-0.5)
|
|
40
40
|
|
|
41
41
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
42
|
-
|
|
42
|
+
B, N, C = x.size()
|
|
43
43
|
|
|
44
44
|
q_latent = self.latent.expand(B, self.latent_len, -1)
|
|
45
45
|
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
46
46
|
|
|
47
47
|
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
48
|
-
|
|
48
|
+
k, v = kv.unbind(0)
|
|
49
49
|
|
|
50
50
|
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) # pylint: disable=not-callable
|
|
51
51
|
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
|
@@ -6,6 +6,7 @@ from typing import Any
|
|
|
6
6
|
from typing import Literal
|
|
7
7
|
from typing import Optional
|
|
8
8
|
|
|
9
|
+
from birder.conf.settings import DEFAULT_NUM_CHANNELS
|
|
9
10
|
from birder.model_registry import manifest
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING is True:
|
|
@@ -229,8 +230,8 @@ class ModelRegistry:
|
|
|
229
230
|
def net_factory(
|
|
230
231
|
self,
|
|
231
232
|
name: str,
|
|
232
|
-
input_channels: int,
|
|
233
233
|
num_classes: int,
|
|
234
|
+
input_channels: int = DEFAULT_NUM_CHANNELS,
|
|
234
235
|
*,
|
|
235
236
|
config: Optional[dict[str, Any]] = None,
|
|
236
237
|
size: Optional[tuple[int, int]] = None,
|
birder/net/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from birder.net.coat import CoaT
|
|
|
6
6
|
from birder.net.conv2former import Conv2Former
|
|
7
7
|
from birder.net.convmixer import ConvMixer
|
|
8
8
|
from birder.net.convnext_v1 import ConvNeXt_v1
|
|
9
|
+
from birder.net.convnext_v1_iso import ConvNeXt_v1_Isotropic
|
|
9
10
|
from birder.net.convnext_v2 import ConvNeXt_v2
|
|
10
11
|
from birder.net.crossformer import CrossFormer
|
|
11
12
|
from birder.net.crossvit import CrossViT
|
|
@@ -118,6 +119,7 @@ __all__ = [
|
|
|
118
119
|
"Conv2Former",
|
|
119
120
|
"ConvMixer",
|
|
120
121
|
"ConvNeXt_v1",
|
|
122
|
+
"ConvNeXt_v1_Isotropic",
|
|
121
123
|
"ConvNeXt_v2",
|
|
122
124
|
"CrossFormer",
|
|
123
125
|
"CrossViT",
|
birder/net/_rope_vit_configs.py
CHANGED
|
@@ -88,6 +88,11 @@ def register_rope_vit_configs(rope_vit: type[BaseNet]) -> None:
|
|
|
88
88
|
rope_vit,
|
|
89
89
|
config={"patch_size": 16, **SMALL},
|
|
90
90
|
)
|
|
91
|
+
registry.register_model_config(
|
|
92
|
+
"rope_vit_s16_avg",
|
|
93
|
+
rope_vit,
|
|
94
|
+
config={"patch_size": 16, **SMALL, "class_token": False},
|
|
95
|
+
)
|
|
91
96
|
registry.register_model_config(
|
|
92
97
|
"rope_i_vit_s16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
|
|
93
98
|
rope_vit,
|
birder/net/_vit_configs.py
CHANGED
|
@@ -215,19 +215,6 @@ def register_vit_configs(vit: type[BaseNet]) -> None:
|
|
|
215
215
|
"drop_path_rate": 0.1,
|
|
216
216
|
},
|
|
217
217
|
)
|
|
218
|
-
registry.register_model_config( # From "Scaling Vision Transformers to 22 Billion Parameters"
|
|
219
|
-
"vit_22b_p16_qkn",
|
|
220
|
-
vit,
|
|
221
|
-
config={
|
|
222
|
-
"patch_size": 16,
|
|
223
|
-
"num_layers": 48,
|
|
224
|
-
"num_heads": 48,
|
|
225
|
-
"hidden_dim": 6144,
|
|
226
|
-
"mlp_dim": 24576,
|
|
227
|
-
"qk_norm": True,
|
|
228
|
-
"drop_path_rate": 0.1,
|
|
229
|
-
},
|
|
230
|
-
)
|
|
231
218
|
|
|
232
219
|
# With registers
|
|
233
220
|
####################
|
birder/net/alexnet.py
CHANGED
|
@@ -27,17 +27,17 @@ class AlexNet(BaseNet):
|
|
|
27
27
|
assert self.config is None, "config not supported"
|
|
28
28
|
|
|
29
29
|
self.body = nn.Sequential(
|
|
30
|
-
nn.Conv2d(self.input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)
|
|
30
|
+
nn.Conv2d(self.input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)),
|
|
31
31
|
nn.ReLU(inplace=True),
|
|
32
32
|
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
|
|
33
|
-
nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)
|
|
33
|
+
nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
|
|
34
34
|
nn.ReLU(inplace=True),
|
|
35
35
|
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
|
|
36
|
-
nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
|
36
|
+
nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
|
37
37
|
nn.ReLU(inplace=True),
|
|
38
|
-
nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
|
38
|
+
nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
|
39
39
|
nn.ReLU(inplace=True),
|
|
40
|
-
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
|
40
|
+
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
|
41
41
|
nn.ReLU(inplace=True),
|
|
42
42
|
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
|
|
43
43
|
nn.AdaptiveAvgPool2d(output_size=(6, 6)),
|
birder/net/base.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Literal
|
|
|
5
5
|
from typing import NotRequired
|
|
6
6
|
from typing import Optional
|
|
7
7
|
from typing import TypedDict
|
|
8
|
+
from typing import overload
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn.functional as F
|
|
@@ -54,6 +55,30 @@ def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> i
|
|
|
54
55
|
return new_v
|
|
55
56
|
|
|
56
57
|
|
|
58
|
+
@overload
|
|
59
|
+
def normalize_out_indices(out_indices: None, num_layers: int) -> None: ...
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@overload
|
|
63
|
+
def normalize_out_indices(out_indices: list[int], num_layers: int) -> list[int]: ...
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def normalize_out_indices(out_indices: Optional[list[int]], num_layers: int) -> Optional[list[int]]:
|
|
67
|
+
if out_indices is None:
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
normalized_indices = []
|
|
71
|
+
for idx in out_indices:
|
|
72
|
+
if idx < 0:
|
|
73
|
+
idx = num_layers + idx
|
|
74
|
+
if idx < 0 or idx >= num_layers:
|
|
75
|
+
raise ValueError(f"out_indices contains invalid index for num_layers={num_layers}")
|
|
76
|
+
|
|
77
|
+
normalized_indices.append(idx)
|
|
78
|
+
|
|
79
|
+
return normalized_indices
|
|
80
|
+
|
|
81
|
+
|
|
57
82
|
# class MiscNet(nn.Module):
|
|
58
83
|
# """
|
|
59
84
|
# Base class for general-purpose neural networks with automatic model registration
|
|
@@ -137,8 +162,8 @@ class BaseNet(nn.Module):
|
|
|
137
162
|
|
|
138
163
|
self.dynamic_size = False
|
|
139
164
|
|
|
140
|
-
self.classifier: nn.Module
|
|
141
165
|
self.embedding_size: int
|
|
166
|
+
self.classifier: nn.Module
|
|
142
167
|
|
|
143
168
|
def create_classifier(self, embed_dim: Optional[int] = None) -> nn.Module:
|
|
144
169
|
if self.num_classes == 0:
|
|
@@ -274,7 +299,7 @@ def pos_embedding_sin_cos_2d(
|
|
|
274
299
|
) -> torch.Tensor:
|
|
275
300
|
# assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sin-cos emb"
|
|
276
301
|
|
|
277
|
-
|
|
302
|
+
y, x = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
|
|
278
303
|
omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1)
|
|
279
304
|
omega = 1.0 / (temperature**omega)
|
|
280
305
|
|
|
@@ -294,7 +319,7 @@ def interpolate_attention_bias(
|
|
|
294
319
|
new_resolution: tuple[int, int],
|
|
295
320
|
mode: Literal["bilinear", "bicubic"] = "bicubic",
|
|
296
321
|
) -> torch.Tensor:
|
|
297
|
-
|
|
322
|
+
H, _ = attention_bias.size()
|
|
298
323
|
|
|
299
324
|
# Interpolate
|
|
300
325
|
orig_dtype = attention_bias.dtype
|
birder/net/biformer.py
CHANGED
|
@@ -30,7 +30,7 @@ from birder.net.base import DetectorBackbone
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def _grid2seq(x: torch.Tensor, region_size: tuple[int, int], num_heads: int) -> tuple[torch.Tensor, int, int]:
|
|
33
|
-
|
|
33
|
+
B, C, H, W = x.size()
|
|
34
34
|
region_h = H // region_size[0]
|
|
35
35
|
region_w = W // region_size[1]
|
|
36
36
|
x = x.view(B, num_heads, C // num_heads, region_h, region_size[0], region_w, region_size[1])
|
|
@@ -40,7 +40,7 @@ def _grid2seq(x: torch.Tensor, region_size: tuple[int, int], num_heads: int) ->
|
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
def _seq2grid(x: torch.Tensor, region_h: int, region_w: int, region_size: tuple[int, int]) -> torch.Tensor:
|
|
43
|
-
|
|
43
|
+
bs, n_head, _, _, head_dim = x.size()
|
|
44
44
|
x = x.view(bs, n_head, region_h, region_w, region_size[0], region_size[1], head_dim)
|
|
45
45
|
x = torch.einsum("bmhwpqd->bmdhpwq", x).reshape(
|
|
46
46
|
bs, n_head * head_dim, region_h * region_size[0], region_w * region_size[1]
|
|
@@ -60,7 +60,7 @@ def regional_routing_attention_torch(
|
|
|
60
60
|
auto_pad: bool,
|
|
61
61
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
62
62
|
kv_region_size = region_size
|
|
63
|
-
|
|
63
|
+
bs, n_head, q_nregion, topk = region_graph.size()
|
|
64
64
|
|
|
65
65
|
# Pad to deal with any input size
|
|
66
66
|
q_pad_b = 0
|
|
@@ -68,13 +68,13 @@ def regional_routing_attention_torch(
|
|
|
68
68
|
kv_pad_b = 0
|
|
69
69
|
kv_pad_r = 0
|
|
70
70
|
if auto_pad is True:
|
|
71
|
-
|
|
71
|
+
_, _, h_q, w_q = query.size()
|
|
72
72
|
q_pad_b = (region_size[0] - h_q % region_size[0]) % region_size[0]
|
|
73
73
|
q_pad_r = (region_size[1] - w_q % region_size[1]) % region_size[1]
|
|
74
74
|
if q_pad_b > 0 or q_pad_r > 0:
|
|
75
75
|
query = F.pad(query, (0, q_pad_r, 0, q_pad_b))
|
|
76
76
|
|
|
77
|
-
|
|
77
|
+
_, _, h_k, w_k = key.size()
|
|
78
78
|
kv_pad_b = (kv_region_size[0] - h_k % kv_region_size[0]) % kv_region_size[0]
|
|
79
79
|
kv_pad_r = (kv_region_size[1] - w_k % kv_region_size[1]) % kv_region_size[1]
|
|
80
80
|
if kv_pad_r > 0 or kv_pad_b > 0:
|
|
@@ -87,12 +87,12 @@ def regional_routing_attention_torch(
|
|
|
87
87
|
w_k = None
|
|
88
88
|
|
|
89
89
|
# To sequence format
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
90
|
+
query, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=n_head)
|
|
91
|
+
key, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=n_head)
|
|
92
|
+
value, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=n_head)
|
|
93
93
|
|
|
94
94
|
# Gather key and values
|
|
95
|
-
|
|
95
|
+
bs, n_head, kv_nregion, kv_region_size, head_dim = key.size()
|
|
96
96
|
broadcasted_region_graph = region_graph.view(bs, n_head, q_nregion, topk, 1, 1).expand(
|
|
97
97
|
-1, -1, -1, -1, kv_region_size, head_dim
|
|
98
98
|
)
|
|
@@ -146,12 +146,12 @@ class BiLevelRoutingAttention(nn.Module):
|
|
|
146
146
|
self.output_linear = nn.Conv2d(dim, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
|
147
147
|
|
|
148
148
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
149
|
-
|
|
149
|
+
_, _, H, W = x.size()
|
|
150
150
|
region_size = (H // self.n_win_h, W // self.n_win_w)
|
|
151
151
|
|
|
152
152
|
# Linear projection
|
|
153
153
|
qkv = self.qkv_linear(x)
|
|
154
|
-
|
|
154
|
+
q, k, v = qkv.chunk(3, dim=1)
|
|
155
155
|
|
|
156
156
|
# Region-to-region routing
|
|
157
157
|
q_r = F.avg_pool2d( # pylint: disable=not-callable
|
|
@@ -163,11 +163,11 @@ class BiLevelRoutingAttention(nn.Module):
|
|
|
163
163
|
q_r = q_r.permute(0, 2, 3, 1).flatten(1, 2) # (n, (hw), c)
|
|
164
164
|
k_r = k_r.flatten(2, 3) # (n, c, (hw))
|
|
165
165
|
a_r = q_r @ k_r
|
|
166
|
-
|
|
166
|
+
_, idx_r = torch.topk(a_r, k=self.topk, dim=-1)
|
|
167
167
|
idx_r = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
|
|
168
168
|
|
|
169
169
|
# Token to token attention
|
|
170
|
-
|
|
170
|
+
output, _ = regional_routing_attention_torch(
|
|
171
171
|
q, k, v, scale=self.scale, region_graph=idx_r, region_size=region_size, auto_pad=True
|
|
172
172
|
)
|
|
173
173
|
|
|
@@ -190,12 +190,12 @@ class Attention(nn.Module):
|
|
|
190
190
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
191
191
|
|
|
192
192
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
193
|
-
|
|
193
|
+
B, C, H, W = x.size()
|
|
194
194
|
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
|
|
195
195
|
|
|
196
196
|
N = H * W
|
|
197
197
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
198
|
-
|
|
198
|
+
q, k, v = qkv.unbind(0)
|
|
199
199
|
|
|
200
200
|
x = F.scaled_dot_product_attention( # pylint: disable=not-callable
|
|
201
201
|
q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, scale=self.scale
|
|
@@ -237,8 +237,8 @@ class AttentionLePE(nn.Module):
|
|
|
237
237
|
)
|
|
238
238
|
|
|
239
239
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
240
|
-
|
|
241
|
-
|
|
240
|
+
B, C, H, W = x.size()
|
|
241
|
+
q, k, v = self.qkv(x).chunk(3, dim=1)
|
|
242
242
|
|
|
243
243
|
attn = q.view(B, self.num_heads, self.head_dim, H * W).transpose(-1, -2) @ k.view(
|
|
244
244
|
B, self.num_heads, self.head_dim, H * W
|
birder/net/cait.py
CHANGED
|
@@ -47,7 +47,7 @@ class ClassAttention(nn.Module):
|
|
|
47
47
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
48
48
|
|
|
49
49
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
50
|
-
|
|
50
|
+
B, N, C = x.shape
|
|
51
51
|
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
52
52
|
k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
53
53
|
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
@@ -103,7 +103,7 @@ class TalkingHeadAttn(nn.Module):
|
|
|
103
103
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
104
104
|
|
|
105
105
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
106
|
-
|
|
106
|
+
B, N, C = x.shape
|
|
107
107
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
108
108
|
q = qkv[0] * self.scale
|
|
109
109
|
k = qkv[1]
|
birder/net/cas_vit.py
CHANGED
|
@@ -122,7 +122,7 @@ class AdditiveTokenMixer(nn.Module):
|
|
|
122
122
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
123
123
|
|
|
124
124
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
125
|
-
|
|
125
|
+
q, k, v = self.qkv(x).chunk(3, dim=1)
|
|
126
126
|
q = self.op_q(q)
|
|
127
127
|
k = self.op_k(k)
|
|
128
128
|
|
birder/net/coat.py
CHANGED
|
@@ -57,8 +57,8 @@ class ConvRelPosEnc(nn.Module):
|
|
|
57
57
|
self.channel_splits = [x * head_channels for x in head_splits]
|
|
58
58
|
|
|
59
59
|
def forward(self, q: torch.Tensor, v: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
B, num_heads, N, C = q.size()
|
|
61
|
+
H, W = size
|
|
62
62
|
torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
|
|
63
63
|
|
|
64
64
|
# Convolutional relative position encoding.
|
|
@@ -102,11 +102,11 @@ class FactorAttnConvRelPosEnc(nn.Module):
|
|
|
102
102
|
self.crpe = shared_crpe
|
|
103
103
|
|
|
104
104
|
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
105
|
-
|
|
105
|
+
B, N, C = x.size()
|
|
106
106
|
|
|
107
107
|
# Generate Q, K, V
|
|
108
108
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
109
|
-
|
|
109
|
+
q, k, v = qkv.unbind(0) # [B, h, N, Ch]
|
|
110
110
|
|
|
111
111
|
# Factorized attention
|
|
112
112
|
k_softmax = k.softmax(dim=2)
|
|
@@ -135,8 +135,8 @@ class ConvPosEnc(nn.Module):
|
|
|
135
135
|
)
|
|
136
136
|
|
|
137
137
|
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
B, N, C = x.size()
|
|
139
|
+
H, W = size
|
|
140
140
|
torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
|
|
141
141
|
|
|
142
142
|
# Extract CLS token and image tokens
|
|
@@ -244,8 +244,8 @@ class ParallelBlock(nn.Module):
|
|
|
244
244
|
return self.interpolate(x, scale_factor=1.0 / factor, size=size)
|
|
245
245
|
|
|
246
246
|
def interpolate(self, x: torch.Tensor, scale_factor: float, size: tuple[int, int]) -> torch.Tensor:
|
|
247
|
-
|
|
248
|
-
|
|
247
|
+
B, N, C = x.size()
|
|
248
|
+
H, W = size
|
|
249
249
|
torch._assert(N == 1 + H * W, "size mismatch") # pylint: disable=protected-access
|
|
250
250
|
|
|
251
251
|
cls_token = x[:, :1, :]
|
|
@@ -268,7 +268,7 @@ class ParallelBlock(nn.Module):
|
|
|
268
268
|
def forward(
|
|
269
269
|
self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, sizes: list[tuple[int, int]]
|
|
270
270
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
271
|
-
|
|
271
|
+
_, s2, s3, s4 = sizes
|
|
272
272
|
cur2 = self.norm12(x2)
|
|
273
273
|
cur3 = self.norm13(x3)
|
|
274
274
|
cur4 = self.norm14(x4)
|
|
@@ -310,7 +310,7 @@ class PatchEmbed(nn.Module):
|
|
|
310
310
|
|
|
311
311
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
|
|
312
312
|
x = self.proj(x)
|
|
313
|
-
|
|
313
|
+
H, W = x.shape[2:4]
|
|
314
314
|
|
|
315
315
|
x = x.flatten(2).transpose(1, 2)
|
|
316
316
|
x = self.norm(x)
|
|
@@ -500,7 +500,7 @@ class CoaT(DetectorBackbone):
|
|
|
500
500
|
B = x.shape[0]
|
|
501
501
|
|
|
502
502
|
# Serial blocks 1
|
|
503
|
-
|
|
503
|
+
x1, (h1, w1) = self.patch_embed1(x)
|
|
504
504
|
x1 = insert_cls(x1, self.cls_token1)
|
|
505
505
|
for blk in self.serial_blocks1:
|
|
506
506
|
x1 = blk(x1, size=(h1, w1))
|
|
@@ -508,7 +508,7 @@ class CoaT(DetectorBackbone):
|
|
|
508
508
|
x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
|
|
509
509
|
|
|
510
510
|
# Serial blocks 2
|
|
511
|
-
|
|
511
|
+
x2, (h2, w2) = self.patch_embed2(x1_no_cls)
|
|
512
512
|
x2 = insert_cls(x2, self.cls_token2)
|
|
513
513
|
for blk in self.serial_blocks2:
|
|
514
514
|
x2 = blk(x2, size=(h2, w2))
|
|
@@ -516,7 +516,7 @@ class CoaT(DetectorBackbone):
|
|
|
516
516
|
x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
|
|
517
517
|
|
|
518
518
|
# Serial blocks 3
|
|
519
|
-
|
|
519
|
+
x3, (h3, w3) = self.patch_embed3(x2_no_cls)
|
|
520
520
|
x3 = insert_cls(x3, self.cls_token3)
|
|
521
521
|
for blk in self.serial_blocks3:
|
|
522
522
|
x3 = blk(x3, size=(h3, w3))
|
|
@@ -524,7 +524,7 @@ class CoaT(DetectorBackbone):
|
|
|
524
524
|
x3_no_cls = remove_cls(x3).reshape(B, h3, w3, -1).permute(0, 3, 1, 2).contiguous()
|
|
525
525
|
|
|
526
526
|
# Serial blocks 4
|
|
527
|
-
|
|
527
|
+
x4, (h4, w4) = self.patch_embed4(x3_no_cls)
|
|
528
528
|
x4 = insert_cls(x4, self.cls_token4)
|
|
529
529
|
for blk in self.serial_blocks4:
|
|
530
530
|
x4 = blk(x4, size=(h4, w4))
|
|
@@ -537,7 +537,7 @@ class CoaT(DetectorBackbone):
|
|
|
537
537
|
x2 = self.cpe2(x2, (h2, w2))
|
|
538
538
|
x3 = self.cpe3(x3, (h3, w3))
|
|
539
539
|
x4 = self.cpe4(x4, (h4, w4))
|
|
540
|
-
|
|
540
|
+
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(h1, w1), (h2, w2), (h3, w3), (h4, w4)])
|
|
541
541
|
|
|
542
542
|
x1_no_cls = remove_cls(x1).reshape(B, h1, w1, -1).permute(0, 3, 1, 2).contiguous()
|
|
543
543
|
x2_no_cls = remove_cls(x2).reshape(B, h2, w2, -1).permute(0, 3, 1, 2).contiguous()
|
birder/net/convnext_v1.py
CHANGED
|
@@ -37,15 +37,7 @@ class ConvNeXtBlock(nn.Module):
|
|
|
37
37
|
) -> None:
|
|
38
38
|
super().__init__()
|
|
39
39
|
self.block = nn.Sequential(
|
|
40
|
-
nn.Conv2d(
|
|
41
|
-
channels,
|
|
42
|
-
channels,
|
|
43
|
-
kernel_size=(7, 7),
|
|
44
|
-
stride=(1, 1),
|
|
45
|
-
padding=(3, 3),
|
|
46
|
-
groups=channels,
|
|
47
|
-
bias=True,
|
|
48
|
-
),
|
|
40
|
+
nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
|
|
49
41
|
Permute([0, 2, 3, 1]),
|
|
50
42
|
nn.LayerNorm(channels, eps=1e-6),
|
|
51
43
|
nn.Linear(channels, 4 * channels), # Same as 1x1 conv
|
|
@@ -119,7 +111,7 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
119
111
|
layers.append(
|
|
120
112
|
nn.Sequential(
|
|
121
113
|
LayerNorm2d(i, eps=1e-6),
|
|
122
|
-
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)
|
|
114
|
+
nn.Conv2d(i, out, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)),
|
|
123
115
|
)
|
|
124
116
|
)
|
|
125
117
|
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ConvNeXt v1 Isotropic, adapted from
|
|
3
|
+
https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext_isotropic.py
|
|
4
|
+
|
|
5
|
+
Paper "A ConvNet for the 2020s", https://arxiv.org/abs/2201.03545
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Reference license: MIT
|
|
9
|
+
|
|
10
|
+
from functools import partial
|
|
11
|
+
from typing import Any
|
|
12
|
+
from typing import Literal
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import nn
|
|
17
|
+
from torchvision.ops import Permute
|
|
18
|
+
from torchvision.ops import StochasticDepth
|
|
19
|
+
|
|
20
|
+
from birder.common.masking import mask_tensor
|
|
21
|
+
from birder.layers import LayerNorm2d
|
|
22
|
+
from birder.model_registry import registry
|
|
23
|
+
from birder.net.base import DetectorBackbone
|
|
24
|
+
from birder.net.base import MaskedTokenRetentionMixin
|
|
25
|
+
from birder.net.base import PreTrainEncoder
|
|
26
|
+
from birder.net.base import TokenRetentionResultType
|
|
27
|
+
from birder.net.base import normalize_out_indices
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ConvNeXtBlock(nn.Module):
|
|
31
|
+
def __init__(self, channels: int, stochastic_depth_prob: float) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.block = nn.Sequential(
|
|
34
|
+
nn.Conv2d(channels, channels, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=channels),
|
|
35
|
+
Permute([0, 2, 3, 1]),
|
|
36
|
+
nn.LayerNorm(channels, eps=1e-6),
|
|
37
|
+
nn.Linear(channels, 4 * channels),
|
|
38
|
+
nn.GELU(),
|
|
39
|
+
nn.Linear(4 * channels, channels),
|
|
40
|
+
Permute([0, 3, 1, 2]),
|
|
41
|
+
)
|
|
42
|
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, mode="row")
|
|
43
|
+
|
|
44
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
identity = x
|
|
46
|
+
x = self.block(x)
|
|
47
|
+
x = self.stochastic_depth(x)
|
|
48
|
+
x += identity
|
|
49
|
+
|
|
50
|
+
return x
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# pylint: disable=invalid-name
|
|
54
|
+
class ConvNeXt_v1_Isotropic(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
55
|
+
block_group_regex = r"body\.(\d+)"
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
input_channels: int,
|
|
60
|
+
num_classes: int,
|
|
61
|
+
*,
|
|
62
|
+
config: Optional[dict[str, Any]] = None,
|
|
63
|
+
size: Optional[tuple[int, int]] = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
super().__init__(input_channels, num_classes, config=config, size=size)
|
|
66
|
+
assert self.config is not None, "must set config"
|
|
67
|
+
|
|
68
|
+
patch_size = 16
|
|
69
|
+
dim: int = self.config["dim"]
|
|
70
|
+
num_layers: int = self.config["num_layers"]
|
|
71
|
+
out_indices: Optional[list[int]] = self.config.get("out_indices", None)
|
|
72
|
+
drop_path_rate: float = self.config["drop_path_rate"]
|
|
73
|
+
|
|
74
|
+
torch._assert(self.size[0] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
75
|
+
torch._assert(self.size[1] % patch_size == 0, "Input shape indivisible by patch size!")
|
|
76
|
+
self.patch_size = patch_size
|
|
77
|
+
self.out_indices = normalize_out_indices(out_indices, num_layers)
|
|
78
|
+
|
|
79
|
+
self.stem = nn.Conv2d(
|
|
80
|
+
self.input_channels,
|
|
81
|
+
dim,
|
|
82
|
+
kernel_size=(patch_size, patch_size),
|
|
83
|
+
stride=(patch_size, patch_size),
|
|
84
|
+
padding=(0, 0),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
layers = []
|
|
88
|
+
for idx in range(num_layers):
|
|
89
|
+
# Adjust stochastic depth probability based on the depth of the stage block
|
|
90
|
+
sd_prob = drop_path_rate * idx / (num_layers - 1.0)
|
|
91
|
+
layers.append(ConvNeXtBlock(dim, sd_prob))
|
|
92
|
+
|
|
93
|
+
self.body = nn.Sequential(*layers)
|
|
94
|
+
self.features = nn.Sequential(
|
|
95
|
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
96
|
+
LayerNorm2d(dim, eps=1e-6),
|
|
97
|
+
nn.Flatten(1),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
num_return_stages = len(self.out_indices) if self.out_indices is not None else 1
|
|
101
|
+
self.return_stages = [f"stage{stage_idx + 1}" for stage_idx in range(num_return_stages)]
|
|
102
|
+
self.return_channels = [dim] * num_return_stages
|
|
103
|
+
self.embedding_size = dim
|
|
104
|
+
self.classifier = self.create_classifier()
|
|
105
|
+
|
|
106
|
+
self.max_stride = patch_size
|
|
107
|
+
self.stem_stride = patch_size
|
|
108
|
+
self.stem_width = dim
|
|
109
|
+
self.encoding_size = dim
|
|
110
|
+
self.decoder_block = partial(ConvNeXtBlock, stochastic_depth_prob=0)
|
|
111
|
+
|
|
112
|
+
# Weights initialization
|
|
113
|
+
for m in self.modules():
|
|
114
|
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
115
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
116
|
+
if m.bias is not None:
|
|
117
|
+
nn.init.zeros_(m.bias)
|
|
118
|
+
|
|
119
|
+
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
120
|
+
x = self.stem(x)
|
|
121
|
+
|
|
122
|
+
if self.out_indices is None:
|
|
123
|
+
x = self.body(x)
|
|
124
|
+
return {self.return_stages[0]: x}
|
|
125
|
+
|
|
126
|
+
stage_num = 0
|
|
127
|
+
out: dict[str, torch.Tensor] = {}
|
|
128
|
+
for idx, module in enumerate(self.body.children()):
|
|
129
|
+
x = module(x)
|
|
130
|
+
if idx in self.out_indices:
|
|
131
|
+
out[self.return_stages[stage_num]] = x
|
|
132
|
+
stage_num += 1
|
|
133
|
+
|
|
134
|
+
return out
|
|
135
|
+
|
|
136
|
+
def freeze_stages(self, up_to_stage: int) -> None:
|
|
137
|
+
for param in self.stem.parameters():
|
|
138
|
+
param.requires_grad_(False)
|
|
139
|
+
|
|
140
|
+
for idx, module in enumerate(self.body.children()):
|
|
141
|
+
if idx >= up_to_stage:
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
for param in module.parameters():
|
|
145
|
+
param.requires_grad_(False)
|
|
146
|
+
|
|
147
|
+
def masked_encoding_retention(
|
|
148
|
+
self,
|
|
149
|
+
x: torch.Tensor,
|
|
150
|
+
mask: torch.Tensor,
|
|
151
|
+
mask_token: Optional[torch.Tensor] = None,
|
|
152
|
+
return_keys: Literal["all", "features", "embedding"] = "features",
|
|
153
|
+
) -> TokenRetentionResultType:
|
|
154
|
+
x = self.stem(x)
|
|
155
|
+
x = mask_tensor(x, mask, patch_factor=self.max_stride // self.stem_stride, mask_token=mask_token)
|
|
156
|
+
x = self.body(x)
|
|
157
|
+
|
|
158
|
+
result: TokenRetentionResultType = {}
|
|
159
|
+
if return_keys in ("all", "features"):
|
|
160
|
+
result["features"] = x
|
|
161
|
+
if return_keys in ("all", "embedding"):
|
|
162
|
+
result["embedding"] = self.features(x)
|
|
163
|
+
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
167
|
+
x = self.stem(x)
|
|
168
|
+
return self.body(x)
|
|
169
|
+
|
|
170
|
+
def embedding(self, x: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
x = self.forward_features(x)
|
|
172
|
+
return self.features(x)
|
|
173
|
+
|
|
174
|
+
def adjust_size(self, new_size: tuple[int, int]) -> None:
|
|
175
|
+
if new_size == self.size:
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
assert new_size[0] % self.patch_size == 0, "Input shape indivisible by patch size!"
|
|
179
|
+
assert new_size[1] % self.patch_size == 0, "Input shape indivisible by patch size!"
|
|
180
|
+
|
|
181
|
+
super().adjust_size(new_size)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
registry.register_model_config(
|
|
185
|
+
"convnext_v1_iso_small",
|
|
186
|
+
ConvNeXt_v1_Isotropic,
|
|
187
|
+
config={"dim": 384, "num_layers": 18, "drop_path_rate": 0.1},
|
|
188
|
+
)
|
|
189
|
+
registry.register_model_config(
|
|
190
|
+
"convnext_v1_iso_base",
|
|
191
|
+
ConvNeXt_v1_Isotropic,
|
|
192
|
+
config={"in_channels": 768, "num_layers": 18, "drop_path_rate": 0.2},
|
|
193
|
+
)
|
|
194
|
+
registry.register_model_config(
|
|
195
|
+
"convnext_v1_iso_large",
|
|
196
|
+
ConvNeXt_v1_Isotropic,
|
|
197
|
+
config={"in_channels": 1024, "num_layers": 36, "drop_path_rate": 0.5},
|
|
198
|
+
)
|