birder 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. birder/adversarial/deepfool.py +2 -0
  2. birder/adversarial/simba.py +2 -0
  3. birder/common/fs_ops.py +2 -2
  4. birder/common/masking.py +13 -4
  5. birder/common/training_cli.py +6 -1
  6. birder/common/training_utils.py +4 -2
  7. birder/inference/classification.py +1 -1
  8. birder/introspection/__init__.py +2 -0
  9. birder/introspection/base.py +0 -7
  10. birder/introspection/feature_pca.py +101 -0
  11. birder/kernels/soft_nms/soft_nms.cpp +5 -2
  12. birder/model_registry/model_registry.py +3 -2
  13. birder/net/base.py +3 -3
  14. birder/net/biformer.py +2 -2
  15. birder/net/cas_vit.py +6 -6
  16. birder/net/coat.py +8 -8
  17. birder/net/conv2former.py +2 -2
  18. birder/net/convnext_v1.py +22 -2
  19. birder/net/convnext_v2.py +2 -2
  20. birder/net/crossformer.py +2 -2
  21. birder/net/cspnet.py +2 -2
  22. birder/net/cswin_transformer.py +2 -2
  23. birder/net/darknet.py +2 -2
  24. birder/net/davit.py +2 -2
  25. birder/net/deit.py +3 -3
  26. birder/net/deit3.py +3 -3
  27. birder/net/densenet.py +2 -2
  28. birder/net/detection/deformable_detr.py +2 -2
  29. birder/net/detection/detr.py +2 -2
  30. birder/net/detection/efficientdet.py +2 -2
  31. birder/net/detection/faster_rcnn.py +2 -2
  32. birder/net/detection/fcos.py +2 -2
  33. birder/net/detection/retinanet.py +2 -2
  34. birder/net/detection/rt_detr_v1.py +4 -4
  35. birder/net/detection/ssd.py +2 -2
  36. birder/net/detection/ssdlite.py +2 -2
  37. birder/net/detection/yolo_v2.py +2 -2
  38. birder/net/detection/yolo_v3.py +2 -2
  39. birder/net/detection/yolo_v4.py +2 -2
  40. birder/net/edgenext.py +2 -2
  41. birder/net/edgevit.py +1 -1
  42. birder/net/efficientformer_v1.py +4 -4
  43. birder/net/efficientformer_v2.py +6 -6
  44. birder/net/efficientnet_lite.py +2 -2
  45. birder/net/efficientnet_v1.py +2 -2
  46. birder/net/efficientnet_v2.py +2 -2
  47. birder/net/efficientvim.py +3 -3
  48. birder/net/efficientvit_mit.py +2 -2
  49. birder/net/efficientvit_msft.py +2 -2
  50. birder/net/fasternet.py +2 -2
  51. birder/net/fastvit.py +2 -3
  52. birder/net/flexivit.py +11 -6
  53. birder/net/focalnet.py +2 -3
  54. birder/net/gc_vit.py +17 -2
  55. birder/net/ghostnet_v1.py +2 -2
  56. birder/net/ghostnet_v2.py +2 -2
  57. birder/net/groupmixformer.py +2 -2
  58. birder/net/hgnet_v1.py +2 -2
  59. birder/net/hgnet_v2.py +2 -2
  60. birder/net/hiera.py +2 -2
  61. birder/net/hieradet.py +2 -2
  62. birder/net/hornet.py +2 -2
  63. birder/net/iformer.py +2 -2
  64. birder/net/inception_next.py +2 -2
  65. birder/net/inception_resnet_v1.py +2 -2
  66. birder/net/inception_resnet_v2.py +2 -2
  67. birder/net/inception_v3.py +2 -2
  68. birder/net/inception_v4.py +2 -2
  69. birder/net/levit.py +4 -4
  70. birder/net/lit_v1.py +2 -2
  71. birder/net/lit_v1_tiny.py +2 -2
  72. birder/net/lit_v2.py +2 -2
  73. birder/net/maxvit.py +2 -2
  74. birder/net/metaformer.py +2 -2
  75. birder/net/mnasnet.py +2 -2
  76. birder/net/mobilenet_v1.py +2 -2
  77. birder/net/mobilenet_v2.py +2 -2
  78. birder/net/mobilenet_v3_large.py +2 -2
  79. birder/net/mobilenet_v4.py +2 -2
  80. birder/net/mobilenet_v4_hybrid.py +2 -2
  81. birder/net/mobileone.py +2 -2
  82. birder/net/mobilevit_v2.py +2 -2
  83. birder/net/moganet.py +2 -2
  84. birder/net/mvit_v2.py +2 -2
  85. birder/net/nextvit.py +2 -2
  86. birder/net/nfnet.py +2 -2
  87. birder/net/pit.py +6 -6
  88. birder/net/pvt_v1.py +2 -2
  89. birder/net/pvt_v2.py +2 -2
  90. birder/net/rdnet.py +2 -2
  91. birder/net/regionvit.py +6 -6
  92. birder/net/regnet.py +2 -2
  93. birder/net/regnet_z.py +2 -2
  94. birder/net/repghost.py +2 -2
  95. birder/net/repvgg.py +2 -2
  96. birder/net/repvit.py +6 -6
  97. birder/net/resnest.py +2 -2
  98. birder/net/resnet_v1.py +2 -2
  99. birder/net/resnet_v2.py +2 -2
  100. birder/net/resnext.py +2 -2
  101. birder/net/rope_deit3.py +3 -3
  102. birder/net/rope_flexivit.py +13 -6
  103. birder/net/rope_vit.py +69 -10
  104. birder/net/shufflenet_v1.py +2 -2
  105. birder/net/shufflenet_v2.py +2 -2
  106. birder/net/smt.py +1 -2
  107. birder/net/squeezenext.py +2 -2
  108. birder/net/ssl/byol.py +3 -2
  109. birder/net/ssl/capi.py +156 -11
  110. birder/net/ssl/data2vec.py +3 -1
  111. birder/net/ssl/data2vec2.py +3 -1
  112. birder/net/ssl/dino_v1.py +1 -1
  113. birder/net/ssl/dino_v2.py +140 -18
  114. birder/net/ssl/franca.py +145 -13
  115. birder/net/ssl/ibot.py +1 -2
  116. birder/net/ssl/mmcr.py +3 -1
  117. birder/net/starnet.py +2 -2
  118. birder/net/swiftformer.py +6 -6
  119. birder/net/swin_transformer_v1.py +2 -2
  120. birder/net/swin_transformer_v2.py +2 -2
  121. birder/net/tiny_vit.py +2 -2
  122. birder/net/transnext.py +1 -1
  123. birder/net/uniformer.py +1 -1
  124. birder/net/van.py +1 -1
  125. birder/net/vgg.py +1 -1
  126. birder/net/vgg_reduced.py +1 -1
  127. birder/net/vit.py +172 -8
  128. birder/net/vit_parallel.py +5 -5
  129. birder/net/vit_sam.py +3 -3
  130. birder/net/vovnet_v1.py +2 -2
  131. birder/net/vovnet_v2.py +2 -2
  132. birder/net/wide_resnet.py +2 -2
  133. birder/net/xception.py +2 -2
  134. birder/net/xcit.py +2 -2
  135. birder/results/detection.py +104 -0
  136. birder/results/gui.py +10 -8
  137. birder/scripts/benchmark.py +1 -1
  138. birder/scripts/train.py +13 -18
  139. birder/scripts/train_barlow_twins.py +10 -14
  140. birder/scripts/train_byol.py +11 -15
  141. birder/scripts/train_capi.py +38 -17
  142. birder/scripts/train_data2vec.py +11 -15
  143. birder/scripts/train_data2vec2.py +13 -17
  144. birder/scripts/train_detection.py +11 -14
  145. birder/scripts/train_dino_v1.py +20 -22
  146. birder/scripts/train_dino_v2.py +126 -63
  147. birder/scripts/train_dino_v2_dist.py +127 -64
  148. birder/scripts/train_franca.py +49 -34
  149. birder/scripts/train_i_jepa.py +11 -14
  150. birder/scripts/train_ibot.py +16 -18
  151. birder/scripts/train_kd.py +14 -20
  152. birder/scripts/train_mim.py +10 -13
  153. birder/scripts/train_mmcr.py +11 -15
  154. birder/scripts/train_rotnet.py +12 -16
  155. birder/scripts/train_simclr.py +10 -14
  156. birder/scripts/train_vicreg.py +10 -14
  157. birder/tools/avg_model.py +24 -8
  158. birder/tools/det_results.py +91 -0
  159. birder/tools/introspection.py +35 -9
  160. birder/tools/results.py +11 -7
  161. birder/tools/show_iterator.py +1 -1
  162. birder/version.py +1 -1
  163. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
  164. birder-0.3.2.dist-info/RECORD +299 -0
  165. birder-0.3.0.dist-info/RECORD +0 -298
  166. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
  167. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
  168. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
  169. {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/net/hgnet_v1.py CHANGED
@@ -387,14 +387,14 @@ class HGNet_v1(DetectorBackbone):
387
387
 
388
388
  def freeze_stages(self, up_to_stage: int) -> None:
389
389
  for param in self.stem.parameters():
390
- param.requires_grad = False
390
+ param.requires_grad_(False)
391
391
 
392
392
  for idx, module in enumerate(self.body.children()):
393
393
  if idx >= up_to_stage:
394
394
  break
395
395
 
396
396
  for param in module.parameters():
397
- param.requires_grad = False
397
+ param.requires_grad_(False)
398
398
 
399
399
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
400
400
  x = self.stem(x)
birder/net/hgnet_v2.py CHANGED
@@ -180,14 +180,14 @@ class HGNet_v2(DetectorBackbone):
180
180
 
181
181
  def freeze_stages(self, up_to_stage: int) -> None:
182
182
  for param in self.stem.parameters():
183
- param.requires_grad = False
183
+ param.requires_grad_(False)
184
184
 
185
185
  for idx, module in enumerate(self.body.children()):
186
186
  if idx >= up_to_stage:
187
187
  break
188
188
 
189
189
  for param in module.parameters():
190
- param.requires_grad = False
190
+ param.requires_grad_(False)
191
191
 
192
192
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
193
193
  x = self.stem(x)
birder/net/hiera.py CHANGED
@@ -515,14 +515,14 @@ class Hiera(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
515
515
 
516
516
  def freeze_stages(self, up_to_stage: int) -> None:
517
517
  for param in self.stem.parameters():
518
- param.requires_grad = False
518
+ param.requires_grad_(False)
519
519
 
520
520
  for idx, module in enumerate(self.body.children()):
521
521
  if idx >= up_to_stage:
522
522
  break
523
523
 
524
524
  for param in module.parameters():
525
- param.requires_grad = False
525
+ param.requires_grad_(False)
526
526
 
527
527
  def masked_encoding_omission(
528
528
  self,
birder/net/hieradet.py CHANGED
@@ -312,14 +312,14 @@ class HieraDet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
312
312
 
313
313
  def freeze_stages(self, up_to_stage: int) -> None:
314
314
  for param in self.stem.parameters():
315
- param.requires_grad = False
315
+ param.requires_grad_(False)
316
316
 
317
317
  for idx, module in enumerate(self.body.children()):
318
318
  if idx >= up_to_stage:
319
319
  break
320
320
 
321
321
  for param in module.parameters():
322
- param.requires_grad = False
322
+ param.requires_grad_(False)
323
323
 
324
324
  def masked_encoding_retention(
325
325
  self,
birder/net/hornet.py CHANGED
@@ -299,14 +299,14 @@ class HorNet(DetectorBackbone):
299
299
 
300
300
  def freeze_stages(self, up_to_stage: int) -> None:
301
301
  for param in self.stem.parameters():
302
- param.requires_grad = False
302
+ param.requires_grad_(False)
303
303
 
304
304
  for idx, module in enumerate(self.body.children()):
305
305
  if idx >= up_to_stage:
306
306
  break
307
307
 
308
308
  for param in module.parameters():
309
- param.requires_grad = False
309
+ param.requires_grad_(False)
310
310
 
311
311
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
312
312
  x = self.stem(x)
birder/net/iformer.py CHANGED
@@ -424,14 +424,14 @@ class iFormer(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
424
424
 
425
425
  def freeze_stages(self, up_to_stage: int) -> None:
426
426
  for param in self.stem.parameters():
427
- param.requires_grad = False
427
+ param.requires_grad_(False)
428
428
 
429
429
  for idx, module in enumerate(self.body.children()):
430
430
  if idx >= up_to_stage:
431
431
  break
432
432
 
433
433
  for param in module.parameters():
434
- param.requires_grad = False
434
+ param.requires_grad_(False)
435
435
 
436
436
  def masked_encoding_retention(
437
437
  self,
@@ -261,14 +261,14 @@ class Inception_NeXt(DetectorBackbone):
261
261
 
262
262
  def freeze_stages(self, up_to_stage: int) -> None:
263
263
  for param in self.stem.parameters():
264
- param.requires_grad = False
264
+ param.requires_grad_(False)
265
265
 
266
266
  for idx, module in enumerate(self.body.children()):
267
267
  if idx >= up_to_stage:
268
268
  break
269
269
 
270
270
  for param in module.parameters():
271
- param.requires_grad = False
271
+ param.requires_grad_(False)
272
272
 
273
273
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
274
274
  x = self.stem(x)
@@ -236,14 +236,14 @@ class Inception_ResNet_v1(DetectorBackbone):
236
236
 
237
237
  def freeze_stages(self, up_to_stage: int) -> None:
238
238
  for param in self.stem.parameters():
239
- param.requires_grad = False
239
+ param.requires_grad_(False)
240
240
 
241
241
  for idx, module in enumerate(self.body.children()):
242
242
  if idx >= up_to_stage:
243
243
  break
244
244
 
245
245
  for param in module.parameters():
246
- param.requires_grad = False
246
+ param.requires_grad_(False)
247
247
 
248
248
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
249
249
  x = self.stem(x)
@@ -277,14 +277,14 @@ class Inception_ResNet_v2(DetectorBackbone):
277
277
 
278
278
  def freeze_stages(self, up_to_stage: int) -> None:
279
279
  for param in self.stem.parameters():
280
- param.requires_grad = False
280
+ param.requires_grad_(False)
281
281
 
282
282
  for idx, module in enumerate(self.body.children()):
283
283
  if idx >= up_to_stage:
284
284
  break
285
285
 
286
286
  for param in module.parameters():
287
- param.requires_grad = False
287
+ param.requires_grad_(False)
288
288
 
289
289
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
290
290
  x = self.stem(x)
@@ -277,14 +277,14 @@ class Inception_v3(DetectorBackbone):
277
277
 
278
278
  def freeze_stages(self, up_to_stage: int) -> None:
279
279
  for param in self.stem.parameters():
280
- param.requires_grad = False
280
+ param.requires_grad_(False)
281
281
 
282
282
  for idx, module in enumerate(self.body.children()):
283
283
  if idx >= up_to_stage:
284
284
  break
285
285
 
286
286
  for param in module.parameters():
287
- param.requires_grad = False
287
+ param.requires_grad_(False)
288
288
 
289
289
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
290
290
  x = self.stem(x)
@@ -306,14 +306,14 @@ class Inception_v4(DetectorBackbone):
306
306
 
307
307
  def freeze_stages(self, up_to_stage: int) -> None:
308
308
  for param in self.stem.parameters():
309
- param.requires_grad = False
309
+ param.requires_grad_(False)
310
310
 
311
311
  for idx, module in enumerate(self.body.children()):
312
312
  if idx >= up_to_stage:
313
313
  break
314
314
 
315
315
  for param in module.parameters():
316
- param.requires_grad = False
316
+ param.requires_grad_(False)
317
317
 
318
318
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
319
319
  x = self.stem(x)
birder/net/levit.py CHANGED
@@ -399,18 +399,18 @@ class LeViT(BaseNet):
399
399
 
400
400
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
401
401
  for param in self.parameters():
402
- param.requires_grad = False
402
+ param.requires_grad_(False)
403
403
 
404
404
  if freeze_classifier is False:
405
405
  for param in self.classifier.parameters():
406
- param.requires_grad = True
406
+ param.requires_grad_(True)
407
407
 
408
408
  for param in self.dist_classifier.parameters():
409
- param.requires_grad = True
409
+ param.requires_grad_(True)
410
410
 
411
411
  if unfreeze_features is True:
412
412
  for param in self.features.parameters():
413
- param.requires_grad = True
413
+ param.requires_grad_(True)
414
414
 
415
415
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
416
416
  x = self.stem(x)
birder/net/lit_v1.py CHANGED
@@ -375,14 +375,14 @@ class LIT_v1(DetectorBackbone):
375
375
 
376
376
  def freeze_stages(self, up_to_stage: int) -> None:
377
377
  for param in self.stem.parameters():
378
- param.requires_grad = False
378
+ param.requires_grad_(False)
379
379
 
380
380
  for idx, stage in enumerate(self.body.values()):
381
381
  if idx >= up_to_stage:
382
382
  break
383
383
 
384
384
  for param in stage.parameters():
385
- param.requires_grad = False
385
+ param.requires_grad_(False)
386
386
 
387
387
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
388
388
  x = self.stem(x)
birder/net/lit_v1_tiny.py CHANGED
@@ -265,14 +265,14 @@ class LIT_v1_Tiny(DetectorBackbone):
265
265
 
266
266
  def freeze_stages(self, up_to_stage: int) -> None:
267
267
  for param in self.stem.parameters():
268
- param.requires_grad = False
268
+ param.requires_grad_(False)
269
269
 
270
270
  for idx, stage in enumerate(self.body.values()):
271
271
  if idx >= up_to_stage:
272
272
  break
273
273
 
274
274
  for param in stage.parameters():
275
- param.requires_grad = False
275
+ param.requires_grad_(False)
276
276
 
277
277
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
278
278
  x = self.stem(x)
birder/net/lit_v2.py CHANGED
@@ -375,14 +375,14 @@ class LIT_v2(DetectorBackbone):
375
375
 
376
376
  def freeze_stages(self, up_to_stage: int) -> None:
377
377
  for param in self.stem.parameters():
378
- param.requires_grad = False
378
+ param.requires_grad_(False)
379
379
 
380
380
  for idx, stage in enumerate(self.body.values()):
381
381
  if idx >= up_to_stage:
382
382
  break
383
383
 
384
384
  for param in stage.parameters():
385
- param.requires_grad = False
385
+ param.requires_grad_(False)
386
386
 
387
387
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
388
388
  x = self.stem(x)
birder/net/maxvit.py CHANGED
@@ -589,14 +589,14 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
589
589
 
590
590
  def freeze_stages(self, up_to_stage: int) -> None:
591
591
  for param in self.stem.parameters():
592
- param.requires_grad = False
592
+ param.requires_grad_(False)
593
593
 
594
594
  for idx, module in enumerate(self.body.children()):
595
595
  if idx >= up_to_stage:
596
596
  break
597
597
 
598
598
  for param in module.parameters():
599
- param.requires_grad = False
599
+ param.requires_grad_(False)
600
600
 
601
601
  def masked_encoding_retention(
602
602
  self,
birder/net/metaformer.py CHANGED
@@ -449,14 +449,14 @@ class MetaFormer(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
449
449
 
450
450
  def freeze_stages(self, up_to_stage: int) -> None:
451
451
  for param in self.stem.parameters():
452
- param.requires_grad = False
452
+ param.requires_grad_(False)
453
453
 
454
454
  for idx, module in enumerate(self.body.children()):
455
455
  if idx >= up_to_stage:
456
456
  break
457
457
 
458
458
  for param in module.parameters():
459
- param.requires_grad = False
459
+ param.requires_grad_(False)
460
460
 
461
461
  def masked_encoding_retention(
462
462
  self,
birder/net/mnasnet.py CHANGED
@@ -251,14 +251,14 @@ class MNASNet(DetectorBackbone):
251
251
 
252
252
  def freeze_stages(self, up_to_stage: int) -> None:
253
253
  for param in self.stem.parameters():
254
- param.requires_grad = False
254
+ param.requires_grad_(False)
255
255
 
256
256
  for idx, module in enumerate(self.body.children()):
257
257
  if idx >= up_to_stage:
258
258
  break
259
259
 
260
260
  for param in module.parameters():
261
- param.requires_grad = False
261
+ param.requires_grad_(False)
262
262
 
263
263
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
264
264
  x = self.stem(x)
@@ -136,14 +136,14 @@ class MobileNet_v1(DetectorBackbone):
136
136
 
137
137
  def freeze_stages(self, up_to_stage: int) -> None:
138
138
  for param in self.stem.parameters():
139
- param.requires_grad = False
139
+ param.requires_grad_(False)
140
140
 
141
141
  for idx, module in enumerate(self.body.children()):
142
142
  if idx >= up_to_stage:
143
143
  break
144
144
 
145
145
  for param in module.parameters():
146
- param.requires_grad = False
146
+ param.requires_grad_(False)
147
147
 
148
148
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
149
149
  x = self.stem(x)
@@ -204,14 +204,14 @@ class MobileNet_v2(DetectorBackbone):
204
204
 
205
205
  def freeze_stages(self, up_to_stage: int) -> None:
206
206
  for param in self.stem.parameters():
207
- param.requires_grad = False
207
+ param.requires_grad_(False)
208
208
 
209
209
  for idx, module in enumerate(self.body.children()):
210
210
  if idx >= up_to_stage:
211
211
  break
212
212
 
213
213
  for param in module.parameters():
214
- param.requires_grad = False
214
+ param.requires_grad_(False)
215
215
 
216
216
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
217
217
  x = self.stem(x)
@@ -236,14 +236,14 @@ class MobileNet_v3_Large(DetectorBackbone):
236
236
 
237
237
  def freeze_stages(self, up_to_stage: int) -> None:
238
238
  for param in self.stem.parameters():
239
- param.requires_grad = False
239
+ param.requires_grad_(False)
240
240
 
241
241
  for idx, module in enumerate(self.body.children()):
242
242
  if idx >= up_to_stage:
243
243
  break
244
244
 
245
245
  for param in module.parameters():
246
- param.requires_grad = False
246
+ param.requires_grad_(False)
247
247
 
248
248
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
249
249
  x = self.stem(x)
@@ -493,14 +493,14 @@ class MobileNet_v4(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin)
493
493
 
494
494
  def freeze_stages(self, up_to_stage: int) -> None:
495
495
  for param in self.stem.parameters():
496
- param.requires_grad = False
496
+ param.requires_grad_(False)
497
497
 
498
498
  for idx, module in enumerate(self.body.children()):
499
499
  if idx >= up_to_stage:
500
500
  break
501
501
 
502
502
  for param in module.parameters():
503
- param.requires_grad = False
503
+ param.requires_grad_(False)
504
504
 
505
505
  def masked_encoding_retention(
506
506
  self,
@@ -439,14 +439,14 @@ class MobileNet_v4_Hybrid(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentio
439
439
 
440
440
  def freeze_stages(self, up_to_stage: int) -> None:
441
441
  for param in self.stem.parameters():
442
- param.requires_grad = False
442
+ param.requires_grad_(False)
443
443
 
444
444
  for idx, module in enumerate(self.body.children()):
445
445
  if idx >= up_to_stage:
446
446
  break
447
447
 
448
448
  for param in module.parameters():
449
- param.requires_grad = False
449
+ param.requires_grad_(False)
450
450
 
451
451
  def masked_encoding_retention(
452
452
  self,
birder/net/mobileone.py CHANGED
@@ -363,14 +363,14 @@ class MobileOne(DetectorBackbone):
363
363
 
364
364
  def freeze_stages(self, up_to_stage: int) -> None:
365
365
  for param in self.stem.parameters():
366
- param.requires_grad = False
366
+ param.requires_grad_(False)
367
367
 
368
368
  for idx, module in enumerate(self.body.children()):
369
369
  if idx >= up_to_stage:
370
370
  break
371
371
 
372
372
  for param in module.parameters():
373
- param.requires_grad = False
373
+ param.requires_grad_(False)
374
374
 
375
375
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
376
376
  x = self.stem(x)
@@ -323,14 +323,14 @@ class MobileViT_v2(DetectorBackbone):
323
323
 
324
324
  def freeze_stages(self, up_to_stage: int) -> None:
325
325
  for param in self.stem.parameters():
326
- param.requires_grad = False
326
+ param.requires_grad_(False)
327
327
 
328
328
  for idx, module in enumerate(self.body.children()):
329
329
  if idx >= up_to_stage:
330
330
  break
331
331
 
332
332
  for param in module.parameters():
333
- param.requires_grad = False
333
+ param.requires_grad_(False)
334
334
 
335
335
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
336
336
  x = self.stem(x)
birder/net/moganet.py CHANGED
@@ -330,14 +330,14 @@ class MogaNet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
330
330
 
331
331
  def freeze_stages(self, up_to_stage: int) -> None:
332
332
  for param in self.stem.parameters():
333
- param.requires_grad = False
333
+ param.requires_grad_(False)
334
334
 
335
335
  for idx, module in enumerate(self.body.children()):
336
336
  if idx >= up_to_stage:
337
337
  break
338
338
 
339
339
  for param in module.parameters():
340
- param.requires_grad = False
340
+ param.requires_grad_(False)
341
341
 
342
342
  def masked_encoding_retention(
343
343
  self,
birder/net/mvit_v2.py CHANGED
@@ -543,14 +543,14 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
543
543
 
544
544
  def freeze_stages(self, up_to_stage: int) -> None:
545
545
  for param in self.patch_embed.parameters():
546
- param.requires_grad = False
546
+ param.requires_grad_(False)
547
547
 
548
548
  for idx, module in enumerate(self.body.children()):
549
549
  if idx >= up_to_stage:
550
550
  break
551
551
 
552
552
  for param in module.parameters():
553
- param.requires_grad = False
553
+ param.requires_grad_(False)
554
554
 
555
555
  def masked_encoding_retention(
556
556
  self,
birder/net/nextvit.py CHANGED
@@ -381,14 +381,14 @@ class NextViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
381
381
 
382
382
  def freeze_stages(self, up_to_stage: int) -> None:
383
383
  for param in self.stem.parameters():
384
- param.requires_grad = False
384
+ param.requires_grad_(False)
385
385
 
386
386
  for idx, module in enumerate(self.body.children()):
387
387
  if idx >= up_to_stage:
388
388
  break
389
389
 
390
390
  for param in module.parameters():
391
- param.requires_grad = False
391
+ param.requires_grad_(False)
392
392
 
393
393
  def masked_encoding_retention(
394
394
  self,
birder/net/nfnet.py CHANGED
@@ -294,14 +294,14 @@ class NFNet(DetectorBackbone):
294
294
 
295
295
  def freeze_stages(self, up_to_stage: int) -> None:
296
296
  for param in self.stem.parameters():
297
- param.requires_grad = False
297
+ param.requires_grad_(False)
298
298
 
299
299
  for idx, module in enumerate(self.body.children()):
300
300
  if idx >= up_to_stage:
301
301
  break
302
302
 
303
303
  for param in module.parameters():
304
- param.requires_grad = False
304
+ param.requires_grad_(False)
305
305
 
306
306
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
307
307
  x = self.stem(x)
birder/net/pit.py CHANGED
@@ -172,18 +172,18 @@ class PiT(DetectorBackbone):
172
172
 
173
173
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
174
174
  for param in self.parameters():
175
- param.requires_grad = False
175
+ param.requires_grad_(False)
176
176
 
177
177
  if freeze_classifier is False:
178
178
  for param in self.classifier.parameters():
179
- param.requires_grad = True
179
+ param.requires_grad_(True)
180
180
 
181
181
  for param in self.dist_classifier.parameters():
182
- param.requires_grad = True
182
+ param.requires_grad_(True)
183
183
 
184
184
  if unfreeze_features is True:
185
185
  for param in self.norm.parameters():
186
- param.requires_grad = True
186
+ param.requires_grad_(True)
187
187
 
188
188
  def transform_to_backbone(self) -> None:
189
189
  self.norm = nn.Identity()
@@ -205,14 +205,14 @@ class PiT(DetectorBackbone):
205
205
 
206
206
  def freeze_stages(self, up_to_stage: int) -> None:
207
207
  for param in self.stem.parameters():
208
- param.requires_grad = False
208
+ param.requires_grad_(False)
209
209
 
210
210
  for idx, module in enumerate(self.body.children()):
211
211
  if idx >= up_to_stage:
212
212
  break
213
213
 
214
214
  for param in module.parameters():
215
- param.requires_grad = False
215
+ param.requires_grad_(False)
216
216
 
217
217
  def forward_features(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
218
218
  x = self.stem(x)
birder/net/pvt_v1.py CHANGED
@@ -277,14 +277,14 @@ class PVT_v1(DetectorBackbone):
277
277
 
278
278
  def freeze_stages(self, up_to_stage: int) -> None:
279
279
  for param in self.patch_embed.parameters():
280
- param.requires_grad = False
280
+ param.requires_grad_(False)
281
281
 
282
282
  for idx, module in enumerate(self.body.children()):
283
283
  if idx >= up_to_stage:
284
284
  break
285
285
 
286
286
  for param in module.parameters():
287
- param.requires_grad = False
287
+ param.requires_grad_(False)
288
288
 
289
289
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
290
290
  x = self.patch_embed(x)
birder/net/pvt_v2.py CHANGED
@@ -336,14 +336,14 @@ class PVT_v2(DetectorBackbone):
336
336
 
337
337
  def freeze_stages(self, up_to_stage: int) -> None:
338
338
  for param in self.patch_embed.parameters():
339
- param.requires_grad = False
339
+ param.requires_grad_(False)
340
340
 
341
341
  for idx, module in enumerate(self.body.children()):
342
342
  if idx >= up_to_stage:
343
343
  break
344
344
 
345
345
  for param in module.parameters():
346
- param.requires_grad = False
346
+ param.requires_grad_(False)
347
347
 
348
348
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
349
349
  x = self.patch_embed(x)
birder/net/rdnet.py CHANGED
@@ -247,14 +247,14 @@ class RDNet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
247
247
 
248
248
  def freeze_stages(self, up_to_stage: int) -> None:
249
249
  for param in self.stem.parameters():
250
- param.requires_grad = False
250
+ param.requires_grad_(False)
251
251
 
252
252
  for idx, module in enumerate(self.body.children()):
253
253
  if idx >= up_to_stage:
254
254
  break
255
255
 
256
256
  for param in module.parameters():
257
- param.requires_grad = False
257
+ param.requires_grad_(False)
258
258
 
259
259
  def masked_encoding_retention(
260
260
  self,