dnt 0.2.1__py3-none-any.whl → 0.3.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dnt/__init__.py +4 -1
- dnt/analysis/__init__.py +3 -1
- dnt/analysis/count.py +107 -0
- dnt/analysis/interaction2.py +518 -0
- dnt/analysis/position.py +12 -0
- dnt/analysis/stop.py +92 -33
- dnt/analysis/stop2.py +289 -0
- dnt/analysis/stop3.py +758 -0
- dnt/detect/__init__.py +1 -1
- dnt/detect/signal/detector.py +326 -0
- dnt/detect/timestamp.py +105 -0
- dnt/detect/yolov8/detector.py +182 -35
- dnt/detect/yolov8/segmentor.py +171 -0
- dnt/engine/__init__.py +8 -0
- dnt/engine/bbox_interp.py +83 -0
- dnt/engine/bbox_iou.py +20 -0
- dnt/engine/cluster.py +31 -0
- dnt/engine/iob.py +66 -0
- dnt/filter/__init__.py +4 -0
- dnt/filter/filter.py +450 -21
- dnt/label/__init__.py +1 -1
- dnt/label/labeler.py +215 -14
- dnt/label/labeler2.py +631 -0
- dnt/shared/__init__.py +2 -1
- dnt/shared/data/coco.names +0 -0
- dnt/shared/data/openimages.names +0 -0
- dnt/shared/data/voc.names +0 -0
- dnt/shared/download.py +12 -0
- dnt/shared/synhcro.py +150 -0
- dnt/shared/util.py +17 -4
- dnt/third_party/fast-reid/__init__.py +1 -0
- dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
- dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
- dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
- dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
- dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
- dnt/third_party/fast-reid/configs/__init__.py +0 -0
- dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
- dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
- dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
- dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
- dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
- dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
- dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
- dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
- dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
- dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
- dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
- dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
- dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
- dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
- dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
- dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
- dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
- dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
- dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
- dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
- dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
- dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
- dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
- dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
- dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
- dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
- dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
- dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
- dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
- dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
- dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
- dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
- dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
- dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
- dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
- dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
- dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
- dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
- dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
- dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
- dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
- dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
- dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
- dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
- dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
- dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
- dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
- dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
- dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
- dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
- dnt/track/__init__.py +3 -1
- dnt/track/botsort/__init__.py +4 -0
- dnt/track/botsort/bot_tracker/__init__.py +3 -0
- dnt/track/botsort/bot_tracker/basetrack.py +60 -0
- dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
- dnt/track/botsort/bot_tracker/gmc.py +316 -0
- dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
- dnt/track/botsort/bot_tracker/matching.py +194 -0
- dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
- dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
- dnt/track/botsort/inference.py +96 -0
- dnt/track/config.py +120 -0
- dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
- dnt/track/dsort/configs/deep_sort.yaml +0 -0
- dnt/track/dsort/configs/fastreid.yaml +1 -1
- dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
- dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
- dnt/track/dsort/deep_sort/deep_sort.py +31 -21
- dnt/track/dsort/deep_sort/sort/detection.py +2 -1
- dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
- dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
- dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
- dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
- dnt/track/dsort/deep_sort/sort/track.py +2 -1
- dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
- dnt/track/dsort/dsort.py +44 -27
- dnt/track/re_class.py +117 -0
- dnt/track/sort/sort.py +9 -7
- dnt/track/tracker.py +225 -20
- dnt-0.3.1.8.dist-info/METADATA +117 -0
- dnt-0.3.1.8.dist-info/RECORD +315 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/WHEEL +1 -1
- dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
- dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
- dnt/track/dsort/deep_sort/deep/test.py +0 -77
- dnt/track/dsort/deep_sort/deep/train.py +0 -189
- dnt/track/dsort/utils/asserts.py +0 -13
- dnt/track/dsort/utils/draw.py +0 -36
- dnt/track/dsort/utils/json_logger.py +0 -383
- dnt/track/dsort/utils/log.py +0 -17
- dnt/track/dsort/utils/parser.py +0 -35
- dnt/track/dsort/utils/tools.py +0 -39
- dnt-0.2.1.dist-info/METADATA +0 -35
- dnt-0.2.1.dist-info/RECORD +0 -60
- /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info/licenses}/LICENSE +0 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from collections import Counter
|
|
8
|
+
|
|
9
|
+
from termcolor import colored
|
|
10
|
+
|
|
11
|
+
from .file_io import PathManager
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _ColorfulFormatter(logging.Formatter):
|
|
15
|
+
def __init__(self, *args, **kwargs):
|
|
16
|
+
self._root_name = kwargs.pop("root_name") + "."
|
|
17
|
+
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
|
18
|
+
if len(self._abbrev_name):
|
|
19
|
+
self._abbrev_name = self._abbrev_name + "."
|
|
20
|
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
|
21
|
+
|
|
22
|
+
def formatMessage(self, record):
|
|
23
|
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
|
24
|
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
|
25
|
+
if record.levelno == logging.WARNING:
|
|
26
|
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
|
27
|
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
|
28
|
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
|
29
|
+
else:
|
|
30
|
+
return log
|
|
31
|
+
return prefix + " " + log
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
|
|
35
|
+
def setup_logger(
|
|
36
|
+
output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Args:
|
|
40
|
+
output (str): a file name or a directory to save log. If None, will not save log file.
|
|
41
|
+
If ends with ".txt" or ".log", assumed to be a file name.
|
|
42
|
+
Otherwise, logs will be saved to `output/log.txt`.
|
|
43
|
+
name (str): the root module name of this logger
|
|
44
|
+
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
|
|
45
|
+
Set to "" to not log the root module in logs.
|
|
46
|
+
By default, will abbreviate "detectron2" to "d2" and leave other
|
|
47
|
+
modules unchanged.
|
|
48
|
+
"""
|
|
49
|
+
logger = logging.getLogger(name)
|
|
50
|
+
logger.setLevel(logging.DEBUG)
|
|
51
|
+
logger.propagate = False
|
|
52
|
+
|
|
53
|
+
if abbrev_name is None:
|
|
54
|
+
abbrev_name = "d2" if name == "detectron2" else name
|
|
55
|
+
|
|
56
|
+
plain_formatter = logging.Formatter(
|
|
57
|
+
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
|
58
|
+
)
|
|
59
|
+
# stdout logging: master only
|
|
60
|
+
if distributed_rank == 0:
|
|
61
|
+
ch = logging.StreamHandler(stream=sys.stdout)
|
|
62
|
+
ch.setLevel(logging.DEBUG)
|
|
63
|
+
if color:
|
|
64
|
+
formatter = _ColorfulFormatter(
|
|
65
|
+
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
|
66
|
+
datefmt="%m/%d %H:%M:%S",
|
|
67
|
+
root_name=name,
|
|
68
|
+
abbrev_name=str(abbrev_name),
|
|
69
|
+
)
|
|
70
|
+
else:
|
|
71
|
+
formatter = plain_formatter
|
|
72
|
+
ch.setFormatter(formatter)
|
|
73
|
+
logger.addHandler(ch)
|
|
74
|
+
|
|
75
|
+
# file logging: all workers
|
|
76
|
+
if output is not None:
|
|
77
|
+
if output.endswith(".txt") or output.endswith(".log"):
|
|
78
|
+
filename = output
|
|
79
|
+
else:
|
|
80
|
+
filename = os.path.join(output, "log.txt")
|
|
81
|
+
if distributed_rank > 0:
|
|
82
|
+
filename = filename + ".rank{}".format(distributed_rank)
|
|
83
|
+
PathManager.mkdirs(os.path.dirname(filename))
|
|
84
|
+
|
|
85
|
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
|
86
|
+
fh.setLevel(logging.DEBUG)
|
|
87
|
+
fh.setFormatter(plain_formatter)
|
|
88
|
+
logger.addHandler(fh)
|
|
89
|
+
|
|
90
|
+
return logger
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# cache the opened file object, so that different calls to `setup_logger`
|
|
94
|
+
# with the same file name can safely write to the same file.
|
|
95
|
+
@functools.lru_cache(maxsize=None)
|
|
96
|
+
def _cached_log_stream(filename):
|
|
97
|
+
return PathManager.open(filename, "a")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
Below are some other convenient logging methods.
|
|
102
|
+
They are mainly adopted from
|
|
103
|
+
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _find_caller():
|
|
108
|
+
"""
|
|
109
|
+
Returns:
|
|
110
|
+
str: module name of the caller
|
|
111
|
+
tuple: a hashable key to be used to identify different callers
|
|
112
|
+
"""
|
|
113
|
+
frame = sys._getframe(2)
|
|
114
|
+
while frame:
|
|
115
|
+
code = frame.f_code
|
|
116
|
+
if os.path.join("utils", "logger.") not in code.co_filename:
|
|
117
|
+
mod_name = frame.f_globals["__name__"]
|
|
118
|
+
if mod_name == "__main__":
|
|
119
|
+
mod_name = "detectron2"
|
|
120
|
+
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
|
|
121
|
+
frame = frame.f_back
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
_LOG_COUNTER = Counter()
|
|
125
|
+
_LOG_TIMER = {}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
|
|
129
|
+
"""
|
|
130
|
+
Log only for the first n times.
|
|
131
|
+
Args:
|
|
132
|
+
lvl (int): the logging level
|
|
133
|
+
msg (str):
|
|
134
|
+
n (int):
|
|
135
|
+
name (str): name of the logger to use. Will use the caller's module by default.
|
|
136
|
+
key (str or tuple[str]): the string(s) can be one of "caller" or
|
|
137
|
+
"message", which defines how to identify duplicated logs.
|
|
138
|
+
For example, if called with `n=1, key="caller"`, this function
|
|
139
|
+
will only log the first call from the same caller, regardless of
|
|
140
|
+
the message content.
|
|
141
|
+
If called with `n=1, key="message"`, this function will log the
|
|
142
|
+
same content only once, even if they are called from different places.
|
|
143
|
+
If called with `n=1, key=("caller", "message")`, this function
|
|
144
|
+
will not log only if the same caller has logged the same message before.
|
|
145
|
+
"""
|
|
146
|
+
if isinstance(key, str):
|
|
147
|
+
key = (key,)
|
|
148
|
+
assert len(key) > 0
|
|
149
|
+
|
|
150
|
+
caller_module, caller_key = _find_caller()
|
|
151
|
+
hash_key = ()
|
|
152
|
+
if "caller" in key:
|
|
153
|
+
hash_key = hash_key + caller_key
|
|
154
|
+
if "message" in key:
|
|
155
|
+
hash_key = hash_key + (msg,)
|
|
156
|
+
|
|
157
|
+
_LOG_COUNTER[hash_key] += 1
|
|
158
|
+
if _LOG_COUNTER[hash_key] <= n:
|
|
159
|
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def log_every_n(lvl, msg, n=1, *, name=None):
|
|
163
|
+
"""
|
|
164
|
+
Log once per n times.
|
|
165
|
+
Args:
|
|
166
|
+
lvl (int): the logging level
|
|
167
|
+
msg (str):
|
|
168
|
+
n (int):
|
|
169
|
+
name (str): name of the logger to use. Will use the caller's module by default.
|
|
170
|
+
"""
|
|
171
|
+
caller_module, key = _find_caller()
|
|
172
|
+
_LOG_COUNTER[key] += 1
|
|
173
|
+
if n == 1 or _LOG_COUNTER[key] % n == 1:
|
|
174
|
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
|
|
178
|
+
"""
|
|
179
|
+
Log no more than once per n seconds.
|
|
180
|
+
Args:
|
|
181
|
+
lvl (int): the logging level
|
|
182
|
+
msg (str):
|
|
183
|
+
n (int):
|
|
184
|
+
name (str): name of the logger to use. Will use the caller's module by default.
|
|
185
|
+
"""
|
|
186
|
+
caller_module, key = _find_caller()
|
|
187
|
+
last_logged = _LOG_TIMER.get(key, None)
|
|
188
|
+
current_time = time.time()
|
|
189
|
+
if last_logged is None or current_time - last_logged >= n:
|
|
190
|
+
logging.getLogger(name or caller_module).log(lvl, msg)
|
|
191
|
+
_LOG_TIMER[key] = current_time
|
|
192
|
+
|
|
193
|
+
# def create_small_table(small_dict):
|
|
194
|
+
# """
|
|
195
|
+
# Create a small table using the keys of small_dict as headers. This is only
|
|
196
|
+
# suitable for small dictionaries.
|
|
197
|
+
# Args:
|
|
198
|
+
# small_dict (dict): a result dictionary of only a few items.
|
|
199
|
+
# Returns:
|
|
200
|
+
# str: the table as a string.
|
|
201
|
+
# """
|
|
202
|
+
# keys, values = tuple(zip(*small_dict.items()))
|
|
203
|
+
# table = tabulate(
|
|
204
|
+
# [values],
|
|
205
|
+
# headers=keys,
|
|
206
|
+
# tablefmt="pipe",
|
|
207
|
+
# floatfmt=".3f",
|
|
208
|
+
# stralign="center",
|
|
209
|
+
# numalign="center",
|
|
210
|
+
# )
|
|
211
|
+
# return table
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# based on: https://github.com/PhilJd/contiguous_pytorch_params/blob/master/contiguous_params/params.py
|
|
8
|
+
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ContiguousParams:
|
|
15
|
+
|
|
16
|
+
def __init__(self, parameters):
|
|
17
|
+
# Create a list of the parameters to prevent emptying an iterator.
|
|
18
|
+
self._parameters = parameters
|
|
19
|
+
self._param_buffer = []
|
|
20
|
+
self._grad_buffer = []
|
|
21
|
+
self._group_dict = OrderedDict()
|
|
22
|
+
self._name_buffer = []
|
|
23
|
+
self._init_buffers()
|
|
24
|
+
# Store the data pointers for each parameter into the buffer. These
|
|
25
|
+
# can be used to check if an operation overwrites the gradient/data
|
|
26
|
+
# tensor (invalidating the assumption of a contiguous buffer).
|
|
27
|
+
self.data_pointers = []
|
|
28
|
+
self.grad_pointers = []
|
|
29
|
+
self.make_params_contiguous()
|
|
30
|
+
|
|
31
|
+
def _init_buffers(self):
|
|
32
|
+
dtype = self._parameters[0]["params"][0].dtype
|
|
33
|
+
device = self._parameters[0]["params"][0].device
|
|
34
|
+
if not all(p["params"][0].dtype == dtype for p in self._parameters):
|
|
35
|
+
raise ValueError("All parameters must be of the same dtype.")
|
|
36
|
+
if not all(p["params"][0].device == device for p in self._parameters):
|
|
37
|
+
raise ValueError("All parameters must be on the same device.")
|
|
38
|
+
|
|
39
|
+
# Group parameters by lr and weight decay
|
|
40
|
+
for param_dict in self._parameters:
|
|
41
|
+
freeze_status = param_dict["freeze_status"]
|
|
42
|
+
param_key = freeze_status + '_' + str(param_dict["lr"]) + '_' + str(param_dict["weight_decay"])
|
|
43
|
+
if param_key not in self._group_dict:
|
|
44
|
+
self._group_dict[param_key] = []
|
|
45
|
+
self._group_dict[param_key].append(param_dict)
|
|
46
|
+
|
|
47
|
+
for key, params in self._group_dict.items():
|
|
48
|
+
size = sum(p["params"][0].numel() for p in params)
|
|
49
|
+
self._param_buffer.append(torch.zeros(size, dtype=dtype, device=device))
|
|
50
|
+
self._grad_buffer.append(torch.zeros(size, dtype=dtype, device=device))
|
|
51
|
+
self._name_buffer.append(key)
|
|
52
|
+
|
|
53
|
+
def make_params_contiguous(self):
|
|
54
|
+
"""Create a buffer to hold all params and update the params to be views of the buffer.
|
|
55
|
+
Args:
|
|
56
|
+
parameters: An iterable of parameters.
|
|
57
|
+
"""
|
|
58
|
+
for i, params in enumerate(self._group_dict.values()):
|
|
59
|
+
index = 0
|
|
60
|
+
for param_dict in params:
|
|
61
|
+
p = param_dict["params"][0]
|
|
62
|
+
size = p.numel()
|
|
63
|
+
self._param_buffer[i][index:index + size] = p.data.view(-1)
|
|
64
|
+
p.data = self._param_buffer[i][index:index + size].view(p.data.shape)
|
|
65
|
+
p.grad = self._grad_buffer[i][index:index + size].view(p.data.shape)
|
|
66
|
+
self.data_pointers.append(p.data.data_ptr)
|
|
67
|
+
self.grad_pointers.append(p.grad.data.data_ptr)
|
|
68
|
+
index += size
|
|
69
|
+
# Bend the param_buffer to use grad_buffer to track its gradients.
|
|
70
|
+
self._param_buffer[i].grad = self._grad_buffer[i]
|
|
71
|
+
|
|
72
|
+
def contiguous(self):
|
|
73
|
+
"""Return all parameters as one contiguous buffer."""
|
|
74
|
+
return [{
|
|
75
|
+
"freeze_status": self._name_buffer[i].split('_')[0],
|
|
76
|
+
"params": self._param_buffer[i],
|
|
77
|
+
"lr": float(self._name_buffer[i].split('_')[1]),
|
|
78
|
+
"weight_decay": float(self._name_buffer[i].split('_')[2]),
|
|
79
|
+
} for i in range(len(self._param_buffer))]
|
|
80
|
+
|
|
81
|
+
def original(self):
|
|
82
|
+
"""Return the non-flattened parameters."""
|
|
83
|
+
return self._parameters
|
|
84
|
+
|
|
85
|
+
def buffer_is_valid(self):
|
|
86
|
+
"""Verify that all parameters and gradients still use the buffer."""
|
|
87
|
+
i = 0
|
|
88
|
+
for params in self._group_dict.values():
|
|
89
|
+
for param_dict in params:
|
|
90
|
+
p = param_dict["params"][0]
|
|
91
|
+
data_ptr = self.data_pointers[i]
|
|
92
|
+
grad_ptr = self.grad_pointers[i]
|
|
93
|
+
if (p.data.data_ptr() != data_ptr()) or (p.grad.data.data_ptr() != grad_ptr()):
|
|
94
|
+
return False
|
|
95
|
+
i += 1
|
|
96
|
+
return True
|
|
97
|
+
|
|
98
|
+
def assert_buffer_is_valid(self):
|
|
99
|
+
if not self.buffer_is_valid():
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"The data or gradient buffer has been invalidated. Please make "
|
|
102
|
+
"sure to use inplace operations only when updating parameters "
|
|
103
|
+
"or gradients.")
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
BN_MODULE_TYPES = (
|
|
12
|
+
torch.nn.BatchNorm1d,
|
|
13
|
+
torch.nn.BatchNorm2d,
|
|
14
|
+
torch.nn.BatchNorm3d,
|
|
15
|
+
torch.nn.SyncBatchNorm,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@torch.no_grad()
|
|
20
|
+
def update_bn_stats(model, data_loader, num_iters: int = 200):
|
|
21
|
+
"""
|
|
22
|
+
Recompute and update the batch norm stats to make them more precise. During
|
|
23
|
+
training both BN stats and the weight are changing after every iteration, so
|
|
24
|
+
the running average can not precisely reflect the actual stats of the
|
|
25
|
+
current model.
|
|
26
|
+
In this function, the BN stats are recomputed with fixed weights, to make
|
|
27
|
+
the running average more precise. Specifically, it computes the true average
|
|
28
|
+
of per-batch mean/variance instead of the running average.
|
|
29
|
+
Args:
|
|
30
|
+
model (nn.Module): the model whose bn stats will be recomputed.
|
|
31
|
+
Note that:
|
|
32
|
+
1. This function will not alter the training mode of the given model.
|
|
33
|
+
Users are responsible for setting the layers that needs
|
|
34
|
+
precise-BN to training mode, prior to calling this function.
|
|
35
|
+
2. Be careful if your models contain other stateful layers in
|
|
36
|
+
addition to BN, i.e. layers whose state can change in forward
|
|
37
|
+
iterations. This function will alter their state. If you wish
|
|
38
|
+
them unchanged, you need to either pass in a submodule without
|
|
39
|
+
those layers, or backup the states.
|
|
40
|
+
data_loader (iterator): an iterator. Produce data as inputs to the model.
|
|
41
|
+
num_iters (int): number of iterations to compute the stats.
|
|
42
|
+
"""
|
|
43
|
+
bn_layers = get_bn_modules(model)
|
|
44
|
+
if len(bn_layers) == 0:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
# In order to make the running stats only reflect the current batch, the
|
|
48
|
+
# momentum is disabled.
|
|
49
|
+
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
|
|
50
|
+
# Setting the momentum to 1.0 to compute the stats without momentum.
|
|
51
|
+
momentum_actual = [bn.momentum for bn in bn_layers]
|
|
52
|
+
for bn in bn_layers:
|
|
53
|
+
bn.momentum = 1.0
|
|
54
|
+
|
|
55
|
+
# Note that running_var actually means "running average of variance"
|
|
56
|
+
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
|
|
57
|
+
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
|
58
|
+
|
|
59
|
+
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
|
|
60
|
+
inputs['targets'].fill_(-1)
|
|
61
|
+
with torch.no_grad(): # No need to backward
|
|
62
|
+
model(inputs)
|
|
63
|
+
for i, bn in enumerate(bn_layers):
|
|
64
|
+
# Accumulates the bn stats.
|
|
65
|
+
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
|
|
66
|
+
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
|
|
67
|
+
# We compute the "average of variance" across iterations.
|
|
68
|
+
assert ind == num_iters - 1, (
|
|
69
|
+
"update_bn_stats is meant to run for {} iterations, "
|
|
70
|
+
"but the dataloader stops at {} iterations.".format(num_iters, ind)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
for i, bn in enumerate(bn_layers):
|
|
74
|
+
# Sets the precise bn stats.
|
|
75
|
+
bn.running_mean = running_mean[i]
|
|
76
|
+
bn.running_var = running_var[i]
|
|
77
|
+
bn.momentum = momentum_actual[i]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_bn_modules(model):
|
|
81
|
+
"""
|
|
82
|
+
Find all BatchNorm (BN) modules that are in training mode. See
|
|
83
|
+
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
|
|
84
|
+
included in this search.
|
|
85
|
+
Args:
|
|
86
|
+
model (nn.Module): a model possibly containing BN modules.
|
|
87
|
+
Returns:
|
|
88
|
+
list[nn.Module]: all BN modules in the model.
|
|
89
|
+
"""
|
|
90
|
+
# Finds all the bn layers.
|
|
91
|
+
bn_layers = [
|
|
92
|
+
m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
|
|
93
|
+
]
|
|
94
|
+
return bn_layers
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
3
|
+
|
|
4
|
+
from typing import Dict, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Registry(object):
|
|
8
|
+
"""
|
|
9
|
+
The registry that provides name -> object mapping, to support third-party
|
|
10
|
+
users' custom modules.
|
|
11
|
+
To create a registry (e.g. a backbone registry):
|
|
12
|
+
.. code-block:: python
|
|
13
|
+
BACKBONE_REGISTRY = Registry('BACKBONE')
|
|
14
|
+
To register an object:
|
|
15
|
+
.. code-block:: python
|
|
16
|
+
@BACKBONE_REGISTRY.register()
|
|
17
|
+
class MyBackbone():
|
|
18
|
+
...
|
|
19
|
+
Or:
|
|
20
|
+
.. code-block:: python
|
|
21
|
+
BACKBONE_REGISTRY.register(MyBackbone)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, name: str) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Args:
|
|
27
|
+
name (str): the name of this registry
|
|
28
|
+
"""
|
|
29
|
+
self._name: str = name
|
|
30
|
+
self._obj_map: Dict[str, object] = {}
|
|
31
|
+
|
|
32
|
+
def _do_register(self, name: str, obj: object) -> None:
|
|
33
|
+
assert (
|
|
34
|
+
name not in self._obj_map
|
|
35
|
+
), "An object named '{}' was already registered in '{}' registry!".format(
|
|
36
|
+
name, self._name
|
|
37
|
+
)
|
|
38
|
+
self._obj_map[name] = obj
|
|
39
|
+
|
|
40
|
+
def register(self, obj: object = None) -> Optional[object]:
|
|
41
|
+
"""
|
|
42
|
+
Register the given object under the the name `obj.__name__`.
|
|
43
|
+
Can be used as either a decorator or not. See docstring of this class for usage.
|
|
44
|
+
"""
|
|
45
|
+
if obj is None:
|
|
46
|
+
# used as a decorator
|
|
47
|
+
def deco(func_or_class: object) -> object:
|
|
48
|
+
name = func_or_class.__name__ # pyre-ignore
|
|
49
|
+
self._do_register(name, func_or_class)
|
|
50
|
+
return func_or_class
|
|
51
|
+
|
|
52
|
+
return deco
|
|
53
|
+
|
|
54
|
+
# used as a function call
|
|
55
|
+
name = obj.__name__ # pyre-ignore
|
|
56
|
+
self._do_register(name, obj)
|
|
57
|
+
|
|
58
|
+
def get(self, name: str) -> object:
|
|
59
|
+
ret = self._obj_map.get(name)
|
|
60
|
+
if ret is None:
|
|
61
|
+
raise KeyError(
|
|
62
|
+
"No object named '{}' found in '{}' registry!".format(
|
|
63
|
+
name, self._name
|
|
64
|
+
)
|
|
65
|
+
)
|
|
66
|
+
return ret
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from torch.autograd import Variable
|
|
10
|
+
|
|
11
|
+
from collections import OrderedDict
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def summary(model, input_size, batch_size=-1, device="cuda"):
|
|
16
|
+
def register_hook(module):
|
|
17
|
+
|
|
18
|
+
def hook(module, input, output):
|
|
19
|
+
class_name = str(module.__class__).split(".")[-1].split("'")[0]
|
|
20
|
+
module_idx = len(summary)
|
|
21
|
+
|
|
22
|
+
m_key = "%s-%i" % (class_name, module_idx + 1)
|
|
23
|
+
summary[m_key] = OrderedDict()
|
|
24
|
+
summary[m_key]["input_shape"] = list(input[0].size())
|
|
25
|
+
summary[m_key]["input_shape"][0] = batch_size
|
|
26
|
+
if isinstance(output, (list, tuple)):
|
|
27
|
+
summary[m_key]["output_shape"] = [
|
|
28
|
+
[-1] + list(o.size())[1:] for o in output
|
|
29
|
+
]
|
|
30
|
+
else:
|
|
31
|
+
summary[m_key]["output_shape"] = list(output.size())
|
|
32
|
+
summary[m_key]["output_shape"][0] = batch_size
|
|
33
|
+
|
|
34
|
+
params = 0
|
|
35
|
+
if hasattr(module, "weight") and hasattr(module.weight, "size"):
|
|
36
|
+
params += torch.prod(torch.LongTensor(list(module.weight.size())))
|
|
37
|
+
summary[m_key]["trainable"] = module.weight.requires_grad
|
|
38
|
+
if hasattr(module, "bias") and hasattr(module.bias, "size"):
|
|
39
|
+
params += torch.prod(torch.LongTensor(list(module.bias.size())))
|
|
40
|
+
summary[m_key]["nb_params"] = params
|
|
41
|
+
|
|
42
|
+
if (
|
|
43
|
+
not isinstance(module, nn.Sequential)
|
|
44
|
+
and not isinstance(module, nn.ModuleList)
|
|
45
|
+
and not (module == model)
|
|
46
|
+
):
|
|
47
|
+
hooks.append(module.register_forward_hook(hook))
|
|
48
|
+
|
|
49
|
+
device = device.lower()
|
|
50
|
+
assert device in [
|
|
51
|
+
"cuda",
|
|
52
|
+
"cpu",
|
|
53
|
+
], "Input device is not valid, please specify 'cuda' or 'cpu'"
|
|
54
|
+
|
|
55
|
+
if device == "cuda" and torch.cuda.is_available():
|
|
56
|
+
dtype = torch.cuda.FloatTensor
|
|
57
|
+
else:
|
|
58
|
+
dtype = torch.FloatTensor
|
|
59
|
+
|
|
60
|
+
# multiple inputs to the network
|
|
61
|
+
if isinstance(input_size, tuple):
|
|
62
|
+
input_size = [input_size]
|
|
63
|
+
|
|
64
|
+
# batch_size of 2 for batchnorm
|
|
65
|
+
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
|
|
66
|
+
# print(type(x[0]))
|
|
67
|
+
|
|
68
|
+
# create properties
|
|
69
|
+
summary = OrderedDict()
|
|
70
|
+
hooks = []
|
|
71
|
+
|
|
72
|
+
# register hook
|
|
73
|
+
model.apply(register_hook)
|
|
74
|
+
|
|
75
|
+
# make a forward pass
|
|
76
|
+
# print(x.shape)
|
|
77
|
+
model(*x)
|
|
78
|
+
|
|
79
|
+
# remove these hooks
|
|
80
|
+
for h in hooks:
|
|
81
|
+
h.remove()
|
|
82
|
+
|
|
83
|
+
print("----------------------------------------------------------------")
|
|
84
|
+
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
|
|
85
|
+
print(line_new)
|
|
86
|
+
print("================================================================")
|
|
87
|
+
total_params = 0
|
|
88
|
+
total_output = 0
|
|
89
|
+
trainable_params = 0
|
|
90
|
+
for layer in summary:
|
|
91
|
+
# input_shape, output_shape, trainable, nb_params
|
|
92
|
+
line_new = "{:>20} {:>25} {:>15}".format(
|
|
93
|
+
layer,
|
|
94
|
+
str(summary[layer]["output_shape"]),
|
|
95
|
+
"{0:,}".format(summary[layer]["nb_params"]),
|
|
96
|
+
)
|
|
97
|
+
total_params += summary[layer]["nb_params"]
|
|
98
|
+
total_output += np.prod(summary[layer]["output_shape"])
|
|
99
|
+
if "trainable" in summary[layer]:
|
|
100
|
+
if summary[layer]["trainable"] == True:
|
|
101
|
+
trainable_params += summary[layer]["nb_params"]
|
|
102
|
+
print(line_new)
|
|
103
|
+
|
|
104
|
+
# assume 4 bytes/number (float on cuda).
|
|
105
|
+
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
|
|
106
|
+
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
|
|
107
|
+
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
|
|
108
|
+
total_size = total_params_size + total_output_size + total_input_size
|
|
109
|
+
|
|
110
|
+
print("================================================================")
|
|
111
|
+
print("Total params: {0:,}".format(total_params))
|
|
112
|
+
print("Trainable params: {0:,}".format(trainable_params))
|
|
113
|
+
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
|
|
114
|
+
print("----------------------------------------------------------------")
|
|
115
|
+
print("Input size (MB): %0.2f" % total_input_size)
|
|
116
|
+
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
|
|
117
|
+
print("Params size (MB): %0.2f" % total_params_size)
|
|
118
|
+
print("Estimated Total Size (MB): %0.2f" % total_size)
|
|
119
|
+
print("----------------------------------------------------------------")
|
|
120
|
+
# return summary
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
from time import perf_counter
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Timer:
|
|
9
|
+
"""
|
|
10
|
+
A timer which computes the time elapsed since the start/reset of the timer.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.reset()
|
|
15
|
+
|
|
16
|
+
def reset(self):
|
|
17
|
+
"""
|
|
18
|
+
Reset the timer.
|
|
19
|
+
"""
|
|
20
|
+
self._start = perf_counter()
|
|
21
|
+
self._paused: Optional[float] = None
|
|
22
|
+
self._total_paused = 0
|
|
23
|
+
self._count_start = 1
|
|
24
|
+
|
|
25
|
+
def pause(self):
|
|
26
|
+
"""
|
|
27
|
+
Pause the timer.
|
|
28
|
+
"""
|
|
29
|
+
if self._paused is not None:
|
|
30
|
+
raise ValueError("Trying to pause a Timer that is already paused!")
|
|
31
|
+
self._paused = perf_counter()
|
|
32
|
+
|
|
33
|
+
def is_paused(self) -> bool:
|
|
34
|
+
"""
|
|
35
|
+
Returns:
|
|
36
|
+
bool: whether the timer is currently paused
|
|
37
|
+
"""
|
|
38
|
+
return self._paused is not None
|
|
39
|
+
|
|
40
|
+
def resume(self):
|
|
41
|
+
"""
|
|
42
|
+
Resume the timer.
|
|
43
|
+
"""
|
|
44
|
+
if self._paused is None:
|
|
45
|
+
raise ValueError("Trying to resume a Timer that is not paused!")
|
|
46
|
+
self._total_paused += perf_counter() - self._paused
|
|
47
|
+
self._paused = None
|
|
48
|
+
self._count_start += 1
|
|
49
|
+
|
|
50
|
+
def seconds(self) -> float:
|
|
51
|
+
"""
|
|
52
|
+
Returns:
|
|
53
|
+
(float): the total number of seconds since the start/reset of the
|
|
54
|
+
timer, excluding the time when the timer is paused.
|
|
55
|
+
"""
|
|
56
|
+
if self._paused is not None:
|
|
57
|
+
end_time: float = self._paused # type: ignore
|
|
58
|
+
else:
|
|
59
|
+
end_time = perf_counter()
|
|
60
|
+
return end_time - self._start - self._total_paused
|
|
61
|
+
|
|
62
|
+
def avg_seconds(self) -> float:
|
|
63
|
+
"""
|
|
64
|
+
Returns:
|
|
65
|
+
(float): the average number of seconds between every start/reset and
|
|
66
|
+
pause.
|
|
67
|
+
"""
|
|
68
|
+
return self.seconds() / self._count_start
|