faiss 0.6.0 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (378) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/ext/faiss/extconf.rb +2 -1
  4. data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
  5. data/ext/faiss/index_binary.cpp +1 -1
  6. data/ext/faiss/kmeans.cpp +1 -1
  7. data/ext/faiss/pca_matrix.cpp +1 -1
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
  10. data/lib/faiss/version.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +93 -80
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -240
  13. data/vendor/faiss/faiss/Clustering.h +6 -0
  14. data/vendor/faiss/faiss/IVFlib.cpp +41 -21
  15. data/vendor/faiss/faiss/Index.cpp +6 -5
  16. data/vendor/faiss/faiss/Index.h +5 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  21. data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
  22. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +88 -97
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +89 -417
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
  32. data/vendor/faiss/faiss/IndexFastScan.h +25 -23
  33. data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
  34. data/vendor/faiss/faiss/IndexFlat.h +21 -18
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
  36. data/vendor/faiss/faiss/IndexHNSW.cpp +374 -206
  37. data/vendor/faiss/faiss/IndexHNSW.h +16 -2
  38. data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
  39. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  40. data/vendor/faiss/faiss/IndexIVF.cpp +467 -364
  41. data/vendor/faiss/faiss/IndexIVF.h +33 -12
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +79 -76
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
  45. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
  46. data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
  47. data/vendor/faiss/faiss/IndexIVFFlat.cpp +39 -69
  48. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  49. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +56 -33
  50. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +73 -846
  53. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  55. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -20
  56. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +30 -52
  57. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
  58. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
  59. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  62. data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
  63. data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  66. data/vendor/faiss/faiss/IndexNSG.cpp +38 -23
  67. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
  68. data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
  69. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  73. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  74. data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
  75. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
  76. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
  77. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
  78. data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
  79. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  80. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +150 -20
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  85. data/vendor/faiss/faiss/IndexShards.cpp +10 -9
  86. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  87. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  88. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  89. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  90. data/vendor/faiss/faiss/MetricType.h +14 -7
  91. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  92. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  93. data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
  94. data/vendor/faiss/faiss/VectorTransform.h +16 -16
  95. data/vendor/faiss/faiss/build.cpp +23 -0
  96. data/vendor/faiss/faiss/build.h +15 -0
  97. data/vendor/faiss/faiss/clone_index.cpp +48 -47
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  105. data/vendor/faiss/faiss/factory_tools.cpp +9 -0
  106. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  107. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  108. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +15 -16
  109. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +5 -4
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  113. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  114. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  115. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  116. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  117. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  118. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  119. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  120. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +58 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +111 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  130. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  136. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
  139. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
  140. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  141. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
  142. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  143. data/vendor/faiss/faiss/impl/HNSW.cpp +639 -507
  144. data/vendor/faiss/faiss/impl/HNSW.h +61 -44
  145. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  146. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  147. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  148. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
  149. data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
  150. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  151. data/vendor/faiss/faiss/impl/NSG.cpp +53 -32
  152. data/vendor/faiss/faiss/impl/NSG.h +4 -4
  153. data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
  154. data/vendor/faiss/faiss/impl/Panorama.h +269 -87
  155. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  156. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  157. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
  158. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  159. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  160. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  161. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +55 -25
  162. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  163. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
  164. data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
  165. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +302 -283
  166. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
  167. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  168. data/vendor/faiss/faiss/impl/ResultHandler.h +100 -75
  169. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +318 -7
  170. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +77 -1
  171. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  172. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  173. data/vendor/faiss/faiss/impl/VisitedTable.h +70 -28
  174. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  175. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  176. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  177. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  178. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  181. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  182. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  183. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  184. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  185. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  191. data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
  192. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  193. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  194. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  195. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  196. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  197. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +270 -0
  198. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
  199. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  203. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  204. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  205. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
  206. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  208. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
  209. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  210. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  211. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +83 -0
  212. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +113 -0
  213. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +150 -0
  214. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +142 -0
  215. data/vendor/faiss/faiss/impl/index_read.cpp +1227 -79
  216. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
  217. data/vendor/faiss/faiss/impl/index_write.cpp +96 -13
  218. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  219. data/vendor/faiss/faiss/impl/io_macros.h +58 -16
  220. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  221. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
  222. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  223. data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/platform_macros.h +15 -4
  225. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  226. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  227. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  228. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  229. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +23 -0
  230. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +23 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +23 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
  233. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
  234. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +45 -107
  235. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +274 -5
  237. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +10 -7
  238. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  239. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +70 -0
  240. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
  241. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  242. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
  243. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +9 -2
  244. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +419 -19
  245. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +387 -2
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +341 -2
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +425 -3
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +290 -2
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +337 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
  256. data/vendor/faiss/faiss/impl/simd_dispatch.h +157 -66
  257. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  258. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  260. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  261. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  262. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  264. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  265. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  266. data/vendor/faiss/faiss/index_factory.cpp +90 -18
  267. data/vendor/faiss/faiss/index_io.h +40 -0
  268. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
  269. data/vendor/faiss/faiss/invlists/DirectMap.cpp +28 -15
  270. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  271. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +170 -86
  272. data/vendor/faiss/faiss/invlists/InvertedLists.h +88 -25
  273. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  274. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
  275. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  276. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
  277. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  278. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  279. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  280. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +142 -21
  285. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +33 -7
  286. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  287. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +77 -27
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +10 -4
  290. data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
  291. data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
  292. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  293. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  294. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  295. data/vendor/faiss/faiss/utils/distances.cpp +390 -560
  296. data/vendor/faiss/faiss/utils/distances.h +20 -1
  297. data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
  298. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  299. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  300. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  301. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  302. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  304. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -178
  305. data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
  306. data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
  307. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  308. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  309. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  310. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  311. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +16 -0
  312. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +210 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -989
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1031 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  355. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.cpp +29 -7
  357. data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
  358. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  359. data/vendor/faiss/faiss/utils/utils.cpp +5 -5
  360. data/vendor/faiss/faiss/utils/utils.h +3 -3
  361. metadata +129 -34
  362. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  363. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  364. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
  366. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  367. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  368. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  369. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  370. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
  371. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  373. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
  374. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  375. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  376. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  377. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
  378. /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
@@ -9,6 +9,8 @@
9
9
 
10
10
  #include <cinttypes>
11
11
  #include <cstddef>
12
+ #include <cstdlib>
13
+ #include <type_traits>
12
14
 
13
15
  #include <faiss/IndexHNSW.h>
14
16
 
@@ -16,13 +18,7 @@
16
18
  #include <faiss/impl/IDSelector.h>
17
19
  #include <faiss/impl/ResultHandler.h>
18
20
  #include <faiss/impl/VisitedTable.h>
19
-
20
- #ifdef __AVX2__
21
- #include <immintrin.h>
22
-
23
- #include <limits>
24
- #include <type_traits>
25
- #endif
21
+ #include <faiss/impl/hnsw/MinimaxHeap.h>
26
22
 
27
23
  namespace faiss {
28
24
 
@@ -31,7 +27,8 @@ namespace faiss {
31
27
  **************************************************************/
32
28
 
33
29
  int HNSW::nb_neighbors(int layer_no) const {
34
- FAISS_THROW_IF_NOT(layer_no + 1 < cum_nneighbor_per_level.size());
30
+ FAISS_THROW_IF_NOT(
31
+ static_cast<size_t>(layer_no + 1) < cum_nneighbor_per_level.size());
35
32
  return cum_nneighbor_per_level[layer_no + 1] -
36
33
  cum_nneighbor_per_level[layer_no];
37
34
  }
@@ -39,7 +36,7 @@ int HNSW::nb_neighbors(int layer_no) const {
39
36
  void HNSW::set_nb_neighbors(int level_no, int n) {
40
37
  FAISS_THROW_IF_NOT(levels.size() == 0);
41
38
  int cur_n = nb_neighbors(level_no);
42
- for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
39
+ for (size_t i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
43
40
  cum_nneighbor_per_level[i] += n - cur_n;
44
41
  }
45
42
  }
@@ -67,7 +64,7 @@ HNSW::HNSW(int M) : rng(12345) {
67
64
  int HNSW::random_level() {
68
65
  double f = rng.rand_float();
69
66
  // could be a bit faster with bisection
70
- for (int level = 0; level < assign_probas.size(); level++) {
67
+ for (size_t level = 0; level < assign_probas.size(); level++) {
71
68
  if (f < assign_probas[level]) {
72
69
  return level;
73
70
  }
@@ -92,7 +89,7 @@ void HNSW::set_default_probas(int M, float levelMult) {
92
89
  }
93
90
 
94
91
  void HNSW::clear_neighbor_tables(int level) {
95
- for (int i = 0; i < levels.size(); i++) {
92
+ for (size_t i = 0; i < levels.size(); i++) {
96
93
  size_t begin, end;
97
94
  neighbor_range(i, level, &begin, &end);
98
95
  for (size_t j = begin; j < end; j++) {
@@ -111,14 +108,15 @@ void HNSW::reset() {
111
108
  }
112
109
 
113
110
  void HNSW::print_neighbor_stats(int level) const {
114
- FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size());
111
+ FAISS_THROW_IF_NOT(
112
+ static_cast<size_t>(level) < cum_nneighbor_per_level.size());
115
113
  printf("stats on level %d, max %d neighbors per vertex:\n",
116
114
  level,
117
115
  nb_neighbors(level));
118
116
  size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
119
117
  #pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
120
118
  reduction(+ : tot_reciprocal) reduction(+ : n_node)
121
- for (int i = 0; i < levels.size(); i++) {
119
+ for (idx_t i = 0; i < static_cast<idx_t>(levels.size()); i++) {
122
120
  if (levels[i] > level) {
123
121
  n_node++;
124
122
  size_t begin, end;
@@ -130,7 +128,7 @@ void HNSW::print_neighbor_stats(int level) const {
130
128
  }
131
129
  neighset.insert(neighbors[j]);
132
130
  }
133
- int n_neigh = neighset.size();
131
+ size_t n_neigh = neighset.size();
134
132
  int n_common = 0;
135
133
  int n_reciprocal = 0;
136
134
  for (size_t j = begin; j < end; j++) {
@@ -179,7 +177,7 @@ void HNSW::fill_with_random_links(size_t n) {
179
177
 
180
178
  for (int level = max_level_2 - 1; level >= 0; --level) {
181
179
  std::vector<int> elts;
182
- for (int i = 0; i < n; i++) {
180
+ for (size_t i = 0; i < n; i++) {
183
181
  if (levels[i] > level) {
184
182
  elts.push_back(i);
185
183
  }
@@ -190,10 +188,10 @@ void HNSW::fill_with_random_links(size_t n) {
190
188
  continue;
191
189
  }
192
190
 
193
- for (int ii = 0; ii < elts.size(); ii++) {
191
+ for (size_t ii = 0; ii < elts.size(); ii++) {
194
192
  int i = elts[ii];
195
193
  size_t begin, end;
196
- neighbor_range(i, 0, &begin, &end);
194
+ neighbor_range(i, level, &begin, &end);
197
195
  for (size_t j = begin; j < end; j++) {
198
196
  int other = 0;
199
197
  do {
@@ -213,14 +211,14 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
213
211
  FAISS_ASSERT(n0 + n == levels.size());
214
212
  } else {
215
213
  FAISS_ASSERT(n0 == levels.size());
216
- for (int i = 0; i < n; i++) {
214
+ for (size_t i = 0; i < n; i++) {
217
215
  int pt_level = random_level();
218
216
  levels.push_back(pt_level + 1);
219
217
  }
220
218
  }
221
219
 
222
220
  int max_level_2 = 0;
223
- for (int i = 0; i < n; i++) {
221
+ for (size_t i = 0; i < n; i++) {
224
222
  int pt_level = levels[i + n0] - 1;
225
223
  if (pt_level > max_level_2) {
226
224
  max_level_2 = pt_level;
@@ -236,28 +234,32 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
236
234
  * neighbor only if there is no previous neighbor that is closer to
237
235
  * that vertex than the query.
238
236
  */
237
+ template <class Comp>
239
238
  void HNSW::shrink_neighbor_list(
240
239
  DistanceComputer& qdis,
241
- std::priority_queue<NodeDistFarther>& input,
242
- std::vector<NodeDistFarther>& output,
243
- int max_size,
240
+ std::priority_queue<NodeDistFartherT<Comp>>& input,
241
+ std::vector<NodeDistFartherT<Comp>>& output,
242
+ size_t max_size,
244
243
  bool keep_max_size_level0) {
245
244
  // This prevents number of neighbors at
246
245
  // level 0 from being shrunk to less than 2 * M.
247
246
  // This is essential in making sure
248
247
  // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
249
- std::vector<NodeDistFarther> outsiders;
248
+ std::vector<NodeDistFartherT<Comp>> outsiders;
250
249
 
251
250
  while (input.size() > 0) {
252
- NodeDistFarther v1 = input.top();
251
+ NodeDistFartherT<Comp> v1 = input.top();
253
252
  input.pop();
254
253
  float dist_v1_q = v1.d;
255
254
 
256
255
  bool good = true;
257
- for (NodeDistFarther v2 : output) {
256
+ for (NodeDistFartherT<Comp> v2 : output) {
258
257
  float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
259
258
 
260
- if (dist_v1_v2 < dist_v1_q) {
259
+ // "v1 is bad" if some previously-kept neighbor v2 is closer
260
+ // (more similar, under CMin) to v1 than the query is. Encoded
261
+ // generically as: v1v2 is "better than" v1q under Comp.
262
+ if (Comp::cmp(dist_v1_q, dist_v1_v2)) {
261
263
  good = false;
262
264
  break;
263
265
  }
@@ -265,7 +267,7 @@ void HNSW::shrink_neighbor_list(
265
267
 
266
268
  if (good) {
267
269
  output.push_back(v1);
268
- if (output.size() >= max_size) {
270
+ if (output.size() >= static_cast<size_t>(max_size)) {
269
271
  return;
270
272
  }
271
273
  } else if (keep_max_size_level0) {
@@ -273,50 +275,95 @@ void HNSW::shrink_neighbor_list(
273
275
  }
274
276
  }
275
277
  size_t idx = 0;
276
- while (keep_max_size_level0 && (output.size() < max_size) &&
278
+ while (keep_max_size_level0 &&
279
+ (output.size() < static_cast<size_t>(max_size)) &&
277
280
  (idx < outsiders.size())) {
278
281
  output.push_back(outsiders[idx++]);
279
282
  }
280
283
  }
281
284
 
285
+ // Explicit instantiations for the two supported comparators.
286
+ template void HNSW::shrink_neighbor_list<HNSW::C_distance>(
287
+ DistanceComputer&,
288
+ std::priority_queue<HNSW::NodeDistFartherT<HNSW::C_distance>>&,
289
+ std::vector<HNSW::NodeDistFartherT<HNSW::C_distance>>&,
290
+ size_t,
291
+ bool);
292
+ template void HNSW::shrink_neighbor_list<HNSW::C_similarity>(
293
+ DistanceComputer&,
294
+ std::priority_queue<HNSW::NodeDistFartherT<HNSW::C_similarity>>&,
295
+ std::vector<HNSW::NodeDistFartherT<HNSW::C_similarity>>&,
296
+ size_t,
297
+ bool);
298
+
282
299
  namespace {
283
300
 
284
301
  using storage_idx_t = HNSW::storage_idx_t;
285
- using NodeDistCloser = HNSW::NodeDistCloser;
286
- using NodeDistFarther = HNSW::NodeDistFarther;
302
+
303
+ // Map a (high-level) HNSW comparator C — which uses int64_t IDs — to the
304
+ // (low-level) MinimaxHeap comparator HC, which uses int32_t IDs.
305
+ template <class C>
306
+ using HC_for = std::
307
+ conditional_t<C::is_max, CMax<float, int32_t>, CMin<float, int32_t>>;
308
+
309
+ // Priority queue types used by the unbounded search variant. For CMax
310
+ // (distance) "top_candidates" is a max-heap of the kept-so-far results
311
+ // (top is the farthest) and "candidates" is a min-heap of the next nodes
312
+ // to explore (top is the closest). For CMin (similarity) the orderings are
313
+ // swapped: top_candidates is a min-heap (top is the least similar) and
314
+ // candidates is a max-heap (top is the most similar).
315
+ template <class C>
316
+ using TopCandidatesQueue = std::conditional_t<
317
+ C::is_max,
318
+ std::priority_queue<HNSW::Node>,
319
+ std::priority_queue<
320
+ HNSW::Node,
321
+ std::vector<HNSW::Node>,
322
+ std::greater<HNSW::Node>>>;
323
+
324
+ template <class C>
325
+ using CandidatesQueue = std::conditional_t<
326
+ C::is_max,
327
+ std::priority_queue<
328
+ HNSW::Node,
329
+ std::vector<HNSW::Node>,
330
+ std::greater<HNSW::Node>>,
331
+ std::priority_queue<HNSW::Node>>;
287
332
 
288
333
  /**************************************************************
289
334
  * Addition subroutines
290
335
  **************************************************************/
291
336
 
292
337
  /// remove neighbors from the list to make it smaller than max_size
293
- void shrink_neighbor_list(
338
+ template <class C>
339
+ void shrink_neighbor_list_inner(
294
340
  DistanceComputer& qdis,
295
- std::priority_queue<NodeDistCloser>& resultSet1,
296
- int max_size,
341
+ std::priority_queue<HNSW::NodeDistCloserT<C>>& resultSet1,
342
+ size_t max_size,
297
343
  bool keep_max_size_level0 = false) {
298
- if (resultSet1.size() < max_size) {
344
+ if (resultSet1.size() < static_cast<size_t>(max_size)) {
299
345
  return;
300
346
  }
301
- std::priority_queue<NodeDistFarther> resultSet;
302
- std::vector<NodeDistFarther> returnlist;
347
+ std::priority_queue<HNSW::NodeDistFartherT<C>> resultSet;
348
+ std::vector<HNSW::NodeDistFartherT<C>> returnlist;
303
349
 
304
350
  while (resultSet1.size() > 0) {
305
351
  resultSet.emplace(resultSet1.top().d, resultSet1.top().id);
306
352
  resultSet1.pop();
307
353
  }
308
354
 
309
- HNSW::shrink_neighbor_list(
355
+ HNSW::shrink_neighbor_list<C>(
310
356
  qdis, resultSet, returnlist, max_size, keep_max_size_level0);
311
357
 
312
- for (NodeDistFarther curen2 : returnlist) {
358
+ for (HNSW::NodeDistFartherT<C> curen2 : returnlist) {
313
359
  resultSet1.emplace(curen2.d, curen2.id);
314
360
  }
315
361
  }
316
362
 
317
363
  /// add a link between two elements, possibly shrinking the list
318
364
  /// of links to make room for it.
319
- void add_link(
365
+ template <class C>
366
+ void add_link_tpl(
320
367
  HNSW& hnsw,
321
368
  DistanceComputer& qdis,
322
369
  storage_idx_t src,
@@ -341,14 +388,17 @@ void add_link(
341
388
  // otherwise we let them fight out which to keep
342
389
 
343
390
  // copy to resultSet...
344
- std::priority_queue<NodeDistCloser> resultSet;
391
+ std::priority_queue<HNSW::NodeDistCloserT<C>> resultSet;
345
392
  resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
346
- for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
393
+ for (size_t i = begin; i < end; i++) {
347
394
  storage_idx_t neigh = hnsw.neighbors[i];
348
395
  resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
349
396
  }
350
397
 
351
- shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
398
+ size_t max_size = end - begin;
399
+ max_size -= max_size * std::clamp(hnsw.prune_headroom, 0.0f, 0.5f);
400
+ shrink_neighbor_list_inner<C>(
401
+ qdis, resultSet, max_size, keep_max_size_level0);
352
402
 
353
403
  // ...and back
354
404
  size_t i = begin;
@@ -362,31 +412,33 @@ void add_link(
362
412
  }
363
413
  }
364
414
 
365
- } // namespace
366
-
367
- /// search neighbors on a single level, starting from an entry point
368
- void search_neighbors_to_add(
415
+ /** Templated body of `search_neighbors_to_add` — instantiated once per final
416
+ * VisitedTable subclass × comparator so that `vt.set/advance` are inlined
417
+ * and the cost of virtual dispatch is paid only once at the top of the call.
418
+ */
419
+ template <typename VTType, class C>
420
+ static void search_neighbors_to_add_fixVT(
369
421
  HNSW& hnsw,
370
422
  DistanceComputer& qdis,
371
- std::priority_queue<NodeDistCloser>& results,
423
+ std::priority_queue<HNSW::NodeDistCloserT<C>>& results,
372
424
  int entry_point,
373
425
  float d_entry_point,
374
426
  int level,
375
- VisitedTable& vt,
427
+ VTType& vt,
376
428
  bool reference_version) {
377
429
  // top is nearest candidate
378
- std::priority_queue<NodeDistFarther> candidates;
430
+ std::priority_queue<HNSW::NodeDistFartherT<C>> candidates;
379
431
 
380
- NodeDistFarther ev(d_entry_point, entry_point);
432
+ HNSW::NodeDistFartherT<C> ev(d_entry_point, entry_point);
381
433
  candidates.push(ev);
382
434
  results.emplace(d_entry_point, entry_point);
383
435
  vt.set(entry_point);
384
436
 
385
437
  while (!candidates.empty()) {
386
438
  // get nearest
387
- const NodeDistFarther& currEv = candidates.top();
439
+ const HNSW::NodeDistFartherT<C>& currEv = candidates.top();
388
440
 
389
- if (currEv.d > results.top().d) {
441
+ if (C::cmp(currEv.d, results.top().d)) {
390
442
  break;
391
443
  }
392
444
  int currNode = currEv.id;
@@ -407,7 +459,7 @@ void search_neighbors_to_add(
407
459
  if (reference_version) {
408
460
  // a reference version
409
461
  for (size_t i = begin; i < end; i++) {
410
- storage_idx_t nodeId = hnsw.neighbors[i];
462
+ HNSW::storage_idx_t nodeId = hnsw.neighbors[i];
411
463
  if (nodeId < 0) {
412
464
  break;
413
465
  }
@@ -416,13 +468,14 @@ void search_neighbors_to_add(
416
468
  }
417
469
 
418
470
  float dis = qdis(nodeId);
419
- NodeDistFarther evE1(dis, nodeId);
471
+ HNSW::NodeDistFartherT<C> evE1(dis, nodeId);
420
472
 
421
- if (results.size() < hnsw.efConstruction ||
422
- results.top().d > dis) {
473
+ if (results.size() < static_cast<size_t>(hnsw.efConstruction) ||
474
+ C::cmp(results.top().d, dis)) {
423
475
  results.emplace(dis, nodeId);
424
476
  candidates.emplace(dis, nodeId);
425
- if (results.size() > hnsw.efConstruction) {
477
+ if (results.size() >
478
+ static_cast<size_t>(hnsw.efConstruction)) {
426
479
  results.pop();
427
480
  }
428
481
  }
@@ -431,23 +484,24 @@ void search_neighbors_to_add(
431
484
  // a faster version
432
485
 
433
486
  // the following version processes 4 neighbors at a time
434
- auto update_with_candidate = [&](const storage_idx_t idx,
487
+ auto update_with_candidate = [&](const HNSW::storage_idx_t idx,
435
488
  const float dis) {
436
- if (results.size() < hnsw.efConstruction ||
437
- results.top().d > dis) {
489
+ if (results.size() < static_cast<size_t>(hnsw.efConstruction) ||
490
+ C::cmp(results.top().d, dis)) {
438
491
  results.emplace(dis, idx);
439
492
  candidates.emplace(dis, idx);
440
- if (results.size() > hnsw.efConstruction) {
493
+ if (results.size() >
494
+ static_cast<size_t>(hnsw.efConstruction)) {
441
495
  results.pop();
442
496
  }
443
497
  }
444
498
  };
445
499
 
446
500
  int n_buffered = 0;
447
- storage_idx_t buffered_ids[4];
501
+ HNSW::storage_idx_t buffered_ids[4];
448
502
 
449
503
  for (size_t j = begin; j < end; j++) {
450
- storage_idx_t nodeId = hnsw.neighbors[j];
504
+ HNSW::storage_idx_t nodeId = hnsw.neighbors[j];
451
505
  if (nodeId < 0) {
452
506
  break;
453
507
  }
@@ -479,7 +533,7 @@ void search_neighbors_to_add(
479
533
  }
480
534
 
481
535
  // process leftovers
482
- for (size_t icnt = 0; icnt < n_buffered; icnt++) {
536
+ for (int icnt = 0; icnt < n_buffered; icnt++) {
483
537
  float dis = qdis(buffered_ids[icnt]);
484
538
  update_with_candidate(buffered_ids[icnt], dis);
485
539
  }
@@ -489,66 +543,263 @@ void search_neighbors_to_add(
489
543
  vt.advance();
490
544
  }
491
545
 
492
- /// Finds neighbors and builds links with them, starting from an entry
493
- /// point. The own neighbor list is assumed to be locked.
494
- void HNSW::add_links_starting_from(
546
+ /// Dispatches the VisitedTable concrete type for a given C, then calls
547
+ /// the templated `search_neighbors_to_add_fixVT<VTType, C>`.
548
+ template <class C>
549
+ void search_neighbors_to_add_dispatch(
550
+ HNSW& hnsw,
551
+ DistanceComputer& qdis,
552
+ std::priority_queue<HNSW::NodeDistCloserT<C>>& results,
553
+ int entry_point,
554
+ float d_entry_point,
555
+ int level,
556
+ VisitedTable& vt,
557
+ bool reference_version) {
558
+ auto call = [&]<typename VTType>(VTType& vt_concrete) {
559
+ search_neighbors_to_add_fixVT<VTType, C>(
560
+ hnsw,
561
+ qdis,
562
+ results,
563
+ entry_point,
564
+ d_entry_point,
565
+ level,
566
+ vt_concrete,
567
+ reference_version);
568
+ };
569
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(&vt)) {
570
+ call(*vtv);
571
+ return;
572
+ }
573
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(vt);
574
+ call(vts);
575
+ }
576
+
577
+ /// Templated implementation of `HNSW::add_links_starting_from`.
578
+ template <class C>
579
+ void add_links_starting_from_impl(
580
+ HNSW& hnsw,
495
581
  DistanceComputer& ptdis,
496
582
  storage_idx_t pt_id,
497
583
  storage_idx_t nearest,
498
584
  float d_nearest,
499
585
  int level,
500
- omp_lock_t* locks,
586
+ LockVector& locks,
501
587
  VisitedTable& vt,
502
588
  bool keep_max_size_level0) {
503
- std::priority_queue<NodeDistCloser> link_targets;
589
+ std::priority_queue<HNSW::NodeDistCloserT<C>> link_targets;
504
590
 
505
- search_neighbors_to_add(
506
- *this, ptdis, link_targets, nearest, d_nearest, level, vt);
591
+ search_neighbors_to_add_dispatch<C>(
592
+ hnsw, ptdis, link_targets, nearest, d_nearest, level, vt, false);
507
593
 
508
594
  // but we can afford only this many neighbors
509
- int M = nb_neighbors(level);
595
+ int M = hnsw.nb_neighbors(level);
510
596
 
511
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
597
+ shrink_neighbor_list_inner<C>(ptdis, link_targets, M, keep_max_size_level0);
512
598
 
513
599
  std::vector<storage_idx_t> neighbors_to_add;
514
600
  neighbors_to_add.reserve(link_targets.size());
515
601
  while (!link_targets.empty()) {
516
602
  storage_idx_t other_id = link_targets.top().id;
517
- add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
603
+ add_link_tpl<C>(
604
+ hnsw, ptdis, pt_id, other_id, level, keep_max_size_level0);
518
605
  neighbors_to_add.push_back(other_id);
519
606
  link_targets.pop();
520
607
  }
521
608
 
522
- omp_unset_lock(&locks[pt_id]);
609
+ locks.unlock(pt_id);
523
610
  for (storage_idx_t other_id : neighbors_to_add) {
524
- omp_set_lock(&locks[other_id]);
525
- add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
526
- omp_unset_lock(&locks[other_id]);
611
+ locks.lock(other_id);
612
+ add_link_tpl<C>(
613
+ hnsw, ptdis, other_id, pt_id, level, keep_max_size_level0);
614
+ locks.unlock(other_id);
615
+ }
616
+ locks.lock(pt_id);
617
+ }
618
+
619
+ } // namespace
620
+
621
+ /// Finds neighbors and builds links with them, starting from an entry
622
+ /// point. The own neighbor list is assumed to be locked.
623
+ void HNSW::add_links_starting_from(
624
+ DistanceComputer& ptdis,
625
+ storage_idx_t pt_id,
626
+ storage_idx_t nearest,
627
+ float d_nearest,
628
+ int level,
629
+ LockVector& locks,
630
+ VisitedTable& vt,
631
+ bool keep_max_size_level0) {
632
+ if (is_similarity) {
633
+ add_links_starting_from_impl<C_similarity>(
634
+ *this,
635
+ ptdis,
636
+ pt_id,
637
+ nearest,
638
+ d_nearest,
639
+ level,
640
+ locks,
641
+ vt,
642
+ keep_max_size_level0);
643
+ } else {
644
+ add_links_starting_from_impl<C_distance>(
645
+ *this,
646
+ ptdis,
647
+ pt_id,
648
+ nearest,
649
+ d_nearest,
650
+ level,
651
+ locks,
652
+ vt,
653
+ keep_max_size_level0);
527
654
  }
528
- omp_set_lock(&locks[pt_id]);
655
+ }
656
+
657
+ /// search neighbors on a single level, starting from an entry point.
658
+ /// Public dispatcher: always operates in distance (CMax) mode because its
659
+ /// `priority_queue<HNSW::NodeDistCloser>` signature is the back-compat
660
+ /// distance flavor. Internal callers that need similarity mode reach the
661
+ /// templated implementation directly via `search_neighbors_to_add_dispatch`.
662
+ void hnsw_detail::search_neighbors_to_add(
663
+ HNSW& hnsw,
664
+ DistanceComputer& qdis,
665
+ std::priority_queue<HNSW::NodeDistCloser>& results,
666
+ int entry_point,
667
+ float d_entry_point,
668
+ int level,
669
+ VisitedTable& vt,
670
+ bool reference_version) {
671
+ search_neighbors_to_add_dispatch<HNSW::C_distance>(
672
+ hnsw,
673
+ qdis,
674
+ results,
675
+ entry_point,
676
+ d_entry_point,
677
+ level,
678
+ vt,
679
+ reference_version);
529
680
  }
530
681
 
531
682
  /**************************************************************
532
683
  * Building, parallel
533
684
  **************************************************************/
534
685
 
535
- void HNSW::add_with_locks(
686
+ namespace {
687
+
688
+ /// Greedy update of the nearest entry point at a given level.
689
+ template <class C>
690
+ HNSWStats greedy_update_nearest_impl(
691
+ const HNSW& hnsw,
692
+ DistanceComputer& qdis,
693
+ int level,
694
+ storage_idx_t& nearest,
695
+ float& d_nearest) {
696
+ HNSWStats stats;
697
+
698
+ for (;;) {
699
+ storage_idx_t prev_nearest = nearest;
700
+
701
+ size_t begin, end;
702
+ hnsw.neighbor_range(nearest, level, &begin, &end);
703
+
704
+ size_t ndis = 0;
705
+
706
+ // a faster version: reference version in unit test test_hnsw.cpp
707
+ // the following version processes 4 neighbors at a time
708
+ auto update_with_candidate = [&](const storage_idx_t idx,
709
+ const float dis) {
710
+ if (C::cmp(d_nearest, dis)) {
711
+ nearest = idx;
712
+ d_nearest = dis;
713
+ }
714
+ };
715
+
716
+ int n_buffered = 0;
717
+ storage_idx_t buffered_ids[4];
718
+
719
+ for (size_t j = begin; j < end; j++) {
720
+ storage_idx_t v = hnsw.neighbors[j];
721
+ if (v < 0) {
722
+ break;
723
+ }
724
+ ndis += 1;
725
+
726
+ buffered_ids[n_buffered] = v;
727
+ n_buffered += 1;
728
+
729
+ if (n_buffered == 4) {
730
+ float dis[4];
731
+ qdis.distances_batch_4(
732
+ buffered_ids[0],
733
+ buffered_ids[1],
734
+ buffered_ids[2],
735
+ buffered_ids[3],
736
+ dis[0],
737
+ dis[1],
738
+ dis[2],
739
+ dis[3]);
740
+
741
+ for (size_t id4 = 0; id4 < 4; id4++) {
742
+ update_with_candidate(buffered_ids[id4], dis[id4]);
743
+ }
744
+
745
+ n_buffered = 0;
746
+ }
747
+ }
748
+
749
+ // process leftovers
750
+ for (int icnt = 0; icnt < n_buffered; icnt++) {
751
+ float dis = qdis(buffered_ids[icnt]);
752
+ update_with_candidate(buffered_ids[icnt], dis);
753
+ }
754
+
755
+ // update stats
756
+ stats.ndis += ndis;
757
+ stats.nhops += 1;
758
+
759
+ if (nearest == prev_nearest) {
760
+ return stats;
761
+ }
762
+ }
763
+ }
764
+
765
+ } // namespace
766
+
767
+ /// greedily update a nearest vector at a given level
768
+ HNSWStats hnsw_detail::greedy_update_nearest(
769
+ const HNSW& hnsw,
770
+ DistanceComputer& qdis,
771
+ int level,
772
+ storage_idx_t& nearest,
773
+ float& d_nearest) {
774
+ if (hnsw.is_similarity) {
775
+ return greedy_update_nearest_impl<HNSW::C_similarity>(
776
+ hnsw, qdis, level, nearest, d_nearest);
777
+ }
778
+ return greedy_update_nearest_impl<HNSW::C_distance>(
779
+ hnsw, qdis, level, nearest, d_nearest);
780
+ }
781
+
782
+ namespace {
783
+
784
+ template <class C>
785
+ void add_with_locks_impl(
786
+ HNSW& hnsw,
536
787
  DistanceComputer& ptdis,
537
788
  int pt_level,
538
789
  int pt_id,
539
- std::vector<omp_lock_t>& locks,
790
+ LockVector& locks,
540
791
  VisitedTable& vt,
541
792
  bool keep_max_size_level0) {
542
- // greedy search on upper levels
543
-
544
- storage_idx_t nearest;
793
+ storage_idx_t nearest = hnsw.entry_point;
794
+ if (nearest == -1) { // avoid locking after the first point.
545
795
  #pragma omp critical
546
- {
547
- nearest = entry_point;
548
-
549
- if (nearest == -1) {
550
- max_level = pt_level;
551
- entry_point = pt_id;
796
+ if (hnsw.entry_point == -1) { // double-check under lock.
797
+ hnsw.max_level = pt_level;
798
+ hnsw.entry_point = pt_id;
799
+ // leave nearest = -1 to trigger early exit after critical block.
800
+ } else {
801
+ // else: Another thread set the entry point.
802
+ nearest = hnsw.entry_point;
552
803
  }
553
804
  }
554
805
 
@@ -556,32 +807,55 @@ void HNSW::add_with_locks(
556
807
  return;
557
808
  }
558
809
 
559
- omp_set_lock(&locks[pt_id]);
810
+ locks.lock(pt_id);
560
811
 
561
- int level = max_level; // level at which we start adding neighbors
812
+ int level = hnsw.max_level; // level at which we start adding neighbors
562
813
  float d_nearest = ptdis(nearest);
563
814
 
815
+ // greedy search on upper levels
564
816
  for (; level > pt_level; level--) {
565
- greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
817
+ greedy_update_nearest_impl<C>(hnsw, ptdis, level, nearest, d_nearest);
566
818
  }
567
819
 
568
820
  for (; level >= 0; level--) {
569
- add_links_starting_from(
821
+ add_links_starting_from_impl<C>(
822
+ hnsw,
570
823
  ptdis,
571
824
  pt_id,
572
825
  nearest,
573
826
  d_nearest,
574
827
  level,
575
- locks.data(),
828
+ locks,
576
829
  vt,
577
830
  keep_max_size_level0);
578
831
  }
579
832
 
580
- omp_unset_lock(&locks[pt_id]);
833
+ locks.unlock(pt_id);
834
+
835
+ #pragma omp critical
836
+ {
837
+ if (pt_level > hnsw.max_level) {
838
+ hnsw.max_level = pt_level;
839
+ hnsw.entry_point = pt_id;
840
+ }
841
+ }
842
+ }
843
+
844
+ } // namespace
581
845
 
582
- if (pt_level > max_level) {
583
- max_level = pt_level;
584
- entry_point = pt_id;
846
+ void HNSW::add_with_locks(
847
+ DistanceComputer& ptdis,
848
+ int pt_level,
849
+ int pt_id,
850
+ LockVector& locks,
851
+ VisitedTable& vt,
852
+ bool keep_max_size_level0) {
853
+ if (is_similarity) {
854
+ add_with_locks_impl<C_similarity>(
855
+ *this, ptdis, pt_level, pt_id, locks, vt, keep_max_size_level0);
856
+ } else {
857
+ add_with_locks_impl<C_distance>(
858
+ *this, ptdis, pt_level, pt_id, locks, vt, keep_max_size_level0);
585
859
  }
586
860
  }
587
861
 
@@ -589,12 +863,10 @@ void HNSW::add_with_locks(
589
863
  * Searching
590
864
  **************************************************************/
591
865
 
592
- using MinimaxHeap = HNSW::MinimaxHeap;
593
- using Node = HNSW::Node;
594
- using C = HNSW::C;
866
+ namespace {
595
867
 
596
868
  /** Helper to extract search parameters from HNSW and SearchParameters */
597
- static inline void extract_search_params(
869
+ inline void extract_search_params(
598
870
  const HNSW& hnsw,
599
871
  const SearchParameters* params,
600
872
  bool& do_dis_check,
@@ -614,13 +886,16 @@ static inline void extract_search_params(
614
886
  }
615
887
  }
616
888
 
617
- /** Do a BFS on the candidates list */
618
- int search_from_candidates(
889
+ /** Templated body of `search_from_candidates` instantiated once per
890
+ * VisitedTable subclass × comparator.
891
+ */
892
+ template <typename VTType, class C>
893
+ int search_from_candidates_fixVT(
619
894
  const HNSW& hnsw,
620
895
  DistanceComputer& qdis,
621
896
  ResultHandler& res,
622
- MinimaxHeap& candidates,
623
- VisitedTable& vt,
897
+ MinimaxHeapT<HC_for<C>>& candidates,
898
+ VTType& vt,
624
899
  HNSWStats& stats,
625
900
  int level,
626
901
  int nres_in,
@@ -633,13 +908,15 @@ int search_from_candidates(
633
908
  const IDSelector* sel;
634
909
  extract_search_params(hnsw, params, do_dis_check, efSearch, sel);
635
910
 
636
- C::T threshold = res.threshold;
911
+ vt.reserve(efSearch);
912
+
913
+ typename C::T threshold = res.threshold;
637
914
  for (int i = 0; i < candidates.size(); i++) {
638
915
  idx_t v1 = candidates.ids[i];
639
916
  float d = candidates.dis[i];
640
917
  FAISS_ASSERT(v1 >= 0);
641
918
  if (!sel || sel->is_member(v1)) {
642
- if (d < threshold) {
919
+ if (C::cmp(threshold, d)) {
643
920
  if (res.add_result(d, v1)) {
644
921
  threshold = res.threshold;
645
922
  }
@@ -688,7 +965,7 @@ int search_from_candidates(
688
965
 
689
966
  auto add_to_heap = [&](const size_t idx, const float dis) {
690
967
  if (!sel || sel->is_member(idx)) {
691
- if (dis < threshold) {
968
+ if (C::cmp(threshold, dis)) {
692
969
  if (res.add_result(dis, idx)) {
693
970
  threshold = res.threshold;
694
971
  nres += 1;
@@ -726,7 +1003,7 @@ int search_from_candidates(
726
1003
  }
727
1004
  }
728
1005
 
729
- for (size_t icnt = 0; icnt < counter; icnt++) {
1006
+ for (int icnt = 0; icnt < counter; icnt++) {
730
1007
  float dis = qdis(saved_j[icnt]);
731
1008
  add_to_heap(saved_j[icnt], dis);
732
1009
 
@@ -751,7 +1028,58 @@ int search_from_candidates(
751
1028
  return nres;
752
1029
  }
753
1030
 
754
- int search_from_candidates_panorama(
1031
+ /// Dispatches the VisitedTable concrete type for a given C, then calls
1032
+ /// the templated `search_from_candidates_fixVT<VTType, C>`.
1033
+ template <class C>
1034
+ int search_from_candidates_dispatch(
1035
+ const HNSW& hnsw,
1036
+ DistanceComputer& qdis,
1037
+ ResultHandler& res,
1038
+ MinimaxHeapT<HC_for<C>>& candidates,
1039
+ VisitedTable& vt,
1040
+ HNSWStats& stats,
1041
+ int level,
1042
+ int nres_in,
1043
+ const SearchParameters* params) {
1044
+ auto call = [&]<typename VTType>(VTType& vt_concrete) -> int {
1045
+ return search_from_candidates_fixVT<VTType, C>(
1046
+ hnsw,
1047
+ qdis,
1048
+ res,
1049
+ candidates,
1050
+ vt_concrete,
1051
+ stats,
1052
+ level,
1053
+ nres_in,
1054
+ params);
1055
+ };
1056
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(&vt)) {
1057
+ return call(*vtv);
1058
+ }
1059
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(vt);
1060
+ return call(vts);
1061
+ }
1062
+
1063
+ } // namespace
1064
+
1065
+ /** Do a BFS on the candidates list. Public dispatcher: only handles the
1066
+ * distance (CMax) flavor because its `MinimaxHeap` parameter is the
1067
+ * CMax instantiation. */
1068
+ int hnsw_detail::search_from_candidates(
1069
+ const HNSW& hnsw,
1070
+ DistanceComputer& qdis,
1071
+ ResultHandler& res,
1072
+ MinimaxHeap& candidates,
1073
+ VisitedTable& vt,
1074
+ HNSWStats& stats,
1075
+ int level,
1076
+ int nres_in,
1077
+ const SearchParameters* params) {
1078
+ return search_from_candidates_dispatch<HNSW::C_distance>(
1079
+ hnsw, qdis, res, candidates, vt, stats, level, nres_in, params);
1080
+ }
1081
+
1082
+ int hnsw_detail::search_from_candidates_panorama(
755
1083
  const HNSW& hnsw,
756
1084
  const IndexHNSW* index,
757
1085
  DistanceComputer& qdis,
@@ -762,6 +1090,14 @@ int search_from_candidates_panorama(
762
1090
  int level,
763
1091
  int nres_in,
764
1092
  const SearchParameters* params) {
1093
+ // Panorama's progressive-bound math is L2-specific: refuse to run in
1094
+ // similarity mode.
1095
+ FAISS_THROW_IF_NOT_MSG(
1096
+ !hnsw.is_similarity,
1097
+ "search_from_candidates_panorama does not support is_similarity=true");
1098
+
1099
+ using C = HNSW::C_distance;
1100
+
765
1101
  int nres = nres_in;
766
1102
  int ndis = 0;
767
1103
 
@@ -776,7 +1112,7 @@ int search_from_candidates_panorama(
776
1112
  float d = candidates.dis[i];
777
1113
  FAISS_ASSERT(v1 >= 0);
778
1114
  if (!sel || sel->is_member(v1)) {
779
- if (d < threshold) {
1115
+ if (C::cmp(threshold, d)) {
780
1116
  if (res.add_result(d, v1)) {
781
1117
  threshold = res.threshold;
782
1118
  }
@@ -807,6 +1143,10 @@ int search_from_candidates_panorama(
807
1143
  float query_norm_sq = query_cum_sums[0] * query_cum_sums[0];
808
1144
 
809
1145
  int nstep = 0;
1146
+ const size_t d = static_cast<size_t>(panorama_index->d);
1147
+
1148
+ PanoramaStats local_pano_stats;
1149
+ local_pano_stats.reset();
810
1150
 
811
1151
  while (candidates.size() > 0) {
812
1152
  float d0 = 0;
@@ -845,6 +1185,7 @@ int search_from_candidates_panorama(
845
1185
  initial_size += is_selected && vt.set(v1) ? 1 : 0;
846
1186
  }
847
1187
 
1188
+ local_pano_stats.total_dims += initial_size * d;
848
1189
  size_t batch_size = initial_size;
849
1190
  size_t curr_panorama_level = 0;
850
1191
  const size_t num_panorama_levels = panorama_index->pano.n_levels;
@@ -907,28 +1248,28 @@ int search_from_candidates_panorama(
907
1248
  // the maintenance of the candidate heap), but micro-benchmarks
908
1249
  // have shown that it is not worth it to write horrible code to
909
1250
  // squeeze out those cycles.
910
- if (lower_bound_0 <= threshold) {
1251
+ if (!C::cmp(lower_bound_0, threshold)) {
911
1252
  exact_distances[next_batch_size] = new_exact_0;
912
1253
  index_array[next_batch_size] = idx_0;
913
1254
  next_batch_size += 1;
914
1255
  } else {
915
1256
  candidates.push(idx_0, new_exact_0);
916
1257
  }
917
- if (lower_bound_1 <= threshold) {
1258
+ if (!C::cmp(lower_bound_1, threshold)) {
918
1259
  exact_distances[next_batch_size] = new_exact_1;
919
1260
  index_array[next_batch_size] = idx_1;
920
1261
  next_batch_size += 1;
921
1262
  } else {
922
1263
  candidates.push(idx_1, new_exact_1);
923
1264
  }
924
- if (lower_bound_2 <= threshold) {
1265
+ if (!C::cmp(lower_bound_2, threshold)) {
925
1266
  exact_distances[next_batch_size] = new_exact_2;
926
1267
  index_array[next_batch_size] = idx_2;
927
1268
  next_batch_size += 1;
928
1269
  } else {
929
1270
  candidates.push(idx_2, new_exact_2);
930
1271
  }
931
- if (lower_bound_3 <= threshold) {
1272
+ if (!C::cmp(lower_bound_3, threshold)) {
932
1273
  exact_distances[next_batch_size] = new_exact_3;
933
1274
  index_array[next_batch_size] = idx_3;
934
1275
  next_batch_size += 1;
@@ -951,7 +1292,7 @@ int search_from_candidates_panorama(
951
1292
  float cs_bound = 2.0f * cum_sum * query_cum_norm;
952
1293
  float lower_bound = new_exact - cs_bound;
953
1294
 
954
- if (lower_bound <= threshold) {
1295
+ if (!C::cmp(lower_bound, threshold)) {
955
1296
  exact_distances[next_batch_size] = new_exact;
956
1297
  index_array[next_batch_size] = idx;
957
1298
  next_batch_size += 1;
@@ -960,6 +1301,8 @@ int search_from_candidates_panorama(
960
1301
  }
961
1302
  }
962
1303
 
1304
+ local_pano_stats.total_dims_scanned +=
1305
+ batch_size * (end_dim - start_dim);
963
1306
  batch_size = next_batch_size;
964
1307
  curr_panorama_level++;
965
1308
  }
@@ -968,6 +1311,7 @@ int search_from_candidates_panorama(
968
1311
  for (size_t i = 0; i < batch_size; i++) {
969
1312
  idx_t idx = index_array[i];
970
1313
  if (res.add_result(exact_distances[i], idx)) {
1314
+ threshold = res.threshold;
971
1315
  nres += 1;
972
1316
  }
973
1317
  candidates.push(idx, exact_distances[i]);
@@ -988,31 +1332,53 @@ int search_from_candidates_panorama(
988
1332
  stats.nhops += nstep;
989
1333
  }
990
1334
 
1335
+ indexPanorama_stats.add(local_pano_stats);
991
1336
  return nres;
992
1337
  }
993
1338
 
994
- std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1339
+ namespace {
1340
+
1341
+ template <typename T, typename Container, typename Compare>
1342
+ void reservePriorityQueue(
1343
+ std::priority_queue<T, Container, Compare>& q,
1344
+ std::size_t size) {
1345
+ struct Access : std::priority_queue<T, Container, Compare> {
1346
+ using std::priority_queue<T, Container, Compare>::c;
1347
+ };
1348
+ Access access{std::move(q)};
1349
+ access.c.reserve(size);
1350
+ q = std::move(access);
1351
+ }
1352
+
1353
+ /// Templated body of `search_from_candidate_unbounded`. The choice of
1354
+ /// max-heap vs min-heap for both `top_candidates` and `candidates` is
1355
+ /// derived from C via `TopCandidatesQueue` / `CandidatesQueue`.
1356
+ template <typename VTType, class C>
1357
+ TopCandidatesQueue<C> search_from_candidate_unbounded_fixVT(
995
1358
  const HNSW& hnsw,
996
- const Node& node,
1359
+ const HNSW::Node& node,
997
1360
  DistanceComputer& qdis,
998
1361
  int ef,
999
- VisitedTable* vt,
1362
+ VTType& vt,
1000
1363
  HNSWStats& stats) {
1001
1364
  int ndis = 0;
1002
- std::priority_queue<Node> top_candidates;
1003
- std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
1365
+ TopCandidatesQueue<C> top_candidates;
1366
+ reservePriorityQueue(top_candidates, ef);
1367
+
1368
+ CandidatesQueue<C> candidates;
1369
+ reservePriorityQueue(candidates, ef);
1004
1370
 
1005
1371
  top_candidates.push(node);
1006
1372
  candidates.push(node);
1007
1373
 
1008
- vt->set(node.second);
1374
+ vt.set(node.second);
1009
1375
 
1010
1376
  while (!candidates.empty()) {
1011
1377
  float d0;
1012
1378
  storage_idx_t v0;
1013
1379
  std::tie(d0, v0) = candidates.top();
1014
1380
 
1015
- if (d0 > top_candidates.top().first) {
1381
+ if (C::cmp(d0, top_candidates.top().first)) {
1016
1382
  break;
1017
1383
  }
1018
1384
 
@@ -1030,7 +1396,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1030
1396
  break;
1031
1397
  }
1032
1398
 
1033
- vt->prefetch(v1);
1399
+ vt.prefetch(v1);
1034
1400
  jmax += 1;
1035
1401
  }
1036
1402
 
@@ -1038,12 +1404,12 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1038
1404
  size_t saved_j[4];
1039
1405
 
1040
1406
  auto add_to_heap = [&](const size_t idx, const float dis) {
1041
- if (top_candidates.top().first > dis ||
1042
- top_candidates.size() < ef) {
1407
+ if (C::cmp(top_candidates.top().first, dis) ||
1408
+ top_candidates.size() < static_cast<size_t>(ef)) {
1043
1409
  candidates.emplace(dis, idx);
1044
1410
  top_candidates.emplace(dis, idx);
1045
1411
 
1046
- if (top_candidates.size() > ef) {
1412
+ if (top_candidates.size() > static_cast<size_t>(ef)) {
1047
1413
  top_candidates.pop();
1048
1414
  }
1049
1415
  }
@@ -1053,7 +1419,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1053
1419
  int v1 = hnsw.neighbors[j];
1054
1420
 
1055
1421
  saved_j[counter] = v1;
1056
- counter += vt->set(v1) ? 1 : 0;
1422
+ counter += vt.set(v1) ? 1 : 0;
1057
1423
 
1058
1424
  if (counter == 4) {
1059
1425
  float dis[4];
@@ -1077,7 +1443,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1077
1443
  }
1078
1444
  }
1079
1445
 
1080
- for (size_t icnt = 0; icnt < counter; icnt++) {
1446
+ for (int icnt = 0; icnt < counter; icnt++) {
1081
1447
  float dis = qdis(saved_j[icnt]);
1082
1448
  add_to_heap(saved_j[icnt], dis);
1083
1449
 
@@ -1096,159 +1462,127 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
1096
1462
  return top_candidates;
1097
1463
  }
1098
1464
 
1099
- /// greedily update a nearest vector at a given level
1100
- HNSWStats greedy_update_nearest(
1465
+ } // namespace
1466
+
1467
+ /// Public dispatcher: only the distance (CMax) flavor is exposed because
1468
+ /// its return type — `std::priority_queue<HNSW::Node>` — is the CMax
1469
+ /// max-heap. Internal callers that need similarity mode use the same
1470
+ /// dispatch pattern inline.
1471
+ std::priority_queue<HNSW::Node> hnsw_detail::search_from_candidate_unbounded(
1101
1472
  const HNSW& hnsw,
1473
+ const HNSW::Node& node,
1102
1474
  DistanceComputer& qdis,
1103
- int level,
1104
- storage_idx_t& nearest,
1105
- float& d_nearest) {
1106
- HNSWStats stats;
1107
-
1108
- for (;;) {
1109
- storage_idx_t prev_nearest = nearest;
1110
-
1111
- size_t begin, end;
1112
- hnsw.neighbor_range(nearest, level, &begin, &end);
1113
-
1114
- size_t ndis = 0;
1115
-
1116
- // a faster version: reference version in unit test test_hnsw.cpp
1117
- // the following version processes 4 neighbors at a time
1118
- auto update_with_candidate = [&](const storage_idx_t idx,
1119
- const float dis) {
1120
- if (dis < d_nearest) {
1121
- nearest = idx;
1122
- d_nearest = dis;
1123
- }
1124
- };
1125
-
1126
- int n_buffered = 0;
1127
- storage_idx_t buffered_ids[4];
1128
-
1129
- for (size_t j = begin; j < end; j++) {
1130
- storage_idx_t v = hnsw.neighbors[j];
1131
- if (v < 0) {
1132
- break;
1133
- }
1134
- ndis += 1;
1135
-
1136
- buffered_ids[n_buffered] = v;
1137
- n_buffered += 1;
1138
-
1139
- if (n_buffered == 4) {
1140
- float dis[4];
1141
- qdis.distances_batch_4(
1142
- buffered_ids[0],
1143
- buffered_ids[1],
1144
- buffered_ids[2],
1145
- buffered_ids[3],
1146
- dis[0],
1147
- dis[1],
1148
- dis[2],
1149
- dis[3]);
1150
-
1151
- for (size_t id4 = 0; id4 < 4; id4++) {
1152
- update_with_candidate(buffered_ids[id4], dis[id4]);
1153
- }
1154
-
1155
- n_buffered = 0;
1156
- }
1157
- }
1158
-
1159
- // process leftovers
1160
- for (size_t icnt = 0; icnt < n_buffered; icnt++) {
1161
- float dis = qdis(buffered_ids[icnt]);
1162
- update_with_candidate(buffered_ids[icnt], dis);
1163
- }
1164
-
1165
- // update stats
1166
- stats.ndis += ndis;
1167
- stats.nhops += 1;
1168
-
1169
- if (nearest == prev_nearest) {
1170
- return stats;
1171
- }
1475
+ int ef,
1476
+ VisitedTable* vt,
1477
+ HNSWStats& stats) {
1478
+ using C = HNSW::C_distance;
1479
+ auto call = [&]<typename VTType>(VTType& vt_concrete) {
1480
+ return search_from_candidate_unbounded_fixVT<VTType, C>(
1481
+ hnsw, node, qdis, ef, vt_concrete, stats);
1482
+ };
1483
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(vt)) {
1484
+ return call(*vtv);
1172
1485
  }
1486
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(*vt);
1487
+ return call(vts);
1173
1488
  }
1174
1489
 
1175
1490
  namespace {
1176
- using MinimaxHeap = HNSW::MinimaxHeap;
1177
- using Node = HNSW::Node;
1178
- using C = HNSW::C;
1179
1491
 
1180
1492
  // just used as a lower bound for the minmaxheap, but it is set for heap search
1493
+ template <class C>
1181
1494
  int extract_k_from_ResultHandler(ResultHandler& res) {
1182
1495
  using RH = HeapBlockResultHandler<C>;
1183
- if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
1496
+ if (auto hres = dynamic_cast<typename RH::SingleResultHandler*>(&res)) {
1184
1497
  return hres->k;
1185
1498
  }
1186
1499
  return 1;
1187
1500
  }
1188
1501
 
1189
- } // namespace
1190
-
1191
- HNSWStats HNSW::search(
1502
+ template <class C>
1503
+ HNSWStats search_impl(
1504
+ const HNSW& hnsw,
1192
1505
  DistanceComputer& qdis,
1193
1506
  const IndexHNSW* index,
1194
1507
  ResultHandler& res,
1195
1508
  VisitedTable& vt,
1196
- const SearchParameters* params) const {
1509
+ const SearchParameters* params) {
1197
1510
  HNSWStats stats;
1198
- if (entry_point == -1) {
1511
+ if (hnsw.entry_point == -1) {
1199
1512
  return stats;
1200
1513
  }
1201
- int k = extract_k_from_ResultHandler(res);
1514
+ int k = extract_k_from_ResultHandler<C>(res);
1202
1515
 
1203
- bool bounded_queue = this->search_bounded_queue;
1204
- int efSearch = this->efSearch;
1516
+ bool bounded_queue = hnsw.search_bounded_queue;
1517
+ int cur_efSearch = hnsw.efSearch;
1205
1518
  if (params) {
1206
1519
  if (const SearchParametersHNSW* hnsw_params =
1207
1520
  dynamic_cast<const SearchParametersHNSW*>(params)) {
1208
1521
  bounded_queue = hnsw_params->bounded_queue;
1209
- efSearch = hnsw_params->efSearch;
1522
+ cur_efSearch = hnsw_params->efSearch;
1210
1523
  }
1211
1524
  }
1212
1525
 
1213
1526
  // greedy search on upper levels
1214
- storage_idx_t nearest = entry_point;
1527
+ storage_idx_t nearest = hnsw.entry_point;
1215
1528
  float d_nearest = qdis(nearest);
1216
1529
 
1217
- for (int level = max_level; level >= 1; level--) {
1218
- HNSWStats local_stats =
1219
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
1530
+ for (int level = hnsw.max_level; level >= 1; level--) {
1531
+ HNSWStats local_stats = greedy_update_nearest_impl<C>(
1532
+ hnsw, qdis, level, nearest, d_nearest);
1220
1533
  stats.combine(local_stats);
1221
1534
  }
1222
1535
 
1223
- int ef = std::max(efSearch, k);
1536
+ int ef = std::max(cur_efSearch, k);
1224
1537
  if (bounded_queue) { // this is the most common branch, for now we only
1225
1538
  // support Panorama search in this branch
1226
- MinimaxHeap candidates(ef);
1539
+ MinimaxHeapT<HC_for<C>> candidates(ef);
1227
1540
 
1228
1541
  candidates.push(nearest, d_nearest);
1229
1542
 
1230
- if (!is_panorama) {
1231
- search_from_candidates(
1232
- *this, qdis, res, candidates, vt, stats, 0, 0, params);
1543
+ if (!hnsw.is_panorama) {
1544
+ search_from_candidates_dispatch<C>(
1545
+ hnsw, qdis, res, candidates, vt, stats, 0, 0, params);
1233
1546
  } else {
1234
- search_from_candidates_panorama(
1235
- *this,
1236
- index,
1237
- qdis,
1238
- res,
1239
- candidates,
1240
- vt,
1241
- stats,
1242
- 0,
1243
- 0,
1244
- params);
1547
+ // Panorama is L2-specific and is only valid for C_distance.
1548
+ // The public dispatch ensures we never reach this code path
1549
+ // with C != C_distance, but assert in debug builds.
1550
+ if constexpr (std::is_same_v<C, HNSW::C_distance>) {
1551
+ hnsw_detail::search_from_candidates_panorama(
1552
+ hnsw,
1553
+ index,
1554
+ qdis,
1555
+ res,
1556
+ candidates,
1557
+ vt,
1558
+ stats,
1559
+ 0,
1560
+ 0,
1561
+ params);
1562
+ } else {
1563
+ FAISS_THROW_MSG(
1564
+ "Panorama search is not supported with is_similarity=true");
1565
+ }
1245
1566
  }
1246
1567
  } else {
1247
- std::priority_queue<Node> top_candidates =
1248
- search_from_candidate_unbounded(
1249
- *this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
1568
+ auto call = [&]<typename VTType>(VTType& vt_concrete) {
1569
+ return search_from_candidate_unbounded_fixVT<VTType, C>(
1570
+ hnsw,
1571
+ HNSW::Node(d_nearest, nearest),
1572
+ qdis,
1573
+ ef,
1574
+ vt_concrete,
1575
+ stats);
1576
+ };
1577
+ TopCandidatesQueue<C> top_candidates;
1578
+ if (VisitedTableVector* vtv = dynamic_cast<VisitedTableVector*>(&vt)) {
1579
+ top_candidates = call(*vtv);
1580
+ } else {
1581
+ VisitedTableSet& vts = dynamic_cast<VisitedTableSet&>(vt);
1582
+ top_candidates = call(vts);
1583
+ }
1250
1584
 
1251
- while (top_candidates.size() > k) {
1585
+ while (top_candidates.size() > static_cast<size_t>(k)) {
1252
1586
  top_candidates.pop();
1253
1587
  }
1254
1588
 
@@ -1266,7 +1600,9 @@ HNSWStats HNSW::search(
1266
1600
  return stats;
1267
1601
  }
1268
1602
 
1269
- void HNSW::search_level_0(
1603
+ template <class C>
1604
+ void search_level_0_impl(
1605
+ const HNSW& hnsw,
1270
1606
  DistanceComputer& qdis,
1271
1607
  ResultHandler& res,
1272
1608
  idx_t nprobe,
@@ -1275,23 +1611,21 @@ void HNSW::search_level_0(
1275
1611
  int search_type,
1276
1612
  HNSWStats& search_stats,
1277
1613
  VisitedTable& vt,
1278
- const SearchParameters* params) const {
1279
- const HNSW& hnsw = *this;
1280
-
1281
- auto efSearch = hnsw.efSearch;
1614
+ const SearchParameters* params) {
1615
+ auto cur_efSearch = hnsw.efSearch;
1282
1616
  if (params) {
1283
1617
  if (const SearchParametersHNSW* hnsw_params =
1284
1618
  dynamic_cast<const SearchParametersHNSW*>(params)) {
1285
- efSearch = hnsw_params->efSearch;
1619
+ cur_efSearch = hnsw_params->efSearch;
1286
1620
  }
1287
1621
  }
1288
1622
 
1289
- int k = extract_k_from_ResultHandler(res);
1623
+ int k = extract_k_from_ResultHandler<C>(res);
1290
1624
 
1291
1625
  if (search_type == 1) {
1292
1626
  int nres = 0;
1293
1627
 
1294
- for (int j = 0; j < nprobe; j++) {
1628
+ for (idx_t j = 0; j < nprobe; j++) {
1295
1629
  storage_idx_t cj = nearest_i[j];
1296
1630
 
1297
1631
  if (cj < 0) {
@@ -1302,12 +1636,12 @@ void HNSW::search_level_0(
1302
1636
  continue;
1303
1637
  }
1304
1638
 
1305
- int candidates_size = std::max(efSearch, k);
1306
- MinimaxHeap candidates(candidates_size);
1639
+ int candidates_size = std::max(cur_efSearch, k);
1640
+ MinimaxHeapT<HC_for<C>> candidates(candidates_size);
1307
1641
 
1308
1642
  candidates.push(cj, nearest_d[j]);
1309
1643
 
1310
- nres = search_from_candidates(
1644
+ nres = search_from_candidates_dispatch<C>(
1311
1645
  hnsw,
1312
1646
  qdis,
1313
1647
  res,
@@ -1320,11 +1654,11 @@ void HNSW::search_level_0(
1320
1654
  nres = std::min(nres, candidates_size);
1321
1655
  }
1322
1656
  } else if (search_type == 2) {
1323
- int candidates_size = std::max(efSearch, int(k));
1657
+ int candidates_size = std::max(cur_efSearch, int(k));
1324
1658
  candidates_size = std::max(candidates_size, int(nprobe));
1325
1659
 
1326
- MinimaxHeap candidates(candidates_size);
1327
- for (int j = 0; j < nprobe; j++) {
1660
+ MinimaxHeapT<HC_for<C>> candidates(candidates_size);
1661
+ for (idx_t j = 0; j < nprobe; j++) {
1328
1662
  storage_idx_t cj = nearest_i[j];
1329
1663
 
1330
1664
  if (cj < 0) {
@@ -1333,11 +1667,62 @@ void HNSW::search_level_0(
1333
1667
  candidates.push(cj, nearest_d[j]);
1334
1668
  }
1335
1669
 
1336
- search_from_candidates(
1670
+ search_from_candidates_dispatch<C>(
1337
1671
  hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
1338
1672
  }
1339
1673
  }
1340
1674
 
1675
+ } // namespace
1676
+
1677
+ HNSWStats HNSW::search(
1678
+ DistanceComputer& qdis,
1679
+ const IndexHNSW* index,
1680
+ ResultHandler& res,
1681
+ VisitedTable& vt,
1682
+ const SearchParameters* params) const {
1683
+ if (is_similarity) {
1684
+ return search_impl<C_similarity>(*this, qdis, index, res, vt, params);
1685
+ }
1686
+ return search_impl<C_distance>(*this, qdis, index, res, vt, params);
1687
+ }
1688
+
1689
+ void HNSW::search_level_0(
1690
+ DistanceComputer& qdis,
1691
+ ResultHandler& res,
1692
+ idx_t nprobe,
1693
+ const storage_idx_t* nearest_i,
1694
+ const float* nearest_d,
1695
+ int search_type,
1696
+ HNSWStats& search_stats,
1697
+ VisitedTable& vt,
1698
+ const SearchParameters* params) const {
1699
+ if (is_similarity) {
1700
+ search_level_0_impl<C_similarity>(
1701
+ *this,
1702
+ qdis,
1703
+ res,
1704
+ nprobe,
1705
+ nearest_i,
1706
+ nearest_d,
1707
+ search_type,
1708
+ search_stats,
1709
+ vt,
1710
+ params);
1711
+ } else {
1712
+ search_level_0_impl<C_distance>(
1713
+ *this,
1714
+ qdis,
1715
+ res,
1716
+ nprobe,
1717
+ nearest_i,
1718
+ nearest_d,
1719
+ search_type,
1720
+ search_stats,
1721
+ vt,
1722
+ params);
1723
+ }
1724
+ }
1725
+
1341
1726
  void HNSW::permute_entries(const idx_t* map) {
1342
1727
  // remap levels
1343
1728
  storage_idx_t ntotal = levels.size();
@@ -1371,257 +1756,4 @@ void HNSW::permute_entries(const idx_t* map) {
1371
1756
  neighbors = std::move(new_neighbors);
1372
1757
  }
1373
1758
 
1374
- /**************************************************************
1375
- * MinimaxHeap
1376
- **************************************************************/
1377
-
1378
- void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
1379
- if (k == n) {
1380
- if (v >= dis[0]) {
1381
- return;
1382
- }
1383
- if (ids[0] != -1) {
1384
- --nvalid;
1385
- }
1386
- faiss::heap_pop<HC>(k--, dis.data(), ids.data());
1387
- }
1388
- faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
1389
- ++nvalid;
1390
- }
1391
-
1392
- float HNSW::MinimaxHeap::max() const {
1393
- return dis[0];
1394
- }
1395
-
1396
- int HNSW::MinimaxHeap::size() const {
1397
- return nvalid;
1398
- }
1399
-
1400
- void HNSW::MinimaxHeap::clear() {
1401
- nvalid = k = 0;
1402
- }
1403
-
1404
- #ifdef __AVX512F__
1405
-
1406
- int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1407
- assert(k > 0);
1408
- static_assert(
1409
- std::is_same<storage_idx_t, int32_t>::value,
1410
- "This code expects storage_idx_t to be int32_t");
1411
-
1412
- int32_t min_idx = -1;
1413
- float min_dis = std::numeric_limits<float>::infinity();
1414
-
1415
- __m512i min_indices = _mm512_set1_epi32(-1);
1416
- __m512 min_distances =
1417
- _mm512_set1_ps(std::numeric_limits<float>::infinity());
1418
- __m512i current_indices = _mm512_setr_epi32(
1419
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1420
- __m512i offset = _mm512_set1_epi32(16);
1421
-
1422
- // The following loop tracks the rightmost index with the min distance.
1423
- // -1 index values are ignored.
1424
- const int k16 = (k / 16) * 16;
1425
- for (size_t iii = 0; iii < k16; iii += 16) {
1426
- __m512i indices =
1427
- _mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1428
- __m512 distances = _mm512_loadu_ps(dis.data() + iii);
1429
-
1430
- // This mask filters out -1 values among indices.
1431
- __mmask16 m1mask =
1432
- _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1433
-
1434
- __mmask16 dmask =
1435
- _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1436
- __mmask16 finalmask = m1mask | dmask;
1437
-
1438
- const __m512i min_indices_new = _mm512_mask_blend_epi32(
1439
- finalmask, current_indices, min_indices);
1440
- const __m512 min_distances_new =
1441
- _mm512_mask_blend_ps(finalmask, distances, min_distances);
1442
-
1443
- min_indices = min_indices_new;
1444
- min_distances = min_distances_new;
1445
-
1446
- current_indices = _mm512_add_epi32(current_indices, offset);
1447
- }
1448
-
1449
- // leftovers
1450
- if (k16 != k) {
1451
- const __mmask16 kmask = (1 << (k - k16)) - 1;
1452
-
1453
- __m512i indices = _mm512_mask_loadu_epi32(
1454
- _mm512_set1_epi32(-1), kmask, ids.data() + k16);
1455
- __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1456
-
1457
- // This mask filters out -1 values among indices.
1458
- __mmask16 m1mask =
1459
- _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1460
-
1461
- __mmask16 dmask =
1462
- _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1463
- __mmask16 finalmask = m1mask | dmask;
1464
-
1465
- const __m512i min_indices_new = _mm512_mask_blend_epi32(
1466
- finalmask, current_indices, min_indices);
1467
- const __m512 min_distances_new =
1468
- _mm512_mask_blend_ps(finalmask, distances, min_distances);
1469
-
1470
- min_indices = min_indices_new;
1471
- min_distances = min_distances_new;
1472
- }
1473
-
1474
- // grab min distance
1475
- min_dis = _mm512_reduce_min_ps(min_distances);
1476
- // blend
1477
- __mmask16 mindmask =
1478
- _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1479
- // pick the max one
1480
- min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1481
-
1482
- if (min_idx == -1) {
1483
- return -1;
1484
- }
1485
-
1486
- if (vmin_out) {
1487
- *vmin_out = min_dis;
1488
- }
1489
- int ret = ids[min_idx];
1490
- ids[min_idx] = -1;
1491
- --nvalid;
1492
- return ret;
1493
- }
1494
-
1495
- #elif __AVX2__
1496
-
1497
- int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1498
- assert(k > 0);
1499
- static_assert(
1500
- std::is_same<storage_idx_t, int32_t>::value,
1501
- "This code expects storage_idx_t to be int32_t");
1502
-
1503
- int32_t min_idx = -1;
1504
- float min_dis = std::numeric_limits<float>::infinity();
1505
-
1506
- size_t iii = 0;
1507
-
1508
- __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
1509
- __m256 min_distances =
1510
- _mm256_set1_ps(std::numeric_limits<float>::infinity());
1511
- __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1512
- __m256i offset = _mm256_set1_epi32(8);
1513
-
1514
- // The baseline version is available in non-AVX2 branch.
1515
-
1516
- // The following loop tracks the rightmost index with the min distance.
1517
- // -1 index values are ignored.
1518
- const int k8 = (k / 8) * 8;
1519
- for (; iii < k8; iii += 8) {
1520
- __m256i indices =
1521
- _mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1522
- __m256 distances = _mm256_loadu_ps(dis.data() + iii);
1523
-
1524
- // This mask filters out -1 values among indices.
1525
- __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1526
-
1527
- __m256i dmask = _mm256_castps_si256(
1528
- _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1529
- __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1530
-
1531
- const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1532
- _mm256_castsi256_ps(current_indices),
1533
- _mm256_castsi256_ps(min_indices),
1534
- finalmask));
1535
-
1536
- const __m256 min_distances_new =
1537
- _mm256_blendv_ps(distances, min_distances, finalmask);
1538
-
1539
- min_indices = min_indices_new;
1540
- min_distances = min_distances_new;
1541
-
1542
- current_indices = _mm256_add_epi32(current_indices, offset);
1543
- }
1544
-
1545
- // Vectorizing is doable, but is not practical
1546
- int32_t vidx8[8];
1547
- float vdis8[8];
1548
- _mm256_storeu_ps(vdis8, min_distances);
1549
- _mm256_storeu_si256((__m256i*)vidx8, min_indices);
1550
-
1551
- for (size_t j = 0; j < 8; j++) {
1552
- if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1553
- min_idx = vidx8[j];
1554
- min_dis = vdis8[j];
1555
- }
1556
- }
1557
-
1558
- // process last values. Vectorizing is doable, but is not practical
1559
- for (; iii < k; iii++) {
1560
- if (ids[iii] != -1 && dis[iii] <= min_dis) {
1561
- min_dis = dis[iii];
1562
- min_idx = iii;
1563
- }
1564
- }
1565
-
1566
- if (min_idx == -1) {
1567
- return -1;
1568
- }
1569
-
1570
- if (vmin_out) {
1571
- *vmin_out = min_dis;
1572
- }
1573
- int ret = ids[min_idx];
1574
- ids[min_idx] = -1;
1575
- --nvalid;
1576
- return ret;
1577
- }
1578
-
1579
- #else
1580
-
1581
- // baseline non-vectorized version
1582
- int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1583
- assert(k > 0);
1584
- // returns min. This is an O(n) operation
1585
- int i = k - 1;
1586
- while (i >= 0) {
1587
- if (ids[i] != -1) {
1588
- break;
1589
- }
1590
- i--;
1591
- }
1592
- if (i == -1) {
1593
- return -1;
1594
- }
1595
- int imin = i;
1596
- float vmin = dis[i];
1597
- i--;
1598
- while (i >= 0) {
1599
- if (ids[i] != -1 && dis[i] < vmin) {
1600
- vmin = dis[i];
1601
- imin = i;
1602
- }
1603
- i--;
1604
- }
1605
- if (vmin_out) {
1606
- *vmin_out = vmin;
1607
- }
1608
- int ret = ids[imin];
1609
- ids[imin] = -1;
1610
- --nvalid;
1611
-
1612
- return ret;
1613
- }
1614
- #endif
1615
-
1616
- int HNSW::MinimaxHeap::count_below(float thresh) {
1617
- int n_below = 0;
1618
- for (int i = 0; i < k; i++) {
1619
- if (dis[i] < thresh) {
1620
- n_below++;
1621
- }
1622
- }
1623
-
1624
- return n_below;
1625
- }
1626
-
1627
1759
  } // namespace faiss