dnt 0.2.1__py3-none-any.whl → 0.3.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. dnt/__init__.py +4 -1
  2. dnt/analysis/__init__.py +3 -1
  3. dnt/analysis/count.py +107 -0
  4. dnt/analysis/interaction2.py +518 -0
  5. dnt/analysis/position.py +12 -0
  6. dnt/analysis/stop.py +92 -33
  7. dnt/analysis/stop2.py +289 -0
  8. dnt/analysis/stop3.py +758 -0
  9. dnt/detect/__init__.py +1 -1
  10. dnt/detect/signal/detector.py +326 -0
  11. dnt/detect/timestamp.py +105 -0
  12. dnt/detect/yolov8/detector.py +182 -35
  13. dnt/detect/yolov8/segmentor.py +171 -0
  14. dnt/engine/__init__.py +8 -0
  15. dnt/engine/bbox_interp.py +83 -0
  16. dnt/engine/bbox_iou.py +20 -0
  17. dnt/engine/cluster.py +31 -0
  18. dnt/engine/iob.py +66 -0
  19. dnt/filter/__init__.py +4 -0
  20. dnt/filter/filter.py +450 -21
  21. dnt/label/__init__.py +1 -1
  22. dnt/label/labeler.py +215 -14
  23. dnt/label/labeler2.py +631 -0
  24. dnt/shared/__init__.py +2 -1
  25. dnt/shared/data/coco.names +0 -0
  26. dnt/shared/data/openimages.names +0 -0
  27. dnt/shared/data/voc.names +0 -0
  28. dnt/shared/download.py +12 -0
  29. dnt/shared/synhcro.py +150 -0
  30. dnt/shared/util.py +17 -4
  31. dnt/third_party/fast-reid/__init__.py +1 -0
  32. dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
  33. dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
  34. dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
  35. dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
  36. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
  37. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
  38. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
  39. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
  40. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
  41. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
  42. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
  43. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
  44. dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
  45. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
  46. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
  47. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
  48. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
  49. dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
  50. dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
  51. dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
  52. dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
  53. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
  54. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
  55. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
  56. dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
  57. dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
  58. dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
  59. dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
  60. dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
  61. dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
  62. dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
  63. dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
  64. dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
  65. dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
  66. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
  67. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
  68. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
  69. dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
  70. dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
  71. dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
  72. dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
  73. dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
  74. dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
  75. dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
  76. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
  77. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
  78. dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
  79. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
  80. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
  81. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
  82. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
  83. dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
  84. dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
  85. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
  86. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
  87. dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
  88. dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
  89. dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
  90. dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
  91. dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
  92. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
  93. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
  94. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
  95. dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
  96. dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
  97. dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
  98. dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
  99. dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
  100. dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
  101. dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
  102. dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
  103. dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
  104. dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
  105. dnt/third_party/fast-reid/configs/__init__.py +0 -0
  106. dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
  107. dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
  108. dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
  109. dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
  110. dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
  111. dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
  112. dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
  113. dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
  114. dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
  115. dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
  116. dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
  117. dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
  118. dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
  119. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
  120. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
  121. dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
  122. dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
  123. dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
  124. dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
  125. dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
  126. dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
  127. dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
  128. dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
  129. dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
  130. dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
  131. dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
  132. dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
  133. dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
  134. dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
  135. dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
  136. dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
  137. dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
  138. dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
  139. dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
  140. dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
  141. dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
  142. dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
  143. dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
  144. dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
  145. dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
  146. dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
  147. dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
  148. dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
  149. dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
  150. dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
  151. dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
  152. dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
  153. dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
  154. dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
  155. dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
  156. dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
  157. dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
  158. dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
  159. dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
  160. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
  161. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
  162. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
  163. dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
  164. dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
  165. dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
  166. dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
  167. dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
  168. dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
  169. dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
  170. dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
  171. dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
  172. dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
  173. dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
  174. dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
  175. dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
  176. dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
  177. dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
  178. dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
  179. dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
  180. dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
  181. dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
  182. dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
  183. dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
  184. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
  185. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
  186. dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
  187. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
  188. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
  189. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
  190. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
  191. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
  192. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
  193. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
  194. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
  195. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
  196. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
  197. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
  198. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
  199. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
  200. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
  201. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
  202. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
  203. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
  204. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
  205. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
  206. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
  207. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
  208. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
  209. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
  210. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
  211. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
  212. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
  213. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
  214. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
  215. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
  216. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
  217. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
  218. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
  219. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
  220. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
  221. dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
  222. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
  223. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
  224. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
  225. dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
  226. dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
  227. dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
  228. dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
  229. dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
  230. dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
  231. dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
  232. dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
  233. dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
  234. dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
  235. dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
  236. dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
  237. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
  238. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
  239. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
  240. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
  241. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
  242. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
  243. dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
  244. dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
  245. dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
  246. dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
  247. dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
  248. dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
  249. dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
  250. dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
  251. dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
  252. dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
  253. dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
  254. dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
  255. dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
  256. dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
  257. dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
  258. dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
  259. dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
  260. dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
  261. dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
  262. dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
  263. dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
  264. dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
  265. dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
  266. dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
  267. dnt/track/__init__.py +3 -1
  268. dnt/track/botsort/__init__.py +4 -0
  269. dnt/track/botsort/bot_tracker/__init__.py +3 -0
  270. dnt/track/botsort/bot_tracker/basetrack.py +60 -0
  271. dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
  272. dnt/track/botsort/bot_tracker/gmc.py +316 -0
  273. dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
  274. dnt/track/botsort/bot_tracker/matching.py +194 -0
  275. dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
  276. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
  277. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
  278. dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
  279. dnt/track/botsort/inference.py +96 -0
  280. dnt/track/config.py +120 -0
  281. dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
  282. dnt/track/dsort/configs/deep_sort.yaml +0 -0
  283. dnt/track/dsort/configs/fastreid.yaml +1 -1
  284. dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
  285. dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
  286. dnt/track/dsort/deep_sort/deep_sort.py +31 -21
  287. dnt/track/dsort/deep_sort/sort/detection.py +2 -1
  288. dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
  289. dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
  290. dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
  291. dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
  292. dnt/track/dsort/deep_sort/sort/track.py +2 -1
  293. dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
  294. dnt/track/dsort/dsort.py +44 -27
  295. dnt/track/re_class.py +117 -0
  296. dnt/track/sort/sort.py +9 -7
  297. dnt/track/tracker.py +225 -20
  298. dnt-0.3.1.8.dist-info/METADATA +117 -0
  299. dnt-0.3.1.8.dist-info/RECORD +315 -0
  300. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/WHEEL +1 -1
  301. dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
  302. dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
  303. dnt/track/dsort/deep_sort/deep/test.py +0 -77
  304. dnt/track/dsort/deep_sort/deep/train.py +0 -189
  305. dnt/track/dsort/utils/asserts.py +0 -13
  306. dnt/track/dsort/utils/draw.py +0 -36
  307. dnt/track/dsort/utils/json_logger.py +0 -383
  308. dnt/track/dsort/utils/log.py +0 -17
  309. dnt/track/dsort/utils/parser.py +0 -35
  310. dnt/track/dsort/utils/tools.py +0 -39
  311. dnt-0.2.1.dist-info/METADATA +0 -35
  312. dnt-0.2.1.dist-info/RECORD +0 -60
  313. /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
  314. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info/licenses}/LICENSE +0 -0
  315. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import functools
3
+ import logging
4
+ import os
5
+ import sys
6
+ import time
7
+ from collections import Counter
8
+
9
+ from termcolor import colored
10
+
11
+ from .file_io import PathManager
12
+
13
+
14
+ class _ColorfulFormatter(logging.Formatter):
15
+ def __init__(self, *args, **kwargs):
16
+ self._root_name = kwargs.pop("root_name") + "."
17
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
18
+ if len(self._abbrev_name):
19
+ self._abbrev_name = self._abbrev_name + "."
20
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
21
+
22
+ def formatMessage(self, record):
23
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
24
+ log = super(_ColorfulFormatter, self).formatMessage(record)
25
+ if record.levelno == logging.WARNING:
26
+ prefix = colored("WARNING", "red", attrs=["blink"])
27
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
28
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
29
+ else:
30
+ return log
31
+ return prefix + " " + log
32
+
33
+
34
+ @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
35
+ def setup_logger(
36
+ output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
37
+ ):
38
+ """
39
+ Args:
40
+ output (str): a file name or a directory to save log. If None, will not save log file.
41
+ If ends with ".txt" or ".log", assumed to be a file name.
42
+ Otherwise, logs will be saved to `output/log.txt`.
43
+ name (str): the root module name of this logger
44
+ abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
45
+ Set to "" to not log the root module in logs.
46
+ By default, will abbreviate "detectron2" to "d2" and leave other
47
+ modules unchanged.
48
+ """
49
+ logger = logging.getLogger(name)
50
+ logger.setLevel(logging.DEBUG)
51
+ logger.propagate = False
52
+
53
+ if abbrev_name is None:
54
+ abbrev_name = "d2" if name == "detectron2" else name
55
+
56
+ plain_formatter = logging.Formatter(
57
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
58
+ )
59
+ # stdout logging: master only
60
+ if distributed_rank == 0:
61
+ ch = logging.StreamHandler(stream=sys.stdout)
62
+ ch.setLevel(logging.DEBUG)
63
+ if color:
64
+ formatter = _ColorfulFormatter(
65
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
66
+ datefmt="%m/%d %H:%M:%S",
67
+ root_name=name,
68
+ abbrev_name=str(abbrev_name),
69
+ )
70
+ else:
71
+ formatter = plain_formatter
72
+ ch.setFormatter(formatter)
73
+ logger.addHandler(ch)
74
+
75
+ # file logging: all workers
76
+ if output is not None:
77
+ if output.endswith(".txt") or output.endswith(".log"):
78
+ filename = output
79
+ else:
80
+ filename = os.path.join(output, "log.txt")
81
+ if distributed_rank > 0:
82
+ filename = filename + ".rank{}".format(distributed_rank)
83
+ PathManager.mkdirs(os.path.dirname(filename))
84
+
85
+ fh = logging.StreamHandler(_cached_log_stream(filename))
86
+ fh.setLevel(logging.DEBUG)
87
+ fh.setFormatter(plain_formatter)
88
+ logger.addHandler(fh)
89
+
90
+ return logger
91
+
92
+
93
+ # cache the opened file object, so that different calls to `setup_logger`
94
+ # with the same file name can safely write to the same file.
95
+ @functools.lru_cache(maxsize=None)
96
+ def _cached_log_stream(filename):
97
+ return PathManager.open(filename, "a")
98
+
99
+
100
+ """
101
+ Below are some other convenient logging methods.
102
+ They are mainly adopted from
103
+ https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
104
+ """
105
+
106
+
107
+ def _find_caller():
108
+ """
109
+ Returns:
110
+ str: module name of the caller
111
+ tuple: a hashable key to be used to identify different callers
112
+ """
113
+ frame = sys._getframe(2)
114
+ while frame:
115
+ code = frame.f_code
116
+ if os.path.join("utils", "logger.") not in code.co_filename:
117
+ mod_name = frame.f_globals["__name__"]
118
+ if mod_name == "__main__":
119
+ mod_name = "detectron2"
120
+ return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
121
+ frame = frame.f_back
122
+
123
+
124
+ _LOG_COUNTER = Counter()
125
+ _LOG_TIMER = {}
126
+
127
+
128
+ def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
129
+ """
130
+ Log only for the first n times.
131
+ Args:
132
+ lvl (int): the logging level
133
+ msg (str):
134
+ n (int):
135
+ name (str): name of the logger to use. Will use the caller's module by default.
136
+ key (str or tuple[str]): the string(s) can be one of "caller" or
137
+ "message", which defines how to identify duplicated logs.
138
+ For example, if called with `n=1, key="caller"`, this function
139
+ will only log the first call from the same caller, regardless of
140
+ the message content.
141
+ If called with `n=1, key="message"`, this function will log the
142
+ same content only once, even if they are called from different places.
143
+ If called with `n=1, key=("caller", "message")`, this function
144
+ will not log only if the same caller has logged the same message before.
145
+ """
146
+ if isinstance(key, str):
147
+ key = (key,)
148
+ assert len(key) > 0
149
+
150
+ caller_module, caller_key = _find_caller()
151
+ hash_key = ()
152
+ if "caller" in key:
153
+ hash_key = hash_key + caller_key
154
+ if "message" in key:
155
+ hash_key = hash_key + (msg,)
156
+
157
+ _LOG_COUNTER[hash_key] += 1
158
+ if _LOG_COUNTER[hash_key] <= n:
159
+ logging.getLogger(name or caller_module).log(lvl, msg)
160
+
161
+
162
+ def log_every_n(lvl, msg, n=1, *, name=None):
163
+ """
164
+ Log once per n times.
165
+ Args:
166
+ lvl (int): the logging level
167
+ msg (str):
168
+ n (int):
169
+ name (str): name of the logger to use. Will use the caller's module by default.
170
+ """
171
+ caller_module, key = _find_caller()
172
+ _LOG_COUNTER[key] += 1
173
+ if n == 1 or _LOG_COUNTER[key] % n == 1:
174
+ logging.getLogger(name or caller_module).log(lvl, msg)
175
+
176
+
177
+ def log_every_n_seconds(lvl, msg, n=1, *, name=None):
178
+ """
179
+ Log no more than once per n seconds.
180
+ Args:
181
+ lvl (int): the logging level
182
+ msg (str):
183
+ n (int):
184
+ name (str): name of the logger to use. Will use the caller's module by default.
185
+ """
186
+ caller_module, key = _find_caller()
187
+ last_logged = _LOG_TIMER.get(key, None)
188
+ current_time = time.time()
189
+ if last_logged is None or current_time - last_logged >= n:
190
+ logging.getLogger(name or caller_module).log(lvl, msg)
191
+ _LOG_TIMER[key] = current_time
192
+
193
+ # def create_small_table(small_dict):
194
+ # """
195
+ # Create a small table using the keys of small_dict as headers. This is only
196
+ # suitable for small dictionaries.
197
+ # Args:
198
+ # small_dict (dict): a result dictionary of only a few items.
199
+ # Returns:
200
+ # str: the table as a string.
201
+ # """
202
+ # keys, values = tuple(zip(*small_dict.items()))
203
+ # table = tabulate(
204
+ # [values],
205
+ # headers=keys,
206
+ # tablefmt="pipe",
207
+ # floatfmt=".3f",
208
+ # stralign="center",
209
+ # numalign="center",
210
+ # )
211
+ # return table
@@ -0,0 +1,103 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ # based on: https://github.com/PhilJd/contiguous_pytorch_params/blob/master/contiguous_params/params.py
8
+
9
+ from collections import OrderedDict
10
+
11
+ import torch
12
+
13
+
14
+ class ContiguousParams:
15
+
16
+ def __init__(self, parameters):
17
+ # Create a list of the parameters to prevent emptying an iterator.
18
+ self._parameters = parameters
19
+ self._param_buffer = []
20
+ self._grad_buffer = []
21
+ self._group_dict = OrderedDict()
22
+ self._name_buffer = []
23
+ self._init_buffers()
24
+ # Store the data pointers for each parameter into the buffer. These
25
+ # can be used to check if an operation overwrites the gradient/data
26
+ # tensor (invalidating the assumption of a contiguous buffer).
27
+ self.data_pointers = []
28
+ self.grad_pointers = []
29
+ self.make_params_contiguous()
30
+
31
+ def _init_buffers(self):
32
+ dtype = self._parameters[0]["params"][0].dtype
33
+ device = self._parameters[0]["params"][0].device
34
+ if not all(p["params"][0].dtype == dtype for p in self._parameters):
35
+ raise ValueError("All parameters must be of the same dtype.")
36
+ if not all(p["params"][0].device == device for p in self._parameters):
37
+ raise ValueError("All parameters must be on the same device.")
38
+
39
+ # Group parameters by lr and weight decay
40
+ for param_dict in self._parameters:
41
+ freeze_status = param_dict["freeze_status"]
42
+ param_key = freeze_status + '_' + str(param_dict["lr"]) + '_' + str(param_dict["weight_decay"])
43
+ if param_key not in self._group_dict:
44
+ self._group_dict[param_key] = []
45
+ self._group_dict[param_key].append(param_dict)
46
+
47
+ for key, params in self._group_dict.items():
48
+ size = sum(p["params"][0].numel() for p in params)
49
+ self._param_buffer.append(torch.zeros(size, dtype=dtype, device=device))
50
+ self._grad_buffer.append(torch.zeros(size, dtype=dtype, device=device))
51
+ self._name_buffer.append(key)
52
+
53
+ def make_params_contiguous(self):
54
+ """Create a buffer to hold all params and update the params to be views of the buffer.
55
+ Args:
56
+ parameters: An iterable of parameters.
57
+ """
58
+ for i, params in enumerate(self._group_dict.values()):
59
+ index = 0
60
+ for param_dict in params:
61
+ p = param_dict["params"][0]
62
+ size = p.numel()
63
+ self._param_buffer[i][index:index + size] = p.data.view(-1)
64
+ p.data = self._param_buffer[i][index:index + size].view(p.data.shape)
65
+ p.grad = self._grad_buffer[i][index:index + size].view(p.data.shape)
66
+ self.data_pointers.append(p.data.data_ptr)
67
+ self.grad_pointers.append(p.grad.data.data_ptr)
68
+ index += size
69
+ # Bend the param_buffer to use grad_buffer to track its gradients.
70
+ self._param_buffer[i].grad = self._grad_buffer[i]
71
+
72
+ def contiguous(self):
73
+ """Return all parameters as one contiguous buffer."""
74
+ return [{
75
+ "freeze_status": self._name_buffer[i].split('_')[0],
76
+ "params": self._param_buffer[i],
77
+ "lr": float(self._name_buffer[i].split('_')[1]),
78
+ "weight_decay": float(self._name_buffer[i].split('_')[2]),
79
+ } for i in range(len(self._param_buffer))]
80
+
81
+ def original(self):
82
+ """Return the non-flattened parameters."""
83
+ return self._parameters
84
+
85
+ def buffer_is_valid(self):
86
+ """Verify that all parameters and gradients still use the buffer."""
87
+ i = 0
88
+ for params in self._group_dict.values():
89
+ for param_dict in params:
90
+ p = param_dict["params"][0]
91
+ data_ptr = self.data_pointers[i]
92
+ grad_ptr = self.grad_pointers[i]
93
+ if (p.data.data_ptr() != data_ptr()) or (p.grad.data.data_ptr() != grad_ptr()):
94
+ return False
95
+ i += 1
96
+ return True
97
+
98
+ def assert_buffer_is_valid(self):
99
+ if not self.buffer_is_valid():
100
+ raise ValueError(
101
+ "The data or gradient buffer has been invalidated. Please make "
102
+ "sure to use inplace operations only when updating parameters "
103
+ "or gradients.")
@@ -0,0 +1,94 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ import itertools
8
+
9
+ import torch
10
+
11
+ BN_MODULE_TYPES = (
12
+ torch.nn.BatchNorm1d,
13
+ torch.nn.BatchNorm2d,
14
+ torch.nn.BatchNorm3d,
15
+ torch.nn.SyncBatchNorm,
16
+ )
17
+
18
+
19
+ @torch.no_grad()
20
+ def update_bn_stats(model, data_loader, num_iters: int = 200):
21
+ """
22
+ Recompute and update the batch norm stats to make them more precise. During
23
+ training both BN stats and the weight are changing after every iteration, so
24
+ the running average can not precisely reflect the actual stats of the
25
+ current model.
26
+ In this function, the BN stats are recomputed with fixed weights, to make
27
+ the running average more precise. Specifically, it computes the true average
28
+ of per-batch mean/variance instead of the running average.
29
+ Args:
30
+ model (nn.Module): the model whose bn stats will be recomputed.
31
+ Note that:
32
+ 1. This function will not alter the training mode of the given model.
33
+ Users are responsible for setting the layers that needs
34
+ precise-BN to training mode, prior to calling this function.
35
+ 2. Be careful if your models contain other stateful layers in
36
+ addition to BN, i.e. layers whose state can change in forward
37
+ iterations. This function will alter their state. If you wish
38
+ them unchanged, you need to either pass in a submodule without
39
+ those layers, or backup the states.
40
+ data_loader (iterator): an iterator. Produce data as inputs to the model.
41
+ num_iters (int): number of iterations to compute the stats.
42
+ """
43
+ bn_layers = get_bn_modules(model)
44
+ if len(bn_layers) == 0:
45
+ return
46
+
47
+ # In order to make the running stats only reflect the current batch, the
48
+ # momentum is disabled.
49
+ # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
50
+ # Setting the momentum to 1.0 to compute the stats without momentum.
51
+ momentum_actual = [bn.momentum for bn in bn_layers]
52
+ for bn in bn_layers:
53
+ bn.momentum = 1.0
54
+
55
+ # Note that running_var actually means "running average of variance"
56
+ running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
57
+ running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
58
+
59
+ for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
60
+ inputs['targets'].fill_(-1)
61
+ with torch.no_grad(): # No need to backward
62
+ model(inputs)
63
+ for i, bn in enumerate(bn_layers):
64
+ # Accumulates the bn stats.
65
+ running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
66
+ running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
67
+ # We compute the "average of variance" across iterations.
68
+ assert ind == num_iters - 1, (
69
+ "update_bn_stats is meant to run for {} iterations, "
70
+ "but the dataloader stops at {} iterations.".format(num_iters, ind)
71
+ )
72
+
73
+ for i, bn in enumerate(bn_layers):
74
+ # Sets the precise bn stats.
75
+ bn.running_mean = running_mean[i]
76
+ bn.running_var = running_var[i]
77
+ bn.momentum = momentum_actual[i]
78
+
79
+
80
+ def get_bn_modules(model):
81
+ """
82
+ Find all BatchNorm (BN) modules that are in training mode. See
83
+ fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
84
+ included in this search.
85
+ Args:
86
+ model (nn.Module): a model possibly containing BN modules.
87
+ Returns:
88
+ list[nn.Module]: all BN modules in the model.
89
+ """
90
+ # Finds all the bn layers.
91
+ bn_layers = [
92
+ m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
93
+ ]
94
+ return bn_layers
@@ -0,0 +1,66 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+
4
+ from typing import Dict, Optional
5
+
6
+
7
+ class Registry(object):
8
+ """
9
+ The registry that provides name -> object mapping, to support third-party
10
+ users' custom modules.
11
+ To create a registry (e.g. a backbone registry):
12
+ .. code-block:: python
13
+ BACKBONE_REGISTRY = Registry('BACKBONE')
14
+ To register an object:
15
+ .. code-block:: python
16
+ @BACKBONE_REGISTRY.register()
17
+ class MyBackbone():
18
+ ...
19
+ Or:
20
+ .. code-block:: python
21
+ BACKBONE_REGISTRY.register(MyBackbone)
22
+ """
23
+
24
+ def __init__(self, name: str) -> None:
25
+ """
26
+ Args:
27
+ name (str): the name of this registry
28
+ """
29
+ self._name: str = name
30
+ self._obj_map: Dict[str, object] = {}
31
+
32
+ def _do_register(self, name: str, obj: object) -> None:
33
+ assert (
34
+ name not in self._obj_map
35
+ ), "An object named '{}' was already registered in '{}' registry!".format(
36
+ name, self._name
37
+ )
38
+ self._obj_map[name] = obj
39
+
40
+ def register(self, obj: object = None) -> Optional[object]:
41
+ """
42
+ Register the given object under the the name `obj.__name__`.
43
+ Can be used as either a decorator or not. See docstring of this class for usage.
44
+ """
45
+ if obj is None:
46
+ # used as a decorator
47
+ def deco(func_or_class: object) -> object:
48
+ name = func_or_class.__name__ # pyre-ignore
49
+ self._do_register(name, func_or_class)
50
+ return func_or_class
51
+
52
+ return deco
53
+
54
+ # used as a function call
55
+ name = obj.__name__ # pyre-ignore
56
+ self._do_register(name, obj)
57
+
58
+ def get(self, name: str) -> object:
59
+ ret = self._obj_map.get(name)
60
+ if ret is None:
61
+ raise KeyError(
62
+ "No object named '{}' found in '{}' registry!".format(
63
+ name, self._name
64
+ )
65
+ )
66
+ return ret
@@ -0,0 +1,120 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.autograd import Variable
10
+
11
+ from collections import OrderedDict
12
+ import numpy as np
13
+
14
+
15
+ def summary(model, input_size, batch_size=-1, device="cuda"):
16
+ def register_hook(module):
17
+
18
+ def hook(module, input, output):
19
+ class_name = str(module.__class__).split(".")[-1].split("'")[0]
20
+ module_idx = len(summary)
21
+
22
+ m_key = "%s-%i" % (class_name, module_idx + 1)
23
+ summary[m_key] = OrderedDict()
24
+ summary[m_key]["input_shape"] = list(input[0].size())
25
+ summary[m_key]["input_shape"][0] = batch_size
26
+ if isinstance(output, (list, tuple)):
27
+ summary[m_key]["output_shape"] = [
28
+ [-1] + list(o.size())[1:] for o in output
29
+ ]
30
+ else:
31
+ summary[m_key]["output_shape"] = list(output.size())
32
+ summary[m_key]["output_shape"][0] = batch_size
33
+
34
+ params = 0
35
+ if hasattr(module, "weight") and hasattr(module.weight, "size"):
36
+ params += torch.prod(torch.LongTensor(list(module.weight.size())))
37
+ summary[m_key]["trainable"] = module.weight.requires_grad
38
+ if hasattr(module, "bias") and hasattr(module.bias, "size"):
39
+ params += torch.prod(torch.LongTensor(list(module.bias.size())))
40
+ summary[m_key]["nb_params"] = params
41
+
42
+ if (
43
+ not isinstance(module, nn.Sequential)
44
+ and not isinstance(module, nn.ModuleList)
45
+ and not (module == model)
46
+ ):
47
+ hooks.append(module.register_forward_hook(hook))
48
+
49
+ device = device.lower()
50
+ assert device in [
51
+ "cuda",
52
+ "cpu",
53
+ ], "Input device is not valid, please specify 'cuda' or 'cpu'"
54
+
55
+ if device == "cuda" and torch.cuda.is_available():
56
+ dtype = torch.cuda.FloatTensor
57
+ else:
58
+ dtype = torch.FloatTensor
59
+
60
+ # multiple inputs to the network
61
+ if isinstance(input_size, tuple):
62
+ input_size = [input_size]
63
+
64
+ # batch_size of 2 for batchnorm
65
+ x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
66
+ # print(type(x[0]))
67
+
68
+ # create properties
69
+ summary = OrderedDict()
70
+ hooks = []
71
+
72
+ # register hook
73
+ model.apply(register_hook)
74
+
75
+ # make a forward pass
76
+ # print(x.shape)
77
+ model(*x)
78
+
79
+ # remove these hooks
80
+ for h in hooks:
81
+ h.remove()
82
+
83
+ print("----------------------------------------------------------------")
84
+ line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
85
+ print(line_new)
86
+ print("================================================================")
87
+ total_params = 0
88
+ total_output = 0
89
+ trainable_params = 0
90
+ for layer in summary:
91
+ # input_shape, output_shape, trainable, nb_params
92
+ line_new = "{:>20} {:>25} {:>15}".format(
93
+ layer,
94
+ str(summary[layer]["output_shape"]),
95
+ "{0:,}".format(summary[layer]["nb_params"]),
96
+ )
97
+ total_params += summary[layer]["nb_params"]
98
+ total_output += np.prod(summary[layer]["output_shape"])
99
+ if "trainable" in summary[layer]:
100
+ if summary[layer]["trainable"] == True:
101
+ trainable_params += summary[layer]["nb_params"]
102
+ print(line_new)
103
+
104
+ # assume 4 bytes/number (float on cuda).
105
+ total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
106
+ total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
107
+ total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
108
+ total_size = total_params_size + total_output_size + total_input_size
109
+
110
+ print("================================================================")
111
+ print("Total params: {0:,}".format(total_params))
112
+ print("Trainable params: {0:,}".format(trainable_params))
113
+ print("Non-trainable params: {0:,}".format(total_params - trainable_params))
114
+ print("----------------------------------------------------------------")
115
+ print("Input size (MB): %0.2f" % total_input_size)
116
+ print("Forward/backward pass size (MB): %0.2f" % total_output_size)
117
+ print("Params size (MB): %0.2f" % total_params_size)
118
+ print("Estimated Total Size (MB): %0.2f" % total_size)
119
+ print("----------------------------------------------------------------")
120
+ # return summary
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from time import perf_counter
5
+ from typing import Optional
6
+
7
+
8
+ class Timer:
9
+ """
10
+ A timer which computes the time elapsed since the start/reset of the timer.
11
+ """
12
+
13
+ def __init__(self):
14
+ self.reset()
15
+
16
+ def reset(self):
17
+ """
18
+ Reset the timer.
19
+ """
20
+ self._start = perf_counter()
21
+ self._paused: Optional[float] = None
22
+ self._total_paused = 0
23
+ self._count_start = 1
24
+
25
+ def pause(self):
26
+ """
27
+ Pause the timer.
28
+ """
29
+ if self._paused is not None:
30
+ raise ValueError("Trying to pause a Timer that is already paused!")
31
+ self._paused = perf_counter()
32
+
33
+ def is_paused(self) -> bool:
34
+ """
35
+ Returns:
36
+ bool: whether the timer is currently paused
37
+ """
38
+ return self._paused is not None
39
+
40
+ def resume(self):
41
+ """
42
+ Resume the timer.
43
+ """
44
+ if self._paused is None:
45
+ raise ValueError("Trying to resume a Timer that is not paused!")
46
+ self._total_paused += perf_counter() - self._paused
47
+ self._paused = None
48
+ self._count_start += 1
49
+
50
+ def seconds(self) -> float:
51
+ """
52
+ Returns:
53
+ (float): the total number of seconds since the start/reset of the
54
+ timer, excluding the time when the timer is paused.
55
+ """
56
+ if self._paused is not None:
57
+ end_time: float = self._paused # type: ignore
58
+ else:
59
+ end_time = perf_counter()
60
+ return end_time - self._start - self._total_paused
61
+
62
+ def avg_seconds(self) -> float:
63
+ """
64
+ Returns:
65
+ (float): the average number of seconds between every start/reset and
66
+ pause.
67
+ """
68
+ return self.seconds() / self._count_start