vismatch 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vismatch/TEMPLATE.py +101 -0
- vismatch/__init__.py +475 -0
- vismatch/assets/example_pairs/false_positive/chartres.jpg +0 -0
- vismatch/assets/example_pairs/false_positive/notre_dame.jpg +0 -0
- vismatch/assets/example_pairs/fresco/fsm.jpg +0 -0
- vismatch/assets/example_pairs/fresco/sist_chapel.jpg +0 -0
- vismatch/assets/example_pairs/indoor/gcs_close.jpg +0 -0
- vismatch/assets/example_pairs/indoor/gcs_far.jpg +0 -0
- vismatch/assets/example_pairs/outdoor/montmartre_close.jpg +0 -0
- vismatch/assets/example_pairs/outdoor/montmartre_far.jpg +0 -0
- vismatch/assets/example_pairs/sat2iss/photo_from_iss.jpg +0 -0
- vismatch/assets/example_pairs/sat2iss/satellite_img.jpg +0 -0
- vismatch/assets/example_pairs/sphereglue/barbershop-00000000.jpg +0 -0
- vismatch/assets/example_pairs/sphereglue/barbershop-00000001.jpg +0 -0
- vismatch/assets/example_pairs/thermal/thermal.jpg +0 -0
- vismatch/assets/example_pairs/thermal/visible.jpg +0 -0
- vismatch/assets/example_test/original.jpg +0 -0
- vismatch/assets/example_test/warped.jpg +0 -0
- vismatch/base_matcher.py +242 -0
- vismatch/im_models/__init__.py +0 -0
- vismatch/im_models/aff_steerers.py +143 -0
- vismatch/im_models/aspanformer.py +74 -0
- vismatch/im_models/dedode.py +150 -0
- vismatch/im_models/duster.py +104 -0
- vismatch/im_models/edm.py +64 -0
- vismatch/im_models/efficient_loftr.py +60 -0
- vismatch/im_models/gim.py +187 -0
- vismatch/im_models/handcrafted.py +81 -0
- vismatch/im_models/keypt2subpx.py +154 -0
- vismatch/im_models/kornia.py +72 -0
- vismatch/im_models/liftfeat.py +44 -0
- vismatch/im_models/lightglue.py +75 -0
- vismatch/im_models/lisrd.py +98 -0
- vismatch/im_models/loftr.py +23 -0
- vismatch/im_models/master.py +107 -0
- vismatch/im_models/matchanything.py +221 -0
- vismatch/im_models/matchformer.py +61 -0
- vismatch/im_models/matching_toolbox.py +238 -0
- vismatch/im_models/minima.py +164 -0
- vismatch/im_models/omniglue.py +91 -0
- vismatch/im_models/rdd.py +250 -0
- vismatch/im_models/ripe.py +55 -0
- vismatch/im_models/roma.py +92 -0
- vismatch/im_models/romav2.py +62 -0
- vismatch/im_models/se2loftr.py +71 -0
- vismatch/im_models/silk.py +405 -0
- vismatch/im_models/sphereglue.py +97 -0
- vismatch/im_models/steerers.py +140 -0
- vismatch/im_models/topicfm.py +93 -0
- vismatch/im_models/ufm.py +57 -0
- vismatch/im_models/xfeat.py +78 -0
- vismatch/im_models/xfeat_steerers.py +151 -0
- vismatch/im_models/xoftr.py +71 -0
- vismatch/third_party/DeDoDe/DeDoDe/__init__.py +2 -0
- vismatch/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py +4 -0
- vismatch/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py +114 -0
- vismatch/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py +119 -0
- vismatch/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py +57 -0
- vismatch/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py +76 -0
- vismatch/third_party/DeDoDe/DeDoDe/checkpoint.py +59 -0
- vismatch/third_party/DeDoDe/DeDoDe/datasets/__init__.py +0 -0
- vismatch/third_party/DeDoDe/DeDoDe/datasets/megadepth.py +269 -0
- vismatch/third_party/DeDoDe/DeDoDe/decoder.py +90 -0
- vismatch/third_party/DeDoDe/DeDoDe/descriptors/__init__.py +0 -0
- vismatch/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py +50 -0
- vismatch/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py +68 -0
- vismatch/third_party/DeDoDe/DeDoDe/detectors/__init__.py +0 -0
- vismatch/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py +76 -0
- vismatch/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py +185 -0
- vismatch/third_party/DeDoDe/DeDoDe/encoder.py +87 -0
- vismatch/third_party/DeDoDe/DeDoDe/matchers/__init__.py +0 -0
- vismatch/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py +38 -0
- vismatch/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py +3 -0
- vismatch/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py +249 -0
- vismatch/third_party/DeDoDe/DeDoDe/train.py +76 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/__init__.py +8 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/dinov2.py +359 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py +12 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py +81 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/block.py +252 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py +59 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py +35 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py +28 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py +41 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py +89 -0
- vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py +63 -0
- vismatch/third_party/DeDoDe/DeDoDe/utils.py +717 -0
- vismatch/third_party/DeDoDe/data_prep/prep_keypoints.py +103 -0
- vismatch/third_party/DeDoDe/demo/demo_kpts.py +24 -0
- vismatch/third_party/DeDoDe/demo/demo_match.py +46 -0
- vismatch/third_party/DeDoDe/demo/demo_match_dedode_G.py +45 -0
- vismatch/third_party/DeDoDe/demo/demo_scoremap.py +23 -0
- vismatch/third_party/DeDoDe/experiments/dedode_descriptor-B.py +135 -0
- vismatch/third_party/DeDoDe/experiments/dedode_descriptor-G.py +145 -0
- vismatch/third_party/DeDoDe/experiments/dedode_detector.py +126 -0
- vismatch/third_party/DeDoDe/experiments/eval/eval_dedode_descriptor-B.py +38 -0
- vismatch/third_party/DeDoDe/experiments/eval/eval_dedode_descriptor-G.py +38 -0
- vismatch/third_party/DeDoDe/setup.py +11 -0
- vismatch/third_party/EDM/configs/data/__init__.py +0 -0
- vismatch/third_party/EDM/configs/data/base.py +37 -0
- vismatch/third_party/EDM/configs/data/megadepth_test_1500.py +23 -0
- vismatch/third_party/EDM/configs/data/megadepth_trainval_832.py +32 -0
- vismatch/third_party/EDM/configs/data/scannet_test_1500.py +24 -0
- vismatch/third_party/EDM/configs/data/scannet_trainval.py +31 -0
- vismatch/third_party/EDM/configs/edm/indoor/edm_base.py +15 -0
- vismatch/third_party/EDM/configs/edm/outdoor/edm_base.py +17 -0
- vismatch/third_party/EDM/deploy/export_onnx.py +69 -0
- vismatch/third_party/EDM/deploy/run_onnx.py +138 -0
- vismatch/third_party/EDM/runtime_single_pair.py +73 -0
- vismatch/third_party/EDM/src/__init__.py +0 -0
- vismatch/third_party/EDM/src/config/default.py +184 -0
- vismatch/third_party/EDM/src/datasets/megadepth.py +164 -0
- vismatch/third_party/EDM/src/datasets/sampler.py +95 -0
- vismatch/third_party/EDM/src/datasets/scannet.py +147 -0
- vismatch/third_party/EDM/src/edm/__init__.py +2 -0
- vismatch/third_party/EDM/src/edm/backbone/resnet.py +116 -0
- vismatch/third_party/EDM/src/edm/edm.py +204 -0
- vismatch/third_party/EDM/src/edm/head/coarse_matching.py +158 -0
- vismatch/third_party/EDM/src/edm/head/fine_matching.py +383 -0
- vismatch/third_party/EDM/src/edm/neck/__init__.py +1 -0
- vismatch/third_party/EDM/src/edm/neck/loftr_module/__init__.py +1 -0
- vismatch/third_party/EDM/src/edm/neck/loftr_module/transformer.py +418 -0
- vismatch/third_party/EDM/src/edm/neck/neck.py +156 -0
- vismatch/third_party/EDM/src/edm/utils/geometry.py +58 -0
- vismatch/third_party/EDM/src/edm/utils/supervision.py +255 -0
- vismatch/third_party/EDM/src/lightning/data.py +450 -0
- vismatch/third_party/EDM/src/lightning/lightning_edm.py +379 -0
- vismatch/third_party/EDM/src/losses/edm_loss.py +206 -0
- vismatch/third_party/EDM/src/optimizers/__init__.py +57 -0
- vismatch/third_party/EDM/src/utils/augment.py +65 -0
- vismatch/third_party/EDM/src/utils/comm.py +271 -0
- vismatch/third_party/EDM/src/utils/dataloader.py +24 -0
- vismatch/third_party/EDM/src/utils/dataset.py +192 -0
- vismatch/third_party/EDM/src/utils/metrics.py +299 -0
- vismatch/third_party/EDM/src/utils/misc.py +113 -0
- vismatch/third_party/EDM/src/utils/plotting.py +186 -0
- vismatch/third_party/EDM/src/utils/profiler.py +40 -0
- vismatch/third_party/EDM/src/utils/warppers.py +428 -0
- vismatch/third_party/EDM/src/utils/warppers_utils.py +172 -0
- vismatch/third_party/EDM/test.py +132 -0
- vismatch/third_party/EDM/train.py +156 -0
- vismatch/third_party/EfficientLoFTR/configs/data/__init__.py +0 -0
- vismatch/third_party/EfficientLoFTR/configs/data/base.py +35 -0
- vismatch/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py +13 -0
- vismatch/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py +24 -0
- vismatch/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py +16 -0
- vismatch/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py +36 -0
- vismatch/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py +37 -0
- vismatch/third_party/EfficientLoFTR/src/__init__.py +0 -0
- vismatch/third_party/EfficientLoFTR/src/config/default.py +182 -0
- vismatch/third_party/EfficientLoFTR/src/datasets/megadepth.py +133 -0
- vismatch/third_party/EfficientLoFTR/src/datasets/sampler.py +77 -0
- vismatch/third_party/EfficientLoFTR/src/datasets/scannet.py +129 -0
- vismatch/third_party/EfficientLoFTR/src/lightning/data.py +357 -0
- vismatch/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py +272 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/__init__.py +4 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py +11 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py +37 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py +224 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/loftr.py +124 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py +2 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py +112 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py +103 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py +164 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py +241 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py +156 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/utils/full_config.py +50 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/utils/geometry.py +54 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py +50 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py +50 -0
- vismatch/third_party/EfficientLoFTR/src/loftr/utils/supervision.py +275 -0
- vismatch/third_party/EfficientLoFTR/src/losses/loftr_loss.py +229 -0
- vismatch/third_party/EfficientLoFTR/src/optimizers/__init__.py +42 -0
- vismatch/third_party/EfficientLoFTR/src/utils/augment.py +55 -0
- vismatch/third_party/EfficientLoFTR/src/utils/comm.py +265 -0
- vismatch/third_party/EfficientLoFTR/src/utils/dataloader.py +23 -0
- vismatch/third_party/EfficientLoFTR/src/utils/dataset.py +186 -0
- vismatch/third_party/EfficientLoFTR/src/utils/metrics.py +264 -0
- vismatch/third_party/EfficientLoFTR/src/utils/misc.py +106 -0
- vismatch/third_party/EfficientLoFTR/src/utils/plotting.py +154 -0
- vismatch/third_party/EfficientLoFTR/src/utils/profiler.py +39 -0
- vismatch/third_party/EfficientLoFTR/src/utils/warppers.py +426 -0
- vismatch/third_party/EfficientLoFTR/src/utils/warppers_utils.py +171 -0
- vismatch/third_party/EfficientLoFTR/test.py +143 -0
- vismatch/third_party/EfficientLoFTR/train.py +154 -0
- vismatch/third_party/LISRD/lisrd/__init__.py +0 -0
- vismatch/third_party/LISRD/lisrd/datasets/__init__.py +7 -0
- vismatch/third_party/LISRD/lisrd/datasets/base_dataset.py +38 -0
- vismatch/third_party/LISRD/lisrd/datasets/coco.py +148 -0
- vismatch/third_party/LISRD/lisrd/datasets/flashes.py +170 -0
- vismatch/third_party/LISRD/lisrd/datasets/hpatches.py +135 -0
- vismatch/third_party/LISRD/lisrd/datasets/mixed_dataset.py +53 -0
- vismatch/third_party/LISRD/lisrd/datasets/rdnim.py +117 -0
- vismatch/third_party/LISRD/lisrd/datasets/utils/data_augmentation.py +168 -0
- vismatch/third_party/LISRD/lisrd/datasets/utils/data_reader.py +48 -0
- vismatch/third_party/LISRD/lisrd/datasets/utils/homographies.py +215 -0
- vismatch/third_party/LISRD/lisrd/datasets/vidit.py +152 -0
- vismatch/third_party/LISRD/lisrd/evaluation/__init__.py +0 -0
- vismatch/third_party/LISRD/lisrd/evaluation/descriptor_evaluation.py +142 -0
- vismatch/third_party/LISRD/lisrd/experiment.py +129 -0
- vismatch/third_party/LISRD/lisrd/export_features.py +148 -0
- vismatch/third_party/LISRD/lisrd/models/__init__.py +7 -0
- vismatch/third_party/LISRD/lisrd/models/backbones/__init__.py +0 -0
- vismatch/third_party/LISRD/lisrd/models/backbones/net_vlad.py +62 -0
- vismatch/third_party/LISRD/lisrd/models/backbones/vgg.py +46 -0
- vismatch/third_party/LISRD/lisrd/models/base_model.py +336 -0
- vismatch/third_party/LISRD/lisrd/models/keypoint_detectors.py +34 -0
- vismatch/third_party/LISRD/lisrd/models/lisrd.py +328 -0
- vismatch/third_party/LISRD/lisrd/models/lisrd_sift.py +289 -0
- vismatch/third_party/LISRD/lisrd/third_party/super_point_magic_leap/demo_superpoint.py +734 -0
- vismatch/third_party/LISRD/lisrd/utils/geometry_utils.py +123 -0
- vismatch/third_party/LISRD/lisrd/utils/losses.py +191 -0
- vismatch/third_party/LISRD/lisrd/utils/metrics.py +66 -0
- vismatch/third_party/LISRD/lisrd/utils/pytorch_utils.py +14 -0
- vismatch/third_party/LISRD/lisrd/utils/stdout_capturing.py +81 -0
- vismatch/third_party/LISRD/notebooks/utils.py +103 -0
- vismatch/third_party/LISRD/setup.py +4 -0
- vismatch/third_party/LiftFeat/dataset/__init__.py +0 -0
- vismatch/third_party/LiftFeat/dataset/coco_augmentor.py +298 -0
- vismatch/third_party/LiftFeat/dataset/coco_wrapper.py +175 -0
- vismatch/third_party/LiftFeat/dataset/dataset_utils.py +183 -0
- vismatch/third_party/LiftFeat/dataset/megadepth.py +177 -0
- vismatch/third_party/LiftFeat/dataset/megadepth_wrapper.py +167 -0
- vismatch/third_party/LiftFeat/demo.py +116 -0
- vismatch/third_party/LiftFeat/evaluation/HPatch_evaluation.py +182 -0
- vismatch/third_party/LiftFeat/evaluation/MegaDepth1500_evaluation.py +105 -0
- vismatch/third_party/LiftFeat/evaluation/eval_utils.py +127 -0
- vismatch/third_party/LiftFeat/loss/loss.py +291 -0
- vismatch/third_party/LiftFeat/models/interpolator.py +34 -0
- vismatch/third_party/LiftFeat/models/liftfeat_wrapper.py +172 -0
- vismatch/third_party/LiftFeat/models/model.py +419 -0
- vismatch/third_party/LiftFeat/tools/demo_match_video.py +145 -0
- vismatch/third_party/LiftFeat/tools/demo_vo.py +163 -0
- vismatch/third_party/LiftFeat/train.py +369 -0
- vismatch/third_party/LiftFeat/utils/VisualOdometry.py +339 -0
- vismatch/third_party/LiftFeat/utils/__init__.py +0 -0
- vismatch/third_party/LiftFeat/utils/alike_wrapper.py +45 -0
- vismatch/third_party/LiftFeat/utils/config.py +16 -0
- vismatch/third_party/LiftFeat/utils/depth_anything_wrapper.py +150 -0
- vismatch/third_party/LiftFeat/utils/featurebooster.py +247 -0
- vismatch/third_party/LiftFeat/utils/post_process.py +21 -0
- vismatch/third_party/LightGlue/benchmark.py +255 -0
- vismatch/third_party/LightGlue/lightglue/__init__.py +7 -0
- vismatch/third_party/LightGlue/lightglue/aliked.py +760 -0
- vismatch/third_party/LightGlue/lightglue/disk.py +55 -0
- vismatch/third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
- vismatch/third_party/LightGlue/lightglue/lightglue.py +662 -0
- vismatch/third_party/LightGlue/lightglue/sift.py +216 -0
- vismatch/third_party/LightGlue/lightglue/superpoint.py +227 -0
- vismatch/third_party/LightGlue/lightglue/utils.py +165 -0
- vismatch/third_party/LightGlue/lightglue/viz2d.py +203 -0
- vismatch/third_party/MINIMA/demo.py +201 -0
- vismatch/third_party/MINIMA/src/__init__.py +0 -0
- vismatch/third_party/MINIMA/src/config/default.py +203 -0
- vismatch/third_party/MINIMA/src/config/default_for_megadepth_dense.py +203 -0
- vismatch/third_party/MINIMA/src/config/default_for_megadepth_sparse.py +203 -0
- vismatch/third_party/MINIMA/src/utils/__init__.py +0 -0
- vismatch/third_party/MINIMA/src/utils/culculate_auc.py +28 -0
- vismatch/third_party/MINIMA/src/utils/data_io.py +156 -0
- vismatch/third_party/MINIMA/src/utils/data_io_loftr.py +152 -0
- vismatch/third_party/MINIMA/src/utils/data_io_roma.py +186 -0
- vismatch/third_party/MINIMA/src/utils/data_io_sp_lg.py +158 -0
- vismatch/third_party/MINIMA/src/utils/load_model.py +164 -0
- vismatch/third_party/MINIMA/src/utils/metrics.py +214 -0
- vismatch/third_party/MINIMA/src/utils/misc.py +101 -0
- vismatch/third_party/MINIMA/src/utils/plotting.py +291 -0
- vismatch/third_party/MINIMA/src/utils/sample_h.py +142 -0
- vismatch/third_party/MINIMA/test_relative_homo_depth.py +683 -0
- vismatch/third_party/MINIMA/test_relative_homo_event.py +722 -0
- vismatch/third_party/MINIMA/test_relative_homo_mmim.py +669 -0
- vismatch/third_party/MINIMA/test_relative_pose_infrared.py +500 -0
- vismatch/third_party/MINIMA/test_relative_pose_mega_1500.py +487 -0
- vismatch/third_party/MINIMA/test_relative_pose_mega_1500_syn.py +516 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/benchmark.py +255 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/__init__.py +7 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/aliked.py +758 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/disk.py +55 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/lightglue.py +655 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/sift.py +216 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/superpoint.py +227 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/utils.py +165 -0
- vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/viz2d.py +184 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/__init__.py +0 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/base.py +35 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/megadepth_test_1500.py +11 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/megadepth_trainval_640.py +22 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/megadepth_trainval_840.py +22 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/scannet_test_1500.py +11 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/scannet_trainval.py +17 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds.py +6 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py +8 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot.py +6 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py +8 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ds.py +5 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ds_dense.py +7 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ot.py +5 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ot_dense.py +7 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval.py +16 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval_new.py +18 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py +16 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py +17 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py +16 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py +17 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ds.py +15 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ds_dense.py +16 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ot.py +15 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ot_dense.py +16 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/demo/demo_loftr.py +240 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/__init__.py +0 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/config/default.py +171 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/datasets/megadepth.py +127 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/datasets/sampler.py +77 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/datasets/scannet.py +114 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/lightning/data.py +320 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/lightning/lightning_loftr.py +249 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/__init__.py +2 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/backbone/__init__.py +11 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/backbone/resnet_fpn.py +199 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr.py +81 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/__init__.py +2 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/fine_preprocess.py +59 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/linear_attention.py +81 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/transformer.py +101 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/coarse_matching.py +261 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/cvpr_ds_config.py +50 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/fine_matching.py +74 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/geometry.py +54 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/position_encoding.py +42 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/supervision.py +151 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/losses/loftr_loss.py +192 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/optimizers/__init__.py +42 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/augment.py +55 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/comm.py +265 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/dataloader.py +23 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/dataset.py +185 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/metrics.py +193 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/misc.py +101 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/plotting.py +154 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/profiler.py +39 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/test.py +68 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/demo_superglue.py +259 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/match_pairs.py +425 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/__init__.py +0 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/matching.py +84 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/superglue.py +283 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/superpoint.py +202 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/utils.py +555 -0
- vismatch/third_party/MINIMA/third_party/LoFTR/train.py +123 -0
- vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_3D_effect.py +47 -0
- vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_fundamental.py +34 -0
- vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_match.py +50 -0
- vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
- vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_match_tiny.py +77 -0
- vismatch/third_party/MINIMA/third_party/RoMa/experiments/eval_roma_outdoor.py +57 -0
- vismatch/third_party/MINIMA/third_party/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
- vismatch/third_party/MINIMA/third_party/RoMa/experiments/roma_indoor.py +320 -0
- vismatch/third_party/MINIMA/third_party/RoMa/experiments/train_roma_outdoor.py +307 -0
- vismatch/third_party/MINIMA/third_party/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/__init__.py +8 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/__init__.py +6 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/checkpointing/__init__.py +1 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/checkpointing/checkpoint.py +60 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/datasets/__init__.py +2 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/datasets/megadepth.py +232 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/datasets/scannet.py +160 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/losses/__init__.py +1 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/losses/robust_loss.py +161 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py +160 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/__init__.py +1 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/encoders.py +122 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/matcher.py +766 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/model_zoo/__init__.py +73 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/model_zoo/roma_models.py +170 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/tiny.py +304 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/__init__.py +48 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/dinov2.py +359 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/__init__.py +12 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/attention.py +81 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/block.py +252 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/dino_head.py +59 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/drop_path.py +35 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py +28 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/mlp.py +41 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py +89 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py +63 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/train/__init__.py +1 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/train/train.py +102 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/__init__.py +16 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/kde.py +13 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/local_correlation.py +48 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/transforms.py +118 -0
- vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/utils.py +662 -0
- vismatch/third_party/MINIMA/third_party/RoMa/setup.py +9 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/__init__.py +0 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/base.py +35 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/megadepth_trainval_840.py +22 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py +23 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/pretrain.py +8 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py +17 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py +12 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/pretrain.py +125 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/__init__.py +0 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/config/default.py +203 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/megadepth.py +143 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/pretrain_dataset.py +156 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/sampler.py +77 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/scannet.py +114 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/vistir.py +109 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/data.py +346 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/data_pretrain.py +125 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/lightning_xoftr.py +334 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py +171 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/losses/xoftr_loss.py +170 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py +37 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/optimizers/__init__.py +42 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/augment.py +113 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/comm.py +265 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/data_io.py +144 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/dataloader.py +23 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/dataset.py +279 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/metrics.py +211 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/misc.py +101 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/plotting.py +227 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/pretrain_utils.py +83 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/profiler.py +39 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/__init__.py +2 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/backbone/__init__.py +1 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/backbone/resnet.py +95 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/utils/geometry.py +107 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/utils/position_encoding.py +36 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/utils/supervision.py +290 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr.py +94 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py +4 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py +305 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py +170 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py +321 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py +81 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py +101 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_pretrain.py +209 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/test.py +68 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/test_relative_pose.py +330 -0
- vismatch/third_party/MINIMA/third_party/XoFTR/train.py +126 -0
- vismatch/third_party/MatchAnything/app.py +27 -0
- vismatch/third_party/MatchAnything/imcui/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/api/__init__.py +47 -0
- vismatch/third_party/MatchAnything/imcui/api/client.py +232 -0
- vismatch/third_party/MatchAnything/imcui/api/core.py +308 -0
- vismatch/third_party/MatchAnything/imcui/api/server.py +170 -0
- vismatch/third_party/MatchAnything/imcui/hloc/__init__.py +65 -0
- vismatch/third_party/MatchAnything/imcui/hloc/colmap_from_nvm.py +216 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extract_features.py +607 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/alike.py +61 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/aliked.py +32 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/cosplace.py +44 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/d2net.py +60 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/darkfeat.py +44 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/dedode.py +86 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/dir.py +78 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/disk.py +35 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/dog.py +135 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/eigenplaces.py +57 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/example.py +56 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/fire.py +72 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/fire_local.py +84 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/lanet.py +63 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/netvlad.py +146 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/openibl.py +26 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/r2d2.py +73 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/rekd.py +60 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/rord.py +59 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/sfd2.py +44 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/sift.py +216 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/superpoint.py +51 -0
- vismatch/third_party/MatchAnything/imcui/hloc/extractors/xfeat.py +33 -0
- vismatch/third_party/MatchAnything/imcui/hloc/localize_inloc.py +179 -0
- vismatch/third_party/MatchAnything/imcui/hloc/localize_sfm.py +243 -0
- vismatch/third_party/MatchAnything/imcui/hloc/match_dense.py +1158 -0
- vismatch/third_party/MatchAnything/imcui/hloc/match_features.py +459 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/__init__.py +3 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/adalam.py +68 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/aspanformer.py +66 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/cotr.py +77 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/dkm.py +53 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/dual_softmax.py +71 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/duster.py +109 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/eloftr.py +97 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/gim.py +200 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/gluestick.py +99 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/imp.py +50 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/lightglue.py +67 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/loftr.py +58 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/mast3r.py +96 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/matchanything.py +191 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/mickey.py +50 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/nearest_neighbor.py +66 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/omniglue.py +80 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/roma.py +80 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/sgmnet.py +106 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/sold2.py +144 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/superglue.py +33 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/topicfm.py +60 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/xfeat_dense.py +54 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/xfeat_lightglue.py +48 -0
- vismatch/third_party/MatchAnything/imcui/hloc/matchers/xoftr.py +90 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_covisibility.py +60 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_exhaustive.py +64 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_poses.py +68 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_retrieval.py +133 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/localize.py +89 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/prepare_reference.py +51 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/utils.py +231 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py +134 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/pipeline.py +139 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/utils.py +34 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen/pipeline.py +109 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen_v1_1/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py +104 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py +104 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/CMU/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/CMU/pipeline.py +133 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Cambridge/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Cambridge/pipeline.py +140 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Cambridge/utils.py +145 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/RobotCar/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py +176 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/RobotCar/pipeline.py +143 -0
- vismatch/third_party/MatchAnything/imcui/hloc/pipelines/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/hloc/reconstruction.py +194 -0
- vismatch/third_party/MatchAnything/imcui/hloc/triangulation.py +311 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/__init__.py +12 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/base_model.py +56 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/database.py +412 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/geometry.py +16 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/io.py +77 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/parsers.py +59 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/read_write_model.py +588 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/viz.py +146 -0
- vismatch/third_party/MatchAnything/imcui/hloc/utils/viz_3d.py +203 -0
- vismatch/third_party/MatchAnything/imcui/hloc/visualization.py +178 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/configs/models/eloftr_model.py +128 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/configs/models/roma_model.py +27 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py +1 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py +344 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/config/default.py +344 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/datasets/common_data_pair.py +214 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py +343 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/__init__.py +1 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py +61 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py +319 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py +1094 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py +131 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr.py +273 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py +2 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py +350 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py +217 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py +1768 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py +76 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py +266 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py +493 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py +298 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py +131 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py +475 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/optimizers/__init__.py +50 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/augment.py +55 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/database.py +417 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py +232 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py +509 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap.py +530 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/comm.py +265 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/dataloader.py +23 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/dataset.py +518 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/easydict.py +148 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/geometry.py +366 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/homography_utils.py +366 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/metrics.py +445 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/misc.py +101 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/plotting.py +248 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/profiler.py +39 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/ray_utils.py +134 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/sample_homo.py +58 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/utils.py +600 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_3D_effect.py +46 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental.py +32 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental_model_warpper.py +34 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match.py +50 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match_opencv_sift.py +43 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo_single_pair.py +329 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_indoor.py +320 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_outdoor.py +327 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/__init__.py +1 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/plotting.py +331 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/__init__.py +8 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/__init__.py +4 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_dense_benchmark.py +106 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_pose_estimation_benchmark.py +140 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/scannet_benchmark.py +143 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/__init__.py +1 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/checkpoint.py +60 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/__init__.py +2 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/megadepth.py +230 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/scannet.py +160 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/__init__.py +1 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/robust_loss.py +157 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/matchanything_roma_model.py +104 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/__init__.py +1 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/blocks.py +241 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/criterion.py +37 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco.py +253 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco_downstream.py +122 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/__init__.py +4 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope2d.py +40 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/setup.py +34 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/dpt_block.py +450 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/head_downstream.py +58 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/masking.py +25 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/pos_embed.py +159 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/__init__.py +2 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/__init__.py +29 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/base_opt.py +375 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/commons.py +90 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/init_im_poses.py +312 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/optimizer.py +230 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/pair_viewer.py +125 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/__init__.py +42 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/__init__.py +2 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/batched_sampler.py +74 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/easy_dataset.py +157 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/co3d.py +146 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/__init__.py +2 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/cropping.py +119 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/transforms.py +11 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/__init__.py +19 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/dpt_head.py +114 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/linear_head.py +41 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/postprocess.py +58 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/image_pairs.py +83 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/inference.py +165 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/losses.py +297 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/model.py +167 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/optim_factory.py +14 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/patch_embed.py +70 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/post_process.py +60 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/__init__.py +2 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/device.py +76 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/geometry.py +361 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/image.py +104 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/misc.py +121 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/path_to_croco.py +19 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/viz.py +320 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/encoders.py +137 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/matcher.py +937 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/__init__.py +53 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/roma_models.py +162 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/__init__.py +47 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/dinov2.py +359 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/__init__.py +12 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/attention.py +81 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/block.py +252 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/dino_head.py +59 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/drop_path.py +35 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/layer_scale.py +28 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/mlp.py +41 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/patch_embed.py +89 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/swiglu_ffn.py +63 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/roma_adpat_model.py +32 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/__init__.py +1 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/train.py +102 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/__init__.py +18 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/kde.py +8 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/local_correlation.py +44 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/transforms.py +118 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/utils.py +661 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/setup.py +9 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/__init__.py +0 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/evaluate_datasets.py +239 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/tools_utils/data_io.py +94 -0
- vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/tools_utils/plot.py +77 -0
- vismatch/third_party/MatchAnything/imcui/ui/__init__.py +5 -0
- vismatch/third_party/MatchAnything/imcui/ui/app_class.py +824 -0
- vismatch/third_party/MatchAnything/imcui/ui/sfm.py +164 -0
- vismatch/third_party/MatchAnything/imcui/ui/utils.py +1085 -0
- vismatch/third_party/MatchAnything/imcui/ui/viz.py +511 -0
- vismatch/third_party/MatchAnything/tests/test_basic.py +111 -0
- vismatch/third_party/MatchFormer/config/data/__init__.py +0 -0
- vismatch/third_party/MatchFormer/config/data/base.py +35 -0
- vismatch/third_party/MatchFormer/config/data/megadepth_test_1500.py +11 -0
- vismatch/third_party/MatchFormer/config/data/scannet_test_1500.py +11 -0
- vismatch/third_party/MatchFormer/config/defaultmf.py +88 -0
- vismatch/third_party/MatchFormer/model/backbone/__init__.py +17 -0
- vismatch/third_party/MatchFormer/model/backbone/coarse_matching.py +228 -0
- vismatch/third_party/MatchFormer/model/backbone/fine_matching.py +74 -0
- vismatch/third_party/MatchFormer/model/backbone/fine_preprocess.py +59 -0
- vismatch/third_party/MatchFormer/model/backbone/match_LA_large.py +254 -0
- vismatch/third_party/MatchFormer/model/backbone/match_LA_lite.py +254 -0
- vismatch/third_party/MatchFormer/model/backbone/match_SEA_large.py +291 -0
- vismatch/third_party/MatchFormer/model/backbone/match_SEA_lite.py +291 -0
- vismatch/third_party/MatchFormer/model/data.py +320 -0
- vismatch/third_party/MatchFormer/model/datasets/dataset.py +231 -0
- vismatch/third_party/MatchFormer/model/datasets/megadepth.py +126 -0
- vismatch/third_party/MatchFormer/model/datasets/sampler.py +77 -0
- vismatch/third_party/MatchFormer/model/datasets/scannet.py +113 -0
- vismatch/third_party/MatchFormer/model/lightning_loftr.py +102 -0
- vismatch/third_party/MatchFormer/model/matchformer.py +54 -0
- vismatch/third_party/MatchFormer/model/utils/augment.py +55 -0
- vismatch/third_party/MatchFormer/model/utils/comm.py +265 -0
- vismatch/third_party/MatchFormer/model/utils/dataloader.py +23 -0
- vismatch/third_party/MatchFormer/model/utils/metrics.py +193 -0
- vismatch/third_party/MatchFormer/model/utils/misc.py +101 -0
- vismatch/third_party/MatchFormer/model/utils/profiler.py +39 -0
- vismatch/third_party/MatchFormer/test.py +55 -0
- vismatch/third_party/RIPE/app.py +272 -0
- vismatch/third_party/RIPE/demo.py +51 -0
- vismatch/third_party/RIPE/ripe/__init__.py +1 -0
- vismatch/third_party/RIPE/ripe/benchmarks/imw_2020.py +320 -0
- vismatch/third_party/RIPE/ripe/data/__init__.py +0 -0
- vismatch/third_party/RIPE/ripe/data/data_transforms.py +204 -0
- vismatch/third_party/RIPE/ripe/data/datasets/__init__.py +0 -0
- vismatch/third_party/RIPE/ripe/data/datasets/acdc.py +154 -0
- vismatch/third_party/RIPE/ripe/data/datasets/dataset_combinator.py +88 -0
- vismatch/third_party/RIPE/ripe/data/datasets/disk_imw.py +160 -0
- vismatch/third_party/RIPE/ripe/data/datasets/disk_megadepth.py +157 -0
- vismatch/third_party/RIPE/ripe/data/datasets/tokyo247.py +134 -0
- vismatch/third_party/RIPE/ripe/data/datasets/tokyo_query_v3.py +78 -0
- vismatch/third_party/RIPE/ripe/losses/__init__.py +0 -0
- vismatch/third_party/RIPE/ripe/losses/contrastive_loss.py +88 -0
- vismatch/third_party/RIPE/ripe/matcher/__init__.py +0 -0
- vismatch/third_party/RIPE/ripe/matcher/concurrent_matcher.py +97 -0
- vismatch/third_party/RIPE/ripe/matcher/pose_estimator_poselib.py +31 -0
- vismatch/third_party/RIPE/ripe/model_zoo/__init__.py +1 -0
- vismatch/third_party/RIPE/ripe/model_zoo/vgg_hyper.py +39 -0
- vismatch/third_party/RIPE/ripe/models/__init__.py +0 -0
- vismatch/third_party/RIPE/ripe/models/backbones/__init__.py +0 -0
- vismatch/third_party/RIPE/ripe/models/backbones/backbone_base.py +61 -0
- vismatch/third_party/RIPE/ripe/models/backbones/vgg.py +99 -0
- vismatch/third_party/RIPE/ripe/models/backbones/vgg_utils.py +143 -0
- vismatch/third_party/RIPE/ripe/models/ripe.py +303 -0
- vismatch/third_party/RIPE/ripe/models/upsampler/hypercolumn_features.py +54 -0
- vismatch/third_party/RIPE/ripe/models/upsampler/interpolate_sparse2d.py +37 -0
- vismatch/third_party/RIPE/ripe/scheduler/__init__.py +0 -0
- vismatch/third_party/RIPE/ripe/scheduler/constant.py +6 -0
- vismatch/third_party/RIPE/ripe/scheduler/expDecay.py +26 -0
- vismatch/third_party/RIPE/ripe/scheduler/linearLR.py +37 -0
- vismatch/third_party/RIPE/ripe/scheduler/linear_with_plateaus.py +44 -0
- vismatch/third_party/RIPE/ripe/train.py +410 -0
- vismatch/third_party/RIPE/ripe/utils/__init__.py +2 -0
- vismatch/third_party/RIPE/ripe/utils/image_utils.py +62 -0
- vismatch/third_party/RIPE/ripe/utils/pose_error.py +62 -0
- vismatch/third_party/RIPE/ripe/utils/pylogger.py +32 -0
- vismatch/third_party/RIPE/ripe/utils/utils.py +192 -0
- vismatch/third_party/RIPE/ripe/utils/wandb_utils.py +16 -0
- vismatch/third_party/RoMa/demo/demo_3D_effect.py +47 -0
- vismatch/third_party/RoMa/demo/demo_fundamental.py +34 -0
- vismatch/third_party/RoMa/demo/demo_match.py +50 -0
- vismatch/third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
- vismatch/third_party/RoMa/demo/demo_match_tiny.py +77 -0
- vismatch/third_party/RoMa/experiments/eval_roma_outdoor.py +57 -0
- vismatch/third_party/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
- vismatch/third_party/RoMa/experiments/roma_indoor.py +320 -0
- vismatch/third_party/RoMa/experiments/train_roma_outdoor.py +307 -0
- vismatch/third_party/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
- vismatch/third_party/RoMa/romatch/__init__.py +8 -0
- vismatch/third_party/RoMa/romatch/benchmarks/__init__.py +6 -0
- vismatch/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
- vismatch/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
- vismatch/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
- vismatch/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
- vismatch/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
- vismatch/third_party/RoMa/romatch/checkpointing/__init__.py +1 -0
- vismatch/third_party/RoMa/romatch/checkpointing/checkpoint.py +60 -0
- vismatch/third_party/RoMa/romatch/datasets/__init__.py +2 -0
- vismatch/third_party/RoMa/romatch/datasets/megadepth.py +232 -0
- vismatch/third_party/RoMa/romatch/datasets/scannet.py +160 -0
- vismatch/third_party/RoMa/romatch/losses/__init__.py +1 -0
- vismatch/third_party/RoMa/romatch/losses/robust_loss.py +161 -0
- vismatch/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py +160 -0
- vismatch/third_party/RoMa/romatch/models/__init__.py +1 -0
- vismatch/third_party/RoMa/romatch/models/encoders.py +122 -0
- vismatch/third_party/RoMa/romatch/models/matcher.py +748 -0
- vismatch/third_party/RoMa/romatch/models/model_zoo/__init__.py +73 -0
- vismatch/third_party/RoMa/romatch/models/model_zoo/roma_models.py +170 -0
- vismatch/third_party/RoMa/romatch/models/tiny.py +304 -0
- vismatch/third_party/RoMa/romatch/models/transformer/__init__.py +48 -0
- vismatch/third_party/RoMa/romatch/models/transformer/dinov2.py +359 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/__init__.py +12 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/attention.py +81 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/block.py +252 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/dino_head.py +59 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/drop_path.py +35 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py +28 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/mlp.py +41 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py +89 -0
- vismatch/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py +63 -0
- vismatch/third_party/RoMa/romatch/train/__init__.py +1 -0
- vismatch/third_party/RoMa/romatch/train/train.py +102 -0
- vismatch/third_party/RoMa/romatch/utils/__init__.py +16 -0
- vismatch/third_party/RoMa/romatch/utils/kde.py +13 -0
- vismatch/third_party/RoMa/romatch/utils/local_correlation.py +48 -0
- vismatch/third_party/RoMa/romatch/utils/transforms.py +118 -0
- vismatch/third_party/RoMa/romatch/utils/utils.py +654 -0
- vismatch/third_party/RoMa/setup.py +9 -0
- vismatch/third_party/RoMaV2/demo/demo_covariance.py +52 -0
- vismatch/third_party/RoMaV2/demo/demo_match.py +55 -0
- vismatch/third_party/RoMaV2/src/romav2/__init__.py +8 -0
- vismatch/third_party/RoMaV2/src/romav2/benchmarks/__init__.py +4 -0
- vismatch/third_party/RoMaV2/src/romav2/benchmarks/mega1500.py +115 -0
- vismatch/third_party/RoMaV2/src/romav2/benchmarks/satast.py +463 -0
- vismatch/third_party/RoMaV2/src/romav2/benchmarks/scannet1500.py +125 -0
- vismatch/third_party/RoMaV2/src/romav2/benchmarks/wxbs.py +104 -0
- vismatch/third_party/RoMaV2/src/romav2/device.py +9 -0
- vismatch/third_party/RoMaV2/src/romav2/dpt.py +516 -0
- vismatch/third_party/RoMaV2/src/romav2/features.py +190 -0
- vismatch/third_party/RoMaV2/src/romav2/geometry.py +261 -0
- vismatch/third_party/RoMaV2/src/romav2/io.py +24 -0
- vismatch/third_party/RoMaV2/src/romav2/local_correlation.py +152 -0
- vismatch/third_party/RoMaV2/src/romav2/logging.py +97 -0
- vismatch/third_party/RoMaV2/src/romav2/matcher.py +207 -0
- vismatch/third_party/RoMaV2/src/romav2/normalizers.py +17 -0
- vismatch/third_party/RoMaV2/src/romav2/refiner.py +277 -0
- vismatch/third_party/RoMaV2/src/romav2/romav2.py +533 -0
- vismatch/third_party/RoMaV2/src/romav2/types.py +75 -0
- vismatch/third_party/RoMaV2/src/romav2/vis.py +36 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/__init__.py +304 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/attention.py +181 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/block.py +293 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/ffn_layers.py +83 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/layer_scale.py +29 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/patch_embed.py +94 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/rms_norm.py +24 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/rope.py +133 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/rope_mixed.py +111 -0
- vismatch/third_party/RoMaV2/src/romav2/vit/utils.py +48 -0
- vismatch/third_party/RoMaV2/tests/test_bidirectional.py +93 -0
- vismatch/third_party/RoMaV2/tests/test_fps.py +49 -0
- vismatch/third_party/RoMaV2/tests/test_mega1500.py +22 -0
- vismatch/third_party/RoMaV2/tests/test_scannet1500.py +21 -0
- vismatch/third_party/RoMaV2/tests/test_smoke.py +15 -0
- vismatch/third_party/Se2_LoFTR/configs/data/__init__.py +0 -0
- vismatch/third_party/Se2_LoFTR/configs/data/base.py +35 -0
- vismatch/third_party/Se2_LoFTR/configs/data/megadepth_test_1500.py +11 -0
- vismatch/third_party/Se2_LoFTR/configs/data/megadepth_trainval_640.py +22 -0
- vismatch/third_party/Se2_LoFTR/configs/data/megadepth_trainval_840.py +22 -0
- vismatch/third_party/Se2_LoFTR/configs/data/scannet_test_1500.py +11 -0
- vismatch/third_party/Se2_LoFTR/configs/data/scannet_trainval.py +17 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds.py +6 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py +8 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot.py +6 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py +8 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ds.py +5 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ds_dense.py +7 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ot.py +5 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ot_dense.py +7 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval.py +16 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval_new.py +18 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py +16 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py +17 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py +16 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py +17 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds.py +17 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_dense.py +17 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2.py +20 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2_dense.py +23 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2_dense_8rot.py +23 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2_dense_big.py +22 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ot.py +17 -0
- vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ot_dense.py +18 -0
- vismatch/third_party/Se2_LoFTR/demo/demo_loftr.py +240 -0
- vismatch/third_party/Se2_LoFTR/src/__init__.py +0 -0
- vismatch/third_party/Se2_LoFTR/src/config/default.py +173 -0
- vismatch/third_party/Se2_LoFTR/src/datasets/megadepth.py +127 -0
- vismatch/third_party/Se2_LoFTR/src/datasets/sampler.py +77 -0
- vismatch/third_party/Se2_LoFTR/src/datasets/scannet.py +114 -0
- vismatch/third_party/Se2_LoFTR/src/lightning/data.py +320 -0
- vismatch/third_party/Se2_LoFTR/src/lightning/lightning_loftr.py +249 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/__init__.py +2 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/backbone/__init__.py +17 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/backbone/resnet_e2.py +170 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/backbone/resnet_fpn.py +199 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/loftr.py +81 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/__init__.py +2 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/fine_preprocess.py +59 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/linear_attention.py +81 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/transformer.py +101 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/utils/coarse_matching.py +261 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/utils/cvpr_ds_config.py +50 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/utils/fine_matching.py +74 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/utils/geometry.py +54 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/utils/position_encoding.py +42 -0
- vismatch/third_party/Se2_LoFTR/src/loftr/utils/supervision.py +151 -0
- vismatch/third_party/Se2_LoFTR/src/losses/loftr_loss.py +192 -0
- vismatch/third_party/Se2_LoFTR/src/optimizers/__init__.py +42 -0
- vismatch/third_party/Se2_LoFTR/src/utils/augment.py +55 -0
- vismatch/third_party/Se2_LoFTR/src/utils/comm.py +265 -0
- vismatch/third_party/Se2_LoFTR/src/utils/dataloader.py +23 -0
- vismatch/third_party/Se2_LoFTR/src/utils/dataset.py +185 -0
- vismatch/third_party/Se2_LoFTR/src/utils/metrics.py +193 -0
- vismatch/third_party/Se2_LoFTR/src/utils/misc.py +104 -0
- vismatch/third_party/Se2_LoFTR/src/utils/plotting.py +154 -0
- vismatch/third_party/Se2_LoFTR/src/utils/profiler.py +39 -0
- vismatch/third_party/Se2_LoFTR/test.py +68 -0
- vismatch/third_party/Se2_LoFTR/train.py +123 -0
- vismatch/third_party/SphereGlue/demo_SphereGlue.py +141 -0
- vismatch/third_party/SphereGlue/model/sphereglue.py +230 -0
- vismatch/third_party/SphereGlue/utils/Utils.py +191 -0
- vismatch/third_party/SphereGlue/utils/demo_mydataset.py +119 -0
- vismatch/third_party/Steerers/rotation_steerers/matchers/dual_softmax_matcher.py +44 -0
- vismatch/third_party/Steerers/rotation_steerers/matchers/max_matches.py +205 -0
- vismatch/third_party/Steerers/rotation_steerers/matchers/max_similarity.py +115 -0
- vismatch/third_party/Steerers/rotation_steerers/steerers.py +37 -0
- vismatch/third_party/Steerers/setup.py +14 -0
- vismatch/third_party/TopicFM/configs/megadepth_test.py +17 -0
- vismatch/third_party/TopicFM/configs/megadepth_test_topicfmfast.py +17 -0
- vismatch/third_party/TopicFM/configs/megadepth_test_topicfmplus.py +20 -0
- vismatch/third_party/TopicFM/configs/megadepth_train.py +36 -0
- vismatch/third_party/TopicFM/configs/megadepth_train_topicfmfast.py +34 -0
- vismatch/third_party/TopicFM/configs/megadepth_train_topicfmplus.py +37 -0
- vismatch/third_party/TopicFM/configs/scannet_test.py +15 -0
- vismatch/third_party/TopicFM/configs/scannet_test_topicfmfast.py +15 -0
- vismatch/third_party/TopicFM/configs/scannet_test_topicfmplus.py +19 -0
- vismatch/third_party/TopicFM/src/__init__.py +11 -0
- vismatch/third_party/TopicFM/src/config/default.py +174 -0
- vismatch/third_party/TopicFM/src/datasets/aachen.py +29 -0
- vismatch/third_party/TopicFM/src/datasets/custom_dataloader.py +126 -0
- vismatch/third_party/TopicFM/src/datasets/inloc.py +29 -0
- vismatch/third_party/TopicFM/src/datasets/megadepth.py +170 -0
- vismatch/third_party/TopicFM/src/datasets/sampler.py +77 -0
- vismatch/third_party/TopicFM/src/datasets/scannet.py +115 -0
- vismatch/third_party/TopicFM/src/lightning_trainer/data.py +292 -0
- vismatch/third_party/TopicFM/src/lightning_trainer/trainer.py +244 -0
- vismatch/third_party/TopicFM/src/losses/loss.py +228 -0
- vismatch/third_party/TopicFM/src/models/__init__.py +1 -0
- vismatch/third_party/TopicFM/src/models/backbone/__init__.py +12 -0
- vismatch/third_party/TopicFM/src/models/backbone/convnext.py +165 -0
- vismatch/third_party/TopicFM/src/models/backbone/fpn.py +114 -0
- vismatch/third_party/TopicFM/src/models/modules/__init__.py +2 -0
- vismatch/third_party/TopicFM/src/models/modules/encoder.py +266 -0
- vismatch/third_party/TopicFM/src/models/modules/fine_preprocess.py +59 -0
- vismatch/third_party/TopicFM/src/models/modules/linear_attention.py +84 -0
- vismatch/third_party/TopicFM/src/models/topic_fm.py +100 -0
- vismatch/third_party/TopicFM/src/models/utils/coarse_matching.py +213 -0
- vismatch/third_party/TopicFM/src/models/utils/fine_matching.py +172 -0
- vismatch/third_party/TopicFM/src/models/utils/geometry.py +54 -0
- vismatch/third_party/TopicFM/src/models/utils/supervision.py +167 -0
- vismatch/third_party/TopicFM/src/optimizers/__init__.py +42 -0
- vismatch/third_party/TopicFM/src/utils/augment.py +55 -0
- vismatch/third_party/TopicFM/src/utils/comm.py +265 -0
- vismatch/third_party/TopicFM/src/utils/dataloader.py +23 -0
- vismatch/third_party/TopicFM/src/utils/dataset.py +206 -0
- vismatch/third_party/TopicFM/src/utils/metrics.py +193 -0
- vismatch/third_party/TopicFM/src/utils/misc.py +101 -0
- vismatch/third_party/TopicFM/src/utils/plotting.py +313 -0
- vismatch/third_party/TopicFM/src/utils/profiler.py +39 -0
- vismatch/third_party/TopicFM/test.py +70 -0
- vismatch/third_party/TopicFM/third_party/__init__.py +0 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/indoor/aspan_test.py +7 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/indoor/aspan_train.py +8 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/outdoor/aspan_test.py +18 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/outdoor/aspan_train.py +17 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/__init__.py +0 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/base.py +35 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/megadepth_test_1500.py +13 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/megadepth_trainval_832.py +22 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/scannet_test_1500.py +11 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/scannet_trainval.py +17 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/demo/demo.py +63 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/demo/demo_utils.py +44 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/__init__.py +2 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/__init__.py +3 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/attention.py +198 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/fine_preprocess.py +59 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/loftr.py +112 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/transformer.py +244 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspanformer.py +133 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/backbone/__init__.py +11 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/backbone/resnet_fpn.py +199 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/coarse_matching.py +331 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/cvpr_ds_config.py +50 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/fine_matching.py +74 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/geometry.py +54 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/position_encoding.py +61 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/supervision.py +151 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/__init__.py +0 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/config/default.py +180 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/__init__.py +3 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/megadepth.py +127 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/sampler.py +77 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/scannet.py +113 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/lightning/data.py +326 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/lightning/lightning_aspanformer.py +276 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/losses/aspan_loss.py +231 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/optimizers/__init__.py +42 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/augment.py +55 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/comm.py +265 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/dataloader.py +23 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/dataset.py +222 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/metrics.py +260 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/misc.py +139 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/plotting.py +219 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/profiler.py +39 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/test.py +69 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/tools/SensorData.py +125 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/tools/extract.py +47 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/tools/preprocess_scene.py +242 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/tools/reader.py +39 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/tools/undistort_mega.py +69 -0
- vismatch/third_party/TopicFM/third_party/aspanformer/train.py +134 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/data/__init__.py +0 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/data/base.py +35 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/data/megadepth_test_1500.py +11 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/data/megadepth_trainval_640.py +22 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/data/megadepth_trainval_840.py +22 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/data/scannet_test_1500.py +11 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/data/scannet_trainval.py +17 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ds.py +6 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py +8 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ot.py +6 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py +8 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ds.py +5 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ds_dense.py +7 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ot.py +5 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ot_dense.py +7 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/scannet/loftr_ds_eval.py +16 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/scannet/loftr_ds_eval_new.py +18 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py +16 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py +17 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py +16 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py +17 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ds.py +15 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ds_dense.py +16 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ot.py +15 -0
- vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ot_dense.py +16 -0
- vismatch/third_party/TopicFM/third_party/loftr/demo/demo_loftr.py +240 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/__init__.py +0 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/config/default.py +171 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/datasets/megadepth.py +127 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/datasets/sampler.py +77 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/datasets/scannet.py +114 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/lightning/data.py +320 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/lightning/lightning_loftr.py +249 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/__init__.py +2 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/backbone/__init__.py +11 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/backbone/resnet_fpn.py +199 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr.py +81 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/__init__.py +2 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/fine_preprocess.py +59 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/linear_attention.py +81 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/transformer.py +101 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/coarse_matching.py +261 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/cvpr_ds_config.py +50 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/fine_matching.py +74 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/geometry.py +54 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/position_encoding.py +42 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/supervision.py +151 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/losses/loftr_loss.py +192 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/optimizers/__init__.py +42 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/augment.py +55 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/comm.py +265 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/dataloader.py +23 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/dataset.py +185 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/metrics.py +193 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/misc.py +101 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/plotting.py +154 -0
- vismatch/third_party/TopicFM/third_party/loftr/src/utils/profiler.py +39 -0
- vismatch/third_party/TopicFM/third_party/loftr/test.py +68 -0
- vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/demo_superglue.py +259 -0
- vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/match_pairs.py +425 -0
- vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/__init__.py +0 -0
- vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/matching.py +84 -0
- vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/superglue.py +283 -0
- vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/superpoint.py +202 -0
- vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/utils.py +555 -0
- vismatch/third_party/TopicFM/third_party/loftr/train.py +123 -0
- vismatch/third_party/TopicFM/third_party/matchformer/config/data/__init__.py +0 -0
- vismatch/third_party/TopicFM/third_party/matchformer/config/data/base.py +35 -0
- vismatch/third_party/TopicFM/third_party/matchformer/config/data/megadepth_test_1500.py +11 -0
- vismatch/third_party/TopicFM/third_party/matchformer/config/data/scannet_test_1500.py +11 -0
- vismatch/third_party/TopicFM/third_party/matchformer/config/defaultmf.py +88 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/__init__.py +17 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/coarse_matching.py +228 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/fine_matching.py +74 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/fine_preprocess.py +59 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_LA_large.py +254 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_LA_lite.py +254 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_SEA_large.py +291 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_SEA_lite.py +291 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/data.py +320 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/dataset.py +231 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/megadepth.py +126 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/sampler.py +77 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/scannet.py +113 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/lightning_loftr.py +102 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/matchformer.py +54 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/utils/augment.py +55 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/utils/comm.py +265 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/utils/dataloader.py +23 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/utils/metrics.py +193 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/utils/misc.py +101 -0
- vismatch/third_party/TopicFM/third_party/matchformer/model/utils/profiler.py +39 -0
- vismatch/third_party/TopicFM/third_party/matchformer/test.py +55 -0
- vismatch/third_party/TopicFM/train.py +123 -0
- vismatch/third_party/TopicFM/visualization.py +123 -0
- vismatch/third_party/TopicFM/viz/__init__.py +1 -0
- vismatch/third_party/TopicFM/viz/configs/__init__.py +0 -0
- vismatch/third_party/TopicFM/viz/methods/__init__.py +0 -0
- vismatch/third_party/TopicFM/viz/methods/base.py +70 -0
- vismatch/third_party/TopicFM/viz/methods/topicfmv2.py +208 -0
- vismatch/third_party/UFM/UniCeption/examples/models/cosmos/autoencoding.py +48 -0
- vismatch/third_party/UFM/UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py +331 -0
- vismatch/third_party/UFM/UniCeption/examples/models/dust3r/dust3r.py +261 -0
- vismatch/third_party/UFM/UniCeption/examples/models/dust3r/profile_dust3r.py +47 -0
- vismatch/third_party/UFM/UniCeption/scripts/check_dependencies.py +48 -0
- vismatch/third_party/UFM/UniCeption/scripts/download_checkpoints.py +50 -0
- vismatch/third_party/UFM/UniCeption/scripts/install_croco_rope.py +61 -0
- vismatch/third_party/UFM/UniCeption/scripts/prepare_offline_install.py +398 -0
- vismatch/third_party/UFM/UniCeption/scripts/validate_installation.py +212 -0
- vismatch/third_party/UFM/UniCeption/setup.py +185 -0
- vismatch/third_party/UFM/UniCeption/tests/models/encoders/conftest.py +26 -0
- vismatch/third_party/UFM/UniCeption/tests/models/encoders/test_encoders.py +202 -0
- vismatch/third_party/UFM/UniCeption/tests/models/encoders/viz_image_encoders.py +294 -0
- vismatch/third_party/UFM/UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py +337 -0
- vismatch/third_party/UFM/UniCeption/uniception/__init__.py +0 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/__init__.py +0 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/__init__.py +225 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/base.py +157 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/cosmos.py +137 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/croco.py +457 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/dense_rep_encoder.py +344 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/dinov2.py +333 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/global_rep_encoder.py +115 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/image_normalizations.py +35 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/list.py +10 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/patch_embedder.py +235 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/radio.py +367 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/encoders/utils.py +86 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/factory/__init__.py +3 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/factory/dust3r.py +332 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/__init__.py +39 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py +973 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/base.py +116 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/cross_attention_transformer.py +612 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py +588 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/global_attention_transformer.py +1154 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/__init__.py +0 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/__init__.py +14 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/image_cli.py +175 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/image_lib.py +123 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/__init__.py +60 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/distributions.py +41 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers2d.py +326 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers3d.py +965 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/patching.py +310 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/quantizers.py +510 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/utils.py +115 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/__init__.py +39 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/configs.py +146 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_image.py +86 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_video.py +98 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_image.py +113 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_video.py +115 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/utils.py +402 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/video_cli.py +195 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/video_lib.py +145 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/__init__.py +0 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/blocks.py +249 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/curope/__init__.py +4 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/curope/curope2d.py +39 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/curope/setup.py +33 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/dpt_block.py +530 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/patch_embed.py +127 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/pos_embed.py +155 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/__init__.py +18 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/adaptors.py +1765 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/base.py +210 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/cosmos.py +211 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/dpt.py +676 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/global_head.py +142 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/linear.py +95 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/mlp_feature.py +114 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/mlp_head.py +114 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/moge_conv.py +342 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/pose_head.py +181 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/utils/__init__.py +0 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/utils/config.py +34 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/utils/intermediate_feature_return.py +85 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/utils/positional_encoding.py +23 -0
- vismatch/third_party/UFM/UniCeption/uniception/models/utils/transformer_blocks.py +1072 -0
- vismatch/third_party/UFM/UniCeption/uniception/utils/__init__.py +0 -0
- vismatch/third_party/UFM/UniCeption/uniception/utils/profile.py +13 -0
- vismatch/third_party/UFM/UniCeption/uniception/utils/viz.py +99 -0
- vismatch/third_party/UFM/example_inference.py +138 -0
- vismatch/third_party/UFM/gradio_demo.py +238 -0
- vismatch/third_party/UFM/setup.py +86 -0
- vismatch/third_party/UFM/uniflowmatch/__init__.py +16 -0
- vismatch/third_party/UFM/uniflowmatch/cli.py +217 -0
- vismatch/third_party/UFM/uniflowmatch/models/__init__.py +25 -0
- vismatch/third_party/UFM/uniflowmatch/models/base.py +334 -0
- vismatch/third_party/UFM/uniflowmatch/models/ufm.py +1323 -0
- vismatch/third_party/UFM/uniflowmatch/models/unet_encoder.py +90 -0
- vismatch/third_party/UFM/uniflowmatch/models/utils.py +16 -0
- vismatch/third_party/UFM/uniflowmatch/utils/__init__.py +63 -0
- vismatch/third_party/UFM/uniflowmatch/utils/flow_resizing.py +1091 -0
- vismatch/third_party/UFM/uniflowmatch/utils/geometry.py +612 -0
- vismatch/third_party/UFM/uniflowmatch/utils/viz.py +97 -0
- vismatch/third_party/XoFTR/configs/data/__init__.py +0 -0
- vismatch/third_party/XoFTR/configs/data/base.py +35 -0
- vismatch/third_party/XoFTR/configs/data/megadepth_trainval_840.py +22 -0
- vismatch/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py +23 -0
- vismatch/third_party/XoFTR/configs/data/pretrain.py +8 -0
- vismatch/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py +17 -0
- vismatch/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py +12 -0
- vismatch/third_party/XoFTR/pretrain.py +125 -0
- vismatch/third_party/XoFTR/src/__init__.py +0 -0
- vismatch/third_party/XoFTR/src/config/default.py +203 -0
- vismatch/third_party/XoFTR/src/datasets/megadepth.py +143 -0
- vismatch/third_party/XoFTR/src/datasets/pretrain_dataset.py +156 -0
- vismatch/third_party/XoFTR/src/datasets/sampler.py +77 -0
- vismatch/third_party/XoFTR/src/datasets/scannet.py +114 -0
- vismatch/third_party/XoFTR/src/datasets/vistir.py +109 -0
- vismatch/third_party/XoFTR/src/lightning/data.py +346 -0
- vismatch/third_party/XoFTR/src/lightning/data_pretrain.py +125 -0
- vismatch/third_party/XoFTR/src/lightning/lightning_xoftr.py +334 -0
- vismatch/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py +171 -0
- vismatch/third_party/XoFTR/src/losses/xoftr_loss.py +170 -0
- vismatch/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py +37 -0
- vismatch/third_party/XoFTR/src/optimizers/__init__.py +42 -0
- vismatch/third_party/XoFTR/src/utils/augment.py +113 -0
- vismatch/third_party/XoFTR/src/utils/comm.py +265 -0
- vismatch/third_party/XoFTR/src/utils/data_io.py +144 -0
- vismatch/third_party/XoFTR/src/utils/dataloader.py +23 -0
- vismatch/third_party/XoFTR/src/utils/dataset.py +279 -0
- vismatch/third_party/XoFTR/src/utils/metrics.py +211 -0
- vismatch/third_party/XoFTR/src/utils/misc.py +101 -0
- vismatch/third_party/XoFTR/src/utils/plotting.py +227 -0
- vismatch/third_party/XoFTR/src/utils/pretrain_utils.py +83 -0
- vismatch/third_party/XoFTR/src/utils/profiler.py +39 -0
- vismatch/third_party/XoFTR/src/xoftr/__init__.py +2 -0
- vismatch/third_party/XoFTR/src/xoftr/backbone/__init__.py +1 -0
- vismatch/third_party/XoFTR/src/xoftr/backbone/resnet.py +95 -0
- vismatch/third_party/XoFTR/src/xoftr/utils/geometry.py +107 -0
- vismatch/third_party/XoFTR/src/xoftr/utils/position_encoding.py +36 -0
- vismatch/third_party/XoFTR/src/xoftr/utils/supervision.py +290 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr.py +94 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py +4 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py +305 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py +170 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py +321 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py +81 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py +101 -0
- vismatch/third_party/XoFTR/src/xoftr/xoftr_pretrain.py +209 -0
- vismatch/third_party/XoFTR/test.py +68 -0
- vismatch/third_party/XoFTR/test_relative_pose.py +330 -0
- vismatch/third_party/XoFTR/train.py +126 -0
- vismatch/third_party/accelerated_features/hubconf.py +15 -0
- vismatch/third_party/accelerated_features/minimal_example.py +49 -0
- vismatch/third_party/accelerated_features/modules/__init__.py +4 -0
- vismatch/third_party/accelerated_features/modules/dataset/__init__.py +5 -0
- vismatch/third_party/accelerated_features/modules/dataset/augmentation.py +314 -0
- vismatch/third_party/accelerated_features/modules/dataset/megadepth/__init__.py +7 -0
- vismatch/third_party/accelerated_features/modules/dataset/megadepth/megadepth.py +174 -0
- vismatch/third_party/accelerated_features/modules/dataset/megadepth/megadepth_warper.py +170 -0
- vismatch/third_party/accelerated_features/modules/dataset/megadepth/utils.py +160 -0
- vismatch/third_party/accelerated_features/modules/interpolator.py +33 -0
- vismatch/third_party/accelerated_features/modules/lighterglue.py +56 -0
- vismatch/third_party/accelerated_features/modules/model.py +154 -0
- vismatch/third_party/accelerated_features/modules/training/__init__.py +4 -0
- vismatch/third_party/accelerated_features/modules/training/losses.py +224 -0
- vismatch/third_party/accelerated_features/modules/training/train.py +311 -0
- vismatch/third_party/accelerated_features/modules/training/utils.py +200 -0
- vismatch/third_party/accelerated_features/modules/xfeat.py +402 -0
- vismatch/third_party/accelerated_features/realtime_demo.py +295 -0
- vismatch/third_party/accelerated_features/third_party/ALIKE/alike.py +143 -0
- vismatch/third_party/accelerated_features/third_party/ALIKE/alnet.py +164 -0
- vismatch/third_party/accelerated_features/third_party/ALIKE/demo.py +167 -0
- vismatch/third_party/accelerated_features/third_party/ALIKE/hseq/eval.py +162 -0
- vismatch/third_party/accelerated_features/third_party/ALIKE/hseq/extract.py +159 -0
- vismatch/third_party/accelerated_features/third_party/ALIKE/soft_detect.py +194 -0
- vismatch/third_party/accelerated_features/third_party/__init__.py +4 -0
- vismatch/third_party/accelerated_features/third_party/alike_wrapper.py +110 -0
- vismatch/third_party/affine-steerers/affine_steerers/__init__.py +7 -0
- vismatch/third_party/affine-steerers/affine_steerers/benchmarks/__init__.py +5 -0
- vismatch/third_party/affine-steerers/affine_steerers/benchmarks/hpatches.py +92 -0
- vismatch/third_party/affine-steerers/affine_steerers/benchmarks/hpatches_oracle_steer.py +108 -0
- vismatch/third_party/affine-steerers/affine_steerers/benchmarks/mega_pose_est.py +116 -0
- vismatch/third_party/affine-steerers/affine_steerers/benchmarks/mega_pose_est_mnn.py +162 -0
- vismatch/third_party/affine-steerers/affine_steerers/benchmarks/nll_benchmark.py +57 -0
- vismatch/third_party/affine-steerers/affine_steerers/benchmarks/num_inliers.py +76 -0
- vismatch/third_party/affine-steerers/affine_steerers/checkpoint.py +82 -0
- vismatch/third_party/affine-steerers/affine_steerers/datasets/__init__.py +0 -0
- vismatch/third_party/affine-steerers/affine_steerers/datasets/homog.py +284 -0
- vismatch/third_party/affine-steerers/affine_steerers/datasets/megadepth.py +408 -0
- vismatch/third_party/affine-steerers/affine_steerers/decoder.py +90 -0
- vismatch/third_party/affine-steerers/affine_steerers/descriptors/__init__.py +0 -0
- vismatch/third_party/affine-steerers/affine_steerers/descriptors/dedode_descriptor.py +77 -0
- vismatch/third_party/affine-steerers/affine_steerers/descriptors/descriptor_loss.py +358 -0
- vismatch/third_party/affine-steerers/affine_steerers/detectors/__init__.py +0 -0
- vismatch/third_party/affine-steerers/affine_steerers/detectors/dedode_detector.py +75 -0
- vismatch/third_party/affine-steerers/affine_steerers/detectors/keypoint_loss.py +215 -0
- vismatch/third_party/affine-steerers/affine_steerers/encoder.py +87 -0
- vismatch/third_party/affine-steerers/affine_steerers/matchers/__init__.py +0 -0
- vismatch/third_party/affine-steerers/affine_steerers/matchers/dual_softmax_matcher.py +816 -0
- vismatch/third_party/affine-steerers/affine_steerers/model_zoo/__init__.py +3 -0
- vismatch/third_party/affine-steerers/affine_steerers/model_zoo/dedode_models.py +298 -0
- vismatch/third_party/affine-steerers/affine_steerers/steerers.py +732 -0
- vismatch/third_party/affine-steerers/affine_steerers/train.py +90 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/__init__.py +8 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/dinov2.py +359 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/__init__.py +12 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/attention.py +81 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/block.py +252 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/dino_head.py +59 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/drop_path.py +35 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/layer_scale.py +28 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/mlp.py +41 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/patch_embed.py +89 -0
- vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/swiglu_ffn.py +63 -0
- vismatch/third_party/affine-steerers/affine_steerers/utils.py +1422 -0
- vismatch/third_party/affine-steerers/experiments/aff_equi_B.py +182 -0
- vismatch/third_party/affine-steerers/experiments/aff_equi_G.py +193 -0
- vismatch/third_party/affine-steerers/experiments/aff_steer_B.py +213 -0
- vismatch/third_party/affine-steerers/experiments/aff_steer_G.py +223 -0
- vismatch/third_party/affine-steerers/experiments/aff_steer_pretrain_B.py +187 -0
- vismatch/third_party/affine-steerers/experiments/aff_steer_pretrain_G.py +198 -0
- vismatch/third_party/affine-steerers/setup.py +15 -0
- vismatch/third_party/aspanformer/configs/aspan/indoor/aspan_test.py +7 -0
- vismatch/third_party/aspanformer/configs/aspan/indoor/aspan_train.py +8 -0
- vismatch/third_party/aspanformer/configs/aspan/outdoor/aspan_test.py +19 -0
- vismatch/third_party/aspanformer/configs/aspan/outdoor/aspan_train.py +17 -0
- vismatch/third_party/aspanformer/configs/data/__init__.py +0 -0
- vismatch/third_party/aspanformer/configs/data/base.py +35 -0
- vismatch/third_party/aspanformer/configs/data/megadepth_test_1500.py +13 -0
- vismatch/third_party/aspanformer/configs/data/megadepth_trainval_832.py +22 -0
- vismatch/third_party/aspanformer/configs/data/scannet_test_1500.py +11 -0
- vismatch/third_party/aspanformer/configs/data/scannet_trainval.py +17 -0
- vismatch/third_party/aspanformer/demo/demo.py +63 -0
- vismatch/third_party/aspanformer/demo/demo_utils.py +44 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/__init__.py +2 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/__init__.py +3 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/attention.py +198 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/fine_preprocess.py +59 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/loftr.py +112 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/transformer.py +244 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/aspanformer.py +152 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/backbone/__init__.py +11 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/backbone/resnet_fpn.py +199 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/utils/coarse_matching.py +331 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/utils/cvpr_ds_config.py +50 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/utils/fine_matching.py +74 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/utils/geometry.py +54 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/utils/position_encoding.py +61 -0
- vismatch/third_party/aspanformer/src/ASpanFormer/utils/supervision.py +151 -0
- vismatch/third_party/aspanformer/src/__init__.py +0 -0
- vismatch/third_party/aspanformer/src/config/default.py +180 -0
- vismatch/third_party/aspanformer/src/datasets/__init__.py +3 -0
- vismatch/third_party/aspanformer/src/datasets/megadepth.py +127 -0
- vismatch/third_party/aspanformer/src/datasets/sampler.py +77 -0
- vismatch/third_party/aspanformer/src/datasets/scannet.py +113 -0
- vismatch/third_party/aspanformer/src/lightning/data.py +326 -0
- vismatch/third_party/aspanformer/src/lightning/lightning_aspanformer.py +276 -0
- vismatch/third_party/aspanformer/src/losses/aspan_loss.py +231 -0
- vismatch/third_party/aspanformer/src/optimizers/__init__.py +42 -0
- vismatch/third_party/aspanformer/src/utils/augment.py +55 -0
- vismatch/third_party/aspanformer/src/utils/comm.py +265 -0
- vismatch/third_party/aspanformer/src/utils/dataloader.py +23 -0
- vismatch/third_party/aspanformer/src/utils/dataset.py +222 -0
- vismatch/third_party/aspanformer/src/utils/metrics.py +260 -0
- vismatch/third_party/aspanformer/src/utils/misc.py +139 -0
- vismatch/third_party/aspanformer/src/utils/plotting.py +219 -0
- vismatch/third_party/aspanformer/src/utils/profiler.py +39 -0
- vismatch/third_party/aspanformer/test.py +69 -0
- vismatch/third_party/aspanformer/tools/SensorData.py +125 -0
- vismatch/third_party/aspanformer/tools/extract.py +47 -0
- vismatch/third_party/aspanformer/tools/preprocess_scene.py +242 -0
- vismatch/third_party/aspanformer/tools/reader.py +39 -0
- vismatch/third_party/aspanformer/tools/undistort_mega.py +69 -0
- vismatch/third_party/aspanformer/train.py +134 -0
- vismatch/third_party/duster/croco/datasets/__init__.py +0 -0
- vismatch/third_party/duster/croco/datasets/crops/extract_crops_from_images.py +159 -0
- vismatch/third_party/duster/croco/datasets/habitat_sim/__init__.py +0 -0
- vismatch/third_party/duster/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
- vismatch/third_party/duster/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
- vismatch/third_party/duster/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
- vismatch/third_party/duster/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
- vismatch/third_party/duster/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
- vismatch/third_party/duster/croco/datasets/habitat_sim/paths.py +129 -0
- vismatch/third_party/duster/croco/datasets/pairs_dataset.py +109 -0
- vismatch/third_party/duster/croco/datasets/transforms.py +95 -0
- vismatch/third_party/duster/croco/demo.py +55 -0
- vismatch/third_party/duster/croco/models/blocks.py +241 -0
- vismatch/third_party/duster/croco/models/criterion.py +37 -0
- vismatch/third_party/duster/croco/models/croco.py +249 -0
- vismatch/third_party/duster/croco/models/croco_downstream.py +122 -0
- vismatch/third_party/duster/croco/models/curope/__init__.py +4 -0
- vismatch/third_party/duster/croco/models/curope/curope2d.py +40 -0
- vismatch/third_party/duster/croco/models/curope/setup.py +34 -0
- vismatch/third_party/duster/croco/models/dpt_block.py +450 -0
- vismatch/third_party/duster/croco/models/head_downstream.py +58 -0
- vismatch/third_party/duster/croco/models/masking.py +25 -0
- vismatch/third_party/duster/croco/models/pos_embed.py +157 -0
- vismatch/third_party/duster/croco/pretrain.py +254 -0
- vismatch/third_party/duster/croco/stereoflow/augmentor.py +290 -0
- vismatch/third_party/duster/croco/stereoflow/criterion.py +251 -0
- vismatch/third_party/duster/croco/stereoflow/datasets_flow.py +630 -0
- vismatch/third_party/duster/croco/stereoflow/datasets_stereo.py +674 -0
- vismatch/third_party/duster/croco/stereoflow/engine.py +280 -0
- vismatch/third_party/duster/croco/stereoflow/test.py +216 -0
- vismatch/third_party/duster/croco/stereoflow/train.py +253 -0
- vismatch/third_party/duster/croco/utils/misc.py +463 -0
- vismatch/third_party/duster/datasets_preprocess/habitat/find_scenes.py +78 -0
- vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/__init__.py +2 -0
- vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py +170 -0
- vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py +93 -0
- vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/projections.py +151 -0
- vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py +45 -0
- vismatch/third_party/duster/datasets_preprocess/habitat/preprocess_habitat.py +121 -0
- vismatch/third_party/duster/datasets_preprocess/path_to_root.py +13 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_arkitscenes.py +355 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_blendedMVS.py +149 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_co3d.py +295 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_megadepth.py +198 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_scannetpp.py +400 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_staticthings3d.py +130 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_waymo.py +257 -0
- vismatch/third_party/duster/datasets_preprocess/preprocess_wildrgbd.py +209 -0
- vismatch/third_party/duster/demo.py +45 -0
- vismatch/third_party/duster/dust3r/__init__.py +2 -0
- vismatch/third_party/duster/dust3r/cloud_opt/__init__.py +33 -0
- vismatch/third_party/duster/dust3r/cloud_opt/base_opt.py +405 -0
- vismatch/third_party/duster/dust3r/cloud_opt/commons.py +90 -0
- vismatch/third_party/duster/dust3r/cloud_opt/init_im_poses.py +316 -0
- vismatch/third_party/duster/dust3r/cloud_opt/modular_optimizer.py +145 -0
- vismatch/third_party/duster/dust3r/cloud_opt/optimizer.py +248 -0
- vismatch/third_party/duster/dust3r/cloud_opt/pair_viewer.py +127 -0
- vismatch/third_party/duster/dust3r/datasets/__init__.py +50 -0
- vismatch/third_party/duster/dust3r/datasets/arkitscenes.py +102 -0
- vismatch/third_party/duster/dust3r/datasets/base/__init__.py +2 -0
- vismatch/third_party/duster/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
- vismatch/third_party/duster/dust3r/datasets/base/batched_sampler.py +74 -0
- vismatch/third_party/duster/dust3r/datasets/base/easy_dataset.py +157 -0
- vismatch/third_party/duster/dust3r/datasets/blendedmvs.py +104 -0
- vismatch/third_party/duster/dust3r/datasets/co3d.py +165 -0
- vismatch/third_party/duster/dust3r/datasets/habitat.py +107 -0
- vismatch/third_party/duster/dust3r/datasets/megadepth.py +123 -0
- vismatch/third_party/duster/dust3r/datasets/scannetpp.py +96 -0
- vismatch/third_party/duster/dust3r/datasets/staticthings3d.py +96 -0
- vismatch/third_party/duster/dust3r/datasets/utils/__init__.py +2 -0
- vismatch/third_party/duster/dust3r/datasets/utils/cropping.py +124 -0
- vismatch/third_party/duster/dust3r/datasets/utils/transforms.py +11 -0
- vismatch/third_party/duster/dust3r/datasets/waymo.py +93 -0
- vismatch/third_party/duster/dust3r/datasets/wildrgbd.py +67 -0
- vismatch/third_party/duster/dust3r/demo.py +287 -0
- vismatch/third_party/duster/dust3r/heads/__init__.py +19 -0
- vismatch/third_party/duster/dust3r/heads/dpt_head.py +115 -0
- vismatch/third_party/duster/dust3r/heads/linear_head.py +41 -0
- vismatch/third_party/duster/dust3r/heads/postprocess.py +58 -0
- vismatch/third_party/duster/dust3r/image_pairs.py +104 -0
- vismatch/third_party/duster/dust3r/inference.py +150 -0
- vismatch/third_party/duster/dust3r/losses.py +299 -0
- vismatch/third_party/duster/dust3r/model.py +211 -0
- vismatch/third_party/duster/dust3r/optim_factory.py +14 -0
- vismatch/third_party/duster/dust3r/patch_embed.py +70 -0
- vismatch/third_party/duster/dust3r/post_process.py +60 -0
- vismatch/third_party/duster/dust3r/training.py +377 -0
- vismatch/third_party/duster/dust3r/utils/__init__.py +2 -0
- vismatch/third_party/duster/dust3r/utils/device.py +76 -0
- vismatch/third_party/duster/dust3r/utils/geometry.py +366 -0
- vismatch/third_party/duster/dust3r/utils/image.py +128 -0
- vismatch/third_party/duster/dust3r/utils/misc.py +121 -0
- vismatch/third_party/duster/dust3r/utils/parallel.py +79 -0
- vismatch/third_party/duster/dust3r/utils/path_to_croco.py +19 -0
- vismatch/third_party/duster/dust3r/viz.py +381 -0
- vismatch/third_party/duster/dust3r_visloc/__init__.py +2 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/__init__.py +6 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/aachen_day_night.py +24 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/base_colmap.py +282 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/base_dataset.py +19 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/cambridge_landmarks.py +19 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/inloc.py +167 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/sevenscenes.py +123 -0
- vismatch/third_party/duster/dust3r_visloc/datasets/utils.py +118 -0
- vismatch/third_party/duster/dust3r_visloc/evaluation.py +65 -0
- vismatch/third_party/duster/dust3r_visloc/localization.py +140 -0
- vismatch/third_party/duster/train.py +13 -0
- vismatch/third_party/duster/visloc.py +193 -0
- vismatch/third_party/gim/demo.py +479 -0
- vismatch/third_party/gim/dkm/__init__.py +4 -0
- vismatch/third_party/gim/dkm/benchmarks/__init__.py +4 -0
- vismatch/third_party/gim/dkm/benchmarks/hpatches_sequences_homog_benchmark.py +114 -0
- vismatch/third_party/gim/dkm/benchmarks/megadepth1500_benchmark.py +124 -0
- vismatch/third_party/gim/dkm/benchmarks/megadepth_dense_benchmark.py +86 -0
- vismatch/third_party/gim/dkm/benchmarks/scannet_benchmark.py +143 -0
- vismatch/third_party/gim/dkm/checkpointing/__init__.py +1 -0
- vismatch/third_party/gim/dkm/checkpointing/checkpoint.py +31 -0
- vismatch/third_party/gim/dkm/datasets/__init__.py +1 -0
- vismatch/third_party/gim/dkm/datasets/megadepth.py +177 -0
- vismatch/third_party/gim/dkm/datasets/scannet.py +151 -0
- vismatch/third_party/gim/dkm/losses/__init__.py +1 -0
- vismatch/third_party/gim/dkm/losses/depth_match_regression_loss.py +128 -0
- vismatch/third_party/gim/dkm/models/__init__.py +4 -0
- vismatch/third_party/gim/dkm/models/dkm.py +745 -0
- vismatch/third_party/gim/dkm/models/encoders.py +148 -0
- vismatch/third_party/gim/dkm/models/model_zoo/DKMv3.py +148 -0
- vismatch/third_party/gim/dkm/models/model_zoo/__init__.py +39 -0
- vismatch/third_party/gim/dkm/train/__init__.py +1 -0
- vismatch/third_party/gim/dkm/train/train.py +67 -0
- vismatch/third_party/gim/dkm/utils/__init__.py +13 -0
- vismatch/third_party/gim/dkm/utils/kde.py +26 -0
- vismatch/third_party/gim/dkm/utils/local_correlation.py +40 -0
- vismatch/third_party/gim/dkm/utils/transforms.py +104 -0
- vismatch/third_party/gim/dkm/utils/utils.py +341 -0
- vismatch/third_party/gim/gluefactory/__init__.py +17 -0
- vismatch/third_party/gim/gluefactory/datasets/__init__.py +25 -0
- vismatch/third_party/gim/gluefactory/datasets/augmentations.py +244 -0
- vismatch/third_party/gim/gluefactory/datasets/base_dataset.py +206 -0
- vismatch/third_party/gim/gluefactory/datasets/eth3d.py +254 -0
- vismatch/third_party/gim/gluefactory/datasets/homographies.py +311 -0
- vismatch/third_party/gim/gluefactory/datasets/hpatches.py +145 -0
- vismatch/third_party/gim/gluefactory/datasets/image_folder.py +59 -0
- vismatch/third_party/gim/gluefactory/datasets/image_pairs.py +100 -0
- vismatch/third_party/gim/gluefactory/datasets/megadepth.py +514 -0
- vismatch/third_party/gim/gluefactory/datasets/utils.py +131 -0
- vismatch/third_party/gim/gluefactory/eval/__init__.py +20 -0
- vismatch/third_party/gim/gluefactory/eval/eth3d.py +202 -0
- vismatch/third_party/gim/gluefactory/eval/eval_pipeline.py +109 -0
- vismatch/third_party/gim/gluefactory/eval/hpatches.py +203 -0
- vismatch/third_party/gim/gluefactory/eval/inspect.py +61 -0
- vismatch/third_party/gim/gluefactory/eval/io.py +109 -0
- vismatch/third_party/gim/gluefactory/eval/megadepth1500.py +189 -0
- vismatch/third_party/gim/gluefactory/eval/utils.py +272 -0
- vismatch/third_party/gim/gluefactory/geometry/depth.py +88 -0
- vismatch/third_party/gim/gluefactory/geometry/epipolar.py +155 -0
- vismatch/third_party/gim/gluefactory/geometry/gt_generation.py +558 -0
- vismatch/third_party/gim/gluefactory/geometry/homography.py +342 -0
- vismatch/third_party/gim/gluefactory/geometry/utils.py +167 -0
- vismatch/third_party/gim/gluefactory/geometry/wrappers.py +425 -0
- vismatch/third_party/gim/gluefactory/models/__init__.py +30 -0
- vismatch/third_party/gim/gluefactory/models/backbones/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/models/backbones/dinov2.py +30 -0
- vismatch/third_party/gim/gluefactory/models/base_model.py +157 -0
- vismatch/third_party/gim/gluefactory/models/cache_loader.py +139 -0
- vismatch/third_party/gim/gluefactory/models/extractors/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/models/extractors/aliked.py +786 -0
- vismatch/third_party/gim/gluefactory/models/extractors/disk_kornia.py +108 -0
- vismatch/third_party/gim/gluefactory/models/extractors/grid_extractor.py +60 -0
- vismatch/third_party/gim/gluefactory/models/extractors/keynet_affnet_hardnet.py +74 -0
- vismatch/third_party/gim/gluefactory/models/extractors/mixed.py +76 -0
- vismatch/third_party/gim/gluefactory/models/extractors/sift.py +234 -0
- vismatch/third_party/gim/gluefactory/models/extractors/sift_kornia.py +46 -0
- vismatch/third_party/gim/gluefactory/models/extractors/superpoint_open.py +210 -0
- vismatch/third_party/gim/gluefactory/models/lines/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/models/lines/deeplsd.py +106 -0
- vismatch/third_party/gim/gluefactory/models/lines/lsd.py +88 -0
- vismatch/third_party/gim/gluefactory/models/lines/wireframe.py +312 -0
- vismatch/third_party/gim/gluefactory/models/matchers/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/models/matchers/adalam.py +0 -0
- vismatch/third_party/gim/gluefactory/models/matchers/depth_matcher.py +82 -0
- vismatch/third_party/gim/gluefactory/models/matchers/gluestick.py +776 -0
- vismatch/third_party/gim/gluefactory/models/matchers/homography_matcher.py +66 -0
- vismatch/third_party/gim/gluefactory/models/matchers/kornia_loftr.py +66 -0
- vismatch/third_party/gim/gluefactory/models/matchers/lightglue.py +632 -0
- vismatch/third_party/gim/gluefactory/models/matchers/lightglue_pretrained.py +36 -0
- vismatch/third_party/gim/gluefactory/models/matchers/nearest_neighbor_matcher.py +97 -0
- vismatch/third_party/gim/gluefactory/models/triplet_pipeline.py +99 -0
- vismatch/third_party/gim/gluefactory/models/two_view_pipeline.py +114 -0
- vismatch/third_party/gim/gluefactory/models/utils/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/models/utils/losses.py +73 -0
- vismatch/third_party/gim/gluefactory/models/utils/metrics.py +50 -0
- vismatch/third_party/gim/gluefactory/models/utils/misc.py +70 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/__init__.py +15 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/base_estimator.py +33 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/homography/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/homography/homography_est.py +74 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/homography/opencv.py +53 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/homography/poselib.py +40 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/opencv.py +64 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/poselib.py +44 -0
- vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/pycolmap.py +52 -0
- vismatch/third_party/gim/gluefactory/scripts/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/scripts/export_local_features.py +127 -0
- vismatch/third_party/gim/gluefactory/scripts/export_megadepth.py +173 -0
- vismatch/third_party/gim/gluefactory/settings.py +6 -0
- vismatch/third_party/gim/gluefactory/superpoint.py +361 -0
- vismatch/third_party/gim/gluefactory/train.py +691 -0
- vismatch/third_party/gim/gluefactory/utils/__init__.py +0 -0
- vismatch/third_party/gim/gluefactory/utils/benchmark.py +33 -0
- vismatch/third_party/gim/gluefactory/utils/experiments.py +134 -0
- vismatch/third_party/gim/gluefactory/utils/export_predictions.py +81 -0
- vismatch/third_party/gim/gluefactory/utils/image.py +130 -0
- vismatch/third_party/gim/gluefactory/utils/misc.py +44 -0
- vismatch/third_party/gim/gluefactory/utils/patches.py +50 -0
- vismatch/third_party/gim/gluefactory/utils/stdout_capturing.py +134 -0
- vismatch/third_party/gim/gluefactory/utils/tensor.py +48 -0
- vismatch/third_party/gim/gluefactory/utils/tools.py +269 -0
- vismatch/third_party/gim/gluefactory/visualization/global_frame.py +289 -0
- vismatch/third_party/gim/gluefactory/visualization/tools.py +465 -0
- vismatch/third_party/gim/gluefactory/visualization/two_view_frame.py +158 -0
- vismatch/third_party/gim/gluefactory/visualization/visualize_batch.py +57 -0
- vismatch/third_party/gim/gluefactory/visualization/viz2d.py +486 -0
- vismatch/third_party/imatch-toolbox/configs/d2net.yml +26 -0
- vismatch/third_party/imatch-toolbox/configs/dogaffnethardnet.yml +10 -0
- vismatch/third_party/imatch-toolbox/configs/ncnet.yml +7 -0
- vismatch/third_party/imatch-toolbox/configs/patch2pix.yml +56 -0
- vismatch/third_party/imatch-toolbox/configs/patch2pix_superglue.yml +58 -0
- vismatch/third_party/imatch-toolbox/configs/r2d2.yml +31 -0
- vismatch/third_party/imatch-toolbox/configs/sift.yml +27 -0
- vismatch/third_party/imatch-toolbox/configs/superglue.yml +69 -0
- vismatch/third_party/imatch-toolbox/configs/superpoint.yml +21 -0
- vismatch/third_party/imatch-toolbox/environment.yml +14 -0
- vismatch/third_party/imatch-toolbox/immatch/__init__.py +8 -0
- vismatch/third_party/imatch-toolbox/immatch/eval_aachen.py +88 -0
- vismatch/third_party/imatch-toolbox/immatch/eval_hpatches.py +117 -0
- vismatch/third_party/imatch-toolbox/immatch/eval_inloc.py +45 -0
- vismatch/third_party/imatch-toolbox/immatch/eval_relapose.py +231 -0
- vismatch/third_party/imatch-toolbox/immatch/eval_robotcar.py +83 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/__init__.py +0 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/base.py +89 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/d2net.py +69 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/dogaffnethardnet.py +94 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/nn_matching.py +31 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/patch2pix.py +126 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/r2d2.py +64 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/sift.py +67 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/superglue.py +62 -0
- vismatch/third_party/imatch-toolbox/immatch/modules/superpoint.py +56 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/__init__.py +13 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/colmap/data_parsing.py +257 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/colmap/database.py +362 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/colmap/read_write_model.py +506 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/data_io.py +111 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/hpatches_helper.py +242 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/localize_sfm_helper.py +403 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/metrics.py +90 -0
- vismatch/third_party/imatch-toolbox/immatch/utils/model_helper.py +27 -0
- vismatch/third_party/imatch-toolbox/setup.py +36 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/extract_features.py +156 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/extract_kapture.py +248 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/lib/dataset.py +239 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/lib/exceptions.py +6 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/lib/loss.py +340 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/lib/model.py +121 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/lib/model_test.py +187 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/lib/pyramid.py +129 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/lib/utils.py +167 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/megadepth_utils/preprocess_scene.py +242 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/megadepth_utils/undistort_reconstructions.py +69 -0
- vismatch/third_party/imatch-toolbox/third_party/d2net/train.py +279 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/data_pairs/precompute_immatch_val_ovs.py +20 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/environment.yml +21 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/modules.py +167 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/ncn/conv4d.py +91 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/ncn/extract_ncmatches.py +158 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/ncn/model.py +333 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/patch2pix.py +403 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/resnet.py +191 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/utils.py +111 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/train_patch2pix.py +374 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/colmap/data_loading.py +169 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/colmap/read_database.py +175 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/colmap/read_write_model.py +483 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/common/plotting.py +393 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/common/setup_helper.py +59 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/common/visdom_helper.py +95 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/__init__.py +1 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/data_parsing.py +145 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/dataset_megadepth.py +141 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/preprocess.py +184 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/eval/geometry.py +90 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/eval/measure.py +161 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/eval/model_helper.py +129 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/train/eval_epoch_immatch.py +99 -0
- vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/train/helper.py +196 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/__init__.py +33 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/aachen.py +146 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/dataset.py +77 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/imgfolder.py +23 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/pair_dataset.py +287 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/web_images.py +64 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/extract.py +183 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/extract_kapture.py +194 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/ap_loss.py +67 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/losses.py +56 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/patchnet.py +134 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/reliability_loss.py +59 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/repeatability_loss.py +66 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/sampler.py +390 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/common.py +41 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/dataloader.py +367 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/trainer.py +76 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/transforms.py +513 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/transforms_tools.py +230 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/viz.py +191 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/train.py +138 -0
- vismatch/third_party/imatch-toolbox/third_party/r2d2/viz_heatmaps.py +122 -0
- vismatch/third_party/imatch-toolbox/third_party/superglue/demo_superglue.py +259 -0
- vismatch/third_party/imatch-toolbox/third_party/superglue/match_pairs.py +425 -0
- vismatch/third_party/imatch-toolbox/third_party/superglue/models/__init__.py +0 -0
- vismatch/third_party/imatch-toolbox/third_party/superglue/models/matching.py +84 -0
- vismatch/third_party/imatch-toolbox/third_party/superglue/models/superglue.py +283 -0
- vismatch/third_party/imatch-toolbox/third_party/superglue/models/superpoint.py +202 -0
- vismatch/third_party/imatch-toolbox/third_party/superglue/models/utils.py +555 -0
- vismatch/third_party/keypt2subpx/dataprocess/aliked.py +163 -0
- vismatch/third_party/keypt2subpx/dataprocess/dedode.py +215 -0
- vismatch/third_party/keypt2subpx/dataprocess/splg.py +162 -0
- vismatch/third_party/keypt2subpx/dataprocess/spnn.py +157 -0
- vismatch/third_party/keypt2subpx/dataprocess/superpoint_densescore.py +357 -0
- vismatch/third_party/keypt2subpx/dataprocess/xfeat.py +187 -0
- vismatch/third_party/keypt2subpx/dataset.py +145 -0
- vismatch/third_party/keypt2subpx/hubconf.py +38 -0
- vismatch/third_party/keypt2subpx/logger.py +127 -0
- vismatch/third_party/keypt2subpx/model.py +183 -0
- vismatch/third_party/keypt2subpx/settings.py +108 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/__init__.py +17 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/__init__.py +25 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/augmentations.py +244 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/base_dataset.py +206 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/eth3d.py +254 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/homographies.py +311 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/hpatches.py +145 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/image_folder.py +59 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/image_pairs.py +100 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/megadepth.py +510 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/utils.py +131 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/__init__.py +20 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/eth3d.py +202 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/eval_pipeline.py +109 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/hpatches.py +203 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/inspect.py +61 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/io.py +109 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/megadepth1500.py +189 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/utils.py +272 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/depth.py +88 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/epipolar.py +155 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/gt_generation.py +558 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/homography.py +342 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/utils.py +167 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/wrappers.py +425 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/__init__.py +30 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/backbones/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/backbones/dinov2.py +30 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/base_model.py +157 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/cache_loader.py +139 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/aliked.py +786 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/disk_kornia.py +108 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/grid_extractor.py +60 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/keynet_affnet_hardnet.py +74 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/mixed.py +76 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/sift.py +234 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/sift_kornia.py +46 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/superpoint_open.py +210 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/deeplsd.py +106 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/lsd.py +88 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/wireframe.py +312 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/adalam.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/depth_matcher.py +82 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/gluestick.py +776 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/homography_matcher.py +66 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/kornia_loftr.py +66 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/lightglue.py +612 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/lightglue_pretrained.py +36 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/nearest_neighbor_matcher.py +97 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/triplet_pipeline.py +99 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/two_view_pipeline.py +114 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/losses.py +73 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/metrics.py +50 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/misc.py +70 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/__init__.py +15 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/base_estimator.py +33 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/homography_est.py +74 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/opencv.py +53 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/poselib.py +40 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/opencv.py +64 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/poselib.py +44 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/pycolmap.py +52 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/scripts/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/scripts/export_local_features.py +127 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/scripts/export_megadepth.py +173 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/settings.py +6 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/train.py +691 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/benchmark.py +33 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/experiments.py +134 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/export_predictions.py +81 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/image.py +130 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/misc.py +44 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/patches.py +50 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/stdout_capturing.py +134 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/tensor.py +48 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/tools.py +269 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/global_frame.py +289 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/tools.py +465 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/two_view_frame.py +158 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/visualize_batch.py +57 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/viz2d.py +486 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory_nonfree/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory_nonfree/superglue.py +342 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory_nonfree/superpoint.py +356 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/tests/__init__.py +0 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/tests/test_eval_utils.py +88 -0
- vismatch/third_party/keypt2subpx/submodules/glue_factory/tests/test_integration.py +132 -0
- vismatch/third_party/keypt2subpx/summarize.py +44 -0
- vismatch/third_party/keypt2subpx/test.py +225 -0
- vismatch/third_party/keypt2subpx/train.py +180 -0
- vismatch/third_party/keypt2subpx/utils.py +150 -0
- vismatch/third_party/mast3r/demo.py +51 -0
- vismatch/third_party/mast3r/demo_dust3r_ga.py +99 -0
- vismatch/third_party/mast3r/demo_glomap.py +52 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/__init__.py +0 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/crops/extract_crops_from_images.py +159 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/__init__.py +0 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/paths.py +129 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/pairs_dataset.py +109 -0
- vismatch/third_party/mast3r/dust3r/croco/datasets/transforms.py +95 -0
- vismatch/third_party/mast3r/dust3r/croco/demo.py +55 -0
- vismatch/third_party/mast3r/dust3r/croco/models/blocks.py +241 -0
- vismatch/third_party/mast3r/dust3r/croco/models/criterion.py +37 -0
- vismatch/third_party/mast3r/dust3r/croco/models/croco.py +249 -0
- vismatch/third_party/mast3r/dust3r/croco/models/croco_downstream.py +122 -0
- vismatch/third_party/mast3r/dust3r/croco/models/curope/__init__.py +4 -0
- vismatch/third_party/mast3r/dust3r/croco/models/curope/curope2d.py +40 -0
- vismatch/third_party/mast3r/dust3r/croco/models/curope/setup.py +34 -0
- vismatch/third_party/mast3r/dust3r/croco/models/dpt_block.py +450 -0
- vismatch/third_party/mast3r/dust3r/croco/models/head_downstream.py +58 -0
- vismatch/third_party/mast3r/dust3r/croco/models/masking.py +25 -0
- vismatch/third_party/mast3r/dust3r/croco/models/pos_embed.py +157 -0
- vismatch/third_party/mast3r/dust3r/croco/pretrain.py +254 -0
- vismatch/third_party/mast3r/dust3r/croco/stereoflow/augmentor.py +290 -0
- vismatch/third_party/mast3r/dust3r/croco/stereoflow/criterion.py +251 -0
- vismatch/third_party/mast3r/dust3r/croco/stereoflow/datasets_flow.py +630 -0
- vismatch/third_party/mast3r/dust3r/croco/stereoflow/datasets_stereo.py +674 -0
- vismatch/third_party/mast3r/dust3r/croco/stereoflow/engine.py +280 -0
- vismatch/third_party/mast3r/dust3r/croco/stereoflow/test.py +216 -0
- vismatch/third_party/mast3r/dust3r/croco/stereoflow/train.py +253 -0
- vismatch/third_party/mast3r/dust3r/croco/utils/misc.py +463 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/find_scenes.py +78 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py +2 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py +170 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py +93 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py +151 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py +45 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/preprocess_habitat.py +121 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/path_to_root.py +13 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_arkitscenes.py +355 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_blendedMVS.py +149 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_co3d.py +295 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_megadepth.py +198 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_scannetpp.py +390 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_staticthings3d.py +130 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_waymo.py +257 -0
- vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_wildrgbd.py +209 -0
- vismatch/third_party/mast3r/dust3r/demo.py +45 -0
- vismatch/third_party/mast3r/dust3r/dust3r/__init__.py +2 -0
- vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/__init__.py +33 -0
- vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/base_opt.py +405 -0
- vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/commons.py +90 -0
- vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/init_im_poses.py +316 -0
- vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/modular_optimizer.py +145 -0
- vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/optimizer.py +248 -0
- vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/pair_viewer.py +127 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/__init__.py +50 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/arkitscenes.py +102 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/__init__.py +2 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/batched_sampler.py +74 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/easy_dataset.py +157 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/blendedmvs.py +104 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/co3d.py +165 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/habitat.py +107 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/megadepth.py +123 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/scannetpp.py +96 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/staticthings3d.py +96 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/utils/__init__.py +2 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/utils/cropping.py +124 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/utils/transforms.py +11 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/waymo.py +93 -0
- vismatch/third_party/mast3r/dust3r/dust3r/datasets/wildrgbd.py +67 -0
- vismatch/third_party/mast3r/dust3r/dust3r/demo.py +287 -0
- vismatch/third_party/mast3r/dust3r/dust3r/heads/__init__.py +19 -0
- vismatch/third_party/mast3r/dust3r/dust3r/heads/dpt_head.py +115 -0
- vismatch/third_party/mast3r/dust3r/dust3r/heads/linear_head.py +41 -0
- vismatch/third_party/mast3r/dust3r/dust3r/heads/postprocess.py +58 -0
- vismatch/third_party/mast3r/dust3r/dust3r/image_pairs.py +104 -0
- vismatch/third_party/mast3r/dust3r/dust3r/inference.py +150 -0
- vismatch/third_party/mast3r/dust3r/dust3r/losses.py +299 -0
- vismatch/third_party/mast3r/dust3r/dust3r/model.py +211 -0
- vismatch/third_party/mast3r/dust3r/dust3r/optim_factory.py +14 -0
- vismatch/third_party/mast3r/dust3r/dust3r/patch_embed.py +70 -0
- vismatch/third_party/mast3r/dust3r/dust3r/post_process.py +60 -0
- vismatch/third_party/mast3r/dust3r/dust3r/training.py +377 -0
- vismatch/third_party/mast3r/dust3r/dust3r/utils/__init__.py +2 -0
- vismatch/third_party/mast3r/dust3r/dust3r/utils/device.py +76 -0
- vismatch/third_party/mast3r/dust3r/dust3r/utils/geometry.py +366 -0
- vismatch/third_party/mast3r/dust3r/dust3r/utils/image.py +128 -0
- vismatch/third_party/mast3r/dust3r/dust3r/utils/misc.py +121 -0
- vismatch/third_party/mast3r/dust3r/dust3r/utils/parallel.py +79 -0
- vismatch/third_party/mast3r/dust3r/dust3r/utils/path_to_croco.py +19 -0
- vismatch/third_party/mast3r/dust3r/dust3r/viz.py +381 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/__init__.py +2 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/__init__.py +6 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/aachen_day_night.py +24 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_colmap.py +282 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_dataset.py +19 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py +19 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/inloc.py +167 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/sevenscenes.py +123 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/utils.py +118 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/evaluation.py +65 -0
- vismatch/third_party/mast3r/dust3r/dust3r_visloc/localization.py +140 -0
- vismatch/third_party/mast3r/dust3r/train.py +13 -0
- vismatch/third_party/mast3r/dust3r/visloc.py +193 -0
- vismatch/third_party/mast3r/kapture_mast3r_mapping.py +127 -0
- vismatch/third_party/mast3r/make_pairs.py +105 -0
- vismatch/third_party/mast3r/mast3r/__init__.py +2 -0
- vismatch/third_party/mast3r/mast3r/catmlp_dpt_head.py +239 -0
- vismatch/third_party/mast3r/mast3r/cloud_opt/__init__.py +2 -0
- vismatch/third_party/mast3r/mast3r/cloud_opt/sparse_ga.py +1078 -0
- vismatch/third_party/mast3r/mast3r/cloud_opt/triangulation.py +80 -0
- vismatch/third_party/mast3r/mast3r/cloud_opt/tsdf_optimizer.py +273 -0
- vismatch/third_party/mast3r/mast3r/cloud_opt/utils/__init__.py +2 -0
- vismatch/third_party/mast3r/mast3r/cloud_opt/utils/losses.py +32 -0
- vismatch/third_party/mast3r/mast3r/cloud_opt/utils/schedules.py +17 -0
- vismatch/third_party/mast3r/mast3r/colmap/__init__.py +2 -0
- vismatch/third_party/mast3r/mast3r/colmap/database.py +383 -0
- vismatch/third_party/mast3r/mast3r/colmap/mapping.py +196 -0
- vismatch/third_party/mast3r/mast3r/datasets/__init__.py +62 -0
- vismatch/third_party/mast3r/mast3r/datasets/base/__init__.py +2 -0
- vismatch/third_party/mast3r/mast3r/datasets/base/mast3r_base_stereo_view_dataset.py +355 -0
- vismatch/third_party/mast3r/mast3r/datasets/utils/__init__.py +2 -0
- vismatch/third_party/mast3r/mast3r/datasets/utils/cropping.py +219 -0
- vismatch/third_party/mast3r/mast3r/demo.py +381 -0
- vismatch/third_party/mast3r/mast3r/demo_glomap.py +343 -0
- vismatch/third_party/mast3r/mast3r/fast_nn.py +223 -0
- vismatch/third_party/mast3r/mast3r/image_pairs.py +115 -0
- vismatch/third_party/mast3r/mast3r/losses.py +508 -0
- vismatch/third_party/mast3r/mast3r/model.py +213 -0
- vismatch/third_party/mast3r/mast3r/retrieval/graph.py +77 -0
- vismatch/third_party/mast3r/mast3r/retrieval/model.py +271 -0
- vismatch/third_party/mast3r/mast3r/retrieval/processor.py +129 -0
- vismatch/third_party/mast3r/mast3r/utils/__init__.py +2 -0
- vismatch/third_party/mast3r/mast3r/utils/coarse_to_fine.py +214 -0
- vismatch/third_party/mast3r/mast3r/utils/collate.py +62 -0
- vismatch/third_party/mast3r/mast3r/utils/misc.py +17 -0
- vismatch/third_party/mast3r/mast3r/utils/path_to_dust3r.py +19 -0
- vismatch/third_party/mast3r/train.py +48 -0
- vismatch/third_party/mast3r/visloc.py +538 -0
- vismatch/third_party/omniglue/__init__.py +19 -0
- vismatch/third_party/omniglue/demo.py +89 -0
- vismatch/third_party/omniglue/src/omniglue/__init__.py +17 -0
- vismatch/third_party/omniglue/src/omniglue/dino_extract.py +215 -0
- vismatch/third_party/omniglue/src/omniglue/omniglue_extract.py +159 -0
- vismatch/third_party/omniglue/src/omniglue/superpoint_extract.py +214 -0
- vismatch/third_party/omniglue/src/omniglue/utils.py +274 -0
- vismatch/third_party/omniglue/third_party/dinov2/__init__.py +0 -0
- vismatch/third_party/omniglue/third_party/dinov2/dino.py +411 -0
- vismatch/third_party/omniglue/third_party/dinov2/dino_utils.py +341 -0
- vismatch/third_party/rdd/RDD/RDD.py +262 -0
- vismatch/third_party/rdd/RDD/RDD_helper.py +181 -0
- vismatch/third_party/rdd/RDD/dataset/__init__.py +0 -0
- vismatch/third_party/rdd/RDD/dataset/megadepth/__init__.py +2 -0
- vismatch/third_party/rdd/RDD/dataset/megadepth/megadepth.py +313 -0
- vismatch/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py +75 -0
- vismatch/third_party/rdd/RDD/dataset/megadepth/utils.py +848 -0
- vismatch/third_party/rdd/RDD/matchers/__init__.py +3 -0
- vismatch/third_party/rdd/RDD/matchers/dense_matcher.py +137 -0
- vismatch/third_party/rdd/RDD/matchers/dual_softmax_matcher.py +31 -0
- vismatch/third_party/rdd/RDD/matchers/lightglue.py +667 -0
- vismatch/third_party/rdd/RDD/models/backbone.py +147 -0
- vismatch/third_party/rdd/RDD/models/deformable_transformer.py +270 -0
- vismatch/third_party/rdd/RDD/models/descriptor.py +116 -0
- vismatch/third_party/rdd/RDD/models/detector.py +141 -0
- vismatch/third_party/rdd/RDD/models/interpolator.py +33 -0
- vismatch/third_party/rdd/RDD/models/ops/functions/__init__.py +13 -0
- vismatch/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py +74 -0
- vismatch/third_party/rdd/RDD/models/ops/modules/__init__.py +12 -0
- vismatch/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py +125 -0
- vismatch/third_party/rdd/RDD/models/ops/setup.py +78 -0
- vismatch/third_party/rdd/RDD/models/ops/test.py +92 -0
- vismatch/third_party/rdd/RDD/models/position_encoding.py +48 -0
- vismatch/third_party/rdd/RDD/models/soft_detect.py +176 -0
- vismatch/third_party/rdd/RDD/utils/__init__.py +1 -0
- vismatch/third_party/rdd/RDD/utils/misc.py +531 -0
- vismatch/third_party/rdd/benchmarks/air_ground.py +250 -0
- vismatch/third_party/rdd/benchmarks/mega_1500.py +259 -0
- vismatch/third_party/rdd/benchmarks/mega_view.py +252 -0
- vismatch/third_party/rdd/benchmarks/scannet_1500.py +251 -0
- vismatch/third_party/rdd/benchmarks/utils.py +112 -0
- vismatch/third_party/rdd/configs/default.yaml +19 -0
- vismatch/third_party/rdd/sfm/extract_rdd.py +145 -0
- vismatch/third_party/rdd/sfm/match_rdd.py +259 -0
- vismatch/third_party/rdd/third_party/LightGlue/.github/workflows/code-quality.yml +24 -0
- vismatch/third_party/rdd/third_party/LightGlue/benchmark.py +255 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/__init__.py +7 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/aliked.py +760 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/disk.py +55 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/lightglue.py +662 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/sift.py +216 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/superpoint.py +227 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/utils.py +165 -0
- vismatch/third_party/rdd/third_party/LightGlue/lightglue/viz2d.py +203 -0
- vismatch/third_party/rdd/third_party/__init__.py +1 -0
- vismatch/third_party/rdd/third_party/aliked_wrapper.py +17 -0
- vismatch/third_party/rdd/training/losses/descriptor_loss.py +73 -0
- vismatch/third_party/rdd/training/losses/detector_loss.py +499 -0
- vismatch/third_party/rdd/training/train.py +473 -0
- vismatch/third_party/rdd/training/utils.py +246 -0
- vismatch/utils.py +390 -0
- vismatch/viz.py +222 -0
- vismatch-1.1.1.dist-info/METADATA +265 -0
- vismatch-1.1.1.dist-info/RECORD +2042 -0
- vismatch-1.1.1.dist-info/WHEEL +5 -0
- vismatch-1.1.1.dist-info/entry_points.txt +4 -0
- vismatch-1.1.1.dist-info/licenses/LICENSE +28 -0
- vismatch-1.1.1.dist-info/top_level.txt +4 -0
- vismatch_extract.py +103 -0
- vismatch_match.py +114 -0
- vismatch_test.py +186 -0
|
@@ -0,0 +1,1768 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from .linear_attention import LinearAttention, RoPELinearAttention, FullAttention, XAttention
|
|
6
|
+
from einops.einops import rearrange
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
from .transformer_utils import TokenConfidence, MatchAssignment, filter_matches
|
|
9
|
+
from ..utils.coarse_matching import CoarseMatching
|
|
10
|
+
from ..utils.position_encoding import RoPEPositionEncodingSine
|
|
11
|
+
import numpy as np
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
PFLASH_AVAILABLE = False
|
|
15
|
+
|
|
16
|
+
class PANEncoderLayer(nn.Module):
|
|
17
|
+
def __init__(self,
|
|
18
|
+
d_model,
|
|
19
|
+
nhead,
|
|
20
|
+
attention='linear',
|
|
21
|
+
pool_size=4,
|
|
22
|
+
bn=True,
|
|
23
|
+
xformer=False,
|
|
24
|
+
leaky=-1.0,
|
|
25
|
+
dw_conv=False,
|
|
26
|
+
scatter=False,
|
|
27
|
+
):
|
|
28
|
+
super(PANEncoderLayer, self).__init__()
|
|
29
|
+
|
|
30
|
+
self.pool_size = pool_size
|
|
31
|
+
self.dw_conv = dw_conv
|
|
32
|
+
self.scatter = scatter
|
|
33
|
+
if self.dw_conv:
|
|
34
|
+
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
|
35
|
+
|
|
36
|
+
assert not self.scatter, 'buggy implemented here'
|
|
37
|
+
self.dim = d_model // nhead
|
|
38
|
+
self.nhead = nhead
|
|
39
|
+
|
|
40
|
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
|
|
41
|
+
# multi-head attention
|
|
42
|
+
if bn:
|
|
43
|
+
method = 'dw_bn'
|
|
44
|
+
else:
|
|
45
|
+
method = 'dw'
|
|
46
|
+
self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
47
|
+
self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
48
|
+
self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
49
|
+
|
|
50
|
+
# self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
|
|
51
|
+
# self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
52
|
+
# self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
53
|
+
if xformer:
|
|
54
|
+
self.attention = XAttention()
|
|
55
|
+
else:
|
|
56
|
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
|
57
|
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
|
58
|
+
|
|
59
|
+
# feed-forward network
|
|
60
|
+
if leaky > 0:
|
|
61
|
+
self.mlp = nn.Sequential(
|
|
62
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
63
|
+
nn.LeakyReLU(leaky, True),
|
|
64
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
else:
|
|
68
|
+
self.mlp = nn.Sequential(
|
|
69
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
70
|
+
nn.ReLU(True),
|
|
71
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# norm and dropout
|
|
75
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
76
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
77
|
+
|
|
78
|
+
# self.norm1 = nn.BatchNorm2d(d_model)
|
|
79
|
+
|
|
80
|
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
|
81
|
+
"""
|
|
82
|
+
Args:
|
|
83
|
+
x (torch.Tensor): [N, C, H1, W1]
|
|
84
|
+
source (torch.Tensor): [N, C, H2, W2]
|
|
85
|
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
|
86
|
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
|
87
|
+
"""
|
|
88
|
+
bs = x.size(0)
|
|
89
|
+
H1, W1 = x.size(-2), x.size(-1)
|
|
90
|
+
H2, W2 = source.size(-2), source.size(-1)
|
|
91
|
+
|
|
92
|
+
query, key, value = x, source, source
|
|
93
|
+
|
|
94
|
+
if self.dw_conv:
|
|
95
|
+
query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
|
|
96
|
+
else:
|
|
97
|
+
query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
|
|
98
|
+
# only need to cal key or value...
|
|
99
|
+
key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
|
|
100
|
+
value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
|
|
101
|
+
|
|
102
|
+
# After 0617 bnorm to prevent permute*6
|
|
103
|
+
# query = self.norm1(self.max_pool(query))
|
|
104
|
+
# key = self.norm1(self.max_pool(key))
|
|
105
|
+
# value = self.norm1(self.max_pool(value))
|
|
106
|
+
# multi-head attention
|
|
107
|
+
query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
|
108
|
+
key = self.k_proj_conv(key)
|
|
109
|
+
value = self.v_proj_conv(value)
|
|
110
|
+
|
|
111
|
+
C = query.shape[-3]
|
|
112
|
+
|
|
113
|
+
ismask = x_mask is not None and source_mask is not None
|
|
114
|
+
if bs == 1 or not ismask:
|
|
115
|
+
if ismask:
|
|
116
|
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
|
117
|
+
source_mask = self.max_pool(source_mask.float()).bool()
|
|
118
|
+
|
|
119
|
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
|
120
|
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
|
121
|
+
|
|
122
|
+
query = query[:, :, :mask_h0, :mask_w0]
|
|
123
|
+
key = key[:, :, :mask_h1, :mask_w1]
|
|
124
|
+
value = value[:, :, :mask_h1, :mask_w1]
|
|
125
|
+
|
|
126
|
+
else:
|
|
127
|
+
assert x_mask is None and source_mask is None
|
|
128
|
+
|
|
129
|
+
# query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
|
|
130
|
+
# key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
|
131
|
+
# value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
|
132
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
133
|
+
query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
134
|
+
key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
135
|
+
value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
136
|
+
|
|
137
|
+
else: # N L H D
|
|
138
|
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
|
139
|
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
140
|
+
value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
141
|
+
|
|
142
|
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
|
|
143
|
+
|
|
144
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
145
|
+
message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
|
146
|
+
|
|
147
|
+
if ismask:
|
|
148
|
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
|
|
149
|
+
if mask_h0 != x_mask.size(-2):
|
|
150
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
|
151
|
+
elif mask_w0 != x_mask.size(-1):
|
|
152
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
|
153
|
+
# message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
|
|
154
|
+
|
|
155
|
+
else:
|
|
156
|
+
assert x_mask is None and source_mask is None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
160
|
+
# message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
|
|
161
|
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
|
162
|
+
|
|
163
|
+
if self.scatter:
|
|
164
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
|
165
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
|
166
|
+
# message = self.aggregate(message)
|
|
167
|
+
message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
|
|
168
|
+
else:
|
|
169
|
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
170
|
+
|
|
171
|
+
# message = self.norm1(message)
|
|
172
|
+
|
|
173
|
+
# feed-forward network
|
|
174
|
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
175
|
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
176
|
+
|
|
177
|
+
return x + message
|
|
178
|
+
else:
|
|
179
|
+
x_mask = self.max_pool(x_mask.float()).bool()
|
|
180
|
+
source_mask = self.max_pool(source_mask.float()).bool()
|
|
181
|
+
m_list = []
|
|
182
|
+
for i in range(bs):
|
|
183
|
+
mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
|
|
184
|
+
mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
|
|
185
|
+
|
|
186
|
+
q = query[i:i+1, :, :mask_h0, :mask_w0]
|
|
187
|
+
k = key[i:i+1, :, :mask_h1, :mask_w1]
|
|
188
|
+
v = value[i:i+1, :, :mask_h1, :mask_w1]
|
|
189
|
+
|
|
190
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
191
|
+
q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
192
|
+
k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
193
|
+
v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
194
|
+
|
|
195
|
+
else: # N L H D
|
|
196
|
+
|
|
197
|
+
q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
|
198
|
+
k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
199
|
+
v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
200
|
+
|
|
201
|
+
m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
|
|
202
|
+
|
|
203
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
204
|
+
m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
|
205
|
+
|
|
206
|
+
m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
|
|
207
|
+
if mask_h0 != x_mask.size(-2):
|
|
208
|
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
|
|
209
|
+
elif mask_w0 != x_mask.size(-1):
|
|
210
|
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
|
|
211
|
+
m_list.append(m)
|
|
212
|
+
message = torch.cat(m_list, dim=0)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
216
|
+
# message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
|
|
217
|
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
|
218
|
+
|
|
219
|
+
if self.scatter:
|
|
220
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
|
221
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
|
222
|
+
# message = self.aggregate(message)
|
|
223
|
+
# assert False
|
|
224
|
+
else:
|
|
225
|
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
226
|
+
|
|
227
|
+
# message = self.norm1(message)
|
|
228
|
+
|
|
229
|
+
# feed-forward network
|
|
230
|
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
231
|
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
232
|
+
|
|
233
|
+
return x + message
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
|
|
237
|
+
"""
|
|
238
|
+
Args:
|
|
239
|
+
x (torch.Tensor): [N, C, H1, W1]
|
|
240
|
+
source (torch.Tensor): [N, C, H2, W2]
|
|
241
|
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
|
242
|
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
|
243
|
+
"""
|
|
244
|
+
bs = x.size(0)
|
|
245
|
+
H1, W1 = x.size(-2), x.size(-1)
|
|
246
|
+
H2, W2 = source.size(-2), source.size(-1)
|
|
247
|
+
|
|
248
|
+
query, key, value = x, source, source
|
|
249
|
+
|
|
250
|
+
with profiler.profile("permute*6+norm1*3+max_pool*3"):
|
|
251
|
+
if self.dw_conv:
|
|
252
|
+
query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
|
|
253
|
+
else:
|
|
254
|
+
query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
|
|
255
|
+
# only need to cal key or value...
|
|
256
|
+
key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
|
|
257
|
+
value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
|
|
258
|
+
|
|
259
|
+
with profiler.profile("permute*6"):
|
|
260
|
+
query = query.permute(0, 2, 3, 1)
|
|
261
|
+
key = key.permute(0, 2, 3, 1)
|
|
262
|
+
value = value.permute(0, 2, 3, 1)
|
|
263
|
+
|
|
264
|
+
query = query.permute(0,3,1,2)
|
|
265
|
+
key = key.permute(0,3,1,2)
|
|
266
|
+
value = value.permute(0,3,1,2)
|
|
267
|
+
|
|
268
|
+
# query = self.bnorm1(self.max_pool(query))
|
|
269
|
+
# key = self.bnorm1(self.max_pool(key))
|
|
270
|
+
# value = self.bnorm1(self.max_pool(value))
|
|
271
|
+
# multi-head attention
|
|
272
|
+
|
|
273
|
+
with profiler.profile("q_conv+k_conv+v_conv"):
|
|
274
|
+
query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
|
275
|
+
key = self.k_proj_conv(key)
|
|
276
|
+
value = self.v_proj_conv(value)
|
|
277
|
+
|
|
278
|
+
C = query.shape[-3]
|
|
279
|
+
# TODO: Need to be consistent with bs=1 (where mask region do not in attention at all)
|
|
280
|
+
if x_mask is not None and source_mask is not None:
|
|
281
|
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
|
282
|
+
source_mask = self.max_pool(source_mask.float()).bool()
|
|
283
|
+
|
|
284
|
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
|
285
|
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
|
286
|
+
|
|
287
|
+
query = query[:, :, :mask_h0, :mask_w0]
|
|
288
|
+
key = key[:, :, :mask_h1, :mask_w1]
|
|
289
|
+
value = value[:, :, :mask_h1, :mask_w1]
|
|
290
|
+
|
|
291
|
+
# mask_h0, mask_w0 = data['mask0'][0].sum(-2)[0], data['mask0'][0].sum(-1)[0]
|
|
292
|
+
# mask_h1, mask_w1 = data['mask1'][0].sum(-2)[0], data['mask1'][0].sum(-1)[0]
|
|
293
|
+
# C = feat_c0.shape[-3]
|
|
294
|
+
# feat_c0 = feat_c0[:, :, :mask_h0, :mask_w0]
|
|
295
|
+
# feat_c1 = feat_c1[:, :, :mask_h1, :mask_w1]
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# feat_c0 = feat_c0.reshape(-1, mask_h0, mask_w0, C)
|
|
299
|
+
# feat_c1 = feat_c1.reshape(-1, mask_h1, mask_w1, C)
|
|
300
|
+
# if mask_h0 != data['mask0'].size(-2):
|
|
301
|
+
# feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0]-mask_h0, data['hw0_c'][1], C, device=feat_c0.device)], dim=1)
|
|
302
|
+
# elif mask_w0 != data['mask0'].size(-1):
|
|
303
|
+
# feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0], data['hw0_c'][1]-mask_w0, C, device=feat_c0.device)], dim=2)
|
|
304
|
+
|
|
305
|
+
# if mask_h1 != data['mask1'].size(-2):
|
|
306
|
+
# feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0]-mask_h1, data['hw1_c'][1], C, device=feat_c1.device)], dim=1)
|
|
307
|
+
# elif mask_w1 != data['mask1'].size(-1):
|
|
308
|
+
# feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0], data['hw1_c'][1]-mask_w1, C, device=feat_c1.device)], dim=2)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
else:
|
|
312
|
+
assert x_mask is None and source_mask is None
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
# query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
|
|
317
|
+
# key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
|
318
|
+
# value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
|
|
319
|
+
|
|
320
|
+
with profiler.profile("rearrange*3"):
|
|
321
|
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
|
322
|
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
323
|
+
value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
324
|
+
|
|
325
|
+
with profiler.profile("attention"):
|
|
326
|
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D]
|
|
327
|
+
|
|
328
|
+
if x_mask is not None and source_mask is not None:
|
|
329
|
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
|
|
330
|
+
if mask_h0 != x_mask.size(-2):
|
|
331
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
|
332
|
+
elif mask_w0 != x_mask.size(-1):
|
|
333
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
|
334
|
+
# message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
|
|
335
|
+
|
|
336
|
+
else:
|
|
337
|
+
assert x_mask is None and source_mask is None
|
|
338
|
+
|
|
339
|
+
with profiler.profile("merge"):
|
|
340
|
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
341
|
+
# message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
|
|
342
|
+
|
|
343
|
+
with profiler.profile("rearrange*1"):
|
|
344
|
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
|
345
|
+
|
|
346
|
+
with profiler.profile("upsample"):
|
|
347
|
+
if self.scatter:
|
|
348
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
|
349
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
|
350
|
+
# message = self.aggregate(message)
|
|
351
|
+
# assert False
|
|
352
|
+
else:
|
|
353
|
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
354
|
+
|
|
355
|
+
# message = self.norm1(message)
|
|
356
|
+
|
|
357
|
+
# feed-forward network
|
|
358
|
+
with profiler.profile("feed-forward_mlp+permute*2+norm2"):
|
|
359
|
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
360
|
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
361
|
+
|
|
362
|
+
return x + message
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _build_projection(self,
|
|
366
|
+
dim_in,
|
|
367
|
+
dim_out,
|
|
368
|
+
kernel_size=3,
|
|
369
|
+
padding=1,
|
|
370
|
+
stride=1,
|
|
371
|
+
method='dw_bn',
|
|
372
|
+
):
|
|
373
|
+
if method == 'dw_bn':
|
|
374
|
+
proj = nn.Sequential(OrderedDict([
|
|
375
|
+
('conv', nn.Conv2d(
|
|
376
|
+
dim_in,
|
|
377
|
+
dim_in,
|
|
378
|
+
kernel_size=kernel_size,
|
|
379
|
+
padding=padding,
|
|
380
|
+
stride=stride,
|
|
381
|
+
bias=False,
|
|
382
|
+
groups=dim_in
|
|
383
|
+
)),
|
|
384
|
+
('bn', nn.BatchNorm2d(dim_in)),
|
|
385
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
386
|
+
]))
|
|
387
|
+
elif method == 'avg':
|
|
388
|
+
proj = nn.Sequential(OrderedDict([
|
|
389
|
+
('avg', nn.AvgPool2d(
|
|
390
|
+
kernel_size=kernel_size,
|
|
391
|
+
padding=padding,
|
|
392
|
+
stride=stride,
|
|
393
|
+
ceil_mode=True
|
|
394
|
+
)),
|
|
395
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
396
|
+
]))
|
|
397
|
+
elif method == 'linear':
|
|
398
|
+
proj = None
|
|
399
|
+
elif method == 'dw':
|
|
400
|
+
proj = nn.Sequential(OrderedDict([
|
|
401
|
+
('conv', nn.Conv2d(
|
|
402
|
+
dim_in,
|
|
403
|
+
dim_in,
|
|
404
|
+
kernel_size=kernel_size,
|
|
405
|
+
padding=padding,
|
|
406
|
+
stride=stride,
|
|
407
|
+
bias=False,
|
|
408
|
+
groups=dim_in
|
|
409
|
+
)),
|
|
410
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
411
|
+
]))
|
|
412
|
+
else:
|
|
413
|
+
raise ValueError('Unknown method ({})'.format(method))
|
|
414
|
+
|
|
415
|
+
return proj
|
|
416
|
+
|
|
417
|
+
class AG_RoPE_EncoderLayer(nn.Module):
|
|
418
|
+
def __init__(self,
|
|
419
|
+
d_model,
|
|
420
|
+
nhead,
|
|
421
|
+
attention='linear',
|
|
422
|
+
pool_size=4,
|
|
423
|
+
pool_size2=4,
|
|
424
|
+
xformer=False,
|
|
425
|
+
leaky=-1.0,
|
|
426
|
+
dw_conv=False,
|
|
427
|
+
dw_conv2=False,
|
|
428
|
+
scatter=False,
|
|
429
|
+
norm_before=True,
|
|
430
|
+
rope=False,
|
|
431
|
+
npe=None,
|
|
432
|
+
vit_norm=False,
|
|
433
|
+
dw_proj=False,
|
|
434
|
+
):
|
|
435
|
+
super(AG_RoPE_EncoderLayer, self).__init__()
|
|
436
|
+
|
|
437
|
+
self.pool_size = pool_size
|
|
438
|
+
self.pool_size2 = pool_size2
|
|
439
|
+
self.dw_conv = dw_conv
|
|
440
|
+
self.dw_conv2 = dw_conv2
|
|
441
|
+
self.scatter = scatter
|
|
442
|
+
self.norm_before = norm_before
|
|
443
|
+
self.vit_norm = vit_norm
|
|
444
|
+
self.dw_proj = dw_proj
|
|
445
|
+
self.rope = rope
|
|
446
|
+
if self.dw_conv and self.pool_size != 1:
|
|
447
|
+
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
|
448
|
+
if self.dw_conv2 and self.pool_size2 != 1:
|
|
449
|
+
self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size2, padding=0, stride=pool_size2, bias=False, groups=d_model)
|
|
450
|
+
|
|
451
|
+
self.dim = d_model // nhead
|
|
452
|
+
self.nhead = nhead
|
|
453
|
+
|
|
454
|
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size2, stride=self.pool_size2)
|
|
455
|
+
|
|
456
|
+
# multi-head attention
|
|
457
|
+
if self.dw_proj:
|
|
458
|
+
self.q_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
|
|
459
|
+
self.k_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
|
|
460
|
+
self.v_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
|
|
461
|
+
else:
|
|
462
|
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
463
|
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
464
|
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
465
|
+
|
|
466
|
+
if self.rope:
|
|
467
|
+
self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
|
|
468
|
+
|
|
469
|
+
if xformer:
|
|
470
|
+
self.attention = XAttention()
|
|
471
|
+
else:
|
|
472
|
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
|
473
|
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
|
474
|
+
|
|
475
|
+
# feed-forward network
|
|
476
|
+
if leaky > 0:
|
|
477
|
+
if self.vit_norm:
|
|
478
|
+
self.mlp = nn.Sequential(
|
|
479
|
+
nn.Linear(d_model, d_model*2, bias=False),
|
|
480
|
+
nn.LeakyReLU(leaky, True),
|
|
481
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
482
|
+
)
|
|
483
|
+
else:
|
|
484
|
+
self.mlp = nn.Sequential(
|
|
485
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
486
|
+
nn.LeakyReLU(leaky, True),
|
|
487
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
else:
|
|
491
|
+
if self.vit_norm:
|
|
492
|
+
self.mlp = nn.Sequential(
|
|
493
|
+
nn.Linear(d_model, d_model*2, bias=False),
|
|
494
|
+
nn.ReLU(True),
|
|
495
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
496
|
+
)
|
|
497
|
+
else:
|
|
498
|
+
self.mlp = nn.Sequential(
|
|
499
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
500
|
+
nn.ReLU(True),
|
|
501
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# norm and dropout
|
|
505
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
506
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
507
|
+
|
|
508
|
+
# self.norm1 = nn.BatchNorm2d(d_model)
|
|
509
|
+
|
|
510
|
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
|
511
|
+
"""
|
|
512
|
+
Args:
|
|
513
|
+
x (torch.Tensor): [N, C, H1, W1]
|
|
514
|
+
source (torch.Tensor): [N, C, H2, W2]
|
|
515
|
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
|
516
|
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
|
517
|
+
"""
|
|
518
|
+
bs, C, H1, W1 = x.size()
|
|
519
|
+
H2, W2 = source.size(-2), source.size(-1)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
if self.norm_before and not self.vit_norm:
|
|
523
|
+
if self.pool_size == 1:
|
|
524
|
+
query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
|
|
525
|
+
elif self.dw_conv:
|
|
526
|
+
query = self.norm1(self.aggregate(x).permute(0,2,3,1)) # [N, H, W, C]
|
|
527
|
+
else:
|
|
528
|
+
query = self.norm1(self.max_pool(x).permute(0,2,3,1)) # [N, H, W, C]
|
|
529
|
+
if self.pool_size2 == 1:
|
|
530
|
+
source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
|
|
531
|
+
elif self.dw_conv2:
|
|
532
|
+
source = self.norm1(self.aggregate2(source).permute(0,2,3,1)) # [N, H, W, C]
|
|
533
|
+
else:
|
|
534
|
+
source = self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
|
|
535
|
+
elif self.vit_norm:
|
|
536
|
+
if self.pool_size == 1:
|
|
537
|
+
query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
|
|
538
|
+
elif self.dw_conv:
|
|
539
|
+
query = self.aggregate(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
|
540
|
+
else:
|
|
541
|
+
query = self.max_pool(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
|
542
|
+
if self.pool_size2 == 1:
|
|
543
|
+
source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
|
|
544
|
+
elif self.dw_conv2:
|
|
545
|
+
source = self.aggregate2(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
|
546
|
+
else:
|
|
547
|
+
source = self.max_pool(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
|
|
548
|
+
else:
|
|
549
|
+
if self.pool_size == 1:
|
|
550
|
+
query = x.permute(0,2,3,1) # [N, H, W, C]
|
|
551
|
+
elif self.dw_conv:
|
|
552
|
+
query = self.aggregate(x).permute(0,2,3,1) # [N, H, W, C]
|
|
553
|
+
else:
|
|
554
|
+
query = self.max_pool(x).permute(0,2,3,1) # [N, H, W, C]
|
|
555
|
+
if self.pool_size2 == 1:
|
|
556
|
+
source = source.permute(0,2,3,1) # [N, H, W, C]
|
|
557
|
+
elif self.dw_conv2:
|
|
558
|
+
source = self.aggregate2(source).permute(0,2,3,1) # [N, H, W, C]
|
|
559
|
+
else:
|
|
560
|
+
source = self.max_pool(source).permute(0,2,3,1) # [N, H, W, C]
|
|
561
|
+
|
|
562
|
+
# projection
|
|
563
|
+
if self.dw_proj:
|
|
564
|
+
query = self.q_proj(query.permute(0,3,1,2)).permute(0,2,3,1)
|
|
565
|
+
key = self.k_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
|
|
566
|
+
value = self.v_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
|
|
567
|
+
else:
|
|
568
|
+
query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
|
|
569
|
+
|
|
570
|
+
# RoPE
|
|
571
|
+
if self.rope:
|
|
572
|
+
query = self.rope_pos_enc(query)
|
|
573
|
+
if self.pool_size == 1 and self.pool_size2 == 4:
|
|
574
|
+
key = self.rope_pos_enc(key, 4)
|
|
575
|
+
else:
|
|
576
|
+
key = self.rope_pos_enc(key)
|
|
577
|
+
|
|
578
|
+
use_mask = x_mask is not None and source_mask is not None
|
|
579
|
+
if bs == 1 or not use_mask:
|
|
580
|
+
if use_mask:
|
|
581
|
+
# downsample mask
|
|
582
|
+
if self.pool_size ==1:
|
|
583
|
+
pass
|
|
584
|
+
else:
|
|
585
|
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
|
586
|
+
|
|
587
|
+
if self.pool_size2 ==1:
|
|
588
|
+
pass
|
|
589
|
+
else:
|
|
590
|
+
source_mask = self.max_pool(source_mask.float()).bool()
|
|
591
|
+
|
|
592
|
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
|
593
|
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
|
594
|
+
|
|
595
|
+
query = query[:, :mask_h0, :mask_w0, :]
|
|
596
|
+
key = key[:, :mask_h1, :mask_w1, :]
|
|
597
|
+
value = value[:, :mask_h1, :mask_w1, :]
|
|
598
|
+
else:
|
|
599
|
+
assert x_mask is None and source_mask is None
|
|
600
|
+
|
|
601
|
+
if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
|
|
602
|
+
query = rearrange(query, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
603
|
+
key = rearrange(key, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
604
|
+
value = rearrange(value, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
605
|
+
else: # N L H D
|
|
606
|
+
query = rearrange(query, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
|
607
|
+
key = rearrange(key, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
|
608
|
+
value = rearrange(value, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
|
609
|
+
|
|
610
|
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
|
|
611
|
+
|
|
612
|
+
if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
|
|
613
|
+
message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
|
614
|
+
|
|
615
|
+
if use_mask: # padding zero
|
|
616
|
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L h D]
|
|
617
|
+
if mask_h0 != x_mask.size(-2):
|
|
618
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
|
619
|
+
elif mask_w0 != x_mask.size(-1):
|
|
620
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
|
621
|
+
else:
|
|
622
|
+
assert x_mask is None and source_mask is None
|
|
623
|
+
|
|
624
|
+
message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
625
|
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
|
626
|
+
|
|
627
|
+
if self.pool_size == 1:
|
|
628
|
+
pass
|
|
629
|
+
else:
|
|
630
|
+
if self.scatter:
|
|
631
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
|
632
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
|
633
|
+
message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
|
|
634
|
+
else:
|
|
635
|
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
636
|
+
|
|
637
|
+
if not self.norm_before and not self.vit_norm:
|
|
638
|
+
message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
|
639
|
+
|
|
640
|
+
# feed-forward network
|
|
641
|
+
if self.vit_norm:
|
|
642
|
+
message_inter = (x + message)
|
|
643
|
+
del x
|
|
644
|
+
message = self.norm2(message_inter.permute(0, 2, 3, 1))
|
|
645
|
+
message = self.mlp(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
646
|
+
return message_inter + message
|
|
647
|
+
else:
|
|
648
|
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
649
|
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
650
|
+
|
|
651
|
+
return x + message
|
|
652
|
+
else: # mask with bs > 1
|
|
653
|
+
if self.pool_size ==1:
|
|
654
|
+
pass
|
|
655
|
+
else:
|
|
656
|
+
x_mask = self.max_pool(x_mask.float()).bool()
|
|
657
|
+
|
|
658
|
+
if self.pool_size2 ==1:
|
|
659
|
+
pass
|
|
660
|
+
else:
|
|
661
|
+
source_mask = self.max_pool(source_mask.float()).bool()
|
|
662
|
+
m_list = []
|
|
663
|
+
for i in range(bs):
|
|
664
|
+
mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
|
|
665
|
+
mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
|
|
666
|
+
|
|
667
|
+
q = query[i:i+1, :mask_h0, :mask_w0, :]
|
|
668
|
+
k = key[i:i+1, :mask_h1, :mask_w1, :]
|
|
669
|
+
v = value[i:i+1, :mask_h1, :mask_w1, :]
|
|
670
|
+
|
|
671
|
+
if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
|
|
672
|
+
q = rearrange(q, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
673
|
+
k = rearrange(k, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
674
|
+
v = rearrange(v, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
675
|
+
else: # N L H D
|
|
676
|
+
q = rearrange(q, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
|
677
|
+
k = rearrange(k, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
|
678
|
+
v = rearrange(v, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
|
|
679
|
+
|
|
680
|
+
m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
|
|
681
|
+
|
|
682
|
+
if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
|
|
683
|
+
m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
|
684
|
+
|
|
685
|
+
m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
|
|
686
|
+
if mask_h0 != x_mask.size(-2):
|
|
687
|
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
|
|
688
|
+
elif mask_w0 != x_mask.size(-1):
|
|
689
|
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
|
|
690
|
+
m_list.append(m)
|
|
691
|
+
m = torch.cat(m_list, dim=0)
|
|
692
|
+
|
|
693
|
+
m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
694
|
+
# m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
|
|
695
|
+
m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
|
696
|
+
|
|
697
|
+
if self.pool_size == 1:
|
|
698
|
+
pass
|
|
699
|
+
else:
|
|
700
|
+
if self.scatter:
|
|
701
|
+
m = torch.repeat_interleave(m, self.pool_size, dim=-2)
|
|
702
|
+
m = torch.repeat_interleave(m, self.pool_size, dim=-1)
|
|
703
|
+
m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
|
|
704
|
+
else:
|
|
705
|
+
m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
if not self.norm_before and not self.vit_norm:
|
|
709
|
+
m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
|
710
|
+
|
|
711
|
+
# feed-forward network
|
|
712
|
+
if self.vit_norm:
|
|
713
|
+
m_inter = (x + m)
|
|
714
|
+
del x
|
|
715
|
+
m = self.norm2(m_inter.permute(0, 2, 3, 1))
|
|
716
|
+
m = self.mlp(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
717
|
+
return m_inter + m
|
|
718
|
+
else:
|
|
719
|
+
m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
720
|
+
m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
721
|
+
|
|
722
|
+
return x + m
|
|
723
|
+
|
|
724
|
+
return x + m
|
|
725
|
+
|
|
726
|
+
class AG_Conv_EncoderLayer(nn.Module):
|
|
727
|
+
def __init__(self,
|
|
728
|
+
d_model,
|
|
729
|
+
nhead,
|
|
730
|
+
attention='linear',
|
|
731
|
+
pool_size=4,
|
|
732
|
+
bn=True,
|
|
733
|
+
xformer=False,
|
|
734
|
+
leaky=-1.0,
|
|
735
|
+
dw_conv=False,
|
|
736
|
+
dw_conv2=False,
|
|
737
|
+
scatter=False,
|
|
738
|
+
norm_before=True,
|
|
739
|
+
):
|
|
740
|
+
super(AG_Conv_EncoderLayer, self).__init__()
|
|
741
|
+
|
|
742
|
+
self.pool_size = pool_size
|
|
743
|
+
self.dw_conv = dw_conv
|
|
744
|
+
self.dw_conv2 = dw_conv2
|
|
745
|
+
self.scatter = scatter
|
|
746
|
+
self.norm_before = norm_before
|
|
747
|
+
if self.dw_conv:
|
|
748
|
+
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
|
749
|
+
if self.dw_conv2:
|
|
750
|
+
self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
|
|
751
|
+
self.dim = d_model // nhead
|
|
752
|
+
self.nhead = nhead
|
|
753
|
+
|
|
754
|
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
|
|
755
|
+
|
|
756
|
+
# multi-head attention
|
|
757
|
+
if bn:
|
|
758
|
+
method = 'dw_bn'
|
|
759
|
+
else:
|
|
760
|
+
method = 'dw'
|
|
761
|
+
self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
762
|
+
self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
763
|
+
self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
764
|
+
|
|
765
|
+
if xformer:
|
|
766
|
+
self.attention = XAttention()
|
|
767
|
+
else:
|
|
768
|
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
|
769
|
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
|
770
|
+
|
|
771
|
+
# feed-forward network
|
|
772
|
+
if leaky > 0:
|
|
773
|
+
self.mlp = nn.Sequential(
|
|
774
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
775
|
+
nn.LeakyReLU(leaky, True),
|
|
776
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
else:
|
|
780
|
+
self.mlp = nn.Sequential(
|
|
781
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
782
|
+
nn.ReLU(True),
|
|
783
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# norm and dropout
|
|
787
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
788
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
789
|
+
|
|
790
|
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
|
791
|
+
"""
|
|
792
|
+
Args:
|
|
793
|
+
x (torch.Tensor): [N, C, H1, W1]
|
|
794
|
+
source (torch.Tensor): [N, C, H2, W2]
|
|
795
|
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
|
796
|
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
|
797
|
+
"""
|
|
798
|
+
bs = x.size(0)
|
|
799
|
+
H1, W1 = x.size(-2), x.size(-1)
|
|
800
|
+
H2, W2 = source.size(-2), source.size(-1)
|
|
801
|
+
C = x.shape[-3]
|
|
802
|
+
|
|
803
|
+
if self.norm_before:
|
|
804
|
+
if self.dw_conv:
|
|
805
|
+
query = self.norm1(self.aggregate(x).permute(0,2,3,1)).permute(0,3,1,2)
|
|
806
|
+
else:
|
|
807
|
+
query = self.norm1(self.max_pool(x).permute(0,2,3,1)).permute(0,3,1,2)
|
|
808
|
+
if self.dw_conv2:
|
|
809
|
+
source = self.norm1(self.aggregate2(source).permute(0,2,3,1)).permute(0,3,1,2)
|
|
810
|
+
else:
|
|
811
|
+
source = self.norm1(self.max_pool(source).permute(0,2,3,1)).permute(0,3,1,2)
|
|
812
|
+
else:
|
|
813
|
+
if self.dw_conv:
|
|
814
|
+
query = self.aggregate(x)
|
|
815
|
+
else:
|
|
816
|
+
query = self.max_pool(x)
|
|
817
|
+
if self.dw_conv2:
|
|
818
|
+
source = self.aggregate2(source)
|
|
819
|
+
else:
|
|
820
|
+
source = self.max_pool(source)
|
|
821
|
+
|
|
822
|
+
key, value = source, source
|
|
823
|
+
|
|
824
|
+
query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
|
825
|
+
key = self.k_proj_conv(key)
|
|
826
|
+
value = self.v_proj_conv(value)
|
|
827
|
+
|
|
828
|
+
use_mask = x_mask is not None and source_mask is not None
|
|
829
|
+
if bs == 1 or not use_mask:
|
|
830
|
+
if use_mask:
|
|
831
|
+
x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
|
|
832
|
+
source_mask = self.max_pool(source_mask.float()).bool()
|
|
833
|
+
|
|
834
|
+
mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
|
|
835
|
+
mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
|
|
836
|
+
|
|
837
|
+
query = query[:, :, :mask_h0, :mask_w0]
|
|
838
|
+
key = key[:, :, :mask_h1, :mask_w1]
|
|
839
|
+
value = value[:, :, :mask_h1, :mask_w1]
|
|
840
|
+
|
|
841
|
+
else:
|
|
842
|
+
assert x_mask is None and source_mask is None
|
|
843
|
+
|
|
844
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
845
|
+
query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
846
|
+
key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
847
|
+
value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
848
|
+
|
|
849
|
+
else: # N L H D
|
|
850
|
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
|
851
|
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
852
|
+
value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
853
|
+
|
|
854
|
+
message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
|
|
855
|
+
|
|
856
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
857
|
+
message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
|
858
|
+
|
|
859
|
+
if use_mask: # padding zero
|
|
860
|
+
message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L H D]
|
|
861
|
+
if mask_h0 != x_mask.size(-2):
|
|
862
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
|
|
863
|
+
elif mask_w0 != x_mask.size(-1):
|
|
864
|
+
message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
|
|
865
|
+
else:
|
|
866
|
+
assert x_mask is None and source_mask is None
|
|
867
|
+
|
|
868
|
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
869
|
+
message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
|
870
|
+
|
|
871
|
+
if self.scatter:
|
|
872
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-2)
|
|
873
|
+
message = torch.repeat_interleave(message, self.pool_size, dim=-1)
|
|
874
|
+
message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
|
|
875
|
+
else:
|
|
876
|
+
message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
877
|
+
|
|
878
|
+
if not self.norm_before:
|
|
879
|
+
message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
|
880
|
+
|
|
881
|
+
# feed-forward network
|
|
882
|
+
message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
883
|
+
message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
884
|
+
|
|
885
|
+
return x + message
|
|
886
|
+
else: # mask with bs > 1
|
|
887
|
+
x_mask = self.max_pool(x_mask.float()).bool()
|
|
888
|
+
source_mask = self.max_pool(source_mask.float()).bool()
|
|
889
|
+
m_list = []
|
|
890
|
+
for i in range(bs):
|
|
891
|
+
mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
|
|
892
|
+
mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
|
|
893
|
+
|
|
894
|
+
q = query[i:i+1, :, :mask_h0, :mask_w0]
|
|
895
|
+
k = key[i:i+1, :, :mask_h1, :mask_w1]
|
|
896
|
+
v = value[i:i+1, :, :mask_h1, :mask_w1]
|
|
897
|
+
|
|
898
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
899
|
+
q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
900
|
+
k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
901
|
+
v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
|
|
902
|
+
|
|
903
|
+
else: # N L H D
|
|
904
|
+
q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
|
905
|
+
k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
906
|
+
v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
907
|
+
|
|
908
|
+
m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
|
|
909
|
+
|
|
910
|
+
if PFLASH_AVAILABLE: # N H L D
|
|
911
|
+
m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
|
|
912
|
+
|
|
913
|
+
m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
|
|
914
|
+
if mask_h0 != x_mask.size(-2):
|
|
915
|
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
|
|
916
|
+
elif mask_w0 != x_mask.size(-1):
|
|
917
|
+
m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
|
|
918
|
+
m_list.append(m)
|
|
919
|
+
m = torch.cat(m_list, dim=0)
|
|
920
|
+
|
|
921
|
+
m = self.merge(m.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
922
|
+
|
|
923
|
+
# m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
|
|
924
|
+
m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
|
|
925
|
+
|
|
926
|
+
if self.scatter:
|
|
927
|
+
m = torch.repeat_interleave(m, self.pool_size, dim=-2)
|
|
928
|
+
m = torch.repeat_interleave(m, self.pool_size, dim=-1)
|
|
929
|
+
m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
|
|
930
|
+
else:
|
|
931
|
+
m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
932
|
+
|
|
933
|
+
if not self.norm_before:
|
|
934
|
+
m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
|
|
935
|
+
|
|
936
|
+
# feed-forward network
|
|
937
|
+
m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
938
|
+
m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
939
|
+
|
|
940
|
+
return x + m
|
|
941
|
+
|
|
942
|
+
def _build_projection(self,
|
|
943
|
+
dim_in,
|
|
944
|
+
dim_out,
|
|
945
|
+
kernel_size=3,
|
|
946
|
+
padding=1,
|
|
947
|
+
stride=1,
|
|
948
|
+
method='dw_bn',
|
|
949
|
+
):
|
|
950
|
+
if method == 'dw_bn':
|
|
951
|
+
proj = nn.Sequential(OrderedDict([
|
|
952
|
+
('conv', nn.Conv2d(
|
|
953
|
+
dim_in,
|
|
954
|
+
dim_in,
|
|
955
|
+
kernel_size=kernel_size,
|
|
956
|
+
padding=padding,
|
|
957
|
+
stride=stride,
|
|
958
|
+
bias=False,
|
|
959
|
+
groups=dim_in
|
|
960
|
+
)),
|
|
961
|
+
('bn', nn.BatchNorm2d(dim_in)),
|
|
962
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
963
|
+
]))
|
|
964
|
+
elif method == 'avg':
|
|
965
|
+
proj = nn.Sequential(OrderedDict([
|
|
966
|
+
('avg', nn.AvgPool2d(
|
|
967
|
+
kernel_size=kernel_size,
|
|
968
|
+
padding=padding,
|
|
969
|
+
stride=stride,
|
|
970
|
+
ceil_mode=True
|
|
971
|
+
)),
|
|
972
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
973
|
+
]))
|
|
974
|
+
elif method == 'linear':
|
|
975
|
+
proj = None
|
|
976
|
+
elif method == 'dw':
|
|
977
|
+
proj = nn.Sequential(OrderedDict([
|
|
978
|
+
('conv', nn.Conv2d(
|
|
979
|
+
dim_in,
|
|
980
|
+
dim_in,
|
|
981
|
+
kernel_size=kernel_size,
|
|
982
|
+
padding=padding,
|
|
983
|
+
stride=stride,
|
|
984
|
+
bias=False,
|
|
985
|
+
groups=dim_in
|
|
986
|
+
)),
|
|
987
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
988
|
+
]))
|
|
989
|
+
else:
|
|
990
|
+
raise ValueError('Unknown method ({})'.format(method))
|
|
991
|
+
|
|
992
|
+
return proj
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
class RoPELoFTREncoderLayer(nn.Module):
|
|
996
|
+
def __init__(self,
|
|
997
|
+
d_model,
|
|
998
|
+
nhead,
|
|
999
|
+
attention='linear',
|
|
1000
|
+
rope=False,
|
|
1001
|
+
token_mixer=None,
|
|
1002
|
+
):
|
|
1003
|
+
super(RoPELoFTREncoderLayer, self).__init__()
|
|
1004
|
+
|
|
1005
|
+
self.dim = d_model // nhead
|
|
1006
|
+
self.nhead = nhead
|
|
1007
|
+
|
|
1008
|
+
# multi-head attention
|
|
1009
|
+
if token_mixer is None:
|
|
1010
|
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1011
|
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1012
|
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1013
|
+
|
|
1014
|
+
self.rope = rope
|
|
1015
|
+
self.token_mixer = None
|
|
1016
|
+
if token_mixer is not None:
|
|
1017
|
+
self.token_mixer = token_mixer
|
|
1018
|
+
if token_mixer == 'dwcn':
|
|
1019
|
+
self.attention = nn.Sequential(OrderedDict([
|
|
1020
|
+
('conv', nn.Conv2d(
|
|
1021
|
+
d_model,
|
|
1022
|
+
d_model,
|
|
1023
|
+
kernel_size=3,
|
|
1024
|
+
padding=1,
|
|
1025
|
+
stride=1,
|
|
1026
|
+
bias=False,
|
|
1027
|
+
groups=d_model
|
|
1028
|
+
)),
|
|
1029
|
+
]))
|
|
1030
|
+
elif self.rope:
|
|
1031
|
+
assert attention == 'linear'
|
|
1032
|
+
self.attention = RoPELinearAttention()
|
|
1033
|
+
|
|
1034
|
+
if token_mixer is None:
|
|
1035
|
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
|
1036
|
+
|
|
1037
|
+
# feed-forward network
|
|
1038
|
+
if token_mixer is None:
|
|
1039
|
+
self.mlp = nn.Sequential(
|
|
1040
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
1041
|
+
nn.ReLU(True),
|
|
1042
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
1043
|
+
)
|
|
1044
|
+
else:
|
|
1045
|
+
self.mlp = nn.Sequential(
|
|
1046
|
+
nn.Linear(d_model, d_model, bias=False),
|
|
1047
|
+
nn.ReLU(True),
|
|
1048
|
+
nn.Linear(d_model, d_model, bias=False),
|
|
1049
|
+
)
|
|
1050
|
+
# norm and dropout
|
|
1051
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
1052
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
1053
|
+
|
|
1054
|
+
def forward(self, x, source, x_mask=None, source_mask=None, H=None, W=None):
|
|
1055
|
+
"""
|
|
1056
|
+
Args:
|
|
1057
|
+
x (torch.Tensor): [N, L, C]
|
|
1058
|
+
source (torch.Tensor): [N, L, C]
|
|
1059
|
+
x_mask (torch.Tensor): [N, L] (optional)
|
|
1060
|
+
source_mask (torch.Tensor): [N, S] (optional)
|
|
1061
|
+
"""
|
|
1062
|
+
bs = x.size(0)
|
|
1063
|
+
assert H*W == x.size(-2)
|
|
1064
|
+
|
|
1065
|
+
# x = rearrange(x, 'n c h w -> n (h w) c')
|
|
1066
|
+
# source = rearrange(source, 'n c h w -> n (h w) c')
|
|
1067
|
+
query, key, value = x, source, source
|
|
1068
|
+
|
|
1069
|
+
if self.token_mixer is not None:
|
|
1070
|
+
# multi-head attention
|
|
1071
|
+
m = self.norm1(x)
|
|
1072
|
+
m = rearrange(m, 'n (h w) c -> n c h w', h=H, w=W)
|
|
1073
|
+
m = self.attention(m)
|
|
1074
|
+
m = rearrange(m, 'n c h w -> n (h w) c')
|
|
1075
|
+
|
|
1076
|
+
x = x + m
|
|
1077
|
+
x = x + self.mlp(self.norm2(x))
|
|
1078
|
+
return x
|
|
1079
|
+
else:
|
|
1080
|
+
# multi-head attention
|
|
1081
|
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
|
1082
|
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
|
1083
|
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
|
1084
|
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask, H=H, W=W) # [N, L, (H, D)]
|
|
1085
|
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
1086
|
+
message = self.norm1(message)
|
|
1087
|
+
|
|
1088
|
+
# feed-forward network
|
|
1089
|
+
message = self.mlp(torch.cat([x, message], dim=2))
|
|
1090
|
+
message = self.norm2(message)
|
|
1091
|
+
|
|
1092
|
+
return x + message
|
|
1093
|
+
|
|
1094
|
+
class LoFTREncoderLayer(nn.Module):
|
|
1095
|
+
def __init__(self,
|
|
1096
|
+
d_model,
|
|
1097
|
+
nhead,
|
|
1098
|
+
attention='linear',
|
|
1099
|
+
xformer=False,
|
|
1100
|
+
):
|
|
1101
|
+
super(LoFTREncoderLayer, self).__init__()
|
|
1102
|
+
|
|
1103
|
+
self.dim = d_model // nhead
|
|
1104
|
+
self.nhead = nhead
|
|
1105
|
+
|
|
1106
|
+
# multi-head attention
|
|
1107
|
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1108
|
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1109
|
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1110
|
+
|
|
1111
|
+
if xformer:
|
|
1112
|
+
self.attention = XAttention()
|
|
1113
|
+
else:
|
|
1114
|
+
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
|
1115
|
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
|
1116
|
+
|
|
1117
|
+
# feed-forward network
|
|
1118
|
+
self.mlp = nn.Sequential(
|
|
1119
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
1120
|
+
nn.ReLU(True),
|
|
1121
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
# norm and dropout
|
|
1125
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
1126
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
1127
|
+
|
|
1128
|
+
def forward(self, x, source, x_mask=None, source_mask=None):
|
|
1129
|
+
"""
|
|
1130
|
+
Args:
|
|
1131
|
+
x (torch.Tensor): [N, L, C]
|
|
1132
|
+
source (torch.Tensor): [N, S, C]
|
|
1133
|
+
x_mask (torch.Tensor): [N, L] (optional)
|
|
1134
|
+
source_mask (torch.Tensor): [N, S] (optional)
|
|
1135
|
+
"""
|
|
1136
|
+
bs = x.size(0)
|
|
1137
|
+
query, key, value = x, source, source
|
|
1138
|
+
|
|
1139
|
+
# multi-head attention
|
|
1140
|
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
|
1141
|
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
|
1142
|
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
|
1143
|
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
|
1144
|
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
1145
|
+
message = self.norm1(message)
|
|
1146
|
+
|
|
1147
|
+
# feed-forward network
|
|
1148
|
+
message = self.mlp(torch.cat([x, message], dim=2))
|
|
1149
|
+
message = self.norm2(message)
|
|
1150
|
+
|
|
1151
|
+
return x + message
|
|
1152
|
+
|
|
1153
|
+
def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
|
|
1154
|
+
"""
|
|
1155
|
+
Args:
|
|
1156
|
+
x (torch.Tensor): [N, L, C]
|
|
1157
|
+
source (torch.Tensor): [N, S, C]
|
|
1158
|
+
x_mask (torch.Tensor): [N, L] (optional)
|
|
1159
|
+
source_mask (torch.Tensor): [N, S] (optional)
|
|
1160
|
+
"""
|
|
1161
|
+
bs = x.size(0)
|
|
1162
|
+
query, key, value = x, source, source
|
|
1163
|
+
|
|
1164
|
+
# multi-head attention
|
|
1165
|
+
with profiler.profile("proj*3"):
|
|
1166
|
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
|
1167
|
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
|
1168
|
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
|
1169
|
+
with profiler.profile("attention"):
|
|
1170
|
+
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
|
1171
|
+
with profiler.profile("merge"):
|
|
1172
|
+
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
1173
|
+
with profiler.profile("norm1"):
|
|
1174
|
+
message = self.norm1(message)
|
|
1175
|
+
|
|
1176
|
+
# feed-forward network
|
|
1177
|
+
with profiler.profile("mlp"):
|
|
1178
|
+
message = self.mlp(torch.cat([x, message], dim=2))
|
|
1179
|
+
with profiler.profile("norm2"):
|
|
1180
|
+
message = self.norm2(message)
|
|
1181
|
+
|
|
1182
|
+
return x + message
|
|
1183
|
+
|
|
1184
|
+
class PANEncoderLayer_cross(nn.Module):
|
|
1185
|
+
def __init__(self,
|
|
1186
|
+
d_model,
|
|
1187
|
+
nhead,
|
|
1188
|
+
attention='linear',
|
|
1189
|
+
pool_size=4,
|
|
1190
|
+
bn=True,
|
|
1191
|
+
):
|
|
1192
|
+
super(PANEncoderLayer_cross, self).__init__()
|
|
1193
|
+
|
|
1194
|
+
self.pool_size = pool_size
|
|
1195
|
+
|
|
1196
|
+
self.dim = d_model // nhead
|
|
1197
|
+
self.nhead = nhead
|
|
1198
|
+
|
|
1199
|
+
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
|
|
1200
|
+
# multi-head attention
|
|
1201
|
+
if bn:
|
|
1202
|
+
method = 'dw_bn'
|
|
1203
|
+
else:
|
|
1204
|
+
method = 'dw'
|
|
1205
|
+
self.qk_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
1206
|
+
self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
|
|
1207
|
+
|
|
1208
|
+
# self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
|
|
1209
|
+
# self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1210
|
+
# self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
|
1211
|
+
self.attention = FullAttention()
|
|
1212
|
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
|
1213
|
+
|
|
1214
|
+
# feed-forward network
|
|
1215
|
+
self.mlp = nn.Sequential(
|
|
1216
|
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
|
1217
|
+
nn.ReLU(True),
|
|
1218
|
+
nn.Linear(d_model*2, d_model, bias=False),
|
|
1219
|
+
)
|
|
1220
|
+
|
|
1221
|
+
# norm and dropout
|
|
1222
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
1223
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
1224
|
+
|
|
1225
|
+
# self.norm1 = nn.BatchNorm2d(d_model)
|
|
1226
|
+
|
|
1227
|
+
def forward(self, x1, x2, x1_mask=None, x2_mask=None):
|
|
1228
|
+
"""
|
|
1229
|
+
Args:
|
|
1230
|
+
x (torch.Tensor): [N, C, H1, W1]
|
|
1231
|
+
source (torch.Tensor): [N, C, H2, W2]
|
|
1232
|
+
x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
|
|
1233
|
+
source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
|
|
1234
|
+
"""
|
|
1235
|
+
bs = x1.size(0)
|
|
1236
|
+
H1, W1 = x1.size(-2) // self.pool_size, x1.size(-1) // self.pool_size
|
|
1237
|
+
H2, W2 = x2.size(-2) // self.pool_size, x2.size(-1) // self.pool_size
|
|
1238
|
+
|
|
1239
|
+
query = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
|
|
1240
|
+
key = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
|
|
1241
|
+
v2 = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
|
|
1242
|
+
v1 = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
|
|
1243
|
+
|
|
1244
|
+
# multi-head attention
|
|
1245
|
+
query = self.qk_proj_conv(query) # [N, C, H1//pool, W1//pool]
|
|
1246
|
+
key = self.qk_proj_conv(key)
|
|
1247
|
+
v2 = self.v_proj_conv(v2)
|
|
1248
|
+
v1 = self.v_proj_conv(v1)
|
|
1249
|
+
|
|
1250
|
+
C = query.shape[-3]
|
|
1251
|
+
if x1_mask is not None and x2_mask is not None:
|
|
1252
|
+
x1_mask = self.max_pool(x1_mask.float()).bool() # [N, H1//pool, W1//pool]
|
|
1253
|
+
x2_mask = self.max_pool(x2_mask.float()).bool()
|
|
1254
|
+
|
|
1255
|
+
mask_h1, mask_w1 = x1_mask[0].sum(-2)[0], x1_mask[0].sum(-1)[0]
|
|
1256
|
+
mask_h2, mask_w2 = x2_mask[0].sum(-2)[0], x2_mask[0].sum(-1)[0]
|
|
1257
|
+
|
|
1258
|
+
query = query[:, :, :mask_h1, :mask_w1]
|
|
1259
|
+
key = key[:, :, :mask_h2, :mask_w2]
|
|
1260
|
+
v1 = v1[:, :, :mask_h1, :mask_w1]
|
|
1261
|
+
v2 = v2[:, :, :mask_h2, :mask_w2]
|
|
1262
|
+
x1_mask = x1_mask[:, :mask_h1, :mask_w1]
|
|
1263
|
+
x2_mask = x2_mask[:, :mask_h2, :mask_w2]
|
|
1264
|
+
|
|
1265
|
+
else:
|
|
1266
|
+
assert x1_mask is None and x2_mask is None
|
|
1267
|
+
|
|
1268
|
+
query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
|
|
1269
|
+
key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
1270
|
+
v2 = rearrange(v2, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
1271
|
+
v1 = rearrange(v1, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
|
|
1272
|
+
if x2_mask is not None or x1_mask is not None:
|
|
1273
|
+
x1_mask = x1_mask.flatten(-2)
|
|
1274
|
+
x2_mask = x2_mask.flatten(-2)
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
QK = torch.einsum("nlhd,nshd->nlsh", query, key)
|
|
1278
|
+
with torch.autocast(enabled=False, device_type='cuda'):
|
|
1279
|
+
if x2_mask is not None or x1_mask is not None:
|
|
1280
|
+
# S1 = S2.transpose(-2,-3).masked_fill(~(x_mask[:, None, :, None] * source_mask[:, :, None, None]), -1e9) # float('-inf')
|
|
1281
|
+
QK = QK.float().masked_fill_(~(x1_mask[:, :, None, None] * x2_mask[:, None, :, None]), -1e9) # float('-inf')
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
# Compute the attention and the weighted average
|
|
1285
|
+
softmax_temp = 1. / query.size(3)**.5 # sqrt(D)
|
|
1286
|
+
S1 = torch.softmax(softmax_temp * QK, dim=2)
|
|
1287
|
+
S2 = torch.softmax(softmax_temp * QK, dim=3)
|
|
1288
|
+
|
|
1289
|
+
m1 = torch.einsum("nlsh,nshd->nlhd", S1, v2)
|
|
1290
|
+
m2 = torch.einsum("nlsh,nlhd->nshd", S2, v1)
|
|
1291
|
+
|
|
1292
|
+
if x1_mask is not None and x2_mask is not None:
|
|
1293
|
+
m1 = m1.view(bs, mask_h1, mask_w1, self.nhead, self.dim)
|
|
1294
|
+
if mask_h1 != H1:
|
|
1295
|
+
m1 = torch.cat([m1, torch.zeros(m1.size(0), H1-mask_h1, W1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=1)
|
|
1296
|
+
elif mask_w1 != W1:
|
|
1297
|
+
m1 = torch.cat([m1, torch.zeros(m1.size(0), H1, W1-mask_w1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=2)
|
|
1298
|
+
else:
|
|
1299
|
+
assert x1_mask is None and x2_mask is None
|
|
1300
|
+
|
|
1301
|
+
m1 = self.merge(m1.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
1302
|
+
m1 = rearrange(m1, 'b (h w) c -> b c h w', h=H1, w=W1) # [N, C, H, W]
|
|
1303
|
+
m1 = torch.nn.functional.interpolate(m1, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
1304
|
+
# feed-forward network
|
|
1305
|
+
m1 = self.mlp(torch.cat([x1, m1], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
1306
|
+
m1 = self.norm2(m1).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
1307
|
+
|
|
1308
|
+
if x1_mask is not None and x2_mask is not None:
|
|
1309
|
+
m2 = m2.view(bs, mask_h2, mask_w2, self.nhead, self.dim)
|
|
1310
|
+
if mask_h2 != H2:
|
|
1311
|
+
m2 = torch.cat([m2, torch.zeros(m2.size(0), H2-mask_h2, W2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=1)
|
|
1312
|
+
elif mask_w2 != W2:
|
|
1313
|
+
m2 = torch.cat([m2, torch.zeros(m2.size(0), H2, W2-mask_w2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=2)
|
|
1314
|
+
else:
|
|
1315
|
+
assert x1_mask is None and x2_mask is None
|
|
1316
|
+
|
|
1317
|
+
m2 = self.merge(m2.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
|
1318
|
+
m2 = rearrange(m2, 'b (h w) c -> b c h w', h=H2, w=W2) # [N, C, H, W]
|
|
1319
|
+
m2 = torch.nn.functional.interpolate(m2, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
|
|
1320
|
+
# feed-forward network
|
|
1321
|
+
m2 = self.mlp(torch.cat([x2, m2], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
|
|
1322
|
+
m2 = self.norm2(m2).permute(0, 3, 1, 2) # [N, C, H1, W1]
|
|
1323
|
+
|
|
1324
|
+
return x1 + m1, x2 + m2
|
|
1325
|
+
|
|
1326
|
+
def _build_projection(self,
|
|
1327
|
+
dim_in,
|
|
1328
|
+
dim_out,
|
|
1329
|
+
kernel_size=3,
|
|
1330
|
+
padding=1,
|
|
1331
|
+
stride=1,
|
|
1332
|
+
method='dw_bn',
|
|
1333
|
+
):
|
|
1334
|
+
if method == 'dw_bn':
|
|
1335
|
+
proj = nn.Sequential(OrderedDict([
|
|
1336
|
+
('conv', nn.Conv2d(
|
|
1337
|
+
dim_in,
|
|
1338
|
+
dim_in,
|
|
1339
|
+
kernel_size=kernel_size,
|
|
1340
|
+
padding=padding,
|
|
1341
|
+
stride=stride,
|
|
1342
|
+
bias=False,
|
|
1343
|
+
groups=dim_in
|
|
1344
|
+
)),
|
|
1345
|
+
('bn', nn.BatchNorm2d(dim_in)),
|
|
1346
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
1347
|
+
]))
|
|
1348
|
+
elif method == 'avg':
|
|
1349
|
+
proj = nn.Sequential(OrderedDict([
|
|
1350
|
+
('avg', nn.AvgPool2d(
|
|
1351
|
+
kernel_size=kernel_size,
|
|
1352
|
+
padding=padding,
|
|
1353
|
+
stride=stride,
|
|
1354
|
+
ceil_mode=True
|
|
1355
|
+
)),
|
|
1356
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
1357
|
+
]))
|
|
1358
|
+
elif method == 'linear':
|
|
1359
|
+
proj = None
|
|
1360
|
+
elif method == 'dw':
|
|
1361
|
+
proj = nn.Sequential(OrderedDict([
|
|
1362
|
+
('conv', nn.Conv2d(
|
|
1363
|
+
dim_in,
|
|
1364
|
+
dim_in,
|
|
1365
|
+
kernel_size=kernel_size,
|
|
1366
|
+
padding=padding,
|
|
1367
|
+
stride=stride,
|
|
1368
|
+
bias=False,
|
|
1369
|
+
groups=dim_in
|
|
1370
|
+
)),
|
|
1371
|
+
# ('rearrage', Rearrange('b c h w -> b (h w) c')),
|
|
1372
|
+
]))
|
|
1373
|
+
else:
|
|
1374
|
+
raise ValueError('Unknown method ({})'.format(method))
|
|
1375
|
+
|
|
1376
|
+
return proj
|
|
1377
|
+
|
|
1378
|
+
class LocalFeatureTransformer(nn.Module):
|
|
1379
|
+
"""A Local Feature Transformer (LoFTR) module."""
|
|
1380
|
+
|
|
1381
|
+
def __init__(self, config):
|
|
1382
|
+
super(LocalFeatureTransformer, self).__init__()
|
|
1383
|
+
|
|
1384
|
+
self.full_config = config
|
|
1385
|
+
self.fine = False
|
|
1386
|
+
if 'coarse' not in config:
|
|
1387
|
+
self.fine = True # fine attention
|
|
1388
|
+
else:
|
|
1389
|
+
config = config['coarse']
|
|
1390
|
+
self.d_model = config['d_model']
|
|
1391
|
+
self.nhead = config['nhead']
|
|
1392
|
+
self.layer_names = config['layer_names']
|
|
1393
|
+
self.pan = config['pan']
|
|
1394
|
+
self.bidirect = config['bidirection']
|
|
1395
|
+
# prune
|
|
1396
|
+
self.pool_size = config['pool_size']
|
|
1397
|
+
self.matchability = False
|
|
1398
|
+
self.depth_confidence = -1.0
|
|
1399
|
+
self.width_confidence = -1.0
|
|
1400
|
+
# self.depth_confidence = config['depth_confidence']
|
|
1401
|
+
# self.width_confidence = config['width_confidence']
|
|
1402
|
+
# self.matchability = self.depth_confidence > 0 or self.width_confidence > 0
|
|
1403
|
+
# self.thr = self.full_config['match_coarse']['thr']
|
|
1404
|
+
if not self.fine:
|
|
1405
|
+
# asy
|
|
1406
|
+
self.asymmetric = config['asymmetric']
|
|
1407
|
+
self.asymmetric_self = config['asymmetric_self']
|
|
1408
|
+
# aggregate
|
|
1409
|
+
self.aggregate = config['dwconv']
|
|
1410
|
+
# RoPE
|
|
1411
|
+
self.rope = config['rope']
|
|
1412
|
+
# absPE
|
|
1413
|
+
self.abspe = config['abspe']
|
|
1414
|
+
|
|
1415
|
+
else:
|
|
1416
|
+
self.rope, self.asymmetric, self.asymmetric_self, self.aggregate = False, False, False, False
|
|
1417
|
+
if self.matchability:
|
|
1418
|
+
self.n_layers = len(self.layer_names) // 2
|
|
1419
|
+
assert self.n_layers == 4
|
|
1420
|
+
self.log_assignment = nn.ModuleList(
|
|
1421
|
+
[MatchAssignment(self.d_model) for _ in range(self.n_layers)])
|
|
1422
|
+
self.token_confidence = nn.ModuleList([
|
|
1423
|
+
TokenConfidence(self.d_model) for _ in range(self.n_layers-1)])
|
|
1424
|
+
|
|
1425
|
+
self.CoarseMatching = CoarseMatching(self.full_config['match_coarse'])
|
|
1426
|
+
|
|
1427
|
+
# self only
|
|
1428
|
+
# if self.rope:
|
|
1429
|
+
# self_layer = RoPELoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'], config['rope'], config['token_mixer'])
|
|
1430
|
+
# self.layers = nn.ModuleList([copy.deepcopy(self_layer) for _ in range(len(self.layer_names))])
|
|
1431
|
+
|
|
1432
|
+
if self.bidirect:
|
|
1433
|
+
assert config['xformer'] is False and config['pan'] is True
|
|
1434
|
+
self_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'], config['xformer'])
|
|
1435
|
+
cross_layer = PANEncoderLayer_cross(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'])
|
|
1436
|
+
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
|
|
1437
|
+
else:
|
|
1438
|
+
if self.aggregate:
|
|
1439
|
+
if self.rope:
|
|
1440
|
+
# assert config['npe'][0] == 832 and config['npe'][1] == 832 and config['npe'][2] == 832 and config['npe'][3] == 832
|
|
1441
|
+
logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
|
|
1442
|
+
self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
|
1443
|
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
|
1444
|
+
config['norm_before'], config['rope'], config['npe'], config['vit_norm'], config['rope_dwproj'])
|
|
1445
|
+
cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
|
1446
|
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
|
1447
|
+
config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
|
|
1448
|
+
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
|
|
1449
|
+
elif self.abspe:
|
|
1450
|
+
logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
|
|
1451
|
+
self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
|
1452
|
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
|
1453
|
+
config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
|
|
1454
|
+
cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
|
|
1455
|
+
config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
|
|
1456
|
+
config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
|
|
1457
|
+
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
|
|
1458
|
+
|
|
1459
|
+
else:
|
|
1460
|
+
encoder_layer = AG_Conv_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'],
|
|
1461
|
+
config['xformer'], config['leaky'], config['dwconv'], config['scatter'],
|
|
1462
|
+
config['norm_before'])
|
|
1463
|
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
|
1464
|
+
else:
|
|
1465
|
+
encoder_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'],
|
|
1466
|
+
config['bn'], config['xformer'], config['leaky'], config['dwconv'], config['scatter']) \
|
|
1467
|
+
if config['pan'] else LoFTREncoderLayer(config['d_model'], config['nhead'],
|
|
1468
|
+
config['attention'], config['xformer'])
|
|
1469
|
+
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
|
1470
|
+
self._reset_parameters()
|
|
1471
|
+
|
|
1472
|
+
def _reset_parameters(self):
|
|
1473
|
+
for p in self.parameters():
|
|
1474
|
+
if p.dim() > 1:
|
|
1475
|
+
nn.init.xavier_uniform_(p)
|
|
1476
|
+
|
|
1477
|
+
def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
|
|
1478
|
+
"""
|
|
1479
|
+
Args:
|
|
1480
|
+
feat0 (torch.Tensor): [N, C, H, W]
|
|
1481
|
+
feat1 (torch.Tensor): [N, C, H, W]
|
|
1482
|
+
mask0 (torch.Tensor): [N, L] (optional)
|
|
1483
|
+
mask1 (torch.Tensor): [N, S] (optional)
|
|
1484
|
+
"""
|
|
1485
|
+
# nchw for pan and n(hw)c for loftr
|
|
1486
|
+
assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
|
|
1487
|
+
H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
|
|
1488
|
+
bs = feat0.shape[0]
|
|
1489
|
+
padding = False
|
|
1490
|
+
if bs == 1 and mask0 is not None and mask1 is not None and self.pan: # NCHW for pan
|
|
1491
|
+
mask_H0, mask_W0 = mask0.size(-2), mask0.size(-1)
|
|
1492
|
+
mask_H1, mask_W1 = mask1.size(-2), mask1.size(-1)
|
|
1493
|
+
mask_h0, mask_w0 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0]
|
|
1494
|
+
mask_h1, mask_w1 = mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
|
|
1495
|
+
|
|
1496
|
+
#round to self.pool_size
|
|
1497
|
+
if self.pan:
|
|
1498
|
+
mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.pool_size*self.pool_size, mask_w0//self.pool_size*self.pool_size, mask_h1//self.pool_size*self.pool_size, mask_w1//self.pool_size*self.pool_size
|
|
1499
|
+
|
|
1500
|
+
feat0 = feat0[:, :, :mask_h0, :mask_w0]
|
|
1501
|
+
feat1 = feat1[:, :, :mask_h1, :mask_w1]
|
|
1502
|
+
|
|
1503
|
+
padding = True
|
|
1504
|
+
|
|
1505
|
+
# rope self only
|
|
1506
|
+
# if self.rope:
|
|
1507
|
+
# feat0, feat1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
|
|
1508
|
+
# prune
|
|
1509
|
+
if padding:
|
|
1510
|
+
l0, l1 = mask_h0 * mask_w0, mask_h1 * mask_w1
|
|
1511
|
+
else:
|
|
1512
|
+
l0, l1 = H0 * W0, H1 * W1
|
|
1513
|
+
do_early_stop = self.depth_confidence > 0
|
|
1514
|
+
do_point_pruning = self.width_confidence > 0
|
|
1515
|
+
if do_point_pruning:
|
|
1516
|
+
ind0 = torch.arange(0, l0, device=feat0.device)[None]
|
|
1517
|
+
ind1 = torch.arange(0, l1, device=feat0.device)[None]
|
|
1518
|
+
# We store the index of the layer at which pruning is detected.
|
|
1519
|
+
prune0 = torch.ones_like(ind0)
|
|
1520
|
+
prune1 = torch.ones_like(ind1)
|
|
1521
|
+
if do_early_stop:
|
|
1522
|
+
token0, token1 = None, None
|
|
1523
|
+
|
|
1524
|
+
for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
|
|
1525
|
+
if padding:
|
|
1526
|
+
mask0, mask1 = None, None
|
|
1527
|
+
if name == 'self':
|
|
1528
|
+
# if self.rope:
|
|
1529
|
+
# feat0 = layer(feat0, feat0, mask0, mask1, H0, W0)
|
|
1530
|
+
# feat1 = layer(feat1, feat1, mask0, mask1, H1, W1)
|
|
1531
|
+
if self.asymmetric:
|
|
1532
|
+
assert False, 'not worked'
|
|
1533
|
+
# feat0 = layer(feat0, feat0, mask0, mask1)
|
|
1534
|
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
|
1535
|
+
else:
|
|
1536
|
+
feat0 = layer(feat0, feat0, mask0, mask0)
|
|
1537
|
+
feat1 = layer(feat1, feat1, mask1, mask1)
|
|
1538
|
+
elif name == 'cross':
|
|
1539
|
+
if self.bidirect:
|
|
1540
|
+
feat0, feat1 = layer(feat0, feat1, mask0, mask1)
|
|
1541
|
+
else:
|
|
1542
|
+
if self.asymmetric or self.asymmetric_self:
|
|
1543
|
+
assert False, 'not worked'
|
|
1544
|
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
|
1545
|
+
else:
|
|
1546
|
+
feat0 = layer(feat0, feat1, mask0, mask1)
|
|
1547
|
+
feat1 = layer(feat1, feat0, mask1, mask0)
|
|
1548
|
+
|
|
1549
|
+
if i == len(self.layer_names) - 1 and not self.training:
|
|
1550
|
+
continue
|
|
1551
|
+
if self.matchability:
|
|
1552
|
+
desc0, desc1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
|
|
1553
|
+
if do_early_stop:
|
|
1554
|
+
token0, token1 = self.token_confidence[i//2](desc0, desc1)
|
|
1555
|
+
if self.check_if_stop(token0, token1, i, l0+l1) and not self.training:
|
|
1556
|
+
break
|
|
1557
|
+
if do_point_pruning:
|
|
1558
|
+
scores0, scores1 = self.log_assignment[i//2].scores(desc0, desc1)
|
|
1559
|
+
mask0 = self.get_pruning_mask(token0, scores0, i)
|
|
1560
|
+
mask1 = self.get_pruning_mask(token1, scores1, i)
|
|
1561
|
+
ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
|
|
1562
|
+
feat0, feat1 = desc0[mask0][None], desc1[mask1][None]
|
|
1563
|
+
if feat0.shape[-2] == 0 or desc1.shape[-2] == 0:
|
|
1564
|
+
break
|
|
1565
|
+
prune0[:, ind0] += 1
|
|
1566
|
+
prune1[:, ind1] += 1
|
|
1567
|
+
if self.training and self.matchability:
|
|
1568
|
+
scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
|
|
1569
|
+
m0_full = torch.zeros((bs, mask_h0 * mask_w0), device=matchability0.device, dtype=matchability0.dtype)
|
|
1570
|
+
m0_full.scatter(1, ind0, matchability0.squeeze(-1))
|
|
1571
|
+
if padding and self.d_model == feat0.size(1):
|
|
1572
|
+
m0_full = m0_full.reshape(bs, mask_h0, mask_w0)
|
|
1573
|
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
|
1574
|
+
if mask_h0 != mask_H0:
|
|
1575
|
+
m0_full = torch.cat([m0_full, torch.zeros(bs, mask_H0-mask_h0, mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=1)
|
|
1576
|
+
elif mask_w0 != mask_W0:
|
|
1577
|
+
m0_full = torch.cat([m0_full, torch.zeros(bs, mask_h0, mask_W0-mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=2)
|
|
1578
|
+
m0_full = m0_full.reshape(bs, mask_H0*mask_W0)
|
|
1579
|
+
m1_full = torch.zeros((bs, mask_h1 * mask_w1), device=matchability0.device, dtype=matchability0.dtype)
|
|
1580
|
+
m1_full.scatter(1, ind1, matchability1.squeeze(-1))
|
|
1581
|
+
if padding and self.d_model == feat1.size(1):
|
|
1582
|
+
m1_full = m1_full.reshape(bs, mask_h1, mask_w1)
|
|
1583
|
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
|
1584
|
+
if mask_h1 != mask_H1:
|
|
1585
|
+
m1_full = torch.cat([m1_full, torch.zeros(bs, mask_H1-mask_h1, mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=1)
|
|
1586
|
+
elif mask_w1 != mask_W1:
|
|
1587
|
+
m1_full = torch.cat([m1_full, torch.zeros(bs, mask_h1, mask_W1-mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=2)
|
|
1588
|
+
m1_full = m1_full.reshape(bs, mask_H1*mask_W1)
|
|
1589
|
+
data.update({'matchability0_'+str(i//2): m0_full, 'matchability1_'+str(i//2): m1_full})
|
|
1590
|
+
m0, m1, mscores0, mscores1 = filter_matches(
|
|
1591
|
+
scores, self.thr)
|
|
1592
|
+
if do_point_pruning:
|
|
1593
|
+
m0_ = torch.full((bs, l0), -1, device=m0.device, dtype=m0.dtype)
|
|
1594
|
+
m1_ = torch.full((bs, l1), -1, device=m1.device, dtype=m1.dtype)
|
|
1595
|
+
m0_[:, ind0] = torch.where(
|
|
1596
|
+
m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
|
1597
|
+
m1_[:, ind1] = torch.where(
|
|
1598
|
+
m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
|
1599
|
+
mscores0_ = torch.zeros((bs, l0), device=mscores0.device)
|
|
1600
|
+
mscores1_ = torch.zeros((bs, l1), device=mscores1.device)
|
|
1601
|
+
mscores0_[:, ind0] = mscores0
|
|
1602
|
+
mscores1_[:, ind1] = mscores1
|
|
1603
|
+
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
|
1604
|
+
if padding and self.d_model == feat0.size(1):
|
|
1605
|
+
m0 = m0.reshape(bs, mask_h0, mask_w0)
|
|
1606
|
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
|
1607
|
+
if mask_h0 != mask_H0:
|
|
1608
|
+
m0 = torch.cat([m0, -torch.ones(bs, mask_H0-mask_h0, mask_w0, device=m0.device, dtype=m0.dtype)], dim=1)
|
|
1609
|
+
elif mask_w0 != mask_W0:
|
|
1610
|
+
m0 = torch.cat([m0, -torch.ones(bs, mask_h0, mask_W0-mask_w0, device=m0.device, dtype=m0.dtype)], dim=2)
|
|
1611
|
+
m0 = m0.reshape(bs, mask_H0*mask_W0)
|
|
1612
|
+
if padding and self.d_model == feat1.size(1):
|
|
1613
|
+
m1 = m1.reshape(bs, mask_h1, mask_w1)
|
|
1614
|
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
|
1615
|
+
if mask_h1 != mask_H1:
|
|
1616
|
+
m1 = torch.cat([m1, -torch.ones(bs, mask_H1-mask_h1, mask_w1, device=m1.device, dtype=m1.dtype)], dim=1)
|
|
1617
|
+
elif mask_w1 != mask_W1:
|
|
1618
|
+
m1 = torch.cat([m1, -torch.ones(bs, mask_h1, mask_W1-mask_w1, device=m1.device, dtype=m1.dtype)], dim=2)
|
|
1619
|
+
m1 = m1.reshape(bs, mask_H1*mask_W1)
|
|
1620
|
+
data.update({'matches0_'+str(i//2): m0, 'matches1_'+str(i//2): m1})
|
|
1621
|
+
conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
|
|
1622
|
+
ind = ind0[...,None] * l1 + ind1[:,None,:]
|
|
1623
|
+
# conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
|
|
1624
|
+
conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
|
|
1625
|
+
if padding and self.d_model == feat0.size(1):
|
|
1626
|
+
conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
|
|
1627
|
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
|
1628
|
+
if mask_h0 != mask_H0:
|
|
1629
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
|
|
1630
|
+
elif mask_w0 != mask_W0:
|
|
1631
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
|
|
1632
|
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
|
1633
|
+
if mask_h1 != mask_H1:
|
|
1634
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
|
|
1635
|
+
elif mask_w1 != mask_W1:
|
|
1636
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
|
|
1637
|
+
conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
|
|
1638
|
+
data.update({'conf_matrix_'+str(i//2): conf})
|
|
1639
|
+
|
|
1640
|
+
|
|
1641
|
+
|
|
1642
|
+
else:
|
|
1643
|
+
raise KeyError
|
|
1644
|
+
|
|
1645
|
+
if self.matchability and not self.training:
|
|
1646
|
+
scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
|
|
1647
|
+
conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
|
|
1648
|
+
ind = ind0[...,None] * l1 + ind1[:,None,:]
|
|
1649
|
+
# conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
|
|
1650
|
+
conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
|
|
1651
|
+
if padding and self.d_model == feat0.size(1):
|
|
1652
|
+
conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
|
|
1653
|
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
|
1654
|
+
if mask_h0 != mask_H0:
|
|
1655
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
|
|
1656
|
+
elif mask_w0 != mask_W0:
|
|
1657
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
|
|
1658
|
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
|
1659
|
+
if mask_h1 != mask_H1:
|
|
1660
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
|
|
1661
|
+
elif mask_w1 != mask_W1:
|
|
1662
|
+
conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
|
|
1663
|
+
conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
|
|
1664
|
+
data.update({'conf_matrix': conf})
|
|
1665
|
+
data.update(**self.CoarseMatching.get_coarse_match(conf, data))
|
|
1666
|
+
# m0, m1, mscores0, mscores1 = filter_matches(
|
|
1667
|
+
# scores, self.conf.filter_threshold)
|
|
1668
|
+
|
|
1669
|
+
# matches, mscores = [], []
|
|
1670
|
+
# for k in range(b):
|
|
1671
|
+
# valid = m0[k] > -1
|
|
1672
|
+
# m_indices_0 = torch.where(valid)[0]
|
|
1673
|
+
# m_indices_1 = m0[k][valid]
|
|
1674
|
+
# if do_point_pruning:
|
|
1675
|
+
# m_indices_0 = ind0[k, m_indices_0]
|
|
1676
|
+
# m_indices_1 = ind1[k, m_indices_1]
|
|
1677
|
+
# matches.append(torch.stack([m_indices_0, m_indices_1], -1))
|
|
1678
|
+
# mscores.append(mscores0[k][valid])
|
|
1679
|
+
|
|
1680
|
+
# # TODO: Remove when hloc switches to the compact format.
|
|
1681
|
+
# if do_point_pruning:
|
|
1682
|
+
# m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
|
1683
|
+
# m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
|
1684
|
+
# m0_[:, ind0] = torch.where(
|
|
1685
|
+
# m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
|
1686
|
+
# m1_[:, ind1] = torch.where(
|
|
1687
|
+
# m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
|
1688
|
+
# mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
|
1689
|
+
# mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
|
1690
|
+
# mscores0_[:, ind0] = mscores0
|
|
1691
|
+
# mscores1_[:, ind1] = mscores1
|
|
1692
|
+
# m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
|
1693
|
+
|
|
1694
|
+
# pred = {
|
|
1695
|
+
# 'matches0': m0,
|
|
1696
|
+
# 'matches1': m1,
|
|
1697
|
+
# 'matching_scores0': mscores0,
|
|
1698
|
+
# 'matching_scores1': mscores1,
|
|
1699
|
+
# 'stop': i+1,
|
|
1700
|
+
# 'matches': matches,
|
|
1701
|
+
# 'scores': mscores,
|
|
1702
|
+
# }
|
|
1703
|
+
|
|
1704
|
+
# if do_point_pruning:
|
|
1705
|
+
# pred.update(dict(prune0=prune0, prune1=prune1))
|
|
1706
|
+
# return pred
|
|
1707
|
+
|
|
1708
|
+
|
|
1709
|
+
if padding and self.d_model == feat0.size(1):
|
|
1710
|
+
bs, c, mask_h0, mask_w0 = feat0.size()
|
|
1711
|
+
if mask_h0 != mask_H0:
|
|
1712
|
+
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
|
|
1713
|
+
elif mask_w0 != mask_W0:
|
|
1714
|
+
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
|
|
1715
|
+
bs, c, mask_h1, mask_w1 = feat1.size()
|
|
1716
|
+
if mask_h1 != mask_H1:
|
|
1717
|
+
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
|
|
1718
|
+
elif mask_w1 != mask_W1:
|
|
1719
|
+
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
|
|
1720
|
+
|
|
1721
|
+
return feat0, feat1
|
|
1722
|
+
|
|
1723
|
+
def pro(self, feat0, feat1, mask0=None, mask1=None, profiler=None):
|
|
1724
|
+
"""
|
|
1725
|
+
Args:
|
|
1726
|
+
feat0 (torch.Tensor): [N, C, H, W]
|
|
1727
|
+
feat1 (torch.Tensor): [N, C, H, W]
|
|
1728
|
+
mask0 (torch.Tensor): [N, L] (optional)
|
|
1729
|
+
mask1 (torch.Tensor): [N, S] (optional)
|
|
1730
|
+
"""
|
|
1731
|
+
|
|
1732
|
+
assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
|
|
1733
|
+
with profiler.profile("LoFTR_transformer_attention"):
|
|
1734
|
+
for layer, name in zip(self.layers, self.layer_names):
|
|
1735
|
+
if name == 'self':
|
|
1736
|
+
feat0 = layer.pro(feat0, feat0, mask0, mask0, profiler=profiler)
|
|
1737
|
+
feat1 = layer.pro(feat1, feat1, mask1, mask1, profiler=profiler)
|
|
1738
|
+
elif name == 'cross':
|
|
1739
|
+
feat0 = layer.pro(feat0, feat1, mask0, mask1, profiler=profiler)
|
|
1740
|
+
feat1 = layer.pro(feat1, feat0, mask1, mask0, profiler=profiler)
|
|
1741
|
+
else:
|
|
1742
|
+
raise KeyError
|
|
1743
|
+
|
|
1744
|
+
return feat0, feat1
|
|
1745
|
+
|
|
1746
|
+
def confidence_threshold(self, layer_index: int) -> float:
|
|
1747
|
+
""" scaled confidence threshold """
|
|
1748
|
+
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
|
|
1749
|
+
return np.clip(threshold, 0, 1)
|
|
1750
|
+
|
|
1751
|
+
def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor,
|
|
1752
|
+
layer_index: int) -> torch.Tensor:
|
|
1753
|
+
""" mask points which should be removed """
|
|
1754
|
+
threshold = self.confidence_threshold(layer_index)
|
|
1755
|
+
if confidences is not None:
|
|
1756
|
+
scores = torch.where(
|
|
1757
|
+
confidences > threshold, scores, scores.new_tensor(1.0))
|
|
1758
|
+
return scores > (1 - self.width_confidence)
|
|
1759
|
+
|
|
1760
|
+
def check_if_stop(self,
|
|
1761
|
+
confidences0: torch.Tensor,
|
|
1762
|
+
confidences1: torch.Tensor,
|
|
1763
|
+
layer_index: int, num_points: int) -> torch.Tensor:
|
|
1764
|
+
""" evaluate stopping condition"""
|
|
1765
|
+
confidences = torch.cat([confidences0, confidences1], -1)
|
|
1766
|
+
threshold = self.confidence_threshold(layer_index)
|
|
1767
|
+
pos = 1.0 - (confidences < threshold).float().sum() / num_points
|
|
1768
|
+
return pos > self.depth_confidence
|