birder 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birder/adversarial/deepfool.py +2 -0
- birder/adversarial/simba.py +2 -0
- birder/common/fs_ops.py +2 -2
- birder/common/masking.py +13 -4
- birder/common/training_cli.py +6 -1
- birder/common/training_utils.py +4 -2
- birder/inference/classification.py +1 -1
- birder/introspection/__init__.py +2 -0
- birder/introspection/base.py +0 -7
- birder/introspection/feature_pca.py +101 -0
- birder/kernels/soft_nms/soft_nms.cpp +5 -2
- birder/model_registry/model_registry.py +3 -2
- birder/net/base.py +3 -3
- birder/net/biformer.py +2 -2
- birder/net/cas_vit.py +6 -6
- birder/net/coat.py +8 -8
- birder/net/conv2former.py +2 -2
- birder/net/convnext_v1.py +22 -2
- birder/net/convnext_v2.py +2 -2
- birder/net/crossformer.py +2 -2
- birder/net/cspnet.py +2 -2
- birder/net/cswin_transformer.py +2 -2
- birder/net/darknet.py +2 -2
- birder/net/davit.py +2 -2
- birder/net/deit.py +3 -3
- birder/net/deit3.py +3 -3
- birder/net/densenet.py +2 -2
- birder/net/detection/deformable_detr.py +2 -2
- birder/net/detection/detr.py +2 -2
- birder/net/detection/efficientdet.py +2 -2
- birder/net/detection/faster_rcnn.py +2 -2
- birder/net/detection/fcos.py +2 -2
- birder/net/detection/retinanet.py +2 -2
- birder/net/detection/rt_detr_v1.py +4 -4
- birder/net/detection/ssd.py +2 -2
- birder/net/detection/ssdlite.py +2 -2
- birder/net/detection/yolo_v2.py +2 -2
- birder/net/detection/yolo_v3.py +2 -2
- birder/net/detection/yolo_v4.py +2 -2
- birder/net/edgenext.py +2 -2
- birder/net/edgevit.py +1 -1
- birder/net/efficientformer_v1.py +4 -4
- birder/net/efficientformer_v2.py +6 -6
- birder/net/efficientnet_lite.py +2 -2
- birder/net/efficientnet_v1.py +2 -2
- birder/net/efficientnet_v2.py +2 -2
- birder/net/efficientvim.py +3 -3
- birder/net/efficientvit_mit.py +2 -2
- birder/net/efficientvit_msft.py +2 -2
- birder/net/fasternet.py +2 -2
- birder/net/fastvit.py +2 -3
- birder/net/flexivit.py +11 -6
- birder/net/focalnet.py +2 -3
- birder/net/gc_vit.py +17 -2
- birder/net/ghostnet_v1.py +2 -2
- birder/net/ghostnet_v2.py +2 -2
- birder/net/groupmixformer.py +2 -2
- birder/net/hgnet_v1.py +2 -2
- birder/net/hgnet_v2.py +2 -2
- birder/net/hiera.py +2 -2
- birder/net/hieradet.py +2 -2
- birder/net/hornet.py +2 -2
- birder/net/iformer.py +2 -2
- birder/net/inception_next.py +2 -2
- birder/net/inception_resnet_v1.py +2 -2
- birder/net/inception_resnet_v2.py +2 -2
- birder/net/inception_v3.py +2 -2
- birder/net/inception_v4.py +2 -2
- birder/net/levit.py +4 -4
- birder/net/lit_v1.py +2 -2
- birder/net/lit_v1_tiny.py +2 -2
- birder/net/lit_v2.py +2 -2
- birder/net/maxvit.py +2 -2
- birder/net/metaformer.py +2 -2
- birder/net/mnasnet.py +2 -2
- birder/net/mobilenet_v1.py +2 -2
- birder/net/mobilenet_v2.py +2 -2
- birder/net/mobilenet_v3_large.py +2 -2
- birder/net/mobilenet_v4.py +2 -2
- birder/net/mobilenet_v4_hybrid.py +2 -2
- birder/net/mobileone.py +2 -2
- birder/net/mobilevit_v2.py +2 -2
- birder/net/moganet.py +2 -2
- birder/net/mvit_v2.py +2 -2
- birder/net/nextvit.py +2 -2
- birder/net/nfnet.py +2 -2
- birder/net/pit.py +6 -6
- birder/net/pvt_v1.py +2 -2
- birder/net/pvt_v2.py +2 -2
- birder/net/rdnet.py +2 -2
- birder/net/regionvit.py +6 -6
- birder/net/regnet.py +2 -2
- birder/net/regnet_z.py +2 -2
- birder/net/repghost.py +2 -2
- birder/net/repvgg.py +2 -2
- birder/net/repvit.py +6 -6
- birder/net/resnest.py +2 -2
- birder/net/resnet_v1.py +2 -2
- birder/net/resnet_v2.py +2 -2
- birder/net/resnext.py +2 -2
- birder/net/rope_deit3.py +3 -3
- birder/net/rope_flexivit.py +13 -6
- birder/net/rope_vit.py +69 -10
- birder/net/shufflenet_v1.py +2 -2
- birder/net/shufflenet_v2.py +2 -2
- birder/net/smt.py +1 -2
- birder/net/squeezenext.py +2 -2
- birder/net/ssl/byol.py +3 -2
- birder/net/ssl/capi.py +156 -11
- birder/net/ssl/data2vec.py +3 -1
- birder/net/ssl/data2vec2.py +3 -1
- birder/net/ssl/dino_v1.py +1 -1
- birder/net/ssl/dino_v2.py +140 -18
- birder/net/ssl/franca.py +145 -13
- birder/net/ssl/ibot.py +1 -2
- birder/net/ssl/mmcr.py +3 -1
- birder/net/starnet.py +2 -2
- birder/net/swiftformer.py +6 -6
- birder/net/swin_transformer_v1.py +2 -2
- birder/net/swin_transformer_v2.py +2 -2
- birder/net/tiny_vit.py +2 -2
- birder/net/transnext.py +1 -1
- birder/net/uniformer.py +1 -1
- birder/net/van.py +1 -1
- birder/net/vgg.py +1 -1
- birder/net/vgg_reduced.py +1 -1
- birder/net/vit.py +172 -8
- birder/net/vit_parallel.py +5 -5
- birder/net/vit_sam.py +3 -3
- birder/net/vovnet_v1.py +2 -2
- birder/net/vovnet_v2.py +2 -2
- birder/net/wide_resnet.py +2 -2
- birder/net/xception.py +2 -2
- birder/net/xcit.py +2 -2
- birder/results/detection.py +104 -0
- birder/results/gui.py +10 -8
- birder/scripts/benchmark.py +1 -1
- birder/scripts/train.py +13 -18
- birder/scripts/train_barlow_twins.py +10 -14
- birder/scripts/train_byol.py +11 -15
- birder/scripts/train_capi.py +38 -17
- birder/scripts/train_data2vec.py +11 -15
- birder/scripts/train_data2vec2.py +13 -17
- birder/scripts/train_detection.py +11 -14
- birder/scripts/train_dino_v1.py +20 -22
- birder/scripts/train_dino_v2.py +126 -63
- birder/scripts/train_dino_v2_dist.py +127 -64
- birder/scripts/train_franca.py +49 -34
- birder/scripts/train_i_jepa.py +11 -14
- birder/scripts/train_ibot.py +16 -18
- birder/scripts/train_kd.py +14 -20
- birder/scripts/train_mim.py +10 -13
- birder/scripts/train_mmcr.py +11 -15
- birder/scripts/train_rotnet.py +12 -16
- birder/scripts/train_simclr.py +10 -14
- birder/scripts/train_vicreg.py +10 -14
- birder/tools/avg_model.py +24 -8
- birder/tools/det_results.py +91 -0
- birder/tools/introspection.py +35 -9
- birder/tools/results.py +11 -7
- birder/tools/show_iterator.py +1 -1
- birder/version.py +1 -1
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/METADATA +1 -1
- birder-0.3.2.dist-info/RECORD +299 -0
- birder-0.3.0.dist-info/RECORD +0 -298
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/WHEEL +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/entry_points.txt +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {birder-0.3.0.dist-info → birder-0.3.2.dist-info}/top_level.txt +0 -0
birder/net/hgnet_v1.py
CHANGED
|
@@ -387,14 +387,14 @@ class HGNet_v1(DetectorBackbone):
|
|
|
387
387
|
|
|
388
388
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
389
389
|
for param in self.stem.parameters():
|
|
390
|
-
param.
|
|
390
|
+
param.requires_grad_(False)
|
|
391
391
|
|
|
392
392
|
for idx, module in enumerate(self.body.children()):
|
|
393
393
|
if idx >= up_to_stage:
|
|
394
394
|
break
|
|
395
395
|
|
|
396
396
|
for param in module.parameters():
|
|
397
|
-
param.
|
|
397
|
+
param.requires_grad_(False)
|
|
398
398
|
|
|
399
399
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
400
400
|
x = self.stem(x)
|
birder/net/hgnet_v2.py
CHANGED
|
@@ -180,14 +180,14 @@ class HGNet_v2(DetectorBackbone):
|
|
|
180
180
|
|
|
181
181
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
182
182
|
for param in self.stem.parameters():
|
|
183
|
-
param.
|
|
183
|
+
param.requires_grad_(False)
|
|
184
184
|
|
|
185
185
|
for idx, module in enumerate(self.body.children()):
|
|
186
186
|
if idx >= up_to_stage:
|
|
187
187
|
break
|
|
188
188
|
|
|
189
189
|
for param in module.parameters():
|
|
190
|
-
param.
|
|
190
|
+
param.requires_grad_(False)
|
|
191
191
|
|
|
192
192
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
193
193
|
x = self.stem(x)
|
birder/net/hiera.py
CHANGED
|
@@ -515,14 +515,14 @@ class Hiera(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin):
|
|
|
515
515
|
|
|
516
516
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
517
517
|
for param in self.stem.parameters():
|
|
518
|
-
param.
|
|
518
|
+
param.requires_grad_(False)
|
|
519
519
|
|
|
520
520
|
for idx, module in enumerate(self.body.children()):
|
|
521
521
|
if idx >= up_to_stage:
|
|
522
522
|
break
|
|
523
523
|
|
|
524
524
|
for param in module.parameters():
|
|
525
|
-
param.
|
|
525
|
+
param.requires_grad_(False)
|
|
526
526
|
|
|
527
527
|
def masked_encoding_omission(
|
|
528
528
|
self,
|
birder/net/hieradet.py
CHANGED
|
@@ -312,14 +312,14 @@ class HieraDet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
312
312
|
|
|
313
313
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
314
314
|
for param in self.stem.parameters():
|
|
315
|
-
param.
|
|
315
|
+
param.requires_grad_(False)
|
|
316
316
|
|
|
317
317
|
for idx, module in enumerate(self.body.children()):
|
|
318
318
|
if idx >= up_to_stage:
|
|
319
319
|
break
|
|
320
320
|
|
|
321
321
|
for param in module.parameters():
|
|
322
|
-
param.
|
|
322
|
+
param.requires_grad_(False)
|
|
323
323
|
|
|
324
324
|
def masked_encoding_retention(
|
|
325
325
|
self,
|
birder/net/hornet.py
CHANGED
|
@@ -299,14 +299,14 @@ class HorNet(DetectorBackbone):
|
|
|
299
299
|
|
|
300
300
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
301
301
|
for param in self.stem.parameters():
|
|
302
|
-
param.
|
|
302
|
+
param.requires_grad_(False)
|
|
303
303
|
|
|
304
304
|
for idx, module in enumerate(self.body.children()):
|
|
305
305
|
if idx >= up_to_stage:
|
|
306
306
|
break
|
|
307
307
|
|
|
308
308
|
for param in module.parameters():
|
|
309
|
-
param.
|
|
309
|
+
param.requires_grad_(False)
|
|
310
310
|
|
|
311
311
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
312
312
|
x = self.stem(x)
|
birder/net/iformer.py
CHANGED
|
@@ -424,14 +424,14 @@ class iFormer(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
424
424
|
|
|
425
425
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
426
426
|
for param in self.stem.parameters():
|
|
427
|
-
param.
|
|
427
|
+
param.requires_grad_(False)
|
|
428
428
|
|
|
429
429
|
for idx, module in enumerate(self.body.children()):
|
|
430
430
|
if idx >= up_to_stage:
|
|
431
431
|
break
|
|
432
432
|
|
|
433
433
|
for param in module.parameters():
|
|
434
|
-
param.
|
|
434
|
+
param.requires_grad_(False)
|
|
435
435
|
|
|
436
436
|
def masked_encoding_retention(
|
|
437
437
|
self,
|
birder/net/inception_next.py
CHANGED
|
@@ -261,14 +261,14 @@ class Inception_NeXt(DetectorBackbone):
|
|
|
261
261
|
|
|
262
262
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
263
263
|
for param in self.stem.parameters():
|
|
264
|
-
param.
|
|
264
|
+
param.requires_grad_(False)
|
|
265
265
|
|
|
266
266
|
for idx, module in enumerate(self.body.children()):
|
|
267
267
|
if idx >= up_to_stage:
|
|
268
268
|
break
|
|
269
269
|
|
|
270
270
|
for param in module.parameters():
|
|
271
|
-
param.
|
|
271
|
+
param.requires_grad_(False)
|
|
272
272
|
|
|
273
273
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
274
274
|
x = self.stem(x)
|
|
@@ -236,14 +236,14 @@ class Inception_ResNet_v1(DetectorBackbone):
|
|
|
236
236
|
|
|
237
237
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
238
238
|
for param in self.stem.parameters():
|
|
239
|
-
param.
|
|
239
|
+
param.requires_grad_(False)
|
|
240
240
|
|
|
241
241
|
for idx, module in enumerate(self.body.children()):
|
|
242
242
|
if idx >= up_to_stage:
|
|
243
243
|
break
|
|
244
244
|
|
|
245
245
|
for param in module.parameters():
|
|
246
|
-
param.
|
|
246
|
+
param.requires_grad_(False)
|
|
247
247
|
|
|
248
248
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
249
249
|
x = self.stem(x)
|
|
@@ -277,14 +277,14 @@ class Inception_ResNet_v2(DetectorBackbone):
|
|
|
277
277
|
|
|
278
278
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
279
279
|
for param in self.stem.parameters():
|
|
280
|
-
param.
|
|
280
|
+
param.requires_grad_(False)
|
|
281
281
|
|
|
282
282
|
for idx, module in enumerate(self.body.children()):
|
|
283
283
|
if idx >= up_to_stage:
|
|
284
284
|
break
|
|
285
285
|
|
|
286
286
|
for param in module.parameters():
|
|
287
|
-
param.
|
|
287
|
+
param.requires_grad_(False)
|
|
288
288
|
|
|
289
289
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
290
290
|
x = self.stem(x)
|
birder/net/inception_v3.py
CHANGED
|
@@ -277,14 +277,14 @@ class Inception_v3(DetectorBackbone):
|
|
|
277
277
|
|
|
278
278
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
279
279
|
for param in self.stem.parameters():
|
|
280
|
-
param.
|
|
280
|
+
param.requires_grad_(False)
|
|
281
281
|
|
|
282
282
|
for idx, module in enumerate(self.body.children()):
|
|
283
283
|
if idx >= up_to_stage:
|
|
284
284
|
break
|
|
285
285
|
|
|
286
286
|
for param in module.parameters():
|
|
287
|
-
param.
|
|
287
|
+
param.requires_grad_(False)
|
|
288
288
|
|
|
289
289
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
290
290
|
x = self.stem(x)
|
birder/net/inception_v4.py
CHANGED
|
@@ -306,14 +306,14 @@ class Inception_v4(DetectorBackbone):
|
|
|
306
306
|
|
|
307
307
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
308
308
|
for param in self.stem.parameters():
|
|
309
|
-
param.
|
|
309
|
+
param.requires_grad_(False)
|
|
310
310
|
|
|
311
311
|
for idx, module in enumerate(self.body.children()):
|
|
312
312
|
if idx >= up_to_stage:
|
|
313
313
|
break
|
|
314
314
|
|
|
315
315
|
for param in module.parameters():
|
|
316
|
-
param.
|
|
316
|
+
param.requires_grad_(False)
|
|
317
317
|
|
|
318
318
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
319
319
|
x = self.stem(x)
|
birder/net/levit.py
CHANGED
|
@@ -399,18 +399,18 @@ class LeViT(BaseNet):
|
|
|
399
399
|
|
|
400
400
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
401
401
|
for param in self.parameters():
|
|
402
|
-
param.
|
|
402
|
+
param.requires_grad_(False)
|
|
403
403
|
|
|
404
404
|
if freeze_classifier is False:
|
|
405
405
|
for param in self.classifier.parameters():
|
|
406
|
-
param.
|
|
406
|
+
param.requires_grad_(True)
|
|
407
407
|
|
|
408
408
|
for param in self.dist_classifier.parameters():
|
|
409
|
-
param.
|
|
409
|
+
param.requires_grad_(True)
|
|
410
410
|
|
|
411
411
|
if unfreeze_features is True:
|
|
412
412
|
for param in self.features.parameters():
|
|
413
|
-
param.
|
|
413
|
+
param.requires_grad_(True)
|
|
414
414
|
|
|
415
415
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
416
416
|
x = self.stem(x)
|
birder/net/lit_v1.py
CHANGED
|
@@ -375,14 +375,14 @@ class LIT_v1(DetectorBackbone):
|
|
|
375
375
|
|
|
376
376
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
377
377
|
for param in self.stem.parameters():
|
|
378
|
-
param.
|
|
378
|
+
param.requires_grad_(False)
|
|
379
379
|
|
|
380
380
|
for idx, stage in enumerate(self.body.values()):
|
|
381
381
|
if idx >= up_to_stage:
|
|
382
382
|
break
|
|
383
383
|
|
|
384
384
|
for param in stage.parameters():
|
|
385
|
-
param.
|
|
385
|
+
param.requires_grad_(False)
|
|
386
386
|
|
|
387
387
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
388
388
|
x = self.stem(x)
|
birder/net/lit_v1_tiny.py
CHANGED
|
@@ -265,14 +265,14 @@ class LIT_v1_Tiny(DetectorBackbone):
|
|
|
265
265
|
|
|
266
266
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
267
267
|
for param in self.stem.parameters():
|
|
268
|
-
param.
|
|
268
|
+
param.requires_grad_(False)
|
|
269
269
|
|
|
270
270
|
for idx, stage in enumerate(self.body.values()):
|
|
271
271
|
if idx >= up_to_stage:
|
|
272
272
|
break
|
|
273
273
|
|
|
274
274
|
for param in stage.parameters():
|
|
275
|
-
param.
|
|
275
|
+
param.requires_grad_(False)
|
|
276
276
|
|
|
277
277
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
278
278
|
x = self.stem(x)
|
birder/net/lit_v2.py
CHANGED
|
@@ -375,14 +375,14 @@ class LIT_v2(DetectorBackbone):
|
|
|
375
375
|
|
|
376
376
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
377
377
|
for param in self.stem.parameters():
|
|
378
|
-
param.
|
|
378
|
+
param.requires_grad_(False)
|
|
379
379
|
|
|
380
380
|
for idx, stage in enumerate(self.body.values()):
|
|
381
381
|
if idx >= up_to_stage:
|
|
382
382
|
break
|
|
383
383
|
|
|
384
384
|
for param in stage.parameters():
|
|
385
|
-
param.
|
|
385
|
+
param.requires_grad_(False)
|
|
386
386
|
|
|
387
387
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
388
388
|
x = self.stem(x)
|
birder/net/maxvit.py
CHANGED
|
@@ -589,14 +589,14 @@ class MaxViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
589
589
|
|
|
590
590
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
591
591
|
for param in self.stem.parameters():
|
|
592
|
-
param.
|
|
592
|
+
param.requires_grad_(False)
|
|
593
593
|
|
|
594
594
|
for idx, module in enumerate(self.body.children()):
|
|
595
595
|
if idx >= up_to_stage:
|
|
596
596
|
break
|
|
597
597
|
|
|
598
598
|
for param in module.parameters():
|
|
599
|
-
param.
|
|
599
|
+
param.requires_grad_(False)
|
|
600
600
|
|
|
601
601
|
def masked_encoding_retention(
|
|
602
602
|
self,
|
birder/net/metaformer.py
CHANGED
|
@@ -449,14 +449,14 @@ class MetaFormer(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
449
449
|
|
|
450
450
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
451
451
|
for param in self.stem.parameters():
|
|
452
|
-
param.
|
|
452
|
+
param.requires_grad_(False)
|
|
453
453
|
|
|
454
454
|
for idx, module in enumerate(self.body.children()):
|
|
455
455
|
if idx >= up_to_stage:
|
|
456
456
|
break
|
|
457
457
|
|
|
458
458
|
for param in module.parameters():
|
|
459
|
-
param.
|
|
459
|
+
param.requires_grad_(False)
|
|
460
460
|
|
|
461
461
|
def masked_encoding_retention(
|
|
462
462
|
self,
|
birder/net/mnasnet.py
CHANGED
|
@@ -251,14 +251,14 @@ class MNASNet(DetectorBackbone):
|
|
|
251
251
|
|
|
252
252
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
253
253
|
for param in self.stem.parameters():
|
|
254
|
-
param.
|
|
254
|
+
param.requires_grad_(False)
|
|
255
255
|
|
|
256
256
|
for idx, module in enumerate(self.body.children()):
|
|
257
257
|
if idx >= up_to_stage:
|
|
258
258
|
break
|
|
259
259
|
|
|
260
260
|
for param in module.parameters():
|
|
261
|
-
param.
|
|
261
|
+
param.requires_grad_(False)
|
|
262
262
|
|
|
263
263
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
264
264
|
x = self.stem(x)
|
birder/net/mobilenet_v1.py
CHANGED
|
@@ -136,14 +136,14 @@ class MobileNet_v1(DetectorBackbone):
|
|
|
136
136
|
|
|
137
137
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
138
138
|
for param in self.stem.parameters():
|
|
139
|
-
param.
|
|
139
|
+
param.requires_grad_(False)
|
|
140
140
|
|
|
141
141
|
for idx, module in enumerate(self.body.children()):
|
|
142
142
|
if idx >= up_to_stage:
|
|
143
143
|
break
|
|
144
144
|
|
|
145
145
|
for param in module.parameters():
|
|
146
|
-
param.
|
|
146
|
+
param.requires_grad_(False)
|
|
147
147
|
|
|
148
148
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
149
149
|
x = self.stem(x)
|
birder/net/mobilenet_v2.py
CHANGED
|
@@ -204,14 +204,14 @@ class MobileNet_v2(DetectorBackbone):
|
|
|
204
204
|
|
|
205
205
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
206
206
|
for param in self.stem.parameters():
|
|
207
|
-
param.
|
|
207
|
+
param.requires_grad_(False)
|
|
208
208
|
|
|
209
209
|
for idx, module in enumerate(self.body.children()):
|
|
210
210
|
if idx >= up_to_stage:
|
|
211
211
|
break
|
|
212
212
|
|
|
213
213
|
for param in module.parameters():
|
|
214
|
-
param.
|
|
214
|
+
param.requires_grad_(False)
|
|
215
215
|
|
|
216
216
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
217
217
|
x = self.stem(x)
|
birder/net/mobilenet_v3_large.py
CHANGED
|
@@ -236,14 +236,14 @@ class MobileNet_v3_Large(DetectorBackbone):
|
|
|
236
236
|
|
|
237
237
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
238
238
|
for param in self.stem.parameters():
|
|
239
|
-
param.
|
|
239
|
+
param.requires_grad_(False)
|
|
240
240
|
|
|
241
241
|
for idx, module in enumerate(self.body.children()):
|
|
242
242
|
if idx >= up_to_stage:
|
|
243
243
|
break
|
|
244
244
|
|
|
245
245
|
for param in module.parameters():
|
|
246
|
-
param.
|
|
246
|
+
param.requires_grad_(False)
|
|
247
247
|
|
|
248
248
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
249
249
|
x = self.stem(x)
|
birder/net/mobilenet_v4.py
CHANGED
|
@@ -493,14 +493,14 @@ class MobileNet_v4(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin)
|
|
|
493
493
|
|
|
494
494
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
495
495
|
for param in self.stem.parameters():
|
|
496
|
-
param.
|
|
496
|
+
param.requires_grad_(False)
|
|
497
497
|
|
|
498
498
|
for idx, module in enumerate(self.body.children()):
|
|
499
499
|
if idx >= up_to_stage:
|
|
500
500
|
break
|
|
501
501
|
|
|
502
502
|
for param in module.parameters():
|
|
503
|
-
param.
|
|
503
|
+
param.requires_grad_(False)
|
|
504
504
|
|
|
505
505
|
def masked_encoding_retention(
|
|
506
506
|
self,
|
|
@@ -439,14 +439,14 @@ class MobileNet_v4_Hybrid(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentio
|
|
|
439
439
|
|
|
440
440
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
441
441
|
for param in self.stem.parameters():
|
|
442
|
-
param.
|
|
442
|
+
param.requires_grad_(False)
|
|
443
443
|
|
|
444
444
|
for idx, module in enumerate(self.body.children()):
|
|
445
445
|
if idx >= up_to_stage:
|
|
446
446
|
break
|
|
447
447
|
|
|
448
448
|
for param in module.parameters():
|
|
449
|
-
param.
|
|
449
|
+
param.requires_grad_(False)
|
|
450
450
|
|
|
451
451
|
def masked_encoding_retention(
|
|
452
452
|
self,
|
birder/net/mobileone.py
CHANGED
|
@@ -363,14 +363,14 @@ class MobileOne(DetectorBackbone):
|
|
|
363
363
|
|
|
364
364
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
365
365
|
for param in self.stem.parameters():
|
|
366
|
-
param.
|
|
366
|
+
param.requires_grad_(False)
|
|
367
367
|
|
|
368
368
|
for idx, module in enumerate(self.body.children()):
|
|
369
369
|
if idx >= up_to_stage:
|
|
370
370
|
break
|
|
371
371
|
|
|
372
372
|
for param in module.parameters():
|
|
373
|
-
param.
|
|
373
|
+
param.requires_grad_(False)
|
|
374
374
|
|
|
375
375
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
376
376
|
x = self.stem(x)
|
birder/net/mobilevit_v2.py
CHANGED
|
@@ -323,14 +323,14 @@ class MobileViT_v2(DetectorBackbone):
|
|
|
323
323
|
|
|
324
324
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
325
325
|
for param in self.stem.parameters():
|
|
326
|
-
param.
|
|
326
|
+
param.requires_grad_(False)
|
|
327
327
|
|
|
328
328
|
for idx, module in enumerate(self.body.children()):
|
|
329
329
|
if idx >= up_to_stage:
|
|
330
330
|
break
|
|
331
331
|
|
|
332
332
|
for param in module.parameters():
|
|
333
|
-
param.
|
|
333
|
+
param.requires_grad_(False)
|
|
334
334
|
|
|
335
335
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
336
336
|
x = self.stem(x)
|
birder/net/moganet.py
CHANGED
|
@@ -330,14 +330,14 @@ class MogaNet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
330
330
|
|
|
331
331
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
332
332
|
for param in self.stem.parameters():
|
|
333
|
-
param.
|
|
333
|
+
param.requires_grad_(False)
|
|
334
334
|
|
|
335
335
|
for idx, module in enumerate(self.body.children()):
|
|
336
336
|
if idx >= up_to_stage:
|
|
337
337
|
break
|
|
338
338
|
|
|
339
339
|
for param in module.parameters():
|
|
340
|
-
param.
|
|
340
|
+
param.requires_grad_(False)
|
|
341
341
|
|
|
342
342
|
def masked_encoding_retention(
|
|
343
343
|
self,
|
birder/net/mvit_v2.py
CHANGED
|
@@ -543,14 +543,14 @@ class MViT_v2(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
543
543
|
|
|
544
544
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
545
545
|
for param in self.patch_embed.parameters():
|
|
546
|
-
param.
|
|
546
|
+
param.requires_grad_(False)
|
|
547
547
|
|
|
548
548
|
for idx, module in enumerate(self.body.children()):
|
|
549
549
|
if idx >= up_to_stage:
|
|
550
550
|
break
|
|
551
551
|
|
|
552
552
|
for param in module.parameters():
|
|
553
|
-
param.
|
|
553
|
+
param.requires_grad_(False)
|
|
554
554
|
|
|
555
555
|
def masked_encoding_retention(
|
|
556
556
|
self,
|
birder/net/nextvit.py
CHANGED
|
@@ -381,14 +381,14 @@ class NextViT(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
381
381
|
|
|
382
382
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
383
383
|
for param in self.stem.parameters():
|
|
384
|
-
param.
|
|
384
|
+
param.requires_grad_(False)
|
|
385
385
|
|
|
386
386
|
for idx, module in enumerate(self.body.children()):
|
|
387
387
|
if idx >= up_to_stage:
|
|
388
388
|
break
|
|
389
389
|
|
|
390
390
|
for param in module.parameters():
|
|
391
|
-
param.
|
|
391
|
+
param.requires_grad_(False)
|
|
392
392
|
|
|
393
393
|
def masked_encoding_retention(
|
|
394
394
|
self,
|
birder/net/nfnet.py
CHANGED
|
@@ -294,14 +294,14 @@ class NFNet(DetectorBackbone):
|
|
|
294
294
|
|
|
295
295
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
296
296
|
for param in self.stem.parameters():
|
|
297
|
-
param.
|
|
297
|
+
param.requires_grad_(False)
|
|
298
298
|
|
|
299
299
|
for idx, module in enumerate(self.body.children()):
|
|
300
300
|
if idx >= up_to_stage:
|
|
301
301
|
break
|
|
302
302
|
|
|
303
303
|
for param in module.parameters():
|
|
304
|
-
param.
|
|
304
|
+
param.requires_grad_(False)
|
|
305
305
|
|
|
306
306
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
307
307
|
x = self.stem(x)
|
birder/net/pit.py
CHANGED
|
@@ -172,18 +172,18 @@ class PiT(DetectorBackbone):
|
|
|
172
172
|
|
|
173
173
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
174
174
|
for param in self.parameters():
|
|
175
|
-
param.
|
|
175
|
+
param.requires_grad_(False)
|
|
176
176
|
|
|
177
177
|
if freeze_classifier is False:
|
|
178
178
|
for param in self.classifier.parameters():
|
|
179
|
-
param.
|
|
179
|
+
param.requires_grad_(True)
|
|
180
180
|
|
|
181
181
|
for param in self.dist_classifier.parameters():
|
|
182
|
-
param.
|
|
182
|
+
param.requires_grad_(True)
|
|
183
183
|
|
|
184
184
|
if unfreeze_features is True:
|
|
185
185
|
for param in self.norm.parameters():
|
|
186
|
-
param.
|
|
186
|
+
param.requires_grad_(True)
|
|
187
187
|
|
|
188
188
|
def transform_to_backbone(self) -> None:
|
|
189
189
|
self.norm = nn.Identity()
|
|
@@ -205,14 +205,14 @@ class PiT(DetectorBackbone):
|
|
|
205
205
|
|
|
206
206
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
207
207
|
for param in self.stem.parameters():
|
|
208
|
-
param.
|
|
208
|
+
param.requires_grad_(False)
|
|
209
209
|
|
|
210
210
|
for idx, module in enumerate(self.body.children()):
|
|
211
211
|
if idx >= up_to_stage:
|
|
212
212
|
break
|
|
213
213
|
|
|
214
214
|
for param in module.parameters():
|
|
215
|
-
param.
|
|
215
|
+
param.requires_grad_(False)
|
|
216
216
|
|
|
217
217
|
def forward_features(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
218
218
|
x = self.stem(x)
|
birder/net/pvt_v1.py
CHANGED
|
@@ -277,14 +277,14 @@ class PVT_v1(DetectorBackbone):
|
|
|
277
277
|
|
|
278
278
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
279
279
|
for param in self.patch_embed.parameters():
|
|
280
|
-
param.
|
|
280
|
+
param.requires_grad_(False)
|
|
281
281
|
|
|
282
282
|
for idx, module in enumerate(self.body.children()):
|
|
283
283
|
if idx >= up_to_stage:
|
|
284
284
|
break
|
|
285
285
|
|
|
286
286
|
for param in module.parameters():
|
|
287
|
-
param.
|
|
287
|
+
param.requires_grad_(False)
|
|
288
288
|
|
|
289
289
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
290
290
|
x = self.patch_embed(x)
|
birder/net/pvt_v2.py
CHANGED
|
@@ -336,14 +336,14 @@ class PVT_v2(DetectorBackbone):
|
|
|
336
336
|
|
|
337
337
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
338
338
|
for param in self.patch_embed.parameters():
|
|
339
|
-
param.
|
|
339
|
+
param.requires_grad_(False)
|
|
340
340
|
|
|
341
341
|
for idx, module in enumerate(self.body.children()):
|
|
342
342
|
if idx >= up_to_stage:
|
|
343
343
|
break
|
|
344
344
|
|
|
345
345
|
for param in module.parameters():
|
|
346
|
-
param.
|
|
346
|
+
param.requires_grad_(False)
|
|
347
347
|
|
|
348
348
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
349
349
|
x = self.patch_embed(x)
|
birder/net/rdnet.py
CHANGED
|
@@ -247,14 +247,14 @@ class RDNet(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
247
247
|
|
|
248
248
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
249
249
|
for param in self.stem.parameters():
|
|
250
|
-
param.
|
|
250
|
+
param.requires_grad_(False)
|
|
251
251
|
|
|
252
252
|
for idx, module in enumerate(self.body.children()):
|
|
253
253
|
if idx >= up_to_stage:
|
|
254
254
|
break
|
|
255
255
|
|
|
256
256
|
for param in module.parameters():
|
|
257
|
-
param.
|
|
257
|
+
param.requires_grad_(False)
|
|
258
258
|
|
|
259
259
|
def masked_encoding_retention(
|
|
260
260
|
self,
|