birder 0.2.2__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (330) hide show
  1. {birder-0.2.2 → birder-0.3.0}/PKG-INFO +4 -3
  2. {birder-0.2.2 → birder-0.3.0}/README.md +1 -1
  3. {birder-0.2.2 → birder-0.3.0}/birder/common/lib.py +2 -9
  4. {birder-0.2.2 → birder-0.3.0}/birder/common/training_cli.py +24 -0
  5. {birder-0.2.2 → birder-0.3.0}/birder/common/training_utils.py +338 -41
  6. {birder-0.2.2 → birder-0.3.0}/birder/data/collators/detection.py +11 -3
  7. {birder-0.2.2 → birder-0.3.0}/birder/data/dataloader/webdataset.py +12 -2
  8. {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/coco.py +8 -10
  9. {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/detection.py +30 -13
  10. {birder-0.2.2 → birder-0.3.0}/birder/inference/detection.py +108 -4
  11. birder-0.3.0/birder/inference/wbf.py +226 -0
  12. {birder-0.2.2 → birder-0.3.0}/birder/kernels/load_kernel.py +16 -11
  13. {birder-0.2.2 → birder-0.3.0}/birder/kernels/soft_nms/soft_nms.cpp +17 -18
  14. {birder-0.2.2 → birder-0.3.0}/birder/net/__init__.py +8 -0
  15. {birder-0.2.2 → birder-0.3.0}/birder/net/cait.py +4 -3
  16. {birder-0.2.2 → birder-0.3.0}/birder/net/convnext_v1.py +5 -0
  17. {birder-0.2.2 → birder-0.3.0}/birder/net/crossformer.py +33 -30
  18. {birder-0.2.2 → birder-0.3.0}/birder/net/crossvit.py +4 -3
  19. {birder-0.2.2 → birder-0.3.0}/birder/net/deit.py +3 -3
  20. {birder-0.2.2 → birder-0.3.0}/birder/net/deit3.py +3 -3
  21. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/deformable_detr.py +2 -5
  22. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/detr.py +2 -5
  23. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/efficientdet.py +67 -93
  24. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/fcos.py +2 -7
  25. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/retinanet.py +2 -7
  26. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/rt_detr_v1.py +2 -0
  27. birder-0.3.0/birder/net/detection/yolo_anchors.py +205 -0
  28. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v2.py +25 -24
  29. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v3.py +39 -40
  30. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v4.py +28 -26
  31. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/yolo_v4_tiny.py +24 -20
  32. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientformer_v1.py +15 -9
  33. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientformer_v2.py +39 -29
  34. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientvit_msft.py +9 -7
  35. {birder-0.2.2 → birder-0.3.0}/birder/net/fasternet.py +1 -1
  36. {birder-0.2.2 → birder-0.3.0}/birder/net/fastvit.py +1 -0
  37. {birder-0.2.2 → birder-0.3.0}/birder/net/flexivit.py +5 -4
  38. birder-0.3.0/birder/net/gc_vit.py +671 -0
  39. {birder-0.2.2 → birder-0.3.0}/birder/net/hiera.py +12 -9
  40. {birder-0.2.2 → birder-0.3.0}/birder/net/hornet.py +9 -7
  41. {birder-0.2.2 → birder-0.3.0}/birder/net/iformer.py +8 -6
  42. {birder-0.2.2 → birder-0.3.0}/birder/net/levit.py +42 -30
  43. birder-0.3.0/birder/net/lit_v1.py +472 -0
  44. birder-0.3.0/birder/net/lit_v1_tiny.py +357 -0
  45. birder-0.3.0/birder/net/lit_v2.py +436 -0
  46. {birder-0.2.2 → birder-0.3.0}/birder/net/maxvit.py +67 -55
  47. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v4_hybrid.py +1 -1
  48. {birder-0.2.2 → birder-0.3.0}/birder/net/mobileone.py +1 -0
  49. {birder-0.2.2 → birder-0.3.0}/birder/net/mvit_v2.py +13 -12
  50. {birder-0.2.2 → birder-0.3.0}/birder/net/pit.py +4 -3
  51. {birder-0.2.2 → birder-0.3.0}/birder/net/pvt_v1.py +4 -1
  52. {birder-0.2.2 → birder-0.3.0}/birder/net/repghost.py +1 -0
  53. {birder-0.2.2 → birder-0.3.0}/birder/net/repvgg.py +1 -0
  54. {birder-0.2.2 → birder-0.3.0}/birder/net/repvit.py +1 -0
  55. {birder-0.2.2 → birder-0.3.0}/birder/net/resnet_v1.py +1 -1
  56. {birder-0.2.2 → birder-0.3.0}/birder/net/resnext.py +67 -25
  57. {birder-0.2.2 → birder-0.3.0}/birder/net/rope_deit3.py +5 -3
  58. {birder-0.2.2 → birder-0.3.0}/birder/net/rope_flexivit.py +7 -4
  59. {birder-0.2.2 → birder-0.3.0}/birder/net/rope_vit.py +10 -5
  60. {birder-0.2.2 → birder-0.3.0}/birder/net/se_resnet_v1.py +46 -0
  61. {birder-0.2.2 → birder-0.3.0}/birder/net/se_resnext.py +3 -0
  62. {birder-0.2.2 → birder-0.3.0}/birder/net/simple_vit.py +11 -8
  63. {birder-0.2.2 → birder-0.3.0}/birder/net/swin_transformer_v1.py +71 -68
  64. {birder-0.2.2 → birder-0.3.0}/birder/net/swin_transformer_v2.py +38 -31
  65. {birder-0.2.2 → birder-0.3.0}/birder/net/tiny_vit.py +20 -10
  66. {birder-0.2.2 → birder-0.3.0}/birder/net/transnext.py +38 -28
  67. {birder-0.2.2 → birder-0.3.0}/birder/net/vit.py +5 -19
  68. {birder-0.2.2 → birder-0.3.0}/birder/net/vit_parallel.py +5 -4
  69. {birder-0.2.2 → birder-0.3.0}/birder/net/vit_sam.py +38 -37
  70. {birder-0.2.2 → birder-0.3.0}/birder/net/vovnet_v1.py +15 -0
  71. {birder-0.2.2 → birder-0.3.0}/birder/net/vovnet_v2.py +31 -1
  72. birder-0.3.0/birder/ops/msda.py +203 -0
  73. birder-0.3.0/birder/ops/swattention.py +288 -0
  74. {birder-0.2.2 → birder-0.3.0}/birder/results/detection.py +4 -0
  75. {birder-0.2.2 → birder-0.3.0}/birder/scripts/benchmark.py +110 -32
  76. {birder-0.2.2 → birder-0.3.0}/birder/scripts/predict.py +8 -0
  77. {birder-0.2.2 → birder-0.3.0}/birder/scripts/predict_detection.py +18 -11
  78. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train.py +48 -46
  79. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_barlow_twins.py +44 -45
  80. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_byol.py +44 -45
  81. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_capi.py +50 -49
  82. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_data2vec.py +45 -47
  83. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_data2vec2.py +45 -47
  84. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_detection.py +83 -50
  85. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_dino_v1.py +60 -47
  86. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_dino_v2.py +86 -52
  87. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_dino_v2_dist.py +84 -50
  88. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_franca.py +51 -52
  89. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_i_jepa.py +45 -47
  90. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_ibot.py +51 -53
  91. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_kd.py +194 -76
  92. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_mim.py +44 -45
  93. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_mmcr.py +44 -45
  94. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_rotnet.py +45 -46
  95. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_simclr.py +44 -45
  96. {birder-0.2.2 → birder-0.3.0}/birder/scripts/train_vicreg.py +44 -45
  97. {birder-0.2.2 → birder-0.3.0}/birder/tools/auto_anchors.py +20 -1
  98. {birder-0.2.2 → birder-0.3.0}/birder/tools/convert_model.py +18 -15
  99. birder-0.3.0/birder/tools/det_results.py +173 -0
  100. {birder-0.2.2 → birder-0.3.0}/birder/tools/pack.py +172 -103
  101. birder-0.3.0/birder/tools/quantize_model.py +162 -0
  102. {birder-0.2.2 → birder-0.3.0}/birder/tools/show_det_iterator.py +10 -1
  103. birder-0.3.0/birder/version.py +1 -0
  104. {birder-0.2.2 → birder-0.3.0}/birder.egg-info/PKG-INFO +4 -3
  105. {birder-0.2.2 → birder-0.3.0}/birder.egg-info/SOURCES.txt +7 -0
  106. {birder-0.2.2 → birder-0.3.0}/birder.egg-info/requires.txt +2 -1
  107. {birder-0.2.2 → birder-0.3.0}/requirements/_requirements-dev.txt +3 -1
  108. {birder-0.2.2 → birder-0.3.0}/tests/test_common.py +271 -16
  109. birder-0.3.0/tests/test_dataloaders.py +101 -0
  110. {birder-0.2.2 → birder-0.3.0}/tests/test_inference.py +69 -0
  111. {birder-0.2.2 → birder-0.3.0}/tests/test_kernels.py +13 -0
  112. {birder-0.2.2 → birder-0.3.0}/tests/test_model_registry.py +2 -2
  113. {birder-0.2.2 → birder-0.3.0}/tests/test_net.py +274 -177
  114. {birder-0.2.2 → birder-0.3.0}/tests/test_net_detection.py +44 -0
  115. {birder-0.2.2 → birder-0.3.0}/tests/test_transforms.py +9 -0
  116. birder-0.2.2/birder/ops/msda.py +0 -138
  117. birder-0.2.2/birder/ops/swattention.py +0 -225
  118. birder-0.2.2/birder/tools/det_results.py +0 -61
  119. birder-0.2.2/birder/tools/quantize_model.py +0 -156
  120. birder-0.2.2/birder/version.py +0 -1
  121. {birder-0.2.2 → birder-0.3.0}/LICENSE +0 -0
  122. {birder-0.2.2 → birder-0.3.0}/birder/__init__.py +0 -0
  123. {birder-0.2.2 → birder-0.3.0}/birder/adversarial/__init__.py +0 -0
  124. {birder-0.2.2 → birder-0.3.0}/birder/adversarial/base.py +0 -0
  125. {birder-0.2.2 → birder-0.3.0}/birder/adversarial/deepfool.py +0 -0
  126. {birder-0.2.2 → birder-0.3.0}/birder/adversarial/fgsm.py +0 -0
  127. {birder-0.2.2 → birder-0.3.0}/birder/adversarial/pgd.py +0 -0
  128. {birder-0.2.2 → birder-0.3.0}/birder/adversarial/simba.py +0 -0
  129. {birder-0.2.2 → birder-0.3.0}/birder/common/__init__.py +0 -0
  130. {birder-0.2.2 → birder-0.3.0}/birder/common/cli.py +0 -0
  131. {birder-0.2.2 → birder-0.3.0}/birder/common/fs_ops.py +0 -0
  132. {birder-0.2.2 → birder-0.3.0}/birder/common/masking.py +0 -0
  133. {birder-0.2.2 → birder-0.3.0}/birder/conf/__init__.py +0 -0
  134. {birder-0.2.2 → birder-0.3.0}/birder/conf/settings.py +0 -0
  135. {birder-0.2.2 → birder-0.3.0}/birder/data/__init__.py +0 -0
  136. {birder-0.2.2 → birder-0.3.0}/birder/data/collators/__init__.py +0 -0
  137. {birder-0.2.2 → birder-0.3.0}/birder/data/dataloader/__init__.py +0 -0
  138. {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/__init__.py +0 -0
  139. {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/directory.py +0 -0
  140. {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/fake.py +0 -0
  141. {birder-0.2.2 → birder-0.3.0}/birder/data/datasets/webdataset.py +0 -0
  142. {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/__init__.py +0 -0
  143. {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/classification.py +0 -0
  144. {birder-0.2.2 → birder-0.3.0}/birder/data/transforms/mosaic.py +0 -0
  145. {birder-0.2.2 → birder-0.3.0}/birder/datahub/__init__.py +0 -0
  146. {birder-0.2.2 → birder-0.3.0}/birder/datahub/_lib.py +0 -0
  147. {birder-0.2.2 → birder-0.3.0}/birder/datahub/classification.py +0 -0
  148. {birder-0.2.2 → birder-0.3.0}/birder/inference/__init__.py +0 -0
  149. {birder-0.2.2 → birder-0.3.0}/birder/inference/classification.py +0 -0
  150. {birder-0.2.2 → birder-0.3.0}/birder/inference/data_parallel.py +0 -0
  151. {birder-0.2.2 → birder-0.3.0}/birder/introspection/__init__.py +0 -0
  152. {birder-0.2.2 → birder-0.3.0}/birder/introspection/attention_rollout.py +0 -0
  153. {birder-0.2.2 → birder-0.3.0}/birder/introspection/base.py +0 -0
  154. {birder-0.2.2 → birder-0.3.0}/birder/introspection/gradcam.py +0 -0
  155. {birder-0.2.2 → birder-0.3.0}/birder/introspection/guided_backprop.py +0 -0
  156. {birder-0.2.2 → birder-0.3.0}/birder/introspection/transformer_attribution.py +0 -0
  157. {birder-0.2.2 → birder-0.3.0}/birder/kernels/__init__.py +0 -0
  158. {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
  159. {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
  160. {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
  161. {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
  162. {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
  163. {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
  164. {birder-0.2.2 → birder-0.3.0}/birder/kernels/deformable_detr/vision.cpp +0 -0
  165. {birder-0.2.2 → birder-0.3.0}/birder/kernels/soft_nms/op.cpp +0 -0
  166. {birder-0.2.2 → birder-0.3.0}/birder/kernels/soft_nms/soft_nms.h +0 -0
  167. {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
  168. {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
  169. {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
  170. {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
  171. {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
  172. {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
  173. {birder-0.2.2 → birder-0.3.0}/birder/kernels/transnext/swattention.cpp +0 -0
  174. {birder-0.2.2 → birder-0.3.0}/birder/layers/__init__.py +0 -0
  175. {birder-0.2.2 → birder-0.3.0}/birder/layers/activations.py +0 -0
  176. {birder-0.2.2 → birder-0.3.0}/birder/layers/attention_pool.py +0 -0
  177. {birder-0.2.2 → birder-0.3.0}/birder/layers/ffn.py +0 -0
  178. {birder-0.2.2 → birder-0.3.0}/birder/layers/gem.py +0 -0
  179. {birder-0.2.2 → birder-0.3.0}/birder/layers/layer_norm.py +0 -0
  180. {birder-0.2.2 → birder-0.3.0}/birder/layers/layer_scale.py +0 -0
  181. {birder-0.2.2 → birder-0.3.0}/birder/model_registry/__init__.py +0 -0
  182. {birder-0.2.2 → birder-0.3.0}/birder/model_registry/manifest.py +0 -0
  183. {birder-0.2.2 → birder-0.3.0}/birder/model_registry/model_registry.py +0 -0
  184. {birder-0.2.2 → birder-0.3.0}/birder/net/alexnet.py +0 -0
  185. {birder-0.2.2 → birder-0.3.0}/birder/net/base.py +0 -0
  186. {birder-0.2.2 → birder-0.3.0}/birder/net/biformer.py +0 -0
  187. {birder-0.2.2 → birder-0.3.0}/birder/net/cas_vit.py +0 -0
  188. {birder-0.2.2 → birder-0.3.0}/birder/net/coat.py +0 -0
  189. {birder-0.2.2 → birder-0.3.0}/birder/net/conv2former.py +0 -0
  190. {birder-0.2.2 → birder-0.3.0}/birder/net/convmixer.py +0 -0
  191. {birder-0.2.2 → birder-0.3.0}/birder/net/convnext_v2.py +0 -0
  192. {birder-0.2.2 → birder-0.3.0}/birder/net/cspnet.py +0 -0
  193. {birder-0.2.2 → birder-0.3.0}/birder/net/cswin_transformer.py +0 -0
  194. {birder-0.2.2 → birder-0.3.0}/birder/net/darknet.py +0 -0
  195. {birder-0.2.2 → birder-0.3.0}/birder/net/davit.py +0 -0
  196. {birder-0.2.2 → birder-0.3.0}/birder/net/densenet.py +0 -0
  197. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/__init__.py +0 -0
  198. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/base.py +0 -0
  199. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/faster_rcnn.py +0 -0
  200. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/ssd.py +0 -0
  201. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/ssdlite.py +0 -0
  202. {birder-0.2.2 → birder-0.3.0}/birder/net/detection/vitdet.py +0 -0
  203. {birder-0.2.2 → birder-0.3.0}/birder/net/dpn.py +0 -0
  204. {birder-0.2.2 → birder-0.3.0}/birder/net/edgenext.py +0 -0
  205. {birder-0.2.2 → birder-0.3.0}/birder/net/edgevit.py +0 -0
  206. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientnet_lite.py +0 -0
  207. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientnet_v1.py +0 -0
  208. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientnet_v2.py +0 -0
  209. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientvim.py +0 -0
  210. {birder-0.2.2 → birder-0.3.0}/birder/net/efficientvit_mit.py +0 -0
  211. {birder-0.2.2 → birder-0.3.0}/birder/net/focalnet.py +0 -0
  212. {birder-0.2.2 → birder-0.3.0}/birder/net/ghostnet_v1.py +0 -0
  213. {birder-0.2.2 → birder-0.3.0}/birder/net/ghostnet_v2.py +0 -0
  214. {birder-0.2.2 → birder-0.3.0}/birder/net/groupmixformer.py +0 -0
  215. {birder-0.2.2 → birder-0.3.0}/birder/net/hgnet_v1.py +0 -0
  216. {birder-0.2.2 → birder-0.3.0}/birder/net/hgnet_v2.py +0 -0
  217. {birder-0.2.2 → birder-0.3.0}/birder/net/hieradet.py +0 -0
  218. {birder-0.2.2 → birder-0.3.0}/birder/net/inception_next.py +0 -0
  219. {birder-0.2.2 → birder-0.3.0}/birder/net/inception_resnet_v1.py +0 -0
  220. {birder-0.2.2 → birder-0.3.0}/birder/net/inception_resnet_v2.py +0 -0
  221. {birder-0.2.2 → birder-0.3.0}/birder/net/inception_v3.py +0 -0
  222. {birder-0.2.2 → birder-0.3.0}/birder/net/inception_v4.py +0 -0
  223. {birder-0.2.2 → birder-0.3.0}/birder/net/metaformer.py +0 -0
  224. {birder-0.2.2 → birder-0.3.0}/birder/net/mim/__init__.py +0 -0
  225. {birder-0.2.2 → birder-0.3.0}/birder/net/mim/base.py +0 -0
  226. {birder-0.2.2 → birder-0.3.0}/birder/net/mim/crossmae.py +0 -0
  227. {birder-0.2.2 → birder-0.3.0}/birder/net/mim/fcmae.py +0 -0
  228. {birder-0.2.2 → birder-0.3.0}/birder/net/mim/mae_hiera.py +0 -0
  229. {birder-0.2.2 → birder-0.3.0}/birder/net/mim/mae_vit.py +0 -0
  230. {birder-0.2.2 → birder-0.3.0}/birder/net/mim/simmim.py +0 -0
  231. {birder-0.2.2 → birder-0.3.0}/birder/net/mnasnet.py +0 -0
  232. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v1.py +0 -0
  233. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v2.py +0 -0
  234. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v3_large.py +0 -0
  235. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v3_small.py +0 -0
  236. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilenet_v4.py +0 -0
  237. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilevit_v1.py +0 -0
  238. {birder-0.2.2 → birder-0.3.0}/birder/net/mobilevit_v2.py +0 -0
  239. {birder-0.2.2 → birder-0.3.0}/birder/net/moganet.py +0 -0
  240. {birder-0.2.2 → birder-0.3.0}/birder/net/nextvit.py +0 -0
  241. {birder-0.2.2 → birder-0.3.0}/birder/net/nfnet.py +0 -0
  242. {birder-0.2.2 → birder-0.3.0}/birder/net/pvt_v2.py +0 -0
  243. {birder-0.2.2 → birder-0.3.0}/birder/net/rdnet.py +0 -0
  244. {birder-0.2.2 → birder-0.3.0}/birder/net/regionvit.py +0 -0
  245. {birder-0.2.2 → birder-0.3.0}/birder/net/regnet.py +0 -0
  246. {birder-0.2.2 → birder-0.3.0}/birder/net/regnet_z.py +0 -0
  247. {birder-0.2.2 → birder-0.3.0}/birder/net/resmlp.py +0 -0
  248. {birder-0.2.2 → birder-0.3.0}/birder/net/resnest.py +0 -0
  249. {birder-0.2.2 → birder-0.3.0}/birder/net/resnet_v2.py +0 -0
  250. {birder-0.2.2 → birder-0.3.0}/birder/net/se_resnet_v2.py +0 -0
  251. {birder-0.2.2 → birder-0.3.0}/birder/net/sequencer2d.py +0 -0
  252. {birder-0.2.2 → birder-0.3.0}/birder/net/shufflenet_v1.py +0 -0
  253. {birder-0.2.2 → birder-0.3.0}/birder/net/shufflenet_v2.py +0 -0
  254. {birder-0.2.2 → birder-0.3.0}/birder/net/smt.py +0 -0
  255. {birder-0.2.2 → birder-0.3.0}/birder/net/squeezenet.py +0 -0
  256. {birder-0.2.2 → birder-0.3.0}/birder/net/squeezenext.py +0 -0
  257. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/__init__.py +0 -0
  258. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/barlow_twins.py +0 -0
  259. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/base.py +0 -0
  260. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/byol.py +0 -0
  261. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/capi.py +0 -0
  262. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/data2vec.py +0 -0
  263. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/data2vec2.py +0 -0
  264. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/dino_v1.py +0 -0
  265. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/dino_v2.py +0 -0
  266. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/franca.py +0 -0
  267. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/i_jepa.py +0 -0
  268. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/ibot.py +0 -0
  269. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/mmcr.py +0 -0
  270. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/simclr.py +0 -0
  271. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/sscd.py +0 -0
  272. {birder-0.2.2 → birder-0.3.0}/birder/net/ssl/vicreg.py +0 -0
  273. {birder-0.2.2 → birder-0.3.0}/birder/net/starnet.py +0 -0
  274. {birder-0.2.2 → birder-0.3.0}/birder/net/swiftformer.py +0 -0
  275. {birder-0.2.2 → birder-0.3.0}/birder/net/uniformer.py +0 -0
  276. {birder-0.2.2 → birder-0.3.0}/birder/net/van.py +0 -0
  277. {birder-0.2.2 → birder-0.3.0}/birder/net/vgg.py +0 -0
  278. {birder-0.2.2 → birder-0.3.0}/birder/net/vgg_reduced.py +0 -0
  279. {birder-0.2.2 → birder-0.3.0}/birder/net/wide_resnet.py +0 -0
  280. {birder-0.2.2 → birder-0.3.0}/birder/net/xception.py +0 -0
  281. {birder-0.2.2 → birder-0.3.0}/birder/net/xcit.py +0 -0
  282. {birder-0.2.2 → birder-0.3.0}/birder/ops/__init__.py +0 -0
  283. {birder-0.2.2 → birder-0.3.0}/birder/ops/soft_nms.py +0 -0
  284. {birder-0.2.2 → birder-0.3.0}/birder/optim/__init__.py +0 -0
  285. {birder-0.2.2 → birder-0.3.0}/birder/optim/lamb.py +0 -0
  286. {birder-0.2.2 → birder-0.3.0}/birder/optim/lars.py +0 -0
  287. {birder-0.2.2 → birder-0.3.0}/birder/py.typed +0 -0
  288. {birder-0.2.2 → birder-0.3.0}/birder/results/__init__.py +0 -0
  289. {birder-0.2.2 → birder-0.3.0}/birder/results/classification.py +0 -0
  290. {birder-0.2.2 → birder-0.3.0}/birder/results/gui.py +0 -0
  291. {birder-0.2.2 → birder-0.3.0}/birder/scheduler/__init__.py +0 -0
  292. {birder-0.2.2 → birder-0.3.0}/birder/scheduler/cooldown.py +0 -0
  293. {birder-0.2.2 → birder-0.3.0}/birder/scripts/__init__.py +0 -0
  294. {birder-0.2.2 → birder-0.3.0}/birder/scripts/__main__.py +0 -0
  295. {birder-0.2.2 → birder-0.3.0}/birder/scripts/evaluate.py +0 -0
  296. {birder-0.2.2 → birder-0.3.0}/birder/tools/__init__.py +0 -0
  297. {birder-0.2.2 → birder-0.3.0}/birder/tools/__main__.py +0 -0
  298. {birder-0.2.2 → birder-0.3.0}/birder/tools/adversarial.py +0 -0
  299. {birder-0.2.2 → birder-0.3.0}/birder/tools/avg_model.py +0 -0
  300. {birder-0.2.2 → birder-0.3.0}/birder/tools/download_model.py +0 -0
  301. {birder-0.2.2 → birder-0.3.0}/birder/tools/ensemble_model.py +0 -0
  302. {birder-0.2.2 → birder-0.3.0}/birder/tools/introspection.py +0 -0
  303. {birder-0.2.2 → birder-0.3.0}/birder/tools/labelme_to_coco.py +0 -0
  304. {birder-0.2.2 → birder-0.3.0}/birder/tools/list_models.py +0 -0
  305. {birder-0.2.2 → birder-0.3.0}/birder/tools/model_info.py +0 -0
  306. {birder-0.2.2 → birder-0.3.0}/birder/tools/results.py +0 -0
  307. {birder-0.2.2 → birder-0.3.0}/birder/tools/show_iterator.py +0 -0
  308. {birder-0.2.2 → birder-0.3.0}/birder/tools/similarity.py +0 -0
  309. {birder-0.2.2 → birder-0.3.0}/birder/tools/stats.py +0 -0
  310. {birder-0.2.2 → birder-0.3.0}/birder/tools/verify_coco.py +0 -0
  311. {birder-0.2.2 → birder-0.3.0}/birder/tools/verify_directory.py +0 -0
  312. {birder-0.2.2 → birder-0.3.0}/birder/tools/voc_to_coco.py +0 -0
  313. {birder-0.2.2 → birder-0.3.0}/birder.egg-info/dependency_links.txt +0 -0
  314. {birder-0.2.2 → birder-0.3.0}/birder.egg-info/entry_points.txt +0 -0
  315. {birder-0.2.2 → birder-0.3.0}/birder.egg-info/top_level.txt +0 -0
  316. {birder-0.2.2 → birder-0.3.0}/pyproject.toml +0 -0
  317. {birder-0.2.2 → birder-0.3.0}/requirements/requirements-hf.txt +0 -0
  318. {birder-0.2.2 → birder-0.3.0}/requirements/requirements.txt +0 -0
  319. {birder-0.2.2 → birder-0.3.0}/setup.cfg +0 -0
  320. {birder-0.2.2 → birder-0.3.0}/tests/test_adversarial.py +0 -0
  321. {birder-0.2.2 → birder-0.3.0}/tests/test_collators.py +0 -0
  322. {birder-0.2.2 → birder-0.3.0}/tests/test_datasets.py +0 -0
  323. {birder-0.2.2 → birder-0.3.0}/tests/test_introspection.py +0 -0
  324. {birder-0.2.2 → birder-0.3.0}/tests/test_layers.py +0 -0
  325. {birder-0.2.2 → birder-0.3.0}/tests/test_net_mim.py +0 -0
  326. {birder-0.2.2 → birder-0.3.0}/tests/test_net_ssl.py +0 -0
  327. {birder-0.2.2 → birder-0.3.0}/tests/test_ops.py +0 -0
  328. {birder-0.2.2 → birder-0.3.0}/tests/test_optim.py +0 -0
  329. {birder-0.2.2 → birder-0.3.0}/tests/test_results.py +0 -0
  330. {birder-0.2.2 → birder-0.3.0}/tests/test_scheduler.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -45,7 +45,7 @@ Provides-Extra: dev
45
45
  Requires-Dist: altair~=5.5.0; extra == "dev"
46
46
  Requires-Dist: bandit~=1.9.2; extra == "dev"
47
47
  Requires-Dist: black~=25.12.0; extra == "dev"
48
- Requires-Dist: build~=1.3.0; extra == "dev"
48
+ Requires-Dist: build~=1.4.0; extra == "dev"
49
49
  Requires-Dist: bumpver~=2025.1131; extra == "dev"
50
50
  Requires-Dist: captum~=0.7.0; extra == "dev"
51
51
  Requires-Dist: coverage~=7.13.1; extra == "dev"
@@ -66,6 +66,7 @@ Requires-Dist: pytest; extra == "dev"
66
66
  Requires-Dist: requests~=2.32.5; extra == "dev"
67
67
  Requires-Dist: safetensors~=0.7.0; extra == "dev"
68
68
  Requires-Dist: setuptools; extra == "dev"
69
+ Requires-Dist: torchao~=0.15.0; extra == "dev"
69
70
  Requires-Dist: torchprofile==0.0.4; extra == "dev"
70
71
  Requires-Dist: twine~=6.2.0; extra == "dev"
71
72
  Requires-Dist: types-requests~=2.32.4; extra == "dev"
@@ -208,7 +209,7 @@ For detailed information about these datasets, including descriptions, citations
208
209
 
209
210
  ## Detection
210
211
 
211
- Detection training and inference are available, see [docs/training_detection.md](docs/training_detection.md) and
212
+ Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
212
213
  [docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
213
214
 
214
215
  ## Project Status and Contributions
@@ -129,7 +129,7 @@ For detailed information about these datasets, including descriptions, citations
129
129
 
130
130
  ## Detection
131
131
 
132
- Detection training and inference are available, see [docs/training_detection.md](docs/training_detection.md) and
132
+ Detection training and inference are available, see [docs/training_scripts.md](docs/training_scripts.md) and
133
133
  [docs/inference.md](docs/inference.md). APIs and model coverage may evolve as detection support matures.
134
134
 
135
135
  ## Project Status and Contributions
@@ -1,11 +1,7 @@
1
1
  import os
2
- import random
3
2
  from typing import Any
4
3
  from typing import Optional
5
4
 
6
- import numpy as np
7
- import torch
8
-
9
5
  from birder.conf import settings
10
6
  from birder.data.transforms.classification import RGBType
11
7
  from birder.model_registry import registry
@@ -19,11 +15,8 @@ from birder.net.ssl.base import SSLBaseNet
19
15
  from birder.version import __version__
20
16
 
21
17
 
22
- def set_random_seeds(seed: int) -> None:
23
- torch.manual_seed(seed)
24
- torch.cuda.manual_seed_all(seed)
25
- np.random.seed(seed)
26
- random.seed(seed)
18
+ def env_bool(name: str) -> bool:
19
+ return os.environ.get(name, "").lower() in {"1", "true", "yes", "on"}
27
20
 
28
21
 
29
22
  def get_size_from_signature(signature: SignatureType | DetectionSignatureType) -> tuple[int, int]:
@@ -5,6 +5,7 @@ import typing
5
5
  from typing import Optional
6
6
  from typing import get_args
7
7
 
8
+ from birder.common.cli import FlexibleDictAction
8
9
  from birder.common.cli import ValidationError
9
10
  from birder.common.training_utils import OptimizerType
10
11
  from birder.common.training_utils import SchedulerType
@@ -82,11 +83,23 @@ def add_lr_wd_args(parser: argparse.ArgumentParser, backbone_lr: bool = False, w
82
83
  metavar="WD",
83
84
  help="weight decay for embedding parameters for vision transformer models",
84
85
  )
86
+ group.add_argument(
87
+ "--custom-layer-wd",
88
+ action=FlexibleDictAction,
89
+ metavar="LAYER=WD",
90
+ help="custom weight decay for specific layers by name (e.g., offset_conv=0.0)",
91
+ )
85
92
  group.add_argument("--layer-decay", type=float, help="layer-wise learning rate decay (LLRD)")
86
93
  group.add_argument("--layer-decay-min-scale", type=float, help="minimum layer scale factor clamp value")
87
94
  group.add_argument(
88
95
  "--layer-decay-no-opt-scale", type=float, help="layer scale threshold below which parameters are frozen"
89
96
  )
97
+ group.add_argument(
98
+ "--custom-layer-lr-scale",
99
+ action=FlexibleDictAction,
100
+ metavar="LAYER=SCALE",
101
+ help="custom lr_scale for specific layers by name (e.g., offset_conv=0.01,attention=0.5)",
102
+ )
90
103
 
91
104
 
92
105
  def add_lr_scheduler_args(parser: argparse.ArgumentParser) -> None:
@@ -185,6 +198,11 @@ def add_detection_input_args(parser: argparse.ArgumentParser) -> None:
185
198
  action="store_true",
186
199
  help="enable random square resize once per batch (capped by max(--size))",
187
200
  )
201
+ group.add_argument(
202
+ "--multiscale-min-size",
203
+ type=int,
204
+ help="minimum short-edge size for multiscale lists (rounded up to nearest multiple of 32)",
205
+ )
188
206
 
189
207
 
190
208
  def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs: int = 100) -> None:
@@ -193,6 +211,12 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
193
211
  group.add_argument(
194
212
  "--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
195
213
  )
214
+ group.add_argument(
215
+ "--steps-per-epoch",
216
+ type=int,
217
+ metavar="N",
218
+ help="virtual epoch length in steps, leave unset to use the full dataset",
219
+ )
196
220
  group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
197
221
  group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
198
222
  group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
@@ -3,8 +3,10 @@ import contextlib
3
3
  import logging
4
4
  import math
5
5
  import os
6
+ import random
6
7
  import re
7
8
  import subprocess
9
+ import sys
8
10
  from collections import deque
9
11
  from collections.abc import Callable
10
12
  from collections.abc import Generator
@@ -15,6 +17,7 @@ from typing import Any
15
17
  from typing import Literal
16
18
  from typing import Optional
17
19
  from typing import Sized
20
+ from typing import overload
18
21
 
19
22
  import numpy as np
20
23
  import torch
@@ -29,12 +32,25 @@ from birder.data.transforms.classification import training_preset
29
32
  from birder.optim import Lamb
30
33
  from birder.optim import Lars
31
34
  from birder.scheduler import CooldownLR
35
+ from birder.version import __version__ as birder_version
32
36
 
33
37
  logger = logging.getLogger(__name__)
34
38
 
35
39
  OptimizerType = Literal["sgd", "rmsprop", "adam", "adamw", "nadam", "nadamw", "lamb", "lambw", "lars"]
36
40
  SchedulerType = Literal["constant", "step", "multistep", "cosine", "polynomial"]
37
41
 
42
+ ###############################################################################
43
+ # Core Utilities
44
+ ###############################################################################
45
+
46
+
47
+ def set_random_seeds(seed: int) -> None:
48
+ torch.manual_seed(seed)
49
+ torch.cuda.manual_seed_all(seed)
50
+ np.random.seed(seed)
51
+ random.seed(seed)
52
+
53
+
38
54
  ###############################################################################
39
55
  # Data Sampling
40
56
  ###############################################################################
@@ -55,13 +71,7 @@ class RASampler(torch.utils.data.Sampler):
55
71
  """
56
72
 
57
73
  def __init__(
58
- self,
59
- dataset: Sized,
60
- num_replicas: int,
61
- rank: int,
62
- shuffle: bool,
63
- seed: int = 0,
64
- repetitions: int = 3,
74
+ self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
65
75
  ) -> None:
66
76
  super().__init__()
67
77
  self.dataset = dataset
@@ -70,12 +80,11 @@ class RASampler(torch.utils.data.Sampler):
70
80
  self.epoch = 0
71
81
  self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
72
82
  self.total_size = self.num_samples * self.num_replicas
73
- self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
74
83
  self.shuffle = shuffle
75
84
  self.seed = seed
76
85
  self.repetitions = repetitions
77
86
 
78
- def __iter__(self) -> Iterator[list[int]]:
87
+ def __iter__(self) -> Iterator[int]:
79
88
  if self.shuffle is True:
80
89
  # Deterministically shuffle based on epoch
81
90
  g = torch.Generator()
@@ -85,18 +94,148 @@ class RASampler(torch.utils.data.Sampler):
85
94
  indices = list(range(len(self.dataset)))
86
95
 
87
96
  # Add extra samples to make it evenly divisible
88
- indices = [ele for ele in indices for i in range(self.repetitions)]
89
- indices += indices[: (self.total_size - len(indices))]
90
- assert len(indices) == self.total_size
97
+ indices = [ele for ele in indices for _ in range(self.repetitions)]
98
+ if len(indices) < self.total_size:
99
+ indices += indices[: (self.total_size - len(indices))]
100
+ else:
101
+ indices = indices[: self.total_size]
91
102
 
92
- # Subsample
103
+ # Shard by rank
93
104
  indices = indices[self.rank : self.total_size : self.num_replicas]
94
105
  assert len(indices) == self.num_samples
95
106
 
96
- return iter(indices[: self.num_selected_samples])
107
+ yield from indices
108
+
109
+ def __len__(self) -> int:
110
+ return self.num_samples
111
+
112
+ def set_epoch(self, epoch: int) -> None:
113
+ self.epoch = epoch
114
+
115
+
116
+ class InfiniteSampler(torch.utils.data.Sampler):
117
+ """
118
+ Infinite sampler that loops indefinitely over the dataset
119
+ """
120
+
121
+ def __init__(self, dataset: Sized, shuffle: bool, seed: int = 0) -> None:
122
+ super().__init__()
123
+ self.dataset = dataset
124
+ self.shuffle = shuffle
125
+ self.seed = seed
126
+ self.epoch = 0
127
+
128
+ def __iter__(self) -> Iterator[int]:
129
+ g = torch.Generator()
130
+ while True:
131
+ if self.shuffle is True:
132
+ g.manual_seed(self.seed + self.epoch)
133
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
134
+ else:
135
+ indices = list(range(len(self.dataset)))
136
+
137
+ yield from indices
138
+
139
+ logger.debug(f"InfiniteSampler finished epoch {self.epoch}")
140
+ self.epoch += 1
141
+
142
+ def __len__(self) -> int:
143
+ return len(self.dataset)
144
+
145
+ def set_epoch(self, epoch: int) -> None:
146
+ self.epoch = epoch
147
+
148
+
149
+ class InfiniteDistributedSampler(torch.utils.data.Sampler):
150
+ """
151
+ Infinite distributed sampler that keeps a continuous shuffled stream per rank
152
+ """
153
+
154
+ def __init__(self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0) -> None:
155
+ super().__init__()
156
+ self.dataset = dataset
157
+ self.num_replicas = num_replicas
158
+ self.rank = rank
159
+ self.shuffle = shuffle
160
+ self.seed = seed
161
+ self.epoch = 0
162
+ self.num_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
163
+ self.total_size = self.num_samples * self.num_replicas
164
+
165
+ def __iter__(self) -> Iterator[int]:
166
+ g = torch.Generator()
167
+ while True:
168
+ if self.shuffle is True:
169
+ g.manual_seed(self.seed + self.epoch)
170
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
171
+ else:
172
+ indices = list(range(len(self.dataset)))
173
+
174
+ if len(indices) < self.total_size:
175
+ indices += indices[: (self.total_size - len(indices))]
176
+ else:
177
+ indices = indices[: self.total_size]
178
+
179
+ indices = indices[self.rank : self.total_size : self.num_replicas]
180
+ assert len(indices) == self.num_samples
181
+
182
+ yield from indices
183
+
184
+ logger.debug(f"InfiniteDistributedSampler finished epoch {self.epoch}")
185
+ self.epoch += 1
186
+
187
+ def __len__(self) -> int:
188
+ return self.num_samples
189
+
190
+ def set_epoch(self, epoch: int) -> None:
191
+ self.epoch = epoch
192
+
193
+
194
+ class InfiniteRASampler(torch.utils.data.Sampler):
195
+ """
196
+ Infinite version of the repeated augmentation sampler
197
+ """
198
+
199
+ def __init__(
200
+ self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
201
+ ) -> None:
202
+ super().__init__()
203
+ self.dataset = dataset
204
+ self.num_replicas = num_replicas
205
+ self.rank = rank
206
+ self.epoch = 0
207
+ self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
208
+ self.total_size = self.num_samples * self.num_replicas
209
+ self.shuffle = shuffle
210
+ self.seed = seed
211
+ self.repetitions = repetitions
212
+
213
+ def __iter__(self) -> Iterator[int]:
214
+ g = torch.Generator()
215
+ while True:
216
+ if self.shuffle is True:
217
+ g.manual_seed(self.seed + self.epoch)
218
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
219
+ else:
220
+ indices = list(range(len(self.dataset)))
221
+
222
+ indices = [ele for ele in indices for _ in range(self.repetitions)]
223
+ if len(indices) < self.total_size:
224
+ indices += indices[: (self.total_size - len(indices))]
225
+ else:
226
+ indices = indices[: self.total_size]
227
+
228
+ # Shard by rank
229
+ indices = indices[self.rank : self.total_size : self.num_replicas]
230
+ assert len(indices) == self.num_samples
231
+
232
+ yield from indices
233
+
234
+ logger.debug(f"InfiniteRASampler finished epoch {self.epoch}")
235
+ self.epoch += 1
97
236
 
98
237
  def __len__(self) -> int:
99
- return self.num_selected_samples
238
+ return self.num_samples
100
239
 
101
240
  def set_epoch(self, epoch: int) -> None:
102
241
  self.epoch = epoch
@@ -207,13 +346,16 @@ def count_layers(model: torch.nn.Module) -> int:
207
346
  def optimizer_parameter_groups(
208
347
  model: torch.nn.Module,
209
348
  weight_decay: float,
349
+ base_lr: float,
210
350
  norm_weight_decay: Optional[float] = None,
211
351
  custom_keys_weight_decay: Optional[list[tuple[str, float]]] = None,
352
+ custom_layer_weight_decay: Optional[dict[str, float]] = None,
212
353
  layer_decay: Optional[float] = None,
213
354
  layer_decay_min_scale: Optional[float] = None,
214
355
  layer_decay_no_opt_scale: Optional[float] = None,
215
356
  bias_lr: Optional[float] = None,
216
357
  backbone_lr: Optional[float] = None,
358
+ custom_layer_lr_scale: Optional[dict[str, float]] = None,
217
359
  ) -> list[dict[str, Any]]:
218
360
  """
219
361
  Return parameter groups for optimizers with per-parameter group weight decay.
@@ -233,11 +375,16 @@ def optimizer_parameter_groups(
233
375
  The PyTorch model whose parameters will be grouped for optimization.
234
376
  weight_decay
235
377
  Default weight decay (L2 regularization) value applied to parameters.
378
+ base_lr
379
+ Base learning rate that will be scaled by lr_scale factors for each parameter group.
236
380
  norm_weight_decay
237
381
  Weight decay value specifically for normalization layers. If None, uses weight_decay.
238
382
  custom_keys_weight_decay
239
383
  List of (parameter_name, weight_decay) tuples for applying custom weight decay
240
384
  values to specific parameters by name matching.
385
+ custom_layer_weight_decay
386
+ Dictionary mapping layer name substrings to custom weight decay values.
387
+ Applied to parameters whose names contain the specified keys.
241
388
  layer_decay
242
389
  Layer-wise learning rate decay factor.
243
390
  layer_decay_min_scale
@@ -248,6 +395,9 @@ def optimizer_parameter_groups(
248
395
  Custom learning rate for bias parameters (parameters ending with '.bias').
249
396
  backbone_lr
250
397
  Custom learning rate for backbone parameters (parameters starting with 'backbone.').
398
+ custom_layer_lr_scale
399
+ Dictionary mapping layer name substrings to custom lr_scale values.
400
+ Applied to parameters whose names contain the specified keys.
251
401
 
252
402
  Returns
253
403
  -------
@@ -291,14 +441,14 @@ def optimizer_parameter_groups(
291
441
  if layer_decay is not None:
292
442
  layer_max = num_layers - 1
293
443
  layer_scales = [max(layer_decay_min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers)]
294
- logger.info(f"Layer scaling in range of {min(layer_scales)} - {max(layer_scales)} on {num_layers} layers")
444
+ logger.info(f"Layer scaling ranges from {min(layer_scales)} to {max(layer_scales)} across {num_layers} layers")
295
445
 
296
446
  # Set weight decay and layer decay
297
447
  idx = 0
298
448
  params = []
299
449
  module_stack_with_prefix = [(model, "")]
300
450
  visited_modules = []
301
- while len(module_stack_with_prefix) > 0:
451
+ while len(module_stack_with_prefix) > 0: # pylint: disable=too-many-nested-blocks
302
452
  skip_module = False
303
453
  (module, prefix) = module_stack_with_prefix.pop()
304
454
  if id(module) in visited_modules:
@@ -324,13 +474,35 @@ def optimizer_parameter_groups(
324
474
  for key, custom_wd in custom_keys_weight_decay:
325
475
  target_name_for_custom_key = f"{prefix}.{name}" if prefix != "" and "." in key else name
326
476
  if key == target_name_for_custom_key:
477
+ # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
478
+ lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
479
+ if custom_layer_lr_scale is not None:
480
+ for layer_name_key, custom_scale in custom_layer_lr_scale.items():
481
+ if layer_name_key in target_name:
482
+ lr_scale = custom_scale
483
+ break
484
+
485
+ # Apply custom layer weight decay (substring matching)
486
+ wd = custom_wd
487
+ if custom_layer_weight_decay is not None:
488
+ for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
489
+ if layer_name_key in target_name:
490
+ wd = custom_wd_value
491
+ break
492
+
327
493
  d = {
328
494
  "params": p,
329
- "weight_decay": custom_wd,
330
- "lr_scale": 1.0 if layer_decay is None else layer_scales[idx],
495
+ "weight_decay": wd,
496
+ "lr_scale": lr_scale, # Used only for reference/debugging
331
497
  }
332
- if backbone_lr is not None and target_name.startswith("backbone.") is True:
498
+
499
+ # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
500
+ if bias_lr is not None and target_name.endswith(".bias") is True:
501
+ d["lr"] = bias_lr
502
+ elif backbone_lr is not None and target_name.startswith("backbone.") is True:
333
503
  d["lr"] = backbone_lr
504
+ elif lr_scale != 1.0:
505
+ d["lr"] = base_lr * lr_scale
334
506
 
335
507
  params.append(d)
336
508
  is_custom_key = True
@@ -342,16 +514,34 @@ def optimizer_parameter_groups(
342
514
  else:
343
515
  wd = weight_decay
344
516
 
517
+ # Apply custom layer weight decay (substring matching)
518
+ if custom_layer_weight_decay is not None:
519
+ for layer_name_key, custom_wd_value in custom_layer_weight_decay.items():
520
+ if layer_name_key in target_name:
521
+ wd = custom_wd_value
522
+ break
523
+
524
+ # Calculate lr_scale (from layer_decay or custom_layer_lr_scale)
525
+ lr_scale = 1.0 if layer_decay is None else layer_scales[idx]
526
+ if custom_layer_lr_scale is not None:
527
+ for layer_name_key, custom_scale in custom_layer_lr_scale.items():
528
+ if layer_name_key in target_name:
529
+ lr_scale = custom_scale
530
+ break
531
+
345
532
  d = {
346
533
  "params": p,
347
534
  "weight_decay": wd,
348
- "lr_scale": 1.0 if layer_decay is None else layer_scales[idx],
535
+ "lr_scale": lr_scale, # Used only for reference/debugging
349
536
  }
350
- if backbone_lr is not None and target_name.startswith("backbone.") is True:
351
- d["lr"] = backbone_lr
352
537
 
538
+ # Apply learning rate based on priority: bias_lr > backbone_lr > lr_scale
353
539
  if bias_lr is not None and target_name.endswith(".bias") is True:
354
540
  d["lr"] = bias_lr
541
+ elif backbone_lr is not None and target_name.startswith("backbone.") is True:
542
+ d["lr"] = backbone_lr
543
+ elif lr_scale != 1.0:
544
+ d["lr"] = base_lr * lr_scale
355
545
 
356
546
  params.append(d)
357
547
 
@@ -442,6 +632,8 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
442
632
  else:
443
633
  raise ValueError("Unknown optimizer")
444
634
 
635
+ logger.debug(f"Created {opt} optimizer with lr={lr}, weight_decay={args.wd}")
636
+
445
637
  return optimizer
446
638
 
447
639
 
@@ -477,10 +669,10 @@ def get_scheduler(
477
669
 
478
670
  main_steps = steps - begin_step - remaining_warmup - remaining_cooldown - 1
479
671
 
480
- logger.debug(f"Using {steps_per_epoch} steps per epoch")
672
+ logger.debug(f"Scheduler using {steps_per_epoch} steps per epoch")
481
673
  logger.debug(
482
674
  f"Scheduler {args.lr_scheduler} set for {steps} steps of which {warmup_steps} "
483
- f"are warmup and {cooldown_steps} cooldown"
675
+ f"are warmup and {cooldown_steps} are cooldown"
484
676
  )
485
677
  logger.debug(
486
678
  f"Currently starting from step {begin_step} with {remaining_warmup} remaining warmup steps "
@@ -568,27 +760,87 @@ def get_amp_scaler(amp: bool, amp_dtype_str: str) -> tuple[Optional[torch.amp.Gr
568
760
  return (scaler, amp_dtype)
569
761
 
570
762
 
763
+ @overload
571
764
  def get_samplers(
572
- args: argparse.Namespace, training_dataset: torch.utils.data.Dataset, validation_dataset: torch.utils.data.Dataset
573
- ) -> torch.utils.data.Sampler:
574
- if args.distributed is True:
575
- if args.ra_sampler is True:
576
- train_sampler = RASampler(
577
- training_dataset,
578
- num_replicas=args.world_size,
579
- rank=args.rank,
580
- shuffle=True,
581
- repetitions=args.ra_reps,
582
- )
765
+ args: argparse.Namespace,
766
+ training_dataset: torch.utils.data.Dataset,
767
+ validation_dataset: torch.utils.data.Dataset,
768
+ infinite: bool = False,
769
+ ) -> tuple[torch.utils.data.Sampler, torch.utils.data.Sampler]: ...
583
770
 
584
- else:
585
- train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset, shuffle=True)
586
771
 
587
- validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
772
+ @overload
773
+ def get_samplers(
774
+ args: argparse.Namespace,
775
+ training_dataset: torch.utils.data.Dataset,
776
+ validation_dataset: None = None,
777
+ infinite: bool = False,
778
+ ) -> tuple[torch.utils.data.Sampler, None]: ...
779
+
780
+
781
+ def get_samplers(
782
+ args: argparse.Namespace,
783
+ training_dataset: torch.utils.data.Dataset,
784
+ validation_dataset: Optional[torch.utils.data.Dataset] = None,
785
+ infinite: bool = False,
786
+ ) -> tuple[torch.utils.data.Sampler, Optional[torch.utils.data.Sampler]]:
787
+ if args.seed is None:
788
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
789
+ if is_dist_available_and_initialized() is True:
790
+ seed_tensor = torch.tensor(seed, dtype=torch.int64).cuda()
791
+ dist.broadcast(seed_tensor, src=0, async_op=False)
792
+ seed = int(seed_tensor.item())
793
+ else:
794
+ seed = args.seed
795
+
796
+ ra_sampler = getattr(args, "ra_sampler", False)
797
+ if args.distributed is True:
798
+ if infinite is True:
799
+ if ra_sampler is True:
800
+ train_sampler = InfiniteRASampler(
801
+ training_dataset,
802
+ num_replicas=args.world_size,
803
+ rank=args.rank,
804
+ shuffle=True,
805
+ seed=seed,
806
+ repetitions=args.ra_reps,
807
+ )
808
+ else:
809
+ train_sampler = InfiniteDistributedSampler(
810
+ training_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True, seed=seed
811
+ )
812
+ else:
813
+ if ra_sampler is True:
814
+ train_sampler = RASampler(
815
+ training_dataset,
816
+ num_replicas=args.world_size,
817
+ rank=args.rank,
818
+ shuffle=True,
819
+ seed=seed,
820
+ repetitions=args.ra_reps,
821
+ )
822
+ else:
823
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
824
+ training_dataset, shuffle=True, seed=seed
825
+ )
826
+
827
+ if validation_dataset is None:
828
+ validation_sampler = None
829
+ else:
830
+ validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
588
831
 
589
832
  else:
590
- train_sampler = torch.utils.data.RandomSampler(training_dataset)
591
- validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
833
+ if infinite is True:
834
+ train_sampler = InfiniteSampler(training_dataset, shuffle=True, seed=seed)
835
+ else:
836
+ generator = torch.Generator()
837
+ generator.manual_seed(seed)
838
+ train_sampler = torch.utils.data.RandomSampler(training_dataset, generator=generator)
839
+
840
+ if validation_dataset is None:
841
+ validation_sampler = None
842
+ else:
843
+ validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
592
844
 
593
845
  return (train_sampler, validation_sampler)
594
846
 
@@ -810,6 +1062,51 @@ def is_local_primary(args: argparse.Namespace) -> bool:
810
1062
  return args.local_rank == 0 # type: ignore[no-any-return]
811
1063
 
812
1064
 
1065
+ def init_training(
1066
+ args: argparse.Namespace,
1067
+ log: logging.Logger,
1068
+ *,
1069
+ cudnn_dynamic_size: bool = False,
1070
+ ) -> tuple[torch.device, int, bool]:
1071
+ init_distributed_mode(args)
1072
+
1073
+ log.info(f"Starting training, birder version: {birder_version}, pytorch version: {torch.__version__}")
1074
+
1075
+ log_git_info()
1076
+
1077
+ if args.cpu is True:
1078
+ device = torch.device("cpu")
1079
+ device_id = 0
1080
+ else:
1081
+ device = torch.device("cuda")
1082
+ device_id = torch.cuda.current_device()
1083
+
1084
+ if args.use_deterministic_algorithms is True:
1085
+ torch.backends.cudnn.benchmark = False
1086
+ torch.use_deterministic_algorithms(True)
1087
+ elif cudnn_dynamic_size is True:
1088
+ # Dynamic sizes: avoid per-size algorithm selection overhead.
1089
+ torch.backends.cudnn.enabled = False
1090
+ else:
1091
+ torch.backends.cudnn.enabled = True
1092
+ torch.backends.cudnn.benchmark = True
1093
+
1094
+ if args.seed is not None:
1095
+ set_random_seeds(args.seed)
1096
+
1097
+ if args.non_interactive is True or is_local_primary(args) is False:
1098
+ disable_tqdm = True
1099
+ elif sys.stderr.isatty() is False:
1100
+ disable_tqdm = True
1101
+ else:
1102
+ disable_tqdm = False
1103
+
1104
+ # Enable or disable the autograd anomaly detection.
1105
+ torch.autograd.set_detect_anomaly(args.grad_anomaly_detection)
1106
+
1107
+ return (device, device_id, disable_tqdm)
1108
+
1109
+
813
1110
  ###############################################################################
814
1111
  # Utility Functions
815
1112
  ###############################################################################