birder 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/deepfool.py +2 -0
- birder/adversarial/simba.py +2 -0
- birder/common/masking.py +13 -4
- birder/inference/classification.py +1 -1
- birder/introspection/__init__.py +2 -0
- birder/introspection/base.py +0 -7
- birder/introspection/feature_pca.py +101 -0
- birder/kernels/soft_nms/soft_nms.cpp +5 -2
- birder/model_registry/model_registry.py +3 -2
- birder/net/convnext_v1.py +20 -0
- birder/net/fastvit.py +0 -1
- birder/net/flexivit.py +5 -0
- birder/net/focalnet.py +0 -1
- birder/net/rope_flexivit.py +7 -0
- birder/net/rope_vit.py +49 -4
- birder/net/smt.py +0 -1
- birder/net/ssl/ibot.py +0 -1
- birder/net/vit.py +166 -2
- birder/scripts/train.py +7 -6
- birder/scripts/train_barlow_twins.py +4 -3
- birder/scripts/train_byol.py +4 -3
- birder/scripts/train_capi.py +6 -5
- birder/scripts/train_data2vec.py +4 -3
- birder/scripts/train_data2vec2.py +4 -3
- birder/scripts/train_detection.py +7 -5
- birder/scripts/train_dino_v1.py +5 -4
- birder/scripts/train_dino_v2.py +69 -20
- birder/scripts/train_dino_v2_dist.py +70 -21
- birder/scripts/train_franca.py +8 -7
- birder/scripts/train_i_jepa.py +4 -3
- birder/scripts/train_ibot.py +5 -4
- birder/scripts/train_kd.py +8 -8
- birder/scripts/train_mim.py +4 -3
- birder/scripts/train_mmcr.py +4 -3
- birder/scripts/train_rotnet.py +5 -4
- birder/scripts/train_simclr.py +4 -3
- birder/scripts/train_vicreg.py +4 -3
- birder/tools/avg_model.py +24 -8
- birder/tools/introspection.py +35 -9
- birder/tools/show_iterator.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
- {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/RECORD +47 -46
- {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
- {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
- {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/adversarial/deepfool.py
CHANGED
birder/adversarial/simba.py
CHANGED
birder/common/masking.py
CHANGED
|
@@ -84,8 +84,8 @@ def mask_tensor(
|
|
|
84
84
|
|
|
85
85
|
(B, H, W, _) = x.size()
|
|
86
86
|
|
|
87
|
-
shaped_mask = mask.reshape(
|
|
88
|
-
shaped_mask = shaped_mask.repeat_interleave(patch_factor,
|
|
87
|
+
shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
|
|
88
|
+
shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
|
|
89
89
|
shaped_mask = shaped_mask.unsqueeze(3).type_as(x)
|
|
90
90
|
|
|
91
91
|
if mask_token is not None:
|
|
@@ -228,14 +228,23 @@ class Masking:
|
|
|
228
228
|
|
|
229
229
|
|
|
230
230
|
class UniformMasking(Masking):
|
|
231
|
-
def __init__(
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
input_size: tuple[int, int],
|
|
234
|
+
mask_ratio: float,
|
|
235
|
+
min_mask_size: int = 1,
|
|
236
|
+
device: Optional[torch.device] = None,
|
|
237
|
+
) -> None:
|
|
232
238
|
self.h = input_size[0]
|
|
233
239
|
self.w = input_size[1]
|
|
234
240
|
self.mask_ratio = mask_ratio
|
|
241
|
+
self.min_mask_size = min_mask_size
|
|
235
242
|
self.device = device
|
|
236
243
|
|
|
237
244
|
def __call__(self, batch_size: int) -> torch.Tensor:
|
|
238
|
-
return uniform_mask(
|
|
245
|
+
return uniform_mask(
|
|
246
|
+
batch_size, self.h, self.w, self.mask_ratio, min_mask_size=self.min_mask_size, device=self.device
|
|
247
|
+
)[0]
|
|
239
248
|
|
|
240
249
|
|
|
241
250
|
class BlockMasking(Masking):
|
|
@@ -85,7 +85,7 @@ def infer_batch(
|
|
|
85
85
|
logits = net(t(tta_input), **kwargs)
|
|
86
86
|
outs.append(logits if return_logits is True else F.softmax(logits, dim=1))
|
|
87
87
|
|
|
88
|
-
out = torch.stack(outs).mean(
|
|
88
|
+
out = torch.stack(outs).mean(dim=0)
|
|
89
89
|
|
|
90
90
|
else:
|
|
91
91
|
logits = net(inputs, **kwargs)
|
birder/introspection/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from birder.introspection.attention_rollout import AttentionRollout
|
|
2
2
|
from birder.introspection.base import InterpretabilityResult
|
|
3
|
+
from birder.introspection.feature_pca import FeaturePCA
|
|
3
4
|
from birder.introspection.gradcam import GradCAM
|
|
4
5
|
from birder.introspection.guided_backprop import GuidedBackprop
|
|
5
6
|
from birder.introspection.transformer_attribution import TransformerAttribution
|
|
@@ -7,6 +8,7 @@ from birder.introspection.transformer_attribution import TransformerAttribution
|
|
|
7
8
|
__all__ = [
|
|
8
9
|
"InterpretabilityResult",
|
|
9
10
|
"AttentionRollout",
|
|
11
|
+
"FeaturePCA",
|
|
10
12
|
"GradCAM",
|
|
11
13
|
"GuidedBackprop",
|
|
12
14
|
"TransformerAttribution",
|
birder/introspection/base.py
CHANGED
|
@@ -2,7 +2,6 @@ from collections.abc import Callable
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Optional
|
|
5
|
-
from typing import Protocol
|
|
6
5
|
|
|
7
6
|
import matplotlib
|
|
8
7
|
import matplotlib.pyplot as plt
|
|
@@ -27,12 +26,6 @@ class InterpretabilityResult:
|
|
|
27
26
|
plt.show()
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
class Interpreter(Protocol):
|
|
31
|
-
def __call__(
|
|
32
|
-
self, image: str | Path | Image.Image, target_class: Optional[int] = None
|
|
33
|
-
) -> InterpretabilityResult: ...
|
|
34
|
-
|
|
35
|
-
|
|
36
29
|
def load_image(image: str | Path | Image.Image) -> Image.Image:
|
|
37
30
|
if isinstance(image, (str, Path)):
|
|
38
31
|
return Image.open(image)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from sklearn.decomposition import PCA
|
|
9
|
+
|
|
10
|
+
from birder.introspection.base import InterpretabilityResult
|
|
11
|
+
from birder.introspection.base import preprocess_image
|
|
12
|
+
from birder.net.base import DetectorBackbone
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FeaturePCA:
|
|
16
|
+
"""
|
|
17
|
+
Visualizes feature maps using Principal Component Analysis
|
|
18
|
+
|
|
19
|
+
This method extracts feature maps from a specified stage of a DetectorBackbone model,
|
|
20
|
+
applies PCA to reduce the channel dimension to 3 components, and visualizes them as an RGB image where:
|
|
21
|
+
- R channel = 1st principal component (most important)
|
|
22
|
+
- G channel = 2nd principal component
|
|
23
|
+
- B channel = 3rd principal component
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
net: DetectorBackbone,
|
|
29
|
+
device: torch.device,
|
|
30
|
+
transform: Callable[..., torch.Tensor],
|
|
31
|
+
normalize: bool = False,
|
|
32
|
+
channels_last: bool = False,
|
|
33
|
+
stage: Optional[str] = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
self.net = net.eval()
|
|
36
|
+
self.device = device
|
|
37
|
+
self.transform = transform
|
|
38
|
+
self.normalize = normalize
|
|
39
|
+
self.channels_last = channels_last
|
|
40
|
+
self.stage = stage
|
|
41
|
+
|
|
42
|
+
def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
|
|
43
|
+
(input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
|
|
44
|
+
|
|
45
|
+
with torch.inference_mode():
|
|
46
|
+
features_dict = self.net.detection_features(input_tensor)
|
|
47
|
+
|
|
48
|
+
if self.stage is not None:
|
|
49
|
+
features = features_dict[self.stage]
|
|
50
|
+
else:
|
|
51
|
+
features = list(features_dict.values())[-1] # Use the last stage by default
|
|
52
|
+
|
|
53
|
+
features_np = features.cpu().numpy()
|
|
54
|
+
|
|
55
|
+
# Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
|
|
56
|
+
if self.channels_last is True:
|
|
57
|
+
(B, H, W, C) = features_np.shape
|
|
58
|
+
# Already in (B, H, W, C), just reshape to (B*H*W, C)
|
|
59
|
+
features_reshaped = features_np.reshape(-1, C)
|
|
60
|
+
else:
|
|
61
|
+
(B, C, H, W) = features_np.shape
|
|
62
|
+
# Reshape to (spatial_points, channels) for PCA
|
|
63
|
+
features_reshaped = features_np.reshape(B, C, -1)
|
|
64
|
+
features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
|
|
65
|
+
features_reshaped = features_reshaped.reshape(-1, C) # (B*H*W, C)
|
|
66
|
+
|
|
67
|
+
x = features_reshaped
|
|
68
|
+
if self.normalize is True:
|
|
69
|
+
x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-6)
|
|
70
|
+
|
|
71
|
+
pca = PCA(n_components=3)
|
|
72
|
+
pca_features = pca.fit_transform(x)
|
|
73
|
+
pca_features = pca_features.reshape(B, H, W, 3)
|
|
74
|
+
|
|
75
|
+
# Extract all 3 components (B=1)
|
|
76
|
+
pca_rgb = pca_features[0] # (H, W, 3)
|
|
77
|
+
|
|
78
|
+
# Normalize each channel independently to [0, 1]
|
|
79
|
+
for i in range(3):
|
|
80
|
+
channel = pca_rgb[:, :, i]
|
|
81
|
+
channel = channel - channel.min()
|
|
82
|
+
channel = channel / (channel.max() + 1e-7)
|
|
83
|
+
pca_rgb[:, :, i] = channel
|
|
84
|
+
|
|
85
|
+
target_size = (input_tensor.size(-1), input_tensor.size(-2)) # PIL expects (width, height)
|
|
86
|
+
pca_rgb_resized = (
|
|
87
|
+
np.array(
|
|
88
|
+
Image.fromarray((pca_rgb * 255).astype(np.uint8)).resize(target_size, Image.Resampling.BILINEAR)
|
|
89
|
+
).astype(np.float32)
|
|
90
|
+
/ 255.0
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
visualization = (pca_rgb_resized * 255).astype(np.uint8)
|
|
94
|
+
|
|
95
|
+
return InterpretabilityResult(
|
|
96
|
+
original_image=rgb_img,
|
|
97
|
+
visualization=visualization,
|
|
98
|
+
raw_output=pca_rgb.astype(np.float32),
|
|
99
|
+
logits=None,
|
|
100
|
+
predicted_class=None,
|
|
101
|
+
)
|
|
@@ -4,6 +4,9 @@
|
|
|
4
4
|
* Taken from:
|
|
5
5
|
* https://github.com/MrParosk/soft_nms
|
|
6
6
|
* Licensed under the MIT License
|
|
7
|
+
*
|
|
8
|
+
* Modified by:
|
|
9
|
+
* Ofer Hasson — 2026-01-10
|
|
7
10
|
**************************************************************************************************
|
|
8
11
|
*/
|
|
9
12
|
|
|
@@ -40,8 +43,8 @@ torch::Tensor calculate_iou(const torch::Tensor& boxes, const torch::Tensor& are
|
|
|
40
43
|
auto xx2 = torch::minimum(boxes.index({idx, 2}), boxes.index({Slice(idx + 1, None), 2}));
|
|
41
44
|
auto yy2 = torch::minimum(boxes.index({idx, 3}), boxes.index({Slice(idx + 1, None), 3}));
|
|
42
45
|
|
|
43
|
-
auto w =
|
|
44
|
-
auto h =
|
|
46
|
+
auto w = (xx2 - xx1).clamp_min(0);
|
|
47
|
+
auto h = (yy2 - yy1).clamp_min(0);
|
|
45
48
|
|
|
46
49
|
auto intersection = w * h;
|
|
47
50
|
auto union_ = areas.index({idx}) + areas.index({Slice(idx + 1, None)}) - intersection;
|
|
@@ -87,14 +87,15 @@ class ModelRegistry:
|
|
|
87
87
|
no further registration is needed.
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
|
+
alias_key = alias.lower()
|
|
90
91
|
if net_type.auto_register is False:
|
|
91
92
|
# Register the model manually, as the base class doesn't take care of that for us
|
|
92
|
-
|
|
93
|
+
self.register_model(alias_key, type(alias, (net_type,), {"config": config}))
|
|
93
94
|
|
|
94
95
|
if alias in self.aliases:
|
|
95
96
|
warnings.warn(f"Alias {alias} is already registered", UserWarning)
|
|
96
97
|
|
|
97
|
-
self.aliases[
|
|
98
|
+
self.aliases[alias_key] = type(alias, (net_type,), {"config": config})
|
|
98
99
|
|
|
99
100
|
def register_weights(self, name: str, weights_info: manifest.ModelMetadataType) -> None:
|
|
100
101
|
if name in self._pretrained_nets:
|
birder/net/convnext_v1.py
CHANGED
|
@@ -195,6 +195,21 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
195
195
|
return self.features(x)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
registry.register_model_config(
|
|
199
|
+
"convnext_v1_atto", # Not in the original v1, taken from v2
|
|
200
|
+
ConvNeXt_v1,
|
|
201
|
+
config={"in_channels": [40, 80, 160, 320], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
202
|
+
)
|
|
203
|
+
registry.register_model_config(
|
|
204
|
+
"convnext_v1_femto", # Not in the original v1, taken from v2
|
|
205
|
+
ConvNeXt_v1,
|
|
206
|
+
config={"in_channels": [48, 96, 192, 384], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
207
|
+
)
|
|
208
|
+
registry.register_model_config(
|
|
209
|
+
"convnext_v1_pico", # Not in the original v1, taken from v2
|
|
210
|
+
ConvNeXt_v1,
|
|
211
|
+
config={"in_channels": [64, 128, 256, 512], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
212
|
+
)
|
|
198
213
|
registry.register_model_config(
|
|
199
214
|
"convnext_v1_nano", # Not in the original v1, taken from v2
|
|
200
215
|
ConvNeXt_v1,
|
|
@@ -220,6 +235,11 @@ registry.register_model_config(
|
|
|
220
235
|
ConvNeXt_v1,
|
|
221
236
|
config={"in_channels": [192, 384, 768, 1536], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
|
|
222
237
|
)
|
|
238
|
+
registry.register_model_config(
|
|
239
|
+
"convnext_v1_huge", # Not in the original v1, taken from v2
|
|
240
|
+
ConvNeXt_v1,
|
|
241
|
+
config={"in_channels": [352, 704, 1408, 2816], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
|
|
242
|
+
)
|
|
223
243
|
|
|
224
244
|
registry.register_weights(
|
|
225
245
|
"convnext_v1_tiny_eu-common256px",
|
birder/net/fastvit.py
CHANGED
birder/net/flexivit.py
CHANGED
|
@@ -98,6 +98,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
98
98
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
|
|
99
99
|
pre_norm: bool = self.config.get("pre_norm", False)
|
|
100
100
|
post_norm: bool = self.config.get("post_norm", True)
|
|
101
|
+
qkv_bias: bool = self.config.get("qkv_bias", True)
|
|
102
|
+
qk_norm: bool = self.config.get("qk_norm", False)
|
|
101
103
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
102
104
|
class_token: bool = self.config.get("class_token", True)
|
|
103
105
|
attn_pool_head: bool = self.config.get("attn_pool_head", False)
|
|
@@ -186,6 +188,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
186
188
|
attention_dropout,
|
|
187
189
|
dpr,
|
|
188
190
|
pre_norm=pre_norm,
|
|
191
|
+
qkv_bias=qkv_bias,
|
|
192
|
+
qk_norm=qk_norm,
|
|
189
193
|
activation_layer=act_layer,
|
|
190
194
|
layer_scale_init_value=layer_scale_init_value,
|
|
191
195
|
norm_layer=norm_layer,
|
|
@@ -224,6 +228,7 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
224
228
|
drop_path=0,
|
|
225
229
|
activation_layer=act_layer,
|
|
226
230
|
norm_layer=norm_layer,
|
|
231
|
+
norm_layer_eps=norm_layer_eps,
|
|
227
232
|
mlp_layer=mlp_layer,
|
|
228
233
|
)
|
|
229
234
|
|
birder/net/focalnet.py
CHANGED
birder/net/rope_flexivit.py
CHANGED
|
@@ -69,6 +69,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
69
69
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
|
|
70
70
|
pre_norm: bool = self.config.get("pre_norm", False)
|
|
71
71
|
post_norm: bool = self.config.get("post_norm", True)
|
|
72
|
+
qkv_bias: bool = self.config.get("qkv_bias", True)
|
|
73
|
+
qk_norm: bool = self.config.get("qk_norm", False)
|
|
72
74
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
73
75
|
class_token: bool = self.config.get("class_token", True)
|
|
74
76
|
attn_pool_head: bool = self.config.get("attn_pool_head", False)
|
|
@@ -118,6 +120,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
118
120
|
self.num_reg_tokens = num_reg_tokens
|
|
119
121
|
self.attn_pool_special_tokens = attn_pool_special_tokens
|
|
120
122
|
self.norm_layer = norm_layer
|
|
123
|
+
self.norm_layer_eps = norm_layer_eps
|
|
121
124
|
self.mlp_layer = mlp_layer
|
|
122
125
|
self.act_layer = act_layer
|
|
123
126
|
self.rope_rot_type = rope_rot_type
|
|
@@ -190,6 +193,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
190
193
|
attention_dropout,
|
|
191
194
|
dpr,
|
|
192
195
|
pre_norm=pre_norm,
|
|
196
|
+
qkv_bias=qkv_bias,
|
|
197
|
+
qk_norm=qk_norm,
|
|
193
198
|
activation_layer=act_layer,
|
|
194
199
|
layer_scale_init_value=layer_scale_init_value,
|
|
195
200
|
norm_layer=norm_layer,
|
|
@@ -231,6 +236,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
231
236
|
rope_temperature=rope_temperature,
|
|
232
237
|
layer_scale_init_value=layer_scale_init_value,
|
|
233
238
|
norm_layer=norm_layer,
|
|
239
|
+
norm_layer_eps=norm_layer_eps,
|
|
234
240
|
mlp_layer=mlp_layer,
|
|
235
241
|
rope_rot_type=rope_rot_type,
|
|
236
242
|
)
|
|
@@ -588,6 +594,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
588
594
|
rope_temperature=self.rope_temperature,
|
|
589
595
|
layer_scale_init_value=self.layer_scale_init_value,
|
|
590
596
|
norm_layer=self.norm_layer,
|
|
597
|
+
norm_layer_eps=self.norm_layer_eps,
|
|
591
598
|
mlp_layer=self.mlp_layer,
|
|
592
599
|
rope_rot_type=self.rope_rot_type,
|
|
593
600
|
)
|
birder/net/rope_vit.py
CHANGED
|
@@ -150,6 +150,10 @@ class RoPEAttention(nn.Module):
|
|
|
150
150
|
attn_drop: float,
|
|
151
151
|
proj_drop: float,
|
|
152
152
|
num_special_tokens: int,
|
|
153
|
+
qkv_bias: bool = True,
|
|
154
|
+
qk_norm: bool = False,
|
|
155
|
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
156
|
+
norm_layer_eps: float = 1e-6,
|
|
153
157
|
rope_rot_type: str = "standard",
|
|
154
158
|
) -> None:
|
|
155
159
|
super().__init__()
|
|
@@ -167,7 +171,14 @@ class RoPEAttention(nn.Module):
|
|
|
167
171
|
else:
|
|
168
172
|
raise ValueError(f"Unknown rope_rot_type, got '{rope_rot_type}'")
|
|
169
173
|
|
|
170
|
-
self.qkv = nn.Linear(dim, dim * 3)
|
|
174
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
175
|
+
if qk_norm is True:
|
|
176
|
+
self.q_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
|
|
177
|
+
self.k_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
|
|
178
|
+
else:
|
|
179
|
+
self.q_norm = nn.Identity()
|
|
180
|
+
self.k_norm = nn.Identity()
|
|
181
|
+
|
|
171
182
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
172
183
|
self.proj = nn.Linear(dim, dim)
|
|
173
184
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
@@ -176,6 +187,8 @@ class RoPEAttention(nn.Module):
|
|
|
176
187
|
(B, N, C) = x.size()
|
|
177
188
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
178
189
|
(q, k, v) = qkv.unbind(0)
|
|
190
|
+
q = self.q_norm(q)
|
|
191
|
+
k = self.k_norm(k)
|
|
179
192
|
|
|
180
193
|
n = self.num_special_tokens
|
|
181
194
|
q = torch.concat([q[:, :, :n, :], self.apply_rot_fn(q[:, :, n:, :], rope)], dim=2)
|
|
@@ -207,6 +220,8 @@ class EncoderBlock(nn.Module):
|
|
|
207
220
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
208
221
|
norm_layer_eps: float = 1e-6,
|
|
209
222
|
mlp_layer: Callable[..., nn.Module] = FFN,
|
|
223
|
+
qkv_bias: bool = True,
|
|
224
|
+
qk_norm: bool = False,
|
|
210
225
|
rope_rot_type: str = "standard",
|
|
211
226
|
) -> None:
|
|
212
227
|
super().__init__()
|
|
@@ -222,6 +237,10 @@ class EncoderBlock(nn.Module):
|
|
|
222
237
|
attn_drop=attention_dropout,
|
|
223
238
|
proj_drop=dropout,
|
|
224
239
|
num_special_tokens=num_special_tokens,
|
|
240
|
+
qkv_bias=qkv_bias,
|
|
241
|
+
qk_norm=qk_norm,
|
|
242
|
+
norm_layer=norm_layer,
|
|
243
|
+
norm_layer_eps=norm_layer_eps,
|
|
225
244
|
rope_rot_type=rope_rot_type,
|
|
226
245
|
)
|
|
227
246
|
if layer_scale_init_value is not None:
|
|
@@ -249,7 +268,6 @@ class EncoderBlock(nn.Module):
|
|
|
249
268
|
|
|
250
269
|
|
|
251
270
|
class Encoder(nn.Module):
|
|
252
|
-
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
253
271
|
def __init__(
|
|
254
272
|
self,
|
|
255
273
|
num_layers: int,
|
|
@@ -261,6 +279,8 @@ class Encoder(nn.Module):
|
|
|
261
279
|
attention_dropout: float,
|
|
262
280
|
dpr: list[float],
|
|
263
281
|
pre_norm: bool = False,
|
|
282
|
+
qkv_bias: bool = True,
|
|
283
|
+
qk_norm: bool = False,
|
|
264
284
|
activation_layer: Callable[..., nn.Module] = nn.GELU,
|
|
265
285
|
layer_scale_init_value: Optional[float] = None,
|
|
266
286
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
@@ -293,6 +313,8 @@ class Encoder(nn.Module):
|
|
|
293
313
|
norm_layer=norm_layer,
|
|
294
314
|
norm_layer_eps=norm_layer_eps,
|
|
295
315
|
mlp_layer=mlp_layer,
|
|
316
|
+
qkv_bias=qkv_bias,
|
|
317
|
+
qk_norm=qk_norm,
|
|
296
318
|
rope_rot_type=rope_rot_type,
|
|
297
319
|
)
|
|
298
320
|
)
|
|
@@ -331,6 +353,7 @@ class MAEDecoderBlock(nn.Module):
|
|
|
331
353
|
rope_temperature: float,
|
|
332
354
|
layer_scale_init_value: Optional[float] = None,
|
|
333
355
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
356
|
+
norm_layer_eps: float = 1e-6,
|
|
334
357
|
mlp_layer: Callable[..., nn.Module] = FFN,
|
|
335
358
|
rope_rot_type: str = "standard",
|
|
336
359
|
) -> None:
|
|
@@ -346,7 +369,7 @@ class MAEDecoderBlock(nn.Module):
|
|
|
346
369
|
)
|
|
347
370
|
|
|
348
371
|
# Attention block
|
|
349
|
-
self.norm1 = norm_layer(hidden_dim, eps=
|
|
372
|
+
self.norm1 = norm_layer(hidden_dim, eps=norm_layer_eps)
|
|
350
373
|
self.attn = RoPEAttention(
|
|
351
374
|
hidden_dim,
|
|
352
375
|
num_heads,
|
|
@@ -361,7 +384,7 @@ class MAEDecoderBlock(nn.Module):
|
|
|
361
384
|
self.layer_scale_1 = nn.Identity()
|
|
362
385
|
|
|
363
386
|
# MLP block
|
|
364
|
-
self.norm2 = norm_layer(hidden_dim, eps=
|
|
387
|
+
self.norm2 = norm_layer(hidden_dim, eps=norm_layer_eps)
|
|
365
388
|
self.mlp = mlp_layer(hidden_dim, mlp_dim, act_layer=activation_layer, dropout=0.0)
|
|
366
389
|
if layer_scale_init_value is not None:
|
|
367
390
|
self.layer_scale_2 = LayerScale(hidden_dim, layer_scale_init_value)
|
|
@@ -403,6 +426,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
403
426
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
|
|
404
427
|
pre_norm: bool = self.config.get("pre_norm", False)
|
|
405
428
|
post_norm: bool = self.config.get("post_norm", True)
|
|
429
|
+
qkv_bias: bool = self.config.get("qkv_bias", True)
|
|
430
|
+
qk_norm: bool = self.config.get("qk_norm", False)
|
|
406
431
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
407
432
|
class_token: bool = self.config.get("class_token", True)
|
|
408
433
|
attn_pool_head: bool = self.config.get("attn_pool_head", False)
|
|
@@ -450,6 +475,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
450
475
|
self.num_reg_tokens = num_reg_tokens
|
|
451
476
|
self.attn_pool_special_tokens = attn_pool_special_tokens
|
|
452
477
|
self.norm_layer = norm_layer
|
|
478
|
+
self.norm_layer_eps = norm_layer_eps
|
|
453
479
|
self.mlp_layer = mlp_layer
|
|
454
480
|
self.act_layer = act_layer
|
|
455
481
|
self.rope_rot_type = rope_rot_type
|
|
@@ -521,6 +547,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
521
547
|
attention_dropout,
|
|
522
548
|
dpr,
|
|
523
549
|
pre_norm=pre_norm,
|
|
550
|
+
qkv_bias=qkv_bias,
|
|
551
|
+
qk_norm=qk_norm,
|
|
524
552
|
activation_layer=act_layer,
|
|
525
553
|
layer_scale_init_value=layer_scale_init_value,
|
|
526
554
|
norm_layer=norm_layer,
|
|
@@ -562,6 +590,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
562
590
|
rope_temperature=rope_temperature,
|
|
563
591
|
layer_scale_init_value=layer_scale_init_value,
|
|
564
592
|
norm_layer=norm_layer,
|
|
593
|
+
norm_layer_eps=norm_layer_eps,
|
|
565
594
|
mlp_layer=mlp_layer,
|
|
566
595
|
rope_rot_type=rope_rot_type,
|
|
567
596
|
)
|
|
@@ -904,6 +933,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
904
933
|
rope_temperature=self.rope_temperature,
|
|
905
934
|
layer_scale_init_value=self.layer_scale_init_value,
|
|
906
935
|
norm_layer=self.norm_layer,
|
|
936
|
+
norm_layer_eps=self.norm_layer_eps,
|
|
907
937
|
mlp_layer=self.mlp_layer,
|
|
908
938
|
rope_rot_type=self.rope_rot_type,
|
|
909
939
|
)
|
|
@@ -931,6 +961,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
931
961
|
# - rms : RMSNorm (instead of LayerNorm)
|
|
932
962
|
# - pn : Pre-Norm (layer norm before the encoder) - implies different norm eps
|
|
933
963
|
# - npn : No Post Norm (disables post-normalization layer)
|
|
964
|
+
# - qkn : QK Norm
|
|
934
965
|
#
|
|
935
966
|
# Feed-Forward Network:
|
|
936
967
|
# - swiglu : SwiGLU FFN layer type (instead of standard FFN)
|
|
@@ -1068,6 +1099,20 @@ registry.register_model_config(
|
|
|
1068
1099
|
"drop_path_rate": 0.1,
|
|
1069
1100
|
},
|
|
1070
1101
|
)
|
|
1102
|
+
registry.register_model_config(
|
|
1103
|
+
"rope_vit_b16_qkn_ls",
|
|
1104
|
+
RoPE_ViT,
|
|
1105
|
+
config={
|
|
1106
|
+
"patch_size": 16,
|
|
1107
|
+
"num_layers": 12,
|
|
1108
|
+
"num_heads": 12,
|
|
1109
|
+
"hidden_dim": 768,
|
|
1110
|
+
"mlp_dim": 3072,
|
|
1111
|
+
"layer_scale_init_value": 1e-5,
|
|
1112
|
+
"qk_norm": True,
|
|
1113
|
+
"drop_path_rate": 0.1,
|
|
1114
|
+
},
|
|
1115
|
+
)
|
|
1071
1116
|
registry.register_model_config(
|
|
1072
1117
|
"rope_i_vit_b16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
|
|
1073
1118
|
RoPE_ViT,
|
birder/net/smt.py
CHANGED