dnt 0.2.4__py3-none-any.whl → 0.3.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dnt/__init__.py +3 -2
- dnt/analysis/__init__.py +3 -2
- dnt/analysis/count.py +54 -37
- dnt/analysis/interaction2.py +518 -0
- dnt/analysis/stop.py +22 -17
- dnt/analysis/stop2.py +289 -0
- dnt/analysis/stop3.py +758 -0
- dnt/detect/signal/detector.py +326 -0
- dnt/detect/timestamp.py +105 -0
- dnt/detect/yolov8/detector.py +179 -36
- dnt/detect/yolov8/segmentor.py +60 -2
- dnt/engine/__init__.py +8 -0
- dnt/engine/bbox_interp.py +83 -0
- dnt/engine/bbox_iou.py +20 -0
- dnt/engine/cluster.py +31 -0
- dnt/engine/iob.py +66 -0
- dnt/filter/filter.py +333 -2
- dnt/label/labeler.py +4 -4
- dnt/label/labeler2.py +631 -0
- dnt/shared/__init__.py +2 -1
- dnt/shared/data/coco.names +0 -0
- dnt/shared/data/openimages.names +0 -0
- dnt/shared/data/voc.names +0 -0
- dnt/shared/download.py +12 -0
- dnt/shared/synhcro.py +150 -0
- dnt/shared/util.py +17 -4
- dnt/third_party/fast-reid/__init__.py +1 -0
- dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
- dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
- dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
- dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
- dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
- dnt/third_party/fast-reid/configs/__init__.py +0 -0
- dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
- dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
- dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
- dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
- dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
- dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
- dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
- dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
- dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
- dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
- dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
- dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
- dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
- dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
- dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
- dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
- dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
- dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
- dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
- dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
- dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
- dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
- dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
- dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
- dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
- dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
- dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
- dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
- dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
- dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
- dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
- dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
- dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
- dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
- dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
- dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
- dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
- dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
- dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
- dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
- dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
- dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
- dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
- dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
- dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
- dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
- dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
- dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
- dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
- dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
- dnt/track/__init__.py +2 -0
- dnt/track/botsort/__init__.py +4 -0
- dnt/track/botsort/bot_tracker/__init__.py +3 -0
- dnt/track/botsort/bot_tracker/basetrack.py +60 -0
- dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
- dnt/track/botsort/bot_tracker/gmc.py +316 -0
- dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
- dnt/track/botsort/bot_tracker/matching.py +194 -0
- dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
- dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
- dnt/track/botsort/inference.py +96 -0
- dnt/track/config.py +120 -0
- dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
- dnt/track/dsort/configs/deep_sort.yaml +0 -0
- dnt/track/dsort/configs/fastreid.yaml +1 -1
- dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
- dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
- dnt/track/dsort/deep_sort/deep_sort.py +31 -20
- dnt/track/dsort/deep_sort/sort/detection.py +2 -1
- dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
- dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
- dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
- dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
- dnt/track/dsort/deep_sort/sort/track.py +2 -1
- dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
- dnt/track/dsort/dsort.py +43 -33
- dnt/track/re_class.py +117 -0
- dnt/track/sort/sort.py +9 -6
- dnt/track/tracker.py +213 -32
- dnt-0.3.1.8.dist-info/METADATA +117 -0
- dnt-0.3.1.8.dist-info/RECORD +315 -0
- {dnt-0.2.4.dist-info → dnt-0.3.1.8.dist-info}/WHEEL +1 -1
- dnt/analysis/yield.py +0 -9
- dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
- dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
- dnt/track/dsort/deep_sort/deep/test.py +0 -77
- dnt/track/dsort/deep_sort/deep/train.py +0 -189
- dnt/track/dsort/utils/asserts.py +0 -13
- dnt/track/dsort/utils/draw.py +0 -36
- dnt/track/dsort/utils/json_logger.py +0 -383
- dnt/track/dsort/utils/log.py +0 -17
- dnt/track/dsort/utils/parser.py +0 -35
- dnt/track/dsort/utils/tools.py +0 -39
- dnt-0.2.4.dist-info/METADATA +0 -35
- dnt-0.2.4.dist-info/RECORD +0 -64
- /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
- {dnt-0.2.4.dist-info → dnt-0.3.1.8.dist-info/licenses}/LICENSE +0 -0
- {dnt-0.2.4.dist-info → dnt-0.3.1.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# Based on: https://github.com/facebookresearch/detectron2/blob/master/detectron2/solver/build.py
|
|
8
|
+
|
|
9
|
+
import copy
|
|
10
|
+
import itertools
|
|
11
|
+
import math
|
|
12
|
+
import re
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from fastreid.config import CfgNode
|
|
19
|
+
from fastreid.utils.params import ContiguousParams
|
|
20
|
+
from . import lr_scheduler
|
|
21
|
+
|
|
22
|
+
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
|
|
23
|
+
_GradientClipper = Callable[[_GradientClipperInput], None]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GradientClipType(Enum):
|
|
27
|
+
VALUE = "value"
|
|
28
|
+
NORM = "norm"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
|
|
32
|
+
"""
|
|
33
|
+
Creates gradient clipping closure to clip by value or by norm,
|
|
34
|
+
according to the provided config.
|
|
35
|
+
"""
|
|
36
|
+
cfg = copy.deepcopy(cfg)
|
|
37
|
+
|
|
38
|
+
def clip_grad_norm(p: _GradientClipperInput):
|
|
39
|
+
torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
|
|
40
|
+
|
|
41
|
+
def clip_grad_value(p: _GradientClipperInput):
|
|
42
|
+
torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
|
|
43
|
+
|
|
44
|
+
_GRADIENT_CLIP_TYPE_TO_CLIPPER = {
|
|
45
|
+
GradientClipType.VALUE: clip_grad_value,
|
|
46
|
+
GradientClipType.NORM: clip_grad_norm,
|
|
47
|
+
}
|
|
48
|
+
return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _generate_optimizer_class_with_gradient_clipping(
|
|
52
|
+
optimizer: Type[torch.optim.Optimizer],
|
|
53
|
+
*,
|
|
54
|
+
per_param_clipper: Optional[_GradientClipper] = None,
|
|
55
|
+
global_clipper: Optional[_GradientClipper] = None,
|
|
56
|
+
) -> Type[torch.optim.Optimizer]:
|
|
57
|
+
"""
|
|
58
|
+
Dynamically creates a new type that inherits the type of a given instance
|
|
59
|
+
and overrides the `step` method to add gradient clipping
|
|
60
|
+
"""
|
|
61
|
+
assert (
|
|
62
|
+
per_param_clipper is None or global_clipper is None
|
|
63
|
+
), "Not allowed to use both per-parameter clipping and global clipping"
|
|
64
|
+
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def optimizer_wgc_step(self, closure=None):
|
|
67
|
+
if per_param_clipper is not None:
|
|
68
|
+
for group in self.param_groups:
|
|
69
|
+
for p in group["params"]:
|
|
70
|
+
per_param_clipper(p)
|
|
71
|
+
else:
|
|
72
|
+
# global clipper for future use with detr
|
|
73
|
+
# (https://github.com/facebookresearch/detr/pull/287)
|
|
74
|
+
all_params = itertools.chain(*[g["params"] for g in self.param_groups])
|
|
75
|
+
global_clipper(all_params)
|
|
76
|
+
optimizer.step(self, closure)
|
|
77
|
+
|
|
78
|
+
OptimizerWithGradientClip = type(
|
|
79
|
+
optimizer.__name__ + "WithGradientClip",
|
|
80
|
+
(optimizer,),
|
|
81
|
+
{"step": optimizer_wgc_step},
|
|
82
|
+
)
|
|
83
|
+
return OptimizerWithGradientClip
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def maybe_add_gradient_clipping(
|
|
87
|
+
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
|
|
88
|
+
) -> Type[torch.optim.Optimizer]:
|
|
89
|
+
"""
|
|
90
|
+
If gradient clipping is enabled through config options, wraps the existing
|
|
91
|
+
optimizer type to become a new dynamically created class OptimizerWithGradientClip
|
|
92
|
+
that inherits the given optimizer and overrides the `step` method to
|
|
93
|
+
include gradient clipping.
|
|
94
|
+
Args:
|
|
95
|
+
cfg: CfgNode, configuration options
|
|
96
|
+
optimizer: type. A subclass of torch.optim.Optimizer
|
|
97
|
+
Return:
|
|
98
|
+
type: either the input `optimizer` (if gradient clipping is disabled), or
|
|
99
|
+
a subclass of it with gradient clipping included in the `step` method.
|
|
100
|
+
"""
|
|
101
|
+
if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
|
|
102
|
+
return optimizer
|
|
103
|
+
if isinstance(optimizer, torch.optim.Optimizer):
|
|
104
|
+
optimizer_type = type(optimizer)
|
|
105
|
+
else:
|
|
106
|
+
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
|
|
107
|
+
optimizer_type = optimizer
|
|
108
|
+
|
|
109
|
+
grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
|
|
110
|
+
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
|
|
111
|
+
optimizer_type, per_param_clipper=grad_clipper
|
|
112
|
+
)
|
|
113
|
+
if isinstance(optimizer, torch.optim.Optimizer):
|
|
114
|
+
optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
|
|
115
|
+
return optimizer
|
|
116
|
+
else:
|
|
117
|
+
return OptimizerWithGradientClip
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _generate_optimizer_class_with_freeze_layer(
|
|
121
|
+
optimizer: Type[torch.optim.Optimizer],
|
|
122
|
+
*,
|
|
123
|
+
freeze_iters: int = 0,
|
|
124
|
+
) -> Type[torch.optim.Optimizer]:
|
|
125
|
+
assert freeze_iters > 0, "No layers need to be frozen or freeze iterations is 0"
|
|
126
|
+
|
|
127
|
+
cnt = 0
|
|
128
|
+
@torch.no_grad()
|
|
129
|
+
def optimizer_wfl_step(self, closure=None):
|
|
130
|
+
nonlocal cnt
|
|
131
|
+
if cnt < freeze_iters:
|
|
132
|
+
cnt += 1
|
|
133
|
+
param_ref = []
|
|
134
|
+
grad_ref = []
|
|
135
|
+
for group in self.param_groups:
|
|
136
|
+
if group["freeze_status"] == "freeze":
|
|
137
|
+
for p in group["params"]:
|
|
138
|
+
if p.grad is not None:
|
|
139
|
+
param_ref.append(p)
|
|
140
|
+
grad_ref.append(p.grad)
|
|
141
|
+
p.grad = None
|
|
142
|
+
|
|
143
|
+
optimizer.step(self, closure)
|
|
144
|
+
for p, g in zip(param_ref, grad_ref):
|
|
145
|
+
p.grad = g
|
|
146
|
+
else:
|
|
147
|
+
optimizer.step(self, closure)
|
|
148
|
+
|
|
149
|
+
OptimizerWithFreezeLayer = type(
|
|
150
|
+
optimizer.__name__ + "WithFreezeLayer",
|
|
151
|
+
(optimizer,),
|
|
152
|
+
{"step": optimizer_wfl_step},
|
|
153
|
+
)
|
|
154
|
+
return OptimizerWithFreezeLayer
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def maybe_add_freeze_layer(
|
|
158
|
+
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
|
|
159
|
+
) -> Type[torch.optim.Optimizer]:
|
|
160
|
+
if len(cfg.MODEL.FREEZE_LAYERS) == 0 or cfg.SOLVER.FREEZE_ITERS <= 0:
|
|
161
|
+
return optimizer
|
|
162
|
+
|
|
163
|
+
if isinstance(optimizer, torch.optim.Optimizer):
|
|
164
|
+
optimizer_type = type(optimizer)
|
|
165
|
+
else:
|
|
166
|
+
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
|
|
167
|
+
optimizer_type = optimizer
|
|
168
|
+
|
|
169
|
+
OptimizerWithFreezeLayer = _generate_optimizer_class_with_freeze_layer(
|
|
170
|
+
optimizer_type,
|
|
171
|
+
freeze_iters=cfg.SOLVER.FREEZE_ITERS
|
|
172
|
+
)
|
|
173
|
+
if isinstance(optimizer, torch.optim.Optimizer):
|
|
174
|
+
optimizer.__class__ = OptimizerWithFreezeLayer # a bit hacky, not recommended
|
|
175
|
+
return optimizer
|
|
176
|
+
else:
|
|
177
|
+
return OptimizerWithFreezeLayer
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def build_optimizer(cfg, model, contiguous=True):
|
|
181
|
+
params = get_default_optimizer_params(
|
|
182
|
+
model,
|
|
183
|
+
base_lr=cfg.SOLVER.BASE_LR,
|
|
184
|
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
|
185
|
+
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
|
186
|
+
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
|
187
|
+
heads_lr_factor=cfg.SOLVER.HEADS_LR_FACTOR,
|
|
188
|
+
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
|
189
|
+
freeze_layers=cfg.MODEL.FREEZE_LAYERS if cfg.SOLVER.FREEZE_ITERS > 0 else [],
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if contiguous:
|
|
193
|
+
params = ContiguousParams(params)
|
|
194
|
+
solver_opt = cfg.SOLVER.OPT
|
|
195
|
+
if solver_opt == "SGD":
|
|
196
|
+
return maybe_add_freeze_layer(
|
|
197
|
+
cfg,
|
|
198
|
+
maybe_add_gradient_clipping(cfg, torch.optim.SGD)
|
|
199
|
+
)(
|
|
200
|
+
params.contiguous() if contiguous else params,
|
|
201
|
+
momentum=cfg.SOLVER.MOMENTUM,
|
|
202
|
+
nesterov=cfg.SOLVER.NESTEROV,
|
|
203
|
+
), params
|
|
204
|
+
else:
|
|
205
|
+
return maybe_add_freeze_layer(
|
|
206
|
+
cfg,
|
|
207
|
+
maybe_add_gradient_clipping(cfg, getattr(torch.optim, solver_opt))
|
|
208
|
+
)(params.contiguous() if contiguous else params), params
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def get_default_optimizer_params(
|
|
212
|
+
model: torch.nn.Module,
|
|
213
|
+
base_lr: Optional[float] = None,
|
|
214
|
+
weight_decay: Optional[float] = None,
|
|
215
|
+
weight_decay_norm: Optional[float] = None,
|
|
216
|
+
bias_lr_factor: Optional[float] = 1.0,
|
|
217
|
+
heads_lr_factor: Optional[float] = 1.0,
|
|
218
|
+
weight_decay_bias: Optional[float] = None,
|
|
219
|
+
overrides: Optional[Dict[str, Dict[str, float]]] = None,
|
|
220
|
+
freeze_layers: Optional[list] = [],
|
|
221
|
+
):
|
|
222
|
+
"""
|
|
223
|
+
Get default param list for optimizer, with support for a few types of
|
|
224
|
+
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
|
|
225
|
+
Args:
|
|
226
|
+
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
|
|
227
|
+
weight_decay: weight decay for every group by default. Can be omitted to use the one
|
|
228
|
+
in optimizer.
|
|
229
|
+
weight_decay_norm: override weight decay for params in normalization layers
|
|
230
|
+
bias_lr_factor: multiplier of lr for bias parameters.
|
|
231
|
+
heads_lr_factor: multiplier of lr for model.head parameters.
|
|
232
|
+
weight_decay_bias: override weight decay for bias parameters
|
|
233
|
+
overrides: if not `None`, provides values for optimizer hyperparameters
|
|
234
|
+
(LR, weight decay) for module parameters with a given name; e.g.
|
|
235
|
+
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
|
|
236
|
+
weight decay values for all module parameters named `embedding`.
|
|
237
|
+
freeze_layers: layer names for freezing.
|
|
238
|
+
For common detection models, ``weight_decay_norm`` is the only option
|
|
239
|
+
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
|
|
240
|
+
from Detectron1 that are not found useful.
|
|
241
|
+
Example:
|
|
242
|
+
::
|
|
243
|
+
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
|
|
244
|
+
lr=0.01, weight_decay=1e-4, momentum=0.9)
|
|
245
|
+
"""
|
|
246
|
+
if overrides is None:
|
|
247
|
+
overrides = {}
|
|
248
|
+
defaults = {}
|
|
249
|
+
if base_lr is not None:
|
|
250
|
+
defaults["lr"] = base_lr
|
|
251
|
+
if weight_decay is not None:
|
|
252
|
+
defaults["weight_decay"] = weight_decay
|
|
253
|
+
bias_overrides = {}
|
|
254
|
+
if bias_lr_factor is not None and bias_lr_factor != 1.0:
|
|
255
|
+
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
|
|
256
|
+
# exactly the same as regular weights.
|
|
257
|
+
if base_lr is None:
|
|
258
|
+
raise ValueError("bias_lr_factor requires base_lr")
|
|
259
|
+
bias_overrides["lr"] = base_lr * bias_lr_factor
|
|
260
|
+
if weight_decay_bias is not None:
|
|
261
|
+
bias_overrides["weight_decay"] = weight_decay_bias
|
|
262
|
+
if len(bias_overrides):
|
|
263
|
+
if "bias" in overrides:
|
|
264
|
+
raise ValueError("Conflicting overrides for 'bias'")
|
|
265
|
+
overrides["bias"] = bias_overrides
|
|
266
|
+
|
|
267
|
+
layer_names_pattern = [re.compile(name) for name in freeze_layers]
|
|
268
|
+
|
|
269
|
+
norm_module_types = (
|
|
270
|
+
torch.nn.BatchNorm1d,
|
|
271
|
+
torch.nn.BatchNorm2d,
|
|
272
|
+
torch.nn.BatchNorm3d,
|
|
273
|
+
torch.nn.SyncBatchNorm,
|
|
274
|
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
|
275
|
+
torch.nn.GroupNorm,
|
|
276
|
+
torch.nn.InstanceNorm1d,
|
|
277
|
+
torch.nn.InstanceNorm2d,
|
|
278
|
+
torch.nn.InstanceNorm3d,
|
|
279
|
+
torch.nn.LayerNorm,
|
|
280
|
+
torch.nn.LocalResponseNorm,
|
|
281
|
+
)
|
|
282
|
+
params: List[Dict[str, Any]] = []
|
|
283
|
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
|
284
|
+
|
|
285
|
+
for module_name, module in model.named_modules():
|
|
286
|
+
for module_param_name, value in module.named_parameters(recurse=False):
|
|
287
|
+
if not value.requires_grad:
|
|
288
|
+
continue
|
|
289
|
+
# Avoid duplicating parameters
|
|
290
|
+
if value in memo:
|
|
291
|
+
continue
|
|
292
|
+
memo.add(value)
|
|
293
|
+
|
|
294
|
+
hyperparams = copy.copy(defaults)
|
|
295
|
+
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
|
|
296
|
+
hyperparams["weight_decay"] = weight_decay_norm
|
|
297
|
+
hyperparams.update(overrides.get(module_param_name, {}))
|
|
298
|
+
if module_name.split('.')[0] == "heads" and (heads_lr_factor is not None and heads_lr_factor != 1.0):
|
|
299
|
+
hyperparams["lr"] = hyperparams.get("lr", base_lr) * heads_lr_factor
|
|
300
|
+
name = module_name + '.' + module_param_name
|
|
301
|
+
freeze_status = "normal"
|
|
302
|
+
# Search freeze layer names, it must match from beginning, so use `match` not `search`
|
|
303
|
+
for pattern in layer_names_pattern:
|
|
304
|
+
if pattern.match(name) is not None:
|
|
305
|
+
freeze_status = "freeze"
|
|
306
|
+
break
|
|
307
|
+
|
|
308
|
+
params.append({"freeze_status": freeze_status, "params": [value], **hyperparams})
|
|
309
|
+
return params
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def build_lr_scheduler(cfg, optimizer, iters_per_epoch):
|
|
313
|
+
max_epoch = cfg.SOLVER.MAX_EPOCH - max(
|
|
314
|
+
math.ceil(cfg.SOLVER.WARMUP_ITERS / iters_per_epoch), cfg.SOLVER.DELAY_EPOCHS)
|
|
315
|
+
|
|
316
|
+
scheduler_dict = {}
|
|
317
|
+
|
|
318
|
+
scheduler_args = {
|
|
319
|
+
"MultiStepLR": {
|
|
320
|
+
"optimizer": optimizer,
|
|
321
|
+
# multi-step lr scheduler options
|
|
322
|
+
"milestones": cfg.SOLVER.STEPS,
|
|
323
|
+
"gamma": cfg.SOLVER.GAMMA,
|
|
324
|
+
},
|
|
325
|
+
"CosineAnnealingLR": {
|
|
326
|
+
"optimizer": optimizer,
|
|
327
|
+
# cosine annealing lr scheduler options
|
|
328
|
+
"T_max": max_epoch,
|
|
329
|
+
"eta_min": cfg.SOLVER.ETA_MIN_LR,
|
|
330
|
+
},
|
|
331
|
+
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
scheduler_dict["lr_sched"] = getattr(lr_scheduler, cfg.SOLVER.SCHED)(
|
|
335
|
+
**scheduler_args[cfg.SOLVER.SCHED])
|
|
336
|
+
|
|
337
|
+
if cfg.SOLVER.WARMUP_ITERS > 0:
|
|
338
|
+
warmup_args = {
|
|
339
|
+
"optimizer": optimizer,
|
|
340
|
+
|
|
341
|
+
# warmup options
|
|
342
|
+
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
|
|
343
|
+
"warmup_iters": cfg.SOLVER.WARMUP_ITERS,
|
|
344
|
+
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
|
|
345
|
+
}
|
|
346
|
+
scheduler_dict["warmup_sched"] = lr_scheduler.WarmupLR(**warmup_args)
|
|
347
|
+
|
|
348
|
+
return scheduler_dict
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.optim.lr_scheduler import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
optimizer: torch.optim.Optimizer,
|
|
17
|
+
warmup_factor: float = 0.1,
|
|
18
|
+
warmup_iters: int = 1000,
|
|
19
|
+
warmup_method: str = "linear",
|
|
20
|
+
last_epoch: int = -1,
|
|
21
|
+
):
|
|
22
|
+
self.warmup_factor = warmup_factor
|
|
23
|
+
self.warmup_iters = warmup_iters
|
|
24
|
+
self.warmup_method = warmup_method
|
|
25
|
+
super().__init__(optimizer, last_epoch)
|
|
26
|
+
|
|
27
|
+
def get_lr(self) -> List[float]:
|
|
28
|
+
warmup_factor = _get_warmup_factor_at_epoch(
|
|
29
|
+
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
|
30
|
+
)
|
|
31
|
+
return [
|
|
32
|
+
base_lr * warmup_factor for base_lr in self.base_lrs
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
def _compute_values(self) -> List[float]:
|
|
36
|
+
# The new interface
|
|
37
|
+
return self.get_lr()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_warmup_factor_at_epoch(
|
|
41
|
+
method: str, iter: int, warmup_iters: int, warmup_factor: float
|
|
42
|
+
) -> float:
|
|
43
|
+
"""
|
|
44
|
+
Return the learning rate warmup factor at a specific iteration.
|
|
45
|
+
See https://arxiv.org/abs/1706.02677 for more details.
|
|
46
|
+
Args:
|
|
47
|
+
method (str): warmup method; either "constant" or "linear".
|
|
48
|
+
iter (int): iter at which to calculate the warmup factor.
|
|
49
|
+
warmup_iters (int): the number of warmup epochs.
|
|
50
|
+
warmup_factor (float): the base warmup factor (the meaning changes according
|
|
51
|
+
to the method used).
|
|
52
|
+
Returns:
|
|
53
|
+
float: the effective warmup factor at the given iteration.
|
|
54
|
+
"""
|
|
55
|
+
if iter >= warmup_iters:
|
|
56
|
+
return 1.0
|
|
57
|
+
|
|
58
|
+
if method == "constant":
|
|
59
|
+
return warmup_factor
|
|
60
|
+
elif method == "linear":
|
|
61
|
+
alpha = iter / warmup_iters
|
|
62
|
+
return warmup_factor * (1 - alpha) + alpha
|
|
63
|
+
elif method == "exp":
|
|
64
|
+
return warmup_factor ** (1 - iter / warmup_iters)
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("Unknown warmup method: {}".format(method))
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
####
|
|
2
|
+
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
|
3
|
+
####
|
|
4
|
+
|
|
5
|
+
import collections
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.optim.optimizer import Optimizer
|
|
9
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
|
|
13
|
+
"""Log a histogram of trust ratio scalars in across layers."""
|
|
14
|
+
results = collections.defaultdict(list)
|
|
15
|
+
for group in optimizer.param_groups:
|
|
16
|
+
for p in group['params']:
|
|
17
|
+
state = optimizer.state[p]
|
|
18
|
+
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
|
|
19
|
+
if i in state:
|
|
20
|
+
results[i].append(state[i])
|
|
21
|
+
|
|
22
|
+
for k, v in results.items():
|
|
23
|
+
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Lamb(Optimizer):
|
|
27
|
+
r"""Implements Lamb algorithm.
|
|
28
|
+
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
|
29
|
+
Arguments:
|
|
30
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
31
|
+
parameter groups
|
|
32
|
+
lr (float, optional): learning rate (default: 1e-3)
|
|
33
|
+
betas (Tuple[float, float], optional): coefficients used for computing
|
|
34
|
+
running averages of gradient and its square (default: (0.9, 0.999))
|
|
35
|
+
eps (float, optional): term added to the denominator to improve
|
|
36
|
+
numerical stability (default: 1e-8)
|
|
37
|
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
38
|
+
adam (bool, optional): always use trust ratio = 1, which turns this into
|
|
39
|
+
Adam. Useful for comparison purposes.
|
|
40
|
+
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
|
|
41
|
+
https://arxiv.org/abs/1904.00962
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
|
|
45
|
+
weight_decay=0, adam=False):
|
|
46
|
+
if not 0.0 <= lr:
|
|
47
|
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
48
|
+
if not 0.0 <= eps:
|
|
49
|
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
50
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
51
|
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
52
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
53
|
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
54
|
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
55
|
+
weight_decay=weight_decay)
|
|
56
|
+
self.adam = adam
|
|
57
|
+
super(Lamb, self).__init__(params, defaults)
|
|
58
|
+
|
|
59
|
+
def step(self, closure=None):
|
|
60
|
+
"""Performs a single optimization step.
|
|
61
|
+
Arguments:
|
|
62
|
+
closure (callable, optional): A closure that reevaluates the model
|
|
63
|
+
and returns the loss.
|
|
64
|
+
"""
|
|
65
|
+
loss = None
|
|
66
|
+
if closure is not None:
|
|
67
|
+
loss = closure()
|
|
68
|
+
|
|
69
|
+
for group in self.param_groups:
|
|
70
|
+
for p in group['params']:
|
|
71
|
+
if p.grad is None:
|
|
72
|
+
continue
|
|
73
|
+
grad = p.grad.data
|
|
74
|
+
if grad.is_sparse:
|
|
75
|
+
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
|
76
|
+
|
|
77
|
+
state = self.state[p]
|
|
78
|
+
|
|
79
|
+
# State initialization
|
|
80
|
+
if len(state) == 0:
|
|
81
|
+
state['step'] = 0
|
|
82
|
+
# Exponential moving average of gradient values
|
|
83
|
+
state['exp_avg'] = torch.zeros_like(p.data)
|
|
84
|
+
# Exponential moving average of squared gradient values
|
|
85
|
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
|
86
|
+
|
|
87
|
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
88
|
+
beta1, beta2 = group['betas']
|
|
89
|
+
|
|
90
|
+
state['step'] += 1
|
|
91
|
+
|
|
92
|
+
# Decay the first and second moment running average coefficient
|
|
93
|
+
# m_t
|
|
94
|
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
|
95
|
+
# v_t
|
|
96
|
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
|
97
|
+
|
|
98
|
+
# Paper v3 does not use debiasing.
|
|
99
|
+
# bias_correction1 = 1 - beta1 ** state['step']
|
|
100
|
+
# bias_correction2 = 1 - beta2 ** state['step']
|
|
101
|
+
# Apply bias to lr to avoid broadcast.
|
|
102
|
+
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
|
|
103
|
+
|
|
104
|
+
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
|
|
105
|
+
|
|
106
|
+
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
|
107
|
+
if group['weight_decay'] != 0:
|
|
108
|
+
adam_step.add_(group['weight_decay'], p.data)
|
|
109
|
+
|
|
110
|
+
adam_norm = adam_step.pow(2).sum().sqrt()
|
|
111
|
+
if weight_norm == 0 or adam_norm == 0:
|
|
112
|
+
trust_ratio = 1
|
|
113
|
+
else:
|
|
114
|
+
trust_ratio = weight_norm / adam_norm
|
|
115
|
+
state['weight_norm'] = weight_norm
|
|
116
|
+
state['adam_norm'] = adam_norm
|
|
117
|
+
state['trust_ratio'] = trust_ratio
|
|
118
|
+
if self.adam:
|
|
119
|
+
trust_ratio = 1
|
|
120
|
+
|
|
121
|
+
p.data.add_(-step_size * trust_ratio, adam_step)
|
|
122
|
+
|
|
123
|
+
return loss
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.optim.optimizer import Optimizer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RAdam(Optimizer):
|
|
8
|
+
|
|
9
|
+
def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
|
10
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
11
|
+
self.buffer = [[None, None, None] for ind in range(10)]
|
|
12
|
+
super(RAdam, self).__init__(params, defaults)
|
|
13
|
+
|
|
14
|
+
def __setstate__(self, state):
|
|
15
|
+
super(RAdam, self).__setstate__(state)
|
|
16
|
+
|
|
17
|
+
def step(self, closure=None):
|
|
18
|
+
|
|
19
|
+
loss = None
|
|
20
|
+
if closure is not None:
|
|
21
|
+
loss = closure()
|
|
22
|
+
|
|
23
|
+
for group in self.param_groups:
|
|
24
|
+
|
|
25
|
+
for p in group['params']:
|
|
26
|
+
if p.grad is None:
|
|
27
|
+
continue
|
|
28
|
+
grad = p.grad.data.float()
|
|
29
|
+
if grad.is_sparse:
|
|
30
|
+
raise RuntimeError('RAdam does not support sparse gradients')
|
|
31
|
+
|
|
32
|
+
p_data_fp32 = p.data.float()
|
|
33
|
+
|
|
34
|
+
state = self.state[p]
|
|
35
|
+
|
|
36
|
+
if len(state) == 0:
|
|
37
|
+
state['step'] = 0
|
|
38
|
+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
|
39
|
+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
|
40
|
+
else:
|
|
41
|
+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
|
42
|
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
|
43
|
+
|
|
44
|
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
45
|
+
beta1, beta2 = group['betas']
|
|
46
|
+
|
|
47
|
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
|
48
|
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
|
49
|
+
|
|
50
|
+
state['step'] += 1
|
|
51
|
+
buffered = self.buffer[int(state['step'] % 10)]
|
|
52
|
+
if state['step'] == buffered[0]:
|
|
53
|
+
N_sma, step_size = buffered[1], buffered[2]
|
|
54
|
+
else:
|
|
55
|
+
buffered[0] = state['step']
|
|
56
|
+
beta2_t = beta2 ** state['step']
|
|
57
|
+
N_sma_max = 2 / (1 - beta2) - 1
|
|
58
|
+
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
|
59
|
+
buffered[1] = N_sma
|
|
60
|
+
|
|
61
|
+
# more conservative since it's an approximated value
|
|
62
|
+
if N_sma >= 5:
|
|
63
|
+
step_size = group['lr'] * math.sqrt(
|
|
64
|
+
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
|
65
|
+
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
|
66
|
+
else:
|
|
67
|
+
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
|
68
|
+
buffered[2] = step_size
|
|
69
|
+
|
|
70
|
+
if group['weight_decay'] != 0:
|
|
71
|
+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
|
72
|
+
|
|
73
|
+
# more conservative since it's an approximated value
|
|
74
|
+
if N_sma >= 5:
|
|
75
|
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
|
76
|
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
|
77
|
+
else:
|
|
78
|
+
p_data_fp32.add_(-step_size, exp_avg)
|
|
79
|
+
|
|
80
|
+
p.data.copy_(p_data_fp32)
|
|
81
|
+
|
|
82
|
+
return loss
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class PlainRAdam(Optimizer):
|
|
86
|
+
|
|
87
|
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
|
88
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
89
|
+
|
|
90
|
+
super(PlainRAdam, self).__init__(params, defaults)
|
|
91
|
+
|
|
92
|
+
def __setstate__(self, state):
|
|
93
|
+
super(PlainRAdam, self).__setstate__(state)
|
|
94
|
+
|
|
95
|
+
def step(self, closure=None):
|
|
96
|
+
|
|
97
|
+
loss = None
|
|
98
|
+
if closure is not None:
|
|
99
|
+
loss = closure()
|
|
100
|
+
|
|
101
|
+
for group in self.param_groups:
|
|
102
|
+
|
|
103
|
+
for p in group['params']:
|
|
104
|
+
if p.grad is None:
|
|
105
|
+
continue
|
|
106
|
+
grad = p.grad.data.float()
|
|
107
|
+
if grad.is_sparse:
|
|
108
|
+
raise RuntimeError('RAdam does not support sparse gradients')
|
|
109
|
+
|
|
110
|
+
p_data_fp32 = p.data.float()
|
|
111
|
+
|
|
112
|
+
state = self.state[p]
|
|
113
|
+
|
|
114
|
+
if len(state) == 0:
|
|
115
|
+
state['step'] = 0
|
|
116
|
+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
|
117
|
+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
|
118
|
+
else:
|
|
119
|
+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
|
120
|
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
|
121
|
+
|
|
122
|
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
123
|
+
beta1, beta2 = group['betas']
|
|
124
|
+
|
|
125
|
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
|
126
|
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
|
127
|
+
|
|
128
|
+
state['step'] += 1
|
|
129
|
+
beta2_t = beta2 ** state['step']
|
|
130
|
+
N_sma_max = 2 / (1 - beta2) - 1
|
|
131
|
+
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
|
132
|
+
|
|
133
|
+
if group['weight_decay'] != 0:
|
|
134
|
+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
|
135
|
+
|
|
136
|
+
# more conservative since it's an approximated value
|
|
137
|
+
if N_sma >= 5:
|
|
138
|
+
step_size = group['lr'] * math.sqrt(
|
|
139
|
+
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
|
140
|
+
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
|
141
|
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
|
142
|
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
|
143
|
+
else:
|
|
144
|
+
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
|
145
|
+
p_data_fp32.add_(-step_size, exp_avg)
|
|
146
|
+
|
|
147
|
+
p.data.copy_(p_data_fp32)
|
|
148
|
+
|
|
149
|
+
return loss
|