birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/common/lib.py +2 -9
- birder/common/training_cli.py +18 -0
- birder/common/training_utils.py +123 -10
- birder/data/collators/detection.py +10 -3
- birder/data/datasets/coco.py +8 -10
- birder/data/transforms/detection.py +30 -13
- birder/inference/detection.py +108 -4
- birder/inference/wbf.py +226 -0
- birder/net/__init__.py +8 -0
- birder/net/detection/efficientdet.py +65 -86
- birder/net/detection/rt_detr_v1.py +1 -0
- birder/net/detection/yolo_anchors.py +205 -0
- birder/net/detection/yolo_v2.py +25 -24
- birder/net/detection/yolo_v3.py +39 -40
- birder/net/detection/yolo_v4.py +28 -26
- birder/net/detection/yolo_v4_tiny.py +24 -20
- birder/net/fasternet.py +1 -1
- birder/net/gc_vit.py +671 -0
- birder/net/lit_v1.py +472 -0
- birder/net/lit_v1_tiny.py +342 -0
- birder/net/lit_v2.py +436 -0
- birder/net/mobilenet_v4_hybrid.py +1 -1
- birder/net/resnet_v1.py +1 -1
- birder/net/resnext.py +67 -25
- birder/net/se_resnet_v1.py +46 -0
- birder/net/se_resnext.py +3 -0
- birder/net/simple_vit.py +2 -2
- birder/net/vit.py +0 -15
- birder/net/vovnet_v2.py +31 -1
- birder/scripts/benchmark.py +90 -21
- birder/scripts/predict.py +1 -0
- birder/scripts/predict_detection.py +18 -11
- birder/scripts/train.py +10 -34
- birder/scripts/train_barlow_twins.py +10 -34
- birder/scripts/train_byol.py +10 -34
- birder/scripts/train_capi.py +10 -35
- birder/scripts/train_data2vec.py +9 -34
- birder/scripts/train_data2vec2.py +9 -34
- birder/scripts/train_detection.py +48 -40
- birder/scripts/train_dino_v1.py +10 -34
- birder/scripts/train_dino_v2.py +9 -34
- birder/scripts/train_dino_v2_dist.py +9 -34
- birder/scripts/train_franca.py +9 -34
- birder/scripts/train_i_jepa.py +9 -34
- birder/scripts/train_ibot.py +9 -34
- birder/scripts/train_kd.py +156 -64
- birder/scripts/train_mim.py +10 -34
- birder/scripts/train_mmcr.py +10 -34
- birder/scripts/train_rotnet.py +10 -34
- birder/scripts/train_simclr.py +10 -34
- birder/scripts/train_vicreg.py +10 -34
- birder/tools/auto_anchors.py +20 -1
- birder/tools/pack.py +172 -103
- birder/tools/show_det_iterator.py +10 -1
- birder/version.py +1 -1
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LIT v1 Tiny, adapted from
|
|
3
|
+
https://github.com/ziplab/LIT/blob/main/classification/code_for_lit_ti/lit.py
|
|
4
|
+
|
|
5
|
+
Paper "Less is More: Pay Less Attention in Vision Transformers", https://arxiv.org/abs/2105.14217
|
|
6
|
+
|
|
7
|
+
Generated by Claude Code Opus 4.5
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# Reference license: Apache-2.0
|
|
11
|
+
|
|
12
|
+
import math
|
|
13
|
+
from collections import OrderedDict
|
|
14
|
+
from typing import Any
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
from torch import nn
|
|
20
|
+
from torchvision.ops import Permute
|
|
21
|
+
from torchvision.ops import StochasticDepth
|
|
22
|
+
|
|
23
|
+
from birder.model_registry import registry
|
|
24
|
+
from birder.net.base import DetectorBackbone
|
|
25
|
+
from birder.net.lit_v1 import MLP
|
|
26
|
+
from birder.net.lit_v1 import DeformablePatchMerging
|
|
27
|
+
from birder.net.lit_v1 import IdentityDownsample
|
|
28
|
+
from birder.net.vit import adjust_position_embedding
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MLPBlock(nn.Module):
|
|
32
|
+
def __init__(self, dim: int, mlp_ratio: float, drop_path: float) -> None:
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
35
|
+
self.mlp = MLP(dim, int(dim * mlp_ratio))
|
|
36
|
+
self.drop_path = StochasticDepth(drop_path, mode="row")
|
|
37
|
+
|
|
38
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
39
|
+
return x + self.drop_path(self.mlp(self.norm(x)))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Attention(nn.Module):
|
|
43
|
+
def __init__(self, dim: int, num_heads: int) -> None:
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.num_heads = num_heads
|
|
46
|
+
self.scale = (dim // num_heads) ** -0.5
|
|
47
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
|
48
|
+
self.proj = nn.Linear(dim, dim)
|
|
49
|
+
|
|
50
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
(B, N, C) = x.size()
|
|
52
|
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
53
|
+
(q, k, v) = qkv.unbind(0)
|
|
54
|
+
|
|
55
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
56
|
+
attn = F.softmax(attn, dim=-1)
|
|
57
|
+
|
|
58
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
59
|
+
x = self.proj(x)
|
|
60
|
+
|
|
61
|
+
return x
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class ViTBlock(nn.Module):
|
|
65
|
+
def __init__(self, dim: int, num_heads: int, mlp_ratio: float, drop_path: float) -> None:
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
68
|
+
self.attn = Attention(dim, num_heads)
|
|
69
|
+
self.drop_path1 = StochasticDepth(drop_path, mode="row")
|
|
70
|
+
|
|
71
|
+
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
|
72
|
+
self.mlp = MLP(dim, int(dim * mlp_ratio))
|
|
73
|
+
self.drop_path2 = StochasticDepth(drop_path, mode="row")
|
|
74
|
+
|
|
75
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
x = x + self.drop_path1(self.attn(self.norm1(x)))
|
|
77
|
+
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
|
78
|
+
|
|
79
|
+
return x
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LITStage(nn.Module):
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
in_dim: int,
|
|
86
|
+
out_dim: int,
|
|
87
|
+
input_resolution: tuple[int, int],
|
|
88
|
+
depth: int,
|
|
89
|
+
num_heads: int,
|
|
90
|
+
mlp_ratio: float,
|
|
91
|
+
has_msa: bool,
|
|
92
|
+
downsample: bool,
|
|
93
|
+
use_cls_token: bool,
|
|
94
|
+
drop_path: list[float],
|
|
95
|
+
) -> None:
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.dynamic_size = False
|
|
98
|
+
self.input_resolution = input_resolution
|
|
99
|
+
self.downsample: nn.Module
|
|
100
|
+
if downsample is True:
|
|
101
|
+
self.downsample = DeformablePatchMerging(in_dim, out_dim)
|
|
102
|
+
else:
|
|
103
|
+
self.downsample = IdentityDownsample()
|
|
104
|
+
|
|
105
|
+
blocks: list[nn.Module] = []
|
|
106
|
+
for i in range(depth):
|
|
107
|
+
if has_msa is True:
|
|
108
|
+
blocks.append(ViTBlock(out_dim, num_heads, mlp_ratio, drop_path[i]))
|
|
109
|
+
else:
|
|
110
|
+
blocks.append(MLPBlock(out_dim, mlp_ratio, drop_path[i]))
|
|
111
|
+
|
|
112
|
+
self.blocks = nn.ModuleList(blocks)
|
|
113
|
+
|
|
114
|
+
num_tokens = input_resolution[0] * input_resolution[1]
|
|
115
|
+
if use_cls_token is True:
|
|
116
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, out_dim))
|
|
117
|
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
118
|
+
num_tokens += 1
|
|
119
|
+
else:
|
|
120
|
+
self.cls_token = None
|
|
121
|
+
|
|
122
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, out_dim))
|
|
123
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
124
|
+
|
|
125
|
+
def set_dynamic_size(self, dynamic_size: bool = True) -> None:
|
|
126
|
+
self.dynamic_size = dynamic_size
|
|
127
|
+
|
|
128
|
+
def _get_pos_embed(self, H: int, W: int) -> torch.Tensor:
|
|
129
|
+
if self.dynamic_size is False or (H == self.input_resolution[0] and W == self.input_resolution[1]):
|
|
130
|
+
return self.pos_embed
|
|
131
|
+
|
|
132
|
+
if self.cls_token is not None:
|
|
133
|
+
num_prefix_tokens = 1
|
|
134
|
+
else:
|
|
135
|
+
num_prefix_tokens = 0
|
|
136
|
+
|
|
137
|
+
return adjust_position_embedding(
|
|
138
|
+
self.pos_embed, self.input_resolution, (H, W), num_prefix_tokens=num_prefix_tokens
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
|
|
142
|
+
(x, H, W) = self.downsample(x, input_resolution)
|
|
143
|
+
|
|
144
|
+
if self.cls_token is not None:
|
|
145
|
+
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
|
|
146
|
+
x = torch.concat((cls_tokens, x), dim=1)
|
|
147
|
+
|
|
148
|
+
x = x + self._get_pos_embed(H, W)
|
|
149
|
+
|
|
150
|
+
for block in self.blocks:
|
|
151
|
+
x = block(x)
|
|
152
|
+
|
|
153
|
+
return (x, H, W)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# pylint: disable=invalid-name
|
|
157
|
+
class LIT_v1_Tiny(DetectorBackbone):
|
|
158
|
+
block_group_regex = r"body\.stage(\d+)\.blocks\.(\d+)"
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
input_channels: int,
|
|
163
|
+
num_classes: int,
|
|
164
|
+
*,
|
|
165
|
+
config: Optional[dict[str, Any]] = None,
|
|
166
|
+
size: Optional[tuple[int, int]] = None,
|
|
167
|
+
) -> None:
|
|
168
|
+
super().__init__(input_channels, num_classes, config=config, size=size)
|
|
169
|
+
assert self.config is not None, "must set config"
|
|
170
|
+
|
|
171
|
+
patch_size = 4
|
|
172
|
+
stage_dims: list[int] = self.config["stage_dims"]
|
|
173
|
+
depths: list[int] = self.config["depths"]
|
|
174
|
+
num_heads: list[int] = self.config["num_heads"]
|
|
175
|
+
mlp_ratios: list[float] = self.config["mlp_ratios"]
|
|
176
|
+
has_msa: list[bool] = self.config["has_msa"]
|
|
177
|
+
drop_path_rate: float = self.config["drop_path_rate"]
|
|
178
|
+
|
|
179
|
+
num_stages = len(depths)
|
|
180
|
+
|
|
181
|
+
self.stem = nn.Sequential(
|
|
182
|
+
nn.Conv2d(
|
|
183
|
+
self.input_channels,
|
|
184
|
+
stage_dims[0],
|
|
185
|
+
kernel_size=(patch_size, patch_size),
|
|
186
|
+
stride=(patch_size, patch_size),
|
|
187
|
+
padding=(0, 0),
|
|
188
|
+
),
|
|
189
|
+
Permute([0, 2, 3, 1]),
|
|
190
|
+
nn.LayerNorm(stage_dims[0], eps=1e-6),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Stochastic depth
|
|
194
|
+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
195
|
+
|
|
196
|
+
stages: OrderedDict[str, nn.Module] = OrderedDict()
|
|
197
|
+
return_channels: list[int] = []
|
|
198
|
+
resolution = (self.size[0] // patch_size, self.size[1] // patch_size)
|
|
199
|
+
|
|
200
|
+
for i in range(num_stages):
|
|
201
|
+
if i > 0:
|
|
202
|
+
resolution = (resolution[0] // 2, resolution[1] // 2)
|
|
203
|
+
|
|
204
|
+
stage = LITStage(
|
|
205
|
+
stage_dims[i - 1] if i > 0 else stage_dims[0],
|
|
206
|
+
stage_dims[i],
|
|
207
|
+
input_resolution=resolution,
|
|
208
|
+
depth=depths[i],
|
|
209
|
+
num_heads=num_heads[i],
|
|
210
|
+
mlp_ratio=mlp_ratios[i],
|
|
211
|
+
has_msa=has_msa[i],
|
|
212
|
+
downsample=i > 0,
|
|
213
|
+
use_cls_token=i == num_stages - 1,
|
|
214
|
+
drop_path=dpr[i],
|
|
215
|
+
)
|
|
216
|
+
stages[f"stage{i + 1}"] = stage
|
|
217
|
+
return_channels.append(stage_dims[i])
|
|
218
|
+
|
|
219
|
+
self.body = nn.ModuleDict(stages)
|
|
220
|
+
self.norm = nn.LayerNorm(stage_dims[-1], eps=1e-6)
|
|
221
|
+
self.return_channels = return_channels
|
|
222
|
+
self.embedding_size = stage_dims[-1]
|
|
223
|
+
self.classifier = self.create_classifier()
|
|
224
|
+
self.patch_size = patch_size
|
|
225
|
+
|
|
226
|
+
# Weight initialization
|
|
227
|
+
for name, m in self.named_modules():
|
|
228
|
+
if isinstance(m, nn.Linear):
|
|
229
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
230
|
+
if m.bias is not None:
|
|
231
|
+
nn.init.zeros_(m.bias)
|
|
232
|
+
elif isinstance(m, nn.LayerNorm):
|
|
233
|
+
nn.init.ones_(m.weight)
|
|
234
|
+
nn.init.zeros_(m.bias)
|
|
235
|
+
elif isinstance(m, nn.Conv2d):
|
|
236
|
+
if name.endswith("offset_conv") is True:
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
240
|
+
fan_out //= m.groups
|
|
241
|
+
nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
|
|
242
|
+
if m.bias is not None:
|
|
243
|
+
nn.init.zeros_(m.bias)
|
|
244
|
+
elif isinstance(m, nn.BatchNorm2d):
|
|
245
|
+
nn.init.ones_(m.weight)
|
|
246
|
+
nn.init.zeros_(m.bias)
|
|
247
|
+
|
|
248
|
+
def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
249
|
+
x = self.stem(x)
|
|
250
|
+
(B, H, W, C) = x.size()
|
|
251
|
+
x = x.reshape(B, H * W, C)
|
|
252
|
+
|
|
253
|
+
out = {}
|
|
254
|
+
for name, stage in self.body.items():
|
|
255
|
+
(x, H, W) = stage(x, (H, W))
|
|
256
|
+
if name in self.return_stages:
|
|
257
|
+
if stage.cls_token is not None:
|
|
258
|
+
spatial_x = x[:, 1:]
|
|
259
|
+
else:
|
|
260
|
+
spatial_x = x
|
|
261
|
+
|
|
262
|
+
out[name] = spatial_x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
|
263
|
+
|
|
264
|
+
return out
|
|
265
|
+
|
|
266
|
+
def freeze_stages(self, up_to_stage: int) -> None:
|
|
267
|
+
for param in self.stem.parameters():
|
|
268
|
+
param.requires_grad = False
|
|
269
|
+
|
|
270
|
+
for idx, stage in enumerate(self.body.values()):
|
|
271
|
+
if idx >= up_to_stage:
|
|
272
|
+
break
|
|
273
|
+
|
|
274
|
+
for param in stage.parameters():
|
|
275
|
+
param.requires_grad = False
|
|
276
|
+
|
|
277
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
278
|
+
x = self.stem(x)
|
|
279
|
+
(B, H, W, C) = x.size()
|
|
280
|
+
x = x.reshape(B, H * W, C)
|
|
281
|
+
for stage in self.body.values():
|
|
282
|
+
(x, H, W) = stage(x, (H, W))
|
|
283
|
+
|
|
284
|
+
return x
|
|
285
|
+
|
|
286
|
+
def embedding(self, x: torch.Tensor) -> torch.Tensor:
|
|
287
|
+
x = self.forward_features(x)
|
|
288
|
+
x = self.norm(x)
|
|
289
|
+
return x[:, 0]
|
|
290
|
+
|
|
291
|
+
def set_dynamic_size(self, dynamic_size: bool = True) -> None:
|
|
292
|
+
super().set_dynamic_size(dynamic_size)
|
|
293
|
+
for stage in self.body.values():
|
|
294
|
+
stage.set_dynamic_size(dynamic_size)
|
|
295
|
+
|
|
296
|
+
def adjust_size(self, new_size: tuple[int, int]) -> None:
|
|
297
|
+
if new_size == self.size:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
super().adjust_size(new_size)
|
|
301
|
+
|
|
302
|
+
new_patches_resolution = (new_size[0] // self.patch_size, new_size[1] // self.patch_size)
|
|
303
|
+
|
|
304
|
+
(h, w) = new_patches_resolution
|
|
305
|
+
for stage in self.body.values():
|
|
306
|
+
if not isinstance(stage.downsample, IdentityDownsample):
|
|
307
|
+
h = h // 2
|
|
308
|
+
w = w // 2
|
|
309
|
+
|
|
310
|
+
out_resolution = (h, w)
|
|
311
|
+
if out_resolution == stage.input_resolution:
|
|
312
|
+
continue
|
|
313
|
+
|
|
314
|
+
if stage.cls_token is not None:
|
|
315
|
+
num_prefix_tokens = 1
|
|
316
|
+
else:
|
|
317
|
+
num_prefix_tokens = 0
|
|
318
|
+
|
|
319
|
+
with torch.no_grad():
|
|
320
|
+
pos_embed = adjust_position_embedding(
|
|
321
|
+
stage.pos_embed,
|
|
322
|
+
stage.input_resolution,
|
|
323
|
+
out_resolution,
|
|
324
|
+
num_prefix_tokens=num_prefix_tokens,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
stage.input_resolution = out_resolution
|
|
328
|
+
stage.pos_embed = nn.Parameter(pos_embed)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
registry.register_model_config(
|
|
332
|
+
"lit_v1_t",
|
|
333
|
+
LIT_v1_Tiny,
|
|
334
|
+
config={
|
|
335
|
+
"stage_dims": [64, 128, 320, 512],
|
|
336
|
+
"depths": [3, 4, 6, 3],
|
|
337
|
+
"num_heads": [1, 2, 5, 8],
|
|
338
|
+
"mlp_ratios": [8.0, 8.0, 4.0, 4.0],
|
|
339
|
+
"has_msa": [False, False, True, True],
|
|
340
|
+
"drop_path_rate": 0.1,
|
|
341
|
+
},
|
|
342
|
+
)
|