dnt 0.2.1__py3-none-any.whl → 0.3.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. dnt/__init__.py +4 -1
  2. dnt/analysis/__init__.py +3 -1
  3. dnt/analysis/count.py +107 -0
  4. dnt/analysis/interaction2.py +518 -0
  5. dnt/analysis/position.py +12 -0
  6. dnt/analysis/stop.py +92 -33
  7. dnt/analysis/stop2.py +289 -0
  8. dnt/analysis/stop3.py +758 -0
  9. dnt/detect/__init__.py +1 -1
  10. dnt/detect/signal/detector.py +326 -0
  11. dnt/detect/timestamp.py +105 -0
  12. dnt/detect/yolov8/detector.py +182 -35
  13. dnt/detect/yolov8/segmentor.py +171 -0
  14. dnt/engine/__init__.py +8 -0
  15. dnt/engine/bbox_interp.py +83 -0
  16. dnt/engine/bbox_iou.py +20 -0
  17. dnt/engine/cluster.py +31 -0
  18. dnt/engine/iob.py +66 -0
  19. dnt/filter/__init__.py +4 -0
  20. dnt/filter/filter.py +450 -21
  21. dnt/label/__init__.py +1 -1
  22. dnt/label/labeler.py +215 -14
  23. dnt/label/labeler2.py +631 -0
  24. dnt/shared/__init__.py +2 -1
  25. dnt/shared/data/coco.names +0 -0
  26. dnt/shared/data/openimages.names +0 -0
  27. dnt/shared/data/voc.names +0 -0
  28. dnt/shared/download.py +12 -0
  29. dnt/shared/synhcro.py +150 -0
  30. dnt/shared/util.py +17 -4
  31. dnt/third_party/fast-reid/__init__.py +1 -0
  32. dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
  33. dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
  34. dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
  35. dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
  36. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
  37. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
  38. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
  39. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
  40. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
  41. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
  42. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
  43. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
  44. dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
  45. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
  46. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
  47. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
  48. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
  49. dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
  50. dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
  51. dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
  52. dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
  53. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
  54. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
  55. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
  56. dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
  57. dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
  58. dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
  59. dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
  60. dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
  61. dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
  62. dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
  63. dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
  64. dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
  65. dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
  66. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
  67. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
  68. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
  69. dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
  70. dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
  71. dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
  72. dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
  73. dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
  74. dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
  75. dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
  76. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
  77. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
  78. dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
  79. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
  80. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
  81. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
  82. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
  83. dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
  84. dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
  85. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
  86. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
  87. dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
  88. dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
  89. dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
  90. dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
  91. dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
  92. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
  93. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
  94. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
  95. dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
  96. dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
  97. dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
  98. dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
  99. dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
  100. dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
  101. dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
  102. dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
  103. dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
  104. dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
  105. dnt/third_party/fast-reid/configs/__init__.py +0 -0
  106. dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
  107. dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
  108. dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
  109. dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
  110. dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
  111. dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
  112. dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
  113. dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
  114. dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
  115. dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
  116. dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
  117. dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
  118. dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
  119. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
  120. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
  121. dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
  122. dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
  123. dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
  124. dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
  125. dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
  126. dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
  127. dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
  128. dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
  129. dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
  130. dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
  131. dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
  132. dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
  133. dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
  134. dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
  135. dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
  136. dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
  137. dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
  138. dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
  139. dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
  140. dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
  141. dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
  142. dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
  143. dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
  144. dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
  145. dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
  146. dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
  147. dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
  148. dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
  149. dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
  150. dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
  151. dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
  152. dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
  153. dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
  154. dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
  155. dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
  156. dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
  157. dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
  158. dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
  159. dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
  160. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
  161. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
  162. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
  163. dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
  164. dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
  165. dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
  166. dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
  167. dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
  168. dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
  169. dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
  170. dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
  171. dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
  172. dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
  173. dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
  174. dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
  175. dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
  176. dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
  177. dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
  178. dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
  179. dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
  180. dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
  181. dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
  182. dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
  183. dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
  184. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
  185. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
  186. dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
  187. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
  188. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
  189. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
  190. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
  191. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
  192. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
  193. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
  194. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
  195. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
  196. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
  197. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
  198. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
  199. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
  200. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
  201. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
  202. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
  203. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
  204. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
  205. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
  206. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
  207. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
  208. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
  209. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
  210. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
  211. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
  212. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
  213. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
  214. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
  215. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
  216. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
  217. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
  218. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
  219. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
  220. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
  221. dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
  222. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
  223. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
  224. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
  225. dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
  226. dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
  227. dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
  228. dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
  229. dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
  230. dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
  231. dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
  232. dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
  233. dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
  234. dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
  235. dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
  236. dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
  237. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
  238. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
  239. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
  240. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
  241. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
  242. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
  243. dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
  244. dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
  245. dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
  246. dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
  247. dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
  248. dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
  249. dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
  250. dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
  251. dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
  252. dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
  253. dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
  254. dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
  255. dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
  256. dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
  257. dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
  258. dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
  259. dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
  260. dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
  261. dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
  262. dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
  263. dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
  264. dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
  265. dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
  266. dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
  267. dnt/track/__init__.py +3 -1
  268. dnt/track/botsort/__init__.py +4 -0
  269. dnt/track/botsort/bot_tracker/__init__.py +3 -0
  270. dnt/track/botsort/bot_tracker/basetrack.py +60 -0
  271. dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
  272. dnt/track/botsort/bot_tracker/gmc.py +316 -0
  273. dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
  274. dnt/track/botsort/bot_tracker/matching.py +194 -0
  275. dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
  276. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
  277. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
  278. dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
  279. dnt/track/botsort/inference.py +96 -0
  280. dnt/track/config.py +120 -0
  281. dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
  282. dnt/track/dsort/configs/deep_sort.yaml +0 -0
  283. dnt/track/dsort/configs/fastreid.yaml +1 -1
  284. dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
  285. dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
  286. dnt/track/dsort/deep_sort/deep_sort.py +31 -21
  287. dnt/track/dsort/deep_sort/sort/detection.py +2 -1
  288. dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
  289. dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
  290. dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
  291. dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
  292. dnt/track/dsort/deep_sort/sort/track.py +2 -1
  293. dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
  294. dnt/track/dsort/dsort.py +44 -27
  295. dnt/track/re_class.py +117 -0
  296. dnt/track/sort/sort.py +9 -7
  297. dnt/track/tracker.py +225 -20
  298. dnt-0.3.1.8.dist-info/METADATA +117 -0
  299. dnt-0.3.1.8.dist-info/RECORD +315 -0
  300. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/WHEEL +1 -1
  301. dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
  302. dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
  303. dnt/track/dsort/deep_sort/deep/test.py +0 -77
  304. dnt/track/dsort/deep_sort/deep/train.py +0 -189
  305. dnt/track/dsort/utils/asserts.py +0 -13
  306. dnt/track/dsort/utils/draw.py +0 -36
  307. dnt/track/dsort/utils/json_logger.py +0 -383
  308. dnt/track/dsort/utils/log.py +0 -17
  309. dnt/track/dsort/utils/parser.py +0 -35
  310. dnt/track/dsort/utils/tools.py +0 -39
  311. dnt-0.2.1.dist-info/METADATA +0 -35
  312. dnt-0.2.1.dist-info/RECORD +0 -60
  313. /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
  314. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info/licenses}/LICENSE +0 -0
  315. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ # Based on: https://github.com/facebookresearch/detectron2/blob/master/detectron2/solver/build.py
8
+
9
+ import copy
10
+ import itertools
11
+ import math
12
+ import re
13
+ from enum import Enum
14
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
15
+
16
+ import torch
17
+
18
+ from fastreid.config import CfgNode
19
+ from fastreid.utils.params import ContiguousParams
20
+ from . import lr_scheduler
21
+
22
+ _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
23
+ _GradientClipper = Callable[[_GradientClipperInput], None]
24
+
25
+
26
+ class GradientClipType(Enum):
27
+ VALUE = "value"
28
+ NORM = "norm"
29
+
30
+
31
+ def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
32
+ """
33
+ Creates gradient clipping closure to clip by value or by norm,
34
+ according to the provided config.
35
+ """
36
+ cfg = copy.deepcopy(cfg)
37
+
38
+ def clip_grad_norm(p: _GradientClipperInput):
39
+ torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
40
+
41
+ def clip_grad_value(p: _GradientClipperInput):
42
+ torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
43
+
44
+ _GRADIENT_CLIP_TYPE_TO_CLIPPER = {
45
+ GradientClipType.VALUE: clip_grad_value,
46
+ GradientClipType.NORM: clip_grad_norm,
47
+ }
48
+ return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
49
+
50
+
51
+ def _generate_optimizer_class_with_gradient_clipping(
52
+ optimizer: Type[torch.optim.Optimizer],
53
+ *,
54
+ per_param_clipper: Optional[_GradientClipper] = None,
55
+ global_clipper: Optional[_GradientClipper] = None,
56
+ ) -> Type[torch.optim.Optimizer]:
57
+ """
58
+ Dynamically creates a new type that inherits the type of a given instance
59
+ and overrides the `step` method to add gradient clipping
60
+ """
61
+ assert (
62
+ per_param_clipper is None or global_clipper is None
63
+ ), "Not allowed to use both per-parameter clipping and global clipping"
64
+
65
+ @torch.no_grad()
66
+ def optimizer_wgc_step(self, closure=None):
67
+ if per_param_clipper is not None:
68
+ for group in self.param_groups:
69
+ for p in group["params"]:
70
+ per_param_clipper(p)
71
+ else:
72
+ # global clipper for future use with detr
73
+ # (https://github.com/facebookresearch/detr/pull/287)
74
+ all_params = itertools.chain(*[g["params"] for g in self.param_groups])
75
+ global_clipper(all_params)
76
+ optimizer.step(self, closure)
77
+
78
+ OptimizerWithGradientClip = type(
79
+ optimizer.__name__ + "WithGradientClip",
80
+ (optimizer,),
81
+ {"step": optimizer_wgc_step},
82
+ )
83
+ return OptimizerWithGradientClip
84
+
85
+
86
+ def maybe_add_gradient_clipping(
87
+ cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
88
+ ) -> Type[torch.optim.Optimizer]:
89
+ """
90
+ If gradient clipping is enabled through config options, wraps the existing
91
+ optimizer type to become a new dynamically created class OptimizerWithGradientClip
92
+ that inherits the given optimizer and overrides the `step` method to
93
+ include gradient clipping.
94
+ Args:
95
+ cfg: CfgNode, configuration options
96
+ optimizer: type. A subclass of torch.optim.Optimizer
97
+ Return:
98
+ type: either the input `optimizer` (if gradient clipping is disabled), or
99
+ a subclass of it with gradient clipping included in the `step` method.
100
+ """
101
+ if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
102
+ return optimizer
103
+ if isinstance(optimizer, torch.optim.Optimizer):
104
+ optimizer_type = type(optimizer)
105
+ else:
106
+ assert issubclass(optimizer, torch.optim.Optimizer), optimizer
107
+ optimizer_type = optimizer
108
+
109
+ grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
110
+ OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
111
+ optimizer_type, per_param_clipper=grad_clipper
112
+ )
113
+ if isinstance(optimizer, torch.optim.Optimizer):
114
+ optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
115
+ return optimizer
116
+ else:
117
+ return OptimizerWithGradientClip
118
+
119
+
120
+ def _generate_optimizer_class_with_freeze_layer(
121
+ optimizer: Type[torch.optim.Optimizer],
122
+ *,
123
+ freeze_iters: int = 0,
124
+ ) -> Type[torch.optim.Optimizer]:
125
+ assert freeze_iters > 0, "No layers need to be frozen or freeze iterations is 0"
126
+
127
+ cnt = 0
128
+ @torch.no_grad()
129
+ def optimizer_wfl_step(self, closure=None):
130
+ nonlocal cnt
131
+ if cnt < freeze_iters:
132
+ cnt += 1
133
+ param_ref = []
134
+ grad_ref = []
135
+ for group in self.param_groups:
136
+ if group["freeze_status"] == "freeze":
137
+ for p in group["params"]:
138
+ if p.grad is not None:
139
+ param_ref.append(p)
140
+ grad_ref.append(p.grad)
141
+ p.grad = None
142
+
143
+ optimizer.step(self, closure)
144
+ for p, g in zip(param_ref, grad_ref):
145
+ p.grad = g
146
+ else:
147
+ optimizer.step(self, closure)
148
+
149
+ OptimizerWithFreezeLayer = type(
150
+ optimizer.__name__ + "WithFreezeLayer",
151
+ (optimizer,),
152
+ {"step": optimizer_wfl_step},
153
+ )
154
+ return OptimizerWithFreezeLayer
155
+
156
+
157
+ def maybe_add_freeze_layer(
158
+ cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
159
+ ) -> Type[torch.optim.Optimizer]:
160
+ if len(cfg.MODEL.FREEZE_LAYERS) == 0 or cfg.SOLVER.FREEZE_ITERS <= 0:
161
+ return optimizer
162
+
163
+ if isinstance(optimizer, torch.optim.Optimizer):
164
+ optimizer_type = type(optimizer)
165
+ else:
166
+ assert issubclass(optimizer, torch.optim.Optimizer), optimizer
167
+ optimizer_type = optimizer
168
+
169
+ OptimizerWithFreezeLayer = _generate_optimizer_class_with_freeze_layer(
170
+ optimizer_type,
171
+ freeze_iters=cfg.SOLVER.FREEZE_ITERS
172
+ )
173
+ if isinstance(optimizer, torch.optim.Optimizer):
174
+ optimizer.__class__ = OptimizerWithFreezeLayer # a bit hacky, not recommended
175
+ return optimizer
176
+ else:
177
+ return OptimizerWithFreezeLayer
178
+
179
+
180
+ def build_optimizer(cfg, model, contiguous=True):
181
+ params = get_default_optimizer_params(
182
+ model,
183
+ base_lr=cfg.SOLVER.BASE_LR,
184
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
185
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
186
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
187
+ heads_lr_factor=cfg.SOLVER.HEADS_LR_FACTOR,
188
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
189
+ freeze_layers=cfg.MODEL.FREEZE_LAYERS if cfg.SOLVER.FREEZE_ITERS > 0 else [],
190
+ )
191
+
192
+ if contiguous:
193
+ params = ContiguousParams(params)
194
+ solver_opt = cfg.SOLVER.OPT
195
+ if solver_opt == "SGD":
196
+ return maybe_add_freeze_layer(
197
+ cfg,
198
+ maybe_add_gradient_clipping(cfg, torch.optim.SGD)
199
+ )(
200
+ params.contiguous() if contiguous else params,
201
+ momentum=cfg.SOLVER.MOMENTUM,
202
+ nesterov=cfg.SOLVER.NESTEROV,
203
+ ), params
204
+ else:
205
+ return maybe_add_freeze_layer(
206
+ cfg,
207
+ maybe_add_gradient_clipping(cfg, getattr(torch.optim, solver_opt))
208
+ )(params.contiguous() if contiguous else params), params
209
+
210
+
211
+ def get_default_optimizer_params(
212
+ model: torch.nn.Module,
213
+ base_lr: Optional[float] = None,
214
+ weight_decay: Optional[float] = None,
215
+ weight_decay_norm: Optional[float] = None,
216
+ bias_lr_factor: Optional[float] = 1.0,
217
+ heads_lr_factor: Optional[float] = 1.0,
218
+ weight_decay_bias: Optional[float] = None,
219
+ overrides: Optional[Dict[str, Dict[str, float]]] = None,
220
+ freeze_layers: Optional[list] = [],
221
+ ):
222
+ """
223
+ Get default param list for optimizer, with support for a few types of
224
+ overrides. If no overrides needed, this is equivalent to `model.parameters()`.
225
+ Args:
226
+ base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
227
+ weight_decay: weight decay for every group by default. Can be omitted to use the one
228
+ in optimizer.
229
+ weight_decay_norm: override weight decay for params in normalization layers
230
+ bias_lr_factor: multiplier of lr for bias parameters.
231
+ heads_lr_factor: multiplier of lr for model.head parameters.
232
+ weight_decay_bias: override weight decay for bias parameters
233
+ overrides: if not `None`, provides values for optimizer hyperparameters
234
+ (LR, weight decay) for module parameters with a given name; e.g.
235
+ ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
236
+ weight decay values for all module parameters named `embedding`.
237
+ freeze_layers: layer names for freezing.
238
+ For common detection models, ``weight_decay_norm`` is the only option
239
+ needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
240
+ from Detectron1 that are not found useful.
241
+ Example:
242
+ ::
243
+ torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
244
+ lr=0.01, weight_decay=1e-4, momentum=0.9)
245
+ """
246
+ if overrides is None:
247
+ overrides = {}
248
+ defaults = {}
249
+ if base_lr is not None:
250
+ defaults["lr"] = base_lr
251
+ if weight_decay is not None:
252
+ defaults["weight_decay"] = weight_decay
253
+ bias_overrides = {}
254
+ if bias_lr_factor is not None and bias_lr_factor != 1.0:
255
+ # NOTE: unlike Detectron v1, we now by default make bias hyperparameters
256
+ # exactly the same as regular weights.
257
+ if base_lr is None:
258
+ raise ValueError("bias_lr_factor requires base_lr")
259
+ bias_overrides["lr"] = base_lr * bias_lr_factor
260
+ if weight_decay_bias is not None:
261
+ bias_overrides["weight_decay"] = weight_decay_bias
262
+ if len(bias_overrides):
263
+ if "bias" in overrides:
264
+ raise ValueError("Conflicting overrides for 'bias'")
265
+ overrides["bias"] = bias_overrides
266
+
267
+ layer_names_pattern = [re.compile(name) for name in freeze_layers]
268
+
269
+ norm_module_types = (
270
+ torch.nn.BatchNorm1d,
271
+ torch.nn.BatchNorm2d,
272
+ torch.nn.BatchNorm3d,
273
+ torch.nn.SyncBatchNorm,
274
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
275
+ torch.nn.GroupNorm,
276
+ torch.nn.InstanceNorm1d,
277
+ torch.nn.InstanceNorm2d,
278
+ torch.nn.InstanceNorm3d,
279
+ torch.nn.LayerNorm,
280
+ torch.nn.LocalResponseNorm,
281
+ )
282
+ params: List[Dict[str, Any]] = []
283
+ memo: Set[torch.nn.parameter.Parameter] = set()
284
+
285
+ for module_name, module in model.named_modules():
286
+ for module_param_name, value in module.named_parameters(recurse=False):
287
+ if not value.requires_grad:
288
+ continue
289
+ # Avoid duplicating parameters
290
+ if value in memo:
291
+ continue
292
+ memo.add(value)
293
+
294
+ hyperparams = copy.copy(defaults)
295
+ if isinstance(module, norm_module_types) and weight_decay_norm is not None:
296
+ hyperparams["weight_decay"] = weight_decay_norm
297
+ hyperparams.update(overrides.get(module_param_name, {}))
298
+ if module_name.split('.')[0] == "heads" and (heads_lr_factor is not None and heads_lr_factor != 1.0):
299
+ hyperparams["lr"] = hyperparams.get("lr", base_lr) * heads_lr_factor
300
+ name = module_name + '.' + module_param_name
301
+ freeze_status = "normal"
302
+ # Search freeze layer names, it must match from beginning, so use `match` not `search`
303
+ for pattern in layer_names_pattern:
304
+ if pattern.match(name) is not None:
305
+ freeze_status = "freeze"
306
+ break
307
+
308
+ params.append({"freeze_status": freeze_status, "params": [value], **hyperparams})
309
+ return params
310
+
311
+
312
+ def build_lr_scheduler(cfg, optimizer, iters_per_epoch):
313
+ max_epoch = cfg.SOLVER.MAX_EPOCH - max(
314
+ math.ceil(cfg.SOLVER.WARMUP_ITERS / iters_per_epoch), cfg.SOLVER.DELAY_EPOCHS)
315
+
316
+ scheduler_dict = {}
317
+
318
+ scheduler_args = {
319
+ "MultiStepLR": {
320
+ "optimizer": optimizer,
321
+ # multi-step lr scheduler options
322
+ "milestones": cfg.SOLVER.STEPS,
323
+ "gamma": cfg.SOLVER.GAMMA,
324
+ },
325
+ "CosineAnnealingLR": {
326
+ "optimizer": optimizer,
327
+ # cosine annealing lr scheduler options
328
+ "T_max": max_epoch,
329
+ "eta_min": cfg.SOLVER.ETA_MIN_LR,
330
+ },
331
+
332
+ }
333
+
334
+ scheduler_dict["lr_sched"] = getattr(lr_scheduler, cfg.SOLVER.SCHED)(
335
+ **scheduler_args[cfg.SOLVER.SCHED])
336
+
337
+ if cfg.SOLVER.WARMUP_ITERS > 0:
338
+ warmup_args = {
339
+ "optimizer": optimizer,
340
+
341
+ # warmup options
342
+ "warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
343
+ "warmup_iters": cfg.SOLVER.WARMUP_ITERS,
344
+ "warmup_method": cfg.SOLVER.WARMUP_METHOD,
345
+ }
346
+ scheduler_dict["warmup_sched"] = lr_scheduler.WarmupLR(**warmup_args)
347
+
348
+ return scheduler_dict
@@ -0,0 +1,66 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ from typing import List
8
+
9
+ import torch
10
+ from torch.optim.lr_scheduler import *
11
+
12
+
13
+ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
14
+ def __init__(
15
+ self,
16
+ optimizer: torch.optim.Optimizer,
17
+ warmup_factor: float = 0.1,
18
+ warmup_iters: int = 1000,
19
+ warmup_method: str = "linear",
20
+ last_epoch: int = -1,
21
+ ):
22
+ self.warmup_factor = warmup_factor
23
+ self.warmup_iters = warmup_iters
24
+ self.warmup_method = warmup_method
25
+ super().__init__(optimizer, last_epoch)
26
+
27
+ def get_lr(self) -> List[float]:
28
+ warmup_factor = _get_warmup_factor_at_epoch(
29
+ self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
30
+ )
31
+ return [
32
+ base_lr * warmup_factor for base_lr in self.base_lrs
33
+ ]
34
+
35
+ def _compute_values(self) -> List[float]:
36
+ # The new interface
37
+ return self.get_lr()
38
+
39
+
40
+ def _get_warmup_factor_at_epoch(
41
+ method: str, iter: int, warmup_iters: int, warmup_factor: float
42
+ ) -> float:
43
+ """
44
+ Return the learning rate warmup factor at a specific iteration.
45
+ See https://arxiv.org/abs/1706.02677 for more details.
46
+ Args:
47
+ method (str): warmup method; either "constant" or "linear".
48
+ iter (int): iter at which to calculate the warmup factor.
49
+ warmup_iters (int): the number of warmup epochs.
50
+ warmup_factor (float): the base warmup factor (the meaning changes according
51
+ to the method used).
52
+ Returns:
53
+ float: the effective warmup factor at the given iteration.
54
+ """
55
+ if iter >= warmup_iters:
56
+ return 1.0
57
+
58
+ if method == "constant":
59
+ return warmup_factor
60
+ elif method == "linear":
61
+ alpha = iter / warmup_iters
62
+ return warmup_factor * (1 - alpha) + alpha
63
+ elif method == "exp":
64
+ return warmup_factor ** (1 - iter / warmup_iters)
65
+ else:
66
+ raise ValueError("Unknown warmup method: {}".format(method))
@@ -0,0 +1,10 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: xingyu liao
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ from .lamb import Lamb
8
+ from .swa import SWA
9
+ from .radam import RAdam
10
+ from torch.optim import *
@@ -0,0 +1,123 @@
1
+ ####
2
+ # CODE TAKEN FROM https://github.com/mgrankin/over9000
3
+ ####
4
+
5
+ import collections
6
+
7
+ import torch
8
+ from torch.optim.optimizer import Optimizer
9
+ from torch.utils.tensorboard import SummaryWriter
10
+
11
+
12
+ def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
13
+ """Log a histogram of trust ratio scalars in across layers."""
14
+ results = collections.defaultdict(list)
15
+ for group in optimizer.param_groups:
16
+ for p in group['params']:
17
+ state = optimizer.state[p]
18
+ for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
19
+ if i in state:
20
+ results[i].append(state[i])
21
+
22
+ for k, v in results.items():
23
+ event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
24
+
25
+
26
+ class Lamb(Optimizer):
27
+ r"""Implements Lamb algorithm.
28
+ It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
29
+ Arguments:
30
+ params (iterable): iterable of parameters to optimize or dicts defining
31
+ parameter groups
32
+ lr (float, optional): learning rate (default: 1e-3)
33
+ betas (Tuple[float, float], optional): coefficients used for computing
34
+ running averages of gradient and its square (default: (0.9, 0.999))
35
+ eps (float, optional): term added to the denominator to improve
36
+ numerical stability (default: 1e-8)
37
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
38
+ adam (bool, optional): always use trust ratio = 1, which turns this into
39
+ Adam. Useful for comparison purposes.
40
+ .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
41
+ https://arxiv.org/abs/1904.00962
42
+ """
43
+
44
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
45
+ weight_decay=0, adam=False):
46
+ if not 0.0 <= lr:
47
+ raise ValueError("Invalid learning rate: {}".format(lr))
48
+ if not 0.0 <= eps:
49
+ raise ValueError("Invalid epsilon value: {}".format(eps))
50
+ if not 0.0 <= betas[0] < 1.0:
51
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
52
+ if not 0.0 <= betas[1] < 1.0:
53
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
54
+ defaults = dict(lr=lr, betas=betas, eps=eps,
55
+ weight_decay=weight_decay)
56
+ self.adam = adam
57
+ super(Lamb, self).__init__(params, defaults)
58
+
59
+ def step(self, closure=None):
60
+ """Performs a single optimization step.
61
+ Arguments:
62
+ closure (callable, optional): A closure that reevaluates the model
63
+ and returns the loss.
64
+ """
65
+ loss = None
66
+ if closure is not None:
67
+ loss = closure()
68
+
69
+ for group in self.param_groups:
70
+ for p in group['params']:
71
+ if p.grad is None:
72
+ continue
73
+ grad = p.grad.data
74
+ if grad.is_sparse:
75
+ raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
76
+
77
+ state = self.state[p]
78
+
79
+ # State initialization
80
+ if len(state) == 0:
81
+ state['step'] = 0
82
+ # Exponential moving average of gradient values
83
+ state['exp_avg'] = torch.zeros_like(p.data)
84
+ # Exponential moving average of squared gradient values
85
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
86
+
87
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
88
+ beta1, beta2 = group['betas']
89
+
90
+ state['step'] += 1
91
+
92
+ # Decay the first and second moment running average coefficient
93
+ # m_t
94
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
95
+ # v_t
96
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
97
+
98
+ # Paper v3 does not use debiasing.
99
+ # bias_correction1 = 1 - beta1 ** state['step']
100
+ # bias_correction2 = 1 - beta2 ** state['step']
101
+ # Apply bias to lr to avoid broadcast.
102
+ step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
103
+
104
+ weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
105
+
106
+ adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
107
+ if group['weight_decay'] != 0:
108
+ adam_step.add_(group['weight_decay'], p.data)
109
+
110
+ adam_norm = adam_step.pow(2).sum().sqrt()
111
+ if weight_norm == 0 or adam_norm == 0:
112
+ trust_ratio = 1
113
+ else:
114
+ trust_ratio = weight_norm / adam_norm
115
+ state['weight_norm'] = weight_norm
116
+ state['adam_norm'] = adam_norm
117
+ state['trust_ratio'] = trust_ratio
118
+ if self.adam:
119
+ trust_ratio = 1
120
+
121
+ p.data.add_(-step_size * trust_ratio, adam_step)
122
+
123
+ return loss
@@ -0,0 +1,149 @@
1
+ import math
2
+
3
+ import torch
4
+ from torch.optim.optimizer import Optimizer
5
+
6
+
7
+ class RAdam(Optimizer):
8
+
9
+ def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
10
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
11
+ self.buffer = [[None, None, None] for ind in range(10)]
12
+ super(RAdam, self).__init__(params, defaults)
13
+
14
+ def __setstate__(self, state):
15
+ super(RAdam, self).__setstate__(state)
16
+
17
+ def step(self, closure=None):
18
+
19
+ loss = None
20
+ if closure is not None:
21
+ loss = closure()
22
+
23
+ for group in self.param_groups:
24
+
25
+ for p in group['params']:
26
+ if p.grad is None:
27
+ continue
28
+ grad = p.grad.data.float()
29
+ if grad.is_sparse:
30
+ raise RuntimeError('RAdam does not support sparse gradients')
31
+
32
+ p_data_fp32 = p.data.float()
33
+
34
+ state = self.state[p]
35
+
36
+ if len(state) == 0:
37
+ state['step'] = 0
38
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
39
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
40
+ else:
41
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
42
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
43
+
44
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
45
+ beta1, beta2 = group['betas']
46
+
47
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
48
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
49
+
50
+ state['step'] += 1
51
+ buffered = self.buffer[int(state['step'] % 10)]
52
+ if state['step'] == buffered[0]:
53
+ N_sma, step_size = buffered[1], buffered[2]
54
+ else:
55
+ buffered[0] = state['step']
56
+ beta2_t = beta2 ** state['step']
57
+ N_sma_max = 2 / (1 - beta2) - 1
58
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
59
+ buffered[1] = N_sma
60
+
61
+ # more conservative since it's an approximated value
62
+ if N_sma >= 5:
63
+ step_size = group['lr'] * math.sqrt(
64
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
65
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
66
+ else:
67
+ step_size = group['lr'] / (1 - beta1 ** state['step'])
68
+ buffered[2] = step_size
69
+
70
+ if group['weight_decay'] != 0:
71
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
72
+
73
+ # more conservative since it's an approximated value
74
+ if N_sma >= 5:
75
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
76
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
77
+ else:
78
+ p_data_fp32.add_(-step_size, exp_avg)
79
+
80
+ p.data.copy_(p_data_fp32)
81
+
82
+ return loss
83
+
84
+
85
+ class PlainRAdam(Optimizer):
86
+
87
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
88
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
89
+
90
+ super(PlainRAdam, self).__init__(params, defaults)
91
+
92
+ def __setstate__(self, state):
93
+ super(PlainRAdam, self).__setstate__(state)
94
+
95
+ def step(self, closure=None):
96
+
97
+ loss = None
98
+ if closure is not None:
99
+ loss = closure()
100
+
101
+ for group in self.param_groups:
102
+
103
+ for p in group['params']:
104
+ if p.grad is None:
105
+ continue
106
+ grad = p.grad.data.float()
107
+ if grad.is_sparse:
108
+ raise RuntimeError('RAdam does not support sparse gradients')
109
+
110
+ p_data_fp32 = p.data.float()
111
+
112
+ state = self.state[p]
113
+
114
+ if len(state) == 0:
115
+ state['step'] = 0
116
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
117
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
118
+ else:
119
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
120
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
121
+
122
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
123
+ beta1, beta2 = group['betas']
124
+
125
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
126
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
127
+
128
+ state['step'] += 1
129
+ beta2_t = beta2 ** state['step']
130
+ N_sma_max = 2 / (1 - beta2) - 1
131
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
132
+
133
+ if group['weight_decay'] != 0:
134
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
135
+
136
+ # more conservative since it's an approximated value
137
+ if N_sma >= 5:
138
+ step_size = group['lr'] * math.sqrt(
139
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
140
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
141
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
142
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
143
+ else:
144
+ step_size = group['lr'] / (1 - beta1 ** state['step'])
145
+ p_data_fp32.add_(-step_size, exp_avg)
146
+
147
+ p.data.copy_(p_data_fp32)
148
+
149
+ return loss