birder 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. {birder-0.2.1 → birder-0.2.3}/PKG-INFO +4 -3
  2. {birder-0.2.1 → birder-0.2.3}/README.md +1 -1
  3. birder-0.2.3/birder/adversarial/__init__.py +13 -0
  4. birder-0.2.3/birder/adversarial/base.py +101 -0
  5. birder-0.2.3/birder/adversarial/deepfool.py +173 -0
  6. birder-0.2.3/birder/adversarial/fgsm.py +67 -0
  7. birder-0.2.3/birder/adversarial/pgd.py +105 -0
  8. birder-0.2.3/birder/adversarial/simba.py +172 -0
  9. {birder-0.2.1 → birder-0.2.3}/birder/common/lib.py +2 -9
  10. {birder-0.2.1 → birder-0.2.3}/birder/common/training_cli.py +29 -3
  11. {birder-0.2.1 → birder-0.2.3}/birder/common/training_utils.py +141 -11
  12. {birder-0.2.1 → birder-0.2.3}/birder/data/collators/detection.py +10 -3
  13. {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/coco.py +8 -10
  14. {birder-0.2.1 → birder-0.2.3}/birder/data/transforms/detection.py +30 -13
  15. {birder-0.2.1 → birder-0.2.3}/birder/inference/data_parallel.py +1 -2
  16. {birder-0.2.1 → birder-0.2.3}/birder/inference/detection.py +108 -4
  17. birder-0.2.3/birder/inference/wbf.py +226 -0
  18. birder-0.2.3/birder/introspection/__init__.py +13 -0
  19. birder-0.2.3/birder/introspection/attention_rollout.py +185 -0
  20. birder-0.2.3/birder/introspection/base.py +104 -0
  21. birder-0.2.3/birder/introspection/gradcam.py +147 -0
  22. birder-0.2.3/birder/introspection/guided_backprop.py +229 -0
  23. birder-0.2.3/birder/introspection/transformer_attribution.py +182 -0
  24. {birder-0.2.1 → birder-0.2.3}/birder/net/__init__.py +8 -0
  25. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/deformable_detr.py +14 -12
  26. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/detr.py +7 -3
  27. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/efficientdet.py +65 -86
  28. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/rt_detr_v1.py +4 -3
  29. birder-0.2.3/birder/net/detection/yolo_anchors.py +205 -0
  30. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v2.py +25 -24
  31. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v3.py +42 -48
  32. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v4.py +31 -40
  33. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/yolo_v4_tiny.py +24 -20
  34. {birder-0.2.1 → birder-0.2.3}/birder/net/fasternet.py +1 -1
  35. birder-0.2.3/birder/net/gc_vit.py +671 -0
  36. birder-0.2.3/birder/net/lit_v1.py +472 -0
  37. birder-0.2.3/birder/net/lit_v1_tiny.py +342 -0
  38. birder-0.2.3/birder/net/lit_v2.py +436 -0
  39. {birder-0.2.1 → birder-0.2.3}/birder/net/mim/mae_vit.py +7 -8
  40. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v4_hybrid.py +1 -1
  41. {birder-0.2.1 → birder-0.2.3}/birder/net/pit.py +1 -1
  42. {birder-0.2.1 → birder-0.2.3}/birder/net/resnet_v1.py +95 -35
  43. {birder-0.2.1 → birder-0.2.3}/birder/net/resnext.py +67 -25
  44. {birder-0.2.1 → birder-0.2.3}/birder/net/se_resnet_v1.py +46 -0
  45. {birder-0.2.1 → birder-0.2.3}/birder/net/se_resnext.py +3 -0
  46. {birder-0.2.1 → birder-0.2.3}/birder/net/simple_vit.py +2 -2
  47. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/data2vec.py +1 -1
  48. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/data2vec2.py +4 -2
  49. {birder-0.2.1 → birder-0.2.3}/birder/net/vit.py +0 -15
  50. {birder-0.2.1 → birder-0.2.3}/birder/net/vovnet_v2.py +31 -1
  51. {birder-0.2.1 → birder-0.2.3}/birder/results/gui.py +15 -2
  52. {birder-0.2.1 → birder-0.2.3}/birder/scripts/benchmark.py +90 -21
  53. {birder-0.2.1 → birder-0.2.3}/birder/scripts/predict.py +1 -0
  54. {birder-0.2.1 → birder-0.2.3}/birder/scripts/predict_detection.py +48 -9
  55. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train.py +33 -50
  56. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_barlow_twins.py +19 -40
  57. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_byol.py +19 -40
  58. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_capi.py +21 -43
  59. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_data2vec.py +18 -40
  60. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_data2vec2.py +18 -40
  61. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_detection.py +89 -57
  62. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_dino_v1.py +19 -40
  63. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_dino_v2.py +18 -40
  64. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_dino_v2_dist.py +25 -40
  65. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_franca.py +18 -40
  66. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_i_jepa.py +25 -46
  67. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_ibot.py +18 -40
  68. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_kd.py +179 -81
  69. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_mim.py +20 -43
  70. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_mmcr.py +19 -40
  71. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_rotnet.py +19 -40
  72. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_simclr.py +19 -40
  73. {birder-0.2.1 → birder-0.2.3}/birder/scripts/train_vicreg.py +19 -40
  74. {birder-0.2.1 → birder-0.2.3}/birder/tools/__main__.py +6 -2
  75. birder-0.2.3/birder/tools/adversarial.py +214 -0
  76. birder-0.2.3/birder/tools/auto_anchors.py +380 -0
  77. {birder-0.2.1 → birder-0.2.3}/birder/tools/ensemble_model.py +1 -1
  78. {birder-0.2.1 → birder-0.2.3}/birder/tools/introspection.py +58 -31
  79. {birder-0.2.1 → birder-0.2.3}/birder/tools/pack.py +172 -103
  80. {birder-0.2.1 → birder-0.2.3}/birder/tools/show_det_iterator.py +10 -1
  81. birder-0.2.3/birder/version.py +1 -0
  82. {birder-0.2.1 → birder-0.2.3}/birder.egg-info/PKG-INFO +4 -3
  83. {birder-0.2.1 → birder-0.2.3}/birder.egg-info/SOURCES.txt +13 -0
  84. {birder-0.2.1 → birder-0.2.3}/birder.egg-info/requires.txt +2 -1
  85. {birder-0.2.1 → birder-0.2.3}/requirements/_requirements-dev.txt +2 -1
  86. birder-0.2.3/tests/test_adversarial.py +238 -0
  87. {birder-0.2.1 → birder-0.2.3}/tests/test_common.py +202 -14
  88. {birder-0.2.1 → birder-0.2.3}/tests/test_inference.py +69 -0
  89. birder-0.2.3/tests/test_introspection.py +310 -0
  90. {birder-0.2.1 → birder-0.2.3}/tests/test_kernels.py +13 -0
  91. {birder-0.2.1 → birder-0.2.3}/tests/test_model_registry.py +2 -2
  92. {birder-0.2.1 → birder-0.2.3}/tests/test_net.py +237 -176
  93. {birder-0.2.1 → birder-0.2.3}/tests/test_net_detection.py +44 -0
  94. {birder-0.2.1 → birder-0.2.3}/tests/test_transforms.py +9 -0
  95. birder-0.2.1/birder/adversarial/fgsm.py +0 -34
  96. birder-0.2.1/birder/adversarial/pgd.py +0 -54
  97. birder-0.2.1/birder/introspection/__init__.py +0 -9
  98. birder-0.2.1/birder/introspection/attention_rollout.py +0 -117
  99. birder-0.2.1/birder/introspection/base.py +0 -60
  100. birder-0.2.1/birder/introspection/gradcam.py +0 -176
  101. birder-0.2.1/birder/introspection/guided_backprop.py +0 -155
  102. birder-0.2.1/birder/tools/__init__.py +0 -0
  103. birder-0.2.1/birder/tools/adversarial.py +0 -163
  104. birder-0.2.1/birder/version.py +0 -1
  105. {birder-0.2.1 → birder-0.2.3}/LICENSE +0 -0
  106. {birder-0.2.1 → birder-0.2.3}/birder/__init__.py +0 -0
  107. {birder-0.2.1/birder/adversarial → birder-0.2.3/birder/common}/__init__.py +0 -0
  108. {birder-0.2.1 → birder-0.2.3}/birder/common/cli.py +0 -0
  109. {birder-0.2.1 → birder-0.2.3}/birder/common/fs_ops.py +0 -0
  110. {birder-0.2.1 → birder-0.2.3}/birder/common/masking.py +0 -0
  111. {birder-0.2.1/birder/common → birder-0.2.3/birder/conf}/__init__.py +0 -0
  112. {birder-0.2.1 → birder-0.2.3}/birder/conf/settings.py +0 -0
  113. {birder-0.2.1/birder/conf → birder-0.2.3/birder/data}/__init__.py +0 -0
  114. {birder-0.2.1/birder/data → birder-0.2.3/birder/data/collators}/__init__.py +0 -0
  115. {birder-0.2.1/birder/data/collators → birder-0.2.3/birder/data/dataloader}/__init__.py +0 -0
  116. {birder-0.2.1 → birder-0.2.3}/birder/data/dataloader/webdataset.py +0 -0
  117. {birder-0.2.1/birder/data/dataloader → birder-0.2.3/birder/data/datasets}/__init__.py +0 -0
  118. {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/directory.py +0 -0
  119. {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/fake.py +0 -0
  120. {birder-0.2.1 → birder-0.2.3}/birder/data/datasets/webdataset.py +0 -0
  121. {birder-0.2.1/birder/data/datasets → birder-0.2.3/birder/data/transforms}/__init__.py +0 -0
  122. {birder-0.2.1 → birder-0.2.3}/birder/data/transforms/classification.py +0 -0
  123. {birder-0.2.1 → birder-0.2.3}/birder/data/transforms/mosaic.py +0 -0
  124. {birder-0.2.1/birder/data/transforms → birder-0.2.3/birder/datahub}/__init__.py +0 -0
  125. {birder-0.2.1 → birder-0.2.3}/birder/datahub/_lib.py +0 -0
  126. {birder-0.2.1 → birder-0.2.3}/birder/datahub/classification.py +0 -0
  127. {birder-0.2.1/birder/datahub → birder-0.2.3/birder/inference}/__init__.py +0 -0
  128. {birder-0.2.1 → birder-0.2.3}/birder/inference/classification.py +0 -0
  129. {birder-0.2.1/birder/inference → birder-0.2.3/birder/kernels}/__init__.py +0 -0
  130. {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
  131. {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
  132. {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
  133. {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
  134. {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
  135. {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
  136. {birder-0.2.1 → birder-0.2.3}/birder/kernels/deformable_detr/vision.cpp +0 -0
  137. {birder-0.2.1 → birder-0.2.3}/birder/kernels/load_kernel.py +0 -0
  138. {birder-0.2.1 → birder-0.2.3}/birder/kernels/soft_nms/op.cpp +0 -0
  139. {birder-0.2.1 → birder-0.2.3}/birder/kernels/soft_nms/soft_nms.cpp +0 -0
  140. {birder-0.2.1 → birder-0.2.3}/birder/kernels/soft_nms/soft_nms.h +0 -0
  141. {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
  142. {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
  143. {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
  144. {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
  145. {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
  146. {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
  147. {birder-0.2.1 → birder-0.2.3}/birder/kernels/transnext/swattention.cpp +0 -0
  148. {birder-0.2.1 → birder-0.2.3}/birder/layers/__init__.py +0 -0
  149. {birder-0.2.1 → birder-0.2.3}/birder/layers/activations.py +0 -0
  150. {birder-0.2.1 → birder-0.2.3}/birder/layers/attention_pool.py +0 -0
  151. {birder-0.2.1 → birder-0.2.3}/birder/layers/ffn.py +0 -0
  152. {birder-0.2.1 → birder-0.2.3}/birder/layers/gem.py +0 -0
  153. {birder-0.2.1 → birder-0.2.3}/birder/layers/layer_norm.py +0 -0
  154. {birder-0.2.1 → birder-0.2.3}/birder/layers/layer_scale.py +0 -0
  155. {birder-0.2.1 → birder-0.2.3}/birder/model_registry/__init__.py +0 -0
  156. {birder-0.2.1 → birder-0.2.3}/birder/model_registry/manifest.py +0 -0
  157. {birder-0.2.1 → birder-0.2.3}/birder/model_registry/model_registry.py +0 -0
  158. {birder-0.2.1 → birder-0.2.3}/birder/net/alexnet.py +0 -0
  159. {birder-0.2.1 → birder-0.2.3}/birder/net/base.py +0 -0
  160. {birder-0.2.1 → birder-0.2.3}/birder/net/biformer.py +0 -0
  161. {birder-0.2.1 → birder-0.2.3}/birder/net/cait.py +0 -0
  162. {birder-0.2.1 → birder-0.2.3}/birder/net/cas_vit.py +0 -0
  163. {birder-0.2.1 → birder-0.2.3}/birder/net/coat.py +0 -0
  164. {birder-0.2.1 → birder-0.2.3}/birder/net/conv2former.py +0 -0
  165. {birder-0.2.1 → birder-0.2.3}/birder/net/convmixer.py +0 -0
  166. {birder-0.2.1 → birder-0.2.3}/birder/net/convnext_v1.py +0 -0
  167. {birder-0.2.1 → birder-0.2.3}/birder/net/convnext_v2.py +0 -0
  168. {birder-0.2.1 → birder-0.2.3}/birder/net/crossformer.py +0 -0
  169. {birder-0.2.1 → birder-0.2.3}/birder/net/crossvit.py +0 -0
  170. {birder-0.2.1 → birder-0.2.3}/birder/net/cspnet.py +0 -0
  171. {birder-0.2.1 → birder-0.2.3}/birder/net/cswin_transformer.py +0 -0
  172. {birder-0.2.1 → birder-0.2.3}/birder/net/darknet.py +0 -0
  173. {birder-0.2.1 → birder-0.2.3}/birder/net/davit.py +0 -0
  174. {birder-0.2.1 → birder-0.2.3}/birder/net/deit.py +0 -0
  175. {birder-0.2.1 → birder-0.2.3}/birder/net/deit3.py +0 -0
  176. {birder-0.2.1 → birder-0.2.3}/birder/net/densenet.py +0 -0
  177. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/__init__.py +0 -0
  178. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/base.py +0 -0
  179. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/faster_rcnn.py +0 -0
  180. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/fcos.py +0 -0
  181. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/retinanet.py +0 -0
  182. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/ssd.py +0 -0
  183. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/ssdlite.py +0 -0
  184. {birder-0.2.1 → birder-0.2.3}/birder/net/detection/vitdet.py +0 -0
  185. {birder-0.2.1 → birder-0.2.3}/birder/net/dpn.py +0 -0
  186. {birder-0.2.1 → birder-0.2.3}/birder/net/edgenext.py +0 -0
  187. {birder-0.2.1 → birder-0.2.3}/birder/net/edgevit.py +0 -0
  188. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientformer_v1.py +0 -0
  189. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientformer_v2.py +0 -0
  190. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientnet_lite.py +0 -0
  191. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientnet_v1.py +0 -0
  192. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientnet_v2.py +0 -0
  193. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientvim.py +0 -0
  194. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientvit_mit.py +0 -0
  195. {birder-0.2.1 → birder-0.2.3}/birder/net/efficientvit_msft.py +0 -0
  196. {birder-0.2.1 → birder-0.2.3}/birder/net/fastvit.py +1 -1
  197. {birder-0.2.1 → birder-0.2.3}/birder/net/flexivit.py +0 -0
  198. {birder-0.2.1 → birder-0.2.3}/birder/net/focalnet.py +0 -0
  199. {birder-0.2.1 → birder-0.2.3}/birder/net/ghostnet_v1.py +0 -0
  200. {birder-0.2.1 → birder-0.2.3}/birder/net/ghostnet_v2.py +0 -0
  201. {birder-0.2.1 → birder-0.2.3}/birder/net/groupmixformer.py +0 -0
  202. {birder-0.2.1 → birder-0.2.3}/birder/net/hgnet_v1.py +0 -0
  203. {birder-0.2.1 → birder-0.2.3}/birder/net/hgnet_v2.py +0 -0
  204. {birder-0.2.1 → birder-0.2.3}/birder/net/hiera.py +0 -0
  205. {birder-0.2.1 → birder-0.2.3}/birder/net/hieradet.py +0 -0
  206. {birder-0.2.1 → birder-0.2.3}/birder/net/hornet.py +0 -0
  207. {birder-0.2.1 → birder-0.2.3}/birder/net/iformer.py +0 -0
  208. {birder-0.2.1 → birder-0.2.3}/birder/net/inception_next.py +0 -0
  209. {birder-0.2.1 → birder-0.2.3}/birder/net/inception_resnet_v1.py +0 -0
  210. {birder-0.2.1 → birder-0.2.3}/birder/net/inception_resnet_v2.py +0 -0
  211. {birder-0.2.1 → birder-0.2.3}/birder/net/inception_v3.py +0 -0
  212. {birder-0.2.1 → birder-0.2.3}/birder/net/inception_v4.py +0 -0
  213. {birder-0.2.1 → birder-0.2.3}/birder/net/levit.py +0 -0
  214. {birder-0.2.1 → birder-0.2.3}/birder/net/maxvit.py +0 -0
  215. {birder-0.2.1 → birder-0.2.3}/birder/net/metaformer.py +0 -0
  216. {birder-0.2.1 → birder-0.2.3}/birder/net/mim/__init__.py +0 -0
  217. {birder-0.2.1 → birder-0.2.3}/birder/net/mim/base.py +0 -0
  218. {birder-0.2.1 → birder-0.2.3}/birder/net/mim/crossmae.py +0 -0
  219. {birder-0.2.1 → birder-0.2.3}/birder/net/mim/fcmae.py +0 -0
  220. {birder-0.2.1 → birder-0.2.3}/birder/net/mim/mae_hiera.py +0 -0
  221. {birder-0.2.1 → birder-0.2.3}/birder/net/mim/simmim.py +0 -0
  222. {birder-0.2.1 → birder-0.2.3}/birder/net/mnasnet.py +0 -0
  223. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v1.py +0 -0
  224. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v2.py +0 -0
  225. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v3_large.py +0 -0
  226. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v3_small.py +0 -0
  227. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilenet_v4.py +0 -0
  228. {birder-0.2.1 → birder-0.2.3}/birder/net/mobileone.py +0 -0
  229. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilevit_v1.py +0 -0
  230. {birder-0.2.1 → birder-0.2.3}/birder/net/mobilevit_v2.py +0 -0
  231. {birder-0.2.1 → birder-0.2.3}/birder/net/moganet.py +0 -0
  232. {birder-0.2.1 → birder-0.2.3}/birder/net/mvit_v2.py +0 -0
  233. {birder-0.2.1 → birder-0.2.3}/birder/net/nextvit.py +0 -0
  234. {birder-0.2.1 → birder-0.2.3}/birder/net/nfnet.py +0 -0
  235. {birder-0.2.1 → birder-0.2.3}/birder/net/pvt_v1.py +0 -0
  236. {birder-0.2.1 → birder-0.2.3}/birder/net/pvt_v2.py +0 -0
  237. {birder-0.2.1 → birder-0.2.3}/birder/net/rdnet.py +0 -0
  238. {birder-0.2.1 → birder-0.2.3}/birder/net/regionvit.py +0 -0
  239. {birder-0.2.1 → birder-0.2.3}/birder/net/regnet.py +0 -0
  240. {birder-0.2.1 → birder-0.2.3}/birder/net/regnet_z.py +0 -0
  241. {birder-0.2.1 → birder-0.2.3}/birder/net/repghost.py +0 -0
  242. {birder-0.2.1 → birder-0.2.3}/birder/net/repvgg.py +0 -0
  243. {birder-0.2.1 → birder-0.2.3}/birder/net/repvit.py +0 -0
  244. {birder-0.2.1 → birder-0.2.3}/birder/net/resmlp.py +0 -0
  245. {birder-0.2.1 → birder-0.2.3}/birder/net/resnest.py +0 -0
  246. {birder-0.2.1 → birder-0.2.3}/birder/net/resnet_v2.py +0 -0
  247. {birder-0.2.1 → birder-0.2.3}/birder/net/rope_deit3.py +0 -0
  248. {birder-0.2.1 → birder-0.2.3}/birder/net/rope_flexivit.py +0 -0
  249. {birder-0.2.1 → birder-0.2.3}/birder/net/rope_vit.py +0 -0
  250. {birder-0.2.1 → birder-0.2.3}/birder/net/se_resnet_v2.py +0 -0
  251. {birder-0.2.1 → birder-0.2.3}/birder/net/sequencer2d.py +0 -0
  252. {birder-0.2.1 → birder-0.2.3}/birder/net/shufflenet_v1.py +0 -0
  253. {birder-0.2.1 → birder-0.2.3}/birder/net/shufflenet_v2.py +0 -0
  254. {birder-0.2.1 → birder-0.2.3}/birder/net/smt.py +0 -0
  255. {birder-0.2.1 → birder-0.2.3}/birder/net/squeezenet.py +0 -0
  256. {birder-0.2.1 → birder-0.2.3}/birder/net/squeezenext.py +0 -0
  257. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/__init__.py +0 -0
  258. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/barlow_twins.py +0 -0
  259. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/base.py +0 -0
  260. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/byol.py +0 -0
  261. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/capi.py +0 -0
  262. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/dino_v1.py +0 -0
  263. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/dino_v2.py +0 -0
  264. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/franca.py +0 -0
  265. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/i_jepa.py +0 -0
  266. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/ibot.py +0 -0
  267. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/mmcr.py +0 -0
  268. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/simclr.py +0 -0
  269. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/sscd.py +0 -0
  270. {birder-0.2.1 → birder-0.2.3}/birder/net/ssl/vicreg.py +0 -0
  271. {birder-0.2.1 → birder-0.2.3}/birder/net/starnet.py +0 -0
  272. {birder-0.2.1 → birder-0.2.3}/birder/net/swiftformer.py +0 -0
  273. {birder-0.2.1 → birder-0.2.3}/birder/net/swin_transformer_v1.py +0 -0
  274. {birder-0.2.1 → birder-0.2.3}/birder/net/swin_transformer_v2.py +0 -0
  275. {birder-0.2.1 → birder-0.2.3}/birder/net/tiny_vit.py +0 -0
  276. {birder-0.2.1 → birder-0.2.3}/birder/net/transnext.py +0 -0
  277. {birder-0.2.1 → birder-0.2.3}/birder/net/uniformer.py +0 -0
  278. {birder-0.2.1 → birder-0.2.3}/birder/net/van.py +0 -0
  279. {birder-0.2.1 → birder-0.2.3}/birder/net/vgg.py +0 -0
  280. {birder-0.2.1 → birder-0.2.3}/birder/net/vgg_reduced.py +0 -0
  281. {birder-0.2.1 → birder-0.2.3}/birder/net/vit_parallel.py +0 -0
  282. {birder-0.2.1 → birder-0.2.3}/birder/net/vit_sam.py +0 -0
  283. {birder-0.2.1 → birder-0.2.3}/birder/net/vovnet_v1.py +0 -0
  284. {birder-0.2.1 → birder-0.2.3}/birder/net/wide_resnet.py +0 -0
  285. {birder-0.2.1 → birder-0.2.3}/birder/net/xception.py +0 -0
  286. {birder-0.2.1 → birder-0.2.3}/birder/net/xcit.py +0 -0
  287. {birder-0.2.1/birder/kernels → birder-0.2.3/birder/ops}/__init__.py +0 -0
  288. {birder-0.2.1 → birder-0.2.3}/birder/ops/msda.py +0 -0
  289. {birder-0.2.1 → birder-0.2.3}/birder/ops/soft_nms.py +0 -0
  290. {birder-0.2.1 → birder-0.2.3}/birder/ops/swattention.py +0 -0
  291. {birder-0.2.1 → birder-0.2.3}/birder/optim/__init__.py +0 -0
  292. {birder-0.2.1 → birder-0.2.3}/birder/optim/lamb.py +0 -0
  293. {birder-0.2.1 → birder-0.2.3}/birder/optim/lars.py +0 -0
  294. {birder-0.2.1 → birder-0.2.3}/birder/py.typed +0 -0
  295. {birder-0.2.1/birder/ops → birder-0.2.3/birder/results}/__init__.py +0 -0
  296. {birder-0.2.1 → birder-0.2.3}/birder/results/classification.py +0 -0
  297. {birder-0.2.1 → birder-0.2.3}/birder/results/detection.py +0 -0
  298. {birder-0.2.1 → birder-0.2.3}/birder/scheduler/__init__.py +0 -0
  299. {birder-0.2.1 → birder-0.2.3}/birder/scheduler/cooldown.py +0 -0
  300. {birder-0.2.1/birder/results → birder-0.2.3/birder/scripts}/__init__.py +0 -0
  301. {birder-0.2.1 → birder-0.2.3}/birder/scripts/__main__.py +0 -0
  302. {birder-0.2.1 → birder-0.2.3}/birder/scripts/evaluate.py +0 -0
  303. {birder-0.2.1/birder/scripts → birder-0.2.3/birder/tools}/__init__.py +0 -0
  304. {birder-0.2.1 → birder-0.2.3}/birder/tools/avg_model.py +0 -0
  305. {birder-0.2.1 → birder-0.2.3}/birder/tools/convert_model.py +0 -0
  306. {birder-0.2.1 → birder-0.2.3}/birder/tools/det_results.py +0 -0
  307. {birder-0.2.1 → birder-0.2.3}/birder/tools/download_model.py +0 -0
  308. {birder-0.2.1 → birder-0.2.3}/birder/tools/labelme_to_coco.py +0 -0
  309. {birder-0.2.1 → birder-0.2.3}/birder/tools/list_models.py +0 -0
  310. {birder-0.2.1 → birder-0.2.3}/birder/tools/model_info.py +0 -0
  311. {birder-0.2.1 → birder-0.2.3}/birder/tools/quantize_model.py +0 -0
  312. {birder-0.2.1 → birder-0.2.3}/birder/tools/results.py +0 -0
  313. {birder-0.2.1 → birder-0.2.3}/birder/tools/show_iterator.py +0 -0
  314. {birder-0.2.1 → birder-0.2.3}/birder/tools/similarity.py +0 -0
  315. {birder-0.2.1 → birder-0.2.3}/birder/tools/stats.py +0 -0
  316. {birder-0.2.1 → birder-0.2.3}/birder/tools/verify_coco.py +0 -0
  317. {birder-0.2.1 → birder-0.2.3}/birder/tools/verify_directory.py +0 -0
  318. {birder-0.2.1 → birder-0.2.3}/birder/tools/voc_to_coco.py +0 -0
  319. {birder-0.2.1 → birder-0.2.3}/birder.egg-info/dependency_links.txt +0 -0
  320. {birder-0.2.1 → birder-0.2.3}/birder.egg-info/entry_points.txt +0 -0
  321. {birder-0.2.1 → birder-0.2.3}/birder.egg-info/top_level.txt +0 -0
  322. {birder-0.2.1 → birder-0.2.3}/pyproject.toml +0 -0
  323. {birder-0.2.1 → birder-0.2.3}/requirements/requirements-hf.txt +0 -0
  324. {birder-0.2.1 → birder-0.2.3}/requirements/requirements.txt +0 -0
  325. {birder-0.2.1 → birder-0.2.3}/setup.cfg +0 -0
  326. {birder-0.2.1 → birder-0.2.3}/tests/test_collators.py +0 -0
  327. {birder-0.2.1 → birder-0.2.3}/tests/test_datasets.py +0 -0
  328. {birder-0.2.1 → birder-0.2.3}/tests/test_layers.py +0 -0
  329. {birder-0.2.1 → birder-0.2.3}/tests/test_net_mim.py +0 -0
  330. {birder-0.2.1 → birder-0.2.3}/tests/test_net_ssl.py +0 -0
  331. {birder-0.2.1 → birder-0.2.3}/tests/test_ops.py +0 -0
  332. {birder-0.2.1 → birder-0.2.3}/tests/test_optim.py +0 -0
  333. {birder-0.2.1 → birder-0.2.3}/tests/test_results.py +0 -0
  334. {birder-0.2.1 → birder-0.2.3}/tests/test_scheduler.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -45,7 +45,7 @@ Provides-Extra: dev
45
45
  Requires-Dist: altair~=5.5.0; extra == "dev"
46
46
  Requires-Dist: bandit~=1.9.2; extra == "dev"
47
47
  Requires-Dist: black~=25.12.0; extra == "dev"
48
- Requires-Dist: build~=1.3.0; extra == "dev"
48
+ Requires-Dist: build~=1.4.0; extra == "dev"
49
49
  Requires-Dist: bumpver~=2025.1131; extra == "dev"
50
50
  Requires-Dist: captum~=0.7.0; extra == "dev"
51
51
  Requires-Dist: coverage~=7.13.1; extra == "dev"
@@ -62,6 +62,7 @@ Requires-Dist: MonkeyType~=23.3.0; extra == "dev"
62
62
  Requires-Dist: mypy~=1.19.1; extra == "dev"
63
63
  Requires-Dist: parameterized~=0.9.0; extra == "dev"
64
64
  Requires-Dist: pylint~=4.0.4; extra == "dev"
65
+ Requires-Dist: pytest; extra == "dev"
65
66
  Requires-Dist: requests~=2.32.5; extra == "dev"
66
67
  Requires-Dist: safetensors~=0.7.0; extra == "dev"
67
68
  Requires-Dist: setuptools; extra == "dev"
@@ -207,7 +208,7 @@ For detailed information about these datasets, including descriptions, citations
207
208
 
208
209
  ## Detection
209
210
 
210
- Detection training and inference are available, see [docs/training_detection.md](docs/training_detection.md) and
211
+ Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
211
212
  [docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
212
213
 
213
214
  ## Project Status and Contributions
@@ -129,7 +129,7 @@ For detailed information about these datasets, including descriptions, citations
129
129
 
130
130
  ## Detection
131
131
 
132
- Detection training and inference are available, see [docs/training_detection.md](docs/training_detection.md) and
132
+ Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
133
133
  [docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
134
134
 
135
135
  ## Project Status and Contributions
@@ -0,0 +1,13 @@
1
+ from birder.adversarial.base import AttackResult
2
+ from birder.adversarial.deepfool import DeepFool
3
+ from birder.adversarial.fgsm import FGSM
4
+ from birder.adversarial.pgd import PGD
5
+ from birder.adversarial.simba import SimBA
6
+
7
+ __all__ = [
8
+ "AttackResult",
9
+ "DeepFool",
10
+ "FGSM",
11
+ "PGD",
12
+ "SimBA",
13
+ ]
@@ -0,0 +1,101 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ from typing import Protocol
4
+
5
+ import torch
6
+
7
+ from birder.data.transforms.classification import RGBType
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class AttackResult:
12
+ adv_inputs: torch.Tensor
13
+ adv_logits: torch.Tensor
14
+ perturbation: torch.Tensor
15
+ logits: Optional[torch.Tensor] = None
16
+ success: Optional[torch.Tensor] = None
17
+ num_queries: Optional[int] = None
18
+
19
+
20
+ class Attack(Protocol):
21
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult: ...
22
+
23
+
24
+ def _to_channel_tensor(
25
+ values: tuple[float, float, float], device: Optional[torch.device], dtype: Optional[torch.dtype]
26
+ ) -> torch.Tensor:
27
+ return torch.tensor(values, device=device, dtype=dtype).view(1, -1, 1, 1)
28
+
29
+
30
+ def normalized_bounds(
31
+ rgb_stats: RGBType, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
32
+ ) -> tuple[torch.Tensor, torch.Tensor]:
33
+ mean = _to_channel_tensor(rgb_stats["mean"], device=device, dtype=dtype)
34
+ std = _to_channel_tensor(rgb_stats["std"], device=device, dtype=dtype)
35
+ min_val = (0.0 - mean) / std
36
+ max_val = (1.0 - mean) / std
37
+
38
+ return (min_val, max_val)
39
+
40
+
41
+ def pixel_eps_to_normalized(
42
+ eps: float | torch.Tensor,
43
+ rgb_stats: RGBType,
44
+ device: Optional[torch.device] = None,
45
+ dtype: Optional[torch.dtype] = None,
46
+ ) -> torch.Tensor:
47
+ eps_tensor = torch.as_tensor(eps, device=device, dtype=dtype)
48
+ std = _to_channel_tensor(rgb_stats["std"], device=eps_tensor.device, dtype=eps_tensor.dtype)
49
+
50
+ if eps_tensor.numel() == 1:
51
+ eps_tensor = eps_tensor.reshape(1, 1, 1, 1)
52
+ else:
53
+ eps_tensor = eps_tensor.reshape(1, -1, 1, 1)
54
+
55
+ return eps_tensor / std
56
+
57
+
58
+ def clamp_normalized(inputs: torch.Tensor, rgb_stats: RGBType) -> torch.Tensor:
59
+ (min_val, max_val) = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
60
+ return torch.clamp(inputs, min=min_val, max=max_val)
61
+
62
+
63
+ def predict_labels(logits: torch.Tensor) -> torch.Tensor:
64
+ return torch.argmax(logits, dim=1)
65
+
66
+
67
+ def validate_target(
68
+ target: Optional[torch.Tensor], batch_size: int, num_classes: int, device: torch.device
69
+ ) -> Optional[torch.Tensor]:
70
+ if target is None:
71
+ return None
72
+
73
+ target = target.to(device=device, dtype=torch.long)
74
+ if target.ndim == 0:
75
+ target = target.view(1)
76
+
77
+ if target.shape[0] != batch_size:
78
+ raise ValueError(f"Target shape {target.shape[0]} must match batch size {batch_size}")
79
+
80
+ if torch.any(target < 0) or torch.any(target >= num_classes):
81
+ raise ValueError(f"Target values must be in range [0, {num_classes})")
82
+
83
+ return target
84
+
85
+
86
+ def attack_success(
87
+ logits: torch.Tensor,
88
+ adv_logits: torch.Tensor,
89
+ targeted: bool,
90
+ target: Optional[torch.Tensor] = None,
91
+ labels: Optional[torch.Tensor] = None,
92
+ ) -> torch.Tensor:
93
+ adv_pred = predict_labels(adv_logits)
94
+ if targeted is True:
95
+ if target is None:
96
+ raise ValueError("Target labels required for targeted attacks")
97
+
98
+ return adv_pred.eq(target)
99
+
100
+ base_labels = labels if labels is not None else predict_labels(logits)
101
+ return adv_pred.ne(base_labels)
@@ -0,0 +1,173 @@
1
+ """
2
+ DeepFool
3
+
4
+ Paper "DeepFool: a simple and accurate method to fool deep neural networks", https://arxiv.org/abs/1511.04599
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from birder.adversarial.base import AttackResult
13
+ from birder.adversarial.base import attack_success
14
+ from birder.adversarial.base import clamp_normalized
15
+ from birder.adversarial.base import predict_labels
16
+ from birder.adversarial.base import validate_target
17
+ from birder.data.transforms.classification import RGBType
18
+
19
+ GRAD_EPS = 1e-12
20
+
21
+
22
+ class DeepFool:
23
+ def __init__(
24
+ self, net: nn.Module, num_classes: int = 10, overshoot: float = 0.02, max_iter: int = 50, *, rgb_stats: RGBType
25
+ ) -> None:
26
+ if num_classes < 2:
27
+ raise ValueError("num_classes must be at least 2")
28
+ if max_iter <= 0:
29
+ raise ValueError("max_iter must be positive")
30
+ if overshoot < 0:
31
+ raise ValueError("overshoot must be non-negative")
32
+
33
+ self.net = net.eval()
34
+ self.num_classes = num_classes
35
+ self.overshoot = overshoot
36
+ self.max_iter = max_iter
37
+ self.rgb_stats = rgb_stats
38
+
39
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
40
+ inputs = input_tensor.detach()
41
+ with torch.no_grad():
42
+ logits = self.net(inputs)
43
+
44
+ target_labels = (
45
+ validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
46
+ )
47
+ targeted = target_labels is not None
48
+
49
+ adv_inputs_list = []
50
+ for idx in range(inputs.size(0)):
51
+ target_label = target_labels[idx : idx + 1] if target_labels is not None else None
52
+ adv_input = self._attack_single(inputs[idx : idx + 1], logits[idx : idx + 1], target_label)
53
+ adv_inputs_list.append(adv_input)
54
+
55
+ adv_inputs = torch.concat(adv_inputs_list, dim=0)
56
+ with torch.no_grad():
57
+ adv_logits = self.net(adv_inputs)
58
+
59
+ success = attack_success(
60
+ logits,
61
+ adv_logits,
62
+ targeted,
63
+ target=target_labels if targeted else None,
64
+ )
65
+
66
+ return AttackResult(
67
+ adv_inputs=adv_inputs,
68
+ adv_logits=adv_logits,
69
+ perturbation=adv_inputs - inputs,
70
+ logits=logits.detach(),
71
+ success=success,
72
+ )
73
+
74
+ def _attack_single(
75
+ self, inputs: torch.Tensor, logits: torch.Tensor, target_label: Optional[torch.Tensor]
76
+ ) -> torch.Tensor:
77
+ adv_inputs = inputs.clone()
78
+ original_label = int(predict_labels(logits).item())
79
+ targeted = target_label is not None
80
+ for _ in range(self.max_iter):
81
+ adv_inputs.requires_grad_(True)
82
+ outputs = self.net(adv_inputs)
83
+ current_label = int(predict_labels(outputs).item())
84
+
85
+ if targeted is True:
86
+ assert target_label is not None
87
+ target_value = int(target_label.item())
88
+ if current_label == target_value:
89
+ break
90
+
91
+ perturbation = self._targeted_perturbation(adv_inputs, outputs, current_label, target_value)
92
+
93
+ else:
94
+ if current_label != original_label:
95
+ break
96
+
97
+ perturbation = self._untargeted_perturbation(adv_inputs, outputs, current_label)
98
+
99
+ if perturbation is None:
100
+ break
101
+
102
+ # Overshoot helps ensure boundary crossing
103
+ adv_inputs = adv_inputs.detach() + (1.0 + self.overshoot) * perturbation
104
+ adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
105
+
106
+ return adv_inputs.detach()
107
+
108
+ def _targeted_perturbation(
109
+ self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int, target_label: int
110
+ ) -> Optional[torch.Tensor]:
111
+ self.net.zero_grad(set_to_none=True)
112
+ grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
113
+ grad_target = torch.autograd.grad(outputs[0, target_label], adv_inputs, retain_graph=False)[0]
114
+
115
+ # Direction toward the target boundary
116
+ w = grad_target - grad_current
117
+ w_norm = torch.norm(w.view(-1))
118
+ if w_norm.item() < GRAD_EPS:
119
+ return None
120
+
121
+ # Distance to the decision boundary
122
+ f = outputs[0, target_label] - outputs[0, current_label]
123
+ perturbation = (f.abs() / (w_norm**2 + GRAD_EPS)) * w
124
+
125
+ return perturbation
126
+
127
+ def _untargeted_perturbation(
128
+ self, adv_inputs: torch.Tensor, outputs: torch.Tensor, current_label: int
129
+ ) -> Optional[torch.Tensor]:
130
+ # Search the top-k competing classes
131
+ top_k = min(self.num_classes, outputs.shape[1])
132
+ top_indices = torch.topk(outputs, k=top_k, dim=1).indices[0]
133
+ candidate_labels = [int(idx) for idx in top_indices if int(idx) != current_label]
134
+
135
+ if len(candidate_labels) == 0:
136
+ return None
137
+
138
+ self.net.zero_grad(set_to_none=True)
139
+ grad_current = torch.autograd.grad(outputs[0, current_label], adv_inputs, retain_graph=True)[0]
140
+
141
+ # Track the closest decision boundary
142
+ best_dist = None
143
+ best_w = None
144
+ best_f = None
145
+ for idx, label in enumerate(candidate_labels):
146
+ # Keep the graph until the last class
147
+ retain_graph = idx != len(candidate_labels) - 1
148
+ grad_other = torch.autograd.grad(outputs[0, label], adv_inputs, retain_graph=retain_graph)[0]
149
+
150
+ w_k = grad_other - grad_current
151
+ w_norm = torch.norm(w_k.view(-1))
152
+ if w_norm.item() < GRAD_EPS:
153
+ continue
154
+
155
+ f_k = outputs[0, label] - outputs[0, current_label]
156
+ dist = f_k.abs() / (w_norm + GRAD_EPS)
157
+
158
+ if best_dist is None or dist < best_dist:
159
+ best_dist = dist
160
+ best_w = w_k
161
+ best_f = f_k
162
+
163
+ if best_w is None or best_f is None:
164
+ return None
165
+
166
+ # Minimal perturbation toward the closest boundary
167
+ best_w_norm = torch.norm(best_w.view(-1))
168
+ if best_w_norm.item() < GRAD_EPS:
169
+ return None
170
+
171
+ perturbation = (best_f.abs() / (best_w_norm**2 + GRAD_EPS)) * best_w
172
+
173
+ return perturbation
@@ -0,0 +1,67 @@
1
+ """
2
+ Fast Gradient Sign Method (FGSM)
3
+
4
+ Paper "Explaining and Harnessing Adversarial Examples", https://arxiv.org/abs/1412.6572
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from birder.adversarial.base import AttackResult
14
+ from birder.adversarial.base import attack_success
15
+ from birder.adversarial.base import clamp_normalized
16
+ from birder.adversarial.base import pixel_eps_to_normalized
17
+ from birder.adversarial.base import predict_labels
18
+ from birder.adversarial.base import validate_target
19
+ from birder.data.transforms.classification import RGBType
20
+
21
+
22
+ class FGSM:
23
+ def __init__(self, net: nn.Module, eps: float, *, rgb_stats: RGBType) -> None:
24
+ self.net = net.eval()
25
+ self.eps = eps
26
+ self.rgb_stats = rgb_stats
27
+
28
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
29
+ inputs = input_tensor.detach().clone()
30
+ inputs.requires_grad_(True)
31
+
32
+ logits = self.net(inputs)
33
+ targeted = target is not None
34
+ if targeted is True:
35
+ target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
36
+ else:
37
+ target = predict_labels(logits)
38
+
39
+ loss = F.cross_entropy(logits, target)
40
+ (grad,) = torch.autograd.grad(loss, inputs, retain_graph=False, create_graph=False)
41
+ eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
42
+
43
+ # Targeted steps descend toward target, untargeted ascend away from original
44
+ if targeted is True:
45
+ direction = -1.0
46
+ else:
47
+ direction = 1.0
48
+
49
+ perturbation = direction * eps_norm * grad.sign()
50
+ adv_inputs = clamp_normalized(inputs + perturbation, self.rgb_stats)
51
+ with torch.no_grad():
52
+ adv_logits = self.net(adv_inputs)
53
+
54
+ success = attack_success(
55
+ logits.detach(),
56
+ adv_logits,
57
+ targeted,
58
+ target=target if targeted else None,
59
+ )
60
+
61
+ return AttackResult(
62
+ adv_inputs=adv_inputs,
63
+ adv_logits=adv_logits,
64
+ perturbation=adv_inputs - inputs,
65
+ logits=logits.detach(),
66
+ success=success,
67
+ )
@@ -0,0 +1,105 @@
1
+ """
2
+ Projected Gradient Descent (PGD)
3
+
4
+ Paper "Towards Deep Learning Models Resistant to Adversarial Attacks", https://arxiv.org/abs/1706.06083
5
+ """
6
+
7
+ # Reference license: MIT
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ from birder.adversarial.base import AttackResult
16
+ from birder.adversarial.base import attack_success
17
+ from birder.adversarial.base import clamp_normalized
18
+ from birder.adversarial.base import pixel_eps_to_normalized
19
+ from birder.adversarial.base import predict_labels
20
+ from birder.adversarial.base import validate_target
21
+ from birder.data.transforms.classification import RGBType
22
+
23
+
24
+ class PGD:
25
+ def __init__(
26
+ self,
27
+ net: nn.Module,
28
+ eps: float,
29
+ steps: int = 10,
30
+ step_size: Optional[float] = None,
31
+ random_start: bool = False,
32
+ *,
33
+ rgb_stats: RGBType,
34
+ ) -> None:
35
+ if steps <= 0:
36
+ raise ValueError("steps must be a positive integer")
37
+
38
+ self.net = net.eval()
39
+ self.eps = eps
40
+ self.steps = steps
41
+ if step_size is not None:
42
+ self.step_size = step_size
43
+ else:
44
+ self.step_size = eps / steps
45
+
46
+ self.random_start = random_start
47
+ self.rgb_stats = rgb_stats
48
+
49
+ if self.step_size <= 0:
50
+ raise ValueError("step_size must be positive")
51
+
52
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
53
+ inputs = input_tensor.detach()
54
+ with torch.no_grad():
55
+ logits = self.net(inputs)
56
+
57
+ targeted = target is not None
58
+ if targeted:
59
+ target = validate_target(target, inputs.shape[0], logits.shape[1], inputs.device)
60
+ else:
61
+ target = predict_labels(logits)
62
+
63
+ eps_norm = pixel_eps_to_normalized(self.eps, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
64
+ step_norm = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=inputs.device, dtype=inputs.dtype)
65
+
66
+ # Targeted steps descend toward target, untargeted ascend away from original
67
+ if targeted is True:
68
+ direction = -1.0
69
+ else:
70
+ direction = 1.0
71
+
72
+ adv_inputs = inputs.clone()
73
+ if self.random_start is True:
74
+ # Random start inside the epsilon ball
75
+ adv_inputs = adv_inputs + torch.empty_like(adv_inputs).uniform_(-1.0, 1.0) * eps_norm
76
+ adv_inputs = clamp_normalized(adv_inputs, self.rgb_stats)
77
+
78
+ for _ in range(self.steps):
79
+ adv_inputs.requires_grad_(True)
80
+ adv_logits = self.net(adv_inputs)
81
+ loss = F.cross_entropy(adv_logits, target)
82
+ (grad,) = torch.autograd.grad(loss, adv_inputs, retain_graph=False, create_graph=False)
83
+ adv_inputs = adv_inputs.detach() + direction * step_norm * grad.sign()
84
+
85
+ # Project back into the epsilon ball around the original input.
86
+ delta = torch.clamp(adv_inputs - inputs, min=-eps_norm, max=eps_norm)
87
+ adv_inputs = clamp_normalized(inputs + delta, self.rgb_stats)
88
+
89
+ with torch.no_grad():
90
+ adv_logits = self.net(adv_inputs)
91
+
92
+ success = attack_success(
93
+ logits.detach(),
94
+ adv_logits,
95
+ targeted,
96
+ target=target if targeted else None,
97
+ )
98
+
99
+ return AttackResult(
100
+ adv_inputs=adv_inputs,
101
+ adv_logits=adv_logits,
102
+ perturbation=adv_inputs - inputs,
103
+ logits=logits.detach(),
104
+ success=success,
105
+ )
@@ -0,0 +1,172 @@
1
+ """
2
+ SimBA (Simple Black-box Attack)
3
+
4
+ Paper "Simple Black-box Adversarial Attacks", https://arxiv.org/abs/1905.07121
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from birder.adversarial.base import AttackResult
14
+ from birder.adversarial.base import attack_success
15
+ from birder.adversarial.base import clamp_normalized
16
+ from birder.adversarial.base import pixel_eps_to_normalized
17
+ from birder.adversarial.base import predict_labels
18
+ from birder.adversarial.base import validate_target
19
+ from birder.data.transforms.classification import RGBType
20
+
21
+
22
+ class SimBA:
23
+ def __init__(self, net: nn.Module, step_size: float, max_iter: int = 1000, *, rgb_stats: RGBType) -> None:
24
+ if step_size <= 0:
25
+ raise ValueError("step_size must be positive")
26
+ if max_iter <= 0:
27
+ raise ValueError("max_iter must be positive")
28
+
29
+ self.net = net.eval()
30
+ self.step_size = step_size
31
+ self.max_iter = max_iter
32
+ self.rgb_stats = rgb_stats
33
+
34
+ def __call__(self, input_tensor: torch.Tensor, target: Optional[torch.Tensor]) -> AttackResult:
35
+ inputs = input_tensor.detach()
36
+ with torch.no_grad():
37
+ logits = self.net(inputs)
38
+
39
+ labels = predict_labels(logits)
40
+ target_labels = (
41
+ validate_target(target, inputs.shape[0], logits.shape[1], inputs.device) if target is not None else None
42
+ )
43
+ targeted = target_labels is not None
44
+
45
+ adv_inputs_list = []
46
+ total_queries = 0
47
+ for idx in range(inputs.size(0)):
48
+ label = labels[idx : idx + 1]
49
+ target_label = target_labels[idx : idx + 1] if target_labels is not None else None
50
+ adv_input, num_queries = self._attack_single(inputs[idx : idx + 1], label, target_label)
51
+ adv_inputs_list.append(adv_input)
52
+ total_queries += num_queries
53
+
54
+ adv_inputs = torch.concat(adv_inputs_list, dim=0)
55
+ with torch.no_grad():
56
+ adv_logits = self.net(adv_inputs)
57
+
58
+ success = attack_success(
59
+ logits,
60
+ adv_logits,
61
+ targeted,
62
+ target=target_labels if targeted else None,
63
+ )
64
+
65
+ return AttackResult(
66
+ adv_inputs=adv_inputs,
67
+ adv_logits=adv_logits,
68
+ perturbation=adv_inputs - inputs,
69
+ logits=logits.detach(),
70
+ success=success,
71
+ num_queries=total_queries,
72
+ )
73
+
74
+ # pylint: disable=too-many-locals
75
+ def _attack_single(
76
+ self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
77
+ ) -> tuple[torch.Tensor, int]:
78
+ adv_inputs = inputs.clone()
79
+ num_queries = 1 # Baseline forward pass
80
+
81
+ with torch.no_grad():
82
+ current_logits = self.net(adv_inputs)
83
+ current_objective = self._compute_objective(current_logits, label, target_label)
84
+
85
+ if self._is_successful(current_logits, label, target_label):
86
+ return adv_inputs.detach(), num_queries
87
+
88
+ (_, channels, height, width) = adv_inputs.shape
89
+ num_dims = channels * height * width
90
+ step = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=adv_inputs.device, dtype=adv_inputs.dtype)
91
+ step_vals = step.view(-1) # Per-channel steps
92
+ stride = height * width
93
+
94
+ perm = torch.randperm(num_dims, device=adv_inputs.device)
95
+ num_steps = min(self.max_iter, num_dims)
96
+
97
+ # Coordinate-wise search in random order
98
+ for flat_idx in perm[:num_steps]:
99
+ (c, rem) = divmod(int(flat_idx.item()), stride)
100
+ (h, w) = divmod(rem, width)
101
+ step_val = step_vals[c]
102
+
103
+ (candidate_inputs, candidate_logits, candidate_objective) = self._best_candidate(
104
+ adv_inputs, c, h, w, step_val, label, target_label
105
+ )
106
+ num_queries += 2
107
+
108
+ if candidate_objective < current_objective:
109
+ adv_inputs = candidate_inputs
110
+ current_logits = candidate_logits
111
+ current_objective = candidate_objective
112
+
113
+ if self._is_successful(current_logits, label, target_label) is True:
114
+ break
115
+
116
+ return adv_inputs.detach(), num_queries
117
+
118
+ def _perturb_pixel(
119
+ self, inputs: torch.Tensor, channel: int, row: int, col: int, step: torch.Tensor
120
+ ) -> torch.Tensor:
121
+ adv_inputs = inputs.clone()
122
+ adv_inputs[0, channel, row, col] = adv_inputs[0, channel, row, col] + step
123
+ return clamp_normalized(adv_inputs, self.rgb_stats)
124
+
125
+ def _evaluate_candidate(
126
+ self, inputs: torch.Tensor, label: torch.Tensor, target_label: Optional[torch.Tensor]
127
+ ) -> tuple[torch.Tensor, float]:
128
+ with torch.no_grad():
129
+ logits = self.net(inputs)
130
+
131
+ return logits, self._compute_objective(logits, label, target_label)
132
+
133
+ def _best_candidate(
134
+ self,
135
+ inputs: torch.Tensor,
136
+ channel: int,
137
+ row: int,
138
+ col: int,
139
+ step: torch.Tensor,
140
+ label: torch.Tensor,
141
+ target_label: Optional[torch.Tensor],
142
+ ) -> tuple[torch.Tensor, torch.Tensor, float]:
143
+ adv_plus = self._perturb_pixel(inputs, channel, row, col, step)
144
+ logits_plus, objective_plus = self._evaluate_candidate(adv_plus, label, target_label)
145
+
146
+ adv_minus = self._perturb_pixel(inputs, channel, row, col, -step)
147
+ logits_minus, objective_minus = self._evaluate_candidate(adv_minus, label, target_label)
148
+
149
+ if objective_plus <= objective_minus:
150
+ return adv_plus, logits_plus, objective_plus
151
+
152
+ return adv_minus, logits_minus, objective_minus
153
+
154
+ @staticmethod
155
+ def _compute_objective(
156
+ logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
157
+ ) -> float:
158
+ # Lower objective is better in both modes
159
+ if target_label is not None:
160
+ return float(F.cross_entropy(logits, target_label).item())
161
+
162
+ return -float(F.cross_entropy(logits, original_label).item())
163
+
164
+ @staticmethod
165
+ def _is_successful(
166
+ logits: torch.Tensor, original_label: torch.Tensor, target_label: Optional[torch.Tensor]
167
+ ) -> bool:
168
+ pred = predict_labels(logits)
169
+ if target_label is not None:
170
+ return bool(pred.eq(target_label).item())
171
+
172
+ return bool(pred.ne(original_label).item())