dnt 0.2.4__py3-none-any.whl → 0.3.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dnt might be problematic. Click here for more details.

Files changed (305) hide show
  1. dnt/__init__.py +3 -2
  2. dnt/analysis/__init__.py +3 -2
  3. dnt/analysis/interaction.py +503 -0
  4. dnt/analysis/stop.py +22 -17
  5. dnt/analysis/stop2.py +289 -0
  6. dnt/analysis/stop3.py +754 -0
  7. dnt/detect/signal/detector.py +317 -0
  8. dnt/detect/yolov8/detector.py +116 -16
  9. dnt/engine/__init__.py +8 -0
  10. dnt/engine/bbox_interp.py +83 -0
  11. dnt/engine/bbox_iou.py +20 -0
  12. dnt/engine/cluster.py +31 -0
  13. dnt/engine/iob.py +66 -0
  14. dnt/filter/filter.py +321 -1
  15. dnt/label/labeler.py +4 -4
  16. dnt/label/labeler2.py +502 -0
  17. dnt/shared/__init__.py +2 -1
  18. dnt/shared/data/coco.names +0 -0
  19. dnt/shared/data/openimages.names +0 -0
  20. dnt/shared/data/voc.names +0 -0
  21. dnt/shared/download.py +12 -0
  22. dnt/shared/synhcro.py +150 -0
  23. dnt/shared/util.py +17 -4
  24. dnt/third_party/fast-reid/__init__.py +1 -0
  25. dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
  26. dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
  27. dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
  28. dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
  29. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
  30. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
  31. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
  32. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
  33. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
  34. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
  35. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
  36. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
  37. dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
  38. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
  39. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
  40. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
  41. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
  42. dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
  43. dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
  44. dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
  45. dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
  46. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
  47. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
  48. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
  49. dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
  50. dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
  51. dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
  52. dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
  53. dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
  54. dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
  55. dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
  56. dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
  57. dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
  58. dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
  59. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
  60. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
  61. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
  62. dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
  63. dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
  64. dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
  65. dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
  66. dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
  67. dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
  68. dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
  69. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
  70. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
  71. dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
  72. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
  73. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
  74. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
  75. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
  76. dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
  77. dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
  78. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
  79. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
  80. dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
  81. dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
  82. dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
  83. dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
  84. dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
  85. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
  86. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
  87. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
  88. dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
  89. dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
  90. dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
  91. dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
  92. dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
  93. dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
  94. dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
  95. dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
  96. dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
  97. dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
  98. dnt/third_party/fast-reid/configs/__init__.py +0 -0
  99. dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
  100. dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
  101. dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
  102. dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
  103. dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
  104. dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
  105. dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
  106. dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
  107. dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
  108. dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
  109. dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
  110. dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
  111. dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
  112. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
  113. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
  114. dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
  115. dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
  116. dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
  117. dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
  118. dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
  119. dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
  120. dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
  121. dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
  122. dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
  123. dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
  124. dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
  125. dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
  126. dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
  127. dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
  128. dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
  129. dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
  130. dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
  131. dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
  132. dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
  133. dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
  134. dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
  135. dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
  136. dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
  137. dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
  138. dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
  139. dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
  140. dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
  141. dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
  142. dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
  143. dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
  144. dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
  145. dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
  146. dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
  147. dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
  148. dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
  149. dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
  150. dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
  151. dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
  152. dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
  153. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
  154. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
  155. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
  156. dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
  157. dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
  158. dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
  159. dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
  160. dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
  161. dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
  162. dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
  163. dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
  164. dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
  165. dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
  166. dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
  167. dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
  168. dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
  169. dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
  170. dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
  171. dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
  172. dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
  173. dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
  174. dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
  175. dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
  176. dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
  177. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
  178. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
  179. dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
  180. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
  181. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
  182. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
  183. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
  184. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
  185. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
  186. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
  187. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
  188. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
  189. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
  190. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
  191. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
  192. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
  193. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
  194. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
  195. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
  196. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
  197. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
  198. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
  199. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
  200. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
  201. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
  202. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
  203. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
  204. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
  205. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
  206. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
  207. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
  208. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
  209. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
  210. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
  211. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
  212. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
  213. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
  214. dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
  215. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
  216. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
  217. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
  218. dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
  219. dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
  220. dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
  221. dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
  222. dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
  223. dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
  224. dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
  225. dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
  226. dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
  227. dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
  228. dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
  229. dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
  230. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
  231. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
  232. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
  233. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
  234. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
  235. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
  236. dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
  237. dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
  238. dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
  239. dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
  240. dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
  241. dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
  242. dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
  243. dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
  244. dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
  245. dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
  246. dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
  247. dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
  248. dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
  249. dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
  250. dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
  251. dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
  252. dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
  253. dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
  254. dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
  255. dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
  256. dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
  257. dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
  258. dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
  259. dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
  260. dnt/track/__init__.py +2 -0
  261. dnt/track/botsort/__init__.py +4 -0
  262. dnt/track/botsort/bot_tracker/__init__.py +3 -0
  263. dnt/track/botsort/bot_tracker/basetrack.py +60 -0
  264. dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
  265. dnt/track/botsort/bot_tracker/gmc.py +316 -0
  266. dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
  267. dnt/track/botsort/bot_tracker/matching.py +194 -0
  268. dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
  269. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
  270. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
  271. dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
  272. dnt/track/botsort/inference.py +96 -0
  273. dnt/track/config.py +120 -0
  274. dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
  275. dnt/track/dsort/configs/deep_sort.yaml +0 -0
  276. dnt/track/dsort/configs/fastreid.yaml +1 -1
  277. dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
  278. dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
  279. dnt/track/dsort/deep_sort/deep_sort.py +28 -18
  280. dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
  281. dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
  282. dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
  283. dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
  284. dnt/track/dsort/dsort.py +21 -28
  285. dnt/track/re_class.py +94 -0
  286. dnt/track/sort/sort.py +5 -1
  287. dnt/track/tracker.py +207 -30
  288. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/METADATA +30 -10
  289. dnt-0.3.1.3.dist-info/RECORD +314 -0
  290. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/WHEEL +1 -1
  291. dnt/analysis/yield.py +0 -9
  292. dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
  293. dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
  294. dnt/track/dsort/deep_sort/deep/test.py +0 -77
  295. dnt/track/dsort/deep_sort/deep/train.py +0 -189
  296. dnt/track/dsort/utils/asserts.py +0 -13
  297. dnt/track/dsort/utils/draw.py +0 -36
  298. dnt/track/dsort/utils/json_logger.py +0 -383
  299. dnt/track/dsort/utils/log.py +0 -17
  300. dnt/track/dsort/utils/parser.py +0 -35
  301. dnt/track/dsort/utils/tools.py +0 -39
  302. dnt-0.2.4.dist-info/RECORD +0 -64
  303. /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
  304. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/LICENSE +0 -0
  305. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,394 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+ import copy
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from fastreid.config import configurable
12
+ from fastreid.layers import get_norm
13
+ from fastreid.modeling.backbones import build_backbone
14
+ from fastreid.modeling.backbones.resnet import Bottleneck
15
+ from fastreid.modeling.heads import build_heads
16
+ from fastreid.modeling.losses import *
17
+ from .build import META_ARCH_REGISTRY
18
+
19
+
20
+ @META_ARCH_REGISTRY.register()
21
+ class MGN(nn.Module):
22
+ """
23
+ Multiple Granularities Network architecture, which contains the following two components:
24
+ 1. Per-image feature extraction (aka backbone)
25
+ 2. Multi-branch feature aggregation
26
+ """
27
+
28
+ @configurable
29
+ def __init__(
30
+ self,
31
+ *,
32
+ backbone,
33
+ neck1,
34
+ neck2,
35
+ neck3,
36
+ b1_head,
37
+ b2_head,
38
+ b21_head,
39
+ b22_head,
40
+ b3_head,
41
+ b31_head,
42
+ b32_head,
43
+ b33_head,
44
+ pixel_mean,
45
+ pixel_std,
46
+ loss_kwargs=None
47
+ ):
48
+ """
49
+ NOTE: this interface is experimental.
50
+
51
+ Args:
52
+ backbone:
53
+ neck1:
54
+ neck2:
55
+ neck3:
56
+ b1_head:
57
+ b2_head:
58
+ b21_head:
59
+ b22_head:
60
+ b3_head:
61
+ b31_head:
62
+ b32_head:
63
+ b33_head:
64
+ pixel_mean:
65
+ pixel_std:
66
+ loss_kwargs:
67
+ """
68
+
69
+ super().__init__()
70
+
71
+ self.backbone = backbone
72
+
73
+ # branch1
74
+ self.b1 = neck1
75
+ self.b1_head = b1_head
76
+
77
+ # branch2
78
+ self.b2 = neck2
79
+ self.b2_head = b2_head
80
+ self.b21_head = b21_head
81
+ self.b22_head = b22_head
82
+
83
+ # branch3
84
+ self.b3 = neck3
85
+ self.b3_head = b3_head
86
+ self.b31_head = b31_head
87
+ self.b32_head = b32_head
88
+ self.b33_head = b33_head
89
+
90
+ self.loss_kwargs = loss_kwargs
91
+ self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
92
+ self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
93
+
94
+ @classmethod
95
+ def from_config(cls, cfg):
96
+ bn_norm = cfg.MODEL.BACKBONE.NORM
97
+ with_se = cfg.MODEL.BACKBONE.WITH_SE
98
+
99
+ all_blocks = build_backbone(cfg)
100
+
101
+ # backbone
102
+ backbone = nn.Sequential(
103
+ all_blocks.conv1,
104
+ all_blocks.bn1,
105
+ all_blocks.relu,
106
+ all_blocks.maxpool,
107
+ all_blocks.layer1,
108
+ all_blocks.layer2,
109
+ all_blocks.layer3[0]
110
+ )
111
+ res_conv4 = nn.Sequential(*all_blocks.layer3[1:])
112
+ res_g_conv5 = all_blocks.layer4
113
+
114
+ res_p_conv5 = nn.Sequential(
115
+ Bottleneck(1024, 512, bn_norm, False, with_se, downsample=nn.Sequential(
116
+ nn.Conv2d(1024, 2048, 1, bias=False), get_norm(bn_norm, 2048))),
117
+ Bottleneck(2048, 512, bn_norm, False, with_se),
118
+ Bottleneck(2048, 512, bn_norm, False, with_se))
119
+ res_p_conv5.load_state_dict(all_blocks.layer4.state_dict())
120
+
121
+ # branch
122
+ neck1 = nn.Sequential(
123
+ copy.deepcopy(res_conv4),
124
+ copy.deepcopy(res_g_conv5)
125
+ )
126
+ b1_head = build_heads(cfg)
127
+
128
+ # branch2
129
+ neck2 = nn.Sequential(
130
+ copy.deepcopy(res_conv4),
131
+ copy.deepcopy(res_p_conv5)
132
+ )
133
+ b2_head = build_heads(cfg)
134
+ b21_head = build_heads(cfg)
135
+ b22_head = build_heads(cfg)
136
+
137
+ # branch3
138
+ neck3 = nn.Sequential(
139
+ copy.deepcopy(res_conv4),
140
+ copy.deepcopy(res_p_conv5)
141
+ )
142
+ b3_head = build_heads(cfg)
143
+ b31_head = build_heads(cfg)
144
+ b32_head = build_heads(cfg)
145
+ b33_head = build_heads(cfg)
146
+
147
+ return {
148
+ 'backbone': backbone,
149
+ 'neck1': neck1,
150
+ 'neck2': neck2,
151
+ 'neck3': neck3,
152
+ 'b1_head': b1_head,
153
+ 'b2_head': b2_head,
154
+ 'b21_head': b21_head,
155
+ 'b22_head': b22_head,
156
+ 'b3_head': b3_head,
157
+ 'b31_head': b31_head,
158
+ 'b32_head': b32_head,
159
+ 'b33_head': b33_head,
160
+ 'pixel_mean': cfg.MODEL.PIXEL_MEAN,
161
+ 'pixel_std': cfg.MODEL.PIXEL_STD,
162
+ 'loss_kwargs':
163
+ {
164
+ # loss name
165
+ 'loss_names': cfg.MODEL.LOSSES.NAME,
166
+
167
+ # loss hyperparameters
168
+ 'ce': {
169
+ 'eps': cfg.MODEL.LOSSES.CE.EPSILON,
170
+ 'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
171
+ 'scale': cfg.MODEL.LOSSES.CE.SCALE
172
+ },
173
+ 'tri': {
174
+ 'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
175
+ 'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
176
+ 'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
177
+ 'scale': cfg.MODEL.LOSSES.TRI.SCALE
178
+ },
179
+ 'circle': {
180
+ 'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
181
+ 'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
182
+ 'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
183
+ },
184
+ 'cosface': {
185
+ 'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
186
+ 'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
187
+ 'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
188
+ }
189
+ }
190
+ }
191
+
192
+ @property
193
+ def device(self):
194
+ return self.pixel_mean.device
195
+
196
+ def forward(self, batched_inputs):
197
+ images = self.preprocess_image(batched_inputs)
198
+ features = self.backbone(images) # (bs, 2048, 16, 8)
199
+
200
+ # branch1
201
+ b1_feat = self.b1(features)
202
+
203
+ # branch2
204
+ b2_feat = self.b2(features)
205
+ b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
206
+
207
+ # branch3
208
+ b3_feat = self.b3(features)
209
+ b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
210
+
211
+ if self.training:
212
+ assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
213
+ targets = batched_inputs["targets"]
214
+
215
+ if targets.sum() < 0: targets.zero_()
216
+
217
+ b1_outputs = self.b1_head(b1_feat, targets)
218
+ b2_outputs = self.b2_head(b2_feat, targets)
219
+ b21_outputs = self.b21_head(b21_feat, targets)
220
+ b22_outputs = self.b22_head(b22_feat, targets)
221
+ b3_outputs = self.b3_head(b3_feat, targets)
222
+ b31_outputs = self.b31_head(b31_feat, targets)
223
+ b32_outputs = self.b32_head(b32_feat, targets)
224
+ b33_outputs = self.b33_head(b33_feat, targets)
225
+
226
+ losses = self.losses(b1_outputs,
227
+ b2_outputs, b21_outputs, b22_outputs,
228
+ b3_outputs, b31_outputs, b32_outputs, b33_outputs,
229
+ targets)
230
+ return losses
231
+ else:
232
+ b1_pool_feat = self.b1_head(b1_feat)
233
+ b2_pool_feat = self.b2_head(b2_feat)
234
+ b21_pool_feat = self.b21_head(b21_feat)
235
+ b22_pool_feat = self.b22_head(b22_feat)
236
+ b3_pool_feat = self.b3_head(b3_feat)
237
+ b31_pool_feat = self.b31_head(b31_feat)
238
+ b32_pool_feat = self.b32_head(b32_feat)
239
+ b33_pool_feat = self.b33_head(b33_feat)
240
+
241
+ pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
242
+ b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
243
+ return pred_feat
244
+
245
+ def preprocess_image(self, batched_inputs):
246
+ r"""
247
+ Normalize and batch the input images.
248
+ """
249
+ if isinstance(batched_inputs, dict):
250
+ images = batched_inputs["images"].to(self.device)
251
+ elif isinstance(batched_inputs, torch.Tensor):
252
+ images = batched_inputs.to(self.device)
253
+ else:
254
+ raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
255
+
256
+ images.sub_(self.pixel_mean).div_(self.pixel_std)
257
+ return images
258
+
259
+ def losses(self,
260
+ b1_outputs,
261
+ b2_outputs, b21_outputs, b22_outputs,
262
+ b3_outputs, b31_outputs, b32_outputs, b33_outputs, gt_labels):
263
+ # model predictions
264
+ # fmt: off
265
+ pred_class_logits = b1_outputs['pred_class_logits'].detach()
266
+ b1_logits = b1_outputs['cls_outputs']
267
+ b2_logits = b2_outputs['cls_outputs']
268
+ b21_logits = b21_outputs['cls_outputs']
269
+ b22_logits = b22_outputs['cls_outputs']
270
+ b3_logits = b3_outputs['cls_outputs']
271
+ b31_logits = b31_outputs['cls_outputs']
272
+ b32_logits = b32_outputs['cls_outputs']
273
+ b33_logits = b33_outputs['cls_outputs']
274
+ b1_pool_feat = b1_outputs['features']
275
+ b2_pool_feat = b2_outputs['features']
276
+ b3_pool_feat = b3_outputs['features']
277
+ b21_pool_feat = b21_outputs['features']
278
+ b22_pool_feat = b22_outputs['features']
279
+ b31_pool_feat = b31_outputs['features']
280
+ b32_pool_feat = b32_outputs['features']
281
+ b33_pool_feat = b33_outputs['features']
282
+ # fmt: on
283
+
284
+ # Log prediction accuracy
285
+ log_accuracy(pred_class_logits, gt_labels)
286
+
287
+ b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1)
288
+ b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
289
+
290
+ loss_dict = {}
291
+ loss_names = self.loss_kwargs['loss_names']
292
+
293
+ if "CrossEntropyLoss" in loss_names:
294
+ ce_kwargs = self.loss_kwargs.get('ce')
295
+ loss_dict['loss_cls_b1'] = cross_entropy_loss(
296
+ b1_logits,
297
+ gt_labels,
298
+ ce_kwargs.get('eps'),
299
+ ce_kwargs.get('alpha')
300
+ ) * ce_kwargs.get('scale') * 0.125
301
+
302
+ loss_dict['loss_cls_b2'] = cross_entropy_loss(
303
+ b2_logits,
304
+ gt_labels,
305
+ ce_kwargs.get('eps'),
306
+ ce_kwargs.get('alpha')
307
+ ) * ce_kwargs.get('scale') * 0.125
308
+
309
+ loss_dict['loss_cls_b21'] = cross_entropy_loss(
310
+ b21_logits,
311
+ gt_labels,
312
+ ce_kwargs.get('eps'),
313
+ ce_kwargs.get('alpha')
314
+ ) * ce_kwargs.get('scale') * 0.125
315
+
316
+ loss_dict['loss_cls_b22'] = cross_entropy_loss(
317
+ b22_logits,
318
+ gt_labels,
319
+ ce_kwargs.get('eps'),
320
+ ce_kwargs.get('alpha')
321
+ ) * ce_kwargs.get('scale') * 0.125
322
+
323
+ loss_dict['loss_cls_b3'] = cross_entropy_loss(
324
+ b3_logits,
325
+ gt_labels,
326
+ ce_kwargs.get('eps'),
327
+ ce_kwargs.get('alpha')
328
+ ) * ce_kwargs.get('scale') * 0.125
329
+
330
+ loss_dict['loss_cls_b31'] = cross_entropy_loss(
331
+ b31_logits,
332
+ gt_labels,
333
+ ce_kwargs.get('eps'),
334
+ ce_kwargs.get('alpha')
335
+ ) * ce_kwargs.get('scale') * 0.125
336
+
337
+ loss_dict['loss_cls_b32'] = cross_entropy_loss(
338
+ b32_logits,
339
+ gt_labels,
340
+ ce_kwargs.get('eps'),
341
+ ce_kwargs.get('alpha')
342
+ ) * ce_kwargs.get('scale') * 0.125
343
+
344
+ loss_dict['loss_cls_b33'] = cross_entropy_loss(
345
+ b33_logits,
346
+ gt_labels,
347
+ ce_kwargs.get('eps'),
348
+ ce_kwargs.get('alpha')
349
+ ) * ce_kwargs.get('scale') * 0.125
350
+
351
+ if "TripletLoss" in loss_names:
352
+ tri_kwargs = self.loss_kwargs.get('tri')
353
+ loss_dict['loss_triplet_b1'] = triplet_loss(
354
+ b1_pool_feat,
355
+ gt_labels,
356
+ tri_kwargs.get('margin'),
357
+ tri_kwargs.get('norm_feat'),
358
+ tri_kwargs.get('hard_mining')
359
+ ) * tri_kwargs.get('scale') * 0.2
360
+
361
+ loss_dict['loss_triplet_b2'] = triplet_loss(
362
+ b2_pool_feat,
363
+ gt_labels,
364
+ tri_kwargs.get('margin'),
365
+ tri_kwargs.get('norm_feat'),
366
+ tri_kwargs.get('hard_mining')
367
+ ) * tri_kwargs.get('scale') * 0.2
368
+
369
+ loss_dict['loss_triplet_b3'] = triplet_loss(
370
+ b3_pool_feat,
371
+ gt_labels,
372
+ tri_kwargs.get('margin'),
373
+ tri_kwargs.get('norm_feat'),
374
+ tri_kwargs.get('hard_mining')
375
+ ) * tri_kwargs.get('scale') * 0.2
376
+
377
+ loss_dict['loss_triplet_b22'] = triplet_loss(
378
+ b22_pool_feat,
379
+ gt_labels,
380
+ tri_kwargs.get('margin'),
381
+ tri_kwargs.get('norm_feat'),
382
+ tri_kwargs.get('hard_mining')
383
+ ) * tri_kwargs.get('scale') * 0.2
384
+
385
+ loss_dict['loss_triplet_b33'] = triplet_loss(
386
+ b33_pool_feat,
387
+ gt_labels,
388
+
389
+ tri_kwargs.get('margin'),
390
+ tri_kwargs.get('norm_feat'),
391
+ tri_kwargs.get('hard_mining')
392
+ ) * tri_kwargs.get('scale') * 0.2
393
+
394
+ return loss_dict
@@ -0,0 +1,126 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: xingyu liao
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from fastreid.modeling.losses.utils import concat_all_gather
12
+ from fastreid.utils import comm
13
+ from .baseline import Baseline
14
+ from .build import META_ARCH_REGISTRY
15
+
16
+
17
+ @META_ARCH_REGISTRY.register()
18
+ class MoCo(Baseline):
19
+ def __init__(self, cfg):
20
+ super().__init__(cfg)
21
+
22
+ dim = cfg.MODEL.HEADS.EMBEDDING_DIM if cfg.MODEL.HEADS.EMBEDDING_DIM \
23
+ else cfg.MODEL.BACKBONE.FEAT_DIM
24
+ size = cfg.MODEL.QUEUE_SIZE
25
+ self.memory = Memory(dim, size)
26
+
27
+ def losses(self, outputs, gt_labels):
28
+ """
29
+ Compute loss from modeling's outputs, the loss function input arguments
30
+ must be the same as the outputs of the model forwarding.
31
+ """
32
+ # regular reid loss
33
+ loss_dict = super().losses(outputs, gt_labels)
34
+
35
+ # memory loss
36
+ pred_features = outputs['features']
37
+ loss_mb = self.memory(pred_features, gt_labels)
38
+ loss_dict['loss_mb'] = loss_mb
39
+ return loss_dict
40
+
41
+
42
+ class Memory(nn.Module):
43
+ """
44
+ Build a MoCo memory with a queue
45
+ https://arxiv.org/abs/1911.05722
46
+ """
47
+
48
+ def __init__(self, dim=512, K=65536):
49
+ """
50
+ dim: feature dimension (default: 128)
51
+ K: queue size; number of negative keys (default: 65536)
52
+ """
53
+ super().__init__()
54
+ self.K = K
55
+
56
+ self.margin = 0.25
57
+ self.gamma = 32
58
+
59
+ # create the queue
60
+ self.register_buffer("queue", torch.randn(dim, K))
61
+ self.queue = F.normalize(self.queue, dim=0)
62
+
63
+ self.register_buffer("queue_label", torch.zeros((1, K), dtype=torch.long))
64
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
65
+
66
+ @torch.no_grad()
67
+ def _dequeue_and_enqueue(self, keys, targets):
68
+ # gather keys/targets before updating queue
69
+ if comm.get_world_size() > 1:
70
+ keys = concat_all_gather(keys)
71
+ targets = concat_all_gather(targets)
72
+ else:
73
+ keys = keys.detach()
74
+ targets = targets.detach()
75
+
76
+ batch_size = keys.shape[0]
77
+
78
+ ptr = int(self.queue_ptr)
79
+ assert self.K % batch_size == 0 # for simplicity
80
+
81
+ # replace the keys at ptr (dequeue and enqueue)
82
+ self.queue[:, ptr:ptr + batch_size] = keys.T
83
+ self.queue_label[:, ptr:ptr + batch_size] = targets
84
+ ptr = (ptr + batch_size) % self.K # move pointer
85
+
86
+ self.queue_ptr[0] = ptr
87
+
88
+ def forward(self, feat_q, targets):
89
+ """
90
+ Memory bank enqueue and compute metric loss
91
+ Args:
92
+ feat_q: model features
93
+ targets: gt labels
94
+
95
+ Returns:
96
+ """
97
+ # normalize embedding features
98
+ feat_q = F.normalize(feat_q, p=2, dim=1)
99
+ # dequeue and enqueue
100
+ self._dequeue_and_enqueue(feat_q.detach(), targets)
101
+ # compute loss
102
+ loss = self._pairwise_cosface(feat_q, targets)
103
+ return loss
104
+
105
+ def _pairwise_cosface(self, feat_q, targets):
106
+ dist_mat = torch.matmul(feat_q, self.queue)
107
+
108
+ N, M = dist_mat.size() # (bsz, memory)
109
+ is_pos = targets.view(N, 1).expand(N, M).eq(self.queue_label.expand(N, M)).float()
110
+ is_neg = targets.view(N, 1).expand(N, M).ne(self.queue_label.expand(N, M)).float()
111
+
112
+ # Mask scores related to themselves
113
+ same_indx = torch.eye(N, N, device=is_pos.device)
114
+ other_indx = torch.zeros(N, M - N, device=is_pos.device)
115
+ same_indx = torch.cat((same_indx, other_indx), dim=1)
116
+ is_pos = is_pos - same_indx
117
+
118
+ s_p = dist_mat * is_pos
119
+ s_n = dist_mat * is_neg
120
+
121
+ logit_p = -self.gamma * s_p + (-99999999.) * (1 - is_pos)
122
+ logit_n = self.gamma * (s_n + self.margin) + (-99999999.) * (1 - is_neg)
123
+
124
+ loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
125
+
126
+ return loss
@@ -0,0 +1,8 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+
8
+ from .build import build_lr_scheduler, build_optimizer