birder 0.4.1__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (354) hide show
  1. {birder-0.4.1 → birder-0.4.4}/PKG-INFO +17 -10
  2. {birder-0.4.1 → birder-0.4.4}/README.md +12 -5
  3. {birder-0.4.1 → birder-0.4.4}/birder/__init__.py +2 -0
  4. {birder-0.4.1 → birder-0.4.4}/birder/common/fs_ops.py +81 -1
  5. {birder-0.4.1 → birder-0.4.4}/birder/common/training_cli.py +12 -2
  6. {birder-0.4.1 → birder-0.4.4}/birder/common/training_utils.py +73 -12
  7. {birder-0.4.1 → birder-0.4.4}/birder/data/collators/detection.py +3 -1
  8. {birder-0.4.1 → birder-0.4.4}/birder/datahub/_lib.py +15 -6
  9. birder-0.4.4/birder/datahub/evaluation.py +591 -0
  10. birder-0.4.4/birder/eval/__main__.py +74 -0
  11. birder-0.4.4/birder/eval/_embeddings.py +50 -0
  12. birder-0.4.4/birder/eval/adversarial.py +315 -0
  13. birder-0.4.4/birder/eval/benchmarks/awa2.py +357 -0
  14. birder-0.4.4/birder/eval/benchmarks/bioscan5m.py +198 -0
  15. birder-0.4.4/birder/eval/benchmarks/fishnet.py +318 -0
  16. birder-0.4.4/birder/eval/benchmarks/flowers102.py +210 -0
  17. birder-0.4.4/birder/eval/benchmarks/fungiclef.py +261 -0
  18. birder-0.4.4/birder/eval/benchmarks/nabirds.py +202 -0
  19. birder-0.4.4/birder/eval/benchmarks/newt.py +262 -0
  20. birder-0.4.4/birder/eval/benchmarks/plankton.py +255 -0
  21. birder-0.4.4/birder/eval/benchmarks/plantdoc.py +259 -0
  22. birder-0.4.4/birder/eval/benchmarks/plantnet.py +252 -0
  23. birder-0.4.4/birder/eval/classification.py +235 -0
  24. birder-0.4.4/birder/eval/methods/ami.py +78 -0
  25. birder-0.4.4/birder/eval/methods/knn.py +71 -0
  26. birder-0.4.4/birder/eval/methods/linear.py +152 -0
  27. birder-0.4.4/birder/eval/methods/mlp.py +178 -0
  28. birder-0.4.4/birder/eval/methods/simpleshot.py +100 -0
  29. birder-0.4.4/birder/eval/methods/svm.py +92 -0
  30. {birder-0.4.1 → birder-0.4.4}/birder/inference/classification.py +23 -2
  31. {birder-0.4.1 → birder-0.4.4}/birder/inference/detection.py +35 -15
  32. {birder-0.4.1 → birder-0.4.4}/birder/net/_vit_configs.py +5 -0
  33. {birder-0.4.1 → birder-0.4.4}/birder/net/cait.py +3 -3
  34. {birder-0.4.1 → birder-0.4.4}/birder/net/coat.py +3 -3
  35. {birder-0.4.1 → birder-0.4.4}/birder/net/cswin_transformer.py +2 -1
  36. {birder-0.4.1 → birder-0.4.4}/birder/net/deit.py +1 -1
  37. {birder-0.4.1 → birder-0.4.4}/birder/net/deit3.py +1 -1
  38. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/__init__.py +2 -0
  39. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/base.py +41 -18
  40. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/deformable_detr.py +74 -50
  41. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/detr.py +29 -26
  42. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/efficientdet.py +42 -25
  43. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/faster_rcnn.py +53 -21
  44. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/fcos.py +42 -23
  45. birder-0.4.4/birder/net/detection/lw_detr.py +1204 -0
  46. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/plain_detr.py +60 -47
  47. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/retinanet.py +47 -35
  48. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/rt_detr_v1.py +49 -46
  49. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/rt_detr_v2.py +95 -102
  50. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/ssd.py +47 -31
  51. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/ssdlite.py +2 -2
  52. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/yolo_v2.py +33 -18
  53. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/yolo_v3.py +35 -33
  54. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/yolo_v4.py +35 -20
  55. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/yolo_v4_tiny.py +1 -2
  56. {birder-0.4.1 → birder-0.4.4}/birder/net/edgevit.py +3 -3
  57. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientvit_msft.py +1 -1
  58. {birder-0.4.1 → birder-0.4.4}/birder/net/flexivit.py +1 -1
  59. {birder-0.4.1 → birder-0.4.4}/birder/net/hiera.py +44 -67
  60. {birder-0.4.1 → birder-0.4.4}/birder/net/hieradet.py +2 -2
  61. {birder-0.4.1 → birder-0.4.4}/birder/net/maxvit.py +2 -2
  62. {birder-0.4.1 → birder-0.4.4}/birder/net/mim/fcmae.py +2 -2
  63. {birder-0.4.1 → birder-0.4.4}/birder/net/mim/mae_hiera.py +9 -16
  64. {birder-0.4.1 → birder-0.4.4}/birder/net/mnasnet.py +2 -2
  65. {birder-0.4.1 → birder-0.4.4}/birder/net/nextvit.py +4 -4
  66. {birder-0.4.1 → birder-0.4.4}/birder/net/resnext.py +2 -2
  67. {birder-0.4.1 → birder-0.4.4}/birder/net/rope_deit3.py +2 -2
  68. {birder-0.4.1 → birder-0.4.4}/birder/net/rope_flexivit.py +2 -2
  69. {birder-0.4.1 → birder-0.4.4}/birder/net/rope_vit.py +2 -2
  70. {birder-0.4.1 → birder-0.4.4}/birder/net/simple_vit.py +1 -1
  71. {birder-0.4.1 → birder-0.4.4}/birder/net/squeezenet.py +1 -1
  72. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/capi.py +32 -25
  73. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/dino_v2.py +12 -15
  74. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/franca.py +26 -19
  75. {birder-0.4.1 → birder-0.4.4}/birder/net/van.py +2 -2
  76. {birder-0.4.1 → birder-0.4.4}/birder/net/vit.py +21 -3
  77. {birder-0.4.1 → birder-0.4.4}/birder/net/vit_parallel.py +1 -1
  78. {birder-0.4.1 → birder-0.4.4}/birder/net/vit_sam.py +62 -16
  79. {birder-0.4.1 → birder-0.4.4}/birder/net/xcit.py +1 -1
  80. {birder-0.4.1 → birder-0.4.4}/birder/ops/msda.py +46 -16
  81. birder-0.4.4/birder/results/__init__.py +0 -0
  82. birder-0.4.4/birder/scripts/__init__.py +0 -0
  83. {birder-0.4.1 → birder-0.4.4}/birder/scripts/benchmark.py +35 -8
  84. {birder-0.4.1 → birder-0.4.4}/birder/scripts/predict.py +14 -1
  85. {birder-0.4.1 → birder-0.4.4}/birder/scripts/predict_detection.py +7 -1
  86. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train.py +27 -11
  87. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_capi.py +13 -10
  88. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_detection.py +18 -7
  89. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_franca.py +10 -2
  90. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_kd.py +28 -11
  91. birder-0.4.4/birder/tools/__init__.py +0 -0
  92. {birder-0.4.1 → birder-0.4.4}/birder/tools/adversarial.py +5 -0
  93. {birder-0.4.1 → birder-0.4.4}/birder/tools/convert_model.py +101 -43
  94. {birder-0.4.1 → birder-0.4.4}/birder/tools/quantize_model.py +33 -16
  95. birder-0.4.4/birder/version.py +1 -0
  96. {birder-0.4.1 → birder-0.4.4}/birder.egg-info/PKG-INFO +17 -10
  97. {birder-0.4.1 → birder-0.4.4}/birder.egg-info/SOURCES.txt +26 -1
  98. {birder-0.4.1 → birder-0.4.4}/birder.egg-info/requires.txt +2 -2
  99. {birder-0.4.1 → birder-0.4.4}/pyproject.toml +4 -1
  100. {birder-0.4.1 → birder-0.4.4}/requirements/_requirements-dev.txt +1 -1
  101. {birder-0.4.1 → birder-0.4.4}/requirements/requirements.txt +1 -1
  102. {birder-0.4.1 → birder-0.4.4}/tests/test_common.py +22 -1
  103. birder-0.4.4/tests/test_eval.py +111 -0
  104. {birder-0.4.1 → birder-0.4.4}/tests/test_kernels.py +16 -11
  105. {birder-0.4.1 → birder-0.4.4}/tests/test_net.py +43 -12
  106. {birder-0.4.1 → birder-0.4.4}/tests/test_net_detection.py +50 -25
  107. {birder-0.4.1 → birder-0.4.4}/tests/test_net_ssl.py +72 -14
  108. birder-0.4.1/birder/scripts/evaluate.py +0 -176
  109. birder-0.4.1/birder/version.py +0 -1
  110. {birder-0.4.1 → birder-0.4.4}/LICENSE +0 -0
  111. {birder-0.4.1 → birder-0.4.4}/birder/adversarial/__init__.py +0 -0
  112. {birder-0.4.1 → birder-0.4.4}/birder/adversarial/base.py +0 -0
  113. {birder-0.4.1 → birder-0.4.4}/birder/adversarial/deepfool.py +0 -0
  114. {birder-0.4.1 → birder-0.4.4}/birder/adversarial/fgsm.py +0 -0
  115. {birder-0.4.1 → birder-0.4.4}/birder/adversarial/pgd.py +0 -0
  116. {birder-0.4.1 → birder-0.4.4}/birder/adversarial/simba.py +0 -0
  117. {birder-0.4.1 → birder-0.4.4}/birder/common/__init__.py +0 -0
  118. {birder-0.4.1 → birder-0.4.4}/birder/common/cli.py +0 -0
  119. {birder-0.4.1 → birder-0.4.4}/birder/common/lib.py +0 -0
  120. {birder-0.4.1 → birder-0.4.4}/birder/common/masking.py +0 -0
  121. {birder-0.4.1 → birder-0.4.4}/birder/conf/__init__.py +0 -0
  122. {birder-0.4.1 → birder-0.4.4}/birder/conf/settings.py +0 -0
  123. {birder-0.4.1 → birder-0.4.4}/birder/data/__init__.py +0 -0
  124. {birder-0.4.1 → birder-0.4.4}/birder/data/collators/__init__.py +0 -0
  125. {birder-0.4.1 → birder-0.4.4}/birder/data/dataloader/__init__.py +0 -0
  126. {birder-0.4.1 → birder-0.4.4}/birder/data/dataloader/webdataset.py +0 -0
  127. {birder-0.4.1 → birder-0.4.4}/birder/data/datasets/__init__.py +0 -0
  128. {birder-0.4.1 → birder-0.4.4}/birder/data/datasets/coco.py +0 -0
  129. {birder-0.4.1 → birder-0.4.4}/birder/data/datasets/directory.py +0 -0
  130. {birder-0.4.1 → birder-0.4.4}/birder/data/datasets/fake.py +0 -0
  131. {birder-0.4.1 → birder-0.4.4}/birder/data/datasets/webdataset.py +0 -0
  132. {birder-0.4.1 → birder-0.4.4}/birder/data/transforms/__init__.py +0 -0
  133. {birder-0.4.1 → birder-0.4.4}/birder/data/transforms/classification.py +0 -0
  134. {birder-0.4.1 → birder-0.4.4}/birder/data/transforms/detection.py +0 -0
  135. {birder-0.4.1 → birder-0.4.4}/birder/data/transforms/mosaic.py +0 -0
  136. {birder-0.4.1 → birder-0.4.4}/birder/datahub/__init__.py +0 -0
  137. {birder-0.4.1 → birder-0.4.4}/birder/datahub/classification.py +0 -0
  138. {birder-0.4.1/birder/inference → birder-0.4.4/birder/eval}/__init__.py +0 -0
  139. {birder-0.4.1/birder/kernels → birder-0.4.4/birder/eval/benchmarks}/__init__.py +0 -0
  140. {birder-0.4.1/birder/ops → birder-0.4.4/birder/eval/methods}/__init__.py +0 -0
  141. {birder-0.4.1/birder/results → birder-0.4.4/birder/inference}/__init__.py +0 -0
  142. {birder-0.4.1 → birder-0.4.4}/birder/inference/data_parallel.py +0 -0
  143. {birder-0.4.1 → birder-0.4.4}/birder/inference/wbf.py +0 -0
  144. {birder-0.4.1 → birder-0.4.4}/birder/introspection/__init__.py +0 -0
  145. {birder-0.4.1 → birder-0.4.4}/birder/introspection/attention_rollout.py +0 -0
  146. {birder-0.4.1 → birder-0.4.4}/birder/introspection/base.py +0 -0
  147. {birder-0.4.1 → birder-0.4.4}/birder/introspection/feature_pca.py +0 -0
  148. {birder-0.4.1 → birder-0.4.4}/birder/introspection/gradcam.py +0 -0
  149. {birder-0.4.1 → birder-0.4.4}/birder/introspection/guided_backprop.py +0 -0
  150. {birder-0.4.1 → birder-0.4.4}/birder/introspection/transformer_attribution.py +0 -0
  151. {birder-0.4.1/birder/scripts → birder-0.4.4/birder/kernels}/__init__.py +0 -0
  152. {birder-0.4.1 → birder-0.4.4}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
  153. {birder-0.4.1 → birder-0.4.4}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
  154. {birder-0.4.1 → birder-0.4.4}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
  155. {birder-0.4.1 → birder-0.4.4}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
  156. {birder-0.4.1 → birder-0.4.4}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
  157. {birder-0.4.1 → birder-0.4.4}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
  158. {birder-0.4.1 → birder-0.4.4}/birder/kernels/deformable_detr/vision.cpp +0 -0
  159. {birder-0.4.1 → birder-0.4.4}/birder/kernels/load_kernel.py +0 -0
  160. {birder-0.4.1 → birder-0.4.4}/birder/kernels/soft_nms/op.cpp +0 -0
  161. {birder-0.4.1 → birder-0.4.4}/birder/kernels/soft_nms/soft_nms.cpp +0 -0
  162. {birder-0.4.1 → birder-0.4.4}/birder/kernels/soft_nms/soft_nms.h +0 -0
  163. {birder-0.4.1 → birder-0.4.4}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
  164. {birder-0.4.1 → birder-0.4.4}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
  165. {birder-0.4.1 → birder-0.4.4}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
  166. {birder-0.4.1 → birder-0.4.4}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
  167. {birder-0.4.1 → birder-0.4.4}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
  168. {birder-0.4.1 → birder-0.4.4}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
  169. {birder-0.4.1 → birder-0.4.4}/birder/kernels/transnext/swattention.cpp +0 -0
  170. {birder-0.4.1 → birder-0.4.4}/birder/layers/__init__.py +0 -0
  171. {birder-0.4.1 → birder-0.4.4}/birder/layers/activations.py +0 -0
  172. {birder-0.4.1 → birder-0.4.4}/birder/layers/attention_pool.py +0 -0
  173. {birder-0.4.1 → birder-0.4.4}/birder/layers/ffn.py +0 -0
  174. {birder-0.4.1 → birder-0.4.4}/birder/layers/gem.py +0 -0
  175. {birder-0.4.1 → birder-0.4.4}/birder/layers/layer_norm.py +0 -0
  176. {birder-0.4.1 → birder-0.4.4}/birder/layers/layer_scale.py +0 -0
  177. {birder-0.4.1 → birder-0.4.4}/birder/model_registry/__init__.py +0 -0
  178. {birder-0.4.1 → birder-0.4.4}/birder/model_registry/manifest.py +0 -0
  179. {birder-0.4.1 → birder-0.4.4}/birder/model_registry/model_registry.py +0 -0
  180. {birder-0.4.1 → birder-0.4.4}/birder/net/__init__.py +0 -0
  181. {birder-0.4.1 → birder-0.4.4}/birder/net/_rope_vit_configs.py +0 -0
  182. {birder-0.4.1 → birder-0.4.4}/birder/net/alexnet.py +0 -0
  183. {birder-0.4.1 → birder-0.4.4}/birder/net/base.py +0 -0
  184. {birder-0.4.1 → birder-0.4.4}/birder/net/biformer.py +0 -0
  185. {birder-0.4.1 → birder-0.4.4}/birder/net/cas_vit.py +0 -0
  186. {birder-0.4.1 → birder-0.4.4}/birder/net/conv2former.py +0 -0
  187. {birder-0.4.1 → birder-0.4.4}/birder/net/convmixer.py +0 -0
  188. {birder-0.4.1 → birder-0.4.4}/birder/net/convnext_v1.py +0 -0
  189. {birder-0.4.1 → birder-0.4.4}/birder/net/convnext_v1_iso.py +0 -0
  190. {birder-0.4.1 → birder-0.4.4}/birder/net/convnext_v2.py +0 -0
  191. {birder-0.4.1 → birder-0.4.4}/birder/net/crossformer.py +0 -0
  192. {birder-0.4.1 → birder-0.4.4}/birder/net/crossvit.py +0 -0
  193. {birder-0.4.1 → birder-0.4.4}/birder/net/cspnet.py +0 -0
  194. {birder-0.4.1 → birder-0.4.4}/birder/net/darknet.py +0 -0
  195. {birder-0.4.1 → birder-0.4.4}/birder/net/davit.py +0 -0
  196. {birder-0.4.1 → birder-0.4.4}/birder/net/densenet.py +0 -0
  197. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/_yolo_anchors.py +0 -0
  198. {birder-0.4.1 → birder-0.4.4}/birder/net/detection/vitdet.py +0 -0
  199. {birder-0.4.1 → birder-0.4.4}/birder/net/dpn.py +0 -0
  200. {birder-0.4.1 → birder-0.4.4}/birder/net/edgenext.py +0 -0
  201. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientformer_v1.py +0 -0
  202. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientformer_v2.py +0 -0
  203. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientnet_lite.py +0 -0
  204. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientnet_v1.py +0 -0
  205. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientnet_v2.py +0 -0
  206. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientvim.py +0 -0
  207. {birder-0.4.1 → birder-0.4.4}/birder/net/efficientvit_mit.py +0 -0
  208. {birder-0.4.1 → birder-0.4.4}/birder/net/fasternet.py +0 -0
  209. {birder-0.4.1 → birder-0.4.4}/birder/net/fastvit.py +0 -0
  210. {birder-0.4.1 → birder-0.4.4}/birder/net/focalnet.py +0 -0
  211. {birder-0.4.1 → birder-0.4.4}/birder/net/gc_vit.py +0 -0
  212. {birder-0.4.1 → birder-0.4.4}/birder/net/ghostnet_v1.py +0 -0
  213. {birder-0.4.1 → birder-0.4.4}/birder/net/ghostnet_v2.py +0 -0
  214. {birder-0.4.1 → birder-0.4.4}/birder/net/groupmixformer.py +0 -0
  215. {birder-0.4.1 → birder-0.4.4}/birder/net/hgnet_v1.py +0 -0
  216. {birder-0.4.1 → birder-0.4.4}/birder/net/hgnet_v2.py +0 -0
  217. {birder-0.4.1 → birder-0.4.4}/birder/net/hornet.py +0 -0
  218. {birder-0.4.1 → birder-0.4.4}/birder/net/iformer.py +0 -0
  219. {birder-0.4.1 → birder-0.4.4}/birder/net/inception_next.py +0 -0
  220. {birder-0.4.1 → birder-0.4.4}/birder/net/inception_resnet_v1.py +0 -0
  221. {birder-0.4.1 → birder-0.4.4}/birder/net/inception_resnet_v2.py +0 -0
  222. {birder-0.4.1 → birder-0.4.4}/birder/net/inception_v3.py +0 -0
  223. {birder-0.4.1 → birder-0.4.4}/birder/net/inception_v4.py +0 -0
  224. {birder-0.4.1 → birder-0.4.4}/birder/net/levit.py +0 -0
  225. {birder-0.4.1 → birder-0.4.4}/birder/net/lit_v1.py +0 -0
  226. {birder-0.4.1 → birder-0.4.4}/birder/net/lit_v1_tiny.py +0 -0
  227. {birder-0.4.1 → birder-0.4.4}/birder/net/lit_v2.py +0 -0
  228. {birder-0.4.1 → birder-0.4.4}/birder/net/metaformer.py +0 -0
  229. {birder-0.4.1 → birder-0.4.4}/birder/net/mim/__init__.py +0 -0
  230. {birder-0.4.1 → birder-0.4.4}/birder/net/mim/base.py +0 -0
  231. {birder-0.4.1 → birder-0.4.4}/birder/net/mim/crossmae.py +0 -0
  232. {birder-0.4.1 → birder-0.4.4}/birder/net/mim/mae_vit.py +0 -0
  233. {birder-0.4.1 → birder-0.4.4}/birder/net/mim/simmim.py +0 -0
  234. {birder-0.4.1 → birder-0.4.4}/birder/net/mobilenet_v1.py +0 -0
  235. {birder-0.4.1 → birder-0.4.4}/birder/net/mobilenet_v2.py +0 -0
  236. {birder-0.4.1 → birder-0.4.4}/birder/net/mobilenet_v3.py +0 -0
  237. {birder-0.4.1 → birder-0.4.4}/birder/net/mobilenet_v4.py +0 -0
  238. {birder-0.4.1 → birder-0.4.4}/birder/net/mobilenet_v4_hybrid.py +0 -0
  239. {birder-0.4.1 → birder-0.4.4}/birder/net/mobileone.py +0 -0
  240. {birder-0.4.1 → birder-0.4.4}/birder/net/mobilevit_v1.py +0 -0
  241. {birder-0.4.1 → birder-0.4.4}/birder/net/mobilevit_v2.py +0 -0
  242. {birder-0.4.1 → birder-0.4.4}/birder/net/moganet.py +0 -0
  243. {birder-0.4.1 → birder-0.4.4}/birder/net/mvit_v2.py +0 -0
  244. {birder-0.4.1 → birder-0.4.4}/birder/net/nfnet.py +0 -0
  245. {birder-0.4.1 → birder-0.4.4}/birder/net/pit.py +0 -0
  246. {birder-0.4.1 → birder-0.4.4}/birder/net/pvt_v1.py +0 -0
  247. {birder-0.4.1 → birder-0.4.4}/birder/net/pvt_v2.py +0 -0
  248. {birder-0.4.1 → birder-0.4.4}/birder/net/rdnet.py +0 -0
  249. {birder-0.4.1 → birder-0.4.4}/birder/net/regionvit.py +0 -0
  250. {birder-0.4.1 → birder-0.4.4}/birder/net/regnet.py +0 -0
  251. {birder-0.4.1 → birder-0.4.4}/birder/net/regnet_z.py +0 -0
  252. {birder-0.4.1 → birder-0.4.4}/birder/net/repghost.py +0 -0
  253. {birder-0.4.1 → birder-0.4.4}/birder/net/repvgg.py +0 -0
  254. {birder-0.4.1 → birder-0.4.4}/birder/net/repvit.py +0 -0
  255. {birder-0.4.1 → birder-0.4.4}/birder/net/resmlp.py +0 -0
  256. {birder-0.4.1 → birder-0.4.4}/birder/net/resnest.py +0 -0
  257. {birder-0.4.1 → birder-0.4.4}/birder/net/resnet_v1.py +0 -0
  258. {birder-0.4.1 → birder-0.4.4}/birder/net/resnet_v2.py +0 -0
  259. {birder-0.4.1 → birder-0.4.4}/birder/net/sequencer2d.py +0 -0
  260. {birder-0.4.1 → birder-0.4.4}/birder/net/shufflenet_v1.py +0 -0
  261. {birder-0.4.1 → birder-0.4.4}/birder/net/shufflenet_v2.py +0 -0
  262. {birder-0.4.1 → birder-0.4.4}/birder/net/smt.py +0 -0
  263. {birder-0.4.1 → birder-0.4.4}/birder/net/squeezenext.py +0 -0
  264. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/__init__.py +0 -0
  265. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/barlow_twins.py +0 -0
  266. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/base.py +0 -0
  267. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/byol.py +0 -0
  268. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/data2vec.py +0 -0
  269. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/data2vec2.py +0 -0
  270. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/dino_v1.py +0 -0
  271. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/i_jepa.py +0 -0
  272. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/ibot.py +0 -0
  273. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/mmcr.py +0 -0
  274. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/simclr.py +0 -0
  275. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/sscd.py +0 -0
  276. {birder-0.4.1 → birder-0.4.4}/birder/net/ssl/vicreg.py +0 -0
  277. {birder-0.4.1 → birder-0.4.4}/birder/net/starnet.py +0 -0
  278. {birder-0.4.1 → birder-0.4.4}/birder/net/swiftformer.py +0 -0
  279. {birder-0.4.1 → birder-0.4.4}/birder/net/swin_transformer_v1.py +0 -0
  280. {birder-0.4.1 → birder-0.4.4}/birder/net/swin_transformer_v2.py +0 -0
  281. {birder-0.4.1 → birder-0.4.4}/birder/net/tiny_vit.py +0 -0
  282. {birder-0.4.1 → birder-0.4.4}/birder/net/transnext.py +0 -0
  283. {birder-0.4.1 → birder-0.4.4}/birder/net/uniformer.py +0 -0
  284. {birder-0.4.1 → birder-0.4.4}/birder/net/vgg.py +0 -0
  285. {birder-0.4.1 → birder-0.4.4}/birder/net/vgg_reduced.py +0 -0
  286. {birder-0.4.1 → birder-0.4.4}/birder/net/vovnet_v1.py +0 -0
  287. {birder-0.4.1 → birder-0.4.4}/birder/net/vovnet_v2.py +0 -0
  288. {birder-0.4.1 → birder-0.4.4}/birder/net/wide_resnet.py +0 -0
  289. {birder-0.4.1 → birder-0.4.4}/birder/net/xception.py +0 -0
  290. {birder-0.4.1/birder/tools → birder-0.4.4/birder/ops}/__init__.py +0 -0
  291. {birder-0.4.1 → birder-0.4.4}/birder/ops/soft_nms.py +0 -0
  292. {birder-0.4.1 → birder-0.4.4}/birder/ops/swattention.py +0 -0
  293. {birder-0.4.1 → birder-0.4.4}/birder/optim/__init__.py +0 -0
  294. {birder-0.4.1 → birder-0.4.4}/birder/optim/lamb.py +0 -0
  295. {birder-0.4.1 → birder-0.4.4}/birder/optim/lars.py +0 -0
  296. {birder-0.4.1 → birder-0.4.4}/birder/py.typed +0 -0
  297. {birder-0.4.1 → birder-0.4.4}/birder/results/classification.py +0 -0
  298. {birder-0.4.1 → birder-0.4.4}/birder/results/detection.py +0 -0
  299. {birder-0.4.1 → birder-0.4.4}/birder/results/gui.py +0 -0
  300. {birder-0.4.1 → birder-0.4.4}/birder/scheduler/__init__.py +0 -0
  301. {birder-0.4.1 → birder-0.4.4}/birder/scheduler/cooldown.py +0 -0
  302. {birder-0.4.1 → birder-0.4.4}/birder/scripts/__main__.py +0 -0
  303. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_barlow_twins.py +0 -0
  304. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_byol.py +0 -0
  305. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_data2vec.py +0 -0
  306. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_data2vec2.py +0 -0
  307. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_dino_v1.py +0 -0
  308. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_dino_v2.py +0 -0
  309. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_dino_v2_dist.py +0 -0
  310. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_i_jepa.py +0 -0
  311. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_ibot.py +0 -0
  312. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_mim.py +0 -0
  313. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_mmcr.py +0 -0
  314. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_rotnet.py +0 -0
  315. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_simclr.py +0 -0
  316. {birder-0.4.1 → birder-0.4.4}/birder/scripts/train_vicreg.py +0 -0
  317. {birder-0.4.1 → birder-0.4.4}/birder/tools/__main__.py +0 -0
  318. {birder-0.4.1 → birder-0.4.4}/birder/tools/auto_anchors.py +0 -0
  319. {birder-0.4.1 → birder-0.4.4}/birder/tools/avg_model.py +0 -0
  320. {birder-0.4.1 → birder-0.4.4}/birder/tools/det_results.py +0 -0
  321. {birder-0.4.1 → birder-0.4.4}/birder/tools/download_model.py +0 -0
  322. {birder-0.4.1 → birder-0.4.4}/birder/tools/ensemble_model.py +0 -0
  323. {birder-0.4.1 → birder-0.4.4}/birder/tools/introspection.py +0 -0
  324. {birder-0.4.1 → birder-0.4.4}/birder/tools/labelme_to_coco.py +0 -0
  325. {birder-0.4.1 → birder-0.4.4}/birder/tools/list_models.py +0 -0
  326. {birder-0.4.1 → birder-0.4.4}/birder/tools/model_info.py +0 -0
  327. {birder-0.4.1 → birder-0.4.4}/birder/tools/pack.py +0 -0
  328. {birder-0.4.1 → birder-0.4.4}/birder/tools/results.py +0 -0
  329. {birder-0.4.1 → birder-0.4.4}/birder/tools/show_det_iterator.py +0 -0
  330. {birder-0.4.1 → birder-0.4.4}/birder/tools/show_iterator.py +0 -0
  331. {birder-0.4.1 → birder-0.4.4}/birder/tools/similarity.py +0 -0
  332. {birder-0.4.1 → birder-0.4.4}/birder/tools/stats.py +0 -0
  333. {birder-0.4.1 → birder-0.4.4}/birder/tools/verify_coco.py +0 -0
  334. {birder-0.4.1 → birder-0.4.4}/birder/tools/verify_directory.py +0 -0
  335. {birder-0.4.1 → birder-0.4.4}/birder/tools/voc_to_coco.py +0 -0
  336. {birder-0.4.1 → birder-0.4.4}/birder.egg-info/dependency_links.txt +0 -0
  337. {birder-0.4.1 → birder-0.4.4}/birder.egg-info/entry_points.txt +0 -0
  338. {birder-0.4.1 → birder-0.4.4}/birder.egg-info/top_level.txt +0 -0
  339. {birder-0.4.1 → birder-0.4.4}/requirements/requirements-hf.txt +0 -0
  340. {birder-0.4.1 → birder-0.4.4}/setup.cfg +0 -0
  341. {birder-0.4.1 → birder-0.4.4}/tests/test_adversarial.py +0 -0
  342. {birder-0.4.1 → birder-0.4.4}/tests/test_collators.py +0 -0
  343. {birder-0.4.1 → birder-0.4.4}/tests/test_dataloaders.py +0 -0
  344. {birder-0.4.1 → birder-0.4.4}/tests/test_datasets.py +0 -0
  345. {birder-0.4.1 → birder-0.4.4}/tests/test_inference.py +0 -0
  346. {birder-0.4.1 → birder-0.4.4}/tests/test_introspection.py +0 -0
  347. {birder-0.4.1 → birder-0.4.4}/tests/test_layers.py +0 -0
  348. {birder-0.4.1 → birder-0.4.4}/tests/test_model_registry.py +0 -0
  349. {birder-0.4.1 → birder-0.4.4}/tests/test_net_mim.py +0 -0
  350. {birder-0.4.1 → birder-0.4.4}/tests/test_ops.py +0 -0
  351. {birder-0.4.1 → birder-0.4.4}/tests/test_optim.py +0 -0
  352. {birder-0.4.1 → birder-0.4.4}/tests/test_results.py +0 -0
  353. {birder-0.4.1 → birder-0.4.4}/tests/test_scheduler.py +0 -0
  354. {birder-0.4.1 → birder-0.4.4}/tests/test_transforms.py +0 -0
@@ -1,14 +1,14 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder
3
- Version: 0.4.1
3
+ Version: 0.4.4
4
4
  Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
7
7
  Project-URL: Homepage, https://gitlab.com/birder/birder
8
8
  Project-URL: Documentation, https://birder.gitlab.io/birder/
9
9
  Project-URL: Issues, https://gitlab.com/birder/birder/-/issues
10
- Keywords: computer-vision,image-classification,object-detection,pytorch,deep-learning
11
- Classifier: Development Status :: 3 - Alpha
10
+ Keywords: computer-vision,image-classification,object-detection,self-supervised learning,masked image modeling,pytorch,deep-learning,artificial intelligence
11
+ Classifier: Development Status :: 4 - Beta
12
12
  Classifier: Intended Audience :: Science/Research
13
13
  Classifier: Intended Audience :: Developers
14
14
  Classifier: Intended Audience :: Education
@@ -26,7 +26,7 @@ License-File: LICENSE
26
26
  Requires-Dist: matplotlib>=3.9.0
27
27
  Requires-Dist: numpy>=2.2.0
28
28
  Requires-Dist: onnx>=1.18.0
29
- Requires-Dist: onnxscript~=0.5.7
29
+ Requires-Dist: onnxscript~=0.6.0
30
30
  Requires-Dist: Pillow>=12.0.0
31
31
  Requires-Dist: polars>=1.31.0
32
32
  Requires-Dist: pyarrow>=20.0.0
@@ -48,7 +48,7 @@ Requires-Dist: black~=26.1.0; extra == "dev"
48
48
  Requires-Dist: build~=1.4.0; extra == "dev"
49
49
  Requires-Dist: bumpver~=2025.1131; extra == "dev"
50
50
  Requires-Dist: captum~=0.7.0; extra == "dev"
51
- Requires-Dist: coverage~=7.13.1; extra == "dev"
51
+ Requires-Dist: coverage~=7.13.3; extra == "dev"
52
52
  Requires-Dist: debugpy; extra == "dev"
53
53
  Requires-Dist: flake8-pep585~=0.1.7; extra == "dev"
54
54
  Requires-Dist: flake8~=7.3.0; extra == "dev"
@@ -87,6 +87,7 @@ An open-source computer vision framework for wildlife image analysis, featuring
87
87
  - [Getting Started](#getting-started)
88
88
  - [Pre-trained Models](#pre-trained-models)
89
89
  - [Detection](#detection)
90
+ - [Evaluation](#evaluation)
90
91
  - [Project Status and Contributions](#project-status-and-contributions)
91
92
  - [Licenses](#licenses)
92
93
  - [Acknowledgments](#acknowledgments)
@@ -117,7 +118,9 @@ The same principle applies to Birder. We stand on the shoulders of giants in the
117
118
 
118
119
  ## Setup
119
120
 
120
- 1. Ensure PyTorch 2.7 is installed on your system
121
+ 1. Ensure your environment meets the minimum requirements:
122
+ - Python 3.11 or newer
123
+ - PyTorch 2.7 or newer (installed for your hardware/driver stack)
121
124
 
122
125
  1. Install the latest Birder version:
123
126
 
@@ -212,6 +215,10 @@ For detailed information about these datasets, including descriptions, citations
212
215
  Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
213
216
  [docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
214
217
 
218
+ ## Evaluation
219
+
220
+ Evaluation workflows are documented in [docs/evaluation.md](docs/evaluation.md).
221
+
215
222
  ## Project Status and Contributions
216
223
 
217
224
  Birder is currently a personal project in active development. As the sole developer, I am focused on building and refining the core functionalities of the framework. At this time, I am not actively seeking external contributors.
@@ -240,15 +247,15 @@ Files subject to additional license restrictions are marked in their headers. So
240
247
 
241
248
  If you think we've missed a reference or a license, please create an issue.
242
249
 
243
- ### Pretrained Weights
250
+ ### Pre-trained Weights
244
251
 
245
- Some of the pretrained weights available here are pretrained on ImageNet. ImageNet was released for non-commercial research purposes only (<https://image-net.org/download>). It's not clear what the implications of that are for the use of pretrained weights from that dataset. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
252
+ Some of the pre-trained weights available here are pre-trained on ImageNet. ImageNet was released for non-commercial research purposes only (<https://image-net.org/download>). It's not clear what the implications are for the use of pre-trained weights from that dataset. It's best to seek legal advice if you intend to use the pre-trained weights in a commercial product.
246
253
 
247
254
  ### Disclaimer
248
255
 
249
- If you intend to use Birder, its pretrained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
256
+ If you intend to use Birder, its pre-trained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
250
257
 
251
- It's the user's responsibility to ensure that their use of this project, including any pretrained weights or datasets, complies with all applicable licenses and legal requirements.
258
+ It's the user's responsibility to ensure that their use of this project, including any pre-trained weights or datasets, complies with all applicable licenses and legal requirements.
252
259
 
253
260
  ## Acknowledgments
254
261
 
@@ -7,6 +7,7 @@ An open-source computer vision framework for wildlife image analysis, featuring
7
7
  - [Getting Started](#getting-started)
8
8
  - [Pre-trained Models](#pre-trained-models)
9
9
  - [Detection](#detection)
10
+ - [Evaluation](#evaluation)
10
11
  - [Project Status and Contributions](#project-status-and-contributions)
11
12
  - [Licenses](#licenses)
12
13
  - [Acknowledgments](#acknowledgments)
@@ -37,7 +38,9 @@ The same principle applies to Birder. We stand on the shoulders of giants in the
37
38
 
38
39
  ## Setup
39
40
 
40
- 1. Ensure PyTorch 2.7 is installed on your system
41
+ 1. Ensure your environment meets the minimum requirements:
42
+ - Python 3.11 or newer
43
+ - PyTorch 2.7 or newer (installed for your hardware/driver stack)
41
44
 
42
45
  1. Install the latest Birder version:
43
46
 
@@ -132,6 +135,10 @@ For detailed information about these datasets, including descriptions, citations
132
135
  Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
133
136
  [docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
134
137
 
138
+ ## Evaluation
139
+
140
+ Evaluation workflows are documented in [docs/evaluation.md](docs/evaluation.md).
141
+
135
142
  ## Project Status and Contributions
136
143
 
137
144
  Birder is currently a personal project in active development. As the sole developer, I am focused on building and refining the core functionalities of the framework. At this time, I am not actively seeking external contributors.
@@ -160,15 +167,15 @@ Files subject to additional license restrictions are marked in their headers. So
160
167
 
161
168
  If you think we've missed a reference or a license, please create an issue.
162
169
 
163
- ### Pretrained Weights
170
+ ### Pre-trained Weights
164
171
 
165
- Some of the pretrained weights available here are pretrained on ImageNet. ImageNet was released for non-commercial research purposes only (<https://image-net.org/download>). It's not clear what the implications of that are for the use of pretrained weights from that dataset. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
172
+ Some of the pre-trained weights available here are pre-trained on ImageNet. ImageNet was released for non-commercial research purposes only (<https://image-net.org/download>). It's not clear what the implications are for the use of pre-trained weights from that dataset. It's best to seek legal advice if you intend to use the pre-trained weights in a commercial product.
166
173
 
167
174
  ### Disclaimer
168
175
 
169
- If you intend to use Birder, its pretrained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
176
+ If you intend to use Birder, its pre-trained weights, or any associated datasets in a commercial product, we strongly recommend seeking legal advice to ensure compliance with all relevant licenses and terms of use.
170
177
 
171
- It's the user's responsibility to ensure that their use of this project, including any pretrained weights or datasets, complies with all applicable licenses and legal requirements.
178
+ It's the user's responsibility to ensure that their use of this project, including any pre-trained weights or datasets, complies with all applicable licenses and legal requirements.
172
179
 
173
180
  ## Acknowledgments
174
181
 
@@ -1,5 +1,6 @@
1
1
  from birder.common.fs_ops import load_model_with_cfg
2
2
  from birder.common.fs_ops import load_pretrained_model
3
+ from birder.common.fs_ops import load_pretrained_model_and_transform
3
4
  from birder.common.lib import get_channels_from_signature
4
5
  from birder.common.lib import get_size_from_signature
5
6
  from birder.data.transforms.classification import inference_preset as classification_transform
@@ -17,5 +18,6 @@ __all__ = [
17
18
  "list_pretrained_models",
18
19
  "load_model_with_cfg",
19
20
  "load_pretrained_model",
21
+ "load_pretrained_model_and_transform",
20
22
  "__version__",
21
23
  ]
@@ -2,6 +2,7 @@ import json
2
2
  import logging
3
3
  import os
4
4
  import re
5
+ from collections.abc import Callable
5
6
  from collections.abc import Iterator
6
7
  from pathlib import Path
7
8
  from typing import Any
@@ -24,6 +25,8 @@ from birder.common.lib import get_network_name
24
25
  from birder.common.lib import get_pretrained_model_url
25
26
  from birder.conf import settings
26
27
  from birder.data.transforms.classification import RGBType
28
+ from birder.data.transforms.classification import inference_preset
29
+ from birder.data.transforms.detection import InferenceTransform
27
30
  from birder.model_registry import Task
28
31
  from birder.model_registry import registry
29
32
  from birder.model_registry.manifest import FileFormatType
@@ -801,7 +804,8 @@ def load_detection_model(
801
804
  for param in net.parameters():
802
805
  param.requires_grad_(False)
803
806
 
804
- net.eval()
807
+ if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
808
+ net.eval()
805
809
 
806
810
  if len(backbone_loaded_config) == 0:
807
811
  backbone_custom_config = None
@@ -918,6 +922,82 @@ def load_pretrained_model(
918
922
  raise ValueError(f"Unknown model type: {model_metadata['task']}")
919
923
 
920
924
 
925
+ def load_pretrained_model_and_transform(
926
+ weights: str,
927
+ *,
928
+ dst: Optional[str | Path] = None,
929
+ file_format: FileFormatType = "pt",
930
+ inference: bool = True,
931
+ device: Optional[torch.device] = None,
932
+ dtype: Optional[torch.dtype] = None,
933
+ custom_config: Optional[dict[str, Any]] = None,
934
+ progress_bar: bool = True,
935
+ classification_kwargs: Optional[dict[str, Any]] = None,
936
+ detection_kwargs: Optional[dict[str, Any]] = None,
937
+ ) -> tuple[BaseNet | DetectionBaseNet, ModelInfo | DetectionModelInfo, Callable[..., torch.Tensor]]:
938
+ """
939
+ Loads a pre-trained model and builds the matching inference transform
940
+
941
+ This is a convenience helper for the common inference path where the model and
942
+ its default preprocessing are needed together. Classification models use
943
+ inference_preset, detection models use InferenceTransform.
944
+
945
+ Parameters
946
+ ----------
947
+ weights
948
+ Name of the pre-trained weights to load from the model registry.
949
+ dst
950
+ Destination path where the model weights will be downloaded or loaded from.
951
+ file_format
952
+ Model format (e.g. pt, pt2, safetensors, etc.)
953
+ inference
954
+ Flag to prepare the model for inference mode.
955
+ device
956
+ The device to load the model on (cpu/cuda).
957
+ dtype
958
+ Data type for model parameters and computations (e.g., torch.float32, torch.float16).
959
+ custom_config
960
+ Additional model configuration that overrides or extends the predefined configuration.
961
+ progress_bar
962
+ Whether to display a progress bar during file download.
963
+ classification_kwargs
964
+ Optional keyword arguments forwarded to inference_preset.
965
+ detection_kwargs
966
+ Optional keyword arguments forwarded to InferenceTransform. If dynamic_size is
967
+ not provided it defaults to the model signature value.
968
+
969
+ Returns
970
+ -------
971
+ A tuple containing three elements:
972
+ - A PyTorch module (neural network model) loaded with pre-trained weights.
973
+ - Model info containing class mappings, signature, and RGB stats.
974
+ - An inference transform matching the model task.
975
+ """
976
+
977
+ net, model_info = load_pretrained_model(
978
+ weights,
979
+ dst=dst,
980
+ file_format=file_format,
981
+ inference=inference,
982
+ device=device,
983
+ dtype=dtype,
984
+ custom_config=custom_config,
985
+ progress_bar=progress_bar,
986
+ )
987
+
988
+ size = lib.get_size_from_signature(model_info.signature)
989
+ transform: Callable[..., torch.Tensor]
990
+ if isinstance(model_info, DetectionModelInfo):
991
+ detection_args = {} if detection_kwargs is None else dict(detection_kwargs)
992
+ detection_args.setdefault("dynamic_size", model_info.signature["dynamic"])
993
+ transform = InferenceTransform(size, model_info.rgb_stats, **detection_args)
994
+ else:
995
+ classification_args = {} if classification_kwargs is None else dict(classification_kwargs)
996
+ transform = inference_preset(size, model_info.rgb_stats, **classification_args)
997
+
998
+ return (net, model_info, transform)
999
+
1000
+
921
1001
  def load_model_with_cfg(
922
1002
  cfg: dict[str, Any] | str | Path, weights_path: Optional[str | Path]
923
1003
  ) -> tuple[torch.nn.Module, dict[str, Any]]:
@@ -56,7 +56,9 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
56
56
  )
57
57
 
58
58
 
59
- def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False) -> None:
59
+ def add_lr_wd_args(
60
+ parser: argparse.ArgumentParser, backbone_lr: bool = False, wd_end: bool = False, backbone_layer_decay: bool = False
61
+ ) -> None:
60
62
  group = parser.add_argument_group("Learning rate and regularization parameters")
61
63
  group.add_argument("--lr", type=float, default=0.1, metavar="LR", help="base learning rate")
62
64
  group.add_argument("--bias-lr", type=float, metavar="LR", help="learning rate of biases")
@@ -92,6 +94,9 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
92
94
  help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
93
95
  )
94
96
  group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
97
+ if backbone_layer_decay is True:
98
+ group.add_argument("--backbone-layer-decay", type=float, help="backbone layer-wise learning rate decay (LLRD)")
99
+
95
100
  group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
96
101
  group.add_argument(
97
102
  "--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
@@ -480,8 +485,13 @@ def add_dataloader_args(
480
485
  )
481
486
 
482
487
 
483
- def add_precision_args(parser: argparse.ArgumentParser) -> None:
488
+ def add_precision_args(parser: argparse.ArgumentParser, channels_last: bool = False) -> None:
484
489
  group = parser.add_argument_group("Precision parameters")
490
+ if channels_last is True:
491
+ group.add_argument(
492
+ "--channels-last", default=False, action="store_true", help="use channels-last memory format"
493
+ )
494
+
485
495
  group.add_argument(
486
496
  "--model-dtype",
487
497
  type=str,
@@ -343,7 +343,7 @@ def count_layers(model: torch.nn.Module) -> int:
343
343
  return num_layers
344
344
 
345
345
 
346
- # pylint: disable=protected-access,too-many-locals,too-many-branches
346
+ # pylint: disable=protected-access,too-many-locals,too-many-branches,too-many-statements
347
347
  def optimizer_parameter_groups(
348
348
  model: torch.nn.Module,
349
349
  weight_decay: float,
@@ -352,6 +352,7 @@ def optimizer_parameter_groups(
352
352
  custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
353
353
  custom_layer_weight_decay: Optional[dict[str, float]] = None,
354
354
  layer_decay: Optional[float] = None,
355
+ backbone_layer_decay: Optional[float] = None,
355
356
  layer_decay_min_scale: Optional[float] = None,
356
357
  layer_decay_no_opt_scale: Optional[float] = None,
357
358
  bias_lr: Optional[float] = None,
@@ -388,6 +389,8 @@ def optimizer_parameter_groups(
388
389
  Applied to parameters whose names contain the specified keys.
389
390
  layer_decay
390
391
  Layer-wise learning rate decay factor.
392
+ backbone_layer_decay
393
+ Layer-wise learning rate decay factor for backbone parameters only.
391
394
  layer_decay_min_scale
392
395
  Minimum learning rate scale factor when using layer decay. Prevents layers from having too small learning rates.
393
396
  layer_decay_no_opt_scale
@@ -434,6 +437,27 @@ def optimizer_parameter_groups(
434
437
  if layer_decay is not None:
435
438
  logger.warning("Assigning lr scaling (layer decay) without a block group map")
436
439
 
440
+ backbone_group_map: dict[str, int] = {}
441
+ backbone_num_layers = 0
442
+ if backbone_layer_decay is not None:
443
+ backbone_module = getattr(model, "backbone", None)
444
+ if backbone_module is None:
445
+ logger.warning("Backbone layer decay requested but model has no backbone")
446
+ backbone_layer_decay = None
447
+ else:
448
+ backbone_block_group_regex = getattr(backbone_module, "block_group_regex", None)
449
+ if backbone_block_group_regex is not None:
450
+ names = [n for n, _ in backbone_module.named_parameters()]
451
+ groups = group_by_regex(names, backbone_block_group_regex)
452
+ backbone_group_map = {
453
+ f"backbone.{item}": index for index, sublist in enumerate(groups) for item in sublist
454
+ }
455
+ backbone_num_layers = len(groups)
456
+ else:
457
+ backbone_group_map = {}
458
+ backbone_num_layers = count_layers(backbone_module)
459
+ logger.warning("Assigning lr scaling (backbone layer decay) without a block group map")
460
+
437
461
  # Build layer scale
438
462
  if layer_decay_min_scale is None:
439
463
  layer_decay_min_scale = 0.0
@@ -444,14 +468,28 @@ def optimizer_parameter_groups(
444
468
  layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
445
469
  logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
446
470
 
471
+ backbone_layer_scales = []
472
+ if backbone_layer_decay is not None:
473
+ backbone_layer_max = backbone_num_layers - 1
474
+ backbone_layer_scales = [
475
+ max(layer_decay_min_scale, backbone_layer_decay ** (backbone_layer_max - i))
476
+ for i in range(backbone_num_layers)
477
+ ]
478
+ logger.info(
479
+ "Backbone layer scaling ranges from "
480
+ f"{min(backbone_layer_scales)} to {max(backbone_layer_scales)} across {backbone_num_layers} layers"
481
+ )
482
+
447
483
  # Set weight decay and layer decay
448
484
  idx = 0
485
+ backbone_idx = 0
449
486
  params = []
450
487
  module_stack_with_prefix = [(model, "")]
451
488
  visited_modules = []
452
489
  while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
453
490
  skip_module = False
454
491
  module, prefix = module_stack_with_prefix.pop()
492
+ is_backbone_module = prefix == "backbone" or prefix.startswith("backbone.")
455
493
  if id(module) in visited_modules:
456
494
  skip_module = True
457
495
 
@@ -460,23 +498,35 @@ def optimizer_parameter_groups(
460
498
  for name, p in module.named_parameters(recurse=False):
461
499
  target_name = f"{prefix}.{name}" if prefix != "" else name
462
500
  idx = group_map.get(target_name, idx)
501
+ is_backbone_param = target_name.startswith("backbone.")
502
+ if backbone_layer_decay is not None and is_backbone_param is True:
503
+ backbone_idx = backbone_group_map.get(target_name, backbone_idx)
463
504
  if skip_module is True:
464
505
  break
465
506
 
466
507
  parameters_found = True
467
508
  if p.requires_grad is False:
468
509
  continue
469
- if layer_decay is not None and layer_decay_no_opt_scale is not None:
470
- if layer_scales[idx] < layer_decay_no_opt_scale:
471
- p.requires_grad_(False)
510
+ if layer_decay_no_opt_scale is not None:
511
+ if backbone_layer_decay is not None and is_backbone_param is True:
512
+ if backbone_layer_scales and backbone_layer_scales[backbone_idx] < layer_decay_no_opt_scale:
513
+ p.requires_grad_(False)
514
+ elif layer_decay is not None:
515
+ if layer_scales[idx] < layer_decay_no_opt_scale:
516
+ p.requires_grad_(False)
472
517
 
473
518
  is_custom_key = False
474
519
  if custom_keys_weight_decay is not None:
475
520
  for key, custom_wd in custom_keys_weight_decay:
476
521
  target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
477
522
  if key == target_name_for_custom_key:
478
- # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
479
- lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
523
+ # Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
524
+ if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
525
+ lr_scale = layer_scales[idx]
526
+ elif backbone_layer_decay is not None and is_backbone_param is True:
527
+ lr_scale = backbone_layer_scales[backbone_idx]
528
+ else:
529
+ lr_scale = 1.0
480
530
  if custom_layer_lr_scale is not None:
481
531
  for layer_name_key, custom_scale in custom_layer_lr_scale.items():
482
532
  if layer_name_key in target_name:
@@ -500,8 +550,8 @@ def optimizer_parameter_groups(
500
550
  # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
501
551
  if bias_lr is not None and target_name.endswith(".bias") is True:
502
552
  d["lr"] = bias_lr
503
- elif backbone_lr is not None and target_name.startswith("backbone.") is True:
504
- d["lr"] = backbone_lr
553
+ elif backbone_lr is not None and is_backbone_param is True:
554
+ d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
505
555
  elif lr_scale != 1.0:
506
556
  d["lr"] = base_lr * lr_scale
507
557
 
@@ -522,8 +572,13 @@ def optimizer_parameter_groups(
522
572
  wd = custom_wd_value
523
573
  break
524
574
 
525
- # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
526
- lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
575
+ # Calculate lr_scale (from layer_decay/backbone_layer_decay or custom_layer_lr_scale)
576
+ if layer_decay is not None and (backbone_layer_decay is None or is_backbone_param is False):
577
+ lr_scale = layer_scales[idx]
578
+ elif backbone_layer_decay is not None and is_backbone_param is True:
579
+ lr_scale = backbone_layer_scales[backbone_idx]
580
+ else:
581
+ lr_scale = 1.0
527
582
  if custom_layer_lr_scale is not None:
528
583
  for layer_name_key, custom_scale in custom_layer_lr_scale.items():
529
584
  if layer_name_key in target_name:
@@ -539,8 +594,8 @@ def optimizer_parameter_groups(
539
594
  # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
540
595
  if bias_lr is not None and target_name.endswith(".bias") is True:
541
596
  d["lr"] = bias_lr
542
- elif backbone_lr is not None and target_name.startswith("backbone.") is True:
543
- d["lr"] = backbone_lr
597
+ elif backbone_lr is not None and is_backbone_param is True:
598
+ d["lr"] = backbone_lr * lr_scale if backbone_layer_decay is not None else backbone_lr
544
599
  elif lr_scale != 1.0:
545
600
  d["lr"] = base_lr * lr_scale
546
601
 
@@ -548,6 +603,8 @@ def optimizer_parameter_groups(
548
603
 
549
604
  if parameters_found is True:
550
605
  idx += 1
606
+ if is_backbone_module is True:
607
+ backbone_idx += 1
551
608
 
552
609
  for child_name, child_module in reversed(list(module.named_children())):
553
610
  child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
@@ -1108,12 +1165,16 @@ def init_training(
1108
1165
  device_id = torch.cuda.current_device()
1109
1166
 
1110
1167
  if args.use_deterministic_algorithms is True:
1168
+ log.debug("Turning on deterministic algorithms")
1111
1169
  torch.backends.cudnn.benchmark = False
1112
1170
  torch.use_deterministic_algorithms(True)
1113
1171
  elif cudnn_dynamic_size is True:
1114
1172
  # Dynamic sizes: avoid per-size algorithm selection overhead.
1173
+ log.debug("Turning off cudnn")
1115
1174
  torch.backends.cudnn.enabled = False
1175
+ torch.backends.cudnn.benchmark = False
1116
1176
  else:
1177
+ log.debug("Turning on cudnn")
1117
1178
  torch.backends.cudnn.enabled = True
1118
1179
  torch.backends.cudnn.benchmark = True
1119
1180
 
@@ -15,7 +15,9 @@ def collate_fn(batch: list[tuple[Any, ...]]) -> tuple[Any, ...]:
15
15
  return tuple(zip(*batch))
16
16
 
17
17
 
18
- def batch_images(images: list[torch.Tensor], size_divisible: int) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]:
18
+ def batch_images(
19
+ images: list[torch.Tensor], size_divisible: int
20
+ ) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]]]:
19
21
  """
20
22
  Batch list of image tensors of different sizes into a single batch.
21
23
  Pad with zeros all images to the shape of the largest image in the list.
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import tarfile
3
+ import zipfile
3
4
  from pathlib import Path
4
5
 
5
6
  from birder.common import cli
@@ -26,9 +27,17 @@ def download_url(url: str, target: str | Path, sha256: str, progress_bar: bool =
26
27
 
27
28
  def extract_archive(from_path: str | Path, to_path: str | Path) -> None:
28
29
  logger.info(f"Extracting {from_path} to {to_path}")
29
- with tarfile.open(from_path, "r") as tar:
30
- if hasattr(tarfile, "data_filter") is True:
31
- tar.extractall(to_path, filter="data")
32
- else:
33
- # NOTE: Remove once minimum Python version is 3.12 or above
34
- tar.extractall(to_path) # nosec # tarfile_unsafe_members
30
+ if isinstance(from_path, str):
31
+ from_path = Path(from_path)
32
+
33
+ if from_path.suffix == ".zip":
34
+ with zipfile.ZipFile(from_path, "r") as zf:
35
+ zf.extractall(to_path) # nosec # tarfile_unsafe_members
36
+
37
+ else:
38
+ with tarfile.open(from_path, "r") as tar:
39
+ if hasattr(tarfile, "data_filter") is True:
40
+ tar.extractall(to_path, filter="data")
41
+ else:
42
+ # NOTE: Remove once minimum Python version is 3.12 or above
43
+ tar.extractall(to_path) # nosec # tarfile_unsafe_members