birder 0.2.3__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder-0.2.3 → birder-0.3.1}/PKG-INFO +2 -1
- {birder-0.2.3 → birder-0.3.1}/birder/common/fs_ops.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/common/training_cli.py +12 -1
- {birder-0.2.3 → birder-0.3.1}/birder/common/training_utils.py +219 -33
- {birder-0.2.3 → birder-0.3.1}/birder/data/collators/detection.py +1 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/dataloader/webdataset.py +12 -2
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/load_kernel.py +16 -11
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/soft_nms/soft_nms.cpp +17 -18
- {birder-0.2.3 → birder-0.3.1}/birder/net/base.py +3 -3
- {birder-0.2.3 → birder-0.3.1}/birder/net/biformer.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/cait.py +4 -3
- {birder-0.2.3 → birder-0.3.1}/birder/net/cas_vit.py +6 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/coat.py +8 -8
- {birder-0.2.3 → birder-0.3.1}/birder/net/conv2former.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/convnext_v1.py +7 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/convnext_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/crossformer.py +35 -32
- {birder-0.2.3 → birder-0.3.1}/birder/net/crossvit.py +4 -3
- {birder-0.2.3 → birder-0.3.1}/birder/net/cspnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/cswin_transformer.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/darknet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/davit.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/deit.py +6 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/deit3.py +6 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/densenet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/deformable_detr.py +4 -7
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/detr.py +4 -7
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/efficientdet.py +4 -9
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/faster_rcnn.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/fcos.py +4 -9
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/retinanet.py +4 -9
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/rt_detr_v1.py +5 -4
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/ssd.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/ssdlite.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v3.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v4.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/edgenext.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/edgevit.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientformer_v1.py +19 -13
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientformer_v2.py +45 -35
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientnet_lite.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientnet_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientnet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientvim.py +3 -3
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientvit_mit.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/efficientvit_msft.py +11 -9
- {birder-0.2.3 → birder-0.3.1}/birder/net/fasternet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/fastvit.py +3 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/flexivit.py +11 -10
- {birder-0.2.3 → birder-0.3.1}/birder/net/focalnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/gc_vit.py +17 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/ghostnet_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/ghostnet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/groupmixformer.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/hgnet_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/hgnet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/hiera.py +14 -11
- {birder-0.2.3 → birder-0.3.1}/birder/net/hieradet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/hornet.py +11 -9
- {birder-0.2.3 → birder-0.3.1}/birder/net/iformer.py +10 -8
- {birder-0.2.3 → birder-0.3.1}/birder/net/inception_next.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/inception_resnet_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/inception_resnet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/inception_v3.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/inception_v4.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/levit.py +46 -34
- {birder-0.2.3 → birder-0.3.1}/birder/net/lit_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/lit_v1_tiny.py +17 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/lit_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/maxvit.py +69 -57
- {birder-0.2.3 → birder-0.3.1}/birder/net/metaformer.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mnasnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v3_large.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v4.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v4_hybrid.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobileone.py +3 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilevit_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/moganet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/mvit_v2.py +15 -14
- {birder-0.2.3 → birder-0.3.1}/birder/net/nextvit.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/nfnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/pit.py +10 -9
- {birder-0.2.3 → birder-0.3.1}/birder/net/pvt_v1.py +6 -3
- {birder-0.2.3 → birder-0.3.1}/birder/net/pvt_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/rdnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/regionvit.py +6 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/regnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/regnet_z.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/repghost.py +3 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/repvgg.py +3 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/repvit.py +7 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/resnest.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/resnet_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/resnet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/resnext.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/rope_deit3.py +8 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/rope_flexivit.py +13 -10
- {birder-0.2.3 → birder-0.3.1}/birder/net/rope_vit.py +30 -11
- {birder-0.2.3 → birder-0.3.1}/birder/net/shufflenet_v1.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/shufflenet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/simple_vit.py +9 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/smt.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/squeezenext.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/byol.py +3 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/capi.py +156 -11
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/data2vec.py +3 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/data2vec2.py +3 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/dino_v1.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/dino_v2.py +140 -18
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/franca.py +145 -13
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/ibot.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/mmcr.py +3 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/starnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/swiftformer.py +6 -6
- {birder-0.2.3 → birder-0.3.1}/birder/net/swin_transformer_v1.py +73 -70
- {birder-0.2.3 → birder-0.3.1}/birder/net/swin_transformer_v2.py +40 -33
- {birder-0.2.3 → birder-0.3.1}/birder/net/tiny_vit.py +22 -12
- {birder-0.2.3 → birder-0.3.1}/birder/net/transnext.py +39 -29
- {birder-0.2.3 → birder-0.3.1}/birder/net/uniformer.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/van.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/vgg.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/vgg_reduced.py +1 -1
- {birder-0.2.3 → birder-0.3.1}/birder/net/vit.py +11 -10
- {birder-0.2.3 → birder-0.3.1}/birder/net/vit_parallel.py +10 -9
- {birder-0.2.3 → birder-0.3.1}/birder/net/vit_sam.py +41 -40
- {birder-0.2.3 → birder-0.3.1}/birder/net/vovnet_v1.py +17 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/vovnet_v2.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/wide_resnet.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/xception.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/birder/net/xcit.py +2 -2
- birder-0.3.1/birder/ops/msda.py +203 -0
- birder-0.3.1/birder/ops/swattention.py +288 -0
- {birder-0.2.3 → birder-0.3.1}/birder/results/detection.py +108 -0
- {birder-0.2.3 → birder-0.3.1}/birder/results/gui.py +10 -8
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/benchmark.py +22 -13
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/predict.py +7 -0
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train.py +44 -24
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_barlow_twins.py +41 -23
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_byol.py +42 -24
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_capi.py +72 -26
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_data2vec.py +44 -26
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_data2vec2.py +46 -28
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_detection.py +40 -20
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_dino_v1.py +65 -31
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_dino_v2.py +133 -60
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_dino_v2_dist.py +131 -58
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_franca.py +84 -46
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_i_jepa.py +44 -25
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_ibot.py +53 -33
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_kd.py +44 -24
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_mim.py +41 -22
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_mmcr.py +42 -24
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_rotnet.py +42 -24
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_simclr.py +41 -23
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_vicreg.py +41 -23
- {birder-0.2.3 → birder-0.3.1}/birder/tools/convert_model.py +18 -15
- birder-0.3.1/birder/tools/det_results.py +264 -0
- birder-0.3.1/birder/tools/quantize_model.py +162 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/results.py +11 -7
- birder-0.3.1/birder/version.py +1 -0
- {birder-0.2.3 → birder-0.3.1}/birder.egg-info/PKG-INFO +2 -1
- {birder-0.2.3 → birder-0.3.1}/birder.egg-info/SOURCES.txt +1 -0
- {birder-0.2.3 → birder-0.3.1}/birder.egg-info/requires.txt +1 -0
- {birder-0.2.3 → birder-0.3.1}/requirements/_requirements-dev.txt +2 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_common.py +73 -2
- birder-0.3.1/tests/test_dataloaders.py +101 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_inference.py +2 -2
- {birder-0.2.3 → birder-0.3.1}/tests/test_net.py +41 -2
- birder-0.3.1/tests/test_net_detection.py +248 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_net_ssl.py +594 -6
- {birder-0.2.3 → birder-0.3.1}/tests/test_results.py +173 -0
- birder-0.2.3/birder/ops/msda.py +0 -138
- birder-0.2.3/birder/ops/swattention.py +0 -225
- birder-0.2.3/birder/tools/det_results.py +0 -61
- birder-0.2.3/birder/tools/quantize_model.py +0 -156
- birder-0.2.3/birder/version.py +0 -1
- birder-0.2.3/tests/test_net_detection.py +0 -170
- {birder-0.2.3 → birder-0.3.1}/LICENSE +0 -0
- {birder-0.2.3 → birder-0.3.1}/README.md +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/adversarial/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/adversarial/base.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/adversarial/deepfool.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/adversarial/fgsm.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/adversarial/pgd.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/adversarial/simba.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/common/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/common/cli.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/common/lib.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/common/masking.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/conf/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/conf/settings.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/collators/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/dataloader/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/coco.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/directory.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/fake.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/webdataset.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/detection.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/mosaic.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/datahub/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/datahub/_lib.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/datahub/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/inference/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/inference/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/inference/data_parallel.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/inference/detection.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/inference/wbf.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/introspection/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/introspection/attention_rollout.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/introspection/base.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/introspection/gradcam.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/introspection/guided_backprop.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/introspection/transformer_attribution.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/vision.cpp +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/soft_nms/op.cpp +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/soft_nms/soft_nms.h +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/swattention.cpp +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/layers/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/layers/activations.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/layers/attention_pool.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/layers/ffn.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/layers/gem.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/layers/layer_norm.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/layers/layer_scale.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/model_registry/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/model_registry/manifest.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/model_registry/model_registry.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/alexnet.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/convmixer.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/base.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/vitdet.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_anchors.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v4_tiny.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/dpn.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mim/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mim/base.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mim/crossmae.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mim/fcmae.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mim/mae_hiera.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mim/mae_vit.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mim/simmim.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v3_small.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/mobilevit_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/resmlp.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/se_resnet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/se_resnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/se_resnext.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/sequencer2d.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/squeezenet.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/barlow_twins.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/base.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/i_jepa.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/simclr.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/sscd.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/vicreg.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/ops/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/ops/soft_nms.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/optim/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/optim/lamb.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/optim/lars.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/py.typed +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/results/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/results/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/scheduler/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/scheduler/cooldown.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/__main__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/evaluate.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/scripts/predict_detection.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/__main__.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/adversarial.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/auto_anchors.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/avg_model.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/download_model.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/ensemble_model.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/introspection.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/labelme_to_coco.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/list_models.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/model_info.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/pack.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/show_det_iterator.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/show_iterator.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/similarity.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/stats.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/verify_coco.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/verify_directory.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder/tools/voc_to_coco.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder.egg-info/dependency_links.txt +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder.egg-info/entry_points.txt +0 -0
- {birder-0.2.3 → birder-0.3.1}/birder.egg-info/top_level.txt +0 -0
- {birder-0.2.3 → birder-0.3.1}/pyproject.toml +0 -0
- {birder-0.2.3 → birder-0.3.1}/requirements/requirements-hf.txt +0 -0
- {birder-0.2.3 → birder-0.3.1}/requirements/requirements.txt +0 -0
- {birder-0.2.3 → birder-0.3.1}/setup.cfg +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_adversarial.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_collators.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_datasets.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_introspection.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_kernels.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_layers.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_model_registry.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_net_mim.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_ops.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_optim.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_scheduler.py +0 -0
- {birder-0.2.3 → birder-0.3.1}/tests/test_transforms.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -66,6 +66,7 @@ Requires-Dist: pytest; extra == "dev"
|
|
|
66
66
|
Requires-Dist: requests~=2.32.5; extra == "dev"
|
|
67
67
|
Requires-Dist: safetensors~=0.7.0; extra == "dev"
|
|
68
68
|
Requires-Dist: setuptools; extra == "dev"
|
|
69
|
+
Requires-Dist: torchao~=0.15.0; extra == "dev"
|
|
69
70
|
Requires-Dist: torchprofile==0.0.4; extra == "dev"
|
|
70
71
|
Requires-Dist: twine~=6.2.0; extra == "dev"
|
|
71
72
|
Requires-Dist: types-requests~=2.32.4; extra == "dev"
|
|
@@ -627,7 +627,7 @@ def load_model(
|
|
|
627
627
|
net.to(dtype)
|
|
628
628
|
if inference is True:
|
|
629
629
|
for param in net.parameters():
|
|
630
|
-
param.
|
|
630
|
+
param.requires_grad_(False)
|
|
631
631
|
|
|
632
632
|
if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
|
|
633
633
|
net.eval()
|
|
@@ -799,7 +799,7 @@ def load_detection_model(
|
|
|
799
799
|
net.to(dtype)
|
|
800
800
|
if inference is True:
|
|
801
801
|
for param in net.parameters():
|
|
802
|
-
param.
|
|
802
|
+
param.requires_grad_(False)
|
|
803
803
|
|
|
804
804
|
net.eval()
|
|
805
805
|
|
|
@@ -39,6 +39,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
|
|
|
39
39
|
group = parser.add_argument_group("Optimization parameters")
|
|
40
40
|
group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
|
|
41
41
|
group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
|
|
42
|
+
group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
|
|
42
43
|
group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
|
|
43
44
|
group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
|
|
44
45
|
group.add_argument("--opt-eps", type=float, help="optimizer epsilon (None to use the optimizer default)")
|
|
@@ -211,6 +212,12 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
|
|
|
211
212
|
group.add_argument(
|
|
212
213
|
"--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
|
|
213
214
|
)
|
|
215
|
+
group.add_argument(
|
|
216
|
+
"--steps-per-epoch",
|
|
217
|
+
type=int,
|
|
218
|
+
metavar="N",
|
|
219
|
+
help="virtual epoch length in steps, leave unset to use the full dataset",
|
|
220
|
+
)
|
|
214
221
|
group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
|
|
215
222
|
group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
|
|
216
223
|
group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
|
|
@@ -243,6 +250,7 @@ def add_data_aug_args(
|
|
|
243
250
|
default_level: int = 4,
|
|
244
251
|
default_min_scale: Optional[float] = None,
|
|
245
252
|
default_re_prob: Optional[float] = None,
|
|
253
|
+
smoothing_alpha: bool = False,
|
|
246
254
|
mixup_cutmix: bool = False,
|
|
247
255
|
) -> None:
|
|
248
256
|
group = parser.add_argument_group("Data augmentation parameters")
|
|
@@ -279,6 +287,8 @@ def add_data_aug_args(
|
|
|
279
287
|
group.add_argument(
|
|
280
288
|
"--simple-crop", default=False, action="store_true", help="use simple random crop (SRC) instead of RRC"
|
|
281
289
|
)
|
|
290
|
+
if smoothing_alpha is True:
|
|
291
|
+
group.add_argument("--smoothing-alpha", type=float, default=0.0, help="label smoothing alpha")
|
|
282
292
|
if mixup_cutmix is True:
|
|
283
293
|
group.add_argument("--mixup-alpha", type=float, help="mixup alpha")
|
|
284
294
|
group.add_argument("--cutmix", default=False, action="store_true", help="enable cutmix")
|
|
@@ -559,9 +569,9 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
|
|
|
559
569
|
group.add_argument("--wds", default=False, action="store_true", help="use webdataset for training")
|
|
560
570
|
group.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
|
|
561
571
|
group.add_argument("--wds-cache-dir", type=str, metavar="DIR", help="webdataset cache directory")
|
|
562
|
-
group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
|
|
563
572
|
if unsupervised is False:
|
|
564
573
|
group.add_argument("--wds-class-file", type=str, metavar="FILE", help="class list file")
|
|
574
|
+
group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
|
|
565
575
|
group.add_argument("--wds-val-size", type=int, metavar="N", help="size of the wds validation set")
|
|
566
576
|
group.add_argument(
|
|
567
577
|
"--wds-training-split", type=str, default="training", metavar="NAME", help="wds dataset train split"
|
|
@@ -570,6 +580,7 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
|
|
|
570
580
|
"--wds-val-split", type=str, default="validation", metavar="NAME", help="wds dataset validation split"
|
|
571
581
|
)
|
|
572
582
|
else:
|
|
583
|
+
group.add_argument("--wds-size", type=int, metavar="N", help="size of the wds")
|
|
573
584
|
group.add_argument(
|
|
574
585
|
"--wds-split", type=str, default="training", metavar="NAME", help="wds dataset split to load"
|
|
575
586
|
)
|
|
@@ -17,6 +17,7 @@ from typing import Any
|
|
|
17
17
|
from typing import Literal
|
|
18
18
|
from typing import Optional
|
|
19
19
|
from typing import Sized
|
|
20
|
+
from typing import overload
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
import torch
|
|
@@ -70,13 +71,7 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
70
71
|
"""
|
|
71
72
|
|
|
72
73
|
def __init__(
|
|
73
|
-
self,
|
|
74
|
-
dataset: Sized,
|
|
75
|
-
num_replicas: int,
|
|
76
|
-
rank: int,
|
|
77
|
-
shuffle: bool,
|
|
78
|
-
seed: int = 0,
|
|
79
|
-
repetitions: int = 3,
|
|
74
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
80
75
|
) -> None:
|
|
81
76
|
super().__init__()
|
|
82
77
|
self.dataset = dataset
|
|
@@ -85,12 +80,11 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
85
80
|
self.epoch = 0
|
|
86
81
|
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
87
82
|
self.total_size = self.num_samples * self.num_replicas
|
|
88
|
-
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
|
89
83
|
self.shuffle = shuffle
|
|
90
84
|
self.seed = seed
|
|
91
85
|
self.repetitions = repetitions
|
|
92
86
|
|
|
93
|
-
def __iter__(self) -> Iterator[
|
|
87
|
+
def __iter__(self) -> Iterator[int]:
|
|
94
88
|
if self.shuffle is True:
|
|
95
89
|
# Deterministically shuffle based on epoch
|
|
96
90
|
g = torch.Generator()
|
|
@@ -100,18 +94,148 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
100
94
|
indices = list(range(len(self.dataset)))
|
|
101
95
|
|
|
102
96
|
# Add extra samples to make it evenly divisible
|
|
103
|
-
indices = [ele for ele in indices for
|
|
104
|
-
indices
|
|
105
|
-
|
|
97
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
98
|
+
if len(indices) < self.total_size:
|
|
99
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
100
|
+
else:
|
|
101
|
+
indices = indices[: self.total_size]
|
|
106
102
|
|
|
107
|
-
#
|
|
103
|
+
# Shard by rank
|
|
108
104
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
109
105
|
assert len(indices) == self.num_samples
|
|
110
106
|
|
|
111
|
-
|
|
107
|
+
yield from indices
|
|
108
|
+
|
|
109
|
+
def __len__(self) -> int:
|
|
110
|
+
return self.num_samples
|
|
111
|
+
|
|
112
|
+
def set_epoch(self, epoch: int) -> None:
|
|
113
|
+
self.epoch = epoch
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class InfiniteSampler(torch.utils.data.Sampler):
|
|
117
|
+
"""
|
|
118
|
+
Infinite sampler that loops indefinitely over the dataset
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, dataset: Sized, shuffle: bool, seed: int = 0) -> None:
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.dataset = dataset
|
|
124
|
+
self.shuffle = shuffle
|
|
125
|
+
self.seed = seed
|
|
126
|
+
self.epoch = 0
|
|
127
|
+
|
|
128
|
+
def __iter__(self) -> Iterator[int]:
|
|
129
|
+
g = torch.Generator()
|
|
130
|
+
while True:
|
|
131
|
+
if self.shuffle is True:
|
|
132
|
+
g.manual_seed(self.seed + self.epoch)
|
|
133
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
134
|
+
else:
|
|
135
|
+
indices = list(range(len(self.dataset)))
|
|
136
|
+
|
|
137
|
+
yield from indices
|
|
138
|
+
|
|
139
|
+
logger.debug(f"InfiniteSampler finished epoch {self.epoch}")
|
|
140
|
+
self.epoch += 1
|
|
141
|
+
|
|
142
|
+
def __len__(self) -> int:
|
|
143
|
+
return len(self.dataset)
|
|
144
|
+
|
|
145
|
+
def set_epoch(self, epoch: int) -> None:
|
|
146
|
+
self.epoch = epoch
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class InfiniteDistributedSampler(torch.utils.data.Sampler):
|
|
150
|
+
"""
|
|
151
|
+
Infinite distributed sampler that keeps a continuous shuffled stream per rank
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0) -> None:
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.dataset = dataset
|
|
157
|
+
self.num_replicas = num_replicas
|
|
158
|
+
self.rank = rank
|
|
159
|
+
self.shuffle = shuffle
|
|
160
|
+
self.seed = seed
|
|
161
|
+
self.epoch = 0
|
|
162
|
+
self.num_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
|
163
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
164
|
+
|
|
165
|
+
def __iter__(self) -> Iterator[int]:
|
|
166
|
+
g = torch.Generator()
|
|
167
|
+
while True:
|
|
168
|
+
if self.shuffle is True:
|
|
169
|
+
g.manual_seed(self.seed + self.epoch)
|
|
170
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
171
|
+
else:
|
|
172
|
+
indices = list(range(len(self.dataset)))
|
|
173
|
+
|
|
174
|
+
if len(indices) < self.total_size:
|
|
175
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
176
|
+
else:
|
|
177
|
+
indices = indices[: self.total_size]
|
|
178
|
+
|
|
179
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
180
|
+
assert len(indices) == self.num_samples
|
|
181
|
+
|
|
182
|
+
yield from indices
|
|
183
|
+
|
|
184
|
+
logger.debug(f"InfiniteDistributedSampler finished epoch {self.epoch}")
|
|
185
|
+
self.epoch += 1
|
|
112
186
|
|
|
113
187
|
def __len__(self) -> int:
|
|
114
|
-
return self.
|
|
188
|
+
return self.num_samples
|
|
189
|
+
|
|
190
|
+
def set_epoch(self, epoch: int) -> None:
|
|
191
|
+
self.epoch = epoch
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class InfiniteRASampler(torch.utils.data.Sampler):
|
|
195
|
+
"""
|
|
196
|
+
Infinite version of the repeated augmentation sampler
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(
|
|
200
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
201
|
+
) -> None:
|
|
202
|
+
super().__init__()
|
|
203
|
+
self.dataset = dataset
|
|
204
|
+
self.num_replicas = num_replicas
|
|
205
|
+
self.rank = rank
|
|
206
|
+
self.epoch = 0
|
|
207
|
+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
208
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
209
|
+
self.shuffle = shuffle
|
|
210
|
+
self.seed = seed
|
|
211
|
+
self.repetitions = repetitions
|
|
212
|
+
|
|
213
|
+
def __iter__(self) -> Iterator[int]:
|
|
214
|
+
g = torch.Generator()
|
|
215
|
+
while True:
|
|
216
|
+
if self.shuffle is True:
|
|
217
|
+
g.manual_seed(self.seed + self.epoch)
|
|
218
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
219
|
+
else:
|
|
220
|
+
indices = list(range(len(self.dataset)))
|
|
221
|
+
|
|
222
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
223
|
+
if len(indices) < self.total_size:
|
|
224
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
225
|
+
else:
|
|
226
|
+
indices = indices[: self.total_size]
|
|
227
|
+
|
|
228
|
+
# Shard by rank
|
|
229
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
230
|
+
assert len(indices) == self.num_samples
|
|
231
|
+
|
|
232
|
+
yield from indices
|
|
233
|
+
|
|
234
|
+
logger.debug(f"InfiniteRASampler finished epoch {self.epoch}")
|
|
235
|
+
self.epoch += 1
|
|
236
|
+
|
|
237
|
+
def __len__(self) -> int:
|
|
238
|
+
return self.num_samples
|
|
115
239
|
|
|
116
240
|
def set_epoch(self, epoch: int) -> None:
|
|
117
241
|
self.epoch = epoch
|
|
@@ -469,12 +593,14 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
|
|
|
469
593
|
kwargs["betas"] = args.opt_betas
|
|
470
594
|
if getattr(args, "opt_alpha", None) is not None:
|
|
471
595
|
kwargs["alpha"] = args.opt_alpha
|
|
596
|
+
if getattr(args, "opt_fused", False) is True:
|
|
597
|
+
kwargs["fused"] = True
|
|
472
598
|
|
|
473
599
|
# For optimizer compilation
|
|
474
600
|
# lr = torch.tensor(l_rate) - Causes weird LR scheduling bugs
|
|
475
601
|
lr = l_rate
|
|
476
|
-
if getattr(args, "compile_opt", False) is
|
|
477
|
-
if opt not in ("lamb", "lambw", "lars"):
|
|
602
|
+
if getattr(args, "compile_opt", False) is True:
|
|
603
|
+
if opt not in ("sgd", "lamb", "lambw", "lars"):
|
|
478
604
|
logger.debug("Setting optimizer capturable to True")
|
|
479
605
|
kwargs["capturable"] = True
|
|
480
606
|
|
|
@@ -636,27 +762,87 @@ def get_amp_scaler(amp: bool, amp_dtype_str: str) -> tuple[Optional[torch.amp.Gr
|
|
|
636
762
|
return (scaler, amp_dtype)
|
|
637
763
|
|
|
638
764
|
|
|
765
|
+
@overload
|
|
639
766
|
def get_samplers(
|
|
640
|
-
args: argparse.Namespace,
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
training_dataset,
|
|
646
|
-
num_replicas=args.world_size,
|
|
647
|
-
rank=args.rank,
|
|
648
|
-
shuffle=True,
|
|
649
|
-
repetitions=args.ra_reps,
|
|
650
|
-
)
|
|
767
|
+
args: argparse.Namespace,
|
|
768
|
+
training_dataset: torch.utils.data.Dataset,
|
|
769
|
+
validation_dataset: torch.utils.data.Dataset,
|
|
770
|
+
infinite: bool = False,
|
|
771
|
+
) -> tuple[torch.utils.data.Sampler, torch.utils.data.Sampler]: ...
|
|
651
772
|
|
|
652
|
-
else:
|
|
653
|
-
train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset, shuffle=True)
|
|
654
773
|
|
|
655
|
-
|
|
774
|
+
@overload
|
|
775
|
+
def get_samplers(
|
|
776
|
+
args: argparse.Namespace,
|
|
777
|
+
training_dataset: torch.utils.data.Dataset,
|
|
778
|
+
validation_dataset: None = None,
|
|
779
|
+
infinite: bool = False,
|
|
780
|
+
) -> tuple[torch.utils.data.Sampler, None]: ...
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
def get_samplers(
|
|
784
|
+
args: argparse.Namespace,
|
|
785
|
+
training_dataset: torch.utils.data.Dataset,
|
|
786
|
+
validation_dataset: Optional[torch.utils.data.Dataset] = None,
|
|
787
|
+
infinite: bool = False,
|
|
788
|
+
) -> tuple[torch.utils.data.Sampler, Optional[torch.utils.data.Sampler]]:
|
|
789
|
+
if args.seed is None:
|
|
790
|
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
|
791
|
+
if is_dist_available_and_initialized() is True:
|
|
792
|
+
seed_tensor = torch.tensor(seed, dtype=torch.int64).cuda()
|
|
793
|
+
dist.broadcast(seed_tensor, src=0, async_op=False)
|
|
794
|
+
seed = int(seed_tensor.item())
|
|
795
|
+
else:
|
|
796
|
+
seed = args.seed
|
|
797
|
+
|
|
798
|
+
ra_sampler = getattr(args, "ra_sampler", False)
|
|
799
|
+
if args.distributed is True:
|
|
800
|
+
if infinite is True:
|
|
801
|
+
if ra_sampler is True:
|
|
802
|
+
train_sampler = InfiniteRASampler(
|
|
803
|
+
training_dataset,
|
|
804
|
+
num_replicas=args.world_size,
|
|
805
|
+
rank=args.rank,
|
|
806
|
+
shuffle=True,
|
|
807
|
+
seed=seed,
|
|
808
|
+
repetitions=args.ra_reps,
|
|
809
|
+
)
|
|
810
|
+
else:
|
|
811
|
+
train_sampler = InfiniteDistributedSampler(
|
|
812
|
+
training_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True, seed=seed
|
|
813
|
+
)
|
|
814
|
+
else:
|
|
815
|
+
if ra_sampler is True:
|
|
816
|
+
train_sampler = RASampler(
|
|
817
|
+
training_dataset,
|
|
818
|
+
num_replicas=args.world_size,
|
|
819
|
+
rank=args.rank,
|
|
820
|
+
shuffle=True,
|
|
821
|
+
seed=seed,
|
|
822
|
+
repetitions=args.ra_reps,
|
|
823
|
+
)
|
|
824
|
+
else:
|
|
825
|
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
826
|
+
training_dataset, shuffle=True, seed=seed
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
if validation_dataset is None:
|
|
830
|
+
validation_sampler = None
|
|
831
|
+
else:
|
|
832
|
+
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
|
|
656
833
|
|
|
657
834
|
else:
|
|
658
|
-
|
|
659
|
-
|
|
835
|
+
if infinite is True:
|
|
836
|
+
train_sampler = InfiniteSampler(training_dataset, shuffle=True, seed=seed)
|
|
837
|
+
else:
|
|
838
|
+
generator = torch.Generator()
|
|
839
|
+
generator.manual_seed(seed)
|
|
840
|
+
train_sampler = torch.utils.data.RandomSampler(training_dataset, generator=generator)
|
|
841
|
+
|
|
842
|
+
if validation_dataset is None:
|
|
843
|
+
validation_sampler = None
|
|
844
|
+
else:
|
|
845
|
+
validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
|
|
660
846
|
|
|
661
847
|
return (train_sampler, validation_sampler)
|
|
662
848
|
|
|
@@ -98,6 +98,7 @@ class BatchRandomResizeCollator(DetectionCollator):
|
|
|
98
98
|
if isinstance(boxes, tv_tensors.BoundingBoxes) is False:
|
|
99
99
|
if boxes.numel() == 0:
|
|
100
100
|
boxes = boxes.reshape(0, 4)
|
|
101
|
+
|
|
101
102
|
boxes = tv_tensors.BoundingBoxes(
|
|
102
103
|
boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=F.get_size(image)
|
|
103
104
|
)
|
|
@@ -22,9 +22,19 @@ def make_wds_loader(
|
|
|
22
22
|
shuffle: bool = False,
|
|
23
23
|
*,
|
|
24
24
|
exact: bool = False,
|
|
25
|
+
infinite: bool = False,
|
|
25
26
|
) -> DataLoader:
|
|
27
|
+
assert exact is False or infinite is False
|
|
28
|
+
|
|
29
|
+
if infinite is True:
|
|
30
|
+
dataset_iterable = dataset.repeat()
|
|
31
|
+
elif exact is False:
|
|
32
|
+
dataset_iterable = dataset.repeat()
|
|
33
|
+
else:
|
|
34
|
+
dataset_iterable = dataset
|
|
35
|
+
|
|
26
36
|
dataloader = wds.WebLoader(
|
|
27
|
-
|
|
37
|
+
dataset_iterable,
|
|
28
38
|
batch_size=batch_size,
|
|
29
39
|
num_workers=num_workers,
|
|
30
40
|
prefetch_factor=prefetch_factor,
|
|
@@ -43,7 +53,7 @@ def make_wds_loader(
|
|
|
43
53
|
epoch_size = math.ceil(len(dataset) / (batch_size * world_size))
|
|
44
54
|
|
|
45
55
|
dataloader = dataloader.with_length(epoch_size, silent=True)
|
|
46
|
-
if exact is False:
|
|
56
|
+
if exact is False and infinite is False:
|
|
47
57
|
dataloader = dataloader.with_epoch(epoch_size)
|
|
48
58
|
|
|
49
59
|
return dataloader
|
|
@@ -14,11 +14,24 @@ logger = logging.getLogger(__name__)
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
_CACHED_KERNELS: dict[str, ModuleType] = {}
|
|
17
|
+
_CUSTOM_KERNELS_ENABLED = True
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def set_custom_kernels_enabled(enabled: bool) -> None:
|
|
21
|
+
global _CUSTOM_KERNELS_ENABLED # pylint: disable=global-statement
|
|
22
|
+
_CUSTOM_KERNELS_ENABLED = enabled
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_custom_kernels_enabled() -> bool:
|
|
26
|
+
if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
return _CUSTOM_KERNELS_ENABLED
|
|
17
30
|
|
|
18
31
|
|
|
19
32
|
def load_msda() -> Optional[ModuleType]:
|
|
20
33
|
name = "msda"
|
|
21
|
-
if torch.cuda.is_available() is False or
|
|
34
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
22
35
|
return None
|
|
23
36
|
|
|
24
37
|
if name in _CACHED_KERNELS:
|
|
@@ -60,7 +73,7 @@ def load_msda() -> Optional[ModuleType]:
|
|
|
60
73
|
|
|
61
74
|
def load_swattention() -> Optional[ModuleType]:
|
|
62
75
|
name = "swattention"
|
|
63
|
-
if torch.cuda.is_available() is False or
|
|
76
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
64
77
|
return None
|
|
65
78
|
|
|
66
79
|
if name in _CACHED_KERNELS:
|
|
@@ -103,7 +116,7 @@ def load_swattention() -> Optional[ModuleType]:
|
|
|
103
116
|
|
|
104
117
|
def load_soft_nms() -> Optional[ModuleType]:
|
|
105
118
|
name = "soft_nms"
|
|
106
|
-
if
|
|
119
|
+
if is_custom_kernels_enabled() is False:
|
|
107
120
|
return None
|
|
108
121
|
|
|
109
122
|
if name in _CACHED_KERNELS:
|
|
@@ -120,14 +133,6 @@ def load_soft_nms() -> Optional[ModuleType]:
|
|
|
120
133
|
soft_nms: Optional[ModuleType] = load(
|
|
121
134
|
"soft_nms",
|
|
122
135
|
src_files,
|
|
123
|
-
with_cuda=True,
|
|
124
|
-
extra_cflags=["-DWITH_CUDA=1"],
|
|
125
|
-
extra_cuda_cflags=[
|
|
126
|
-
"-DCUDA_HAS_FP16=1",
|
|
127
|
-
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
128
|
-
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
129
|
-
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
130
|
-
],
|
|
131
136
|
)
|
|
132
137
|
|
|
133
138
|
if soft_nms is not None:
|
|
@@ -61,24 +61,23 @@ void update_sorting_order(torch::Tensor& boxes, torch::Tensor& scores, torch::Te
|
|
|
61
61
|
std::tie(max_score, t_max_idx) = torch::max(scores.index({Slice(idx + 1, None)}), 0);
|
|
62
62
|
|
|
63
63
|
// max_idx is computed from sliced data, therefore need to convert it to "global" max idx
|
|
64
|
-
auto max_idx = t_max_idx
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
}
|
|
64
|
+
auto max_idx = t_max_idx + (idx + 1);
|
|
65
|
+
auto should_swap = scores.index({idx}) < max_score;
|
|
66
|
+
|
|
67
|
+
auto boxes_idx = boxes.index({idx}).clone();
|
|
68
|
+
auto boxes_max = boxes.index({max_idx}).clone();
|
|
69
|
+
boxes.index_put_({idx}, torch::where(should_swap, boxes_max, boxes_idx));
|
|
70
|
+
boxes.index_put_({max_idx}, torch::where(should_swap, boxes_idx, boxes_max));
|
|
71
|
+
|
|
72
|
+
auto scores_idx = scores.index({idx}).clone();
|
|
73
|
+
auto scores_max = scores.index({max_idx}).clone();
|
|
74
|
+
scores.index_put_({idx}, torch::where(should_swap, scores_max, scores_idx));
|
|
75
|
+
scores.index_put_({max_idx}, torch::where(should_swap, scores_idx, scores_max));
|
|
76
|
+
|
|
77
|
+
auto areas_idx = areas.index({idx}).clone();
|
|
78
|
+
auto areas_max = areas.index({max_idx}).clone();
|
|
79
|
+
areas.index_put_({idx}, torch::where(should_swap, areas_max, areas_idx));
|
|
80
|
+
areas.index_put_({max_idx}, torch::where(should_swap, areas_idx, areas_max));
|
|
82
81
|
}
|
|
83
82
|
|
|
84
83
|
std::tuple<torch::Tensor, torch::Tensor> soft_nms(
|
|
@@ -173,14 +173,14 @@ class BaseNet(nn.Module):
|
|
|
173
173
|
|
|
174
174
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
175
175
|
for param in self.parameters():
|
|
176
|
-
param.
|
|
176
|
+
param.requires_grad_(False)
|
|
177
177
|
|
|
178
178
|
if freeze_classifier is False:
|
|
179
179
|
for param in self.classifier.parameters():
|
|
180
|
-
param.
|
|
180
|
+
param.requires_grad_(True)
|
|
181
181
|
if unfreeze_features is True and hasattr(self, "features") is True:
|
|
182
182
|
for param in self.features.parameters():
|
|
183
|
-
param.
|
|
183
|
+
param.requires_grad_(True)
|
|
184
184
|
|
|
185
185
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
186
186
|
"""
|
|
@@ -468,14 +468,14 @@ class BiFormer(DetectorBackbone):
|
|
|
468
468
|
|
|
469
469
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
470
470
|
for param in self.stem.parameters():
|
|
471
|
-
param.
|
|
471
|
+
param.requires_grad_(False)
|
|
472
472
|
|
|
473
473
|
for idx, module in enumerate(self.body.children()):
|
|
474
474
|
if idx >= up_to_stage:
|
|
475
475
|
break
|
|
476
476
|
|
|
477
477
|
for param in module.parameters():
|
|
478
|
-
param.
|
|
478
|
+
param.requires_grad_(False)
|
|
479
479
|
|
|
480
480
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
481
481
|
x = self.stem(x)
|
|
@@ -268,14 +268,15 @@ class CaiT(BaseNet):
|
|
|
268
268
|
super().adjust_size(new_size)
|
|
269
269
|
|
|
270
270
|
# Add back class tokens
|
|
271
|
-
|
|
272
|
-
adjust_position_embedding(
|
|
271
|
+
with torch.no_grad():
|
|
272
|
+
pos_embed = adjust_position_embedding(
|
|
273
273
|
self.pos_embed,
|
|
274
274
|
(old_size[0] // self.patch_size[0], old_size[1] // self.patch_size[1]),
|
|
275
275
|
(new_size[0] // self.patch_size[0], new_size[1] // self.patch_size[1]),
|
|
276
276
|
0,
|
|
277
277
|
)
|
|
278
|
-
|
|
278
|
+
|
|
279
|
+
self.pos_embed = nn.Parameter(pos_embed)
|
|
279
280
|
|
|
280
281
|
|
|
281
282
|
registry.register_model_config(
|
|
@@ -269,18 +269,18 @@ class CAS_ViT(DetectorBackbone):
|
|
|
269
269
|
|
|
270
270
|
def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
|
|
271
271
|
for param in self.parameters():
|
|
272
|
-
param.
|
|
272
|
+
param.requires_grad_(False)
|
|
273
273
|
|
|
274
274
|
if freeze_classifier is False:
|
|
275
275
|
for param in self.classifier.parameters():
|
|
276
|
-
param.
|
|
276
|
+
param.requires_grad_(True)
|
|
277
277
|
|
|
278
278
|
for param in self.dist_classifier.parameters():
|
|
279
|
-
param.
|
|
279
|
+
param.requires_grad_(True)
|
|
280
280
|
|
|
281
281
|
if unfreeze_features is True:
|
|
282
282
|
for param in self.features.parameters():
|
|
283
|
-
param.
|
|
283
|
+
param.requires_grad_(True)
|
|
284
284
|
|
|
285
285
|
def transform_to_backbone(self) -> None:
|
|
286
286
|
self.features = nn.Identity()
|
|
@@ -300,14 +300,14 @@ class CAS_ViT(DetectorBackbone):
|
|
|
300
300
|
|
|
301
301
|
def freeze_stages(self, up_to_stage: int) -> None:
|
|
302
302
|
for param in self.stem.parameters():
|
|
303
|
-
param.
|
|
303
|
+
param.requires_grad_(False)
|
|
304
304
|
|
|
305
305
|
for idx, module in enumerate(self.body.children()):
|
|
306
306
|
if idx >= up_to_stage:
|
|
307
307
|
break
|
|
308
308
|
|
|
309
309
|
for param in module.parameters():
|
|
310
|
-
param.
|
|
310
|
+
param.requires_grad_(False)
|
|
311
311
|
|
|
312
312
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
313
313
|
x = self.stem(x)
|