birder 0.2.3__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (331) hide show
  1. {birder-0.2.3 → birder-0.3.1}/PKG-INFO +2 -1
  2. {birder-0.2.3 → birder-0.3.1}/birder/common/fs_ops.py +2 -2
  3. {birder-0.2.3 → birder-0.3.1}/birder/common/training_cli.py +12 -1
  4. {birder-0.2.3 → birder-0.3.1}/birder/common/training_utils.py +219 -33
  5. {birder-0.2.3 → birder-0.3.1}/birder/data/collators/detection.py +1 -0
  6. {birder-0.2.3 → birder-0.3.1}/birder/data/dataloader/webdataset.py +12 -2
  7. {birder-0.2.3 → birder-0.3.1}/birder/kernels/load_kernel.py +16 -11
  8. {birder-0.2.3 → birder-0.3.1}/birder/kernels/soft_nms/soft_nms.cpp +17 -18
  9. {birder-0.2.3 → birder-0.3.1}/birder/net/base.py +3 -3
  10. {birder-0.2.3 → birder-0.3.1}/birder/net/biformer.py +2 -2
  11. {birder-0.2.3 → birder-0.3.1}/birder/net/cait.py +4 -3
  12. {birder-0.2.3 → birder-0.3.1}/birder/net/cas_vit.py +6 -6
  13. {birder-0.2.3 → birder-0.3.1}/birder/net/coat.py +8 -8
  14. {birder-0.2.3 → birder-0.3.1}/birder/net/conv2former.py +2 -2
  15. {birder-0.2.3 → birder-0.3.1}/birder/net/convnext_v1.py +7 -2
  16. {birder-0.2.3 → birder-0.3.1}/birder/net/convnext_v2.py +2 -2
  17. {birder-0.2.3 → birder-0.3.1}/birder/net/crossformer.py +35 -32
  18. {birder-0.2.3 → birder-0.3.1}/birder/net/crossvit.py +4 -3
  19. {birder-0.2.3 → birder-0.3.1}/birder/net/cspnet.py +2 -2
  20. {birder-0.2.3 → birder-0.3.1}/birder/net/cswin_transformer.py +2 -2
  21. {birder-0.2.3 → birder-0.3.1}/birder/net/darknet.py +2 -2
  22. {birder-0.2.3 → birder-0.3.1}/birder/net/davit.py +2 -2
  23. {birder-0.2.3 → birder-0.3.1}/birder/net/deit.py +6 -6
  24. {birder-0.2.3 → birder-0.3.1}/birder/net/deit3.py +6 -6
  25. {birder-0.2.3 → birder-0.3.1}/birder/net/densenet.py +2 -2
  26. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/deformable_detr.py +4 -7
  27. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/detr.py +4 -7
  28. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/efficientdet.py +4 -9
  29. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/faster_rcnn.py +2 -2
  30. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/fcos.py +4 -9
  31. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/retinanet.py +4 -9
  32. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/rt_detr_v1.py +5 -4
  33. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/ssd.py +2 -2
  34. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/ssdlite.py +2 -2
  35. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v2.py +2 -2
  36. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v3.py +2 -2
  37. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v4.py +2 -2
  38. {birder-0.2.3 → birder-0.3.1}/birder/net/edgenext.py +2 -2
  39. {birder-0.2.3 → birder-0.3.1}/birder/net/edgevit.py +1 -1
  40. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientformer_v1.py +19 -13
  41. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientformer_v2.py +45 -35
  42. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientnet_lite.py +2 -2
  43. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientnet_v1.py +2 -2
  44. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientnet_v2.py +2 -2
  45. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientvim.py +3 -3
  46. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientvit_mit.py +2 -2
  47. {birder-0.2.3 → birder-0.3.1}/birder/net/efficientvit_msft.py +11 -9
  48. {birder-0.2.3 → birder-0.3.1}/birder/net/fasternet.py +2 -2
  49. {birder-0.2.3 → birder-0.3.1}/birder/net/fastvit.py +3 -2
  50. {birder-0.2.3 → birder-0.3.1}/birder/net/flexivit.py +11 -10
  51. {birder-0.2.3 → birder-0.3.1}/birder/net/focalnet.py +2 -2
  52. {birder-0.2.3 → birder-0.3.1}/birder/net/gc_vit.py +17 -2
  53. {birder-0.2.3 → birder-0.3.1}/birder/net/ghostnet_v1.py +2 -2
  54. {birder-0.2.3 → birder-0.3.1}/birder/net/ghostnet_v2.py +2 -2
  55. {birder-0.2.3 → birder-0.3.1}/birder/net/groupmixformer.py +2 -2
  56. {birder-0.2.3 → birder-0.3.1}/birder/net/hgnet_v1.py +2 -2
  57. {birder-0.2.3 → birder-0.3.1}/birder/net/hgnet_v2.py +2 -2
  58. {birder-0.2.3 → birder-0.3.1}/birder/net/hiera.py +14 -11
  59. {birder-0.2.3 → birder-0.3.1}/birder/net/hieradet.py +2 -2
  60. {birder-0.2.3 → birder-0.3.1}/birder/net/hornet.py +11 -9
  61. {birder-0.2.3 → birder-0.3.1}/birder/net/iformer.py +10 -8
  62. {birder-0.2.3 → birder-0.3.1}/birder/net/inception_next.py +2 -2
  63. {birder-0.2.3 → birder-0.3.1}/birder/net/inception_resnet_v1.py +2 -2
  64. {birder-0.2.3 → birder-0.3.1}/birder/net/inception_resnet_v2.py +2 -2
  65. {birder-0.2.3 → birder-0.3.1}/birder/net/inception_v3.py +2 -2
  66. {birder-0.2.3 → birder-0.3.1}/birder/net/inception_v4.py +2 -2
  67. {birder-0.2.3 → birder-0.3.1}/birder/net/levit.py +46 -34
  68. {birder-0.2.3 → birder-0.3.1}/birder/net/lit_v1.py +2 -2
  69. {birder-0.2.3 → birder-0.3.1}/birder/net/lit_v1_tiny.py +17 -2
  70. {birder-0.2.3 → birder-0.3.1}/birder/net/lit_v2.py +2 -2
  71. {birder-0.2.3 → birder-0.3.1}/birder/net/maxvit.py +69 -57
  72. {birder-0.2.3 → birder-0.3.1}/birder/net/metaformer.py +2 -2
  73. {birder-0.2.3 → birder-0.3.1}/birder/net/mnasnet.py +2 -2
  74. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v1.py +2 -2
  75. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v2.py +2 -2
  76. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v3_large.py +2 -2
  77. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v4.py +2 -2
  78. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v4_hybrid.py +2 -2
  79. {birder-0.2.3 → birder-0.3.1}/birder/net/mobileone.py +3 -2
  80. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilevit_v2.py +2 -2
  81. {birder-0.2.3 → birder-0.3.1}/birder/net/moganet.py +2 -2
  82. {birder-0.2.3 → birder-0.3.1}/birder/net/mvit_v2.py +15 -14
  83. {birder-0.2.3 → birder-0.3.1}/birder/net/nextvit.py +2 -2
  84. {birder-0.2.3 → birder-0.3.1}/birder/net/nfnet.py +2 -2
  85. {birder-0.2.3 → birder-0.3.1}/birder/net/pit.py +10 -9
  86. {birder-0.2.3 → birder-0.3.1}/birder/net/pvt_v1.py +6 -3
  87. {birder-0.2.3 → birder-0.3.1}/birder/net/pvt_v2.py +2 -2
  88. {birder-0.2.3 → birder-0.3.1}/birder/net/rdnet.py +2 -2
  89. {birder-0.2.3 → birder-0.3.1}/birder/net/regionvit.py +6 -6
  90. {birder-0.2.3 → birder-0.3.1}/birder/net/regnet.py +2 -2
  91. {birder-0.2.3 → birder-0.3.1}/birder/net/regnet_z.py +2 -2
  92. {birder-0.2.3 → birder-0.3.1}/birder/net/repghost.py +3 -2
  93. {birder-0.2.3 → birder-0.3.1}/birder/net/repvgg.py +3 -2
  94. {birder-0.2.3 → birder-0.3.1}/birder/net/repvit.py +7 -6
  95. {birder-0.2.3 → birder-0.3.1}/birder/net/resnest.py +2 -2
  96. {birder-0.2.3 → birder-0.3.1}/birder/net/resnet_v1.py +2 -2
  97. {birder-0.2.3 → birder-0.3.1}/birder/net/resnet_v2.py +2 -2
  98. {birder-0.2.3 → birder-0.3.1}/birder/net/resnext.py +2 -2
  99. {birder-0.2.3 → birder-0.3.1}/birder/net/rope_deit3.py +8 -6
  100. {birder-0.2.3 → birder-0.3.1}/birder/net/rope_flexivit.py +13 -10
  101. {birder-0.2.3 → birder-0.3.1}/birder/net/rope_vit.py +30 -11
  102. {birder-0.2.3 → birder-0.3.1}/birder/net/shufflenet_v1.py +2 -2
  103. {birder-0.2.3 → birder-0.3.1}/birder/net/shufflenet_v2.py +2 -2
  104. {birder-0.2.3 → birder-0.3.1}/birder/net/simple_vit.py +9 -6
  105. {birder-0.2.3 → birder-0.3.1}/birder/net/smt.py +1 -1
  106. {birder-0.2.3 → birder-0.3.1}/birder/net/squeezenext.py +2 -2
  107. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/byol.py +3 -2
  108. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/capi.py +156 -11
  109. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/data2vec.py +3 -1
  110. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/data2vec2.py +3 -1
  111. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/dino_v1.py +1 -1
  112. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/dino_v2.py +140 -18
  113. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/franca.py +145 -13
  114. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/ibot.py +1 -1
  115. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/mmcr.py +3 -1
  116. {birder-0.2.3 → birder-0.3.1}/birder/net/starnet.py +2 -2
  117. {birder-0.2.3 → birder-0.3.1}/birder/net/swiftformer.py +6 -6
  118. {birder-0.2.3 → birder-0.3.1}/birder/net/swin_transformer_v1.py +73 -70
  119. {birder-0.2.3 → birder-0.3.1}/birder/net/swin_transformer_v2.py +40 -33
  120. {birder-0.2.3 → birder-0.3.1}/birder/net/tiny_vit.py +22 -12
  121. {birder-0.2.3 → birder-0.3.1}/birder/net/transnext.py +39 -29
  122. {birder-0.2.3 → birder-0.3.1}/birder/net/uniformer.py +1 -1
  123. {birder-0.2.3 → birder-0.3.1}/birder/net/van.py +1 -1
  124. {birder-0.2.3 → birder-0.3.1}/birder/net/vgg.py +1 -1
  125. {birder-0.2.3 → birder-0.3.1}/birder/net/vgg_reduced.py +1 -1
  126. {birder-0.2.3 → birder-0.3.1}/birder/net/vit.py +11 -10
  127. {birder-0.2.3 → birder-0.3.1}/birder/net/vit_parallel.py +10 -9
  128. {birder-0.2.3 → birder-0.3.1}/birder/net/vit_sam.py +41 -40
  129. {birder-0.2.3 → birder-0.3.1}/birder/net/vovnet_v1.py +17 -2
  130. {birder-0.2.3 → birder-0.3.1}/birder/net/vovnet_v2.py +2 -2
  131. {birder-0.2.3 → birder-0.3.1}/birder/net/wide_resnet.py +2 -2
  132. {birder-0.2.3 → birder-0.3.1}/birder/net/xception.py +2 -2
  133. {birder-0.2.3 → birder-0.3.1}/birder/net/xcit.py +2 -2
  134. birder-0.3.1/birder/ops/msda.py +203 -0
  135. birder-0.3.1/birder/ops/swattention.py +288 -0
  136. {birder-0.2.3 → birder-0.3.1}/birder/results/detection.py +108 -0
  137. {birder-0.2.3 → birder-0.3.1}/birder/results/gui.py +10 -8
  138. {birder-0.2.3 → birder-0.3.1}/birder/scripts/benchmark.py +22 -13
  139. {birder-0.2.3 → birder-0.3.1}/birder/scripts/predict.py +7 -0
  140. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train.py +44 -24
  141. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_barlow_twins.py +41 -23
  142. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_byol.py +42 -24
  143. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_capi.py +72 -26
  144. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_data2vec.py +44 -26
  145. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_data2vec2.py +46 -28
  146. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_detection.py +40 -20
  147. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_dino_v1.py +65 -31
  148. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_dino_v2.py +133 -60
  149. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_dino_v2_dist.py +131 -58
  150. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_franca.py +84 -46
  151. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_i_jepa.py +44 -25
  152. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_ibot.py +53 -33
  153. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_kd.py +44 -24
  154. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_mim.py +41 -22
  155. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_mmcr.py +42 -24
  156. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_rotnet.py +42 -24
  157. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_simclr.py +41 -23
  158. {birder-0.2.3 → birder-0.3.1}/birder/scripts/train_vicreg.py +41 -23
  159. {birder-0.2.3 → birder-0.3.1}/birder/tools/convert_model.py +18 -15
  160. birder-0.3.1/birder/tools/det_results.py +264 -0
  161. birder-0.3.1/birder/tools/quantize_model.py +162 -0
  162. {birder-0.2.3 → birder-0.3.1}/birder/tools/results.py +11 -7
  163. birder-0.3.1/birder/version.py +1 -0
  164. {birder-0.2.3 → birder-0.3.1}/birder.egg-info/PKG-INFO +2 -1
  165. {birder-0.2.3 → birder-0.3.1}/birder.egg-info/SOURCES.txt +1 -0
  166. {birder-0.2.3 → birder-0.3.1}/birder.egg-info/requires.txt +1 -0
  167. {birder-0.2.3 → birder-0.3.1}/requirements/_requirements-dev.txt +2 -0
  168. {birder-0.2.3 → birder-0.3.1}/tests/test_common.py +73 -2
  169. birder-0.3.1/tests/test_dataloaders.py +101 -0
  170. {birder-0.2.3 → birder-0.3.1}/tests/test_inference.py +2 -2
  171. {birder-0.2.3 → birder-0.3.1}/tests/test_net.py +41 -2
  172. birder-0.3.1/tests/test_net_detection.py +248 -0
  173. {birder-0.2.3 → birder-0.3.1}/tests/test_net_ssl.py +594 -6
  174. {birder-0.2.3 → birder-0.3.1}/tests/test_results.py +173 -0
  175. birder-0.2.3/birder/ops/msda.py +0 -138
  176. birder-0.2.3/birder/ops/swattention.py +0 -225
  177. birder-0.2.3/birder/tools/det_results.py +0 -61
  178. birder-0.2.3/birder/tools/quantize_model.py +0 -156
  179. birder-0.2.3/birder/version.py +0 -1
  180. birder-0.2.3/tests/test_net_detection.py +0 -170
  181. {birder-0.2.3 → birder-0.3.1}/LICENSE +0 -0
  182. {birder-0.2.3 → birder-0.3.1}/README.md +0 -0
  183. {birder-0.2.3 → birder-0.3.1}/birder/__init__.py +0 -0
  184. {birder-0.2.3 → birder-0.3.1}/birder/adversarial/__init__.py +0 -0
  185. {birder-0.2.3 → birder-0.3.1}/birder/adversarial/base.py +0 -0
  186. {birder-0.2.3 → birder-0.3.1}/birder/adversarial/deepfool.py +0 -0
  187. {birder-0.2.3 → birder-0.3.1}/birder/adversarial/fgsm.py +0 -0
  188. {birder-0.2.3 → birder-0.3.1}/birder/adversarial/pgd.py +0 -0
  189. {birder-0.2.3 → birder-0.3.1}/birder/adversarial/simba.py +0 -0
  190. {birder-0.2.3 → birder-0.3.1}/birder/common/__init__.py +0 -0
  191. {birder-0.2.3 → birder-0.3.1}/birder/common/cli.py +0 -0
  192. {birder-0.2.3 → birder-0.3.1}/birder/common/lib.py +0 -0
  193. {birder-0.2.3 → birder-0.3.1}/birder/common/masking.py +0 -0
  194. {birder-0.2.3 → birder-0.3.1}/birder/conf/__init__.py +0 -0
  195. {birder-0.2.3 → birder-0.3.1}/birder/conf/settings.py +0 -0
  196. {birder-0.2.3 → birder-0.3.1}/birder/data/__init__.py +0 -0
  197. {birder-0.2.3 → birder-0.3.1}/birder/data/collators/__init__.py +0 -0
  198. {birder-0.2.3 → birder-0.3.1}/birder/data/dataloader/__init__.py +0 -0
  199. {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/__init__.py +0 -0
  200. {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/coco.py +0 -0
  201. {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/directory.py +0 -0
  202. {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/fake.py +0 -0
  203. {birder-0.2.3 → birder-0.3.1}/birder/data/datasets/webdataset.py +0 -0
  204. {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/__init__.py +0 -0
  205. {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/classification.py +0 -0
  206. {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/detection.py +0 -0
  207. {birder-0.2.3 → birder-0.3.1}/birder/data/transforms/mosaic.py +0 -0
  208. {birder-0.2.3 → birder-0.3.1}/birder/datahub/__init__.py +0 -0
  209. {birder-0.2.3 → birder-0.3.1}/birder/datahub/_lib.py +0 -0
  210. {birder-0.2.3 → birder-0.3.1}/birder/datahub/classification.py +0 -0
  211. {birder-0.2.3 → birder-0.3.1}/birder/inference/__init__.py +0 -0
  212. {birder-0.2.3 → birder-0.3.1}/birder/inference/classification.py +0 -0
  213. {birder-0.2.3 → birder-0.3.1}/birder/inference/data_parallel.py +0 -0
  214. {birder-0.2.3 → birder-0.3.1}/birder/inference/detection.py +0 -0
  215. {birder-0.2.3 → birder-0.3.1}/birder/inference/wbf.py +0 -0
  216. {birder-0.2.3 → birder-0.3.1}/birder/introspection/__init__.py +0 -0
  217. {birder-0.2.3 → birder-0.3.1}/birder/introspection/attention_rollout.py +0 -0
  218. {birder-0.2.3 → birder-0.3.1}/birder/introspection/base.py +0 -0
  219. {birder-0.2.3 → birder-0.3.1}/birder/introspection/gradcam.py +0 -0
  220. {birder-0.2.3 → birder-0.3.1}/birder/introspection/guided_backprop.py +0 -0
  221. {birder-0.2.3 → birder-0.3.1}/birder/introspection/transformer_attribution.py +0 -0
  222. {birder-0.2.3 → birder-0.3.1}/birder/kernels/__init__.py +0 -0
  223. {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp +0 -0
  224. {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h +0 -0
  225. {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu +0 -0
  226. {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cuda/ms_deform_attn_cuda.h +0 -0
  227. {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh +0 -0
  228. {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/ms_deform_attn.h +0 -0
  229. {birder-0.2.3 → birder-0.3.1}/birder/kernels/deformable_detr/vision.cpp +0 -0
  230. {birder-0.2.3 → birder-0.3.1}/birder/kernels/soft_nms/op.cpp +0 -0
  231. {birder-0.2.3 → birder-0.3.1}/birder/kernels/soft_nms/soft_nms.h +0 -0
  232. {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/av_bw_kernel.cu +0 -0
  233. {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/av_fw_kernel.cu +0 -0
  234. {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_bw_kernel.cu +0 -0
  235. {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_fw_kernel.cu +0 -0
  236. {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_rpb_bw_kernel.cu +0 -0
  237. {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/qk_rpb_fw_kernel.cu +0 -0
  238. {birder-0.2.3 → birder-0.3.1}/birder/kernels/transnext/swattention.cpp +0 -0
  239. {birder-0.2.3 → birder-0.3.1}/birder/layers/__init__.py +0 -0
  240. {birder-0.2.3 → birder-0.3.1}/birder/layers/activations.py +0 -0
  241. {birder-0.2.3 → birder-0.3.1}/birder/layers/attention_pool.py +0 -0
  242. {birder-0.2.3 → birder-0.3.1}/birder/layers/ffn.py +0 -0
  243. {birder-0.2.3 → birder-0.3.1}/birder/layers/gem.py +0 -0
  244. {birder-0.2.3 → birder-0.3.1}/birder/layers/layer_norm.py +0 -0
  245. {birder-0.2.3 → birder-0.3.1}/birder/layers/layer_scale.py +0 -0
  246. {birder-0.2.3 → birder-0.3.1}/birder/model_registry/__init__.py +0 -0
  247. {birder-0.2.3 → birder-0.3.1}/birder/model_registry/manifest.py +0 -0
  248. {birder-0.2.3 → birder-0.3.1}/birder/model_registry/model_registry.py +0 -0
  249. {birder-0.2.3 → birder-0.3.1}/birder/net/__init__.py +0 -0
  250. {birder-0.2.3 → birder-0.3.1}/birder/net/alexnet.py +0 -0
  251. {birder-0.2.3 → birder-0.3.1}/birder/net/convmixer.py +0 -0
  252. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/__init__.py +0 -0
  253. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/base.py +0 -0
  254. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/vitdet.py +0 -0
  255. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_anchors.py +0 -0
  256. {birder-0.2.3 → birder-0.3.1}/birder/net/detection/yolo_v4_tiny.py +0 -0
  257. {birder-0.2.3 → birder-0.3.1}/birder/net/dpn.py +0 -0
  258. {birder-0.2.3 → birder-0.3.1}/birder/net/mim/__init__.py +0 -0
  259. {birder-0.2.3 → birder-0.3.1}/birder/net/mim/base.py +0 -0
  260. {birder-0.2.3 → birder-0.3.1}/birder/net/mim/crossmae.py +0 -0
  261. {birder-0.2.3 → birder-0.3.1}/birder/net/mim/fcmae.py +0 -0
  262. {birder-0.2.3 → birder-0.3.1}/birder/net/mim/mae_hiera.py +0 -0
  263. {birder-0.2.3 → birder-0.3.1}/birder/net/mim/mae_vit.py +0 -0
  264. {birder-0.2.3 → birder-0.3.1}/birder/net/mim/simmim.py +0 -0
  265. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilenet_v3_small.py +0 -0
  266. {birder-0.2.3 → birder-0.3.1}/birder/net/mobilevit_v1.py +0 -0
  267. {birder-0.2.3 → birder-0.3.1}/birder/net/resmlp.py +0 -0
  268. {birder-0.2.3 → birder-0.3.1}/birder/net/se_resnet_v1.py +0 -0
  269. {birder-0.2.3 → birder-0.3.1}/birder/net/se_resnet_v2.py +0 -0
  270. {birder-0.2.3 → birder-0.3.1}/birder/net/se_resnext.py +0 -0
  271. {birder-0.2.3 → birder-0.3.1}/birder/net/sequencer2d.py +0 -0
  272. {birder-0.2.3 → birder-0.3.1}/birder/net/squeezenet.py +0 -0
  273. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/__init__.py +0 -0
  274. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/barlow_twins.py +0 -0
  275. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/base.py +0 -0
  276. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/i_jepa.py +0 -0
  277. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/simclr.py +0 -0
  278. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/sscd.py +0 -0
  279. {birder-0.2.3 → birder-0.3.1}/birder/net/ssl/vicreg.py +0 -0
  280. {birder-0.2.3 → birder-0.3.1}/birder/ops/__init__.py +0 -0
  281. {birder-0.2.3 → birder-0.3.1}/birder/ops/soft_nms.py +0 -0
  282. {birder-0.2.3 → birder-0.3.1}/birder/optim/__init__.py +0 -0
  283. {birder-0.2.3 → birder-0.3.1}/birder/optim/lamb.py +0 -0
  284. {birder-0.2.3 → birder-0.3.1}/birder/optim/lars.py +0 -0
  285. {birder-0.2.3 → birder-0.3.1}/birder/py.typed +0 -0
  286. {birder-0.2.3 → birder-0.3.1}/birder/results/__init__.py +0 -0
  287. {birder-0.2.3 → birder-0.3.1}/birder/results/classification.py +0 -0
  288. {birder-0.2.3 → birder-0.3.1}/birder/scheduler/__init__.py +0 -0
  289. {birder-0.2.3 → birder-0.3.1}/birder/scheduler/cooldown.py +0 -0
  290. {birder-0.2.3 → birder-0.3.1}/birder/scripts/__init__.py +0 -0
  291. {birder-0.2.3 → birder-0.3.1}/birder/scripts/__main__.py +0 -0
  292. {birder-0.2.3 → birder-0.3.1}/birder/scripts/evaluate.py +0 -0
  293. {birder-0.2.3 → birder-0.3.1}/birder/scripts/predict_detection.py +0 -0
  294. {birder-0.2.3 → birder-0.3.1}/birder/tools/__init__.py +0 -0
  295. {birder-0.2.3 → birder-0.3.1}/birder/tools/__main__.py +0 -0
  296. {birder-0.2.3 → birder-0.3.1}/birder/tools/adversarial.py +0 -0
  297. {birder-0.2.3 → birder-0.3.1}/birder/tools/auto_anchors.py +0 -0
  298. {birder-0.2.3 → birder-0.3.1}/birder/tools/avg_model.py +0 -0
  299. {birder-0.2.3 → birder-0.3.1}/birder/tools/download_model.py +0 -0
  300. {birder-0.2.3 → birder-0.3.1}/birder/tools/ensemble_model.py +0 -0
  301. {birder-0.2.3 → birder-0.3.1}/birder/tools/introspection.py +0 -0
  302. {birder-0.2.3 → birder-0.3.1}/birder/tools/labelme_to_coco.py +0 -0
  303. {birder-0.2.3 → birder-0.3.1}/birder/tools/list_models.py +0 -0
  304. {birder-0.2.3 → birder-0.3.1}/birder/tools/model_info.py +0 -0
  305. {birder-0.2.3 → birder-0.3.1}/birder/tools/pack.py +0 -0
  306. {birder-0.2.3 → birder-0.3.1}/birder/tools/show_det_iterator.py +0 -0
  307. {birder-0.2.3 → birder-0.3.1}/birder/tools/show_iterator.py +0 -0
  308. {birder-0.2.3 → birder-0.3.1}/birder/tools/similarity.py +0 -0
  309. {birder-0.2.3 → birder-0.3.1}/birder/tools/stats.py +0 -0
  310. {birder-0.2.3 → birder-0.3.1}/birder/tools/verify_coco.py +0 -0
  311. {birder-0.2.3 → birder-0.3.1}/birder/tools/verify_directory.py +0 -0
  312. {birder-0.2.3 → birder-0.3.1}/birder/tools/voc_to_coco.py +0 -0
  313. {birder-0.2.3 → birder-0.3.1}/birder.egg-info/dependency_links.txt +0 -0
  314. {birder-0.2.3 → birder-0.3.1}/birder.egg-info/entry_points.txt +0 -0
  315. {birder-0.2.3 → birder-0.3.1}/birder.egg-info/top_level.txt +0 -0
  316. {birder-0.2.3 → birder-0.3.1}/pyproject.toml +0 -0
  317. {birder-0.2.3 → birder-0.3.1}/requirements/requirements-hf.txt +0 -0
  318. {birder-0.2.3 → birder-0.3.1}/requirements/requirements.txt +0 -0
  319. {birder-0.2.3 → birder-0.3.1}/setup.cfg +0 -0
  320. {birder-0.2.3 → birder-0.3.1}/tests/test_adversarial.py +0 -0
  321. {birder-0.2.3 → birder-0.3.1}/tests/test_collators.py +0 -0
  322. {birder-0.2.3 → birder-0.3.1}/tests/test_datasets.py +0 -0
  323. {birder-0.2.3 → birder-0.3.1}/tests/test_introspection.py +0 -0
  324. {birder-0.2.3 → birder-0.3.1}/tests/test_kernels.py +0 -0
  325. {birder-0.2.3 → birder-0.3.1}/tests/test_layers.py +0 -0
  326. {birder-0.2.3 → birder-0.3.1}/tests/test_model_registry.py +0 -0
  327. {birder-0.2.3 → birder-0.3.1}/tests/test_net_mim.py +0 -0
  328. {birder-0.2.3 → birder-0.3.1}/tests/test_ops.py +0 -0
  329. {birder-0.2.3 → birder-0.3.1}/tests/test_optim.py +0 -0
  330. {birder-0.2.3 → birder-0.3.1}/tests/test_scheduler.py +0 -0
  331. {birder-0.2.3 → birder-0.3.1}/tests/test_transforms.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: birder
3
- Version: 0.2.3
3
+ Version: 0.3.1
4
4
  Summary: An open-source computer vision framework for wildlife image analysis, featuring state-of-the-art models for species classification and detection.
5
5
  Author: Ofer Hasson
6
6
  License-Expression: Apache-2.0
@@ -66,6 +66,7 @@ Requires-Dist: pytest; extra == "dev"
66
66
  Requires-Dist: requests~=2.32.5; extra == "dev"
67
67
  Requires-Dist: safetensors~=0.7.0; extra == "dev"
68
68
  Requires-Dist: setuptools; extra == "dev"
69
+ Requires-Dist: torchao~=0.15.0; extra == "dev"
69
70
  Requires-Dist: torchprofile==0.0.4; extra == "dev"
70
71
  Requires-Dist: twine~=6.2.0; extra == "dev"
71
72
  Requires-Dist: types-requests~=2.32.4; extra == "dev"
@@ -627,7 +627,7 @@ def load_model(
627
627
  net.to(dtype)
628
628
  if inference is True:
629
629
  for param in net.parameters():
630
- param.requires_grad = False
630
+ param.requires_grad_(False)
631
631
 
632
632
  if pt2 is False: # NOTE: Remove when GraphModule add support for 'eval'
633
633
  net.eval()
@@ -799,7 +799,7 @@ def load_detection_model(
799
799
  net.to(dtype)
800
800
  if inference is True:
801
801
  for param in net.parameters():
802
- param.requires_grad = False
802
+ param.requires_grad_(False)
803
803
 
804
804
  net.eval()
805
805
 
@@ -39,6 +39,7 @@ def add_optimization_args(parser: argparse.ArgumentParser, default_batch_size: i
39
39
  group = parser.add_argument_group("Optimization parameters")
40
40
  group.add_argument("--batch-size", type=int, default=default_batch_size, metavar="N", help="the batch size")
41
41
  group.add_argument("--opt", type=str, choices=list(get_args(OptimizerType)), default="sgd", help="optimizer to use")
42
+ group.add_argument("--opt-fused", default=False, action="store_true", help="use fused optimizer implementation")
42
43
  group.add_argument("--momentum", type=float, default=0.9, metavar="M", help="optimizer momentum")
43
44
  group.add_argument("--nesterov", default=False, action="store_true", help="use nesterov momentum")
44
45
  group.add_argument("--opt-eps", type=float, help="optimizer epsilon (None to use the optimizer default)")
@@ -211,6 +212,12 @@ def add_training_schedule_args(parser: argparse.ArgumentParser, default_epochs:
211
212
  group.add_argument(
212
213
  "--stop-epoch", type=int, metavar="N", help="epoch to stop the training at (multi stage training)"
213
214
  )
215
+ group.add_argument(
216
+ "--steps-per-epoch",
217
+ type=int,
218
+ metavar="N",
219
+ help="virtual epoch length in steps, leave unset to use the full dataset",
220
+ )
214
221
  group.add_argument("--warmup-epochs", type=int, metavar="N", help="number of warmup epochs")
215
222
  group.add_argument("--warmup-steps", type=int, metavar="N", help="number of warmup optimizer steps")
216
223
  group.add_argument("--cooldown-epochs", type=int, metavar="N", help="number of cooldown epochs (linear to zero)")
@@ -243,6 +250,7 @@ def add_data_aug_args(
243
250
  default_level: int = 4,
244
251
  default_min_scale: Optional[float] = None,
245
252
  default_re_prob: Optional[float] = None,
253
+ smoothing_alpha: bool = False,
246
254
  mixup_cutmix: bool = False,
247
255
  ) -> None:
248
256
  group = parser.add_argument_group("Data augmentation parameters")
@@ -279,6 +287,8 @@ def add_data_aug_args(
279
287
  group.add_argument(
280
288
  "--simple-crop", default=False, action="store_true", help="use simple random crop (SRC) instead of RRC"
281
289
  )
290
+ if smoothing_alpha is True:
291
+ group.add_argument("--smoothing-alpha", type=float, default=0.0, help="label smoothing alpha")
282
292
  if mixup_cutmix is True:
283
293
  group.add_argument("--mixup-alpha", type=float, help="mixup alpha")
284
294
  group.add_argument("--cutmix", default=False, action="store_true", help="enable cutmix")
@@ -559,9 +569,9 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
559
569
  group.add_argument("--wds", default=False, action="store_true", help="use webdataset for training")
560
570
  group.add_argument("--wds-info", type=str, metavar="FILE", help="wds info file path")
561
571
  group.add_argument("--wds-cache-dir", type=str, metavar="DIR", help="webdataset cache directory")
562
- group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
563
572
  if unsupervised is False:
564
573
  group.add_argument("--wds-class-file", type=str, metavar="FILE", help="class list file")
574
+ group.add_argument("--wds-train-size", type=int, metavar="N", help="size of the wds training set")
565
575
  group.add_argument("--wds-val-size", type=int, metavar="N", help="size of the wds validation set")
566
576
  group.add_argument(
567
577
  "--wds-training-split", type=str, default="training", metavar="NAME", help="wds dataset train split"
@@ -570,6 +580,7 @@ def add_training_data_args(parser: argparse.ArgumentParser, unsupervised: bool =
570
580
  "--wds-val-split", type=str, default="validation", metavar="NAME", help="wds dataset validation split"
571
581
  )
572
582
  else:
583
+ group.add_argument("--wds-size", type=int, metavar="N", help="size of the wds")
573
584
  group.add_argument(
574
585
  "--wds-split", type=str, default="training", metavar="NAME", help="wds dataset split to load"
575
586
  )
@@ -17,6 +17,7 @@ from typing import Any
17
17
  from typing import Literal
18
18
  from typing import Optional
19
19
  from typing import Sized
20
+ from typing import overload
20
21
 
21
22
  import numpy as np
22
23
  import torch
@@ -70,13 +71,7 @@ class RASampler(torch.utils.data.Sampler):
70
71
  """
71
72
 
72
73
  def __init__(
73
- self,
74
- dataset: Sized,
75
- num_replicas: int,
76
- rank: int,
77
- shuffle: bool,
78
- seed: int = 0,
79
- repetitions: int = 3,
74
+ self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
80
75
  ) -> None:
81
76
  super().__init__()
82
77
  self.dataset = dataset
@@ -85,12 +80,11 @@ class RASampler(torch.utils.data.Sampler):
85
80
  self.epoch = 0
86
81
  self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
87
82
  self.total_size = self.num_samples * self.num_replicas
88
- self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
89
83
  self.shuffle = shuffle
90
84
  self.seed = seed
91
85
  self.repetitions = repetitions
92
86
 
93
- def __iter__(self) -> Iterator[list[int]]:
87
+ def __iter__(self) -> Iterator[int]:
94
88
  if self.shuffle is True:
95
89
  # Deterministically shuffle based on epoch
96
90
  g = torch.Generator()
@@ -100,18 +94,148 @@ class RASampler(torch.utils.data.Sampler):
100
94
  indices = list(range(len(self.dataset)))
101
95
 
102
96
  # Add extra samples to make it evenly divisible
103
- indices = [ele for ele in indices for i in range(self.repetitions)]
104
- indices += indices[: (self.total_size - len(indices))]
105
- assert len(indices) == self.total_size
97
+ indices = [ele for ele in indices for _ in range(self.repetitions)]
98
+ if len(indices) < self.total_size:
99
+ indices += indices[: (self.total_size - len(indices))]
100
+ else:
101
+ indices = indices[: self.total_size]
106
102
 
107
- # Subsample
103
+ # Shard by rank
108
104
  indices = indices[self.rank : self.total_size : self.num_replicas]
109
105
  assert len(indices) == self.num_samples
110
106
 
111
- return iter(indices[: self.num_selected_samples])
107
+ yield from indices
108
+
109
+ def __len__(self) -> int:
110
+ return self.num_samples
111
+
112
+ def set_epoch(self, epoch: int) -> None:
113
+ self.epoch = epoch
114
+
115
+
116
+ class InfiniteSampler(torch.utils.data.Sampler):
117
+ """
118
+ Infinite sampler that loops indefinitely over the dataset
119
+ """
120
+
121
+ def __init__(self, dataset: Sized, shuffle: bool, seed: int = 0) -> None:
122
+ super().__init__()
123
+ self.dataset = dataset
124
+ self.shuffle = shuffle
125
+ self.seed = seed
126
+ self.epoch = 0
127
+
128
+ def __iter__(self) -> Iterator[int]:
129
+ g = torch.Generator()
130
+ while True:
131
+ if self.shuffle is True:
132
+ g.manual_seed(self.seed + self.epoch)
133
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
134
+ else:
135
+ indices = list(range(len(self.dataset)))
136
+
137
+ yield from indices
138
+
139
+ logger.debug(f"InfiniteSampler finished epoch {self.epoch}")
140
+ self.epoch += 1
141
+
142
+ def __len__(self) -> int:
143
+ return len(self.dataset)
144
+
145
+ def set_epoch(self, epoch: int) -> None:
146
+ self.epoch = epoch
147
+
148
+
149
+ class InfiniteDistributedSampler(torch.utils.data.Sampler):
150
+ """
151
+ Infinite distributed sampler that keeps a continuous shuffled stream per rank
152
+ """
153
+
154
+ def __init__(self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0) -> None:
155
+ super().__init__()
156
+ self.dataset = dataset
157
+ self.num_replicas = num_replicas
158
+ self.rank = rank
159
+ self.shuffle = shuffle
160
+ self.seed = seed
161
+ self.epoch = 0
162
+ self.num_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
163
+ self.total_size = self.num_samples * self.num_replicas
164
+
165
+ def __iter__(self) -> Iterator[int]:
166
+ g = torch.Generator()
167
+ while True:
168
+ if self.shuffle is True:
169
+ g.manual_seed(self.seed + self.epoch)
170
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
171
+ else:
172
+ indices = list(range(len(self.dataset)))
173
+
174
+ if len(indices) < self.total_size:
175
+ indices += indices[: (self.total_size - len(indices))]
176
+ else:
177
+ indices = indices[: self.total_size]
178
+
179
+ indices = indices[self.rank : self.total_size : self.num_replicas]
180
+ assert len(indices) == self.num_samples
181
+
182
+ yield from indices
183
+
184
+ logger.debug(f"InfiniteDistributedSampler finished epoch {self.epoch}")
185
+ self.epoch += 1
112
186
 
113
187
  def __len__(self) -> int:
114
- return self.num_selected_samples
188
+ return self.num_samples
189
+
190
+ def set_epoch(self, epoch: int) -> None:
191
+ self.epoch = epoch
192
+
193
+
194
+ class InfiniteRASampler(torch.utils.data.Sampler):
195
+ """
196
+ Infinite version of the repeated augmentation sampler
197
+ """
198
+
199
+ def __init__(
200
+ self, dataset: Sized, num_replicas: int, rank: int, shuffle: bool, seed: int = 0, repetitions: int = 3
201
+ ) -> None:
202
+ super().__init__()
203
+ self.dataset = dataset
204
+ self.num_replicas = num_replicas
205
+ self.rank = rank
206
+ self.epoch = 0
207
+ self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
208
+ self.total_size = self.num_samples * self.num_replicas
209
+ self.shuffle = shuffle
210
+ self.seed = seed
211
+ self.repetitions = repetitions
212
+
213
+ def __iter__(self) -> Iterator[int]:
214
+ g = torch.Generator()
215
+ while True:
216
+ if self.shuffle is True:
217
+ g.manual_seed(self.seed + self.epoch)
218
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
219
+ else:
220
+ indices = list(range(len(self.dataset)))
221
+
222
+ indices = [ele for ele in indices for _ in range(self.repetitions)]
223
+ if len(indices) < self.total_size:
224
+ indices += indices[: (self.total_size - len(indices))]
225
+ else:
226
+ indices = indices[: self.total_size]
227
+
228
+ # Shard by rank
229
+ indices = indices[self.rank : self.total_size : self.num_replicas]
230
+ assert len(indices) == self.num_samples
231
+
232
+ yield from indices
233
+
234
+ logger.debug(f"InfiniteRASampler finished epoch {self.epoch}")
235
+ self.epoch += 1
236
+
237
+ def __len__(self) -> int:
238
+ return self.num_samples
115
239
 
116
240
  def set_epoch(self, epoch: int) -> None:
117
241
  self.epoch = epoch
@@ -469,12 +593,14 @@ def get_optimizer(parameters: list[dict[str, Any]], l_rate: float, args: argpars
469
593
  kwargs["betas"] = args.opt_betas
470
594
  if getattr(args, "opt_alpha", None) is not None:
471
595
  kwargs["alpha"] = args.opt_alpha
596
+ if getattr(args, "opt_fused", False) is True:
597
+ kwargs["fused"] = True
472
598
 
473
599
  # For optimizer compilation
474
600
  # lr = torch.tensor(l_rate) - Causes weird LR scheduling bugs
475
601
  lr = l_rate
476
- if getattr(args, "compile_opt", False) is not False:
477
- if opt not in ("lamb", "lambw", "lars"):
602
+ if getattr(args, "compile_opt", False) is True:
603
+ if opt not in ("sgd", "lamb", "lambw", "lars"):
478
604
  logger.debug("Setting optimizer capturable to True")
479
605
  kwargs["capturable"] = True
480
606
 
@@ -636,27 +762,87 @@ def get_amp_scaler(amp: bool, amp_dtype_str: str) -> tuple[Optional[torch.amp.Gr
636
762
  return (scaler, amp_dtype)
637
763
 
638
764
 
765
+ @overload
639
766
  def get_samplers(
640
- args: argparse.Namespace, training_dataset: torch.utils.data.Dataset, validation_dataset: torch.utils.data.Dataset
641
- ) -> torch.utils.data.Sampler:
642
- if args.distributed is True:
643
- if args.ra_sampler is True:
644
- train_sampler = RASampler(
645
- training_dataset,
646
- num_replicas=args.world_size,
647
- rank=args.rank,
648
- shuffle=True,
649
- repetitions=args.ra_reps,
650
- )
767
+ args: argparse.Namespace,
768
+ training_dataset: torch.utils.data.Dataset,
769
+ validation_dataset: torch.utils.data.Dataset,
770
+ infinite: bool = False,
771
+ ) -> tuple[torch.utils.data.Sampler, torch.utils.data.Sampler]: ...
651
772
 
652
- else:
653
- train_sampler = torch.utils.data.distributed.DistributedSampler(training_dataset, shuffle=True)
654
773
 
655
- validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
774
+ @overload
775
+ def get_samplers(
776
+ args: argparse.Namespace,
777
+ training_dataset: torch.utils.data.Dataset,
778
+ validation_dataset: None = None,
779
+ infinite: bool = False,
780
+ ) -> tuple[torch.utils.data.Sampler, None]: ...
781
+
782
+
783
+ def get_samplers(
784
+ args: argparse.Namespace,
785
+ training_dataset: torch.utils.data.Dataset,
786
+ validation_dataset: Optional[torch.utils.data.Dataset] = None,
787
+ infinite: bool = False,
788
+ ) -> tuple[torch.utils.data.Sampler, Optional[torch.utils.data.Sampler]]:
789
+ if args.seed is None:
790
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
791
+ if is_dist_available_and_initialized() is True:
792
+ seed_tensor = torch.tensor(seed, dtype=torch.int64).cuda()
793
+ dist.broadcast(seed_tensor, src=0, async_op=False)
794
+ seed = int(seed_tensor.item())
795
+ else:
796
+ seed = args.seed
797
+
798
+ ra_sampler = getattr(args, "ra_sampler", False)
799
+ if args.distributed is True:
800
+ if infinite is True:
801
+ if ra_sampler is True:
802
+ train_sampler = InfiniteRASampler(
803
+ training_dataset,
804
+ num_replicas=args.world_size,
805
+ rank=args.rank,
806
+ shuffle=True,
807
+ seed=seed,
808
+ repetitions=args.ra_reps,
809
+ )
810
+ else:
811
+ train_sampler = InfiniteDistributedSampler(
812
+ training_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True, seed=seed
813
+ )
814
+ else:
815
+ if ra_sampler is True:
816
+ train_sampler = RASampler(
817
+ training_dataset,
818
+ num_replicas=args.world_size,
819
+ rank=args.rank,
820
+ shuffle=True,
821
+ seed=seed,
822
+ repetitions=args.ra_reps,
823
+ )
824
+ else:
825
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
826
+ training_dataset, shuffle=True, seed=seed
827
+ )
828
+
829
+ if validation_dataset is None:
830
+ validation_sampler = None
831
+ else:
832
+ validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset, shuffle=False)
656
833
 
657
834
  else:
658
- train_sampler = torch.utils.data.RandomSampler(training_dataset)
659
- validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
835
+ if infinite is True:
836
+ train_sampler = InfiniteSampler(training_dataset, shuffle=True, seed=seed)
837
+ else:
838
+ generator = torch.Generator()
839
+ generator.manual_seed(seed)
840
+ train_sampler = torch.utils.data.RandomSampler(training_dataset, generator=generator)
841
+
842
+ if validation_dataset is None:
843
+ validation_sampler = None
844
+ else:
845
+ validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)
660
846
 
661
847
  return (train_sampler, validation_sampler)
662
848
 
@@ -98,6 +98,7 @@ class BatchRandomResizeCollator(DetectionCollator):
98
98
  if isinstance(boxes, tv_tensors.BoundingBoxes) is False:
99
99
  if boxes.numel() == 0:
100
100
  boxes = boxes.reshape(0, 4)
101
+
101
102
  boxes = tv_tensors.BoundingBoxes(
102
103
  boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=F.get_size(image)
103
104
  )
@@ -22,9 +22,19 @@ def make_wds_loader(
22
22
  shuffle: bool = False,
23
23
  *,
24
24
  exact: bool = False,
25
+ infinite: bool = False,
25
26
  ) -> DataLoader:
27
+ assert exact is False or infinite is False
28
+
29
+ if infinite is True:
30
+ dataset_iterable = dataset.repeat()
31
+ elif exact is False:
32
+ dataset_iterable = dataset.repeat()
33
+ else:
34
+ dataset_iterable = dataset
35
+
26
36
  dataloader = wds.WebLoader(
27
- dataset.repeat() if exact is False else dataset,
37
+ dataset_iterable,
28
38
  batch_size=batch_size,
29
39
  num_workers=num_workers,
30
40
  prefetch_factor=prefetch_factor,
@@ -43,7 +53,7 @@ def make_wds_loader(
43
53
  epoch_size = math.ceil(len(dataset) / (batch_size * world_size))
44
54
 
45
55
  dataloader = dataloader.with_length(epoch_size, silent=True)
46
- if exact is False:
56
+ if exact is False and infinite is False:
47
57
  dataloader = dataloader.with_epoch(epoch_size)
48
58
 
49
59
  return dataloader
@@ -14,11 +14,24 @@ logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  _CACHED_KERNELS: dict[str, ModuleType] = {}
17
+ _CUSTOM_KERNELS_ENABLED = True
18
+
19
+
20
+ def set_custom_kernels_enabled(enabled: bool) -> None:
21
+ global _CUSTOM_KERNELS_ENABLED # pylint: disable=global-statement
22
+ _CUSTOM_KERNELS_ENABLED = enabled
23
+
24
+
25
+ def is_custom_kernels_enabled() -> bool:
26
+ if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
27
+ return False
28
+
29
+ return _CUSTOM_KERNELS_ENABLED
17
30
 
18
31
 
19
32
  def load_msda() -> Optional[ModuleType]:
20
33
  name = "msda"
21
- if torch.cuda.is_available() is False or os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
34
+ if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
22
35
  return None
23
36
 
24
37
  if name in _CACHED_KERNELS:
@@ -60,7 +73,7 @@ def load_msda() -> Optional[ModuleType]:
60
73
 
61
74
  def load_swattention() -> Optional[ModuleType]:
62
75
  name = "swattention"
63
- if torch.cuda.is_available() is False or os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
76
+ if torch.cuda.is_available() is False or is_custom_kernels_enabled() is False:
64
77
  return None
65
78
 
66
79
  if name in _CACHED_KERNELS:
@@ -103,7 +116,7 @@ def load_swattention() -> Optional[ModuleType]:
103
116
 
104
117
  def load_soft_nms() -> Optional[ModuleType]:
105
118
  name = "soft_nms"
106
- if os.environ.get("DISABLE_CUSTOM_KERNELS", "0") == "1":
119
+ if is_custom_kernels_enabled() is False:
107
120
  return None
108
121
 
109
122
  if name in _CACHED_KERNELS:
@@ -120,14 +133,6 @@ def load_soft_nms() -> Optional[ModuleType]:
120
133
  soft_nms: Optional[ModuleType] = load(
121
134
  "soft_nms",
122
135
  src_files,
123
- with_cuda=True,
124
- extra_cflags=["-DWITH_CUDA=1"],
125
- extra_cuda_cflags=[
126
- "-DCUDA_HAS_FP16=1",
127
- "-D__CUDA_NO_HALF_OPERATORS__",
128
- "-D__CUDA_NO_HALF_CONVERSIONS__",
129
- "-D__CUDA_NO_HALF2_OPERATORS__",
130
- ],
131
136
  )
132
137
 
133
138
  if soft_nms is not None:
@@ -61,24 +61,23 @@ void update_sorting_order(torch::Tensor& boxes, torch::Tensor& scores, torch::Te
61
61
  std::tie(max_score, t_max_idx) = torch::max(scores.index({Slice(idx + 1, None)}), 0);
62
62
 
63
63
  // max_idx is computed from sliced data, therefore need to convert it to "global" max idx
64
- auto max_idx = t_max_idx.item<int>() + idx + 1;
65
-
66
- if (scores.index({idx}).item<float>() < max_score.item<float>()) {
67
- auto boxes_idx = boxes.index({idx}).clone();
68
- auto boxes_max = boxes.index({max_idx}).clone();
69
- boxes.index({idx}) = boxes_max;
70
- boxes.index({max_idx}) = boxes_idx;
71
-
72
- auto scores_idx = scores.index({idx}).clone();
73
- auto scores_max = scores.index({max_idx}).clone();
74
- scores.index({idx}) = scores_max;
75
- scores.index({max_idx}) = scores_idx;
76
-
77
- auto areas_idx = areas.index({idx}).clone();
78
- auto areas_max = areas.index({max_idx}).clone();
79
- areas.index({idx}) = areas_max;
80
- areas.index({max_idx}) = areas_idx;
81
- }
64
+ auto max_idx = t_max_idx + (idx + 1);
65
+ auto should_swap = scores.index({idx}) < max_score;
66
+
67
+ auto boxes_idx = boxes.index({idx}).clone();
68
+ auto boxes_max = boxes.index({max_idx}).clone();
69
+ boxes.index_put_({idx}, torch::where(should_swap, boxes_max, boxes_idx));
70
+ boxes.index_put_({max_idx}, torch::where(should_swap, boxes_idx, boxes_max));
71
+
72
+ auto scores_idx = scores.index({idx}).clone();
73
+ auto scores_max = scores.index({max_idx}).clone();
74
+ scores.index_put_({idx}, torch::where(should_swap, scores_max, scores_idx));
75
+ scores.index_put_({max_idx}, torch::where(should_swap, scores_idx, scores_max));
76
+
77
+ auto areas_idx = areas.index({idx}).clone();
78
+ auto areas_max = areas.index({max_idx}).clone();
79
+ areas.index_put_({idx}, torch::where(should_swap, areas_max, areas_idx));
80
+ areas.index_put_({max_idx}, torch::where(should_swap, areas_idx, areas_max));
82
81
  }
83
82
 
84
83
  std::tuple<torch::Tensor, torch::Tensor> soft_nms(
@@ -173,14 +173,14 @@ class BaseNet(nn.Module):
173
173
 
174
174
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
175
175
  for param in self.parameters():
176
- param.requires_grad = False
176
+ param.requires_grad_(False)
177
177
 
178
178
  if freeze_classifier is False:
179
179
  for param in self.classifier.parameters():
180
- param.requires_grad = True
180
+ param.requires_grad_(True)
181
181
  if unfreeze_features is True and hasattr(self, "features") is True:
182
182
  for param in self.features.parameters():
183
- param.requires_grad = True
183
+ param.requires_grad_(True)
184
184
 
185
185
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
186
186
  """
@@ -468,14 +468,14 @@ class BiFormer(DetectorBackbone):
468
468
 
469
469
  def freeze_stages(self, up_to_stage: int) -> None:
470
470
  for param in self.stem.parameters():
471
- param.requires_grad = False
471
+ param.requires_grad_(False)
472
472
 
473
473
  for idx, module in enumerate(self.body.children()):
474
474
  if idx >= up_to_stage:
475
475
  break
476
476
 
477
477
  for param in module.parameters():
478
- param.requires_grad = False
478
+ param.requires_grad_(False)
479
479
 
480
480
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
481
481
  x = self.stem(x)
@@ -268,14 +268,15 @@ class CaiT(BaseNet):
268
268
  super().adjust_size(new_size)
269
269
 
270
270
  # Add back class tokens
271
- self.pos_embed = nn.Parameter(
272
- adjust_position_embedding(
271
+ with torch.no_grad():
272
+ pos_embed = adjust_position_embedding(
273
273
  self.pos_embed,
274
274
  (old_size[0] // self.patch_size[0], old_size[1] // self.patch_size[1]),
275
275
  (new_size[0] // self.patch_size[0], new_size[1] // self.patch_size[1]),
276
276
  0,
277
277
  )
278
- )
278
+
279
+ self.pos_embed = nn.Parameter(pos_embed)
279
280
 
280
281
 
281
282
  registry.register_model_config(
@@ -269,18 +269,18 @@ class CAS_ViT(DetectorBackbone):
269
269
 
270
270
  def freeze(self, freeze_classifier: bool = True, unfreeze_features: bool = False) -> None:
271
271
  for param in self.parameters():
272
- param.requires_grad = False
272
+ param.requires_grad_(False)
273
273
 
274
274
  if freeze_classifier is False:
275
275
  for param in self.classifier.parameters():
276
- param.requires_grad = True
276
+ param.requires_grad_(True)
277
277
 
278
278
  for param in self.dist_classifier.parameters():
279
- param.requires_grad = True
279
+ param.requires_grad_(True)
280
280
 
281
281
  if unfreeze_features is True:
282
282
  for param in self.features.parameters():
283
- param.requires_grad = True
283
+ param.requires_grad_(True)
284
284
 
285
285
  def transform_to_backbone(self) -> None:
286
286
  self.features = nn.Identity()
@@ -300,14 +300,14 @@ class CAS_ViT(DetectorBackbone):
300
300
 
301
301
  def freeze_stages(self, up_to_stage: int) -> None:
302
302
  for param in self.stem.parameters():
303
- param.requires_grad = False
303
+ param.requires_grad_(False)
304
304
 
305
305
  for idx, module in enumerate(self.body.children()):
306
306
  if idx >= up_to_stage:
307
307
  break
308
308
 
309
309
  for param in module.parameters():
310
- param.requires_grad = False
310
+ param.requires_grad_(False)
311
311
 
312
312
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
313
313
  x = self.stem(x)