birder 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/fs_ops.py +2 -2
  4. birder/common/masking.py +13 -4
  5. birder/common/training_cli.py +6 -1
  6. birder/common/training_utils.py +4 -2
  7. birder/inference/classification.py +1 -1
  8. birder/introspection/__init__.py +2 -0
  9. birder/introspection/base.py +0 -7
  10. birder/introspection/feature_pca.py +101 -0
  11. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  12. birder/model_registry/model_registry.py +3 -2
  13. birder/net/base.py +3 -3
  14. birder/net/biformer.py +2 -2
  15. birder/net/cas_vit.py +6 -6
  16. birder/net/coat.py +8 -8
  17. birder/net/conv2former.py +2 -2
  18. birder/net/convnext_v1.py +22 -2
  19. birder/net/convnext_v2.py +2 -2
  20. birder/net/crossformer.py +2 -2
  21. birder/net/cspnet.py +2 -2
  22. birder/net/cswin_transformer.py +2 -2
  23. birder/net/darknet.py +2 -2
  24. birder/net/davit.py +2 -2
  25. birder/net/deit.py +3 -3
  26. birder/net/deit3.py +3 -3
  27. birder/net/densenet.py +2 -2
  28. birder/net/detection/deformable_detr.py +2 -2
  29. birder/net/detection/detr.py +2 -2
  30. birder/net/detection/efficientdet.py +2 -2
  31. birder/net/detection/faster_rcnn.py +2 -2
  32. birder/net/detection/fcos.py +2 -2
  33. birder/net/detection/retinanet.py +2 -2
  34. birder/net/detection/rt_detr_v1.py +4 -4
  35. birder/net/detection/ssd.py +2 -2
  36. birder/net/detection/ssdlite.py +2 -2
  37. birder/net/detection/yolo_v2.py +2 -2
  38. birder/net/detection/yolo_v3.py +2 -2
  39. birder/net/detection/yolo_v4.py +2 -2
  40. birder/net/edgenext.py +2 -2
  41. birder/net/edgevit.py +1 -1
  42. birder/net/efficientformer_v1.py +4 -4
  43. birder/net/efficientformer_v2.py +6 -6
  44. birder/net/efficientnet_lite.py +2 -2
  45. birder/net/efficientnet_v1.py +2 -2
  46. birder/net/efficientnet_v2.py +2 -2
  47. birder/net/efficientvim.py +3 -3
  48. birder/net/efficientvit_mit.py +2 -2
  49. birder/net/efficientvit_msft.py +2 -2
  50. birder/net/fasternet.py +2 -2
  51. birder/net/fastvit.py +2 -3
  52. birder/net/flexivit.py +11 -6
  53. birder/net/focalnet.py +2 -3
  54. birder/net/gc_vit.py +17 -2
  55. birder/net/ghostnet_v1.py +2 -2
  56. birder/net/ghostnet_v2.py +2 -2
  57. birder/net/groupmixformer.py +2 -2
  58. birder/net/hgnet_v1.py +2 -2
  59. birder/net/hgnet_v2.py +2 -2
  60. birder/net/hiera.py +2 -2
  61. birder/net/hieradet.py +2 -2
  62. birder/net/hornet.py +2 -2
  63. birder/net/iformer.py +2 -2
  64. birder/net/inception_next.py +2 -2
  65. birder/net/inception_resnet_v1.py +2 -2
  66. birder/net/inception_resnet_v2.py +2 -2
  67. birder/net/inception_v3.py +2 -2
  68. birder/net/inception_v4.py +2 -2
  69. birder/net/levit.py +4 -4
  70. birder/net/lit_v1.py +2 -2
  71. birder/net/lit_v1_tiny.py +2 -2
  72. birder/net/lit_v2.py +2 -2
  73. birder/net/maxvit.py +2 -2
  74. birder/net/metaformer.py +2 -2
  75. birder/net/mnasnet.py +2 -2
  76. birder/net/mobilenet_v1.py +2 -2
  77. birder/net/mobilenet_v2.py +2 -2
  78. birder/net/mobilenet_v3_large.py +2 -2
  79. birder/net/mobilenet_v4.py +2 -2
  80. birder/net/mobilenet_v4_hybrid.py +2 -2
  81. birder/net/mobileone.py +2 -2
  82. birder/net/mobilevit_v2.py +2 -2
  83. birder/net/moganet.py +2 -2
  84. birder/net/mvit_v2.py +2 -2
  85. birder/net/nextvit.py +2 -2
  86. birder/net/nfnet.py +2 -2
  87. birder/net/pit.py +6 -6
  88. birder/net/pvt_v1.py +2 -2
  89. birder/net/pvt_v2.py +2 -2
  90. birder/net/rdnet.py +2 -2
  91. birder/net/regionvit.py +6 -6
  92. birder/net/regnet.py +2 -2
  93. birder/net/regnet_z.py +2 -2
  94. birder/net/repghost.py +2 -2
  95. birder/net/repvgg.py +2 -2
  96. birder/net/repvit.py +6 -6
  97. birder/net/resnest.py +2 -2
  98. birder/net/resnet_v1.py +2 -2
  99. birder/net/resnet_v2.py +2 -2
  100. birder/net/resnext.py +2 -2
  101. birder/net/rope_deit3.py +3 -3
  102. birder/net/rope_flexivit.py +13 -6
  103. birder/net/rope_vit.py +69 -10
  104. birder/net/shufflenet_v1.py +2 -2
  105. birder/net/shufflenet_v2.py +2 -2
  106. birder/net/smt.py +1 -2
  107. birder/net/squeezenext.py +2 -2
  108. birder/net/ssl/byol.py +3 -2
  109. birder/net/ssl/capi.py +156 -11
  110. birder/net/ssl/data2vec.py +3 -1
  111. birder/net/ssl/data2vec2.py +3 -1
  112. birder/net/ssl/dino_v1.py +1 -1
  113. birder/net/ssl/dino_v2.py +140 -18
  114. birder/net/ssl/franca.py +145 -13
  115. birder/net/ssl/ibot.py +1 -2
  116. birder/net/ssl/mmcr.py +3 -1
  117. birder/net/starnet.py +2 -2
  118. birder/net/swiftformer.py +6 -6
  119. birder/net/swin_transformer_v1.py +2 -2
  120. birder/net/swin_transformer_v2.py +2 -2
  121. birder/net/tiny_vit.py +2 -2
  122. birder/net/transnext.py +1 -1
  123. birder/net/uniformer.py +1 -1
  124. birder/net/van.py +1 -1
  125. birder/net/vgg.py +1 -1
  126. birder/net/vgg_reduced.py +1 -1
  127. birder/net/vit.py +172 -8
  128. birder/net/vit_parallel.py +5 -5
  129. birder/net/vit_sam.py +3 -3
  130. birder/net/vovnet_v1.py +2 -2
  131. birder/net/vovnet_v2.py +2 -2
  132. birder/net/wide_resnet.py +2 -2
  133. birder/net/xception.py +2 -2
  134. birder/net/xcit.py +2 -2
  135. birder/results/detection.py +104 -0
  136. birder/results/gui.py +10 -8
  137. birder/scripts/benchmark.py +1 -1
  138. birder/scripts/train.py +13 -18
  139. birder/scripts/train_barlow_twins.py +10 -14
  140. birder/scripts/train_byol.py +11 -15
  141. birder/scripts/train_capi.py +38 -17
  142. birder/scripts/train_data2vec.py +11 -15
  143. birder/scripts/train_data2vec2.py +13 -17
  144. birder/scripts/train_detection.py +11 -14
  145. birder/scripts/train_dino_v1.py +20 -22
  146. birder/scripts/train_dino_v2.py +126 -63
  147. birder/scripts/train_dino_v2_dist.py +127 -64
  148. birder/scripts/train_franca.py +49 -34
  149. birder/scripts/train_i_jepa.py +11 -14
  150. birder/scripts/train_ibot.py +16 -18
  151. birder/scripts/train_kd.py +14 -20
  152. birder/scripts/train_mim.py +10 -13
  153. birder/scripts/train_mmcr.py +11 -15
  154. birder/scripts/train_rotnet.py +12 -16
  155. birder/scripts/train_simclr.py +10 -14
  156. birder/scripts/train_vicreg.py +10 -14
  157. birder/tools/avg_model.py +24 -8
  158. birder/tools/det_results.py +91 -0
  159. birder/tools/introspection.py +35 -9
  160. birder/tools/results.py +11 -7
  161. birder/tools/show_iterator.py +1 -1
  162. birder/version.py +1 -1
  163. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
  164. birder-0.3.2.dist-info/RECORD +299 -0
  165. birder-0.3.0.dist-info/RECORD +0 -298
  166. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
  167. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
  168. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
  169. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/net/densenet.py CHANGED
@@ -140,14 +140,14 @@ class DenseNet(DetectorBackbone):
140
140
 
141
141
  def freeze_stages(self, up_to_stage: int) -> None:
142
142
  for param in self.stem.parameters():
143
- param.requires_grad = False
143
+ param.requires_grad_(False)
144
144
 
145
145
  for idx, module in enumerate(self.body.children()):
146
146
  if idx >= up_to_stage:
147
147
  break
148
148
 
149
149
  for param in module.parameters():
150
- param.requires_grad = False
150
+ param.requires_grad_(False)
151
151
 
152
152
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
153
153
  x = self.stem(x)
@@ -633,11 +633,11 @@ class Deformable_DETR(DetectionBaseNet):
633
633
 
634
634
  def freeze(self, freeze_classifier: bool = True) -> None:
635
635
  for param in self.parameters():
636
- param.requires_grad = False
636
+ param.requires_grad_(False)
637
637
 
638
638
  if freeze_classifier is False:
639
639
  for param in self.class_embed.parameters():
640
- param.requires_grad = True
640
+ param.requires_grad_(True)
641
641
 
642
642
  def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
643
643
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
@@ -356,11 +356,11 @@ class DETR(DetectionBaseNet):
356
356
 
357
357
  def freeze(self, freeze_classifier: bool = True) -> None:
358
358
  for param in self.parameters():
359
- param.requires_grad = False
359
+ param.requires_grad_(False)
360
360
 
361
361
  if freeze_classifier is False:
362
362
  for param in self.class_embed.parameters():
363
- param.requires_grad = True
363
+ param.requires_grad_(True)
364
364
 
365
365
  def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
366
366
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
@@ -601,11 +601,11 @@ class EfficientDet(DetectionBaseNet):
601
601
 
602
602
  def freeze(self, freeze_classifier: bool = True) -> None:
603
603
  for param in self.parameters():
604
- param.requires_grad = False
604
+ param.requires_grad_(False)
605
605
 
606
606
  if freeze_classifier is False:
607
607
  for param in self.class_net.parameters():
608
- param.requires_grad = True
608
+ param.requires_grad_(True)
609
609
 
610
610
  def compute_loss(
611
611
  self,
@@ -851,11 +851,11 @@ class Faster_RCNN(DetectionBaseNet):
851
851
 
852
852
  def freeze(self, freeze_classifier: bool = True) -> None:
853
853
  for param in self.parameters():
854
- param.requires_grad = False
854
+ param.requires_grad_(False)
855
855
 
856
856
  if freeze_classifier is False:
857
857
  for param in self.roi_heads.box_predictor.parameters():
858
- param.requires_grad = True
858
+ param.requires_grad_(True)
859
859
 
860
860
  def forward(
861
861
  self,
@@ -338,11 +338,11 @@ class FCOS(DetectionBaseNet):
338
338
 
339
339
  def freeze(self, freeze_classifier: bool = True) -> None:
340
340
  for param in self.parameters():
341
- param.requires_grad = False
341
+ param.requires_grad_(False)
342
342
 
343
343
  if freeze_classifier is False:
344
344
  for param in self.head.classification_head.parameters():
345
- param.requires_grad = True
345
+ param.requires_grad_(True)
346
346
 
347
347
  def compute_loss(
348
348
  self,
@@ -332,11 +332,11 @@ class RetinaNet(DetectionBaseNet):
332
332
 
333
333
  def freeze(self, freeze_classifier: bool = True) -> None:
334
334
  for param in self.parameters():
335
- param.requires_grad = False
335
+ param.requires_grad_(False)
336
336
 
337
337
  if freeze_classifier is False:
338
338
  for param in self.head.classification_head.parameters():
339
- param.requires_grad = True
339
+ param.requires_grad_(True)
340
340
 
341
341
  @torch.jit.unused # type: ignore[untyped-decorator]
342
342
  @torch.compiler.disable() # type: ignore[untyped-decorator]
@@ -742,16 +742,16 @@ class RT_DETR_v1(DetectionBaseNet):
742
742
 
743
743
  def freeze(self, freeze_classifier: bool = True) -> None:
744
744
  for param in self.parameters():
745
- param.requires_grad = False
745
+ param.requires_grad_(False)
746
746
 
747
747
  if freeze_classifier is False:
748
748
  for param in self.decoder.class_embed.parameters():
749
- param.requires_grad = True
749
+ param.requires_grad_(True)
750
750
  for param in self.decoder.enc_score_head.parameters():
751
- param.requires_grad = True
751
+ param.requires_grad_(True)
752
752
  if self.num_denoising > 0:
753
753
  for param in self.denoising_class_embed.parameters():
754
- param.requires_grad = True
754
+ param.requires_grad_(True)
755
755
 
756
756
  def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
757
757
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
@@ -341,11 +341,11 @@ class SSD(DetectionBaseNet):
341
341
 
342
342
  def freeze(self, freeze_classifier: bool = True) -> None:
343
343
  for param in self.parameters():
344
- param.requires_grad = False
344
+ param.requires_grad_(False)
345
345
 
346
346
  if freeze_classifier is False:
347
347
  for param in self.head.classification_head.parameters():
348
- param.requires_grad = True
348
+ param.requires_grad_(True)
349
349
 
350
350
  # pylint: disable=too-many-locals
351
351
  def compute_loss(
@@ -197,8 +197,8 @@ class SSDLite(SSD):
197
197
 
198
198
  def freeze(self, freeze_classifier: bool = True) -> None:
199
199
  for param in self.parameters():
200
- param.requires_grad = False
200
+ param.requires_grad_(False)
201
201
 
202
202
  if freeze_classifier is False:
203
203
  for param in self.head.classification_head.parameters():
204
- param.requires_grad = True
204
+ param.requires_grad_(True)
@@ -270,11 +270,11 @@ class YOLO_v2(DetectionBaseNet):
270
270
 
271
271
  def freeze(self, freeze_classifier: bool = True) -> None:
272
272
  for param in self.parameters():
273
- param.requires_grad = False
273
+ param.requires_grad_(False)
274
274
 
275
275
  if freeze_classifier is False:
276
276
  for param in self.head.parameters():
277
- param.requires_grad = True
277
+ param.requires_grad_(True)
278
278
 
279
279
  def _compute_anchor_iou(self, box_wh: torch.Tensor, anchor_wh: torch.Tensor) -> torch.Tensor:
280
280
  """
@@ -376,11 +376,11 @@ class YOLO_v3(DetectionBaseNet):
376
376
 
377
377
  def freeze(self, freeze_classifier: bool = True) -> None:
378
378
  for param in self.parameters():
379
- param.requires_grad = False
379
+ param.requires_grad_(False)
380
380
 
381
381
  if freeze_classifier is False:
382
382
  for param in self.head.parameters():
383
- param.requires_grad = True
383
+ param.requires_grad_(True)
384
384
 
385
385
  def _compute_anchor_iou(self, box_wh: torch.Tensor, anchor_wh: torch.Tensor) -> torch.Tensor:
386
386
  """
@@ -444,11 +444,11 @@ class YOLO_v4(DetectionBaseNet):
444
444
 
445
445
  def freeze(self, freeze_classifier: bool = True) -> None:
446
446
  for param in self.parameters():
447
- param.requires_grad = False
447
+ param.requires_grad_(False)
448
448
 
449
449
  if freeze_classifier is False:
450
450
  for param in self.head.parameters():
451
- param.requires_grad = True
451
+ param.requires_grad_(True)
452
452
 
453
453
  def _compute_anchor_iou(self, box_wh: torch.Tensor, anchor_wh: torch.Tensor) -> torch.Tensor:
454
454
  """
birder/net/edgenext.py CHANGED
@@ -336,14 +336,14 @@ class EdgeNeXt(DetectorBackbone):
336
336
 
337
337
  def freeze_stages(self, up_to_stage: int) -> None:
338
338
  for param in self.stem.parameters():
339
- param.requires_grad = False
339
+ param.requires_grad_(False)
340
340
 
341
341
  for idx, module in enumerate(self.body.children()):
342
342
  if idx >= up_to_stage:
343
343
  break
344
344
 
345
345
  for param in module.parameters():
346
- param.requires_grad = False
346
+ param.requires_grad_(False)
347
347
 
348
348
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
349
349
  x = self.stem(x)
birder/net/edgevit.py CHANGED
@@ -354,7 +354,7 @@ class EdgeViT(DetectorBackbone):
354
354
  break
355
355
 
356
356
  for param in module.parameters():
357
- param.requires_grad = False
357
+ param.requires_grad_(False)
358
358
 
359
359
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
360
360
  return self.body(x)
@@ -307,18 +307,18 @@ class EfficientFormer_v1(BaseNet):
307
307
 
308
308
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
309
309
  for param in self.parameters():
310
- param.requires_grad = False
310
+ param.requires_grad_(False)
311
311
 
312
312
  if freeze_classifier is False:
313
313
  for param in self.classifier.parameters():
314
- param.requires_grad = True
314
+ param.requires_grad_(True)
315
315
 
316
316
  for param in self.dist_classifier.parameters():
317
- param.requires_grad = True
317
+ param.requires_grad_(True)
318
318
 
319
319
  if unfreeze_features is True:
320
320
  for param in self.features.parameters():
321
- param.requires_grad = True
321
+ param.requires_grad_(True)
322
322
 
323
323
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
324
324
  x = self.stem(x)
@@ -469,18 +469,18 @@ class EfficientFormer_v2(DetectorBackbone):
469
469
 
470
470
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
471
471
  for param in self.parameters():
472
- param.requires_grad = False
472
+ param.requires_grad_(False)
473
473
 
474
474
  if freeze_classifier is False:
475
475
  for param in self.classifier.parameters():
476
- param.requires_grad = True
476
+ param.requires_grad_(True)
477
477
 
478
478
  for param in self.dist_classifier.parameters():
479
- param.requires_grad = True
479
+ param.requires_grad_(True)
480
480
 
481
481
  if unfreeze_features is True:
482
482
  for param in self.features.parameters():
483
- param.requires_grad = True
483
+ param.requires_grad_(True)
484
484
 
485
485
  def transform_to_backbone(self) -> None:
486
486
  self.features = nn.Identity()
@@ -500,14 +500,14 @@ class EfficientFormer_v2(DetectorBackbone):
500
500
 
501
501
  def freeze_stages(self, up_to_stage: int) -> None:
502
502
  for param in self.stem.parameters():
503
- param.requires_grad = False
503
+ param.requires_grad_(False)
504
504
 
505
505
  for idx, module in enumerate(self.body.children()):
506
506
  if idx >= up_to_stage:
507
507
  break
508
508
 
509
509
  for param in module.parameters():
510
- param.requires_grad = False
510
+ param.requires_grad_(False)
511
511
 
512
512
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
513
513
  x = self.stem(x)
@@ -242,14 +242,14 @@ class EfficientNet_Lite(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionM
242
242
 
243
243
  def freeze_stages(self, up_to_stage: int) -> None:
244
244
  for param in self.stem.parameters():
245
- param.requires_grad = False
245
+ param.requires_grad_(False)
246
246
 
247
247
  for idx, module in enumerate(self.body.children()):
248
248
  if idx >= up_to_stage:
249
249
  break
250
250
 
251
251
  for param in module.parameters():
252
- param.requires_grad = False
252
+ param.requires_grad_(False)
253
253
 
254
254
  def masked_encoding_retention(
255
255
  self,
@@ -249,14 +249,14 @@ class EfficientNet_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMix
249
249
 
250
250
  def freeze_stages(self, up_to_stage: int) -> None:
251
251
  for param in self.stem.parameters():
252
- param.requires_grad = False
252
+ param.requires_grad_(False)
253
253
 
254
254
  for idx, module in enumerate(self.body.children()):
255
255
  if idx >= up_to_stage:
256
256
  break
257
257
 
258
258
  for param in module.parameters():
259
- param.requires_grad = False
259
+ param.requires_grad_(False)
260
260
 
261
261
  def masked_encoding_retention(
262
262
  self,
@@ -257,14 +257,14 @@ class EfficientNet_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMix
257
257
 
258
258
  def freeze_stages(self, up_to_stage: int) -> None:
259
259
  for param in self.stem.parameters():
260
- param.requires_grad = False
260
+ param.requires_grad_(False)
261
261
 
262
262
  for idx, module in enumerate(self.body.children()):
263
263
  if idx >= up_to_stage:
264
264
  break
265
265
 
266
266
  for param in module.parameters():
267
- param.requires_grad = False
267
+ param.requires_grad_(False)
268
268
 
269
269
  def masked_encoding_retention(
270
270
  self,
@@ -418,21 +418,21 @@ class EfficientViM(DetectorBackbone):
418
418
 
419
419
  def freeze_stages(self, up_to_stage: int) -> None:
420
420
  for param in self.stem.parameters():
421
- param.requires_grad = False
421
+ param.requires_grad_(False)
422
422
 
423
423
  for idx, module in enumerate(self.body.children()):
424
424
  if idx >= up_to_stage:
425
425
  break
426
426
 
427
427
  for param in module.parameters():
428
- param.requires_grad = False
428
+ param.requires_grad_(False)
429
429
 
430
430
  for idx, module in enumerate(self.norm.children()):
431
431
  if idx >= up_to_stage:
432
432
  break
433
433
 
434
434
  for param in module.parameters():
435
- param.requires_grad = False
435
+ param.requires_grad_(False)
436
436
 
437
437
  def forward_features(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
438
438
  x = self.stem(x)
@@ -619,14 +619,14 @@ class EfficientViT_MIT(DetectorBackbone):
619
619
 
620
620
  def freeze_stages(self, up_to_stage: int) -> None:
621
621
  for param in self.stem.parameters():
622
- param.requires_grad = False
622
+ param.requires_grad_(False)
623
623
 
624
624
  for idx, module in enumerate(self.body.children()):
625
625
  if idx >= up_to_stage:
626
626
  break
627
627
 
628
628
  for param in module.parameters():
629
- param.requires_grad = False
629
+ param.requires_grad_(False)
630
630
 
631
631
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
632
632
  x = self.stem(x)
@@ -436,14 +436,14 @@ class EfficientViT_MSFT(DetectorBackbone):
436
436
 
437
437
  def freeze_stages(self, up_to_stage: int) -> None:
438
438
  for param in self.stem.parameters():
439
- param.requires_grad = False
439
+ param.requires_grad_(False)
440
440
 
441
441
  for idx, module in enumerate(self.body.children()):
442
442
  if idx >= up_to_stage:
443
443
  break
444
444
 
445
445
  for param in module.parameters():
446
- param.requires_grad = False
446
+ param.requires_grad_(False)
447
447
 
448
448
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
449
449
  x = self.stem(x)
birder/net/fasternet.py CHANGED
@@ -199,14 +199,14 @@ class FasterNet(DetectorBackbone):
199
199
 
200
200
  def freeze_stages(self, up_to_stage: int) -> None:
201
201
  for param in self.stem.parameters():
202
- param.requires_grad = False
202
+ param.requires_grad_(False)
203
203
 
204
204
  for idx, module in enumerate(self.body.children()):
205
205
  if idx >= up_to_stage:
206
206
  break
207
207
 
208
208
  for param in module.parameters():
209
- param.requires_grad = False
209
+ param.requires_grad_(False)
210
210
 
211
211
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
212
212
  x = self.stem(x)
birder/net/fastvit.py CHANGED
@@ -607,7 +607,6 @@ class AttentionBlock(nn.Module):
607
607
 
608
608
 
609
609
  class FastVitStage(nn.Module):
610
- # pylint: disable=too-many-arguments,too-many-positional-arguments
611
610
  def __init__(
612
611
  self,
613
612
  dim: int,
@@ -843,14 +842,14 @@ class FastViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
843
842
 
844
843
  def freeze_stages(self, up_to_stage: int) -> None:
845
844
  for param in self.stem.parameters():
846
- param.requires_grad = False
845
+ param.requires_grad_(False)
847
846
 
848
847
  for idx, module in enumerate(self.body.children()):
849
848
  if idx >= up_to_stage:
850
849
  break
851
850
 
852
851
  for param in module.parameters():
853
- param.requires_grad = False
852
+ param.requires_grad_(False)
854
853
 
855
854
  def masked_encoding_retention(
856
855
  self,
birder/net/flexivit.py CHANGED
@@ -98,6 +98,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
98
98
  layer_scale_init_value: Optional[float] = self.config.get("layer_scale_init_value", None)
99
99
  pre_norm: bool = self.config.get("pre_norm", False)
100
100
  post_norm: bool = self.config.get("post_norm", True)
101
+ qkv_bias: bool = self.config.get("qkv_bias", True)
102
+ qk_norm: bool = self.config.get("qk_norm", False)
101
103
  num_reg_tokens: int = self.config.get("num_reg_tokens", 0)
102
104
  class_token: bool = self.config.get("class_token", True)
103
105
  attn_pool_head: bool = self.config.get("attn_pool_head", False)
@@ -186,6 +188,8 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
186
188
  attention_dropout,
187
189
  dpr,
188
190
  pre_norm=pre_norm,
191
+ qkv_bias=qkv_bias,
192
+ qk_norm=qk_norm,
189
193
  activation_layer=act_layer,
190
194
  layer_scale_init_value=layer_scale_init_value,
191
195
  norm_layer=norm_layer,
@@ -224,6 +228,7 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
224
228
  drop_path=0,
225
229
  activation_layer=act_layer,
226
230
  norm_layer=norm_layer,
231
+ norm_layer_eps=norm_layer_eps,
227
232
  mlp_layer=mlp_layer,
228
233
  )
229
234
 
@@ -258,16 +263,16 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
258
263
 
259
264
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
260
265
  for param in self.parameters():
261
- param.requires_grad = False
266
+ param.requires_grad_(False)
262
267
 
263
268
  if freeze_classifier is False:
264
269
  for param in self.classifier.parameters():
265
- param.requires_grad = True
270
+ param.requires_grad_(True)
266
271
 
267
272
  if unfreeze_features is True:
268
273
  if self.attn_pool is not None:
269
274
  for param in self.attn_pool.parameters():
270
- param.requires_grad = True
275
+ param.requires_grad_(True)
271
276
 
272
277
  def set_causal_attention(self, is_causal: bool = True) -> None:
273
278
  self.encoder.set_causal_attention(is_causal)
@@ -305,16 +310,16 @@ class FlexiViT(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, Mask
305
310
 
306
311
  def freeze_stages(self, up_to_stage: int) -> None:
307
312
  for param in self.conv_proj.parameters():
308
- param.requires_grad = False
313
+ param.requires_grad_(False)
309
314
 
310
- self.pos_embedding.requires_grad = False
315
+ self.pos_embedding.requires_grad_(False)
311
316
 
312
317
  for idx, module in enumerate(self.encoder.children()):
313
318
  if idx >= up_to_stage:
314
319
  break
315
320
 
316
321
  for param in module.parameters():
317
- param.requires_grad = False
322
+ param.requires_grad_(False)
318
323
 
319
324
  # pylint: disable=too-many-branches
320
325
  def masked_encoding_omission(
birder/net/focalnet.py CHANGED
@@ -245,7 +245,6 @@ class FocalNetBlock(nn.Module):
245
245
 
246
246
 
247
247
  class FocalNetStage(nn.Module):
248
- # pylint: disable=too-many-arguments,too-many-positional-arguments
249
248
  def __init__(
250
249
  self,
251
250
  dim: int,
@@ -398,14 +397,14 @@ class FocalNet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
398
397
 
399
398
  def freeze_stages(self, up_to_stage: int) -> None:
400
399
  for param in self.stem.parameters():
401
- param.requires_grad = False
400
+ param.requires_grad_(False)
402
401
 
403
402
  for idx, module in enumerate(self.body.children()):
404
403
  if idx >= up_to_stage:
405
404
  break
406
405
 
407
406
  for param in module.parameters():
408
- param.requires_grad = False
407
+ param.requires_grad_(False)
409
408
 
410
409
  def masked_encoding_retention(
411
410
  self,
birder/net/gc_vit.py CHANGED
@@ -500,14 +500,14 @@ class GC_ViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
500
500
 
501
501
  def freeze_stages(self, up_to_stage: int) -> None:
502
502
  for param in self.stem.parameters():
503
- param.requires_grad = False
503
+ param.requires_grad_(False)
504
504
 
505
505
  for idx, module in enumerate(self.body.children()):
506
506
  if idx >= up_to_stage:
507
507
  break
508
508
 
509
509
  for param in module.parameters():
510
- param.requires_grad = False
510
+ param.requires_grad_(False)
511
511
 
512
512
  def set_dynamic_size(self, dynamic_size: bool = True) -> None:
513
513
  super().set_dynamic_size(dynamic_size)
@@ -669,3 +669,18 @@ registry.register_model_config(
669
669
  "drop_path_rate": 0.5,
670
670
  },
671
671
  )
672
+
673
+ registry.register_weights(
674
+ "gc_vit_xxt_il-common",
675
+ {
676
+ "description": "GC ViT XX-Tiny model trained on the il-common dataset",
677
+ "resolution": (256, 256),
678
+ "formats": {
679
+ "pt": {
680
+ "file_size": 47.9,
681
+ "sha256": "5326a53903759e32178a6c2994639e6d0172faa51e1573a700f8d12b4f447c61",
682
+ }
683
+ },
684
+ "net": {"network": "gc_vit_xxt", "tag": "il-common"},
685
+ },
686
+ )
birder/net/ghostnet_v1.py CHANGED
@@ -254,14 +254,14 @@ class GhostNet_v1(DetectorBackbone):
254
254
 
255
255
  def freeze_stages(self, up_to_stage: int) -> None:
256
256
  for param in self.stem.parameters():
257
- param.requires_grad = False
257
+ param.requires_grad_(False)
258
258
 
259
259
  for idx, module in enumerate(self.body.children()):
260
260
  if idx >= up_to_stage:
261
261
  break
262
262
 
263
263
  for param in module.parameters():
264
- param.requires_grad = False
264
+ param.requires_grad_(False)
265
265
 
266
266
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
267
267
  x = self.stem(x)
birder/net/ghostnet_v2.py CHANGED
@@ -317,14 +317,14 @@ class GhostNet_v2(DetectorBackbone):
317
317
 
318
318
  def freeze_stages(self, up_to_stage: int) -> None:
319
319
  for param in self.stem.parameters():
320
- param.requires_grad = False
320
+ param.requires_grad_(False)
321
321
 
322
322
  for idx, module in enumerate(self.body.children()):
323
323
  if idx >= up_to_stage:
324
324
  break
325
325
 
326
326
  for param in module.parameters():
327
- param.requires_grad = False
327
+ param.requires_grad_(False)
328
328
 
329
329
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
330
330
  x = self.stem(x)
@@ -391,14 +391,14 @@ class GroupMixFormer(DetectorBackbone):
391
391
 
392
392
  def freeze_stages(self, up_to_stage: int) -> None:
393
393
  for param in self.stem.parameters():
394
- param.requires_grad = False
394
+ param.requires_grad_(False)
395
395
 
396
396
  for idx, module in enumerate(self.body.children()):
397
397
  if idx >= up_to_stage:
398
398
  break
399
399
 
400
400
  for param in module.parameters():
401
- param.requires_grad = False
401
+ param.requires_grad_(False)
402
402
 
403
403
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
404
404
  x = self.stem(x)