birder 0.2.1__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {birder-0.2.1 → birder-0.2.3}/PKG-INFO +4 -3
- {birder-0.2.1 → birder-0.2.3}/README.md +1 -1
- birder-0.2.3/birder/adversarial/__init__.py +13 -0
- birder-0.2.3/birder/adversarial/base.py +101 -0
- birder-0.2.3/birder/adversarial/deepfool.py +173 -0
- birder-0.2.3/birder/adversarial/fgsm.py +67 -0
- birder-0.2.3/birder/adversarial/pgd.py +105 -0
- birder-0.2.3/birder/adversarial/simba.py +172 -0
- {birder-0.2.1 → birder-0.2.3}/birder/common/lib.py +2 -9
- {birder-0.2.1 → birder-0.2.3}/birder/common/training_cli.py +29 -3
- {birder-0.2.1 → birder-0.2.3}/birder/common/training_utils.py +141 -11
- {birder-0.2.1 → birder-0.2.3}/birder/data/collators/detection.py +10 -3
- {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/coco.py +8 -10
- {birder-0.2.1 → birder-0.2.3}/birder/data/transforms/detection.py +30 -13
- {birder-0.2.1 → birder-0.2.3}/birder/inference/data_parallel.py +1 -2
- {birder-0.2.1 → birder-0.2.3}/birder/inference/detection.py +108 -4
- birder-0.2.3/birder/inference/wbf.py +226 -0
- birder-0.2.3/birder/introspection/__init__.py +13 -0
- birder-0.2.3/birder/introspection/attention_rollout.py +185 -0
- birder-0.2.3/birder/introspection/base.py +104 -0
- birder-0.2.3/birder/introspection/gradcam.py +147 -0
- birder-0.2.3/birder/introspection/guided_backprop.py +229 -0
- birder-0.2.3/birder/introspection/transformer_attribution.py +182 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/__init__.py +8 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/deformable_detr.py +14 -12
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/detr.py +7 -3
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/efficientdet.py +65 -86
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/rt_detr_v1.py +4 -3
- birder-0.2.3/birder/net/detection/yolo_anchors.py +205 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v2.py +25 -24
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v3.py +42 -48
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v4.py +31 -40
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v4_tiny.py +24 -20
- {birder-0.2.1 → birder-0.2.3}/birder/net/fasternet.py +1 -1
- birder-0.2.3/birder/net/gc_vit.py +671 -0
- birder-0.2.3/birder/net/lit_v1.py +472 -0
- birder-0.2.3/birder/net/lit_v1_tiny.py +342 -0
- birder-0.2.3/birder/net/lit_v2.py +436 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mim/mae_vit.py +7 -8
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v4_hybrid.py +1 -1
- {birder-0.2.1 → birder-0.2.3}/birder/net/pit.py +1 -1
- {birder-0.2.1 → birder-0.2.3}/birder/net/resnet_v1.py +95 -35
- {birder-0.2.1 → birder-0.2.3}/birder/net/resnext.py +67 -25
- {birder-0.2.1 → birder-0.2.3}/birder/net/se_resnet_v1.py +46 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/se_resnext.py +3 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/simple_vit.py +2 -2
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/data2vec.py +1 -1
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/data2vec2.py +4 -2
- {birder-0.2.1 → birder-0.2.3}/birder/net/vit.py +0 -15
- {birder-0.2.1 → birder-0.2.3}/birder/net/vovnet_v2.py +31 -1
- {birder-0.2.1 → birder-0.2.3}/birder/results/gui.py +15 -2
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/benchmark.py +90 -21
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/predict.py +1 -0
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/predict_detection.py +48 -9
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train.py +33 -50
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_barlow_twins.py +19 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_byol.py +19 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_capi.py +21 -43
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_data2vec.py +18 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_data2vec2.py +18 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_detection.py +89 -57
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_dino_v1.py +19 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_dino_v2.py +18 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_dino_v2_dist.py +25 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_franca.py +18 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_i_jepa.py +25 -46
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_ibot.py +18 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_kd.py +179 -81
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_mim.py +20 -43
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_mmcr.py +19 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_rotnet.py +19 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_simclr.py +19 -40
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_vicreg.py +19 -40
- {birder-0.2.1 → birder-0.2.3}/birder/tools/__main__.py +6 -2
- birder-0.2.3/birder/tools/adversarial.py +214 -0
- birder-0.2.3/birder/tools/auto_anchors.py +380 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/ensemble_model.py +1 -1
- {birder-0.2.1 → birder-0.2.3}/birder/tools/introspection.py +58 -31
- {birder-0.2.1 → birder-0.2.3}/birder/tools/pack.py +172 -103
- {birder-0.2.1 → birder-0.2.3}/birder/tools/show_det_iterator.py +10 -1
- birder-0.2.3/birder/version.py +1 -0
- {birder-0.2.1 → birder-0.2.3}/birder.egg-info/PKG-INFO +4 -3
- {birder-0.2.1 → birder-0.2.3}/birder.egg-info/SOURCES.txt +13 -0
- {birder-0.2.1 → birder-0.2.3}/birder.egg-info/requires.txt +2 -1
- {birder-0.2.1 → birder-0.2.3}/requirements/_requirements-dev.txt +2 -1
- birder-0.2.3/tests/test_adversarial.py +238 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_common.py +202 -14
- {birder-0.2.1 → birder-0.2.3}/tests/test_inference.py +69 -0
- birder-0.2.3/tests/test_introspection.py +310 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_kernels.py +13 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_model_registry.py +2 -2
- {birder-0.2.1 → birder-0.2.3}/tests/test_net.py +237 -176
- {birder-0.2.1 → birder-0.2.3}/tests/test_net_detection.py +44 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_transforms.py +9 -0
- birder-0.2.1/birder/adversarial/fgsm.py +0 -34
- birder-0.2.1/birder/adversarial/pgd.py +0 -54
- birder-0.2.1/birder/introspection/__init__.py +0 -9
- birder-0.2.1/birder/introspection/attention_rollout.py +0 -117
- birder-0.2.1/birder/introspection/base.py +0 -60
- birder-0.2.1/birder/introspection/gradcam.py +0 -176
- birder-0.2.1/birder/introspection/guided_backprop.py +0 -155
- birder-0.2.1/birder/tools/__init__.py +0 -0
- birder-0.2.1/birder/tools/adversarial.py +0 -163
- birder-0.2.1/birder/version.py +0 -1
- {birder-0.2.1 → birder-0.2.3}/LICENSE +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/__init__.py +0 -0
- {birder-0.2.1/birder/adversarial → birder-0.2.3/birder/common}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/common/cli.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/common/fs_ops.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/common/masking.py +0 -0
- {birder-0.2.1/birder/common → birder-0.2.3/birder/conf}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/conf/settings.py +0 -0
- {birder-0.2.1/birder/conf → birder-0.2.3/birder/data}/__init__.py +0 -0
- {birder-0.2.1/birder/data → birder-0.2.3/birder/data/collators}/__init__.py +0 -0
- {birder-0.2.1/birder/data/collators → birder-0.2.3/birder/data/dataloader}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/data/dataloader/webdataset.py +0 -0
- {birder-0.2.1/birder/data/dataloader → birder-0.2.3/birder/data/datasets}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/directory.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/fake.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/webdataset.py +0 -0
- {birder-0.2.1/birder/data/datasets → birder-0.2.3/birder/data/transforms}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/data/transforms/classification.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/data/transforms/mosaic.py +0 -0
- {birder-0.2.1/birder/data/transforms → birder-0.2.3/birder/datahub}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/datahub/_lib.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/datahub/classification.py +0 -0
- {birder-0.2.1/birder/datahub → birder-0.2.3/birder/inference}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/inference/classification.py +0 -0
- {birder-0.2.1/birder/inference → birder-0.2.3/birder/kernels}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/vision.cpp +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/load_kernel.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/soft_nms/op.cpp +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/soft_nms/soft_nms.cpp +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/soft_nms/soft_nms.h +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/swattention.cpp +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/layers/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/layers/activations.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/layers/attention_pool.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/layers/ffn.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/layers/gem.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/layers/layer_norm.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/layers/layer_scale.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/model_registry/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/model_registry/manifest.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/model_registry/model_registry.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/alexnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/base.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/biformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/cait.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/cas_vit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/coat.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/conv2former.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/convmixer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/convnext_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/convnext_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/crossformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/crossvit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/cspnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/cswin_transformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/darknet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/davit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/deit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/deit3.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/densenet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/base.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/faster_rcnn.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/fcos.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/retinanet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/ssd.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/ssdlite.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/detection/vitdet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/dpn.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/edgenext.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/edgevit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientformer_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientformer_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientnet_lite.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientvim.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientvit_mit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/efficientvit_msft.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/fastvit.py +1 -1
- {birder-0.2.1 → birder-0.2.3}/birder/net/flexivit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/focalnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ghostnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ghostnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/groupmixformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/hgnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/hgnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/hiera.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/hieradet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/hornet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/iformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/inception_next.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/inception_resnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/inception_resnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/inception_v3.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/inception_v4.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/levit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/maxvit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/metaformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mim/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mim/base.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mim/crossmae.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mim/fcmae.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mim/mae_hiera.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mim/simmim.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mnasnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v3_large.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v3_small.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v4.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobileone.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilevit_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mobilevit_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/moganet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/mvit_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/nextvit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/nfnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/pvt_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/pvt_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/rdnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/regionvit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/regnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/regnet_z.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/repghost.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/repvgg.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/repvit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/resmlp.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/resnest.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/resnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/rope_deit3.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/rope_flexivit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/rope_vit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/se_resnet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/sequencer2d.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/shufflenet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/shufflenet_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/smt.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/squeezenet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/squeezenext.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/barlow_twins.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/base.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/byol.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/capi.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/dino_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/dino_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/franca.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/i_jepa.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/ibot.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/mmcr.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/simclr.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/sscd.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/vicreg.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/starnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/swiftformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/swin_transformer_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/swin_transformer_v2.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/tiny_vit.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/transnext.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/uniformer.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/van.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/vgg.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/vgg_reduced.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/vit_parallel.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/vit_sam.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/vovnet_v1.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/wide_resnet.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/xception.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/net/xcit.py +0 -0
- {birder-0.2.1/birder/kernels → birder-0.2.3/birder/ops}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/ops/msda.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/ops/soft_nms.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/ops/swattention.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/optim/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/optim/lamb.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/optim/lars.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/py.typed +0 -0
- {birder-0.2.1/birder/ops → birder-0.2.3/birder/results}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/results/classification.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/results/detection.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/scheduler/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/scheduler/cooldown.py +0 -0
- {birder-0.2.1/birder/results → birder-0.2.3/birder/scripts}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/__main__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/scripts/evaluate.py +0 -0
- {birder-0.2.1/birder/scripts → birder-0.2.3/birder/tools}/__init__.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/avg_model.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/convert_model.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/det_results.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/download_model.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/labelme_to_coco.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/list_models.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/model_info.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/quantize_model.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/results.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/show_iterator.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/similarity.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/stats.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/verify_coco.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/verify_directory.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder/tools/voc_to_coco.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder.egg-info/dependency_links.txt +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder.egg-info/entry_points.txt +0 -0
- {birder-0.2.1 → birder-0.2.3}/birder.egg-info/top_level.txt +0 -0
- {birder-0.2.1 → birder-0.2.3}/pyproject.toml +0 -0
- {birder-0.2.1 → birder-0.2.3}/requirements/requirements-hf.txt +0 -0
- {birder-0.2.1 → birder-0.2.3}/requirements/requirements.txt +0 -0
- {birder-0.2.1 → birder-0.2.3}/setup.cfg +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_collators.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_datasets.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_layers.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_net_mim.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_net_ssl.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_ops.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_optim.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_results.py +0 -0
- {birder-0.2.1 → birder-0.2.3}/tests/test_scheduler.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: birder
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
|
|
5
5
|
Author: Ofer Hasson
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -45,7 +45,7 @@ Provides-Extra: dev
|
|
|
45
45
|
Requires-Dist: altair~=5.5.0; extra == "dev"
|
|
46
46
|
Requires-Dist: bandit~=1.9.2; extra == "dev"
|
|
47
47
|
Requires-Dist: black~=25.12.0; extra == "dev"
|
|
48
|
-
Requires-Dist: build~=1.
|
|
48
|
+
Requires-Dist: build~=1.4.0; extra == "dev"
|
|
49
49
|
Requires-Dist: bumpver~=2025.1131; extra == "dev"
|
|
50
50
|
Requires-Dist: captum~=0.7.0; extra == "dev"
|
|
51
51
|
Requires-Dist: coverage~=7.13.1; extra == "dev"
|
|
@@ -62,6 +62,7 @@ Requires-Dist: MonkeyType~=23.3.0; extra == "dev"
|
|
|
62
62
|
Requires-Dist: mypy~=1.19.1; extra == "dev"
|
|
63
63
|
Requires-Dist: parameterized~=0.9.0; extra == "dev"
|
|
64
64
|
Requires-Dist: pylint~=4.0.4; extra == "dev"
|
|
65
|
+
Requires-Dist: pytest; extra == "dev"
|
|
65
66
|
Requires-Dist: requests~=2.32.5; extra == "dev"
|
|
66
67
|
Requires-Dist: safetensors~=0.7.0; extra == "dev"
|
|
67
68
|
Requires-Dist: setuptools; extra == "dev"
|
|
@@ -207,7 +208,7 @@ For detailed information about these datasets, including descriptions, citations
|
|
|
207
208
|
|
|
208
209
|
## Detection
|
|
209
210
|
|
|
210
|
-
Detection training and inference are available, see [docs/
|
|
211
|
+
Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
|
|
211
212
|
[docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
|
|
212
213
|
|
|
213
214
|
## Project Status and Contributions
|
|
@@ -129,7 +129,7 @@ For detailed information about these datasets, including descriptions, citations
|
|
|
129
129
|
|
|
130
130
|
## Detection
|
|
131
131
|
|
|
132
|
-
Detection training and inference are available, see [docs/
|
|
132
|
+
Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
|
|
133
133
|
[docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
|
|
134
134
|
|
|
135
135
|
## Project Status and Contributions
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from birder.adversarial.base import AttackResult
|
|
2
|
+
from birder.adversarial.deepfool import DeepFool
|
|
3
|
+
from birder.adversarial.fgsm import FGSM
|
|
4
|
+
from birder.adversarial.pgd import PGD
|
|
5
|
+
from birder.adversarial.simba import SimBA
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"AttackResult",
|
|
9
|
+
"DeepFool",
|
|
10
|
+
"FGSM",
|
|
11
|
+
"PGD",
|
|
12
|
+
"SimBA",
|
|
13
|
+
]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from birder.data.transforms.classification import RGBType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class AttackResult:
|
|
12
|
+
adv_inputs: torch.Tensor
|
|
13
|
+
adv_logits: torch.Tensor
|
|
14
|
+
perturbation: torch.Tensor
|
|
15
|
+
logits: Optional[torch.Tensor] = None
|
|
16
|
+
success: Optional[torch.Tensor] = None
|
|
17
|
+
num_queries: Optional[int] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Attack(Protocol):
|
|
21
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _to_channel_tensor(
|
|
25
|
+
values: tuple[float, float, float], device: Optional[torch.device], dtype: Optional[torch.dtype]
|
|
26
|
+
) -> torch.Tensor:
|
|
27
|
+
return torch.tensor(values, device=device, dtype=dtype).view(1, -1, 1, 1)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def normalized_bounds(
|
|
31
|
+
rgb_stats: RGBType, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
|
32
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
33
|
+
mean = _to_channel_tensor(rgb_stats["mean"], device=device, dtype=dtype)
|
|
34
|
+
std = _to_channel_tensor(rgb_stats["std"], device=device, dtype=dtype)
|
|
35
|
+
min_val = (0.0 - mean) / std
|
|
36
|
+
max_val = (1.0 - mean) / std
|
|
37
|
+
|
|
38
|
+
return (min_val, max_val)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def pixel_eps_to_normalized(
|
|
42
|
+
eps: float | torch.Tensor,
|
|
43
|
+
rgb_stats: RGBType,
|
|
44
|
+
device: Optional[torch.device] = None,
|
|
45
|
+
dtype: Optional[torch.dtype] = None,
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
eps_tensor = torch.as_tensor(eps, device=device, dtype=dtype)
|
|
48
|
+
std = _to_channel_tensor(rgb_stats["std"], device=eps_tensor.device, dtype=eps_tensor.dtype)
|
|
49
|
+
|
|
50
|
+
if eps_tensor.numel() == 1:
|
|
51
|
+
eps_tensor = eps_tensor.reshape(1, 1, 1, 1)
|
|
52
|
+
else:
|
|
53
|
+
eps_tensor = eps_tensor.reshape(1, -1, 1, 1)
|
|
54
|
+
|
|
55
|
+
return eps_tensor / std
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def clamp_normalized(inputs: torch.Tensor, rgb_stats: RGBType) -> torch.Tensor:
|
|
59
|
+
(min_val, max_val) = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
60
|
+
return torch.clamp(inputs, min=min_val, max=max_val)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def predict_labels(logits: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
return torch.argmax(logits, dim=1)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def validate_target(
|
|
68
|
+
target: Optional[torch.Tensor], batch_size: int, num_classes: int, device: torch.device
|
|
69
|
+
) -> Optional[torch.Tensor]:
|
|
70
|
+
if target is None:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
target = target.to(device=device, dtype=torch.long)
|
|
74
|
+
if target.ndim == 0:
|
|
75
|
+
target = target.view(1)
|
|
76
|
+
|
|
77
|
+
if target.shape[0] != batch_size:
|
|
78
|
+
raise ValueError(f"Target shape {target.shape[0]} must match batch size {batch_size}")
|
|
79
|
+
|
|
80
|
+
if torch.any(target < 0) or torch.any(target >= num_classes):
|
|
81
|
+
raise ValueError(f"Target values must be in range [0, {num_classes})")
|
|
82
|
+
|
|
83
|
+
return target
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def attack_success(
|
|
87
|
+
logits: torch.Tensor,
|
|
88
|
+
adv_logits: torch.Tensor,
|
|
89
|
+
targeted: bool,
|
|
90
|
+
target: Optional[torch.Tensor] = None,
|
|
91
|
+
labels: Optional[torch.Tensor] = None,
|
|
92
|
+
) -> torch.Tensor:
|
|
93
|
+
adv_pred = predict_labels(adv_logits)
|
|
94
|
+
if targeted is True:
|
|
95
|
+
if target is None:
|
|
96
|
+
raise ValueError("Target labels required for targeted attacks")
|
|
97
|
+
|
|
98
|
+
return adv_pred.eq(target)
|
|
99
|
+
|
|
100
|
+
base_labels = labels if labels is not None else predict_labels(logits)
|
|
101
|
+
return adv_pred.ne(base_labels)
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DeepFool
|
|
3
|
+
|
|
4
|
+
Paper "DeepFool: a simple and accurate method to fool deep neural networks", https://arxiv.org/abs/1511.04599
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from birder.adversarial.base import AttackResult
|
|
13
|
+
from birder.adversarial.base import attack_success
|
|
14
|
+
from birder.adversarial.base import clamp_normalized
|
|
15
|
+
from birder.adversarial.base import predict_labels
|
|
16
|
+
from birder.adversarial.base import validate_target
|
|
17
|
+
from birder.data.transforms.classification import RGBType
|
|
18
|
+
|
|
19
|
+
GRAD_EPS = 1e-12
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DeepFool:
|
|
23
|
+
def __init__(
|
|
24
|
+
self, net: nn.Module, num_classes: int = 10, overshoot: float = 0.02, max_iter: int = 50, *, rgb_stats: RGBType
|
|
25
|
+
) -> None:
|
|
26
|
+
if num_classes < 2:
|
|
27
|
+
raise ValueError("num_classes must be at least 2")
|
|
28
|
+
if max_iter <= 0:
|
|
29
|
+
raise ValueError("max_iter must be positive")
|
|
30
|
+
if overshoot < 0:
|
|
31
|
+
raise ValueError("overshoot must be non-negative")
|
|
32
|
+
|
|
33
|
+
self.net = net.eval()
|
|
34
|
+
self.num_classes = num_classes
|
|
35
|
+
self.overshoot = overshoot
|
|
36
|
+
self.max_iter = max_iter
|
|
37
|
+
self.rgb_stats = rgb_stats
|
|
38
|
+
|
|
39
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
40
|
+
inputs = input_tensor.detach()
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
logits = self.net(inputs)
|
|
43
|
+
|
|
44
|
+
target_labels = (
|
|
45
|
+
validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
|
|
46
|
+
)
|
|
47
|
+
targeted = target_labels is not None
|
|
48
|
+
|
|
49
|
+
adv_inputs_list = []
|
|
50
|
+
for idx in range(inputs.size(0)):
|
|
51
|
+
target_label = target_labels[idx : idx + 1] if target_labels is not None else None
|
|
52
|
+
adv_input = self._attack_single(inputs[idx : idx + 1], logits[idx : idx + 1], target_label)
|
|
53
|
+
adv_inputs_list.append(adv_input)
|
|
54
|
+
|
|
55
|
+
adv_inputs = torch.concat(adv_inputs_list, dim=0)
|
|
56
|
+
with torch.no_grad():
|
|
57
|
+
adv_logits = self.net(adv_inputs)
|
|
58
|
+
|
|
59
|
+
success = attack_success(
|
|
60
|
+
logits,
|
|
61
|
+
adv_logits,
|
|
62
|
+
targeted,
|
|
63
|
+
target=target_labels if targeted else None,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return AttackResult(
|
|
67
|
+
adv_inputs=adv_inputs,
|
|
68
|
+
adv_logits=adv_logits,
|
|
69
|
+
perturbation=adv_inputs - inputs,
|
|
70
|
+
logits=logits.detach(),
|
|
71
|
+
success=success,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _attack_single(
|
|
75
|
+
self, inputs: torch.Tensor, logits: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
76
|
+
) -> torch.Tensor:
|
|
77
|
+
adv_inputs = inputs.clone()
|
|
78
|
+
original_label = int(predict_labels(logits).item())
|
|
79
|
+
targeted = target_label is not None
|
|
80
|
+
for _ in range(self.max_iter):
|
|
81
|
+
adv_inputs.requires_grad_(True)
|
|
82
|
+
outputs = self.net(adv_inputs)
|
|
83
|
+
current_label = int(predict_labels(outputs).item())
|
|
84
|
+
|
|
85
|
+
if targeted is True:
|
|
86
|
+
assert target_label is not None
|
|
87
|
+
target_value = int(target_label.item())
|
|
88
|
+
if current_label == target_value:
|
|
89
|
+
break
|
|
90
|
+
|
|
91
|
+
perturbation = self._targeted_perturbation(adv_inputs, outputs, current_label, target_value)
|
|
92
|
+
|
|
93
|
+
else:
|
|
94
|
+
if current_label != original_label:
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
perturbation = self._untargeted_perturbation(adv_inputs, outputs, current_label)
|
|
98
|
+
|
|
99
|
+
if perturbation is None:
|
|
100
|
+
break
|
|
101
|
+
|
|
102
|
+
# Overshoot helps ensure boundary crossing
|
|
103
|
+
adv_inputs = adv_inputs.detach() + (1.0 + self.overshoot) * perturbation
|
|
104
|
+
adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
|
|
105
|
+
|
|
106
|
+
return adv_inputs.detach()
|
|
107
|
+
|
|
108
|
+
def _targeted_perturbation(
|
|
109
|
+
self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int, target_label: int
|
|
110
|
+
) -> Optional[torch.Tensor]:
|
|
111
|
+
self.net.zero_grad(set_to_none=True)
|
|
112
|
+
grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
|
|
113
|
+
grad_target = torch.autograd.grad(outputs[0, target_label], adv_inputs, retain_graph=False)[0]
|
|
114
|
+
|
|
115
|
+
# Direction toward the target boundary
|
|
116
|
+
w = grad_target - grad_current
|
|
117
|
+
w_norm = torch.norm(w.view(-1))
|
|
118
|
+
if w_norm.item() < GRAD_EPS:
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
# Distance to the decision boundary
|
|
122
|
+
f = outputs[0, target_label] - outputs[0, current_label]
|
|
123
|
+
perturbation = (f.abs() / (w_norm**2 + GRAD_EPS)) * w
|
|
124
|
+
|
|
125
|
+
return perturbation
|
|
126
|
+
|
|
127
|
+
def _untargeted_perturbation(
|
|
128
|
+
self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int
|
|
129
|
+
) -> Optional[torch.Tensor]:
|
|
130
|
+
# Search the top-k competing classes
|
|
131
|
+
top_k = min(self.num_classes, outputs.shape[1])
|
|
132
|
+
top_indices = torch.topk(outputs, k=top_k, dim=1).indices[0]
|
|
133
|
+
candidate_labels = [int(idx) for idx in top_indices if int(idx) != current_label]
|
|
134
|
+
|
|
135
|
+
if len(candidate_labels) == 0:
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
self.net.zero_grad(set_to_none=True)
|
|
139
|
+
grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
|
|
140
|
+
|
|
141
|
+
# Track the closest decision boundary
|
|
142
|
+
best_dist = None
|
|
143
|
+
best_w = None
|
|
144
|
+
best_f = None
|
|
145
|
+
for idx, label in enumerate(candidate_labels):
|
|
146
|
+
# Keep the graph until the last class
|
|
147
|
+
retain_graph = idx != len(candidate_labels) - 1
|
|
148
|
+
grad_other = torch.autograd.grad(outputs[0, label], adv_inputs, retain_graph=retain_graph)[0]
|
|
149
|
+
|
|
150
|
+
w_k = grad_other - grad_current
|
|
151
|
+
w_norm = torch.norm(w_k.view(-1))
|
|
152
|
+
if w_norm.item() < GRAD_EPS:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
f_k = outputs[0, label] - outputs[0, current_label]
|
|
156
|
+
dist = f_k.abs() / (w_norm + GRAD_EPS)
|
|
157
|
+
|
|
158
|
+
if best_dist is None or dist < best_dist:
|
|
159
|
+
best_dist = dist
|
|
160
|
+
best_w = w_k
|
|
161
|
+
best_f = f_k
|
|
162
|
+
|
|
163
|
+
if best_w is None or best_f is None:
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
# Minimal perturbation toward the closest boundary
|
|
167
|
+
best_w_norm = torch.norm(best_w.view(-1))
|
|
168
|
+
if best_w_norm.item() < GRAD_EPS:
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
perturbation = (best_f.abs() / (best_w_norm**2 + GRAD_EPS)) * best_w
|
|
172
|
+
|
|
173
|
+
return perturbation
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fast Gradient Sign Method (FGSM)
|
|
3
|
+
|
|
4
|
+
Paper "Explaining and Harnessing Adversarial Examples", https://arxiv.org/abs/1412.6572
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from birder.adversarial.base import AttackResult
|
|
14
|
+
from birder.adversarial.base import attack_success
|
|
15
|
+
from birder.adversarial.base import clamp_normalized
|
|
16
|
+
from birder.adversarial.base import pixel_eps_to_normalized
|
|
17
|
+
from birder.adversarial.base import predict_labels
|
|
18
|
+
from birder.adversarial.base import validate_target
|
|
19
|
+
from birder.data.transforms.classification import RGBType
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FGSM:
|
|
23
|
+
def __init__(self, net: nn.Module, eps: float, *, rgb_stats: RGBType) -> None:
|
|
24
|
+
self.net = net.eval()
|
|
25
|
+
self.eps = eps
|
|
26
|
+
self.rgb_stats = rgb_stats
|
|
27
|
+
|
|
28
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
29
|
+
inputs = input_tensor.detach().clone()
|
|
30
|
+
inputs.requires_grad_(True)
|
|
31
|
+
|
|
32
|
+
logits = self.net(inputs)
|
|
33
|
+
targeted = target is not None
|
|
34
|
+
if targeted is True:
|
|
35
|
+
target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
|
|
36
|
+
else:
|
|
37
|
+
target = predict_labels(logits)
|
|
38
|
+
|
|
39
|
+
loss = F.cross_entropy(logits, target)
|
|
40
|
+
(grad,) = torch.autograd.grad(loss, inputs, retain_graph=False, create_graph=False)
|
|
41
|
+
eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
42
|
+
|
|
43
|
+
# Targeted steps descend toward target, untargeted ascend away from original
|
|
44
|
+
if targeted is True:
|
|
45
|
+
direction = -1.0
|
|
46
|
+
else:
|
|
47
|
+
direction = 1.0
|
|
48
|
+
|
|
49
|
+
perturbation = direction * eps_norm * grad.sign()
|
|
50
|
+
adv_inputs = clamp_normalized(inputs + perturbation, self.rgb_stats)
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
adv_logits = self.net(adv_inputs)
|
|
53
|
+
|
|
54
|
+
success = attack_success(
|
|
55
|
+
logits.detach(),
|
|
56
|
+
adv_logits,
|
|
57
|
+
targeted,
|
|
58
|
+
target=target if targeted else None,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return AttackResult(
|
|
62
|
+
adv_inputs=adv_inputs,
|
|
63
|
+
adv_logits=adv_logits,
|
|
64
|
+
perturbation=adv_inputs - inputs,
|
|
65
|
+
logits=logits.detach(),
|
|
66
|
+
success=success,
|
|
67
|
+
)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Projected Gradient Descent (PGD)
|
|
3
|
+
|
|
4
|
+
Paper "Towards Deep Learning Models Resistant to Adversarial Attacks", https://arxiv.org/abs/1706.06083
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Reference license: MIT
|
|
8
|
+
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
from birder.adversarial.base import AttackResult
|
|
16
|
+
from birder.adversarial.base import attack_success
|
|
17
|
+
from birder.adversarial.base import clamp_normalized
|
|
18
|
+
from birder.adversarial.base import pixel_eps_to_normalized
|
|
19
|
+
from birder.adversarial.base import predict_labels
|
|
20
|
+
from birder.adversarial.base import validate_target
|
|
21
|
+
from birder.data.transforms.classification import RGBType
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PGD:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
net: nn.Module,
|
|
28
|
+
eps: float,
|
|
29
|
+
steps: int = 10,
|
|
30
|
+
step_size: Optional[float] = None,
|
|
31
|
+
random_start: bool = False,
|
|
32
|
+
*,
|
|
33
|
+
rgb_stats: RGBType,
|
|
34
|
+
) -> None:
|
|
35
|
+
if steps <= 0:
|
|
36
|
+
raise ValueError("steps must be a positive integer")
|
|
37
|
+
|
|
38
|
+
self.net = net.eval()
|
|
39
|
+
self.eps = eps
|
|
40
|
+
self.steps = steps
|
|
41
|
+
if step_size is not None:
|
|
42
|
+
self.step_size = step_size
|
|
43
|
+
else:
|
|
44
|
+
self.step_size = eps / steps
|
|
45
|
+
|
|
46
|
+
self.random_start = random_start
|
|
47
|
+
self.rgb_stats = rgb_stats
|
|
48
|
+
|
|
49
|
+
if self.step_size <= 0:
|
|
50
|
+
raise ValueError("step_size must be positive")
|
|
51
|
+
|
|
52
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
53
|
+
inputs = input_tensor.detach()
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
logits = self.net(inputs)
|
|
56
|
+
|
|
57
|
+
targeted = target is not None
|
|
58
|
+
if targeted:
|
|
59
|
+
target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
|
|
60
|
+
else:
|
|
61
|
+
target = predict_labels(logits)
|
|
62
|
+
|
|
63
|
+
eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
64
|
+
step_norm = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
|
|
65
|
+
|
|
66
|
+
# Targeted steps descend toward target, untargeted ascend away from original
|
|
67
|
+
if targeted is True:
|
|
68
|
+
direction = -1.0
|
|
69
|
+
else:
|
|
70
|
+
direction = 1.0
|
|
71
|
+
|
|
72
|
+
adv_inputs = inputs.clone()
|
|
73
|
+
if self.random_start is True:
|
|
74
|
+
# Random start inside the epsilon ball
|
|
75
|
+
adv_inputs = adv_inputs + torch.empty_like(adv_inputs).uniform_(-1.0, 1.0) * eps_norm
|
|
76
|
+
adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
|
|
77
|
+
|
|
78
|
+
for _ in range(self.steps):
|
|
79
|
+
adv_inputs.requires_grad_(True)
|
|
80
|
+
adv_logits = self.net(adv_inputs)
|
|
81
|
+
loss = F.cross_entropy(adv_logits, target)
|
|
82
|
+
(grad,) = torch.autograd.grad(loss, adv_inputs, retain_graph=False, create_graph=False)
|
|
83
|
+
adv_inputs = adv_inputs.detach() + direction * step_norm * grad.sign()
|
|
84
|
+
|
|
85
|
+
# Project back into the epsilon ball around the original input.
|
|
86
|
+
delta = torch.clamp(adv_inputs - inputs, min=-eps_norm, max=eps_norm)
|
|
87
|
+
adv_inputs = clamp_normalized(inputs + delta, self.rgb_stats)
|
|
88
|
+
|
|
89
|
+
with torch.no_grad():
|
|
90
|
+
adv_logits = self.net(adv_inputs)
|
|
91
|
+
|
|
92
|
+
success = attack_success(
|
|
93
|
+
logits.detach(),
|
|
94
|
+
adv_logits,
|
|
95
|
+
targeted,
|
|
96
|
+
target=target if targeted else None,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return AttackResult(
|
|
100
|
+
adv_inputs=adv_inputs,
|
|
101
|
+
adv_logits=adv_logits,
|
|
102
|
+
perturbation=adv_inputs - inputs,
|
|
103
|
+
logits=logits.detach(),
|
|
104
|
+
success=success,
|
|
105
|
+
)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SimBA (Simple Black-box Attack)
|
|
3
|
+
|
|
4
|
+
Paper "Simple Black-box Adversarial Attacks", https://arxiv.org/abs/1905.07121
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from birder.adversarial.base import AttackResult
|
|
14
|
+
from birder.adversarial.base import attack_success
|
|
15
|
+
from birder.adversarial.base import clamp_normalized
|
|
16
|
+
from birder.adversarial.base import pixel_eps_to_normalized
|
|
17
|
+
from birder.adversarial.base import predict_labels
|
|
18
|
+
from birder.adversarial.base import validate_target
|
|
19
|
+
from birder.data.transforms.classification import RGBType
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SimBA:
|
|
23
|
+
def __init__(self, net: nn.Module, step_size: float, max_iter: int = 1000, *, rgb_stats: RGBType) -> None:
|
|
24
|
+
if step_size <= 0:
|
|
25
|
+
raise ValueError("step_size must be positive")
|
|
26
|
+
if max_iter <= 0:
|
|
27
|
+
raise ValueError("max_iter must be positive")
|
|
28
|
+
|
|
29
|
+
self.net = net.eval()
|
|
30
|
+
self.step_size = step_size
|
|
31
|
+
self.max_iter = max_iter
|
|
32
|
+
self.rgb_stats = rgb_stats
|
|
33
|
+
|
|
34
|
+
def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
|
|
35
|
+
inputs = input_tensor.detach()
|
|
36
|
+
with torch.no_grad():
|
|
37
|
+
logits = self.net(inputs)
|
|
38
|
+
|
|
39
|
+
labels = predict_labels(logits)
|
|
40
|
+
target_labels = (
|
|
41
|
+
validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
|
|
42
|
+
)
|
|
43
|
+
targeted = target_labels is not None
|
|
44
|
+
|
|
45
|
+
adv_inputs_list = []
|
|
46
|
+
total_queries = 0
|
|
47
|
+
for idx in range(inputs.size(0)):
|
|
48
|
+
label = labels[idx : idx + 1]
|
|
49
|
+
target_label = target_labels[idx : idx + 1] if target_labels is not None else None
|
|
50
|
+
adv_input, num_queries = self._attack_single(inputs[idx : idx + 1], label, target_label)
|
|
51
|
+
adv_inputs_list.append(adv_input)
|
|
52
|
+
total_queries += num_queries
|
|
53
|
+
|
|
54
|
+
adv_inputs = torch.concat(adv_inputs_list, dim=0)
|
|
55
|
+
with torch.no_grad():
|
|
56
|
+
adv_logits = self.net(adv_inputs)
|
|
57
|
+
|
|
58
|
+
success = attack_success(
|
|
59
|
+
logits,
|
|
60
|
+
adv_logits,
|
|
61
|
+
targeted,
|
|
62
|
+
target=target_labels if targeted else None,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return AttackResult(
|
|
66
|
+
adv_inputs=adv_inputs,
|
|
67
|
+
adv_logits=adv_logits,
|
|
68
|
+
perturbation=adv_inputs - inputs,
|
|
69
|
+
logits=logits.detach(),
|
|
70
|
+
success=success,
|
|
71
|
+
num_queries=total_queries,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# pylint: disable=too-many-locals
|
|
75
|
+
def _attack_single(
|
|
76
|
+
self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
77
|
+
) -> tuple[torch.Tensor, int]:
|
|
78
|
+
adv_inputs = inputs.clone()
|
|
79
|
+
num_queries = 1 # Baseline forward pass
|
|
80
|
+
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
current_logits = self.net(adv_inputs)
|
|
83
|
+
current_objective = self._compute_objective(current_logits, label, target_label)
|
|
84
|
+
|
|
85
|
+
if self._is_successful(current_logits, label, target_label):
|
|
86
|
+
return adv_inputs.detach(), num_queries
|
|
87
|
+
|
|
88
|
+
(_, channels, height, width) = adv_inputs.shape
|
|
89
|
+
num_dims = channels * height * width
|
|
90
|
+
step = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=adv_inputs.device, dtype=adv_inputs.dtype)
|
|
91
|
+
step_vals = step.view(-1) # Per-channel steps
|
|
92
|
+
stride = height * width
|
|
93
|
+
|
|
94
|
+
perm = torch.randperm(num_dims, device=adv_inputs.device)
|
|
95
|
+
num_steps = min(self.max_iter, num_dims)
|
|
96
|
+
|
|
97
|
+
# Coordinate-wise search in random order
|
|
98
|
+
for flat_idx in perm[:num_steps]:
|
|
99
|
+
(c, rem) = divmod(int(flat_idx.item()), stride)
|
|
100
|
+
(h, w) = divmod(rem, width)
|
|
101
|
+
step_val = step_vals[c]
|
|
102
|
+
|
|
103
|
+
(candidate_inputs, candidate_logits, candidate_objective) = self._best_candidate(
|
|
104
|
+
adv_inputs, c, h, w, step_val, label, target_label
|
|
105
|
+
)
|
|
106
|
+
num_queries += 2
|
|
107
|
+
|
|
108
|
+
if candidate_objective < current_objective:
|
|
109
|
+
adv_inputs = candidate_inputs
|
|
110
|
+
current_logits = candidate_logits
|
|
111
|
+
current_objective = candidate_objective
|
|
112
|
+
|
|
113
|
+
if self._is_successful(current_logits, label, target_label) is True:
|
|
114
|
+
break
|
|
115
|
+
|
|
116
|
+
return adv_inputs.detach(), num_queries
|
|
117
|
+
|
|
118
|
+
def _perturb_pixel(
|
|
119
|
+
self, inputs: torch.Tensor, channel: int, row: int, col: int, step: torch.Tensor
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
adv_inputs = inputs.clone()
|
|
122
|
+
adv_inputs[0, channel, row, col] = adv_inputs[0, channel, row, col] + step
|
|
123
|
+
return clamp_normalized(adv_inputs, self.rgb_stats)
|
|
124
|
+
|
|
125
|
+
def _evaluate_candidate(
|
|
126
|
+
self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
127
|
+
) -> tuple[torch.Tensor, float]:
|
|
128
|
+
with torch.no_grad():
|
|
129
|
+
logits = self.net(inputs)
|
|
130
|
+
|
|
131
|
+
return logits, self._compute_objective(logits, label, target_label)
|
|
132
|
+
|
|
133
|
+
def _best_candidate(
|
|
134
|
+
self,
|
|
135
|
+
inputs: torch.Tensor,
|
|
136
|
+
channel: int,
|
|
137
|
+
row: int,
|
|
138
|
+
col: int,
|
|
139
|
+
step: torch.Tensor,
|
|
140
|
+
label: torch.Tensor,
|
|
141
|
+
target_label: Optional[torch.Tensor],
|
|
142
|
+
) -> tuple[torch.Tensor, torch.Tensor, float]:
|
|
143
|
+
adv_plus = self._perturb_pixel(inputs, channel, row, col, step)
|
|
144
|
+
logits_plus, objective_plus = self._evaluate_candidate(adv_plus, label, target_label)
|
|
145
|
+
|
|
146
|
+
adv_minus = self._perturb_pixel(inputs, channel, row, col, -step)
|
|
147
|
+
logits_minus, objective_minus = self._evaluate_candidate(adv_minus, label, target_label)
|
|
148
|
+
|
|
149
|
+
if objective_plus <= objective_minus:
|
|
150
|
+
return adv_plus, logits_plus, objective_plus
|
|
151
|
+
|
|
152
|
+
return adv_minus, logits_minus, objective_minus
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def _compute_objective(
|
|
156
|
+
logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
157
|
+
) -> float:
|
|
158
|
+
# Lower objective is better in both modes
|
|
159
|
+
if target_label is not None:
|
|
160
|
+
return float(F.cross_entropy(logits, target_label).item())
|
|
161
|
+
|
|
162
|
+
return -float(F.cross_entropy(logits, original_label).item())
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def _is_successful(
|
|
166
|
+
logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
|
|
167
|
+
) -> bool:
|
|
168
|
+
pred = predict_labels(logits)
|
|
169
|
+
if target_label is not None:
|
|
170
|
+
return bool(pred.eq(target_label).item())
|
|
171
|
+
|
|
172
|
+
return bool(pred.ne(original_label).item())
|