birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +24 -0
  3. birder/common/training_utils.py +338 -41
  4. birder/data/collators/detection.py +11 -3
  5. birder/data/dataloader/webdataset.py +12 -2
  6. birder/data/datasets/coco.py +8 -10
  7. birder/data/transforms/detection.py +30 -13
  8. birder/inference/detection.py +108 -4
  9. birder/inference/wbf.py +226 -0
  10. birder/kernels/load_kernel.py +16 -11
  11. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  12. birder/net/__init__.py +8 -0
  13. birder/net/cait.py +4 -3
  14. birder/net/convnext_v1.py +5 -0
  15. birder/net/crossformer.py +33 -30
  16. birder/net/crossvit.py +4 -3
  17. birder/net/deit.py +3 -3
  18. birder/net/deit3.py +3 -3
  19. birder/net/detection/deformable_detr.py +2 -5
  20. birder/net/detection/detr.py +2 -5
  21. birder/net/detection/efficientdet.py +67 -93
  22. birder/net/detection/fcos.py +2 -7
  23. birder/net/detection/retinanet.py +2 -7
  24. birder/net/detection/rt_detr_v1.py +2 -0
  25. birder/net/detection/yolo_anchors.py +205 -0
  26. birder/net/detection/yolo_v2.py +25 -24
  27. birder/net/detection/yolo_v3.py +39 -40
  28. birder/net/detection/yolo_v4.py +28 -26
  29. birder/net/detection/yolo_v4_tiny.py +24 -20
  30. birder/net/efficientformer_v1.py +15 -9
  31. birder/net/efficientformer_v2.py +39 -29
  32. birder/net/efficientvit_msft.py +9 -7
  33. birder/net/fasternet.py +1 -1
  34. birder/net/fastvit.py +1 -0
  35. birder/net/flexivit.py +5 -4
  36. birder/net/gc_vit.py +671 -0
  37. birder/net/hiera.py +12 -9
  38. birder/net/hornet.py +9 -7
  39. birder/net/iformer.py +8 -6
  40. birder/net/levit.py +42 -30
  41. birder/net/lit_v1.py +472 -0
  42. birder/net/lit_v1_tiny.py +357 -0
  43. birder/net/lit_v2.py +436 -0
  44. birder/net/maxvit.py +67 -55
  45. birder/net/mobilenet_v4_hybrid.py +1 -1
  46. birder/net/mobileone.py +1 -0
  47. birder/net/mvit_v2.py +13 -12
  48. birder/net/pit.py +4 -3
  49. birder/net/pvt_v1.py +4 -1
  50. birder/net/repghost.py +1 -0
  51. birder/net/repvgg.py +1 -0
  52. birder/net/repvit.py +1 -0
  53. birder/net/resnet_v1.py +1 -1
  54. birder/net/resnext.py +67 -25
  55. birder/net/rope_deit3.py +5 -3
  56. birder/net/rope_flexivit.py +7 -4
  57. birder/net/rope_vit.py +10 -5
  58. birder/net/se_resnet_v1.py +46 -0
  59. birder/net/se_resnext.py +3 -0
  60. birder/net/simple_vit.py +11 -8
  61. birder/net/swin_transformer_v1.py +71 -68
  62. birder/net/swin_transformer_v2.py +38 -31
  63. birder/net/tiny_vit.py +20 -10
  64. birder/net/transnext.py +38 -28
  65. birder/net/vit.py +5 -19
  66. birder/net/vit_parallel.py +5 -4
  67. birder/net/vit_sam.py +38 -37
  68. birder/net/vovnet_v1.py +15 -0
  69. birder/net/vovnet_v2.py +31 -1
  70. birder/ops/msda.py +108 -43
  71. birder/ops/swattention.py +124 -61
  72. birder/results/detection.py +4 -0
  73. birder/scripts/benchmark.py +110 -32
  74. birder/scripts/predict.py +8 -0
  75. birder/scripts/predict_detection.py +18 -11
  76. birder/scripts/train.py +48 -46
  77. birder/scripts/train_barlow_twins.py +44 -45
  78. birder/scripts/train_byol.py +44 -45
  79. birder/scripts/train_capi.py +50 -49
  80. birder/scripts/train_data2vec.py +45 -47
  81. birder/scripts/train_data2vec2.py +45 -47
  82. birder/scripts/train_detection.py +83 -50
  83. birder/scripts/train_dino_v1.py +60 -47
  84. birder/scripts/train_dino_v2.py +86 -52
  85. birder/scripts/train_dino_v2_dist.py +84 -50
  86. birder/scripts/train_franca.py +51 -52
  87. birder/scripts/train_i_jepa.py +45 -47
  88. birder/scripts/train_ibot.py +51 -53
  89. birder/scripts/train_kd.py +194 -76
  90. birder/scripts/train_mim.py +44 -45
  91. birder/scripts/train_mmcr.py +44 -45
  92. birder/scripts/train_rotnet.py +45 -46
  93. birder/scripts/train_simclr.py +44 -45
  94. birder/scripts/train_vicreg.py +44 -45
  95. birder/tools/auto_anchors.py +20 -1
  96. birder/tools/convert_model.py +18 -15
  97. birder/tools/det_results.py +114 -2
  98. birder/tools/pack.py +172 -103
  99. birder/tools/quantize_model.py +73 -67
  100. birder/tools/show_det_iterator.py +10 -1
  101. birder/version.py +1 -1
  102. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
  103. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
  104. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  105. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  106. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  107. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
birder/net/hornet.py CHANGED
@@ -332,13 +332,15 @@ class HorNet(DetectorBackbone):
332
332
  for m in module.modules():
333
333
  if isinstance(m, HorBlock):
334
334
  if isinstance(m.gn_conv.dwconv, GlobalLocalFilter):
335
- weight = m.gn_conv.dwconv.complex_weight
336
- weight = F.interpolate(
337
- weight.permute(3, 0, 1, 2),
338
- size=(gn_conv_h[i], gn_conv_w[i]),
339
- mode="bilinear",
340
- align_corners=True,
341
- ).permute(1, 2, 3, 0)
335
+ with torch.no_grad():
336
+ weight = m.gn_conv.dwconv.complex_weight
337
+ weight = F.interpolate(
338
+ weight.permute(3, 0, 1, 2),
339
+ size=(gn_conv_h[i], gn_conv_w[i]),
340
+ mode="bilinear",
341
+ align_corners=True,
342
+ ).permute(1, 2, 3, 0)
343
+
342
344
  m.gn_conv.dwconv.complex_weight = nn.Parameter(weight)
343
345
 
344
346
 
birder/net/iformer.py CHANGED
@@ -477,12 +477,14 @@ class iFormer(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
477
477
  resolution = (new_size[0] // 4, new_size[1] // 4)
478
478
  for stage in self.body.modules():
479
479
  if isinstance(stage, InceptionTransformerStage):
480
- orig_dtype = stage.pos_embed.dtype
481
- pos_embedding = stage.pos_embed.float()
482
- pos_embedding = F.interpolate(
483
- pos_embedding.permute(0, 3, 1, 2), size=resolution, mode="bilinear"
484
- ).permute(0, 2, 3, 1)
485
- pos_embedding = pos_embedding.to(orig_dtype)
480
+ with torch.no_grad():
481
+ orig_dtype = stage.pos_embed.dtype
482
+ pos_embedding = stage.pos_embed.float()
483
+ pos_embedding = F.interpolate(
484
+ pos_embedding.permute(0, 3, 1, 2), size=resolution, mode="bilinear"
485
+ ).permute(0, 2, 3, 1)
486
+ pos_embedding = pos_embedding.to(orig_dtype)
487
+
486
488
  stage.pos_embed = nn.Parameter(pos_embedding)
487
489
  stage.resolution = resolution
488
490
  resolution = (resolution[0] // 2, resolution[1] // 2)
birder/net/levit.py CHANGED
@@ -454,42 +454,54 @@ class LeViT(BaseNet):
454
454
  # Update Subsample resolution
455
455
  m.q[0].resolution = resolution
456
456
 
457
- # Interpolate attention biases
458
- m.attention_biases = nn.Parameter(
459
- interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
460
- )
461
-
462
- # Rebuild attention bias indices
463
- k_pos = torch.stack(
464
- torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]), indexing="ij")
465
- ).flatten(1)
466
- q_pos = torch.stack(
467
- torch.meshgrid(
468
- torch.arange(0, resolution[0], step=m.stride),
469
- torch.arange(0, resolution[1], step=m.stride),
470
- indexing="ij",
457
+ with torch.no_grad():
458
+ # Interpolate attention biases
459
+ m.attention_biases = nn.Parameter(
460
+ interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
471
461
  )
472
- ).flatten(1)
473
- rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
474
- rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
475
- m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
462
+
463
+ # Rebuild attention bias indices
464
+ device = m.attention_biases.device
465
+ k_pos = torch.stack(
466
+ torch.meshgrid(
467
+ torch.arange(resolution[0], device=device),
468
+ torch.arange(resolution[1], device=device),
469
+ indexing="ij",
470
+ )
471
+ ).flatten(1)
472
+ q_pos = torch.stack(
473
+ torch.meshgrid(
474
+ torch.arange(0, resolution[0], step=m.stride, device=device),
475
+ torch.arange(0, resolution[1], step=m.stride, device=device),
476
+ indexing="ij",
477
+ )
478
+ ).flatten(1)
479
+ rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
480
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
481
+ m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
476
482
 
477
483
  old_resolution = ((old_resolution[0] - 1) // 2 + 1, (old_resolution[1] - 1) // 2 + 1)
478
484
  resolution = ((resolution[0] - 1) // 2 + 1, (resolution[1] - 1) // 2 + 1)
479
485
 
480
486
  elif isinstance(m, Attention):
481
- # Interpolate attention biases
482
- m.attention_biases = nn.Parameter(
483
- interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
484
- )
485
-
486
- # Rebuild attention bias indices
487
- pos = torch.stack(
488
- torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]), indexing="ij")
489
- ).flatten(1)
490
- rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
491
- rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
492
- m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
487
+ with torch.no_grad():
488
+ # Interpolate attention biases
489
+ m.attention_biases = nn.Parameter(
490
+ interpolate_attention_bias(m.attention_biases, old_resolution, resolution)
491
+ )
492
+
493
+ # Rebuild attention bias indices
494
+ device = m.attention_biases.device
495
+ pos = torch.stack(
496
+ torch.meshgrid(
497
+ torch.arange(resolution[0], device=device),
498
+ torch.arange(resolution[1], device=device),
499
+ indexing="ij",
500
+ )
501
+ ).flatten(1)
502
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
503
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
504
+ m.attention_bias_idxs = nn.Buffer(rel_pos, persistent=False)
493
505
 
494
506
 
495
507
  registry.register_model_config(
birder/net/lit_v1.py ADDED
@@ -0,0 +1,472 @@
1
+ """
2
+ LIT v1, adapted from
3
+ https://github.com/ziplab/LIT/blob/main/classification/code_for_lit_s_m_b/models/lit.py
4
+
5
+ Paper "Less is More: Pay Less Attention in Vision Transformers", https://arxiv.org/abs/2105.14217
6
+ """
7
+
8
+ # Reference license: Apache-2.0
9
+
10
+ import math
11
+ from collections import OrderedDict
12
+ from typing import Any
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+ from torchvision.ops import DeformConv2d
19
+ from torchvision.ops import Permute
20
+ from torchvision.ops import StochasticDepth
21
+
22
+ from birder.model_registry import registry
23
+ from birder.net.base import DetectorBackbone
24
+
25
+
26
+ def build_relative_position_index(input_resolution: tuple[int, int], device: torch.device) -> torch.Tensor:
27
+ coords_h = torch.arange(input_resolution[0], device=device)
28
+ coords_w = torch.arange(input_resolution[1], device=device)
29
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
30
+ coords_flatten = torch.flatten(coords, 1)
31
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
32
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
33
+ relative_coords[:, :, 0] += input_resolution[0] - 1
34
+ relative_coords[:, :, 1] += input_resolution[1] - 1
35
+ relative_coords[:, :, 0] *= 2 * input_resolution[1] - 1
36
+
37
+ return relative_coords.sum(-1).flatten()
38
+
39
+
40
+ def interpolate_rel_pos_bias_table(
41
+ rel_pos_bias_table: torch.Tensor, base_resolution: tuple[int, int], new_resolution: tuple[int, int]
42
+ ) -> torch.Tensor:
43
+ if new_resolution == base_resolution:
44
+ return rel_pos_bias_table
45
+
46
+ (base_h, base_w) = base_resolution
47
+ num_heads = rel_pos_bias_table.size(1)
48
+ orig_dtype = rel_pos_bias_table.dtype
49
+ bias_table = rel_pos_bias_table.float()
50
+ bias_table = bias_table.reshape(2 * base_h - 1, 2 * base_w - 1, num_heads).permute(2, 0, 1).unsqueeze(0)
51
+ bias_table = F.interpolate(
52
+ bias_table,
53
+ size=(2 * new_resolution[0] - 1, 2 * new_resolution[1] - 1),
54
+ mode="bicubic",
55
+ align_corners=False,
56
+ )
57
+ bias_table = bias_table.squeeze(0).permute(1, 2, 0).reshape(-1, num_heads)
58
+
59
+ return bias_table.to(orig_dtype)
60
+
61
+
62
+ class MLP(nn.Module):
63
+ def __init__(self, in_features: int, hidden_features: int) -> None:
64
+ super().__init__()
65
+ self.fc1 = nn.Linear(in_features, hidden_features)
66
+ self.act = nn.GELU()
67
+ self.fc2 = nn.Linear(hidden_features, in_features)
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ x = self.fc1(x)
71
+ x = self.act(x)
72
+ x = self.fc2(x)
73
+
74
+ return x
75
+
76
+
77
+ class MLPBlock(nn.Module):
78
+ def __init__(self, dim: int, mlp_ratio: float, drop_path: float) -> None:
79
+ super().__init__()
80
+ self.norm = nn.LayerNorm(dim)
81
+ self.mlp = MLP(dim, int(dim * mlp_ratio))
82
+ self.drop_path = StochasticDepth(drop_path, mode="row")
83
+
84
+ def forward(self, x: torch.Tensor, _resolution: tuple[int, int]) -> torch.Tensor:
85
+ return x + self.drop_path(self.mlp(self.norm(x)))
86
+
87
+
88
+ class RelPosAttention(nn.Module):
89
+ def __init__(self, dim: int, input_resolution: tuple[int, int], num_heads: int) -> None:
90
+ super().__init__()
91
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
92
+
93
+ self.input_resolution = input_resolution
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ self.scale = head_dim**-0.5
97
+ self.dynamic_size = False
98
+
99
+ # Relative position bias table
100
+ bias_table = torch.zeros((2 * input_resolution[0] - 1) * (2 * input_resolution[1] - 1), num_heads)
101
+ self.relative_position_bias_table = nn.Parameter(bias_table)
102
+
103
+ # Get pair-wise relative position index for each token
104
+ relative_position_index = build_relative_position_index(input_resolution, device=bias_table.device)
105
+ self.relative_position_index = nn.Buffer(relative_position_index)
106
+
107
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
108
+ self.proj = nn.Linear(dim, dim)
109
+
110
+ # Weight initialization
111
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
112
+
113
+ def _get_rel_pos_bias(self, resolution: tuple[int, int]) -> torch.Tensor:
114
+ if self.dynamic_size is False or resolution == self.input_resolution:
115
+ N = self.input_resolution[0] * self.input_resolution[1]
116
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index].reshape(N, N, -1)
117
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
118
+ return relative_position_bias.unsqueeze(0)
119
+
120
+ bias_table = interpolate_rel_pos_bias_table(
121
+ self.relative_position_bias_table,
122
+ self.input_resolution,
123
+ resolution,
124
+ )
125
+ relative_position_index = build_relative_position_index(resolution, device=bias_table.device)
126
+ N = resolution[0] * resolution[1]
127
+ relative_position_bias = bias_table[relative_position_index].reshape(N, N, -1)
128
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
129
+
130
+ return relative_position_bias.unsqueeze(0)
131
+
132
+ def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
133
+ (B, N, C) = x.size()
134
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
135
+ (q, k, v) = qkv.unbind(0)
136
+
137
+ attn = (q * self.scale) @ k.transpose(-2, -1)
138
+ attn = attn + self._get_rel_pos_bias(resolution)
139
+ attn = F.softmax(attn, dim=-1)
140
+
141
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
142
+ x = self.proj(x)
143
+
144
+ return x
145
+
146
+
147
+ class MSABlock(nn.Module):
148
+ def __init__(
149
+ self, dim: int, input_resolution: tuple[int, int], num_heads: int, mlp_ratio: float, drop_path: float
150
+ ) -> None:
151
+ super().__init__()
152
+ self.norm1 = nn.LayerNorm(dim)
153
+ self.attn = RelPosAttention(dim, input_resolution, num_heads)
154
+ self.drop_path1 = StochasticDepth(drop_path, mode="row")
155
+ self.norm2 = nn.LayerNorm(dim)
156
+ self.mlp = MLP(dim, int(dim * mlp_ratio))
157
+ self.drop_path2 = StochasticDepth(drop_path, mode="row")
158
+
159
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
160
+ self.attn.dynamic_size = dynamic_size
161
+
162
+ def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
163
+ x = x + self.drop_path1(self.attn(self.norm1(x), resolution))
164
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
165
+
166
+ return x
167
+
168
+
169
+ class DeformablePatchMerging(nn.Module):
170
+ def __init__(self, in_dim: int, out_dim: int) -> None:
171
+ super().__init__()
172
+ kernel_size = 2
173
+
174
+ self.offset_conv = nn.Conv2d(
175
+ in_dim,
176
+ 2 * kernel_size * kernel_size,
177
+ kernel_size=(kernel_size, kernel_size),
178
+ stride=(kernel_size, kernel_size),
179
+ padding=(0, 0),
180
+ bias=True,
181
+ )
182
+ self.deform_conv = DeformConv2d(
183
+ in_dim,
184
+ out_dim,
185
+ kernel_size=(kernel_size, kernel_size),
186
+ stride=(kernel_size, kernel_size),
187
+ padding=(0, 0),
188
+ bias=True,
189
+ )
190
+ self.norm = nn.BatchNorm2d(out_dim)
191
+ self.act = nn.GELU()
192
+
193
+ # Initialize offsets to zero (start with regular convolution behavior)
194
+ nn.init.zeros_(self.offset_conv.weight)
195
+ nn.init.zeros_(self.offset_conv.bias)
196
+
197
+ def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
198
+ (H, W) = resolution
199
+ (B, _, C) = x.size()
200
+
201
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
202
+
203
+ offset = self.offset_conv(x)
204
+ x = self.deform_conv(x, offset)
205
+
206
+ x = self.norm(x)
207
+ x = self.act(x)
208
+
209
+ (B, C, H, W) = x.size()
210
+ x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
211
+
212
+ return (x, H, W)
213
+
214
+
215
+ class IdentityDownsample(nn.Module):
216
+ def forward(self, x: torch.Tensor, resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
217
+ return (x, resolution[0], resolution[1])
218
+
219
+
220
+ class LITStage(nn.Module):
221
+ def __init__(
222
+ self,
223
+ in_dim: int,
224
+ out_dim: int,
225
+ resolution: tuple[int, int],
226
+ depth: int,
227
+ num_heads: int,
228
+ mlp_ratio: float,
229
+ has_msa: bool,
230
+ downsample: bool,
231
+ drop_path: list[float],
232
+ ) -> None:
233
+ super().__init__()
234
+ if downsample is True:
235
+ self.downsample = DeformablePatchMerging(in_dim, out_dim)
236
+ resolution = (resolution[0] // 2, resolution[1] // 2)
237
+ else:
238
+ self.downsample = IdentityDownsample()
239
+
240
+ blocks: list[nn.Module] = []
241
+ for i in range(depth):
242
+ if has_msa is True:
243
+ blocks.append(MSABlock(out_dim, resolution, num_heads, mlp_ratio, drop_path[i]))
244
+ else:
245
+ blocks.append(MLPBlock(out_dim, mlp_ratio, drop_path[i]))
246
+
247
+ self.blocks = nn.ModuleList(blocks)
248
+
249
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
250
+ for block in self.blocks:
251
+ if isinstance(block, MSABlock):
252
+ block.set_dynamic_size(dynamic_size)
253
+
254
+ def forward(self, x: torch.Tensor, input_resolution: tuple[int, int]) -> tuple[torch.Tensor, int, int]:
255
+ (x, H, W) = self.downsample(x, input_resolution)
256
+ for block in self.blocks:
257
+ x = block(x, (H, W))
258
+
259
+ return (x, H, W)
260
+
261
+
262
+ # pylint: disable=invalid-name
263
+ class LIT_v1(DetectorBackbone):
264
+ block_group_regex = r"body\.stage(\d+)\.blocks.(\d+)"
265
+
266
+ def __init__(
267
+ self,
268
+ input_channels: int,
269
+ num_classes: int,
270
+ *,
271
+ config: Optional[dict[str, Any]] = None,
272
+ size: Optional[tuple[int, int]] = None,
273
+ ) -> None:
274
+ super().__init__(input_channels, num_classes, config=config, size=size)
275
+ assert self.config is not None, "must set config"
276
+
277
+ patch_size = 4
278
+ embed_dim: int = self.config["embed_dim"]
279
+ depths: list[int] = self.config["depths"]
280
+ num_heads: list[int] = self.config["num_heads"]
281
+ has_msa: list[bool] = self.config["has_msa"]
282
+ drop_path_rate: float = self.config["drop_path_rate"]
283
+
284
+ num_stages = len(depths)
285
+
286
+ # Patch embedding
287
+ self.stem = nn.Sequential(
288
+ nn.Conv2d(
289
+ self.input_channels,
290
+ embed_dim,
291
+ kernel_size=(patch_size, patch_size),
292
+ stride=(patch_size, patch_size),
293
+ padding=(0, 0),
294
+ bias=True,
295
+ ),
296
+ Permute([0, 2, 3, 1]),
297
+ nn.LayerNorm(embed_dim),
298
+ )
299
+
300
+ # Stochastic depth
301
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
302
+
303
+ stages: OrderedDict[str, nn.Module] = OrderedDict()
304
+ return_channels: list[int] = []
305
+ prev_dim = embed_dim
306
+ resolution = (self.size[0] // patch_size, self.size[1] // patch_size)
307
+ for i_stage in range(num_stages):
308
+ in_dim = prev_dim
309
+ out_dim = in_dim * 2 if i_stage > 0 else in_dim
310
+ stage = LITStage(
311
+ in_dim,
312
+ out_dim,
313
+ resolution,
314
+ depth=depths[i_stage],
315
+ num_heads=num_heads[i_stage],
316
+ mlp_ratio=4.0,
317
+ has_msa=has_msa[i_stage],
318
+ downsample=i_stage > 0,
319
+ drop_path=dpr[i_stage],
320
+ )
321
+ stages[f"stage{i_stage + 1}"] = stage
322
+
323
+ if i_stage > 0:
324
+ resolution = (resolution[0] // 2, resolution[1] // 2)
325
+
326
+ prev_dim = out_dim
327
+ return_channels.append(out_dim)
328
+
329
+ num_features = embed_dim * (2 ** (num_stages - 1))
330
+ self.body = nn.ModuleDict(stages)
331
+ self.features = nn.Sequential(
332
+ nn.LayerNorm(num_features),
333
+ Permute([0, 2, 1]),
334
+ nn.AdaptiveAvgPool1d(output_size=1),
335
+ nn.Flatten(1),
336
+ )
337
+ self.return_channels = return_channels
338
+ self.embedding_size = num_features
339
+ self.classifier = self.create_classifier()
340
+
341
+ self.patch_size = patch_size
342
+
343
+ # Weight initialization
344
+ for name, m in self.named_modules():
345
+ if isinstance(m, nn.Linear):
346
+ nn.init.trunc_normal_(m.weight, std=0.02)
347
+ if m.bias is not None:
348
+ nn.init.zeros_(m.bias)
349
+ elif isinstance(m, nn.LayerNorm):
350
+ nn.init.ones_(m.weight)
351
+ nn.init.zeros_(m.bias)
352
+ elif isinstance(m, nn.Conv2d):
353
+ if name.endswith("offset_conv") is True:
354
+ continue
355
+
356
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
357
+ fan_out //= m.groups
358
+ nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
359
+ if m.bias is not None:
360
+ nn.init.zeros_(m.bias)
361
+
362
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
363
+ x = self.stem(x)
364
+ (B, H, W, C) = x.size()
365
+ x = x.reshape(B, H * W, C)
366
+
367
+ out = {}
368
+ for name, stage in self.body.items():
369
+ (x, H, W) = stage(x, (H, W))
370
+ if name in self.return_stages:
371
+ features = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
372
+ out[name] = features
373
+
374
+ return out
375
+
376
+ def freeze_stages(self, up_to_stage: int) -> None:
377
+ for param in self.stem.parameters():
378
+ param.requires_grad = False
379
+
380
+ for idx, stage in enumerate(self.body.values()):
381
+ if idx >= up_to_stage:
382
+ break
383
+
384
+ for param in stage.parameters():
385
+ param.requires_grad = False
386
+
387
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
388
+ x = self.stem(x)
389
+ (B, H, W, C) = x.size()
390
+ x = x.reshape(B, H * W, C)
391
+ for stage in self.body.values():
392
+ (x, H, W) = stage(x, (H, W))
393
+
394
+ return x
395
+
396
+ def embedding(self, x: torch.Tensor) -> torch.Tensor:
397
+ x = self.forward_features(x)
398
+ return self.features(x)
399
+
400
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
401
+ super().set_dynamic_size(dynamic_size)
402
+ for stage in self.body.values():
403
+ stage.set_dynamic_size(dynamic_size)
404
+
405
+ def adjust_size(self, new_size: tuple[int, int]) -> None:
406
+ if new_size == self.size:
407
+ return
408
+
409
+ super().adjust_size(new_size)
410
+
411
+ new_patches_resolution = (new_size[0] // self.patch_size, new_size[1] // self.patch_size)
412
+
413
+ (h, w) = new_patches_resolution
414
+ for stage in self.body.values():
415
+ if not isinstance(stage.downsample, IdentityDownsample):
416
+ h = h // 2
417
+ w = w // 2
418
+
419
+ out_resolution = (h, w)
420
+ for block in stage.blocks:
421
+ if isinstance(block, MSABlock):
422
+ attn = block.attn
423
+ if out_resolution == attn.input_resolution:
424
+ continue
425
+
426
+ with torch.no_grad():
427
+ bias_table = interpolate_rel_pos_bias_table(
428
+ attn.relative_position_bias_table,
429
+ attn.input_resolution,
430
+ out_resolution,
431
+ )
432
+
433
+ attn.input_resolution = out_resolution
434
+ attn.relative_position_bias_table = nn.Parameter(bias_table)
435
+ attn.relative_position_index = nn.Buffer(
436
+ build_relative_position_index(out_resolution, device=bias_table.device)
437
+ )
438
+
439
+
440
+ registry.register_model_config(
441
+ "lit_v1_s",
442
+ LIT_v1,
443
+ config={
444
+ "embed_dim": 96,
445
+ "depths": [2, 2, 6, 2],
446
+ "num_heads": [3, 6, 12, 24],
447
+ "has_msa": [False, False, True, True],
448
+ "drop_path_rate": 0.1,
449
+ },
450
+ )
451
+ registry.register_model_config(
452
+ "lit_v1_m",
453
+ LIT_v1,
454
+ config={
455
+ "embed_dim": 96,
456
+ "depths": [2, 2, 18, 2],
457
+ "num_heads": [3, 6, 12, 24],
458
+ "has_msa": [False, False, True, True],
459
+ "drop_path_rate": 0.2,
460
+ },
461
+ )
462
+ registry.register_model_config(
463
+ "lit_v1_b",
464
+ LIT_v1,
465
+ config={
466
+ "embed_dim": 128,
467
+ "depths": [2, 2, 18, 2],
468
+ "num_heads": [4, 8, 16, 32],
469
+ "has_msa": [False, False, True, True],
470
+ "drop_path_rate": 0.3,
471
+ },
472
+ )