dnt 0.2.1__py3-none-any.whl → 0.3.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. dnt/__init__.py +4 -1
  2. dnt/analysis/__init__.py +3 -1
  3. dnt/analysis/count.py +107 -0
  4. dnt/analysis/interaction2.py +518 -0
  5. dnt/analysis/position.py +12 -0
  6. dnt/analysis/stop.py +92 -33
  7. dnt/analysis/stop2.py +289 -0
  8. dnt/analysis/stop3.py +758 -0
  9. dnt/detect/__init__.py +1 -1
  10. dnt/detect/signal/detector.py +326 -0
  11. dnt/detect/timestamp.py +105 -0
  12. dnt/detect/yolov8/detector.py +182 -35
  13. dnt/detect/yolov8/segmentor.py +171 -0
  14. dnt/engine/__init__.py +8 -0
  15. dnt/engine/bbox_interp.py +83 -0
  16. dnt/engine/bbox_iou.py +20 -0
  17. dnt/engine/cluster.py +31 -0
  18. dnt/engine/iob.py +66 -0
  19. dnt/filter/__init__.py +4 -0
  20. dnt/filter/filter.py +450 -21
  21. dnt/label/__init__.py +1 -1
  22. dnt/label/labeler.py +215 -14
  23. dnt/label/labeler2.py +631 -0
  24. dnt/shared/__init__.py +2 -1
  25. dnt/shared/data/coco.names +0 -0
  26. dnt/shared/data/openimages.names +0 -0
  27. dnt/shared/data/voc.names +0 -0
  28. dnt/shared/download.py +12 -0
  29. dnt/shared/synhcro.py +150 -0
  30. dnt/shared/util.py +17 -4
  31. dnt/third_party/fast-reid/__init__.py +1 -0
  32. dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
  33. dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
  34. dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
  35. dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
  36. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
  37. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
  38. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
  39. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
  40. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
  41. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
  42. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
  43. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
  44. dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
  45. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
  46. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
  47. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
  48. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
  49. dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
  50. dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
  51. dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
  52. dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
  53. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
  54. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
  55. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
  56. dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
  57. dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
  58. dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
  59. dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
  60. dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
  61. dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
  62. dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
  63. dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
  64. dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
  65. dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
  66. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
  67. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
  68. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
  69. dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
  70. dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
  71. dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
  72. dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
  73. dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
  74. dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
  75. dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
  76. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
  77. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
  78. dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
  79. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
  80. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
  81. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
  82. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
  83. dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
  84. dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
  85. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
  86. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
  87. dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
  88. dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
  89. dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
  90. dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
  91. dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
  92. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
  93. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
  94. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
  95. dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
  96. dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
  97. dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
  98. dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
  99. dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
  100. dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
  101. dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
  102. dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
  103. dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
  104. dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
  105. dnt/third_party/fast-reid/configs/__init__.py +0 -0
  106. dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
  107. dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
  108. dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
  109. dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
  110. dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
  111. dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
  112. dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
  113. dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
  114. dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
  115. dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
  116. dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
  117. dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
  118. dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
  119. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
  120. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
  121. dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
  122. dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
  123. dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
  124. dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
  125. dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
  126. dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
  127. dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
  128. dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
  129. dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
  130. dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
  131. dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
  132. dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
  133. dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
  134. dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
  135. dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
  136. dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
  137. dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
  138. dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
  139. dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
  140. dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
  141. dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
  142. dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
  143. dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
  144. dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
  145. dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
  146. dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
  147. dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
  148. dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
  149. dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
  150. dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
  151. dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
  152. dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
  153. dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
  154. dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
  155. dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
  156. dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
  157. dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
  158. dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
  159. dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
  160. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
  161. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
  162. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
  163. dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
  164. dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
  165. dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
  166. dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
  167. dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
  168. dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
  169. dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
  170. dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
  171. dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
  172. dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
  173. dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
  174. dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
  175. dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
  176. dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
  177. dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
  178. dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
  179. dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
  180. dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
  181. dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
  182. dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
  183. dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
  184. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
  185. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
  186. dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
  187. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
  188. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
  189. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
  190. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
  191. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
  192. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
  193. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
  194. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
  195. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
  196. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
  197. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
  198. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
  199. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
  200. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
  201. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
  202. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
  203. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
  204. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
  205. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
  206. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
  207. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
  208. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
  209. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
  210. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
  211. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
  212. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
  213. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
  214. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
  215. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
  216. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
  217. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
  218. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
  219. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
  220. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
  221. dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
  222. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
  223. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
  224. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
  225. dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
  226. dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
  227. dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
  228. dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
  229. dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
  230. dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
  231. dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
  232. dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
  233. dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
  234. dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
  235. dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
  236. dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
  237. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
  238. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
  239. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
  240. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
  241. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
  242. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
  243. dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
  244. dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
  245. dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
  246. dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
  247. dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
  248. dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
  249. dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
  250. dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
  251. dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
  252. dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
  253. dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
  254. dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
  255. dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
  256. dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
  257. dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
  258. dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
  259. dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
  260. dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
  261. dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
  262. dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
  263. dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
  264. dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
  265. dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
  266. dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
  267. dnt/track/__init__.py +3 -1
  268. dnt/track/botsort/__init__.py +4 -0
  269. dnt/track/botsort/bot_tracker/__init__.py +3 -0
  270. dnt/track/botsort/bot_tracker/basetrack.py +60 -0
  271. dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
  272. dnt/track/botsort/bot_tracker/gmc.py +316 -0
  273. dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
  274. dnt/track/botsort/bot_tracker/matching.py +194 -0
  275. dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
  276. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
  277. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
  278. dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
  279. dnt/track/botsort/inference.py +96 -0
  280. dnt/track/config.py +120 -0
  281. dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
  282. dnt/track/dsort/configs/deep_sort.yaml +0 -0
  283. dnt/track/dsort/configs/fastreid.yaml +1 -1
  284. dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
  285. dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
  286. dnt/track/dsort/deep_sort/deep_sort.py +31 -21
  287. dnt/track/dsort/deep_sort/sort/detection.py +2 -1
  288. dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
  289. dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
  290. dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
  291. dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
  292. dnt/track/dsort/deep_sort/sort/track.py +2 -1
  293. dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
  294. dnt/track/dsort/dsort.py +44 -27
  295. dnt/track/re_class.py +117 -0
  296. dnt/track/sort/sort.py +9 -7
  297. dnt/track/tracker.py +225 -20
  298. dnt-0.3.1.8.dist-info/METADATA +117 -0
  299. dnt-0.3.1.8.dist-info/RECORD +315 -0
  300. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/WHEEL +1 -1
  301. dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
  302. dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
  303. dnt/track/dsort/deep_sort/deep/test.py +0 -77
  304. dnt/track/dsort/deep_sort/deep/train.py +0 -189
  305. dnt/track/dsort/utils/asserts.py +0 -13
  306. dnt/track/dsort/utils/draw.py +0 -36
  307. dnt/track/dsort/utils/json_logger.py +0 -383
  308. dnt/track/dsort/utils/log.py +0 -17
  309. dnt/track/dsort/utils/parser.py +0 -35
  310. dnt/track/dsort/utils/tools.py +0 -39
  311. dnt-0.2.1.dist-info/METADATA +0 -35
  312. dnt-0.2.1.dist-info/RECORD +0 -60
  313. /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
  314. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info/licenses}/LICENSE +0 -0
  315. {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ # encoding: utf-8
2
+ """
3
+ credit:
4
+ https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/train_loop.py
5
+ """
6
+ import os, sys
7
+ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
8
+
9
+ import logging
10
+ import time
11
+ import weakref
12
+ from typing import Dict
13
+
14
+ import numpy as np
15
+ import torch
16
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
17
+
18
+
19
+ import fastreid.utils.comm as comm
20
+ from fastreid.utils.events import EventStorage, get_event_storage
21
+ from fastreid.utils.params import ContiguousParams
22
+
23
+ __all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class HookBase:
29
+ """
30
+ Base class for hooks that can be registered with :class:`TrainerBase`.
31
+ Each hook can implement 6 methods. The way they are called is demonstrated
32
+ in the following snippet:
33
+ .. code-block:: python
34
+ hook.before_train()
35
+ for _ in range(start_epoch, max_epoch):
36
+ hook.before_epoch()
37
+ for iter in range(start_iter, max_iter):
38
+ hook.before_step()
39
+ trainer.run_step()
40
+ hook.after_step()
41
+ hook.after_epoch()
42
+ hook.after_train()
43
+ Notes:
44
+ 1. In the hook method, users can access `self.trainer` to access more
45
+ properties about the context (e.g., current iteration).
46
+ 2. A hook that does something in :meth:`before_step` can often be
47
+ implemented equivalently in :meth:`after_step`.
48
+ If the hook takes non-trivial time, it is strongly recommended to
49
+ implement the hook in :meth:`after_step` instead of :meth:`before_step`.
50
+ The convention is that :meth:`before_step` should only take negligible time.
51
+ Following this convention will allow hooks that do care about the difference
52
+ between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
53
+ function properly.
54
+ Attributes:
55
+ trainer: A weak reference to the trainer object. Set by the trainer when the hook is
56
+ registered.
57
+ """
58
+
59
+ def before_train(self):
60
+ """
61
+ Called before the first iteration.
62
+ """
63
+ pass
64
+
65
+ def after_train(self):
66
+ """
67
+ Called after the last iteration.
68
+ """
69
+ pass
70
+
71
+ def before_epoch(self):
72
+ """
73
+ Called before each epoch.
74
+ """
75
+ pass
76
+
77
+ def after_epoch(self):
78
+ """
79
+ Called after each epoch.
80
+ """
81
+ pass
82
+
83
+ def before_step(self):
84
+ """
85
+ Called before each iteration.
86
+ """
87
+ pass
88
+
89
+ def after_step(self):
90
+ """
91
+ Called after each iteration.
92
+ """
93
+ pass
94
+
95
+
96
+ class TrainerBase:
97
+ """
98
+ Base class for iterative trainer with hooks.
99
+ The only assumption we made here is: the training runs in a loop.
100
+ A subclass can implement what the loop is.
101
+ We made no assumptions about the existence of dataloader, optimizer, model, etc.
102
+ Attributes:
103
+ iter(int): the current iteration.
104
+ epoch(int): the current epoch.
105
+ start_iter(int): The iteration to start with.
106
+ By convention the minimum possible value is 0.
107
+ max_epoch (int): The epoch to end training.
108
+ storage(EventStorage): An EventStorage that's opened during the course of training.
109
+ """
110
+
111
+ def __init__(self):
112
+ self._hooks = []
113
+
114
+ def register_hooks(self, hooks):
115
+ """
116
+ Register hooks to the trainer. The hooks are executed in the order
117
+ they are registered.
118
+ Args:
119
+ hooks (list[Optional[HookBase]]): list of hooks
120
+ """
121
+ hooks = [h for h in hooks if h is not None]
122
+ for h in hooks:
123
+ assert isinstance(h, HookBase)
124
+ # To avoid circular reference, hooks and trainer cannot own each other.
125
+ # This normally does not matter, but will cause memory leak if the
126
+ # involved objects contain __del__:
127
+ # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
128
+ h.trainer = weakref.proxy(self)
129
+ self._hooks.extend(hooks)
130
+
131
+ def train(self, start_epoch: int, max_epoch: int, iters_per_epoch: int):
132
+ """
133
+ Args:
134
+ start_epoch, max_epoch (int): See docs above
135
+ """
136
+ logger = logging.getLogger(__name__)
137
+ logger.info("Starting training from epoch {}".format(start_epoch))
138
+
139
+ self.iter = self.start_iter = start_epoch * iters_per_epoch
140
+
141
+ with EventStorage(self.start_iter) as self.storage:
142
+ try:
143
+ self.before_train()
144
+ for self.epoch in range(start_epoch, max_epoch):
145
+ self.before_epoch()
146
+ for _ in range(iters_per_epoch):
147
+ self.before_step()
148
+ self.run_step()
149
+ self.after_step()
150
+ self.iter += 1
151
+ self.after_epoch()
152
+ except Exception:
153
+ logger.exception("Exception during training:")
154
+ raise
155
+ finally:
156
+ self.after_train()
157
+
158
+ def before_train(self):
159
+ for h in self._hooks:
160
+ h.before_train()
161
+
162
+ def after_train(self):
163
+ self.storage.iter = self.iter
164
+ for h in self._hooks:
165
+ h.after_train()
166
+
167
+ def before_epoch(self):
168
+ self.storage.epoch = self.epoch
169
+
170
+ for h in self._hooks:
171
+ h.before_epoch()
172
+
173
+ def before_step(self):
174
+ self.storage.iter = self.iter
175
+
176
+ for h in self._hooks:
177
+ h.before_step()
178
+
179
+ def after_step(self):
180
+ for h in self._hooks:
181
+ h.after_step()
182
+
183
+ def after_epoch(self):
184
+ for h in self._hooks:
185
+ h.after_epoch()
186
+
187
+ def run_step(self):
188
+ raise NotImplementedError
189
+
190
+
191
+ class SimpleTrainer(TrainerBase):
192
+ """
193
+ A simple trainer for the most common type of task:
194
+ single-cost single-optimizer single-data-source iterative optimization.
195
+ It assumes that every step, you:
196
+ 1. Compute the loss with a data from the data_loader.
197
+ 2. Compute the gradients with the above loss.
198
+ 3. Update the model with the optimizer.
199
+ If you want to do anything fancier than this,
200
+ either subclass TrainerBase and implement your own `run_step`,
201
+ or write your own training loop.
202
+ """
203
+
204
+ def __init__(self, model, data_loader, optimizer, param_wrapper):
205
+ """
206
+ Args:
207
+ model: a torch Module. Takes a data from data_loader and returns a
208
+ dict of heads.
209
+ data_loader: an iterable. Contains data to be used to call model.
210
+ optimizer: a torch optimizer.
211
+ """
212
+ super().__init__()
213
+
214
+ """
215
+ We set the model to training mode in the trainer.
216
+ However it's valid to train a model that's in eval mode.
217
+ If you want your model (or a submodule of it) to behave
218
+ like evaluation during training, you can overwrite its train() method.
219
+ """
220
+ model.train()
221
+
222
+ self.model = model
223
+ self.data_loader = data_loader
224
+ self._data_loader_iter = iter(data_loader)
225
+ self.optimizer = optimizer
226
+ self.param_wrapper = param_wrapper
227
+
228
+ def run_step(self):
229
+ """
230
+ Implement the standard training logic described above.
231
+ """
232
+ assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
233
+ start = time.perf_counter()
234
+ """
235
+ If your want to do something with the data, you can wrap the dataloader.
236
+ """
237
+ data = next(self._data_loader_iter)
238
+ data_time = time.perf_counter() - start
239
+
240
+ """
241
+ If your want to do something with the heads, you can wrap the model.
242
+ """
243
+
244
+ loss_dict = self.model(data)
245
+ losses = sum(loss_dict.values())
246
+
247
+ """
248
+ If you need accumulate gradients or something similar, you can
249
+ wrap the optimizer with your custom `zero_grad()` method.
250
+ """
251
+ self.optimizer.zero_grad()
252
+
253
+ losses.backward()
254
+
255
+ self._write_metrics(loss_dict, data_time)
256
+
257
+ """
258
+ If you need gradient clipping/scaling or other processing, you can
259
+ wrap the optimizer with your custom `step()` method.
260
+ """
261
+ self.optimizer.step()
262
+ if isinstance(self.param_wrapper, ContiguousParams):
263
+ self.param_wrapper.assert_buffer_is_valid()
264
+
265
+ def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float):
266
+ """
267
+ Args:
268
+ loss_dict (dict): dict of scalar losses
269
+ data_time (float): time taken by the dataloader iteration
270
+ """
271
+ device = next(iter(loss_dict.values())).device
272
+
273
+ # Use a new stream so these ops don't wait for DDP or backward
274
+ with torch.cuda.stream(torch.cuda.Stream() if device.type == "cuda" else None):
275
+ metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
276
+ metrics_dict["data_time"] = data_time
277
+
278
+ # Gather metrics among all workers for logging
279
+ # This assumes we do DDP-style training, which is currently the only
280
+ # supported method in detectron2.
281
+ all_metrics_dict = comm.gather(metrics_dict)
282
+
283
+ if comm.is_main_process():
284
+ storage = get_event_storage()
285
+
286
+ # data_time among workers can have high variance. The actual latency
287
+ # caused by data_time is the maximum among workers.
288
+ data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
289
+ storage.put_scalar("data_time", data_time)
290
+
291
+ # average the rest metrics
292
+ metrics_dict = {
293
+ k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
294
+ }
295
+ total_losses_reduced = sum(metrics_dict.values())
296
+ if not np.isfinite(total_losses_reduced):
297
+ raise FloatingPointError(
298
+ f"Loss became infinite or NaN at iteration={self.iter}!\n"
299
+ f"loss_dict = {metrics_dict}"
300
+ )
301
+
302
+ storage.put_scalar("total_loss", total_losses_reduced)
303
+ if len(metrics_dict) > 1:
304
+ storage.put_scalars(**metrics_dict)
305
+
306
+
307
+ class AMPTrainer(SimpleTrainer):
308
+ """
309
+ Like :class:`SimpleTrainer`, but uses automatic mixed precision
310
+ in the training loop.
311
+ """
312
+
313
+ def __init__(self, model, data_loader, optimizer, param_wrapper, grad_scaler=None):
314
+ """
315
+
316
+ Args:
317
+ model, data_loader, optimizer: same as in :class:`SimpleTrainer`.
318
+ grad_scaler: torch GradScaler to automatically scale gradients.
319
+ """
320
+ unsupported = "AMPTrainer does not support single-process multi-device training!"
321
+ if isinstance(model, DistributedDataParallel):
322
+ assert not (model.device_ids and len(model.device_ids) > 1), unsupported
323
+ assert not isinstance(model, DataParallel), unsupported
324
+
325
+ super().__init__(model, data_loader, optimizer, param_wrapper)
326
+
327
+ if grad_scaler is None:
328
+ from torch.cuda.amp import GradScaler
329
+
330
+ grad_scaler = GradScaler()
331
+ self.grad_scaler = grad_scaler
332
+
333
+ def run_step(self):
334
+ """
335
+ Implement the AMP training logic.
336
+ """
337
+ assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
338
+ assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
339
+ from torch.cuda.amp import autocast
340
+
341
+ start = time.perf_counter()
342
+ data = next(self._data_loader_iter)
343
+ data_time = time.perf_counter() - start
344
+
345
+ with autocast():
346
+ loss_dict = self.model(data)
347
+ losses = sum(loss_dict.values())
348
+
349
+ self.optimizer.zero_grad()
350
+ self.grad_scaler.scale(losses).backward()
351
+
352
+ self._write_metrics(loss_dict, data_time)
353
+
354
+ self.grad_scaler.step(self.optimizer)
355
+ self.grad_scaler.update()
356
+ if isinstance(self.param_wrapper, ContiguousParams):
357
+ self.param_wrapper.assert_buffer_is_valid()
@@ -0,0 +1,6 @@
1
+ from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
2
+ from .reid_evaluation import ReidEvaluator
3
+ from .clas_evaluator import ClasEvaluator
4
+ from .testing import print_csv_format, verify_results
5
+
6
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
@@ -0,0 +1,81 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: xingyu liao
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ import copy
8
+ import itertools
9
+ import logging
10
+ from collections import OrderedDict
11
+
12
+ import torch
13
+
14
+ from fastreid.utils import comm
15
+ from .evaluator import DatasetEvaluator
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def accuracy(output, target, topk=(1,)):
21
+ """Computes the accuracy over the k top predictions for the specified values of k"""
22
+ with torch.no_grad():
23
+ maxk = max(topk)
24
+ batch_size = target.size(0)
25
+
26
+ _, pred = output.topk(maxk, 1, True, True)
27
+ pred = pred.t()
28
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
29
+
30
+ res = []
31
+ for k in topk:
32
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
33
+ res.append(correct_k.mul_(100.0 / batch_size))
34
+ return res
35
+
36
+
37
+ class ClasEvaluator(DatasetEvaluator):
38
+ def __init__(self, cfg, output_dir=None):
39
+ self.cfg = cfg
40
+ self._output_dir = output_dir
41
+ self._cpu_device = torch.device('cpu')
42
+
43
+ self._predictions = []
44
+
45
+ def reset(self):
46
+ self._predictions = []
47
+
48
+ def process(self, inputs, outputs):
49
+ pred_logits = outputs.to(self._cpu_device, torch.float32)
50
+ labels = inputs["targets"].to(self._cpu_device)
51
+
52
+ # measure accuracy
53
+ acc1, = accuracy(pred_logits, labels, topk=(1,))
54
+ num_correct_acc1 = acc1 * labels.size(0) / 100
55
+
56
+ self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)})
57
+
58
+ def evaluate(self):
59
+ if comm.get_world_size() > 1:
60
+ comm.synchronize()
61
+ predictions = comm.gather(self._predictions, dst=0)
62
+ predictions = list(itertools.chain(*predictions))
63
+
64
+ if not comm.is_main_process(): return {}
65
+
66
+ else:
67
+ predictions = self._predictions
68
+
69
+ total_correct_num = 0
70
+ total_samples = 0
71
+ for prediction in predictions:
72
+ total_correct_num += prediction["num_correct"]
73
+ total_samples += prediction["num_samples"]
74
+
75
+ acc1 = total_correct_num / total_samples * 100
76
+
77
+ self._results = OrderedDict()
78
+ self._results["Acc@1"] = acc1
79
+ self._results["metric"] = acc1
80
+
81
+ return copy.deepcopy(self._results)
@@ -0,0 +1,176 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import datetime
3
+ import logging
4
+ import time
5
+ from contextlib import contextmanager
6
+
7
+ import torch
8
+
9
+ from fastreid.utils import comm
10
+ from fastreid.utils.logger import log_every_n_seconds
11
+
12
+
13
+ class DatasetEvaluator:
14
+ """
15
+ Base class for a dataset evaluator.
16
+ The function :func:`inference_on_dataset` runs the model over
17
+ all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
18
+ This class will accumulate information of the inputs/outputs (by :meth:`process`),
19
+ and produce evaluation results in the end (by :meth:`evaluate`).
20
+ """
21
+
22
+ def reset(self):
23
+ """
24
+ Preparation for a new round of evaluation.
25
+ Should be called before starting a round of evaluation.
26
+ """
27
+ pass
28
+
29
+ def preprocess_inputs(self, inputs):
30
+ pass
31
+
32
+ def process(self, inputs, outputs):
33
+ """
34
+ Process an input/output pair.
35
+ Args:
36
+ inputs: the inputs that's used to call the model.
37
+ outputs: the return value of `model(input)`
38
+ """
39
+ pass
40
+
41
+ def evaluate(self):
42
+ """
43
+ Evaluate/summarize the performance, after processing all input/output pairs.
44
+ Returns:
45
+ dict:
46
+ A new evaluator class can return a dict of arbitrary format
47
+ as long as the user can process the results.
48
+ In our train_net.py, we expect the following format:
49
+ * key: the name of the task (e.g., bbox)
50
+ * value: a dict of {metric name: score}, e.g.: {"AP50": 80}
51
+ """
52
+ pass
53
+
54
+
55
+ # class DatasetEvaluators(DatasetEvaluator):
56
+ # def __init__(self, evaluators):
57
+ # assert len(evaluators)
58
+ # super().__init__()
59
+ # self._evaluators = evaluators
60
+ #
61
+ # def reset(self):
62
+ # for evaluator in self._evaluators:
63
+ # evaluator.reset()
64
+ #
65
+ # def process(self, input, output):
66
+ # for evaluator in self._evaluators:
67
+ # evaluator.process(input, output)
68
+ #
69
+ # def evaluate(self):
70
+ # results = OrderedDict()
71
+ # for evaluator in self._evaluators:
72
+ # result = evaluator.evaluate()
73
+ # if is_main_process() and result is not None:
74
+ # for k, v in result.items():
75
+ # assert (
76
+ # k not in results
77
+ # ), "Different evaluators produce results with the same key {}".format(k)
78
+ # results[k] = v
79
+ # return results
80
+
81
+
82
+ def inference_on_dataset(model, data_loader, evaluator, flip_test=False):
83
+ """
84
+ Run model on the data_loader and evaluate the metrics with evaluator.
85
+ The model will be used in eval mode.
86
+ Args:
87
+ model (nn.Module): a module which accepts an object from
88
+ `data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
89
+ If you wish to evaluate a model in `training` mode instead, you can
90
+ wrap the given model and override its behavior of `.eval()` and `.train()`.
91
+ data_loader: an iterable object with a length.
92
+ The elements it generates will be the inputs to the model.
93
+ evaluator (DatasetEvaluator): the evaluator to run. Use
94
+ :class:`DatasetEvaluators([])` if you only want to benchmark, but
95
+ don't want to do any evaluation.
96
+ flip_test (bool): If get features with flipped images
97
+ Returns:
98
+ The return value of `evaluator.evaluate()`
99
+ """
100
+ num_devices = comm.get_world_size()
101
+ logger = logging.getLogger(__name__)
102
+ logger.info("Start inference on {} images".format(len(data_loader.dataset)))
103
+
104
+ total = len(data_loader) # inference data loader must have a fixed length
105
+ evaluator.reset()
106
+
107
+ num_warmup = min(5, total - 1)
108
+ start_time = time.perf_counter()
109
+ total_compute_time = 0
110
+ with inference_context(model), torch.no_grad():
111
+ for idx, inputs in enumerate(data_loader):
112
+ if idx == num_warmup:
113
+ start_time = time.perf_counter()
114
+ total_compute_time = 0
115
+
116
+ start_compute_time = time.perf_counter()
117
+ outputs = model(inputs)
118
+ # Flip test
119
+ if flip_test:
120
+ inputs["images"] = inputs["images"].flip(dims=[3])
121
+ flip_outputs = model(inputs)
122
+ outputs = (outputs + flip_outputs) / 2
123
+ if torch.cuda.is_available():
124
+ torch.cuda.synchronize()
125
+ total_compute_time += time.perf_counter() - start_compute_time
126
+ evaluator.process(inputs, outputs)
127
+
128
+ iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
129
+ seconds_per_batch = total_compute_time / iters_after_start
130
+ if idx >= num_warmup * 2 or seconds_per_batch > 30:
131
+ total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
132
+ eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
133
+ log_every_n_seconds(
134
+ logging.INFO,
135
+ "Inference done {}/{}. {:.4f} s / batch. ETA={}".format(
136
+ idx + 1, total, seconds_per_batch, str(eta)
137
+ ),
138
+ n=30,
139
+ )
140
+
141
+ # Measure the time only for this worker (before the synchronization barrier)
142
+ total_time = time.perf_counter() - start_time
143
+ total_time_str = str(datetime.timedelta(seconds=total_time))
144
+ # NOTE this format is parsed by grep
145
+ logger.info(
146
+ "Total inference time: {} ({:.6f} s / batch per device, on {} devices)".format(
147
+ total_time_str, total_time / (total - num_warmup), num_devices
148
+ )
149
+ )
150
+ total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
151
+ logger.info(
152
+ "Total inference pure compute time: {} ({:.6f} s / batch per device, on {} devices)".format(
153
+ total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
154
+ )
155
+ )
156
+ results = evaluator.evaluate()
157
+
158
+ # An evaluator may return None when not in main process.
159
+ # Replace it by an empty dict instead to make it easier for downstream code to handle
160
+ if results is None:
161
+ results = {}
162
+ return results
163
+
164
+
165
+ @contextmanager
166
+ def inference_context(model):
167
+ """
168
+ A context where the model is temporarily changed to eval mode,
169
+ and restored to previous mode afterwards.
170
+ Args:
171
+ model: a torch Module
172
+ """
173
+ training_mode = model.training
174
+ model.eval()
175
+ yield
176
+ model.train(training_mode)
@@ -0,0 +1,46 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: xingyu liao
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ # based on
8
+ # https://github.com/PyRetri/PyRetri/blob/master/pyretri/index/re_ranker/re_ranker_impl/query_expansion.py
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def aqe(query_feat: torch.tensor, gallery_feat: torch.tensor,
16
+ qe_times: int = 1, qe_k: int = 10, alpha: float = 3.0):
17
+ """
18
+ Combining the retrieved topk nearest neighbors with the original query and doing another retrieval.
19
+ c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf
20
+ Args :
21
+ query_feat (torch.tensor):
22
+ gallery_feat (torch.tensor):
23
+ qe_times (int): number of query expansion times.
24
+ qe_k (int): number of the neighbors to be combined.
25
+ alpha (float):
26
+ """
27
+ num_query = query_feat.shape[0]
28
+ all_feat = torch.cat((query_feat, gallery_feat), dim=0)
29
+ norm_feat = F.normalize(all_feat, p=2, dim=1)
30
+
31
+ all_feat = all_feat.numpy()
32
+ for i in range(qe_times):
33
+ all_feat_list = []
34
+ sims = torch.mm(norm_feat, norm_feat.t())
35
+ sims = sims.data.cpu().numpy()
36
+ for sim in sims:
37
+ init_rank = np.argpartition(-sim, range(1, qe_k + 1))
38
+ weights = sim[init_rank[:qe_k]].reshape((-1, 1))
39
+ weights = np.power(weights, alpha)
40
+ all_feat_list.append(np.mean(all_feat[init_rank[:qe_k], :] * weights, axis=0))
41
+ all_feat = np.stack(all_feat_list, axis=0)
42
+ norm_feat = F.normalize(torch.from_numpy(all_feat), p=2, dim=1)
43
+
44
+ query_feat = torch.from_numpy(all_feat[:num_query])
45
+ gallery_feat = torch.from_numpy(all_feat[num_query:])
46
+ return query_feat, gallery_feat