dnt 0.2.1__py3-none-any.whl → 0.3.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dnt/__init__.py +4 -1
- dnt/analysis/__init__.py +3 -1
- dnt/analysis/count.py +107 -0
- dnt/analysis/interaction2.py +518 -0
- dnt/analysis/position.py +12 -0
- dnt/analysis/stop.py +92 -33
- dnt/analysis/stop2.py +289 -0
- dnt/analysis/stop3.py +758 -0
- dnt/detect/__init__.py +1 -1
- dnt/detect/signal/detector.py +326 -0
- dnt/detect/timestamp.py +105 -0
- dnt/detect/yolov8/detector.py +182 -35
- dnt/detect/yolov8/segmentor.py +171 -0
- dnt/engine/__init__.py +8 -0
- dnt/engine/bbox_interp.py +83 -0
- dnt/engine/bbox_iou.py +20 -0
- dnt/engine/cluster.py +31 -0
- dnt/engine/iob.py +66 -0
- dnt/filter/__init__.py +4 -0
- dnt/filter/filter.py +450 -21
- dnt/label/__init__.py +1 -1
- dnt/label/labeler.py +215 -14
- dnt/label/labeler2.py +631 -0
- dnt/shared/__init__.py +2 -1
- dnt/shared/data/coco.names +0 -0
- dnt/shared/data/openimages.names +0 -0
- dnt/shared/data/voc.names +0 -0
- dnt/shared/download.py +12 -0
- dnt/shared/synhcro.py +150 -0
- dnt/shared/util.py +17 -4
- dnt/third_party/fast-reid/__init__.py +1 -0
- dnt/third_party/fast-reid/configs/Base-AGW.yml +19 -0
- dnt/third_party/fast-reid/configs/Base-MGN.yml +12 -0
- dnt/third_party/fast-reid/configs/Base-SBS.yml +63 -0
- dnt/third_party/fast-reid/configs/Base-bagtricks.yml +76 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/DukeMTMC/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MOT20/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R101-ibn.yml +13 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/bagtricks_S50.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/MSMT17/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/AGW_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/bagtricks_vit.yml +88 -0
- dnt/third_party/fast-reid/configs/Market1501/mgn_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R101-ibn.yml +12 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50-ibn.yml +11 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_R50.yml +7 -0
- dnt/third_party/fast-reid/configs/Market1501/sbs_S50.yml +11 -0
- dnt/third_party/fast-reid/configs/VERIWild/bagtricks_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VeRi/sbs_R50-ibn.yml +35 -0
- dnt/third_party/fast-reid/configs/VehicleID/bagtricks_R50-ibn.yml +36 -0
- dnt/third_party/fast-reid/configs/__init__.py +0 -0
- dnt/third_party/fast-reid/fast_reid_interfece.py +175 -0
- dnt/third_party/fast-reid/fastreid/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/config/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/config/config.py +319 -0
- dnt/third_party/fast-reid/fastreid/config/defaults.py +329 -0
- dnt/third_party/fast-reid/fastreid/data/__init__.py +17 -0
- dnt/third_party/fast-reid/fastreid/data/build.py +194 -0
- dnt/third_party/fast-reid/fastreid/data/common.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/data_utils.py +202 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/AirportALERT.py +50 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/__init__.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/bases.py +183 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/caviara.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk03.py +274 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/cuhk_sysu.py +58 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/dukemtmcreid.py +70 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/grid.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/iLIDS.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/lpw.py +49 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/market1501.py +89 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/msmt17.py +114 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pes3d.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/pku.py +44 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prai.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/prid.py +41 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/saivt.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sensereid.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/shinpuhkan.py +48 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/sysu_mm.py +47 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/thermalworld.py +43 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/vehicleid.py +126 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veri.py +69 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/veriwild.py +140 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/viper.py +45 -0
- dnt/third_party/fast-reid/fastreid/data/datasets/wildtracker.py +59 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/data_sampler.py +85 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/imbalance_sampler.py +67 -0
- dnt/third_party/fast-reid/fastreid/data/samplers/triplet_sampler.py +260 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/autoaugment.py +806 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/build.py +100 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/functional.py +180 -0
- dnt/third_party/fast-reid/fastreid/data/transforms/transforms.py +161 -0
- dnt/third_party/fast-reid/fastreid/engine/__init__.py +15 -0
- dnt/third_party/fast-reid/fastreid/engine/defaults.py +490 -0
- dnt/third_party/fast-reid/fastreid/engine/hooks.py +534 -0
- dnt/third_party/fast-reid/fastreid/engine/launch.py +103 -0
- dnt/third_party/fast-reid/fastreid/engine/train_loop.py +357 -0
- dnt/third_party/fast-reid/fastreid/evaluation/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/evaluation/clas_evaluator.py +81 -0
- dnt/third_party/fast-reid/fastreid/evaluation/evaluator.py +176 -0
- dnt/third_party/fast-reid/fastreid/evaluation/query_expansion.py +46 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank.py +200 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/__init__.py +20 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/setup.py +32 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rank_cylib/test_cython.py +106 -0
- dnt/third_party/fast-reid/fastreid/evaluation/reid_evaluation.py +143 -0
- dnt/third_party/fast-reid/fastreid/evaluation/rerank.py +73 -0
- dnt/third_party/fast-reid/fastreid/evaluation/roc.py +90 -0
- dnt/third_party/fast-reid/fastreid/evaluation/testing.py +88 -0
- dnt/third_party/fast-reid/fastreid/layers/__init__.py +19 -0
- dnt/third_party/fast-reid/fastreid/layers/activation.py +59 -0
- dnt/third_party/fast-reid/fastreid/layers/any_softmax.py +80 -0
- dnt/third_party/fast-reid/fastreid/layers/batch_norm.py +205 -0
- dnt/third_party/fast-reid/fastreid/layers/context_block.py +113 -0
- dnt/third_party/fast-reid/fastreid/layers/drop.py +161 -0
- dnt/third_party/fast-reid/fastreid/layers/frn.py +199 -0
- dnt/third_party/fast-reid/fastreid/layers/gather_layer.py +30 -0
- dnt/third_party/fast-reid/fastreid/layers/helpers.py +31 -0
- dnt/third_party/fast-reid/fastreid/layers/non_local.py +54 -0
- dnt/third_party/fast-reid/fastreid/layers/pooling.py +124 -0
- dnt/third_party/fast-reid/fastreid/layers/se_layer.py +25 -0
- dnt/third_party/fast-reid/fastreid/layers/splat.py +109 -0
- dnt/third_party/fast-reid/fastreid/layers/weight_init.py +122 -0
- dnt/third_party/fast-reid/fastreid/modeling/__init__.py +23 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/__init__.py +18 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/build.py +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenet.py +195 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/mobilenetv3.py +283 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/osnet.py +525 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/__init__.py +4 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/config.py +396 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/effnet.py +281 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnet.py +596 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml +27 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/repvgg.py +309 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnest.py +365 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnet.py +364 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/resnext.py +335 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/shufflenet.py +203 -0
- dnt/third_party/fast-reid/fastreid/modeling/backbones/vision_transformer.py +399 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/__init__.py +11 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/build.py +25 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/clas_head.py +36 -0
- dnt/third_party/fast-reid/fastreid/modeling/heads/embedding_head.py +151 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/__init__.py +12 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/circle_loss.py +71 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/cross_entroy_loss.py +54 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/focal_loss.py +92 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/triplet_loss.py +113 -0
- dnt/third_party/fast-reid/fastreid/modeling/losses/utils.py +48 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/__init__.py +14 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/baseline.py +188 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/build.py +26 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/distiller.py +140 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/mgn.py +394 -0
- dnt/third_party/fast-reid/fastreid/modeling/meta_arch/moco.py +126 -0
- dnt/third_party/fast-reid/fastreid/solver/__init__.py +8 -0
- dnt/third_party/fast-reid/fastreid/solver/build.py +348 -0
- dnt/third_party/fast-reid/fastreid/solver/lr_scheduler.py +66 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/__init__.py +10 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/lamb.py +123 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/radam.py +149 -0
- dnt/third_party/fast-reid/fastreid/solver/optim/swa.py +246 -0
- dnt/third_party/fast-reid/fastreid/utils/__init__.py +6 -0
- dnt/third_party/fast-reid/fastreid/utils/checkpoint.py +503 -0
- dnt/third_party/fast-reid/fastreid/utils/collect_env.py +158 -0
- dnt/third_party/fast-reid/fastreid/utils/comm.py +255 -0
- dnt/third_party/fast-reid/fastreid/utils/compute_dist.py +200 -0
- dnt/third_party/fast-reid/fastreid/utils/env.py +119 -0
- dnt/third_party/fast-reid/fastreid/utils/events.py +461 -0
- dnt/third_party/fast-reid/fastreid/utils/faiss_utils.py +127 -0
- dnt/third_party/fast-reid/fastreid/utils/file_io.py +520 -0
- dnt/third_party/fast-reid/fastreid/utils/history_buffer.py +71 -0
- dnt/third_party/fast-reid/fastreid/utils/logger.py +211 -0
- dnt/third_party/fast-reid/fastreid/utils/params.py +103 -0
- dnt/third_party/fast-reid/fastreid/utils/precision_bn.py +94 -0
- dnt/third_party/fast-reid/fastreid/utils/registry.py +66 -0
- dnt/third_party/fast-reid/fastreid/utils/summary.py +120 -0
- dnt/third_party/fast-reid/fastreid/utils/timer.py +68 -0
- dnt/third_party/fast-reid/fastreid/utils/visualizer.py +278 -0
- dnt/track/__init__.py +3 -1
- dnt/track/botsort/__init__.py +4 -0
- dnt/track/botsort/bot_tracker/__init__.py +3 -0
- dnt/track/botsort/bot_tracker/basetrack.py +60 -0
- dnt/track/botsort/bot_tracker/bot_sort.py +473 -0
- dnt/track/botsort/bot_tracker/gmc.py +316 -0
- dnt/track/botsort/bot_tracker/kalman_filter.py +269 -0
- dnt/track/botsort/bot_tracker/matching.py +194 -0
- dnt/track/botsort/bot_tracker/mc_bot_sort.py +505 -0
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/evaluation.py +14 -4
- dnt/track/{dsort/utils → botsort/bot_tracker/tracking_utils}/io.py +19 -36
- dnt/track/botsort/bot_tracker/tracking_utils/timer.py +37 -0
- dnt/track/botsort/inference.py +96 -0
- dnt/track/config.py +120 -0
- dnt/track/dsort/configs/bagtricks_R50.yml +7 -0
- dnt/track/dsort/configs/deep_sort.yaml +0 -0
- dnt/track/dsort/configs/fastreid.yaml +1 -1
- dnt/track/dsort/deep_sort/deep/checkpoint/ckpt.t7 +0 -0
- dnt/track/dsort/deep_sort/deep/feature_extractor.py +87 -8
- dnt/track/dsort/deep_sort/deep_sort.py +31 -21
- dnt/track/dsort/deep_sort/sort/detection.py +2 -1
- dnt/track/dsort/deep_sort/sort/iou_matching.py +0 -2
- dnt/track/dsort/deep_sort/sort/linear_assignment.py +0 -3
- dnt/track/dsort/deep_sort/sort/nn_matching.py +5 -5
- dnt/track/dsort/deep_sort/sort/preprocessing.py +1 -2
- dnt/track/dsort/deep_sort/sort/track.py +2 -1
- dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
- dnt/track/dsort/dsort.py +44 -27
- dnt/track/re_class.py +117 -0
- dnt/track/sort/sort.py +9 -7
- dnt/track/tracker.py +225 -20
- dnt-0.3.1.8.dist-info/METADATA +117 -0
- dnt-0.3.1.8.dist-info/RECORD +315 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/WHEEL +1 -1
- dnt/track/dsort/deep_sort/deep/evaluate.py +0 -15
- dnt/track/dsort/deep_sort/deep/original_model.py +0 -106
- dnt/track/dsort/deep_sort/deep/test.py +0 -77
- dnt/track/dsort/deep_sort/deep/train.py +0 -189
- dnt/track/dsort/utils/asserts.py +0 -13
- dnt/track/dsort/utils/draw.py +0 -36
- dnt/track/dsort/utils/json_logger.py +0 -383
- dnt/track/dsort/utils/log.py +0 -17
- dnt/track/dsort/utils/parser.py +0 -35
- dnt/track/dsort/utils/tools.py +0 -39
- dnt-0.2.1.dist-info/METADATA +0 -35
- dnt-0.2.1.dist-info/RECORD +0 -60
- /dnt/{track/dsort/utils → third_party/fast-reid/checkpoint}/__init__.py +0 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info/licenses}/LICENSE +0 -0
- {dnt-0.2.1.dist-info → dnt-0.3.1.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: Jinkai Zheng
|
|
4
|
+
@contact: 1315673509@qq.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import glob
|
|
8
|
+
import os.path as osp
|
|
9
|
+
import re
|
|
10
|
+
|
|
11
|
+
from .bases import ImageDataset
|
|
12
|
+
from ..datasets import DATASET_REGISTRY
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@DATASET_REGISTRY.register()
|
|
16
|
+
class VeRi(ImageDataset):
|
|
17
|
+
"""VeRi.
|
|
18
|
+
|
|
19
|
+
Reference:
|
|
20
|
+
Xinchen Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016.
|
|
21
|
+
Xinchen Liu et al. PROVID: Progressive and Multimodal Vehicle Reidentification for Large-Scale Urban Surveillance. IEEE TMM 2018.
|
|
22
|
+
|
|
23
|
+
URL: `<https://vehiclereid.github.io/VeRi/>`_
|
|
24
|
+
|
|
25
|
+
Dataset statistics:
|
|
26
|
+
- identities: 775.
|
|
27
|
+
- images: 37778 (train) + 1678 (query) + 11579 (gallery).
|
|
28
|
+
"""
|
|
29
|
+
dataset_dir = "veri"
|
|
30
|
+
dataset_name = "veri"
|
|
31
|
+
|
|
32
|
+
def __init__(self, root='datasets', **kwargs):
|
|
33
|
+
self.dataset_dir = osp.join(root, self.dataset_dir)
|
|
34
|
+
|
|
35
|
+
self.train_dir = osp.join(self.dataset_dir, 'image_train')
|
|
36
|
+
self.query_dir = osp.join(self.dataset_dir, 'image_query')
|
|
37
|
+
self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
|
|
38
|
+
|
|
39
|
+
required_files = [
|
|
40
|
+
self.dataset_dir,
|
|
41
|
+
self.train_dir,
|
|
42
|
+
self.query_dir,
|
|
43
|
+
self.gallery_dir,
|
|
44
|
+
]
|
|
45
|
+
self.check_before_run(required_files)
|
|
46
|
+
|
|
47
|
+
train = self.process_dir(self.train_dir)
|
|
48
|
+
query = self.process_dir(self.query_dir, is_train=False)
|
|
49
|
+
gallery = self.process_dir(self.gallery_dir, is_train=False)
|
|
50
|
+
|
|
51
|
+
super(VeRi, self).__init__(train, query, gallery, **kwargs)
|
|
52
|
+
|
|
53
|
+
def process_dir(self, dir_path, is_train=True):
|
|
54
|
+
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
|
55
|
+
pattern = re.compile(r'([\d]+)_c(\d\d\d)')
|
|
56
|
+
|
|
57
|
+
data = []
|
|
58
|
+
for img_path in img_paths:
|
|
59
|
+
pid, camid = map(int, pattern.search(img_path).groups())
|
|
60
|
+
if pid == -1: continue # junk images are just ignored
|
|
61
|
+
assert 0 <= pid <= 776
|
|
62
|
+
assert 1 <= camid <= 20
|
|
63
|
+
camid -= 1 # index starts from 0
|
|
64
|
+
if is_train:
|
|
65
|
+
pid = self.dataset_name + "_" + str(pid)
|
|
66
|
+
camid = self.dataset_name + "_" + str(camid)
|
|
67
|
+
data.append((img_path, pid, camid))
|
|
68
|
+
|
|
69
|
+
return data
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: Jinkai Zheng
|
|
4
|
+
@contact: 1315673509@qq.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os.path as osp
|
|
8
|
+
|
|
9
|
+
from .bases import ImageDataset
|
|
10
|
+
from ..datasets import DATASET_REGISTRY
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@DATASET_REGISTRY.register()
|
|
14
|
+
class VeRiWild(ImageDataset):
|
|
15
|
+
"""VeRi-Wild.
|
|
16
|
+
|
|
17
|
+
Reference:
|
|
18
|
+
Lou et al. A Large-Scale Dataset for Vehicle Re-Identification in the Wild. CVPR 2019.
|
|
19
|
+
|
|
20
|
+
URL: `<https://github.com/PKU-IMRE/VERI-Wild>`_
|
|
21
|
+
|
|
22
|
+
Train dataset statistics:
|
|
23
|
+
- identities: 30671.
|
|
24
|
+
- images: 277797.
|
|
25
|
+
"""
|
|
26
|
+
dataset_dir = "VERI-Wild"
|
|
27
|
+
dataset_name = "veriwild"
|
|
28
|
+
|
|
29
|
+
def __init__(self, root='datasets', query_list='', gallery_list='', **kwargs):
|
|
30
|
+
self.dataset_dir = osp.join(root, self.dataset_dir)
|
|
31
|
+
|
|
32
|
+
self.image_dir = osp.join(self.dataset_dir, 'images')
|
|
33
|
+
self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
|
|
34
|
+
self.vehicle_info = osp.join(self.dataset_dir, 'train_test_split/vehicle_info.txt')
|
|
35
|
+
if query_list and gallery_list:
|
|
36
|
+
self.query_list = query_list
|
|
37
|
+
self.gallery_list = gallery_list
|
|
38
|
+
else:
|
|
39
|
+
self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
|
|
40
|
+
self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
|
|
41
|
+
|
|
42
|
+
required_files = [
|
|
43
|
+
self.image_dir,
|
|
44
|
+
self.train_list,
|
|
45
|
+
self.query_list,
|
|
46
|
+
self.gallery_list,
|
|
47
|
+
self.vehicle_info,
|
|
48
|
+
]
|
|
49
|
+
self.check_before_run(required_files)
|
|
50
|
+
|
|
51
|
+
self.imgid2vid, self.imgid2camid, self.imgid2imgpath = self.process_vehicle(self.vehicle_info)
|
|
52
|
+
|
|
53
|
+
train = self.process_dir(self.train_list)
|
|
54
|
+
query = self.process_dir(self.query_list, is_train=False)
|
|
55
|
+
gallery = self.process_dir(self.gallery_list, is_train=False)
|
|
56
|
+
|
|
57
|
+
super(VeRiWild, self).__init__(train, query, gallery, **kwargs)
|
|
58
|
+
|
|
59
|
+
def process_dir(self, img_list, is_train=True):
|
|
60
|
+
img_list_lines = open(img_list, 'r').readlines()
|
|
61
|
+
|
|
62
|
+
dataset = []
|
|
63
|
+
for idx, line in enumerate(img_list_lines):
|
|
64
|
+
line = line.strip()
|
|
65
|
+
vid = int(line.split('/')[0])
|
|
66
|
+
imgid = line.split('/')[1].split('.')[0]
|
|
67
|
+
camid = int(self.imgid2camid[imgid])
|
|
68
|
+
if is_train:
|
|
69
|
+
vid = f"{self.dataset_name}_{vid}"
|
|
70
|
+
camid = f"{self.dataset_name}_{camid}"
|
|
71
|
+
dataset.append((self.imgid2imgpath[imgid], vid, camid))
|
|
72
|
+
|
|
73
|
+
assert len(dataset) == len(img_list_lines)
|
|
74
|
+
return dataset
|
|
75
|
+
|
|
76
|
+
def process_vehicle(self, vehicle_info):
|
|
77
|
+
imgid2vid = {}
|
|
78
|
+
imgid2camid = {}
|
|
79
|
+
imgid2imgpath = {}
|
|
80
|
+
vehicle_info_lines = open(vehicle_info, 'r').readlines()
|
|
81
|
+
|
|
82
|
+
for idx, line in enumerate(vehicle_info_lines[1:]):
|
|
83
|
+
vid = line.strip().split('/')[0]
|
|
84
|
+
imgid = line.strip().split(';')[0].split('/')[1]
|
|
85
|
+
camid = line.strip().split(';')[1]
|
|
86
|
+
img_path = osp.join(self.image_dir, vid, imgid + '.jpg')
|
|
87
|
+
imgid2vid[imgid] = vid
|
|
88
|
+
imgid2camid[imgid] = camid
|
|
89
|
+
imgid2imgpath[imgid] = img_path
|
|
90
|
+
|
|
91
|
+
assert len(imgid2vid) == len(vehicle_info_lines) - 1
|
|
92
|
+
return imgid2vid, imgid2camid, imgid2imgpath
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@DATASET_REGISTRY.register()
|
|
96
|
+
class SmallVeRiWild(VeRiWild):
|
|
97
|
+
"""VeRi-Wild.
|
|
98
|
+
Small test dataset statistics:
|
|
99
|
+
- identities: 3000.
|
|
100
|
+
- images: 41861.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, root='datasets', **kwargs):
|
|
104
|
+
dataset_dir = osp.join(root, self.dataset_dir)
|
|
105
|
+
self.query_list = osp.join(dataset_dir, 'train_test_split/test_3000_query.txt')
|
|
106
|
+
self.gallery_list = osp.join(dataset_dir, 'train_test_split/test_3000.txt')
|
|
107
|
+
|
|
108
|
+
super(SmallVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@DATASET_REGISTRY.register()
|
|
112
|
+
class MediumVeRiWild(VeRiWild):
|
|
113
|
+
"""VeRi-Wild.
|
|
114
|
+
Medium test dataset statistics:
|
|
115
|
+
- identities: 5000.
|
|
116
|
+
- images: 69389.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(self, root='datasets', **kwargs):
|
|
120
|
+
dataset_dir = osp.join(root, self.dataset_dir)
|
|
121
|
+
self.query_list = osp.join(dataset_dir, 'train_test_split/test_5000_query.txt')
|
|
122
|
+
self.gallery_list = osp.join(dataset_dir, 'train_test_split/test_5000.txt')
|
|
123
|
+
|
|
124
|
+
super(MediumVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@DATASET_REGISTRY.register()
|
|
128
|
+
class LargeVeRiWild(VeRiWild):
|
|
129
|
+
"""VeRi-Wild.
|
|
130
|
+
Large test dataset statistics:
|
|
131
|
+
- identities: 10000.
|
|
132
|
+
- images: 138517.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(self, root='datasets', **kwargs):
|
|
136
|
+
dataset_dir = osp.join(root, self.dataset_dir)
|
|
137
|
+
self.query_list = osp.join(dataset_dir, 'train_test_split/test_10000_query.txt')
|
|
138
|
+
self.gallery_list = osp.join(dataset_dir, 'train_test_split/test_10000.txt')
|
|
139
|
+
|
|
140
|
+
super(LargeVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: xingyu liao
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from glob import glob
|
|
9
|
+
|
|
10
|
+
from fastreid.data.datasets import DATASET_REGISTRY
|
|
11
|
+
from fastreid.data.datasets.bases import ImageDataset
|
|
12
|
+
|
|
13
|
+
__all__ = ['VIPeR', ]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@DATASET_REGISTRY.register()
|
|
17
|
+
class VIPeR(ImageDataset):
|
|
18
|
+
dataset_dir = "VIPeR"
|
|
19
|
+
dataset_name = "viper"
|
|
20
|
+
|
|
21
|
+
def __init__(self, root='datasets', **kwargs):
|
|
22
|
+
self.root = root
|
|
23
|
+
self.train_path = os.path.join(self.root, self.dataset_dir)
|
|
24
|
+
|
|
25
|
+
required_files = [self.train_path]
|
|
26
|
+
self.check_before_run(required_files)
|
|
27
|
+
|
|
28
|
+
train = self.process_train(self.train_path)
|
|
29
|
+
|
|
30
|
+
super().__init__(train, [], [], **kwargs)
|
|
31
|
+
|
|
32
|
+
def process_train(self, train_path):
|
|
33
|
+
data = []
|
|
34
|
+
|
|
35
|
+
file_path_list = ['cam_a', 'cam_b']
|
|
36
|
+
|
|
37
|
+
for file_path in file_path_list:
|
|
38
|
+
camid = self.dataset_name + "_" + file_path
|
|
39
|
+
img_list = glob(os.path.join(train_path, file_path, "*.bmp"))
|
|
40
|
+
for img_path in img_list:
|
|
41
|
+
img_name = img_path.split('/')[-1]
|
|
42
|
+
pid = self.dataset_name + "_" + img_name.split('_')[0]
|
|
43
|
+
data.append([img_path, pid, camid])
|
|
44
|
+
|
|
45
|
+
return data
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: wangguanan
|
|
4
|
+
@contact: guan.wang0706@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import glob
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
from .bases import ImageDataset
|
|
11
|
+
from ..datasets import DATASET_REGISTRY
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@DATASET_REGISTRY.register()
|
|
15
|
+
class WildTrackCrop(ImageDataset):
|
|
16
|
+
"""WildTrack.
|
|
17
|
+
Reference:
|
|
18
|
+
WILDTRACK: A Multi-camera HD Dataset for Dense Unscripted Pedestrian Detection
|
|
19
|
+
T. Chavdarova; P. Baqué; A. Maksai; S. Bouquet; C. Jose et al.
|
|
20
|
+
URL: `<https://www.epfl.ch/labs/cvlab/data/data-wildtrack/>`_
|
|
21
|
+
Dataset statistics:
|
|
22
|
+
- identities: 313
|
|
23
|
+
- images: 33979 (train only)
|
|
24
|
+
- cameras: 7
|
|
25
|
+
Args:
|
|
26
|
+
data_path(str): path to WildTrackCrop dataset
|
|
27
|
+
combineall(bool): combine train and test sets as train set if True
|
|
28
|
+
"""
|
|
29
|
+
dataset_url = None
|
|
30
|
+
dataset_dir = 'Wildtrack_crop_dataset'
|
|
31
|
+
dataset_name = 'wildtrack'
|
|
32
|
+
|
|
33
|
+
def __init__(self, root='datasets', **kwargs):
|
|
34
|
+
self.root = root
|
|
35
|
+
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
|
|
36
|
+
|
|
37
|
+
self.train_dir = os.path.join(self.dataset_dir, "crop")
|
|
38
|
+
|
|
39
|
+
train = self.process_dir(self.train_dir)
|
|
40
|
+
query = []
|
|
41
|
+
gallery = []
|
|
42
|
+
|
|
43
|
+
super(WildTrackCrop, self).__init__(train, query, gallery, **kwargs)
|
|
44
|
+
|
|
45
|
+
def process_dir(self, dir_path):
|
|
46
|
+
r"""
|
|
47
|
+
:param dir_path: directory path saving images
|
|
48
|
+
Returns
|
|
49
|
+
data(list) = [img_path, pid, camid]
|
|
50
|
+
"""
|
|
51
|
+
data = []
|
|
52
|
+
for dir_name in os.listdir(dir_path):
|
|
53
|
+
img_lists = glob.glob(os.path.join(dir_path, dir_name, "*.png"))
|
|
54
|
+
for img_path in img_lists:
|
|
55
|
+
pid = self.dataset_name + "_" + dir_name
|
|
56
|
+
camid = img_path.split('/')[-1].split('_')[0]
|
|
57
|
+
camid = self.dataset_name + "_" + camid
|
|
58
|
+
data.append([img_path, pid, camid])
|
|
59
|
+
return data
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: liaoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler
|
|
8
|
+
from .data_sampler import TrainingSampler, InferenceSampler
|
|
9
|
+
from .imbalance_sampler import ImbalancedDatasetSampler
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BalancedIdentitySampler",
|
|
13
|
+
"NaiveIdentitySampler",
|
|
14
|
+
"SetReWeightSampler",
|
|
15
|
+
"TrainingSampler",
|
|
16
|
+
"InferenceSampler",
|
|
17
|
+
"ImbalancedDatasetSampler",
|
|
18
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: l1aoxingyu
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
import itertools
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from torch.utils.data import Sampler
|
|
11
|
+
|
|
12
|
+
from fastreid.utils import comm
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TrainingSampler(Sampler):
|
|
16
|
+
"""
|
|
17
|
+
In training, we only care about the "infinite stream" of training data.
|
|
18
|
+
So this sampler produces an infinite stream of indices and
|
|
19
|
+
all workers cooperate to correctly shuffle the indices and sample different indices.
|
|
20
|
+
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
|
21
|
+
where `indices` is an infinite stream of indices consisting of
|
|
22
|
+
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
|
23
|
+
or `range(size) + range(size) + ...` (if shuffle is False)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
|
|
27
|
+
"""
|
|
28
|
+
Args:
|
|
29
|
+
size (int): the total number of data of the underlying dataset to sample from
|
|
30
|
+
shuffle (bool): whether to shuffle the indices or not
|
|
31
|
+
seed (int): the initial seed of the shuffle. Must be the same
|
|
32
|
+
across all workers. If None, will use a random seed shared
|
|
33
|
+
among workers (require synchronization among all workers).
|
|
34
|
+
"""
|
|
35
|
+
self._size = size
|
|
36
|
+
assert size > 0
|
|
37
|
+
self._shuffle = shuffle
|
|
38
|
+
if seed is None:
|
|
39
|
+
seed = comm.shared_random_seed()
|
|
40
|
+
self._seed = int(seed)
|
|
41
|
+
|
|
42
|
+
self._rank = comm.get_rank()
|
|
43
|
+
self._world_size = comm.get_world_size()
|
|
44
|
+
|
|
45
|
+
def __iter__(self):
|
|
46
|
+
start = self._rank
|
|
47
|
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
|
48
|
+
|
|
49
|
+
def _infinite_indices(self):
|
|
50
|
+
np.random.seed(self._seed)
|
|
51
|
+
while True:
|
|
52
|
+
if self._shuffle:
|
|
53
|
+
yield from np.random.permutation(self._size)
|
|
54
|
+
else:
|
|
55
|
+
yield from np.arange(self._size)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class InferenceSampler(Sampler):
|
|
59
|
+
"""
|
|
60
|
+
Produce indices for inference.
|
|
61
|
+
Inference needs to run on the __exact__ set of samples,
|
|
62
|
+
therefore when the total number of samples is not divisible by the number of workers,
|
|
63
|
+
this sampler produces different number of samples on different workers.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, size: int):
|
|
67
|
+
"""
|
|
68
|
+
Args:
|
|
69
|
+
size (int): the total number of data of the underlying dataset to sample from
|
|
70
|
+
"""
|
|
71
|
+
self._size = size
|
|
72
|
+
assert size > 0
|
|
73
|
+
self._rank = comm.get_rank()
|
|
74
|
+
self._world_size = comm.get_world_size()
|
|
75
|
+
|
|
76
|
+
shard_size = (self._size - 1) // self._world_size + 1
|
|
77
|
+
begin = shard_size * self._rank
|
|
78
|
+
end = min(shard_size * (self._rank + 1), self._size)
|
|
79
|
+
self._local_indices = range(begin, end)
|
|
80
|
+
|
|
81
|
+
def __iter__(self):
|
|
82
|
+
yield from self._local_indices
|
|
83
|
+
|
|
84
|
+
def __len__(self):
|
|
85
|
+
return len(self._local_indices)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
"""
|
|
3
|
+
@author: xingyu liao
|
|
4
|
+
@contact: sherlockliao01@gmail.com
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# based on:
|
|
8
|
+
# https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import itertools
|
|
12
|
+
from typing import Optional, List, Callable
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from torch.utils.data.sampler import Sampler
|
|
17
|
+
|
|
18
|
+
from fastreid.utils import comm
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ImbalancedDatasetSampler(Sampler):
|
|
22
|
+
"""Samples elements randomly from a given list of indices for imbalanced dataset
|
|
23
|
+
Arguments:
|
|
24
|
+
data_source: a list of data items
|
|
25
|
+
size: number of samples to draw
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None,
|
|
29
|
+
callback_get_label: Callable = None):
|
|
30
|
+
self.data_source = data_source
|
|
31
|
+
# consider all elements in the dataset
|
|
32
|
+
self.indices = list(range(len(data_source)))
|
|
33
|
+
# if num_samples is not provided, draw `len(indices)` samples in each iteration
|
|
34
|
+
self._size = len(self.indices) if size is None else size
|
|
35
|
+
self.callback_get_label = callback_get_label
|
|
36
|
+
|
|
37
|
+
# distribution of classes in the dataset
|
|
38
|
+
label_to_count = {}
|
|
39
|
+
for idx in self.indices:
|
|
40
|
+
label = self._get_label(data_source, idx)
|
|
41
|
+
label_to_count[label] = label_to_count.get(label, 0) + 1
|
|
42
|
+
|
|
43
|
+
# weight for each sample
|
|
44
|
+
weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices]
|
|
45
|
+
self.weights = torch.DoubleTensor(weights)
|
|
46
|
+
|
|
47
|
+
if seed is None:
|
|
48
|
+
seed = comm.shared_random_seed()
|
|
49
|
+
self._seed = int(seed)
|
|
50
|
+
self._rank = comm.get_rank()
|
|
51
|
+
self._world_size = comm.get_world_size()
|
|
52
|
+
|
|
53
|
+
def _get_label(self, dataset, idx):
|
|
54
|
+
if self.callback_get_label:
|
|
55
|
+
return self.callback_get_label(dataset, idx)
|
|
56
|
+
else:
|
|
57
|
+
return dataset[idx][1]
|
|
58
|
+
|
|
59
|
+
def __iter__(self):
|
|
60
|
+
start = self._rank
|
|
61
|
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
|
62
|
+
|
|
63
|
+
def _infinite_indices(self):
|
|
64
|
+
np.random.seed(self._seed)
|
|
65
|
+
while True:
|
|
66
|
+
for i in torch.multinomial(self.weights, self._size, replacement=True):
|
|
67
|
+
yield self.indices[i]
|