birder 0.2.2__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder-0.2.2 → birder-0.3.0}/PKG-INFO +4 -3
- {birder-0.2.2 → birder-0.3.0}/README.md +1 -1
- {birder-0.2.2 → birder-0.3.0}/birder/common/lib.py +2 -9
- {birder-0.2.2 → birder-0.3.0}/birder/common/training_cli.py +24 -0
- {birder-0.2.2 → birder-0.3.0}/birder/common/training_utils.py +338 -41
- {birder-0.2.2 → birder-0.3.0}/birder/data/collators/detection.py +11 -3
- {birder-0.2.2 → birder-0.3.0}/birder/data/dataloader/webdataset.py +12 -2
- {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/coco.py +8 -10
- {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/detection.py +30 -13
- {birder-0.2.2 → birder-0.3.0}/birder/inference/detection.py +108 -4
- birder-0.3.0/birder/inference/wbf.py +226 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/load_kernel.py +16 -11
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/soft_nms/soft_nms.cpp +17 -18
- {birder-0.2.2 → birder-0.3.0}/birder/net/__init__.py +8 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/cait.py +4 -3
- {birder-0.2.2 → birder-0.3.0}/birder/net/convnext_v1.py +5 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/crossformer.py +33 -30
- {birder-0.2.2 → birder-0.3.0}/birder/net/crossvit.py +4 -3
- {birder-0.2.2 → birder-0.3.0}/birder/net/deit.py +3 -3
- {birder-0.2.2 → birder-0.3.0}/birder/net/deit3.py +3 -3
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/deformable_detr.py +2 -5
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/detr.py +2 -5
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/efficientdet.py +67 -93
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/fcos.py +2 -7
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/retinanet.py +2 -7
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/rt_detr_v1.py +2 -0
- birder-0.3.0/birder/net/detection/yolo_anchors.py +205 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v2.py +25 -24
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v3.py +39 -40
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v4.py +28 -26
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v4_tiny.py +24 -20
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientformer_v1.py +15 -9
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientformer_v2.py +39 -29
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientvit_msft.py +9 -7
- {birder-0.2.2 → birder-0.3.0}/birder/net/fasternet.py +1 -1
- {birder-0.2.2 → birder-0.3.0}/birder/net/fastvit.py +1 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/flexivit.py +5 -4
- birder-0.3.0/birder/net/gc_vit.py +671 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/hiera.py +12 -9
- {birder-0.2.2 → birder-0.3.0}/birder/net/hornet.py +9 -7
- {birder-0.2.2 → birder-0.3.0}/birder/net/iformer.py +8 -6
- {birder-0.2.2 → birder-0.3.0}/birder/net/levit.py +42 -30
- birder-0.3.0/birder/net/lit_v1.py +472 -0
- birder-0.3.0/birder/net/lit_v1_tiny.py +357 -0
- birder-0.3.0/birder/net/lit_v2.py +436 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/maxvit.py +67 -55
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v4_hybrid.py +1 -1
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobileone.py +1 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mvit_v2.py +13 -12
- {birder-0.2.2 → birder-0.3.0}/birder/net/pit.py +4 -3
- {birder-0.2.2 → birder-0.3.0}/birder/net/pvt_v1.py +4 -1
- {birder-0.2.2 → birder-0.3.0}/birder/net/repghost.py +1 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/repvgg.py +1 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/repvit.py +1 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/resnet_v1.py +1 -1
- {birder-0.2.2 → birder-0.3.0}/birder/net/resnext.py +67 -25
- {birder-0.2.2 → birder-0.3.0}/birder/net/rope_deit3.py +5 -3
- {birder-0.2.2 → birder-0.3.0}/birder/net/rope_flexivit.py +7 -4
- {birder-0.2.2 → birder-0.3.0}/birder/net/rope_vit.py +10 -5
- {birder-0.2.2 → birder-0.3.0}/birder/net/se_resnet_v1.py +46 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/se_resnext.py +3 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/simple_vit.py +11 -8
- {birder-0.2.2 → birder-0.3.0}/birder/net/swin_transformer_v1.py +71 -68
- {birder-0.2.2 → birder-0.3.0}/birder/net/swin_transformer_v2.py +38 -31
- {birder-0.2.2 → birder-0.3.0}/birder/net/tiny_vit.py +20 -10
- {birder-0.2.2 → birder-0.3.0}/birder/net/transnext.py +38 -28
- {birder-0.2.2 → birder-0.3.0}/birder/net/vit.py +5 -19
- {birder-0.2.2 → birder-0.3.0}/birder/net/vit_parallel.py +5 -4
- {birder-0.2.2 → birder-0.3.0}/birder/net/vit_sam.py +38 -37
- {birder-0.2.2 → birder-0.3.0}/birder/net/vovnet_v1.py +15 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/vovnet_v2.py +31 -1
- birder-0.3.0/birder/ops/msda.py +203 -0
- birder-0.3.0/birder/ops/swattention.py +288 -0
- {birder-0.2.2 → birder-0.3.0}/birder/results/detection.py +4 -0
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/benchmark.py +110 -32
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/predict.py +8 -0
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/predict_detection.py +18 -11
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train.py +48 -46
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_barlow_twins.py +44 -45
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_byol.py +44 -45
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_capi.py +50 -49
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_data2vec.py +45 -47
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_data2vec2.py +45 -47
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_detection.py +83 -50
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_dino_v1.py +60 -47
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_dino_v2.py +86 -52
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_dino_v2_dist.py +84 -50
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_franca.py +51 -52
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_i_jepa.py +45 -47
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_ibot.py +51 -53
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_kd.py +194 -76
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_mim.py +44 -45
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_mmcr.py +44 -45
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_rotnet.py +45 -46
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_simclr.py +44 -45
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_vicreg.py +44 -45
- {birder-0.2.2 → birder-0.3.0}/birder/tools/auto_anchors.py +20 -1
- {birder-0.2.2 → birder-0.3.0}/birder/tools/convert_model.py +18 -15
- birder-0.3.0/birder/tools/det_results.py +173 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/pack.py +172 -103
- birder-0.3.0/birder/tools/quantize_model.py +162 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/show_det_iterator.py +10 -1
- birder-0.3.0/birder/version.py +1 -0
- {birder-0.2.2 → birder-0.3.0}/birder.egg-info/PKG-INFO +4 -3
- {birder-0.2.2 → birder-0.3.0}/birder.egg-info/SOURCES.txt +7 -0
- {birder-0.2.2 → birder-0.3.0}/birder.egg-info/requires.txt +2 -1
- {birder-0.2.2 → birder-0.3.0}/requirements/_requirements-dev.txt +3 -1
- {birder-0.2.2 → birder-0.3.0}/tests/test_common.py +271 -16
- birder-0.3.0/tests/test_dataloaders.py +101 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_inference.py +69 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_kernels.py +13 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_model_registry.py +2 -2
- {birder-0.2.2 → birder-0.3.0}/tests/test_net.py +274 -177
- {birder-0.2.2 → birder-0.3.0}/tests/test_net_detection.py +44 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_transforms.py +9 -0
- birder-0.2.2/birder/ops/msda.py +0 -138
- birder-0.2.2/birder/ops/swattention.py +0 -225
- birder-0.2.2/birder/tools/det_results.py +0 -61
- birder-0.2.2/birder/tools/quantize_model.py +0 -156
- birder-0.2.2/birder/version.py +0 -1
- {birder-0.2.2 → birder-0.3.0}/LICENSE +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/adversarial/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/adversarial/base.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/adversarial/deepfool.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/adversarial/fgsm.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/adversarial/pgd.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/adversarial/simba.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/common/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/common/cli.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/common/fs_ops.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/common/masking.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/conf/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/conf/settings.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/collators/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/dataloader/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/directory.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/fake.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/webdataset.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/classification.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/mosaic.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/datahub/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/datahub/_lib.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/datahub/classification.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/inference/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/inference/classification.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/inference/data_parallel.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/introspection/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/introspection/attention_rollout.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/introspection/base.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/introspection/gradcam.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/introspection/guided_backprop.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/introspection/transformer_attribution.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/vision.cpp +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/soft_nms/op.cpp +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/soft_nms/soft_nms.h +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/swattention.cpp +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/layers/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/layers/activations.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/layers/attention_pool.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/layers/ffn.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/layers/gem.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/layers/layer_norm.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/layers/layer_scale.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/model_registry/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/model_registry/manifest.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/model_registry/model_registry.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/alexnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/base.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/biformer.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/cas_vit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/coat.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/conv2former.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/convmixer.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/convnext_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/cspnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/cswin_transformer.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/darknet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/davit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/densenet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/base.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/faster_rcnn.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/ssd.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/ssdlite.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/detection/vitdet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/dpn.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/edgenext.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/edgevit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientnet_lite.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientnet_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientnet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientvim.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/efficientvit_mit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/focalnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ghostnet_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ghostnet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/groupmixformer.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/hgnet_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/hgnet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/hieradet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/inception_next.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/inception_resnet_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/inception_resnet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/inception_v3.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/inception_v4.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/metaformer.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mim/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mim/base.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mim/crossmae.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mim/fcmae.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mim/mae_hiera.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mim/mae_vit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mim/simmim.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mnasnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v3_large.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v3_small.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v4.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilevit_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/mobilevit_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/moganet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/nextvit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/nfnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/pvt_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/rdnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/regionvit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/regnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/regnet_z.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/resmlp.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/resnest.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/resnet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/se_resnet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/sequencer2d.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/shufflenet_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/shufflenet_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/smt.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/squeezenet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/squeezenext.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/barlow_twins.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/base.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/byol.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/capi.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/data2vec.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/data2vec2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/dino_v1.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/dino_v2.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/franca.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/i_jepa.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/ibot.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/mmcr.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/simclr.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/sscd.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/vicreg.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/starnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/swiftformer.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/uniformer.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/van.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/vgg.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/vgg_reduced.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/wide_resnet.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/xception.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/net/xcit.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/ops/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/ops/soft_nms.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/optim/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/optim/lamb.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/optim/lars.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/py.typed +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/results/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/results/classification.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/results/gui.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/scheduler/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/scheduler/cooldown.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/__main__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/scripts/evaluate.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/__init__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/__main__.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/adversarial.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/avg_model.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/download_model.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/ensemble_model.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/introspection.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/labelme_to_coco.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/list_models.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/model_info.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/results.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/show_iterator.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/similarity.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/stats.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/verify_coco.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/verify_directory.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder/tools/voc_to_coco.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder.egg-info/dependency_links.txt +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder.egg-info/entry_points.txt +0 -0
- {birder-0.2.2 → birder-0.3.0}/birder.egg-info/top_level.txt +0 -0
- {birder-0.2.2 → birder-0.3.0}/pyproject.toml +0 -0
- {birder-0.2.2 → birder-0.3.0}/requirements/requirements-hf.txt +0 -0
- {birder-0.2.2 → birder-0.3.0}/requirements/requirements.txt +0 -0
- {birder-0.2.2 → birder-0.3.0}/setup.cfg +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_adversarial.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_collators.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_datasets.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_introspection.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_layers.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_net_mim.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_net_ssl.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_ops.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_optim.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_results.py +0 -0
- {birder-0.2.2 → birder-0.3.0}/tests/test_scheduler.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -45,7 +45,7 @@ Provides-Extra: dev
|
|
|
45
45
|
Requires-Dist: altair~=5.5.0; extra == "dev"
|
|
46
46
|
Requires-Dist: bandit~=1.9.2; extra == "dev"
|
|
47
47
|
Requires-Dist: black~=25.12.0; extra == "dev"
|
|
48
|
-
Requires-Dist: build~=1.
|
|
48
|
+
Requires-Dist: build~=1.4.0; extra == "dev"
|
|
49
49
|
Requires-Dist: bumpver~=2025.1131; extra == "dev"
|
|
50
50
|
Requires-Dist: captum~=0.7.0; extra == "dev"
|
|
51
51
|
Requires-Dist: coverage~=7.13.1; extra == "dev"
|
|
@@ -66,6 +66,7 @@ Requires-Dist: pytest; extra == "dev"
|
|
|
66
66
|
Requires-Dist: requests~=2.32.5; extra == "dev"
|
|
67
67
|
Requires-Dist: safetensors~=0.7.0; extra == "dev"
|
|
68
68
|
Requires-Dist: setuptools; extra == "dev"
|
|
69
|
+
Requires-Dist: torchao~=0.15.0; extra == "dev"
|
|
69
70
|
Requires-Dist: torchprofile==0.0.4; extra == "dev"
|
|
70
71
|
Requires-Dist: twine~=6.2.0; extra == "dev"
|
|
71
72
|
Requires-Dist: types-requests~=2.32.4; extra == "dev"
|
|
@@ -208,7 +209,7 @@ For detailed information about these datasets, including descriptions, citations
|
|
|
208
209
|
|
|
209
210
|
## Detection
|
|
210
211
|
|
|
211
|
-
Detection training and inference are available, see [docs/
|
|
212
|
+
Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
|
|
212
213
|
[docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
|
|
213
214
|
|
|
214
215
|
## Project Status and Contributions
|
|
@@ -129,7 +129,7 @@ For detailed information about these datasets, including descriptions, citations
|
|
|
129
129
|
|
|
130
130
|
## Detection
|
|
131
131
|
|
|
132
|
-
Detection training and inference are available, see [docs/
|
|
132
|
+
Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
|
|
133
133
|
[docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
|
|
134
134
|
|
|
135
135
|
## Project Status and Contributions
|
|
@@ -1,11 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import random
|
|
3
2
|
from typing import Any
|
|
4
3
|
from typing import Optional
|
|
5
4
|
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
5
|
from birder.conf import settings
|
|
10
6
|
from birder.data.transforms.classification import RGBType
|
|
11
7
|
from birder.model_registry import registry
|
|
@@ -19,11 +15,8 @@ from birder.net.ssl.base import SSLBaseNet
|
|
|
19
15
|
from birder.version import __version__
|
|
20
16
|
|
|
21
17
|
|
|
22
|
-
def
|
|
23
|
-
|
|
24
|
-
torch.cuda.manual_seed_all(seed)
|
|
25
|
-
np.random.seed(seed)
|
|
26
|
-
random.seed(seed)
|
|
18
|
+
def env_bool(name: str) -> bool:
|
|
19
|
+
return os.environ.get(name, "").lower() in {"1", "true", "yes", "on"}
|
|
27
20
|
|
|
28
21
|
|
|
29
22
|
def get_size_from_signature(signature: SignatureType | DetectionSignatureType) -> tuple[int, int]:
|
|
@@ -5,6 +5,7 @@ import typing
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
from typing import get_args
|
|
7
7
|
|
|
8
|
+
from birder.common.cli import FlexibleDictAction
|
|
8
9
|
from birder.common.cli import ValidationError
|
|
9
10
|
from birder.common.training_utils import OptimizerType
|
|
10
11
|
from birder.common.training_utils import SchedulerType
|
|
@@ -82,11 +83,23 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
|
|
|
82
83
|
metavar="WD",
|
|
83
84
|
help="weight decay for embedding parameters for vision transformer models",
|
|
84
85
|
)
|
|
86
|
+
group.add_argument(
|
|
87
|
+
"--custom-layer-wd",
|
|
88
|
+
action=FlexibleDictAction,
|
|
89
|
+
metavar="LAYER=WD",
|
|
90
|
+
help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
|
|
91
|
+
)
|
|
85
92
|
group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
|
|
86
93
|
group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
|
|
87
94
|
group.add_argument(
|
|
88
95
|
"--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
|
|
89
96
|
)
|
|
97
|
+
group.add_argument(
|
|
98
|
+
"--custom-layer-lr-scale",
|
|
99
|
+
action=FlexibleDictAction,
|
|
100
|
+
metavar="LAYER=SCALE",
|
|
101
|
+
help="custom lr_scale for specific layers by name (e.g., offset_conv=0.01,attention=0.5)",
|
|
102
|
+
)
|
|
90
103
|
|
|
91
104
|
|
|
92
105
|
def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -185,6 +198,11 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
|
|
|
185
198
|
action="store_true",
|
|
186
199
|
help="enable random square resize once per batch (capped by max(--size))",
|
|
187
200
|
)
|
|
201
|
+
group.add_argument(
|
|
202
|
+
"--multiscale-min-size",
|
|
203
|
+
type=int,
|
|
204
|
+
help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
|
|
205
|
+
)
|
|
188
206
|
|
|
189
207
|
|
|
190
208
|
def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
|
|
@@ -193,6 +211,12 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
|
|
|
193
211
|
group.add_argument(
|
|
194
212
|
"--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
|
|
195
213
|
)
|
|
214
|
+
group.add_argument(
|
|
215
|
+
"--steps-per-epoch",
|
|
216
|
+
type=int,
|
|
217
|
+
metavar="N",
|
|
218
|
+
help="virtual epoch length in steps, leave unset to use the full dataset",
|
|
219
|
+
)
|
|
196
220
|
group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
|
|
197
221
|
group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
|
|
198
222
|
group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
|
|
@@ -3,8 +3,10 @@ import contextlib
|
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
5
|
import os
|
|
6
|
+
import random
|
|
6
7
|
import re
|
|
7
8
|
import subprocess
|
|
9
|
+
import sys
|
|
8
10
|
from collections import deque
|
|
9
11
|
from collections.abc import Callable
|
|
10
12
|
from collections.abc import Generator
|
|
@@ -15,6 +17,7 @@ from typing import Any
|
|
|
15
17
|
from typing import Literal
|
|
16
18
|
from typing import Optional
|
|
17
19
|
from typing import Sized
|
|
20
|
+
from typing import overload
|
|
18
21
|
|
|
19
22
|
import numpy as np
|
|
20
23
|
import torch
|
|
@@ -29,12 +32,25 @@ from birder.data.transforms.classification import training_preset
|
|
|
29
32
|
from birder.optim import Lamb
|
|
30
33
|
from birder.optim import Lars
|
|
31
34
|
from birder.scheduler import CooldownLR
|
|
35
|
+
from birder.version import __version__ as birder_version
|
|
32
36
|
|
|
33
37
|
logger = logging.getLogger(__name__)
|
|
34
38
|
|
|
35
39
|
OptimizerType = Literal["sgd", "rmsprop", "adam", "adamw", "nadam", "nadamw", "lamb", "lambw", "lars"]
|
|
36
40
|
SchedulerType = Literal["constant", "step", "multistep", "cosine", "polynomial"]
|
|
37
41
|
|
|
42
|
+
###############################################################################
|
|
43
|
+
# Core Utilities
|
|
44
|
+
###############################################################################
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def set_random_seeds(seed: int) -> None:
|
|
48
|
+
torch.manual_seed(seed)
|
|
49
|
+
torch.cuda.manual_seed_all(seed)
|
|
50
|
+
np.random.seed(seed)
|
|
51
|
+
random.seed(seed)
|
|
52
|
+
|
|
53
|
+
|
|
38
54
|
###############################################################################
|
|
39
55
|
# Data Sampling
|
|
40
56
|
###############################################################################
|
|
@@ -55,13 +71,7 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
55
71
|
"""
|
|
56
72
|
|
|
57
73
|
def __init__(
|
|
58
|
-
self,
|
|
59
|
-
dataset: Sized,
|
|
60
|
-
num_replicas: int,
|
|
61
|
-
rank: int,
|
|
62
|
-
shuffle: bool,
|
|
63
|
-
seed: int = 0,
|
|
64
|
-
repetitions: int = 3,
|
|
74
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
65
75
|
) -> None:
|
|
66
76
|
super().__init__()
|
|
67
77
|
self.dataset = dataset
|
|
@@ -70,12 +80,11 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
70
80
|
self.epoch = 0
|
|
71
81
|
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
72
82
|
self.total_size = self.num_samples * self.num_replicas
|
|
73
|
-
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
|
74
83
|
self.shuffle = shuffle
|
|
75
84
|
self.seed = seed
|
|
76
85
|
self.repetitions = repetitions
|
|
77
86
|
|
|
78
|
-
def __iter__(self) -> Iterator[
|
|
87
|
+
def __iter__(self) -> Iterator[int]:
|
|
79
88
|
if self.shuffle is True:
|
|
80
89
|
# Deterministically shuffle based on epoch
|
|
81
90
|
g = torch.Generator()
|
|
@@ -85,18 +94,148 @@ class RASampler(torch.utils.data.Sampler):
|
|
|
85
94
|
indices = list(range(len(self.dataset)))
|
|
86
95
|
|
|
87
96
|
# Add extra samples to make it evenly divisible
|
|
88
|
-
indices = [ele for ele in indices for
|
|
89
|
-
indices
|
|
90
|
-
|
|
97
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
98
|
+
if len(indices) < self.total_size:
|
|
99
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
100
|
+
else:
|
|
101
|
+
indices = indices[: self.total_size]
|
|
91
102
|
|
|
92
|
-
#
|
|
103
|
+
# Shard by rank
|
|
93
104
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
94
105
|
assert len(indices) == self.num_samples
|
|
95
106
|
|
|
96
|
-
|
|
107
|
+
yield from indices
|
|
108
|
+
|
|
109
|
+
def __len__(self) -> int:
|
|
110
|
+
return self.num_samples
|
|
111
|
+
|
|
112
|
+
def set_epoch(self, epoch: int) -> None:
|
|
113
|
+
self.epoch = epoch
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class InfiniteSampler(torch.utils.data.Sampler):
|
|
117
|
+
"""
|
|
118
|
+
Infinite sampler that loops indefinitely over the dataset
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, dataset: Sized, shuffle: bool, seed: int = 0) -> None:
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.dataset = dataset
|
|
124
|
+
self.shuffle = shuffle
|
|
125
|
+
self.seed = seed
|
|
126
|
+
self.epoch = 0
|
|
127
|
+
|
|
128
|
+
def __iter__(self) -> Iterator[int]:
|
|
129
|
+
g = torch.Generator()
|
|
130
|
+
while True:
|
|
131
|
+
if self.shuffle is True:
|
|
132
|
+
g.manual_seed(self.seed + self.epoch)
|
|
133
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
134
|
+
else:
|
|
135
|
+
indices = list(range(len(self.dataset)))
|
|
136
|
+
|
|
137
|
+
yield from indices
|
|
138
|
+
|
|
139
|
+
logger.debug(f"InfiniteSampler finished epoch {self.epoch}")
|
|
140
|
+
self.epoch += 1
|
|
141
|
+
|
|
142
|
+
def __len__(self) -> int:
|
|
143
|
+
return len(self.dataset)
|
|
144
|
+
|
|
145
|
+
def set_epoch(self, epoch: int) -> None:
|
|
146
|
+
self.epoch = epoch
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class InfiniteDistributedSampler(torch.utils.data.Sampler):
|
|
150
|
+
"""
|
|
151
|
+
Infinite distributed sampler that keeps a continuous shuffled stream per rank
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0) -> None:
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.dataset = dataset
|
|
157
|
+
self.num_replicas = num_replicas
|
|
158
|
+
self.rank = rank
|
|
159
|
+
self.shuffle = shuffle
|
|
160
|
+
self.seed = seed
|
|
161
|
+
self.epoch = 0
|
|
162
|
+
self.num_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
|
163
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
164
|
+
|
|
165
|
+
def __iter__(self) -> Iterator[int]:
|
|
166
|
+
g = torch.Generator()
|
|
167
|
+
while True:
|
|
168
|
+
if self.shuffle is True:
|
|
169
|
+
g.manual_seed(self.seed + self.epoch)
|
|
170
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
171
|
+
else:
|
|
172
|
+
indices = list(range(len(self.dataset)))
|
|
173
|
+
|
|
174
|
+
if len(indices) < self.total_size:
|
|
175
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
176
|
+
else:
|
|
177
|
+
indices = indices[: self.total_size]
|
|
178
|
+
|
|
179
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
180
|
+
assert len(indices) == self.num_samples
|
|
181
|
+
|
|
182
|
+
yield from indices
|
|
183
|
+
|
|
184
|
+
logger.debug(f"InfiniteDistributedSampler finished epoch {self.epoch}")
|
|
185
|
+
self.epoch += 1
|
|
186
|
+
|
|
187
|
+
def __len__(self) -> int:
|
|
188
|
+
return self.num_samples
|
|
189
|
+
|
|
190
|
+
def set_epoch(self, epoch: int) -> None:
|
|
191
|
+
self.epoch = epoch
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class InfiniteRASampler(torch.utils.data.Sampler):
|
|
195
|
+
"""
|
|
196
|
+
Infinite version of the repeated augmentation sampler
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(
|
|
200
|
+
self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
|
|
201
|
+
) -> None:
|
|
202
|
+
super().__init__()
|
|
203
|
+
self.dataset = dataset
|
|
204
|
+
self.num_replicas = num_replicas
|
|
205
|
+
self.rank = rank
|
|
206
|
+
self.epoch = 0
|
|
207
|
+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
|
208
|
+
self.total_size = self.num_samples * self.num_replicas
|
|
209
|
+
self.shuffle = shuffle
|
|
210
|
+
self.seed = seed
|
|
211
|
+
self.repetitions = repetitions
|
|
212
|
+
|
|
213
|
+
def __iter__(self) -> Iterator[int]:
|
|
214
|
+
g = torch.Generator()
|
|
215
|
+
while True:
|
|
216
|
+
if self.shuffle is True:
|
|
217
|
+
g.manual_seed(self.seed + self.epoch)
|
|
218
|
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
219
|
+
else:
|
|
220
|
+
indices = list(range(len(self.dataset)))
|
|
221
|
+
|
|
222
|
+
indices = [ele for ele in indices for _ in range(self.repetitions)]
|
|
223
|
+
if len(indices) < self.total_size:
|
|
224
|
+
indices += indices[: (self.total_size - len(indices))]
|
|
225
|
+
else:
|
|
226
|
+
indices = indices[: self.total_size]
|
|
227
|
+
|
|
228
|
+
# Shard by rank
|
|
229
|
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
|
230
|
+
assert len(indices) == self.num_samples
|
|
231
|
+
|
|
232
|
+
yield from indices
|
|
233
|
+
|
|
234
|
+
logger.debug(f"InfiniteRASampler finished epoch {self.epoch}")
|
|
235
|
+
self.epoch += 1
|
|
97
236
|
|
|
98
237
|
def __len__(self) -> int:
|
|
99
|
-
return self.
|
|
238
|
+
return self.num_samples
|
|
100
239
|
|
|
101
240
|
def set_epoch(self, epoch: int) -> None:
|
|
102
241
|
self.epoch = epoch
|
|
@@ -207,13 +346,16 @@ def count_layers(model: torch.nn.Module) -> int:
|
|
|
207
346
|
def optimizer_parameter_groups(
|
|
208
347
|
model: torch.nn.Module,
|
|
209
348
|
weight_decay: float,
|
|
349
|
+
base_lr: float,
|
|
210
350
|
norm_weight_decay: Optional[float] = None,
|
|
211
351
|
custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
|
|
352
|
+
custom_layer_weight_decay: Optional[dict[str, float]] = None,
|
|
212
353
|
layer_decay: Optional[float] = None,
|
|
213
354
|
layer_decay_min_scale: Optional[float] = None,
|
|
214
355
|
layer_decay_no_opt_scale: Optional[float] = None,
|
|
215
356
|
bias_lr: Optional[float] = None,
|
|
216
357
|
backbone_lr: Optional[float] = None,
|
|
358
|
+
custom_layer_lr_scale: Optional[dict[str, float]] = None,
|
|
217
359
|
) -> list[dict[str, Any]]:
|
|
218
360
|
"""
|
|
219
361
|
Return parameter groups for optimizers with per-parameter group weight decay.
|
|
@@ -233,11 +375,16 @@ def optimizer_parameter_groups(
|
|
|
233
375
|
The PyTorch model whose parameters will be grouped for optimization.
|
|
234
376
|
weight_decay
|
|
235
377
|
Default weight decay (L2 regularization) value applied to parameters.
|
|
378
|
+
base_lr
|
|
379
|
+
Base learning rate that will be scaled by lr_scale factors for each parameter group.
|
|
236
380
|
norm_weight_decay
|
|
237
381
|
Weight decay value specifically for normalization layers. If None, uses weight_decay.
|
|
238
382
|
custom_keys_weight_decay
|
|
239
383
|
List of (parameter_name, weight_decay) tuples for applying custom weight decay
|
|
240
384
|
values to specific parameters by name matching.
|
|
385
|
+
custom_layer_weight_decay
|
|
386
|
+
Dictionary mapping layer name substrings to custom weight decay values.
|
|
387
|
+
Applied to parameters whose names contain the specified keys.
|
|
241
388
|
layer_decay
|
|
242
389
|
Layer-wise learning rate decay factor.
|
|
243
390
|
layer_decay_min_scale
|
|
@@ -248,6 +395,9 @@ def optimizer_parameter_groups(
|
|
|
248
395
|
Custom learning rate for bias parameters (parameters ending with '.bias').
|
|
249
396
|
backbone_lr
|
|
250
397
|
Custom learning rate for backbone parameters (parameters starting with 'backbone.').
|
|
398
|
+
custom_layer_lr_scale
|
|
399
|
+
Dictionary mapping layer name substrings to custom lr_scale values.
|
|
400
|
+
Applied to parameters whose names contain the specified keys.
|
|
251
401
|
|
|
252
402
|
Returns
|
|
253
403
|
-------
|
|
@@ -291,14 +441,14 @@ def optimizer_parameter_groups(
|
|
|
291
441
|
if layer_decay is not None:
|
|
292
442
|
layer_max = num_layers - 1
|
|
293
443
|
layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
|
|
294
|
-
logger.info(f"Layer scaling
|
|
444
|
+
logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
|
|
295
445
|
|
|
296
446
|
# Set weight decay and layer decay
|
|
297
447
|
idx = 0
|
|
298
448
|
params = []
|
|
299
449
|
module_stack_with_prefix = [(model, "")]
|
|
300
450
|
visited_modules = []
|
|
301
|
-
while len(module_stack_with_prefix) > 0:
|
|
451
|
+
while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
|
|
302
452
|
skip_module = False
|
|
303
453
|
(module, prefix) = module_stack_with_prefix.pop()
|
|
304
454
|
if id(module) in visited_modules:
|
|
@@ -324,13 +474,35 @@ def optimizer_parameter_groups(
|
|
|
324
474
|
for key, custom_wd in custom_keys_weight_decay:
|
|
325
475
|
target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
|
326
476
|
if key == target_name_for_custom_key:
|
|
477
|
+
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
478
|
+
lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
|
|
479
|
+
if custom_layer_lr_scale is not None:
|
|
480
|
+
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
481
|
+
if layer_name_key in target_name:
|
|
482
|
+
lr_scale = custom_scale
|
|
483
|
+
break
|
|
484
|
+
|
|
485
|
+
# Apply custom layer weight decay (substring matching)
|
|
486
|
+
wd = custom_wd
|
|
487
|
+
if custom_layer_weight_decay is not None:
|
|
488
|
+
for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
|
|
489
|
+
if layer_name_key in target_name:
|
|
490
|
+
wd = custom_wd_value
|
|
491
|
+
break
|
|
492
|
+
|
|
327
493
|
d = {
|
|
328
494
|
"params": p,
|
|
329
|
-
"weight_decay":
|
|
330
|
-
"lr_scale":
|
|
495
|
+
"weight_decay": wd,
|
|
496
|
+
"lr_scale": lr_scale, # Used only for reference/debugging
|
|
331
497
|
}
|
|
332
|
-
|
|
498
|
+
|
|
499
|
+
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
500
|
+
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
501
|
+
d["lr"] = bias_lr
|
|
502
|
+
elif backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
333
503
|
d["lr"] = backbone_lr
|
|
504
|
+
elif lr_scale != 1.0:
|
|
505
|
+
d["lr"] = base_lr * lr_scale
|
|
334
506
|
|
|
335
507
|
params.append(d)
|
|
336
508
|
is_custom_key = True
|
|
@@ -342,16 +514,34 @@ def optimizer_parameter_groups(
|
|
|
342
514
|
else:
|
|
343
515
|
wd = weight_decay
|
|
344
516
|
|
|
517
|
+
# Apply custom layer weight decay (substring matching)
|
|
518
|
+
if custom_layer_weight_decay is not None:
|
|
519
|
+
for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
|
|
520
|
+
if layer_name_key in target_name:
|
|
521
|
+
wd = custom_wd_value
|
|
522
|
+
break
|
|
523
|
+
|
|
524
|
+
# Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
|
|
525
|
+
lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
|
|
526
|
+
if custom_layer_lr_scale is not None:
|
|
527
|
+
for layer_name_key, custom_scale in custom_layer_lr_scale.items():
|
|
528
|
+
if layer_name_key in target_name:
|
|
529
|
+
lr_scale = custom_scale
|
|
530
|
+
break
|
|
531
|
+
|
|
345
532
|
d = {
|
|
346
533
|
"params": p,
|
|
347
534
|
"weight_decay": wd,
|
|
348
|
-
"lr_scale":
|
|
535
|
+
"lr_scale": lr_scale, # Used only for reference/debugging
|
|
349
536
|
}
|
|
350
|
-
if backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
351
|
-
d["lr"] = backbone_lr
|
|
352
537
|
|
|
538
|
+
# Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
|
|
353
539
|
if bias_lr is not None and target_name.endswith(".bias") is True:
|
|
354
540
|
d["lr"] = bias_lr
|
|
541
|
+
elif backbone_lr is not None and target_name.startswith("backbone.") is True:
|
|
542
|
+
d["lr"] = backbone_lr
|
|
543
|
+
elif lr_scale != 1.0:
|
|
544
|
+
d["lr"] = base_lr * lr_scale
|
|
355
545
|
|
|
356
546
|
params.append(d)
|
|
357
547
|
|
|
@@ -442,6 +632,8 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
|
|
|
442
632
|
else:
|
|
443
633
|
raise ValueError("Unknown optimizer")
|
|
444
634
|
|
|
635
|
+
logger.debug(f"Created {opt} optimizer with lr={lr}, weight_decay={args.wd}")
|
|
636
|
+
|
|
445
637
|
return optimizer
|
|
446
638
|
|
|
447
639
|
|
|
@@ -477,10 +669,10 @@ def get_scheduler(
|
|
|
477
669
|
|
|
478
670
|
main_steps = steps - begin_step - remaining_warmup - remaining_cooldown - 1
|
|
479
671
|
|
|
480
|
-
logger.debug(f"
|
|
672
|
+
logger.debug(f"Scheduler using {steps_per_epoch} steps per epoch")
|
|
481
673
|
logger.debug(
|
|
482
674
|
f"Scheduler {args.lr_scheduler} set for {steps} steps of which {warmup_steps} "
|
|
483
|
-
f"are warmup and {cooldown_steps} cooldown"
|
|
675
|
+
f"are warmup and {cooldown_steps} are cooldown"
|
|
484
676
|
)
|
|
485
677
|
logger.debug(
|
|
486
678
|
f"Currently starting from step {begin_step} with {remaining_warmup} remaining warmup steps "
|
|
@@ -568,27 +760,87 @@ def get_amp_scaler(amp: bool, amp_dtype_str: str) -> tuple[Optional[torch.amp.Gr
|
|
|
568
760
|
return (scaler, amp_dtype)
|
|
569
761
|
|
|
570
762
|
|
|
763
|
+
@overload
|
|
571
764
|
def get_samplers(
|
|
572
|
-
args: argparse.Namespace,
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
training_dataset,
|
|
578
|
-
num_replicas=args.world_size,
|
|
579
|
-
rank=args.rank,
|
|
580
|
-
shuffle=True,
|
|
581
|
-
repetitions=args.ra_reps,
|
|
582
|
-
)
|
|
765
|
+
args: argparse.Namespace,
|
|
766
|
+
training_dataset: torch.utils.data.Dataset,
|
|
767
|
+
validation_dataset: torch.utils.data.Dataset,
|
|
768
|
+
infinite: bool = False,
|
|
769
|
+
) -> tuple[torch.utils.data.Sampler, torch.utils.data.Sampler]: ...
|
|
583
770
|
|
|
584
|
-
else:
|
|
585
|
-
train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset, shuffle=True)
|
|
586
771
|
|
|
587
|
-
|
|
772
|
+
@overload
|
|
773
|
+
def get_samplers(
|
|
774
|
+
args: argparse.Namespace,
|
|
775
|
+
training_dataset: torch.utils.data.Dataset,
|
|
776
|
+
validation_dataset: None = None,
|
|
777
|
+
infinite: bool = False,
|
|
778
|
+
) -> tuple[torch.utils.data.Sampler, None]: ...
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def get_samplers(
|
|
782
|
+
args: argparse.Namespace,
|
|
783
|
+
training_dataset: torch.utils.data.Dataset,
|
|
784
|
+
validation_dataset: Optional[torch.utils.data.Dataset] = None,
|
|
785
|
+
infinite: bool = False,
|
|
786
|
+
) -> tuple[torch.utils.data.Sampler, Optional[torch.utils.data.Sampler]]:
|
|
787
|
+
if args.seed is None:
|
|
788
|
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
|
789
|
+
if is_dist_available_and_initialized() is True:
|
|
790
|
+
seed_tensor = torch.tensor(seed, dtype=torch.int64).cuda()
|
|
791
|
+
dist.broadcast(seed_tensor, src=0, async_op=False)
|
|
792
|
+
seed = int(seed_tensor.item())
|
|
793
|
+
else:
|
|
794
|
+
seed = args.seed
|
|
795
|
+
|
|
796
|
+
ra_sampler = getattr(args, "ra_sampler", False)
|
|
797
|
+
if args.distributed is True:
|
|
798
|
+
if infinite is True:
|
|
799
|
+
if ra_sampler is True:
|
|
800
|
+
train_sampler = InfiniteRASampler(
|
|
801
|
+
training_dataset,
|
|
802
|
+
num_replicas=args.world_size,
|
|
803
|
+
rank=args.rank,
|
|
804
|
+
shuffle=True,
|
|
805
|
+
seed=seed,
|
|
806
|
+
repetitions=args.ra_reps,
|
|
807
|
+
)
|
|
808
|
+
else:
|
|
809
|
+
train_sampler = InfiniteDistributedSampler(
|
|
810
|
+
training_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True, seed=seed
|
|
811
|
+
)
|
|
812
|
+
else:
|
|
813
|
+
if ra_sampler is True:
|
|
814
|
+
train_sampler = RASampler(
|
|
815
|
+
training_dataset,
|
|
816
|
+
num_replicas=args.world_size,
|
|
817
|
+
rank=args.rank,
|
|
818
|
+
shuffle=True,
|
|
819
|
+
seed=seed,
|
|
820
|
+
repetitions=args.ra_reps,
|
|
821
|
+
)
|
|
822
|
+
else:
|
|
823
|
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
824
|
+
training_dataset, shuffle=True, seed=seed
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
if validation_dataset is None:
|
|
828
|
+
validation_sampler = None
|
|
829
|
+
else:
|
|
830
|
+
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
|
|
588
831
|
|
|
589
832
|
else:
|
|
590
|
-
|
|
591
|
-
|
|
833
|
+
if infinite is True:
|
|
834
|
+
train_sampler = InfiniteSampler(training_dataset, shuffle=True, seed=seed)
|
|
835
|
+
else:
|
|
836
|
+
generator = torch.Generator()
|
|
837
|
+
generator.manual_seed(seed)
|
|
838
|
+
train_sampler = torch.utils.data.RandomSampler(training_dataset, generator=generator)
|
|
839
|
+
|
|
840
|
+
if validation_dataset is None:
|
|
841
|
+
validation_sampler = None
|
|
842
|
+
else:
|
|
843
|
+
validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
|
|
592
844
|
|
|
593
845
|
return (train_sampler, validation_sampler)
|
|
594
846
|
|
|
@@ -810,6 +1062,51 @@ def is_local_primary(args: argparse.Namespace) -> bool:
|
|
|
810
1062
|
return args.local_rank == 0 # type: ignore[no-any-return]
|
|
811
1063
|
|
|
812
1064
|
|
|
1065
|
+
def init_training(
|
|
1066
|
+
args: argparse.Namespace,
|
|
1067
|
+
log: logging.Logger,
|
|
1068
|
+
*,
|
|
1069
|
+
cudnn_dynamic_size: bool = False,
|
|
1070
|
+
) -> tuple[torch.device, int, bool]:
|
|
1071
|
+
init_distributed_mode(args)
|
|
1072
|
+
|
|
1073
|
+
log.info(f"Starting training, birder version: {birder_version}, pytorch version: {torch.__version__}")
|
|
1074
|
+
|
|
1075
|
+
log_git_info()
|
|
1076
|
+
|
|
1077
|
+
if args.cpu is True:
|
|
1078
|
+
device = torch.device("cpu")
|
|
1079
|
+
device_id = 0
|
|
1080
|
+
else:
|
|
1081
|
+
device = torch.device("cuda")
|
|
1082
|
+
device_id = torch.cuda.current_device()
|
|
1083
|
+
|
|
1084
|
+
if args.use_deterministic_algorithms is True:
|
|
1085
|
+
torch.backends.cudnn.benchmark = False
|
|
1086
|
+
torch.use_deterministic_algorithms(True)
|
|
1087
|
+
elif cudnn_dynamic_size is True:
|
|
1088
|
+
# Dynamic sizes: avoid per-size algorithm selection overhead.
|
|
1089
|
+
torch.backends.cudnn.enabled = False
|
|
1090
|
+
else:
|
|
1091
|
+
torch.backends.cudnn.enabled = True
|
|
1092
|
+
torch.backends.cudnn.benchmark = True
|
|
1093
|
+
|
|
1094
|
+
if args.seed is not None:
|
|
1095
|
+
set_random_seeds(args.seed)
|
|
1096
|
+
|
|
1097
|
+
if args.non_interactive is True or is_local_primary(args) is False:
|
|
1098
|
+
disable_tqdm = True
|
|
1099
|
+
elif sys.stderr.isatty() is False:
|
|
1100
|
+
disable_tqdm = True
|
|
1101
|
+
else:
|
|
1102
|
+
disable_tqdm = False
|
|
1103
|
+
|
|
1104
|
+
# Enable or disable the autograd anomaly detection.
|
|
1105
|
+
torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
|
|
1106
|
+
|
|
1107
|
+
return (device, device_id, disable_tqdm)
|
|
1108
|
+
|
|
1109
|
+
|
|
813
1110
|
###############################################################################
|
|
814
1111
|
# Utility Functions
|
|
815
1112
|
###############################################################################
|