vismatch 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2042) hide show
  1. vismatch/TEMPLATE.py +101 -0
  2. vismatch/__init__.py +475 -0
  3. vismatch/assets/example_pairs/false_positive/chartres.jpg +0 -0
  4. vismatch/assets/example_pairs/false_positive/notre_dame.jpg +0 -0
  5. vismatch/assets/example_pairs/fresco/fsm.jpg +0 -0
  6. vismatch/assets/example_pairs/fresco/sist_chapel.jpg +0 -0
  7. vismatch/assets/example_pairs/indoor/gcs_close.jpg +0 -0
  8. vismatch/assets/example_pairs/indoor/gcs_far.jpg +0 -0
  9. vismatch/assets/example_pairs/outdoor/montmartre_close.jpg +0 -0
  10. vismatch/assets/example_pairs/outdoor/montmartre_far.jpg +0 -0
  11. vismatch/assets/example_pairs/sat2iss/photo_from_iss.jpg +0 -0
  12. vismatch/assets/example_pairs/sat2iss/satellite_img.jpg +0 -0
  13. vismatch/assets/example_pairs/sphereglue/barbershop-00000000.jpg +0 -0
  14. vismatch/assets/example_pairs/sphereglue/barbershop-00000001.jpg +0 -0
  15. vismatch/assets/example_pairs/thermal/thermal.jpg +0 -0
  16. vismatch/assets/example_pairs/thermal/visible.jpg +0 -0
  17. vismatch/assets/example_test/original.jpg +0 -0
  18. vismatch/assets/example_test/warped.jpg +0 -0
  19. vismatch/base_matcher.py +242 -0
  20. vismatch/im_models/__init__.py +0 -0
  21. vismatch/im_models/aff_steerers.py +143 -0
  22. vismatch/im_models/aspanformer.py +74 -0
  23. vismatch/im_models/dedode.py +150 -0
  24. vismatch/im_models/duster.py +104 -0
  25. vismatch/im_models/edm.py +64 -0
  26. vismatch/im_models/efficient_loftr.py +60 -0
  27. vismatch/im_models/gim.py +187 -0
  28. vismatch/im_models/handcrafted.py +81 -0
  29. vismatch/im_models/keypt2subpx.py +154 -0
  30. vismatch/im_models/kornia.py +72 -0
  31. vismatch/im_models/liftfeat.py +44 -0
  32. vismatch/im_models/lightglue.py +75 -0
  33. vismatch/im_models/lisrd.py +98 -0
  34. vismatch/im_models/loftr.py +23 -0
  35. vismatch/im_models/master.py +107 -0
  36. vismatch/im_models/matchanything.py +221 -0
  37. vismatch/im_models/matchformer.py +61 -0
  38. vismatch/im_models/matching_toolbox.py +238 -0
  39. vismatch/im_models/minima.py +164 -0
  40. vismatch/im_models/omniglue.py +91 -0
  41. vismatch/im_models/rdd.py +250 -0
  42. vismatch/im_models/ripe.py +55 -0
  43. vismatch/im_models/roma.py +92 -0
  44. vismatch/im_models/romav2.py +62 -0
  45. vismatch/im_models/se2loftr.py +71 -0
  46. vismatch/im_models/silk.py +405 -0
  47. vismatch/im_models/sphereglue.py +97 -0
  48. vismatch/im_models/steerers.py +140 -0
  49. vismatch/im_models/topicfm.py +93 -0
  50. vismatch/im_models/ufm.py +57 -0
  51. vismatch/im_models/xfeat.py +78 -0
  52. vismatch/im_models/xfeat_steerers.py +151 -0
  53. vismatch/im_models/xoftr.py +71 -0
  54. vismatch/third_party/DeDoDe/DeDoDe/__init__.py +2 -0
  55. vismatch/third_party/DeDoDe/DeDoDe/benchmarks/__init__.py +4 -0
  56. vismatch/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est.py +114 -0
  57. vismatch/third_party/DeDoDe/DeDoDe/benchmarks/mega_pose_est_mnn.py +119 -0
  58. vismatch/third_party/DeDoDe/DeDoDe/benchmarks/nll_benchmark.py +57 -0
  59. vismatch/third_party/DeDoDe/DeDoDe/benchmarks/num_inliers.py +76 -0
  60. vismatch/third_party/DeDoDe/DeDoDe/checkpoint.py +59 -0
  61. vismatch/third_party/DeDoDe/DeDoDe/datasets/__init__.py +0 -0
  62. vismatch/third_party/DeDoDe/DeDoDe/datasets/megadepth.py +269 -0
  63. vismatch/third_party/DeDoDe/DeDoDe/decoder.py +90 -0
  64. vismatch/third_party/DeDoDe/DeDoDe/descriptors/__init__.py +0 -0
  65. vismatch/third_party/DeDoDe/DeDoDe/descriptors/dedode_descriptor.py +50 -0
  66. vismatch/third_party/DeDoDe/DeDoDe/descriptors/descriptor_loss.py +68 -0
  67. vismatch/third_party/DeDoDe/DeDoDe/detectors/__init__.py +0 -0
  68. vismatch/third_party/DeDoDe/DeDoDe/detectors/dedode_detector.py +76 -0
  69. vismatch/third_party/DeDoDe/DeDoDe/detectors/keypoint_loss.py +185 -0
  70. vismatch/third_party/DeDoDe/DeDoDe/encoder.py +87 -0
  71. vismatch/third_party/DeDoDe/DeDoDe/matchers/__init__.py +0 -0
  72. vismatch/third_party/DeDoDe/DeDoDe/matchers/dual_softmax_matcher.py +38 -0
  73. vismatch/third_party/DeDoDe/DeDoDe/model_zoo/__init__.py +3 -0
  74. vismatch/third_party/DeDoDe/DeDoDe/model_zoo/dedode_models.py +249 -0
  75. vismatch/third_party/DeDoDe/DeDoDe/train.py +76 -0
  76. vismatch/third_party/DeDoDe/DeDoDe/transformer/__init__.py +8 -0
  77. vismatch/third_party/DeDoDe/DeDoDe/transformer/dinov2.py +359 -0
  78. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/__init__.py +12 -0
  79. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/attention.py +81 -0
  80. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/block.py +252 -0
  81. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/dino_head.py +59 -0
  82. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/drop_path.py +35 -0
  83. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/layer_scale.py +28 -0
  84. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/mlp.py +41 -0
  85. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/patch_embed.py +89 -0
  86. vismatch/third_party/DeDoDe/DeDoDe/transformer/layers/swiglu_ffn.py +63 -0
  87. vismatch/third_party/DeDoDe/DeDoDe/utils.py +717 -0
  88. vismatch/third_party/DeDoDe/data_prep/prep_keypoints.py +103 -0
  89. vismatch/third_party/DeDoDe/demo/demo_kpts.py +24 -0
  90. vismatch/third_party/DeDoDe/demo/demo_match.py +46 -0
  91. vismatch/third_party/DeDoDe/demo/demo_match_dedode_G.py +45 -0
  92. vismatch/third_party/DeDoDe/demo/demo_scoremap.py +23 -0
  93. vismatch/third_party/DeDoDe/experiments/dedode_descriptor-B.py +135 -0
  94. vismatch/third_party/DeDoDe/experiments/dedode_descriptor-G.py +145 -0
  95. vismatch/third_party/DeDoDe/experiments/dedode_detector.py +126 -0
  96. vismatch/third_party/DeDoDe/experiments/eval/eval_dedode_descriptor-B.py +38 -0
  97. vismatch/third_party/DeDoDe/experiments/eval/eval_dedode_descriptor-G.py +38 -0
  98. vismatch/third_party/DeDoDe/setup.py +11 -0
  99. vismatch/third_party/EDM/configs/data/__init__.py +0 -0
  100. vismatch/third_party/EDM/configs/data/base.py +37 -0
  101. vismatch/third_party/EDM/configs/data/megadepth_test_1500.py +23 -0
  102. vismatch/third_party/EDM/configs/data/megadepth_trainval_832.py +32 -0
  103. vismatch/third_party/EDM/configs/data/scannet_test_1500.py +24 -0
  104. vismatch/third_party/EDM/configs/data/scannet_trainval.py +31 -0
  105. vismatch/third_party/EDM/configs/edm/indoor/edm_base.py +15 -0
  106. vismatch/third_party/EDM/configs/edm/outdoor/edm_base.py +17 -0
  107. vismatch/third_party/EDM/deploy/export_onnx.py +69 -0
  108. vismatch/third_party/EDM/deploy/run_onnx.py +138 -0
  109. vismatch/third_party/EDM/runtime_single_pair.py +73 -0
  110. vismatch/third_party/EDM/src/__init__.py +0 -0
  111. vismatch/third_party/EDM/src/config/default.py +184 -0
  112. vismatch/third_party/EDM/src/datasets/megadepth.py +164 -0
  113. vismatch/third_party/EDM/src/datasets/sampler.py +95 -0
  114. vismatch/third_party/EDM/src/datasets/scannet.py +147 -0
  115. vismatch/third_party/EDM/src/edm/__init__.py +2 -0
  116. vismatch/third_party/EDM/src/edm/backbone/resnet.py +116 -0
  117. vismatch/third_party/EDM/src/edm/edm.py +204 -0
  118. vismatch/third_party/EDM/src/edm/head/coarse_matching.py +158 -0
  119. vismatch/third_party/EDM/src/edm/head/fine_matching.py +383 -0
  120. vismatch/third_party/EDM/src/edm/neck/__init__.py +1 -0
  121. vismatch/third_party/EDM/src/edm/neck/loftr_module/__init__.py +1 -0
  122. vismatch/third_party/EDM/src/edm/neck/loftr_module/transformer.py +418 -0
  123. vismatch/third_party/EDM/src/edm/neck/neck.py +156 -0
  124. vismatch/third_party/EDM/src/edm/utils/geometry.py +58 -0
  125. vismatch/third_party/EDM/src/edm/utils/supervision.py +255 -0
  126. vismatch/third_party/EDM/src/lightning/data.py +450 -0
  127. vismatch/third_party/EDM/src/lightning/lightning_edm.py +379 -0
  128. vismatch/third_party/EDM/src/losses/edm_loss.py +206 -0
  129. vismatch/third_party/EDM/src/optimizers/__init__.py +57 -0
  130. vismatch/third_party/EDM/src/utils/augment.py +65 -0
  131. vismatch/third_party/EDM/src/utils/comm.py +271 -0
  132. vismatch/third_party/EDM/src/utils/dataloader.py +24 -0
  133. vismatch/third_party/EDM/src/utils/dataset.py +192 -0
  134. vismatch/third_party/EDM/src/utils/metrics.py +299 -0
  135. vismatch/third_party/EDM/src/utils/misc.py +113 -0
  136. vismatch/third_party/EDM/src/utils/plotting.py +186 -0
  137. vismatch/third_party/EDM/src/utils/profiler.py +40 -0
  138. vismatch/third_party/EDM/src/utils/warppers.py +428 -0
  139. vismatch/third_party/EDM/src/utils/warppers_utils.py +172 -0
  140. vismatch/third_party/EDM/test.py +132 -0
  141. vismatch/third_party/EDM/train.py +156 -0
  142. vismatch/third_party/EfficientLoFTR/configs/data/__init__.py +0 -0
  143. vismatch/third_party/EfficientLoFTR/configs/data/base.py +35 -0
  144. vismatch/third_party/EfficientLoFTR/configs/data/megadepth_test_1500.py +13 -0
  145. vismatch/third_party/EfficientLoFTR/configs/data/megadepth_trainval_832.py +24 -0
  146. vismatch/third_party/EfficientLoFTR/configs/data/scannet_test_1500.py +16 -0
  147. vismatch/third_party/EfficientLoFTR/configs/loftr/eloftr_full.py +36 -0
  148. vismatch/third_party/EfficientLoFTR/configs/loftr/eloftr_optimized.py +37 -0
  149. vismatch/third_party/EfficientLoFTR/src/__init__.py +0 -0
  150. vismatch/third_party/EfficientLoFTR/src/config/default.py +182 -0
  151. vismatch/third_party/EfficientLoFTR/src/datasets/megadepth.py +133 -0
  152. vismatch/third_party/EfficientLoFTR/src/datasets/sampler.py +77 -0
  153. vismatch/third_party/EfficientLoFTR/src/datasets/scannet.py +129 -0
  154. vismatch/third_party/EfficientLoFTR/src/lightning/data.py +357 -0
  155. vismatch/third_party/EfficientLoFTR/src/lightning/lightning_loftr.py +272 -0
  156. vismatch/third_party/EfficientLoFTR/src/loftr/__init__.py +4 -0
  157. vismatch/third_party/EfficientLoFTR/src/loftr/backbone/__init__.py +11 -0
  158. vismatch/third_party/EfficientLoFTR/src/loftr/backbone/backbone.py +37 -0
  159. vismatch/third_party/EfficientLoFTR/src/loftr/backbone/repvgg.py +224 -0
  160. vismatch/third_party/EfficientLoFTR/src/loftr/loftr.py +124 -0
  161. vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/__init__.py +2 -0
  162. vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/fine_preprocess.py +112 -0
  163. vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/linear_attention.py +103 -0
  164. vismatch/third_party/EfficientLoFTR/src/loftr/loftr_module/transformer.py +164 -0
  165. vismatch/third_party/EfficientLoFTR/src/loftr/utils/coarse_matching.py +241 -0
  166. vismatch/third_party/EfficientLoFTR/src/loftr/utils/fine_matching.py +156 -0
  167. vismatch/third_party/EfficientLoFTR/src/loftr/utils/full_config.py +50 -0
  168. vismatch/third_party/EfficientLoFTR/src/loftr/utils/geometry.py +54 -0
  169. vismatch/third_party/EfficientLoFTR/src/loftr/utils/opt_config.py +50 -0
  170. vismatch/third_party/EfficientLoFTR/src/loftr/utils/position_encoding.py +50 -0
  171. vismatch/third_party/EfficientLoFTR/src/loftr/utils/supervision.py +275 -0
  172. vismatch/third_party/EfficientLoFTR/src/losses/loftr_loss.py +229 -0
  173. vismatch/third_party/EfficientLoFTR/src/optimizers/__init__.py +42 -0
  174. vismatch/third_party/EfficientLoFTR/src/utils/augment.py +55 -0
  175. vismatch/third_party/EfficientLoFTR/src/utils/comm.py +265 -0
  176. vismatch/third_party/EfficientLoFTR/src/utils/dataloader.py +23 -0
  177. vismatch/third_party/EfficientLoFTR/src/utils/dataset.py +186 -0
  178. vismatch/third_party/EfficientLoFTR/src/utils/metrics.py +264 -0
  179. vismatch/third_party/EfficientLoFTR/src/utils/misc.py +106 -0
  180. vismatch/third_party/EfficientLoFTR/src/utils/plotting.py +154 -0
  181. vismatch/third_party/EfficientLoFTR/src/utils/profiler.py +39 -0
  182. vismatch/third_party/EfficientLoFTR/src/utils/warppers.py +426 -0
  183. vismatch/third_party/EfficientLoFTR/src/utils/warppers_utils.py +171 -0
  184. vismatch/third_party/EfficientLoFTR/test.py +143 -0
  185. vismatch/third_party/EfficientLoFTR/train.py +154 -0
  186. vismatch/third_party/LISRD/lisrd/__init__.py +0 -0
  187. vismatch/third_party/LISRD/lisrd/datasets/__init__.py +7 -0
  188. vismatch/third_party/LISRD/lisrd/datasets/base_dataset.py +38 -0
  189. vismatch/third_party/LISRD/lisrd/datasets/coco.py +148 -0
  190. vismatch/third_party/LISRD/lisrd/datasets/flashes.py +170 -0
  191. vismatch/third_party/LISRD/lisrd/datasets/hpatches.py +135 -0
  192. vismatch/third_party/LISRD/lisrd/datasets/mixed_dataset.py +53 -0
  193. vismatch/third_party/LISRD/lisrd/datasets/rdnim.py +117 -0
  194. vismatch/third_party/LISRD/lisrd/datasets/utils/data_augmentation.py +168 -0
  195. vismatch/third_party/LISRD/lisrd/datasets/utils/data_reader.py +48 -0
  196. vismatch/third_party/LISRD/lisrd/datasets/utils/homographies.py +215 -0
  197. vismatch/third_party/LISRD/lisrd/datasets/vidit.py +152 -0
  198. vismatch/third_party/LISRD/lisrd/evaluation/__init__.py +0 -0
  199. vismatch/third_party/LISRD/lisrd/evaluation/descriptor_evaluation.py +142 -0
  200. vismatch/third_party/LISRD/lisrd/experiment.py +129 -0
  201. vismatch/third_party/LISRD/lisrd/export_features.py +148 -0
  202. vismatch/third_party/LISRD/lisrd/models/__init__.py +7 -0
  203. vismatch/third_party/LISRD/lisrd/models/backbones/__init__.py +0 -0
  204. vismatch/third_party/LISRD/lisrd/models/backbones/net_vlad.py +62 -0
  205. vismatch/third_party/LISRD/lisrd/models/backbones/vgg.py +46 -0
  206. vismatch/third_party/LISRD/lisrd/models/base_model.py +336 -0
  207. vismatch/third_party/LISRD/lisrd/models/keypoint_detectors.py +34 -0
  208. vismatch/third_party/LISRD/lisrd/models/lisrd.py +328 -0
  209. vismatch/third_party/LISRD/lisrd/models/lisrd_sift.py +289 -0
  210. vismatch/third_party/LISRD/lisrd/third_party/super_point_magic_leap/demo_superpoint.py +734 -0
  211. vismatch/third_party/LISRD/lisrd/utils/geometry_utils.py +123 -0
  212. vismatch/third_party/LISRD/lisrd/utils/losses.py +191 -0
  213. vismatch/third_party/LISRD/lisrd/utils/metrics.py +66 -0
  214. vismatch/third_party/LISRD/lisrd/utils/pytorch_utils.py +14 -0
  215. vismatch/third_party/LISRD/lisrd/utils/stdout_capturing.py +81 -0
  216. vismatch/third_party/LISRD/notebooks/utils.py +103 -0
  217. vismatch/third_party/LISRD/setup.py +4 -0
  218. vismatch/third_party/LiftFeat/dataset/__init__.py +0 -0
  219. vismatch/third_party/LiftFeat/dataset/coco_augmentor.py +298 -0
  220. vismatch/third_party/LiftFeat/dataset/coco_wrapper.py +175 -0
  221. vismatch/third_party/LiftFeat/dataset/dataset_utils.py +183 -0
  222. vismatch/third_party/LiftFeat/dataset/megadepth.py +177 -0
  223. vismatch/third_party/LiftFeat/dataset/megadepth_wrapper.py +167 -0
  224. vismatch/third_party/LiftFeat/demo.py +116 -0
  225. vismatch/third_party/LiftFeat/evaluation/HPatch_evaluation.py +182 -0
  226. vismatch/third_party/LiftFeat/evaluation/MegaDepth1500_evaluation.py +105 -0
  227. vismatch/third_party/LiftFeat/evaluation/eval_utils.py +127 -0
  228. vismatch/third_party/LiftFeat/loss/loss.py +291 -0
  229. vismatch/third_party/LiftFeat/models/interpolator.py +34 -0
  230. vismatch/third_party/LiftFeat/models/liftfeat_wrapper.py +172 -0
  231. vismatch/third_party/LiftFeat/models/model.py +419 -0
  232. vismatch/third_party/LiftFeat/tools/demo_match_video.py +145 -0
  233. vismatch/third_party/LiftFeat/tools/demo_vo.py +163 -0
  234. vismatch/third_party/LiftFeat/train.py +369 -0
  235. vismatch/third_party/LiftFeat/utils/VisualOdometry.py +339 -0
  236. vismatch/third_party/LiftFeat/utils/__init__.py +0 -0
  237. vismatch/third_party/LiftFeat/utils/alike_wrapper.py +45 -0
  238. vismatch/third_party/LiftFeat/utils/config.py +16 -0
  239. vismatch/third_party/LiftFeat/utils/depth_anything_wrapper.py +150 -0
  240. vismatch/third_party/LiftFeat/utils/featurebooster.py +247 -0
  241. vismatch/third_party/LiftFeat/utils/post_process.py +21 -0
  242. vismatch/third_party/LightGlue/benchmark.py +255 -0
  243. vismatch/third_party/LightGlue/lightglue/__init__.py +7 -0
  244. vismatch/third_party/LightGlue/lightglue/aliked.py +760 -0
  245. vismatch/third_party/LightGlue/lightglue/disk.py +55 -0
  246. vismatch/third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
  247. vismatch/third_party/LightGlue/lightglue/lightglue.py +662 -0
  248. vismatch/third_party/LightGlue/lightglue/sift.py +216 -0
  249. vismatch/third_party/LightGlue/lightglue/superpoint.py +227 -0
  250. vismatch/third_party/LightGlue/lightglue/utils.py +165 -0
  251. vismatch/third_party/LightGlue/lightglue/viz2d.py +203 -0
  252. vismatch/third_party/MINIMA/demo.py +201 -0
  253. vismatch/third_party/MINIMA/src/__init__.py +0 -0
  254. vismatch/third_party/MINIMA/src/config/default.py +203 -0
  255. vismatch/third_party/MINIMA/src/config/default_for_megadepth_dense.py +203 -0
  256. vismatch/third_party/MINIMA/src/config/default_for_megadepth_sparse.py +203 -0
  257. vismatch/third_party/MINIMA/src/utils/__init__.py +0 -0
  258. vismatch/third_party/MINIMA/src/utils/culculate_auc.py +28 -0
  259. vismatch/third_party/MINIMA/src/utils/data_io.py +156 -0
  260. vismatch/third_party/MINIMA/src/utils/data_io_loftr.py +152 -0
  261. vismatch/third_party/MINIMA/src/utils/data_io_roma.py +186 -0
  262. vismatch/third_party/MINIMA/src/utils/data_io_sp_lg.py +158 -0
  263. vismatch/third_party/MINIMA/src/utils/load_model.py +164 -0
  264. vismatch/third_party/MINIMA/src/utils/metrics.py +214 -0
  265. vismatch/third_party/MINIMA/src/utils/misc.py +101 -0
  266. vismatch/third_party/MINIMA/src/utils/plotting.py +291 -0
  267. vismatch/third_party/MINIMA/src/utils/sample_h.py +142 -0
  268. vismatch/third_party/MINIMA/test_relative_homo_depth.py +683 -0
  269. vismatch/third_party/MINIMA/test_relative_homo_event.py +722 -0
  270. vismatch/third_party/MINIMA/test_relative_homo_mmim.py +669 -0
  271. vismatch/third_party/MINIMA/test_relative_pose_infrared.py +500 -0
  272. vismatch/third_party/MINIMA/test_relative_pose_mega_1500.py +487 -0
  273. vismatch/third_party/MINIMA/test_relative_pose_mega_1500_syn.py +516 -0
  274. vismatch/third_party/MINIMA/third_party/LightGlue/benchmark.py +255 -0
  275. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/__init__.py +7 -0
  276. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/aliked.py +758 -0
  277. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/disk.py +55 -0
  278. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
  279. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/lightglue.py +655 -0
  280. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/sift.py +216 -0
  281. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/superpoint.py +227 -0
  282. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/utils.py +165 -0
  283. vismatch/third_party/MINIMA/third_party/LightGlue/lightglue/viz2d.py +184 -0
  284. vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/__init__.py +0 -0
  285. vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/base.py +35 -0
  286. vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/megadepth_test_1500.py +11 -0
  287. vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/megadepth_trainval_640.py +22 -0
  288. vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/megadepth_trainval_840.py +22 -0
  289. vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/scannet_test_1500.py +11 -0
  290. vismatch/third_party/MINIMA/third_party/LoFTR/configs/data/scannet_trainval.py +17 -0
  291. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds.py +6 -0
  292. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py +8 -0
  293. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot.py +6 -0
  294. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py +8 -0
  295. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ds.py +5 -0
  296. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ds_dense.py +7 -0
  297. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ot.py +5 -0
  298. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/loftr_ot_dense.py +7 -0
  299. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval.py +16 -0
  300. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval_new.py +18 -0
  301. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py +16 -0
  302. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py +17 -0
  303. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py +16 -0
  304. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py +17 -0
  305. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ds.py +15 -0
  306. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ds_dense.py +16 -0
  307. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ot.py +15 -0
  308. vismatch/third_party/MINIMA/third_party/LoFTR/configs/loftr/outdoor/loftr_ot_dense.py +16 -0
  309. vismatch/third_party/MINIMA/third_party/LoFTR/demo/demo_loftr.py +240 -0
  310. vismatch/third_party/MINIMA/third_party/LoFTR/src/__init__.py +0 -0
  311. vismatch/third_party/MINIMA/third_party/LoFTR/src/config/default.py +171 -0
  312. vismatch/third_party/MINIMA/third_party/LoFTR/src/datasets/megadepth.py +127 -0
  313. vismatch/third_party/MINIMA/third_party/LoFTR/src/datasets/sampler.py +77 -0
  314. vismatch/third_party/MINIMA/third_party/LoFTR/src/datasets/scannet.py +114 -0
  315. vismatch/third_party/MINIMA/third_party/LoFTR/src/lightning/data.py +320 -0
  316. vismatch/third_party/MINIMA/third_party/LoFTR/src/lightning/lightning_loftr.py +249 -0
  317. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/__init__.py +2 -0
  318. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/backbone/__init__.py +11 -0
  319. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/backbone/resnet_fpn.py +199 -0
  320. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr.py +81 -0
  321. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/__init__.py +2 -0
  322. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/fine_preprocess.py +59 -0
  323. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/linear_attention.py +81 -0
  324. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/loftr_module/transformer.py +101 -0
  325. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/coarse_matching.py +261 -0
  326. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/cvpr_ds_config.py +50 -0
  327. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/fine_matching.py +74 -0
  328. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/geometry.py +54 -0
  329. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/position_encoding.py +42 -0
  330. vismatch/third_party/MINIMA/third_party/LoFTR/src/loftr/utils/supervision.py +151 -0
  331. vismatch/third_party/MINIMA/third_party/LoFTR/src/losses/loftr_loss.py +192 -0
  332. vismatch/third_party/MINIMA/third_party/LoFTR/src/optimizers/__init__.py +42 -0
  333. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/augment.py +55 -0
  334. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/comm.py +265 -0
  335. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/dataloader.py +23 -0
  336. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/dataset.py +185 -0
  337. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/metrics.py +193 -0
  338. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/misc.py +101 -0
  339. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/plotting.py +154 -0
  340. vismatch/third_party/MINIMA/third_party/LoFTR/src/utils/profiler.py +39 -0
  341. vismatch/third_party/MINIMA/third_party/LoFTR/test.py +68 -0
  342. vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/demo_superglue.py +259 -0
  343. vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/match_pairs.py +425 -0
  344. vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/__init__.py +0 -0
  345. vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/matching.py +84 -0
  346. vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/superglue.py +283 -0
  347. vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/superpoint.py +202 -0
  348. vismatch/third_party/MINIMA/third_party/LoFTR/third_party/SuperGluePretrainedNetwork/models/utils.py +555 -0
  349. vismatch/third_party/MINIMA/third_party/LoFTR/train.py +123 -0
  350. vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_3D_effect.py +47 -0
  351. vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_fundamental.py +34 -0
  352. vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_match.py +50 -0
  353. vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
  354. vismatch/third_party/MINIMA/third_party/RoMa/demo/demo_match_tiny.py +77 -0
  355. vismatch/third_party/MINIMA/third_party/RoMa/experiments/eval_roma_outdoor.py +57 -0
  356. vismatch/third_party/MINIMA/third_party/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
  357. vismatch/third_party/MINIMA/third_party/RoMa/experiments/roma_indoor.py +320 -0
  358. vismatch/third_party/MINIMA/third_party/RoMa/experiments/train_roma_outdoor.py +307 -0
  359. vismatch/third_party/MINIMA/third_party/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
  360. vismatch/third_party/MINIMA/third_party/RoMa/romatch/__init__.py +8 -0
  361. vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/__init__.py +6 -0
  362. vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
  363. vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
  364. vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
  365. vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
  366. vismatch/third_party/MINIMA/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
  367. vismatch/third_party/MINIMA/third_party/RoMa/romatch/checkpointing/__init__.py +1 -0
  368. vismatch/third_party/MINIMA/third_party/RoMa/romatch/checkpointing/checkpoint.py +60 -0
  369. vismatch/third_party/MINIMA/third_party/RoMa/romatch/datasets/__init__.py +2 -0
  370. vismatch/third_party/MINIMA/third_party/RoMa/romatch/datasets/megadepth.py +232 -0
  371. vismatch/third_party/MINIMA/third_party/RoMa/romatch/datasets/scannet.py +160 -0
  372. vismatch/third_party/MINIMA/third_party/RoMa/romatch/losses/__init__.py +1 -0
  373. vismatch/third_party/MINIMA/third_party/RoMa/romatch/losses/robust_loss.py +161 -0
  374. vismatch/third_party/MINIMA/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py +160 -0
  375. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/__init__.py +1 -0
  376. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/encoders.py +122 -0
  377. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/matcher.py +766 -0
  378. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/model_zoo/__init__.py +73 -0
  379. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/model_zoo/roma_models.py +170 -0
  380. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/tiny.py +304 -0
  381. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/__init__.py +48 -0
  382. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/dinov2.py +359 -0
  383. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/__init__.py +12 -0
  384. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/attention.py +81 -0
  385. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/block.py +252 -0
  386. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/dino_head.py +59 -0
  387. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/drop_path.py +35 -0
  388. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py +28 -0
  389. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/mlp.py +41 -0
  390. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py +89 -0
  391. vismatch/third_party/MINIMA/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py +63 -0
  392. vismatch/third_party/MINIMA/third_party/RoMa/romatch/train/__init__.py +1 -0
  393. vismatch/third_party/MINIMA/third_party/RoMa/romatch/train/train.py +102 -0
  394. vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/__init__.py +16 -0
  395. vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/kde.py +13 -0
  396. vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/local_correlation.py +48 -0
  397. vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/transforms.py +118 -0
  398. vismatch/third_party/MINIMA/third_party/RoMa/romatch/utils/utils.py +662 -0
  399. vismatch/third_party/MINIMA/third_party/RoMa/setup.py +9 -0
  400. vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/__init__.py +0 -0
  401. vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/base.py +35 -0
  402. vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/megadepth_trainval_840.py +22 -0
  403. vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py +23 -0
  404. vismatch/third_party/MINIMA/third_party/XoFTR/configs/data/pretrain.py +8 -0
  405. vismatch/third_party/MINIMA/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py +17 -0
  406. vismatch/third_party/MINIMA/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py +12 -0
  407. vismatch/third_party/MINIMA/third_party/XoFTR/pretrain.py +125 -0
  408. vismatch/third_party/MINIMA/third_party/XoFTR/src/__init__.py +0 -0
  409. vismatch/third_party/MINIMA/third_party/XoFTR/src/config/default.py +203 -0
  410. vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/megadepth.py +143 -0
  411. vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/pretrain_dataset.py +156 -0
  412. vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/sampler.py +77 -0
  413. vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/scannet.py +114 -0
  414. vismatch/third_party/MINIMA/third_party/XoFTR/src/datasets/vistir.py +109 -0
  415. vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/data.py +346 -0
  416. vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/data_pretrain.py +125 -0
  417. vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/lightning_xoftr.py +334 -0
  418. vismatch/third_party/MINIMA/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py +171 -0
  419. vismatch/third_party/MINIMA/third_party/XoFTR/src/losses/xoftr_loss.py +170 -0
  420. vismatch/third_party/MINIMA/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py +37 -0
  421. vismatch/third_party/MINIMA/third_party/XoFTR/src/optimizers/__init__.py +42 -0
  422. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/augment.py +113 -0
  423. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/comm.py +265 -0
  424. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/data_io.py +144 -0
  425. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/dataloader.py +23 -0
  426. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/dataset.py +279 -0
  427. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/metrics.py +211 -0
  428. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/misc.py +101 -0
  429. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/plotting.py +227 -0
  430. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/pretrain_utils.py +83 -0
  431. vismatch/third_party/MINIMA/third_party/XoFTR/src/utils/profiler.py +39 -0
  432. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/__init__.py +2 -0
  433. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/backbone/__init__.py +1 -0
  434. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/backbone/resnet.py +95 -0
  435. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/utils/geometry.py +107 -0
  436. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/utils/position_encoding.py +36 -0
  437. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/utils/supervision.py +290 -0
  438. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr.py +94 -0
  439. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py +4 -0
  440. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py +305 -0
  441. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py +170 -0
  442. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py +321 -0
  443. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py +81 -0
  444. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py +101 -0
  445. vismatch/third_party/MINIMA/third_party/XoFTR/src/xoftr/xoftr_pretrain.py +209 -0
  446. vismatch/third_party/MINIMA/third_party/XoFTR/test.py +68 -0
  447. vismatch/third_party/MINIMA/third_party/XoFTR/test_relative_pose.py +330 -0
  448. vismatch/third_party/MINIMA/third_party/XoFTR/train.py +126 -0
  449. vismatch/third_party/MatchAnything/app.py +27 -0
  450. vismatch/third_party/MatchAnything/imcui/__init__.py +0 -0
  451. vismatch/third_party/MatchAnything/imcui/api/__init__.py +47 -0
  452. vismatch/third_party/MatchAnything/imcui/api/client.py +232 -0
  453. vismatch/third_party/MatchAnything/imcui/api/core.py +308 -0
  454. vismatch/third_party/MatchAnything/imcui/api/server.py +170 -0
  455. vismatch/third_party/MatchAnything/imcui/hloc/__init__.py +65 -0
  456. vismatch/third_party/MatchAnything/imcui/hloc/colmap_from_nvm.py +216 -0
  457. vismatch/third_party/MatchAnything/imcui/hloc/extract_features.py +607 -0
  458. vismatch/third_party/MatchAnything/imcui/hloc/extractors/__init__.py +0 -0
  459. vismatch/third_party/MatchAnything/imcui/hloc/extractors/alike.py +61 -0
  460. vismatch/third_party/MatchAnything/imcui/hloc/extractors/aliked.py +32 -0
  461. vismatch/third_party/MatchAnything/imcui/hloc/extractors/cosplace.py +44 -0
  462. vismatch/third_party/MatchAnything/imcui/hloc/extractors/d2net.py +60 -0
  463. vismatch/third_party/MatchAnything/imcui/hloc/extractors/darkfeat.py +44 -0
  464. vismatch/third_party/MatchAnything/imcui/hloc/extractors/dedode.py +86 -0
  465. vismatch/third_party/MatchAnything/imcui/hloc/extractors/dir.py +78 -0
  466. vismatch/third_party/MatchAnything/imcui/hloc/extractors/disk.py +35 -0
  467. vismatch/third_party/MatchAnything/imcui/hloc/extractors/dog.py +135 -0
  468. vismatch/third_party/MatchAnything/imcui/hloc/extractors/eigenplaces.py +57 -0
  469. vismatch/third_party/MatchAnything/imcui/hloc/extractors/example.py +56 -0
  470. vismatch/third_party/MatchAnything/imcui/hloc/extractors/fire.py +72 -0
  471. vismatch/third_party/MatchAnything/imcui/hloc/extractors/fire_local.py +84 -0
  472. vismatch/third_party/MatchAnything/imcui/hloc/extractors/lanet.py +63 -0
  473. vismatch/third_party/MatchAnything/imcui/hloc/extractors/netvlad.py +146 -0
  474. vismatch/third_party/MatchAnything/imcui/hloc/extractors/openibl.py +26 -0
  475. vismatch/third_party/MatchAnything/imcui/hloc/extractors/r2d2.py +73 -0
  476. vismatch/third_party/MatchAnything/imcui/hloc/extractors/rekd.py +60 -0
  477. vismatch/third_party/MatchAnything/imcui/hloc/extractors/rord.py +59 -0
  478. vismatch/third_party/MatchAnything/imcui/hloc/extractors/sfd2.py +44 -0
  479. vismatch/third_party/MatchAnything/imcui/hloc/extractors/sift.py +216 -0
  480. vismatch/third_party/MatchAnything/imcui/hloc/extractors/superpoint.py +51 -0
  481. vismatch/third_party/MatchAnything/imcui/hloc/extractors/xfeat.py +33 -0
  482. vismatch/third_party/MatchAnything/imcui/hloc/localize_inloc.py +179 -0
  483. vismatch/third_party/MatchAnything/imcui/hloc/localize_sfm.py +243 -0
  484. vismatch/third_party/MatchAnything/imcui/hloc/match_dense.py +1158 -0
  485. vismatch/third_party/MatchAnything/imcui/hloc/match_features.py +459 -0
  486. vismatch/third_party/MatchAnything/imcui/hloc/matchers/__init__.py +3 -0
  487. vismatch/third_party/MatchAnything/imcui/hloc/matchers/adalam.py +68 -0
  488. vismatch/third_party/MatchAnything/imcui/hloc/matchers/aspanformer.py +66 -0
  489. vismatch/third_party/MatchAnything/imcui/hloc/matchers/cotr.py +77 -0
  490. vismatch/third_party/MatchAnything/imcui/hloc/matchers/dkm.py +53 -0
  491. vismatch/third_party/MatchAnything/imcui/hloc/matchers/dual_softmax.py +71 -0
  492. vismatch/third_party/MatchAnything/imcui/hloc/matchers/duster.py +109 -0
  493. vismatch/third_party/MatchAnything/imcui/hloc/matchers/eloftr.py +97 -0
  494. vismatch/third_party/MatchAnything/imcui/hloc/matchers/gim.py +200 -0
  495. vismatch/third_party/MatchAnything/imcui/hloc/matchers/gluestick.py +99 -0
  496. vismatch/third_party/MatchAnything/imcui/hloc/matchers/imp.py +50 -0
  497. vismatch/third_party/MatchAnything/imcui/hloc/matchers/lightglue.py +67 -0
  498. vismatch/third_party/MatchAnything/imcui/hloc/matchers/loftr.py +58 -0
  499. vismatch/third_party/MatchAnything/imcui/hloc/matchers/mast3r.py +96 -0
  500. vismatch/third_party/MatchAnything/imcui/hloc/matchers/matchanything.py +191 -0
  501. vismatch/third_party/MatchAnything/imcui/hloc/matchers/mickey.py +50 -0
  502. vismatch/third_party/MatchAnything/imcui/hloc/matchers/nearest_neighbor.py +66 -0
  503. vismatch/third_party/MatchAnything/imcui/hloc/matchers/omniglue.py +80 -0
  504. vismatch/third_party/MatchAnything/imcui/hloc/matchers/roma.py +80 -0
  505. vismatch/third_party/MatchAnything/imcui/hloc/matchers/sgmnet.py +106 -0
  506. vismatch/third_party/MatchAnything/imcui/hloc/matchers/sold2.py +144 -0
  507. vismatch/third_party/MatchAnything/imcui/hloc/matchers/superglue.py +33 -0
  508. vismatch/third_party/MatchAnything/imcui/hloc/matchers/topicfm.py +60 -0
  509. vismatch/third_party/MatchAnything/imcui/hloc/matchers/xfeat_dense.py +54 -0
  510. vismatch/third_party/MatchAnything/imcui/hloc/matchers/xfeat_lightglue.py +48 -0
  511. vismatch/third_party/MatchAnything/imcui/hloc/matchers/xoftr.py +90 -0
  512. vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_covisibility.py +60 -0
  513. vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_exhaustive.py +64 -0
  514. vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_poses.py +68 -0
  515. vismatch/third_party/MatchAnything/imcui/hloc/pairs_from_retrieval.py +133 -0
  516. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/__init__.py +0 -0
  517. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/localize.py +89 -0
  518. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/prepare_reference.py +51 -0
  519. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/4Seasons/utils.py +231 -0
  520. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/__init__.py +0 -0
  521. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/create_gt_sfm.py +134 -0
  522. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/pipeline.py +139 -0
  523. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/7Scenes/utils.py +34 -0
  524. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen/__init__.py +0 -0
  525. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen/pipeline.py +109 -0
  526. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen_v1_1/__init__.py +0 -0
  527. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen_v1_1/pipeline.py +104 -0
  528. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Aachen_v1_1/pipeline_loftr.py +104 -0
  529. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/CMU/__init__.py +0 -0
  530. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/CMU/pipeline.py +133 -0
  531. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Cambridge/__init__.py +0 -0
  532. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Cambridge/pipeline.py +140 -0
  533. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/Cambridge/utils.py +145 -0
  534. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/RobotCar/__init__.py +0 -0
  535. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/RobotCar/colmap_from_nvm.py +176 -0
  536. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/RobotCar/pipeline.py +143 -0
  537. vismatch/third_party/MatchAnything/imcui/hloc/pipelines/__init__.py +0 -0
  538. vismatch/third_party/MatchAnything/imcui/hloc/reconstruction.py +194 -0
  539. vismatch/third_party/MatchAnything/imcui/hloc/triangulation.py +311 -0
  540. vismatch/third_party/MatchAnything/imcui/hloc/utils/__init__.py +12 -0
  541. vismatch/third_party/MatchAnything/imcui/hloc/utils/base_model.py +56 -0
  542. vismatch/third_party/MatchAnything/imcui/hloc/utils/database.py +412 -0
  543. vismatch/third_party/MatchAnything/imcui/hloc/utils/geometry.py +16 -0
  544. vismatch/third_party/MatchAnything/imcui/hloc/utils/io.py +77 -0
  545. vismatch/third_party/MatchAnything/imcui/hloc/utils/parsers.py +59 -0
  546. vismatch/third_party/MatchAnything/imcui/hloc/utils/read_write_model.py +588 -0
  547. vismatch/third_party/MatchAnything/imcui/hloc/utils/viz.py +146 -0
  548. vismatch/third_party/MatchAnything/imcui/hloc/utils/viz_3d.py +203 -0
  549. vismatch/third_party/MatchAnything/imcui/hloc/visualization.py +178 -0
  550. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/configs/models/eloftr_model.py +128 -0
  551. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/configs/models/roma_model.py +27 -0
  552. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/notebooks/notebooks_utils/__init__.py +1 -0
  553. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/notebooks/notebooks_utils/plotting.py +344 -0
  554. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/__init__.py +0 -0
  555. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/config/default.py +344 -0
  556. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/datasets/common_data_pair.py +214 -0
  557. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/lightning/lightning_loftr.py +343 -0
  558. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/__init__.py +1 -0
  559. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/__init__.py +61 -0
  560. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/repvgg.py +319 -0
  561. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/resnet_fpn.py +1094 -0
  562. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/backbone/s2dnet.py +131 -0
  563. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr.py +273 -0
  564. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/__init__.py +2 -0
  565. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/fine_preprocess.py +350 -0
  566. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/linear_attention.py +217 -0
  567. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer.py +1768 -0
  568. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/loftr_module/transformer_utils.py +76 -0
  569. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/coarse_matching.py +266 -0
  570. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/fine_matching.py +493 -0
  571. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/geometry.py +298 -0
  572. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/position_encoding.py +131 -0
  573. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/loftr/utils/supervision.py +475 -0
  574. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/optimizers/__init__.py +50 -0
  575. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/__init__.py +0 -0
  576. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/augment.py +55 -0
  577. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/__init__.py +0 -0
  578. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/database.py +417 -0
  579. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/eval_helper.py +232 -0
  580. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap/read_write_model.py +509 -0
  581. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/colmap.py +530 -0
  582. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/comm.py +265 -0
  583. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/dataloader.py +23 -0
  584. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/dataset.py +518 -0
  585. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/easydict.py +148 -0
  586. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/geometry.py +366 -0
  587. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/homography_utils.py +366 -0
  588. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/metrics.py +445 -0
  589. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/misc.py +101 -0
  590. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/plotting.py +248 -0
  591. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/profiler.py +39 -0
  592. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/ray_utils.py +134 -0
  593. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/sample_homo.py +58 -0
  594. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/src/utils/utils.py +600 -0
  595. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_3D_effect.py +46 -0
  596. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental.py +32 -0
  597. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_fundamental_model_warpper.py +34 -0
  598. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match.py +50 -0
  599. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo/demo_match_opencv_sift.py +43 -0
  600. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/demo_single_pair.py +329 -0
  601. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_indoor.py +320 -0
  602. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/experiments/roma_outdoor.py +327 -0
  603. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/__init__.py +1 -0
  604. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/notebooks/notebooks_utils/plotting.py +331 -0
  605. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/__init__.py +8 -0
  606. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/__init__.py +4 -0
  607. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
  608. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_dense_benchmark.py +106 -0
  609. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/megadepth_pose_estimation_benchmark.py +140 -0
  610. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/benchmarks/scannet_benchmark.py +143 -0
  611. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/__init__.py +1 -0
  612. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/checkpointing/checkpoint.py +60 -0
  613. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/__init__.py +2 -0
  614. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/megadepth.py +230 -0
  615. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/datasets/scannet.py +160 -0
  616. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/__init__.py +1 -0
  617. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/losses/robust_loss.py +157 -0
  618. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/matchanything_roma_model.py +104 -0
  619. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/__init__.py +1 -0
  620. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/blocks.py +241 -0
  621. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/criterion.py +37 -0
  622. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco.py +253 -0
  623. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/croco_downstream.py +122 -0
  624. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/__init__.py +4 -0
  625. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/curope2d.py +40 -0
  626. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/curope/setup.py +34 -0
  627. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/dpt_block.py +450 -0
  628. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/head_downstream.py +58 -0
  629. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/masking.py +25 -0
  630. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/croco/pos_embed.py +159 -0
  631. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/__init__.py +2 -0
  632. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/__init__.py +29 -0
  633. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/base_opt.py +375 -0
  634. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/commons.py +90 -0
  635. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/init_im_poses.py +312 -0
  636. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/optimizer.py +230 -0
  637. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/cloud_opt/pair_viewer.py +125 -0
  638. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/__init__.py +42 -0
  639. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/__init__.py +2 -0
  640. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
  641. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/batched_sampler.py +74 -0
  642. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/base/easy_dataset.py +157 -0
  643. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/co3d.py +146 -0
  644. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/__init__.py +2 -0
  645. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/cropping.py +119 -0
  646. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/datasets/utils/transforms.py +11 -0
  647. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/__init__.py +19 -0
  648. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/dpt_head.py +114 -0
  649. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/linear_head.py +41 -0
  650. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/heads/postprocess.py +58 -0
  651. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/image_pairs.py +83 -0
  652. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/inference.py +165 -0
  653. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/losses.py +297 -0
  654. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/model.py +167 -0
  655. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/optim_factory.py +14 -0
  656. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/patch_embed.py +70 -0
  657. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/post_process.py +60 -0
  658. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/__init__.py +2 -0
  659. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/device.py +76 -0
  660. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/geometry.py +361 -0
  661. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/image.py +104 -0
  662. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/misc.py +121 -0
  663. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/utils/path_to_croco.py +19 -0
  664. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/dust3r/viz.py +320 -0
  665. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/encoders.py +137 -0
  666. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/matcher.py +937 -0
  667. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/__init__.py +53 -0
  668. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/model_zoo/roma_models.py +162 -0
  669. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/__init__.py +47 -0
  670. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/dinov2.py +359 -0
  671. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/__init__.py +12 -0
  672. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/attention.py +81 -0
  673. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/block.py +252 -0
  674. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/dino_head.py +59 -0
  675. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/drop_path.py +35 -0
  676. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/layer_scale.py +28 -0
  677. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/mlp.py +41 -0
  678. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/patch_embed.py +89 -0
  679. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/models/transformer/layers/swiglu_ffn.py +63 -0
  680. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/roma_adpat_model.py +32 -0
  681. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/__init__.py +1 -0
  682. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/train/train.py +102 -0
  683. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/__init__.py +18 -0
  684. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/kde.py +8 -0
  685. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/local_correlation.py +44 -0
  686. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/transforms.py +118 -0
  687. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/roma/utils/utils.py +661 -0
  688. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA/setup.py +9 -0
  689. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/__init__.py +0 -0
  690. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/evaluate_datasets.py +239 -0
  691. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/tools_utils/data_io.py +94 -0
  692. vismatch/third_party/MatchAnything/imcui/third_party/MatchAnything/tools/tools_utils/plot.py +77 -0
  693. vismatch/third_party/MatchAnything/imcui/ui/__init__.py +5 -0
  694. vismatch/third_party/MatchAnything/imcui/ui/app_class.py +824 -0
  695. vismatch/third_party/MatchAnything/imcui/ui/sfm.py +164 -0
  696. vismatch/third_party/MatchAnything/imcui/ui/utils.py +1085 -0
  697. vismatch/third_party/MatchAnything/imcui/ui/viz.py +511 -0
  698. vismatch/third_party/MatchAnything/tests/test_basic.py +111 -0
  699. vismatch/third_party/MatchFormer/config/data/__init__.py +0 -0
  700. vismatch/third_party/MatchFormer/config/data/base.py +35 -0
  701. vismatch/third_party/MatchFormer/config/data/megadepth_test_1500.py +11 -0
  702. vismatch/third_party/MatchFormer/config/data/scannet_test_1500.py +11 -0
  703. vismatch/third_party/MatchFormer/config/defaultmf.py +88 -0
  704. vismatch/third_party/MatchFormer/model/backbone/__init__.py +17 -0
  705. vismatch/third_party/MatchFormer/model/backbone/coarse_matching.py +228 -0
  706. vismatch/third_party/MatchFormer/model/backbone/fine_matching.py +74 -0
  707. vismatch/third_party/MatchFormer/model/backbone/fine_preprocess.py +59 -0
  708. vismatch/third_party/MatchFormer/model/backbone/match_LA_large.py +254 -0
  709. vismatch/third_party/MatchFormer/model/backbone/match_LA_lite.py +254 -0
  710. vismatch/third_party/MatchFormer/model/backbone/match_SEA_large.py +291 -0
  711. vismatch/third_party/MatchFormer/model/backbone/match_SEA_lite.py +291 -0
  712. vismatch/third_party/MatchFormer/model/data.py +320 -0
  713. vismatch/third_party/MatchFormer/model/datasets/dataset.py +231 -0
  714. vismatch/third_party/MatchFormer/model/datasets/megadepth.py +126 -0
  715. vismatch/third_party/MatchFormer/model/datasets/sampler.py +77 -0
  716. vismatch/third_party/MatchFormer/model/datasets/scannet.py +113 -0
  717. vismatch/third_party/MatchFormer/model/lightning_loftr.py +102 -0
  718. vismatch/third_party/MatchFormer/model/matchformer.py +54 -0
  719. vismatch/third_party/MatchFormer/model/utils/augment.py +55 -0
  720. vismatch/third_party/MatchFormer/model/utils/comm.py +265 -0
  721. vismatch/third_party/MatchFormer/model/utils/dataloader.py +23 -0
  722. vismatch/third_party/MatchFormer/model/utils/metrics.py +193 -0
  723. vismatch/third_party/MatchFormer/model/utils/misc.py +101 -0
  724. vismatch/third_party/MatchFormer/model/utils/profiler.py +39 -0
  725. vismatch/third_party/MatchFormer/test.py +55 -0
  726. vismatch/third_party/RIPE/app.py +272 -0
  727. vismatch/third_party/RIPE/demo.py +51 -0
  728. vismatch/third_party/RIPE/ripe/__init__.py +1 -0
  729. vismatch/third_party/RIPE/ripe/benchmarks/imw_2020.py +320 -0
  730. vismatch/third_party/RIPE/ripe/data/__init__.py +0 -0
  731. vismatch/third_party/RIPE/ripe/data/data_transforms.py +204 -0
  732. vismatch/third_party/RIPE/ripe/data/datasets/__init__.py +0 -0
  733. vismatch/third_party/RIPE/ripe/data/datasets/acdc.py +154 -0
  734. vismatch/third_party/RIPE/ripe/data/datasets/dataset_combinator.py +88 -0
  735. vismatch/third_party/RIPE/ripe/data/datasets/disk_imw.py +160 -0
  736. vismatch/third_party/RIPE/ripe/data/datasets/disk_megadepth.py +157 -0
  737. vismatch/third_party/RIPE/ripe/data/datasets/tokyo247.py +134 -0
  738. vismatch/third_party/RIPE/ripe/data/datasets/tokyo_query_v3.py +78 -0
  739. vismatch/third_party/RIPE/ripe/losses/__init__.py +0 -0
  740. vismatch/third_party/RIPE/ripe/losses/contrastive_loss.py +88 -0
  741. vismatch/third_party/RIPE/ripe/matcher/__init__.py +0 -0
  742. vismatch/third_party/RIPE/ripe/matcher/concurrent_matcher.py +97 -0
  743. vismatch/third_party/RIPE/ripe/matcher/pose_estimator_poselib.py +31 -0
  744. vismatch/third_party/RIPE/ripe/model_zoo/__init__.py +1 -0
  745. vismatch/third_party/RIPE/ripe/model_zoo/vgg_hyper.py +39 -0
  746. vismatch/third_party/RIPE/ripe/models/__init__.py +0 -0
  747. vismatch/third_party/RIPE/ripe/models/backbones/__init__.py +0 -0
  748. vismatch/third_party/RIPE/ripe/models/backbones/backbone_base.py +61 -0
  749. vismatch/third_party/RIPE/ripe/models/backbones/vgg.py +99 -0
  750. vismatch/third_party/RIPE/ripe/models/backbones/vgg_utils.py +143 -0
  751. vismatch/third_party/RIPE/ripe/models/ripe.py +303 -0
  752. vismatch/third_party/RIPE/ripe/models/upsampler/hypercolumn_features.py +54 -0
  753. vismatch/third_party/RIPE/ripe/models/upsampler/interpolate_sparse2d.py +37 -0
  754. vismatch/third_party/RIPE/ripe/scheduler/__init__.py +0 -0
  755. vismatch/third_party/RIPE/ripe/scheduler/constant.py +6 -0
  756. vismatch/third_party/RIPE/ripe/scheduler/expDecay.py +26 -0
  757. vismatch/third_party/RIPE/ripe/scheduler/linearLR.py +37 -0
  758. vismatch/third_party/RIPE/ripe/scheduler/linear_with_plateaus.py +44 -0
  759. vismatch/third_party/RIPE/ripe/train.py +410 -0
  760. vismatch/third_party/RIPE/ripe/utils/__init__.py +2 -0
  761. vismatch/third_party/RIPE/ripe/utils/image_utils.py +62 -0
  762. vismatch/third_party/RIPE/ripe/utils/pose_error.py +62 -0
  763. vismatch/third_party/RIPE/ripe/utils/pylogger.py +32 -0
  764. vismatch/third_party/RIPE/ripe/utils/utils.py +192 -0
  765. vismatch/third_party/RIPE/ripe/utils/wandb_utils.py +16 -0
  766. vismatch/third_party/RoMa/demo/demo_3D_effect.py +47 -0
  767. vismatch/third_party/RoMa/demo/demo_fundamental.py +34 -0
  768. vismatch/third_party/RoMa/demo/demo_match.py +50 -0
  769. vismatch/third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
  770. vismatch/third_party/RoMa/demo/demo_match_tiny.py +77 -0
  771. vismatch/third_party/RoMa/experiments/eval_roma_outdoor.py +57 -0
  772. vismatch/third_party/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
  773. vismatch/third_party/RoMa/experiments/roma_indoor.py +320 -0
  774. vismatch/third_party/RoMa/experiments/train_roma_outdoor.py +307 -0
  775. vismatch/third_party/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
  776. vismatch/third_party/RoMa/romatch/__init__.py +8 -0
  777. vismatch/third_party/RoMa/romatch/benchmarks/__init__.py +6 -0
  778. vismatch/third_party/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
  779. vismatch/third_party/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
  780. vismatch/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
  781. vismatch/third_party/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py +119 -0
  782. vismatch/third_party/RoMa/romatch/benchmarks/scannet_benchmark.py +143 -0
  783. vismatch/third_party/RoMa/romatch/checkpointing/__init__.py +1 -0
  784. vismatch/third_party/RoMa/romatch/checkpointing/checkpoint.py +60 -0
  785. vismatch/third_party/RoMa/romatch/datasets/__init__.py +2 -0
  786. vismatch/third_party/RoMa/romatch/datasets/megadepth.py +232 -0
  787. vismatch/third_party/RoMa/romatch/datasets/scannet.py +160 -0
  788. vismatch/third_party/RoMa/romatch/losses/__init__.py +1 -0
  789. vismatch/third_party/RoMa/romatch/losses/robust_loss.py +161 -0
  790. vismatch/third_party/RoMa/romatch/losses/robust_loss_tiny_roma.py +160 -0
  791. vismatch/third_party/RoMa/romatch/models/__init__.py +1 -0
  792. vismatch/third_party/RoMa/romatch/models/encoders.py +122 -0
  793. vismatch/third_party/RoMa/romatch/models/matcher.py +748 -0
  794. vismatch/third_party/RoMa/romatch/models/model_zoo/__init__.py +73 -0
  795. vismatch/third_party/RoMa/romatch/models/model_zoo/roma_models.py +170 -0
  796. vismatch/third_party/RoMa/romatch/models/tiny.py +304 -0
  797. vismatch/third_party/RoMa/romatch/models/transformer/__init__.py +48 -0
  798. vismatch/third_party/RoMa/romatch/models/transformer/dinov2.py +359 -0
  799. vismatch/third_party/RoMa/romatch/models/transformer/layers/__init__.py +12 -0
  800. vismatch/third_party/RoMa/romatch/models/transformer/layers/attention.py +81 -0
  801. vismatch/third_party/RoMa/romatch/models/transformer/layers/block.py +252 -0
  802. vismatch/third_party/RoMa/romatch/models/transformer/layers/dino_head.py +59 -0
  803. vismatch/third_party/RoMa/romatch/models/transformer/layers/drop_path.py +35 -0
  804. vismatch/third_party/RoMa/romatch/models/transformer/layers/layer_scale.py +28 -0
  805. vismatch/third_party/RoMa/romatch/models/transformer/layers/mlp.py +41 -0
  806. vismatch/third_party/RoMa/romatch/models/transformer/layers/patch_embed.py +89 -0
  807. vismatch/third_party/RoMa/romatch/models/transformer/layers/swiglu_ffn.py +63 -0
  808. vismatch/third_party/RoMa/romatch/train/__init__.py +1 -0
  809. vismatch/third_party/RoMa/romatch/train/train.py +102 -0
  810. vismatch/third_party/RoMa/romatch/utils/__init__.py +16 -0
  811. vismatch/third_party/RoMa/romatch/utils/kde.py +13 -0
  812. vismatch/third_party/RoMa/romatch/utils/local_correlation.py +48 -0
  813. vismatch/third_party/RoMa/romatch/utils/transforms.py +118 -0
  814. vismatch/third_party/RoMa/romatch/utils/utils.py +654 -0
  815. vismatch/third_party/RoMa/setup.py +9 -0
  816. vismatch/third_party/RoMaV2/demo/demo_covariance.py +52 -0
  817. vismatch/third_party/RoMaV2/demo/demo_match.py +55 -0
  818. vismatch/third_party/RoMaV2/src/romav2/__init__.py +8 -0
  819. vismatch/third_party/RoMaV2/src/romav2/benchmarks/__init__.py +4 -0
  820. vismatch/third_party/RoMaV2/src/romav2/benchmarks/mega1500.py +115 -0
  821. vismatch/third_party/RoMaV2/src/romav2/benchmarks/satast.py +463 -0
  822. vismatch/third_party/RoMaV2/src/romav2/benchmarks/scannet1500.py +125 -0
  823. vismatch/third_party/RoMaV2/src/romav2/benchmarks/wxbs.py +104 -0
  824. vismatch/third_party/RoMaV2/src/romav2/device.py +9 -0
  825. vismatch/third_party/RoMaV2/src/romav2/dpt.py +516 -0
  826. vismatch/third_party/RoMaV2/src/romav2/features.py +190 -0
  827. vismatch/third_party/RoMaV2/src/romav2/geometry.py +261 -0
  828. vismatch/third_party/RoMaV2/src/romav2/io.py +24 -0
  829. vismatch/third_party/RoMaV2/src/romav2/local_correlation.py +152 -0
  830. vismatch/third_party/RoMaV2/src/romav2/logging.py +97 -0
  831. vismatch/third_party/RoMaV2/src/romav2/matcher.py +207 -0
  832. vismatch/third_party/RoMaV2/src/romav2/normalizers.py +17 -0
  833. vismatch/third_party/RoMaV2/src/romav2/refiner.py +277 -0
  834. vismatch/third_party/RoMaV2/src/romav2/romav2.py +533 -0
  835. vismatch/third_party/RoMaV2/src/romav2/types.py +75 -0
  836. vismatch/third_party/RoMaV2/src/romav2/vis.py +36 -0
  837. vismatch/third_party/RoMaV2/src/romav2/vit/__init__.py +304 -0
  838. vismatch/third_party/RoMaV2/src/romav2/vit/attention.py +181 -0
  839. vismatch/third_party/RoMaV2/src/romav2/vit/block.py +293 -0
  840. vismatch/third_party/RoMaV2/src/romav2/vit/ffn_layers.py +83 -0
  841. vismatch/third_party/RoMaV2/src/romav2/vit/layer_scale.py +29 -0
  842. vismatch/third_party/RoMaV2/src/romav2/vit/patch_embed.py +94 -0
  843. vismatch/third_party/RoMaV2/src/romav2/vit/rms_norm.py +24 -0
  844. vismatch/third_party/RoMaV2/src/romav2/vit/rope.py +133 -0
  845. vismatch/third_party/RoMaV2/src/romav2/vit/rope_mixed.py +111 -0
  846. vismatch/third_party/RoMaV2/src/romav2/vit/utils.py +48 -0
  847. vismatch/third_party/RoMaV2/tests/test_bidirectional.py +93 -0
  848. vismatch/third_party/RoMaV2/tests/test_fps.py +49 -0
  849. vismatch/third_party/RoMaV2/tests/test_mega1500.py +22 -0
  850. vismatch/third_party/RoMaV2/tests/test_scannet1500.py +21 -0
  851. vismatch/third_party/RoMaV2/tests/test_smoke.py +15 -0
  852. vismatch/third_party/Se2_LoFTR/configs/data/__init__.py +0 -0
  853. vismatch/third_party/Se2_LoFTR/configs/data/base.py +35 -0
  854. vismatch/third_party/Se2_LoFTR/configs/data/megadepth_test_1500.py +11 -0
  855. vismatch/third_party/Se2_LoFTR/configs/data/megadepth_trainval_640.py +22 -0
  856. vismatch/third_party/Se2_LoFTR/configs/data/megadepth_trainval_840.py +22 -0
  857. vismatch/third_party/Se2_LoFTR/configs/data/scannet_test_1500.py +11 -0
  858. vismatch/third_party/Se2_LoFTR/configs/data/scannet_trainval.py +17 -0
  859. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds.py +6 -0
  860. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py +8 -0
  861. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot.py +6 -0
  862. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py +8 -0
  863. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ds.py +5 -0
  864. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ds_dense.py +7 -0
  865. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ot.py +5 -0
  866. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/loftr_ot_dense.py +7 -0
  867. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval.py +16 -0
  868. vismatch/third_party/Se2_LoFTR/configs/loftr/indoor/scannet/loftr_ds_eval_new.py +18 -0
  869. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py +16 -0
  870. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py +17 -0
  871. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py +16 -0
  872. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py +17 -0
  873. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds.py +17 -0
  874. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_dense.py +17 -0
  875. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2.py +20 -0
  876. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2_dense.py +23 -0
  877. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2_dense_8rot.py +23 -0
  878. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ds_e2_dense_big.py +22 -0
  879. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ot.py +17 -0
  880. vismatch/third_party/Se2_LoFTR/configs/loftr/outdoor/loftr_ot_dense.py +18 -0
  881. vismatch/third_party/Se2_LoFTR/demo/demo_loftr.py +240 -0
  882. vismatch/third_party/Se2_LoFTR/src/__init__.py +0 -0
  883. vismatch/third_party/Se2_LoFTR/src/config/default.py +173 -0
  884. vismatch/third_party/Se2_LoFTR/src/datasets/megadepth.py +127 -0
  885. vismatch/third_party/Se2_LoFTR/src/datasets/sampler.py +77 -0
  886. vismatch/third_party/Se2_LoFTR/src/datasets/scannet.py +114 -0
  887. vismatch/third_party/Se2_LoFTR/src/lightning/data.py +320 -0
  888. vismatch/third_party/Se2_LoFTR/src/lightning/lightning_loftr.py +249 -0
  889. vismatch/third_party/Se2_LoFTR/src/loftr/__init__.py +2 -0
  890. vismatch/third_party/Se2_LoFTR/src/loftr/backbone/__init__.py +17 -0
  891. vismatch/third_party/Se2_LoFTR/src/loftr/backbone/resnet_e2.py +170 -0
  892. vismatch/third_party/Se2_LoFTR/src/loftr/backbone/resnet_fpn.py +199 -0
  893. vismatch/third_party/Se2_LoFTR/src/loftr/loftr.py +81 -0
  894. vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/__init__.py +2 -0
  895. vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/fine_preprocess.py +59 -0
  896. vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/linear_attention.py +81 -0
  897. vismatch/third_party/Se2_LoFTR/src/loftr/loftr_module/transformer.py +101 -0
  898. vismatch/third_party/Se2_LoFTR/src/loftr/utils/coarse_matching.py +261 -0
  899. vismatch/third_party/Se2_LoFTR/src/loftr/utils/cvpr_ds_config.py +50 -0
  900. vismatch/third_party/Se2_LoFTR/src/loftr/utils/fine_matching.py +74 -0
  901. vismatch/third_party/Se2_LoFTR/src/loftr/utils/geometry.py +54 -0
  902. vismatch/third_party/Se2_LoFTR/src/loftr/utils/position_encoding.py +42 -0
  903. vismatch/third_party/Se2_LoFTR/src/loftr/utils/supervision.py +151 -0
  904. vismatch/third_party/Se2_LoFTR/src/losses/loftr_loss.py +192 -0
  905. vismatch/third_party/Se2_LoFTR/src/optimizers/__init__.py +42 -0
  906. vismatch/third_party/Se2_LoFTR/src/utils/augment.py +55 -0
  907. vismatch/third_party/Se2_LoFTR/src/utils/comm.py +265 -0
  908. vismatch/third_party/Se2_LoFTR/src/utils/dataloader.py +23 -0
  909. vismatch/third_party/Se2_LoFTR/src/utils/dataset.py +185 -0
  910. vismatch/third_party/Se2_LoFTR/src/utils/metrics.py +193 -0
  911. vismatch/third_party/Se2_LoFTR/src/utils/misc.py +104 -0
  912. vismatch/third_party/Se2_LoFTR/src/utils/plotting.py +154 -0
  913. vismatch/third_party/Se2_LoFTR/src/utils/profiler.py +39 -0
  914. vismatch/third_party/Se2_LoFTR/test.py +68 -0
  915. vismatch/third_party/Se2_LoFTR/train.py +123 -0
  916. vismatch/third_party/SphereGlue/demo_SphereGlue.py +141 -0
  917. vismatch/third_party/SphereGlue/model/sphereglue.py +230 -0
  918. vismatch/third_party/SphereGlue/utils/Utils.py +191 -0
  919. vismatch/third_party/SphereGlue/utils/demo_mydataset.py +119 -0
  920. vismatch/third_party/Steerers/rotation_steerers/matchers/dual_softmax_matcher.py +44 -0
  921. vismatch/third_party/Steerers/rotation_steerers/matchers/max_matches.py +205 -0
  922. vismatch/third_party/Steerers/rotation_steerers/matchers/max_similarity.py +115 -0
  923. vismatch/third_party/Steerers/rotation_steerers/steerers.py +37 -0
  924. vismatch/third_party/Steerers/setup.py +14 -0
  925. vismatch/third_party/TopicFM/configs/megadepth_test.py +17 -0
  926. vismatch/third_party/TopicFM/configs/megadepth_test_topicfmfast.py +17 -0
  927. vismatch/third_party/TopicFM/configs/megadepth_test_topicfmplus.py +20 -0
  928. vismatch/third_party/TopicFM/configs/megadepth_train.py +36 -0
  929. vismatch/third_party/TopicFM/configs/megadepth_train_topicfmfast.py +34 -0
  930. vismatch/third_party/TopicFM/configs/megadepth_train_topicfmplus.py +37 -0
  931. vismatch/third_party/TopicFM/configs/scannet_test.py +15 -0
  932. vismatch/third_party/TopicFM/configs/scannet_test_topicfmfast.py +15 -0
  933. vismatch/third_party/TopicFM/configs/scannet_test_topicfmplus.py +19 -0
  934. vismatch/third_party/TopicFM/src/__init__.py +11 -0
  935. vismatch/third_party/TopicFM/src/config/default.py +174 -0
  936. vismatch/third_party/TopicFM/src/datasets/aachen.py +29 -0
  937. vismatch/third_party/TopicFM/src/datasets/custom_dataloader.py +126 -0
  938. vismatch/third_party/TopicFM/src/datasets/inloc.py +29 -0
  939. vismatch/third_party/TopicFM/src/datasets/megadepth.py +170 -0
  940. vismatch/third_party/TopicFM/src/datasets/sampler.py +77 -0
  941. vismatch/third_party/TopicFM/src/datasets/scannet.py +115 -0
  942. vismatch/third_party/TopicFM/src/lightning_trainer/data.py +292 -0
  943. vismatch/third_party/TopicFM/src/lightning_trainer/trainer.py +244 -0
  944. vismatch/third_party/TopicFM/src/losses/loss.py +228 -0
  945. vismatch/third_party/TopicFM/src/models/__init__.py +1 -0
  946. vismatch/third_party/TopicFM/src/models/backbone/__init__.py +12 -0
  947. vismatch/third_party/TopicFM/src/models/backbone/convnext.py +165 -0
  948. vismatch/third_party/TopicFM/src/models/backbone/fpn.py +114 -0
  949. vismatch/third_party/TopicFM/src/models/modules/__init__.py +2 -0
  950. vismatch/third_party/TopicFM/src/models/modules/encoder.py +266 -0
  951. vismatch/third_party/TopicFM/src/models/modules/fine_preprocess.py +59 -0
  952. vismatch/third_party/TopicFM/src/models/modules/linear_attention.py +84 -0
  953. vismatch/third_party/TopicFM/src/models/topic_fm.py +100 -0
  954. vismatch/third_party/TopicFM/src/models/utils/coarse_matching.py +213 -0
  955. vismatch/third_party/TopicFM/src/models/utils/fine_matching.py +172 -0
  956. vismatch/third_party/TopicFM/src/models/utils/geometry.py +54 -0
  957. vismatch/third_party/TopicFM/src/models/utils/supervision.py +167 -0
  958. vismatch/third_party/TopicFM/src/optimizers/__init__.py +42 -0
  959. vismatch/third_party/TopicFM/src/utils/augment.py +55 -0
  960. vismatch/third_party/TopicFM/src/utils/comm.py +265 -0
  961. vismatch/third_party/TopicFM/src/utils/dataloader.py +23 -0
  962. vismatch/third_party/TopicFM/src/utils/dataset.py +206 -0
  963. vismatch/third_party/TopicFM/src/utils/metrics.py +193 -0
  964. vismatch/third_party/TopicFM/src/utils/misc.py +101 -0
  965. vismatch/third_party/TopicFM/src/utils/plotting.py +313 -0
  966. vismatch/third_party/TopicFM/src/utils/profiler.py +39 -0
  967. vismatch/third_party/TopicFM/test.py +70 -0
  968. vismatch/third_party/TopicFM/third_party/__init__.py +0 -0
  969. vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/indoor/aspan_test.py +7 -0
  970. vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/indoor/aspan_train.py +8 -0
  971. vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/outdoor/aspan_test.py +18 -0
  972. vismatch/third_party/TopicFM/third_party/aspanformer/configs/aspan/outdoor/aspan_train.py +17 -0
  973. vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/__init__.py +0 -0
  974. vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/base.py +35 -0
  975. vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/megadepth_test_1500.py +13 -0
  976. vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/megadepth_trainval_832.py +22 -0
  977. vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/scannet_test_1500.py +11 -0
  978. vismatch/third_party/TopicFM/third_party/aspanformer/configs/data/scannet_trainval.py +17 -0
  979. vismatch/third_party/TopicFM/third_party/aspanformer/demo/demo.py +63 -0
  980. vismatch/third_party/TopicFM/third_party/aspanformer/demo/demo_utils.py +44 -0
  981. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/__init__.py +2 -0
  982. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/__init__.py +3 -0
  983. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/attention.py +198 -0
  984. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/fine_preprocess.py +59 -0
  985. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/loftr.py +112 -0
  986. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspan_module/transformer.py +244 -0
  987. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/aspanformer.py +133 -0
  988. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/backbone/__init__.py +11 -0
  989. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/backbone/resnet_fpn.py +199 -0
  990. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/coarse_matching.py +331 -0
  991. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/cvpr_ds_config.py +50 -0
  992. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/fine_matching.py +74 -0
  993. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/geometry.py +54 -0
  994. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/position_encoding.py +61 -0
  995. vismatch/third_party/TopicFM/third_party/aspanformer/src/ASpanFormer/utils/supervision.py +151 -0
  996. vismatch/third_party/TopicFM/third_party/aspanformer/src/__init__.py +0 -0
  997. vismatch/third_party/TopicFM/third_party/aspanformer/src/config/default.py +180 -0
  998. vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/__init__.py +3 -0
  999. vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/megadepth.py +127 -0
  1000. vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/sampler.py +77 -0
  1001. vismatch/third_party/TopicFM/third_party/aspanformer/src/datasets/scannet.py +113 -0
  1002. vismatch/third_party/TopicFM/third_party/aspanformer/src/lightning/data.py +326 -0
  1003. vismatch/third_party/TopicFM/third_party/aspanformer/src/lightning/lightning_aspanformer.py +276 -0
  1004. vismatch/third_party/TopicFM/third_party/aspanformer/src/losses/aspan_loss.py +231 -0
  1005. vismatch/third_party/TopicFM/third_party/aspanformer/src/optimizers/__init__.py +42 -0
  1006. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/augment.py +55 -0
  1007. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/comm.py +265 -0
  1008. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/dataloader.py +23 -0
  1009. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/dataset.py +222 -0
  1010. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/metrics.py +260 -0
  1011. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/misc.py +139 -0
  1012. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/plotting.py +219 -0
  1013. vismatch/third_party/TopicFM/third_party/aspanformer/src/utils/profiler.py +39 -0
  1014. vismatch/third_party/TopicFM/third_party/aspanformer/test.py +69 -0
  1015. vismatch/third_party/TopicFM/third_party/aspanformer/tools/SensorData.py +125 -0
  1016. vismatch/third_party/TopicFM/third_party/aspanformer/tools/extract.py +47 -0
  1017. vismatch/third_party/TopicFM/third_party/aspanformer/tools/preprocess_scene.py +242 -0
  1018. vismatch/third_party/TopicFM/third_party/aspanformer/tools/reader.py +39 -0
  1019. vismatch/third_party/TopicFM/third_party/aspanformer/tools/undistort_mega.py +69 -0
  1020. vismatch/third_party/TopicFM/third_party/aspanformer/train.py +134 -0
  1021. vismatch/third_party/TopicFM/third_party/loftr/configs/data/__init__.py +0 -0
  1022. vismatch/third_party/TopicFM/third_party/loftr/configs/data/base.py +35 -0
  1023. vismatch/third_party/TopicFM/third_party/loftr/configs/data/megadepth_test_1500.py +11 -0
  1024. vismatch/third_party/TopicFM/third_party/loftr/configs/data/megadepth_trainval_640.py +22 -0
  1025. vismatch/third_party/TopicFM/third_party/loftr/configs/data/megadepth_trainval_840.py +22 -0
  1026. vismatch/third_party/TopicFM/third_party/loftr/configs/data/scannet_test_1500.py +11 -0
  1027. vismatch/third_party/TopicFM/third_party/loftr/configs/data/scannet_trainval.py +17 -0
  1028. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ds.py +6 -0
  1029. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py +8 -0
  1030. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ot.py +6 -0
  1031. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py +8 -0
  1032. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ds.py +5 -0
  1033. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ds_dense.py +7 -0
  1034. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ot.py +5 -0
  1035. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/loftr_ot_dense.py +7 -0
  1036. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/scannet/loftr_ds_eval.py +16 -0
  1037. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/indoor/scannet/loftr_ds_eval_new.py +18 -0
  1038. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py +16 -0
  1039. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py +17 -0
  1040. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py +16 -0
  1041. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py +17 -0
  1042. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ds.py +15 -0
  1043. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ds_dense.py +16 -0
  1044. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ot.py +15 -0
  1045. vismatch/third_party/TopicFM/third_party/loftr/configs/loftr/outdoor/loftr_ot_dense.py +16 -0
  1046. vismatch/third_party/TopicFM/third_party/loftr/demo/demo_loftr.py +240 -0
  1047. vismatch/third_party/TopicFM/third_party/loftr/src/__init__.py +0 -0
  1048. vismatch/third_party/TopicFM/third_party/loftr/src/config/default.py +171 -0
  1049. vismatch/third_party/TopicFM/third_party/loftr/src/datasets/megadepth.py +127 -0
  1050. vismatch/third_party/TopicFM/third_party/loftr/src/datasets/sampler.py +77 -0
  1051. vismatch/third_party/TopicFM/third_party/loftr/src/datasets/scannet.py +114 -0
  1052. vismatch/third_party/TopicFM/third_party/loftr/src/lightning/data.py +320 -0
  1053. vismatch/third_party/TopicFM/third_party/loftr/src/lightning/lightning_loftr.py +249 -0
  1054. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/__init__.py +2 -0
  1055. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/backbone/__init__.py +11 -0
  1056. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/backbone/resnet_fpn.py +199 -0
  1057. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr.py +81 -0
  1058. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/__init__.py +2 -0
  1059. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/fine_preprocess.py +59 -0
  1060. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/linear_attention.py +81 -0
  1061. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/loftr_module/transformer.py +101 -0
  1062. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/coarse_matching.py +261 -0
  1063. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/cvpr_ds_config.py +50 -0
  1064. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/fine_matching.py +74 -0
  1065. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/geometry.py +54 -0
  1066. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/position_encoding.py +42 -0
  1067. vismatch/third_party/TopicFM/third_party/loftr/src/loftr/utils/supervision.py +151 -0
  1068. vismatch/third_party/TopicFM/third_party/loftr/src/losses/loftr_loss.py +192 -0
  1069. vismatch/third_party/TopicFM/third_party/loftr/src/optimizers/__init__.py +42 -0
  1070. vismatch/third_party/TopicFM/third_party/loftr/src/utils/augment.py +55 -0
  1071. vismatch/third_party/TopicFM/third_party/loftr/src/utils/comm.py +265 -0
  1072. vismatch/third_party/TopicFM/third_party/loftr/src/utils/dataloader.py +23 -0
  1073. vismatch/third_party/TopicFM/third_party/loftr/src/utils/dataset.py +185 -0
  1074. vismatch/third_party/TopicFM/third_party/loftr/src/utils/metrics.py +193 -0
  1075. vismatch/third_party/TopicFM/third_party/loftr/src/utils/misc.py +101 -0
  1076. vismatch/third_party/TopicFM/third_party/loftr/src/utils/plotting.py +154 -0
  1077. vismatch/third_party/TopicFM/third_party/loftr/src/utils/profiler.py +39 -0
  1078. vismatch/third_party/TopicFM/third_party/loftr/test.py +68 -0
  1079. vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/demo_superglue.py +259 -0
  1080. vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/match_pairs.py +425 -0
  1081. vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/__init__.py +0 -0
  1082. vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/matching.py +84 -0
  1083. vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/superglue.py +283 -0
  1084. vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/superpoint.py +202 -0
  1085. vismatch/third_party/TopicFM/third_party/loftr/third_party/SuperGluePretrainedNetwork/models/utils.py +555 -0
  1086. vismatch/third_party/TopicFM/third_party/loftr/train.py +123 -0
  1087. vismatch/third_party/TopicFM/third_party/matchformer/config/data/__init__.py +0 -0
  1088. vismatch/third_party/TopicFM/third_party/matchformer/config/data/base.py +35 -0
  1089. vismatch/third_party/TopicFM/third_party/matchformer/config/data/megadepth_test_1500.py +11 -0
  1090. vismatch/third_party/TopicFM/third_party/matchformer/config/data/scannet_test_1500.py +11 -0
  1091. vismatch/third_party/TopicFM/third_party/matchformer/config/defaultmf.py +88 -0
  1092. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/__init__.py +17 -0
  1093. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/coarse_matching.py +228 -0
  1094. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/fine_matching.py +74 -0
  1095. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/fine_preprocess.py +59 -0
  1096. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_LA_large.py +254 -0
  1097. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_LA_lite.py +254 -0
  1098. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_SEA_large.py +291 -0
  1099. vismatch/third_party/TopicFM/third_party/matchformer/model/backbone/match_SEA_lite.py +291 -0
  1100. vismatch/third_party/TopicFM/third_party/matchformer/model/data.py +320 -0
  1101. vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/dataset.py +231 -0
  1102. vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/megadepth.py +126 -0
  1103. vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/sampler.py +77 -0
  1104. vismatch/third_party/TopicFM/third_party/matchformer/model/datasets/scannet.py +113 -0
  1105. vismatch/third_party/TopicFM/third_party/matchformer/model/lightning_loftr.py +102 -0
  1106. vismatch/third_party/TopicFM/third_party/matchformer/model/matchformer.py +54 -0
  1107. vismatch/third_party/TopicFM/third_party/matchformer/model/utils/augment.py +55 -0
  1108. vismatch/third_party/TopicFM/third_party/matchformer/model/utils/comm.py +265 -0
  1109. vismatch/third_party/TopicFM/third_party/matchformer/model/utils/dataloader.py +23 -0
  1110. vismatch/third_party/TopicFM/third_party/matchformer/model/utils/metrics.py +193 -0
  1111. vismatch/third_party/TopicFM/third_party/matchformer/model/utils/misc.py +101 -0
  1112. vismatch/third_party/TopicFM/third_party/matchformer/model/utils/profiler.py +39 -0
  1113. vismatch/third_party/TopicFM/third_party/matchformer/test.py +55 -0
  1114. vismatch/third_party/TopicFM/train.py +123 -0
  1115. vismatch/third_party/TopicFM/visualization.py +123 -0
  1116. vismatch/third_party/TopicFM/viz/__init__.py +1 -0
  1117. vismatch/third_party/TopicFM/viz/configs/__init__.py +0 -0
  1118. vismatch/third_party/TopicFM/viz/methods/__init__.py +0 -0
  1119. vismatch/third_party/TopicFM/viz/methods/base.py +70 -0
  1120. vismatch/third_party/TopicFM/viz/methods/topicfmv2.py +208 -0
  1121. vismatch/third_party/UFM/UniCeption/examples/models/cosmos/autoencoding.py +48 -0
  1122. vismatch/third_party/UFM/UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py +331 -0
  1123. vismatch/third_party/UFM/UniCeption/examples/models/dust3r/dust3r.py +261 -0
  1124. vismatch/third_party/UFM/UniCeption/examples/models/dust3r/profile_dust3r.py +47 -0
  1125. vismatch/third_party/UFM/UniCeption/scripts/check_dependencies.py +48 -0
  1126. vismatch/third_party/UFM/UniCeption/scripts/download_checkpoints.py +50 -0
  1127. vismatch/third_party/UFM/UniCeption/scripts/install_croco_rope.py +61 -0
  1128. vismatch/third_party/UFM/UniCeption/scripts/prepare_offline_install.py +398 -0
  1129. vismatch/third_party/UFM/UniCeption/scripts/validate_installation.py +212 -0
  1130. vismatch/third_party/UFM/UniCeption/setup.py +185 -0
  1131. vismatch/third_party/UFM/UniCeption/tests/models/encoders/conftest.py +26 -0
  1132. vismatch/third_party/UFM/UniCeption/tests/models/encoders/test_encoders.py +202 -0
  1133. vismatch/third_party/UFM/UniCeption/tests/models/encoders/viz_image_encoders.py +294 -0
  1134. vismatch/third_party/UFM/UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py +337 -0
  1135. vismatch/third_party/UFM/UniCeption/uniception/__init__.py +0 -0
  1136. vismatch/third_party/UFM/UniCeption/uniception/models/__init__.py +0 -0
  1137. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/__init__.py +225 -0
  1138. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/base.py +157 -0
  1139. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/cosmos.py +137 -0
  1140. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/croco.py +457 -0
  1141. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/dense_rep_encoder.py +344 -0
  1142. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/dinov2.py +333 -0
  1143. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/global_rep_encoder.py +115 -0
  1144. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/image_normalizations.py +35 -0
  1145. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/list.py +10 -0
  1146. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/patch_embedder.py +235 -0
  1147. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/radio.py +367 -0
  1148. vismatch/third_party/UFM/UniCeption/uniception/models/encoders/utils.py +86 -0
  1149. vismatch/third_party/UFM/UniCeption/uniception/models/factory/__init__.py +3 -0
  1150. vismatch/third_party/UFM/UniCeption/uniception/models/factory/dust3r.py +332 -0
  1151. vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/__init__.py +39 -0
  1152. vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py +973 -0
  1153. vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/base.py +116 -0
  1154. vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/cross_attention_transformer.py +612 -0
  1155. vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py +588 -0
  1156. vismatch/third_party/UFM/UniCeption/uniception/models/info_sharing/global_attention_transformer.py +1154 -0
  1157. vismatch/third_party/UFM/UniCeption/uniception/models/libs/__init__.py +0 -0
  1158. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/__init__.py +14 -0
  1159. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/image_cli.py +175 -0
  1160. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/image_lib.py +123 -0
  1161. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/__init__.py +60 -0
  1162. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/distributions.py +41 -0
  1163. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers2d.py +326 -0
  1164. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers3d.py +965 -0
  1165. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/patching.py +310 -0
  1166. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/quantizers.py +510 -0
  1167. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/utils.py +115 -0
  1168. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/__init__.py +39 -0
  1169. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/configs.py +146 -0
  1170. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_image.py +86 -0
  1171. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_video.py +98 -0
  1172. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_image.py +113 -0
  1173. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_video.py +115 -0
  1174. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/utils.py +402 -0
  1175. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/video_cli.py +195 -0
  1176. vismatch/third_party/UFM/UniCeption/uniception/models/libs/cosmos_tokenizer/video_lib.py +145 -0
  1177. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/__init__.py +0 -0
  1178. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/blocks.py +249 -0
  1179. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/curope/__init__.py +4 -0
  1180. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/curope/curope2d.py +39 -0
  1181. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/curope/setup.py +33 -0
  1182. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/dpt_block.py +530 -0
  1183. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/patch_embed.py +127 -0
  1184. vismatch/third_party/UFM/UniCeption/uniception/models/libs/croco/pos_embed.py +155 -0
  1185. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/__init__.py +18 -0
  1186. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/adaptors.py +1765 -0
  1187. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/base.py +210 -0
  1188. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/cosmos.py +211 -0
  1189. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/dpt.py +676 -0
  1190. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/global_head.py +142 -0
  1191. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/linear.py +95 -0
  1192. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/mlp_feature.py +114 -0
  1193. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/mlp_head.py +114 -0
  1194. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/moge_conv.py +342 -0
  1195. vismatch/third_party/UFM/UniCeption/uniception/models/prediction_heads/pose_head.py +181 -0
  1196. vismatch/third_party/UFM/UniCeption/uniception/models/utils/__init__.py +0 -0
  1197. vismatch/third_party/UFM/UniCeption/uniception/models/utils/config.py +34 -0
  1198. vismatch/third_party/UFM/UniCeption/uniception/models/utils/intermediate_feature_return.py +85 -0
  1199. vismatch/third_party/UFM/UniCeption/uniception/models/utils/positional_encoding.py +23 -0
  1200. vismatch/third_party/UFM/UniCeption/uniception/models/utils/transformer_blocks.py +1072 -0
  1201. vismatch/third_party/UFM/UniCeption/uniception/utils/__init__.py +0 -0
  1202. vismatch/third_party/UFM/UniCeption/uniception/utils/profile.py +13 -0
  1203. vismatch/third_party/UFM/UniCeption/uniception/utils/viz.py +99 -0
  1204. vismatch/third_party/UFM/example_inference.py +138 -0
  1205. vismatch/third_party/UFM/gradio_demo.py +238 -0
  1206. vismatch/third_party/UFM/setup.py +86 -0
  1207. vismatch/third_party/UFM/uniflowmatch/__init__.py +16 -0
  1208. vismatch/third_party/UFM/uniflowmatch/cli.py +217 -0
  1209. vismatch/third_party/UFM/uniflowmatch/models/__init__.py +25 -0
  1210. vismatch/third_party/UFM/uniflowmatch/models/base.py +334 -0
  1211. vismatch/third_party/UFM/uniflowmatch/models/ufm.py +1323 -0
  1212. vismatch/third_party/UFM/uniflowmatch/models/unet_encoder.py +90 -0
  1213. vismatch/third_party/UFM/uniflowmatch/models/utils.py +16 -0
  1214. vismatch/third_party/UFM/uniflowmatch/utils/__init__.py +63 -0
  1215. vismatch/third_party/UFM/uniflowmatch/utils/flow_resizing.py +1091 -0
  1216. vismatch/third_party/UFM/uniflowmatch/utils/geometry.py +612 -0
  1217. vismatch/third_party/UFM/uniflowmatch/utils/viz.py +97 -0
  1218. vismatch/third_party/XoFTR/configs/data/__init__.py +0 -0
  1219. vismatch/third_party/XoFTR/configs/data/base.py +35 -0
  1220. vismatch/third_party/XoFTR/configs/data/megadepth_trainval_840.py +22 -0
  1221. vismatch/third_party/XoFTR/configs/data/megadepth_vistir_trainval_640.py +23 -0
  1222. vismatch/third_party/XoFTR/configs/data/pretrain.py +8 -0
  1223. vismatch/third_party/XoFTR/configs/xoftr/outdoor/visible_thermal.py +17 -0
  1224. vismatch/third_party/XoFTR/configs/xoftr/pretrain/pretrain.py +12 -0
  1225. vismatch/third_party/XoFTR/pretrain.py +125 -0
  1226. vismatch/third_party/XoFTR/src/__init__.py +0 -0
  1227. vismatch/third_party/XoFTR/src/config/default.py +203 -0
  1228. vismatch/third_party/XoFTR/src/datasets/megadepth.py +143 -0
  1229. vismatch/third_party/XoFTR/src/datasets/pretrain_dataset.py +156 -0
  1230. vismatch/third_party/XoFTR/src/datasets/sampler.py +77 -0
  1231. vismatch/third_party/XoFTR/src/datasets/scannet.py +114 -0
  1232. vismatch/third_party/XoFTR/src/datasets/vistir.py +109 -0
  1233. vismatch/third_party/XoFTR/src/lightning/data.py +346 -0
  1234. vismatch/third_party/XoFTR/src/lightning/data_pretrain.py +125 -0
  1235. vismatch/third_party/XoFTR/src/lightning/lightning_xoftr.py +334 -0
  1236. vismatch/third_party/XoFTR/src/lightning/lightning_xoftr_pretrain.py +171 -0
  1237. vismatch/third_party/XoFTR/src/losses/xoftr_loss.py +170 -0
  1238. vismatch/third_party/XoFTR/src/losses/xoftr_loss_pretrain.py +37 -0
  1239. vismatch/third_party/XoFTR/src/optimizers/__init__.py +42 -0
  1240. vismatch/third_party/XoFTR/src/utils/augment.py +113 -0
  1241. vismatch/third_party/XoFTR/src/utils/comm.py +265 -0
  1242. vismatch/third_party/XoFTR/src/utils/data_io.py +144 -0
  1243. vismatch/third_party/XoFTR/src/utils/dataloader.py +23 -0
  1244. vismatch/third_party/XoFTR/src/utils/dataset.py +279 -0
  1245. vismatch/third_party/XoFTR/src/utils/metrics.py +211 -0
  1246. vismatch/third_party/XoFTR/src/utils/misc.py +101 -0
  1247. vismatch/third_party/XoFTR/src/utils/plotting.py +227 -0
  1248. vismatch/third_party/XoFTR/src/utils/pretrain_utils.py +83 -0
  1249. vismatch/third_party/XoFTR/src/utils/profiler.py +39 -0
  1250. vismatch/third_party/XoFTR/src/xoftr/__init__.py +2 -0
  1251. vismatch/third_party/XoFTR/src/xoftr/backbone/__init__.py +1 -0
  1252. vismatch/third_party/XoFTR/src/xoftr/backbone/resnet.py +95 -0
  1253. vismatch/third_party/XoFTR/src/xoftr/utils/geometry.py +107 -0
  1254. vismatch/third_party/XoFTR/src/xoftr/utils/position_encoding.py +36 -0
  1255. vismatch/third_party/XoFTR/src/xoftr/utils/supervision.py +290 -0
  1256. vismatch/third_party/XoFTR/src/xoftr/xoftr.py +94 -0
  1257. vismatch/third_party/XoFTR/src/xoftr/xoftr_module/__init__.py +4 -0
  1258. vismatch/third_party/XoFTR/src/xoftr/xoftr_module/coarse_matching.py +305 -0
  1259. vismatch/third_party/XoFTR/src/xoftr/xoftr_module/fine_matching.py +170 -0
  1260. vismatch/third_party/XoFTR/src/xoftr/xoftr_module/fine_process.py +321 -0
  1261. vismatch/third_party/XoFTR/src/xoftr/xoftr_module/linear_attention.py +81 -0
  1262. vismatch/third_party/XoFTR/src/xoftr/xoftr_module/transformer.py +101 -0
  1263. vismatch/third_party/XoFTR/src/xoftr/xoftr_pretrain.py +209 -0
  1264. vismatch/third_party/XoFTR/test.py +68 -0
  1265. vismatch/third_party/XoFTR/test_relative_pose.py +330 -0
  1266. vismatch/third_party/XoFTR/train.py +126 -0
  1267. vismatch/third_party/accelerated_features/hubconf.py +15 -0
  1268. vismatch/third_party/accelerated_features/minimal_example.py +49 -0
  1269. vismatch/third_party/accelerated_features/modules/__init__.py +4 -0
  1270. vismatch/third_party/accelerated_features/modules/dataset/__init__.py +5 -0
  1271. vismatch/third_party/accelerated_features/modules/dataset/augmentation.py +314 -0
  1272. vismatch/third_party/accelerated_features/modules/dataset/megadepth/__init__.py +7 -0
  1273. vismatch/third_party/accelerated_features/modules/dataset/megadepth/megadepth.py +174 -0
  1274. vismatch/third_party/accelerated_features/modules/dataset/megadepth/megadepth_warper.py +170 -0
  1275. vismatch/third_party/accelerated_features/modules/dataset/megadepth/utils.py +160 -0
  1276. vismatch/third_party/accelerated_features/modules/interpolator.py +33 -0
  1277. vismatch/third_party/accelerated_features/modules/lighterglue.py +56 -0
  1278. vismatch/third_party/accelerated_features/modules/model.py +154 -0
  1279. vismatch/third_party/accelerated_features/modules/training/__init__.py +4 -0
  1280. vismatch/third_party/accelerated_features/modules/training/losses.py +224 -0
  1281. vismatch/third_party/accelerated_features/modules/training/train.py +311 -0
  1282. vismatch/third_party/accelerated_features/modules/training/utils.py +200 -0
  1283. vismatch/third_party/accelerated_features/modules/xfeat.py +402 -0
  1284. vismatch/third_party/accelerated_features/realtime_demo.py +295 -0
  1285. vismatch/third_party/accelerated_features/third_party/ALIKE/alike.py +143 -0
  1286. vismatch/third_party/accelerated_features/third_party/ALIKE/alnet.py +164 -0
  1287. vismatch/third_party/accelerated_features/third_party/ALIKE/demo.py +167 -0
  1288. vismatch/third_party/accelerated_features/third_party/ALIKE/hseq/eval.py +162 -0
  1289. vismatch/third_party/accelerated_features/third_party/ALIKE/hseq/extract.py +159 -0
  1290. vismatch/third_party/accelerated_features/third_party/ALIKE/soft_detect.py +194 -0
  1291. vismatch/third_party/accelerated_features/third_party/__init__.py +4 -0
  1292. vismatch/third_party/accelerated_features/third_party/alike_wrapper.py +110 -0
  1293. vismatch/third_party/affine-steerers/affine_steerers/__init__.py +7 -0
  1294. vismatch/third_party/affine-steerers/affine_steerers/benchmarks/__init__.py +5 -0
  1295. vismatch/third_party/affine-steerers/affine_steerers/benchmarks/hpatches.py +92 -0
  1296. vismatch/third_party/affine-steerers/affine_steerers/benchmarks/hpatches_oracle_steer.py +108 -0
  1297. vismatch/third_party/affine-steerers/affine_steerers/benchmarks/mega_pose_est.py +116 -0
  1298. vismatch/third_party/affine-steerers/affine_steerers/benchmarks/mega_pose_est_mnn.py +162 -0
  1299. vismatch/third_party/affine-steerers/affine_steerers/benchmarks/nll_benchmark.py +57 -0
  1300. vismatch/third_party/affine-steerers/affine_steerers/benchmarks/num_inliers.py +76 -0
  1301. vismatch/third_party/affine-steerers/affine_steerers/checkpoint.py +82 -0
  1302. vismatch/third_party/affine-steerers/affine_steerers/datasets/__init__.py +0 -0
  1303. vismatch/third_party/affine-steerers/affine_steerers/datasets/homog.py +284 -0
  1304. vismatch/third_party/affine-steerers/affine_steerers/datasets/megadepth.py +408 -0
  1305. vismatch/third_party/affine-steerers/affine_steerers/decoder.py +90 -0
  1306. vismatch/third_party/affine-steerers/affine_steerers/descriptors/__init__.py +0 -0
  1307. vismatch/third_party/affine-steerers/affine_steerers/descriptors/dedode_descriptor.py +77 -0
  1308. vismatch/third_party/affine-steerers/affine_steerers/descriptors/descriptor_loss.py +358 -0
  1309. vismatch/third_party/affine-steerers/affine_steerers/detectors/__init__.py +0 -0
  1310. vismatch/third_party/affine-steerers/affine_steerers/detectors/dedode_detector.py +75 -0
  1311. vismatch/third_party/affine-steerers/affine_steerers/detectors/keypoint_loss.py +215 -0
  1312. vismatch/third_party/affine-steerers/affine_steerers/encoder.py +87 -0
  1313. vismatch/third_party/affine-steerers/affine_steerers/matchers/__init__.py +0 -0
  1314. vismatch/third_party/affine-steerers/affine_steerers/matchers/dual_softmax_matcher.py +816 -0
  1315. vismatch/third_party/affine-steerers/affine_steerers/model_zoo/__init__.py +3 -0
  1316. vismatch/third_party/affine-steerers/affine_steerers/model_zoo/dedode_models.py +298 -0
  1317. vismatch/third_party/affine-steerers/affine_steerers/steerers.py +732 -0
  1318. vismatch/third_party/affine-steerers/affine_steerers/train.py +90 -0
  1319. vismatch/third_party/affine-steerers/affine_steerers/transformer/__init__.py +8 -0
  1320. vismatch/third_party/affine-steerers/affine_steerers/transformer/dinov2.py +359 -0
  1321. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/__init__.py +12 -0
  1322. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/attention.py +81 -0
  1323. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/block.py +252 -0
  1324. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/dino_head.py +59 -0
  1325. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/drop_path.py +35 -0
  1326. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/layer_scale.py +28 -0
  1327. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/mlp.py +41 -0
  1328. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/patch_embed.py +89 -0
  1329. vismatch/third_party/affine-steerers/affine_steerers/transformer/layers/swiglu_ffn.py +63 -0
  1330. vismatch/third_party/affine-steerers/affine_steerers/utils.py +1422 -0
  1331. vismatch/third_party/affine-steerers/experiments/aff_equi_B.py +182 -0
  1332. vismatch/third_party/affine-steerers/experiments/aff_equi_G.py +193 -0
  1333. vismatch/third_party/affine-steerers/experiments/aff_steer_B.py +213 -0
  1334. vismatch/third_party/affine-steerers/experiments/aff_steer_G.py +223 -0
  1335. vismatch/third_party/affine-steerers/experiments/aff_steer_pretrain_B.py +187 -0
  1336. vismatch/third_party/affine-steerers/experiments/aff_steer_pretrain_G.py +198 -0
  1337. vismatch/third_party/affine-steerers/setup.py +15 -0
  1338. vismatch/third_party/aspanformer/configs/aspan/indoor/aspan_test.py +7 -0
  1339. vismatch/third_party/aspanformer/configs/aspan/indoor/aspan_train.py +8 -0
  1340. vismatch/third_party/aspanformer/configs/aspan/outdoor/aspan_test.py +19 -0
  1341. vismatch/third_party/aspanformer/configs/aspan/outdoor/aspan_train.py +17 -0
  1342. vismatch/third_party/aspanformer/configs/data/__init__.py +0 -0
  1343. vismatch/third_party/aspanformer/configs/data/base.py +35 -0
  1344. vismatch/third_party/aspanformer/configs/data/megadepth_test_1500.py +13 -0
  1345. vismatch/third_party/aspanformer/configs/data/megadepth_trainval_832.py +22 -0
  1346. vismatch/third_party/aspanformer/configs/data/scannet_test_1500.py +11 -0
  1347. vismatch/third_party/aspanformer/configs/data/scannet_trainval.py +17 -0
  1348. vismatch/third_party/aspanformer/demo/demo.py +63 -0
  1349. vismatch/third_party/aspanformer/demo/demo_utils.py +44 -0
  1350. vismatch/third_party/aspanformer/src/ASpanFormer/__init__.py +2 -0
  1351. vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/__init__.py +3 -0
  1352. vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/attention.py +198 -0
  1353. vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/fine_preprocess.py +59 -0
  1354. vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/loftr.py +112 -0
  1355. vismatch/third_party/aspanformer/src/ASpanFormer/aspan_module/transformer.py +244 -0
  1356. vismatch/third_party/aspanformer/src/ASpanFormer/aspanformer.py +152 -0
  1357. vismatch/third_party/aspanformer/src/ASpanFormer/backbone/__init__.py +11 -0
  1358. vismatch/third_party/aspanformer/src/ASpanFormer/backbone/resnet_fpn.py +199 -0
  1359. vismatch/third_party/aspanformer/src/ASpanFormer/utils/coarse_matching.py +331 -0
  1360. vismatch/third_party/aspanformer/src/ASpanFormer/utils/cvpr_ds_config.py +50 -0
  1361. vismatch/third_party/aspanformer/src/ASpanFormer/utils/fine_matching.py +74 -0
  1362. vismatch/third_party/aspanformer/src/ASpanFormer/utils/geometry.py +54 -0
  1363. vismatch/third_party/aspanformer/src/ASpanFormer/utils/position_encoding.py +61 -0
  1364. vismatch/third_party/aspanformer/src/ASpanFormer/utils/supervision.py +151 -0
  1365. vismatch/third_party/aspanformer/src/__init__.py +0 -0
  1366. vismatch/third_party/aspanformer/src/config/default.py +180 -0
  1367. vismatch/third_party/aspanformer/src/datasets/__init__.py +3 -0
  1368. vismatch/third_party/aspanformer/src/datasets/megadepth.py +127 -0
  1369. vismatch/third_party/aspanformer/src/datasets/sampler.py +77 -0
  1370. vismatch/third_party/aspanformer/src/datasets/scannet.py +113 -0
  1371. vismatch/third_party/aspanformer/src/lightning/data.py +326 -0
  1372. vismatch/third_party/aspanformer/src/lightning/lightning_aspanformer.py +276 -0
  1373. vismatch/third_party/aspanformer/src/losses/aspan_loss.py +231 -0
  1374. vismatch/third_party/aspanformer/src/optimizers/__init__.py +42 -0
  1375. vismatch/third_party/aspanformer/src/utils/augment.py +55 -0
  1376. vismatch/third_party/aspanformer/src/utils/comm.py +265 -0
  1377. vismatch/third_party/aspanformer/src/utils/dataloader.py +23 -0
  1378. vismatch/third_party/aspanformer/src/utils/dataset.py +222 -0
  1379. vismatch/third_party/aspanformer/src/utils/metrics.py +260 -0
  1380. vismatch/third_party/aspanformer/src/utils/misc.py +139 -0
  1381. vismatch/third_party/aspanformer/src/utils/plotting.py +219 -0
  1382. vismatch/third_party/aspanformer/src/utils/profiler.py +39 -0
  1383. vismatch/third_party/aspanformer/test.py +69 -0
  1384. vismatch/third_party/aspanformer/tools/SensorData.py +125 -0
  1385. vismatch/third_party/aspanformer/tools/extract.py +47 -0
  1386. vismatch/third_party/aspanformer/tools/preprocess_scene.py +242 -0
  1387. vismatch/third_party/aspanformer/tools/reader.py +39 -0
  1388. vismatch/third_party/aspanformer/tools/undistort_mega.py +69 -0
  1389. vismatch/third_party/aspanformer/train.py +134 -0
  1390. vismatch/third_party/duster/croco/datasets/__init__.py +0 -0
  1391. vismatch/third_party/duster/croco/datasets/crops/extract_crops_from_images.py +159 -0
  1392. vismatch/third_party/duster/croco/datasets/habitat_sim/__init__.py +0 -0
  1393. vismatch/third_party/duster/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
  1394. vismatch/third_party/duster/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
  1395. vismatch/third_party/duster/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
  1396. vismatch/third_party/duster/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
  1397. vismatch/third_party/duster/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
  1398. vismatch/third_party/duster/croco/datasets/habitat_sim/paths.py +129 -0
  1399. vismatch/third_party/duster/croco/datasets/pairs_dataset.py +109 -0
  1400. vismatch/third_party/duster/croco/datasets/transforms.py +95 -0
  1401. vismatch/third_party/duster/croco/demo.py +55 -0
  1402. vismatch/third_party/duster/croco/models/blocks.py +241 -0
  1403. vismatch/third_party/duster/croco/models/criterion.py +37 -0
  1404. vismatch/third_party/duster/croco/models/croco.py +249 -0
  1405. vismatch/third_party/duster/croco/models/croco_downstream.py +122 -0
  1406. vismatch/third_party/duster/croco/models/curope/__init__.py +4 -0
  1407. vismatch/third_party/duster/croco/models/curope/curope2d.py +40 -0
  1408. vismatch/third_party/duster/croco/models/curope/setup.py +34 -0
  1409. vismatch/third_party/duster/croco/models/dpt_block.py +450 -0
  1410. vismatch/third_party/duster/croco/models/head_downstream.py +58 -0
  1411. vismatch/third_party/duster/croco/models/masking.py +25 -0
  1412. vismatch/third_party/duster/croco/models/pos_embed.py +157 -0
  1413. vismatch/third_party/duster/croco/pretrain.py +254 -0
  1414. vismatch/third_party/duster/croco/stereoflow/augmentor.py +290 -0
  1415. vismatch/third_party/duster/croco/stereoflow/criterion.py +251 -0
  1416. vismatch/third_party/duster/croco/stereoflow/datasets_flow.py +630 -0
  1417. vismatch/third_party/duster/croco/stereoflow/datasets_stereo.py +674 -0
  1418. vismatch/third_party/duster/croco/stereoflow/engine.py +280 -0
  1419. vismatch/third_party/duster/croco/stereoflow/test.py +216 -0
  1420. vismatch/third_party/duster/croco/stereoflow/train.py +253 -0
  1421. vismatch/third_party/duster/croco/utils/misc.py +463 -0
  1422. vismatch/third_party/duster/datasets_preprocess/habitat/find_scenes.py +78 -0
  1423. vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/__init__.py +2 -0
  1424. vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py +170 -0
  1425. vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py +93 -0
  1426. vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/projections.py +151 -0
  1427. vismatch/third_party/duster/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py +45 -0
  1428. vismatch/third_party/duster/datasets_preprocess/habitat/preprocess_habitat.py +121 -0
  1429. vismatch/third_party/duster/datasets_preprocess/path_to_root.py +13 -0
  1430. vismatch/third_party/duster/datasets_preprocess/preprocess_arkitscenes.py +355 -0
  1431. vismatch/third_party/duster/datasets_preprocess/preprocess_blendedMVS.py +149 -0
  1432. vismatch/third_party/duster/datasets_preprocess/preprocess_co3d.py +295 -0
  1433. vismatch/third_party/duster/datasets_preprocess/preprocess_megadepth.py +198 -0
  1434. vismatch/third_party/duster/datasets_preprocess/preprocess_scannetpp.py +400 -0
  1435. vismatch/third_party/duster/datasets_preprocess/preprocess_staticthings3d.py +130 -0
  1436. vismatch/third_party/duster/datasets_preprocess/preprocess_waymo.py +257 -0
  1437. vismatch/third_party/duster/datasets_preprocess/preprocess_wildrgbd.py +209 -0
  1438. vismatch/third_party/duster/demo.py +45 -0
  1439. vismatch/third_party/duster/dust3r/__init__.py +2 -0
  1440. vismatch/third_party/duster/dust3r/cloud_opt/__init__.py +33 -0
  1441. vismatch/third_party/duster/dust3r/cloud_opt/base_opt.py +405 -0
  1442. vismatch/third_party/duster/dust3r/cloud_opt/commons.py +90 -0
  1443. vismatch/third_party/duster/dust3r/cloud_opt/init_im_poses.py +316 -0
  1444. vismatch/third_party/duster/dust3r/cloud_opt/modular_optimizer.py +145 -0
  1445. vismatch/third_party/duster/dust3r/cloud_opt/optimizer.py +248 -0
  1446. vismatch/third_party/duster/dust3r/cloud_opt/pair_viewer.py +127 -0
  1447. vismatch/third_party/duster/dust3r/datasets/__init__.py +50 -0
  1448. vismatch/third_party/duster/dust3r/datasets/arkitscenes.py +102 -0
  1449. vismatch/third_party/duster/dust3r/datasets/base/__init__.py +2 -0
  1450. vismatch/third_party/duster/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
  1451. vismatch/third_party/duster/dust3r/datasets/base/batched_sampler.py +74 -0
  1452. vismatch/third_party/duster/dust3r/datasets/base/easy_dataset.py +157 -0
  1453. vismatch/third_party/duster/dust3r/datasets/blendedmvs.py +104 -0
  1454. vismatch/third_party/duster/dust3r/datasets/co3d.py +165 -0
  1455. vismatch/third_party/duster/dust3r/datasets/habitat.py +107 -0
  1456. vismatch/third_party/duster/dust3r/datasets/megadepth.py +123 -0
  1457. vismatch/third_party/duster/dust3r/datasets/scannetpp.py +96 -0
  1458. vismatch/third_party/duster/dust3r/datasets/staticthings3d.py +96 -0
  1459. vismatch/third_party/duster/dust3r/datasets/utils/__init__.py +2 -0
  1460. vismatch/third_party/duster/dust3r/datasets/utils/cropping.py +124 -0
  1461. vismatch/third_party/duster/dust3r/datasets/utils/transforms.py +11 -0
  1462. vismatch/third_party/duster/dust3r/datasets/waymo.py +93 -0
  1463. vismatch/third_party/duster/dust3r/datasets/wildrgbd.py +67 -0
  1464. vismatch/third_party/duster/dust3r/demo.py +287 -0
  1465. vismatch/third_party/duster/dust3r/heads/__init__.py +19 -0
  1466. vismatch/third_party/duster/dust3r/heads/dpt_head.py +115 -0
  1467. vismatch/third_party/duster/dust3r/heads/linear_head.py +41 -0
  1468. vismatch/third_party/duster/dust3r/heads/postprocess.py +58 -0
  1469. vismatch/third_party/duster/dust3r/image_pairs.py +104 -0
  1470. vismatch/third_party/duster/dust3r/inference.py +150 -0
  1471. vismatch/third_party/duster/dust3r/losses.py +299 -0
  1472. vismatch/third_party/duster/dust3r/model.py +211 -0
  1473. vismatch/third_party/duster/dust3r/optim_factory.py +14 -0
  1474. vismatch/third_party/duster/dust3r/patch_embed.py +70 -0
  1475. vismatch/third_party/duster/dust3r/post_process.py +60 -0
  1476. vismatch/third_party/duster/dust3r/training.py +377 -0
  1477. vismatch/third_party/duster/dust3r/utils/__init__.py +2 -0
  1478. vismatch/third_party/duster/dust3r/utils/device.py +76 -0
  1479. vismatch/third_party/duster/dust3r/utils/geometry.py +366 -0
  1480. vismatch/third_party/duster/dust3r/utils/image.py +128 -0
  1481. vismatch/third_party/duster/dust3r/utils/misc.py +121 -0
  1482. vismatch/third_party/duster/dust3r/utils/parallel.py +79 -0
  1483. vismatch/third_party/duster/dust3r/utils/path_to_croco.py +19 -0
  1484. vismatch/third_party/duster/dust3r/viz.py +381 -0
  1485. vismatch/third_party/duster/dust3r_visloc/__init__.py +2 -0
  1486. vismatch/third_party/duster/dust3r_visloc/datasets/__init__.py +6 -0
  1487. vismatch/third_party/duster/dust3r_visloc/datasets/aachen_day_night.py +24 -0
  1488. vismatch/third_party/duster/dust3r_visloc/datasets/base_colmap.py +282 -0
  1489. vismatch/third_party/duster/dust3r_visloc/datasets/base_dataset.py +19 -0
  1490. vismatch/third_party/duster/dust3r_visloc/datasets/cambridge_landmarks.py +19 -0
  1491. vismatch/third_party/duster/dust3r_visloc/datasets/inloc.py +167 -0
  1492. vismatch/third_party/duster/dust3r_visloc/datasets/sevenscenes.py +123 -0
  1493. vismatch/third_party/duster/dust3r_visloc/datasets/utils.py +118 -0
  1494. vismatch/third_party/duster/dust3r_visloc/evaluation.py +65 -0
  1495. vismatch/third_party/duster/dust3r_visloc/localization.py +140 -0
  1496. vismatch/third_party/duster/train.py +13 -0
  1497. vismatch/third_party/duster/visloc.py +193 -0
  1498. vismatch/third_party/gim/demo.py +479 -0
  1499. vismatch/third_party/gim/dkm/__init__.py +4 -0
  1500. vismatch/third_party/gim/dkm/benchmarks/__init__.py +4 -0
  1501. vismatch/third_party/gim/dkm/benchmarks/hpatches_sequences_homog_benchmark.py +114 -0
  1502. vismatch/third_party/gim/dkm/benchmarks/megadepth1500_benchmark.py +124 -0
  1503. vismatch/third_party/gim/dkm/benchmarks/megadepth_dense_benchmark.py +86 -0
  1504. vismatch/third_party/gim/dkm/benchmarks/scannet_benchmark.py +143 -0
  1505. vismatch/third_party/gim/dkm/checkpointing/__init__.py +1 -0
  1506. vismatch/third_party/gim/dkm/checkpointing/checkpoint.py +31 -0
  1507. vismatch/third_party/gim/dkm/datasets/__init__.py +1 -0
  1508. vismatch/third_party/gim/dkm/datasets/megadepth.py +177 -0
  1509. vismatch/third_party/gim/dkm/datasets/scannet.py +151 -0
  1510. vismatch/third_party/gim/dkm/losses/__init__.py +1 -0
  1511. vismatch/third_party/gim/dkm/losses/depth_match_regression_loss.py +128 -0
  1512. vismatch/third_party/gim/dkm/models/__init__.py +4 -0
  1513. vismatch/third_party/gim/dkm/models/dkm.py +745 -0
  1514. vismatch/third_party/gim/dkm/models/encoders.py +148 -0
  1515. vismatch/third_party/gim/dkm/models/model_zoo/DKMv3.py +148 -0
  1516. vismatch/third_party/gim/dkm/models/model_zoo/__init__.py +39 -0
  1517. vismatch/third_party/gim/dkm/train/__init__.py +1 -0
  1518. vismatch/third_party/gim/dkm/train/train.py +67 -0
  1519. vismatch/third_party/gim/dkm/utils/__init__.py +13 -0
  1520. vismatch/third_party/gim/dkm/utils/kde.py +26 -0
  1521. vismatch/third_party/gim/dkm/utils/local_correlation.py +40 -0
  1522. vismatch/third_party/gim/dkm/utils/transforms.py +104 -0
  1523. vismatch/third_party/gim/dkm/utils/utils.py +341 -0
  1524. vismatch/third_party/gim/gluefactory/__init__.py +17 -0
  1525. vismatch/third_party/gim/gluefactory/datasets/__init__.py +25 -0
  1526. vismatch/third_party/gim/gluefactory/datasets/augmentations.py +244 -0
  1527. vismatch/third_party/gim/gluefactory/datasets/base_dataset.py +206 -0
  1528. vismatch/third_party/gim/gluefactory/datasets/eth3d.py +254 -0
  1529. vismatch/third_party/gim/gluefactory/datasets/homographies.py +311 -0
  1530. vismatch/third_party/gim/gluefactory/datasets/hpatches.py +145 -0
  1531. vismatch/third_party/gim/gluefactory/datasets/image_folder.py +59 -0
  1532. vismatch/third_party/gim/gluefactory/datasets/image_pairs.py +100 -0
  1533. vismatch/third_party/gim/gluefactory/datasets/megadepth.py +514 -0
  1534. vismatch/third_party/gim/gluefactory/datasets/utils.py +131 -0
  1535. vismatch/third_party/gim/gluefactory/eval/__init__.py +20 -0
  1536. vismatch/third_party/gim/gluefactory/eval/eth3d.py +202 -0
  1537. vismatch/third_party/gim/gluefactory/eval/eval_pipeline.py +109 -0
  1538. vismatch/third_party/gim/gluefactory/eval/hpatches.py +203 -0
  1539. vismatch/third_party/gim/gluefactory/eval/inspect.py +61 -0
  1540. vismatch/third_party/gim/gluefactory/eval/io.py +109 -0
  1541. vismatch/third_party/gim/gluefactory/eval/megadepth1500.py +189 -0
  1542. vismatch/third_party/gim/gluefactory/eval/utils.py +272 -0
  1543. vismatch/third_party/gim/gluefactory/geometry/depth.py +88 -0
  1544. vismatch/third_party/gim/gluefactory/geometry/epipolar.py +155 -0
  1545. vismatch/third_party/gim/gluefactory/geometry/gt_generation.py +558 -0
  1546. vismatch/third_party/gim/gluefactory/geometry/homography.py +342 -0
  1547. vismatch/third_party/gim/gluefactory/geometry/utils.py +167 -0
  1548. vismatch/third_party/gim/gluefactory/geometry/wrappers.py +425 -0
  1549. vismatch/third_party/gim/gluefactory/models/__init__.py +30 -0
  1550. vismatch/third_party/gim/gluefactory/models/backbones/__init__.py +0 -0
  1551. vismatch/third_party/gim/gluefactory/models/backbones/dinov2.py +30 -0
  1552. vismatch/third_party/gim/gluefactory/models/base_model.py +157 -0
  1553. vismatch/third_party/gim/gluefactory/models/cache_loader.py +139 -0
  1554. vismatch/third_party/gim/gluefactory/models/extractors/__init__.py +0 -0
  1555. vismatch/third_party/gim/gluefactory/models/extractors/aliked.py +786 -0
  1556. vismatch/third_party/gim/gluefactory/models/extractors/disk_kornia.py +108 -0
  1557. vismatch/third_party/gim/gluefactory/models/extractors/grid_extractor.py +60 -0
  1558. vismatch/third_party/gim/gluefactory/models/extractors/keynet_affnet_hardnet.py +74 -0
  1559. vismatch/third_party/gim/gluefactory/models/extractors/mixed.py +76 -0
  1560. vismatch/third_party/gim/gluefactory/models/extractors/sift.py +234 -0
  1561. vismatch/third_party/gim/gluefactory/models/extractors/sift_kornia.py +46 -0
  1562. vismatch/third_party/gim/gluefactory/models/extractors/superpoint_open.py +210 -0
  1563. vismatch/third_party/gim/gluefactory/models/lines/__init__.py +0 -0
  1564. vismatch/third_party/gim/gluefactory/models/lines/deeplsd.py +106 -0
  1565. vismatch/third_party/gim/gluefactory/models/lines/lsd.py +88 -0
  1566. vismatch/third_party/gim/gluefactory/models/lines/wireframe.py +312 -0
  1567. vismatch/third_party/gim/gluefactory/models/matchers/__init__.py +0 -0
  1568. vismatch/third_party/gim/gluefactory/models/matchers/adalam.py +0 -0
  1569. vismatch/third_party/gim/gluefactory/models/matchers/depth_matcher.py +82 -0
  1570. vismatch/third_party/gim/gluefactory/models/matchers/gluestick.py +776 -0
  1571. vismatch/third_party/gim/gluefactory/models/matchers/homography_matcher.py +66 -0
  1572. vismatch/third_party/gim/gluefactory/models/matchers/kornia_loftr.py +66 -0
  1573. vismatch/third_party/gim/gluefactory/models/matchers/lightglue.py +632 -0
  1574. vismatch/third_party/gim/gluefactory/models/matchers/lightglue_pretrained.py +36 -0
  1575. vismatch/third_party/gim/gluefactory/models/matchers/nearest_neighbor_matcher.py +97 -0
  1576. vismatch/third_party/gim/gluefactory/models/triplet_pipeline.py +99 -0
  1577. vismatch/third_party/gim/gluefactory/models/two_view_pipeline.py +114 -0
  1578. vismatch/third_party/gim/gluefactory/models/utils/__init__.py +0 -0
  1579. vismatch/third_party/gim/gluefactory/models/utils/losses.py +73 -0
  1580. vismatch/third_party/gim/gluefactory/models/utils/metrics.py +50 -0
  1581. vismatch/third_party/gim/gluefactory/models/utils/misc.py +70 -0
  1582. vismatch/third_party/gim/gluefactory/robust_estimators/__init__.py +15 -0
  1583. vismatch/third_party/gim/gluefactory/robust_estimators/base_estimator.py +33 -0
  1584. vismatch/third_party/gim/gluefactory/robust_estimators/homography/__init__.py +0 -0
  1585. vismatch/third_party/gim/gluefactory/robust_estimators/homography/homography_est.py +74 -0
  1586. vismatch/third_party/gim/gluefactory/robust_estimators/homography/opencv.py +53 -0
  1587. vismatch/third_party/gim/gluefactory/robust_estimators/homography/poselib.py +40 -0
  1588. vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/__init__.py +0 -0
  1589. vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/opencv.py +64 -0
  1590. vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/poselib.py +44 -0
  1591. vismatch/third_party/gim/gluefactory/robust_estimators/relative_pose/pycolmap.py +52 -0
  1592. vismatch/third_party/gim/gluefactory/scripts/__init__.py +0 -0
  1593. vismatch/third_party/gim/gluefactory/scripts/export_local_features.py +127 -0
  1594. vismatch/third_party/gim/gluefactory/scripts/export_megadepth.py +173 -0
  1595. vismatch/third_party/gim/gluefactory/settings.py +6 -0
  1596. vismatch/third_party/gim/gluefactory/superpoint.py +361 -0
  1597. vismatch/third_party/gim/gluefactory/train.py +691 -0
  1598. vismatch/third_party/gim/gluefactory/utils/__init__.py +0 -0
  1599. vismatch/third_party/gim/gluefactory/utils/benchmark.py +33 -0
  1600. vismatch/third_party/gim/gluefactory/utils/experiments.py +134 -0
  1601. vismatch/third_party/gim/gluefactory/utils/export_predictions.py +81 -0
  1602. vismatch/third_party/gim/gluefactory/utils/image.py +130 -0
  1603. vismatch/third_party/gim/gluefactory/utils/misc.py +44 -0
  1604. vismatch/third_party/gim/gluefactory/utils/patches.py +50 -0
  1605. vismatch/third_party/gim/gluefactory/utils/stdout_capturing.py +134 -0
  1606. vismatch/third_party/gim/gluefactory/utils/tensor.py +48 -0
  1607. vismatch/third_party/gim/gluefactory/utils/tools.py +269 -0
  1608. vismatch/third_party/gim/gluefactory/visualization/global_frame.py +289 -0
  1609. vismatch/third_party/gim/gluefactory/visualization/tools.py +465 -0
  1610. vismatch/third_party/gim/gluefactory/visualization/two_view_frame.py +158 -0
  1611. vismatch/third_party/gim/gluefactory/visualization/visualize_batch.py +57 -0
  1612. vismatch/third_party/gim/gluefactory/visualization/viz2d.py +486 -0
  1613. vismatch/third_party/imatch-toolbox/configs/d2net.yml +26 -0
  1614. vismatch/third_party/imatch-toolbox/configs/dogaffnethardnet.yml +10 -0
  1615. vismatch/third_party/imatch-toolbox/configs/ncnet.yml +7 -0
  1616. vismatch/third_party/imatch-toolbox/configs/patch2pix.yml +56 -0
  1617. vismatch/third_party/imatch-toolbox/configs/patch2pix_superglue.yml +58 -0
  1618. vismatch/third_party/imatch-toolbox/configs/r2d2.yml +31 -0
  1619. vismatch/third_party/imatch-toolbox/configs/sift.yml +27 -0
  1620. vismatch/third_party/imatch-toolbox/configs/superglue.yml +69 -0
  1621. vismatch/third_party/imatch-toolbox/configs/superpoint.yml +21 -0
  1622. vismatch/third_party/imatch-toolbox/environment.yml +14 -0
  1623. vismatch/third_party/imatch-toolbox/immatch/__init__.py +8 -0
  1624. vismatch/third_party/imatch-toolbox/immatch/eval_aachen.py +88 -0
  1625. vismatch/third_party/imatch-toolbox/immatch/eval_hpatches.py +117 -0
  1626. vismatch/third_party/imatch-toolbox/immatch/eval_inloc.py +45 -0
  1627. vismatch/third_party/imatch-toolbox/immatch/eval_relapose.py +231 -0
  1628. vismatch/third_party/imatch-toolbox/immatch/eval_robotcar.py +83 -0
  1629. vismatch/third_party/imatch-toolbox/immatch/modules/__init__.py +0 -0
  1630. vismatch/third_party/imatch-toolbox/immatch/modules/base.py +89 -0
  1631. vismatch/third_party/imatch-toolbox/immatch/modules/d2net.py +69 -0
  1632. vismatch/third_party/imatch-toolbox/immatch/modules/dogaffnethardnet.py +94 -0
  1633. vismatch/third_party/imatch-toolbox/immatch/modules/nn_matching.py +31 -0
  1634. vismatch/third_party/imatch-toolbox/immatch/modules/patch2pix.py +126 -0
  1635. vismatch/third_party/imatch-toolbox/immatch/modules/r2d2.py +64 -0
  1636. vismatch/third_party/imatch-toolbox/immatch/modules/sift.py +67 -0
  1637. vismatch/third_party/imatch-toolbox/immatch/modules/superglue.py +62 -0
  1638. vismatch/third_party/imatch-toolbox/immatch/modules/superpoint.py +56 -0
  1639. vismatch/third_party/imatch-toolbox/immatch/utils/__init__.py +13 -0
  1640. vismatch/third_party/imatch-toolbox/immatch/utils/colmap/data_parsing.py +257 -0
  1641. vismatch/third_party/imatch-toolbox/immatch/utils/colmap/database.py +362 -0
  1642. vismatch/third_party/imatch-toolbox/immatch/utils/colmap/read_write_model.py +506 -0
  1643. vismatch/third_party/imatch-toolbox/immatch/utils/data_io.py +111 -0
  1644. vismatch/third_party/imatch-toolbox/immatch/utils/hpatches_helper.py +242 -0
  1645. vismatch/third_party/imatch-toolbox/immatch/utils/localize_sfm_helper.py +403 -0
  1646. vismatch/third_party/imatch-toolbox/immatch/utils/metrics.py +90 -0
  1647. vismatch/third_party/imatch-toolbox/immatch/utils/model_helper.py +27 -0
  1648. vismatch/third_party/imatch-toolbox/setup.py +36 -0
  1649. vismatch/third_party/imatch-toolbox/third_party/d2net/extract_features.py +156 -0
  1650. vismatch/third_party/imatch-toolbox/third_party/d2net/extract_kapture.py +248 -0
  1651. vismatch/third_party/imatch-toolbox/third_party/d2net/lib/dataset.py +239 -0
  1652. vismatch/third_party/imatch-toolbox/third_party/d2net/lib/exceptions.py +6 -0
  1653. vismatch/third_party/imatch-toolbox/third_party/d2net/lib/loss.py +340 -0
  1654. vismatch/third_party/imatch-toolbox/third_party/d2net/lib/model.py +121 -0
  1655. vismatch/third_party/imatch-toolbox/third_party/d2net/lib/model_test.py +187 -0
  1656. vismatch/third_party/imatch-toolbox/third_party/d2net/lib/pyramid.py +129 -0
  1657. vismatch/third_party/imatch-toolbox/third_party/d2net/lib/utils.py +167 -0
  1658. vismatch/third_party/imatch-toolbox/third_party/d2net/megadepth_utils/preprocess_scene.py +242 -0
  1659. vismatch/third_party/imatch-toolbox/third_party/d2net/megadepth_utils/undistort_reconstructions.py +69 -0
  1660. vismatch/third_party/imatch-toolbox/third_party/d2net/train.py +279 -0
  1661. vismatch/third_party/imatch-toolbox/third_party/patch2pix/data_pairs/precompute_immatch_val_ovs.py +20 -0
  1662. vismatch/third_party/imatch-toolbox/third_party/patch2pix/environment.yml +21 -0
  1663. vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/modules.py +167 -0
  1664. vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/ncn/conv4d.py +91 -0
  1665. vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/ncn/extract_ncmatches.py +158 -0
  1666. vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/ncn/model.py +333 -0
  1667. vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/patch2pix.py +403 -0
  1668. vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/resnet.py +191 -0
  1669. vismatch/third_party/imatch-toolbox/third_party/patch2pix/networks/utils.py +111 -0
  1670. vismatch/third_party/imatch-toolbox/third_party/patch2pix/train_patch2pix.py +374 -0
  1671. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/colmap/data_loading.py +169 -0
  1672. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/colmap/read_database.py +175 -0
  1673. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/colmap/read_write_model.py +483 -0
  1674. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/common/plotting.py +393 -0
  1675. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/common/setup_helper.py +59 -0
  1676. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/common/visdom_helper.py +95 -0
  1677. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/__init__.py +1 -0
  1678. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/data_parsing.py +145 -0
  1679. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/dataset_megadepth.py +141 -0
  1680. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/datasets/preprocess.py +184 -0
  1681. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/eval/geometry.py +90 -0
  1682. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/eval/measure.py +161 -0
  1683. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/eval/model_helper.py +129 -0
  1684. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/train/eval_epoch_immatch.py +99 -0
  1685. vismatch/third_party/imatch-toolbox/third_party/patch2pix/utils/train/helper.py +196 -0
  1686. vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/__init__.py +33 -0
  1687. vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/aachen.py +146 -0
  1688. vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/dataset.py +77 -0
  1689. vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/imgfolder.py +23 -0
  1690. vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/pair_dataset.py +287 -0
  1691. vismatch/third_party/imatch-toolbox/third_party/r2d2/datasets/web_images.py +64 -0
  1692. vismatch/third_party/imatch-toolbox/third_party/r2d2/extract.py +183 -0
  1693. vismatch/third_party/imatch-toolbox/third_party/r2d2/extract_kapture.py +194 -0
  1694. vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/ap_loss.py +67 -0
  1695. vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/losses.py +56 -0
  1696. vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/patchnet.py +134 -0
  1697. vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/reliability_loss.py +59 -0
  1698. vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/repeatability_loss.py +66 -0
  1699. vismatch/third_party/imatch-toolbox/third_party/r2d2/nets/sampler.py +390 -0
  1700. vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/common.py +41 -0
  1701. vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/dataloader.py +367 -0
  1702. vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/trainer.py +76 -0
  1703. vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/transforms.py +513 -0
  1704. vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/transforms_tools.py +230 -0
  1705. vismatch/third_party/imatch-toolbox/third_party/r2d2/tools/viz.py +191 -0
  1706. vismatch/third_party/imatch-toolbox/third_party/r2d2/train.py +138 -0
  1707. vismatch/third_party/imatch-toolbox/third_party/r2d2/viz_heatmaps.py +122 -0
  1708. vismatch/third_party/imatch-toolbox/third_party/superglue/demo_superglue.py +259 -0
  1709. vismatch/third_party/imatch-toolbox/third_party/superglue/match_pairs.py +425 -0
  1710. vismatch/third_party/imatch-toolbox/third_party/superglue/models/__init__.py +0 -0
  1711. vismatch/third_party/imatch-toolbox/third_party/superglue/models/matching.py +84 -0
  1712. vismatch/third_party/imatch-toolbox/third_party/superglue/models/superglue.py +283 -0
  1713. vismatch/third_party/imatch-toolbox/third_party/superglue/models/superpoint.py +202 -0
  1714. vismatch/third_party/imatch-toolbox/third_party/superglue/models/utils.py +555 -0
  1715. vismatch/third_party/keypt2subpx/dataprocess/aliked.py +163 -0
  1716. vismatch/third_party/keypt2subpx/dataprocess/dedode.py +215 -0
  1717. vismatch/third_party/keypt2subpx/dataprocess/splg.py +162 -0
  1718. vismatch/third_party/keypt2subpx/dataprocess/spnn.py +157 -0
  1719. vismatch/third_party/keypt2subpx/dataprocess/superpoint_densescore.py +357 -0
  1720. vismatch/third_party/keypt2subpx/dataprocess/xfeat.py +187 -0
  1721. vismatch/third_party/keypt2subpx/dataset.py +145 -0
  1722. vismatch/third_party/keypt2subpx/hubconf.py +38 -0
  1723. vismatch/third_party/keypt2subpx/logger.py +127 -0
  1724. vismatch/third_party/keypt2subpx/model.py +183 -0
  1725. vismatch/third_party/keypt2subpx/settings.py +108 -0
  1726. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/__init__.py +17 -0
  1727. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/__init__.py +25 -0
  1728. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/augmentations.py +244 -0
  1729. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/base_dataset.py +206 -0
  1730. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/eth3d.py +254 -0
  1731. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/homographies.py +311 -0
  1732. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/hpatches.py +145 -0
  1733. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/image_folder.py +59 -0
  1734. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/image_pairs.py +100 -0
  1735. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/megadepth.py +510 -0
  1736. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/datasets/utils.py +131 -0
  1737. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/__init__.py +20 -0
  1738. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/eth3d.py +202 -0
  1739. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/eval_pipeline.py +109 -0
  1740. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/hpatches.py +203 -0
  1741. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/inspect.py +61 -0
  1742. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/io.py +109 -0
  1743. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/megadepth1500.py +189 -0
  1744. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/eval/utils.py +272 -0
  1745. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/__init__.py +0 -0
  1746. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/depth.py +88 -0
  1747. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/epipolar.py +155 -0
  1748. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/gt_generation.py +558 -0
  1749. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/homography.py +342 -0
  1750. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/utils.py +167 -0
  1751. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/geometry/wrappers.py +425 -0
  1752. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/__init__.py +30 -0
  1753. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/backbones/__init__.py +0 -0
  1754. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/backbones/dinov2.py +30 -0
  1755. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/base_model.py +157 -0
  1756. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/cache_loader.py +139 -0
  1757. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/__init__.py +0 -0
  1758. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/aliked.py +786 -0
  1759. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/disk_kornia.py +108 -0
  1760. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/grid_extractor.py +60 -0
  1761. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/keynet_affnet_hardnet.py +74 -0
  1762. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/mixed.py +76 -0
  1763. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/sift.py +234 -0
  1764. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/sift_kornia.py +46 -0
  1765. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/extractors/superpoint_open.py +210 -0
  1766. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/__init__.py +0 -0
  1767. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/deeplsd.py +106 -0
  1768. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/lsd.py +88 -0
  1769. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/lines/wireframe.py +312 -0
  1770. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/__init__.py +0 -0
  1771. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/adalam.py +0 -0
  1772. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/depth_matcher.py +82 -0
  1773. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/gluestick.py +776 -0
  1774. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/homography_matcher.py +66 -0
  1775. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/kornia_loftr.py +66 -0
  1776. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/lightglue.py +612 -0
  1777. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/lightglue_pretrained.py +36 -0
  1778. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/matchers/nearest_neighbor_matcher.py +97 -0
  1779. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/triplet_pipeline.py +99 -0
  1780. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/two_view_pipeline.py +114 -0
  1781. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/__init__.py +0 -0
  1782. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/losses.py +73 -0
  1783. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/metrics.py +50 -0
  1784. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/models/utils/misc.py +70 -0
  1785. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/__init__.py +15 -0
  1786. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/base_estimator.py +33 -0
  1787. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/__init__.py +0 -0
  1788. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/homography_est.py +74 -0
  1789. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/opencv.py +53 -0
  1790. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/homography/poselib.py +40 -0
  1791. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/__init__.py +0 -0
  1792. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/opencv.py +64 -0
  1793. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/poselib.py +44 -0
  1794. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/robust_estimators/relative_pose/pycolmap.py +52 -0
  1795. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/scripts/__init__.py +0 -0
  1796. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/scripts/export_local_features.py +127 -0
  1797. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/scripts/export_megadepth.py +173 -0
  1798. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/settings.py +6 -0
  1799. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/train.py +691 -0
  1800. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/__init__.py +0 -0
  1801. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/benchmark.py +33 -0
  1802. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/experiments.py +134 -0
  1803. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/export_predictions.py +81 -0
  1804. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/image.py +130 -0
  1805. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/misc.py +44 -0
  1806. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/patches.py +50 -0
  1807. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/stdout_capturing.py +134 -0
  1808. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/tensor.py +48 -0
  1809. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/utils/tools.py +269 -0
  1810. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/global_frame.py +289 -0
  1811. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/tools.py +465 -0
  1812. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/two_view_frame.py +158 -0
  1813. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/visualize_batch.py +57 -0
  1814. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory/visualization/viz2d.py +486 -0
  1815. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory_nonfree/__init__.py +0 -0
  1816. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory_nonfree/superglue.py +342 -0
  1817. vismatch/third_party/keypt2subpx/submodules/glue_factory/gluefactory_nonfree/superpoint.py +356 -0
  1818. vismatch/third_party/keypt2subpx/submodules/glue_factory/tests/__init__.py +0 -0
  1819. vismatch/third_party/keypt2subpx/submodules/glue_factory/tests/test_eval_utils.py +88 -0
  1820. vismatch/third_party/keypt2subpx/submodules/glue_factory/tests/test_integration.py +132 -0
  1821. vismatch/third_party/keypt2subpx/summarize.py +44 -0
  1822. vismatch/third_party/keypt2subpx/test.py +225 -0
  1823. vismatch/third_party/keypt2subpx/train.py +180 -0
  1824. vismatch/third_party/keypt2subpx/utils.py +150 -0
  1825. vismatch/third_party/mast3r/demo.py +51 -0
  1826. vismatch/third_party/mast3r/demo_dust3r_ga.py +99 -0
  1827. vismatch/third_party/mast3r/demo_glomap.py +52 -0
  1828. vismatch/third_party/mast3r/dust3r/croco/datasets/__init__.py +0 -0
  1829. vismatch/third_party/mast3r/dust3r/croco/datasets/crops/extract_crops_from_images.py +159 -0
  1830. vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/__init__.py +0 -0
  1831. vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
  1832. vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
  1833. vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
  1834. vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
  1835. vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
  1836. vismatch/third_party/mast3r/dust3r/croco/datasets/habitat_sim/paths.py +129 -0
  1837. vismatch/third_party/mast3r/dust3r/croco/datasets/pairs_dataset.py +109 -0
  1838. vismatch/third_party/mast3r/dust3r/croco/datasets/transforms.py +95 -0
  1839. vismatch/third_party/mast3r/dust3r/croco/demo.py +55 -0
  1840. vismatch/third_party/mast3r/dust3r/croco/models/blocks.py +241 -0
  1841. vismatch/third_party/mast3r/dust3r/croco/models/criterion.py +37 -0
  1842. vismatch/third_party/mast3r/dust3r/croco/models/croco.py +249 -0
  1843. vismatch/third_party/mast3r/dust3r/croco/models/croco_downstream.py +122 -0
  1844. vismatch/third_party/mast3r/dust3r/croco/models/curope/__init__.py +4 -0
  1845. vismatch/third_party/mast3r/dust3r/croco/models/curope/curope2d.py +40 -0
  1846. vismatch/third_party/mast3r/dust3r/croco/models/curope/setup.py +34 -0
  1847. vismatch/third_party/mast3r/dust3r/croco/models/dpt_block.py +450 -0
  1848. vismatch/third_party/mast3r/dust3r/croco/models/head_downstream.py +58 -0
  1849. vismatch/third_party/mast3r/dust3r/croco/models/masking.py +25 -0
  1850. vismatch/third_party/mast3r/dust3r/croco/models/pos_embed.py +157 -0
  1851. vismatch/third_party/mast3r/dust3r/croco/pretrain.py +254 -0
  1852. vismatch/third_party/mast3r/dust3r/croco/stereoflow/augmentor.py +290 -0
  1853. vismatch/third_party/mast3r/dust3r/croco/stereoflow/criterion.py +251 -0
  1854. vismatch/third_party/mast3r/dust3r/croco/stereoflow/datasets_flow.py +630 -0
  1855. vismatch/third_party/mast3r/dust3r/croco/stereoflow/datasets_stereo.py +674 -0
  1856. vismatch/third_party/mast3r/dust3r/croco/stereoflow/engine.py +280 -0
  1857. vismatch/third_party/mast3r/dust3r/croco/stereoflow/test.py +216 -0
  1858. vismatch/third_party/mast3r/dust3r/croco/stereoflow/train.py +253 -0
  1859. vismatch/third_party/mast3r/dust3r/croco/utils/misc.py +463 -0
  1860. vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/find_scenes.py +78 -0
  1861. vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/__init__.py +2 -0
  1862. vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py +170 -0
  1863. vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py +93 -0
  1864. vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections.py +151 -0
  1865. vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/habitat_renderer/projections_conversions.py +45 -0
  1866. vismatch/third_party/mast3r/dust3r/datasets_preprocess/habitat/preprocess_habitat.py +121 -0
  1867. vismatch/third_party/mast3r/dust3r/datasets_preprocess/path_to_root.py +13 -0
  1868. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_arkitscenes.py +355 -0
  1869. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_blendedMVS.py +149 -0
  1870. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_co3d.py +295 -0
  1871. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_megadepth.py +198 -0
  1872. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_scannetpp.py +390 -0
  1873. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_staticthings3d.py +130 -0
  1874. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_waymo.py +257 -0
  1875. vismatch/third_party/mast3r/dust3r/datasets_preprocess/preprocess_wildrgbd.py +209 -0
  1876. vismatch/third_party/mast3r/dust3r/demo.py +45 -0
  1877. vismatch/third_party/mast3r/dust3r/dust3r/__init__.py +2 -0
  1878. vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/__init__.py +33 -0
  1879. vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/base_opt.py +405 -0
  1880. vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/commons.py +90 -0
  1881. vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/init_im_poses.py +316 -0
  1882. vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/modular_optimizer.py +145 -0
  1883. vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/optimizer.py +248 -0
  1884. vismatch/third_party/mast3r/dust3r/dust3r/cloud_opt/pair_viewer.py +127 -0
  1885. vismatch/third_party/mast3r/dust3r/dust3r/datasets/__init__.py +50 -0
  1886. vismatch/third_party/mast3r/dust3r/dust3r/datasets/arkitscenes.py +102 -0
  1887. vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/__init__.py +2 -0
  1888. vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
  1889. vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/batched_sampler.py +74 -0
  1890. vismatch/third_party/mast3r/dust3r/dust3r/datasets/base/easy_dataset.py +157 -0
  1891. vismatch/third_party/mast3r/dust3r/dust3r/datasets/blendedmvs.py +104 -0
  1892. vismatch/third_party/mast3r/dust3r/dust3r/datasets/co3d.py +165 -0
  1893. vismatch/third_party/mast3r/dust3r/dust3r/datasets/habitat.py +107 -0
  1894. vismatch/third_party/mast3r/dust3r/dust3r/datasets/megadepth.py +123 -0
  1895. vismatch/third_party/mast3r/dust3r/dust3r/datasets/scannetpp.py +96 -0
  1896. vismatch/third_party/mast3r/dust3r/dust3r/datasets/staticthings3d.py +96 -0
  1897. vismatch/third_party/mast3r/dust3r/dust3r/datasets/utils/__init__.py +2 -0
  1898. vismatch/third_party/mast3r/dust3r/dust3r/datasets/utils/cropping.py +124 -0
  1899. vismatch/third_party/mast3r/dust3r/dust3r/datasets/utils/transforms.py +11 -0
  1900. vismatch/third_party/mast3r/dust3r/dust3r/datasets/waymo.py +93 -0
  1901. vismatch/third_party/mast3r/dust3r/dust3r/datasets/wildrgbd.py +67 -0
  1902. vismatch/third_party/mast3r/dust3r/dust3r/demo.py +287 -0
  1903. vismatch/third_party/mast3r/dust3r/dust3r/heads/__init__.py +19 -0
  1904. vismatch/third_party/mast3r/dust3r/dust3r/heads/dpt_head.py +115 -0
  1905. vismatch/third_party/mast3r/dust3r/dust3r/heads/linear_head.py +41 -0
  1906. vismatch/third_party/mast3r/dust3r/dust3r/heads/postprocess.py +58 -0
  1907. vismatch/third_party/mast3r/dust3r/dust3r/image_pairs.py +104 -0
  1908. vismatch/third_party/mast3r/dust3r/dust3r/inference.py +150 -0
  1909. vismatch/third_party/mast3r/dust3r/dust3r/losses.py +299 -0
  1910. vismatch/third_party/mast3r/dust3r/dust3r/model.py +211 -0
  1911. vismatch/third_party/mast3r/dust3r/dust3r/optim_factory.py +14 -0
  1912. vismatch/third_party/mast3r/dust3r/dust3r/patch_embed.py +70 -0
  1913. vismatch/third_party/mast3r/dust3r/dust3r/post_process.py +60 -0
  1914. vismatch/third_party/mast3r/dust3r/dust3r/training.py +377 -0
  1915. vismatch/third_party/mast3r/dust3r/dust3r/utils/__init__.py +2 -0
  1916. vismatch/third_party/mast3r/dust3r/dust3r/utils/device.py +76 -0
  1917. vismatch/third_party/mast3r/dust3r/dust3r/utils/geometry.py +366 -0
  1918. vismatch/third_party/mast3r/dust3r/dust3r/utils/image.py +128 -0
  1919. vismatch/third_party/mast3r/dust3r/dust3r/utils/misc.py +121 -0
  1920. vismatch/third_party/mast3r/dust3r/dust3r/utils/parallel.py +79 -0
  1921. vismatch/third_party/mast3r/dust3r/dust3r/utils/path_to_croco.py +19 -0
  1922. vismatch/third_party/mast3r/dust3r/dust3r/viz.py +381 -0
  1923. vismatch/third_party/mast3r/dust3r/dust3r_visloc/__init__.py +2 -0
  1924. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/__init__.py +6 -0
  1925. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/aachen_day_night.py +24 -0
  1926. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_colmap.py +282 -0
  1927. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/base_dataset.py +19 -0
  1928. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/cambridge_landmarks.py +19 -0
  1929. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/inloc.py +167 -0
  1930. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/sevenscenes.py +123 -0
  1931. vismatch/third_party/mast3r/dust3r/dust3r_visloc/datasets/utils.py +118 -0
  1932. vismatch/third_party/mast3r/dust3r/dust3r_visloc/evaluation.py +65 -0
  1933. vismatch/third_party/mast3r/dust3r/dust3r_visloc/localization.py +140 -0
  1934. vismatch/third_party/mast3r/dust3r/train.py +13 -0
  1935. vismatch/third_party/mast3r/dust3r/visloc.py +193 -0
  1936. vismatch/third_party/mast3r/kapture_mast3r_mapping.py +127 -0
  1937. vismatch/third_party/mast3r/make_pairs.py +105 -0
  1938. vismatch/third_party/mast3r/mast3r/__init__.py +2 -0
  1939. vismatch/third_party/mast3r/mast3r/catmlp_dpt_head.py +239 -0
  1940. vismatch/third_party/mast3r/mast3r/cloud_opt/__init__.py +2 -0
  1941. vismatch/third_party/mast3r/mast3r/cloud_opt/sparse_ga.py +1078 -0
  1942. vismatch/third_party/mast3r/mast3r/cloud_opt/triangulation.py +80 -0
  1943. vismatch/third_party/mast3r/mast3r/cloud_opt/tsdf_optimizer.py +273 -0
  1944. vismatch/third_party/mast3r/mast3r/cloud_opt/utils/__init__.py +2 -0
  1945. vismatch/third_party/mast3r/mast3r/cloud_opt/utils/losses.py +32 -0
  1946. vismatch/third_party/mast3r/mast3r/cloud_opt/utils/schedules.py +17 -0
  1947. vismatch/third_party/mast3r/mast3r/colmap/__init__.py +2 -0
  1948. vismatch/third_party/mast3r/mast3r/colmap/database.py +383 -0
  1949. vismatch/third_party/mast3r/mast3r/colmap/mapping.py +196 -0
  1950. vismatch/third_party/mast3r/mast3r/datasets/__init__.py +62 -0
  1951. vismatch/third_party/mast3r/mast3r/datasets/base/__init__.py +2 -0
  1952. vismatch/third_party/mast3r/mast3r/datasets/base/mast3r_base_stereo_view_dataset.py +355 -0
  1953. vismatch/third_party/mast3r/mast3r/datasets/utils/__init__.py +2 -0
  1954. vismatch/third_party/mast3r/mast3r/datasets/utils/cropping.py +219 -0
  1955. vismatch/third_party/mast3r/mast3r/demo.py +381 -0
  1956. vismatch/third_party/mast3r/mast3r/demo_glomap.py +343 -0
  1957. vismatch/third_party/mast3r/mast3r/fast_nn.py +223 -0
  1958. vismatch/third_party/mast3r/mast3r/image_pairs.py +115 -0
  1959. vismatch/third_party/mast3r/mast3r/losses.py +508 -0
  1960. vismatch/third_party/mast3r/mast3r/model.py +213 -0
  1961. vismatch/third_party/mast3r/mast3r/retrieval/graph.py +77 -0
  1962. vismatch/third_party/mast3r/mast3r/retrieval/model.py +271 -0
  1963. vismatch/third_party/mast3r/mast3r/retrieval/processor.py +129 -0
  1964. vismatch/third_party/mast3r/mast3r/utils/__init__.py +2 -0
  1965. vismatch/third_party/mast3r/mast3r/utils/coarse_to_fine.py +214 -0
  1966. vismatch/third_party/mast3r/mast3r/utils/collate.py +62 -0
  1967. vismatch/third_party/mast3r/mast3r/utils/misc.py +17 -0
  1968. vismatch/third_party/mast3r/mast3r/utils/path_to_dust3r.py +19 -0
  1969. vismatch/third_party/mast3r/train.py +48 -0
  1970. vismatch/third_party/mast3r/visloc.py +538 -0
  1971. vismatch/third_party/omniglue/__init__.py +19 -0
  1972. vismatch/third_party/omniglue/demo.py +89 -0
  1973. vismatch/third_party/omniglue/src/omniglue/__init__.py +17 -0
  1974. vismatch/third_party/omniglue/src/omniglue/dino_extract.py +215 -0
  1975. vismatch/third_party/omniglue/src/omniglue/omniglue_extract.py +159 -0
  1976. vismatch/third_party/omniglue/src/omniglue/superpoint_extract.py +214 -0
  1977. vismatch/third_party/omniglue/src/omniglue/utils.py +274 -0
  1978. vismatch/third_party/omniglue/third_party/dinov2/__init__.py +0 -0
  1979. vismatch/third_party/omniglue/third_party/dinov2/dino.py +411 -0
  1980. vismatch/third_party/omniglue/third_party/dinov2/dino_utils.py +341 -0
  1981. vismatch/third_party/rdd/RDD/RDD.py +262 -0
  1982. vismatch/third_party/rdd/RDD/RDD_helper.py +181 -0
  1983. vismatch/third_party/rdd/RDD/dataset/__init__.py +0 -0
  1984. vismatch/third_party/rdd/RDD/dataset/megadepth/__init__.py +2 -0
  1985. vismatch/third_party/rdd/RDD/dataset/megadepth/megadepth.py +313 -0
  1986. vismatch/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py +75 -0
  1987. vismatch/third_party/rdd/RDD/dataset/megadepth/utils.py +848 -0
  1988. vismatch/third_party/rdd/RDD/matchers/__init__.py +3 -0
  1989. vismatch/third_party/rdd/RDD/matchers/dense_matcher.py +137 -0
  1990. vismatch/third_party/rdd/RDD/matchers/dual_softmax_matcher.py +31 -0
  1991. vismatch/third_party/rdd/RDD/matchers/lightglue.py +667 -0
  1992. vismatch/third_party/rdd/RDD/models/backbone.py +147 -0
  1993. vismatch/third_party/rdd/RDD/models/deformable_transformer.py +270 -0
  1994. vismatch/third_party/rdd/RDD/models/descriptor.py +116 -0
  1995. vismatch/third_party/rdd/RDD/models/detector.py +141 -0
  1996. vismatch/third_party/rdd/RDD/models/interpolator.py +33 -0
  1997. vismatch/third_party/rdd/RDD/models/ops/functions/__init__.py +13 -0
  1998. vismatch/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py +74 -0
  1999. vismatch/third_party/rdd/RDD/models/ops/modules/__init__.py +12 -0
  2000. vismatch/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py +125 -0
  2001. vismatch/third_party/rdd/RDD/models/ops/setup.py +78 -0
  2002. vismatch/third_party/rdd/RDD/models/ops/test.py +92 -0
  2003. vismatch/third_party/rdd/RDD/models/position_encoding.py +48 -0
  2004. vismatch/third_party/rdd/RDD/models/soft_detect.py +176 -0
  2005. vismatch/third_party/rdd/RDD/utils/__init__.py +1 -0
  2006. vismatch/third_party/rdd/RDD/utils/misc.py +531 -0
  2007. vismatch/third_party/rdd/benchmarks/air_ground.py +250 -0
  2008. vismatch/third_party/rdd/benchmarks/mega_1500.py +259 -0
  2009. vismatch/third_party/rdd/benchmarks/mega_view.py +252 -0
  2010. vismatch/third_party/rdd/benchmarks/scannet_1500.py +251 -0
  2011. vismatch/third_party/rdd/benchmarks/utils.py +112 -0
  2012. vismatch/third_party/rdd/configs/default.yaml +19 -0
  2013. vismatch/third_party/rdd/sfm/extract_rdd.py +145 -0
  2014. vismatch/third_party/rdd/sfm/match_rdd.py +259 -0
  2015. vismatch/third_party/rdd/third_party/LightGlue/.github/workflows/code-quality.yml +24 -0
  2016. vismatch/third_party/rdd/third_party/LightGlue/benchmark.py +255 -0
  2017. vismatch/third_party/rdd/third_party/LightGlue/lightglue/__init__.py +7 -0
  2018. vismatch/third_party/rdd/third_party/LightGlue/lightglue/aliked.py +760 -0
  2019. vismatch/third_party/rdd/third_party/LightGlue/lightglue/disk.py +55 -0
  2020. vismatch/third_party/rdd/third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
  2021. vismatch/third_party/rdd/third_party/LightGlue/lightglue/lightglue.py +662 -0
  2022. vismatch/third_party/rdd/third_party/LightGlue/lightglue/sift.py +216 -0
  2023. vismatch/third_party/rdd/third_party/LightGlue/lightglue/superpoint.py +227 -0
  2024. vismatch/third_party/rdd/third_party/LightGlue/lightglue/utils.py +165 -0
  2025. vismatch/third_party/rdd/third_party/LightGlue/lightglue/viz2d.py +203 -0
  2026. vismatch/third_party/rdd/third_party/__init__.py +1 -0
  2027. vismatch/third_party/rdd/third_party/aliked_wrapper.py +17 -0
  2028. vismatch/third_party/rdd/training/losses/descriptor_loss.py +73 -0
  2029. vismatch/third_party/rdd/training/losses/detector_loss.py +499 -0
  2030. vismatch/third_party/rdd/training/train.py +473 -0
  2031. vismatch/third_party/rdd/training/utils.py +246 -0
  2032. vismatch/utils.py +390 -0
  2033. vismatch/viz.py +222 -0
  2034. vismatch-1.1.1.dist-info/METADATA +265 -0
  2035. vismatch-1.1.1.dist-info/RECORD +2042 -0
  2036. vismatch-1.1.1.dist-info/WHEEL +5 -0
  2037. vismatch-1.1.1.dist-info/entry_points.txt +4 -0
  2038. vismatch-1.1.1.dist-info/licenses/LICENSE +28 -0
  2039. vismatch-1.1.1.dist-info/top_level.txt +4 -0
  2040. vismatch_extract.py +103 -0
  2041. vismatch_match.py +114 -0
  2042. vismatch_test.py +186 -0
@@ -0,0 +1,1768 @@
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .linear_attention import LinearAttention, RoPELinearAttention, FullAttention, XAttention
6
+ from einops.einops import rearrange
7
+ from collections import OrderedDict
8
+ from .transformer_utils import TokenConfidence, MatchAssignment, filter_matches
9
+ from ..utils.coarse_matching import CoarseMatching
10
+ from ..utils.position_encoding import RoPEPositionEncodingSine
11
+ import numpy as np
12
+ from loguru import logger
13
+
14
+ PFLASH_AVAILABLE = False
15
+
16
+ class PANEncoderLayer(nn.Module):
17
+ def __init__(self,
18
+ d_model,
19
+ nhead,
20
+ attention='linear',
21
+ pool_size=4,
22
+ bn=True,
23
+ xformer=False,
24
+ leaky=-1.0,
25
+ dw_conv=False,
26
+ scatter=False,
27
+ ):
28
+ super(PANEncoderLayer, self).__init__()
29
+
30
+ self.pool_size = pool_size
31
+ self.dw_conv = dw_conv
32
+ self.scatter = scatter
33
+ if self.dw_conv:
34
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
35
+
36
+ assert not self.scatter, 'buggy implemented here'
37
+ self.dim = d_model // nhead
38
+ self.nhead = nhead
39
+
40
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
41
+ # multi-head attention
42
+ if bn:
43
+ method = 'dw_bn'
44
+ else:
45
+ method = 'dw'
46
+ self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
47
+ self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
48
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
49
+
50
+ # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
51
+ # self.k_proj = nn.Linear(d_model, d_model, bias=False)
52
+ # self.v_proj = nn.Linear(d_model, d_model, bias=False)
53
+ if xformer:
54
+ self.attention = XAttention()
55
+ else:
56
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
57
+ self.merge = nn.Linear(d_model, d_model, bias=False)
58
+
59
+ # feed-forward network
60
+ if leaky > 0:
61
+ self.mlp = nn.Sequential(
62
+ nn.Linear(d_model*2, d_model*2, bias=False),
63
+ nn.LeakyReLU(leaky, True),
64
+ nn.Linear(d_model*2, d_model, bias=False),
65
+ )
66
+
67
+ else:
68
+ self.mlp = nn.Sequential(
69
+ nn.Linear(d_model*2, d_model*2, bias=False),
70
+ nn.ReLU(True),
71
+ nn.Linear(d_model*2, d_model, bias=False),
72
+ )
73
+
74
+ # norm and dropout
75
+ self.norm1 = nn.LayerNorm(d_model)
76
+ self.norm2 = nn.LayerNorm(d_model)
77
+
78
+ # self.norm1 = nn.BatchNorm2d(d_model)
79
+
80
+ def forward(self, x, source, x_mask=None, source_mask=None):
81
+ """
82
+ Args:
83
+ x (torch.Tensor): [N, C, H1, W1]
84
+ source (torch.Tensor): [N, C, H2, W2]
85
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
86
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
87
+ """
88
+ bs = x.size(0)
89
+ H1, W1 = x.size(-2), x.size(-1)
90
+ H2, W2 = source.size(-2), source.size(-1)
91
+
92
+ query, key, value = x, source, source
93
+
94
+ if self.dw_conv:
95
+ query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
96
+ else:
97
+ query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
98
+ # only need to cal key or value...
99
+ key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
100
+ value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
101
+
102
+ # After 0617 bnorm to prevent permute*6
103
+ # query = self.norm1(self.max_pool(query))
104
+ # key = self.norm1(self.max_pool(key))
105
+ # value = self.norm1(self.max_pool(value))
106
+ # multi-head attention
107
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
108
+ key = self.k_proj_conv(key)
109
+ value = self.v_proj_conv(value)
110
+
111
+ C = query.shape[-3]
112
+
113
+ ismask = x_mask is not None and source_mask is not None
114
+ if bs == 1 or not ismask:
115
+ if ismask:
116
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
117
+ source_mask = self.max_pool(source_mask.float()).bool()
118
+
119
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
120
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
121
+
122
+ query = query[:, :, :mask_h0, :mask_w0]
123
+ key = key[:, :, :mask_h1, :mask_w1]
124
+ value = value[:, :, :mask_h1, :mask_w1]
125
+
126
+ else:
127
+ assert x_mask is None and source_mask is None
128
+
129
+ # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
130
+ # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
131
+ # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
132
+ if PFLASH_AVAILABLE: # N H L D
133
+ query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
134
+ key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
135
+ value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
136
+
137
+ else: # N L H D
138
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
139
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
140
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
141
+
142
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
143
+
144
+ if PFLASH_AVAILABLE: # N H L D
145
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
146
+
147
+ if ismask:
148
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
149
+ if mask_h0 != x_mask.size(-2):
150
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
151
+ elif mask_w0 != x_mask.size(-1):
152
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
153
+ # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
154
+
155
+ else:
156
+ assert x_mask is None and source_mask is None
157
+
158
+
159
+ message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
160
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
161
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
162
+
163
+ if self.scatter:
164
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
165
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
166
+ # message = self.aggregate(message)
167
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
168
+ else:
169
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
170
+
171
+ # message = self.norm1(message)
172
+
173
+ # feed-forward network
174
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
175
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
176
+
177
+ return x + message
178
+ else:
179
+ x_mask = self.max_pool(x_mask.float()).bool()
180
+ source_mask = self.max_pool(source_mask.float()).bool()
181
+ m_list = []
182
+ for i in range(bs):
183
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
184
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
185
+
186
+ q = query[i:i+1, :, :mask_h0, :mask_w0]
187
+ k = key[i:i+1, :, :mask_h1, :mask_w1]
188
+ v = value[i:i+1, :, :mask_h1, :mask_w1]
189
+
190
+ if PFLASH_AVAILABLE: # N H L D
191
+ q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
192
+ k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
193
+ v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
194
+
195
+ else: # N L H D
196
+
197
+ q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
198
+ k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
199
+ v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
200
+
201
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
202
+
203
+ if PFLASH_AVAILABLE: # N H L D
204
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
205
+
206
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
207
+ if mask_h0 != x_mask.size(-2):
208
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
209
+ elif mask_w0 != x_mask.size(-1):
210
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
211
+ m_list.append(m)
212
+ message = torch.cat(m_list, dim=0)
213
+
214
+
215
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
216
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
217
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
218
+
219
+ if self.scatter:
220
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
221
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
222
+ # message = self.aggregate(message)
223
+ # assert False
224
+ else:
225
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
226
+
227
+ # message = self.norm1(message)
228
+
229
+ # feed-forward network
230
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
231
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
232
+
233
+ return x + message
234
+
235
+
236
+ def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
237
+ """
238
+ Args:
239
+ x (torch.Tensor): [N, C, H1, W1]
240
+ source (torch.Tensor): [N, C, H2, W2]
241
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
242
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
243
+ """
244
+ bs = x.size(0)
245
+ H1, W1 = x.size(-2), x.size(-1)
246
+ H2, W2 = source.size(-2), source.size(-1)
247
+
248
+ query, key, value = x, source, source
249
+
250
+ with profiler.profile("permute*6+norm1*3+max_pool*3"):
251
+ if self.dw_conv:
252
+ query = self.norm1(self.aggregate(query).permute(0,2,3,1)).permute(0,3,1,2)
253
+ else:
254
+ query = self.norm1(self.max_pool(query).permute(0,2,3,1)).permute(0,3,1,2)
255
+ # only need to cal key or value...
256
+ key = self.norm1(self.max_pool(key).permute(0,2,3,1)).permute(0,3,1,2)
257
+ value = self.norm1(self.max_pool(value).permute(0,2,3,1)).permute(0,3,1,2)
258
+
259
+ with profiler.profile("permute*6"):
260
+ query = query.permute(0, 2, 3, 1)
261
+ key = key.permute(0, 2, 3, 1)
262
+ value = value.permute(0, 2, 3, 1)
263
+
264
+ query = query.permute(0,3,1,2)
265
+ key = key.permute(0,3,1,2)
266
+ value = value.permute(0,3,1,2)
267
+
268
+ # query = self.bnorm1(self.max_pool(query))
269
+ # key = self.bnorm1(self.max_pool(key))
270
+ # value = self.bnorm1(self.max_pool(value))
271
+ # multi-head attention
272
+
273
+ with profiler.profile("q_conv+k_conv+v_conv"):
274
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
275
+ key = self.k_proj_conv(key)
276
+ value = self.v_proj_conv(value)
277
+
278
+ C = query.shape[-3]
279
+ # TODO: Need to be consistent with bs=1 (where mask region do not in attention at all)
280
+ if x_mask is not None and source_mask is not None:
281
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
282
+ source_mask = self.max_pool(source_mask.float()).bool()
283
+
284
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
285
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
286
+
287
+ query = query[:, :, :mask_h0, :mask_w0]
288
+ key = key[:, :, :mask_h1, :mask_w1]
289
+ value = value[:, :, :mask_h1, :mask_w1]
290
+
291
+ # mask_h0, mask_w0 = data['mask0'][0].sum(-2)[0], data['mask0'][0].sum(-1)[0]
292
+ # mask_h1, mask_w1 = data['mask1'][0].sum(-2)[0], data['mask1'][0].sum(-1)[0]
293
+ # C = feat_c0.shape[-3]
294
+ # feat_c0 = feat_c0[:, :, :mask_h0, :mask_w0]
295
+ # feat_c1 = feat_c1[:, :, :mask_h1, :mask_w1]
296
+
297
+
298
+ # feat_c0 = feat_c0.reshape(-1, mask_h0, mask_w0, C)
299
+ # feat_c1 = feat_c1.reshape(-1, mask_h1, mask_w1, C)
300
+ # if mask_h0 != data['mask0'].size(-2):
301
+ # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0]-mask_h0, data['hw0_c'][1], C, device=feat_c0.device)], dim=1)
302
+ # elif mask_w0 != data['mask0'].size(-1):
303
+ # feat_c0 = torch.cat([feat_c0, torch.zeros(feat_c0.size(0), data['hw0_c'][0], data['hw0_c'][1]-mask_w0, C, device=feat_c0.device)], dim=2)
304
+
305
+ # if mask_h1 != data['mask1'].size(-2):
306
+ # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0]-mask_h1, data['hw1_c'][1], C, device=feat_c1.device)], dim=1)
307
+ # elif mask_w1 != data['mask1'].size(-1):
308
+ # feat_c1 = torch.cat([feat_c1, torch.zeros(feat_c1.size(0), data['hw1_c'][0], data['hw1_c'][1]-mask_w1, C, device=feat_c1.device)], dim=2)
309
+
310
+
311
+ else:
312
+ assert x_mask is None and source_mask is None
313
+
314
+
315
+
316
+ # query = query.reshape(bs, -1, self.nhead, self.dim) # [N, L, H, D]
317
+ # key = key.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
318
+ # value = value.reshape(bs, -1, self.nhead, self.dim) # [N, S, H, D]
319
+
320
+ with profiler.profile("rearrange*3"):
321
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
322
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
323
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
324
+
325
+ with profiler.profile("attention"):
326
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D]
327
+
328
+ if x_mask is not None and source_mask is not None:
329
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim)
330
+ if mask_h0 != x_mask.size(-2):
331
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
332
+ elif mask_w0 != x_mask.size(-1):
333
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
334
+ # message = message.view(bs, -1, self.nhead*self.dim) # [N, L, C]
335
+
336
+ else:
337
+ assert x_mask is None and source_mask is None
338
+
339
+ with profiler.profile("merge"):
340
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
341
+ # message = message.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] bug???
342
+
343
+ with profiler.profile("rearrange*1"):
344
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
345
+
346
+ with profiler.profile("upsample"):
347
+ if self.scatter:
348
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
349
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
350
+ # message = self.aggregate(message)
351
+ # assert False
352
+ else:
353
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
354
+
355
+ # message = self.norm1(message)
356
+
357
+ # feed-forward network
358
+ with profiler.profile("feed-forward_mlp+permute*2+norm2"):
359
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
360
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
361
+
362
+ return x + message
363
+
364
+
365
+ def _build_projection(self,
366
+ dim_in,
367
+ dim_out,
368
+ kernel_size=3,
369
+ padding=1,
370
+ stride=1,
371
+ method='dw_bn',
372
+ ):
373
+ if method == 'dw_bn':
374
+ proj = nn.Sequential(OrderedDict([
375
+ ('conv', nn.Conv2d(
376
+ dim_in,
377
+ dim_in,
378
+ kernel_size=kernel_size,
379
+ padding=padding,
380
+ stride=stride,
381
+ bias=False,
382
+ groups=dim_in
383
+ )),
384
+ ('bn', nn.BatchNorm2d(dim_in)),
385
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
386
+ ]))
387
+ elif method == 'avg':
388
+ proj = nn.Sequential(OrderedDict([
389
+ ('avg', nn.AvgPool2d(
390
+ kernel_size=kernel_size,
391
+ padding=padding,
392
+ stride=stride,
393
+ ceil_mode=True
394
+ )),
395
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
396
+ ]))
397
+ elif method == 'linear':
398
+ proj = None
399
+ elif method == 'dw':
400
+ proj = nn.Sequential(OrderedDict([
401
+ ('conv', nn.Conv2d(
402
+ dim_in,
403
+ dim_in,
404
+ kernel_size=kernel_size,
405
+ padding=padding,
406
+ stride=stride,
407
+ bias=False,
408
+ groups=dim_in
409
+ )),
410
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
411
+ ]))
412
+ else:
413
+ raise ValueError('Unknown method ({})'.format(method))
414
+
415
+ return proj
416
+
417
+ class AG_RoPE_EncoderLayer(nn.Module):
418
+ def __init__(self,
419
+ d_model,
420
+ nhead,
421
+ attention='linear',
422
+ pool_size=4,
423
+ pool_size2=4,
424
+ xformer=False,
425
+ leaky=-1.0,
426
+ dw_conv=False,
427
+ dw_conv2=False,
428
+ scatter=False,
429
+ norm_before=True,
430
+ rope=False,
431
+ npe=None,
432
+ vit_norm=False,
433
+ dw_proj=False,
434
+ ):
435
+ super(AG_RoPE_EncoderLayer, self).__init__()
436
+
437
+ self.pool_size = pool_size
438
+ self.pool_size2 = pool_size2
439
+ self.dw_conv = dw_conv
440
+ self.dw_conv2 = dw_conv2
441
+ self.scatter = scatter
442
+ self.norm_before = norm_before
443
+ self.vit_norm = vit_norm
444
+ self.dw_proj = dw_proj
445
+ self.rope = rope
446
+ if self.dw_conv and self.pool_size != 1:
447
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
448
+ if self.dw_conv2 and self.pool_size2 != 1:
449
+ self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size2, padding=0, stride=pool_size2, bias=False, groups=d_model)
450
+
451
+ self.dim = d_model // nhead
452
+ self.nhead = nhead
453
+
454
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size2, stride=self.pool_size2)
455
+
456
+ # multi-head attention
457
+ if self.dw_proj:
458
+ self.q_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
459
+ self.k_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
460
+ self.v_proj = nn.Conv2d(d_model, d_model, kernel_size=3, padding=1, stride=1, bias=False, groups=d_model)
461
+ else:
462
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
463
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
464
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
465
+
466
+ if self.rope:
467
+ self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
468
+
469
+ if xformer:
470
+ self.attention = XAttention()
471
+ else:
472
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
473
+ self.merge = nn.Linear(d_model, d_model, bias=False)
474
+
475
+ # feed-forward network
476
+ if leaky > 0:
477
+ if self.vit_norm:
478
+ self.mlp = nn.Sequential(
479
+ nn.Linear(d_model, d_model*2, bias=False),
480
+ nn.LeakyReLU(leaky, True),
481
+ nn.Linear(d_model*2, d_model, bias=False),
482
+ )
483
+ else:
484
+ self.mlp = nn.Sequential(
485
+ nn.Linear(d_model*2, d_model*2, bias=False),
486
+ nn.LeakyReLU(leaky, True),
487
+ nn.Linear(d_model*2, d_model, bias=False),
488
+ )
489
+
490
+ else:
491
+ if self.vit_norm:
492
+ self.mlp = nn.Sequential(
493
+ nn.Linear(d_model, d_model*2, bias=False),
494
+ nn.ReLU(True),
495
+ nn.Linear(d_model*2, d_model, bias=False),
496
+ )
497
+ else:
498
+ self.mlp = nn.Sequential(
499
+ nn.Linear(d_model*2, d_model*2, bias=False),
500
+ nn.ReLU(True),
501
+ nn.Linear(d_model*2, d_model, bias=False),
502
+ )
503
+
504
+ # norm and dropout
505
+ self.norm1 = nn.LayerNorm(d_model)
506
+ self.norm2 = nn.LayerNorm(d_model)
507
+
508
+ # self.norm1 = nn.BatchNorm2d(d_model)
509
+
510
+ def forward(self, x, source, x_mask=None, source_mask=None):
511
+ """
512
+ Args:
513
+ x (torch.Tensor): [N, C, H1, W1]
514
+ source (torch.Tensor): [N, C, H2, W2]
515
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
516
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
517
+ """
518
+ bs, C, H1, W1 = x.size()
519
+ H2, W2 = source.size(-2), source.size(-1)
520
+
521
+
522
+ if self.norm_before and not self.vit_norm:
523
+ if self.pool_size == 1:
524
+ query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
525
+ elif self.dw_conv:
526
+ query = self.norm1(self.aggregate(x).permute(0,2,3,1)) # [N, H, W, C]
527
+ else:
528
+ query = self.norm1(self.max_pool(x).permute(0,2,3,1)) # [N, H, W, C]
529
+ if self.pool_size2 == 1:
530
+ source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
531
+ elif self.dw_conv2:
532
+ source = self.norm1(self.aggregate2(source).permute(0,2,3,1)) # [N, H, W, C]
533
+ else:
534
+ source = self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
535
+ elif self.vit_norm:
536
+ if self.pool_size == 1:
537
+ query = self.norm1(x.permute(0,2,3,1)) # [N, H, W, C]
538
+ elif self.dw_conv:
539
+ query = self.aggregate(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
540
+ else:
541
+ query = self.max_pool(self.norm1(x.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
542
+ if self.pool_size2 == 1:
543
+ source = self.norm1(source.permute(0,2,3,1)) # [N, H, W, C]
544
+ elif self.dw_conv2:
545
+ source = self.aggregate2(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
546
+ else:
547
+ source = self.max_pool(self.norm1(source.permute(0,2,3,1)).permute(0,3,1,2)).permute(0,2,3,1) # [N, H, W, C]
548
+ else:
549
+ if self.pool_size == 1:
550
+ query = x.permute(0,2,3,1) # [N, H, W, C]
551
+ elif self.dw_conv:
552
+ query = self.aggregate(x).permute(0,2,3,1) # [N, H, W, C]
553
+ else:
554
+ query = self.max_pool(x).permute(0,2,3,1) # [N, H, W, C]
555
+ if self.pool_size2 == 1:
556
+ source = source.permute(0,2,3,1) # [N, H, W, C]
557
+ elif self.dw_conv2:
558
+ source = self.aggregate2(source).permute(0,2,3,1) # [N, H, W, C]
559
+ else:
560
+ source = self.max_pool(source).permute(0,2,3,1) # [N, H, W, C]
561
+
562
+ # projection
563
+ if self.dw_proj:
564
+ query = self.q_proj(query.permute(0,3,1,2)).permute(0,2,3,1)
565
+ key = self.k_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
566
+ value = self.v_proj(source.permute(0,3,1,2)).permute(0,2,3,1)
567
+ else:
568
+ query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
569
+
570
+ # RoPE
571
+ if self.rope:
572
+ query = self.rope_pos_enc(query)
573
+ if self.pool_size == 1 and self.pool_size2 == 4:
574
+ key = self.rope_pos_enc(key, 4)
575
+ else:
576
+ key = self.rope_pos_enc(key)
577
+
578
+ use_mask = x_mask is not None and source_mask is not None
579
+ if bs == 1 or not use_mask:
580
+ if use_mask:
581
+ # downsample mask
582
+ if self.pool_size ==1:
583
+ pass
584
+ else:
585
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
586
+
587
+ if self.pool_size2 ==1:
588
+ pass
589
+ else:
590
+ source_mask = self.max_pool(source_mask.float()).bool()
591
+
592
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
593
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
594
+
595
+ query = query[:, :mask_h0, :mask_w0, :]
596
+ key = key[:, :mask_h1, :mask_w1, :]
597
+ value = value[:, :mask_h1, :mask_w1, :]
598
+ else:
599
+ assert x_mask is None and source_mask is None
600
+
601
+ if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
602
+ query = rearrange(query, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
603
+ key = rearrange(key, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
604
+ value = rearrange(value, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
605
+ else: # N L H D
606
+ query = rearrange(query, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
607
+ key = rearrange(key, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
608
+ value = rearrange(value, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
609
+
610
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
611
+
612
+ if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
613
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
614
+
615
+ if use_mask: # padding zero
616
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L h D]
617
+ if mask_h0 != x_mask.size(-2):
618
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
619
+ elif mask_w0 != x_mask.size(-1):
620
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
621
+ else:
622
+ assert x_mask is None and source_mask is None
623
+
624
+ message = self.merge(message.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
625
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
626
+
627
+ if self.pool_size == 1:
628
+ pass
629
+ else:
630
+ if self.scatter:
631
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
632
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
633
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
634
+ else:
635
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
636
+
637
+ if not self.norm_before and not self.vit_norm:
638
+ message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
639
+
640
+ # feed-forward network
641
+ if self.vit_norm:
642
+ message_inter = (x + message)
643
+ del x
644
+ message = self.norm2(message_inter.permute(0, 2, 3, 1))
645
+ message = self.mlp(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
646
+ return message_inter + message
647
+ else:
648
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
649
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
650
+
651
+ return x + message
652
+ else: # mask with bs > 1
653
+ if self.pool_size ==1:
654
+ pass
655
+ else:
656
+ x_mask = self.max_pool(x_mask.float()).bool()
657
+
658
+ if self.pool_size2 ==1:
659
+ pass
660
+ else:
661
+ source_mask = self.max_pool(source_mask.float()).bool()
662
+ m_list = []
663
+ for i in range(bs):
664
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
665
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
666
+
667
+ q = query[i:i+1, :mask_h0, :mask_w0, :]
668
+ k = key[i:i+1, :mask_h1, :mask_w1, :]
669
+ v = value[i:i+1, :mask_h1, :mask_w1, :]
670
+
671
+ if PFLASH_AVAILABLE: # [N, H, W, C] -> [N, h, L, D]
672
+ q = rearrange(q, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
673
+ k = rearrange(k, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
674
+ v = rearrange(v, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
675
+ else: # N L H D
676
+ q = rearrange(q, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
677
+ k = rearrange(k, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
678
+ v = rearrange(v, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim)
679
+
680
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, h, D] or [N, h, L, D]
681
+
682
+ if PFLASH_AVAILABLE: # [N, h, L, D] -> [N, L, h, D]
683
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
684
+
685
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
686
+ if mask_h0 != x_mask.size(-2):
687
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
688
+ elif mask_w0 != x_mask.size(-1):
689
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
690
+ m_list.append(m)
691
+ m = torch.cat(m_list, dim=0)
692
+
693
+ m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
694
+ # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
695
+ m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
696
+
697
+ if self.pool_size == 1:
698
+ pass
699
+ else:
700
+ if self.scatter:
701
+ m = torch.repeat_interleave(m, self.pool_size, dim=-2)
702
+ m = torch.repeat_interleave(m, self.pool_size, dim=-1)
703
+ m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
704
+ else:
705
+ m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
706
+
707
+
708
+ if not self.norm_before and not self.vit_norm:
709
+ m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
710
+
711
+ # feed-forward network
712
+ if self.vit_norm:
713
+ m_inter = (x + m)
714
+ del x
715
+ m = self.norm2(m_inter.permute(0, 2, 3, 1))
716
+ m = self.mlp(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
717
+ return m_inter + m
718
+ else:
719
+ m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
720
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
721
+
722
+ return x + m
723
+
724
+ return x + m
725
+
726
+ class AG_Conv_EncoderLayer(nn.Module):
727
+ def __init__(self,
728
+ d_model,
729
+ nhead,
730
+ attention='linear',
731
+ pool_size=4,
732
+ bn=True,
733
+ xformer=False,
734
+ leaky=-1.0,
735
+ dw_conv=False,
736
+ dw_conv2=False,
737
+ scatter=False,
738
+ norm_before=True,
739
+ ):
740
+ super(AG_Conv_EncoderLayer, self).__init__()
741
+
742
+ self.pool_size = pool_size
743
+ self.dw_conv = dw_conv
744
+ self.dw_conv2 = dw_conv2
745
+ self.scatter = scatter
746
+ self.norm_before = norm_before
747
+ if self.dw_conv:
748
+ self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
749
+ if self.dw_conv2:
750
+ self.aggregate2 = nn.Conv2d(d_model, d_model, kernel_size=pool_size, padding=0, stride=pool_size, bias=False, groups=d_model)
751
+ self.dim = d_model // nhead
752
+ self.nhead = nhead
753
+
754
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
755
+
756
+ # multi-head attention
757
+ if bn:
758
+ method = 'dw_bn'
759
+ else:
760
+ method = 'dw'
761
+ self.q_proj_conv = self._build_projection(d_model, d_model, method=method)
762
+ self.k_proj_conv = self._build_projection(d_model, d_model, method=method)
763
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
764
+
765
+ if xformer:
766
+ self.attention = XAttention()
767
+ else:
768
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
769
+ self.merge = nn.Linear(d_model, d_model, bias=False)
770
+
771
+ # feed-forward network
772
+ if leaky > 0:
773
+ self.mlp = nn.Sequential(
774
+ nn.Linear(d_model*2, d_model*2, bias=False),
775
+ nn.LeakyReLU(leaky, True),
776
+ nn.Linear(d_model*2, d_model, bias=False),
777
+ )
778
+
779
+ else:
780
+ self.mlp = nn.Sequential(
781
+ nn.Linear(d_model*2, d_model*2, bias=False),
782
+ nn.ReLU(True),
783
+ nn.Linear(d_model*2, d_model, bias=False),
784
+ )
785
+
786
+ # norm and dropout
787
+ self.norm1 = nn.LayerNorm(d_model)
788
+ self.norm2 = nn.LayerNorm(d_model)
789
+
790
+ def forward(self, x, source, x_mask=None, source_mask=None):
791
+ """
792
+ Args:
793
+ x (torch.Tensor): [N, C, H1, W1]
794
+ source (torch.Tensor): [N, C, H2, W2]
795
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
796
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
797
+ """
798
+ bs = x.size(0)
799
+ H1, W1 = x.size(-2), x.size(-1)
800
+ H2, W2 = source.size(-2), source.size(-1)
801
+ C = x.shape[-3]
802
+
803
+ if self.norm_before:
804
+ if self.dw_conv:
805
+ query = self.norm1(self.aggregate(x).permute(0,2,3,1)).permute(0,3,1,2)
806
+ else:
807
+ query = self.norm1(self.max_pool(x).permute(0,2,3,1)).permute(0,3,1,2)
808
+ if self.dw_conv2:
809
+ source = self.norm1(self.aggregate2(source).permute(0,2,3,1)).permute(0,3,1,2)
810
+ else:
811
+ source = self.norm1(self.max_pool(source).permute(0,2,3,1)).permute(0,3,1,2)
812
+ else:
813
+ if self.dw_conv:
814
+ query = self.aggregate(x)
815
+ else:
816
+ query = self.max_pool(x)
817
+ if self.dw_conv2:
818
+ source = self.aggregate2(source)
819
+ else:
820
+ source = self.max_pool(source)
821
+
822
+ key, value = source, source
823
+
824
+ query = self.q_proj_conv(query) # [N, C, H1//pool, W1//pool]
825
+ key = self.k_proj_conv(key)
826
+ value = self.v_proj_conv(value)
827
+
828
+ use_mask = x_mask is not None and source_mask is not None
829
+ if bs == 1 or not use_mask:
830
+ if use_mask:
831
+ x_mask = self.max_pool(x_mask.float()).bool() # [N, H1//pool, W1//pool]
832
+ source_mask = self.max_pool(source_mask.float()).bool()
833
+
834
+ mask_h0, mask_w0 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0]
835
+ mask_h1, mask_w1 = source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0]
836
+
837
+ query = query[:, :, :mask_h0, :mask_w0]
838
+ key = key[:, :, :mask_h1, :mask_w1]
839
+ value = value[:, :, :mask_h1, :mask_w1]
840
+
841
+ else:
842
+ assert x_mask is None and source_mask is None
843
+
844
+ if PFLASH_AVAILABLE: # N H L D
845
+ query = rearrange(query, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
846
+ key = rearrange(key, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
847
+ value = rearrange(value, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
848
+
849
+ else: # N L H D
850
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
851
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
852
+ value = rearrange(value, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
853
+
854
+ message = self.attention(query, key, value, q_mask=None, kv_mask=None) # [N, L, H, D] or [N, H, L, D]
855
+
856
+ if PFLASH_AVAILABLE: # N H L D
857
+ message = rearrange(message, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
858
+
859
+ if use_mask: # padding zero
860
+ message = message.view(bs, mask_h0, mask_w0, self.nhead, self.dim) # [N L H D]
861
+ if mask_h0 != x_mask.size(-2):
862
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=1)
863
+ elif mask_w0 != x_mask.size(-1):
864
+ message = torch.cat([message, torch.zeros(message.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=message.device, dtype=message.dtype)], dim=2)
865
+ else:
866
+ assert x_mask is None and source_mask is None
867
+
868
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
869
+ message = rearrange(message, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
870
+
871
+ if self.scatter:
872
+ message = torch.repeat_interleave(message, self.pool_size, dim=-2)
873
+ message = torch.repeat_interleave(message, self.pool_size, dim=-1)
874
+ message = message * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,message.shape[-2]//self.pool_size,message.shape[-1]//self.pool_size)
875
+ else:
876
+ message = torch.nn.functional.interpolate(message, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
877
+
878
+ if not self.norm_before:
879
+ message = self.norm1(message.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
880
+
881
+ # feed-forward network
882
+ message = self.mlp(torch.cat([x, message], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
883
+ message = self.norm2(message).permute(0, 3, 1, 2) # [N, C, H1, W1]
884
+
885
+ return x + message
886
+ else: # mask with bs > 1
887
+ x_mask = self.max_pool(x_mask.float()).bool()
888
+ source_mask = self.max_pool(source_mask.float()).bool()
889
+ m_list = []
890
+ for i in range(bs):
891
+ mask_h0, mask_w0 = x_mask[i].sum(-2)[0], x_mask[i].sum(-1)[0]
892
+ mask_h1, mask_w1 = source_mask[i].sum(-2)[0], source_mask[i].sum(-1)[0]
893
+
894
+ q = query[i:i+1, :, :mask_h0, :mask_w0]
895
+ k = key[i:i+1, :, :mask_h1, :mask_w1]
896
+ v = value[i:i+1, :, :mask_h1, :mask_w1]
897
+
898
+ if PFLASH_AVAILABLE: # N H L D
899
+ q = rearrange(q, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
900
+ k = rearrange(k, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
901
+ v = rearrange(v, 'n (nhead d) h w -> n nhead (h w) d', nhead=self.nhead, d=self.dim)
902
+
903
+ else: # N L H D
904
+ q = rearrange(q, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
905
+ k = rearrange(k, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
906
+ v = rearrange(v, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
907
+
908
+ m = self.attention(q, k, v, q_mask=None, kv_mask=None) # [N, L, H, D]
909
+
910
+ if PFLASH_AVAILABLE: # N H L D
911
+ m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim)
912
+
913
+ m = m.view(1, mask_h0, mask_w0, self.nhead, self.dim)
914
+ if mask_h0 != x_mask.size(-2):
915
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2)-mask_h0, x_mask.size(-1), self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=1)
916
+ elif mask_w0 != x_mask.size(-1):
917
+ m = torch.cat([m, torch.zeros(1, x_mask.size(-2), x_mask.size(-1)-mask_w0, self.nhead, self.dim, device=m.device, dtype=m.dtype)], dim=2)
918
+ m_list.append(m)
919
+ m = torch.cat(m_list, dim=0)
920
+
921
+ m = self.merge(m.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
922
+
923
+ # m = m.reshape(bs, C, H1//self.pool_size, W1//self.pool_size) # [N, C, H, W] why this bug worked
924
+ m = rearrange(m, 'b (h w) c -> b c h w', h=H1//self.pool_size, w=W1//self.pool_size) # [N, C, H, W]
925
+
926
+ if self.scatter:
927
+ m = torch.repeat_interleave(m, self.pool_size, dim=-2)
928
+ m = torch.repeat_interleave(m, self.pool_size, dim=-1)
929
+ m = m * self.aggregate.weight.data.reshape(1, C, self.pool_size, self.pool_size).repeat(1,1,m.shape[-2]//self.pool_size,m.shape[-1]//self.pool_size)
930
+ else:
931
+ m = torch.nn.functional.interpolate(m, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
932
+
933
+ if not self.norm_before:
934
+ m = self.norm1(m.permute(0,2,3,1)).permute(0,3,1,2) # [N, C, H, W]
935
+
936
+ # feed-forward network
937
+ m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
938
+ m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H1, W1]
939
+
940
+ return x + m
941
+
942
+ def _build_projection(self,
943
+ dim_in,
944
+ dim_out,
945
+ kernel_size=3,
946
+ padding=1,
947
+ stride=1,
948
+ method='dw_bn',
949
+ ):
950
+ if method == 'dw_bn':
951
+ proj = nn.Sequential(OrderedDict([
952
+ ('conv', nn.Conv2d(
953
+ dim_in,
954
+ dim_in,
955
+ kernel_size=kernel_size,
956
+ padding=padding,
957
+ stride=stride,
958
+ bias=False,
959
+ groups=dim_in
960
+ )),
961
+ ('bn', nn.BatchNorm2d(dim_in)),
962
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
963
+ ]))
964
+ elif method == 'avg':
965
+ proj = nn.Sequential(OrderedDict([
966
+ ('avg', nn.AvgPool2d(
967
+ kernel_size=kernel_size,
968
+ padding=padding,
969
+ stride=stride,
970
+ ceil_mode=True
971
+ )),
972
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
973
+ ]))
974
+ elif method == 'linear':
975
+ proj = None
976
+ elif method == 'dw':
977
+ proj = nn.Sequential(OrderedDict([
978
+ ('conv', nn.Conv2d(
979
+ dim_in,
980
+ dim_in,
981
+ kernel_size=kernel_size,
982
+ padding=padding,
983
+ stride=stride,
984
+ bias=False,
985
+ groups=dim_in
986
+ )),
987
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
988
+ ]))
989
+ else:
990
+ raise ValueError('Unknown method ({})'.format(method))
991
+
992
+ return proj
993
+
994
+
995
+ class RoPELoFTREncoderLayer(nn.Module):
996
+ def __init__(self,
997
+ d_model,
998
+ nhead,
999
+ attention='linear',
1000
+ rope=False,
1001
+ token_mixer=None,
1002
+ ):
1003
+ super(RoPELoFTREncoderLayer, self).__init__()
1004
+
1005
+ self.dim = d_model // nhead
1006
+ self.nhead = nhead
1007
+
1008
+ # multi-head attention
1009
+ if token_mixer is None:
1010
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
1011
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
1012
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
1013
+
1014
+ self.rope = rope
1015
+ self.token_mixer = None
1016
+ if token_mixer is not None:
1017
+ self.token_mixer = token_mixer
1018
+ if token_mixer == 'dwcn':
1019
+ self.attention = nn.Sequential(OrderedDict([
1020
+ ('conv', nn.Conv2d(
1021
+ d_model,
1022
+ d_model,
1023
+ kernel_size=3,
1024
+ padding=1,
1025
+ stride=1,
1026
+ bias=False,
1027
+ groups=d_model
1028
+ )),
1029
+ ]))
1030
+ elif self.rope:
1031
+ assert attention == 'linear'
1032
+ self.attention = RoPELinearAttention()
1033
+
1034
+ if token_mixer is None:
1035
+ self.merge = nn.Linear(d_model, d_model, bias=False)
1036
+
1037
+ # feed-forward network
1038
+ if token_mixer is None:
1039
+ self.mlp = nn.Sequential(
1040
+ nn.Linear(d_model*2, d_model*2, bias=False),
1041
+ nn.ReLU(True),
1042
+ nn.Linear(d_model*2, d_model, bias=False),
1043
+ )
1044
+ else:
1045
+ self.mlp = nn.Sequential(
1046
+ nn.Linear(d_model, d_model, bias=False),
1047
+ nn.ReLU(True),
1048
+ nn.Linear(d_model, d_model, bias=False),
1049
+ )
1050
+ # norm and dropout
1051
+ self.norm1 = nn.LayerNorm(d_model)
1052
+ self.norm2 = nn.LayerNorm(d_model)
1053
+
1054
+ def forward(self, x, source, x_mask=None, source_mask=None, H=None, W=None):
1055
+ """
1056
+ Args:
1057
+ x (torch.Tensor): [N, L, C]
1058
+ source (torch.Tensor): [N, L, C]
1059
+ x_mask (torch.Tensor): [N, L] (optional)
1060
+ source_mask (torch.Tensor): [N, S] (optional)
1061
+ """
1062
+ bs = x.size(0)
1063
+ assert H*W == x.size(-2)
1064
+
1065
+ # x = rearrange(x, 'n c h w -> n (h w) c')
1066
+ # source = rearrange(source, 'n c h w -> n (h w) c')
1067
+ query, key, value = x, source, source
1068
+
1069
+ if self.token_mixer is not None:
1070
+ # multi-head attention
1071
+ m = self.norm1(x)
1072
+ m = rearrange(m, 'n (h w) c -> n c h w', h=H, w=W)
1073
+ m = self.attention(m)
1074
+ m = rearrange(m, 'n c h w -> n (h w) c')
1075
+
1076
+ x = x + m
1077
+ x = x + self.mlp(self.norm2(x))
1078
+ return x
1079
+ else:
1080
+ # multi-head attention
1081
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
1082
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
1083
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
1084
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask, H=H, W=W) # [N, L, (H, D)]
1085
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
1086
+ message = self.norm1(message)
1087
+
1088
+ # feed-forward network
1089
+ message = self.mlp(torch.cat([x, message], dim=2))
1090
+ message = self.norm2(message)
1091
+
1092
+ return x + message
1093
+
1094
+ class LoFTREncoderLayer(nn.Module):
1095
+ def __init__(self,
1096
+ d_model,
1097
+ nhead,
1098
+ attention='linear',
1099
+ xformer=False,
1100
+ ):
1101
+ super(LoFTREncoderLayer, self).__init__()
1102
+
1103
+ self.dim = d_model // nhead
1104
+ self.nhead = nhead
1105
+
1106
+ # multi-head attention
1107
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
1108
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
1109
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
1110
+
1111
+ if xformer:
1112
+ self.attention = XAttention()
1113
+ else:
1114
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
1115
+ self.merge = nn.Linear(d_model, d_model, bias=False)
1116
+
1117
+ # feed-forward network
1118
+ self.mlp = nn.Sequential(
1119
+ nn.Linear(d_model*2, d_model*2, bias=False),
1120
+ nn.ReLU(True),
1121
+ nn.Linear(d_model*2, d_model, bias=False),
1122
+ )
1123
+
1124
+ # norm and dropout
1125
+ self.norm1 = nn.LayerNorm(d_model)
1126
+ self.norm2 = nn.LayerNorm(d_model)
1127
+
1128
+ def forward(self, x, source, x_mask=None, source_mask=None):
1129
+ """
1130
+ Args:
1131
+ x (torch.Tensor): [N, L, C]
1132
+ source (torch.Tensor): [N, S, C]
1133
+ x_mask (torch.Tensor): [N, L] (optional)
1134
+ source_mask (torch.Tensor): [N, S] (optional)
1135
+ """
1136
+ bs = x.size(0)
1137
+ query, key, value = x, source, source
1138
+
1139
+ # multi-head attention
1140
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
1141
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
1142
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
1143
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
1144
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
1145
+ message = self.norm1(message)
1146
+
1147
+ # feed-forward network
1148
+ message = self.mlp(torch.cat([x, message], dim=2))
1149
+ message = self.norm2(message)
1150
+
1151
+ return x + message
1152
+
1153
+ def pro(self, x, source, x_mask=None, source_mask=None, profiler=None):
1154
+ """
1155
+ Args:
1156
+ x (torch.Tensor): [N, L, C]
1157
+ source (torch.Tensor): [N, S, C]
1158
+ x_mask (torch.Tensor): [N, L] (optional)
1159
+ source_mask (torch.Tensor): [N, S] (optional)
1160
+ """
1161
+ bs = x.size(0)
1162
+ query, key, value = x, source, source
1163
+
1164
+ # multi-head attention
1165
+ with profiler.profile("proj*3"):
1166
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
1167
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
1168
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
1169
+ with profiler.profile("attention"):
1170
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
1171
+ with profiler.profile("merge"):
1172
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
1173
+ with profiler.profile("norm1"):
1174
+ message = self.norm1(message)
1175
+
1176
+ # feed-forward network
1177
+ with profiler.profile("mlp"):
1178
+ message = self.mlp(torch.cat([x, message], dim=2))
1179
+ with profiler.profile("norm2"):
1180
+ message = self.norm2(message)
1181
+
1182
+ return x + message
1183
+
1184
+ class PANEncoderLayer_cross(nn.Module):
1185
+ def __init__(self,
1186
+ d_model,
1187
+ nhead,
1188
+ attention='linear',
1189
+ pool_size=4,
1190
+ bn=True,
1191
+ ):
1192
+ super(PANEncoderLayer_cross, self).__init__()
1193
+
1194
+ self.pool_size = pool_size
1195
+
1196
+ self.dim = d_model // nhead
1197
+ self.nhead = nhead
1198
+
1199
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size)
1200
+ # multi-head attention
1201
+ if bn:
1202
+ method = 'dw_bn'
1203
+ else:
1204
+ method = 'dw'
1205
+ self.qk_proj_conv = self._build_projection(d_model, d_model, method=method)
1206
+ self.v_proj_conv = self._build_projection(d_model, d_model, method=method)
1207
+
1208
+ # self.q_proj = nn.Linear(d_mosdel, d_model, bias=False)
1209
+ # self.k_proj = nn.Linear(d_model, d_model, bias=False)
1210
+ # self.v_proj = nn.Linear(d_model, d_model, bias=False)
1211
+ self.attention = FullAttention()
1212
+ self.merge = nn.Linear(d_model, d_model, bias=False)
1213
+
1214
+ # feed-forward network
1215
+ self.mlp = nn.Sequential(
1216
+ nn.Linear(d_model*2, d_model*2, bias=False),
1217
+ nn.ReLU(True),
1218
+ nn.Linear(d_model*2, d_model, bias=False),
1219
+ )
1220
+
1221
+ # norm and dropout
1222
+ self.norm1 = nn.LayerNorm(d_model)
1223
+ self.norm2 = nn.LayerNorm(d_model)
1224
+
1225
+ # self.norm1 = nn.BatchNorm2d(d_model)
1226
+
1227
+ def forward(self, x1, x2, x1_mask=None, x2_mask=None):
1228
+ """
1229
+ Args:
1230
+ x (torch.Tensor): [N, C, H1, W1]
1231
+ source (torch.Tensor): [N, C, H2, W2]
1232
+ x_mask (torch.Tensor): [N, H1, W1] (optional) (L = H1*W1)
1233
+ source_mask (torch.Tensor): [N, H2, W2] (optional) (S = H2*W2)
1234
+ """
1235
+ bs = x1.size(0)
1236
+ H1, W1 = x1.size(-2) // self.pool_size, x1.size(-1) // self.pool_size
1237
+ H2, W2 = x2.size(-2) // self.pool_size, x2.size(-1) // self.pool_size
1238
+
1239
+ query = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
1240
+ key = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
1241
+ v2 = self.norm1(self.max_pool(x2).permute(0,2,3,1)).permute(0,3,1,2)
1242
+ v1 = self.norm1(self.max_pool(x1).permute(0,2,3,1)).permute(0,3,1,2)
1243
+
1244
+ # multi-head attention
1245
+ query = self.qk_proj_conv(query) # [N, C, H1//pool, W1//pool]
1246
+ key = self.qk_proj_conv(key)
1247
+ v2 = self.v_proj_conv(v2)
1248
+ v1 = self.v_proj_conv(v1)
1249
+
1250
+ C = query.shape[-3]
1251
+ if x1_mask is not None and x2_mask is not None:
1252
+ x1_mask = self.max_pool(x1_mask.float()).bool() # [N, H1//pool, W1//pool]
1253
+ x2_mask = self.max_pool(x2_mask.float()).bool()
1254
+
1255
+ mask_h1, mask_w1 = x1_mask[0].sum(-2)[0], x1_mask[0].sum(-1)[0]
1256
+ mask_h2, mask_w2 = x2_mask[0].sum(-2)[0], x2_mask[0].sum(-1)[0]
1257
+
1258
+ query = query[:, :, :mask_h1, :mask_w1]
1259
+ key = key[:, :, :mask_h2, :mask_w2]
1260
+ v1 = v1[:, :, :mask_h1, :mask_w1]
1261
+ v2 = v2[:, :, :mask_h2, :mask_w2]
1262
+ x1_mask = x1_mask[:, :mask_h1, :mask_w1]
1263
+ x2_mask = x2_mask[:, :mask_h2, :mask_w2]
1264
+
1265
+ else:
1266
+ assert x1_mask is None and x2_mask is None
1267
+
1268
+ query = rearrange(query, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, L, H, D]
1269
+ key = rearrange(key, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
1270
+ v2 = rearrange(v2, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
1271
+ v1 = rearrange(v1, 'n (nhead d) h w -> n (h w) nhead d', nhead=self.nhead, d=self.dim) # [N, S, H, D]
1272
+ if x2_mask is not None or x1_mask is not None:
1273
+ x1_mask = x1_mask.flatten(-2)
1274
+ x2_mask = x2_mask.flatten(-2)
1275
+
1276
+
1277
+ QK = torch.einsum("nlhd,nshd->nlsh", query, key)
1278
+ with torch.autocast(enabled=False, device_type='cuda'):
1279
+ if x2_mask is not None or x1_mask is not None:
1280
+ # S1 = S2.transpose(-2,-3).masked_fill(~(x_mask[:, None, :, None] * source_mask[:, :, None, None]), -1e9) # float('-inf')
1281
+ QK = QK.float().masked_fill_(~(x1_mask[:, :, None, None] * x2_mask[:, None, :, None]), -1e9) # float('-inf')
1282
+
1283
+
1284
+ # Compute the attention and the weighted average
1285
+ softmax_temp = 1. / query.size(3)**.5 # sqrt(D)
1286
+ S1 = torch.softmax(softmax_temp * QK, dim=2)
1287
+ S2 = torch.softmax(softmax_temp * QK, dim=3)
1288
+
1289
+ m1 = torch.einsum("nlsh,nshd->nlhd", S1, v2)
1290
+ m2 = torch.einsum("nlsh,nlhd->nshd", S2, v1)
1291
+
1292
+ if x1_mask is not None and x2_mask is not None:
1293
+ m1 = m1.view(bs, mask_h1, mask_w1, self.nhead, self.dim)
1294
+ if mask_h1 != H1:
1295
+ m1 = torch.cat([m1, torch.zeros(m1.size(0), H1-mask_h1, W1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=1)
1296
+ elif mask_w1 != W1:
1297
+ m1 = torch.cat([m1, torch.zeros(m1.size(0), H1, W1-mask_w1, self.nhead, self.dim, device=m1.device, dtype=m1.dtype)], dim=2)
1298
+ else:
1299
+ assert x1_mask is None and x2_mask is None
1300
+
1301
+ m1 = self.merge(m1.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
1302
+ m1 = rearrange(m1, 'b (h w) c -> b c h w', h=H1, w=W1) # [N, C, H, W]
1303
+ m1 = torch.nn.functional.interpolate(m1, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
1304
+ # feed-forward network
1305
+ m1 = self.mlp(torch.cat([x1, m1], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
1306
+ m1 = self.norm2(m1).permute(0, 3, 1, 2) # [N, C, H1, W1]
1307
+
1308
+ if x1_mask is not None and x2_mask is not None:
1309
+ m2 = m2.view(bs, mask_h2, mask_w2, self.nhead, self.dim)
1310
+ if mask_h2 != H2:
1311
+ m2 = torch.cat([m2, torch.zeros(m2.size(0), H2-mask_h2, W2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=1)
1312
+ elif mask_w2 != W2:
1313
+ m2 = torch.cat([m2, torch.zeros(m2.size(0), H2, W2-mask_w2, self.nhead, self.dim, device=m2.device, dtype=m2.dtype)], dim=2)
1314
+ else:
1315
+ assert x1_mask is None and x2_mask is None
1316
+
1317
+ m2 = self.merge(m2.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
1318
+ m2 = rearrange(m2, 'b (h w) c -> b c h w', h=H2, w=W2) # [N, C, H, W]
1319
+ m2 = torch.nn.functional.interpolate(m2, scale_factor=self.pool_size, mode='bilinear', align_corners=False) # [N, C, H1, W1]
1320
+ # feed-forward network
1321
+ m2 = self.mlp(torch.cat([x2, m2], dim=1).permute(0, 2, 3, 1)) # [N, H1, W1, C]
1322
+ m2 = self.norm2(m2).permute(0, 3, 1, 2) # [N, C, H1, W1]
1323
+
1324
+ return x1 + m1, x2 + m2
1325
+
1326
+ def _build_projection(self,
1327
+ dim_in,
1328
+ dim_out,
1329
+ kernel_size=3,
1330
+ padding=1,
1331
+ stride=1,
1332
+ method='dw_bn',
1333
+ ):
1334
+ if method == 'dw_bn':
1335
+ proj = nn.Sequential(OrderedDict([
1336
+ ('conv', nn.Conv2d(
1337
+ dim_in,
1338
+ dim_in,
1339
+ kernel_size=kernel_size,
1340
+ padding=padding,
1341
+ stride=stride,
1342
+ bias=False,
1343
+ groups=dim_in
1344
+ )),
1345
+ ('bn', nn.BatchNorm2d(dim_in)),
1346
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
1347
+ ]))
1348
+ elif method == 'avg':
1349
+ proj = nn.Sequential(OrderedDict([
1350
+ ('avg', nn.AvgPool2d(
1351
+ kernel_size=kernel_size,
1352
+ padding=padding,
1353
+ stride=stride,
1354
+ ceil_mode=True
1355
+ )),
1356
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
1357
+ ]))
1358
+ elif method == 'linear':
1359
+ proj = None
1360
+ elif method == 'dw':
1361
+ proj = nn.Sequential(OrderedDict([
1362
+ ('conv', nn.Conv2d(
1363
+ dim_in,
1364
+ dim_in,
1365
+ kernel_size=kernel_size,
1366
+ padding=padding,
1367
+ stride=stride,
1368
+ bias=False,
1369
+ groups=dim_in
1370
+ )),
1371
+ # ('rearrage', Rearrange('b c h w -> b (h w) c')),
1372
+ ]))
1373
+ else:
1374
+ raise ValueError('Unknown method ({})'.format(method))
1375
+
1376
+ return proj
1377
+
1378
+ class LocalFeatureTransformer(nn.Module):
1379
+ """A Local Feature Transformer (LoFTR) module."""
1380
+
1381
+ def __init__(self, config):
1382
+ super(LocalFeatureTransformer, self).__init__()
1383
+
1384
+ self.full_config = config
1385
+ self.fine = False
1386
+ if 'coarse' not in config:
1387
+ self.fine = True # fine attention
1388
+ else:
1389
+ config = config['coarse']
1390
+ self.d_model = config['d_model']
1391
+ self.nhead = config['nhead']
1392
+ self.layer_names = config['layer_names']
1393
+ self.pan = config['pan']
1394
+ self.bidirect = config['bidirection']
1395
+ # prune
1396
+ self.pool_size = config['pool_size']
1397
+ self.matchability = False
1398
+ self.depth_confidence = -1.0
1399
+ self.width_confidence = -1.0
1400
+ # self.depth_confidence = config['depth_confidence']
1401
+ # self.width_confidence = config['width_confidence']
1402
+ # self.matchability = self.depth_confidence > 0 or self.width_confidence > 0
1403
+ # self.thr = self.full_config['match_coarse']['thr']
1404
+ if not self.fine:
1405
+ # asy
1406
+ self.asymmetric = config['asymmetric']
1407
+ self.asymmetric_self = config['asymmetric_self']
1408
+ # aggregate
1409
+ self.aggregate = config['dwconv']
1410
+ # RoPE
1411
+ self.rope = config['rope']
1412
+ # absPE
1413
+ self.abspe = config['abspe']
1414
+
1415
+ else:
1416
+ self.rope, self.asymmetric, self.asymmetric_self, self.aggregate = False, False, False, False
1417
+ if self.matchability:
1418
+ self.n_layers = len(self.layer_names) // 2
1419
+ assert self.n_layers == 4
1420
+ self.log_assignment = nn.ModuleList(
1421
+ [MatchAssignment(self.d_model) for _ in range(self.n_layers)])
1422
+ self.token_confidence = nn.ModuleList([
1423
+ TokenConfidence(self.d_model) for _ in range(self.n_layers-1)])
1424
+
1425
+ self.CoarseMatching = CoarseMatching(self.full_config['match_coarse'])
1426
+
1427
+ # self only
1428
+ # if self.rope:
1429
+ # self_layer = RoPELoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'], config['rope'], config['token_mixer'])
1430
+ # self.layers = nn.ModuleList([copy.deepcopy(self_layer) for _ in range(len(self.layer_names))])
1431
+
1432
+ if self.bidirect:
1433
+ assert config['xformer'] is False and config['pan'] is True
1434
+ self_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'], config['xformer'])
1435
+ cross_layer = PANEncoderLayer_cross(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'])
1436
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
1437
+ else:
1438
+ if self.aggregate:
1439
+ if self.rope:
1440
+ # assert config['npe'][0] == 832 and config['npe'][1] == 832 and config['npe'][2] == 832 and config['npe'][3] == 832
1441
+ logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
1442
+ self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1443
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1444
+ config['norm_before'], config['rope'], config['npe'], config['vit_norm'], config['rope_dwproj'])
1445
+ cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1446
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1447
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
1448
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
1449
+ elif self.abspe:
1450
+ logger.info(f'npe trainH,trainW,testH,testW: {config["npe"][0]}, {config["npe"][1]}, {config["npe"][2]}, {config["npe"][3]}')
1451
+ self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1452
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1453
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
1454
+ cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['pool_size2'],
1455
+ config['xformer'], config['leaky'], config['dwconv'], config['dwconv2'], config['scatter'],
1456
+ config['norm_before'], False, config['npe'], config['vit_norm'], config['rope_dwproj'])
1457
+ self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
1458
+
1459
+ else:
1460
+ encoder_layer = AG_Conv_EncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'], config['bn'],
1461
+ config['xformer'], config['leaky'], config['dwconv'], config['scatter'],
1462
+ config['norm_before'])
1463
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
1464
+ else:
1465
+ encoder_layer = PANEncoderLayer(config['d_model'], config['nhead'], config['attention'], config['pool_size'],
1466
+ config['bn'], config['xformer'], config['leaky'], config['dwconv'], config['scatter']) \
1467
+ if config['pan'] else LoFTREncoderLayer(config['d_model'], config['nhead'],
1468
+ config['attention'], config['xformer'])
1469
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
1470
+ self._reset_parameters()
1471
+
1472
+ def _reset_parameters(self):
1473
+ for p in self.parameters():
1474
+ if p.dim() > 1:
1475
+ nn.init.xavier_uniform_(p)
1476
+
1477
+ def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
1478
+ """
1479
+ Args:
1480
+ feat0 (torch.Tensor): [N, C, H, W]
1481
+ feat1 (torch.Tensor): [N, C, H, W]
1482
+ mask0 (torch.Tensor): [N, L] (optional)
1483
+ mask1 (torch.Tensor): [N, S] (optional)
1484
+ """
1485
+ # nchw for pan and n(hw)c for loftr
1486
+ assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
1487
+ H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
1488
+ bs = feat0.shape[0]
1489
+ padding = False
1490
+ if bs == 1 and mask0 is not None and mask1 is not None and self.pan: # NCHW for pan
1491
+ mask_H0, mask_W0 = mask0.size(-2), mask0.size(-1)
1492
+ mask_H1, mask_W1 = mask1.size(-2), mask1.size(-1)
1493
+ mask_h0, mask_w0 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0]
1494
+ mask_h1, mask_w1 = mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
1495
+
1496
+ #round to self.pool_size
1497
+ if self.pan:
1498
+ mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.pool_size*self.pool_size, mask_w0//self.pool_size*self.pool_size, mask_h1//self.pool_size*self.pool_size, mask_w1//self.pool_size*self.pool_size
1499
+
1500
+ feat0 = feat0[:, :, :mask_h0, :mask_w0]
1501
+ feat1 = feat1[:, :, :mask_h1, :mask_w1]
1502
+
1503
+ padding = True
1504
+
1505
+ # rope self only
1506
+ # if self.rope:
1507
+ # feat0, feat1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
1508
+ # prune
1509
+ if padding:
1510
+ l0, l1 = mask_h0 * mask_w0, mask_h1 * mask_w1
1511
+ else:
1512
+ l0, l1 = H0 * W0, H1 * W1
1513
+ do_early_stop = self.depth_confidence > 0
1514
+ do_point_pruning = self.width_confidence > 0
1515
+ if do_point_pruning:
1516
+ ind0 = torch.arange(0, l0, device=feat0.device)[None]
1517
+ ind1 = torch.arange(0, l1, device=feat0.device)[None]
1518
+ # We store the index of the layer at which pruning is detected.
1519
+ prune0 = torch.ones_like(ind0)
1520
+ prune1 = torch.ones_like(ind1)
1521
+ if do_early_stop:
1522
+ token0, token1 = None, None
1523
+
1524
+ for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
1525
+ if padding:
1526
+ mask0, mask1 = None, None
1527
+ if name == 'self':
1528
+ # if self.rope:
1529
+ # feat0 = layer(feat0, feat0, mask0, mask1, H0, W0)
1530
+ # feat1 = layer(feat1, feat1, mask0, mask1, H1, W1)
1531
+ if self.asymmetric:
1532
+ assert False, 'not worked'
1533
+ # feat0 = layer(feat0, feat0, mask0, mask1)
1534
+ feat1 = layer(feat1, feat1, mask1, mask1)
1535
+ else:
1536
+ feat0 = layer(feat0, feat0, mask0, mask0)
1537
+ feat1 = layer(feat1, feat1, mask1, mask1)
1538
+ elif name == 'cross':
1539
+ if self.bidirect:
1540
+ feat0, feat1 = layer(feat0, feat1, mask0, mask1)
1541
+ else:
1542
+ if self.asymmetric or self.asymmetric_self:
1543
+ assert False, 'not worked'
1544
+ feat0 = layer(feat0, feat1, mask0, mask1)
1545
+ else:
1546
+ feat0 = layer(feat0, feat1, mask0, mask1)
1547
+ feat1 = layer(feat1, feat0, mask1, mask0)
1548
+
1549
+ if i == len(self.layer_names) - 1 and not self.training:
1550
+ continue
1551
+ if self.matchability:
1552
+ desc0, desc1 = rearrange(feat0, 'b c h w -> b (h w) c'), rearrange(feat1, 'b c h w -> b (h w) c')
1553
+ if do_early_stop:
1554
+ token0, token1 = self.token_confidence[i//2](desc0, desc1)
1555
+ if self.check_if_stop(token0, token1, i, l0+l1) and not self.training:
1556
+ break
1557
+ if do_point_pruning:
1558
+ scores0, scores1 = self.log_assignment[i//2].scores(desc0, desc1)
1559
+ mask0 = self.get_pruning_mask(token0, scores0, i)
1560
+ mask1 = self.get_pruning_mask(token1, scores1, i)
1561
+ ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
1562
+ feat0, feat1 = desc0[mask0][None], desc1[mask1][None]
1563
+ if feat0.shape[-2] == 0 or desc1.shape[-2] == 0:
1564
+ break
1565
+ prune0[:, ind0] += 1
1566
+ prune1[:, ind1] += 1
1567
+ if self.training and self.matchability:
1568
+ scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
1569
+ m0_full = torch.zeros((bs, mask_h0 * mask_w0), device=matchability0.device, dtype=matchability0.dtype)
1570
+ m0_full.scatter(1, ind0, matchability0.squeeze(-1))
1571
+ if padding and self.d_model == feat0.size(1):
1572
+ m0_full = m0_full.reshape(bs, mask_h0, mask_w0)
1573
+ bs, c, mask_h0, mask_w0 = feat0.size()
1574
+ if mask_h0 != mask_H0:
1575
+ m0_full = torch.cat([m0_full, torch.zeros(bs, mask_H0-mask_h0, mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=1)
1576
+ elif mask_w0 != mask_W0:
1577
+ m0_full = torch.cat([m0_full, torch.zeros(bs, mask_h0, mask_W0-mask_w0, device=m0_full.device, dtype=m0_full.dtype)], dim=2)
1578
+ m0_full = m0_full.reshape(bs, mask_H0*mask_W0)
1579
+ m1_full = torch.zeros((bs, mask_h1 * mask_w1), device=matchability0.device, dtype=matchability0.dtype)
1580
+ m1_full.scatter(1, ind1, matchability1.squeeze(-1))
1581
+ if padding and self.d_model == feat1.size(1):
1582
+ m1_full = m1_full.reshape(bs, mask_h1, mask_w1)
1583
+ bs, c, mask_h1, mask_w1 = feat1.size()
1584
+ if mask_h1 != mask_H1:
1585
+ m1_full = torch.cat([m1_full, torch.zeros(bs, mask_H1-mask_h1, mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=1)
1586
+ elif mask_w1 != mask_W1:
1587
+ m1_full = torch.cat([m1_full, torch.zeros(bs, mask_h1, mask_W1-mask_w1, device=m1_full.device, dtype=m1_full.dtype)], dim=2)
1588
+ m1_full = m1_full.reshape(bs, mask_H1*mask_W1)
1589
+ data.update({'matchability0_'+str(i//2): m0_full, 'matchability1_'+str(i//2): m1_full})
1590
+ m0, m1, mscores0, mscores1 = filter_matches(
1591
+ scores, self.thr)
1592
+ if do_point_pruning:
1593
+ m0_ = torch.full((bs, l0), -1, device=m0.device, dtype=m0.dtype)
1594
+ m1_ = torch.full((bs, l1), -1, device=m1.device, dtype=m1.dtype)
1595
+ m0_[:, ind0] = torch.where(
1596
+ m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
1597
+ m1_[:, ind1] = torch.where(
1598
+ m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
1599
+ mscores0_ = torch.zeros((bs, l0), device=mscores0.device)
1600
+ mscores1_ = torch.zeros((bs, l1), device=mscores1.device)
1601
+ mscores0_[:, ind0] = mscores0
1602
+ mscores1_[:, ind1] = mscores1
1603
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
1604
+ if padding and self.d_model == feat0.size(1):
1605
+ m0 = m0.reshape(bs, mask_h0, mask_w0)
1606
+ bs, c, mask_h0, mask_w0 = feat0.size()
1607
+ if mask_h0 != mask_H0:
1608
+ m0 = torch.cat([m0, -torch.ones(bs, mask_H0-mask_h0, mask_w0, device=m0.device, dtype=m0.dtype)], dim=1)
1609
+ elif mask_w0 != mask_W0:
1610
+ m0 = torch.cat([m0, -torch.ones(bs, mask_h0, mask_W0-mask_w0, device=m0.device, dtype=m0.dtype)], dim=2)
1611
+ m0 = m0.reshape(bs, mask_H0*mask_W0)
1612
+ if padding and self.d_model == feat1.size(1):
1613
+ m1 = m1.reshape(bs, mask_h1, mask_w1)
1614
+ bs, c, mask_h1, mask_w1 = feat1.size()
1615
+ if mask_h1 != mask_H1:
1616
+ m1 = torch.cat([m1, -torch.ones(bs, mask_H1-mask_h1, mask_w1, device=m1.device, dtype=m1.dtype)], dim=1)
1617
+ elif mask_w1 != mask_W1:
1618
+ m1 = torch.cat([m1, -torch.ones(bs, mask_h1, mask_W1-mask_w1, device=m1.device, dtype=m1.dtype)], dim=2)
1619
+ m1 = m1.reshape(bs, mask_H1*mask_W1)
1620
+ data.update({'matches0_'+str(i//2): m0, 'matches1_'+str(i//2): m1})
1621
+ conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
1622
+ ind = ind0[...,None] * l1 + ind1[:,None,:]
1623
+ # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
1624
+ conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
1625
+ if padding and self.d_model == feat0.size(1):
1626
+ conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
1627
+ bs, c, mask_h0, mask_w0 = feat0.size()
1628
+ if mask_h0 != mask_H0:
1629
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
1630
+ elif mask_w0 != mask_W0:
1631
+ conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
1632
+ bs, c, mask_h1, mask_w1 = feat1.size()
1633
+ if mask_h1 != mask_H1:
1634
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
1635
+ elif mask_w1 != mask_W1:
1636
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
1637
+ conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
1638
+ data.update({'conf_matrix_'+str(i//2): conf})
1639
+
1640
+
1641
+
1642
+ else:
1643
+ raise KeyError
1644
+
1645
+ if self.matchability and not self.training:
1646
+ scores, _, matchability0, matchability1 = self.log_assignment[i//2](desc0, desc1)
1647
+ conf = torch.zeros((bs, l0 * l1), device=scores.device, dtype=scores.dtype)
1648
+ ind = ind0[...,None] * l1 + ind1[:,None,:]
1649
+ # conf[ind.reshape(bs, -1)] = scores.reshape(bs, -1).exp()
1650
+ conf.scatter(1, ind.reshape(bs, -1), scores.reshape(bs, -1).exp())
1651
+ if padding and self.d_model == feat0.size(1):
1652
+ conf = conf.reshape(bs, mask_h0, mask_w0, mask_h1, mask_w1)
1653
+ bs, c, mask_h0, mask_w0 = feat0.size()
1654
+ if mask_h0 != mask_H0:
1655
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0-mask_h0, mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=1)
1656
+ elif mask_w0 != mask_W0:
1657
+ conf = torch.cat([conf, torch.zeros(bs, mask_h0, mask_W0-mask_w0, mask_h1, mask_w1, device=conf.device, dtype=conf.dtype)], dim=2)
1658
+ bs, c, mask_h1, mask_w1 = feat1.size()
1659
+ if mask_h1 != mask_H1:
1660
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1-mask_h1, mask_W1, device=conf.device, dtype=conf.dtype)], dim=3)
1661
+ elif mask_w1 != mask_W1:
1662
+ conf = torch.cat([conf, torch.zeros(bs, mask_H0, mask_W0, mask_H1, mask_W1-mask_w1, device=conf.device, dtype=conf.dtype)], dim=4)
1663
+ conf = conf.reshape(bs, mask_H0*mask_W0, mask_H1*mask_W1)
1664
+ data.update({'conf_matrix': conf})
1665
+ data.update(**self.CoarseMatching.get_coarse_match(conf, data))
1666
+ # m0, m1, mscores0, mscores1 = filter_matches(
1667
+ # scores, self.conf.filter_threshold)
1668
+
1669
+ # matches, mscores = [], []
1670
+ # for k in range(b):
1671
+ # valid = m0[k] > -1
1672
+ # m_indices_0 = torch.where(valid)[0]
1673
+ # m_indices_1 = m0[k][valid]
1674
+ # if do_point_pruning:
1675
+ # m_indices_0 = ind0[k, m_indices_0]
1676
+ # m_indices_1 = ind1[k, m_indices_1]
1677
+ # matches.append(torch.stack([m_indices_0, m_indices_1], -1))
1678
+ # mscores.append(mscores0[k][valid])
1679
+
1680
+ # # TODO: Remove when hloc switches to the compact format.
1681
+ # if do_point_pruning:
1682
+ # m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
1683
+ # m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
1684
+ # m0_[:, ind0] = torch.where(
1685
+ # m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
1686
+ # m1_[:, ind1] = torch.where(
1687
+ # m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
1688
+ # mscores0_ = torch.zeros((b, m), device=mscores0.device)
1689
+ # mscores1_ = torch.zeros((b, n), device=mscores1.device)
1690
+ # mscores0_[:, ind0] = mscores0
1691
+ # mscores1_[:, ind1] = mscores1
1692
+ # m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
1693
+
1694
+ # pred = {
1695
+ # 'matches0': m0,
1696
+ # 'matches1': m1,
1697
+ # 'matching_scores0': mscores0,
1698
+ # 'matching_scores1': mscores1,
1699
+ # 'stop': i+1,
1700
+ # 'matches': matches,
1701
+ # 'scores': mscores,
1702
+ # }
1703
+
1704
+ # if do_point_pruning:
1705
+ # pred.update(dict(prune0=prune0, prune1=prune1))
1706
+ # return pred
1707
+
1708
+
1709
+ if padding and self.d_model == feat0.size(1):
1710
+ bs, c, mask_h0, mask_w0 = feat0.size()
1711
+ if mask_h0 != mask_H0:
1712
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
1713
+ elif mask_w0 != mask_W0:
1714
+ feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
1715
+ bs, c, mask_h1, mask_w1 = feat1.size()
1716
+ if mask_h1 != mask_H1:
1717
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
1718
+ elif mask_w1 != mask_W1:
1719
+ feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
1720
+
1721
+ return feat0, feat1
1722
+
1723
+ def pro(self, feat0, feat1, mask0=None, mask1=None, profiler=None):
1724
+ """
1725
+ Args:
1726
+ feat0 (torch.Tensor): [N, C, H, W]
1727
+ feat1 (torch.Tensor): [N, C, H, W]
1728
+ mask0 (torch.Tensor): [N, L] (optional)
1729
+ mask1 (torch.Tensor): [N, S] (optional)
1730
+ """
1731
+
1732
+ assert self.d_model == feat0.size(1) or self.d_model == feat0.size(-1), "the feature number of src and transformer must be equal"
1733
+ with profiler.profile("LoFTR_transformer_attention"):
1734
+ for layer, name in zip(self.layers, self.layer_names):
1735
+ if name == 'self':
1736
+ feat0 = layer.pro(feat0, feat0, mask0, mask0, profiler=profiler)
1737
+ feat1 = layer.pro(feat1, feat1, mask1, mask1, profiler=profiler)
1738
+ elif name == 'cross':
1739
+ feat0 = layer.pro(feat0, feat1, mask0, mask1, profiler=profiler)
1740
+ feat1 = layer.pro(feat1, feat0, mask1, mask0, profiler=profiler)
1741
+ else:
1742
+ raise KeyError
1743
+
1744
+ return feat0, feat1
1745
+
1746
+ def confidence_threshold(self, layer_index: int) -> float:
1747
+ """ scaled confidence threshold """
1748
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
1749
+ return np.clip(threshold, 0, 1)
1750
+
1751
+ def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor,
1752
+ layer_index: int) -> torch.Tensor:
1753
+ """ mask points which should be removed """
1754
+ threshold = self.confidence_threshold(layer_index)
1755
+ if confidences is not None:
1756
+ scores = torch.where(
1757
+ confidences > threshold, scores, scores.new_tensor(1.0))
1758
+ return scores > (1 - self.width_confidence)
1759
+
1760
+ def check_if_stop(self,
1761
+ confidences0: torch.Tensor,
1762
+ confidences1: torch.Tensor,
1763
+ layer_index: int, num_points: int) -> torch.Tensor:
1764
+ """ evaluate stopping condition"""
1765
+ confidences = torch.cat([confidences0, confidences1], -1)
1766
+ threshold = self.confidence_threshold(layer_index)
1767
+ pos = 1.0 - (confidences < threshold).float().sum() / num_points
1768
+ return pos > self.depth_confidence