birder 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. birder/common/lib.py +2 -9
  2. birder/common/training_cli.py +24 -0
  3. birder/common/training_utils.py +338 -41
  4. birder/data/collators/detection.py +11 -3
  5. birder/data/dataloader/webdataset.py +12 -2
  6. birder/data/datasets/coco.py +8 -10
  7. birder/data/transforms/detection.py +30 -13
  8. birder/inference/detection.py +108 -4
  9. birder/inference/wbf.py +226 -0
  10. birder/kernels/load_kernel.py +16 -11
  11. birder/kernels/soft_nms/soft_nms.cpp +17 -18
  12. birder/net/__init__.py +8 -0
  13. birder/net/cait.py +4 -3
  14. birder/net/convnext_v1.py +5 -0
  15. birder/net/crossformer.py +33 -30
  16. birder/net/crossvit.py +4 -3
  17. birder/net/deit.py +3 -3
  18. birder/net/deit3.py +3 -3
  19. birder/net/detection/deformable_detr.py +2 -5
  20. birder/net/detection/detr.py +2 -5
  21. birder/net/detection/efficientdet.py +67 -93
  22. birder/net/detection/fcos.py +2 -7
  23. birder/net/detection/retinanet.py +2 -7
  24. birder/net/detection/rt_detr_v1.py +2 -0
  25. birder/net/detection/yolo_anchors.py +205 -0
  26. birder/net/detection/yolo_v2.py +25 -24
  27. birder/net/detection/yolo_v3.py +39 -40
  28. birder/net/detection/yolo_v4.py +28 -26
  29. birder/net/detection/yolo_v4_tiny.py +24 -20
  30. birder/net/efficientformer_v1.py +15 -9
  31. birder/net/efficientformer_v2.py +39 -29
  32. birder/net/efficientvit_msft.py +9 -7
  33. birder/net/fasternet.py +1 -1
  34. birder/net/fastvit.py +1 -0
  35. birder/net/flexivit.py +5 -4
  36. birder/net/gc_vit.py +671 -0
  37. birder/net/hiera.py +12 -9
  38. birder/net/hornet.py +9 -7
  39. birder/net/iformer.py +8 -6
  40. birder/net/levit.py +42 -30
  41. birder/net/lit_v1.py +472 -0
  42. birder/net/lit_v1_tiny.py +357 -0
  43. birder/net/lit_v2.py +436 -0
  44. birder/net/maxvit.py +67 -55
  45. birder/net/mobilenet_v4_hybrid.py +1 -1
  46. birder/net/mobileone.py +1 -0
  47. birder/net/mvit_v2.py +13 -12
  48. birder/net/pit.py +4 -3
  49. birder/net/pvt_v1.py +4 -1
  50. birder/net/repghost.py +1 -0
  51. birder/net/repvgg.py +1 -0
  52. birder/net/repvit.py +1 -0
  53. birder/net/resnet_v1.py +1 -1
  54. birder/net/resnext.py +67 -25
  55. birder/net/rope_deit3.py +5 -3
  56. birder/net/rope_flexivit.py +7 -4
  57. birder/net/rope_vit.py +10 -5
  58. birder/net/se_resnet_v1.py +46 -0
  59. birder/net/se_resnext.py +3 -0
  60. birder/net/simple_vit.py +11 -8
  61. birder/net/swin_transformer_v1.py +71 -68
  62. birder/net/swin_transformer_v2.py +38 -31
  63. birder/net/tiny_vit.py +20 -10
  64. birder/net/transnext.py +38 -28
  65. birder/net/vit.py +5 -19
  66. birder/net/vit_parallel.py +5 -4
  67. birder/net/vit_sam.py +38 -37
  68. birder/net/vovnet_v1.py +15 -0
  69. birder/net/vovnet_v2.py +31 -1
  70. birder/ops/msda.py +108 -43
  71. birder/ops/swattention.py +124 -61
  72. birder/results/detection.py +4 -0
  73. birder/scripts/benchmark.py +110 -32
  74. birder/scripts/predict.py +8 -0
  75. birder/scripts/predict_detection.py +18 -11
  76. birder/scripts/train.py +48 -46
  77. birder/scripts/train_barlow_twins.py +44 -45
  78. birder/scripts/train_byol.py +44 -45
  79. birder/scripts/train_capi.py +50 -49
  80. birder/scripts/train_data2vec.py +45 -47
  81. birder/scripts/train_data2vec2.py +45 -47
  82. birder/scripts/train_detection.py +83 -50
  83. birder/scripts/train_dino_v1.py +60 -47
  84. birder/scripts/train_dino_v2.py +86 -52
  85. birder/scripts/train_dino_v2_dist.py +84 -50
  86. birder/scripts/train_franca.py +51 -52
  87. birder/scripts/train_i_jepa.py +45 -47
  88. birder/scripts/train_ibot.py +51 -53
  89. birder/scripts/train_kd.py +194 -76
  90. birder/scripts/train_mim.py +44 -45
  91. birder/scripts/train_mmcr.py +44 -45
  92. birder/scripts/train_rotnet.py +45 -46
  93. birder/scripts/train_simclr.py +44 -45
  94. birder/scripts/train_vicreg.py +44 -45
  95. birder/tools/auto_anchors.py +20 -1
  96. birder/tools/convert_model.py +18 -15
  97. birder/tools/det_results.py +114 -2
  98. birder/tools/pack.py +172 -103
  99. birder/tools/quantize_model.py +73 -67
  100. birder/tools/show_det_iterator.py +10 -1
  101. birder/version.py +1 -1
  102. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/METADATA +4 -3
  103. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/RECORD +107 -101
  104. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/WHEEL +0 -0
  105. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/entry_points.txt +0 -0
  106. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/licenses/LICENSE +0 -0
  107. {birder-0.2.2.dist-info → birder-0.3.0.dist-info}/top_level.txt +0 -0
birder/net/gc_vit.py ADDED
@@ -0,0 +1,671 @@
1
+ """
2
+ GC ViT, adapted from
3
+ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/gcvit.py
4
+
5
+ Paper "Global Context Vision Transformers", https://arxiv.org/abs/2206.09959
6
+ """
7
+
8
+ # Reference license: Apache-2.0
9
+
10
+ import math
11
+ from collections import OrderedDict
12
+ from typing import Any
13
+ from typing import Literal
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+ from torchvision.ops import MLP
20
+ from torchvision.ops import StochasticDepth
21
+
22
+ from birder.common.masking import mask_tensor
23
+ from birder.layers import LayerNorm2d
24
+ from birder.layers import LayerScale
25
+ from birder.model_registry import registry
26
+ from birder.net.base import DetectorBackbone
27
+ from birder.net.base import MaskedTokenRetentionMixin
28
+ from birder.net.base import PreTrainEncoder
29
+ from birder.net.base import TokenRetentionResultType
30
+
31
+
32
+ def window_partition(x: torch.Tensor, window_size: tuple[int, int]) -> torch.Tensor:
33
+ (B, H, W, C) = x.size()
34
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
35
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
36
+
37
+ return windows
38
+
39
+
40
+ def window_reverse(windows: torch.Tensor, window_size: tuple[int, int], H: int, W: int) -> torch.Tensor:
41
+ C = windows.size(-1)
42
+ x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
43
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
44
+
45
+ return x
46
+
47
+
48
+ def build_relative_position_index(window_size: tuple[int, int], device: torch.device) -> torch.Tensor:
49
+ coords_h = torch.arange(window_size[0], device=device)
50
+ coords_w = torch.arange(window_size[1], device=device)
51
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # (2, Wh, Ww)
52
+ coords_flatten = torch.flatten(coords, 1) # (2, Wh*Ww)
53
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # (2, Wh*Ww, Wh*Ww)
54
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # (Wh*Ww, Wh*Ww, 2)
55
+ relative_coords[:, :, 0] += window_size[0] - 1
56
+ relative_coords[:, :, 1] += window_size[1] - 1
57
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
58
+
59
+ return relative_coords.sum(-1).flatten() # (Wh*Ww*Wh*Ww,)
60
+
61
+
62
+ def interpolate_rel_pos_bias_table(
63
+ rel_pos_bias_table: torch.Tensor, base_window_size: tuple[int, int], new_window_size: tuple[int, int]
64
+ ) -> torch.Tensor:
65
+ if new_window_size == base_window_size:
66
+ return rel_pos_bias_table
67
+
68
+ (base_h, base_w) = base_window_size
69
+ num_heads = rel_pos_bias_table.size(1)
70
+ orig_dtype = rel_pos_bias_table.dtype
71
+ bias_table = rel_pos_bias_table.float()
72
+ bias_table = bias_table.view(2 * base_h - 1, 2 * base_w - 1, num_heads).permute(2, 0, 1).unsqueeze(0)
73
+ bias_table = F.interpolate(
74
+ bias_table,
75
+ size=(2 * new_window_size[0] - 1, 2 * new_window_size[1] - 1),
76
+ mode="bicubic",
77
+ align_corners=False,
78
+ )
79
+ bias_table = bias_table.squeeze(0).permute(1, 2, 0).reshape(-1, num_heads)
80
+
81
+ return bias_table.to(orig_dtype)
82
+
83
+
84
+ class SqueezeExcitation(nn.Module):
85
+ def __init__(self, in_channels: int, squeeze_channels: int) -> None:
86
+ super().__init__()
87
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
88
+ self.fc1 = nn.Conv2d(
89
+ in_channels, squeeze_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
90
+ )
91
+ self.act = nn.GELU()
92
+ self.fc2 = nn.Conv2d(
93
+ squeeze_channels, in_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
94
+ )
95
+ self.scale_act = nn.Sigmoid()
96
+
97
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
98
+ scale = self.avg_pool(x)
99
+ scale = self.fc1(scale)
100
+ scale = self.act(scale)
101
+ scale = self.fc2(scale)
102
+ scale = self.scale_act(scale)
103
+
104
+ return x * scale
105
+
106
+
107
+ class RelPosBias(nn.Module):
108
+ def __init__(self, window_size: tuple[int, int], num_heads: int) -> None:
109
+ super().__init__()
110
+ self.window_size = window_size
111
+
112
+ bias_table = torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
113
+ self.relative_position_bias_table = nn.Parameter(bias_table)
114
+ relative_position_index = build_relative_position_index(self.window_size, device=bias_table.device)
115
+ self.relative_position_index = nn.Buffer(relative_position_index)
116
+
117
+ # Weight initialization
118
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
119
+
120
+ def forward(self, window_size: tuple[int, int], dynamic_size: bool = False) -> torch.Tensor:
121
+ if dynamic_size is False or window_size == self.window_size:
122
+ N = self.window_size[0] * self.window_size[1]
123
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index].view(N, N, -1)
124
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
125
+ return relative_position_bias.unsqueeze(0)
126
+
127
+ bias_table = interpolate_rel_pos_bias_table(
128
+ self.relative_position_bias_table,
129
+ self.window_size,
130
+ window_size,
131
+ )
132
+ relative_position_index = build_relative_position_index(window_size, device=bias_table.device)
133
+ N = window_size[0] * window_size[1]
134
+ relative_position_bias = bias_table[relative_position_index].view(N, N, -1)
135
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
136
+
137
+ return relative_position_bias.unsqueeze(0)
138
+
139
+
140
+ class MBConvBlock(nn.Module):
141
+ def __init__(self, in_channels: int, out_channels: int) -> None:
142
+ super().__init__()
143
+ self.dw_conv = nn.Conv2d(
144
+ in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=in_channels, bias=False
145
+ )
146
+ self.act = nn.GELU()
147
+
148
+ squeeze_channels = max(1, int(in_channels * 0.25))
149
+ self.se = SqueezeExcitation(in_channels, squeeze_channels)
150
+
151
+ self.pw_conv = nn.Conv2d(
152
+ in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
153
+ )
154
+
155
+ if in_channels == out_channels:
156
+ self.has_residual = True
157
+ else:
158
+ self.has_residual = False
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ residual = x
162
+
163
+ x = self.dw_conv(x)
164
+ x = self.act(x)
165
+ x = self.se(x)
166
+ x = self.pw_conv(x)
167
+
168
+ if self.has_residual is True:
169
+ x = x + residual
170
+
171
+ return x
172
+
173
+
174
+ class FeatureBlock(nn.Module):
175
+ def __init__(self, dim: int, levels: int) -> None:
176
+ super().__init__()
177
+ reductions = levels
178
+ levels = max(1, levels)
179
+ layers = []
180
+ for _ in range(levels):
181
+ layers.append(MBConvBlock(dim, dim))
182
+ if reductions > 0:
183
+ layers.append(nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)))
184
+ reductions -= 1
185
+
186
+ self.blocks = nn.Sequential(*layers)
187
+
188
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
189
+ return self.blocks(x)
190
+
191
+
192
+ class Downsample2d(nn.Module):
193
+ def __init__(self, in_channels: int, out_channels: int) -> None:
194
+ super().__init__()
195
+
196
+ self.norm1 = LayerNorm2d(in_channels)
197
+ self.conv = MBConvBlock(in_channels, in_channels)
198
+ self.reduction = nn.Conv2d(
199
+ in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
200
+ )
201
+ self.norm2 = LayerNorm2d(out_channels)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ x = self.norm1(x)
205
+ x = self.conv(x)
206
+ x = self.reduction(x)
207
+ x = self.norm2(x)
208
+
209
+ return x
210
+
211
+
212
+ class Stem(nn.Module):
213
+ def __init__(self, in_channels: int, out_channels: int) -> None:
214
+ super().__init__()
215
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=True)
216
+ self.downsample = Downsample2d(out_channels, out_channels)
217
+
218
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
219
+ x = self.conv(x)
220
+ x = self.downsample(x)
221
+
222
+ return x
223
+
224
+
225
+ class WindowAttentionGlobal(nn.Module):
226
+ def __init__(self, dim: int, num_heads: int, window_size: tuple[int, int], use_global: bool) -> None:
227
+ super().__init__()
228
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
229
+
230
+ self.num_heads = num_heads
231
+ self.head_dim = dim // num_heads
232
+ self.scale = self.head_dim**-0.5
233
+ self.use_global = use_global
234
+
235
+ self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)
236
+ if self.use_global is True:
237
+ self.qkv = nn.Linear(dim, dim * 2, bias=True)
238
+ else:
239
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
240
+
241
+ self.proj = nn.Linear(dim, dim)
242
+
243
+ def forward(
244
+ self, x: torch.Tensor, q_global: torch.Tensor, window_size: tuple[int, int], dynamic_size: bool
245
+ ) -> torch.Tensor:
246
+ (B, N, C) = x.size()
247
+ if self.use_global is True:
248
+ kv = self.qkv(x)
249
+ kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
250
+ (k, v) = kv.unbind(0)
251
+
252
+ q_global = q_global.repeat(B // q_global.size(0), 1, 1, 1)
253
+ q = q_global.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
254
+
255
+ else:
256
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
257
+ (q, k, v) = qkv.unbind(0)
258
+
259
+ q = q * self.scale
260
+ attn = q @ k.transpose(-2, -1)
261
+ attn = attn + self.rel_pos(window_size, dynamic_size)
262
+ attn = attn.softmax(dim=-1)
263
+
264
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
265
+ x = self.proj(x)
266
+
267
+ return x
268
+
269
+
270
+ class GlobalContextVitBlock(nn.Module):
271
+ def __init__(
272
+ self,
273
+ dim: int,
274
+ num_heads: int,
275
+ window_size: tuple[int, int],
276
+ mlp_ratio: float,
277
+ use_global: bool,
278
+ layer_scale: Optional[float],
279
+ drop_path: float,
280
+ ) -> None:
281
+ super().__init__()
282
+ self.norm1 = nn.LayerNorm(dim)
283
+ self.attn = WindowAttentionGlobal(dim, num_heads, window_size, use_global)
284
+ if layer_scale is not None:
285
+ self.ls1 = LayerScale(dim, layer_scale)
286
+ else:
287
+ self.ls1 = nn.Identity()
288
+
289
+ self.drop_path1 = StochasticDepth(drop_path, mode="row")
290
+
291
+ self.norm2 = nn.LayerNorm(dim)
292
+ self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None)
293
+ if layer_scale is not None:
294
+ self.ls2 = LayerScale(dim, layer_scale)
295
+ else:
296
+ self.ls2 = nn.Identity()
297
+
298
+ self.drop_path2 = StochasticDepth(drop_path, mode="row")
299
+ self.dynamic_size = False
300
+
301
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
302
+ self.dynamic_size = dynamic_size
303
+
304
+ def _window_attn(self, x: torch.Tensor, q_global: torch.Tensor, window_size: tuple[int, int]) -> torch.Tensor:
305
+ (_, H, W, C) = x.size()
306
+
307
+ # Pad feature maps to multiples of window size for dynamic size support
308
+ pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
309
+ pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
310
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
311
+
312
+ # Resize global query to match window size if needed
313
+ (_, h_g, w_g, _) = q_global.size()
314
+ if h_g != window_size[0] or w_g != window_size[1]:
315
+ q_global = q_global.permute(0, 3, 1, 2)
316
+ q_global = F.interpolate(q_global, size=window_size, mode="bilinear", align_corners=False)
317
+ q_global = q_global.permute(0, 2, 3, 1)
318
+
319
+ (_, pad_h, pad_w, _) = x.size()
320
+ x_win = window_partition(x, window_size)
321
+ x_win = x_win.view(-1, window_size[0] * window_size[1], C)
322
+ attn_win = self.attn(x_win, q_global, window_size, self.dynamic_size)
323
+ x = window_reverse(attn_win, window_size, pad_h, pad_w)
324
+
325
+ # Unpad features
326
+ x = x[:, :H, :W, :].contiguous()
327
+
328
+ return x
329
+
330
+ def forward(self, x: torch.Tensor, q_global: torch.Tensor, window_size: tuple[int, int]) -> torch.Tensor:
331
+ x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global, window_size)))
332
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
333
+
334
+ return x
335
+
336
+
337
+ class GlobalContextVitStage(nn.Module):
338
+ def __init__(
339
+ self,
340
+ dim: int,
341
+ depth: int,
342
+ num_heads: int,
343
+ feat_size: tuple[int, int],
344
+ window_size: tuple[int, int],
345
+ downsample: bool,
346
+ mlp_ratio: float,
347
+ layer_scale: Optional[float],
348
+ stage_norm: bool,
349
+ drop_path: list[float],
350
+ ) -> None:
351
+ super().__init__()
352
+ if downsample is True:
353
+ self.downsample = Downsample2d(dim, dim * 2)
354
+ dim = dim * 2
355
+ feat_size = (math.ceil(feat_size[0] / 2), math.ceil(feat_size[1] / 2))
356
+ else:
357
+ self.downsample = nn.Identity()
358
+
359
+ self.window_size = window_size
360
+ self.window_ratio = (max(1, feat_size[0] // window_size[0]), max(1, feat_size[1] // window_size[1]))
361
+ self.dynamic_size = False
362
+
363
+ feat_levels = int(math.log2(min(feat_size) / min(window_size)))
364
+ self.global_block = FeatureBlock(dim, feat_levels)
365
+
366
+ self.blocks = nn.ModuleList(
367
+ [
368
+ GlobalContextVitBlock(
369
+ dim=dim,
370
+ num_heads=num_heads,
371
+ window_size=window_size,
372
+ mlp_ratio=mlp_ratio,
373
+ use_global=(idx % 2 != 0),
374
+ layer_scale=layer_scale,
375
+ drop_path=drop_path[idx],
376
+ )
377
+ for idx in range(depth)
378
+ ]
379
+ )
380
+ if stage_norm is True:
381
+ self.norm = nn.LayerNorm(dim)
382
+ else:
383
+ self.norm = nn.Identity()
384
+
385
+ def _get_window_size(self, feat_size: tuple[int, int]) -> tuple[int, int]:
386
+ if self.dynamic_size is False:
387
+ return self.window_size
388
+
389
+ window_h = max(1, feat_size[0] // self.window_ratio[0])
390
+ window_w = max(1, feat_size[1] // self.window_ratio[1])
391
+ return (window_h, window_w)
392
+
393
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
394
+ x = self.downsample(x)
395
+ window_size = self._get_window_size((x.size(2), x.size(3)))
396
+ global_query = self.global_block(x)
397
+
398
+ x = x.permute(0, 2, 3, 1)
399
+ global_query = global_query.permute(0, 2, 3, 1)
400
+ for blk in self.blocks:
401
+ x = blk(x, global_query, window_size)
402
+
403
+ x = self.norm(x)
404
+ x = x.permute(0, 3, 1, 2).contiguous()
405
+
406
+ return x
407
+
408
+
409
+ # pylint: disable=invalid-name
410
+ class GC_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
411
+ block_group_regex = r"body\.stage(\d+)\.blocks\.(\d+)"
412
+
413
+ def __init__(
414
+ self,
415
+ input_channels: int,
416
+ num_classes: int,
417
+ *,
418
+ config: Optional[dict[str, Any]] = None,
419
+ size: Optional[tuple[int, int]] = None,
420
+ ) -> None:
421
+ super().__init__(input_channels, num_classes, config=config, size=size)
422
+ assert self.config is not None, "must set config"
423
+
424
+ depths: list[int] = self.config["depths"]
425
+ num_heads: list[int] = self.config["num_heads"]
426
+ window_ratio: list[int] = self.config["window_ratio"]
427
+ embed_dim: int = self.config["embed_dim"]
428
+ mlp_ratio: float = self.config["mlp_ratio"]
429
+ layer_scale: Optional[float] = self.config["layer_scale"]
430
+ drop_path_rate: float = self.config["drop_path_rate"]
431
+
432
+ self.window_ratio = window_ratio
433
+ num_stages = len(depths)
434
+ img_size = self.size
435
+
436
+ # Calculate window sizes from window ratios
437
+ window_sizes = []
438
+ for r in window_ratio:
439
+ window_sizes.append((max(1, img_size[0] // r), max(1, img_size[1] // r)))
440
+
441
+ self.stem = Stem(self.input_channels, embed_dim)
442
+
443
+ feat_size = (math.ceil(img_size[0] / 4), math.ceil(img_size[1] / 4))
444
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
445
+
446
+ in_dim = embed_dim
447
+ stages: OrderedDict[str, nn.Module] = OrderedDict()
448
+ return_channels: list[int] = []
449
+ for idx in range(num_stages):
450
+ stage = GlobalContextVitStage(
451
+ dim=in_dim,
452
+ depth=depths[idx],
453
+ num_heads=num_heads[idx],
454
+ feat_size=feat_size,
455
+ window_size=window_sizes[idx],
456
+ downsample=idx > 0,
457
+ mlp_ratio=mlp_ratio,
458
+ layer_scale=layer_scale,
459
+ stage_norm=(idx == num_stages - 1),
460
+ drop_path=dpr[idx],
461
+ )
462
+
463
+ stages[f"stage{idx + 1}"] = stage
464
+ if idx > 0:
465
+ in_dim = in_dim * 2
466
+ feat_size = (math.ceil(feat_size[0] / 2), math.ceil(feat_size[1] / 2))
467
+
468
+ return_channels.append(in_dim)
469
+
470
+ self.body = nn.Sequential(stages)
471
+ self.features = nn.Sequential(
472
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
473
+ nn.Flatten(1),
474
+ )
475
+ self.return_channels = return_channels
476
+ self.embedding_size = return_channels[-1]
477
+ self.classifier = self.create_classifier()
478
+
479
+ self.stem_stride = 4
480
+ self.stem_width = embed_dim
481
+ self.encoding_size = return_channels[-1]
482
+
483
+ # Weight initialization
484
+ for m in self.modules():
485
+ if isinstance(m, nn.Linear):
486
+ nn.init.trunc_normal_(m.weight, std=0.02)
487
+ if m.bias is not None:
488
+ nn.init.zeros_(m.bias)
489
+
490
+ def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
491
+ x = self.stem(x)
492
+
493
+ out = {}
494
+ for name, module in self.body.named_children():
495
+ x = module(x)
496
+ if name in self.return_stages:
497
+ out[name] = x
498
+
499
+ return out
500
+
501
+ def freeze_stages(self, up_to_stage: int) -> None:
502
+ for param in self.stem.parameters():
503
+ param.requires_grad = False
504
+
505
+ for idx, module in enumerate(self.body.children()):
506
+ if idx >= up_to_stage:
507
+ break
508
+
509
+ for param in module.parameters():
510
+ param.requires_grad = False
511
+
512
+ def set_dynamic_size(self, dynamic_size: bool = True) -> None:
513
+ super().set_dynamic_size(dynamic_size)
514
+ for stage in self.body.children():
515
+ if isinstance(stage, GlobalContextVitStage):
516
+ stage.dynamic_size = dynamic_size
517
+ for block in stage.blocks:
518
+ block.set_dynamic_size(dynamic_size)
519
+
520
+ def masked_encoding_retention(
521
+ self,
522
+ x: torch.Tensor,
523
+ mask: torch.Tensor,
524
+ mask_token: Optional[torch.Tensor] = None,
525
+ return_keys: Literal["all", "features", "embedding"] = "features",
526
+ ) -> TokenRetentionResultType:
527
+ x = self.stem(x)
528
+ x = mask_tensor(x, mask, patch_factor=self.max_stride // self.stem_stride, mask_token=mask_token)
529
+ x = self.body(x)
530
+
531
+ result: TokenRetentionResultType = {}
532
+ if return_keys in ("all", "features"):
533
+ result["features"] = x
534
+ if return_keys in ("all", "embedding"):
535
+ result["embedding"] = self.features(x)
536
+
537
+ return result
538
+
539
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
540
+ x = self.stem(x)
541
+ return self.body(x)
542
+
543
+ def embedding(self, x: torch.Tensor) -> torch.Tensor:
544
+ x = self.forward_features(x)
545
+ return self.features(x)
546
+
547
+ def adjust_size(self, new_size: tuple[int, int]) -> None:
548
+ if new_size == self.size:
549
+ return
550
+
551
+ super().adjust_size(new_size)
552
+
553
+ new_window_sizes = []
554
+ for r in self.window_ratio:
555
+ new_window_sizes.append((max(1, new_size[0] // r), max(1, new_size[1] // r)))
556
+
557
+ feat_size = (math.ceil(new_size[0] / self.stem_stride), math.ceil(new_size[1] / self.stem_stride))
558
+ stage_idx = 0
559
+ for stage in self.body.children():
560
+ if isinstance(stage, GlobalContextVitStage):
561
+ new_window_size = new_window_sizes[stage_idx]
562
+ if isinstance(stage.downsample, nn.Identity):
563
+ stage_feat_size = feat_size
564
+ else:
565
+ stage_feat_size = (math.ceil(feat_size[0] / 2), math.ceil(feat_size[1] / 2))
566
+
567
+ stage.window_size = new_window_size
568
+ stage.window_ratio = (
569
+ max(1, stage_feat_size[0] // new_window_size[0]),
570
+ max(1, stage_feat_size[1] // new_window_size[1]),
571
+ )
572
+ for block in stage.blocks:
573
+ rel_pos = block.attn.rel_pos
574
+ if new_window_size == rel_pos.window_size:
575
+ continue
576
+
577
+ with torch.no_grad():
578
+ bias_table = interpolate_rel_pos_bias_table(
579
+ rel_pos.relative_position_bias_table,
580
+ rel_pos.window_size,
581
+ new_window_size,
582
+ )
583
+
584
+ rel_pos.window_size = new_window_size
585
+ rel_pos.relative_position_bias_table = nn.Parameter(bias_table)
586
+ rel_pos.relative_position_index = nn.Buffer(
587
+ build_relative_position_index(new_window_size, device=bias_table.device)
588
+ )
589
+
590
+ feat_size = stage_feat_size
591
+ stage_idx += 1
592
+
593
+
594
+ registry.register_model_config(
595
+ "gc_vit_xxt",
596
+ GC_ViT,
597
+ config={
598
+ "depths": [2, 2, 6, 2],
599
+ "num_heads": [2, 4, 8, 16],
600
+ "window_ratio": [32, 32, 16, 32],
601
+ "embed_dim": 64,
602
+ "mlp_ratio": 3.0,
603
+ "layer_scale": None,
604
+ "drop_path_rate": 0.2,
605
+ },
606
+ )
607
+ registry.register_model_config(
608
+ "gc_vit_xt",
609
+ GC_ViT,
610
+ config={
611
+ "depths": [3, 4, 6, 5],
612
+ "num_heads": [2, 4, 8, 16],
613
+ "window_ratio": [32, 32, 16, 32],
614
+ "embed_dim": 64,
615
+ "mlp_ratio": 3.0,
616
+ "layer_scale": None,
617
+ "drop_path_rate": 0.2,
618
+ },
619
+ )
620
+ registry.register_model_config(
621
+ "gc_vit_t",
622
+ GC_ViT,
623
+ config={
624
+ "depths": [3, 4, 19, 5],
625
+ "num_heads": [2, 4, 8, 16],
626
+ "window_ratio": [32, 32, 16, 32],
627
+ "embed_dim": 64,
628
+ "mlp_ratio": 3.0,
629
+ "layer_scale": None,
630
+ "drop_path_rate": 0.2,
631
+ },
632
+ )
633
+ registry.register_model_config(
634
+ "gc_vit_s",
635
+ GC_ViT,
636
+ config={
637
+ "depths": [3, 4, 19, 5],
638
+ "num_heads": [3, 6, 12, 24],
639
+ "window_ratio": [32, 32, 16, 32],
640
+ "embed_dim": 96,
641
+ "mlp_ratio": 2.0,
642
+ "layer_scale": 1e-5,
643
+ "drop_path_rate": 0.3,
644
+ },
645
+ )
646
+ registry.register_model_config(
647
+ "gc_vit_b",
648
+ GC_ViT,
649
+ config={
650
+ "depths": [3, 4, 19, 5],
651
+ "num_heads": [4, 8, 16, 32],
652
+ "window_ratio": [32, 32, 16, 32],
653
+ "embed_dim": 128,
654
+ "mlp_ratio": 2.0,
655
+ "layer_scale": 1e-5,
656
+ "drop_path_rate": 0.5,
657
+ },
658
+ )
659
+ registry.register_model_config(
660
+ "gc_vit_l",
661
+ GC_ViT,
662
+ config={
663
+ "depths": [3, 4, 19, 5],
664
+ "num_heads": [6, 12, 24, 48],
665
+ "window_ratio": [32, 32, 16, 32],
666
+ "embed_dim": 192,
667
+ "mlp_ratio": 2.0,
668
+ "layer_scale": 1e-5,
669
+ "drop_path_rate": 0.5,
670
+ },
671
+ )
birder/net/hiera.py CHANGED
@@ -612,23 +612,26 @@ class Hiera(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
612
612
 
613
613
  if self.pos_embed_win is not None:
614
614
  global_pos_size = (new_size[0] // 2**4, new_size[1] // 2**4)
615
- pos_embed = F.interpolate(
616
- self.pos_embed,
617
- size=global_pos_size,
618
- mode="bicubic",
619
- antialias=True,
620
- )
615
+ with torch.no_grad():
616
+ pos_embed = F.interpolate(
617
+ self.pos_embed,
618
+ size=global_pos_size,
619
+ mode="bicubic",
620
+ antialias=True,
621
+ )
622
+
621
623
  self.pos_embed = nn.Parameter(pos_embed)
622
624
 
623
625
  else:
624
- self.pos_embed = nn.Parameter(
625
- adjust_position_embedding(
626
+ with torch.no_grad():
627
+ pos_embed = adjust_position_embedding(
626
628
  self.pos_embed,
627
629
  (old_size[0] // self.patch_stride[0], old_size[1] // self.patch_stride[1]),
628
630
  (new_size[0] // self.patch_stride[0], new_size[1] // self.patch_stride[1]),
629
631
  0,
630
632
  )
631
- )
633
+
634
+ self.pos_embed = nn.Parameter(pos_embed)
632
635
 
633
636
  # Re-init vars
634
637
  self.tokens_spatial_shape = [i // s for i, s in zip(new_size, self.patch_stride)]