birder 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/deepfool.py +2 -0
- birder/adversarial/simba.py +2 -0
- birder/common/masking.py +13 -4
- birder/inference/classification.py +1 -1
- birder/introspection/__init__.py +2 -0
- birder/introspection/base.py +0 -7
- birder/introspection/feature_pca.py +101 -0
- birder/kernels/soft_nms/soft_nms.cpp +5 -2
- birder/model_registry/model_registry.py +3 -2
- birder/net/convnext_v1.py +20 -0
- birder/net/fastvit.py +0 -1
- birder/net/flexivit.py +5 -0
- birder/net/focalnet.py +0 -1
- birder/net/hiera.py +3 -3
- birder/net/hieradet.py +116 -28
- birder/net/rope_flexivit.py +7 -0
- birder/net/rope_vit.py +49 -4
- birder/net/smt.py +0 -1
- birder/net/ssl/ibot.py +0 -1
- birder/net/vit.py +166 -2
- birder/scripts/train.py +24 -21
- birder/scripts/train_barlow_twins.py +4 -3
- birder/scripts/train_byol.py +4 -3
- birder/scripts/train_capi.py +6 -5
- birder/scripts/train_data2vec.py +4 -3
- birder/scripts/train_data2vec2.py +4 -3
- birder/scripts/train_detection.py +7 -5
- birder/scripts/train_dino_v1.py +5 -4
- birder/scripts/train_dino_v2.py +69 -20
- birder/scripts/train_dino_v2_dist.py +70 -21
- birder/scripts/train_franca.py +8 -7
- birder/scripts/train_i_jepa.py +4 -3
- birder/scripts/train_ibot.py +5 -4
- birder/scripts/train_kd.py +25 -24
- birder/scripts/train_mim.py +4 -3
- birder/scripts/train_mmcr.py +4 -3
- birder/scripts/train_rotnet.py +5 -4
- birder/scripts/train_simclr.py +4 -3
- birder/scripts/train_vicreg.py +4 -3
- birder/tools/avg_model.py +24 -8
- birder/tools/introspection.py +35 -9
- birder/tools/show_iterator.py +17 -3
- birder/version.py +1 -1
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/METADATA +1 -1
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/RECORD +49 -48
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/WHEEL +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/entry_points.txt +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.1.dist-info → birder-0.3.3.dist-info}/top_level.txt +0 -0
birder/adversarial/deepfool.py
CHANGED
birder/adversarial/simba.py
CHANGED
birder/common/masking.py
CHANGED
|
@@ -84,8 +84,8 @@ def mask_tensor(
|
|
|
84
84
|
|
|
85
85
|
(B, H, W, _) = x.size()
|
|
86
86
|
|
|
87
|
-
shaped_mask = mask.reshape(
|
|
88
|
-
shaped_mask = shaped_mask.repeat_interleave(patch_factor,
|
|
87
|
+
shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
|
|
88
|
+
shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
|
|
89
89
|
shaped_mask = shaped_mask.unsqueeze(3).type_as(x)
|
|
90
90
|
|
|
91
91
|
if mask_token is not None:
|
|
@@ -228,14 +228,23 @@ class Masking:
|
|
|
228
228
|
|
|
229
229
|
|
|
230
230
|
class UniformMasking(Masking):
|
|
231
|
-
def __init__(
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
input_size: tuple[int, int],
|
|
234
|
+
mask_ratio: float,
|
|
235
|
+
min_mask_size: int = 1,
|
|
236
|
+
device: Optional[torch.device] = None,
|
|
237
|
+
) -> None:
|
|
232
238
|
self.h = input_size[0]
|
|
233
239
|
self.w = input_size[1]
|
|
234
240
|
self.mask_ratio = mask_ratio
|
|
241
|
+
self.min_mask_size = min_mask_size
|
|
235
242
|
self.device = device
|
|
236
243
|
|
|
237
244
|
def __call__(self, batch_size: int) -> torch.Tensor:
|
|
238
|
-
return uniform_mask(
|
|
245
|
+
return uniform_mask(
|
|
246
|
+
batch_size, self.h, self.w, self.mask_ratio, min_mask_size=self.min_mask_size, device=self.device
|
|
247
|
+
)[0]
|
|
239
248
|
|
|
240
249
|
|
|
241
250
|
class BlockMasking(Masking):
|
|
@@ -85,7 +85,7 @@ def infer_batch(
|
|
|
85
85
|
logits = net(t(tta_input), **kwargs)
|
|
86
86
|
outs.append(logits if return_logits is True else F.softmax(logits, dim=1))
|
|
87
87
|
|
|
88
|
-
out = torch.stack(outs).mean(
|
|
88
|
+
out = torch.stack(outs).mean(dim=0)
|
|
89
89
|
|
|
90
90
|
else:
|
|
91
91
|
logits = net(inputs, **kwargs)
|
birder/introspection/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from birder.introspection.attention_rollout import AttentionRollout
|
|
2
2
|
from birder.introspection.base import InterpretabilityResult
|
|
3
|
+
from birder.introspection.feature_pca import FeaturePCA
|
|
3
4
|
from birder.introspection.gradcam import GradCAM
|
|
4
5
|
from birder.introspection.guided_backprop import GuidedBackprop
|
|
5
6
|
from birder.introspection.transformer_attribution import TransformerAttribution
|
|
@@ -7,6 +8,7 @@ from birder.introspection.transformer_attribution import TransformerAttribution
|
|
|
7
8
|
__all__ = [
|
|
8
9
|
"InterpretabilityResult",
|
|
9
10
|
"AttentionRollout",
|
|
11
|
+
"FeaturePCA",
|
|
10
12
|
"GradCAM",
|
|
11
13
|
"GuidedBackprop",
|
|
12
14
|
"TransformerAttribution",
|
birder/introspection/base.py
CHANGED
|
@@ -2,7 +2,6 @@ from collections.abc import Callable
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Optional
|
|
5
|
-
from typing import Protocol
|
|
6
5
|
|
|
7
6
|
import matplotlib
|
|
8
7
|
import matplotlib.pyplot as plt
|
|
@@ -27,12 +26,6 @@ class InterpretabilityResult:
|
|
|
27
26
|
plt.show()
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
class Interpreter(Protocol):
|
|
31
|
-
def __call__(
|
|
32
|
-
self, image: str | Path | Image.Image, target_class: Optional[int] = None
|
|
33
|
-
) -> InterpretabilityResult: ...
|
|
34
|
-
|
|
35
|
-
|
|
36
29
|
def load_image(image: str | Path | Image.Image) -> Image.Image:
|
|
37
30
|
if isinstance(image, (str, Path)):
|
|
38
31
|
return Image.open(image)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from sklearn.decomposition import PCA
|
|
9
|
+
|
|
10
|
+
from birder.introspection.base import InterpretabilityResult
|
|
11
|
+
from birder.introspection.base import preprocess_image
|
|
12
|
+
from birder.net.base import DetectorBackbone
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FeaturePCA:
|
|
16
|
+
"""
|
|
17
|
+
Visualizes feature maps using Principal Component Analysis
|
|
18
|
+
|
|
19
|
+
This method extracts feature maps from a specified stage of a DetectorBackbone model,
|
|
20
|
+
applies PCA to reduce the channel dimension to 3 components, and visualizes them as an RGB image where:
|
|
21
|
+
- R channel = 1st principal component (most important)
|
|
22
|
+
- G channel = 2nd principal component
|
|
23
|
+
- B channel = 3rd principal component
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
net: DetectorBackbone,
|
|
29
|
+
device: torch.device,
|
|
30
|
+
transform: Callable[..., torch.Tensor],
|
|
31
|
+
normalize: bool = False,
|
|
32
|
+
channels_last: bool = False,
|
|
33
|
+
stage: Optional[str] = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
self.net = net.eval()
|
|
36
|
+
self.device = device
|
|
37
|
+
self.transform = transform
|
|
38
|
+
self.normalize = normalize
|
|
39
|
+
self.channels_last = channels_last
|
|
40
|
+
self.stage = stage
|
|
41
|
+
|
|
42
|
+
def __call__(self, image: str | Path | Image.Image) -> InterpretabilityResult:
|
|
43
|
+
(input_tensor, rgb_img) = preprocess_image(image, self.transform, self.device)
|
|
44
|
+
|
|
45
|
+
with torch.inference_mode():
|
|
46
|
+
features_dict = self.net.detection_features(input_tensor)
|
|
47
|
+
|
|
48
|
+
if self.stage is not None:
|
|
49
|
+
features = features_dict[self.stage]
|
|
50
|
+
else:
|
|
51
|
+
features = list(features_dict.values())[-1] # Use the last stage by default
|
|
52
|
+
|
|
53
|
+
features_np = features.cpu().numpy()
|
|
54
|
+
|
|
55
|
+
# Handle channels_last format (B, H, W, C) vs channels_first (B, C, H, W)
|
|
56
|
+
if self.channels_last is True:
|
|
57
|
+
(B, H, W, C) = features_np.shape
|
|
58
|
+
# Already in (B, H, W, C), just reshape to (B*H*W, C)
|
|
59
|
+
features_reshaped = features_np.reshape(-1, C)
|
|
60
|
+
else:
|
|
61
|
+
(B, C, H, W) = features_np.shape
|
|
62
|
+
# Reshape to (spatial_points, channels) for PCA
|
|
63
|
+
features_reshaped = features_np.reshape(B, C, -1)
|
|
64
|
+
features_reshaped = features_reshaped.transpose(0, 2, 1) # (B, H*W, C)
|
|
65
|
+
features_reshaped = features_reshaped.reshape(-1, C) # (B*H*W, C)
|
|
66
|
+
|
|
67
|
+
x = features_reshaped
|
|
68
|
+
if self.normalize is True:
|
|
69
|
+
x = x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-6)
|
|
70
|
+
|
|
71
|
+
pca = PCA(n_components=3)
|
|
72
|
+
pca_features = pca.fit_transform(x)
|
|
73
|
+
pca_features = pca_features.reshape(B, H, W, 3)
|
|
74
|
+
|
|
75
|
+
# Extract all 3 components (B=1)
|
|
76
|
+
pca_rgb = pca_features[0] # (H, W, 3)
|
|
77
|
+
|
|
78
|
+
# Normalize each channel independently to [0, 1]
|
|
79
|
+
for i in range(3):
|
|
80
|
+
channel = pca_rgb[:, :, i]
|
|
81
|
+
channel = channel - channel.min()
|
|
82
|
+
channel = channel / (channel.max() + 1e-7)
|
|
83
|
+
pca_rgb[:, :, i] = channel
|
|
84
|
+
|
|
85
|
+
target_size = (input_tensor.size(-1), input_tensor.size(-2)) # PIL expects (width, height)
|
|
86
|
+
pca_rgb_resized = (
|
|
87
|
+
np.array(
|
|
88
|
+
Image.fromarray((pca_rgb * 255).astype(np.uint8)).resize(target_size, Image.Resampling.BILINEAR)
|
|
89
|
+
).astype(np.float32)
|
|
90
|
+
/ 255.0
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
visualization = (pca_rgb_resized * 255).astype(np.uint8)
|
|
94
|
+
|
|
95
|
+
return InterpretabilityResult(
|
|
96
|
+
original_image=rgb_img,
|
|
97
|
+
visualization=visualization,
|
|
98
|
+
raw_output=pca_rgb.astype(np.float32),
|
|
99
|
+
logits=None,
|
|
100
|
+
predicted_class=None,
|
|
101
|
+
)
|
|
@@ -4,6 +4,9 @@
|
|
|
4
4
|
* Taken from:
|
|
5
5
|
* https://github.com/MrParosk/soft_nms
|
|
6
6
|
* Licensed under the MIT License
|
|
7
|
+
*
|
|
8
|
+
* Modified by:
|
|
9
|
+
* Ofer Hasson — 2026-01-10
|
|
7
10
|
**************************************************************************************************
|
|
8
11
|
*/
|
|
9
12
|
|
|
@@ -40,8 +43,8 @@ torch::Tensor calculate_iou(const torch::Tensor& boxes, const torch::Tensor& are
|
|
|
40
43
|
auto xx2 = torch::minimum(boxes.index({idx, 2}), boxes.index({Slice(idx + 1, None), 2}));
|
|
41
44
|
auto yy2 = torch::minimum(boxes.index({idx, 3}), boxes.index({Slice(idx + 1, None), 3}));
|
|
42
45
|
|
|
43
|
-
auto w =
|
|
44
|
-
auto h =
|
|
46
|
+
auto w = (xx2 - xx1).clamp_min(0);
|
|
47
|
+
auto h = (yy2 - yy1).clamp_min(0);
|
|
45
48
|
|
|
46
49
|
auto intersection = w * h;
|
|
47
50
|
auto union_ = areas.index({idx}) + areas.index({Slice(idx + 1, None)}) - intersection;
|
|
@@ -87,14 +87,15 @@ class ModelRegistry:
|
|
|
87
87
|
no further registration is needed.
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
|
+
alias_key = alias.lower()
|
|
90
91
|
if net_type.auto_register is False:
|
|
91
92
|
# Register the model manually, as the base class doesn't take care of that for us
|
|
92
|
-
|
|
93
|
+
self.register_model(alias_key, type(alias, (net_type,), {"config": config}))
|
|
93
94
|
|
|
94
95
|
if alias in self.aliases:
|
|
95
96
|
warnings.warn(f"Alias {alias} is already registered", UserWarning)
|
|
96
97
|
|
|
97
|
-
self.aliases[
|
|
98
|
+
self.aliases[alias_key] = type(alias, (net_type,), {"config": config})
|
|
98
99
|
|
|
99
100
|
def register_weights(self, name: str, weights_info: manifest.ModelMetadataType) -> None:
|
|
100
101
|
if name in self._pretrained_nets:
|
birder/net/convnext_v1.py
CHANGED
|
@@ -195,6 +195,21 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
195
195
|
return self.features(x)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
registry.register_model_config(
|
|
199
|
+
"convnext_v1_atto", # Not in the original v1, taken from v2
|
|
200
|
+
ConvNeXt_v1,
|
|
201
|
+
config={"in_channels": [40, 80, 160, 320], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
202
|
+
)
|
|
203
|
+
registry.register_model_config(
|
|
204
|
+
"convnext_v1_femto", # Not in the original v1, taken from v2
|
|
205
|
+
ConvNeXt_v1,
|
|
206
|
+
config={"in_channels": [48, 96, 192, 384], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
207
|
+
)
|
|
208
|
+
registry.register_model_config(
|
|
209
|
+
"convnext_v1_pico", # Not in the original v1, taken from v2
|
|
210
|
+
ConvNeXt_v1,
|
|
211
|
+
config={"in_channels": [64, 128, 256, 512], "num_layers": [2, 2, 6, 2], "drop_path_rate": 0.0},
|
|
212
|
+
)
|
|
198
213
|
registry.register_model_config(
|
|
199
214
|
"convnext_v1_nano", # Not in the original v1, taken from v2
|
|
200
215
|
ConvNeXt_v1,
|
|
@@ -220,6 +235,11 @@ registry.register_model_config(
|
|
|
220
235
|
ConvNeXt_v1,
|
|
221
236
|
config={"in_channels": [192, 384, 768, 1536], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
|
|
222
237
|
)
|
|
238
|
+
registry.register_model_config(
|
|
239
|
+
"convnext_v1_huge", # Not in the original v1, taken from v2
|
|
240
|
+
ConvNeXt_v1,
|
|
241
|
+
config={"in_channels": [352, 704, 1408, 2816], "num_layers": [3, 3, 27, 3], "drop_path_rate": 0.5},
|
|
242
|
+
)
|
|
223
243
|
|
|
224
244
|
registry.register_weights(
|
|
225
245
|
"convnext_v1_tiny_eu-common256px",
|
birder/net/fastvit.py
CHANGED
birder/net/flexivit.py
CHANGED
|
@@ -98,6 +98,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
98
98
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
|
|
99
99
|
pre_norm: bool = self.config.get("pre_norm", False)
|
|
100
100
|
post_norm: bool = self.config.get("post_norm", True)
|
|
101
|
+
qkv_bias: bool = self.config.get("qkv_bias", True)
|
|
102
|
+
qk_norm: bool = self.config.get("qk_norm", False)
|
|
101
103
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
102
104
|
class_token: bool = self.config.get("class_token", True)
|
|
103
105
|
attn_pool_head: bool = self.config.get("attn_pool_head", False)
|
|
@@ -186,6 +188,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
186
188
|
attention_dropout,
|
|
187
189
|
dpr,
|
|
188
190
|
pre_norm=pre_norm,
|
|
191
|
+
qkv_bias=qkv_bias,
|
|
192
|
+
qk_norm=qk_norm,
|
|
189
193
|
activation_layer=act_layer,
|
|
190
194
|
layer_scale_init_value=layer_scale_init_value,
|
|
191
195
|
norm_layer=norm_layer,
|
|
@@ -224,6 +228,7 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
|
|
|
224
228
|
drop_path=0,
|
|
225
229
|
activation_layer=act_layer,
|
|
226
230
|
norm_layer=norm_layer,
|
|
231
|
+
norm_layer_eps=norm_layer_eps,
|
|
227
232
|
mlp_layer=mlp_layer,
|
|
228
233
|
)
|
|
229
234
|
|
birder/net/focalnet.py
CHANGED
birder/net/hiera.py
CHANGED
|
@@ -301,14 +301,14 @@ class HieraBlock(nn.Module):
|
|
|
301
301
|
self.dim = dim
|
|
302
302
|
self.dim_out = dim_out
|
|
303
303
|
|
|
304
|
-
self.norm1 = nn.LayerNorm(dim)
|
|
304
|
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
305
305
|
if dim != dim_out:
|
|
306
306
|
self.proj = nn.Linear(dim, dim_out)
|
|
307
307
|
else:
|
|
308
308
|
self.proj = None
|
|
309
309
|
|
|
310
310
|
self.attn = MaskUnitAttention(dim, dim_out, heads, q_stride, window_size, use_mask_unit_attn)
|
|
311
|
-
self.norm2 = nn.LayerNorm(dim_out)
|
|
311
|
+
self.norm2 = nn.LayerNorm(dim_out, eps=1e-6)
|
|
312
312
|
self.mlp = MLP(dim_out, [int(dim_out * mlp_ratio), dim_out], activation_layer=nn.GELU)
|
|
313
313
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
314
314
|
|
|
@@ -450,7 +450,7 @@ class Hiera(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
450
450
|
self.body = nn.Sequential(stages)
|
|
451
451
|
self.features = nn.Sequential(
|
|
452
452
|
attn_pool if attn_pool is not None else AvgTokens(),
|
|
453
|
-
nn.LayerNorm(embed_dim),
|
|
453
|
+
nn.LayerNorm(embed_dim, eps=1e-6),
|
|
454
454
|
nn.Flatten(1),
|
|
455
455
|
)
|
|
456
456
|
self.return_channels = return_channels
|
birder/net/hieradet.py
CHANGED
|
@@ -125,7 +125,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
125
125
|
self.dim = dim
|
|
126
126
|
self.dim_out = dim_out
|
|
127
127
|
|
|
128
|
-
self.norm1 = nn.LayerNorm(dim)
|
|
128
|
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
129
129
|
if dim != dim_out:
|
|
130
130
|
self.proj = nn.Linear(dim, dim_out)
|
|
131
131
|
else:
|
|
@@ -144,7 +144,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
144
144
|
num_heads=num_heads,
|
|
145
145
|
q_pool=copy.deepcopy(self.pool),
|
|
146
146
|
)
|
|
147
|
-
self.norm2 = nn.LayerNorm(dim_out)
|
|
147
|
+
self.norm2 = nn.LayerNorm(dim_out, eps=1e-6)
|
|
148
148
|
self.mlp = MLP(dim_out, [int(dim_out * mlp_ratio), dim_out], activation_layer=nn.GELU)
|
|
149
149
|
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
150
150
|
|
|
@@ -173,11 +173,9 @@ class MultiScaleBlock(nn.Module):
|
|
|
173
173
|
if self.q_stride is not None:
|
|
174
174
|
# Shapes have changed due to Q pooling
|
|
175
175
|
window_size = self.window_size // self.q_stride[0]
|
|
176
|
-
|
|
176
|
+
pad_hw = (pad_hw[0] // self.q_stride[0], pad_hw[1] // self.q_stride[1])
|
|
177
177
|
|
|
178
|
-
|
|
179
|
-
pad_w = (window_size - W % window_size) % window_size
|
|
180
|
-
pad_hw = (H + pad_h, W + pad_w)
|
|
178
|
+
(H, W) = (shortcut.size(1), shortcut.size(2))
|
|
181
179
|
|
|
182
180
|
# Reverse window partition
|
|
183
181
|
if self.window_size > 0:
|
|
@@ -271,7 +269,7 @@ class HieraDet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
271
269
|
|
|
272
270
|
self.body = nn.Sequential(stages)
|
|
273
271
|
self.features = nn.Sequential(
|
|
274
|
-
nn.LayerNorm(embed_dim),
|
|
272
|
+
nn.LayerNorm(embed_dim, eps=1e-6),
|
|
275
273
|
Permute([0, 3, 1, 2]), # B H W C -> B C H W
|
|
276
274
|
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
|
277
275
|
nn.Flatten(1),
|
|
@@ -415,7 +413,7 @@ registry.register_model_config(
|
|
|
415
413
|
"num_heads": 1,
|
|
416
414
|
"global_pos_size": (7, 7),
|
|
417
415
|
"global_att_blocks": [5, 7, 9],
|
|
418
|
-
"window_spec": [8, 4,
|
|
416
|
+
"window_spec": [8, 4, 14, 7],
|
|
419
417
|
"drop_path_rate": 0.1,
|
|
420
418
|
},
|
|
421
419
|
)
|
|
@@ -428,7 +426,7 @@ registry.register_model_config(
|
|
|
428
426
|
"num_heads": 1,
|
|
429
427
|
"global_pos_size": (7, 7),
|
|
430
428
|
"global_att_blocks": [7, 10, 13],
|
|
431
|
-
"window_spec": [8, 4,
|
|
429
|
+
"window_spec": [8, 4, 14, 7],
|
|
432
430
|
"drop_path_rate": 0.1,
|
|
433
431
|
},
|
|
434
432
|
)
|
|
@@ -441,7 +439,7 @@ registry.register_model_config(
|
|
|
441
439
|
"num_heads": 1,
|
|
442
440
|
"global_pos_size": (14, 14),
|
|
443
441
|
"global_att_blocks": [12, 16, 20],
|
|
444
|
-
"window_spec": [8, 4,
|
|
442
|
+
"window_spec": [8, 4, 14, 7],
|
|
445
443
|
"drop_path_rate": 0.1,
|
|
446
444
|
},
|
|
447
445
|
)
|
|
@@ -454,7 +452,7 @@ registry.register_model_config(
|
|
|
454
452
|
"num_heads": 2,
|
|
455
453
|
"global_pos_size": (14, 14),
|
|
456
454
|
"global_att_blocks": [12, 16, 20],
|
|
457
|
-
"window_spec": [8, 4,
|
|
455
|
+
"window_spec": [8, 4, 14, 7],
|
|
458
456
|
"drop_path_rate": 0.1,
|
|
459
457
|
},
|
|
460
458
|
)
|
|
@@ -467,17 +465,84 @@ registry.register_model_config(
|
|
|
467
465
|
"num_heads": 2,
|
|
468
466
|
"global_pos_size": (7, 7),
|
|
469
467
|
"global_att_blocks": [23, 33, 43],
|
|
470
|
-
"window_spec": [8, 4,
|
|
468
|
+
"window_spec": [8, 4, 14, 7],
|
|
469
|
+
"drop_path_rate": 0.2,
|
|
470
|
+
},
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Dynamic window size
|
|
474
|
+
registry.register_model_config(
|
|
475
|
+
"hieradet_d_tiny",
|
|
476
|
+
HieraDet,
|
|
477
|
+
config={
|
|
478
|
+
"depths": [1, 2, 7, 2],
|
|
479
|
+
"embed_dim": 96,
|
|
480
|
+
"num_heads": 1,
|
|
481
|
+
"global_pos_size": (7, 7),
|
|
482
|
+
"global_att_blocks": [5, 7, 9],
|
|
483
|
+
"window_spec": [8, 4, 0, 0],
|
|
484
|
+
"drop_path_rate": 0.1,
|
|
485
|
+
},
|
|
486
|
+
)
|
|
487
|
+
registry.register_model_config(
|
|
488
|
+
"hieradet_d_small",
|
|
489
|
+
HieraDet,
|
|
490
|
+
config={
|
|
491
|
+
"depths": [1, 2, 11, 2],
|
|
492
|
+
"embed_dim": 96,
|
|
493
|
+
"num_heads": 1,
|
|
494
|
+
"global_pos_size": (7, 7),
|
|
495
|
+
"global_att_blocks": [7, 10, 13],
|
|
496
|
+
"window_spec": [8, 4, 0, 0],
|
|
497
|
+
"drop_path_rate": 0.1,
|
|
498
|
+
},
|
|
499
|
+
)
|
|
500
|
+
registry.register_model_config(
|
|
501
|
+
"hieradet_d_base",
|
|
502
|
+
HieraDet,
|
|
503
|
+
config={
|
|
504
|
+
"depths": [2, 3, 16, 3],
|
|
505
|
+
"embed_dim": 96,
|
|
506
|
+
"num_heads": 1,
|
|
507
|
+
"global_pos_size": (14, 14),
|
|
508
|
+
"global_att_blocks": [12, 16, 20],
|
|
509
|
+
"window_spec": [8, 4, 0, 0],
|
|
510
|
+
"drop_path_rate": 0.1,
|
|
511
|
+
},
|
|
512
|
+
)
|
|
513
|
+
registry.register_model_config(
|
|
514
|
+
"hieradet_d_base_plus",
|
|
515
|
+
HieraDet,
|
|
516
|
+
config={
|
|
517
|
+
"depths": [2, 3, 16, 3],
|
|
518
|
+
"embed_dim": 112,
|
|
519
|
+
"num_heads": 2,
|
|
520
|
+
"global_pos_size": (14, 14),
|
|
521
|
+
"global_att_blocks": [12, 16, 20],
|
|
522
|
+
"window_spec": [8, 4, 0, 0],
|
|
523
|
+
"drop_path_rate": 0.1,
|
|
524
|
+
},
|
|
525
|
+
)
|
|
526
|
+
registry.register_model_config(
|
|
527
|
+
"hieradet_d_large",
|
|
528
|
+
HieraDet,
|
|
529
|
+
config={
|
|
530
|
+
"depths": [2, 6, 36, 4],
|
|
531
|
+
"embed_dim": 144,
|
|
532
|
+
"num_heads": 2,
|
|
533
|
+
"global_pos_size": (7, 7),
|
|
534
|
+
"global_att_blocks": [23, 33, 43],
|
|
535
|
+
"window_spec": [8, 4, 0, 0],
|
|
471
536
|
"drop_path_rate": 0.2,
|
|
472
537
|
},
|
|
473
538
|
)
|
|
474
539
|
|
|
475
540
|
registry.register_weights(
|
|
476
|
-
"
|
|
541
|
+
"hieradet_d_small_dino-v2",
|
|
477
542
|
{
|
|
478
|
-
"url": "https://huggingface.co/birder-project/
|
|
543
|
+
"url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2/resolve/main",
|
|
479
544
|
"description": (
|
|
480
|
-
"HieraDet small image encoder pre-trained using DINOv2. "
|
|
545
|
+
"HieraDet (d) small image encoder pre-trained using DINOv2. "
|
|
481
546
|
"This model has not been fine-tuned for a specific classification task"
|
|
482
547
|
),
|
|
483
548
|
"resolution": (224, 224),
|
|
@@ -487,14 +552,16 @@ registry.register_weights(
|
|
|
487
552
|
"sha256": "eb41b8a35445e7f350797094d5e365306b29351e64edd4a316420c23d1e17073",
|
|
488
553
|
}
|
|
489
554
|
},
|
|
490
|
-
"net": {"network": "
|
|
555
|
+
"net": {"network": "hieradet_d_small", "tag": "dino-v2"},
|
|
491
556
|
},
|
|
492
557
|
)
|
|
493
558
|
registry.register_weights(
|
|
494
|
-
"
|
|
559
|
+
"hieradet_d_small_dino-v2-inat21-256px",
|
|
495
560
|
{
|
|
496
|
-
"url": "https://huggingface.co/birder-project/
|
|
497
|
-
"description":
|
|
561
|
+
"url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2-inat21/resolve/main",
|
|
562
|
+
"description": (
|
|
563
|
+
"HieraDet (d) small model pre-trained using DINOv2, then fine-tuned on the iNaturalist 2021 dataset"
|
|
564
|
+
),
|
|
498
565
|
"resolution": (256, 256),
|
|
499
566
|
"formats": {
|
|
500
567
|
"pt": {
|
|
@@ -502,14 +569,16 @@ registry.register_weights(
|
|
|
502
569
|
"sha256": "e1bdeba97eae816ec3ab9b3238d97decf2c34d29b70f9291116ce962b9a4f9df",
|
|
503
570
|
}
|
|
504
571
|
},
|
|
505
|
-
"net": {"network": "
|
|
572
|
+
"net": {"network": "hieradet_d_small", "tag": "dino-v2-inat21-256px"},
|
|
506
573
|
},
|
|
507
574
|
)
|
|
508
575
|
registry.register_weights(
|
|
509
|
-
"
|
|
576
|
+
"hieradet_d_small_dino-v2-inat21",
|
|
510
577
|
{
|
|
511
|
-
"url": "https://huggingface.co/birder-project/
|
|
512
|
-
"description":
|
|
578
|
+
"url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2-inat21/resolve/main",
|
|
579
|
+
"description": (
|
|
580
|
+
"HieraDet (d) small model pre-trained using DINOv2, then fine-tuned on the iNaturalist 2021 dataset"
|
|
581
|
+
),
|
|
513
582
|
"resolution": (384, 384),
|
|
514
583
|
"formats": {
|
|
515
584
|
"pt": {
|
|
@@ -517,14 +586,14 @@ registry.register_weights(
|
|
|
517
586
|
"sha256": "271fa9ed6a9aa1f4d1fc8bbb4c4cac9d15b264f2ac544efb5cd971412691880d",
|
|
518
587
|
}
|
|
519
588
|
},
|
|
520
|
-
"net": {"network": "
|
|
589
|
+
"net": {"network": "hieradet_d_small", "tag": "dino-v2-inat21"},
|
|
521
590
|
},
|
|
522
591
|
)
|
|
523
592
|
registry.register_weights(
|
|
524
|
-
"
|
|
593
|
+
"hieradet_d_small_dino-v2-imagenet12k",
|
|
525
594
|
{
|
|
526
|
-
"url": "https://huggingface.co/birder-project/
|
|
527
|
-
"description": "HieraDet small model pre-trained using DINOv2, then fine-tuned on the ImageNet-12K dataset",
|
|
595
|
+
"url": "https://huggingface.co/birder-project/hieradet_d_small_dino-v2-imagenet12k/resolve/main",
|
|
596
|
+
"description": "HieraDet (d) small model pre-trained using DINOv2, then fine-tuned on the ImageNet-12K dataset",
|
|
528
597
|
"resolution": (256, 256),
|
|
529
598
|
"formats": {
|
|
530
599
|
"pt": {
|
|
@@ -532,6 +601,25 @@ registry.register_weights(
|
|
|
532
601
|
"sha256": "b89dd6c13d061fe8a09d051bb3d76e632e650067ca71578e37b02033107c9963",
|
|
533
602
|
}
|
|
534
603
|
},
|
|
535
|
-
"net": {"network": "
|
|
604
|
+
"net": {"network": "hieradet_d_small", "tag": "dino-v2-imagenet12k"},
|
|
605
|
+
},
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
registry.register_weights( # SAM v2: https://arxiv.org/abs/2408.00714
|
|
609
|
+
"hieradet_small_sam2_1",
|
|
610
|
+
{
|
|
611
|
+
"url": "https://huggingface.co/birder-project/hieradet_small_sam2_1/resolve/main",
|
|
612
|
+
"description": (
|
|
613
|
+
"HieraDet small image encoder pre-trained by Meta AI using SAM v2. "
|
|
614
|
+
"This model has not been fine-tuned for a specific classification task"
|
|
615
|
+
),
|
|
616
|
+
"resolution": (224, 224),
|
|
617
|
+
"formats": {
|
|
618
|
+
"pt": {
|
|
619
|
+
"file_size": 129.6,
|
|
620
|
+
"sha256": "79b6ffdfd4ea9f3b1489ce5a229fe9756b215fc3b52640d01d64136560c1d341",
|
|
621
|
+
}
|
|
622
|
+
},
|
|
623
|
+
"net": {"network": "hieradet_small", "tag": "sam2_1"},
|
|
536
624
|
},
|
|
537
625
|
)
|
birder/net/rope_flexivit.py
CHANGED
|
@@ -69,6 +69,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
69
69
|
layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
|
|
70
70
|
pre_norm: bool = self.config.get("pre_norm", False)
|
|
71
71
|
post_norm: bool = self.config.get("post_norm", True)
|
|
72
|
+
qkv_bias: bool = self.config.get("qkv_bias", True)
|
|
73
|
+
qk_norm: bool = self.config.get("qk_norm", False)
|
|
72
74
|
num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
|
|
73
75
|
class_token: bool = self.config.get("class_token", True)
|
|
74
76
|
attn_pool_head: bool = self.config.get("attn_pool_head", False)
|
|
@@ -118,6 +120,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
118
120
|
self.num_reg_tokens = num_reg_tokens
|
|
119
121
|
self.attn_pool_special_tokens = attn_pool_special_tokens
|
|
120
122
|
self.norm_layer = norm_layer
|
|
123
|
+
self.norm_layer_eps = norm_layer_eps
|
|
121
124
|
self.mlp_layer = mlp_layer
|
|
122
125
|
self.act_layer = act_layer
|
|
123
126
|
self.rope_rot_type = rope_rot_type
|
|
@@ -190,6 +193,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
190
193
|
attention_dropout,
|
|
191
194
|
dpr,
|
|
192
195
|
pre_norm=pre_norm,
|
|
196
|
+
qkv_bias=qkv_bias,
|
|
197
|
+
qk_norm=qk_norm,
|
|
193
198
|
activation_layer=act_layer,
|
|
194
199
|
layer_scale_init_value=layer_scale_init_value,
|
|
195
200
|
norm_layer=norm_layer,
|
|
@@ -231,6 +236,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
231
236
|
rope_temperature=rope_temperature,
|
|
232
237
|
layer_scale_init_value=layer_scale_init_value,
|
|
233
238
|
norm_layer=norm_layer,
|
|
239
|
+
norm_layer_eps=norm_layer_eps,
|
|
234
240
|
mlp_layer=mlp_layer,
|
|
235
241
|
rope_rot_type=rope_rot_type,
|
|
236
242
|
)
|
|
@@ -588,6 +594,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
|
|
|
588
594
|
rope_temperature=self.rope_temperature,
|
|
589
595
|
layer_scale_init_value=self.layer_scale_init_value,
|
|
590
596
|
norm_layer=self.norm_layer,
|
|
597
|
+
norm_layer_eps=self.norm_layer_eps,
|
|
591
598
|
mlp_layer=self.mlp_layer,
|
|
592
599
|
rope_rot_type=self.rope_rot_type,
|
|
593
600
|
)
|