birder 0.2.3__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder-0.2.3 → birder-0.3.0}/PKG-INFO +2 -1
- {birder-0.2.3 → birder-0.3.0}/birder/common/training_cli.py +6 -0
- {birder-0.2.3 → birder-0.3.0}/birder/common/training_utils.py +215 -31
- {birder-0.2.3 → birder-0.3.0}/birder/data/collators/detection.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/dataloader/webdataset.py +12 -2
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/load_kernel.py +16 -11
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/soft_nms/soft_nms.cpp +17 -18
- {birder-0.2.3 → birder-0.3.0}/birder/net/cait.py +4 -3
- {birder-0.2.3 → birder-0.3.0}/birder/net/convnext_v1.py +5 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/crossformer.py +33 -30
- {birder-0.2.3 → birder-0.3.0}/birder/net/crossvit.py +4 -3
- {birder-0.2.3 → birder-0.3.0}/birder/net/deit.py +3 -3
- {birder-0.2.3 → birder-0.3.0}/birder/net/deit3.py +3 -3
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/deformable_detr.py +2 -5
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/detr.py +2 -5
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/efficientdet.py +2 -7
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/fcos.py +2 -7
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/retinanet.py +2 -7
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/rt_detr_v1.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientformer_v1.py +15 -9
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientformer_v2.py +39 -29
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientvit_msft.py +9 -7
- {birder-0.2.3 → birder-0.3.0}/birder/net/fastvit.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/flexivit.py +5 -4
- {birder-0.2.3 → birder-0.3.0}/birder/net/hiera.py +12 -9
- {birder-0.2.3 → birder-0.3.0}/birder/net/hornet.py +9 -7
- {birder-0.2.3 → birder-0.3.0}/birder/net/iformer.py +8 -6
- {birder-0.2.3 → birder-0.3.0}/birder/net/levit.py +42 -30
- {birder-0.2.3 → birder-0.3.0}/birder/net/lit_v1_tiny.py +15 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/maxvit.py +67 -55
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobileone.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mvit_v2.py +13 -12
- {birder-0.2.3 → birder-0.3.0}/birder/net/pit.py +4 -3
- {birder-0.2.3 → birder-0.3.0}/birder/net/pvt_v1.py +4 -1
- {birder-0.2.3 → birder-0.3.0}/birder/net/repghost.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/repvgg.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/repvit.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/rope_deit3.py +5 -3
- {birder-0.2.3 → birder-0.3.0}/birder/net/rope_flexivit.py +7 -4
- {birder-0.2.3 → birder-0.3.0}/birder/net/rope_vit.py +10 -5
- {birder-0.2.3 → birder-0.3.0}/birder/net/simple_vit.py +9 -6
- {birder-0.2.3 → birder-0.3.0}/birder/net/swin_transformer_v1.py +71 -68
- {birder-0.2.3 → birder-0.3.0}/birder/net/swin_transformer_v2.py +38 -31
- {birder-0.2.3 → birder-0.3.0}/birder/net/tiny_vit.py +20 -10
- {birder-0.2.3 → birder-0.3.0}/birder/net/transnext.py +38 -28
- {birder-0.2.3 → birder-0.3.0}/birder/net/vit.py +5 -4
- {birder-0.2.3 → birder-0.3.0}/birder/net/vit_parallel.py +5 -4
- {birder-0.2.3 → birder-0.3.0}/birder/net/vit_sam.py +38 -37
- {birder-0.2.3 → birder-0.3.0}/birder/net/vovnet_v1.py +15 -0
- birder-0.3.0/birder/ops/msda.py +203 -0
- birder-0.3.0/birder/ops/swattention.py +288 -0
- {birder-0.2.3 → birder-0.3.0}/birder/results/detection.py +4 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/benchmark.py +21 -12
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/predict.py +7 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train.py +39 -13
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_barlow_twins.py +35 -12
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_byol.py +35 -12
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_capi.py +41 -15
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_data2vec.py +37 -14
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_data2vec2.py +37 -14
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_detection.py +36 -11
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_dino_v1.py +51 -14
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_dino_v2.py +78 -19
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_dino_v2_dist.py +76 -17
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_franca.py +43 -19
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_i_jepa.py +37 -14
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_ibot.py +43 -20
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_kd.py +39 -13
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_mim.py +35 -12
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_mmcr.py +35 -12
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_rotnet.py +36 -13
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_simclr.py +35 -12
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/train_vicreg.py +35 -12
- {birder-0.2.3 → birder-0.3.0}/birder/tools/convert_model.py +18 -15
- birder-0.3.0/birder/tools/det_results.py +173 -0
- birder-0.3.0/birder/tools/quantize_model.py +162 -0
- birder-0.3.0/birder/version.py +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder.egg-info/PKG-INFO +2 -1
- {birder-0.2.3 → birder-0.3.0}/birder.egg-info/SOURCES.txt +1 -0
- {birder-0.2.3 → birder-0.3.0}/birder.egg-info/requires.txt +1 -0
- {birder-0.2.3 → birder-0.3.0}/requirements/_requirements-dev.txt +2 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_common.py +73 -2
- birder-0.3.0/tests/test_dataloaders.py +101 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_net.py +41 -2
- birder-0.2.3/birder/ops/msda.py +0 -138
- birder-0.2.3/birder/ops/swattention.py +0 -225
- birder-0.2.3/birder/tools/det_results.py +0 -61
- birder-0.2.3/birder/tools/quantize_model.py +0 -156
- birder-0.2.3/birder/version.py +0 -1
- {birder-0.2.3 → birder-0.3.0}/LICENSE +0 -0
- {birder-0.2.3 → birder-0.3.0}/README.md +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/adversarial/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/adversarial/base.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/adversarial/deepfool.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/adversarial/fgsm.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/adversarial/pgd.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/adversarial/simba.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/common/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/common/cli.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/common/fs_ops.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/common/lib.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/common/masking.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/conf/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/conf/settings.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/collators/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/dataloader/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/datasets/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/datasets/coco.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/datasets/directory.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/datasets/fake.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/datasets/webdataset.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/transforms/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/transforms/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/transforms/detection.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/data/transforms/mosaic.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/datahub/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/datahub/_lib.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/datahub/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/inference/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/inference/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/inference/data_parallel.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/inference/detection.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/inference/wbf.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/introspection/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/introspection/attention_rollout.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/introspection/base.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/introspection/gradcam.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/introspection/guided_backprop.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/introspection/transformer_attribution.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/deformable_detr/vision.cpp +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/soft_nms/op.cpp +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/soft_nms/soft_nms.h +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/kernels/transnext/swattention.cpp +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/layers/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/layers/activations.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/layers/attention_pool.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/layers/ffn.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/layers/gem.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/layers/layer_norm.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/layers/layer_scale.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/model_registry/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/model_registry/manifest.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/model_registry/model_registry.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/alexnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/base.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/biformer.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/cas_vit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/coat.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/conv2former.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/convmixer.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/convnext_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/cspnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/cswin_transformer.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/darknet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/davit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/densenet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/base.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/faster_rcnn.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/ssd.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/ssdlite.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/vitdet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/yolo_anchors.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/yolo_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/yolo_v3.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/yolo_v4.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/detection/yolo_v4_tiny.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/dpn.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/edgenext.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/edgevit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientnet_lite.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientnet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientvim.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/efficientvit_mit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/fasternet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/focalnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/gc_vit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ghostnet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ghostnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/groupmixformer.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/hgnet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/hgnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/hieradet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/inception_next.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/inception_resnet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/inception_resnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/inception_v3.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/inception_v4.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/lit_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/lit_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/metaformer.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mim/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mim/base.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mim/crossmae.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mim/fcmae.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mim/mae_hiera.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mim/mae_vit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mim/simmim.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mnasnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilenet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilenet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilenet_v3_large.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilenet_v3_small.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilenet_v4.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilenet_v4_hybrid.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilevit_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/mobilevit_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/moganet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/nextvit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/nfnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/pvt_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/rdnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/regionvit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/regnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/regnet_z.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/resmlp.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/resnest.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/resnet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/resnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/resnext.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/se_resnet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/se_resnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/se_resnext.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/sequencer2d.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/shufflenet_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/shufflenet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/smt.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/squeezenet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/squeezenext.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/barlow_twins.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/base.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/byol.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/capi.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/data2vec.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/data2vec2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/dino_v1.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/dino_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/franca.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/i_jepa.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/ibot.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/mmcr.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/simclr.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/sscd.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/ssl/vicreg.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/starnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/swiftformer.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/uniformer.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/van.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/vgg.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/vgg_reduced.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/vovnet_v2.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/wide_resnet.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/xception.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/net/xcit.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/ops/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/ops/soft_nms.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/optim/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/optim/lamb.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/optim/lars.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/py.typed +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/results/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/results/classification.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/results/gui.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scheduler/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scheduler/cooldown.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/__main__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/evaluate.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/scripts/predict_detection.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/__init__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/__main__.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/adversarial.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/auto_anchors.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/avg_model.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/download_model.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/ensemble_model.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/introspection.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/labelme_to_coco.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/list_models.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/model_info.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/pack.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/results.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/show_det_iterator.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/show_iterator.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/similarity.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/stats.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/verify_coco.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/verify_directory.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder/tools/voc_to_coco.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder.egg-info/dependency_links.txt +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder.egg-info/entry_points.txt +0 -0
- {birder-0.2.3 → birder-0.3.0}/birder.egg-info/top_level.txt +0 -0
- {birder-0.2.3 → birder-0.3.0}/pyproject.toml +0 -0
- {birder-0.2.3 → birder-0.3.0}/requirements/requirements-hf.txt +0 -0
- {birder-0.2.3 → birder-0.3.0}/requirements/requirements.txt +0 -0
- {birder-0.2.3 → birder-0.3.0}/setup.cfg +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_adversarial.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_collators.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_datasets.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_inference.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_introspection.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_kernels.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_layers.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_model_registry.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_net_detection.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_net_mim.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_net_ssl.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_ops.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_optim.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_results.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_scheduler.py +0 -0
- {birder-0.2.3 → birder-0.3.0}/tests/test_transforms.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -66,6 +66,7 @@ Requires-Dist: pytest; extra == "dev"
|
|
|
66
66
|
Requires-Dist: requests~=2.32.5; extra == "dev"
|
|
67
67
|
Requires-Dist: safetensors~=0.7.0; extra == "dev"
|
|
68
68
|
Requires-Dist: setuptools; extra == "dev"
|
|
69
|
+
Requires-Dist: torchao~=0.15.0; extra == "dev"
|
|
69
70
|
Requires-Dist: torchprofile==0.0.4; extra == "dev"
|
|
70
71
|
Requires-Dist: twine~=6.2.0; extra == "dev"
|
|
71
72
|
Requires-Dist: types-requests~=2.32.4; extra == "dev"
|
|
@@ -211,6 +211,12 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
|
|
|
211
211
|
group.add_argument(
|
|
212
212
|
"--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
|
|
213
213
|
)
|
|
214
|
+
group.add_argument(
|
|
215
|
+
"--steps-per-epoch",
|
|
216
|
+
type=int,
|
|
217
|
+
metavar="N",
|
|
218
|
+
help="virtual epoch length in steps, leave unset to use the full dataset",
|
|
219
|
+
)
|
|
214
220
|
group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
|
|
215
221
|
group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
|
|
216
222
|
group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
|
|
@@ -17,6 +17,7 @@ from typing import Any
|
|
|
17
17
|
from typing import Literal
|
|
18
18
|
from typing import Optional
|
|
19
19
|
from typing import Sized
|
|
20
|
+
from typing import overload
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
import torch
|
|
@@ -70,13 +71,7 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
70
71
|
"""
|
|
71
72
|
|
|
72
73
|
def __init__(
|
|
73
|
-
self,
|
|
74
|
-
dataset: Sized,
|
|
75
|
-
num_replicas: int,
|
|
76
|
-
rank: int,
|
|
77
|
-
shuffle: bool,
|
|
78
|
-
seed: int = 0,
|
|
79
|
-
repetitions: int = 3,
|
|
74
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
80
75
|
) -> None:
|
|
81
76
|
super().__init__()
|
|
82
77
|
self.dataset = dataset
|
|
@@ -85,12 +80,11 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
85
80
|
self.epoch = 0
|
|
86
81
|
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
87
82
|
self.total_size = self.num_samples * self.num_replicas
|
|
88
|
-
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
|
89
83
|
self.shuffle = shuffle
|
|
90
84
|
self.seed = seed
|
|
91
85
|
self.repetitions = repetitions
|
|
92
86
|
|
|
93
|
-
def __iter__(self) -> Iterator[
|
|
87
|
+
def __iter__(self) -> Iterator[int]:
|
|
94
88
|
if self.shuffle is True:
|
|
95
89
|
# Deterministically shuffle based on epoch
|
|
96
90
|
g = torch.Generator()
|
|
@@ -100,18 +94,148 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
100
94
|
indices = list(range(len(self.dataset)))
|
|
101
95
|
|
|
102
96
|
# Add extra samples to make it evenly divisible
|
|
103
|
-
indices = [ele for ele in indices for
|
|
104
|
-
indices
|
|
105
|
-
|
|
97
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
98
|
+
if len(indices) < self.total_size:
|
|
99
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
100
|
+
else:
|
|
101
|
+
indices = indices[: self.total_size]
|
|
106
102
|
|
|
107
|
-
#
|
|
103
|
+
# Shard by rank
|
|
108
104
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
109
105
|
assert len(indices) == self.num_samples
|
|
110
106
|
|
|
111
|
-
|
|
107
|
+
yield from indices
|
|
108
|
+
|
|
109
|
+
def __len__(self) -> int:
|
|
110
|
+
return self.num_samples
|
|
111
|
+
|
|
112
|
+
def set_epoch(self, epoch: int) -> None:
|
|
113
|
+
self.epoch = epoch
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class InfiniteSampler(torch.utils.data.Sampler):
|
|
117
|
+
"""
|
|
118
|
+
Infinite sampler that loops indefinitely over the dataset
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, dataset: Sized, shuffle: bool, seed: int = 0) -> None:
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.dataset = dataset
|
|
124
|
+
self.shuffle = shuffle
|
|
125
|
+
self.seed = seed
|
|
126
|
+
self.epoch = 0
|
|
127
|
+
|
|
128
|
+
def __iter__(self) -> Iterator[int]:
|
|
129
|
+
g = torch.Generator()
|
|
130
|
+
while True:
|
|
131
|
+
if self.shuffle is True:
|
|
132
|
+
g.manual_seed(self.seed + self.epoch)
|
|
133
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
134
|
+
else:
|
|
135
|
+
indices = list(range(len(self.dataset)))
|
|
136
|
+
|
|
137
|
+
yield from indices
|
|
138
|
+
|
|
139
|
+
logger.debug(f"InfiniteSampler finished epoch {self.epoch}")
|
|
140
|
+
self.epoch += 1
|
|
141
|
+
|
|
142
|
+
def __len__(self) -> int:
|
|
143
|
+
return len(self.dataset)
|
|
144
|
+
|
|
145
|
+
def set_epoch(self, epoch: int) -> None:
|
|
146
|
+
self.epoch = epoch
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class InfiniteDistributedSampler(torch.utils.data.Sampler):
|
|
150
|
+
"""
|
|
151
|
+
Infinite distributed sampler that keeps a continuous shuffled stream per rank
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0) -> None:
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.dataset = dataset
|
|
157
|
+
self.num_replicas = num_replicas
|
|
158
|
+
self.rank = rank
|
|
159
|
+
self.shuffle = shuffle
|
|
160
|
+
self.seed = seed
|
|
161
|
+
self.epoch = 0
|
|
162
|
+
self.num_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
|
163
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
164
|
+
|
|
165
|
+
def __iter__(self) -> Iterator[int]:
|
|
166
|
+
g = torch.Generator()
|
|
167
|
+
while True:
|
|
168
|
+
if self.shuffle is True:
|
|
169
|
+
g.manual_seed(self.seed + self.epoch)
|
|
170
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
171
|
+
else:
|
|
172
|
+
indices = list(range(len(self.dataset)))
|
|
173
|
+
|
|
174
|
+
if len(indices) < self.total_size:
|
|
175
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
176
|
+
else:
|
|
177
|
+
indices = indices[: self.total_size]
|
|
178
|
+
|
|
179
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
180
|
+
assert len(indices) == self.num_samples
|
|
181
|
+
|
|
182
|
+
yield from indices
|
|
183
|
+
|
|
184
|
+
logger.debug(f"InfiniteDistributedSampler finished epoch {self.epoch}")
|
|
185
|
+
self.epoch += 1
|
|
112
186
|
|
|
113
187
|
def __len__(self) -> int:
|
|
114
|
-
return self.
|
|
188
|
+
return self.num_samples
|
|
189
|
+
|
|
190
|
+
def set_epoch(self, epoch: int) -> None:
|
|
191
|
+
self.epoch = epoch
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class InfiniteRASampler(torch.utils.data.Sampler):
|
|
195
|
+
"""
|
|
196
|
+
Infinite version of the repeated augmentation sampler
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(
|
|
200
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
201
|
+
) -> None:
|
|
202
|
+
super().__init__()
|
|
203
|
+
self.dataset = dataset
|
|
204
|
+
self.num_replicas = num_replicas
|
|
205
|
+
self.rank = rank
|
|
206
|
+
self.epoch = 0
|
|
207
|
+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
208
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
209
|
+
self.shuffle = shuffle
|
|
210
|
+
self.seed = seed
|
|
211
|
+
self.repetitions = repetitions
|
|
212
|
+
|
|
213
|
+
def __iter__(self) -> Iterator[int]:
|
|
214
|
+
g = torch.Generator()
|
|
215
|
+
while True:
|
|
216
|
+
if self.shuffle is True:
|
|
217
|
+
g.manual_seed(self.seed + self.epoch)
|
|
218
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
219
|
+
else:
|
|
220
|
+
indices = list(range(len(self.dataset)))
|
|
221
|
+
|
|
222
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
223
|
+
if len(indices) < self.total_size:
|
|
224
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
225
|
+
else:
|
|
226
|
+
indices = indices[: self.total_size]
|
|
227
|
+
|
|
228
|
+
# Shard by rank
|
|
229
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
230
|
+
assert len(indices) == self.num_samples
|
|
231
|
+
|
|
232
|
+
yield from indices
|
|
233
|
+
|
|
234
|
+
logger.debug(f"InfiniteRASampler finished epoch {self.epoch}")
|
|
235
|
+
self.epoch += 1
|
|
236
|
+
|
|
237
|
+
def __len__(self) -> int:
|
|
238
|
+
return self.num_samples
|
|
115
239
|
|
|
116
240
|
def set_epoch(self, epoch: int) -> None:
|
|
117
241
|
self.epoch = epoch
|
|
@@ -636,27 +760,87 @@ def get_amp_scaler(amp: bool, amp_dtype_str: str) -> tuple[Optional[torch.amp.Gr
|
|
|
636
760
|
return (scaler, amp_dtype)
|
|
637
761
|
|
|
638
762
|
|
|
763
|
+
@overload
|
|
639
764
|
def get_samplers(
|
|
640
|
-
args: argparse.Namespace,
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
training_dataset,
|
|
646
|
-
num_replicas=args.world_size,
|
|
647
|
-
rank=args.rank,
|
|
648
|
-
shuffle=True,
|
|
649
|
-
repetitions=args.ra_reps,
|
|
650
|
-
)
|
|
765
|
+
args: argparse.Namespace,
|
|
766
|
+
training_dataset: torch.utils.data.Dataset,
|
|
767
|
+
validation_dataset: torch.utils.data.Dataset,
|
|
768
|
+
infinite: bool = False,
|
|
769
|
+
) -> tuple[torch.utils.data.Sampler, torch.utils.data.Sampler]: ...
|
|
651
770
|
|
|
652
|
-
else:
|
|
653
|
-
train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset, shuffle=True)
|
|
654
771
|
|
|
655
|
-
|
|
772
|
+
@overload
|
|
773
|
+
def get_samplers(
|
|
774
|
+
args: argparse.Namespace,
|
|
775
|
+
training_dataset: torch.utils.data.Dataset,
|
|
776
|
+
validation_dataset: None = None,
|
|
777
|
+
infinite: bool = False,
|
|
778
|
+
) -> tuple[torch.utils.data.Sampler, None]: ...
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def get_samplers(
|
|
782
|
+
args: argparse.Namespace,
|
|
783
|
+
training_dataset: torch.utils.data.Dataset,
|
|
784
|
+
validation_dataset: Optional[torch.utils.data.Dataset] = None,
|
|
785
|
+
infinite: bool = False,
|
|
786
|
+
) -> tuple[torch.utils.data.Sampler, Optional[torch.utils.data.Sampler]]:
|
|
787
|
+
if args.seed is None:
|
|
788
|
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
|
789
|
+
if is_dist_available_and_initialized() is True:
|
|
790
|
+
seed_tensor = torch.tensor(seed, dtype=torch.int64).cuda()
|
|
791
|
+
dist.broadcast(seed_tensor, src=0, async_op=False)
|
|
792
|
+
seed = int(seed_tensor.item())
|
|
793
|
+
else:
|
|
794
|
+
seed = args.seed
|
|
795
|
+
|
|
796
|
+
ra_sampler = getattr(args, "ra_sampler", False)
|
|
797
|
+
if args.distributed is True:
|
|
798
|
+
if infinite is True:
|
|
799
|
+
if ra_sampler is True:
|
|
800
|
+
train_sampler = InfiniteRASampler(
|
|
801
|
+
training_dataset,
|
|
802
|
+
num_replicas=args.world_size,
|
|
803
|
+
rank=args.rank,
|
|
804
|
+
shuffle=True,
|
|
805
|
+
seed=seed,
|
|
806
|
+
repetitions=args.ra_reps,
|
|
807
|
+
)
|
|
808
|
+
else:
|
|
809
|
+
train_sampler = InfiniteDistributedSampler(
|
|
810
|
+
training_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True, seed=seed
|
|
811
|
+
)
|
|
812
|
+
else:
|
|
813
|
+
if ra_sampler is True:
|
|
814
|
+
train_sampler = RASampler(
|
|
815
|
+
training_dataset,
|
|
816
|
+
num_replicas=args.world_size,
|
|
817
|
+
rank=args.rank,
|
|
818
|
+
shuffle=True,
|
|
819
|
+
seed=seed,
|
|
820
|
+
repetitions=args.ra_reps,
|
|
821
|
+
)
|
|
822
|
+
else:
|
|
823
|
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
824
|
+
training_dataset, shuffle=True, seed=seed
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
if validation_dataset is None:
|
|
828
|
+
validation_sampler = None
|
|
829
|
+
else:
|
|
830
|
+
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
|
|
656
831
|
|
|
657
832
|
else:
|
|
658
|
-
|
|
659
|
-
|
|
833
|
+
if infinite is True:
|
|
834
|
+
train_sampler = InfiniteSampler(training_dataset, shuffle=True, seed=seed)
|
|
835
|
+
else:
|
|
836
|
+
generator = torch.Generator()
|
|
837
|
+
generator.manual_seed(seed)
|
|
838
|
+
train_sampler = torch.utils.data.RandomSampler(training_dataset, generator=generator)
|
|
839
|
+
|
|
840
|
+
if validation_dataset is None:
|
|
841
|
+
validation_sampler = None
|
|
842
|
+
else:
|
|
843
|
+
validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
|
|
660
844
|
|
|
661
845
|
return (train_sampler, validation_sampler)
|
|
662
846
|
|
|
@@ -98,6 +98,7 @@ class BatchRandomResizeCollator(DetectionCollator):
|
|
|
98
98
|
if isinstance(boxes, tv_tensors.BoundingBoxes) is False:
|
|
99
99
|
if boxes.numel() == 0:
|
|
100
100
|
boxes = boxes.reshape(0, 4)
|
|
101
|
+
|
|
101
102
|
boxes = tv_tensors.BoundingBoxes(
|
|
102
103
|
boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=F.get_size(image)
|
|
103
104
|
)
|
|
@@ -22,9 +22,19 @@ def make_wds_loader(
|
|
|
22
22
|
shuffle: bool = False,
|
|
23
23
|
*,
|
|
24
24
|
exact: bool = False,
|
|
25
|
+
infinite: bool = False,
|
|
25
26
|
) -> DataLoader:
|
|
27
|
+
assert exact is False or infinite is False
|
|
28
|
+
|
|
29
|
+
if infinite is True:
|
|
30
|
+
dataset_iterable = dataset.repeat()
|
|
31
|
+
elif exact is False:
|
|
32
|
+
dataset_iterable = dataset.repeat()
|
|
33
|
+
else:
|
|
34
|
+
dataset_iterable = dataset
|
|
35
|
+
|
|
26
36
|
dataloader = wds.WebLoader(
|
|
27
|
-
|
|
37
|
+
dataset_iterable,
|
|
28
38
|
batch_size=batch_size,
|
|
29
39
|
num_workers=num_workers,
|
|
30
40
|
prefetch_factor=prefetch_factor,
|
|
@@ -43,7 +53,7 @@ def make_wds_loader(
|
|
|
43
53
|
epoch_size = math.ceil(len(dataset) / (batch_size * world_size))
|
|
44
54
|
|
|
45
55
|
dataloader = dataloader.with_length(epoch_size, silent=True)
|
|
46
|
-
if exact is False:
|
|
56
|
+
if exact is False and infinite is False:
|
|
47
57
|
dataloader = dataloader.with_epoch(epoch_size)
|
|
48
58
|
|
|
49
59
|
return dataloader
|
|
@@ -14,11 +14,24 @@ logger = logging.getLogger(__name__)
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
_CACHED_KERNELS: dict[str, ModuleType] = {}
|
|
17
|
+
_CUSTOM_KERNELS_ENABLED = True
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def set_custom_kernels_enabled(enabled: bool) -> None:
|
|
21
|
+
global _CUSTOM_KERNELS_ENABLED # pylint: disable=global-statement
|
|
22
|
+
_CUSTOM_KERNELS_ENABLED = enabled
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_custom_kernels_enabled() -> bool:
|
|
26
|
+
if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
return _CUSTOM_KERNELS_ENABLED
|
|
17
30
|
|
|
18
31
|
|
|
19
32
|
def load_msda() -> Optional[ModuleType]:
|
|
20
33
|
name = "msda"
|
|
21
|
-
if torch.cuda.is_available() is False or
|
|
34
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
22
35
|
return None
|
|
23
36
|
|
|
24
37
|
if name in _CACHED_KERNELS:
|
|
@@ -60,7 +73,7 @@ def load_msda() -> Optional[ModuleType]:
|
|
|
60
73
|
|
|
61
74
|
def load_swattention() -> Optional[ModuleType]:
|
|
62
75
|
name = "swattention"
|
|
63
|
-
if torch.cuda.is_available() is False or
|
|
76
|
+
if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
|
|
64
77
|
return None
|
|
65
78
|
|
|
66
79
|
if name in _CACHED_KERNELS:
|
|
@@ -103,7 +116,7 @@ def load_swattention() -> Optional[ModuleType]:
|
|
|
103
116
|
|
|
104
117
|
def load_soft_nms() -> Optional[ModuleType]:
|
|
105
118
|
name = "soft_nms"
|
|
106
|
-
if
|
|
119
|
+
if is_custom_kernels_enabled() is False:
|
|
107
120
|
return None
|
|
108
121
|
|
|
109
122
|
if name in _CACHED_KERNELS:
|
|
@@ -120,14 +133,6 @@ def load_soft_nms() -> Optional[ModuleType]:
|
|
|
120
133
|
soft_nms: Optional[ModuleType] = load(
|
|
121
134
|
"soft_nms",
|
|
122
135
|
src_files,
|
|
123
|
-
with_cuda=True,
|
|
124
|
-
extra_cflags=["-DWITH_CUDA=1"],
|
|
125
|
-
extra_cuda_cflags=[
|
|
126
|
-
"-DCUDA_HAS_FP16=1",
|
|
127
|
-
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
128
|
-
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
129
|
-
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
130
|
-
],
|
|
131
136
|
)
|
|
132
137
|
|
|
133
138
|
if soft_nms is not None:
|
|
@@ -61,24 +61,23 @@ void update_sorting_order(torch::Tensor& boxes, torch::Tensor& scores, torch::Te
|
|
|
61
61
|
std::tie(max_score, t_max_idx) = torch::max(scores.index({Slice(idx + 1, None)}), 0);
|
|
62
62
|
|
|
63
63
|
// max_idx is computed from sliced data, therefore need to convert it to "global" max idx
|
|
64
|
-
auto max_idx = t_max_idx
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
}
|
|
64
|
+
auto max_idx = t_max_idx + (idx + 1);
|
|
65
|
+
auto should_swap = scores.index({idx}) < max_score;
|
|
66
|
+
|
|
67
|
+
auto boxes_idx = boxes.index({idx}).clone();
|
|
68
|
+
auto boxes_max = boxes.index({max_idx}).clone();
|
|
69
|
+
boxes.index_put_({idx}, torch::where(should_swap, boxes_max, boxes_idx));
|
|
70
|
+
boxes.index_put_({max_idx}, torch::where(should_swap, boxes_idx, boxes_max));
|
|
71
|
+
|
|
72
|
+
auto scores_idx = scores.index({idx}).clone();
|
|
73
|
+
auto scores_max = scores.index({max_idx}).clone();
|
|
74
|
+
scores.index_put_({idx}, torch::where(should_swap, scores_max, scores_idx));
|
|
75
|
+
scores.index_put_({max_idx}, torch::where(should_swap, scores_idx, scores_max));
|
|
76
|
+
|
|
77
|
+
auto areas_idx = areas.index({idx}).clone();
|
|
78
|
+
auto areas_max = areas.index({max_idx}).clone();
|
|
79
|
+
areas.index_put_({idx}, torch::where(should_swap, areas_max, areas_idx));
|
|
80
|
+
areas.index_put_({max_idx}, torch::where(should_swap, areas_idx, areas_max));
|
|
82
81
|
}
|
|
83
82
|
|
|
84
83
|
std::tuple<torch::Tensor, torch::Tensor> soft_nms(
|
|
@@ -268,14 +268,15 @@ class CaiT(BaseNet):
|
|
|
268
268
|
super().adjust_size(new_size)
|
|
269
269
|
|
|
270
270
|
# Add back class tokens
|
|
271
|
-
|
|
272
|
-
adjust_position_embedding(
|
|
271
|
+
with torch.no_grad():
|
|
272
|
+
pos_embed = adjust_position_embedding(
|
|
273
273
|
self.pos_embed,
|
|
274
274
|
(old_size[0] // self.patch_size[0], old_size[1] // self.patch_size[1]),
|
|
275
275
|
(new_size[0] // self.patch_size[0], new_size[1] // self.patch_size[1]),
|
|
276
276
|
0,
|
|
277
277
|
)
|
|
278
|
-
|
|
278
|
+
|
|
279
|
+
self.pos_embed = nn.Parameter(pos_embed)
|
|
279
280
|
|
|
280
281
|
|
|
281
282
|
registry.register_model_config(
|
|
@@ -195,6 +195,11 @@ class ConvNeXt_v1(DetectorBackbone, PreTrainEncoder, MaskedTokenRetentionMixin):
|
|
|
195
195
|
return self.features(x)
|
|
196
196
|
|
|
197
197
|
|
|
198
|
+
registry.register_model_config(
|
|
199
|
+
"convnext_v1_nano", # Not in the original v1, taken from v2
|
|
200
|
+
ConvNeXt_v1,
|
|
201
|
+
config={"in_channels": [80, 160, 320, 640], "num_layers": [2, 2, 8, 2], "drop_path_rate": 0.1},
|
|
202
|
+
)
|
|
198
203
|
registry.register_model_config(
|
|
199
204
|
"convnext_v1_tiny",
|
|
200
205
|
ConvNeXt_v1,
|
|
@@ -98,15 +98,17 @@ class Attention(nn.Module):
|
|
|
98
98
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
99
99
|
|
|
100
100
|
def define_bias_table(self) -> None:
|
|
101
|
-
|
|
102
|
-
|
|
101
|
+
device = next(self.pos.parameters()).device
|
|
102
|
+
position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0], device=device)
|
|
103
|
+
position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1], device=device)
|
|
103
104
|
biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w], indexing="ij")) # 2, 2Wh-1, 2W2-1
|
|
104
105
|
biases = biases.flatten(1).transpose(0, 1).float()
|
|
105
106
|
self.biases = nn.Buffer(biases)
|
|
106
107
|
|
|
107
108
|
def define_relative_position_index(self) -> None:
|
|
108
|
-
|
|
109
|
-
|
|
109
|
+
device = self.biases.device
|
|
110
|
+
coords_h = torch.arange(self.group_size[0], device=device)
|
|
111
|
+
coords_w = torch.arange(self.group_size[1], device=device)
|
|
110
112
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
|
111
113
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
112
114
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
@@ -430,32 +432,33 @@ class CrossFormer(DetectorBackbone):
|
|
|
430
432
|
|
|
431
433
|
new_patch_resolution = (new_size[0] // self.patch_sizes[0], new_size[1] // self.patch_sizes[0])
|
|
432
434
|
input_resolution = new_patch_resolution
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
m
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
m
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
m
|
|
448
|
-
|
|
449
|
-
m.
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
m.
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
m
|
|
457
|
-
|
|
458
|
-
|
|
435
|
+
with torch.no_grad():
|
|
436
|
+
for mod in self.body.modules():
|
|
437
|
+
if isinstance(mod, CrossFormerStage):
|
|
438
|
+
for m in mod.modules():
|
|
439
|
+
if isinstance(m, PatchMerging):
|
|
440
|
+
m.input_resolution = input_resolution
|
|
441
|
+
input_resolution = (input_resolution[0] // 2, input_resolution[1] // 2)
|
|
442
|
+
elif isinstance(m, CrossFormerBlock):
|
|
443
|
+
m.input_resolution = input_resolution
|
|
444
|
+
|
|
445
|
+
mod.resolution = input_resolution
|
|
446
|
+
|
|
447
|
+
new_group_size = (int(new_size[0] / (2**5)), int(new_size[1] / (2**5)))
|
|
448
|
+
for m in self.body.modules():
|
|
449
|
+
if isinstance(m, CrossFormerBlock):
|
|
450
|
+
m.group_size = new_group_size
|
|
451
|
+
if m.input_resolution[0] <= m.group_size[0]:
|
|
452
|
+
m.use_lda = False
|
|
453
|
+
m.group_size = (m.input_resolution[0], m.group_size[1])
|
|
454
|
+
if m.input_resolution[1] <= m.group_size[1]:
|
|
455
|
+
m.use_lda = False
|
|
456
|
+
m.group_size = (m.group_size[0], m.input_resolution[1])
|
|
457
|
+
|
|
458
|
+
elif isinstance(m, Attention):
|
|
459
|
+
m.group_size = new_group_size
|
|
460
|
+
m.define_bias_table()
|
|
461
|
+
m.define_relative_position_index()
|
|
459
462
|
|
|
460
463
|
|
|
461
464
|
registry.register_model_config(
|
|
@@ -359,9 +359,10 @@ class CrossViT(BaseNet):
|
|
|
359
359
|
old_w = old_size[1] // self.patch_size[i]
|
|
360
360
|
h = new_size[0] // self.patch_size[i]
|
|
361
361
|
w = new_size[1] // self.patch_size[i]
|
|
362
|
-
|
|
363
|
-
adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
|
|
364
|
-
|
|
362
|
+
with torch.no_grad():
|
|
363
|
+
pos_embed = adjust_position_embedding(self.pos_embed[i], (old_h, old_w), (h, w), num_prefix_tokens=1)
|
|
364
|
+
|
|
365
|
+
self.pos_embed[i] = nn.Parameter(pos_embed)
|
|
365
366
|
|
|
366
367
|
|
|
367
368
|
registry.register_model_config(
|
|
@@ -187,14 +187,14 @@ class DeiT(BaseNet):
|
|
|
187
187
|
num_prefix_tokens = 2
|
|
188
188
|
|
|
189
189
|
# Add back class tokens
|
|
190
|
-
|
|
191
|
-
adjust_position_embedding(
|
|
190
|
+
with torch.no_grad():
|
|
191
|
+
pos_embedding = adjust_position_embedding(
|
|
192
192
|
self.pos_embedding,
|
|
193
193
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
194
194
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
195
195
|
num_prefix_tokens,
|
|
196
196
|
)
|
|
197
|
-
)
|
|
197
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
198
198
|
|
|
199
199
|
|
|
200
200
|
registry.register_model_config(
|
|
@@ -355,14 +355,14 @@ class DeiT3(DetectorBackbone, PreTrainEncoder, MaskedTokenOmissionMixin, MaskedT
|
|
|
355
355
|
num_prefix_tokens = 0
|
|
356
356
|
|
|
357
357
|
# Add back class tokens
|
|
358
|
-
|
|
359
|
-
adjust_position_embedding(
|
|
358
|
+
with torch.no_grad():
|
|
359
|
+
pos_embedding = adjust_position_embedding(
|
|
360
360
|
self.pos_embedding,
|
|
361
361
|
(old_size[0] // self.patch_size, old_size[1] // self.patch_size),
|
|
362
362
|
(new_size[0] // self.patch_size, new_size[1] // self.patch_size),
|
|
363
363
|
num_prefix_tokens,
|
|
364
364
|
)
|
|
365
|
-
)
|
|
365
|
+
self.pos_embedding = nn.Parameter(pos_embedding)
|
|
366
366
|
|
|
367
367
|
|
|
368
368
|
registry.register_model_config(
|
|
@@ -757,11 +757,8 @@ class Deformable_DETR(DetectionBaseNet):
|
|
|
757
757
|
for s, l, b in zip(scores, labels, boxes):
|
|
758
758
|
# Non-maximum suppression
|
|
759
759
|
if self.soft_nms is not None:
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
(soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
|
|
763
|
-
keep = keep.to(device)
|
|
764
|
-
s[keep] = soft_scores.to(device)
|
|
760
|
+
(soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
761
|
+
s[keep] = soft_scores
|
|
765
762
|
|
|
766
763
|
b = b[keep]
|
|
767
764
|
s = s[keep]
|
|
@@ -465,11 +465,8 @@ class DETR(DetectionBaseNet):
|
|
|
465
465
|
for s, l, b in zip(scores, labels, boxes):
|
|
466
466
|
# Non-maximum suppression
|
|
467
467
|
if self.soft_nms is not None:
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
(soft_scores, keep) = self.soft_nms(b.cpu(), s.cpu(), l.cpu(), score_threshold=0.001)
|
|
471
|
-
keep = keep.to(device)
|
|
472
|
-
s[keep] = soft_scores.to(device)
|
|
468
|
+
(soft_scores, keep) = self.soft_nms(b, s, l, score_threshold=0.001)
|
|
469
|
+
s[keep] = soft_scores
|
|
473
470
|
|
|
474
471
|
b = b[keep]
|
|
475
472
|
s = s[keep]
|