dnt 0.2.4__py3-none-any.whl → 0.3.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dnt might be problematic. Click here for more details.
- dnt/__init__.py +3 -2
- dnt/analysis/__init__.py +3 -2
- dnt/analysis/interaction.py +503 -0
- dnt/analysis/stop.py +22 -17
- dnt/analysis/stop2.py +289 -0
- dnt/analysis/stop3.py +754 -0
- dnt/detect/signal/detector.py +317 -0
- dnt/detect/yolov8/detector.py +116 -16
- dnt/engine/__init__.py +8 -0
- dnt/engine/bbox_interp.py +83 -0
- dnt/engine/bbox_iou.py +20 -0
- dnt/engine/cluster.py +31 -0
- dnt/engine/iob.py +66 -0
- dnt/filter/filter.py +321 -1
- dnt/label/labeler.py +4 -4
- dnt/label/labeler2.py +502 -0
- dnt/shared/__init__.py +2 -1
- dnt/shared/data/coco.names +0 -0
- dnt/shared/data/openimages.names +0 -0
- dnt/shared/data/voc.names +0 -0
- dnt/shared/download.py +12 -0
- dnt/shared/synhcro.py +150 -0
- dnt/shared/util.py +17 -4
- dnt/third_party/fast-reid/__init__.py +1 -0
- dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
- dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
- dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
- dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
- dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
- dnt/third_party/fast-reid/configs/__init__.py +0 -0
- dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
- dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
- dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
- dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
- dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
- dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
- dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
- dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
- dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
- dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
- dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
- dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
- dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
- dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
- dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
- dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
- dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
- dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
- dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
- dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
- dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
- dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
- dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
- dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
- dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
- dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
- dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
- dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
- dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
- dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
- dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
- dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
- dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
- dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
- dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
- dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
- dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
- dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
- dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
- dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
- dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
- dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
- dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
- dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
- dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
- dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
- dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
- dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
- dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
- dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
- dnt/track/__init__.py +2 -0
- dnt/track/botsort/__init__.py +4 -0
- dnt/track/botsort/bot_tracker/__init__.py +3 -0
- dnt/track/botsort/bot_tracker/basetrack.py +60 -0
- dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
- dnt/track/botsort/bot_tracker/gmc.py +316 -0
- dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
- dnt/track/botsort/bot_tracker/matching.py +194 -0
- dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
- dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
- dnt/track/botsort/inference.py +96 -0
- dnt/track/config.py +120 -0
- dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
- dnt/track/dsort/configs/deep_sort.yaml +0 -0
- dnt/track/dsort/configs/fastreid.yaml +1 -1
- dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
- dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
- dnt/track/dsort/deep_sort/deep_sort.py +28 -18
- dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
- dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
- dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
- dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
- dnt/track/dsort/dsort.py +21 -28
- dnt/track/re_class.py +94 -0
- dnt/track/sort/sort.py +5 -1
- dnt/track/tracker.py +207 -30
- {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/METADATA +30 -10
- dnt-0.3.1.3.dist-info/RECORD +314 -0
- {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/WHEEL +1 -1
- dnt/analysis/yield.py +0 -9
- dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
- dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
- dnt/track/dsort/deep_sort/deep/test.py +0 -77
- dnt/track/dsort/deep_sort/deep/train.py +0 -189
- dnt/track/dsort/utils/asserts.py +0 -13
- dnt/track/dsort/utils/draw.py +0 -36
- dnt/track/dsort/utils/json_logger.py +0 -383
- dnt/track/dsort/utils/log.py +0 -17
- dnt/track/dsort/utils/parser.py +0 -35
- dnt/track/dsort/utils/tools.py +0 -39
- dnt-0.2.4.dist-info/RECORD +0 -64
- /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
- {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/LICENSE +0 -0
- {dnt-0.2.4.dist-info → dnt-0.3.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
2
|
+
import logging
|
|
3
|
+
import pprint
|
|
4
|
+
import sys
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from tabulate import tabulate
|
|
10
|
+
from termcolor import colored
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def print_csv_format(results):
|
|
14
|
+
"""
|
|
15
|
+
Print main metrics in a format similar to Detectron2,
|
|
16
|
+
so that they are easy to copypaste into a spreadsheet.
|
|
17
|
+
Args:
|
|
18
|
+
results (OrderedDict): {metric -> score}
|
|
19
|
+
"""
|
|
20
|
+
# unordered results cannot be properly printed
|
|
21
|
+
assert isinstance(results, OrderedDict) or not len(results), results
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
dataset_name = results.pop('dataset')
|
|
25
|
+
metrics = ["Dataset"] + [k for k in results]
|
|
26
|
+
csv_results = [(dataset_name, *list(results.values()))]
|
|
27
|
+
|
|
28
|
+
# tabulate it
|
|
29
|
+
table = tabulate(
|
|
30
|
+
csv_results,
|
|
31
|
+
tablefmt="pipe",
|
|
32
|
+
floatfmt=".2f",
|
|
33
|
+
headers=metrics,
|
|
34
|
+
numalign="left",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def verify_results(cfg, results):
|
|
41
|
+
"""
|
|
42
|
+
Args:
|
|
43
|
+
results (OrderedDict[dict]): task_name -> {metric -> score}
|
|
44
|
+
Returns:
|
|
45
|
+
bool: whether the verification succeeds or not
|
|
46
|
+
"""
|
|
47
|
+
expected_results = cfg.TEST.EXPECTED_RESULTS
|
|
48
|
+
if not len(expected_results):
|
|
49
|
+
return True
|
|
50
|
+
|
|
51
|
+
ok = True
|
|
52
|
+
for task, metric, expected, tolerance in expected_results:
|
|
53
|
+
actual = results[task][metric]
|
|
54
|
+
if not np.isfinite(actual):
|
|
55
|
+
ok = False
|
|
56
|
+
diff = abs(actual - expected)
|
|
57
|
+
if diff > tolerance:
|
|
58
|
+
ok = False
|
|
59
|
+
|
|
60
|
+
logger = logging.getLogger(__name__)
|
|
61
|
+
if not ok:
|
|
62
|
+
logger.error("Result verification failed!")
|
|
63
|
+
logger.error("Expected Results: " + str(expected_results))
|
|
64
|
+
logger.error("Actual Results: " + pprint.pformat(results))
|
|
65
|
+
|
|
66
|
+
sys.exit(1)
|
|
67
|
+
else:
|
|
68
|
+
logger.info("Results verification passed.")
|
|
69
|
+
return ok
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def flatten_results_dict(results):
|
|
73
|
+
"""
|
|
74
|
+
Expand a hierarchical dict of scalars into a flat dict of scalars.
|
|
75
|
+
If results[k1][k2][k3] = v, the returned dict will have the entry
|
|
76
|
+
{"k1/k2/k3": v}.
|
|
77
|
+
Args:
|
|
78
|
+
results (dict):
|
|
79
|
+
"""
|
|
80
|
+
r = {}
|
|
81
|
+
for k, v in results.items():
|
|
82
|
+
if isinstance(v, Mapping):
|
|
83
|
+
v = flatten_results_dict(v)
|
|
84
|
+
for kk, vv in v.items():
|
|
85
|
+
r[k + "/" + kk] = vv
|
|
86
|
+
else:
|
|
87
|
+
r[k] = v
|
|
88
|
+
return r
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .activation import *
|
|
8
|
+
from .batch_norm import *
|
|
9
|
+
from .context_block import ContextBlock
|
|
10
|
+
from .drop import DropPath, DropBlock2d, drop_block_2d, drop_path
|
|
11
|
+
from .frn import FRN, TLU
|
|
12
|
+
from .gather_layer import GatherLayer
|
|
13
|
+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
|
|
14
|
+
from .non_local import Non_local
|
|
15
|
+
from .se_layer import SELayer
|
|
16
|
+
from .splat import SplAtConv2d, DropBlock2D
|
|
17
|
+
from .weight_init import (
|
|
18
|
+
trunc_normal_, variance_scaling_, lecun_normal_, weights_init_kaiming, weights_init_classifier
|
|
19
|
+
)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: xingyu liao
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
'Mish',
|
|
15
|
+
'Swish',
|
|
16
|
+
'MemoryEfficientSwish',
|
|
17
|
+
'GELU']
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Mish(nn.Module):
|
|
21
|
+
def __init__(self):
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
def forward(self, x):
|
|
25
|
+
# inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
|
|
26
|
+
return x * (torch.tanh(F.softplus(x)))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Swish(nn.Module):
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
return x * torch.sigmoid(x)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SwishImplementation(torch.autograd.Function):
|
|
35
|
+
@staticmethod
|
|
36
|
+
def forward(ctx, i):
|
|
37
|
+
result = i * torch.sigmoid(i)
|
|
38
|
+
ctx.save_for_backward(i)
|
|
39
|
+
return result
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def backward(ctx, grad_output):
|
|
43
|
+
i = ctx.saved_variables[0]
|
|
44
|
+
sigmoid_i = torch.sigmoid(i)
|
|
45
|
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MemoryEfficientSwish(nn.Module):
|
|
49
|
+
def forward(self, x):
|
|
50
|
+
return SwishImplementation.apply(x)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class GELU(nn.Module):
|
|
54
|
+
"""
|
|
55
|
+
Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def forward(self, x):
|
|
59
|
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Linear",
|
|
12
|
+
"ArcSoftmax",
|
|
13
|
+
"CosSoftmax",
|
|
14
|
+
"CircleSoftmax"
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Linear(nn.Module):
|
|
19
|
+
def __init__(self, num_classes, scale, margin):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.num_classes = num_classes
|
|
22
|
+
self.s = scale
|
|
23
|
+
self.m = margin
|
|
24
|
+
|
|
25
|
+
def forward(self, logits, targets):
|
|
26
|
+
return logits.mul_(self.s)
|
|
27
|
+
|
|
28
|
+
def extra_repr(self):
|
|
29
|
+
return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CosSoftmax(Linear):
|
|
33
|
+
r"""Implement of large margin cosine distance:
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def forward(self, logits, targets):
|
|
37
|
+
index = torch.where(targets != -1)[0]
|
|
38
|
+
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
|
39
|
+
m_hot.scatter_(1, targets[index, None], self.m)
|
|
40
|
+
logits[index] -= m_hot
|
|
41
|
+
logits.mul_(self.s)
|
|
42
|
+
return logits
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ArcSoftmax(Linear):
|
|
46
|
+
|
|
47
|
+
def forward(self, logits, targets):
|
|
48
|
+
index = torch.where(targets != -1)[0]
|
|
49
|
+
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
|
50
|
+
m_hot.scatter_(1, targets[index, None], self.m)
|
|
51
|
+
logits.acos_()
|
|
52
|
+
logits[index] += m_hot
|
|
53
|
+
logits.cos_().mul_(self.s)
|
|
54
|
+
return logits
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CircleSoftmax(Linear):
|
|
58
|
+
|
|
59
|
+
def forward(self, logits, targets):
|
|
60
|
+
alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.)
|
|
61
|
+
alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.)
|
|
62
|
+
delta_p = 1 - self.m
|
|
63
|
+
delta_n = self.m
|
|
64
|
+
|
|
65
|
+
# When use model parallel, there are some targets not in class centers of local rank
|
|
66
|
+
index = torch.where(targets != -1)[0]
|
|
67
|
+
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
|
|
68
|
+
m_hot.scatter_(1, targets[index, None], 1)
|
|
69
|
+
|
|
70
|
+
logits_p = alpha_p * (logits - delta_p)
|
|
71
|
+
logits_n = alpha_n * (logits - delta_n)
|
|
72
|
+
|
|
73
|
+
logits[index] = logits_p[index] * m_hot + logits_n[index] * (1 - m_hot)
|
|
74
|
+
|
|
75
|
+
neg_index = torch.where(targets == -1)[0]
|
|
76
|
+
logits[neg_index] = logits_n[neg_index]
|
|
77
|
+
|
|
78
|
+
logits.mul_(self.s)
|
|
79
|
+
|
|
80
|
+
return logits
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
__all__ = ["IBN", "get_norm"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BatchNorm(nn.BatchNorm2d):
|
|
17
|
+
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
|
18
|
+
bias_init=0.0, **kwargs):
|
|
19
|
+
super().__init__(num_features, eps=eps, momentum=momentum)
|
|
20
|
+
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
|
|
21
|
+
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
|
|
22
|
+
self.weight.requires_grad_(not weight_freeze)
|
|
23
|
+
self.bias.requires_grad_(not bias_freeze)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SyncBatchNorm(nn.SyncBatchNorm):
|
|
27
|
+
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
|
28
|
+
bias_init=0.0):
|
|
29
|
+
super().__init__(num_features, eps=eps, momentum=momentum)
|
|
30
|
+
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
|
|
31
|
+
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
|
|
32
|
+
self.weight.requires_grad_(not weight_freeze)
|
|
33
|
+
self.bias.requires_grad_(not bias_freeze)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class IBN(nn.Module):
|
|
37
|
+
def __init__(self, planes, bn_norm, **kwargs):
|
|
38
|
+
super(IBN, self).__init__()
|
|
39
|
+
half1 = int(planes / 2)
|
|
40
|
+
self.half = half1
|
|
41
|
+
half2 = planes - half1
|
|
42
|
+
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
|
43
|
+
self.BN = get_norm(bn_norm, half2, **kwargs)
|
|
44
|
+
|
|
45
|
+
def forward(self, x):
|
|
46
|
+
split = torch.split(x, self.half, 1)
|
|
47
|
+
out1 = self.IN(split[0].contiguous())
|
|
48
|
+
out2 = self.BN(split[1].contiguous())
|
|
49
|
+
out = torch.cat((out1, out2), 1)
|
|
50
|
+
return out
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class GhostBatchNorm(BatchNorm):
|
|
54
|
+
def __init__(self, num_features, num_splits=1, **kwargs):
|
|
55
|
+
super().__init__(num_features, **kwargs)
|
|
56
|
+
self.num_splits = num_splits
|
|
57
|
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
58
|
+
self.register_buffer('running_var', torch.ones(num_features))
|
|
59
|
+
|
|
60
|
+
def forward(self, input):
|
|
61
|
+
N, C, H, W = input.shape
|
|
62
|
+
if self.training or not self.track_running_stats:
|
|
63
|
+
self.running_mean = self.running_mean.repeat(self.num_splits)
|
|
64
|
+
self.running_var = self.running_var.repeat(self.num_splits)
|
|
65
|
+
outputs = F.batch_norm(
|
|
66
|
+
input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
|
|
67
|
+
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
|
|
68
|
+
True, self.momentum, self.eps).view(N, C, H, W)
|
|
69
|
+
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
|
|
70
|
+
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
|
|
71
|
+
return outputs
|
|
72
|
+
else:
|
|
73
|
+
return F.batch_norm(
|
|
74
|
+
input, self.running_mean, self.running_var,
|
|
75
|
+
self.weight, self.bias, False, self.momentum, self.eps)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class FrozenBatchNorm(nn.Module):
|
|
79
|
+
"""
|
|
80
|
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
81
|
+
It contains non-trainable buffers called
|
|
82
|
+
"weight" and "bias", "running_mean", "running_var",
|
|
83
|
+
initialized to perform identity transformation.
|
|
84
|
+
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
|
85
|
+
which are computed from the original four parameters of BN.
|
|
86
|
+
The affine transform `x * weight + bias` will perform the equivalent
|
|
87
|
+
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
|
88
|
+
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
|
89
|
+
will be left unchanged as identity transformation.
|
|
90
|
+
Other pre-trained backbone models may contain all 4 parameters.
|
|
91
|
+
The forward is implemented by `F.batch_norm(..., training=False)`.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
_version = 3
|
|
95
|
+
|
|
96
|
+
def __init__(self, num_features, eps=1e-5, **kwargs):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.num_features = num_features
|
|
99
|
+
self.eps = eps
|
|
100
|
+
self.register_buffer("weight", torch.ones(num_features))
|
|
101
|
+
self.register_buffer("bias", torch.zeros(num_features))
|
|
102
|
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
|
103
|
+
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
|
104
|
+
|
|
105
|
+
def forward(self, x):
|
|
106
|
+
if x.requires_grad:
|
|
107
|
+
# When gradients are needed, F.batch_norm will use extra memory
|
|
108
|
+
# because its backward op computes gradients for weight/bias as well.
|
|
109
|
+
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
|
110
|
+
bias = self.bias - self.running_mean * scale
|
|
111
|
+
scale = scale.reshape(1, -1, 1, 1)
|
|
112
|
+
bias = bias.reshape(1, -1, 1, 1)
|
|
113
|
+
return x * scale + bias
|
|
114
|
+
else:
|
|
115
|
+
# When gradients are not needed, F.batch_norm is a single fused op
|
|
116
|
+
# and provide more optimization opportunities.
|
|
117
|
+
return F.batch_norm(
|
|
118
|
+
x,
|
|
119
|
+
self.running_mean,
|
|
120
|
+
self.running_var,
|
|
121
|
+
self.weight,
|
|
122
|
+
self.bias,
|
|
123
|
+
training=False,
|
|
124
|
+
eps=self.eps,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def _load_from_state_dict(
|
|
128
|
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
129
|
+
):
|
|
130
|
+
version = local_metadata.get("version", None)
|
|
131
|
+
|
|
132
|
+
if version is None or version < 2:
|
|
133
|
+
# No running_mean/var in early versions
|
|
134
|
+
# This will silent the warnings
|
|
135
|
+
if prefix + "running_mean" not in state_dict:
|
|
136
|
+
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
|
137
|
+
if prefix + "running_var" not in state_dict:
|
|
138
|
+
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
|
139
|
+
|
|
140
|
+
if version is not None and version < 3:
|
|
141
|
+
logger = logging.getLogger(__name__)
|
|
142
|
+
logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
|
|
143
|
+
# In version < 3, running_var are used without +eps.
|
|
144
|
+
state_dict[prefix + "running_var"] -= self.eps
|
|
145
|
+
|
|
146
|
+
super()._load_from_state_dict(
|
|
147
|
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def __repr__(self):
|
|
151
|
+
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def convert_frozen_batchnorm(cls, module):
|
|
155
|
+
"""
|
|
156
|
+
Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
|
157
|
+
Args:
|
|
158
|
+
module (torch.nn.Module):
|
|
159
|
+
Returns:
|
|
160
|
+
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
|
161
|
+
Otherwise, in-place convert module and return it.
|
|
162
|
+
Similar to convert_sync_batchnorm in
|
|
163
|
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
|
164
|
+
"""
|
|
165
|
+
bn_module = nn.modules.batchnorm
|
|
166
|
+
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
|
167
|
+
res = module
|
|
168
|
+
if isinstance(module, bn_module):
|
|
169
|
+
res = cls(module.num_features)
|
|
170
|
+
if module.affine:
|
|
171
|
+
res.weight.data = module.weight.data.clone().detach()
|
|
172
|
+
res.bias.data = module.bias.data.clone().detach()
|
|
173
|
+
res.running_mean.data = module.running_mean.data
|
|
174
|
+
res.running_var.data = module.running_var.data
|
|
175
|
+
res.eps = module.eps
|
|
176
|
+
else:
|
|
177
|
+
for name, child in module.named_children():
|
|
178
|
+
new_child = cls.convert_frozen_batchnorm(child)
|
|
179
|
+
if new_child is not child:
|
|
180
|
+
res.add_module(name, new_child)
|
|
181
|
+
return res
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def get_norm(norm, out_channels, **kwargs):
|
|
185
|
+
"""
|
|
186
|
+
Args:
|
|
187
|
+
norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
|
|
188
|
+
or a callable that takes a channel number and returns
|
|
189
|
+
the normalization layer as a nn.Module
|
|
190
|
+
out_channels: number of channels for normalization layer
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
nn.Module or None: the normalization layer
|
|
194
|
+
"""
|
|
195
|
+
if isinstance(norm, str):
|
|
196
|
+
if len(norm) == 0:
|
|
197
|
+
return None
|
|
198
|
+
norm = {
|
|
199
|
+
"BN": BatchNorm,
|
|
200
|
+
"syncBN": SyncBatchNorm,
|
|
201
|
+
"GhostBN": GhostBatchNorm,
|
|
202
|
+
"FrozenBN": FrozenBatchNorm,
|
|
203
|
+
"GN": lambda channels, **args: nn.GroupNorm(32, channels),
|
|
204
|
+
}[norm]
|
|
205
|
+
return norm(out_channels, **kwargs)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
__all__ = ['ContextBlock']
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def last_zero_init(m):
|
|
10
|
+
if isinstance(m, nn.Sequential):
|
|
11
|
+
nn.init.constant_(m[-1].weight, val=0)
|
|
12
|
+
if hasattr(m[-1], 'bias') and m[-1].bias is not None:
|
|
13
|
+
nn.init.constant_(m[-1].bias, 0)
|
|
14
|
+
else:
|
|
15
|
+
nn.init.constant_(m.weight, val=0)
|
|
16
|
+
if hasattr(m, 'bias') and m.bias is not None:
|
|
17
|
+
nn.init.constant_(m.bias, 0)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ContextBlock(nn.Module):
|
|
21
|
+
|
|
22
|
+
def __init__(self,
|
|
23
|
+
inplanes,
|
|
24
|
+
ratio,
|
|
25
|
+
pooling_type='att',
|
|
26
|
+
fusion_types=('channel_add',)):
|
|
27
|
+
super(ContextBlock, self).__init__()
|
|
28
|
+
assert pooling_type in ['avg', 'att']
|
|
29
|
+
assert isinstance(fusion_types, (list, tuple))
|
|
30
|
+
valid_fusion_types = ['channel_add', 'channel_mul']
|
|
31
|
+
assert all([f in valid_fusion_types for f in fusion_types])
|
|
32
|
+
assert len(fusion_types) > 0, 'at least one fusion should be used'
|
|
33
|
+
self.inplanes = inplanes
|
|
34
|
+
self.ratio = ratio
|
|
35
|
+
self.planes = int(inplanes * ratio)
|
|
36
|
+
self.pooling_type = pooling_type
|
|
37
|
+
self.fusion_types = fusion_types
|
|
38
|
+
if pooling_type == 'att':
|
|
39
|
+
self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
|
|
40
|
+
self.softmax = nn.Softmax(dim=2)
|
|
41
|
+
else:
|
|
42
|
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
43
|
+
if 'channel_add' in fusion_types:
|
|
44
|
+
self.channel_add_conv = nn.Sequential(
|
|
45
|
+
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
|
|
46
|
+
nn.LayerNorm([self.planes, 1, 1]),
|
|
47
|
+
nn.ReLU(inplace=True), # yapf: disable
|
|
48
|
+
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
|
|
49
|
+
else:
|
|
50
|
+
self.channel_add_conv = None
|
|
51
|
+
if 'channel_mul' in fusion_types:
|
|
52
|
+
self.channel_mul_conv = nn.Sequential(
|
|
53
|
+
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
|
|
54
|
+
nn.LayerNorm([self.planes, 1, 1]),
|
|
55
|
+
nn.ReLU(inplace=True), # yapf: disable
|
|
56
|
+
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
|
|
57
|
+
else:
|
|
58
|
+
self.channel_mul_conv = None
|
|
59
|
+
self.reset_parameters()
|
|
60
|
+
|
|
61
|
+
def reset_parameters(self):
|
|
62
|
+
if self.pooling_type == 'att':
|
|
63
|
+
nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
|
|
64
|
+
if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
|
|
65
|
+
nn.init.constant_(self.conv_mask.bias, 0)
|
|
66
|
+
self.conv_mask.inited = True
|
|
67
|
+
|
|
68
|
+
if self.channel_add_conv is not None:
|
|
69
|
+
last_zero_init(self.channel_add_conv)
|
|
70
|
+
if self.channel_mul_conv is not None:
|
|
71
|
+
last_zero_init(self.channel_mul_conv)
|
|
72
|
+
|
|
73
|
+
def spatial_pool(self, x):
|
|
74
|
+
batch, channel, height, width = x.size()
|
|
75
|
+
if self.pooling_type == 'att':
|
|
76
|
+
input_x = x
|
|
77
|
+
# [N, C, H * W]
|
|
78
|
+
input_x = input_x.view(batch, channel, height * width)
|
|
79
|
+
# [N, 1, C, H * W]
|
|
80
|
+
input_x = input_x.unsqueeze(1)
|
|
81
|
+
# [N, 1, H, W]
|
|
82
|
+
context_mask = self.conv_mask(x)
|
|
83
|
+
# [N, 1, H * W]
|
|
84
|
+
context_mask = context_mask.view(batch, 1, height * width)
|
|
85
|
+
# [N, 1, H * W]
|
|
86
|
+
context_mask = self.softmax(context_mask)
|
|
87
|
+
# [N, 1, H * W, 1]
|
|
88
|
+
context_mask = context_mask.unsqueeze(-1)
|
|
89
|
+
# [N, 1, C, 1]
|
|
90
|
+
context = torch.matmul(input_x, context_mask)
|
|
91
|
+
# [N, C, 1, 1]
|
|
92
|
+
context = context.view(batch, channel, 1, 1)
|
|
93
|
+
else:
|
|
94
|
+
# [N, C, 1, 1]
|
|
95
|
+
context = self.avg_pool(x)
|
|
96
|
+
|
|
97
|
+
return context
|
|
98
|
+
|
|
99
|
+
def forward(self, x):
|
|
100
|
+
# [N, C, 1, 1]
|
|
101
|
+
context = self.spatial_pool(x)
|
|
102
|
+
|
|
103
|
+
out = x
|
|
104
|
+
if self.channel_mul_conv is not None:
|
|
105
|
+
# [N, C, 1, 1]
|
|
106
|
+
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
|
|
107
|
+
out = out * channel_mul_term
|
|
108
|
+
if self.channel_add_conv is not None:
|
|
109
|
+
# [N, C, 1, 1]
|
|
110
|
+
channel_add_term = self.channel_add_conv(context)
|
|
111
|
+
out = out + channel_add_term
|
|
112
|
+
|
|
113
|
+
return out
|