birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +24 -0
  3. birder/common/training_utils.py +338 -41
  4. birder/data/collators/detection.py +11 -3
  5. birder/data/dataloader/webdataset.py +12 -2
  6. birder/data/datasets/coco.py +8 -10
  7. birder/data/transforms/detection.py +30 -13
  8. birder/inference/detection.py +108 -4
  9. birder/inference/wbf.py +226 -0
  10. birder/kernels/load_kernel.py +16 -11
  11. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  12. birder/net/__init__.py +8 -0
  13. birder/net/cait.py +4 -3
  14. birder/net/convnext_v1.py +5 -0
  15. birder/net/crossformer.py +33 -30
  16. birder/net/crossvit.py +4 -3
  17. birder/net/deit.py +3 -3
  18. birder/net/deit3.py +3 -3
  19. birder/net/detection/deformable_detr.py +2 -5
  20. birder/net/detection/detr.py +2 -5
  21. birder/net/detection/efficientdet.py +67 -93
  22. birder/net/detection/fcos.py +2 -7
  23. birder/net/detection/retinanet.py +2 -7
  24. birder/net/detection/rt_detr_v1.py +2 -0
  25. birder/net/detection/yolo_anchors.py +205 -0
  26. birder/net/detection/yolo_v2.py +25 -24
  27. birder/net/detection/yolo_v3.py +39 -40
  28. birder/net/detection/yolo_v4.py +28 -26
  29. birder/net/detection/yolo_v4_tiny.py +24 -20
  30. birder/net/efficientformer_v1.py +15 -9
  31. birder/net/efficientformer_v2.py +39 -29
  32. birder/net/efficientvit_msft.py +9 -7
  33. birder/net/fasternet.py +1 -1
  34. birder/net/fastvit.py +1 -0
  35. birder/net/flexivit.py +5 -4
  36. birder/net/gc_vit.py +671 -0
  37. birder/net/hiera.py +12 -9
  38. birder/net/hornet.py +9 -7
  39. birder/net/iformer.py +8 -6
  40. birder/net/levit.py +42 -30
  41. birder/net/lit_v1.py +472 -0
  42. birder/net/lit_v1_tiny.py +357 -0
  43. birder/net/lit_v2.py +436 -0
  44. birder/net/maxvit.py +67 -55
  45. birder/net/mobilenet_v4_hybrid.py +1 -1
  46. birder/net/mobileone.py +1 -0
  47. birder/net/mvit_v2.py +13 -12
  48. birder/net/pit.py +4 -3
  49. birder/net/pvt_v1.py +4 -1
  50. birder/net/repghost.py +1 -0
  51. birder/net/repvgg.py +1 -0
  52. birder/net/repvit.py +1 -0
  53. birder/net/resnet_v1.py +1 -1
  54. birder/net/resnext.py +67 -25
  55. birder/net/rope_deit3.py +5 -3
  56. birder/net/rope_flexivit.py +7 -4
  57. birder/net/rope_vit.py +10 -5
  58. birder/net/se_resnet_v1.py +46 -0
  59. birder/net/se_resnext.py +3 -0
  60. birder/net/simple_vit.py +11 -8
  61. birder/net/swin_transformer_v1.py +71 -68
  62. birder/net/swin_transformer_v2.py +38 -31
  63. birder/net/tiny_vit.py +20 -10
  64. birder/net/transnext.py +38 -28
  65. birder/net/vit.py +5 -19
  66. birder/net/vit_parallel.py +5 -4
  67. birder/net/vit_sam.py +38 -37
  68. birder/net/vovnet_v1.py +15 -0
  69. birder/net/vovnet_v2.py +31 -1
  70. birder/ops/msda.py +108 -43
  71. birder/ops/swattention.py +124 -61
  72. birder/results/detection.py +4 -0
  73. birder/scripts/benchmark.py +110 -32
  74. birder/scripts/predict.py +8 -0
  75. birder/scripts/predict_detection.py +18 -11
  76. birder/scripts/train.py +48 -46
  77. birder/scripts/train_barlow_twins.py +44 -45
  78. birder/scripts/train_byol.py +44 -45
  79. birder/scripts/train_capi.py +50 -49
  80. birder/scripts/train_data2vec.py +45 -47
  81. birder/scripts/train_data2vec2.py +45 -47
  82. birder/scripts/train_detection.py +83 -50
  83. birder/scripts/train_dino_v1.py +60 -47
  84. birder/scripts/train_dino_v2.py +86 -52
  85. birder/scripts/train_dino_v2_dist.py +84 -50
  86. birder/scripts/train_franca.py +51 -52
  87. birder/scripts/train_i_jepa.py +45 -47
  88. birder/scripts/train_ibot.py +51 -53
  89. birder/scripts/train_kd.py +194 -76
  90. birder/scripts/train_mim.py +44 -45
  91. birder/scripts/train_mmcr.py +44 -45
  92. birder/scripts/train_rotnet.py +45 -46
  93. birder/scripts/train_simclr.py +44 -45
  94. birder/scripts/train_vicreg.py +44 -45
  95. birder/tools/auto_anchors.py +20 -1
  96. birder/tools/convert_model.py +18 -15
  97. birder/tools/det_results.py +114 -2
  98. birder/tools/pack.py +172 -103
  99. birder/tools/quantize_model.py +73 -67
  100. birder/tools/show_det_iterator.py +10 -1
  101. birder/version.py +1 -1
  102. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
  103. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
  104. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  105. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  106. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  107. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ """
2
+ LIT v1 Tiny, adapted from
3
+ https://github.com/ziplab/LIT/blob/main/classification/code_for_lit_ti/lit.py
4
+
5
+ Paper "Less is More: Pay Less Attention in Vision Transformers", https://arxiv.org/abs/2105.14217
6
+
7
+ Generated by Claude Code Opus 4.5
8
+ """
9
+
10
+ # Reference license: Apache-2.0
11
+
12
+ import math
13
+ from collections import OrderedDict
14
+ from typing import Any
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from torchvision.ops import Permute
21
+ from torchvision.ops import StochasticDepth
22
+
23
+ from birder.model_registry import registry
24
+ from birder.net.base import DetectorBackbone
25
+ from birder.net.lit_v1 import MLP
26
+ from birder.net.lit_v1 import DeformablePatchMerging
27
+ from birder.net.lit_v1 import IdentityDownsample
28
+ from birder.net.vit import adjust_position_embedding
29
+
30
+
31
+ class MLPBlock(nn.Module):
32
+ def __init__(self, dim: int, mlp_ratio: float, drop_path: float) -> None:
33
+ super().__init__()
34
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
35
+ self.mlp = MLP(dim, int(dim * mlp_ratio))
36
+ self.drop_path = StochasticDepth(drop_path, mode="row")
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ return x + self.drop_path(self.mlp(self.norm(x)))
40
+
41
+
42
+ class Attention(nn.Module):
43
+ def __init__(self, dim: int, num_heads: int) -> None:
44
+ super().__init__()
45
+ self.num_heads = num_heads
46
+ self.scale = (dim // num_heads) ** -0.5
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
48
+ self.proj = nn.Linear(dim, dim)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ (B, N, C) = x.size()
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
53
+ (q, k, v) = qkv.unbind(0)
54
+
55
+ attn = (q @ k.transpose(-2, -1)) * self.scale
56
+ attn = F.softmax(attn, dim=-1)
57
+
58
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
59
+ x = self.proj(x)
60
+
61
+ return x
62
+
63
+
64
+ class ViTBlock(nn.Module):
65
+ def __init__(self, dim: int, num_heads: int, mlp_ratio: float, drop_path: float) -> None:
66
+ super().__init__()
67
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
68
+ self.attn = Attention(dim, num_heads)
69
+ self.drop_path1 = StochasticDepth(drop_path, mode="row")
70
+
71
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
72
+ self.mlp = MLP(dim, int(dim * mlp_ratio))
73
+ self.drop_path2 = StochasticDepth(drop_path, mode="row")
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ x = x + self.drop_path1(self.attn(self.norm1(x)))
77
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
78
+
79
+ return x
80
+
81
+
82
+ class LITStage(nn.Module):
83
+ def __init__(
84
+ self,
85
+ in_dim: int,
86
+ out_dim: int,
87
+ input_resolution: tuple[int, int],
88
+ depth: int,
89
+ num_heads: int,
90
+ mlp_ratio: float,
91
+ has_msa: bool,
92
+ downsample: bool,
93
+ use_cls_token: bool,
94
+ drop_path: list[float],
95
+ ) -> None:
96
+ super().__init__()
97
+ self.dynamic_size = False
98
+ self.input_resolution = input_resolution
99
+ self.downsample: nn.Module
100
+ if downsample is True:
101
+ self.downsample = DeformablePatchMerging(in_dim, out_dim)
102
+ else:
103
+ self.downsample = IdentityDownsample()
104
+
105
+ blocks: list[nn.Module] = []
106
+ for i in range(depth):
107
+ if has_msa is True:
108
+ blocks.append(ViTBlock(out_dim, num_heads, mlp_ratio, drop_path[i]))
109
+ else:
110
+ blocks.append(MLPBlock(out_dim, mlp_ratio, drop_path[i]))
111
+
112
+ self.blocks = nn.ModuleList(blocks)
113
+
114
+ num_tokens = input_resolution[0] * input_resolution[1]
115
+ if use_cls_token is True:
116
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, out_dim))
117
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
118
+ num_tokens += 1
119
+ else:
120
+ self.cls_token = None
121
+
122
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, out_dim))
123
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
124
+
125
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
126
+ self.dynamic_size = dynamic_size
127
+
128
+ def _get_pos_embed(self, H: int, W: int) -> torch.Tensor:
129
+ if self.dynamic_size is False or (H == self.input_resolution[0] and W == self.input_resolution[1]):
130
+ return self.pos_embed
131
+
132
+ if self.cls_token is not None:
133
+ num_prefix_tokens = 1
134
+ else:
135
+ num_prefix_tokens = 0
136
+
137
+ return adjust_position_embedding(
138
+ self.pos_embed, self.input_resolution, (H, W), num_prefix_tokens=num_prefix_tokens
139
+ )
140
+
141
+ def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
142
+ (x, H, W) = self.downsample(x, input_resolution)
143
+
144
+ if self.cls_token is not None:
145
+ cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
146
+ x = torch.concat((cls_tokens, x), dim=1)
147
+
148
+ x = x + self._get_pos_embed(H, W)
149
+
150
+ for block in self.blocks:
151
+ x = block(x)
152
+
153
+ return (x, H, W)
154
+
155
+
156
+ # pylint: disable=invalid-name
157
+ class LIT_v1_Tiny(DetectorBackbone):
158
+ block_group_regex = r"body\.stage(\d+)\.blocks\.(\d+)"
159
+
160
+ def __init__(
161
+ self,
162
+ input_channels: int,
163
+ num_classes: int,
164
+ *,
165
+ config: Optional[dict[str, Any]] = None,
166
+ size: Optional[tuple[int, int]] = None,
167
+ ) -> None:
168
+ super().__init__(input_channels, num_classes, config=config, size=size)
169
+ assert self.config is not None, "must set config"
170
+
171
+ patch_size = 4
172
+ stage_dims: list[int] = self.config["stage_dims"]
173
+ depths: list[int] = self.config["depths"]
174
+ num_heads: list[int] = self.config["num_heads"]
175
+ mlp_ratios: list[float] = self.config["mlp_ratios"]
176
+ has_msa: list[bool] = self.config["has_msa"]
177
+ drop_path_rate: float = self.config["drop_path_rate"]
178
+
179
+ num_stages = len(depths)
180
+
181
+ self.stem = nn.Sequential(
182
+ nn.Conv2d(
183
+ self.input_channels,
184
+ stage_dims[0],
185
+ kernel_size=(patch_size, patch_size),
186
+ stride=(patch_size, patch_size),
187
+ padding=(0, 0),
188
+ ),
189
+ Permute([0, 2, 3, 1]),
190
+ nn.LayerNorm(stage_dims[0], eps=1e-6),
191
+ )
192
+
193
+ # Stochastic depth
194
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
195
+
196
+ stages: OrderedDict[str, nn.Module] = OrderedDict()
197
+ return_channels: list[int] = []
198
+ resolution = (self.size[0] // patch_size, self.size[1] // patch_size)
199
+
200
+ for i in range(num_stages):
201
+ if i > 0:
202
+ resolution = (resolution[0] // 2, resolution[1] // 2)
203
+
204
+ stage = LITStage(
205
+ stage_dims[i - 1] if i > 0 else stage_dims[0],
206
+ stage_dims[i],
207
+ input_resolution=resolution,
208
+ depth=depths[i],
209
+ num_heads=num_heads[i],
210
+ mlp_ratio=mlp_ratios[i],
211
+ has_msa=has_msa[i],
212
+ downsample=i > 0,
213
+ use_cls_token=i == num_stages - 1,
214
+ drop_path=dpr[i],
215
+ )
216
+ stages[f"stage{i + 1}"] = stage
217
+ return_channels.append(stage_dims[i])
218
+
219
+ self.body = nn.ModuleDict(stages)
220
+ self.norm = nn.LayerNorm(stage_dims[-1], eps=1e-6)
221
+ self.return_channels = return_channels
222
+ self.embedding_size = stage_dims[-1]
223
+ self.classifier = self.create_classifier()
224
+ self.patch_size = patch_size
225
+
226
+ # Weight initialization
227
+ for name, m in self.named_modules():
228
+ if isinstance(m, nn.Linear):
229
+ nn.init.trunc_normal_(m.weight, std=0.02)
230
+ if m.bias is not None:
231
+ nn.init.zeros_(m.bias)
232
+ elif isinstance(m, nn.LayerNorm):
233
+ nn.init.ones_(m.weight)
234
+ nn.init.zeros_(m.bias)
235
+ elif isinstance(m, nn.Conv2d):
236
+ if name.endswith("offset_conv") is True:
237
+ continue
238
+
239
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
240
+ fan_out //= m.groups
241
+ nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
242
+ if m.bias is not None:
243
+ nn.init.zeros_(m.bias)
244
+ elif isinstance(m, nn.BatchNorm2d):
245
+ nn.init.ones_(m.weight)
246
+ nn.init.zeros_(m.bias)
247
+
248
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
249
+ x = self.stem(x)
250
+ (B, H, W, C) = x.size()
251
+ x = x.reshape(B, H * W, C)
252
+
253
+ out = {}
254
+ for name, stage in self.body.items():
255
+ (x, H, W) = stage(x, (H, W))
256
+ if name in self.return_stages:
257
+ if stage.cls_token is not None:
258
+ spatial_x = x[:, 1:]
259
+ else:
260
+ spatial_x = x
261
+
262
+ out[name] = spatial_x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
263
+
264
+ return out
265
+
266
+ def freeze_stages(self, up_to_stage: int) -> None:
267
+ for param in self.stem.parameters():
268
+ param.requires_grad = False
269
+
270
+ for idx, stage in enumerate(self.body.values()):
271
+ if idx >= up_to_stage:
272
+ break
273
+
274
+ for param in stage.parameters():
275
+ param.requires_grad = False
276
+
277
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
278
+ x = self.stem(x)
279
+ (B, H, W, C) = x.size()
280
+ x = x.reshape(B, H * W, C)
281
+ for stage in self.body.values():
282
+ (x, H, W) = stage(x, (H, W))
283
+
284
+ return x
285
+
286
+ def embedding(self, x: torch.Tensor) -> torch.Tensor:
287
+ x = self.forward_features(x)
288
+ x = self.norm(x)
289
+ return x[:, 0]
290
+
291
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
292
+ super().set_dynamic_size(dynamic_size)
293
+ for stage in self.body.values():
294
+ stage.set_dynamic_size(dynamic_size)
295
+
296
+ def adjust_size(self, new_size: tuple[int, int]) -> None:
297
+ if new_size == self.size:
298
+ return
299
+
300
+ super().adjust_size(new_size)
301
+
302
+ new_patches_resolution = (new_size[0] // self.patch_size, new_size[1] // self.patch_size)
303
+
304
+ (h, w) = new_patches_resolution
305
+ for stage in self.body.values():
306
+ if not isinstance(stage.downsample, IdentityDownsample):
307
+ h = h // 2
308
+ w = w // 2
309
+
310
+ out_resolution = (h, w)
311
+ if out_resolution == stage.input_resolution:
312
+ continue
313
+
314
+ if stage.cls_token is not None:
315
+ num_prefix_tokens = 1
316
+ else:
317
+ num_prefix_tokens = 0
318
+
319
+ with torch.no_grad():
320
+ pos_embed = adjust_position_embedding(
321
+ stage.pos_embed,
322
+ stage.input_resolution,
323
+ out_resolution,
324
+ num_prefix_tokens=num_prefix_tokens,
325
+ )
326
+
327
+ stage.input_resolution = out_resolution
328
+ stage.pos_embed = nn.Parameter(pos_embed)
329
+
330
+
331
+ registry.register_model_config(
332
+ "lit_v1_t",
333
+ LIT_v1_Tiny,
334
+ config={
335
+ "stage_dims": [64, 128, 320, 512],
336
+ "depths": [3, 4, 6, 3],
337
+ "num_heads": [1, 2, 5, 8],
338
+ "mlp_ratios": [8.0, 8.0, 4.0, 4.0],
339
+ "has_msa": [False, False, True, True],
340
+ "drop_path_rate": 0.1,
341
+ },
342
+ )
343
+
344
+ registry.register_weights(
345
+ "lit_v1_t_il-common",
346
+ {
347
+ "description": "LIT v1 Tiny model trained on the il-common dataset",
348
+ "resolution": (256, 256),
349
+ "formats": {
350
+ "pt": {
351
+ "file_size": 75.2,
352
+ "sha256": "93813b2716eb9f33e06dc15ab2ba335c6d219354d2983bbc4f834f8f4e688e5c",
353
+ }
354
+ },
355
+ "net": {"network": "lit_v1_t", "tag": "il-common"},
356
+ },
357
+ )