birder 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +18 -0
  3. birder/common/training_utils.py +123 -10
  4. birder/data/collators/detection.py +10 -3
  5. birder/data/datasets/coco.py +8 -10
  6. birder/data/transforms/detection.py +30 -13
  7. birder/inference/detection.py +108 -4
  8. birder/inference/wbf.py +226 -0
  9. birder/net/__init__.py +8 -0
  10. birder/net/detection/efficientdet.py +65 -86
  11. birder/net/detection/rt_detr_v1.py +1 -0
  12. birder/net/detection/yolo_anchors.py +205 -0
  13. birder/net/detection/yolo_v2.py +25 -24
  14. birder/net/detection/yolo_v3.py +39 -40
  15. birder/net/detection/yolo_v4.py +28 -26
  16. birder/net/detection/yolo_v4_tiny.py +24 -20
  17. birder/net/fasternet.py +1 -1
  18. birder/net/gc_vit.py +671 -0
  19. birder/net/lit_v1.py +472 -0
  20. birder/net/lit_v1_tiny.py +342 -0
  21. birder/net/lit_v2.py +436 -0
  22. birder/net/mobilenet_v4_hybrid.py +1 -1
  23. birder/net/resnet_v1.py +1 -1
  24. birder/net/resnext.py +67 -25
  25. birder/net/se_resnet_v1.py +46 -0
  26. birder/net/se_resnext.py +3 -0
  27. birder/net/simple_vit.py +2 -2
  28. birder/net/vit.py +0 -15
  29. birder/net/vovnet_v2.py +31 -1
  30. birder/scripts/benchmark.py +90 -21
  31. birder/scripts/predict.py +1 -0
  32. birder/scripts/predict_detection.py +18 -11
  33. birder/scripts/train.py +10 -34
  34. birder/scripts/train_barlow_twins.py +10 -34
  35. birder/scripts/train_byol.py +10 -34
  36. birder/scripts/train_capi.py +10 -35
  37. birder/scripts/train_data2vec.py +9 -34
  38. birder/scripts/train_data2vec2.py +9 -34
  39. birder/scripts/train_detection.py +48 -40
  40. birder/scripts/train_dino_v1.py +10 -34
  41. birder/scripts/train_dino_v2.py +9 -34
  42. birder/scripts/train_dino_v2_dist.py +9 -34
  43. birder/scripts/train_franca.py +9 -34
  44. birder/scripts/train_i_jepa.py +9 -34
  45. birder/scripts/train_ibot.py +9 -34
  46. birder/scripts/train_kd.py +156 -64
  47. birder/scripts/train_mim.py +10 -34
  48. birder/scripts/train_mmcr.py +10 -34
  49. birder/scripts/train_rotnet.py +10 -34
  50. birder/scripts/train_simclr.py +10 -34
  51. birder/scripts/train_vicreg.py +10 -34
  52. birder/tools/auto_anchors.py +20 -1
  53. birder/tools/pack.py +172 -103
  54. birder/tools/show_det_iterator.py +10 -1
  55. birder/version.py +1 -1
  56. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/METADATA +3 -3
  57. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/RECORD +61 -55
  58. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/WHEEL +0 -0
  59. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/entry_points.txt +0 -0
  60. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/licenses/LICENSE +0 -0
  61. {birder-0.2.2.dist-info → birder-0.2.3.dist-info}/top_level.txt +0 -0
birder/net/lit_v2.py ADDED
@@ -0,0 +1,436 @@
1
+ """
2
+ LIT v2, adapted from
3
+ https://github.com/ziplab/LITv2/blob/main/classification/models/litv2.py
4
+
5
+ Paper "Fast Vision Transformers with HiLo Attention", https://arxiv.org/abs/2205.13213
6
+
7
+ Generated by Claude Code Opus 4.5
8
+ """
9
+
10
+ # Reference license: Apache-2.0
11
+
12
+ import math
13
+ from collections import OrderedDict
14
+ from typing import Any
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from torchvision.ops import Permute
21
+ from torchvision.ops import StochasticDepth
22
+
23
+ from birder.model_registry import registry
24
+ from birder.net.base import DetectorBackbone
25
+ from birder.net.lit_v1 import DeformablePatchMerging
26
+ from birder.net.lit_v1 import IdentityDownsample
27
+
28
+
29
+ class DepthwiseMLP(nn.Module):
30
+ def __init__(self, in_features: int, hidden_features: int) -> None:
31
+ super().__init__()
32
+ self.fc1 = nn.Linear(in_features, hidden_features)
33
+ self.dwconv = nn.Conv2d(
34
+ hidden_features, hidden_features, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=hidden_features
35
+ )
36
+ self.act = nn.GELU()
37
+ self.fc2 = nn.Linear(hidden_features, in_features)
38
+
39
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
40
+ x = self.fc1(x)
41
+
42
+ (B, N, C) = x.size()
43
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
44
+ x = self.dwconv(x)
45
+ x = x.permute(0, 2, 3, 1).reshape(B, N, C)
46
+ x = self.act(x)
47
+ x = self.fc2(x)
48
+
49
+ return x
50
+
51
+
52
+ class DepthwiseMLPBlock(nn.Module):
53
+ def __init__(self, dim: int, mlp_ratio: float, drop_path: float) -> None:
54
+ super().__init__()
55
+ self.norm = nn.LayerNorm(dim)
56
+ self.mlp = DepthwiseMLP(dim, int(dim * mlp_ratio))
57
+ self.drop_path = StochasticDepth(drop_path, mode="row")
58
+
59
+ def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
60
+ (H, W) = resolution
61
+ return x + self.drop_path(self.mlp(self.norm(x), H, W))
62
+
63
+
64
+ class HiLoAttention(nn.Module):
65
+ """
66
+ HiLo Attention: High-frequency local attention + Low-frequency global attention
67
+
68
+ Hi-Fi (High frequency): Local window attention
69
+ Lo-Fi (Low frequency): Global attention with average pooling
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ dim: int,
75
+ num_heads: int,
76
+ window_size: int,
77
+ alpha: float,
78
+ ) -> None:
79
+ super().__init__()
80
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
81
+
82
+ self.window_size = window_size
83
+ self.head_dim = dim // num_heads
84
+ self.scale = self.head_dim**-0.5
85
+
86
+ # Split heads between Lo-Fi (global) and Hi-Fi (local)
87
+ self.l_heads = int(num_heads * alpha) # Lo-Fi heads
88
+ self.h_heads = num_heads - self.l_heads # Hi-Fi heads
89
+ self.l_dim = self.l_heads * self.head_dim
90
+ self.h_dim = self.h_heads * self.head_dim
91
+ self.head_dim = self.head_dim
92
+
93
+ # ws == 1 is equal to standard multi-head self-attention
94
+ if window_size == 1:
95
+ self.h_heads = 0
96
+ self.h_dim = 0
97
+ self.l_heads = num_heads
98
+ self.l_dim = dim
99
+
100
+ # Lo-Fi: Global attention with pooling
101
+ if self.l_heads > 0:
102
+ if window_size > 1:
103
+ self.sr = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=(window_size, window_size))
104
+ else:
105
+ self.sr = nn.Identity()
106
+
107
+ self.l_q = nn.Linear(dim, self.l_dim)
108
+ self.l_kv = nn.Linear(dim, self.l_dim * 2)
109
+ self.l_proj = nn.Linear(self.l_dim, self.l_dim)
110
+ else:
111
+ self.l_q = nn.Identity()
112
+ self.l_kv = nn.Identity()
113
+ self.l_proj = nn.Identity()
114
+
115
+ # Hi-Fi: Local window attention
116
+ if self.h_heads > 0:
117
+ self.h_qkv = nn.Linear(dim, self.h_dim * 3)
118
+ self.h_proj = nn.Linear(self.h_dim, self.h_dim)
119
+ else:
120
+ self.h_qkv = nn.Identity()
121
+ self.h_proj = nn.Identity()
122
+
123
+ def _lofi(self, x: torch.Tensor) -> torch.Tensor:
124
+ (B, H, W, C) = x.size()
125
+
126
+ q = self.l_q(x).reshape(B, H * W, self.l_heads, self.head_dim).permute(0, 2, 1, 3)
127
+
128
+ # Spatial reduction for k, v
129
+ if self.window_size > 1:
130
+ x = x.permute(0, 3, 1, 2)
131
+ x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
132
+ kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.head_dim).permute(2, 0, 3, 1, 4)
133
+ else:
134
+ kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.head_dim).permute(2, 0, 3, 1, 4)
135
+
136
+ (k, v) = kv.unbind(0)
137
+
138
+ attn = (q @ k.transpose(-2, -1)) * self.scale
139
+ attn = F.softmax(attn, dim=-1)
140
+
141
+ x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
142
+ x = self.l_proj(x)
143
+
144
+ return x
145
+
146
+ def _hifi(self, x: torch.Tensor) -> torch.Tensor:
147
+ (B, H, W, _) = x.size()
148
+ ws = self.window_size
149
+
150
+ # Pad if needed
151
+ pad_h = (ws - H % ws) % ws
152
+ pad_w = (ws - W % ws) % ws
153
+ if pad_h > 0 or pad_w > 0:
154
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
155
+
156
+ (_, h_pad, w_pad, _) = x.size()
157
+ h_groups = h_pad // ws
158
+ w_groups = w_pad // ws
159
+ total_groups = h_groups * w_groups
160
+
161
+ x = x.reshape(B, h_groups, ws, w_groups, ws, -1).transpose(2, 3)
162
+
163
+ qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.head_dim).permute(3, 0, 1, 4, 2, 5)
164
+ (q, k, v) = qkv.unbind(0)
165
+
166
+ attn = (q @ k.transpose(-2, -1)) * self.scale
167
+ attn = F.softmax(attn, dim=-1)
168
+
169
+ x = (attn @ v).transpose(2, 3).reshape(B, h_groups, w_groups, ws, ws, self.h_dim)
170
+ x = x.transpose(2, 3).reshape(B, h_pad, w_pad, self.h_dim)
171
+ x = self.h_proj(x)
172
+
173
+ # Remove padding
174
+ if pad_h > 0 or pad_w > 0:
175
+ x = x[:, :H, :W, :].contiguous()
176
+
177
+ return x
178
+
179
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
180
+ (B, N, C) = x.size()
181
+ x = x.reshape(B, H, W, C)
182
+
183
+ if self.h_heads == 0:
184
+ x = self._lofi(x)
185
+ return x.reshape(B, N, C)
186
+
187
+ if self.l_heads == 0:
188
+ x = self._hifi(x)
189
+ return x.reshape(B, N, C)
190
+
191
+ # Process both branches and concatenate
192
+ hifi_out = self._hifi(x)
193
+ lofi_out = self._lofi(x)
194
+
195
+ x = torch.concat((hifi_out, lofi_out), dim=-1)
196
+ return x.reshape(B, N, C)
197
+
198
+
199
+ class HiLoBlock(nn.Module):
200
+ def __init__(
201
+ self,
202
+ dim: int,
203
+ num_heads: int,
204
+ window_size: int,
205
+ alpha: float,
206
+ mlp_ratio: float,
207
+ drop_path: float,
208
+ ) -> None:
209
+ super().__init__()
210
+ self.norm1 = nn.LayerNorm(dim)
211
+ self.attn = HiLoAttention(dim, num_heads, window_size, alpha)
212
+ self.drop_path1 = StochasticDepth(drop_path, mode="row")
213
+ self.norm2 = nn.LayerNorm(dim)
214
+ self.mlp = DepthwiseMLP(dim, int(dim * mlp_ratio))
215
+ self.drop_path2 = StochasticDepth(drop_path, mode="row")
216
+
217
+ def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
218
+ (H, W) = resolution
219
+ x = x + self.drop_path1(self.attn(self.norm1(x), H, W))
220
+ x = x + self.drop_path2(self.mlp(self.norm2(x), H, W))
221
+ return x
222
+
223
+
224
+ class LITStage(nn.Module):
225
+ def __init__(
226
+ self,
227
+ in_dim: int,
228
+ out_dim: int,
229
+ resolution: tuple[int, int],
230
+ depth: int,
231
+ num_heads: int,
232
+ window_size: int,
233
+ alpha: float,
234
+ mlp_ratio: float,
235
+ downsample: bool,
236
+ drop_path: list[float],
237
+ ) -> None:
238
+ super().__init__()
239
+ if downsample is True:
240
+ self.downsample = DeformablePatchMerging(in_dim, out_dim)
241
+ resolution = (resolution[0] // 2, resolution[1] // 2)
242
+ else:
243
+ self.downsample = IdentityDownsample()
244
+
245
+ blocks: list[nn.Module] = []
246
+ for i in range(depth):
247
+ if window_size > 0:
248
+ blocks.append(HiLoBlock(out_dim, num_heads, window_size, alpha, mlp_ratio, drop_path[i]))
249
+ else:
250
+ blocks.append(DepthwiseMLPBlock(out_dim, mlp_ratio, drop_path[i]))
251
+
252
+ self.blocks = nn.ModuleList(blocks)
253
+
254
+ def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
255
+ (x, H, W) = self.downsample(x, input_resolution)
256
+ for block in self.blocks:
257
+ x = block(x, (H, W))
258
+
259
+ return (x, H, W)
260
+
261
+
262
+ # pylint: disable=invalid-name
263
+ class LIT_v2(DetectorBackbone):
264
+ block_group_regex = r"body\.stage(\d+)\.blocks\.(\d+)"
265
+
266
+ # pylint:disable=too-many-locals
267
+ def __init__(
268
+ self,
269
+ input_channels: int,
270
+ num_classes: int,
271
+ *,
272
+ config: Optional[dict[str, Any]] = None,
273
+ size: Optional[tuple[int, int]] = None,
274
+ ) -> None:
275
+ super().__init__(input_channels, num_classes, config=config, size=size)
276
+ assert self.config is not None, "must set config"
277
+
278
+ patch_size = 4
279
+ embed_dim: int = self.config["embed_dim"]
280
+ depths: list[int] = self.config["depths"]
281
+ num_heads: list[int] = self.config["num_heads"]
282
+ local_ws: list[int] = self.config["local_ws"]
283
+ alpha: float = self.config["alpha"]
284
+ drop_path_rate: float = self.config["drop_path_rate"]
285
+
286
+ num_stages = len(depths)
287
+
288
+ self.stem = nn.Sequential(
289
+ nn.Conv2d(
290
+ self.input_channels,
291
+ embed_dim,
292
+ kernel_size=(patch_size, patch_size),
293
+ stride=(patch_size, patch_size),
294
+ padding=(0, 0),
295
+ bias=True,
296
+ ),
297
+ Permute([0, 2, 3, 1]),
298
+ nn.LayerNorm(embed_dim),
299
+ )
300
+
301
+ # Stochastic depth
302
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
303
+
304
+ stages: OrderedDict[str, nn.Module] = OrderedDict()
305
+ return_channels: list[int] = []
306
+ prev_dim = embed_dim
307
+ resolution = (self.size[0] // patch_size, self.size[1] // patch_size)
308
+ for i_stage in range(num_stages):
309
+ in_dim = prev_dim
310
+ out_dim = in_dim * 2 if i_stage > 0 else in_dim
311
+ stage = LITStage(
312
+ in_dim,
313
+ out_dim,
314
+ resolution,
315
+ depth=depths[i_stage],
316
+ num_heads=num_heads[i_stage],
317
+ window_size=local_ws[i_stage],
318
+ alpha=alpha,
319
+ mlp_ratio=4.0,
320
+ downsample=i_stage > 0,
321
+ drop_path=dpr[i_stage],
322
+ )
323
+ stages[f"stage{i_stage + 1}"] = stage
324
+
325
+ if i_stage > 0:
326
+ resolution = (resolution[0] // 2, resolution[1] // 2)
327
+
328
+ prev_dim = out_dim
329
+ return_channels.append(out_dim)
330
+
331
+ num_features = embed_dim * (2 ** (num_stages - 1))
332
+ self.body = nn.ModuleDict(stages)
333
+ self.features = nn.Sequential(
334
+ nn.LayerNorm(num_features),
335
+ Permute([0, 2, 1]),
336
+ nn.AdaptiveAvgPool1d(output_size=1),
337
+ nn.Flatten(1),
338
+ )
339
+ self.return_channels = return_channels
340
+ self.embedding_size = num_features
341
+ self.classifier = self.create_classifier()
342
+
343
+ # Weight initialization
344
+ for name, m in self.named_modules():
345
+ if isinstance(m, nn.Linear):
346
+ nn.init.trunc_normal_(m.weight, std=0.02)
347
+ if m.bias is not None:
348
+ nn.init.zeros_(m.bias)
349
+ elif isinstance(m, nn.LayerNorm):
350
+ nn.init.ones_(m.weight)
351
+ nn.init.zeros_(m.bias)
352
+ elif isinstance(m, nn.Conv2d):
353
+ if name.endswith("offset_conv") is True:
354
+ continue
355
+
356
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
357
+ fan_out //= m.groups
358
+ nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
359
+ if m.bias is not None:
360
+ nn.init.zeros_(m.bias)
361
+
362
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
363
+ x = self.stem(x)
364
+ (B, H, W, C) = x.size()
365
+ x = x.reshape(B, H * W, C)
366
+
367
+ out = {}
368
+ for name, stage in self.body.items():
369
+ (x, H, W) = stage(x, (H, W))
370
+ if name in self.return_stages:
371
+ features = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
372
+ out[name] = features
373
+
374
+ return out
375
+
376
+ def freeze_stages(self, up_to_stage: int) -> None:
377
+ for param in self.stem.parameters():
378
+ param.requires_grad = False
379
+
380
+ for idx, stage in enumerate(self.body.values()):
381
+ if idx >= up_to_stage:
382
+ break
383
+
384
+ for param in stage.parameters():
385
+ param.requires_grad = False
386
+
387
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
388
+ x = self.stem(x)
389
+ (B, H, W, C) = x.size()
390
+ x = x.reshape(B, H * W, C)
391
+ for stage in self.body.values():
392
+ (x, H, W) = stage(x, (H, W))
393
+
394
+ return x
395
+
396
+ def embedding(self, x: torch.Tensor) -> torch.Tensor:
397
+ x = self.forward_features(x)
398
+ return self.features(x)
399
+
400
+
401
+ registry.register_model_config(
402
+ "lit_v2_s",
403
+ LIT_v2,
404
+ config={
405
+ "embed_dim": 96,
406
+ "depths": [2, 2, 6, 2],
407
+ "num_heads": [3, 6, 12, 24],
408
+ "local_ws": [0, 0, 2, 1],
409
+ "alpha": 0.9,
410
+ "drop_path_rate": 0.2,
411
+ },
412
+ )
413
+ registry.register_model_config(
414
+ "lit_v2_m",
415
+ LIT_v2,
416
+ config={
417
+ "embed_dim": 96,
418
+ "depths": [2, 2, 18, 2],
419
+ "num_heads": [3, 6, 12, 24],
420
+ "local_ws": [0, 0, 2, 1],
421
+ "alpha": 0.9,
422
+ "drop_path_rate": 0.3,
423
+ },
424
+ )
425
+ registry.register_model_config(
426
+ "lit_v2_b",
427
+ LIT_v2,
428
+ config={
429
+ "embed_dim": 128,
430
+ "depths": [2, 2, 18, 2],
431
+ "num_heads": [4, 8, 16, 32],
432
+ "local_ws": [0, 0, 2, 1],
433
+ "alpha": 0.9,
434
+ "drop_path_rate": 0.5,
435
+ },
436
+ )
@@ -491,7 +491,7 @@ registry.register_weights(
491
491
  "formats": {
492
492
  "pt": {
493
493
  "file_size": 39.7,
494
- "sha256": "220df49e08ea49e24f30dcc777bf48c7aaea4aaa5909b56c931f41747381d390",
494
+ "sha256": "d7d76733e0116d351bf8aafc563659eab7bea02174a02c10fba8eb3a64ea87e1",
495
495
  }
496
496
  },
497
497
  "net": {"network": "mobilenet_v4_hybrid_m", "tag": "il-common"},
birder/net/resnet_v1.py CHANGED
@@ -58,7 +58,7 @@ class ResidualBlock(nn.Module):
58
58
  nn.BatchNorm2d(out_channels),
59
59
  )
60
60
 
61
- if in_channels == out_channels:
61
+ if in_channels == out_channels and stride == (1, 1):
62
62
  self.block2 = nn.Identity()
63
63
  else:
64
64
  if avg_down is True and stride != (1, 1):
birder/net/resnext.py CHANGED
@@ -30,6 +30,7 @@ class ResidualBlock(nn.Module):
30
30
  base_width: int,
31
31
  expansion: int,
32
32
  squeeze_excitation: bool,
33
+ avg_down: bool,
33
34
  ) -> None:
34
35
  super().__init__()
35
36
  width = int(out_channels * (base_width / 64.0)) * groups
@@ -62,20 +63,34 @@ class ResidualBlock(nn.Module):
62
63
  nn.BatchNorm2d(out_channels * expansion),
63
64
  )
64
65
 
65
- if in_channels == out_channels * expansion:
66
+ if in_channels == out_channels * expansion and stride == (1, 1):
66
67
  self.block2 = nn.Identity()
67
68
  else:
68
- self.block2 = nn.Sequential(
69
- nn.Conv2d(
70
- in_channels,
71
- out_channels * expansion,
72
- kernel_size=(1, 1),
73
- stride=stride,
74
- padding=(0, 0),
75
- bias=False,
76
- ),
77
- nn.BatchNorm2d(out_channels * expansion),
78
- )
69
+ if avg_down is True and stride != (1, 1):
70
+ self.block2 = nn.Sequential(
71
+ nn.AvgPool2d(kernel_size=2, stride=stride, ceil_mode=True, count_include_pad=False),
72
+ nn.Conv2d(
73
+ in_channels,
74
+ out_channels * expansion,
75
+ kernel_size=(1, 1),
76
+ stride=(1, 1),
77
+ padding=(0, 0),
78
+ bias=False,
79
+ ),
80
+ nn.BatchNorm2d(out_channels * expansion),
81
+ )
82
+ else:
83
+ self.block2 = nn.Sequential(
84
+ nn.Conv2d(
85
+ in_channels,
86
+ out_channels * expansion,
87
+ kernel_size=(1, 1),
88
+ stride=stride,
89
+ padding=(0, 0),
90
+ bias=False,
91
+ ),
92
+ nn.BatchNorm2d(out_channels * expansion),
93
+ )
79
94
 
80
95
  self.relu = nn.ReLU(inplace=True)
81
96
  if squeeze_excitation is True:
@@ -107,23 +122,35 @@ class ResNeXt(DetectorBackbone):
107
122
  super().__init__(input_channels, num_classes, config=config, size=size)
108
123
  assert self.config is not None, "must set config"
109
124
 
110
- groups = 32
111
- base_width = 4
112
125
  expansion = 4
126
+ groups: int = self.config.get("groups", 32)
127
+ base_width: int = self.config.get("base_width", 4)
113
128
  filter_list = [64, 128, 256, 512]
114
129
  units: list[int] = self.config["units"]
130
+ deep_stem: bool = self.config.get("deep_stem", False)
131
+ avg_down: bool = self.config.get("avg_down", False)
115
132
 
116
- self.stem = nn.Sequential(
117
- Conv2dNormActivation(
118
- self.input_channels,
119
- filter_list[0],
120
- kernel_size=(7, 7),
121
- stride=(2, 2),
122
- padding=(3, 3),
123
- bias=False,
124
- ),
125
- nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
126
- )
133
+ if deep_stem is True:
134
+ self.stem = nn.Sequential(
135
+ Conv2dNormActivation(
136
+ self.input_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
137
+ ),
138
+ Conv2dNormActivation(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
139
+ Conv2dNormActivation(32, filter_list[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
140
+ nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
141
+ )
142
+ else:
143
+ self.stem = nn.Sequential(
144
+ Conv2dNormActivation(
145
+ self.input_channels,
146
+ filter_list[0],
147
+ kernel_size=(7, 7),
148
+ stride=(2, 2),
149
+ padding=(3, 3),
150
+ bias=False,
151
+ ),
152
+ nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
153
+ )
127
154
 
128
155
  # Generate body layers
129
156
  in_channels = filter_list[0]
@@ -150,6 +177,7 @@ class ResNeXt(DetectorBackbone):
150
177
  base_width=base_width,
151
178
  expansion=expansion,
152
179
  squeeze_excitation=squeeze_excitation,
180
+ avg_down=avg_down,
153
181
  )
154
182
  )
155
183
  in_channels = channels * expansion
@@ -209,3 +237,17 @@ class ResNeXt(DetectorBackbone):
209
237
  registry.register_model_config("resnext_50", ResNeXt, config={"units": [3, 4, 6, 3]})
210
238
  registry.register_model_config("resnext_101", ResNeXt, config={"units": [3, 4, 23, 3]})
211
239
  registry.register_model_config("resnext_152", ResNeXt, config={"units": [3, 8, 36, 3]})
240
+
241
+ registry.register_model_config("resnext_101_32x8", ResNeXt, config={"units": [3, 4, 23, 3], "base_width": 8})
242
+ registry.register_model_config("resnext_101_64x4", ResNeXt, config={"units": [3, 4, 23, 3], "groups": 64})
243
+
244
+ # ResNeXt-D variants (From: Bag of Tricks for Image Classification with Convolutional Neural Networks)
245
+ registry.register_model_config(
246
+ "resnext_d_50", ResNeXt, config={"units": [3, 4, 6, 3], "deep_stem": True, "avg_down": True}
247
+ )
248
+ registry.register_model_config(
249
+ "resnext_d_101", ResNeXt, config={"units": [3, 4, 23, 3], "deep_stem": True, "avg_down": True}
250
+ )
251
+ registry.register_model_config(
252
+ "resnext_d_152", ResNeXt, config={"units": [3, 8, 36, 3], "deep_stem": True, "avg_down": True}
253
+ )
@@ -57,3 +57,49 @@ registry.register_model_config(
57
57
  SE_ResNet_v1,
58
58
  config={"bottle_neck": True, "filter_list": [64, 256, 512, 1024, 2048], "units": [3, 30, 48, 8]},
59
59
  )
60
+
61
+ # SE-ResNet-D variants (From: Bag of Tricks for Image Classification with Convolutional Neural Networks)
62
+ registry.register_model_config(
63
+ "se_resnet_d_50",
64
+ SE_ResNet_v1,
65
+ config={
66
+ "bottle_neck": True,
67
+ "filter_list": [64, 256, 512, 1024, 2048],
68
+ "units": [3, 4, 6, 3],
69
+ "deep_stem": True,
70
+ "avg_down": True,
71
+ },
72
+ )
73
+ registry.register_model_config(
74
+ "se_resnet_d_101",
75
+ SE_ResNet_v1,
76
+ config={
77
+ "bottle_neck": True,
78
+ "filter_list": [64, 256, 512, 1024, 2048],
79
+ "units": [3, 4, 23, 3],
80
+ "deep_stem": True,
81
+ "avg_down": True,
82
+ },
83
+ )
84
+ registry.register_model_config(
85
+ "se_resnet_d_152",
86
+ SE_ResNet_v1,
87
+ config={
88
+ "bottle_neck": True,
89
+ "filter_list": [64, 256, 512, 1024, 2048],
90
+ "units": [3, 8, 36, 3],
91
+ "deep_stem": True,
92
+ "avg_down": True,
93
+ },
94
+ )
95
+ registry.register_model_config(
96
+ "se_resnet_d_200",
97
+ SE_ResNet_v1,
98
+ config={
99
+ "bottle_neck": True,
100
+ "filter_list": [64, 256, 512, 1024, 2048],
101
+ "units": [3, 24, 36, 3],
102
+ "deep_stem": True,
103
+ "avg_down": True,
104
+ },
105
+ )
birder/net/se_resnext.py CHANGED
@@ -25,3 +25,6 @@ class SE_ResNeXt(ResNeXt):
25
25
  registry.register_model_config("se_resnext_50", SE_ResNeXt, config={"units": [3, 4, 6, 3]})
26
26
  registry.register_model_config("se_resnext_101", SE_ResNeXt, config={"units": [3, 4, 23, 3]})
27
27
  registry.register_model_config("se_resnext_152", SE_ResNeXt, config={"units": [3, 8, 36, 3]})
28
+
29
+ registry.register_model_config("se_resnext_101_32x8", SE_ResNeXt, config={"units": [3, 4, 23, 3], "base_width": 8})
30
+ registry.register_model_config("se_resnext_101_64x4", SE_ResNeXt, config={"units": [3, 4, 23, 3], "groups": 64})
birder/net/simple_vit.py CHANGED
@@ -79,7 +79,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
79
79
  dim=hidden_dim,
80
80
  num_special_tokens=self.num_special_tokens,
81
81
  )
82
- self.pos_embedding = nn.Parameter(pos_embedding, requires_grad=False)
82
+ self.pos_embedding = nn.Buffer(pos_embedding)
83
83
 
84
84
  self.encoder = Encoder(num_layers, num_heads, hidden_dim, mlp_dim, dropout=0.0, attention_dropout=0.0, dpr=dpr)
85
85
  self.norm = nn.LayerNorm(hidden_dim, eps=1e-6)
@@ -203,7 +203,7 @@ class Simple_ViT(PreTrainEncoder, MaskedTokenOmissionMixin):
203
203
  dim=self.hidden_dim,
204
204
  num_special_tokens=self.num_special_tokens,
205
205
  )
206
- self.pos_embedding = nn.Parameter(pos_embedding, requires_grad=False)
206
+ self.pos_embedding = nn.Buffer(pos_embedding)
207
207
 
208
208
  def set_causal_attention(self, is_causal: bool = True) -> None:
209
209
  self.encoder.set_causal_attention(is_causal)
birder/net/vit.py CHANGED
@@ -1588,21 +1588,6 @@ registry.register_weights(
1588
1588
  "net": {"network": "vit_l16", "tag": "mim"},
1589
1589
  },
1590
1590
  )
1591
- registry.register_weights(
1592
- "vit_l16_mim-eu-common",
1593
- {
1594
- "url": "https://huggingface.co/birder-project/vit_l16_mim-eu-common/resolve/main",
1595
- "description": "ViT l16 model with MIM pretraining, then fine-tuned on the eu-common dataset",
1596
- "resolution": (256, 256),
1597
- "formats": {
1598
- "pt": {
1599
- "file_size": 1160.1,
1600
- "sha256": "3b7235b90f76fb1e0e36d4c4111777a4cc4e4500552fe840c51170b208310d16",
1601
- },
1602
- },
1603
- "net": {"network": "vit_l16", "tag": "mim-eu-common"},
1604
- },
1605
- )
1606
1591
  registry.register_weights( # BioCLIP v2: https://arxiv.org/abs/2505.23883
1607
1592
  "vit_l14_pn_bioclip-v2",
1608
1593
  {