birder 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/fs_ops.py +2 -2
  4. birder/common/masking.py +13 -4
  5. birder/common/training_cli.py +6 -1
  6. birder/common/training_utils.py +4 -2
  7. birder/inference/classification.py +1 -1
  8. birder/introspection/__init__.py +2 -0
  9. birder/introspection/base.py +0 -7
  10. birder/introspection/feature_pca.py +101 -0
  11. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  12. birder/model_registry/model_registry.py +3 -2
  13. birder/net/base.py +3 -3
  14. birder/net/biformer.py +2 -2
  15. birder/net/cas_vit.py +6 -6
  16. birder/net/coat.py +8 -8
  17. birder/net/conv2former.py +2 -2
  18. birder/net/convnext_v1.py +22 -2
  19. birder/net/convnext_v2.py +2 -2
  20. birder/net/crossformer.py +2 -2
  21. birder/net/cspnet.py +2 -2
  22. birder/net/cswin_transformer.py +2 -2
  23. birder/net/darknet.py +2 -2
  24. birder/net/davit.py +2 -2
  25. birder/net/deit.py +3 -3
  26. birder/net/deit3.py +3 -3
  27. birder/net/densenet.py +2 -2
  28. birder/net/detection/deformable_detr.py +2 -2
  29. birder/net/detection/detr.py +2 -2
  30. birder/net/detection/efficientdet.py +2 -2
  31. birder/net/detection/faster_rcnn.py +2 -2
  32. birder/net/detection/fcos.py +2 -2
  33. birder/net/detection/retinanet.py +2 -2
  34. birder/net/detection/rt_detr_v1.py +4 -4
  35. birder/net/detection/ssd.py +2 -2
  36. birder/net/detection/ssdlite.py +2 -2
  37. birder/net/detection/yolo_v2.py +2 -2
  38. birder/net/detection/yolo_v3.py +2 -2
  39. birder/net/detection/yolo_v4.py +2 -2
  40. birder/net/edgenext.py +2 -2
  41. birder/net/edgevit.py +1 -1
  42. birder/net/efficientformer_v1.py +4 -4
  43. birder/net/efficientformer_v2.py +6 -6
  44. birder/net/efficientnet_lite.py +2 -2
  45. birder/net/efficientnet_v1.py +2 -2
  46. birder/net/efficientnet_v2.py +2 -2
  47. birder/net/efficientvim.py +3 -3
  48. birder/net/efficientvit_mit.py +2 -2
  49. birder/net/efficientvit_msft.py +2 -2
  50. birder/net/fasternet.py +2 -2
  51. birder/net/fastvit.py +2 -3
  52. birder/net/flexivit.py +11 -6
  53. birder/net/focalnet.py +2 -3
  54. birder/net/gc_vit.py +17 -2
  55. birder/net/ghostnet_v1.py +2 -2
  56. birder/net/ghostnet_v2.py +2 -2
  57. birder/net/groupmixformer.py +2 -2
  58. birder/net/hgnet_v1.py +2 -2
  59. birder/net/hgnet_v2.py +2 -2
  60. birder/net/hiera.py +2 -2
  61. birder/net/hieradet.py +2 -2
  62. birder/net/hornet.py +2 -2
  63. birder/net/iformer.py +2 -2
  64. birder/net/inception_next.py +2 -2
  65. birder/net/inception_resnet_v1.py +2 -2
  66. birder/net/inception_resnet_v2.py +2 -2
  67. birder/net/inception_v3.py +2 -2
  68. birder/net/inception_v4.py +2 -2
  69. birder/net/levit.py +4 -4
  70. birder/net/lit_v1.py +2 -2
  71. birder/net/lit_v1_tiny.py +2 -2
  72. birder/net/lit_v2.py +2 -2
  73. birder/net/maxvit.py +2 -2
  74. birder/net/metaformer.py +2 -2
  75. birder/net/mnasnet.py +2 -2
  76. birder/net/mobilenet_v1.py +2 -2
  77. birder/net/mobilenet_v2.py +2 -2
  78. birder/net/mobilenet_v3_large.py +2 -2
  79. birder/net/mobilenet_v4.py +2 -2
  80. birder/net/mobilenet_v4_hybrid.py +2 -2
  81. birder/net/mobileone.py +2 -2
  82. birder/net/mobilevit_v2.py +2 -2
  83. birder/net/moganet.py +2 -2
  84. birder/net/mvit_v2.py +2 -2
  85. birder/net/nextvit.py +2 -2
  86. birder/net/nfnet.py +2 -2
  87. birder/net/pit.py +6 -6
  88. birder/net/pvt_v1.py +2 -2
  89. birder/net/pvt_v2.py +2 -2
  90. birder/net/rdnet.py +2 -2
  91. birder/net/regionvit.py +6 -6
  92. birder/net/regnet.py +2 -2
  93. birder/net/regnet_z.py +2 -2
  94. birder/net/repghost.py +2 -2
  95. birder/net/repvgg.py +2 -2
  96. birder/net/repvit.py +6 -6
  97. birder/net/resnest.py +2 -2
  98. birder/net/resnet_v1.py +2 -2
  99. birder/net/resnet_v2.py +2 -2
  100. birder/net/resnext.py +2 -2
  101. birder/net/rope_deit3.py +3 -3
  102. birder/net/rope_flexivit.py +13 -6
  103. birder/net/rope_vit.py +69 -10
  104. birder/net/shufflenet_v1.py +2 -2
  105. birder/net/shufflenet_v2.py +2 -2
  106. birder/net/smt.py +1 -2
  107. birder/net/squeezenext.py +2 -2
  108. birder/net/ssl/byol.py +3 -2
  109. birder/net/ssl/capi.py +156 -11
  110. birder/net/ssl/data2vec.py +3 -1
  111. birder/net/ssl/data2vec2.py +3 -1
  112. birder/net/ssl/dino_v1.py +1 -1
  113. birder/net/ssl/dino_v2.py +140 -18
  114. birder/net/ssl/franca.py +145 -13
  115. birder/net/ssl/ibot.py +1 -2
  116. birder/net/ssl/mmcr.py +3 -1
  117. birder/net/starnet.py +2 -2
  118. birder/net/swiftformer.py +6 -6
  119. birder/net/swin_transformer_v1.py +2 -2
  120. birder/net/swin_transformer_v2.py +2 -2
  121. birder/net/tiny_vit.py +2 -2
  122. birder/net/transnext.py +1 -1
  123. birder/net/uniformer.py +1 -1
  124. birder/net/van.py +1 -1
  125. birder/net/vgg.py +1 -1
  126. birder/net/vgg_reduced.py +1 -1
  127. birder/net/vit.py +172 -8
  128. birder/net/vit_parallel.py +5 -5
  129. birder/net/vit_sam.py +3 -3
  130. birder/net/vovnet_v1.py +2 -2
  131. birder/net/vovnet_v2.py +2 -2
  132. birder/net/wide_resnet.py +2 -2
  133. birder/net/xception.py +2 -2
  134. birder/net/xcit.py +2 -2
  135. birder/results/detection.py +104 -0
  136. birder/results/gui.py +10 -8
  137. birder/scripts/benchmark.py +1 -1
  138. birder/scripts/train.py +13 -18
  139. birder/scripts/train_barlow_twins.py +10 -14
  140. birder/scripts/train_byol.py +11 -15
  141. birder/scripts/train_capi.py +38 -17
  142. birder/scripts/train_data2vec.py +11 -15
  143. birder/scripts/train_data2vec2.py +13 -17
  144. birder/scripts/train_detection.py +11 -14
  145. birder/scripts/train_dino_v1.py +20 -22
  146. birder/scripts/train_dino_v2.py +126 -63
  147. birder/scripts/train_dino_v2_dist.py +127 -64
  148. birder/scripts/train_franca.py +49 -34
  149. birder/scripts/train_i_jepa.py +11 -14
  150. birder/scripts/train_ibot.py +16 -18
  151. birder/scripts/train_kd.py +14 -20
  152. birder/scripts/train_mim.py +10 -13
  153. birder/scripts/train_mmcr.py +11 -15
  154. birder/scripts/train_rotnet.py +12 -16
  155. birder/scripts/train_simclr.py +10 -14
  156. birder/scripts/train_vicreg.py +10 -14
  157. birder/tools/avg_model.py +24 -8
  158. birder/tools/det_results.py +91 -0
  159. birder/tools/introspection.py +35 -9
  160. birder/tools/results.py +11 -7
  161. birder/tools/show_iterator.py +1 -1
  162. birder/version.py +1 -1
  163. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
  164. birder-0.3.2.dist-info/RECORD +299 -0
  165. birder-0.3.0.dist-info/RECORD +0 -298
  166. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
  167. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
  168. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
  169. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/net/regionvit.py CHANGED
@@ -464,14 +464,14 @@ class RegionViT(DetectorBackbone):
464
464
 
465
465
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
466
466
  for param in self.parameters():
467
- param.requires_grad = False
467
+ param.requires_grad_(False)
468
468
 
469
469
  if freeze_classifier is False:
470
470
  for param in self.classifier.parameters():
471
- param.requires_grad = True
471
+ param.requires_grad_(True)
472
472
  if unfreeze_features is True:
473
473
  for param in self.norm.parameters():
474
- param.requires_grad = True
474
+ param.requires_grad_(True)
475
475
 
476
476
  def detection_features(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
477
477
  o_x = x
@@ -488,16 +488,16 @@ class RegionViT(DetectorBackbone):
488
488
 
489
489
  def freeze_stages(self, up_to_stage: int) -> None:
490
490
  for param in self.patch_embed.parameters():
491
- param.requires_grad = False
491
+ param.requires_grad_(False)
492
492
  for param in self.cls_token.parameters():
493
- param.requires_grad = False
493
+ param.requires_grad_(False)
494
494
 
495
495
  for idx, module in enumerate(self.body.children()):
496
496
  if idx >= up_to_stage:
497
497
  break
498
498
 
499
499
  for param in module.parameters():
500
- param.requires_grad = False
500
+ param.requires_grad_(False)
501
501
 
502
502
  def forward_features(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
503
503
  o_x = x
birder/net/regnet.py CHANGED
@@ -364,14 +364,14 @@ class RegNet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
364
364
 
365
365
  def freeze_stages(self, up_to_stage: int) -> None:
366
366
  for param in self.stem.parameters():
367
- param.requires_grad = False
367
+ param.requires_grad_(False)
368
368
 
369
369
  for idx, module in enumerate(self.body.children()):
370
370
  if idx >= up_to_stage:
371
371
  break
372
372
 
373
373
  for param in module.parameters():
374
- param.requires_grad = False
374
+ param.requires_grad_(False)
375
375
 
376
376
  def masked_encoding_retention(
377
377
  self,
birder/net/regnet_z.py CHANGED
@@ -210,14 +210,14 @@ class RegNet_Z(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
210
210
 
211
211
  def freeze_stages(self, up_to_stage: int) -> None:
212
212
  for param in self.stem.parameters():
213
- param.requires_grad = False
213
+ param.requires_grad_(False)
214
214
 
215
215
  for idx, module in enumerate(self.body.children()):
216
216
  if idx >= up_to_stage:
217
217
  break
218
218
 
219
219
  for param in module.parameters():
220
- param.requires_grad = False
220
+ param.requires_grad_(False)
221
221
 
222
222
  def masked_encoding_retention(
223
223
  self,
birder/net/repghost.py CHANGED
@@ -321,14 +321,14 @@ class RepGhost(DetectorBackbone):
321
321
 
322
322
  def freeze_stages(self, up_to_stage: int) -> None:
323
323
  for param in self.stem.parameters():
324
- param.requires_grad = False
324
+ param.requires_grad_(False)
325
325
 
326
326
  for idx, module in enumerate(self.body.children()):
327
327
  if idx >= up_to_stage:
328
328
  break
329
329
 
330
330
  for param in module.parameters():
331
- param.requires_grad = False
331
+ param.requires_grad_(False)
332
332
 
333
333
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
334
334
  x = self.stem(x)
birder/net/repvgg.py CHANGED
@@ -302,14 +302,14 @@ class RepVgg(DetectorBackbone):
302
302
 
303
303
  def freeze_stages(self, up_to_stage: int) -> None:
304
304
  for param in self.stem.parameters():
305
- param.requires_grad = False
305
+ param.requires_grad_(False)
306
306
 
307
307
  for idx, module in enumerate(self.body.children()):
308
308
  if idx >= up_to_stage:
309
309
  break
310
310
 
311
311
  for param in module.parameters():
312
- param.requires_grad = False
312
+ param.requires_grad_(False)
313
313
 
314
314
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
315
315
  x = self.stem(x)
birder/net/repvit.py CHANGED
@@ -399,18 +399,18 @@ class RepViT(DetectorBackbone):
399
399
 
400
400
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
401
401
  for param in self.parameters():
402
- param.requires_grad = False
402
+ param.requires_grad_(False)
403
403
 
404
404
  if freeze_classifier is False:
405
405
  for param in self.classifier.parameters():
406
- param.requires_grad = True
406
+ param.requires_grad_(True)
407
407
 
408
408
  for param in self.dist_classifier.parameters():
409
- param.requires_grad = True
409
+ param.requires_grad_(True)
410
410
 
411
411
  if unfreeze_features is True:
412
412
  for param in self.features.parameters():
413
- param.requires_grad = True
413
+ param.requires_grad_(True)
414
414
 
415
415
  def transform_to_backbone(self) -> None:
416
416
  self.features = nn.Identity()
@@ -430,14 +430,14 @@ class RepViT(DetectorBackbone):
430
430
 
431
431
  def freeze_stages(self, up_to_stage: int) -> None:
432
432
  for param in self.stem.parameters():
433
- param.requires_grad = False
433
+ param.requires_grad_(False)
434
434
 
435
435
  for idx, module in enumerate(self.body.children()):
436
436
  if idx >= up_to_stage:
437
437
  break
438
438
 
439
439
  for param in module.parameters():
440
- param.requires_grad = False
440
+ param.requires_grad_(False)
441
441
 
442
442
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
443
443
  x = self.stem(x)
birder/net/resnest.py CHANGED
@@ -271,14 +271,14 @@ class ResNeSt(DetectorBackbone):
271
271
 
272
272
  def freeze_stages(self, up_to_stage: int) -> None:
273
273
  for param in self.stem.parameters():
274
- param.requires_grad = False
274
+ param.requires_grad_(False)
275
275
 
276
276
  for idx, module in enumerate(self.body.children()):
277
277
  if idx >= up_to_stage:
278
278
  break
279
279
 
280
280
  for param in module.parameters():
281
- param.requires_grad = False
281
+ param.requires_grad_(False)
282
282
 
283
283
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
284
284
  x = self.stem(x)
birder/net/resnet_v1.py CHANGED
@@ -192,14 +192,14 @@ class ResNet_v1(DetectorBackbone):
192
192
 
193
193
  def freeze_stages(self, up_to_stage: int) -> None:
194
194
  for param in self.stem.parameters():
195
- param.requires_grad = False
195
+ param.requires_grad_(False)
196
196
 
197
197
  for idx, module in enumerate(self.body.children()):
198
198
  if idx >= up_to_stage:
199
199
  break
200
200
 
201
201
  for param in module.parameters():
202
- param.requires_grad = False
202
+ param.requires_grad_(False)
203
203
 
204
204
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
205
205
  x = self.stem(x)
birder/net/resnet_v2.py CHANGED
@@ -178,14 +178,14 @@ class ResNet_v2(DetectorBackbone):
178
178
 
179
179
  def freeze_stages(self, up_to_stage: int) -> None:
180
180
  for param in self.stem.parameters():
181
- param.requires_grad = False
181
+ param.requires_grad_(False)
182
182
 
183
183
  for idx, module in enumerate(self.body.children()):
184
184
  if idx >= up_to_stage:
185
185
  break
186
186
 
187
187
  for param in module.parameters():
188
- param.requires_grad = False
188
+ param.requires_grad_(False)
189
189
 
190
190
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
191
191
  x = self.stem(x)
birder/net/resnext.py CHANGED
@@ -216,14 +216,14 @@ class ResNeXt(DetectorBackbone):
216
216
 
217
217
  def freeze_stages(self, up_to_stage: int) -> None:
218
218
  for param in self.stem.parameters():
219
- param.requires_grad = False
219
+ param.requires_grad_(False)
220
220
 
221
221
  for idx, module in enumerate(self.body.children()):
222
222
  if idx >= up_to_stage:
223
223
  break
224
224
 
225
225
  for param in module.parameters():
226
- param.requires_grad = False
226
+ param.requires_grad_(False)
227
227
 
228
228
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
229
229
  x = self.stem(x)
birder/net/rope_deit3.py CHANGED
@@ -245,16 +245,16 @@ class RoPE_DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Ma
245
245
 
246
246
  def freeze_stages(self, up_to_stage: int) -> None:
247
247
  for param in self.conv_proj.parameters():
248
- param.requires_grad = False
248
+ param.requires_grad_(False)
249
249
 
250
- self.pos_embedding.requires_grad = False
250
+ self.pos_embedding.requires_grad_(False)
251
251
 
252
252
  for idx, module in enumerate(self.encoder.children()):
253
253
  if idx >= up_to_stage:
254
254
  break
255
255
 
256
256
  for param in module.parameters():
257
- param.requires_grad = False
257
+ param.requires_grad_(False)
258
258
 
259
259
  def set_causal_attention(self, is_causal: bool = True) -> None:
260
260
  self.encoder.set_causal_attention(is_causal)
@@ -69,6 +69,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
69
69
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
70
70
  pre_norm: bool = self.config.get("pre_norm", False)
71
71
  post_norm: bool = self.config.get("post_norm", True)
72
+ qkv_bias: bool = self.config.get("qkv_bias", True)
73
+ qk_norm: bool = self.config.get("qk_norm", False)
72
74
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
73
75
  class_token: bool = self.config.get("class_token", True)
74
76
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -118,6 +120,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
118
120
  self.num_reg_tokens = num_reg_tokens
119
121
  self.attn_pool_special_tokens = attn_pool_special_tokens
120
122
  self.norm_layer = norm_layer
123
+ self.norm_layer_eps = norm_layer_eps
121
124
  self.mlp_layer = mlp_layer
122
125
  self.act_layer = act_layer
123
126
  self.rope_rot_type = rope_rot_type
@@ -190,6 +193,8 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
190
193
  attention_dropout,
191
194
  dpr,
192
195
  pre_norm=pre_norm,
196
+ qkv_bias=qkv_bias,
197
+ qk_norm=qk_norm,
193
198
  activation_layer=act_layer,
194
199
  layer_scale_init_value=layer_scale_init_value,
195
200
  norm_layer=norm_layer,
@@ -231,6 +236,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
231
236
  rope_temperature=rope_temperature,
232
237
  layer_scale_init_value=layer_scale_init_value,
233
238
  norm_layer=norm_layer,
239
+ norm_layer_eps=norm_layer_eps,
234
240
  mlp_layer=mlp_layer,
235
241
  rope_rot_type=rope_rot_type,
236
242
  )
@@ -285,16 +291,16 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
285
291
 
286
292
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
287
293
  for param in self.parameters():
288
- param.requires_grad = False
294
+ param.requires_grad_(False)
289
295
 
290
296
  if freeze_classifier is False:
291
297
  for param in self.classifier.parameters():
292
- param.requires_grad = True
298
+ param.requires_grad_(True)
293
299
 
294
300
  if unfreeze_features is True:
295
301
  if self.attn_pool is not None:
296
302
  for param in self.attn_pool.parameters():
297
- param.requires_grad = True
303
+ param.requires_grad_(True)
298
304
 
299
305
  def set_causal_attention(self, is_causal: bool = True) -> None:
300
306
  self.encoder.set_causal_attention(is_causal)
@@ -332,16 +338,16 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
332
338
 
333
339
  def freeze_stages(self, up_to_stage: int) -> None:
334
340
  for param in self.conv_proj.parameters():
335
- param.requires_grad = False
341
+ param.requires_grad_(False)
336
342
 
337
- self.pos_embedding.requires_grad = False
343
+ self.pos_embedding.requires_grad_(False)
338
344
 
339
345
  for idx, module in enumerate(self.encoder.children()):
340
346
  if idx >= up_to_stage:
341
347
  break
342
348
 
343
349
  for param in module.parameters():
344
- param.requires_grad = False
350
+ param.requires_grad_(False)
345
351
 
346
352
  # pylint: disable=too-many-branches
347
353
  def masked_encoding_omission(
@@ -588,6 +594,7 @@ class RoPE_FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin,
588
594
  rope_temperature=self.rope_temperature,
589
595
  layer_scale_init_value=self.layer_scale_init_value,
590
596
  norm_layer=self.norm_layer,
597
+ norm_layer_eps=self.norm_layer_eps,
591
598
  mlp_layer=self.mlp_layer,
592
599
  rope_rot_type=self.rope_rot_type,
593
600
  )
birder/net/rope_vit.py CHANGED
@@ -150,6 +150,10 @@ class RoPEAttention(nn.Module):
150
150
  attn_drop: float,
151
151
  proj_drop: float,
152
152
  num_special_tokens: int,
153
+ qkv_bias: bool = True,
154
+ qk_norm: bool = False,
155
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
156
+ norm_layer_eps: float = 1e-6,
153
157
  rope_rot_type: str = "standard",
154
158
  ) -> None:
155
159
  super().__init__()
@@ -167,7 +171,14 @@ class RoPEAttention(nn.Module):
167
171
  else:
168
172
  raise ValueError(f"Unknown rope_rot_type, got '{rope_rot_type}'")
169
173
 
170
- self.qkv = nn.Linear(dim, dim * 3)
174
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
175
+ if qk_norm is True:
176
+ self.q_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
177
+ self.k_norm = norm_layer(self.head_dim, eps=norm_layer_eps)
178
+ else:
179
+ self.q_norm = nn.Identity()
180
+ self.k_norm = nn.Identity()
181
+
171
182
  self.attn_drop = nn.Dropout(attn_drop)
172
183
  self.proj = nn.Linear(dim, dim)
173
184
  self.proj_drop = nn.Dropout(proj_drop)
@@ -176,6 +187,8 @@ class RoPEAttention(nn.Module):
176
187
  (B, N, C) = x.size()
177
188
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
178
189
  (q, k, v) = qkv.unbind(0)
190
+ q = self.q_norm(q)
191
+ k = self.k_norm(k)
179
192
 
180
193
  n = self.num_special_tokens
181
194
  q = torch.concat([q[:, :, :n, :], self.apply_rot_fn(q[:, :, n:, :], rope)], dim=2)
@@ -207,6 +220,8 @@ class EncoderBlock(nn.Module):
207
220
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
208
221
  norm_layer_eps: float = 1e-6,
209
222
  mlp_layer: Callable[..., nn.Module] = FFN,
223
+ qkv_bias: bool = True,
224
+ qk_norm: bool = False,
210
225
  rope_rot_type: str = "standard",
211
226
  ) -> None:
212
227
  super().__init__()
@@ -222,6 +237,10 @@ class EncoderBlock(nn.Module):
222
237
  attn_drop=attention_dropout,
223
238
  proj_drop=dropout,
224
239
  num_special_tokens=num_special_tokens,
240
+ qkv_bias=qkv_bias,
241
+ qk_norm=qk_norm,
242
+ norm_layer=norm_layer,
243
+ norm_layer_eps=norm_layer_eps,
225
244
  rope_rot_type=rope_rot_type,
226
245
  )
227
246
  if layer_scale_init_value is not None:
@@ -249,7 +268,6 @@ class EncoderBlock(nn.Module):
249
268
 
250
269
 
251
270
  class Encoder(nn.Module):
252
- # pylint: disable=too-many-arguments,too-many-positional-arguments
253
271
  def __init__(
254
272
  self,
255
273
  num_layers: int,
@@ -261,6 +279,8 @@ class Encoder(nn.Module):
261
279
  attention_dropout: float,
262
280
  dpr: list[float],
263
281
  pre_norm: bool = False,
282
+ qkv_bias: bool = True,
283
+ qk_norm: bool = False,
264
284
  activation_layer: Callable[..., nn.Module] = nn.GELU,
265
285
  layer_scale_init_value: Optional[float] = None,
266
286
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
@@ -293,6 +313,8 @@ class Encoder(nn.Module):
293
313
  norm_layer=norm_layer,
294
314
  norm_layer_eps=norm_layer_eps,
295
315
  mlp_layer=mlp_layer,
316
+ qkv_bias=qkv_bias,
317
+ qk_norm=qk_norm,
296
318
  rope_rot_type=rope_rot_type,
297
319
  )
298
320
  )
@@ -331,6 +353,7 @@ class MAEDecoderBlock(nn.Module):
331
353
  rope_temperature: float,
332
354
  layer_scale_init_value: Optional[float] = None,
333
355
  norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
356
+ norm_layer_eps: float = 1e-6,
334
357
  mlp_layer: Callable[..., nn.Module] = FFN,
335
358
  rope_rot_type: str = "standard",
336
359
  ) -> None:
@@ -346,7 +369,7 @@ class MAEDecoderBlock(nn.Module):
346
369
  )
347
370
 
348
371
  # Attention block
349
- self.norm1 = norm_layer(hidden_dim, eps=1e-6)
372
+ self.norm1 = norm_layer(hidden_dim, eps=norm_layer_eps)
350
373
  self.attn = RoPEAttention(
351
374
  hidden_dim,
352
375
  num_heads,
@@ -361,7 +384,7 @@ class MAEDecoderBlock(nn.Module):
361
384
  self.layer_scale_1 = nn.Identity()
362
385
 
363
386
  # MLP block
364
- self.norm2 = norm_layer(hidden_dim, eps=1e-6)
387
+ self.norm2 = norm_layer(hidden_dim, eps=norm_layer_eps)
365
388
  self.mlp = mlp_layer(hidden_dim, mlp_dim, act_layer=activation_layer, dropout=0.0)
366
389
  if layer_scale_init_value is not None:
367
390
  self.layer_scale_2 = LayerScale(hidden_dim, layer_scale_init_value)
@@ -403,6 +426,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
403
426
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
404
427
  pre_norm: bool = self.config.get("pre_norm", False)
405
428
  post_norm: bool = self.config.get("post_norm", True)
429
+ qkv_bias: bool = self.config.get("qkv_bias", True)
430
+ qk_norm: bool = self.config.get("qk_norm", False)
406
431
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
407
432
  class_token: bool = self.config.get("class_token", True)
408
433
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -450,6 +475,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
450
475
  self.num_reg_tokens = num_reg_tokens
451
476
  self.attn_pool_special_tokens = attn_pool_special_tokens
452
477
  self.norm_layer = norm_layer
478
+ self.norm_layer_eps = norm_layer_eps
453
479
  self.mlp_layer = mlp_layer
454
480
  self.act_layer = act_layer
455
481
  self.rope_rot_type = rope_rot_type
@@ -521,6 +547,8 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
521
547
  attention_dropout,
522
548
  dpr,
523
549
  pre_norm=pre_norm,
550
+ qkv_bias=qkv_bias,
551
+ qk_norm=qk_norm,
524
552
  activation_layer=act_layer,
525
553
  layer_scale_init_value=layer_scale_init_value,
526
554
  norm_layer=norm_layer,
@@ -562,6 +590,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
562
590
  rope_temperature=rope_temperature,
563
591
  layer_scale_init_value=layer_scale_init_value,
564
592
  norm_layer=norm_layer,
593
+ norm_layer_eps=norm_layer_eps,
565
594
  mlp_layer=mlp_layer,
566
595
  rope_rot_type=rope_rot_type,
567
596
  )
@@ -614,16 +643,16 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
614
643
 
615
644
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
616
645
  for param in self.parameters():
617
- param.requires_grad = False
646
+ param.requires_grad_(False)
618
647
 
619
648
  if freeze_classifier is False:
620
649
  for param in self.classifier.parameters():
621
- param.requires_grad = True
650
+ param.requires_grad_(True)
622
651
 
623
652
  if unfreeze_features is True:
624
653
  if self.attn_pool is not None:
625
654
  for param in self.attn_pool.parameters():
626
- param.requires_grad = True
655
+ param.requires_grad_(True)
627
656
 
628
657
  def set_causal_attention(self, is_causal: bool = True) -> None:
629
658
  self.encoder.set_causal_attention(is_causal)
@@ -661,16 +690,16 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
661
690
 
662
691
  def freeze_stages(self, up_to_stage: int) -> None:
663
692
  for param in self.conv_proj.parameters():
664
- param.requires_grad = False
693
+ param.requires_grad_(False)
665
694
 
666
- self.pos_embedding.requires_grad = False
695
+ self.pos_embedding.requires_grad_(False)
667
696
 
668
697
  for idx, module in enumerate(self.encoder.children()):
669
698
  if idx >= up_to_stage:
670
699
  break
671
700
 
672
701
  for param in module.parameters():
673
- param.requires_grad = False
702
+ param.requires_grad_(False)
674
703
 
675
704
  def masked_encoding_omission(
676
705
  self,
@@ -904,6 +933,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
904
933
  rope_temperature=self.rope_temperature,
905
934
  layer_scale_init_value=self.layer_scale_init_value,
906
935
  norm_layer=self.norm_layer,
936
+ norm_layer_eps=self.norm_layer_eps,
907
937
  mlp_layer=self.mlp_layer,
908
938
  rope_rot_type=self.rope_rot_type,
909
939
  )
@@ -931,6 +961,7 @@ class RoPE_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
931
961
  # - rms : RMSNorm (instead of LayerNorm)
932
962
  # - pn : Pre-Norm (layer norm before the encoder) - implies different norm eps
933
963
  # - npn : No Post Norm (disables post-normalization layer)
964
+ # - qkn : QK Norm
934
965
  #
935
966
  # Feed-Forward Network:
936
967
  # - swiglu : SwiGLU FFN layer type (instead of standard FFN)
@@ -1068,6 +1099,20 @@ registry.register_model_config(
1068
1099
  "drop_path_rate": 0.1,
1069
1100
  },
1070
1101
  )
1102
+ registry.register_model_config(
1103
+ "rope_vit_b16_qkn_ls",
1104
+ RoPE_ViT,
1105
+ config={
1106
+ "patch_size": 16,
1107
+ "num_layers": 12,
1108
+ "num_heads": 12,
1109
+ "hidden_dim": 768,
1110
+ "mlp_dim": 3072,
1111
+ "layer_scale_init_value": 1e-5,
1112
+ "qk_norm": True,
1113
+ "drop_path_rate": 0.1,
1114
+ },
1115
+ )
1071
1116
  registry.register_model_config(
1072
1117
  "rope_i_vit_b16_pn_aps_c1", # For PE Core - https://arxiv.org/abs/2504.13181
1073
1118
  RoPE_ViT,
@@ -1310,6 +1355,20 @@ registry.register_model_config(
1310
1355
  "drop_path_rate": 0.0,
1311
1356
  },
1312
1357
  )
1358
+ registry.register_model_config(
1359
+ "rope_vit_reg4_m14_avg",
1360
+ RoPE_ViT,
1361
+ config={
1362
+ "patch_size": 14,
1363
+ "num_layers": 12,
1364
+ "num_heads": 8,
1365
+ "hidden_dim": 512,
1366
+ "mlp_dim": 2048,
1367
+ "num_reg_tokens": 4,
1368
+ "class_token": False,
1369
+ "drop_path_rate": 0.0,
1370
+ },
1371
+ )
1313
1372
  registry.register_model_config(
1314
1373
  "rope_vit_reg4_b32",
1315
1374
  RoPE_ViT,
@@ -220,14 +220,14 @@ class ShuffleNet_v1(DetectorBackbone):
220
220
 
221
221
  def freeze_stages(self, up_to_stage: int) -> None:
222
222
  for param in self.stem.parameters():
223
- param.requires_grad = False
223
+ param.requires_grad_(False)
224
224
 
225
225
  for idx, module in enumerate(self.body.children()):
226
226
  if idx >= up_to_stage:
227
227
  break
228
228
 
229
229
  for param in module.parameters():
230
- param.requires_grad = False
230
+ param.requires_grad_(False)
231
231
 
232
232
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
233
233
  x = self.stem(x)
@@ -166,14 +166,14 @@ class ShuffleNet_v2(DetectorBackbone):
166
166
 
167
167
  def freeze_stages(self, up_to_stage: int) -> None:
168
168
  for param in self.stem.parameters():
169
- param.requires_grad = False
169
+ param.requires_grad_(False)
170
170
 
171
171
  for idx, module in enumerate(self.body.children()):
172
172
  if idx >= up_to_stage:
173
173
  break
174
174
 
175
175
  for param in module.parameters():
176
- param.requires_grad = False
176
+ param.requires_grad_(False)
177
177
 
178
178
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
179
179
  x = self.stem(x)
birder/net/smt.py CHANGED
@@ -275,7 +275,6 @@ class Stem(nn.Module):
275
275
 
276
276
 
277
277
  class SMTStage(nn.Module):
278
- # pylint: disable=too-many-arguments,too-many-positional-arguments
279
278
  def __init__(
280
279
  self,
281
280
  dim: int,
@@ -429,7 +428,7 @@ class SMT(DetectorBackbone):
429
428
  break
430
429
 
431
430
  for param in module.parameters():
432
- param.requires_grad = False
431
+ param.requires_grad_(False)
433
432
 
434
433
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
435
434
  return self.body(x)
birder/net/squeezenext.py CHANGED
@@ -177,14 +177,14 @@ class SqueezeNext(DetectorBackbone):
177
177
 
178
178
  def freeze_stages(self, up_to_stage: int) -> None:
179
179
  for param in self.stem.parameters():
180
- param.requires_grad = False
180
+ param.requires_grad_(False)
181
181
 
182
182
  for idx, module in enumerate(self.body.children()):
183
183
  if idx >= up_to_stage:
184
184
  break
185
185
 
186
186
  for param in module.parameters():
187
- param.requires_grad = False
187
+ param.requires_grad_(False)
188
188
 
189
189
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
190
190
  x = self.stem(x)
birder/net/ssl/byol.py CHANGED
@@ -82,8 +82,9 @@ class BYOL(SSLBaseNet):
82
82
  online_predictions = self.online_predictor(projection)
83
83
  (online_pred_one, online_pred_two) = online_predictions.chunk(2, dim=0)
84
84
 
85
- target_projections = self.target_encoder(x)
86
- (target_proj_one, target_proj_two) = target_projections.chunk(2, dim=0)
85
+ with torch.no_grad():
86
+ target_projections = self.target_encoder(x)
87
+ (target_proj_one, target_proj_two) = target_projections.chunk(2, dim=0)
87
88
 
88
89
  loss_one = loss_fn(online_pred_one, target_proj_two.detach())
89
90
  loss_two = loss_fn(online_pred_two, target_proj_one.detach())