birder 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. birder/common/fs_ops.py +2 -2
  2. birder/common/training_cli.py +6 -1
  3. birder/common/training_utils.py +4 -2
  4. birder/net/base.py +3 -3
  5. birder/net/biformer.py +2 -2
  6. birder/net/cas_vit.py +6 -6
  7. birder/net/coat.py +8 -8
  8. birder/net/conv2former.py +2 -2
  9. birder/net/convnext_v1.py +2 -2
  10. birder/net/convnext_v2.py +2 -2
  11. birder/net/crossformer.py +2 -2
  12. birder/net/cspnet.py +2 -2
  13. birder/net/cswin_transformer.py +2 -2
  14. birder/net/darknet.py +2 -2
  15. birder/net/davit.py +2 -2
  16. birder/net/deit.py +3 -3
  17. birder/net/deit3.py +3 -3
  18. birder/net/densenet.py +2 -2
  19. birder/net/detection/deformable_detr.py +2 -2
  20. birder/net/detection/detr.py +2 -2
  21. birder/net/detection/efficientdet.py +2 -2
  22. birder/net/detection/faster_rcnn.py +2 -2
  23. birder/net/detection/fcos.py +2 -2
  24. birder/net/detection/retinanet.py +2 -2
  25. birder/net/detection/rt_detr_v1.py +4 -4
  26. birder/net/detection/ssd.py +2 -2
  27. birder/net/detection/ssdlite.py +2 -2
  28. birder/net/detection/yolo_v2.py +2 -2
  29. birder/net/detection/yolo_v3.py +2 -2
  30. birder/net/detection/yolo_v4.py +2 -2
  31. birder/net/edgenext.py +2 -2
  32. birder/net/edgevit.py +1 -1
  33. birder/net/efficientformer_v1.py +4 -4
  34. birder/net/efficientformer_v2.py +6 -6
  35. birder/net/efficientnet_lite.py +2 -2
  36. birder/net/efficientnet_v1.py +2 -2
  37. birder/net/efficientnet_v2.py +2 -2
  38. birder/net/efficientvim.py +3 -3
  39. birder/net/efficientvit_mit.py +2 -2
  40. birder/net/efficientvit_msft.py +2 -2
  41. birder/net/fasternet.py +2 -2
  42. birder/net/fastvit.py +2 -2
  43. birder/net/flexivit.py +6 -6
  44. birder/net/focalnet.py +2 -2
  45. birder/net/gc_vit.py +17 -2
  46. birder/net/ghostnet_v1.py +2 -2
  47. birder/net/ghostnet_v2.py +2 -2
  48. birder/net/groupmixformer.py +2 -2
  49. birder/net/hgnet_v1.py +2 -2
  50. birder/net/hgnet_v2.py +2 -2
  51. birder/net/hiera.py +2 -2
  52. birder/net/hieradet.py +2 -2
  53. birder/net/hornet.py +2 -2
  54. birder/net/iformer.py +2 -2
  55. birder/net/inception_next.py +2 -2
  56. birder/net/inception_resnet_v1.py +2 -2
  57. birder/net/inception_resnet_v2.py +2 -2
  58. birder/net/inception_v3.py +2 -2
  59. birder/net/inception_v4.py +2 -2
  60. birder/net/levit.py +4 -4
  61. birder/net/lit_v1.py +2 -2
  62. birder/net/lit_v1_tiny.py +2 -2
  63. birder/net/lit_v2.py +2 -2
  64. birder/net/maxvit.py +2 -2
  65. birder/net/metaformer.py +2 -2
  66. birder/net/mnasnet.py +2 -2
  67. birder/net/mobilenet_v1.py +2 -2
  68. birder/net/mobilenet_v2.py +2 -2
  69. birder/net/mobilenet_v3_large.py +2 -2
  70. birder/net/mobilenet_v4.py +2 -2
  71. birder/net/mobilenet_v4_hybrid.py +2 -2
  72. birder/net/mobileone.py +2 -2
  73. birder/net/mobilevit_v2.py +2 -2
  74. birder/net/moganet.py +2 -2
  75. birder/net/mvit_v2.py +2 -2
  76. birder/net/nextvit.py +2 -2
  77. birder/net/nfnet.py +2 -2
  78. birder/net/pit.py +6 -6
  79. birder/net/pvt_v1.py +2 -2
  80. birder/net/pvt_v2.py +2 -2
  81. birder/net/rdnet.py +2 -2
  82. birder/net/regionvit.py +6 -6
  83. birder/net/regnet.py +2 -2
  84. birder/net/regnet_z.py +2 -2
  85. birder/net/repghost.py +2 -2
  86. birder/net/repvgg.py +2 -2
  87. birder/net/repvit.py +6 -6
  88. birder/net/resnest.py +2 -2
  89. birder/net/resnet_v1.py +2 -2
  90. birder/net/resnet_v2.py +2 -2
  91. birder/net/resnext.py +2 -2
  92. birder/net/rope_deit3.py +3 -3
  93. birder/net/rope_flexivit.py +6 -6
  94. birder/net/rope_vit.py +20 -6
  95. birder/net/shufflenet_v1.py +2 -2
  96. birder/net/shufflenet_v2.py +2 -2
  97. birder/net/smt.py +1 -1
  98. birder/net/squeezenext.py +2 -2
  99. birder/net/ssl/byol.py +3 -2
  100. birder/net/ssl/capi.py +156 -11
  101. birder/net/ssl/data2vec.py +3 -1
  102. birder/net/ssl/data2vec2.py +3 -1
  103. birder/net/ssl/dino_v1.py +1 -1
  104. birder/net/ssl/dino_v2.py +140 -18
  105. birder/net/ssl/franca.py +145 -13
  106. birder/net/ssl/ibot.py +1 -1
  107. birder/net/ssl/mmcr.py +3 -1
  108. birder/net/starnet.py +2 -2
  109. birder/net/swiftformer.py +6 -6
  110. birder/net/swin_transformer_v1.py +2 -2
  111. birder/net/swin_transformer_v2.py +2 -2
  112. birder/net/tiny_vit.py +2 -2
  113. birder/net/transnext.py +1 -1
  114. birder/net/uniformer.py +1 -1
  115. birder/net/van.py +1 -1
  116. birder/net/vgg.py +1 -1
  117. birder/net/vgg_reduced.py +1 -1
  118. birder/net/vit.py +6 -6
  119. birder/net/vit_parallel.py +5 -5
  120. birder/net/vit_sam.py +3 -3
  121. birder/net/vovnet_v1.py +2 -2
  122. birder/net/vovnet_v2.py +2 -2
  123. birder/net/wide_resnet.py +2 -2
  124. birder/net/xception.py +2 -2
  125. birder/net/xcit.py +2 -2
  126. birder/results/detection.py +104 -0
  127. birder/results/gui.py +10 -8
  128. birder/scripts/benchmark.py +1 -1
  129. birder/scripts/train.py +7 -13
  130. birder/scripts/train_barlow_twins.py +7 -12
  131. birder/scripts/train_byol.py +8 -13
  132. birder/scripts/train_capi.py +33 -13
  133. birder/scripts/train_data2vec.py +8 -13
  134. birder/scripts/train_data2vec2.py +10 -15
  135. birder/scripts/train_detection.py +5 -10
  136. birder/scripts/train_dino_v1.py +16 -19
  137. birder/scripts/train_dino_v2.py +58 -44
  138. birder/scripts/train_dino_v2_dist.py +58 -44
  139. birder/scripts/train_franca.py +42 -28
  140. birder/scripts/train_i_jepa.py +8 -12
  141. birder/scripts/train_ibot.py +12 -15
  142. birder/scripts/train_kd.py +7 -13
  143. birder/scripts/train_mim.py +7 -11
  144. birder/scripts/train_mmcr.py +8 -13
  145. birder/scripts/train_rotnet.py +8 -13
  146. birder/scripts/train_simclr.py +7 -12
  147. birder/scripts/train_vicreg.py +7 -12
  148. birder/tools/det_results.py +91 -0
  149. birder/tools/results.py +11 -7
  150. birder/version.py +1 -1
  151. {birder-0.3.0.dist-info → birder-0.3.1.dist-info}/METADATA +1 -1
  152. birder-0.3.1.dist-info/RECORD +298 -0
  153. birder-0.3.0.dist-info/RECORD +0 -298
  154. {birder-0.3.0.dist-info → birder-0.3.1.dist-info}/WHEEL +0 -0
  155. {birder-0.3.0.dist-info → birder-0.3.1.dist-info}/entry_points.txt +0 -0
  156. {birder-0.3.0.dist-info → birder-0.3.1.dist-info}/licenses/LICENSE +0 -0
  157. {birder-0.3.0.dist-info → birder-0.3.1.dist-info}/top_level.txt +0 -0
birder/common/fs_ops.py CHANGED
@@ -627,7 +627,7 @@ def load_model(
627
627
  net.to(dtype)
628
628
  if inference is True:
629
629
  for param in net.parameters():
630
- param.requires_grad = False
630
+ param.requires_grad_(False)
631
631
 
632
632
  if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
633
633
  net.eval()
@@ -799,7 +799,7 @@ def load_detection_model(
799
799
  net.to(dtype)
800
800
  if inference is True:
801
801
  for param in net.parameters():
802
- param.requires_grad = False
802
+ param.requires_grad_(False)
803
803
 
804
804
  net.eval()
805
805
 
@@ -39,6 +39,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
39
39
  group = parser.add_argument_group("Optimization parameters")
40
40
  group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
41
41
  group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
42
+ group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
42
43
  group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
43
44
  group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
44
45
  group.add_argument("--opt-eps", type=float, help="optimizer epsilon (None to use the optimizer default)")
@@ -249,6 +250,7 @@ def add_data_aug_args(
249
250
  default_level: int = 4,
250
251
  default_min_scale: Optional[float] = None,
251
252
  default_re_prob: Optional[float] = None,
253
+ smoothing_alpha: bool = False,
252
254
  mixup_cutmix: bool = False,
253
255
  ) -> None:
254
256
  group = parser.add_argument_group("Data augmentation parameters")
@@ -285,6 +287,8 @@ def add_data_aug_args(
285
287
  group.add_argument(
286
288
  "--simple-crop", default=False, action="store_true", help="use simple random crop (SRC) instead of RRC"
287
289
  )
290
+ if smoothing_alpha is True:
291
+ group.add_argument("--smoothing-alpha", type=float, default=0.0, help="label smoothing alpha")
288
292
  if mixup_cutmix is True:
289
293
  group.add_argument("--mixup-alpha", type=float, help="mixup alpha")
290
294
  group.add_argument("--cutmix", default=False, action="store_true", help="enable cutmix")
@@ -565,9 +569,9 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
565
569
  group.add_argument("--wds", default=False, action="store_true", help="use webdataset for training")
566
570
  group.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
567
571
  group.add_argument("--wds-cache-dir", type=str, metavar="DIR", help="webdataset cache directory")
568
- group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
569
572
  if unsupervised is False:
570
573
  group.add_argument("--wds-class-file", type=str, metavar="FILE", help="class list file")
574
+ group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
571
575
  group.add_argument("--wds-val-size", type=int, metavar="N", help="size of the wds validation set")
572
576
  group.add_argument(
573
577
  "--wds-training-split", type=str, default="training", metavar="NAME", help="wds dataset train split"
@@ -576,6 +580,7 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
576
580
  "--wds-val-split", type=str, default="validation", metavar="NAME", help="wds dataset validation split"
577
581
  )
578
582
  else:
583
+ group.add_argument("--wds-size", type=int, metavar="N", help="size of the wds")
579
584
  group.add_argument(
580
585
  "--wds-split", type=str, default="training", metavar="NAME", help="wds dataset split to load"
581
586
  )
@@ -593,12 +593,14 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
593
593
  kwargs["betas"] = args.opt_betas
594
594
  if getattr(args, "opt_alpha", None) is not None:
595
595
  kwargs["alpha"] = args.opt_alpha
596
+ if getattr(args, "opt_fused", False) is True:
597
+ kwargs["fused"] = True
596
598
 
597
599
  # For optimizer compilation
598
600
  # lr = torch.tensor(l_rate) - Causes weird LR scheduling bugs
599
601
  lr = l_rate
600
- if getattr(args, "compile_opt", False) is not False:
601
- if opt not in ("lamb", "lambw", "lars"):
602
+ if getattr(args, "compile_opt", False) is True:
603
+ if opt not in ("sgd", "lamb", "lambw", "lars"):
602
604
  logger.debug("Setting optimizer capturable to True")
603
605
  kwargs["capturable"] = True
604
606
 
birder/net/base.py CHANGED
@@ -173,14 +173,14 @@ class BaseNet(nn.Module):
173
173
 
174
174
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
175
175
  for param in self.parameters():
176
- param.requires_grad = False
176
+ param.requires_grad_(False)
177
177
 
178
178
  if freeze_classifier is False:
179
179
  for param in self.classifier.parameters():
180
- param.requires_grad = True
180
+ param.requires_grad_(True)
181
181
  if unfreeze_features is True and hasattr(self, "features") is True:
182
182
  for param in self.features.parameters():
183
- param.requires_grad = True
183
+ param.requires_grad_(True)
184
184
 
185
185
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
186
186
  """
birder/net/biformer.py CHANGED
@@ -468,14 +468,14 @@ class BiFormer(DetectorBackbone):
468
468
 
469
469
  def freeze_stages(self, up_to_stage: int) -> None:
470
470
  for param in self.stem.parameters():
471
- param.requires_grad = False
471
+ param.requires_grad_(False)
472
472
 
473
473
  for idx, module in enumerate(self.body.children()):
474
474
  if idx >= up_to_stage:
475
475
  break
476
476
 
477
477
  for param in module.parameters():
478
- param.requires_grad = False
478
+ param.requires_grad_(False)
479
479
 
480
480
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
481
481
  x = self.stem(x)
birder/net/cas_vit.py CHANGED
@@ -269,18 +269,18 @@ class CAS_ViT(DetectorBackbone):
269
269
 
270
270
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
271
271
  for param in self.parameters():
272
- param.requires_grad = False
272
+ param.requires_grad_(False)
273
273
 
274
274
  if freeze_classifier is False:
275
275
  for param in self.classifier.parameters():
276
- param.requires_grad = True
276
+ param.requires_grad_(True)
277
277
 
278
278
  for param in self.dist_classifier.parameters():
279
- param.requires_grad = True
279
+ param.requires_grad_(True)
280
280
 
281
281
  if unfreeze_features is True:
282
282
  for param in self.features.parameters():
283
- param.requires_grad = True
283
+ param.requires_grad_(True)
284
284
 
285
285
  def transform_to_backbone(self) -> None:
286
286
  self.features = nn.Identity()
@@ -300,14 +300,14 @@ class CAS_ViT(DetectorBackbone):
300
300
 
301
301
  def freeze_stages(self, up_to_stage: int) -> None:
302
302
  for param in self.stem.parameters():
303
- param.requires_grad = False
303
+ param.requires_grad_(False)
304
304
 
305
305
  for idx, module in enumerate(self.body.children()):
306
306
  if idx >= up_to_stage:
307
307
  break
308
308
 
309
309
  for param in module.parameters():
310
- param.requires_grad = False
310
+ param.requires_grad_(False)
311
311
 
312
312
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
313
313
  x = self.stem(x)
birder/net/coat.py CHANGED
@@ -563,24 +563,24 @@ class CoaT(DetectorBackbone):
563
563
  def freeze_stages(self, up_to_stage: int) -> None:
564
564
  if up_to_stage >= 1:
565
565
  for param in self.patch_embed1.parameters():
566
- param.requires_grad = False
566
+ param.requires_grad_(False)
567
567
  for param in self.serial_blocks1.parameters():
568
- param.requires_grad = False
568
+ param.requires_grad_(False)
569
569
  if up_to_stage >= 2:
570
570
  for param in self.patch_embed2.parameters():
571
- param.requires_grad = False
571
+ param.requires_grad_(False)
572
572
  for param in self.serial_blocks2.parameters():
573
- param.requires_grad = False
573
+ param.requires_grad_(False)
574
574
  if up_to_stage >= 3:
575
575
  for param in self.patch_embed3.parameters():
576
- param.requires_grad = False
576
+ param.requires_grad_(False)
577
577
  for param in self.serial_blocks3.parameters():
578
- param.requires_grad = False
578
+ param.requires_grad_(False)
579
579
  if up_to_stage >= 4:
580
580
  for param in self.patch_embed4.parameters():
581
- param.requires_grad = False
581
+ param.requires_grad_(False)
582
582
  for param in self.serial_blocks4.parameters():
583
- param.requires_grad = False
583
+ param.requires_grad_(False)
584
584
 
585
585
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
586
586
  features = self._features(x)
birder/net/conv2former.py CHANGED
@@ -218,14 +218,14 @@ class Conv2Former(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
218
218
 
219
219
  def freeze_stages(self, up_to_stage: int) -> None:
220
220
  for param in self.stem.parameters():
221
- param.requires_grad = False
221
+ param.requires_grad_(False)
222
222
 
223
223
  for idx, module in enumerate(self.body.children()):
224
224
  if idx >= up_to_stage:
225
225
  break
226
226
 
227
227
  for param in module.parameters():
228
- param.requires_grad = False
228
+ param.requires_grad_(False)
229
229
 
230
230
  def masked_encoding_retention(
231
231
  self,
birder/net/convnext_v1.py CHANGED
@@ -158,14 +158,14 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
158
158
 
159
159
  def freeze_stages(self, up_to_stage: int) -> None:
160
160
  for param in self.stem.parameters():
161
- param.requires_grad = False
161
+ param.requires_grad_(False)
162
162
 
163
163
  for idx, module in enumerate(self.body.children()):
164
164
  if idx >= up_to_stage:
165
165
  break
166
166
 
167
167
  for param in module.parameters():
168
- param.requires_grad = False
168
+ param.requires_grad_(False)
169
169
 
170
170
  def masked_encoding_retention(
171
171
  self,
birder/net/convnext_v2.py CHANGED
@@ -180,14 +180,14 @@ class ConvNeXt_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
180
180
 
181
181
  def freeze_stages(self, up_to_stage: int) -> None:
182
182
  for param in self.stem.parameters():
183
- param.requires_grad = False
183
+ param.requires_grad_(False)
184
184
 
185
185
  for idx, module in enumerate(self.body.children()):
186
186
  if idx >= up_to_stage:
187
187
  break
188
188
 
189
189
  for param in module.parameters():
190
- param.requires_grad = False
190
+ param.requires_grad_(False)
191
191
 
192
192
  def masked_encoding_retention(
193
193
  self,
birder/net/crossformer.py CHANGED
@@ -404,14 +404,14 @@ class CrossFormer(DetectorBackbone):
404
404
 
405
405
  def freeze_stages(self, up_to_stage: int) -> None:
406
406
  for param in self.patch_embed.parameters():
407
- param.requires_grad = False
407
+ param.requires_grad_(False)
408
408
 
409
409
  for idx, module in enumerate(self.body.children()):
410
410
  if idx >= up_to_stage:
411
411
  break
412
412
 
413
413
  for param in module.parameters():
414
- param.requires_grad = False
414
+ param.requires_grad_(False)
415
415
 
416
416
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
417
417
  x = self.patch_embed(x)
birder/net/cspnet.py CHANGED
@@ -342,14 +342,14 @@ class CSPNet(DetectorBackbone):
342
342
 
343
343
  def freeze_stages(self, up_to_stage: int) -> None:
344
344
  for param in self.stem.parameters():
345
- param.requires_grad = False
345
+ param.requires_grad_(False)
346
346
 
347
347
  for idx, module in enumerate(self.body.children()):
348
348
  if idx >= up_to_stage:
349
349
  break
350
350
 
351
351
  for param in module.parameters():
352
- param.requires_grad = False
352
+ param.requires_grad_(False)
353
353
 
354
354
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
355
355
  x = self.stem(x)
@@ -359,14 +359,14 @@ class CSWin_Transformer(DetectorBackbone):
359
359
 
360
360
  def freeze_stages(self, up_to_stage: int) -> None:
361
361
  for param in self.stem.parameters():
362
- param.requires_grad = False
362
+ param.requires_grad_(False)
363
363
 
364
364
  for idx, module in enumerate(self.body.children()):
365
365
  if idx >= up_to_stage:
366
366
  break
367
367
 
368
368
  for param in module.parameters():
369
- param.requires_grad = False
369
+ param.requires_grad_(False)
370
370
 
371
371
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
372
372
  x = self.stem(x)
birder/net/darknet.py CHANGED
@@ -115,14 +115,14 @@ class Darknet(DetectorBackbone):
115
115
 
116
116
  def freeze_stages(self, up_to_stage: int) -> None:
117
117
  for param in self.stem.parameters():
118
- param.requires_grad = False
118
+ param.requires_grad_(False)
119
119
 
120
120
  for idx, module in enumerate(self.body.children()):
121
121
  if idx >= up_to_stage:
122
122
  break
123
123
 
124
124
  for param in module.parameters():
125
- param.requires_grad = False
125
+ param.requires_grad_(False)
126
126
 
127
127
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
128
128
  x = self.stem(x)
birder/net/davit.py CHANGED
@@ -391,14 +391,14 @@ class DaViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
391
391
 
392
392
  def freeze_stages(self, up_to_stage: int) -> None:
393
393
  for param in self.stem.parameters():
394
- param.requires_grad = False
394
+ param.requires_grad_(False)
395
395
 
396
396
  for idx, module in enumerate(self.body.children()):
397
397
  if idx >= up_to_stage:
398
398
  break
399
399
 
400
400
  for param in module.parameters():
401
- param.requires_grad = False
401
+ param.requires_grad_(False)
402
402
 
403
403
  def masked_encoding_retention(
404
404
  self,
birder/net/deit.py CHANGED
@@ -117,14 +117,14 @@ class DeiT(BaseNet):
117
117
 
118
118
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
119
119
  for param in self.parameters():
120
- param.requires_grad = False
120
+ param.requires_grad_(False)
121
121
 
122
122
  if freeze_classifier is False:
123
123
  for param in self.classifier.parameters():
124
- param.requires_grad = True
124
+ param.requires_grad_(True)
125
125
 
126
126
  for param in self.dist_classifier.parameters():
127
- param.requires_grad = True
127
+ param.requires_grad_(True)
128
128
 
129
129
  def set_causal_attention(self, is_causal: bool = True) -> None:
130
130
  self.encoder.set_causal_attention(is_causal)
birder/net/deit3.py CHANGED
@@ -182,16 +182,16 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
182
182
 
183
183
  def freeze_stages(self, up_to_stage: int) -> None:
184
184
  for param in self.conv_proj.parameters():
185
- param.requires_grad = False
185
+ param.requires_grad_(False)
186
186
 
187
- self.pos_embedding.requires_grad = False
187
+ self.pos_embedding.requires_grad_(False)
188
188
 
189
189
  for idx, module in enumerate(self.encoder.children()):
190
190
  if idx >= up_to_stage:
191
191
  break
192
192
 
193
193
  for param in module.parameters():
194
- param.requires_grad = False
194
+ param.requires_grad_(False)
195
195
 
196
196
  def set_causal_attention(self, is_causal: bool = True) -> None:
197
197
  self.encoder.set_causal_attention(is_causal)
birder/net/densenet.py CHANGED
@@ -140,14 +140,14 @@ class DenseNet(DetectorBackbone):
140
140
 
141
141
  def freeze_stages(self, up_to_stage: int) -> None:
142
142
  for param in self.stem.parameters():
143
- param.requires_grad = False
143
+ param.requires_grad_(False)
144
144
 
145
145
  for idx, module in enumerate(self.body.children()):
146
146
  if idx >= up_to_stage:
147
147
  break
148
148
 
149
149
  for param in module.parameters():
150
- param.requires_grad = False
150
+ param.requires_grad_(False)
151
151
 
152
152
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
153
153
  x = self.stem(x)
@@ -633,11 +633,11 @@ class Deformable_DETR(DetectionBaseNet):
633
633
 
634
634
  def freeze(self, freeze_classifier: bool = True) -> None:
635
635
  for param in self.parameters():
636
- param.requires_grad = False
636
+ param.requires_grad_(False)
637
637
 
638
638
  if freeze_classifier is False:
639
639
  for param in self.class_embed.parameters():
640
- param.requires_grad = True
640
+ param.requires_grad_(True)
641
641
 
642
642
  def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
643
643
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
@@ -356,11 +356,11 @@ class DETR(DetectionBaseNet):
356
356
 
357
357
  def freeze(self, freeze_classifier: bool = True) -> None:
358
358
  for param in self.parameters():
359
- param.requires_grad = False
359
+ param.requires_grad_(False)
360
360
 
361
361
  if freeze_classifier is False:
362
362
  for param in self.class_embed.parameters():
363
- param.requires_grad = True
363
+ param.requires_grad_(True)
364
364
 
365
365
  def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
366
366
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
@@ -601,11 +601,11 @@ class EfficientDet(DetectionBaseNet):
601
601
 
602
602
  def freeze(self, freeze_classifier: bool = True) -> None:
603
603
  for param in self.parameters():
604
- param.requires_grad = False
604
+ param.requires_grad_(False)
605
605
 
606
606
  if freeze_classifier is False:
607
607
  for param in self.class_net.parameters():
608
- param.requires_grad = True
608
+ param.requires_grad_(True)
609
609
 
610
610
  def compute_loss(
611
611
  self,
@@ -851,11 +851,11 @@ class Faster_RCNN(DetectionBaseNet):
851
851
 
852
852
  def freeze(self, freeze_classifier: bool = True) -> None:
853
853
  for param in self.parameters():
854
- param.requires_grad = False
854
+ param.requires_grad_(False)
855
855
 
856
856
  if freeze_classifier is False:
857
857
  for param in self.roi_heads.box_predictor.parameters():
858
- param.requires_grad = True
858
+ param.requires_grad_(True)
859
859
 
860
860
  def forward(
861
861
  self,
@@ -338,11 +338,11 @@ class FCOS(DetectionBaseNet):
338
338
 
339
339
  def freeze(self, freeze_classifier: bool = True) -> None:
340
340
  for param in self.parameters():
341
- param.requires_grad = False
341
+ param.requires_grad_(False)
342
342
 
343
343
  if freeze_classifier is False:
344
344
  for param in self.head.classification_head.parameters():
345
- param.requires_grad = True
345
+ param.requires_grad_(True)
346
346
 
347
347
  def compute_loss(
348
348
  self,
@@ -332,11 +332,11 @@ class RetinaNet(DetectionBaseNet):
332
332
 
333
333
  def freeze(self, freeze_classifier: bool = True) -> None:
334
334
  for param in self.parameters():
335
- param.requires_grad = False
335
+ param.requires_grad_(False)
336
336
 
337
337
  if freeze_classifier is False:
338
338
  for param in self.head.classification_head.parameters():
339
- param.requires_grad = True
339
+ param.requires_grad_(True)
340
340
 
341
341
  @torch.jit.unused # type: ignore[untyped-decorator]
342
342
  @torch.compiler.disable() # type: ignore[untyped-decorator]
@@ -742,16 +742,16 @@ class RT_DETR_v1(DetectionBaseNet):
742
742
 
743
743
  def freeze(self, freeze_classifier: bool = True) -> None:
744
744
  for param in self.parameters():
745
- param.requires_grad = False
745
+ param.requires_grad_(False)
746
746
 
747
747
  if freeze_classifier is False:
748
748
  for param in self.decoder.class_embed.parameters():
749
- param.requires_grad = True
749
+ param.requires_grad_(True)
750
750
  for param in self.decoder.enc_score_head.parameters():
751
- param.requires_grad = True
751
+ param.requires_grad_(True)
752
752
  if self.num_denoising > 0:
753
753
  for param in self.denoising_class_embed.parameters():
754
- param.requires_grad = True
754
+ param.requires_grad_(True)
755
755
 
756
756
  def _get_src_permutation_idx(self, indices: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
757
757
  batch_idx = torch.concat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
@@ -341,11 +341,11 @@ class SSD(DetectionBaseNet):
341
341
 
342
342
  def freeze(self, freeze_classifier: bool = True) -> None:
343
343
  for param in self.parameters():
344
- param.requires_grad = False
344
+ param.requires_grad_(False)
345
345
 
346
346
  if freeze_classifier is False:
347
347
  for param in self.head.classification_head.parameters():
348
- param.requires_grad = True
348
+ param.requires_grad_(True)
349
349
 
350
350
  # pylint: disable=too-many-locals
351
351
  def compute_loss(
@@ -197,8 +197,8 @@ class SSDLite(SSD):
197
197
 
198
198
  def freeze(self, freeze_classifier: bool = True) -> None:
199
199
  for param in self.parameters():
200
- param.requires_grad = False
200
+ param.requires_grad_(False)
201
201
 
202
202
  if freeze_classifier is False:
203
203
  for param in self.head.classification_head.parameters():
204
- param.requires_grad = True
204
+ param.requires_grad_(True)
@@ -270,11 +270,11 @@ class YOLO_v2(DetectionBaseNet):
270
270
 
271
271
  def freeze(self, freeze_classifier: bool = True) -> None:
272
272
  for param in self.parameters():
273
- param.requires_grad = False
273
+ param.requires_grad_(False)
274
274
 
275
275
  if freeze_classifier is False:
276
276
  for param in self.head.parameters():
277
- param.requires_grad = True
277
+ param.requires_grad_(True)
278
278
 
279
279
  def _compute_anchor_iou(self, box_wh: torch.Tensor, anchor_wh: torch.Tensor) -> torch.Tensor:
280
280
  """
@@ -376,11 +376,11 @@ class YOLO_v3(DetectionBaseNet):
376
376
 
377
377
  def freeze(self, freeze_classifier: bool = True) -> None:
378
378
  for param in self.parameters():
379
- param.requires_grad = False
379
+ param.requires_grad_(False)
380
380
 
381
381
  if freeze_classifier is False:
382
382
  for param in self.head.parameters():
383
- param.requires_grad = True
383
+ param.requires_grad_(True)
384
384
 
385
385
  def _compute_anchor_iou(self, box_wh: torch.Tensor, anchor_wh: torch.Tensor) -> torch.Tensor:
386
386
  """
@@ -444,11 +444,11 @@ class YOLO_v4(DetectionBaseNet):
444
444
 
445
445
  def freeze(self, freeze_classifier: bool = True) -> None:
446
446
  for param in self.parameters():
447
- param.requires_grad = False
447
+ param.requires_grad_(False)
448
448
 
449
449
  if freeze_classifier is False:
450
450
  for param in self.head.parameters():
451
- param.requires_grad = True
451
+ param.requires_grad_(True)
452
452
 
453
453
  def _compute_anchor_iou(self, box_wh: torch.Tensor, anchor_wh: torch.Tensor) -> torch.Tensor:
454
454
  """
birder/net/edgenext.py CHANGED
@@ -336,14 +336,14 @@ class EdgeNeXt(DetectorBackbone):
336
336
 
337
337
  def freeze_stages(self, up_to_stage: int) -> None:
338
338
  for param in self.stem.parameters():
339
- param.requires_grad = False
339
+ param.requires_grad_(False)
340
340
 
341
341
  for idx, module in enumerate(self.body.children()):
342
342
  if idx >= up_to_stage:
343
343
  break
344
344
 
345
345
  for param in module.parameters():
346
- param.requires_grad = False
346
+ param.requires_grad_(False)
347
347
 
348
348
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
349
349
  x = self.stem(x)
birder/net/edgevit.py CHANGED
@@ -354,7 +354,7 @@ class EdgeViT(DetectorBackbone):
354
354
  break
355
355
 
356
356
  for param in module.parameters():
357
- param.requires_grad = False
357
+ param.requires_grad_(False)
358
358
 
359
359
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
360
360
  return self.body(x)
@@ -307,18 +307,18 @@ class EfficientFormer_v1(BaseNet):
307
307
 
308
308
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
309
309
  for param in self.parameters():
310
- param.requires_grad = False
310
+ param.requires_grad_(False)
311
311
 
312
312
  if freeze_classifier is False:
313
313
  for param in self.classifier.parameters():
314
- param.requires_grad = True
314
+ param.requires_grad_(True)
315
315
 
316
316
  for param in self.dist_classifier.parameters():
317
- param.requires_grad = True
317
+ param.requires_grad_(True)
318
318
 
319
319
  if unfreeze_features is True:
320
320
  for param in self.features.parameters():
321
- param.requires_grad = True
321
+ param.requires_grad_(True)
322
322
 
323
323
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
324
324
  x = self.stem(x)