birder 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/masking.py +13 -4
  4. birder/inference/classification.py +1 -1
  5. birder/introspection/__init__.py +2 -0
  6. birder/introspection/base.py +0 -7
  7. birder/introspection/feature_pca.py +101 -0
  8. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  9. birder/model_registry/model_registry.py +3 -2
  10. birder/net/convnext_v1.py +20 -0
  11. birder/net/fastvit.py +0 -1
  12. birder/net/flexivit.py +5 -0
  13. birder/net/focalnet.py +0 -1
  14. birder/net/hiera.py +3 -3
  15. birder/net/hieradet.py +116 -28
  16. birder/net/rope_flexivit.py +7 -0
  17. birder/net/rope_vit.py +49 -4
  18. birder/net/smt.py +0 -1
  19. birder/net/ssl/ibot.py +0 -1
  20. birder/net/vit.py +166 -2
  21. birder/scripts/train.py +24 -21
  22. birder/scripts/train_barlow_twins.py +4 -3
  23. birder/scripts/train_byol.py +4 -3
  24. birder/scripts/train_capi.py +6 -5
  25. birder/scripts/train_data2vec.py +4 -3
  26. birder/scripts/train_data2vec2.py +4 -3
  27. birder/scripts/train_detection.py +7 -5
  28. birder/scripts/train_dino_v1.py +5 -4
  29. birder/scripts/train_dino_v2.py +69 -20
  30. birder/scripts/train_dino_v2_dist.py +70 -21
  31. birder/scripts/train_franca.py +8 -7
  32. birder/scripts/train_i_jepa.py +4 -3
  33. birder/scripts/train_ibot.py +5 -4
  34. birder/scripts/train_kd.py +25 -24
  35. birder/scripts/train_mim.py +4 -3
  36. birder/scripts/train_mmcr.py +4 -3
  37. birder/scripts/train_rotnet.py +5 -4
  38. birder/scripts/train_simclr.py +4 -3
  39. birder/scripts/train_vicreg.py +4 -3
  40. birder/tools/avg_model.py +24 -8
  41. birder/tools/introspection.py +35 -9
  42. birder/tools/show_iterator.py +17 -3
  43. birder/version.py +1 -1
  44. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/METADATA +1 -1
  45. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/RECORD +49 -48
  46. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/WHEEL +0 -0
  47. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/entry_points.txt +0 -0
  48. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/licenses/LICENSE +0 -0
  49. {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@
2
2
  DeepFool
3
3
 
4
4
  Paper "DeepFool: a simple and accurate method to fool deep neural networks", https://arxiv.org/abs/1511.04599
5
+
6
+ Generated by gpt-5.2-codex xhigh.
5
7
  """
6
8
 
7
9
  from typing import Optional
@@ -2,6 +2,8 @@
2
2
  SimBA (Simple Black-box Attack)
3
3
 
4
4
  Paper "Simple Black-box Adversarial Attacks", https://arxiv.org/abs/1905.07121
5
+
6
+ Generated by gpt-5.2-codex xhigh.
5
7
  """
6
8
 
7
9
  from typing import Optional
birder/common/masking.py CHANGED
@@ -84,8 +84,8 @@ def mask_tensor(
84
84
 
85
85
  (B, H, W, _) = x.size()
86
86
 
87
- shaped_mask = mask.reshape(-1, H // patch_factor, W // patch_factor)
88
- shaped_mask = shaped_mask.repeat_interleave(patch_factor, axis=1).repeat_interleave(patch_factor, axis=2)
87
+ shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
88
+ shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
89
89
  shaped_mask = shaped_mask.unsqueeze(3).type_as(x)
90
90
 
91
91
  if mask_token is not None:
@@ -228,14 +228,23 @@ class Masking:
228
228
 
229
229
 
230
230
  class UniformMasking(Masking):
231
- def __init__(self, input_size: tuple[int, int], mask_ratio: float, device: Optional[torch.device] = None) -> None:
231
+ def __init__(
232
+ self,
233
+ input_size: tuple[int, int],
234
+ mask_ratio: float,
235
+ min_mask_size: int = 1,
236
+ device: Optional[torch.device] = None,
237
+ ) -> None:
232
238
  self.h = input_size[0]
233
239
  self.w = input_size[1]
234
240
  self.mask_ratio = mask_ratio
241
+ self.min_mask_size = min_mask_size
235
242
  self.device = device
236
243
 
237
244
  def __call__(self, batch_size: int) -> torch.Tensor:
238
- return uniform_mask(batch_size, self.h, self.w, self.mask_ratio, device=self.device)[0]
245
+ return uniform_mask(
246
+ batch_size, self.h, self.w, self.mask_ratio, min_mask_size=self.min_mask_size, device=self.device
247
+ )[0]
239
248
 
240
249
 
241
250
  class BlockMasking(Masking):
@@ -85,7 +85,7 @@ def infer_batch(
85
85
  logits = net(t(tta_input), **kwargs)
86
86
  outs.append(logits if return_logits is True else F.softmax(logits, dim=1))
87
87
 
88
- out = torch.stack(outs).mean(axis=0)
88
+ out = torch.stack(outs).mean(dim=0)
89
89
 
90
90
  else:
91
91
  logits = net(inputs, **kwargs)
@@ -1,5 +1,6 @@
1
1
  from birder.introspection.attention_rollout import AttentionRollout
2
2
  from birder.introspection.base import InterpretabilityResult
3
+ from birder.introspection.feature_pca import FeaturePCA
3
4
  from birder.introspection.gradcam import GradCAM
4
5
  from birder.introspection.guided_backprop import GuidedBackprop
5
6
  from birder.introspection.transformer_attribution import TransformerAttribution
@@ -7,6 +8,7 @@ from birder.introspection.transformer_attribution import TransformerAttribution
7
8
  __all__ = [
8
9
  "InterpretabilityResult",
9
10
  "AttentionRollout",
11
+ "FeaturePCA",
10
12
  "GradCAM",
11
13
  "GuidedBackprop",
12
14
  "TransformerAttribution",
@@ -2,7 +2,6 @@ from collections.abc import Callable
2
2
  from dataclasses import dataclass
3
3
  from pathlib import Path
4
4
  from typing import Optional
5
- from typing import Protocol
6
5
 
7
6
  import matplotlib
8
7
  import matplotlib.pyplot as plt
@@ -27,12 +26,6 @@ class InterpretabilityResult:
27
26
  plt.show()
28
27
 
29
28
 
30
- class Interpreter(Protocol):
31
- def __call__(
32
- self, image: str | Path | Image.Image, target_class: Optional[int] = None
33
- ) -> InterpretabilityResult: ...
34
-
35
-
36
29
  def load_image(image: str | Path | Image.Image) -> Image.Image:
37
30
  if isinstance(image, (str, Path)):
38
31
  return Image.open(image)
@@ -0,0 +1,101 @@
1
+ from collections.abc import Callable
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from sklearn.decomposition import PCA
9
+
10
+ from birder.introspection.base import InterpretabilityResult
11
+ from birder.introspection.base import preprocess_image
12
+ from birder.net.base import DetectorBackbone
13
+
14
+
15
+ class FeaturePCA:
16
+ """
17
+ Visualizes feature maps using Principal Component Analysis
18
+
19
+ This method extracts feature maps from a specified stage of a DetectorBackbone model,
20
+ applies PCA to reduce the channel dimension to 3 components, and visualizes them as an RGB image where:
21
+ - R channel = 1st principal component (most important)
22
+ - G channel = 2nd principal component
23
+ - B channel = 3rd principal component
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ net: DetectorBackbone,
29
+ device: torch.device,
30
+ transform: Callable[..., torch.Tensor],
31
+ normalize: bool = False,
32
+ channels_last: bool = False,
33
+ stage: Optional[str] = None,
34
+ ) -> None:
35
+ self.net = net.eval()
36
+ self.device = device
37
+ self.transform = transform
38
+ self.normalize = normalize
39
+ self.channels_last = channels_last
40
+ self.stage = stage
41
+
42
+ def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
43
+ (input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
44
+
45
+ with torch.inference_mode():
46
+ features_dict = self.net.detection_features(input_tensor)
47
+
48
+ if self.stage is not None:
49
+ features = features_dict[self.stage]
50
+ else:
51
+ features = list(features_dict.values())[-1] # Use the last stage by default
52
+
53
+ features_np = features.cpu().numpy()
54
+
55
+ # Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
56
+ if self.channels_last is True:
57
+ (B, H, W, C) = features_np.shape
58
+ # Already in (B, H, W, C), just reshape to (B*H*W, C)
59
+ features_reshaped = features_np.reshape(-1, C)
60
+ else:
61
+ (B, C, H, W) = features_np.shape
62
+ # Reshape to (spatial_points, channels) for PCA
63
+ features_reshaped = features_np.reshape(B, C, -1)
64
+ features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
65
+ features_reshaped = features_reshaped.reshape(-1, C) # (B*H*W, C)
66
+
67
+ x = features_reshaped
68
+ if self.normalize is True:
69
+ x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-6)
70
+
71
+ pca = PCA(n_components=3)
72
+ pca_features = pca.fit_transform(x)
73
+ pca_features = pca_features.reshape(B, H, W, 3)
74
+
75
+ # Extract all 3 components (B=1)
76
+ pca_rgb = pca_features[0] # (H, W, 3)
77
+
78
+ # Normalize each channel independently to [0, 1]
79
+ for i in range(3):
80
+ channel = pca_rgb[:, :, i]
81
+ channel = channel - channel.min()
82
+ channel = channel / (channel.max() + 1e-7)
83
+ pca_rgb[:, :, i] = channel
84
+
85
+ target_size = (input_tensor.size(-1), input_tensor.size(-2)) # PIL expects (width, height)
86
+ pca_rgb_resized = (
87
+ np.array(
88
+ Image.fromarray((pca_rgb * 255).astype(np.uint8)).resize(target_size, Image.Resampling.BILINEAR)
89
+ ).astype(np.float32)
90
+ / 255.0
91
+ )
92
+
93
+ visualization = (pca_rgb_resized * 255).astype(np.uint8)
94
+
95
+ return InterpretabilityResult(
96
+ original_image=rgb_img,
97
+ visualization=visualization,
98
+ raw_output=pca_rgb.astype(np.float32),
99
+ logits=None,
100
+ predicted_class=None,
101
+ )
@@ -4,6 +4,9 @@
4
4
  * Taken from:
5
5
  * https://github.com/MrParosk/soft_nms
6
6
  * Licensed under the MIT License
7
+ *
8
+ * Modified by:
9
+ * Ofer Hasson — 2026-01-10
7
10
  **************************************************************************************************
8
11
  */
9
12
 
@@ -40,8 +43,8 @@ torch::Tensor calculate_iou(const torch::Tensor& boxes, const torch::Tensor& are
40
43
  auto xx2 = torch::minimum(boxes.index({idx, 2}), boxes.index({Slice(idx + 1, None), 2}));
41
44
  auto yy2 = torch::minimum(boxes.index({idx, 3}), boxes.index({Slice(idx + 1, None), 3}));
42
45
 
43
- auto w = torch::maximum(torch::zeros_like(xx1), xx2 - xx1);
44
- auto h = torch::maximum(torch::zeros_like(yy1), yy2 - yy1);
46
+ auto w = (xx2 - xx1).clamp_min(0);
47
+ auto h = (yy2 - yy1).clamp_min(0);
45
48
 
46
49
  auto intersection = w * h;
47
50
  auto union_ = areas.index({idx}) + areas.index({Slice(idx + 1, None)}) - intersection;
@@ -87,14 +87,15 @@ class ModelRegistry:
87
87
  no further registration is needed.
88
88
  """
89
89
 
90
+ alias_key = alias.lower()
90
91
  if net_type.auto_register is False:
91
92
  # Register the model manually, as the base class doesn't take care of that for us
92
- registry.register_model(alias, type(alias, (net_type,), {"config": config}))
93
+ self.register_model(alias_key, type(alias, (net_type,), {"config": config}))
93
94
 
94
95
  if alias in self.aliases:
95
96
  warnings.warn(f"Alias {alias} is already registered", UserWarning)
96
97
 
97
- self.aliases[alias] = type(alias, (net_type,), {"config": config})
98
+ self.aliases[alias_key] = type(alias, (net_type,), {"config": config})
98
99
 
99
100
  def register_weights(self, name: str, weights_info: manifest.ModelMetadataType) -> None:
100
101
  if name in self._pretrained_nets:
birder/net/convnext_v1.py CHANGED
@@ -195,6 +195,21 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
195
195
  return self.features(x)
196
196
 
197
197
 
198
+ registry.register_model_config(
199
+ "convnext_v1_atto", # Not in the original v1, taken from v2
200
+ ConvNeXt_v1,
201
+ config={"in_channels": [40, 80, 160, 320], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
202
+ )
203
+ registry.register_model_config(
204
+ "convnext_v1_femto", # Not in the original v1, taken from v2
205
+ ConvNeXt_v1,
206
+ config={"in_channels": [48, 96, 192, 384], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
207
+ )
208
+ registry.register_model_config(
209
+ "convnext_v1_pico", # Not in the original v1, taken from v2
210
+ ConvNeXt_v1,
211
+ config={"in_channels": [64, 128, 256, 512], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
212
+ )
198
213
  registry.register_model_config(
199
214
  "convnext_v1_nano", # Not in the original v1, taken from v2
200
215
  ConvNeXt_v1,
@@ -220,6 +235,11 @@ registry.register_model_config(
220
235
  ConvNeXt_v1,
221
236
  config={"in_channels": [192, 384, 768, 1536], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
222
237
  )
238
+ registry.register_model_config(
239
+ "convnext_v1_huge", # Not in the original v1, taken from v2
240
+ ConvNeXt_v1,
241
+ config={"in_channels": [352, 704, 1408, 2816], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
242
+ )
223
243
 
224
244
  registry.register_weights(
225
245
  "convnext_v1_tiny_eu-common256px",
birder/net/fastvit.py CHANGED
@@ -607,7 +607,6 @@ class AttentionBlock(nn.Module):
607
607
 
608
608
 
609
609
  class FastVitStage(nn.Module):
610
- # pylint: disable=too-many-arguments,too-many-positional-arguments
611
610
  def __init__(
612
611
  self,
613
612
  dim: int,
birder/net/flexivit.py CHANGED
@@ -98,6 +98,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
98
98
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
99
99
  pre_norm: bool = self.config.get("pre_norm", False)
100
100
  post_norm: bool = self.config.get("post_norm", True)
101
+ qkv_bias: bool = self.config.get("qkv_bias", True)
102
+ qk_norm: bool = self.config.get("qk_norm", False)
101
103
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
102
104
  class_token: bool = self.config.get("class_token", True)
103
105
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -186,6 +188,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
186
188
  attention_dropout,
187
189
  dpr,
188
190
  pre_norm=pre_norm,
191
+ qkv_bias=qkv_bias,
192
+ qk_norm=qk_norm,
189
193
  activation_layer=act_layer,
190
194
  layer_scale_init_value=layer_scale_init_value,
191
195
  norm_layer=norm_layer,
@@ -224,6 +228,7 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
224
228
  drop_path=0,
225
229
  activation_layer=act_layer,
226
230
  norm_layer=norm_layer,
231
+ norm_layer_eps=norm_layer_eps,
227
232
  mlp_layer=mlp_layer,
228
233
  )
229
234
 
birder/net/focalnet.py CHANGED
@@ -245,7 +245,6 @@ class FocalNetBlock(nn.Module):
245
245
 
246
246
 
247
247
  class FocalNetStage(nn.Module):
248
- # pylint: disable=too-many-arguments,too-many-positional-arguments
249
248
  def __init__(
250
249
  self,
251
250
  dim: int,
birder/net/hiera.py CHANGED
@@ -301,14 +301,14 @@ class HieraBlock(nn.Module):
301
301
  self.dim = dim
302
302
  self.dim_out = dim_out
303
303
 
304
- self.norm1 = nn.LayerNorm(dim)
304
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
305
305
  if dim != dim_out:
306
306
  self.proj = nn.Linear(dim, dim_out)
307
307
  else:
308
308
  self.proj = None
309
309
 
310
310
  self.attn = MaskUnitAttention(dim, dim_out, heads, q_stride, window_size, use_mask_unit_attn)
311
- self.norm2 = nn.LayerNorm(dim_out)
311
+ self.norm2 = nn.LayerNorm(dim_out, eps=1e-6)
312
312
  self.mlp = MLP(dim_out, [int(dim_out * mlp_ratio), dim_out], activation_layer=nn.GELU)
313
313
  self.drop_path = StochasticDepth(drop_path, mode="row")
314
314
 
@@ -450,7 +450,7 @@ class Hiera(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
450
450
  self.body = nn.Sequential(stages)
451
451
  self.features = nn.Sequential(
452
452
  attn_pool if attn_pool is not None else AvgTokens(),
453
- nn.LayerNorm(embed_dim),
453
+ nn.LayerNorm(embed_dim, eps=1e-6),
454
454
  nn.Flatten(1),
455
455
  )
456
456
  self.return_channels = return_channels
birder/net/hieradet.py CHANGED
@@ -125,7 +125,7 @@ class MultiScaleBlock(nn.Module):
125
125
  self.dim = dim
126
126
  self.dim_out = dim_out
127
127
 
128
- self.norm1 = nn.LayerNorm(dim)
128
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
129
129
  if dim != dim_out:
130
130
  self.proj = nn.Linear(dim, dim_out)
131
131
  else:
@@ -144,7 +144,7 @@ class MultiScaleBlock(nn.Module):
144
144
  num_heads=num_heads,
145
145
  q_pool=copy.deepcopy(self.pool),
146
146
  )
147
- self.norm2 = nn.LayerNorm(dim_out)
147
+ self.norm2 = nn.LayerNorm(dim_out, eps=1e-6)
148
148
  self.mlp = MLP(dim_out, [int(dim_out * mlp_ratio), dim_out], activation_layer=nn.GELU)
149
149
  self.drop_path = StochasticDepth(drop_path, mode="row")
150
150
 
@@ -173,11 +173,9 @@ class MultiScaleBlock(nn.Module):
173
173
  if self.q_stride is not None:
174
174
  # Shapes have changed due to Q pooling
175
175
  window_size = self.window_size // self.q_stride[0]
176
- (H, W) = (shortcut.size(1), shortcut.size(2))
176
+ pad_hw = (pad_hw[0] // self.q_stride[0], pad_hw[1] // self.q_stride[1])
177
177
 
178
- pad_h = (window_size - H % window_size) % window_size
179
- pad_w = (window_size - W % window_size) % window_size
180
- pad_hw = (H + pad_h, W + pad_w)
178
+ (H, W) = (shortcut.size(1), shortcut.size(2))
181
179
 
182
180
  # Reverse window partition
183
181
  if self.window_size > 0:
@@ -271,7 +269,7 @@ class HieraDet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
271
269
 
272
270
  self.body = nn.Sequential(stages)
273
271
  self.features = nn.Sequential(
274
- nn.LayerNorm(embed_dim),
272
+ nn.LayerNorm(embed_dim, eps=1e-6),
275
273
  Permute([0, 3, 1, 2]), # B H W C -> B C H W
276
274
  nn.AdaptiveAvgPool2d(output_size=(1, 1)),
277
275
  nn.Flatten(1),
@@ -415,7 +413,7 @@ registry.register_model_config(
415
413
  "num_heads": 1,
416
414
  "global_pos_size": (7, 7),
417
415
  "global_att_blocks": [5, 7, 9],
418
- "window_spec": [8, 4, -16, -32],
416
+ "window_spec": [8, 4, 14, 7],
419
417
  "drop_path_rate": 0.1,
420
418
  },
421
419
  )
@@ -428,7 +426,7 @@ registry.register_model_config(
428
426
  "num_heads": 1,
429
427
  "global_pos_size": (7, 7),
430
428
  "global_att_blocks": [7, 10, 13],
431
- "window_spec": [8, 4, -16, -32],
429
+ "window_spec": [8, 4, 14, 7],
432
430
  "drop_path_rate": 0.1,
433
431
  },
434
432
  )
@@ -441,7 +439,7 @@ registry.register_model_config(
441
439
  "num_heads": 1,
442
440
  "global_pos_size": (14, 14),
443
441
  "global_att_blocks": [12, 16, 20],
444
- "window_spec": [8, 4, -16, -32],
442
+ "window_spec": [8, 4, 14, 7],
445
443
  "drop_path_rate": 0.1,
446
444
  },
447
445
  )
@@ -454,7 +452,7 @@ registry.register_model_config(
454
452
  "num_heads": 2,
455
453
  "global_pos_size": (14, 14),
456
454
  "global_att_blocks": [12, 16, 20],
457
- "window_spec": [8, 4, -16, -32],
455
+ "window_spec": [8, 4, 14, 7],
458
456
  "drop_path_rate": 0.1,
459
457
  },
460
458
  )
@@ -467,17 +465,84 @@ registry.register_model_config(
467
465
  "num_heads": 2,
468
466
  "global_pos_size": (7, 7),
469
467
  "global_att_blocks": [23, 33, 43],
470
- "window_spec": [8, 4, -16, -32],
468
+ "window_spec": [8, 4, 14, 7],
469
+ "drop_path_rate": 0.2,
470
+ },
471
+ )
472
+
473
+ # Dynamic window size
474
+ registry.register_model_config(
475
+ "hieradet_d_tiny",
476
+ HieraDet,
477
+ config={
478
+ "depths": [1, 2, 7, 2],
479
+ "embed_dim": 96,
480
+ "num_heads": 1,
481
+ "global_pos_size": (7, 7),
482
+ "global_att_blocks": [5, 7, 9],
483
+ "window_spec": [8, 4, 0, 0],
484
+ "drop_path_rate": 0.1,
485
+ },
486
+ )
487
+ registry.register_model_config(
488
+ "hieradet_d_small",
489
+ HieraDet,
490
+ config={
491
+ "depths": [1, 2, 11, 2],
492
+ "embed_dim": 96,
493
+ "num_heads": 1,
494
+ "global_pos_size": (7, 7),
495
+ "global_att_blocks": [7, 10, 13],
496
+ "window_spec": [8, 4, 0, 0],
497
+ "drop_path_rate": 0.1,
498
+ },
499
+ )
500
+ registry.register_model_config(
501
+ "hieradet_d_base",
502
+ HieraDet,
503
+ config={
504
+ "depths": [2, 3, 16, 3],
505
+ "embed_dim": 96,
506
+ "num_heads": 1,
507
+ "global_pos_size": (14, 14),
508
+ "global_att_blocks": [12, 16, 20],
509
+ "window_spec": [8, 4, 0, 0],
510
+ "drop_path_rate": 0.1,
511
+ },
512
+ )
513
+ registry.register_model_config(
514
+ "hieradet_d_base_plus",
515
+ HieraDet,
516
+ config={
517
+ "depths": [2, 3, 16, 3],
518
+ "embed_dim": 112,
519
+ "num_heads": 2,
520
+ "global_pos_size": (14, 14),
521
+ "global_att_blocks": [12, 16, 20],
522
+ "window_spec": [8, 4, 0, 0],
523
+ "drop_path_rate": 0.1,
524
+ },
525
+ )
526
+ registry.register_model_config(
527
+ "hieradet_d_large",
528
+ HieraDet,
529
+ config={
530
+ "depths": [2, 6, 36, 4],
531
+ "embed_dim": 144,
532
+ "num_heads": 2,
533
+ "global_pos_size": (7, 7),
534
+ "global_att_blocks": [23, 33, 43],
535
+ "window_spec": [8, 4, 0, 0],
471
536
  "drop_path_rate": 0.2,
472
537
  },
473
538
  )
474
539
 
475
540
  registry.register_weights(
476
- "hieradet_small_dino-v2",
541
+ "hieradet_d_small_dino-v2",
477
542
  {
478
- "url": "https://huggingface.co/birder-project/hieradet_small_dino-v2/resolve/main",
543
+ "url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2/resolve/main",
479
544
  "description": (
480
- "HieraDet small image encoder pre-trained using DINOv2. "
545
+ "HieraDet (d) small image encoder pre-trained using DINOv2. "
481
546
  "This model has not been fine-tuned for a specific classification task"
482
547
  ),
483
548
  "resolution": (224, 224),
@@ -487,14 +552,16 @@ registry.register_weights(
487
552
  "sha256": "eb41b8a35445e7f350797094d5e365306b29351e64edd4a316420c23d1e17073",
488
553
  }
489
554
  },
490
- "net": {"network": "hieradet_small", "tag": "dino-v2"},
555
+ "net": {"network": "hieradet_d_small", "tag": "dino-v2"},
491
556
  },
492
557
  )
493
558
  registry.register_weights(
494
- "hieradet_small_dino-v2-inat21-256px",
559
+ "hieradet_d_small_dino-v2-inat21-256px",
495
560
  {
496
- "url": "https://huggingface.co/birder-project/hieradet_small_dino-v2-inat21/resolve/main",
497
- "description": "HieraDet small model pre-trained using DINOv2, then fine-tuned on the iNaturalist 2021 dataset",
561
+ "url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2-inat21/resolve/main",
562
+ "description": (
563
+ "HieraDet (d) small model pre-trained using DINOv2, then fine-tuned on the iNaturalist 2021 dataset"
564
+ ),
498
565
  "resolution": (256, 256),
499
566
  "formats": {
500
567
  "pt": {
@@ -502,14 +569,16 @@ registry.register_weights(
502
569
  "sha256": "e1bdeba97eae816ec3ab9b3238d97decf2c34d29b70f9291116ce962b9a4f9df",
503
570
  }
504
571
  },
505
- "net": {"network": "hieradet_small", "tag": "dino-v2-inat21-256px"},
572
+ "net": {"network": "hieradet_d_small", "tag": "dino-v2-inat21-256px"},
506
573
  },
507
574
  )
508
575
  registry.register_weights(
509
- "hieradet_small_dino-v2-inat21",
576
+ "hieradet_d_small_dino-v2-inat21",
510
577
  {
511
- "url": "https://huggingface.co/birder-project/hieradet_small_dino-v2-inat21/resolve/main",
512
- "description": "HieraDet small model pre-trained using DINOv2, then fine-tuned on the iNaturalist 2021 dataset",
578
+ "url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2-inat21/resolve/main",
579
+ "description": (
580
+ "HieraDet (d) small model pre-trained using DINOv2, then fine-tuned on the iNaturalist 2021 dataset"
581
+ ),
513
582
  "resolution": (384, 384),
514
583
  "formats": {
515
584
  "pt": {
@@ -517,14 +586,14 @@ registry.register_weights(
517
586
  "sha256": "271fa9ed6a9aa1f4d1fc8bbb4c4cac9d15b264f2ac544efb5cd971412691880d",
518
587
  }
519
588
  },
520
- "net": {"network": "hieradet_small", "tag": "dino-v2-inat21"},
589
+ "net": {"network": "hieradet_d_small", "tag": "dino-v2-inat21"},
521
590
  },
522
591
  )
523
592
  registry.register_weights(
524
- "hieradet_small_dino-v2-imagenet12k",
593
+ "hieradet_d_small_dino-v2-imagenet12k",
525
594
  {
526
- "url": "https://huggingface.co/birder-project/hieradet_small_dino-v2-imagenet12k/resolve/main",
527
- "description": "HieraDet small model pre-trained using DINOv2, then fine-tuned on the ImageNet-12K dataset",
595
+ "url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2-imagenet12k/resolve/main",
596
+ "description": "HieraDet (d) small model pre-trained using DINOv2, then fine-tuned on the ImageNet-12K dataset",
528
597
  "resolution": (256, 256),
529
598
  "formats": {
530
599
  "pt": {
@@ -532,6 +601,25 @@ registry.register_weights(
532
601
  "sha256": "b89dd6c13d061fe8a09d051bb3d76e632e650067ca71578e37b02033107c9963",
533
602
  }
534
603
  },
535
- "net": {"network": "hieradet_small", "tag": "dino-v2-imagenet12k"},
604
+ "net": {"network": "hieradet_d_small", "tag": "dino-v2-imagenet12k"},
605
+ },
606
+ )
607
+
608
+ registry.register_weights( # SAM v2: https://arxiv.org/abs/2408.00714
609
+ "hieradet_small_sam2_1",
610
+ {
611
+ "url": "https://huggingface.co/birder-project/hieradet_small_sam2_1/resolve/main",
612
+ "description": (
613
+ "HieraDet small image encoder pre-trained by Meta AI using SAM v2. "
614
+ "This model has not been fine-tuned for a specific classification task"
615
+ ),
616
+ "resolution": (224, 224),
617
+ "formats": {
618
+ "pt": {
619
+ "file_size": 129.6,
620
+ "sha256": "79b6ffdfd4ea9f3b1489ce5a229fe9756b215fc3b52640d01d64136560c1d341",
621
+ }
622
+ },
623
+ "net": {"network": "hieradet_small", "tag": "sam2_1"},
536
624
  },
537
625
  )
@@ -69,6 +69,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
69
69
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
70
70
  pre_norm: bool = self.config.get("pre_norm", False)
71
71
  post_norm: bool = self.config.get("post_norm", True)
72
+ qkv_bias: bool = self.config.get("qkv_bias", True)
73
+ qk_norm: bool = self.config.get("qk_norm", False)
72
74
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
73
75
  class_token: bool = self.config.get("class_token", True)
74
76
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -118,6 +120,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
118
120
  self.num_reg_tokens = num_reg_tokens
119
121
  self.attn_pool_special_tokens = attn_pool_special_tokens
120
122
  self.norm_layer = norm_layer
123
+ self.norm_layer_eps = norm_layer_eps
121
124
  self.mlp_layer = mlp_layer
122
125
  self.act_layer = act_layer
123
126
  self.rope_rot_type = rope_rot_type
@@ -190,6 +193,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
190
193
  attention_dropout,
191
194
  dpr,
192
195
  pre_norm=pre_norm,
196
+ qkv_bias=qkv_bias,
197
+ qk_norm=qk_norm,
193
198
  activation_layer=act_layer,
194
199
  layer_scale_init_value=layer_scale_init_value,
195
200
  norm_layer=norm_layer,
@@ -231,6 +236,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
231
236
  rope_temperature=rope_temperature,
232
237
  layer_scale_init_value=layer_scale_init_value,
233
238
  norm_layer=norm_layer,
239
+ norm_layer_eps=norm_layer_eps,
234
240
  mlp_layer=mlp_layer,
235
241
  rope_rot_type=rope_rot_type,
236
242
  )
@@ -588,6 +594,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
588
594
  rope_temperature=self.rope_temperature,
589
595
  layer_scale_init_value=self.layer_scale_init_value,
590
596
  norm_layer=self.norm_layer,
597
+ norm_layer_eps=self.norm_layer_eps,
591
598
  mlp_layer=self.mlp_layer,
592
599
  rope_rot_type=self.rope_rot_type,
593
600
  )