birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +18 -0
- birder/common/training_utils.py +123 -10
- birder/data/collators/detection.py +10 -3
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/net/__init__.py +8 -0
- birder/net/detection/efficientdet.py +65 -86
- birder/net/detection/rt_detr_v1.py +1 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/fasternet.py +1 -1
- birder/net/gc_vit.py +671 -0
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +342 -0
- birder/net/lit_v2.py +436 -0
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +2 -2
- birder/net/vit.py +0 -15
- birder/net/vovnet_v2.py +31 -1
- birder/scripts/benchmark.py +90 -21
- birder/scripts/predict.py +1 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +10 -34
- birder/scripts/train_barlow_twins.py +10 -34
- birder/scripts/train_byol.py +10 -34
- birder/scripts/train_capi.py +10 -35
- birder/scripts/train_data2vec.py +9 -34
- birder/scripts/train_data2vec2.py +9 -34
- birder/scripts/train_detection.py +48 -40
- birder/scripts/train_dino_v1.py +10 -34
- birder/scripts/train_dino_v2.py +9 -34
- birder/scripts/train_dino_v2_dist.py +9 -34
- birder/scripts/train_franca.py +9 -34
- birder/scripts/train_i_jepa.py +9 -34
- birder/scripts/train_ibot.py +9 -34
- birder/scripts/train_kd.py +156 -64
- birder/scripts/train_mim.py +10 -34
- birder/scripts/train_mmcr.py +10 -34
- birder/scripts/train_rotnet.py +10 -34
- birder/scripts/train_simclr.py +10 -34
- birder/scripts/train_vicreg.py +10 -34
- birder/tools/auto_anchors.py +20 -1
- birder/tools/pack.py +172 -103
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
birder/net/lit_v2.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LIT v2, adapted from
|
|
3
|
+
https://github.com/ziplab/LITv2/blob/main/classification/models/litv2.py
|
|
4
|
+
|
|
5
|
+
Paper "Fast Vision Transformers with HiLo Attention", https://arxiv.org/abs/2205.13213
|
|
6
|
+
|
|
7
|
+
Generated by Claude Code Opus 4.5
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# Reference license: Apache-2.0
|
|
11
|
+
|
|
12
|
+
import math
|
|
13
|
+
from collections import OrderedDict
|
|
14
|
+
from typing import Any
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
from torch import nn
|
|
20
|
+
from torchvision.ops import Permute
|
|
21
|
+
from torchvision.ops import StochasticDepth
|
|
22
|
+
|
|
23
|
+
from birder.model_registry import registry
|
|
24
|
+
from birder.net.base import DetectorBackbone
|
|
25
|
+
from birder.net.lit_v1 import DeformablePatchMerging
|
|
26
|
+
from birder.net.lit_v1 import IdentityDownsample
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DepthwiseMLP(nn.Module):
|
|
30
|
+
def __init__(self, in_features: int, hidden_features: int) -> None:
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
33
|
+
self.dwconv = nn.Conv2d(
|
|
34
|
+
hidden_features, hidden_features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=hidden_features
|
|
35
|
+
)
|
|
36
|
+
self.act = nn.GELU()
|
|
37
|
+
self.fc2 = nn.Linear(hidden_features, in_features)
|
|
38
|
+
|
|
39
|
+
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
40
|
+
x = self.fc1(x)
|
|
41
|
+
|
|
42
|
+
(B, N, C) = x.size()
|
|
43
|
+
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
44
|
+
x = self.dwconv(x)
|
|
45
|
+
x = x.permute(0, 2, 3, 1).reshape(B, N, C)
|
|
46
|
+
x = self.act(x)
|
|
47
|
+
x = self.fc2(x)
|
|
48
|
+
|
|
49
|
+
return x
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DepthwiseMLPBlock(nn.Module):
|
|
53
|
+
def __init__(self, dim: int, mlp_ratio: float, drop_path: float) -> None:
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.norm = nn.LayerNorm(dim)
|
|
56
|
+
self.mlp = DepthwiseMLP(dim, int(dim * mlp_ratio))
|
|
57
|
+
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
58
|
+
|
|
59
|
+
def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
|
|
60
|
+
(H, W) = resolution
|
|
61
|
+
return x + self.drop_path(self.mlp(self.norm(x), H, W))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class HiLoAttention(nn.Module):
|
|
65
|
+
"""
|
|
66
|
+
HiLo Attention: High-frequency local attention + Low-frequency global attention
|
|
67
|
+
|
|
68
|
+
Hi-Fi (High frequency): Local window attention
|
|
69
|
+
Lo-Fi (Low frequency): Global attention with average pooling
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dim: int,
|
|
75
|
+
num_heads: int,
|
|
76
|
+
window_size: int,
|
|
77
|
+
alpha: float,
|
|
78
|
+
) -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
|
81
|
+
|
|
82
|
+
self.window_size = window_size
|
|
83
|
+
self.head_dim = dim // num_heads
|
|
84
|
+
self.scale = self.head_dim**-0.5
|
|
85
|
+
|
|
86
|
+
# Split heads between Lo-Fi (global) and Hi-Fi (local)
|
|
87
|
+
self.l_heads = int(num_heads * alpha) # Lo-Fi heads
|
|
88
|
+
self.h_heads = num_heads - self.l_heads # Hi-Fi heads
|
|
89
|
+
self.l_dim = self.l_heads * self.head_dim
|
|
90
|
+
self.h_dim = self.h_heads * self.head_dim
|
|
91
|
+
self.head_dim = self.head_dim
|
|
92
|
+
|
|
93
|
+
# ws == 1 is equal to standard multi-head self-attention
|
|
94
|
+
if window_size == 1:
|
|
95
|
+
self.h_heads = 0
|
|
96
|
+
self.h_dim = 0
|
|
97
|
+
self.l_heads = num_heads
|
|
98
|
+
self.l_dim = dim
|
|
99
|
+
|
|
100
|
+
# Lo-Fi: Global attention with pooling
|
|
101
|
+
if self.l_heads > 0:
|
|
102
|
+
if window_size > 1:
|
|
103
|
+
self.sr = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=(window_size, window_size))
|
|
104
|
+
else:
|
|
105
|
+
self.sr = nn.Identity()
|
|
106
|
+
|
|
107
|
+
self.l_q = nn.Linear(dim, self.l_dim)
|
|
108
|
+
self.l_kv = nn.Linear(dim, self.l_dim * 2)
|
|
109
|
+
self.l_proj = nn.Linear(self.l_dim, self.l_dim)
|
|
110
|
+
else:
|
|
111
|
+
self.l_q = nn.Identity()
|
|
112
|
+
self.l_kv = nn.Identity()
|
|
113
|
+
self.l_proj = nn.Identity()
|
|
114
|
+
|
|
115
|
+
# Hi-Fi: Local window attention
|
|
116
|
+
if self.h_heads > 0:
|
|
117
|
+
self.h_qkv = nn.Linear(dim, self.h_dim * 3)
|
|
118
|
+
self.h_proj = nn.Linear(self.h_dim, self.h_dim)
|
|
119
|
+
else:
|
|
120
|
+
self.h_qkv = nn.Identity()
|
|
121
|
+
self.h_proj = nn.Identity()
|
|
122
|
+
|
|
123
|
+
def _lofi(self, x: torch.Tensor) -> torch.Tensor:
|
|
124
|
+
(B, H, W, C) = x.size()
|
|
125
|
+
|
|
126
|
+
q = self.l_q(x).reshape(B, H * W, self.l_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
127
|
+
|
|
128
|
+
# Spatial reduction for k, v
|
|
129
|
+
if self.window_size > 1:
|
|
130
|
+
x = x.permute(0, 3, 1, 2)
|
|
131
|
+
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
|
|
132
|
+
kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
133
|
+
else:
|
|
134
|
+
kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
135
|
+
|
|
136
|
+
(k, v) = kv.unbind(0)
|
|
137
|
+
|
|
138
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
139
|
+
attn = F.softmax(attn, dim=-1)
|
|
140
|
+
|
|
141
|
+
x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
|
|
142
|
+
x = self.l_proj(x)
|
|
143
|
+
|
|
144
|
+
return x
|
|
145
|
+
|
|
146
|
+
def _hifi(self, x: torch.Tensor) -> torch.Tensor:
|
|
147
|
+
(B, H, W, _) = x.size()
|
|
148
|
+
ws = self.window_size
|
|
149
|
+
|
|
150
|
+
# Pad if needed
|
|
151
|
+
pad_h = (ws - H % ws) % ws
|
|
152
|
+
pad_w = (ws - W % ws) % ws
|
|
153
|
+
if pad_h > 0 or pad_w > 0:
|
|
154
|
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
|
155
|
+
|
|
156
|
+
(_, h_pad, w_pad, _) = x.size()
|
|
157
|
+
h_groups = h_pad // ws
|
|
158
|
+
w_groups = w_pad // ws
|
|
159
|
+
total_groups = h_groups * w_groups
|
|
160
|
+
|
|
161
|
+
x = x.reshape(B, h_groups, ws, w_groups, ws, -1).transpose(2, 3)
|
|
162
|
+
|
|
163
|
+
qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.head_dim).permute(3, 0, 1, 4, 2, 5)
|
|
164
|
+
(q, k, v) = qkv.unbind(0)
|
|
165
|
+
|
|
166
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
167
|
+
attn = F.softmax(attn, dim=-1)
|
|
168
|
+
|
|
169
|
+
x = (attn @ v).transpose(2, 3).reshape(B, h_groups, w_groups, ws, ws, self.h_dim)
|
|
170
|
+
x = x.transpose(2, 3).reshape(B, h_pad, w_pad, self.h_dim)
|
|
171
|
+
x = self.h_proj(x)
|
|
172
|
+
|
|
173
|
+
# Remove padding
|
|
174
|
+
if pad_h > 0 or pad_w > 0:
|
|
175
|
+
x = x[:, :H, :W, :].contiguous()
|
|
176
|
+
|
|
177
|
+
return x
|
|
178
|
+
|
|
179
|
+
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
180
|
+
(B, N, C) = x.size()
|
|
181
|
+
x = x.reshape(B, H, W, C)
|
|
182
|
+
|
|
183
|
+
if self.h_heads == 0:
|
|
184
|
+
x = self._lofi(x)
|
|
185
|
+
return x.reshape(B, N, C)
|
|
186
|
+
|
|
187
|
+
if self.l_heads == 0:
|
|
188
|
+
x = self._hifi(x)
|
|
189
|
+
return x.reshape(B, N, C)
|
|
190
|
+
|
|
191
|
+
# Process both branches and concatenate
|
|
192
|
+
hifi_out = self._hifi(x)
|
|
193
|
+
lofi_out = self._lofi(x)
|
|
194
|
+
|
|
195
|
+
x = torch.concat((hifi_out, lofi_out), dim=-1)
|
|
196
|
+
return x.reshape(B, N, C)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class HiLoBlock(nn.Module):
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
dim: int,
|
|
203
|
+
num_heads: int,
|
|
204
|
+
window_size: int,
|
|
205
|
+
alpha: float,
|
|
206
|
+
mlp_ratio: float,
|
|
207
|
+
drop_path: float,
|
|
208
|
+
) -> None:
|
|
209
|
+
super().__init__()
|
|
210
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
211
|
+
self.attn = HiLoAttention(dim, num_heads, window_size, alpha)
|
|
212
|
+
self.drop_path1 = StochasticDepth(drop_path, mode="row")
|
|
213
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
214
|
+
self.mlp = DepthwiseMLP(dim, int(dim * mlp_ratio))
|
|
215
|
+
self.drop_path2 = StochasticDepth(drop_path, mode="row")
|
|
216
|
+
|
|
217
|
+
def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
|
|
218
|
+
(H, W) = resolution
|
|
219
|
+
x = x + self.drop_path1(self.attn(self.norm1(x), H, W))
|
|
220
|
+
x = x + self.drop_path2(self.mlp(self.norm2(x), H, W))
|
|
221
|
+
return x
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class LITStage(nn.Module):
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
in_dim: int,
|
|
228
|
+
out_dim: int,
|
|
229
|
+
resolution: tuple[int, int],
|
|
230
|
+
depth: int,
|
|
231
|
+
num_heads: int,
|
|
232
|
+
window_size: int,
|
|
233
|
+
alpha: float,
|
|
234
|
+
mlp_ratio: float,
|
|
235
|
+
downsample: bool,
|
|
236
|
+
drop_path: list[float],
|
|
237
|
+
) -> None:
|
|
238
|
+
super().__init__()
|
|
239
|
+
if downsample is True:
|
|
240
|
+
self.downsample = DeformablePatchMerging(in_dim, out_dim)
|
|
241
|
+
resolution = (resolution[0] // 2, resolution[1] // 2)
|
|
242
|
+
else:
|
|
243
|
+
self.downsample = IdentityDownsample()
|
|
244
|
+
|
|
245
|
+
blocks: list[nn.Module] = []
|
|
246
|
+
for i in range(depth):
|
|
247
|
+
if window_size > 0:
|
|
248
|
+
blocks.append(HiLoBlock(out_dim, num_heads, window_size, alpha, mlp_ratio, drop_path[i]))
|
|
249
|
+
else:
|
|
250
|
+
blocks.append(DepthwiseMLPBlock(out_dim, mlp_ratio, drop_path[i]))
|
|
251
|
+
|
|
252
|
+
self.blocks = nn.ModuleList(blocks)
|
|
253
|
+
|
|
254
|
+
def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
|
|
255
|
+
(x, H, W) = self.downsample(x, input_resolution)
|
|
256
|
+
for block in self.blocks:
|
|
257
|
+
x = block(x, (H, W))
|
|
258
|
+
|
|
259
|
+
return (x, H, W)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
# pylint: disable=invalid-name
|
|
263
|
+
class LIT_v2(DetectorBackbone):
|
|
264
|
+
block_group_regex = r"body\.stage(\d+)\.blocks\.(\d+)"
|
|
265
|
+
|
|
266
|
+
# pylint:disable=too-many-locals
|
|
267
|
+
def __init__(
|
|
268
|
+
self,
|
|
269
|
+
input_channels: int,
|
|
270
|
+
num_classes: int,
|
|
271
|
+
*,
|
|
272
|
+
config: Optional[dict[str, Any]] = None,
|
|
273
|
+
size: Optional[tuple[int, int]] = None,
|
|
274
|
+
) -> None:
|
|
275
|
+
super().__init__(input_channels, num_classes, config=config, size=size)
|
|
276
|
+
assert self.config is not None, "must set config"
|
|
277
|
+
|
|
278
|
+
patch_size = 4
|
|
279
|
+
embed_dim: int = self.config["embed_dim"]
|
|
280
|
+
depths: list[int] = self.config["depths"]
|
|
281
|
+
num_heads: list[int] = self.config["num_heads"]
|
|
282
|
+
local_ws: list[int] = self.config["local_ws"]
|
|
283
|
+
alpha: float = self.config["alpha"]
|
|
284
|
+
drop_path_rate: float = self.config["drop_path_rate"]
|
|
285
|
+
|
|
286
|
+
num_stages = len(depths)
|
|
287
|
+
|
|
288
|
+
self.stem = nn.Sequential(
|
|
289
|
+
nn.Conv2d(
|
|
290
|
+
self.input_channels,
|
|
291
|
+
embed_dim,
|
|
292
|
+
kernel_size=(patch_size, patch_size),
|
|
293
|
+
stride=(patch_size, patch_size),
|
|
294
|
+
padding=(0, 0),
|
|
295
|
+
bias=True,
|
|
296
|
+
),
|
|
297
|
+
Permute([0, 2, 3, 1]),
|
|
298
|
+
nn.LayerNorm(embed_dim),
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Stochastic depth
|
|
302
|
+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
303
|
+
|
|
304
|
+
stages: OrderedDict[str, nn.Module] = OrderedDict()
|
|
305
|
+
return_channels: list[int] = []
|
|
306
|
+
prev_dim = embed_dim
|
|
307
|
+
resolution = (self.size[0] // patch_size, self.size[1] // patch_size)
|
|
308
|
+
for i_stage in range(num_stages):
|
|
309
|
+
in_dim = prev_dim
|
|
310
|
+
out_dim = in_dim * 2 if i_stage > 0 else in_dim
|
|
311
|
+
stage = LITStage(
|
|
312
|
+
in_dim,
|
|
313
|
+
out_dim,
|
|
314
|
+
resolution,
|
|
315
|
+
depth=depths[i_stage],
|
|
316
|
+
num_heads=num_heads[i_stage],
|
|
317
|
+
window_size=local_ws[i_stage],
|
|
318
|
+
alpha=alpha,
|
|
319
|
+
mlp_ratio=4.0,
|
|
320
|
+
downsample=i_stage > 0,
|
|
321
|
+
drop_path=dpr[i_stage],
|
|
322
|
+
)
|
|
323
|
+
stages[f"stage{i_stage + 1}"] = stage
|
|
324
|
+
|
|
325
|
+
if i_stage > 0:
|
|
326
|
+
resolution = (resolution[0] // 2, resolution[1] // 2)
|
|
327
|
+
|
|
328
|
+
prev_dim = out_dim
|
|
329
|
+
return_channels.append(out_dim)
|
|
330
|
+
|
|
331
|
+
num_features = embed_dim * (2 ** (num_stages - 1))
|
|
332
|
+
self.body = nn.ModuleDict(stages)
|
|
333
|
+
self.features = nn.Sequential(
|
|
334
|
+
nn.LayerNorm(num_features),
|
|
335
|
+
Permute([0, 2, 1]),
|
|
336
|
+
nn.AdaptiveAvgPool1d(output_size=1),
|
|
337
|
+
nn.Flatten(1),
|
|
338
|
+
)
|
|
339
|
+
self.return_channels = return_channels
|
|
340
|
+
self.embedding_size = num_features
|
|
341
|
+
self.classifier = self.create_classifier()
|
|
342
|
+
|
|
343
|
+
# Weight initialization
|
|
344
|
+
for name, m in self.named_modules():
|
|
345
|
+
if isinstance(m, nn.Linear):
|
|
346
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
347
|
+
if m.bias is not None:
|
|
348
|
+
nn.init.zeros_(m.bias)
|
|
349
|
+
elif isinstance(m, nn.LayerNorm):
|
|
350
|
+
nn.init.ones_(m.weight)
|
|
351
|
+
nn.init.zeros_(m.bias)
|
|
352
|
+
elif isinstance(m, nn.Conv2d):
|
|
353
|
+
if name.endswith("offset_conv") is True:
|
|
354
|
+
continue
|
|
355
|
+
|
|
356
|
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
357
|
+
fan_out //= m.groups
|
|
358
|
+
nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
|
|
359
|
+
if m.bias is not None:
|
|
360
|
+
nn.init.zeros_(m.bias)
|
|
361
|
+
|
|
362
|
+
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
363
|
+
x = self.stem(x)
|
|
364
|
+
(B, H, W, C) = x.size()
|
|
365
|
+
x = x.reshape(B, H * W, C)
|
|
366
|
+
|
|
367
|
+
out = {}
|
|
368
|
+
for name, stage in self.body.items():
|
|
369
|
+
(x, H, W) = stage(x, (H, W))
|
|
370
|
+
if name in self.return_stages:
|
|
371
|
+
features = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
372
|
+
out[name] = features
|
|
373
|
+
|
|
374
|
+
return out
|
|
375
|
+
|
|
376
|
+
def freeze_stages(self, up_to_stage: int) -> None:
|
|
377
|
+
for param in self.stem.parameters():
|
|
378
|
+
param.requires_grad = False
|
|
379
|
+
|
|
380
|
+
for idx, stage in enumerate(self.body.values()):
|
|
381
|
+
if idx >= up_to_stage:
|
|
382
|
+
break
|
|
383
|
+
|
|
384
|
+
for param in stage.parameters():
|
|
385
|
+
param.requires_grad = False
|
|
386
|
+
|
|
387
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
388
|
+
x = self.stem(x)
|
|
389
|
+
(B, H, W, C) = x.size()
|
|
390
|
+
x = x.reshape(B, H * W, C)
|
|
391
|
+
for stage in self.body.values():
|
|
392
|
+
(x, H, W) = stage(x, (H, W))
|
|
393
|
+
|
|
394
|
+
return x
|
|
395
|
+
|
|
396
|
+
def embedding(self, x: torch.Tensor) -> torch.Tensor:
|
|
397
|
+
x = self.forward_features(x)
|
|
398
|
+
return self.features(x)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
registry.register_model_config(
|
|
402
|
+
"lit_v2_s",
|
|
403
|
+
LIT_v2,
|
|
404
|
+
config={
|
|
405
|
+
"embed_dim": 96,
|
|
406
|
+
"depths": [2, 2, 6, 2],
|
|
407
|
+
"num_heads": [3, 6, 12, 24],
|
|
408
|
+
"local_ws": [0, 0, 2, 1],
|
|
409
|
+
"alpha": 0.9,
|
|
410
|
+
"drop_path_rate": 0.2,
|
|
411
|
+
},
|
|
412
|
+
)
|
|
413
|
+
registry.register_model_config(
|
|
414
|
+
"lit_v2_m",
|
|
415
|
+
LIT_v2,
|
|
416
|
+
config={
|
|
417
|
+
"embed_dim": 96,
|
|
418
|
+
"depths": [2, 2, 18, 2],
|
|
419
|
+
"num_heads": [3, 6, 12, 24],
|
|
420
|
+
"local_ws": [0, 0, 2, 1],
|
|
421
|
+
"alpha": 0.9,
|
|
422
|
+
"drop_path_rate": 0.3,
|
|
423
|
+
},
|
|
424
|
+
)
|
|
425
|
+
registry.register_model_config(
|
|
426
|
+
"lit_v2_b",
|
|
427
|
+
LIT_v2,
|
|
428
|
+
config={
|
|
429
|
+
"embed_dim": 128,
|
|
430
|
+
"depths": [2, 2, 18, 2],
|
|
431
|
+
"num_heads": [4, 8, 16, 32],
|
|
432
|
+
"local_ws": [0, 0, 2, 1],
|
|
433
|
+
"alpha": 0.9,
|
|
434
|
+
"drop_path_rate": 0.5,
|
|
435
|
+
},
|
|
436
|
+
)
|
|
@@ -491,7 +491,7 @@ registry.register_weights(
|
|
|
491
491
|
"formats": {
|
|
492
492
|
"pt": {
|
|
493
493
|
"file_size": 39.7,
|
|
494
|
-
"sha256": "
|
|
494
|
+
"sha256": "d7d76733e0116d351bf8aafc563659eab7bea02174a02c10fba8eb3a64ea87e1",
|
|
495
495
|
}
|
|
496
496
|
},
|
|
497
497
|
"net": {"network": "mobilenet_v4_hybrid_m", "tag": "il-common"},
|
birder/net/resnet_v1.py
CHANGED
birder/net/resnext.py
CHANGED
|
@@ -30,6 +30,7 @@ class ResidualBlock(nn.Module):
|
|
|
30
30
|
base_width: int,
|
|
31
31
|
expansion: int,
|
|
32
32
|
squeeze_excitation: bool,
|
|
33
|
+
avg_down: bool,
|
|
33
34
|
) -> None:
|
|
34
35
|
super().__init__()
|
|
35
36
|
width = int(out_channels * (base_width / 64.0)) * groups
|
|
@@ -62,20 +63,34 @@ class ResidualBlock(nn.Module):
|
|
|
62
63
|
nn.BatchNorm2d(out_channels * expansion),
|
|
63
64
|
)
|
|
64
65
|
|
|
65
|
-
if in_channels == out_channels * expansion:
|
|
66
|
+
if in_channels == out_channels * expansion and stride == (1, 1):
|
|
66
67
|
self.block2 = nn.Identity()
|
|
67
68
|
else:
|
|
68
|
-
|
|
69
|
-
nn.
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
69
|
+
if avg_down is True and stride != (1, 1):
|
|
70
|
+
self.block2 = nn.Sequential(
|
|
71
|
+
nn.AvgPool2d(kernel_size=2, stride=stride, ceil_mode=True, count_include_pad=False),
|
|
72
|
+
nn.Conv2d(
|
|
73
|
+
in_channels,
|
|
74
|
+
out_channels * expansion,
|
|
75
|
+
kernel_size=(1, 1),
|
|
76
|
+
stride=(1, 1),
|
|
77
|
+
padding=(0, 0),
|
|
78
|
+
bias=False,
|
|
79
|
+
),
|
|
80
|
+
nn.BatchNorm2d(out_channels * expansion),
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
self.block2 = nn.Sequential(
|
|
84
|
+
nn.Conv2d(
|
|
85
|
+
in_channels,
|
|
86
|
+
out_channels * expansion,
|
|
87
|
+
kernel_size=(1, 1),
|
|
88
|
+
stride=stride,
|
|
89
|
+
padding=(0, 0),
|
|
90
|
+
bias=False,
|
|
91
|
+
),
|
|
92
|
+
nn.BatchNorm2d(out_channels * expansion),
|
|
93
|
+
)
|
|
79
94
|
|
|
80
95
|
self.relu = nn.ReLU(inplace=True)
|
|
81
96
|
if squeeze_excitation is True:
|
|
@@ -107,23 +122,35 @@ class ResNeXt(DetectorBackbone):
|
|
|
107
122
|
super().__init__(input_channels, num_classes, config=config, size=size)
|
|
108
123
|
assert self.config is not None, "must set config"
|
|
109
124
|
|
|
110
|
-
groups = 32
|
|
111
|
-
base_width = 4
|
|
112
125
|
expansion = 4
|
|
126
|
+
groups: int = self.config.get("groups", 32)
|
|
127
|
+
base_width: int = self.config.get("base_width", 4)
|
|
113
128
|
filter_list = [64, 128, 256, 512]
|
|
114
129
|
units: list[int] = self.config["units"]
|
|
130
|
+
deep_stem: bool = self.config.get("deep_stem", False)
|
|
131
|
+
avg_down: bool = self.config.get("avg_down", False)
|
|
115
132
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
stride=(
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
|
|
133
|
+
if deep_stem is True:
|
|
134
|
+
self.stem = nn.Sequential(
|
|
135
|
+
Conv2dNormActivation(
|
|
136
|
+
self.input_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
|
|
137
|
+
),
|
|
138
|
+
Conv2dNormActivation(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
|
|
139
|
+
Conv2dNormActivation(32, filter_list[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
|
|
140
|
+
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
self.stem = nn.Sequential(
|
|
144
|
+
Conv2dNormActivation(
|
|
145
|
+
self.input_channels,
|
|
146
|
+
filter_list[0],
|
|
147
|
+
kernel_size=(7, 7),
|
|
148
|
+
stride=(2, 2),
|
|
149
|
+
padding=(3, 3),
|
|
150
|
+
bias=False,
|
|
151
|
+
),
|
|
152
|
+
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
|
|
153
|
+
)
|
|
127
154
|
|
|
128
155
|
# Generate body layers
|
|
129
156
|
in_channels = filter_list[0]
|
|
@@ -150,6 +177,7 @@ class ResNeXt(DetectorBackbone):
|
|
|
150
177
|
base_width=base_width,
|
|
151
178
|
expansion=expansion,
|
|
152
179
|
squeeze_excitation=squeeze_excitation,
|
|
180
|
+
avg_down=avg_down,
|
|
153
181
|
)
|
|
154
182
|
)
|
|
155
183
|
in_channels = channels * expansion
|
|
@@ -209,3 +237,17 @@ class ResNeXt(DetectorBackbone):
|
|
|
209
237
|
registry.register_model_config("resnext_50", ResNeXt, config={"units": [3, 4, 6, 3]})
|
|
210
238
|
registry.register_model_config("resnext_101", ResNeXt, config={"units": [3, 4, 23, 3]})
|
|
211
239
|
registry.register_model_config("resnext_152", ResNeXt, config={"units": [3, 8, 36, 3]})
|
|
240
|
+
|
|
241
|
+
registry.register_model_config("resnext_101_32x8", ResNeXt, config={"units": [3, 4, 23, 3], "base_width": 8})
|
|
242
|
+
registry.register_model_config("resnext_101_64x4", ResNeXt, config={"units": [3, 4, 23, 3], "groups": 64})
|
|
243
|
+
|
|
244
|
+
# ResNeXt-D variants (From: Bag of Tricks for Image Classification with Convolutional Neural Networks)
|
|
245
|
+
registry.register_model_config(
|
|
246
|
+
"resnext_d_50", ResNeXt, config={"units": [3, 4, 6, 3], "deep_stem": True, "avg_down": True}
|
|
247
|
+
)
|
|
248
|
+
registry.register_model_config(
|
|
249
|
+
"resnext_d_101", ResNeXt, config={"units": [3, 4, 23, 3], "deep_stem": True, "avg_down": True}
|
|
250
|
+
)
|
|
251
|
+
registry.register_model_config(
|
|
252
|
+
"resnext_d_152", ResNeXt, config={"units": [3, 8, 36, 3], "deep_stem": True, "avg_down": True}
|
|
253
|
+
)
|
birder/net/se_resnet_v1.py
CHANGED
|
@@ -57,3 +57,49 @@ registry.register_model_config(
|
|
|
57
57
|
SE_ResNet_v1,
|
|
58
58
|
config={"bottle_neck": True, "filter_list": [64, 256, 512, 1024, 2048], "units": [3, 30, 48, 8]},
|
|
59
59
|
)
|
|
60
|
+
|
|
61
|
+
# SE-ResNet-D variants (From: Bag of Tricks for Image Classification with Convolutional Neural Networks)
|
|
62
|
+
registry.register_model_config(
|
|
63
|
+
"se_resnet_d_50",
|
|
64
|
+
SE_ResNet_v1,
|
|
65
|
+
config={
|
|
66
|
+
"bottle_neck": True,
|
|
67
|
+
"filter_list": [64, 256, 512, 1024, 2048],
|
|
68
|
+
"units": [3, 4, 6, 3],
|
|
69
|
+
"deep_stem": True,
|
|
70
|
+
"avg_down": True,
|
|
71
|
+
},
|
|
72
|
+
)
|
|
73
|
+
registry.register_model_config(
|
|
74
|
+
"se_resnet_d_101",
|
|
75
|
+
SE_ResNet_v1,
|
|
76
|
+
config={
|
|
77
|
+
"bottle_neck": True,
|
|
78
|
+
"filter_list": [64, 256, 512, 1024, 2048],
|
|
79
|
+
"units": [3, 4, 23, 3],
|
|
80
|
+
"deep_stem": True,
|
|
81
|
+
"avg_down": True,
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
registry.register_model_config(
|
|
85
|
+
"se_resnet_d_152",
|
|
86
|
+
SE_ResNet_v1,
|
|
87
|
+
config={
|
|
88
|
+
"bottle_neck": True,
|
|
89
|
+
"filter_list": [64, 256, 512, 1024, 2048],
|
|
90
|
+
"units": [3, 8, 36, 3],
|
|
91
|
+
"deep_stem": True,
|
|
92
|
+
"avg_down": True,
|
|
93
|
+
},
|
|
94
|
+
)
|
|
95
|
+
registry.register_model_config(
|
|
96
|
+
"se_resnet_d_200",
|
|
97
|
+
SE_ResNet_v1,
|
|
98
|
+
config={
|
|
99
|
+
"bottle_neck": True,
|
|
100
|
+
"filter_list": [64, 256, 512, 1024, 2048],
|
|
101
|
+
"units": [3, 24, 36, 3],
|
|
102
|
+
"deep_stem": True,
|
|
103
|
+
"avg_down": True,
|
|
104
|
+
},
|
|
105
|
+
)
|
birder/net/se_resnext.py
CHANGED
|
@@ -25,3 +25,6 @@ class SE_ResNeXt(ResNeXt):
|
|
|
25
25
|
registry.register_model_config("se_resnext_50", SE_ResNeXt, config={"units": [3, 4, 6, 3]})
|
|
26
26
|
registry.register_model_config("se_resnext_101", SE_ResNeXt, config={"units": [3, 4, 23, 3]})
|
|
27
27
|
registry.register_model_config("se_resnext_152", SE_ResNeXt, config={"units": [3, 8, 36, 3]})
|
|
28
|
+
|
|
29
|
+
registry.register_model_config("se_resnext_101_32x8", SE_ResNeXt, config={"units": [3, 4, 23, 3], "base_width": 8})
|
|
30
|
+
registry.register_model_config("se_resnext_101_64x4", SE_ResNeXt, config={"units": [3, 4, 23, 3], "groups": 64})
|
birder/net/simple_vit.py
CHANGED
|
@@ -79,7 +79,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
79
79
|
dim=hidden_dim,
|
|
80
80
|
num_special_tokens=self.num_special_tokens,
|
|
81
81
|
)
|
|
82
|
-
self.pos_embedding = nn.
|
|
82
|
+
self.pos_embedding = nn.Buffer(pos_embedding)
|
|
83
83
|
|
|
84
84
|
self.encoder = Encoder(num_layers, num_heads, hidden_dim, mlp_dim, dropout=0.0, attention_dropout=0.0, dpr=dpr)
|
|
85
85
|
self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
|
|
@@ -203,7 +203,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
203
203
|
dim=self.hidden_dim,
|
|
204
204
|
num_special_tokens=self.num_special_tokens,
|
|
205
205
|
)
|
|
206
|
-
self.pos_embedding = nn.
|
|
206
|
+
self.pos_embedding = nn.Buffer(pos_embedding)
|
|
207
207
|
|
|
208
208
|
def set_causal_attention(self, is_causal: bool = True) -> None:
|
|
209
209
|
self.encoder.set_causal_attention(is_causal)
|
birder/net/vit.py
CHANGED
|
@@ -1588,21 +1588,6 @@ registry.register_weights(
|
|
|
1588
1588
|
"net": {"network": "vit_l16", "tag": "mim"},
|
|
1589
1589
|
},
|
|
1590
1590
|
)
|
|
1591
|
-
registry.register_weights(
|
|
1592
|
-
"vit_l16_mim-eu-common",
|
|
1593
|
-
{
|
|
1594
|
-
"url": "https://huggingface.co/birder-project/vit_l16_mim-eu-common/resolve/main",
|
|
1595
|
-
"description": "ViT l16 model with MIM pretraining, then fine-tuned on the eu-common dataset",
|
|
1596
|
-
"resolution": (256, 256),
|
|
1597
|
-
"formats": {
|
|
1598
|
-
"pt": {
|
|
1599
|
-
"file_size": 1160.1,
|
|
1600
|
-
"sha256": "3b7235b90f76fb1e0e36d4c4111777a4cc4e4500552fe840c51170b208310d16",
|
|
1601
|
-
},
|
|
1602
|
-
},
|
|
1603
|
-
"net": {"network": "vit_l16", "tag": "mim-eu-common"},
|
|
1604
|
-
},
|
|
1605
|
-
)
|
|
1606
1591
|
registry.register_weights( # BioCLIP v2: https://arxiv.org/abs/2505.23883
|
|
1607
1592
|
"vit_l14_pn_bioclip-v2",
|
|
1608
1593
|
{
|