birder 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder-0.2.1 → birder-0.2.2}/PKG-INFO +2 -1
- birder-0.2.2/birder/adversarial/__init__.py +13 -0
- birder-0.2.2/birder/adversarial/base.py +101 -0
- birder-0.2.2/birder/adversarial/deepfool.py +173 -0
- birder-0.2.2/birder/adversarial/fgsm.py +67 -0
- birder-0.2.2/birder/adversarial/pgd.py +105 -0
- birder-0.2.2/birder/adversarial/simba.py +172 -0
- {birder-0.2.1 → birder-0.2.2}/birder/common/training_cli.py +11 -3
- {birder-0.2.1 → birder-0.2.2}/birder/common/training_utils.py +18 -1
- {birder-0.2.1 → birder-0.2.2}/birder/inference/data_parallel.py +1 -2
- birder-0.2.2/birder/introspection/__init__.py +13 -0
- birder-0.2.2/birder/introspection/attention_rollout.py +185 -0
- birder-0.2.2/birder/introspection/base.py +104 -0
- birder-0.2.2/birder/introspection/gradcam.py +147 -0
- birder-0.2.2/birder/introspection/guided_backprop.py +229 -0
- birder-0.2.2/birder/introspection/transformer_attribution.py +182 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/deformable_detr.py +14 -12
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/detr.py +7 -3
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/rt_detr_v1.py +3 -3
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/yolo_v3.py +6 -11
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/yolo_v4.py +7 -18
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/yolo_v4_tiny.py +3 -3
- {birder-0.2.1 → birder-0.2.2}/birder/net/mim/mae_vit.py +7 -8
- {birder-0.2.1 → birder-0.2.2}/birder/net/pit.py +1 -1
- {birder-0.2.1 → birder-0.2.2}/birder/net/resnet_v1.py +94 -34
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/data2vec.py +1 -1
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/data2vec2.py +4 -2
- {birder-0.2.1 → birder-0.2.2}/birder/results/gui.py +15 -2
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/predict_detection.py +33 -1
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train.py +24 -17
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_barlow_twins.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_byol.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_capi.py +12 -9
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_data2vec.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_data2vec2.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_detection.py +42 -18
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_dino_v1.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_dino_v2.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_dino_v2_dist.py +17 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_franca.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_i_jepa.py +17 -13
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_ibot.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_kd.py +24 -18
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_mim.py +11 -10
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_mmcr.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_rotnet.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_simclr.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/train_vicreg.py +10 -7
- {birder-0.2.1 → birder-0.2.2}/birder/tools/__main__.py +6 -2
- birder-0.2.2/birder/tools/adversarial.py +214 -0
- birder-0.2.2/birder/tools/auto_anchors.py +361 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/ensemble_model.py +1 -1
- {birder-0.2.1 → birder-0.2.2}/birder/tools/introspection.py +58 -31
- birder-0.2.2/birder/version.py +1 -0
- {birder-0.2.1 → birder-0.2.2}/birder.egg-info/PKG-INFO +2 -1
- {birder-0.2.1 → birder-0.2.2}/birder.egg-info/SOURCES.txt +7 -0
- {birder-0.2.1 → birder-0.2.2}/birder.egg-info/requires.txt +1 -0
- {birder-0.2.1 → birder-0.2.2}/requirements/_requirements-dev.txt +1 -0
- birder-0.2.2/tests/test_adversarial.py +238 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_common.py +5 -1
- birder-0.2.2/tests/test_introspection.py +310 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_net.py +3 -0
- birder-0.2.1/birder/adversarial/fgsm.py +0 -34
- birder-0.2.1/birder/adversarial/pgd.py +0 -54
- birder-0.2.1/birder/introspection/__init__.py +0 -9
- birder-0.2.1/birder/introspection/attention_rollout.py +0 -117
- birder-0.2.1/birder/introspection/base.py +0 -60
- birder-0.2.1/birder/introspection/gradcam.py +0 -176
- birder-0.2.1/birder/introspection/guided_backprop.py +0 -155
- birder-0.2.1/birder/tools/__init__.py +0 -0
- birder-0.2.1/birder/tools/adversarial.py +0 -163
- birder-0.2.1/birder/version.py +0 -1
- {birder-0.2.1 → birder-0.2.2}/LICENSE +0 -0
- {birder-0.2.1 → birder-0.2.2}/README.md +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/__init__.py +0 -0
- {birder-0.2.1/birder/adversarial → birder-0.2.2/birder/common}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/common/cli.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/common/fs_ops.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/common/lib.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/common/masking.py +0 -0
- {birder-0.2.1/birder/common → birder-0.2.2/birder/conf}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/conf/settings.py +0 -0
- {birder-0.2.1/birder/conf → birder-0.2.2/birder/data}/__init__.py +0 -0
- {birder-0.2.1/birder/data → birder-0.2.2/birder/data/collators}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/collators/detection.py +0 -0
- {birder-0.2.1/birder/data/collators → birder-0.2.2/birder/data/dataloader}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/dataloader/webdataset.py +0 -0
- {birder-0.2.1/birder/data/dataloader → birder-0.2.2/birder/data/datasets}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/datasets/coco.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/datasets/directory.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/datasets/fake.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/datasets/webdataset.py +0 -0
- {birder-0.2.1/birder/data/datasets → birder-0.2.2/birder/data/transforms}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/transforms/classification.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/transforms/detection.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/data/transforms/mosaic.py +0 -0
- {birder-0.2.1/birder/data/transforms → birder-0.2.2/birder/datahub}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/datahub/_lib.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/datahub/classification.py +0 -0
- {birder-0.2.1/birder/datahub → birder-0.2.2/birder/inference}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/inference/classification.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/inference/detection.py +0 -0
- {birder-0.2.1/birder/inference → birder-0.2.2/birder/kernels}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/deformable_detr/vision.cpp +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/load_kernel.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/soft_nms/op.cpp +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/soft_nms/soft_nms.cpp +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/soft_nms/soft_nms.h +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/kernels/transnext/swattention.cpp +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/layers/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/layers/activations.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/layers/attention_pool.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/layers/ffn.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/layers/gem.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/layers/layer_norm.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/layers/layer_scale.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/model_registry/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/model_registry/manifest.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/model_registry/model_registry.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/alexnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/base.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/biformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/cait.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/cas_vit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/coat.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/conv2former.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/convmixer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/convnext_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/convnext_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/crossformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/crossvit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/cspnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/cswin_transformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/darknet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/davit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/deit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/deit3.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/densenet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/base.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/efficientdet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/faster_rcnn.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/fcos.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/retinanet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/ssd.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/ssdlite.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/vitdet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/detection/yolo_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/dpn.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/edgenext.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/edgevit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientformer_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientformer_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientnet_lite.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientvim.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientvit_mit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/efficientvit_msft.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/fasternet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/fastvit.py +1 -1
- {birder-0.2.1 → birder-0.2.2}/birder/net/flexivit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/focalnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ghostnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ghostnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/groupmixformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/hgnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/hgnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/hiera.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/hieradet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/hornet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/iformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/inception_next.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/inception_resnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/inception_resnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/inception_v3.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/inception_v4.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/levit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/maxvit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/metaformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mim/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mim/base.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mim/crossmae.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mim/fcmae.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mim/mae_hiera.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mim/simmim.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mnasnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilenet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilenet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilenet_v3_large.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilenet_v3_small.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilenet_v4.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilenet_v4_hybrid.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobileone.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilevit_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mobilevit_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/moganet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/mvit_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/nextvit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/nfnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/pvt_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/pvt_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/rdnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/regionvit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/regnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/regnet_z.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/repghost.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/repvgg.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/repvit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/resmlp.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/resnest.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/resnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/resnext.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/rope_deit3.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/rope_flexivit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/rope_vit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/se_resnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/se_resnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/se_resnext.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/sequencer2d.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/shufflenet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/shufflenet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/simple_vit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/smt.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/squeezenet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/squeezenext.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/barlow_twins.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/base.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/byol.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/capi.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/dino_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/dino_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/franca.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/i_jepa.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/ibot.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/mmcr.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/simclr.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/sscd.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/ssl/vicreg.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/starnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/swiftformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/swin_transformer_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/swin_transformer_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/tiny_vit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/transnext.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/uniformer.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/van.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/vgg.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/vgg_reduced.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/vit.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/vit_parallel.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/vit_sam.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/vovnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/vovnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/wide_resnet.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/xception.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/net/xcit.py +0 -0
- {birder-0.2.1/birder/kernels → birder-0.2.2/birder/ops}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/ops/msda.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/ops/soft_nms.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/ops/swattention.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/optim/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/optim/lamb.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/optim/lars.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/py.typed +0 -0
- {birder-0.2.1/birder/ops → birder-0.2.2/birder/results}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/results/classification.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/results/detection.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/scheduler/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/scheduler/cooldown.py +0 -0
- {birder-0.2.1/birder/results → birder-0.2.2/birder/scripts}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/__main__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/benchmark.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/evaluate.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/scripts/predict.py +0 -0
- {birder-0.2.1/birder/scripts → birder-0.2.2/birder/tools}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/avg_model.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/convert_model.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/det_results.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/download_model.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/labelme_to_coco.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/list_models.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/model_info.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/pack.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/quantize_model.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/results.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/show_det_iterator.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/show_iterator.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/similarity.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/stats.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/verify_coco.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/verify_directory.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder/tools/voc_to_coco.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder.egg-info/dependency_links.txt +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder.egg-info/entry_points.txt +0 -0
- {birder-0.2.1 → birder-0.2.2}/birder.egg-info/top_level.txt +0 -0
- {birder-0.2.1 → birder-0.2.2}/pyproject.toml +0 -0
- {birder-0.2.1 → birder-0.2.2}/requirements/requirements-hf.txt +0 -0
- {birder-0.2.1 → birder-0.2.2}/requirements/requirements.txt +0 -0
- {birder-0.2.1 → birder-0.2.2}/setup.cfg +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_collators.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_datasets.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_inference.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_kernels.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_layers.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_model_registry.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_net_detection.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_net_mim.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_net_ssl.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_ops.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_optim.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_results.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_scheduler.py +0 -0
- {birder-0.2.1 → birder-0.2.2}/tests/test_transforms.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -62,6 +62,7 @@ Requires-Dist: MonkeyType~=23.3.0; extra == "dev"
|
|
|
62
62
|
Requires-Dist: mypy~=1.19.1; extra == "dev"
|
|
63
63
|
Requires-Dist: parameterized~=0.9.0; extra == "dev"
|
|
64
64
|
Requires-Dist: pylint~=4.0.4; extra == "dev"
|
|
65
|
+
Requires-Dist: pytest; extra == "dev"
|
|
65
66
|
Requires-Dist: requests~=2.32.5; extra == "dev"
|
|
66
67
|
Requires-Dist: safetensors~=0.7.0; extra == "dev"
|
|
67
68
|
Requires-Dist: setuptools; extra == "dev"
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from birder.adversarial.base import AttackResult
|
|
2
|
+
from birder.adversarial.deepfool import DeepFool
|
|
3
|
+
from birder.adversarial.fgsm import FGSM
|
|
4
|
+
from birder.adversarial.pgd import PGD
|
|
5
|
+
from birder.adversarial.simba import SimBA
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"AttackResult",
|
|
9
|
+
"DeepFool",
|
|
10
|
+
"FGSM",
|
|
11
|
+
"PGD",
|
|
12
|
+
"SimBA",
|
|
13
|
+
]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from birder.data.transforms.classification import RGBType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class AttackResult:
|
|
12
|
+
adv_inputs: torch.Tensor
|
|
13
|
+
adv_logits: torch.Tensor
|
|
14
|
+
perturbation: torch.Tensor
|
|
15
|
+
logits: Optional[torch.Tensor] = None
|
|
16
|
+
success: Optional[torch.Tensor] = None
|
|
17
|
+
num_queries: Optional[int] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Attack(Protocol):
|
|
21
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _to_channel_tensor(
|
|
25
|
+
values: tuple[float, float, float], device: Optional[torch.device], dtype: Optional[torch.dtype]
|
|
26
|
+
) -> torch.Tensor:
|
|
27
|
+
return torch.tensor(values, device=device, dtype=dtype).view(1, -1, 1, 1)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def normalized_bounds(
|
|
31
|
+
rgb_stats: RGBType, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
|
32
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
33
|
+
mean = _to_channel_tensor(rgb_stats["mean"], device=device, dtype=dtype)
|
|
34
|
+
std = _to_channel_tensor(rgb_stats["std"], device=device, dtype=dtype)
|
|
35
|
+
min_val = (0.0 - mean) / std
|
|
36
|
+
max_val = (1.0 - mean) / std
|
|
37
|
+
|
|
38
|
+
return (min_val, max_val)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def pixel_eps_to_normalized(
|
|
42
|
+
eps: float | torch.Tensor,
|
|
43
|
+
rgb_stats: RGBType,
|
|
44
|
+
device: Optional[torch.device] = None,
|
|
45
|
+
dtype: Optional[torch.dtype] = None,
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
eps_tensor = torch.as_tensor(eps, device=device, dtype=dtype)
|
|
48
|
+
std = _to_channel_tensor(rgb_stats["std"], device=eps_tensor.device, dtype=eps_tensor.dtype)
|
|
49
|
+
|
|
50
|
+
if eps_tensor.numel() == 1:
|
|
51
|
+
eps_tensor = eps_tensor.reshape(1, 1, 1, 1)
|
|
52
|
+
else:
|
|
53
|
+
eps_tensor = eps_tensor.reshape(1, -1, 1, 1)
|
|
54
|
+
|
|
55
|
+
return eps_tensor / std
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def clamp_normalized(inputs: torch.Tensor, rgb_stats: RGBType) -> torch.Tensor:
|
|
59
|
+
(min_val, max_val) = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
60
|
+
return torch.clamp(inputs, min=min_val, max=max_val)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def predict_labels(logits: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
return torch.argmax(logits, dim=1)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def validate_target(
|
|
68
|
+
target: Optional[torch.Tensor], batch_size: int, num_classes: int, device: torch.device
|
|
69
|
+
) -> Optional[torch.Tensor]:
|
|
70
|
+
if target is None:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
target = target.to(device=device, dtype=torch.long)
|
|
74
|
+
if target.ndim == 0:
|
|
75
|
+
target = target.view(1)
|
|
76
|
+
|
|
77
|
+
if target.shape[0] != batch_size:
|
|
78
|
+
raise ValueError(f"Target shape {target.shape[0]} must match batch size {batch_size}")
|
|
79
|
+
|
|
80
|
+
if torch.any(target < 0) or torch.any(target >= num_classes):
|
|
81
|
+
raise ValueError(f"Target values must be in range [0, {num_classes})")
|
|
82
|
+
|
|
83
|
+
return target
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def attack_success(
|
|
87
|
+
logits: torch.Tensor,
|
|
88
|
+
adv_logits: torch.Tensor,
|
|
89
|
+
targeted: bool,
|
|
90
|
+
target: Optional[torch.Tensor] = None,
|
|
91
|
+
labels: Optional[torch.Tensor] = None,
|
|
92
|
+
) -> torch.Tensor:
|
|
93
|
+
adv_pred = predict_labels(adv_logits)
|
|
94
|
+
if targeted is True:
|
|
95
|
+
if target is None:
|
|
96
|
+
raise ValueError("Target labels required for targeted attacks")
|
|
97
|
+
|
|
98
|
+
return adv_pred.eq(target)
|
|
99
|
+
|
|
100
|
+
base_labels = labels if labels is not None else predict_labels(logits)
|
|
101
|
+
return adv_pred.ne(base_labels)
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DeepFool
|
|
3
|
+
|
|
4
|
+
Paper "DeepFool: a simple and accurate method to fool deep neural networks", https://arxiv.org/abs/1511.04599
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from birder.adversarial.base import AttackResult
|
|
13
|
+
from birder.adversarial.base import attack_success
|
|
14
|
+
from birder.adversarial.base import clamp_normalized
|
|
15
|
+
from birder.adversarial.base import predict_labels
|
|
16
|
+
from birder.adversarial.base import validate_target
|
|
17
|
+
from birder.data.transforms.classification import RGBType
|
|
18
|
+
|
|
19
|
+
GRAD_EPS = 1e-12
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DeepFool:
|
|
23
|
+
def __init__(
|
|
24
|
+
self, net: nn.Module, num_classes: int = 10, overshoot: float = 0.02, max_iter: int = 50, *, rgb_stats: RGBType
|
|
25
|
+
) -> None:
|
|
26
|
+
if num_classes < 2:
|
|
27
|
+
raise ValueError("num_classes must be at least 2")
|
|
28
|
+
if max_iter <= 0:
|
|
29
|
+
raise ValueError("max_iter must be positive")
|
|
30
|
+
if overshoot < 0:
|
|
31
|
+
raise ValueError("overshoot must be non-negative")
|
|
32
|
+
|
|
33
|
+
self.net = net.eval()
|
|
34
|
+
self.num_classes = num_classes
|
|
35
|
+
self.overshoot = overshoot
|
|
36
|
+
self.max_iter = max_iter
|
|
37
|
+
self.rgb_stats = rgb_stats
|
|
38
|
+
|
|
39
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
40
|
+
inputs = input_tensor.detach()
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
logits = self.net(inputs)
|
|
43
|
+
|
|
44
|
+
target_labels = (
|
|
45
|
+
validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
|
|
46
|
+
)
|
|
47
|
+
targeted = target_labels is not None
|
|
48
|
+
|
|
49
|
+
adv_inputs_list = []
|
|
50
|
+
for idx in range(inputs.size(0)):
|
|
51
|
+
target_label = target_labels[idx : idx + 1] if target_labels is not None else None
|
|
52
|
+
adv_input = self._attack_single(inputs[idx : idx + 1], logits[idx : idx + 1], target_label)
|
|
53
|
+
adv_inputs_list.append(adv_input)
|
|
54
|
+
|
|
55
|
+
adv_inputs = torch.concat(adv_inputs_list, dim=0)
|
|
56
|
+
with torch.no_grad():
|
|
57
|
+
adv_logits = self.net(adv_inputs)
|
|
58
|
+
|
|
59
|
+
success = attack_success(
|
|
60
|
+
logits,
|
|
61
|
+
adv_logits,
|
|
62
|
+
targeted,
|
|
63
|
+
target=target_labels if targeted else None,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return AttackResult(
|
|
67
|
+
adv_inputs=adv_inputs,
|
|
68
|
+
adv_logits=adv_logits,
|
|
69
|
+
perturbation=adv_inputs - inputs,
|
|
70
|
+
logits=logits.detach(),
|
|
71
|
+
success=success,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _attack_single(
|
|
75
|
+
self, inputs: torch.Tensor, logits: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
76
|
+
) -> torch.Tensor:
|
|
77
|
+
adv_inputs = inputs.clone()
|
|
78
|
+
original_label = int(predict_labels(logits).item())
|
|
79
|
+
targeted = target_label is not None
|
|
80
|
+
for _ in range(self.max_iter):
|
|
81
|
+
adv_inputs.requires_grad_(True)
|
|
82
|
+
outputs = self.net(adv_inputs)
|
|
83
|
+
current_label = int(predict_labels(outputs).item())
|
|
84
|
+
|
|
85
|
+
if targeted is True:
|
|
86
|
+
assert target_label is not None
|
|
87
|
+
target_value = int(target_label.item())
|
|
88
|
+
if current_label == target_value:
|
|
89
|
+
break
|
|
90
|
+
|
|
91
|
+
perturbation = self._targeted_perturbation(adv_inputs, outputs, current_label, target_value)
|
|
92
|
+
|
|
93
|
+
else:
|
|
94
|
+
if current_label != original_label:
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
perturbation = self._untargeted_perturbation(adv_inputs, outputs, current_label)
|
|
98
|
+
|
|
99
|
+
if perturbation is None:
|
|
100
|
+
break
|
|
101
|
+
|
|
102
|
+
# Overshoot helps ensure boundary crossing
|
|
103
|
+
adv_inputs = adv_inputs.detach() + (1.0 + self.overshoot) * perturbation
|
|
104
|
+
adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
|
|
105
|
+
|
|
106
|
+
return adv_inputs.detach()
|
|
107
|
+
|
|
108
|
+
def _targeted_perturbation(
|
|
109
|
+
self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int, target_label: int
|
|
110
|
+
) -> Optional[torch.Tensor]:
|
|
111
|
+
self.net.zero_grad(set_to_none=True)
|
|
112
|
+
grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
|
|
113
|
+
grad_target = torch.autograd.grad(outputs[0, target_label], adv_inputs, retain_graph=False)[0]
|
|
114
|
+
|
|
115
|
+
# Direction toward the target boundary
|
|
116
|
+
w = grad_target - grad_current
|
|
117
|
+
w_norm = torch.norm(w.view(-1))
|
|
118
|
+
if w_norm.item() < GRAD_EPS:
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
# Distance to the decision boundary
|
|
122
|
+
f = outputs[0, target_label] - outputs[0, current_label]
|
|
123
|
+
perturbation = (f.abs() / (w_norm**2 + GRAD_EPS)) * w
|
|
124
|
+
|
|
125
|
+
return perturbation
|
|
126
|
+
|
|
127
|
+
def _untargeted_perturbation(
|
|
128
|
+
self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int
|
|
129
|
+
) -> Optional[torch.Tensor]:
|
|
130
|
+
# Search the top-k competing classes
|
|
131
|
+
top_k = min(self.num_classes, outputs.shape[1])
|
|
132
|
+
top_indices = torch.topk(outputs, k=top_k, dim=1).indices[0]
|
|
133
|
+
candidate_labels = [int(idx) for idx in top_indices if int(idx) != current_label]
|
|
134
|
+
|
|
135
|
+
if len(candidate_labels) == 0:
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
self.net.zero_grad(set_to_none=True)
|
|
139
|
+
grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
|
|
140
|
+
|
|
141
|
+
# Track the closest decision boundary
|
|
142
|
+
best_dist = None
|
|
143
|
+
best_w = None
|
|
144
|
+
best_f = None
|
|
145
|
+
for idx, label in enumerate(candidate_labels):
|
|
146
|
+
# Keep the graph until the last class
|
|
147
|
+
retain_graph = idx != len(candidate_labels) - 1
|
|
148
|
+
grad_other = torch.autograd.grad(outputs[0, label], adv_inputs, retain_graph=retain_graph)[0]
|
|
149
|
+
|
|
150
|
+
w_k = grad_other - grad_current
|
|
151
|
+
w_norm = torch.norm(w_k.view(-1))
|
|
152
|
+
if w_norm.item() < GRAD_EPS:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
f_k = outputs[0, label] - outputs[0, current_label]
|
|
156
|
+
dist = f_k.abs() / (w_norm + GRAD_EPS)
|
|
157
|
+
|
|
158
|
+
if best_dist is None or dist < best_dist:
|
|
159
|
+
best_dist = dist
|
|
160
|
+
best_w = w_k
|
|
161
|
+
best_f = f_k
|
|
162
|
+
|
|
163
|
+
if best_w is None or best_f is None:
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
# Minimal perturbation toward the closest boundary
|
|
167
|
+
best_w_norm = torch.norm(best_w.view(-1))
|
|
168
|
+
if best_w_norm.item() < GRAD_EPS:
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
perturbation = (best_f.abs() / (best_w_norm**2 + GRAD_EPS)) * best_w
|
|
172
|
+
|
|
173
|
+
return perturbation
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fast Gradient Sign Method (FGSM)
|
|
3
|
+
|
|
4
|
+
Paper "Explaining and Harnessing Adversarial Examples", https://arxiv.org/abs/1412.6572
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from birder.adversarial.base import AttackResult
|
|
14
|
+
from birder.adversarial.base import attack_success
|
|
15
|
+
from birder.adversarial.base import clamp_normalized
|
|
16
|
+
from birder.adversarial.base import pixel_eps_to_normalized
|
|
17
|
+
from birder.adversarial.base import predict_labels
|
|
18
|
+
from birder.adversarial.base import validate_target
|
|
19
|
+
from birder.data.transforms.classification import RGBType
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FGSM:
|
|
23
|
+
def __init__(self, net: nn.Module, eps: float, *, rgb_stats: RGBType) -> None:
|
|
24
|
+
self.net = net.eval()
|
|
25
|
+
self.eps = eps
|
|
26
|
+
self.rgb_stats = rgb_stats
|
|
27
|
+
|
|
28
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
29
|
+
inputs = input_tensor.detach().clone()
|
|
30
|
+
inputs.requires_grad_(True)
|
|
31
|
+
|
|
32
|
+
logits = self.net(inputs)
|
|
33
|
+
targeted = target is not None
|
|
34
|
+
if targeted is True:
|
|
35
|
+
target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
|
|
36
|
+
else:
|
|
37
|
+
target = predict_labels(logits)
|
|
38
|
+
|
|
39
|
+
loss = F.cross_entropy(logits, target)
|
|
40
|
+
(grad,) = torch.autograd.grad(loss, inputs, retain_graph=False, create_graph=False)
|
|
41
|
+
eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
42
|
+
|
|
43
|
+
# Targeted steps descend toward target, untargeted ascend away from original
|
|
44
|
+
if targeted is True:
|
|
45
|
+
direction = -1.0
|
|
46
|
+
else:
|
|
47
|
+
direction = 1.0
|
|
48
|
+
|
|
49
|
+
perturbation = direction * eps_norm * grad.sign()
|
|
50
|
+
adv_inputs = clamp_normalized(inputs + perturbation, self.rgb_stats)
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
adv_logits = self.net(adv_inputs)
|
|
53
|
+
|
|
54
|
+
success = attack_success(
|
|
55
|
+
logits.detach(),
|
|
56
|
+
adv_logits,
|
|
57
|
+
targeted,
|
|
58
|
+
target=target if targeted else None,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return AttackResult(
|
|
62
|
+
adv_inputs=adv_inputs,
|
|
63
|
+
adv_logits=adv_logits,
|
|
64
|
+
perturbation=adv_inputs - inputs,
|
|
65
|
+
logits=logits.detach(),
|
|
66
|
+
success=success,
|
|
67
|
+
)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Projected Gradient Descent (PGD)
|
|
3
|
+
|
|
4
|
+
Paper "Towards Deep Learning Models Resistant to Adversarial Attacks", https://arxiv.org/abs/1706.06083
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Reference license: MIT
|
|
8
|
+
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
from birder.adversarial.base import AttackResult
|
|
16
|
+
from birder.adversarial.base import attack_success
|
|
17
|
+
from birder.adversarial.base import clamp_normalized
|
|
18
|
+
from birder.adversarial.base import pixel_eps_to_normalized
|
|
19
|
+
from birder.adversarial.base import predict_labels
|
|
20
|
+
from birder.adversarial.base import validate_target
|
|
21
|
+
from birder.data.transforms.classification import RGBType
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PGD:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
net: nn.Module,
|
|
28
|
+
eps: float,
|
|
29
|
+
steps: int = 10,
|
|
30
|
+
step_size: Optional[float] = None,
|
|
31
|
+
random_start: bool = False,
|
|
32
|
+
*,
|
|
33
|
+
rgb_stats: RGBType,
|
|
34
|
+
) -> None:
|
|
35
|
+
if steps <= 0:
|
|
36
|
+
raise ValueError("steps must be a positive integer")
|
|
37
|
+
|
|
38
|
+
self.net = net.eval()
|
|
39
|
+
self.eps = eps
|
|
40
|
+
self.steps = steps
|
|
41
|
+
if step_size is not None:
|
|
42
|
+
self.step_size = step_size
|
|
43
|
+
else:
|
|
44
|
+
self.step_size = eps / steps
|
|
45
|
+
|
|
46
|
+
self.random_start = random_start
|
|
47
|
+
self.rgb_stats = rgb_stats
|
|
48
|
+
|
|
49
|
+
if self.step_size <= 0:
|
|
50
|
+
raise ValueError("step_size must be positive")
|
|
51
|
+
|
|
52
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
53
|
+
inputs = input_tensor.detach()
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
logits = self.net(inputs)
|
|
56
|
+
|
|
57
|
+
targeted = target is not None
|
|
58
|
+
if targeted:
|
|
59
|
+
target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
|
|
60
|
+
else:
|
|
61
|
+
target = predict_labels(logits)
|
|
62
|
+
|
|
63
|
+
eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
64
|
+
step_norm = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
65
|
+
|
|
66
|
+
# Targeted steps descend toward target, untargeted ascend away from original
|
|
67
|
+
if targeted is True:
|
|
68
|
+
direction = -1.0
|
|
69
|
+
else:
|
|
70
|
+
direction = 1.0
|
|
71
|
+
|
|
72
|
+
adv_inputs = inputs.clone()
|
|
73
|
+
if self.random_start is True:
|
|
74
|
+
# Random start inside the epsilon ball
|
|
75
|
+
adv_inputs = adv_inputs + torch.empty_like(adv_inputs).uniform_(-1.0, 1.0) * eps_norm
|
|
76
|
+
adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
|
|
77
|
+
|
|
78
|
+
for _ in range(self.steps):
|
|
79
|
+
adv_inputs.requires_grad_(True)
|
|
80
|
+
adv_logits = self.net(adv_inputs)
|
|
81
|
+
loss = F.cross_entropy(adv_logits, target)
|
|
82
|
+
(grad,) = torch.autograd.grad(loss, adv_inputs, retain_graph=False, create_graph=False)
|
|
83
|
+
adv_inputs = adv_inputs.detach() + direction * step_norm * grad.sign()
|
|
84
|
+
|
|
85
|
+
# Project back into the epsilon ball around the original input.
|
|
86
|
+
delta = torch.clamp(adv_inputs - inputs, min=-eps_norm, max=eps_norm)
|
|
87
|
+
adv_inputs = clamp_normalized(inputs + delta, self.rgb_stats)
|
|
88
|
+
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
adv_logits = self.net(adv_inputs)
|
|
91
|
+
|
|
92
|
+
success = attack_success(
|
|
93
|
+
logits.detach(),
|
|
94
|
+
adv_logits,
|
|
95
|
+
targeted,
|
|
96
|
+
target=target if targeted else None,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return AttackResult(
|
|
100
|
+
adv_inputs=adv_inputs,
|
|
101
|
+
adv_logits=adv_logits,
|
|
102
|
+
perturbation=adv_inputs - inputs,
|
|
103
|
+
logits=logits.detach(),
|
|
104
|
+
success=success,
|
|
105
|
+
)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SimBA (Simple Black-box Attack)
|
|
3
|
+
|
|
4
|
+
Paper "Simple Black-box Adversarial Attacks", https://arxiv.org/abs/1905.07121
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from birder.adversarial.base import AttackResult
|
|
14
|
+
from birder.adversarial.base import attack_success
|
|
15
|
+
from birder.adversarial.base import clamp_normalized
|
|
16
|
+
from birder.adversarial.base import pixel_eps_to_normalized
|
|
17
|
+
from birder.adversarial.base import predict_labels
|
|
18
|
+
from birder.adversarial.base import validate_target
|
|
19
|
+
from birder.data.transforms.classification import RGBType
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SimBA:
|
|
23
|
+
def __init__(self, net: nn.Module, step_size: float, max_iter: int = 1000, *, rgb_stats: RGBType) -> None:
|
|
24
|
+
if step_size <= 0:
|
|
25
|
+
raise ValueError("step_size must be positive")
|
|
26
|
+
if max_iter <= 0:
|
|
27
|
+
raise ValueError("max_iter must be positive")
|
|
28
|
+
|
|
29
|
+
self.net = net.eval()
|
|
30
|
+
self.step_size = step_size
|
|
31
|
+
self.max_iter = max_iter
|
|
32
|
+
self.rgb_stats = rgb_stats
|
|
33
|
+
|
|
34
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
35
|
+
inputs = input_tensor.detach()
|
|
36
|
+
with torch.no_grad():
|
|
37
|
+
logits = self.net(inputs)
|
|
38
|
+
|
|
39
|
+
labels = predict_labels(logits)
|
|
40
|
+
target_labels = (
|
|
41
|
+
validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
|
|
42
|
+
)
|
|
43
|
+
targeted = target_labels is not None
|
|
44
|
+
|
|
45
|
+
adv_inputs_list = []
|
|
46
|
+
total_queries = 0
|
|
47
|
+
for idx in range(inputs.size(0)):
|
|
48
|
+
label = labels[idx : idx + 1]
|
|
49
|
+
target_label = target_labels[idx : idx + 1] if target_labels is not None else None
|
|
50
|
+
adv_input, num_queries = self._attack_single(inputs[idx : idx + 1], label, target_label)
|
|
51
|
+
adv_inputs_list.append(adv_input)
|
|
52
|
+
total_queries += num_queries
|
|
53
|
+
|
|
54
|
+
adv_inputs = torch.concat(adv_inputs_list, dim=0)
|
|
55
|
+
with torch.no_grad():
|
|
56
|
+
adv_logits = self.net(adv_inputs)
|
|
57
|
+
|
|
58
|
+
success = attack_success(
|
|
59
|
+
logits,
|
|
60
|
+
adv_logits,
|
|
61
|
+
targeted,
|
|
62
|
+
target=target_labels if targeted else None,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return AttackResult(
|
|
66
|
+
adv_inputs=adv_inputs,
|
|
67
|
+
adv_logits=adv_logits,
|
|
68
|
+
perturbation=adv_inputs - inputs,
|
|
69
|
+
logits=logits.detach(),
|
|
70
|
+
success=success,
|
|
71
|
+
num_queries=total_queries,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# pylint: disable=too-many-locals
|
|
75
|
+
def _attack_single(
|
|
76
|
+
self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
77
|
+
) -> tuple[torch.Tensor, int]:
|
|
78
|
+
adv_inputs = inputs.clone()
|
|
79
|
+
num_queries = 1 # Baseline forward pass
|
|
80
|
+
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
current_logits = self.net(adv_inputs)
|
|
83
|
+
current_objective = self._compute_objective(current_logits, label, target_label)
|
|
84
|
+
|
|
85
|
+
if self._is_successful(current_logits, label, target_label):
|
|
86
|
+
return adv_inputs.detach(), num_queries
|
|
87
|
+
|
|
88
|
+
(_, channels, height, width) = adv_inputs.shape
|
|
89
|
+
num_dims = channels * height * width
|
|
90
|
+
step = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=adv_inputs.device, dtype=adv_inputs.dtype)
|
|
91
|
+
step_vals = step.view(-1) # Per-channel steps
|
|
92
|
+
stride = height * width
|
|
93
|
+
|
|
94
|
+
perm = torch.randperm(num_dims, device=adv_inputs.device)
|
|
95
|
+
num_steps = min(self.max_iter, num_dims)
|
|
96
|
+
|
|
97
|
+
# Coordinate-wise search in random order
|
|
98
|
+
for flat_idx in perm[:num_steps]:
|
|
99
|
+
(c, rem) = divmod(int(flat_idx.item()), stride)
|
|
100
|
+
(h, w) = divmod(rem, width)
|
|
101
|
+
step_val = step_vals[c]
|
|
102
|
+
|
|
103
|
+
(candidate_inputs, candidate_logits, candidate_objective) = self._best_candidate(
|
|
104
|
+
adv_inputs, c, h, w, step_val, label, target_label
|
|
105
|
+
)
|
|
106
|
+
num_queries += 2
|
|
107
|
+
|
|
108
|
+
if candidate_objective < current_objective:
|
|
109
|
+
adv_inputs = candidate_inputs
|
|
110
|
+
current_logits = candidate_logits
|
|
111
|
+
current_objective = candidate_objective
|
|
112
|
+
|
|
113
|
+
if self._is_successful(current_logits, label, target_label) is True:
|
|
114
|
+
break
|
|
115
|
+
|
|
116
|
+
return adv_inputs.detach(), num_queries
|
|
117
|
+
|
|
118
|
+
def _perturb_pixel(
|
|
119
|
+
self, inputs: torch.Tensor, channel: int, row: int, col: int, step: torch.Tensor
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
adv_inputs = inputs.clone()
|
|
122
|
+
adv_inputs[0, channel, row, col] = adv_inputs[0, channel, row, col] + step
|
|
123
|
+
return clamp_normalized(adv_inputs, self.rgb_stats)
|
|
124
|
+
|
|
125
|
+
def _evaluate_candidate(
|
|
126
|
+
self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
127
|
+
) -> tuple[torch.Tensor, float]:
|
|
128
|
+
with torch.no_grad():
|
|
129
|
+
logits = self.net(inputs)
|
|
130
|
+
|
|
131
|
+
return logits, self._compute_objective(logits, label, target_label)
|
|
132
|
+
|
|
133
|
+
def _best_candidate(
|
|
134
|
+
self,
|
|
135
|
+
inputs: torch.Tensor,
|
|
136
|
+
channel: int,
|
|
137
|
+
row: int,
|
|
138
|
+
col: int,
|
|
139
|
+
step: torch.Tensor,
|
|
140
|
+
label: torch.Tensor,
|
|
141
|
+
target_label: Optional[torch.Tensor],
|
|
142
|
+
) -> tuple[torch.Tensor, torch.Tensor, float]:
|
|
143
|
+
adv_plus = self._perturb_pixel(inputs, channel, row, col, step)
|
|
144
|
+
logits_plus, objective_plus = self._evaluate_candidate(adv_plus, label, target_label)
|
|
145
|
+
|
|
146
|
+
adv_minus = self._perturb_pixel(inputs, channel, row, col, -step)
|
|
147
|
+
logits_minus, objective_minus = self._evaluate_candidate(adv_minus, label, target_label)
|
|
148
|
+
|
|
149
|
+
if objective_plus <= objective_minus:
|
|
150
|
+
return adv_plus, logits_plus, objective_plus
|
|
151
|
+
|
|
152
|
+
return adv_minus, logits_minus, objective_minus
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def _compute_objective(
|
|
156
|
+
logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
157
|
+
) -> float:
|
|
158
|
+
# Lower objective is better in both modes
|
|
159
|
+
if target_label is not None:
|
|
160
|
+
return float(F.cross_entropy(logits, target_label).item())
|
|
161
|
+
|
|
162
|
+
return -float(F.cross_entropy(logits, original_label).item())
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def _is_successful(
|
|
166
|
+
logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
167
|
+
) -> bool:
|
|
168
|
+
pred = predict_labels(logits)
|
|
169
|
+
if target_label is not None:
|
|
170
|
+
return bool(pred.eq(target_label).item())
|
|
171
|
+
|
|
172
|
+
return bool(pred.ne(original_label).item())
|
|
@@ -110,10 +110,13 @@ def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
|
|
|
110
110
|
type=int,
|
|
111
111
|
default=40,
|
|
112
112
|
metavar="N",
|
|
113
|
-
help="decrease lr every
|
|
113
|
+
help="decrease lr every N epochs/steps (relative to after warmup, step scheduler only)",
|
|
114
114
|
)
|
|
115
115
|
group.add_argument(
|
|
116
|
-
"--lr-steps",
|
|
116
|
+
"--lr-steps",
|
|
117
|
+
type=int,
|
|
118
|
+
nargs="+",
|
|
119
|
+
help="absolute epoch/step milestones when to decrease lr (multistep scheduler only)",
|
|
117
120
|
)
|
|
118
121
|
group.add_argument(
|
|
119
122
|
"--lr-step-gamma",
|
|
@@ -391,7 +394,7 @@ def add_ema_args(
|
|
|
391
394
|
"--model-ema-warmup",
|
|
392
395
|
type=int,
|
|
393
396
|
metavar="N",
|
|
394
|
-
help="number of epochs before EMA is applied (defaults to warmup epochs/
|
|
397
|
+
help="number of epochs/steps before EMA is applied (defaults to warmup epochs/steps, pass 0 to disable warmup)",
|
|
395
398
|
)
|
|
396
399
|
|
|
397
400
|
|
|
@@ -656,6 +659,11 @@ def common_args_validation(args: argparse.Namespace) -> None:
|
|
|
656
659
|
f"but it is set to '{args.lr_scheduler_update}'"
|
|
657
660
|
)
|
|
658
661
|
|
|
662
|
+
# EMA
|
|
663
|
+
if hasattr(args, "model_ema_steps") is True:
|
|
664
|
+
if args.model_ema_steps < 1:
|
|
665
|
+
raise ValidationError("--model-ema-steps must be >= 1")
|
|
666
|
+
|
|
659
667
|
# Compile args, argument dependant
|
|
660
668
|
if hasattr(args, "compile_teacher") is True:
|
|
661
669
|
if args.compile is True and args.compile_teacher is True:
|