birder 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/masking.py +13 -4
  4. birder/inference/classification.py +1 -1
  5. birder/introspection/__init__.py +2 -0
  6. birder/introspection/base.py +0 -7
  7. birder/introspection/feature_pca.py +101 -0
  8. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  9. birder/model_registry/model_registry.py +3 -2
  10. birder/net/convnext_v1.py +20 -0
  11. birder/net/fastvit.py +0 -1
  12. birder/net/flexivit.py +5 -0
  13. birder/net/focalnet.py +0 -1
  14. birder/net/rope_flexivit.py +7 -0
  15. birder/net/rope_vit.py +49 -4
  16. birder/net/smt.py +0 -1
  17. birder/net/ssl/ibot.py +0 -1
  18. birder/net/vit.py +166 -2
  19. birder/scripts/train.py +7 -6
  20. birder/scripts/train_barlow_twins.py +4 -3
  21. birder/scripts/train_byol.py +4 -3
  22. birder/scripts/train_capi.py +6 -5
  23. birder/scripts/train_data2vec.py +4 -3
  24. birder/scripts/train_data2vec2.py +4 -3
  25. birder/scripts/train_detection.py +7 -5
  26. birder/scripts/train_dino_v1.py +5 -4
  27. birder/scripts/train_dino_v2.py +69 -20
  28. birder/scripts/train_dino_v2_dist.py +70 -21
  29. birder/scripts/train_franca.py +8 -7
  30. birder/scripts/train_i_jepa.py +4 -3
  31. birder/scripts/train_ibot.py +5 -4
  32. birder/scripts/train_kd.py +8 -8
  33. birder/scripts/train_mim.py +4 -3
  34. birder/scripts/train_mmcr.py +4 -3
  35. birder/scripts/train_rotnet.py +5 -4
  36. birder/scripts/train_simclr.py +4 -3
  37. birder/scripts/train_vicreg.py +4 -3
  38. birder/tools/avg_model.py +24 -8
  39. birder/tools/introspection.py +35 -9
  40. birder/tools/show_iterator.py +1 -1
  41. birder/version.py +1 -1
  42. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
  43. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/RECORD +47 -46
  44. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
  45. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
  46. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
  47. {birder-0.3.1.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@
2
2
  DeepFool
3
3
 
4
4
  Paper "DeepFool: a simple and accurate method to fool deep neural networks", https://arxiv.org/abs/1511.04599
5
+
6
+ Generated by gpt-5.2-codex xhigh.
5
7
  """
6
8
 
7
9
  from typing import Optional
@@ -2,6 +2,8 @@
2
2
  SimBA (Simple Black-box Attack)
3
3
 
4
4
  Paper "Simple Black-box Adversarial Attacks", https://arxiv.org/abs/1905.07121
5
+
6
+ Generated by gpt-5.2-codex xhigh.
5
7
  """
6
8
 
7
9
  from typing import Optional
birder/common/masking.py CHANGED
@@ -84,8 +84,8 @@ def mask_tensor(
84
84
 
85
85
  (B, H, W, _) = x.size()
86
86
 
87
- shaped_mask = mask.reshape(-1, H // patch_factor, W // patch_factor)
88
- shaped_mask = shaped_mask.repeat_interleave(patch_factor, axis=1).repeat_interleave(patch_factor, axis=2)
87
+ shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
88
+ shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
89
89
  shaped_mask = shaped_mask.unsqueeze(3).type_as(x)
90
90
 
91
91
  if mask_token is not None:
@@ -228,14 +228,23 @@ class Masking:
228
228
 
229
229
 
230
230
  class UniformMasking(Masking):
231
- def __init__(self, input_size: tuple[int, int], mask_ratio: float, device: Optional[torch.device] = None) -> None:
231
+ def __init__(
232
+ self,
233
+ input_size: tuple[int, int],
234
+ mask_ratio: float,
235
+ min_mask_size: int = 1,
236
+ device: Optional[torch.device] = None,
237
+ ) -> None:
232
238
  self.h = input_size[0]
233
239
  self.w = input_size[1]
234
240
  self.mask_ratio = mask_ratio
241
+ self.min_mask_size = min_mask_size
235
242
  self.device = device
236
243
 
237
244
  def __call__(self, batch_size: int) -> torch.Tensor:
238
- return uniform_mask(batch_size, self.h, self.w, self.mask_ratio, device=self.device)[0]
245
+ return uniform_mask(
246
+ batch_size, self.h, self.w, self.mask_ratio, min_mask_size=self.min_mask_size, device=self.device
247
+ )[0]
239
248
 
240
249
 
241
250
  class BlockMasking(Masking):
@@ -85,7 +85,7 @@ def infer_batch(
85
85
  logits = net(t(tta_input), **kwargs)
86
86
  outs.append(logits if return_logits is True else F.softmax(logits, dim=1))
87
87
 
88
- out = torch.stack(outs).mean(axis=0)
88
+ out = torch.stack(outs).mean(dim=0)
89
89
 
90
90
  else:
91
91
  logits = net(inputs, **kwargs)
@@ -1,5 +1,6 @@
1
1
  from birder.introspection.attention_rollout import AttentionRollout
2
2
  from birder.introspection.base import InterpretabilityResult
3
+ from birder.introspection.feature_pca import FeaturePCA
3
4
  from birder.introspection.gradcam import GradCAM
4
5
  from birder.introspection.guided_backprop import GuidedBackprop
5
6
  from birder.introspection.transformer_attribution import TransformerAttribution
@@ -7,6 +8,7 @@ from birder.introspection.transformer_attribution import TransformerAttribution
7
8
  __all__ = [
8
9
  "InterpretabilityResult",
9
10
  "AttentionRollout",
11
+ "FeaturePCA",
10
12
  "GradCAM",
11
13
  "GuidedBackprop",
12
14
  "TransformerAttribution",
@@ -2,7 +2,6 @@ from collections.abc import Callable
2
2
  from dataclasses import dataclass
3
3
  from pathlib import Path
4
4
  from typing import Optional
5
- from typing import Protocol
6
5
 
7
6
  import matplotlib
8
7
  import matplotlib.pyplot as plt
@@ -27,12 +26,6 @@ class InterpretabilityResult:
27
26
  plt.show()
28
27
 
29
28
 
30
- class Interpreter(Protocol):
31
- def __call__(
32
- self, image: str | Path | Image.Image, target_class: Optional[int] = None
33
- ) -> InterpretabilityResult: ...
34
-
35
-
36
29
  def load_image(image: str | Path | Image.Image) -> Image.Image:
37
30
  if isinstance(image, (str, Path)):
38
31
  return Image.open(image)
@@ -0,0 +1,101 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from sklearn.decomposition import PCA
9
+
10
+ from birder.introspection.base import InterpretabilityResult
11
+ from birder.introspection.base import preprocess_image
12
+ from birder.net.base import DetectorBackbone
13
+
14
+
15
+ class FeaturePCA:
16
+ """
17
+ Visualizes feature maps using Principal Component Analysis
18
+
19
+ This method extracts feature maps from a specified stage of a DetectorBackbone model,
20
+ applies PCA to reduce the channel dimension to 3 components, and visualizes them as an RGB image where:
21
+ - R channel = 1st principal component (most important)
22
+ - G channel = 2nd principal component
23
+ - B channel = 3rd principal component
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ net: DetectorBackbone,
29
+ device: torch.device,
30
+ transform: Callable[..., torch.Tensor],
31
+ normalize: bool = False,
32
+ channels_last: bool = False,
33
+ stage: Optional[str] = None,
34
+ ) -> None:
35
+ self.net = net.eval()
36
+ self.device = device
37
+ self.transform = transform
38
+ self.normalize = normalize
39
+ self.channels_last = channels_last
40
+ self.stage = stage
41
+
42
+ def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
43
+ (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
44
+
45
+ with torch.inference_mode():
46
+ features_dict = self.net.detection_features(input_tensor)
47
+
48
+ if self.stage is not None:
49
+ features = features_dict[self.stage]
50
+ else:
51
+ features = list(features_dict.values())[-1] # Use the last stage by default
52
+
53
+ features_np = features.cpu().numpy()
54
+
55
+ # Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
56
+ if self.channels_last is True:
57
+ (B, H, W, C) = features_np.shape
58
+ # Already in (B, H, W, C), just reshape to (B*H*W, C)
59
+ features_reshaped = features_np.reshape(-1, C)
60
+ else:
61
+ (B, C, H, W) = features_np.shape
62
+ # Reshape to (spatial_points, channels) for PCA
63
+ features_reshaped = features_np.reshape(B, C, -1)
64
+ features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
65
+ features_reshaped = features_reshaped.reshape(-1, C) # (B*H*W, C)
66
+
67
+ x = features_reshaped
68
+ if self.normalize is True:
69
+ x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-6)
70
+
71
+ pca = PCA(n_components=3)
72
+ pca_features = pca.fit_transform(x)
73
+ pca_features = pca_features.reshape(B, H, W, 3)
74
+
75
+ # Extract all 3 components (B=1)
76
+ pca_rgb = pca_features[0] # (H, W, 3)
77
+
78
+ # Normalize each channel independently to [0, 1]
79
+ for i in range(3):
80
+ channel = pca_rgb[:, :, i]
81
+ channel = channel - channel.min()
82
+ channel = channel / (channel.max() + 1e-7)
83
+ pca_rgb[:, :, i] = channel
84
+
85
+ target_size = (input_tensor.size(-1), input_tensor.size(-2)) # PIL expects (width, height)
86
+ pca_rgb_resized = (
87
+ np.array(
88
+ Image.fromarray((pca_rgb * 255).astype(np.uint8)).resize(target_size, Image.Resampling.BILINEAR)
89
+ ).astype(np.float32)
90
+ / 255.0
91
+ )
92
+
93
+ visualization = (pca_rgb_resized * 255).astype(np.uint8)
94
+
95
+ return InterpretabilityResult(
96
+ original_image=rgb_img,
97
+ visualization=visualization,
98
+ raw_output=pca_rgb.astype(np.float32),
99
+ logits=None,
100
+ predicted_class=None,
101
+ )
@@ -4,6 +4,9 @@
4
4
  * Taken from:
5
5
  * https://github.com/MrParosk/soft_nms
6
6
  * Licensed under the MIT License
7
+ *
8
+ * Modified by:
9
+ * Ofer Hasson — 2026-01-10
7
10
  **************************************************************************************************
8
11
  */
9
12
 
@@ -40,8 +43,8 @@ torch::Tensor calculate_iou(const torch::Tensor& boxes, const torch::Tensor& are
40
43
  auto xx2 = torch::minimum(boxes.index({idx, 2}), boxes.index({Slice(idx + 1, None), 2}));
41
44
  auto yy2 = torch::minimum(boxes.index({idx, 3}), boxes.index({Slice(idx + 1, None), 3}));
42
45
 
43
- auto w = torch::maximum(torch::zeros_like(xx1), xx2 - xx1);
44
- auto h = torch::maximum(torch::zeros_like(yy1), yy2 - yy1);
46
+ auto w = (xx2 - xx1).clamp_min(0);
47
+ auto h = (yy2 - yy1).clamp_min(0);
45
48
 
46
49
  auto intersection = w * h;
47
50
  auto union_ = areas.index({idx}) + areas.index({Slice(idx + 1, None)}) - intersection;
@@ -87,14 +87,15 @@ class ModelRegistry:
87
87
  no further registration is needed.
88
88
  """
89
89
 
90
+ alias_key = alias.lower()
90
91
  if net_type.auto_register is False:
91
92
  # Register the model manually, as the base class doesn't take care of that for us
92
- registry.register_model(alias, type(alias, (net_type,), {"config": config}))
93
+ self.register_model(alias_key, type(alias, (net_type,), {"config": config}))
93
94
 
94
95
  if alias in self.aliases:
95
96
  warnings.warn(f"Alias {alias} is already registered", UserWarning)
96
97
 
97
- self.aliases[alias] = type(alias, (net_type,), {"config": config})
98
+ self.aliases[alias_key] = type(alias, (net_type,), {"config": config})
98
99
 
99
100
  def register_weights(self, name: str, weights_info: manifest.ModelMetadataType) -> None:
100
101
  if name in self._pretrained_nets:
birder/net/convnext_v1.py CHANGED
@@ -195,6 +195,21 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
195
195
  return self.features(x)
196
196
 
197
197
 
198
+ registry.register_model_config(
199
+ "convnext_v1_atto", # Not in the original v1, taken from v2
200
+ ConvNeXt_v1,
201
+ config={"in_channels": [40, 80, 160, 320], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
202
+ )
203
+ registry.register_model_config(
204
+ "convnext_v1_femto", # Not in the original v1, taken from v2
205
+ ConvNeXt_v1,
206
+ config={"in_channels": [48, 96, 192, 384], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
207
+ )
208
+ registry.register_model_config(
209
+ "convnext_v1_pico", # Not in the original v1, taken from v2
210
+ ConvNeXt_v1,
211
+ config={"in_channels": [64, 128, 256, 512], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
212
+ )
198
213
  registry.register_model_config(
199
214
  "convnext_v1_nano", # Not in the original v1, taken from v2
200
215
  ConvNeXt_v1,
@@ -220,6 +235,11 @@ registry.register_model_config(
220
235
  ConvNeXt_v1,
221
236
  config={"in_channels": [192, 384, 768, 1536], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
222
237
  )
238
+ registry.register_model_config(
239
+ "convnext_v1_huge", # Not in the original v1, taken from v2
240
+ ConvNeXt_v1,
241
+ config={"in_channels": [352, 704, 1408, 2816], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
242
+ )
223
243
 
224
244
  registry.register_weights(
225
245
  "convnext_v1_tiny_eu-common256px",
birder/net/fastvit.py CHANGED
@@ -607,7 +607,6 @@ class AttentionBlock(nn.Module):
607
607
 
608
608
 
609
609
  class FastVitStage(nn.Module):
610
- # pylint: disable=too-many-arguments,too-many-positional-arguments
611
610
  def __init__(
612
611
  self,
613
612
  dim: int,
birder/net/flexivit.py CHANGED
@@ -98,6 +98,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
98
98
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
99
99
  pre_norm: bool = self.config.get("pre_norm", False)
100
100
  post_norm: bool = self.config.get("post_norm", True)
101
+ qkv_bias: bool = self.config.get("qkv_bias", True)
102
+ qk_norm: bool = self.config.get("qk_norm", False)
101
103
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
102
104
  class_token: bool = self.config.get("class_token", True)
103
105
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -186,6 +188,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
186
188
  attention_dropout,
187
189
  dpr,
188
190
  pre_norm=pre_norm,
191
+ qkv_bias=qkv_bias,
192
+ qk_norm=qk_norm,
189
193
  activation_layer=act_layer,
190
194
  layer_scale_init_value=layer_scale_init_value,
191
195
  norm_layer=norm_layer,
@@ -224,6 +228,7 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
224
228
  drop_path=0,
225
229
  activation_layer=act_layer,
226
230
  norm_layer=norm_layer,
231
+ norm_layer_eps=norm_layer_eps,
227
232
  mlp_layer=mlp_layer,
228
233
  )
229
234
 
birder/net/focalnet.py CHANGED
@@ -245,7 +245,6 @@ class FocalNetBlock(nn.Module):
245
245
 
246
246
 
247
247
  class FocalNetStage(nn.Module):
248
- # pylint: disable=too-many-arguments,too-many-positional-arguments
249
248
  def __init__(
250
249
  self,
251
250
  dim: int,
@@ -69,6 +69,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
69
69
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
70
70
  pre_norm: bool = self.config.get("pre_norm", False)
71
71
  post_norm: bool = self.config.get("post_norm", True)
72
+ qkv_bias: bool = self.config.get("qkv_bias", True)
73
+ qk_norm: bool = self.config.get("qk_norm", False)
72
74
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
73
75
  class_token: bool = self.config.get("class_token", True)
74
76
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -118,6 +120,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
118
120
  self.num_reg_tokens = num_reg_tokens
119
121
  self.attn_pool_special_tokens = attn_pool_special_tokens
120
122
  self.norm_layer = norm_layer
123
+ self.norm_layer_eps = norm_layer_eps
121
124
  self.mlp_layer = mlp_layer
122
125
  self.act_layer = act_layer
123
126
  self.rope_rot_type = rope_rot_type
@@ -190,6 +193,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
190
193
  attention_dropout,
191
194
  dpr,
192
195
  pre_norm=pre_norm,
196
+ qkv_bias=qkv_bias,
197
+ qk_norm=qk_norm,
193
198
  activation_layer=act_layer,
194
199
  layer_scale_init_value=layer_scale_init_value,
195
200
  norm_layer=norm_layer,
@@ -231,6 +236,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
231
236
  rope_temperature=rope_temperature,
232
237
  layer_scale_init_value=layer_scale_init_value,
233
238
  norm_layer=norm_layer,
239
+ norm_layer_eps=norm_layer_eps,
234
240
  mlp_layer=mlp_layer,
235
241
  rope_rot_type=rope_rot_type,
236
242
  )
@@ -588,6 +594,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
588
594
  rope_temperature=self.rope_temperature,
589
595
  layer_scale_init_value=self.layer_scale_init_value,
590
596
  norm_layer=self.norm_layer,
597
+ norm_layer_eps=self.norm_layer_eps,
591
598
  mlp_layer=self.mlp_layer,
592
599
  rope_rot_type=self.rope_rot_type,
593
600
  )
birder/net/rope_vit.py CHANGED
@@ -150,6 +150,10 @@ class RoPEAttention(nn.Module):
150
150
  attn_drop: float,
151
151
  proj_drop: float,
152
152
  num_special_tokens: int,
153
+ qkv_bias: bool = True,
154
+ qk_norm: bool = False,
155
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
156
+ norm_layer_eps: float = 1e-6,
153
157
  rope_rot_type: str = "standard",
154
158
  ) -> None:
155
159
  super().__init__()
@@ -167,7 +171,14 @@ class RoPEAttention(nn.Module):
167
171
  else:
168
172
  raise ValueError(f"Unknown rope_rot_type, got '{rope_rot_type}'")
169
173
 
170
- self.qkv = nn.Linear(dim, dim * 3)
174
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
175
+ if qk_norm is True:
176
+ self.q_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
177
+ self.k_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
178
+ else:
179
+ self.q_norm = nn.Identity()
180
+ self.k_norm = nn.Identity()
181
+
171
182
  self.attn_drop = nn.Dropout(attn_drop)
172
183
  self.proj = nn.Linear(dim, dim)
173
184
  self.proj_drop = nn.Dropout(proj_drop)
@@ -176,6 +187,8 @@ class RoPEAttention(nn.Module):
176
187
  (B, N, C) = x.size()
177
188
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
178
189
  (q, k, v) = qkv.unbind(0)
190
+ q = self.q_norm(q)
191
+ k = self.k_norm(k)
179
192
 
180
193
  n = self.num_special_tokens
181
194
  q = torch.concat([q[:, :, :n, :], self.apply_rot_fn(q[:, :, n:, :], rope)], dim=2)
@@ -207,6 +220,8 @@ class EncoderBlock(nn.Module):
207
220
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
208
221
  norm_layer_eps: float = 1e-6,
209
222
  mlp_layer: Callable[..., nn.Module] = FFN,
223
+ qkv_bias: bool = True,
224
+ qk_norm: bool = False,
210
225
  rope_rot_type: str = "standard",
211
226
  ) -> None:
212
227
  super().__init__()
@@ -222,6 +237,10 @@ class EncoderBlock(nn.Module):
222
237
  attn_drop=attention_dropout,
223
238
  proj_drop=dropout,
224
239
  num_special_tokens=num_special_tokens,
240
+ qkv_bias=qkv_bias,
241
+ qk_norm=qk_norm,
242
+ norm_layer=norm_layer,
243
+ norm_layer_eps=norm_layer_eps,
225
244
  rope_rot_type=rope_rot_type,
226
245
  )
227
246
  if layer_scale_init_value is not None:
@@ -249,7 +268,6 @@ class EncoderBlock(nn.Module):
249
268
 
250
269
 
251
270
  class Encoder(nn.Module):
252
- # pylint: disable=too-many-arguments,too-many-positional-arguments
253
271
  def __init__(
254
272
  self,
255
273
  num_layers: int,
@@ -261,6 +279,8 @@ class Encoder(nn.Module):
261
279
  attention_dropout: float,
262
280
  dpr: list[float],
263
281
  pre_norm: bool = False,
282
+ qkv_bias: bool = True,
283
+ qk_norm: bool = False,
264
284
  activation_layer: Callable[..., nn.Module] = nn.GELU,
265
285
  layer_scale_init_value: Optional[float] = None,
266
286
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
@@ -293,6 +313,8 @@ class Encoder(nn.Module):
293
313
  norm_layer=norm_layer,
294
314
  norm_layer_eps=norm_layer_eps,
295
315
  mlp_layer=mlp_layer,
316
+ qkv_bias=qkv_bias,
317
+ qk_norm=qk_norm,
296
318
  rope_rot_type=rope_rot_type,
297
319
  )
298
320
  )
@@ -331,6 +353,7 @@ class MAEDecoderBlock(nn.Module):
331
353
  rope_temperature: float,
332
354
  layer_scale_init_value: Optional[float] = None,
333
355
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
356
+ norm_layer_eps: float = 1e-6,
334
357
  mlp_layer: Callable[..., nn.Module] = FFN,
335
358
  rope_rot_type: str = "standard",
336
359
  ) -> None:
@@ -346,7 +369,7 @@ class MAEDecoderBlock(nn.Module):
346
369
  )
347
370
 
348
371
  # Attention block
349
- self.norm1 = norm_layer(hidden_dim, eps=1e-6)
372
+ self.norm1 = norm_layer(hidden_dim, eps=norm_layer_eps)
350
373
  self.attn = RoPEAttention(
351
374
  hidden_dim,
352
375
  num_heads,
@@ -361,7 +384,7 @@ class MAEDecoderBlock(nn.Module):
361
384
  self.layer_scale_1 = nn.Identity()
362
385
 
363
386
  # MLP block
364
- self.norm2 = norm_layer(hidden_dim, eps=1e-6)
387
+ self.norm2 = norm_layer(hidden_dim, eps=norm_layer_eps)
365
388
  self.mlp = mlp_layer(hidden_dim, mlp_dim, act_layer=activation_layer, dropout=0.0)
366
389
  if layer_scale_init_value is not None:
367
390
  self.layer_scale_2 = LayerScale(hidden_dim, layer_scale_init_value)
@@ -403,6 +426,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
403
426
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
404
427
  pre_norm: bool = self.config.get("pre_norm", False)
405
428
  post_norm: bool = self.config.get("post_norm", True)
429
+ qkv_bias: bool = self.config.get("qkv_bias", True)
430
+ qk_norm: bool = self.config.get("qk_norm", False)
406
431
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
407
432
  class_token: bool = self.config.get("class_token", True)
408
433
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -450,6 +475,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
450
475
  self.num_reg_tokens = num_reg_tokens
451
476
  self.attn_pool_special_tokens = attn_pool_special_tokens
452
477
  self.norm_layer = norm_layer
478
+ self.norm_layer_eps = norm_layer_eps
453
479
  self.mlp_layer = mlp_layer
454
480
  self.act_layer = act_layer
455
481
  self.rope_rot_type = rope_rot_type
@@ -521,6 +547,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
521
547
  attention_dropout,
522
548
  dpr,
523
549
  pre_norm=pre_norm,
550
+ qkv_bias=qkv_bias,
551
+ qk_norm=qk_norm,
524
552
  activation_layer=act_layer,
525
553
  layer_scale_init_value=layer_scale_init_value,
526
554
  norm_layer=norm_layer,
@@ -562,6 +590,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
562
590
  rope_temperature=rope_temperature,
563
591
  layer_scale_init_value=layer_scale_init_value,
564
592
  norm_layer=norm_layer,
593
+ norm_layer_eps=norm_layer_eps,
565
594
  mlp_layer=mlp_layer,
566
595
  rope_rot_type=rope_rot_type,
567
596
  )
@@ -904,6 +933,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
904
933
  rope_temperature=self.rope_temperature,
905
934
  layer_scale_init_value=self.layer_scale_init_value,
906
935
  norm_layer=self.norm_layer,
936
+ norm_layer_eps=self.norm_layer_eps,
907
937
  mlp_layer=self.mlp_layer,
908
938
  rope_rot_type=self.rope_rot_type,
909
939
  )
@@ -931,6 +961,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
931
961
  # - rms : RMSNorm (instead of LayerNorm)
932
962
  # - pn : Pre-Norm (layer norm before the encoder) - implies different norm eps
933
963
  # - npn : No Post Norm (disables post-normalization layer)
964
+ # - qkn : QK Norm
934
965
  #
935
966
  # Feed-Forward Network:
936
967
  # - swiglu : SwiGLU FFN layer type (instead of standard FFN)
@@ -1068,6 +1099,20 @@ registry.register_model_config(
1068
1099
  "drop_path_rate": 0.1,
1069
1100
  },
1070
1101
  )
1102
+ registry.register_model_config(
1103
+ "rope_vit_b16_qkn_ls",
1104
+ RoPE_ViT,
1105
+ config={
1106
+ "patch_size": 16,
1107
+ "num_layers": 12,
1108
+ "num_heads": 12,
1109
+ "hidden_dim": 768,
1110
+ "mlp_dim": 3072,
1111
+ "layer_scale_init_value": 1e-5,
1112
+ "qk_norm": True,
1113
+ "drop_path_rate": 0.1,
1114
+ },
1115
+ )
1071
1116
  registry.register_model_config(
1072
1117
  "rope_i_vit_b16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
1073
1118
  RoPE_ViT,
birder/net/smt.py CHANGED
@@ -275,7 +275,6 @@ class Stem(nn.Module):
275
275
 
276
276
 
277
277
  class SMTStage(nn.Module):
278
- # pylint: disable=too-many-arguments,too-many-positional-arguments
279
278
  def __init__(
280
279
  self,
281
280
  dim: int,
birder/net/ssl/ibot.py CHANGED
@@ -25,7 +25,6 @@ from birder.net.ssl.dino_v1 import DINOHead
25
25
 
26
26
  # pylint: disable=invalid-name
27
27
  class iBOTLoss(nn.Module):
28
- # pylint: disable=too-many-arguments,too-many-positional-arguments
29
28
  def __init__(
30
29
  self,
31
30
  out_dim: int,