dnt 0.2.4__py3-none-any.whl → 0.3.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dnt might be problematic. Click here for more details.

Files changed (305) hide show
  1. dnt/__init__.py +3 -2
  2. dnt/analysis/__init__.py +3 -2
  3. dnt/analysis/interaction.py +503 -0
  4. dnt/analysis/stop.py +22 -17
  5. dnt/analysis/stop2.py +289 -0
  6. dnt/analysis/stop3.py +754 -0
  7. dnt/detect/signal/detector.py +317 -0
  8. dnt/detect/yolov8/detector.py +116 -16
  9. dnt/engine/__init__.py +8 -0
  10. dnt/engine/bbox_interp.py +83 -0
  11. dnt/engine/bbox_iou.py +20 -0
  12. dnt/engine/cluster.py +31 -0
  13. dnt/engine/iob.py +66 -0
  14. dnt/filter/filter.py +321 -1
  15. dnt/label/labeler.py +4 -4
  16. dnt/label/labeler2.py +502 -0
  17. dnt/shared/__init__.py +2 -1
  18. dnt/shared/data/coco.names +0 -0
  19. dnt/shared/data/openimages.names +0 -0
  20. dnt/shared/data/voc.names +0 -0
  21. dnt/shared/download.py +12 -0
  22. dnt/shared/synhcro.py +150 -0
  23. dnt/shared/util.py +17 -4
  24. dnt/third_party/fast-reid/__init__.py +1 -0
  25. dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
  26. dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
  27. dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
  28. dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
  29. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
  30. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
  31. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
  32. dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
  33. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
  34. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
  35. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
  36. dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
  37. dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
  38. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
  39. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
  40. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
  41. dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
  42. dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
  43. dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
  44. dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
  45. dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
  46. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
  47. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
  48. dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
  49. dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
  50. dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
  51. dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
  52. dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
  53. dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
  54. dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
  55. dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
  56. dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
  57. dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
  58. dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
  59. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
  60. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
  61. dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
  62. dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
  63. dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
  64. dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
  65. dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
  66. dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
  67. dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
  68. dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
  69. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
  70. dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
  71. dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
  72. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
  73. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
  74. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
  75. dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
  76. dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
  77. dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
  78. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
  79. dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
  80. dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
  81. dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
  82. dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
  83. dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
  84. dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
  85. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
  86. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
  87. dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
  88. dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
  89. dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
  90. dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
  91. dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
  92. dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
  93. dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
  94. dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
  95. dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
  96. dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
  97. dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
  98. dnt/third_party/fast-reid/configs/__init__.py +0 -0
  99. dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
  100. dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
  101. dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
  102. dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
  103. dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
  104. dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
  105. dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
  106. dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
  107. dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
  108. dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
  109. dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
  110. dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
  111. dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
  112. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
  113. dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
  114. dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
  115. dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
  116. dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
  117. dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
  118. dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
  119. dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
  120. dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
  121. dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
  122. dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
  123. dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
  124. dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
  125. dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
  126. dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
  127. dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
  128. dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
  129. dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
  130. dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
  131. dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
  132. dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
  133. dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
  134. dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
  135. dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
  136. dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
  137. dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
  138. dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
  139. dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
  140. dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
  141. dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
  142. dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
  143. dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
  144. dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
  145. dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
  146. dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
  147. dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
  148. dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
  149. dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
  150. dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
  151. dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
  152. dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
  153. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
  154. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
  155. dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
  156. dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
  157. dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
  158. dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
  159. dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
  160. dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
  161. dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
  162. dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
  163. dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
  164. dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
  165. dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
  166. dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
  167. dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
  168. dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
  169. dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
  170. dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
  171. dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
  172. dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
  173. dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
  174. dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
  175. dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
  176. dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
  177. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
  178. dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
  179. dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
  180. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
  181. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
  182. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
  183. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
  184. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
  185. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
  186. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
  187. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
  188. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
  189. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
  190. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
  191. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
  192. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
  193. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
  194. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
  195. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
  196. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
  197. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
  198. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
  199. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
  200. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
  201. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
  202. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
  203. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
  204. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
  205. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
  206. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
  207. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
  208. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
  209. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
  210. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
  211. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
  212. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
  213. dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
  214. dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
  215. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
  216. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
  217. dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
  218. dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
  219. dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
  220. dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
  221. dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
  222. dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
  223. dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
  224. dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
  225. dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
  226. dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
  227. dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
  228. dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
  229. dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
  230. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
  231. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
  232. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
  233. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
  234. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
  235. dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
  236. dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
  237. dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
  238. dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
  239. dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
  240. dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
  241. dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
  242. dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
  243. dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
  244. dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
  245. dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
  246. dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
  247. dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
  248. dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
  249. dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
  250. dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
  251. dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
  252. dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
  253. dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
  254. dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
  255. dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
  256. dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
  257. dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
  258. dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
  259. dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
  260. dnt/track/__init__.py +2 -0
  261. dnt/track/botsort/__init__.py +4 -0
  262. dnt/track/botsort/bot_tracker/__init__.py +3 -0
  263. dnt/track/botsort/bot_tracker/basetrack.py +60 -0
  264. dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
  265. dnt/track/botsort/bot_tracker/gmc.py +316 -0
  266. dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
  267. dnt/track/botsort/bot_tracker/matching.py +194 -0
  268. dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
  269. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
  270. dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
  271. dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
  272. dnt/track/botsort/inference.py +96 -0
  273. dnt/track/config.py +120 -0
  274. dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
  275. dnt/track/dsort/configs/deep_sort.yaml +0 -0
  276. dnt/track/dsort/configs/fastreid.yaml +1 -1
  277. dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
  278. dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
  279. dnt/track/dsort/deep_sort/deep_sort.py +28 -18
  280. dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
  281. dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
  282. dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
  283. dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
  284. dnt/track/dsort/dsort.py +21 -28
  285. dnt/track/re_class.py +94 -0
  286. dnt/track/sort/sort.py +5 -1
  287. dnt/track/tracker.py +207 -30
  288. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/METADATA +30 -10
  289. dnt-0.3.1.3.dist-info/RECORD +314 -0
  290. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/WHEEL +1 -1
  291. dnt/analysis/yield.py +0 -9
  292. dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
  293. dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
  294. dnt/track/dsort/deep_sort/deep/test.py +0 -77
  295. dnt/track/dsort/deep_sort/deep/train.py +0 -189
  296. dnt/track/dsort/utils/asserts.py +0 -13
  297. dnt/track/dsort/utils/draw.py +0 -36
  298. dnt/track/dsort/utils/json_logger.py +0 -383
  299. dnt/track/dsort/utils/log.py +0 -17
  300. dnt/track/dsort/utils/parser.py +0 -35
  301. dnt/track/dsort/utils/tools.py +0 -39
  302. dnt-0.2.4.dist-info/RECORD +0 -64
  303. /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
  304. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/LICENSE +0 -0
  305. {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import logging
3
+ import pprint
4
+ import sys
5
+ from collections.abc import Mapping
6
+ from collections import OrderedDict
7
+
8
+ import numpy as np
9
+ from tabulate import tabulate
10
+ from termcolor import colored
11
+
12
+
13
+ def print_csv_format(results):
14
+ """
15
+ Print main metrics in a format similar to Detectron2,
16
+ so that they are easy to copypaste into a spreadsheet.
17
+ Args:
18
+ results (OrderedDict): {metric -> score}
19
+ """
20
+ # unordered results cannot be properly printed
21
+ assert isinstance(results, OrderedDict) or not len(results), results
22
+ logger = logging.getLogger(__name__)
23
+
24
+ dataset_name = results.pop('dataset')
25
+ metrics = ["Dataset"] + [k for k in results]
26
+ csv_results = [(dataset_name, *list(results.values()))]
27
+
28
+ # tabulate it
29
+ table = tabulate(
30
+ csv_results,
31
+ tablefmt="pipe",
32
+ floatfmt=".2f",
33
+ headers=metrics,
34
+ numalign="left",
35
+ )
36
+
37
+ logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
38
+
39
+
40
+ def verify_results(cfg, results):
41
+ """
42
+ Args:
43
+ results (OrderedDict[dict]): task_name -> {metric -> score}
44
+ Returns:
45
+ bool: whether the verification succeeds or not
46
+ """
47
+ expected_results = cfg.TEST.EXPECTED_RESULTS
48
+ if not len(expected_results):
49
+ return True
50
+
51
+ ok = True
52
+ for task, metric, expected, tolerance in expected_results:
53
+ actual = results[task][metric]
54
+ if not np.isfinite(actual):
55
+ ok = False
56
+ diff = abs(actual - expected)
57
+ if diff > tolerance:
58
+ ok = False
59
+
60
+ logger = logging.getLogger(__name__)
61
+ if not ok:
62
+ logger.error("Result verification failed!")
63
+ logger.error("Expected Results: " + str(expected_results))
64
+ logger.error("Actual Results: " + pprint.pformat(results))
65
+
66
+ sys.exit(1)
67
+ else:
68
+ logger.info("Results verification passed.")
69
+ return ok
70
+
71
+
72
+ def flatten_results_dict(results):
73
+ """
74
+ Expand a hierarchical dict of scalars into a flat dict of scalars.
75
+ If results[k1][k2][k3] = v, the returned dict will have the entry
76
+ {"k1/k2/k3": v}.
77
+ Args:
78
+ results (dict):
79
+ """
80
+ r = {}
81
+ for k, v in results.items():
82
+ if isinstance(v, Mapping):
83
+ v = flatten_results_dict(v)
84
+ for kk, vv in v.items():
85
+ r[k + "/" + kk] = vv
86
+ else:
87
+ r[k] = v
88
+ return r
@@ -0,0 +1,19 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ from .activation import *
8
+ from .batch_norm import *
9
+ from .context_block import ContextBlock
10
+ from .drop import DropPath, DropBlock2d, drop_block_2d, drop_path
11
+ from .frn import FRN, TLU
12
+ from .gather_layer import GatherLayer
13
+ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
14
+ from .non_local import Non_local
15
+ from .se_layer import SELayer
16
+ from .splat import SplAtConv2d, DropBlock2D
17
+ from .weight_init import (
18
+ trunc_normal_, variance_scaling_, lecun_normal_, weights_init_kaiming, weights_init_classifier
19
+ )
@@ -0,0 +1,59 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: xingyu liao
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ __all__ = [
14
+ 'Mish',
15
+ 'Swish',
16
+ 'MemoryEfficientSwish',
17
+ 'GELU']
18
+
19
+
20
+ class Mish(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+ def forward(self, x):
25
+ # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
26
+ return x * (torch.tanh(F.softplus(x)))
27
+
28
+
29
+ class Swish(nn.Module):
30
+ def forward(self, x):
31
+ return x * torch.sigmoid(x)
32
+
33
+
34
+ class SwishImplementation(torch.autograd.Function):
35
+ @staticmethod
36
+ def forward(ctx, i):
37
+ result = i * torch.sigmoid(i)
38
+ ctx.save_for_backward(i)
39
+ return result
40
+
41
+ @staticmethod
42
+ def backward(ctx, grad_output):
43
+ i = ctx.saved_variables[0]
44
+ sigmoid_i = torch.sigmoid(i)
45
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
46
+
47
+
48
+ class MemoryEfficientSwish(nn.Module):
49
+ def forward(self, x):
50
+ return SwishImplementation.apply(x)
51
+
52
+
53
+ class GELU(nn.Module):
54
+ """
55
+ Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
56
+ """
57
+
58
+ def forward(self, x):
59
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
@@ -0,0 +1,80 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ __all__ = [
11
+ "Linear",
12
+ "ArcSoftmax",
13
+ "CosSoftmax",
14
+ "CircleSoftmax"
15
+ ]
16
+
17
+
18
+ class Linear(nn.Module):
19
+ def __init__(self, num_classes, scale, margin):
20
+ super().__init__()
21
+ self.num_classes = num_classes
22
+ self.s = scale
23
+ self.m = margin
24
+
25
+ def forward(self, logits, targets):
26
+ return logits.mul_(self.s)
27
+
28
+ def extra_repr(self):
29
+ return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}"
30
+
31
+
32
+ class CosSoftmax(Linear):
33
+ r"""Implement of large margin cosine distance:
34
+ """
35
+
36
+ def forward(self, logits, targets):
37
+ index = torch.where(targets != -1)[0]
38
+ m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
39
+ m_hot.scatter_(1, targets[index, None], self.m)
40
+ logits[index] -= m_hot
41
+ logits.mul_(self.s)
42
+ return logits
43
+
44
+
45
+ class ArcSoftmax(Linear):
46
+
47
+ def forward(self, logits, targets):
48
+ index = torch.where(targets != -1)[0]
49
+ m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
50
+ m_hot.scatter_(1, targets[index, None], self.m)
51
+ logits.acos_()
52
+ logits[index] += m_hot
53
+ logits.cos_().mul_(self.s)
54
+ return logits
55
+
56
+
57
+ class CircleSoftmax(Linear):
58
+
59
+ def forward(self, logits, targets):
60
+ alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.)
61
+ alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.)
62
+ delta_p = 1 - self.m
63
+ delta_n = self.m
64
+
65
+ # When use model parallel, there are some targets not in class centers of local rank
66
+ index = torch.where(targets != -1)[0]
67
+ m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
68
+ m_hot.scatter_(1, targets[index, None], 1)
69
+
70
+ logits_p = alpha_p * (logits - delta_p)
71
+ logits_n = alpha_n * (logits - delta_n)
72
+
73
+ logits[index] = logits_p[index] * m_hot + logits_n[index] * (1 - m_hot)
74
+
75
+ neg_index = torch.where(targets == -1)[0]
76
+ logits[neg_index] = logits_n[neg_index]
77
+
78
+ logits.mul_(self.s)
79
+
80
+ return logits
@@ -0,0 +1,205 @@
1
+ # encoding: utf-8
2
+ """
3
+ @author: liaoxingyu
4
+ @contact: sherlockliao01@gmail.com
5
+ """
6
+
7
+ import logging
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ __all__ = ["IBN", "get_norm"]
14
+
15
+
16
+ class BatchNorm(nn.BatchNorm2d):
17
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
18
+ bias_init=0.0, **kwargs):
19
+ super().__init__(num_features, eps=eps, momentum=momentum)
20
+ if weight_init is not None: nn.init.constant_(self.weight, weight_init)
21
+ if bias_init is not None: nn.init.constant_(self.bias, bias_init)
22
+ self.weight.requires_grad_(not weight_freeze)
23
+ self.bias.requires_grad_(not bias_freeze)
24
+
25
+
26
+ class SyncBatchNorm(nn.SyncBatchNorm):
27
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
28
+ bias_init=0.0):
29
+ super().__init__(num_features, eps=eps, momentum=momentum)
30
+ if weight_init is not None: nn.init.constant_(self.weight, weight_init)
31
+ if bias_init is not None: nn.init.constant_(self.bias, bias_init)
32
+ self.weight.requires_grad_(not weight_freeze)
33
+ self.bias.requires_grad_(not bias_freeze)
34
+
35
+
36
+ class IBN(nn.Module):
37
+ def __init__(self, planes, bn_norm, **kwargs):
38
+ super(IBN, self).__init__()
39
+ half1 = int(planes / 2)
40
+ self.half = half1
41
+ half2 = planes - half1
42
+ self.IN = nn.InstanceNorm2d(half1, affine=True)
43
+ self.BN = get_norm(bn_norm, half2, **kwargs)
44
+
45
+ def forward(self, x):
46
+ split = torch.split(x, self.half, 1)
47
+ out1 = self.IN(split[0].contiguous())
48
+ out2 = self.BN(split[1].contiguous())
49
+ out = torch.cat((out1, out2), 1)
50
+ return out
51
+
52
+
53
+ class GhostBatchNorm(BatchNorm):
54
+ def __init__(self, num_features, num_splits=1, **kwargs):
55
+ super().__init__(num_features, **kwargs)
56
+ self.num_splits = num_splits
57
+ self.register_buffer('running_mean', torch.zeros(num_features))
58
+ self.register_buffer('running_var', torch.ones(num_features))
59
+
60
+ def forward(self, input):
61
+ N, C, H, W = input.shape
62
+ if self.training or not self.track_running_stats:
63
+ self.running_mean = self.running_mean.repeat(self.num_splits)
64
+ self.running_var = self.running_var.repeat(self.num_splits)
65
+ outputs = F.batch_norm(
66
+ input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
67
+ self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
68
+ True, self.momentum, self.eps).view(N, C, H, W)
69
+ self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
70
+ self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
71
+ return outputs
72
+ else:
73
+ return F.batch_norm(
74
+ input, self.running_mean, self.running_var,
75
+ self.weight, self.bias, False, self.momentum, self.eps)
76
+
77
+
78
+ class FrozenBatchNorm(nn.Module):
79
+ """
80
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
81
+ It contains non-trainable buffers called
82
+ "weight" and "bias", "running_mean", "running_var",
83
+ initialized to perform identity transformation.
84
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
85
+ which are computed from the original four parameters of BN.
86
+ The affine transform `x * weight + bias` will perform the equivalent
87
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
88
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
89
+ will be left unchanged as identity transformation.
90
+ Other pre-trained backbone models may contain all 4 parameters.
91
+ The forward is implemented by `F.batch_norm(..., training=False)`.
92
+ """
93
+
94
+ _version = 3
95
+
96
+ def __init__(self, num_features, eps=1e-5, **kwargs):
97
+ super().__init__()
98
+ self.num_features = num_features
99
+ self.eps = eps
100
+ self.register_buffer("weight", torch.ones(num_features))
101
+ self.register_buffer("bias", torch.zeros(num_features))
102
+ self.register_buffer("running_mean", torch.zeros(num_features))
103
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
104
+
105
+ def forward(self, x):
106
+ if x.requires_grad:
107
+ # When gradients are needed, F.batch_norm will use extra memory
108
+ # because its backward op computes gradients for weight/bias as well.
109
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
110
+ bias = self.bias - self.running_mean * scale
111
+ scale = scale.reshape(1, -1, 1, 1)
112
+ bias = bias.reshape(1, -1, 1, 1)
113
+ return x * scale + bias
114
+ else:
115
+ # When gradients are not needed, F.batch_norm is a single fused op
116
+ # and provide more optimization opportunities.
117
+ return F.batch_norm(
118
+ x,
119
+ self.running_mean,
120
+ self.running_var,
121
+ self.weight,
122
+ self.bias,
123
+ training=False,
124
+ eps=self.eps,
125
+ )
126
+
127
+ def _load_from_state_dict(
128
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
129
+ ):
130
+ version = local_metadata.get("version", None)
131
+
132
+ if version is None or version < 2:
133
+ # No running_mean/var in early versions
134
+ # This will silent the warnings
135
+ if prefix + "running_mean" not in state_dict:
136
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
137
+ if prefix + "running_var" not in state_dict:
138
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
139
+
140
+ if version is not None and version < 3:
141
+ logger = logging.getLogger(__name__)
142
+ logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
143
+ # In version < 3, running_var are used without +eps.
144
+ state_dict[prefix + "running_var"] -= self.eps
145
+
146
+ super()._load_from_state_dict(
147
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
148
+ )
149
+
150
+ def __repr__(self):
151
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
152
+
153
+ @classmethod
154
+ def convert_frozen_batchnorm(cls, module):
155
+ """
156
+ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
157
+ Args:
158
+ module (torch.nn.Module):
159
+ Returns:
160
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
161
+ Otherwise, in-place convert module and return it.
162
+ Similar to convert_sync_batchnorm in
163
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
164
+ """
165
+ bn_module = nn.modules.batchnorm
166
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
167
+ res = module
168
+ if isinstance(module, bn_module):
169
+ res = cls(module.num_features)
170
+ if module.affine:
171
+ res.weight.data = module.weight.data.clone().detach()
172
+ res.bias.data = module.bias.data.clone().detach()
173
+ res.running_mean.data = module.running_mean.data
174
+ res.running_var.data = module.running_var.data
175
+ res.eps = module.eps
176
+ else:
177
+ for name, child in module.named_children():
178
+ new_child = cls.convert_frozen_batchnorm(child)
179
+ if new_child is not child:
180
+ res.add_module(name, new_child)
181
+ return res
182
+
183
+
184
+ def get_norm(norm, out_channels, **kwargs):
185
+ """
186
+ Args:
187
+ norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
188
+ or a callable that takes a channel number and returns
189
+ the normalization layer as a nn.Module
190
+ out_channels: number of channels for normalization layer
191
+
192
+ Returns:
193
+ nn.Module or None: the normalization layer
194
+ """
195
+ if isinstance(norm, str):
196
+ if len(norm) == 0:
197
+ return None
198
+ norm = {
199
+ "BN": BatchNorm,
200
+ "syncBN": SyncBatchNorm,
201
+ "GhostBN": GhostBatchNorm,
202
+ "FrozenBN": FrozenBatchNorm,
203
+ "GN": lambda channels, **args: nn.GroupNorm(32, channels),
204
+ }[norm]
205
+ return norm(out_channels, **kwargs)
@@ -0,0 +1,113 @@
1
+ # copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ __all__ = ['ContextBlock']
7
+
8
+
9
+ def last_zero_init(m):
10
+ if isinstance(m, nn.Sequential):
11
+ nn.init.constant_(m[-1].weight, val=0)
12
+ if hasattr(m[-1], 'bias') and m[-1].bias is not None:
13
+ nn.init.constant_(m[-1].bias, 0)
14
+ else:
15
+ nn.init.constant_(m.weight, val=0)
16
+ if hasattr(m, 'bias') and m.bias is not None:
17
+ nn.init.constant_(m.bias, 0)
18
+
19
+
20
+ class ContextBlock(nn.Module):
21
+
22
+ def __init__(self,
23
+ inplanes,
24
+ ratio,
25
+ pooling_type='att',
26
+ fusion_types=('channel_add',)):
27
+ super(ContextBlock, self).__init__()
28
+ assert pooling_type in ['avg', 'att']
29
+ assert isinstance(fusion_types, (list, tuple))
30
+ valid_fusion_types = ['channel_add', 'channel_mul']
31
+ assert all([f in valid_fusion_types for f in fusion_types])
32
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
33
+ self.inplanes = inplanes
34
+ self.ratio = ratio
35
+ self.planes = int(inplanes * ratio)
36
+ self.pooling_type = pooling_type
37
+ self.fusion_types = fusion_types
38
+ if pooling_type == 'att':
39
+ self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
40
+ self.softmax = nn.Softmax(dim=2)
41
+ else:
42
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
43
+ if 'channel_add' in fusion_types:
44
+ self.channel_add_conv = nn.Sequential(
45
+ nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
46
+ nn.LayerNorm([self.planes, 1, 1]),
47
+ nn.ReLU(inplace=True), # yapf: disable
48
+ nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
49
+ else:
50
+ self.channel_add_conv = None
51
+ if 'channel_mul' in fusion_types:
52
+ self.channel_mul_conv = nn.Sequential(
53
+ nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
54
+ nn.LayerNorm([self.planes, 1, 1]),
55
+ nn.ReLU(inplace=True), # yapf: disable
56
+ nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
57
+ else:
58
+ self.channel_mul_conv = None
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self):
62
+ if self.pooling_type == 'att':
63
+ nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
64
+ if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
65
+ nn.init.constant_(self.conv_mask.bias, 0)
66
+ self.conv_mask.inited = True
67
+
68
+ if self.channel_add_conv is not None:
69
+ last_zero_init(self.channel_add_conv)
70
+ if self.channel_mul_conv is not None:
71
+ last_zero_init(self.channel_mul_conv)
72
+
73
+ def spatial_pool(self, x):
74
+ batch, channel, height, width = x.size()
75
+ if self.pooling_type == 'att':
76
+ input_x = x
77
+ # [N, C, H * W]
78
+ input_x = input_x.view(batch, channel, height * width)
79
+ # [N, 1, C, H * W]
80
+ input_x = input_x.unsqueeze(1)
81
+ # [N, 1, H, W]
82
+ context_mask = self.conv_mask(x)
83
+ # [N, 1, H * W]
84
+ context_mask = context_mask.view(batch, 1, height * width)
85
+ # [N, 1, H * W]
86
+ context_mask = self.softmax(context_mask)
87
+ # [N, 1, H * W, 1]
88
+ context_mask = context_mask.unsqueeze(-1)
89
+ # [N, 1, C, 1]
90
+ context = torch.matmul(input_x, context_mask)
91
+ # [N, C, 1, 1]
92
+ context = context.view(batch, channel, 1, 1)
93
+ else:
94
+ # [N, C, 1, 1]
95
+ context = self.avg_pool(x)
96
+
97
+ return context
98
+
99
+ def forward(self, x):
100
+ # [N, C, 1, 1]
101
+ context = self.spatial_pool(x)
102
+
103
+ out = x
104
+ if self.channel_mul_conv is not None:
105
+ # [N, C, 1, 1]
106
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
107
+ out = out * channel_mul_term
108
+ if self.channel_add_conv is not None:
109
+ # [N, C, 1, 1]
110
+ channel_add_term = self.channel_add_conv(context)
111
+ out = out + channel_add_term
112
+
113
+ return out