dnt 0.2.1__py3-none-any.whl → 0.3.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dnt/__init__.py +4 -1
- dnt/analysis/__init__.py +3 -1
- dnt/analysis/count.py +107 -0
- dnt/analysis/interaction2.py +518 -0
- dnt/analysis/position.py +12 -0
- dnt/analysis/stop.py +92 -33
- dnt/analysis/stop2.py +289 -0
- dnt/analysis/stop3.py +758 -0
- dnt/detect/__init__.py +1 -1
- dnt/detect/signal/detector.py +326 -0
- dnt/detect/timestamp.py +105 -0
- dnt/detect/yolov8/detector.py +182 -35
- dnt/detect/yolov8/segmentor.py +171 -0
- dnt/engine/__init__.py +8 -0
- dnt/engine/bbox_interp.py +83 -0
- dnt/engine/bbox_iou.py +20 -0
- dnt/engine/cluster.py +31 -0
- dnt/engine/iob.py +66 -0
- dnt/filter/__init__.py +4 -0
- dnt/filter/filter.py +450 -21
- dnt/label/__init__.py +1 -1
- dnt/label/labeler.py +215 -14
- dnt/label/labeler2.py +631 -0
- dnt/shared/__init__.py +2 -1
- dnt/shared/data/coco.names +0 -0
- dnt/shared/data/openimages.names +0 -0
- dnt/shared/data/voc.names +0 -0
- dnt/shared/download.py +12 -0
- dnt/shared/synhcro.py +150 -0
- dnt/shared/util.py +17 -4
- dnt/third_party/fast-reid/__init__.py +1 -0
- dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
- dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
- dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
- dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
- dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
- dnt/third_party/fast-reid/configs/__init__.py +0 -0
- dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
- dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
- dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
- dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
- dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
- dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
- dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
- dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
- dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
- dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
- dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
- dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
- dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
- dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
- dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
- dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
- dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
- dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
- dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
- dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
- dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
- dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
- dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
- dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
- dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
- dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
- dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
- dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
- dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
- dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
- dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
- dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
- dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
- dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
- dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
- dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
- dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
- dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
- dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
- dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
- dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
- dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
- dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
- dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
- dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
- dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
- dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
- dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
- dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
- dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
- dnt/track/__init__.py +3 -1
- dnt/track/botsort/__init__.py +4 -0
- dnt/track/botsort/bot_tracker/__init__.py +3 -0
- dnt/track/botsort/bot_tracker/basetrack.py +60 -0
- dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
- dnt/track/botsort/bot_tracker/gmc.py +316 -0
- dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
- dnt/track/botsort/bot_tracker/matching.py +194 -0
- dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
- dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
- dnt/track/botsort/inference.py +96 -0
- dnt/track/config.py +120 -0
- dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
- dnt/track/dsort/configs/deep_sort.yaml +0 -0
- dnt/track/dsort/configs/fastreid.yaml +1 -1
- dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
- dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
- dnt/track/dsort/deep_sort/deep_sort.py +31 -21
- dnt/track/dsort/deep_sort/sort/detection.py +2 -1
- dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
- dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
- dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
- dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
- dnt/track/dsort/deep_sort/sort/track.py +2 -1
- dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
- dnt/track/dsort/dsort.py +44 -27
- dnt/track/re_class.py +117 -0
- dnt/track/sort/sort.py +9 -7
- dnt/track/tracker.py +225 -20
- dnt-0.3.1.8.dist-info/METADATA +117 -0
- dnt-0.3.1.8.dist-info/RECORD +315 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/WHEEL +1 -1
- dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
- dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
- dnt/track/dsort/deep_sort/deep/test.py +0 -77
- dnt/track/dsort/deep_sort/deep/train.py +0 -189
- dnt/track/dsort/utils/asserts.py +0 -13
- dnt/track/dsort/utils/draw.py +0 -36
- dnt/track/dsort/utils/json_logger.py +0 -383
- dnt/track/dsort/utils/log.py +0 -17
- dnt/track/dsort/utils/parser.py +0 -35
- dnt/track/dsort/utils/tools.py +0 -39
- dnt-0.2.1.dist-info/METADATA +0 -35
- dnt-0.2.1.dist-info/RECORD +0 -60
- /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info/licenses}/LICENSE +0 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import requests
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
import numpy as np
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
# from torch.backends import cudnn
|
|
10
|
+
|
|
11
|
+
from fastreid.config import get_cfg
|
|
12
|
+
from fastreid.modeling.meta_arch import build_model
|
|
13
|
+
from fastreid.utils.checkpoint import Checkpointer
|
|
14
|
+
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
|
15
|
+
|
|
16
|
+
__all__ = ['FastReIDInterface', 'setup_cfg', 'postprocess', 'preprocess']
|
|
17
|
+
|
|
18
|
+
def setup_cfg(config_file, opts):
|
|
19
|
+
# load config from file and command-line arguments
|
|
20
|
+
cfg = get_cfg()
|
|
21
|
+
cfg.merge_from_file(config_file)
|
|
22
|
+
cfg.merge_from_list(opts)
|
|
23
|
+
cfg.MODEL.BACKBONE.PRETRAIN = False
|
|
24
|
+
|
|
25
|
+
cfg.freeze()
|
|
26
|
+
|
|
27
|
+
return cfg
|
|
28
|
+
|
|
29
|
+
def check_weights(weights_path):
|
|
30
|
+
if not os.path.exists(weights_path):
|
|
31
|
+
try:
|
|
32
|
+
url = 'https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/'+os.path.basename(weights_path)
|
|
33
|
+
response = requests.get(url, stream=True, allow_redirects=True)
|
|
34
|
+
with open(weights_path, mode="wb") as file:
|
|
35
|
+
pbar = tqdm(unit="B", total=int(response.headers['Content-Length']),
|
|
36
|
+
desc="Downloading "+os.path.basename(weights_path)+" ...")
|
|
37
|
+
for chunk in response.iter_content(chunk_size=10 * 1024):
|
|
38
|
+
file.write(chunk)
|
|
39
|
+
pbar.update(len(chunk))
|
|
40
|
+
except ValueError:
|
|
41
|
+
raise FileNotFoundError(f"File {weights_path} not found and cannot be downloaded from {url}!")
|
|
42
|
+
|
|
43
|
+
def postprocess(features):
|
|
44
|
+
# Normalize feature to compute cosine distance
|
|
45
|
+
features = F.normalize(features)
|
|
46
|
+
features = features.cpu().data.numpy()
|
|
47
|
+
return features
|
|
48
|
+
|
|
49
|
+
def preprocess(image, input_size):
|
|
50
|
+
if len(image.shape) == 3:
|
|
51
|
+
padded_img = np.ones((input_size[1], input_size[0], 3), dtype=np.uint8) * 114
|
|
52
|
+
else:
|
|
53
|
+
padded_img = np.ones(input_size) * 114
|
|
54
|
+
img = np.array(image)
|
|
55
|
+
r = min(input_size[1] / img.shape[0], input_size[0] / img.shape[1])
|
|
56
|
+
resized_img = cv2.resize(
|
|
57
|
+
img,
|
|
58
|
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
|
59
|
+
interpolation=cv2.INTER_LINEAR,
|
|
60
|
+
)
|
|
61
|
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
|
62
|
+
|
|
63
|
+
return padded_img, r
|
|
64
|
+
|
|
65
|
+
class FastReIDInterface:
|
|
66
|
+
def __init__(self, config_file:str, weights_path:str, device:str='cuda', half:bool=True, batch_size:int=1):
|
|
67
|
+
super(FastReIDInterface, self).__init__()
|
|
68
|
+
if device != 'cpu':
|
|
69
|
+
self.device = 'cuda'
|
|
70
|
+
else:
|
|
71
|
+
self.device = 'cpu'
|
|
72
|
+
self.half = half
|
|
73
|
+
self.batch_size = batch_size
|
|
74
|
+
|
|
75
|
+
check_weights(weights_path)
|
|
76
|
+
|
|
77
|
+
self.cfg = setup_cfg(config_file, ['MODEL.WEIGHTS', weights_path])
|
|
78
|
+
self.model = build_model(self.cfg)
|
|
79
|
+
self.model.eval()
|
|
80
|
+
|
|
81
|
+
Checkpointer(self.model).load(weights_path)
|
|
82
|
+
|
|
83
|
+
if self.device == 'cuda':
|
|
84
|
+
if self.half:
|
|
85
|
+
self.model = self.model.eval().to(device='cuda').half()
|
|
86
|
+
else:
|
|
87
|
+
self.model = self.model.eval().to(device='cuda')
|
|
88
|
+
else:
|
|
89
|
+
self.model = self.model.eval()
|
|
90
|
+
|
|
91
|
+
self.pH, self.pW = self.cfg.INPUT.SIZE_TEST
|
|
92
|
+
|
|
93
|
+
def inference(self, image:np.array, detections:np.array)->np.array:
|
|
94
|
+
|
|
95
|
+
if detections is None or np.size(detections) == 0:
|
|
96
|
+
return []
|
|
97
|
+
|
|
98
|
+
H, W, _ = np.shape(image)
|
|
99
|
+
|
|
100
|
+
batch_patches = []
|
|
101
|
+
patches = []
|
|
102
|
+
for d in range(np.size(detections, 0)):
|
|
103
|
+
tlbr = detections[d, :4].astype(np.int_)
|
|
104
|
+
tlbr[0] = max(0, tlbr[0]) #top left x
|
|
105
|
+
tlbr[1] = max(0, tlbr[1]) #top left y
|
|
106
|
+
tlbr[2] = min(W - 1, tlbr[2]) #bottom right x
|
|
107
|
+
tlbr[3] = min(H - 1, tlbr[3]) #bottom right y
|
|
108
|
+
patch = image[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2], :] #crop the bbox
|
|
109
|
+
|
|
110
|
+
# the model expects RGB inputs
|
|
111
|
+
patch = patch[:, :, ::-1]
|
|
112
|
+
|
|
113
|
+
# Apply pre-processing to image.
|
|
114
|
+
patch = cv2.resize(patch, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_LINEAR)
|
|
115
|
+
# patch, scale = preprocess(patch, self.cfg.INPUT.SIZE_TEST[::-1])
|
|
116
|
+
|
|
117
|
+
#plt.figure()
|
|
118
|
+
#plt.imshow(patch)
|
|
119
|
+
#plt.show()
|
|
120
|
+
|
|
121
|
+
# Make shape with a new batch dimension which is adapted for network input
|
|
122
|
+
patch = torch.as_tensor(patch.astype("float32").transpose(2, 0, 1))
|
|
123
|
+
if self.half:
|
|
124
|
+
patch = patch.to(device=self.device).half()
|
|
125
|
+
patches.append(patch)
|
|
126
|
+
|
|
127
|
+
if (d + 1) % self.batch_size == 0:
|
|
128
|
+
patches = torch.stack(patches, dim=0)
|
|
129
|
+
batch_patches.append(patches)
|
|
130
|
+
patches = []
|
|
131
|
+
|
|
132
|
+
if len(patches):
|
|
133
|
+
patches = torch.stack(patches, dim=0)
|
|
134
|
+
batch_patches.append(patches)
|
|
135
|
+
|
|
136
|
+
features = np.zeros((0, 2048))
|
|
137
|
+
# features = np.zeros((0, 768))
|
|
138
|
+
|
|
139
|
+
for patches in batch_patches:
|
|
140
|
+
if self.device == 'cuda':
|
|
141
|
+
patches = patches.to(device='cuda')
|
|
142
|
+
else:
|
|
143
|
+
patches = patches.to(device='cpu')
|
|
144
|
+
# Run model
|
|
145
|
+
patches_ = torch.clone(patches)
|
|
146
|
+
pred = self.model(patches)
|
|
147
|
+
pred[torch.isinf(pred)] = 1.0
|
|
148
|
+
|
|
149
|
+
feat = postprocess(pred)
|
|
150
|
+
nans = np.isnan(np.sum(feat, axis=1))
|
|
151
|
+
if np.isnan(feat).any():
|
|
152
|
+
for n in range(np.size(nans)):
|
|
153
|
+
if nans[n]:
|
|
154
|
+
# patch_np = patches[n, ...].squeeze().transpose(1, 2, 0).cpu().numpy()
|
|
155
|
+
patch_np = patches_[n, ...]
|
|
156
|
+
patch_np_ = torch.unsqueeze(patch_np, 0)
|
|
157
|
+
pred_ = self.model(patch_np_)
|
|
158
|
+
|
|
159
|
+
patch_np = torch.squeeze(patch_np).cpu()
|
|
160
|
+
patch_np = torch.permute(patch_np, (1, 2, 0)).int()
|
|
161
|
+
patch_np = patch_np.numpy()
|
|
162
|
+
|
|
163
|
+
features = np.vstack((features, feat))
|
|
164
|
+
|
|
165
|
+
return features
|
|
166
|
+
|
|
167
|
+
if __name__=='__main__':
|
|
168
|
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
169
|
+
extractor = FastReIDInterface(os.path.join(cwd, 'configs/MOT17/sbs_S50.yml'),
|
|
170
|
+
os.path.join(cwd, '/home/wonstran/dnt/detntrack_0.3alpha/dnt/track/botsort/checkpoint/mot17_sbs_S50.pth'))
|
|
171
|
+
|
|
172
|
+
image = cv2.imread('/mnt/d/videos/ped2stage/frames/824.jpg')
|
|
173
|
+
detections = np.array([[971, 482, 1016, 611], [913, 482, 963, 614]])
|
|
174
|
+
features = extractor.inference(image, detections)
|
|
175
|
+
print(features)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: l1aoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
'CfgNode',
|
|
11
|
+
'get_cfg',
|
|
12
|
+
'global_cfg',
|
|
13
|
+
'set_global_cfg',
|
|
14
|
+
'configurable'
|
|
15
|
+
]
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: l1aoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import inspect
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import yaml
|
|
14
|
+
from yacs.config import CfgNode as _CfgNode
|
|
15
|
+
|
|
16
|
+
from ..utils.file_io import PathManager
|
|
17
|
+
|
|
18
|
+
BASE_KEY = "_BASE_"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CfgNode(_CfgNode):
|
|
22
|
+
"""
|
|
23
|
+
Our own extended version of :class:`yacs.config.CfgNode`.
|
|
24
|
+
It contains the following extra features:
|
|
25
|
+
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
|
|
26
|
+
which allows the new CfgNode to inherit all the attributes from the
|
|
27
|
+
base configuration file.
|
|
28
|
+
2. Keys that start with "COMPUTED_" are treated as insertion-only
|
|
29
|
+
"computed" attributes. They can be inserted regardless of whether
|
|
30
|
+
the CfgNode is frozen or not.
|
|
31
|
+
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
|
|
32
|
+
expressions in config. See examples in
|
|
33
|
+
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
|
|
34
|
+
Note that this may lead to arbitrary code execution: you must not
|
|
35
|
+
load a config file from untrusted sources before manually inspecting
|
|
36
|
+
the content of the file.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def load_yaml_with_base(filename: str, allow_unsafe: bool = False):
|
|
41
|
+
"""
|
|
42
|
+
Just like `yaml.load(open(filename))`, but inherit attributes from its
|
|
43
|
+
`_BASE_`.
|
|
44
|
+
Args:
|
|
45
|
+
filename (str): the file name of the current config. Will be used to
|
|
46
|
+
find the base config file.
|
|
47
|
+
allow_unsafe (bool): whether to allow loading the config file with
|
|
48
|
+
`yaml.unsafe_load`.
|
|
49
|
+
Returns:
|
|
50
|
+
(dict): the loaded yaml
|
|
51
|
+
"""
|
|
52
|
+
with PathManager.open(filename, "r") as f:
|
|
53
|
+
try:
|
|
54
|
+
cfg = yaml.safe_load(f)
|
|
55
|
+
except yaml.constructor.ConstructorError:
|
|
56
|
+
if not allow_unsafe:
|
|
57
|
+
raise
|
|
58
|
+
logger = logging.getLogger(__name__)
|
|
59
|
+
logger.warning(
|
|
60
|
+
"Loading config {} with yaml.unsafe_load. Your machine may "
|
|
61
|
+
"be at risk if the file contains malicious content.".format(
|
|
62
|
+
filename
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
f.close()
|
|
66
|
+
with open(filename, "r") as f:
|
|
67
|
+
cfg = yaml.unsafe_load(f)
|
|
68
|
+
|
|
69
|
+
def merge_a_into_b(a, b):
|
|
70
|
+
# merge dict a into dict b. values in a will overwrite b.
|
|
71
|
+
for k, v in a.items():
|
|
72
|
+
if isinstance(v, dict) and k in b:
|
|
73
|
+
assert isinstance(
|
|
74
|
+
b[k], dict
|
|
75
|
+
), "Cannot inherit key '{}' from base!".format(k)
|
|
76
|
+
merge_a_into_b(v, b[k])
|
|
77
|
+
else:
|
|
78
|
+
b[k] = v
|
|
79
|
+
|
|
80
|
+
if BASE_KEY in cfg:
|
|
81
|
+
base_cfg_file = cfg[BASE_KEY]
|
|
82
|
+
if base_cfg_file.startswith("~"):
|
|
83
|
+
base_cfg_file = os.path.expanduser(base_cfg_file)
|
|
84
|
+
if not any(
|
|
85
|
+
map(base_cfg_file.startswith, ["/", "https://", "http://"])
|
|
86
|
+
):
|
|
87
|
+
# the path to base cfg is relative to the config file itself.
|
|
88
|
+
base_cfg_file = os.path.join(
|
|
89
|
+
os.path.dirname(filename), base_cfg_file
|
|
90
|
+
)
|
|
91
|
+
base_cfg = CfgNode.load_yaml_with_base(
|
|
92
|
+
base_cfg_file, allow_unsafe=allow_unsafe
|
|
93
|
+
)
|
|
94
|
+
del cfg[BASE_KEY]
|
|
95
|
+
|
|
96
|
+
merge_a_into_b(cfg, base_cfg)
|
|
97
|
+
return base_cfg
|
|
98
|
+
return cfg
|
|
99
|
+
|
|
100
|
+
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False):
|
|
101
|
+
"""
|
|
102
|
+
Merge configs from a given yaml file.
|
|
103
|
+
Args:
|
|
104
|
+
cfg_filename: the file name of the yaml config.
|
|
105
|
+
allow_unsafe: whether to allow loading the config file with
|
|
106
|
+
`yaml.unsafe_load`.
|
|
107
|
+
"""
|
|
108
|
+
loaded_cfg = CfgNode.load_yaml_with_base(
|
|
109
|
+
cfg_filename, allow_unsafe=allow_unsafe
|
|
110
|
+
)
|
|
111
|
+
loaded_cfg = type(self)(loaded_cfg)
|
|
112
|
+
self.merge_from_other_cfg(loaded_cfg)
|
|
113
|
+
|
|
114
|
+
# Forward the following calls to base, but with a check on the BASE_KEY.
|
|
115
|
+
def merge_from_other_cfg(self, cfg_other):
|
|
116
|
+
"""
|
|
117
|
+
Args:
|
|
118
|
+
cfg_other (CfgNode): configs to merge from.
|
|
119
|
+
"""
|
|
120
|
+
assert (
|
|
121
|
+
BASE_KEY not in cfg_other
|
|
122
|
+
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
|
123
|
+
return super().merge_from_other_cfg(cfg_other)
|
|
124
|
+
|
|
125
|
+
def merge_from_list(self, cfg_list: list):
|
|
126
|
+
"""
|
|
127
|
+
Args:
|
|
128
|
+
cfg_list (list): list of configs to merge from.
|
|
129
|
+
"""
|
|
130
|
+
keys = set(cfg_list[0::2])
|
|
131
|
+
assert (
|
|
132
|
+
BASE_KEY not in keys
|
|
133
|
+
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
|
134
|
+
return super().merge_from_list(cfg_list)
|
|
135
|
+
|
|
136
|
+
def __setattr__(self, name: str, val: Any):
|
|
137
|
+
if name.startswith("COMPUTED_"):
|
|
138
|
+
if name in self:
|
|
139
|
+
old_val = self[name]
|
|
140
|
+
if old_val == val:
|
|
141
|
+
return
|
|
142
|
+
raise KeyError(
|
|
143
|
+
"Computed attributed '{}' already exists "
|
|
144
|
+
"with a different value! old={}, new={}.".format(
|
|
145
|
+
name, old_val, val
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
self[name] = val
|
|
149
|
+
else:
|
|
150
|
+
super().__setattr__(name, val)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
global_cfg = CfgNode()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_cfg() -> CfgNode:
|
|
157
|
+
"""
|
|
158
|
+
Get a copy of the default config.
|
|
159
|
+
Returns:
|
|
160
|
+
a fastreid CfgNode instance.
|
|
161
|
+
"""
|
|
162
|
+
from .defaults import _C
|
|
163
|
+
|
|
164
|
+
return _C.clone()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def set_global_cfg(cfg: CfgNode) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Let the global config point to the given cfg.
|
|
170
|
+
Assume that the given "cfg" has the key "KEY", after calling
|
|
171
|
+
`set_global_cfg(cfg)`, the key can be accessed by:
|
|
172
|
+
::
|
|
173
|
+
from detectron2.config import global_cfg
|
|
174
|
+
print(global_cfg.KEY)
|
|
175
|
+
By using a hacky global config, you can access these configs anywhere,
|
|
176
|
+
without having to pass the config object or the values deep into the code.
|
|
177
|
+
This is a hacky feature introduced for quick prototyping / research exploration.
|
|
178
|
+
"""
|
|
179
|
+
global global_cfg
|
|
180
|
+
global_cfg.clear()
|
|
181
|
+
global_cfg.update(cfg)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def configurable(init_func=None, *, from_config=None):
|
|
185
|
+
"""
|
|
186
|
+
Decorate a function or a class's __init__ method so that it can be called
|
|
187
|
+
with a :class:`CfgNode` object using a :func:`from_config` function that translates
|
|
188
|
+
:class:`CfgNode` to arguments.
|
|
189
|
+
Examples:
|
|
190
|
+
::
|
|
191
|
+
# Usage 1: Decorator on __init__:
|
|
192
|
+
class A:
|
|
193
|
+
@configurable
|
|
194
|
+
def __init__(self, a, b=2, c=3):
|
|
195
|
+
pass
|
|
196
|
+
@classmethod
|
|
197
|
+
def from_config(cls, cfg): # 'cfg' must be the first argument
|
|
198
|
+
# Returns kwargs to be passed to __init__
|
|
199
|
+
return {"a": cfg.A, "b": cfg.B}
|
|
200
|
+
a1 = A(a=1, b=2) # regular construction
|
|
201
|
+
a2 = A(cfg) # construct with a cfg
|
|
202
|
+
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
|
|
203
|
+
# Usage 2: Decorator on any function. Needs an extra from_config argument:
|
|
204
|
+
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
|
|
205
|
+
def a_func(a, b=2, c=3):
|
|
206
|
+
pass
|
|
207
|
+
a1 = a_func(a=1, b=2) # regular call
|
|
208
|
+
a2 = a_func(cfg) # call with a cfg
|
|
209
|
+
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
|
|
210
|
+
Args:
|
|
211
|
+
init_func (callable): a class's ``__init__`` method in usage 1. The
|
|
212
|
+
class must have a ``from_config`` classmethod which takes `cfg` as
|
|
213
|
+
the first argument.
|
|
214
|
+
from_config (callable): the from_config function in usage 2. It must take `cfg`
|
|
215
|
+
as its first argument.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def check_docstring(func):
|
|
219
|
+
if func.__module__.startswith("fastreid."):
|
|
220
|
+
assert (
|
|
221
|
+
func.__doc__ is not None and "experimental" in func.__doc__.lower()
|
|
222
|
+
), f"configurable {func} should be marked experimental"
|
|
223
|
+
|
|
224
|
+
if init_func is not None:
|
|
225
|
+
assert (
|
|
226
|
+
inspect.isfunction(init_func)
|
|
227
|
+
and from_config is None
|
|
228
|
+
and init_func.__name__ == "__init__"
|
|
229
|
+
), "Incorrect use of @configurable. Check API documentation for examples."
|
|
230
|
+
check_docstring(init_func)
|
|
231
|
+
|
|
232
|
+
@functools.wraps(init_func)
|
|
233
|
+
def wrapped(self, *args, **kwargs):
|
|
234
|
+
try:
|
|
235
|
+
from_config_func = type(self).from_config
|
|
236
|
+
except AttributeError as e:
|
|
237
|
+
raise AttributeError(
|
|
238
|
+
"Class with @configurable must have a 'from_config' classmethod."
|
|
239
|
+
) from e
|
|
240
|
+
if not inspect.ismethod(from_config_func):
|
|
241
|
+
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
|
|
242
|
+
|
|
243
|
+
if _called_with_cfg(*args, **kwargs):
|
|
244
|
+
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
|
|
245
|
+
init_func(self, **explicit_args)
|
|
246
|
+
else:
|
|
247
|
+
init_func(self, *args, **kwargs)
|
|
248
|
+
|
|
249
|
+
return wrapped
|
|
250
|
+
|
|
251
|
+
else:
|
|
252
|
+
if from_config is None:
|
|
253
|
+
return configurable # @configurable() is made equivalent to @configurable
|
|
254
|
+
assert inspect.isfunction(
|
|
255
|
+
from_config
|
|
256
|
+
), "from_config argument of configurable must be a function!"
|
|
257
|
+
|
|
258
|
+
def wrapper(orig_func):
|
|
259
|
+
check_docstring(orig_func)
|
|
260
|
+
|
|
261
|
+
@functools.wraps(orig_func)
|
|
262
|
+
def wrapped(*args, **kwargs):
|
|
263
|
+
if _called_with_cfg(*args, **kwargs):
|
|
264
|
+
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
|
|
265
|
+
return orig_func(**explicit_args)
|
|
266
|
+
else:
|
|
267
|
+
return orig_func(*args, **kwargs)
|
|
268
|
+
|
|
269
|
+
return wrapped
|
|
270
|
+
|
|
271
|
+
return wrapper
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _get_args_from_config(from_config_func, *args, **kwargs):
|
|
275
|
+
"""
|
|
276
|
+
Use `from_config` to obtain explicit arguments.
|
|
277
|
+
Returns:
|
|
278
|
+
dict: arguments to be used for cls.__init__
|
|
279
|
+
"""
|
|
280
|
+
signature = inspect.signature(from_config_func)
|
|
281
|
+
if list(signature.parameters.keys())[0] != "cfg":
|
|
282
|
+
if inspect.isfunction(from_config_func):
|
|
283
|
+
name = from_config_func.__name__
|
|
284
|
+
else:
|
|
285
|
+
name = f"{from_config_func.__self__}.from_config"
|
|
286
|
+
raise TypeError(f"{name} must take 'cfg' as the first argument!")
|
|
287
|
+
support_var_arg = any(
|
|
288
|
+
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
|
|
289
|
+
for param in signature.parameters.values()
|
|
290
|
+
)
|
|
291
|
+
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
|
|
292
|
+
ret = from_config_func(*args, **kwargs)
|
|
293
|
+
else:
|
|
294
|
+
# forward supported arguments to from_config
|
|
295
|
+
supported_arg_names = set(signature.parameters.keys())
|
|
296
|
+
extra_kwargs = {}
|
|
297
|
+
for name in list(kwargs.keys()):
|
|
298
|
+
if name not in supported_arg_names:
|
|
299
|
+
extra_kwargs[name] = kwargs.pop(name)
|
|
300
|
+
ret = from_config_func(*args, **kwargs)
|
|
301
|
+
# forward the other arguments to __init__
|
|
302
|
+
ret.update(extra_kwargs)
|
|
303
|
+
return ret
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _called_with_cfg(*args, **kwargs):
|
|
307
|
+
"""
|
|
308
|
+
Returns:
|
|
309
|
+
bool: whether the arguments contain CfgNode and should be considered
|
|
310
|
+
forwarded to from_config.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
if len(args) and isinstance(args[0], _CfgNode):
|
|
314
|
+
return True
|
|
315
|
+
if isinstance(kwargs.pop("cfg", None), _CfgNode):
|
|
316
|
+
return True
|
|
317
|
+
# `from_config`'s first argument is forced to be "cfg".
|
|
318
|
+
# So the above check covers all cases.
|
|
319
|
+
return False
|