birder 0.4.0__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (329) hide show
  1. {birder-0.4.0 → birder-0.4.2}/PKG-INFO +5 -5
  2. {birder-0.4.0 → birder-0.4.2}/birder/adversarial/base.py +1 -1
  3. {birder-0.4.0 → birder-0.4.2}/birder/adversarial/simba.py +4 -4
  4. {birder-0.4.0 → birder-0.4.2}/birder/common/cli.py +1 -1
  5. {birder-0.4.0 → birder-0.4.2}/birder/common/fs_ops.py +11 -11
  6. {birder-0.4.0 → birder-0.4.2}/birder/common/lib.py +2 -2
  7. {birder-0.4.0 → birder-0.4.2}/birder/common/masking.py +3 -3
  8. {birder-0.4.0 → birder-0.4.2}/birder/common/training_cli.py +30 -3
  9. {birder-0.4.0 → birder-0.4.2}/birder/common/training_utils.py +97 -16
  10. {birder-0.4.0 → birder-0.4.2}/birder/data/collators/detection.py +9 -1
  11. {birder-0.4.0 → birder-0.4.2}/birder/data/transforms/detection.py +27 -8
  12. {birder-0.4.0 → birder-0.4.2}/birder/data/transforms/mosaic.py +1 -1
  13. {birder-0.4.0 → birder-0.4.2}/birder/datahub/classification.py +3 -3
  14. {birder-0.4.0 → birder-0.4.2}/birder/inference/classification.py +3 -3
  15. {birder-0.4.0 → birder-0.4.2}/birder/inference/data_parallel.py +1 -1
  16. {birder-0.4.0 → birder-0.4.2}/birder/inference/detection.py +5 -5
  17. {birder-0.4.0 → birder-0.4.2}/birder/inference/wbf.py +1 -1
  18. {birder-0.4.0 → birder-0.4.2}/birder/introspection/attention_rollout.py +5 -5
  19. {birder-0.4.0 → birder-0.4.2}/birder/introspection/feature_pca.py +4 -4
  20. {birder-0.4.0 → birder-0.4.2}/birder/introspection/gradcam.py +1 -1
  21. {birder-0.4.0 → birder-0.4.2}/birder/introspection/guided_backprop.py +2 -2
  22. {birder-0.4.0 → birder-0.4.2}/birder/introspection/transformer_attribution.py +3 -3
  23. {birder-0.4.0 → birder-0.4.2}/birder/layers/attention_pool.py +2 -2
  24. {birder-0.4.0 → birder-0.4.2}/birder/model_registry/model_registry.py +2 -1
  25. {birder-0.4.0 → birder-0.4.2}/birder/net/__init__.py +2 -0
  26. {birder-0.4.0 → birder-0.4.2}/birder/net/_rope_vit_configs.py +5 -0
  27. {birder-0.4.0 → birder-0.4.2}/birder/net/_vit_configs.py +5 -13
  28. {birder-0.4.0 → birder-0.4.2}/birder/net/alexnet.py +5 -5
  29. {birder-0.4.0 → birder-0.4.2}/birder/net/base.py +28 -3
  30. {birder-0.4.0 → birder-0.4.2}/birder/net/biformer.py +17 -17
  31. {birder-0.4.0 → birder-0.4.2}/birder/net/cait.py +5 -5
  32. {birder-0.4.0 → birder-0.4.2}/birder/net/cas_vit.py +1 -1
  33. {birder-0.4.0 → birder-0.4.2}/birder/net/coat.py +18 -18
  34. {birder-0.4.0 → birder-0.4.2}/birder/net/convnext_v1.py +2 -10
  35. birder-0.4.2/birder/net/convnext_v1_iso.py +198 -0
  36. {birder-0.4.0 → birder-0.4.2}/birder/net/convnext_v2.py +2 -10
  37. {birder-0.4.0 → birder-0.4.2}/birder/net/crossformer.py +9 -9
  38. {birder-0.4.0 → birder-0.4.2}/birder/net/crossvit.py +1 -1
  39. {birder-0.4.0 → birder-0.4.2}/birder/net/cspnet.py +1 -1
  40. {birder-0.4.0 → birder-0.4.2}/birder/net/cswin_transformer.py +10 -10
  41. {birder-0.4.0 → birder-0.4.2}/birder/net/davit.py +10 -10
  42. {birder-0.4.0 → birder-0.4.2}/birder/net/deit.py +56 -3
  43. {birder-0.4.0 → birder-0.4.2}/birder/net/deit3.py +27 -15
  44. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/__init__.py +6 -0
  45. birder-0.4.0/birder/net/detection/yolo_anchors.py → birder-0.4.2/birder/net/detection/_yolo_anchors.py +5 -5
  46. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/base.py +6 -5
  47. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/deformable_detr.py +38 -40
  48. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/detr.py +15 -15
  49. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/efficientdet.py +9 -28
  50. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/faster_rcnn.py +22 -22
  51. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/fcos.py +8 -8
  52. birder-0.4.2/birder/net/detection/lw_detr.py +1181 -0
  53. birder-0.4.2/birder/net/detection/plain_detr.py +854 -0
  54. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/retinanet.py +5 -5
  55. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/rt_detr_v1.py +91 -35
  56. birder-0.4.2/birder/net/detection/rt_detr_v2.py +1130 -0
  57. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/ssd.py +5 -5
  58. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/ssdlite.py +2 -2
  59. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/yolo_v2.py +12 -12
  60. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/yolo_v3.py +19 -19
  61. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/yolo_v4.py +16 -16
  62. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/yolo_v4_tiny.py +3 -3
  63. {birder-0.4.0 → birder-0.4.2}/birder/net/edgenext.py +3 -3
  64. {birder-0.4.0 → birder-0.4.2}/birder/net/edgevit.py +13 -17
  65. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientformer_v1.py +1 -1
  66. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientvim.py +9 -9
  67. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientvit_mit.py +2 -2
  68. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientvit_msft.py +4 -4
  69. {birder-0.4.0 → birder-0.4.2}/birder/net/fasternet.py +1 -1
  70. {birder-0.4.0 → birder-0.4.2}/birder/net/fastvit.py +5 -12
  71. {birder-0.4.0 → birder-0.4.2}/birder/net/flexivit.py +28 -15
  72. {birder-0.4.0 → birder-0.4.2}/birder/net/focalnet.py +5 -9
  73. {birder-0.4.0 → birder-0.4.2}/birder/net/gc_vit.py +11 -11
  74. {birder-0.4.0 → birder-0.4.2}/birder/net/ghostnet_v1.py +1 -1
  75. {birder-0.4.0 → birder-0.4.2}/birder/net/ghostnet_v2.py +1 -1
  76. {birder-0.4.0 → birder-0.4.2}/birder/net/groupmixformer.py +12 -12
  77. {birder-0.4.0 → birder-0.4.2}/birder/net/hgnet_v1.py +1 -1
  78. {birder-0.4.0 → birder-0.4.2}/birder/net/hgnet_v2.py +4 -4
  79. {birder-0.4.0 → birder-0.4.2}/birder/net/hiera.py +6 -6
  80. {birder-0.4.0 → birder-0.4.2}/birder/net/hieradet.py +11 -11
  81. {birder-0.4.0 → birder-0.4.2}/birder/net/hornet.py +3 -3
  82. {birder-0.4.0 → birder-0.4.2}/birder/net/iformer.py +4 -4
  83. {birder-0.4.0 → birder-0.4.2}/birder/net/inception_next.py +4 -14
  84. {birder-0.4.0 → birder-0.4.2}/birder/net/levit.py +3 -3
  85. {birder-0.4.0 → birder-0.4.2}/birder/net/lit_v1.py +13 -15
  86. {birder-0.4.0 → birder-0.4.2}/birder/net/lit_v1_tiny.py +9 -9
  87. {birder-0.4.0 → birder-0.4.2}/birder/net/lit_v2.py +14 -15
  88. {birder-0.4.0 → birder-0.4.2}/birder/net/maxvit.py +10 -22
  89. {birder-0.4.0 → birder-0.4.2}/birder/net/metaformer.py +2 -2
  90. {birder-0.4.0 → birder-0.4.2}/birder/net/mim/crossmae.py +5 -5
  91. {birder-0.4.0 → birder-0.4.2}/birder/net/mim/fcmae.py +3 -5
  92. {birder-0.4.0 → birder-0.4.2}/birder/net/mim/mae_hiera.py +7 -7
  93. {birder-0.4.0 → birder-0.4.2}/birder/net/mim/mae_vit.py +3 -5
  94. {birder-0.4.0 → birder-0.4.2}/birder/net/mim/simmim.py +2 -3
  95. {birder-0.4.0 → birder-0.4.2}/birder/net/mnasnet.py +2 -2
  96. {birder-0.4.0 → birder-0.4.2}/birder/net/mobilenet_v4_hybrid.py +4 -4
  97. {birder-0.4.0 → birder-0.4.2}/birder/net/mobileone.py +5 -12
  98. {birder-0.4.0 → birder-0.4.2}/birder/net/mobilevit_v1.py +2 -2
  99. {birder-0.4.0 → birder-0.4.2}/birder/net/mobilevit_v2.py +5 -9
  100. {birder-0.4.0 → birder-0.4.2}/birder/net/mvit_v2.py +24 -24
  101. {birder-0.4.0 → birder-0.4.2}/birder/net/nextvit.py +2 -2
  102. {birder-0.4.0 → birder-0.4.2}/birder/net/pit.py +11 -26
  103. {birder-0.4.0 → birder-0.4.2}/birder/net/pvt_v1.py +4 -4
  104. {birder-0.4.0 → birder-0.4.2}/birder/net/pvt_v2.py +5 -11
  105. {birder-0.4.0 → birder-0.4.2}/birder/net/regionvit.py +15 -15
  106. {birder-0.4.0 → birder-0.4.2}/birder/net/regnet.py +1 -1
  107. {birder-0.4.0 → birder-0.4.2}/birder/net/repghost.py +4 -5
  108. {birder-0.4.0 → birder-0.4.2}/birder/net/repvgg.py +3 -5
  109. {birder-0.4.0 → birder-0.4.2}/birder/net/repvit.py +2 -2
  110. {birder-0.4.0 → birder-0.4.2}/birder/net/resnest.py +1 -1
  111. {birder-0.4.0 → birder-0.4.2}/birder/net/resnext.py +2 -2
  112. {birder-0.4.0 → birder-0.4.2}/birder/net/rope_deit3.py +29 -15
  113. {birder-0.4.0 → birder-0.4.2}/birder/net/rope_flexivit.py +28 -15
  114. {birder-0.4.0 → birder-0.4.2}/birder/net/rope_vit.py +41 -23
  115. {birder-0.4.0 → birder-0.4.2}/birder/net/sequencer2d.py +3 -4
  116. {birder-0.4.0 → birder-0.4.2}/birder/net/shufflenet_v1.py +1 -1
  117. {birder-0.4.0 → birder-0.4.2}/birder/net/shufflenet_v2.py +1 -1
  118. {birder-0.4.0 → birder-0.4.2}/birder/net/simple_vit.py +47 -5
  119. {birder-0.4.0 → birder-0.4.2}/birder/net/smt.py +7 -7
  120. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/barlow_twins.py +1 -1
  121. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/byol.py +2 -2
  122. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/capi.py +3 -3
  123. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/data2vec2.py +1 -1
  124. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/dino_v2.py +11 -1
  125. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/franca.py +26 -2
  126. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/i_jepa.py +4 -4
  127. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/mmcr.py +1 -1
  128. {birder-0.4.0 → birder-0.4.2}/birder/net/swiftformer.py +1 -1
  129. {birder-0.4.0 → birder-0.4.2}/birder/net/swin_transformer_v1.py +4 -5
  130. {birder-0.4.0 → birder-0.4.2}/birder/net/swin_transformer_v2.py +4 -7
  131. {birder-0.4.0 → birder-0.4.2}/birder/net/tiny_vit.py +3 -3
  132. {birder-0.4.0 → birder-0.4.2}/birder/net/transnext.py +19 -19
  133. {birder-0.4.0 → birder-0.4.2}/birder/net/uniformer.py +4 -4
  134. {birder-0.4.0 → birder-0.4.2}/birder/net/vgg.py +1 -10
  135. {birder-0.4.0 → birder-0.4.2}/birder/net/vit.py +58 -27
  136. {birder-0.4.0 → birder-0.4.2}/birder/net/vit_parallel.py +35 -20
  137. {birder-0.4.0 → birder-0.4.2}/birder/net/vit_sam.py +72 -26
  138. {birder-0.4.0 → birder-0.4.2}/birder/net/vovnet_v2.py +1 -1
  139. {birder-0.4.0 → birder-0.4.2}/birder/net/xcit.py +9 -7
  140. {birder-0.4.0 → birder-0.4.2}/birder/ops/msda.py +4 -4
  141. {birder-0.4.0 → birder-0.4.2}/birder/ops/swattention.py +10 -10
  142. {birder-0.4.0 → birder-0.4.2}/birder/results/classification.py +3 -3
  143. {birder-0.4.0 → birder-0.4.2}/birder/results/gui.py +8 -8
  144. {birder-0.4.0 → birder-0.4.2}/birder/scripts/benchmark.py +37 -12
  145. {birder-0.4.0 → birder-0.4.2}/birder/scripts/evaluate.py +1 -1
  146. {birder-0.4.0 → birder-0.4.2}/birder/scripts/predict.py +3 -3
  147. {birder-0.4.0 → birder-0.4.2}/birder/scripts/predict_detection.py +2 -2
  148. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train.py +69 -17
  149. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_barlow_twins.py +10 -7
  150. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_byol.py +10 -7
  151. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_capi.py +28 -20
  152. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_data2vec.py +10 -7
  153. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_data2vec2.py +10 -7
  154. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_detection.py +31 -15
  155. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_dino_v1.py +13 -9
  156. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_dino_v2.py +27 -14
  157. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_dino_v2_dist.py +28 -15
  158. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_franca.py +16 -9
  159. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_i_jepa.py +12 -9
  160. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_ibot.py +15 -11
  161. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_kd.py +70 -19
  162. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_mim.py +11 -8
  163. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_mmcr.py +11 -8
  164. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_rotnet.py +11 -7
  165. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_simclr.py +10 -7
  166. {birder-0.4.0 → birder-0.4.2}/birder/scripts/train_vicreg.py +10 -7
  167. {birder-0.4.0 → birder-0.4.2}/birder/tools/adversarial.py +4 -4
  168. {birder-0.4.0 → birder-0.4.2}/birder/tools/auto_anchors.py +5 -5
  169. {birder-0.4.0 → birder-0.4.2}/birder/tools/avg_model.py +1 -1
  170. {birder-0.4.0 → birder-0.4.2}/birder/tools/convert_model.py +30 -22
  171. {birder-0.4.0 → birder-0.4.2}/birder/tools/det_results.py +1 -1
  172. {birder-0.4.0 → birder-0.4.2}/birder/tools/download_model.py +1 -1
  173. {birder-0.4.0 → birder-0.4.2}/birder/tools/ensemble_model.py +1 -1
  174. {birder-0.4.0 → birder-0.4.2}/birder/tools/introspection.py +11 -2
  175. {birder-0.4.0 → birder-0.4.2}/birder/tools/labelme_to_coco.py +2 -2
  176. {birder-0.4.0 → birder-0.4.2}/birder/tools/model_info.py +12 -14
  177. {birder-0.4.0 → birder-0.4.2}/birder/tools/pack.py +8 -8
  178. {birder-0.4.0 → birder-0.4.2}/birder/tools/quantize_model.py +53 -4
  179. {birder-0.4.0 → birder-0.4.2}/birder/tools/results.py +2 -2
  180. {birder-0.4.0 → birder-0.4.2}/birder/tools/show_det_iterator.py +19 -6
  181. {birder-0.4.0 → birder-0.4.2}/birder/tools/show_iterator.py +2 -2
  182. {birder-0.4.0 → birder-0.4.2}/birder/tools/similarity.py +5 -5
  183. {birder-0.4.0 → birder-0.4.2}/birder/tools/stats.py +4 -6
  184. {birder-0.4.0 → birder-0.4.2}/birder/tools/voc_to_coco.py +1 -1
  185. birder-0.4.2/birder/version.py +1 -0
  186. {birder-0.4.0 → birder-0.4.2}/birder.egg-info/PKG-INFO +5 -5
  187. {birder-0.4.0 → birder-0.4.2}/birder.egg-info/SOURCES.txt +5 -1
  188. {birder-0.4.0 → birder-0.4.2}/birder.egg-info/requires.txt +4 -4
  189. {birder-0.4.0 → birder-0.4.2}/requirements/_requirements-dev.txt +3 -3
  190. {birder-0.4.0 → birder-0.4.2}/requirements/requirements.txt +1 -1
  191. {birder-0.4.0 → birder-0.4.2}/tests/test_adversarial.py +5 -5
  192. {birder-0.4.0 → birder-0.4.2}/tests/test_collators.py +2 -2
  193. {birder-0.4.0 → birder-0.4.2}/tests/test_common.py +74 -14
  194. {birder-0.4.0 → birder-0.4.2}/tests/test_datasets.py +2 -2
  195. {birder-0.4.0 → birder-0.4.2}/tests/test_inference.py +10 -10
  196. {birder-0.4.0 → birder-0.4.2}/tests/test_introspection.py +5 -6
  197. {birder-0.4.0 → birder-0.4.2}/tests/test_kernels.py +19 -14
  198. {birder-0.4.0 → birder-0.4.2}/tests/test_model_registry.py +1 -1
  199. {birder-0.4.0 → birder-0.4.2}/tests/test_net.py +85 -40
  200. {birder-0.4.0 → birder-0.4.2}/tests/test_net_detection.py +34 -12
  201. {birder-0.4.0 → birder-0.4.2}/tests/test_net_mim.py +1 -1
  202. {birder-0.4.0 → birder-0.4.2}/tests/test_net_ssl.py +166 -49
  203. {birder-0.4.0 → birder-0.4.2}/tests/test_ops.py +6 -6
  204. {birder-0.4.0 → birder-0.4.2}/tests/test_transforms.py +8 -8
  205. birder-0.4.0/birder/version.py +0 -1
  206. {birder-0.4.0 → birder-0.4.2}/LICENSE +0 -0
  207. {birder-0.4.0 → birder-0.4.2}/README.md +0 -0
  208. {birder-0.4.0 → birder-0.4.2}/birder/__init__.py +0 -0
  209. {birder-0.4.0 → birder-0.4.2}/birder/adversarial/__init__.py +0 -0
  210. {birder-0.4.0 → birder-0.4.2}/birder/adversarial/deepfool.py +0 -0
  211. {birder-0.4.0 → birder-0.4.2}/birder/adversarial/fgsm.py +0 -0
  212. {birder-0.4.0 → birder-0.4.2}/birder/adversarial/pgd.py +0 -0
  213. {birder-0.4.0 → birder-0.4.2}/birder/common/__init__.py +0 -0
  214. {birder-0.4.0 → birder-0.4.2}/birder/conf/__init__.py +0 -0
  215. {birder-0.4.0 → birder-0.4.2}/birder/conf/settings.py +0 -0
  216. {birder-0.4.0 → birder-0.4.2}/birder/data/__init__.py +0 -0
  217. {birder-0.4.0 → birder-0.4.2}/birder/data/collators/__init__.py +0 -0
  218. {birder-0.4.0 → birder-0.4.2}/birder/data/dataloader/__init__.py +0 -0
  219. {birder-0.4.0 → birder-0.4.2}/birder/data/dataloader/webdataset.py +0 -0
  220. {birder-0.4.0 → birder-0.4.2}/birder/data/datasets/__init__.py +0 -0
  221. {birder-0.4.0 → birder-0.4.2}/birder/data/datasets/coco.py +0 -0
  222. {birder-0.4.0 → birder-0.4.2}/birder/data/datasets/directory.py +0 -0
  223. {birder-0.4.0 → birder-0.4.2}/birder/data/datasets/fake.py +0 -0
  224. {birder-0.4.0 → birder-0.4.2}/birder/data/datasets/webdataset.py +0 -0
  225. {birder-0.4.0 → birder-0.4.2}/birder/data/transforms/__init__.py +0 -0
  226. {birder-0.4.0 → birder-0.4.2}/birder/data/transforms/classification.py +0 -0
  227. {birder-0.4.0 → birder-0.4.2}/birder/datahub/__init__.py +0 -0
  228. {birder-0.4.0 → birder-0.4.2}/birder/datahub/_lib.py +0 -0
  229. {birder-0.4.0 → birder-0.4.2}/birder/inference/__init__.py +0 -0
  230. {birder-0.4.0 → birder-0.4.2}/birder/introspection/__init__.py +0 -0
  231. {birder-0.4.0 → birder-0.4.2}/birder/introspection/base.py +0 -0
  232. {birder-0.4.0 → birder-0.4.2}/birder/kernels/__init__.py +0 -0
  233. {birder-0.4.0 → birder-0.4.2}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
  234. {birder-0.4.0 → birder-0.4.2}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
  235. {birder-0.4.0 → birder-0.4.2}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
  236. {birder-0.4.0 → birder-0.4.2}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
  237. {birder-0.4.0 → birder-0.4.2}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
  238. {birder-0.4.0 → birder-0.4.2}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
  239. {birder-0.4.0 → birder-0.4.2}/birder/kernels/deformable_detr/vision.cpp +0 -0
  240. {birder-0.4.0 → birder-0.4.2}/birder/kernels/load_kernel.py +0 -0
  241. {birder-0.4.0 → birder-0.4.2}/birder/kernels/soft_nms/op.cpp +0 -0
  242. {birder-0.4.0 → birder-0.4.2}/birder/kernels/soft_nms/soft_nms.cpp +0 -0
  243. {birder-0.4.0 → birder-0.4.2}/birder/kernels/soft_nms/soft_nms.h +0 -0
  244. {birder-0.4.0 → birder-0.4.2}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
  245. {birder-0.4.0 → birder-0.4.2}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
  246. {birder-0.4.0 → birder-0.4.2}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
  247. {birder-0.4.0 → birder-0.4.2}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
  248. {birder-0.4.0 → birder-0.4.2}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
  249. {birder-0.4.0 → birder-0.4.2}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
  250. {birder-0.4.0 → birder-0.4.2}/birder/kernels/transnext/swattention.cpp +0 -0
  251. {birder-0.4.0 → birder-0.4.2}/birder/layers/__init__.py +0 -0
  252. {birder-0.4.0 → birder-0.4.2}/birder/layers/activations.py +0 -0
  253. {birder-0.4.0 → birder-0.4.2}/birder/layers/ffn.py +0 -0
  254. {birder-0.4.0 → birder-0.4.2}/birder/layers/gem.py +0 -0
  255. {birder-0.4.0 → birder-0.4.2}/birder/layers/layer_norm.py +0 -0
  256. {birder-0.4.0 → birder-0.4.2}/birder/layers/layer_scale.py +0 -0
  257. {birder-0.4.0 → birder-0.4.2}/birder/model_registry/__init__.py +0 -0
  258. {birder-0.4.0 → birder-0.4.2}/birder/model_registry/manifest.py +0 -0
  259. {birder-0.4.0 → birder-0.4.2}/birder/net/conv2former.py +0 -0
  260. {birder-0.4.0 → birder-0.4.2}/birder/net/convmixer.py +0 -0
  261. {birder-0.4.0 → birder-0.4.2}/birder/net/darknet.py +0 -0
  262. {birder-0.4.0 → birder-0.4.2}/birder/net/densenet.py +0 -0
  263. {birder-0.4.0 → birder-0.4.2}/birder/net/detection/vitdet.py +0 -0
  264. {birder-0.4.0 → birder-0.4.2}/birder/net/dpn.py +0 -0
  265. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientformer_v2.py +0 -0
  266. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientnet_lite.py +0 -0
  267. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientnet_v1.py +0 -0
  268. {birder-0.4.0 → birder-0.4.2}/birder/net/efficientnet_v2.py +0 -0
  269. {birder-0.4.0 → birder-0.4.2}/birder/net/inception_resnet_v1.py +0 -0
  270. {birder-0.4.0 → birder-0.4.2}/birder/net/inception_resnet_v2.py +0 -0
  271. {birder-0.4.0 → birder-0.4.2}/birder/net/inception_v3.py +0 -0
  272. {birder-0.4.0 → birder-0.4.2}/birder/net/inception_v4.py +0 -0
  273. {birder-0.4.0 → birder-0.4.2}/birder/net/mim/__init__.py +0 -0
  274. {birder-0.4.0 → birder-0.4.2}/birder/net/mim/base.py +0 -0
  275. {birder-0.4.0 → birder-0.4.2}/birder/net/mobilenet_v1.py +0 -0
  276. {birder-0.4.0 → birder-0.4.2}/birder/net/mobilenet_v2.py +0 -0
  277. {birder-0.4.0 → birder-0.4.2}/birder/net/mobilenet_v3.py +0 -0
  278. {birder-0.4.0 → birder-0.4.2}/birder/net/mobilenet_v4.py +0 -0
  279. {birder-0.4.0 → birder-0.4.2}/birder/net/moganet.py +0 -0
  280. {birder-0.4.0 → birder-0.4.2}/birder/net/nfnet.py +0 -0
  281. {birder-0.4.0 → birder-0.4.2}/birder/net/rdnet.py +0 -0
  282. {birder-0.4.0 → birder-0.4.2}/birder/net/regnet_z.py +0 -0
  283. {birder-0.4.0 → birder-0.4.2}/birder/net/resmlp.py +0 -0
  284. {birder-0.4.0 → birder-0.4.2}/birder/net/resnet_v1.py +0 -0
  285. {birder-0.4.0 → birder-0.4.2}/birder/net/resnet_v2.py +0 -0
  286. {birder-0.4.0 → birder-0.4.2}/birder/net/squeezenet.py +0 -0
  287. {birder-0.4.0 → birder-0.4.2}/birder/net/squeezenext.py +0 -0
  288. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/__init__.py +0 -0
  289. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/base.py +0 -0
  290. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/data2vec.py +0 -0
  291. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/dino_v1.py +0 -0
  292. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/ibot.py +0 -0
  293. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/simclr.py +0 -0
  294. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/sscd.py +0 -0
  295. {birder-0.4.0 → birder-0.4.2}/birder/net/ssl/vicreg.py +0 -0
  296. {birder-0.4.0 → birder-0.4.2}/birder/net/starnet.py +0 -0
  297. {birder-0.4.0 → birder-0.4.2}/birder/net/van.py +0 -0
  298. {birder-0.4.0 → birder-0.4.2}/birder/net/vgg_reduced.py +0 -0
  299. {birder-0.4.0 → birder-0.4.2}/birder/net/vovnet_v1.py +0 -0
  300. {birder-0.4.0 → birder-0.4.2}/birder/net/wide_resnet.py +0 -0
  301. {birder-0.4.0 → birder-0.4.2}/birder/net/xception.py +0 -0
  302. {birder-0.4.0 → birder-0.4.2}/birder/ops/__init__.py +0 -0
  303. {birder-0.4.0 → birder-0.4.2}/birder/ops/soft_nms.py +0 -0
  304. {birder-0.4.0 → birder-0.4.2}/birder/optim/__init__.py +0 -0
  305. {birder-0.4.0 → birder-0.4.2}/birder/optim/lamb.py +0 -0
  306. {birder-0.4.0 → birder-0.4.2}/birder/optim/lars.py +0 -0
  307. {birder-0.4.0 → birder-0.4.2}/birder/py.typed +0 -0
  308. {birder-0.4.0 → birder-0.4.2}/birder/results/__init__.py +0 -0
  309. {birder-0.4.0 → birder-0.4.2}/birder/results/detection.py +0 -0
  310. {birder-0.4.0 → birder-0.4.2}/birder/scheduler/__init__.py +0 -0
  311. {birder-0.4.0 → birder-0.4.2}/birder/scheduler/cooldown.py +0 -0
  312. {birder-0.4.0 → birder-0.4.2}/birder/scripts/__init__.py +0 -0
  313. {birder-0.4.0 → birder-0.4.2}/birder/scripts/__main__.py +0 -0
  314. {birder-0.4.0 → birder-0.4.2}/birder/tools/__init__.py +0 -0
  315. {birder-0.4.0 → birder-0.4.2}/birder/tools/__main__.py +0 -0
  316. {birder-0.4.0 → birder-0.4.2}/birder/tools/list_models.py +0 -0
  317. {birder-0.4.0 → birder-0.4.2}/birder/tools/verify_coco.py +0 -0
  318. {birder-0.4.0 → birder-0.4.2}/birder/tools/verify_directory.py +0 -0
  319. {birder-0.4.0 → birder-0.4.2}/birder.egg-info/dependency_links.txt +0 -0
  320. {birder-0.4.0 → birder-0.4.2}/birder.egg-info/entry_points.txt +0 -0
  321. {birder-0.4.0 → birder-0.4.2}/birder.egg-info/top_level.txt +0 -0
  322. {birder-0.4.0 → birder-0.4.2}/pyproject.toml +0 -0
  323. {birder-0.4.0 → birder-0.4.2}/requirements/requirements-hf.txt +0 -0
  324. {birder-0.4.0 → birder-0.4.2}/setup.cfg +0 -0
  325. {birder-0.4.0 → birder-0.4.2}/tests/test_dataloaders.py +0 -0
  326. {birder-0.4.0 → birder-0.4.2}/tests/test_layers.py +0 -0
  327. {birder-0.4.0 → birder-0.4.2}/tests/test_optim.py +0 -0
  328. {birder-0.4.0 → birder-0.4.2}/tests/test_results.py +0 -0
  329. {birder-0.4.0 → birder-0.4.2}/tests/test_scheduler.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -26,7 +26,7 @@ License-File: LICENSE
26
26
  Requires-Dist: matplotlib>=3.9.0
27
27
  Requires-Dist: numpy>=2.2.0
28
28
  Requires-Dist: onnx>=1.18.0
29
- Requires-Dist: onnxscript~=0.5.7
29
+ Requires-Dist: onnxscript~=0.6.0
30
30
  Requires-Dist: Pillow>=12.0.0
31
31
  Requires-Dist: polars>=1.31.0
32
32
  Requires-Dist: pyarrow>=20.0.0
@@ -43,12 +43,12 @@ Requires-Dist: torch>=2.7.0
43
43
  Requires-Dist: torchvision
44
44
  Provides-Extra: dev
45
45
  Requires-Dist: altair~=5.5.0; extra == "dev"
46
- Requires-Dist: bandit~=1.9.2; extra == "dev"
47
- Requires-Dist: black~=25.12.0; extra == "dev"
46
+ Requires-Dist: bandit~=1.9.3; extra == "dev"
47
+ Requires-Dist: black~=26.1.0; extra == "dev"
48
48
  Requires-Dist: build~=1.4.0; extra == "dev"
49
49
  Requires-Dist: bumpver~=2025.1131; extra == "dev"
50
50
  Requires-Dist: captum~=0.7.0; extra == "dev"
51
- Requires-Dist: coverage~=7.13.1; extra == "dev"
51
+ Requires-Dist: coverage~=7.13.2; extra == "dev"
52
52
  Requires-Dist: debugpy; extra == "dev"
53
53
  Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
54
54
  Requires-Dist: flake8~=7.3.0; extra == "dev"
@@ -56,7 +56,7 @@ def pixel_eps_to_normalized(
56
56
 
57
57
 
58
58
  def clamp_normalized(inputs: torch.Tensor, rgb_stats: RGBType) -> torch.Tensor:
59
- (min_val, max_val) = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
59
+ min_val, max_val = normalized_bounds(rgb_stats, device=inputs.device, dtype=inputs.dtype)
60
60
  return torch.clamp(inputs, min=min_val, max=max_val)
61
61
 
62
62
 
@@ -87,7 +87,7 @@ class SimBA:
87
87
  if self._is_successful(current_logits, label, target_label):
88
88
  return adv_inputs.detach(), num_queries
89
89
 
90
- (_, channels, height, width) = adv_inputs.shape
90
+ _, channels, height, width = adv_inputs.shape
91
91
  num_dims = channels * height * width
92
92
  step = pixel_eps_to_normalized(self.step_size, self.rgb_stats, device=adv_inputs.device, dtype=adv_inputs.dtype)
93
93
  step_vals = step.view(-1) # Per-channel steps
@@ -98,11 +98,11 @@ class SimBA:
98
98
 
99
99
  # Coordinate-wise search in random order
100
100
  for flat_idx in perm[:num_steps]:
101
- (c, rem) = divmod(int(flat_idx.item()), stride)
102
- (h, w) = divmod(rem, width)
101
+ c, rem = divmod(int(flat_idx.item()), stride)
102
+ h, w = divmod(rem, width)
103
103
  step_val = step_vals[c]
104
104
 
105
- (candidate_inputs, candidate_logits, candidate_objective) = self._best_candidate(
105
+ candidate_inputs, candidate_logits, candidate_objective = self._best_candidate(
106
106
  adv_inputs, c, h, w, step_val, label, target_label
107
107
  )
108
108
  num_queries += 2
@@ -49,7 +49,7 @@ class FlexibleDictAction(argparse.Action):
49
49
  new_dict = {}
50
50
  for pair in pairs:
51
51
  # Split each pair into key and value
52
- (key, value) = pair.split("=", 1)
52
+ key, value = pair.split("=", 1)
53
53
  key = key.strip()
54
54
 
55
55
  # Try to safely evaluate the value (handles ints and strings mostly)
@@ -384,7 +384,7 @@ def load_checkpoint(
384
384
  )
385
385
 
386
386
  # Initialize network and restore checkpoint state
387
- net = registry.net_factory(network, input_channels, num_classes, config=config, size=size)
387
+ net = registry.net_factory(network, num_classes, input_channels, config=config, size=size)
388
388
 
389
389
  # When a checkpoint was trained with EMA:
390
390
  # The primary weights in the checkpoint file are the EMA weights
@@ -437,7 +437,7 @@ def load_mim_checkpoint(
437
437
  size = lib.get_size_from_signature(signature)
438
438
 
439
439
  # Initialize network and restore checkpoint state
440
- net_encoder = registry.net_factory(encoder, input_channels, num_classes, config=encoder_config, size=size)
440
+ net_encoder = registry.net_factory(encoder, num_classes, input_channels, config=encoder_config, size=size)
441
441
  net = registry.mim_net_factory(
442
442
  network, net_encoder, config=config, size=size, mask_ratio=mask_ratio, min_mask_size=min_mask_size
443
443
  )
@@ -488,7 +488,7 @@ def load_detection_checkpoint(
488
488
  size = lib.get_size_from_signature(signature)
489
489
 
490
490
  # Initialize network and restore checkpoint state
491
- net_backbone = registry.net_factory(backbone, input_channels, num_classes, config=backbone_config, size=size)
491
+ net_backbone = registry.net_factory(backbone, num_classes, input_channels, config=backbone_config, size=size)
492
492
  net = registry.detection_net_factory(network, num_classes, net_backbone, config=config, size=size)
493
493
 
494
494
  # When a checkpoint was trained with EMA:
@@ -584,7 +584,7 @@ def load_model(
584
584
  merged_config = None # type: ignore[assignment]
585
585
 
586
586
  model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
587
- net = registry.net_factory(network, input_channels, num_classes, config=merged_config, size=size)
587
+ net = registry.net_factory(network, num_classes, input_channels, config=merged_config, size=size)
588
588
  if reparameterized is True:
589
589
  net.reparameterize_model()
590
590
 
@@ -611,7 +611,7 @@ def load_model(
611
611
  if len(merged_config) == 0:
612
612
  merged_config = None
613
613
 
614
- net = registry.net_factory(network, input_channels, num_classes, config=merged_config, size=size)
614
+ net = registry.net_factory(network, num_classes, input_channels, config=merged_config, size=size)
615
615
  if reparameterized is True:
616
616
  net.reparameterize_model()
617
617
 
@@ -733,7 +733,7 @@ def load_detection_model(
733
733
 
734
734
  model_state: dict[str, Any] = safetensors.torch.load_file(path, device=device.type)
735
735
  net_backbone = registry.net_factory(
736
- backbone, input_channels, num_classes, config=backbone_merged_config, size=size
736
+ backbone, num_classes, input_channels, config=backbone_merged_config, size=size
737
737
  )
738
738
  if backbone_reparameterized is True:
739
739
  net_backbone.reparameterize_model()
@@ -776,7 +776,7 @@ def load_detection_model(
776
776
  merged_config = None
777
777
 
778
778
  net_backbone = registry.net_factory(
779
- backbone, input_channels, num_classes, config=backbone_merged_config, size=size
779
+ backbone, num_classes, input_channels, config=backbone_merged_config, size=size
780
780
  )
781
781
  if backbone_reparameterized is True:
782
782
  net_backbone.reparameterize_model()
@@ -959,7 +959,7 @@ def load_model_with_cfg(
959
959
  encoder_name = cfg["encoder"]
960
960
 
961
961
  encoder_config = cfg.get("encoder_config", None)
962
- encoder = registry.net_factory(encoder_name, input_channels, num_classes=0, config=encoder_config, size=size)
962
+ encoder = registry.net_factory(encoder_name, 0, input_channels, config=encoder_config, size=size)
963
963
  net = registry.mim_net_factory(name, encoder, config=model_config, size=size)
964
964
 
965
965
  elif cfg["task"] == Task.OBJECT_DETECTION:
@@ -969,14 +969,14 @@ def load_model_with_cfg(
969
969
  backbone_name = cfg["backbone"]
970
970
 
971
971
  backbone_config = cfg.get("backbone_config", None)
972
- backbone = registry.net_factory(backbone_name, input_channels, num_classes, config=backbone_config, size=size)
972
+ backbone = registry.net_factory(backbone_name, num_classes, input_channels, config=backbone_config, size=size)
973
973
  if cfg.get("backbone_reparameterized", False) is True:
974
974
  backbone.reparameterize_model()
975
975
 
976
976
  net = registry.detection_net_factory(name, num_classes, backbone, config=model_config, size=size)
977
977
 
978
978
  elif cfg["task"] == Task.IMAGE_CLASSIFICATION:
979
- net = registry.net_factory(name, input_channels, num_classes, config=model_config, size=size)
979
+ net = registry.net_factory(name, num_classes, input_channels, config=model_config, size=size)
980
980
 
981
981
  else:
982
982
  raise ValueError(f"Configuration not supported: {cfg['task']}")
@@ -1019,7 +1019,7 @@ def download_model_by_weights(
1019
1019
  f"Requested format '{file_format}' not available for {weights}, available formats are: {available_formats}"
1020
1020
  )
1021
1021
 
1022
- (model_file, url) = get_pretrained_model_url(weights, file_format)
1022
+ model_file, url = get_pretrained_model_url(weights, file_format)
1023
1023
  if dst is None:
1024
1024
  dst = settings.MODELS_DIR.joinpath(model_file)
1025
1025
 
@@ -157,6 +157,6 @@ def get_pretrained_model_url(weights: str, file_format: str) -> tuple[str, str]:
157
157
 
158
158
  def format_duration(seconds: float) -> str:
159
159
  s = int(seconds)
160
- (mm, ss) = divmod(s, 60)
161
- (hh, mm) = divmod(mm, 60)
160
+ mm, ss = divmod(s, 60)
161
+ hh, mm = divmod(mm, 60)
162
162
  return f"{hh:d}:{mm:02d}:{ss:02d}"
@@ -16,7 +16,7 @@ def _mask_token_omission(
16
16
  Parameters
17
17
  ----------
18
18
  x
19
- Tensor of shape (N, L, D), where N is the batch size, L is the sequence length, and D is the feature dimension.
19
+ Tensor of shape (N, L, D), where N is the batch size, L is the sequence length and D is the feature dimension.
20
20
  mask_ratio
21
21
  The ratio of the sequence length to be masked. This value should be between 0 and 1.
22
22
  kept_mask_ratio
@@ -48,7 +48,7 @@ def _mask_token_omission(
48
48
  # Masking: length -> length * mask_ratio
49
49
  # Perform per-sample random masking by per-sample shuffling.
50
50
  # Per-sample shuffling is done by argsort random noise.
51
- (N, L, D) = x.size() # batch, length, dim
51
+ N, L, D = x.size() # batch, length, dim
52
52
  len_keep = int(L * (1 - mask_ratio))
53
53
  len_masked = int(L * (mask_ratio - kept_mask_ratio))
54
54
 
@@ -82,7 +82,7 @@ def mask_tensor(
82
82
  if channels_last is False:
83
83
  x = x.permute(0, 2, 3, 1)
84
84
 
85
- (B, H, W, _) = x.size()
85
+ B, H, W, _ = x.size()
86
86
 
87
87
  shaped_mask = mask.reshape(B, H // patch_factor, W // patch_factor)
88
88
  shaped_mask = shaped_mask.repeat_interleave(patch_factor, dim=1).repeat_interleave(patch_factor, dim=2)
@@ -13,6 +13,7 @@ from birder.conf import settings
13
13
  from birder.data.datasets.coco import MosaicType
14
14
  from birder.data.transforms.classification import AugType
15
15
  from birder.data.transforms.classification import RGBMode
16
+ from birder.data.transforms.detection import MULTISCALE_STEP
16
17
  from birder.data.transforms.detection import AugType as DetAugType
17
18
 
18
19
  logger = logging.getLogger(__name__)
@@ -55,7 +56,9 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
55
56
  )
56
57
 
57
58
 
58
- def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False) -> None:
59
+ def add_lr_wd_args(
60
+ parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False, backbone_layer_decay: bool = False
61
+ ) -> None:
59
62
  group = parser.add_argument_group("Learning rate and regularization parameters")
60
63
  group.add_argument("--lr", type=float, default=0.1, metavar="LR", help="base learning rate")
61
64
  group.add_argument("--bias-lr", type=float, metavar="LR", help="learning rate of biases")
@@ -91,6 +94,9 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
91
94
  help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
92
95
  )
93
96
  group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
97
+ if backbone_layer_decay is True:
98
+ group.add_argument("--backbone-layer-decay", type=float, help="backbone layer-wise learning rate decay (LLRD)")
99
+
94
100
  group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
95
101
  group.add_argument(
96
102
  "--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
@@ -199,10 +205,16 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
199
205
  action="store_true",
200
206
  help="enable random square resize once per batch (capped by max(--size))",
201
207
  )
208
+ group.add_argument(
209
+ "--multiscale-step",
210
+ type=int,
211
+ default=MULTISCALE_STEP,
212
+ help="step size for multiscale size lists and collator padding divisibility (size_divisible)",
213
+ )
202
214
  group.add_argument(
203
215
  "--multiscale-min-size",
204
216
  type=int,
205
- help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
217
+ help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of --multiscale-step)",
206
218
  )
207
219
 
208
220
 
@@ -515,7 +527,10 @@ def add_distributed_args(parser: argparse.ArgumentParser) -> None:
515
527
 
516
528
 
517
529
  def add_logging_and_debug_args(
518
- parser: argparse.ArgumentParser, default_log_interval: int = 50, fake_data: bool = True
530
+ parser: argparse.ArgumentParser,
531
+ default_log_interval: int = 50,
532
+ fake_data: bool = True,
533
+ classification: bool = False,
519
534
  ) -> None:
520
535
  group = parser.add_argument_group("Logging and debugging parameters")
521
536
  group.add_argument(
@@ -525,6 +540,11 @@ def add_logging_and_debug_args(
525
540
  metavar="NAME",
526
541
  help="experiment name for logging (creates dedicated directory for the run)",
527
542
  )
543
+ if classification is True:
544
+ group.add_argument(
545
+ "--top-k", type=int, metavar="K", help="additional top-k accuracy value to track (top-1 is always tracked)"
546
+ )
547
+
528
548
  group.add_argument(
529
549
  "--log-interval",
530
550
  type=int,
@@ -746,3 +766,10 @@ def common_args_validation(args: argparse.Namespace) -> None:
746
766
  # Precision_args, shared by all scripts
747
767
  if args.amp is True and args.model_dtype != "float32":
748
768
  raise ValidationError("--amp can only be used with --model-dtype float32")
769
+
770
+ if hasattr(args, "top_k") is True and args.top_k is not None:
771
+ if args.top_k == 1:
772
+ raise ValidationError("Top-1 accuracy is tracked by default, please remove 1 from --top-k argument")
773
+
774
+ if args.top_k <= 0:
775
+ raise ValidationError("--top-k value must be a positive integer")
@@ -11,6 +11,7 @@ from collections import deque
11
11
  from collections.abc import Callable
12
12
  from collections.abc import Generator
13
13
  from collections.abc import Iterator
14
+ from collections.abc import Sequence
14
15
  from datetime import datetime
15
16
  from pathlib import Path
16
17
  from typing import Any
@@ -342,7 +343,7 @@ def count_layers(model: torch.nn.Module) -> int:
342
343
  return num_layers
343
344
 
344
345
 
345
- # pylint: disable=protected-access,too-many-locals,too-many-branches
346
+ # pylint: disable=protected-access,too-many-locals,too-many-branches,too-many-statements
346
347
  def optimizer_parameter_groups(
347
348
  model: torch.nn.Module,
348
349
  weight_decay: float,
@@ -351,6 +352,7 @@ def optimizer_parameter_groups(
351
352
  custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
352
353
  custom_layer_weight_decay: Optional[dict[str, float]] = None,
353
354
  layer_decay: Optional[float] = None,
355
+ backbone_layer_decay: Optional[float] = None,
354
356
  layer_decay_min_scale: Optional[float] = None,
355
357
  layer_decay_no_opt_scale: Optional[float] = None,
356
358
  bias_lr: Optional[float] = None,
@@ -361,7 +363,7 @@ def optimizer_parameter_groups(
361
363
  Return parameter groups for optimizers with per-parameter group weight decay.
362
364
 
363
365
  This function creates parameter groups with customizable weight decay, layer-wise
364
- learning rate scaling, and special handling for different parameter types. It supports
366
+ learning rate scaling and special handling for different parameter types. It supports
365
367
  advanced optimization techniques like layer decay and custom weight decay rules.
366
368
 
367
369
  Referenced from https://github.com/pytorch/vision/blob/main/references/classification/utils.py and from
@@ -387,6 +389,8 @@ def optimizer_parameter_groups(
387
389
  Applied to parameters whose names contain the specified keys.
388
390
  layer_decay
389
391
  Layer-wise learning rate decay factor.
392
+ backbone_layer_decay
393
+ Layer-wise learning rate decay factor for backbone parameters only.
390
394
  layer_decay_min_scale
391
395
  Minimum learning rate scale factor when using layer decay. Prevents layers from having too small learning rates.
392
396
  layer_decay_no_opt_scale
@@ -433,6 +437,27 @@ def optimizer_parameter_groups(
433
437
  if layer_decay is not None:
434
438
  logger.warning("Assigning lr scaling (layer decay) without a block group map")
435
439
 
440
+ backbone_group_map: dict[str, int] = {}
441
+ backbone_num_layers = 0
442
+ if backbone_layer_decay is not None:
443
+ backbone_module = getattr(model, "backbone", None)
444
+ if backbone_module is None:
445
+ logger.warning("Backbone layer decay requested but model has no backbone")
446
+ backbone_layer_decay = None
447
+ else:
448
+ backbone_block_group_regex = getattr(backbone_module, "block_group_regex", None)
449
+ if backbone_block_group_regex is not None:
450
+ names = [n for n, _ in backbone_module.named_parameters()]
451
+ groups = group_by_regex(names, backbone_block_group_regex)
452
+ backbone_group_map = {
453
+ f"backbone.{item}": index for index, sublist in enumerate(groups) for item in sublist
454
+ }
455
+ backbone_num_layers = len(groups)
456
+ else:
457
+ backbone_group_map = {}
458
+ backbone_num_layers = count_layers(backbone_module)
459
+ logger.warning("Assigning lr scaling (backbone layer decay) without a block group map")
460
+
436
461
  # Build layer scale
437
462
  if layer_decay_min_scale is None:
438
463
  layer_decay_min_scale = 0.0
@@ -443,14 +468,28 @@ def optimizer_parameter_groups(
443
468
  layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
444
469
  logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
445
470
 
471
+ backbone_layer_scales = []
472
+ if backbone_layer_decay is not None:
473
+ backbone_layer_max = backbone_num_layers - 1
474
+ backbone_layer_scales = [
475
+ max(layer_decay_min_scale, backbone_layer_decay ** (backbone_layer_max - i))
476
+ for i in range(backbone_num_layers)
477
+ ]
478
+ logger.info(
479
+ "Backbone layer scaling ranges from "
480
+ f"{min(backbone_layer_scales)} to {max(backbone_layer_scales)} across {backbone_num_layers} layers"
481
+ )
482
+
446
483
  # Set weight decay and layer decay
447
484
  idx = 0
485
+ backbone_idx = 0
448
486
  params = []
449
487
  module_stack_with_prefix = [(model, "")]
450
488
  visited_modules = []
451
489
  while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
452
490
  skip_module = False
453
- (module, prefix) = module_stack_with_prefix.pop()
491
+ module, prefix = module_stack_with_prefix.pop()
492
+ is_backbone_module = prefix == "backbone" or prefix.startswith("backbone.")
454
493
  if id(module) in visited_modules:
455
494
  skip_module = True
456
495
 
@@ -459,23 +498,35 @@ def optimizer_parameter_groups(
459
498
  for name, p in module.named_parameters(recurse=False):
460
499
  target_name = f"{prefix}.{name}" if prefix != "" else name
461
500
  idx = group_map.get(target_name, idx)
501
+ is_backbone_param = target_name.startswith("backbone.")
502
+ if backbone_layer_decay is not None and is_backbone_param is True:
503
+ backbone_idx = backbone_group_map.get(target_name, backbone_idx)
462
504
  if skip_module is True:
463
505
  break
464
506
 
465
507
  parameters_found = True
466
508
  if p.requires_grad is False:
467
509
  continue
468
- if layer_decay is not None and layer_decay_no_opt_scale is not None:
469
- if layer_scales[idx] < layer_decay_no_opt_scale:
470
- p.requires_grad_(False)
510
+ if layer_decay_no_opt_scale is not None:
511
+ if backbone_layer_decay is not None and is_backbone_param is True:
512
+ if backbone_layer_scales and backbone_layer_scales[backbone_idx] < layer_decay_no_opt_scale:
513
+ p.requires_grad_(False)
514
+ elif layer_decay is not None:
515
+ if layer_scales[idx] < layer_decay_no_opt_scale:
516
+ p.requires_grad_(False)
471
517
 
472
518
  is_custom_key = False
473
519
  if custom_keys_weight_decay is not None:
474
520
  for key, custom_wd in custom_keys_weight_decay:
475
521
  target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
476
522
  if key == target_name_for_custom_key:
477
- # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
478
- lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
523
+ # Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
524
+ if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
525
+ lr_scale = layer_scales[idx]
526
+ elif backbone_layer_decay is not None and is_backbone_param is True:
527
+ lr_scale = backbone_layer_scales[backbone_idx]
528
+ else:
529
+ lr_scale = 1.0
479
530
  if custom_layer_lr_scale is not None:
480
531
  for layer_name_key, custom_scale in custom_layer_lr_scale.items():
481
532
  if layer_name_key in target_name:
@@ -499,8 +550,8 @@ def optimizer_parameter_groups(
499
550
  # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
500
551
  if bias_lr is not None and target_name.endswith(".bias") is True:
501
552
  d["lr"] = bias_lr
502
- elif backbone_lr is not None and target_name.startswith("backbone.") is True:
503
- d["lr"] = backbone_lr
553
+ elif backbone_lr is not None and is_backbone_param is True:
554
+ d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
504
555
  elif lr_scale != 1.0:
505
556
  d["lr"] = base_lr * lr_scale
506
557
 
@@ -521,8 +572,13 @@ def optimizer_parameter_groups(
521
572
  wd = custom_wd_value
522
573
  break
523
574
 
524
- # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
525
- lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
575
+ # Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
576
+ if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
577
+ lr_scale = layer_scales[idx]
578
+ elif backbone_layer_decay is not None and is_backbone_param is True:
579
+ lr_scale = backbone_layer_scales[backbone_idx]
580
+ else:
581
+ lr_scale = 1.0
526
582
  if custom_layer_lr_scale is not None:
527
583
  for layer_name_key, custom_scale in custom_layer_lr_scale.items():
528
584
  if layer_name_key in target_name:
@@ -538,8 +594,8 @@ def optimizer_parameter_groups(
538
594
  # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
539
595
  if bias_lr is not None and target_name.endswith(".bias") is True:
540
596
  d["lr"] = bias_lr
541
- elif backbone_lr is not None and target_name.startswith("backbone.") is True:
542
- d["lr"] = backbone_lr
597
+ elif backbone_lr is not None and is_backbone_param is True:
598
+ d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
543
599
  elif lr_scale != 1.0:
544
600
  d["lr"] = base_lr * lr_scale
545
601
 
@@ -547,6 +603,8 @@ def optimizer_parameter_groups(
547
603
 
548
604
  if parameters_found is True:
549
605
  idx += 1
606
+ if is_backbone_module is True:
607
+ backbone_idx += 1
550
608
 
551
609
  for child_name, child_module in reversed(list(module.named_children())):
552
610
  child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
@@ -884,6 +942,11 @@ class SmoothedValue:
884
942
  self.total: torch.Tensor | float = 0.0
885
943
  self.count: int = 0
886
944
 
945
+ def clear(self) -> None:
946
+ self.deque.clear()
947
+ self.total = 0.0
948
+ self.count = 0
949
+
887
950
  def update(self, value: torch.Tensor | float, n: int = 1) -> None:
888
951
  self.deque.append(value)
889
952
  self.count += n
@@ -927,14 +990,32 @@ class SmoothedValue:
927
990
  return to_tensor(v, torch.device("cpu")).item() # type: ignore[no-any-return]
928
991
 
929
992
 
930
- def accuracy(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:
993
+ @torch.no_grad() # type: ignore[untyped-decorator]
994
+ def accuracy(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
931
995
  if y_pred.dim() > 1 and y_pred.size(1) > 1:
932
996
  y_pred = y_pred.argmax(dim=1)
933
997
 
934
998
  y_true = y_true.flatten()
935
999
  y_pred = y_pred.flatten()
936
1000
 
937
- return (y_true == y_pred).float().mean().item() # type: ignore[no-any-return]
1001
+ return (y_true == y_pred).sum() / y_true.numel()
1002
+
1003
+
1004
+ @torch.no_grad() # type: ignore[untyped-decorator]
1005
+ def topk_accuracy(y_true: torch.Tensor, y_pred: torch.Tensor, topk: Sequence[int]) -> list[torch.Tensor]:
1006
+ maxk = min(max(topk), y_pred.size(1))
1007
+ batch_size = y_true.size(0)
1008
+
1009
+ _, pred = y_pred.topk(maxk, dim=1, largest=True, sorted=True)
1010
+ correct = pred.eq(y_true.unsqueeze(1))
1011
+
1012
+ res: list[torch.Tensor] = []
1013
+ for k in topk:
1014
+ k = min(k, maxk)
1015
+ correct_k = correct[:, :k].any(dim=1).sum(dtype=torch.float32)
1016
+ res.append((correct_k / batch_size))
1017
+
1018
+ return res
938
1019
 
939
1020
 
940
1021
  ###############################################################################
@@ -70,13 +70,21 @@ class BatchRandomResizeCollator(DetectionCollator):
70
70
  size: tuple[int, int],
71
71
  size_divisible: int = 32,
72
72
  multiscale_min_size: Optional[int] = None,
73
+ multiscale_step: Optional[int] = None,
73
74
  ) -> None:
74
75
  super().__init__(input_offset, size_divisible=size_divisible)
75
76
  if size is None:
76
77
  raise ValueError("size must be provided for batch multiscale")
77
78
 
78
79
  max_side = max(size)
79
- sizes = [side for side in build_multiscale_sizes(multiscale_min_size) if side <= max_side]
80
+ if multiscale_step is None:
81
+ multiscale_step = size_divisible
82
+
83
+ sizes = []
84
+ for side in build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step):
85
+ if side <= max_side:
86
+ sizes.append(side)
87
+
80
88
  if len(sizes) == 0:
81
89
  sizes = [max_side]
82
90
 
@@ -17,17 +17,20 @@ DEFAULT_MULTISCALE_MAX_SIZE = 800
17
17
 
18
18
 
19
19
  def build_multiscale_sizes(
20
- min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE
20
+ min_size: Optional[int] = None, max_size: int = DEFAULT_MULTISCALE_MAX_SIZE, multiscale_step: int = MULTISCALE_STEP
21
21
  ) -> tuple[int, ...]:
22
+ if multiscale_step <= 0:
23
+ raise ValueError("multiscale_step must be positive")
24
+
22
25
  if min_size is None:
23
26
  min_size = DEFAULT_MULTISCALE_MIN_SIZE
24
27
 
25
- start = int(math.ceil(min_size / MULTISCALE_STEP) * MULTISCALE_STEP)
26
- end = int(math.floor(max_size / MULTISCALE_STEP) * MULTISCALE_STEP)
28
+ start = int(math.ceil(min_size / multiscale_step) * multiscale_step)
29
+ end = int(math.floor(max_size / multiscale_step) * multiscale_step)
27
30
  if end < start:
28
31
  return (start,)
29
32
 
30
- return tuple(range(start, end + 1, MULTISCALE_STEP))
33
+ return tuple(range(start, end + 1, multiscale_step))
31
34
 
32
35
 
33
36
  class ResizeWithRandomInterpolation(nn.Module):
@@ -59,6 +62,7 @@ def get_birder_augment(
59
62
  multiscale: bool,
60
63
  max_size: Optional[int],
61
64
  multiscale_min_size: Optional[int],
65
+ multiscale_step: int = MULTISCALE_STEP,
62
66
  post_mosaic: bool = False,
63
67
  ) -> Callable[..., torch.Tensor]:
64
68
  if dynamic_size is True:
@@ -98,7 +102,10 @@ def get_birder_augment(
98
102
  # Resize
99
103
  if multiscale is True:
100
104
  transformations.append(
101
- v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
105
+ v2.RandomShortestSize(
106
+ min_size=build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step),
107
+ max_size=max_size or 1333,
108
+ ),
102
109
  )
103
110
  else:
104
111
  transformations.append(
@@ -160,6 +167,7 @@ def training_preset(
160
167
  multiscale: bool = False,
161
168
  max_size: Optional[int] = None,
162
169
  multiscale_min_size: Optional[int] = None,
170
+ multiscale_step: int = MULTISCALE_STEP,
163
171
  post_mosaic: bool = False,
164
172
  ) -> Callable[..., torch.Tensor]:
165
173
  mean = rgv_values["mean"]
@@ -180,7 +188,15 @@ def training_preset(
180
188
  [
181
189
  v2.ToImage(),
182
190
  get_birder_augment(
183
- size, level, fill_value, dynamic_size, multiscale, max_size, multiscale_min_size, post_mosaic
191
+ size,
192
+ level,
193
+ fill_value,
194
+ dynamic_size,
195
+ multiscale,
196
+ max_size,
197
+ multiscale_min_size,
198
+ multiscale_step,
199
+ post_mosaic,
184
200
  ),
185
201
  v2.ToDtype(torch.float32, scale=True),
186
202
  v2.Normalize(mean=mean, std=std),
@@ -212,7 +228,10 @@ def training_preset(
212
228
  return v2.Compose( # type: ignore
213
229
  [
214
230
  v2.ToImage(),
215
- v2.RandomShortestSize(min_size=build_multiscale_sizes(multiscale_min_size), max_size=max_size or 1333),
231
+ v2.RandomShortestSize(
232
+ min_size=build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step),
233
+ max_size=max_size or 1333,
234
+ ),
216
235
  v2.RandomHorizontalFlip(0.5),
217
236
  v2.SanitizeBoundingBoxes(),
218
237
  v2.ToDtype(torch.float32, scale=True),
@@ -284,7 +303,7 @@ def training_preset(
284
303
  )
285
304
 
286
305
  if aug_type == "detr":
287
- multiscale_sizes = build_multiscale_sizes(multiscale_min_size)
306
+ multiscale_sizes = build_multiscale_sizes(multiscale_min_size, multiscale_step=multiscale_step)
288
307
  return v2.Compose( # type: ignore
289
308
  [
290
309
  v2.ToImage(),
@@ -19,7 +19,7 @@ def mosaic_random_center(
19
19
  Create a mosaic augmentation by combining 4 images into a single image.
20
20
 
21
21
  This augmentation places 4 images on a canvas, meeting at a randomly selected
22
- center point. Each image is scaled to fit, cropped as needed, and their bounding
22
+ center point. Each image is scaled to fit, cropped as needed and their bounding
23
23
  boxes are transformed accordingly.
24
24
 
25
25
  Parameters