faiss 0.3.1 → 0.3.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -7,6 +7,7 @@
7
7
 
8
8
  #include <faiss/impl/HNSW.h>
9
9
 
10
+ #include <cstddef>
10
11
  #include <string>
11
12
 
12
13
  #include <faiss/impl/AuxIndexStructures.h>
@@ -110,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
110
111
  level,
111
112
  nb_neighbors(level));
112
113
  size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
113
- #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
114
- reduction(+: tot_reciprocal) reduction(+: n_node)
114
+ #pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
115
+ reduction(+ : tot_reciprocal) reduction(+ : n_node)
115
116
  for (int i = 0; i < levels.size(); i++) {
116
117
  if (levels[i] > level) {
117
118
  n_node++;
@@ -165,10 +166,10 @@ void HNSW::print_neighbor_stats(int level) const {
165
166
  }
166
167
 
167
168
  void HNSW::fill_with_random_links(size_t n) {
168
- int max_level = prepare_level_tab(n);
169
+ int max_level_2 = prepare_level_tab(n);
169
170
  RandomGenerator rng2(456);
170
171
 
171
- for (int level = max_level - 1; level >= 0; --level) {
172
+ for (int level = max_level_2 - 1; level >= 0; --level) {
172
173
  std::vector<int> elts;
173
174
  for (int i = 0; i < n; i++) {
174
175
  if (levels[i] > level) {
@@ -209,16 +210,16 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
209
210
  }
210
211
  }
211
212
 
212
- int max_level = 0;
213
+ int max_level_2 = 0;
213
214
  for (int i = 0; i < n; i++) {
214
215
  int pt_level = levels[i + n0] - 1;
215
- if (pt_level > max_level)
216
- max_level = pt_level;
216
+ if (pt_level > max_level_2)
217
+ max_level_2 = pt_level;
217
218
  offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
218
- neighbors.resize(offsets.back(), -1);
219
219
  }
220
+ neighbors.resize(offsets.back(), -1);
220
221
 
221
- return max_level;
222
+ return max_level_2;
222
223
  }
223
224
 
224
225
  /** Enumerate vertices from nearest to farthest from query, keep a
@@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list(
229
230
  DistanceComputer& qdis,
230
231
  std::priority_queue<NodeDistFarther>& input,
231
232
  std::vector<NodeDistFarther>& output,
232
- int max_size) {
233
+ int max_size,
234
+ bool keep_max_size_level0) {
235
+ // This prevents number of neighbors at
236
+ // level 0 from being shrunk to less than 2 * M.
237
+ // This is essential in making sure
238
+ // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
239
+ std::vector<NodeDistFarther> outsiders;
240
+
233
241
  while (input.size() > 0) {
234
242
  NodeDistFarther v1 = input.top();
235
243
  input.pop();
@@ -250,8 +258,15 @@ void HNSW::shrink_neighbor_list(
250
258
  if (output.size() >= max_size) {
251
259
  return;
252
260
  }
261
+ } else if (keep_max_size_level0) {
262
+ outsiders.push_back(v1);
253
263
  }
254
264
  }
265
+ size_t idx = 0;
266
+ while (keep_max_size_level0 && (output.size() < max_size) &&
267
+ (idx < outsiders.size())) {
268
+ output.push_back(outsiders[idx++]);
269
+ }
255
270
  }
256
271
 
257
272
  namespace {
@@ -268,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
268
283
  void shrink_neighbor_list(
269
284
  DistanceComputer& qdis,
270
285
  std::priority_queue<NodeDistCloser>& resultSet1,
271
- int max_size) {
286
+ int max_size,
287
+ bool keep_max_size_level0 = false) {
272
288
  if (resultSet1.size() < max_size) {
273
289
  return;
274
290
  }
@@ -280,7 +296,8 @@ void shrink_neighbor_list(
280
296
  resultSet1.pop();
281
297
  }
282
298
 
283
- HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size);
299
+ HNSW::shrink_neighbor_list(
300
+ qdis, resultSet, returnlist, max_size, keep_max_size_level0);
284
301
 
285
302
  for (NodeDistFarther curen2 : returnlist) {
286
303
  resultSet1.emplace(curen2.d, curen2.id);
@@ -294,7 +311,8 @@ void add_link(
294
311
  DistanceComputer& qdis,
295
312
  storage_idx_t src,
296
313
  storage_idx_t dest,
297
- int level) {
314
+ int level,
315
+ bool keep_max_size_level0 = false) {
298
316
  size_t begin, end;
299
317
  hnsw.neighbor_range(src, level, &begin, &end);
300
318
  if (hnsw.neighbors[end - 1] == -1) {
@@ -319,7 +337,7 @@ void add_link(
319
337
  resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
320
338
  }
321
339
 
322
- shrink_neighbor_list(qdis, resultSet, end - begin);
340
+ shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
323
341
 
324
342
  // ...and back
325
343
  size_t i = begin;
@@ -333,6 +351,8 @@ void add_link(
333
351
  }
334
352
  }
335
353
 
354
+ } // namespace
355
+
336
356
  /// search neighbors on a single level, starting from an entry point
337
357
  void search_neighbors_to_add(
338
358
  HNSW& hnsw,
@@ -341,7 +361,8 @@ void search_neighbors_to_add(
341
361
  int entry_point,
342
362
  float d_entry_point,
343
363
  int level,
344
- VisitedTable& vt) {
364
+ VisitedTable& vt,
365
+ bool reference_version) {
345
366
  // top is nearest candidate
346
367
  std::priority_queue<NodeDistFarther> candidates;
347
368
 
@@ -363,62 +384,98 @@ void search_neighbors_to_add(
363
384
  // loop over neighbors
364
385
  size_t begin, end;
365
386
  hnsw.neighbor_range(currNode, level, &begin, &end);
366
- for (size_t i = begin; i < end; i++) {
367
- storage_idx_t nodeId = hnsw.neighbors[i];
368
- if (nodeId < 0)
369
- break;
370
- if (vt.get(nodeId))
371
- continue;
372
- vt.set(nodeId);
373
387
 
374
- float dis = qdis(nodeId);
375
- NodeDistFarther evE1(dis, nodeId);
376
-
377
- if (results.size() < hnsw.efConstruction || results.top().d > dis) {
378
- results.emplace(dis, nodeId);
379
- candidates.emplace(dis, nodeId);
380
- if (results.size() > hnsw.efConstruction) {
381
- results.pop();
388
+ // The reference version is not used, but kept here because:
389
+ // 1. It is easier to switch back if the optimized version has a problem
390
+ // 2. It serves as a starting point for new optimizations
391
+ // 3. It helps understand the code
392
+ // 4. It ensures the reference version is still compilable if the
393
+ // optimized version changes
394
+ // The reference and the optimized versions' results are compared in
395
+ // test_hnsw.cpp
396
+ if (reference_version) {
397
+ // a reference version
398
+ for (size_t i = begin; i < end; i++) {
399
+ storage_idx_t nodeId = hnsw.neighbors[i];
400
+ if (nodeId < 0)
401
+ break;
402
+ if (vt.get(nodeId))
403
+ continue;
404
+ vt.set(nodeId);
405
+
406
+ float dis = qdis(nodeId);
407
+ NodeDistFarther evE1(dis, nodeId);
408
+
409
+ if (results.size() < hnsw.efConstruction ||
410
+ results.top().d > dis) {
411
+ results.emplace(dis, nodeId);
412
+ candidates.emplace(dis, nodeId);
413
+ if (results.size() > hnsw.efConstruction) {
414
+ results.pop();
415
+ }
382
416
  }
383
417
  }
384
- }
385
- }
386
- vt.advance();
387
- }
418
+ } else {
419
+ // a faster version
420
+
421
+ // the following version processes 4 neighbors at a time
422
+ auto update_with_candidate = [&](const storage_idx_t idx,
423
+ const float dis) {
424
+ if (results.size() < hnsw.efConstruction ||
425
+ results.top().d > dis) {
426
+ results.emplace(dis, idx);
427
+ candidates.emplace(dis, idx);
428
+ if (results.size() > hnsw.efConstruction) {
429
+ results.pop();
430
+ }
431
+ }
432
+ };
388
433
 
389
- /**************************************************************
390
- * Searching subroutines
391
- **************************************************************/
434
+ int n_buffered = 0;
435
+ storage_idx_t buffered_ids[4];
392
436
 
393
- /// greedily update a nearest vector at a given level
394
- void greedy_update_nearest(
395
- const HNSW& hnsw,
396
- DistanceComputer& qdis,
397
- int level,
398
- storage_idx_t& nearest,
399
- float& d_nearest) {
400
- for (;;) {
401
- storage_idx_t prev_nearest = nearest;
437
+ for (size_t j = begin; j < end; j++) {
438
+ storage_idx_t nodeId = hnsw.neighbors[j];
439
+ if (nodeId < 0)
440
+ break;
441
+ if (vt.get(nodeId)) {
442
+ continue;
443
+ }
444
+ vt.set(nodeId);
445
+
446
+ buffered_ids[n_buffered] = nodeId;
447
+ n_buffered += 1;
448
+
449
+ if (n_buffered == 4) {
450
+ float dis[4];
451
+ qdis.distances_batch_4(
452
+ buffered_ids[0],
453
+ buffered_ids[1],
454
+ buffered_ids[2],
455
+ buffered_ids[3],
456
+ dis[0],
457
+ dis[1],
458
+ dis[2],
459
+ dis[3]);
460
+
461
+ for (size_t id4 = 0; id4 < 4; id4++) {
462
+ update_with_candidate(buffered_ids[id4], dis[id4]);
463
+ }
402
464
 
403
- size_t begin, end;
404
- hnsw.neighbor_range(nearest, level, &begin, &end);
405
- for (size_t i = begin; i < end; i++) {
406
- storage_idx_t v = hnsw.neighbors[i];
407
- if (v < 0)
408
- break;
409
- float dis = qdis(v);
410
- if (dis < d_nearest) {
411
- nearest = v;
412
- d_nearest = dis;
465
+ n_buffered = 0;
466
+ }
467
+ }
468
+
469
+ // process leftovers
470
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
471
+ float dis = qdis(buffered_ids[icnt]);
472
+ update_with_candidate(buffered_ids[icnt], dis);
413
473
  }
414
- }
415
- if (nearest == prev_nearest) {
416
- return;
417
474
  }
418
475
  }
419
- }
420
476
 
421
- } // namespace
477
+ vt.advance();
478
+ }
422
479
 
423
480
  /// Finds neighbors and builds links with them, starting from an entry
424
481
  /// point. The own neighbor list is assumed to be locked.
@@ -429,7 +486,8 @@ void HNSW::add_links_starting_from(
429
486
  float d_nearest,
430
487
  int level,
431
488
  omp_lock_t* locks,
432
- VisitedTable& vt) {
489
+ VisitedTable& vt,
490
+ bool keep_max_size_level0) {
433
491
  std::priority_queue<NodeDistCloser> link_targets;
434
492
 
435
493
  search_neighbors_to_add(
@@ -438,21 +496,21 @@ void HNSW::add_links_starting_from(
438
496
  // but we can afford only this many neighbors
439
497
  int M = nb_neighbors(level);
440
498
 
441
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
499
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
442
500
 
443
- std::vector<storage_idx_t> neighbors;
444
- neighbors.reserve(link_targets.size());
501
+ std::vector<storage_idx_t> neighbors_to_add;
502
+ neighbors_to_add.reserve(link_targets.size());
445
503
  while (!link_targets.empty()) {
446
504
  storage_idx_t other_id = link_targets.top().id;
447
- add_link(*this, ptdis, pt_id, other_id, level);
448
- neighbors.push_back(other_id);
505
+ add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
506
+ neighbors_to_add.push_back(other_id);
449
507
  link_targets.pop();
450
508
  }
451
509
 
452
510
  omp_unset_lock(&locks[pt_id]);
453
- for (storage_idx_t other_id : neighbors) {
511
+ for (storage_idx_t other_id : neighbors_to_add) {
454
512
  omp_set_lock(&locks[other_id]);
455
- add_link(*this, ptdis, other_id, pt_id, level);
513
+ add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
456
514
  omp_unset_lock(&locks[other_id]);
457
515
  }
458
516
  omp_set_lock(&locks[pt_id]);
@@ -467,7 +525,8 @@ void HNSW::add_with_locks(
467
525
  int pt_level,
468
526
  int pt_id,
469
527
  std::vector<omp_lock_t>& locks,
470
- VisitedTable& vt) {
528
+ VisitedTable& vt,
529
+ bool keep_max_size_level0) {
471
530
  // greedy search on upper levels
472
531
 
473
532
  storage_idx_t nearest;
@@ -496,7 +555,14 @@ void HNSW::add_with_locks(
496
555
 
497
556
  for (; level >= 0; level--) {
498
557
  add_links_starting_from(
499
- ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
558
+ ptdis,
559
+ pt_id,
560
+ nearest,
561
+ d_nearest,
562
+ level,
563
+ locks.data(),
564
+ vt,
565
+ keep_max_size_level0);
500
566
  }
501
567
 
502
568
  omp_unset_lock(&locks[pt_id]);
@@ -511,12 +577,10 @@ void HNSW::add_with_locks(
511
577
  * Searching
512
578
  **************************************************************/
513
579
 
514
- namespace {
515
580
  using MinimaxHeap = HNSW::MinimaxHeap;
516
581
  using Node = HNSW::Node;
517
582
  using C = HNSW::C;
518
583
  /** Do a BFS on the candidates list */
519
-
520
584
  int search_from_candidates(
521
585
  const HNSW& hnsw,
522
586
  DistanceComputer& qdis,
@@ -525,8 +589,8 @@ int search_from_candidates(
525
589
  VisitedTable& vt,
526
590
  HNSWStats& stats,
527
591
  int level,
528
- int nres_in = 0,
529
- const SearchParametersHNSW* params = nullptr) {
592
+ int nres_in,
593
+ const SearchParametersHNSW* params) {
530
594
  int nres = nres_in;
531
595
  int ndis = 0;
532
596
 
@@ -571,27 +635,7 @@ int search_from_candidates(
571
635
  size_t begin, end;
572
636
  hnsw.neighbor_range(v0, level, &begin, &end);
573
637
 
574
- // // baseline version
575
- // for (size_t j = begin; j < end; j++) {
576
- // int v1 = hnsw.neighbors[j];
577
- // if (v1 < 0)
578
- // break;
579
- // if (vt.get(v1)) {
580
- // continue;
581
- // }
582
- // vt.set(v1);
583
- // ndis++;
584
- // float d = qdis(v1);
585
- // if (!sel || sel->is_member(v1)) {
586
- // if (nres < k) {
587
- // faiss::maxheap_push(++nres, D, I, d, v1);
588
- // } else if (d < D[0]) {
589
- // faiss::maxheap_replace_top(nres, D, I, d, v1);
590
- // }
591
- // }
592
- // candidates.push(v1, d);
593
- // }
594
-
638
+ // a faster version: reference version in unit test test_hnsw.cpp
595
639
  // the following version processes 4 neighbors at a time
596
640
  size_t jmax = begin;
597
641
  for (size_t j = begin; j < end; j++) {
@@ -606,7 +650,6 @@ int search_from_candidates(
606
650
  int counter = 0;
607
651
  size_t saved_j[4];
608
652
 
609
- ndis += jmax - begin;
610
653
  threshold = res.threshold;
611
654
 
612
655
  auto add_to_heap = [&](const size_t idx, const float dis) {
@@ -614,6 +657,7 @@ int search_from_candidates(
614
657
  if (dis < threshold) {
615
658
  if (res.add_result(dis, idx)) {
616
659
  threshold = res.threshold;
660
+ nres += 1;
617
661
  }
618
662
  }
619
663
  }
@@ -644,6 +688,8 @@ int search_from_candidates(
644
688
  add_to_heap(saved_j[id4], dis[id4]);
645
689
  }
646
690
 
691
+ ndis += 4;
692
+
647
693
  counter = 0;
648
694
  }
649
695
  }
@@ -651,6 +697,8 @@ int search_from_candidates(
651
697
  for (size_t icnt = 0; icnt < counter; icnt++) {
652
698
  float dis = qdis(saved_j[icnt]);
653
699
  add_to_heap(saved_j[icnt], dis);
700
+
701
+ ndis += 1;
654
702
  }
655
703
 
656
704
  nstep++;
@@ -664,7 +712,8 @@ int search_from_candidates(
664
712
  if (candidates.size() == 0) {
665
713
  stats.n2++;
666
714
  }
667
- stats.n3 += ndis;
715
+ stats.ndis += ndis;
716
+ stats.nhops += nstep;
668
717
  }
669
718
 
670
719
  return nres;
@@ -700,33 +749,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
700
749
  size_t begin, end;
701
750
  hnsw.neighbor_range(v0, 0, &begin, &end);
702
751
 
703
- // // baseline version
704
- // for (size_t j = begin; j < end; ++j) {
705
- // int v1 = hnsw.neighbors[j];
706
- //
707
- // if (v1 < 0) {
708
- // break;
709
- // }
710
- // if (vt->get(v1)) {
711
- // continue;
712
- // }
713
- //
714
- // vt->set(v1);
715
- //
716
- // float d1 = qdis(v1);
717
- // ++ndis;
718
- //
719
- // if (top_candidates.top().first > d1 ||
720
- // top_candidates.size() < ef) {
721
- // candidates.emplace(d1, v1);
722
- // top_candidates.emplace(d1, v1);
723
- //
724
- // if (top_candidates.size() > ef) {
725
- // top_candidates.pop();
726
- // }
727
- // }
728
- // }
729
-
752
+ // a faster version: reference version in unit test test_hnsw.cpp
730
753
  // the following version processes 4 neighbors at a time
731
754
  size_t jmax = begin;
732
755
  for (size_t j = begin; j < end; j++) {
@@ -741,8 +764,6 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
741
764
  int counter = 0;
742
765
  size_t saved_j[4];
743
766
 
744
- ndis += jmax - begin;
745
-
746
767
  auto add_to_heap = [&](const size_t idx, const float dis) {
747
768
  if (top_candidates.top().first > dis ||
748
769
  top_candidates.size() < ef) {
@@ -779,6 +800,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
779
800
  add_to_heap(saved_j[id4], dis[id4]);
780
801
  }
781
802
 
803
+ ndis += 4;
804
+
782
805
  counter = 0;
783
806
  }
784
807
  }
@@ -786,18 +809,102 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
786
809
  for (size_t icnt = 0; icnt < counter; icnt++) {
787
810
  float dis = qdis(saved_j[icnt]);
788
811
  add_to_heap(saved_j[icnt], dis);
812
+
813
+ ndis += 1;
789
814
  }
815
+
816
+ stats.nhops += 1;
790
817
  }
791
818
 
792
819
  ++stats.n1;
793
820
  if (candidates.size() == 0) {
794
821
  ++stats.n2;
795
822
  }
796
- stats.n3 += ndis;
823
+ stats.ndis += ndis;
797
824
 
798
825
  return top_candidates;
799
826
  }
800
827
 
828
+ /// greedily update a nearest vector at a given level
829
+ HNSWStats greedy_update_nearest(
830
+ const HNSW& hnsw,
831
+ DistanceComputer& qdis,
832
+ int level,
833
+ storage_idx_t& nearest,
834
+ float& d_nearest) {
835
+ HNSWStats stats;
836
+
837
+ for (;;) {
838
+ storage_idx_t prev_nearest = nearest;
839
+
840
+ size_t begin, end;
841
+ hnsw.neighbor_range(nearest, level, &begin, &end);
842
+
843
+ size_t ndis = 0;
844
+
845
+ // a faster version: reference version in unit test test_hnsw.cpp
846
+ // the following version processes 4 neighbors at a time
847
+ auto update_with_candidate = [&](const storage_idx_t idx,
848
+ const float dis) {
849
+ if (dis < d_nearest) {
850
+ nearest = idx;
851
+ d_nearest = dis;
852
+ }
853
+ };
854
+
855
+ int n_buffered = 0;
856
+ storage_idx_t buffered_ids[4];
857
+
858
+ for (size_t j = begin; j < end; j++) {
859
+ storage_idx_t v = hnsw.neighbors[j];
860
+ if (v < 0)
861
+ break;
862
+ ndis += 1;
863
+
864
+ buffered_ids[n_buffered] = v;
865
+ n_buffered += 1;
866
+
867
+ if (n_buffered == 4) {
868
+ float dis[4];
869
+ qdis.distances_batch_4(
870
+ buffered_ids[0],
871
+ buffered_ids[1],
872
+ buffered_ids[2],
873
+ buffered_ids[3],
874
+ dis[0],
875
+ dis[1],
876
+ dis[2],
877
+ dis[3]);
878
+
879
+ for (size_t id4 = 0; id4 < 4; id4++) {
880
+ update_with_candidate(buffered_ids[id4], dis[id4]);
881
+ }
882
+
883
+ n_buffered = 0;
884
+ }
885
+ }
886
+
887
+ // process leftovers
888
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
889
+ float dis = qdis(buffered_ids[icnt]);
890
+ update_with_candidate(buffered_ids[icnt], dis);
891
+ }
892
+
893
+ // update stats
894
+ stats.ndis += ndis;
895
+ stats.nhops += 1;
896
+
897
+ if (nearest == prev_nearest) {
898
+ return stats;
899
+ }
900
+ }
901
+ }
902
+
903
+ namespace {
904
+ using MinimaxHeap = HNSW::MinimaxHeap;
905
+ using Node = HNSW::Node;
906
+ using C = HNSW::C;
907
+
801
908
  // just used as a lower bound for the minmaxheap, but it is set for heap search
802
909
  int extract_k_from_ResultHandler(ResultHandler<C>& res) {
803
910
  using RH = HeapBlockResultHandler<C>;
@@ -807,7 +914,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
807
914
  return 1;
808
915
  }
809
916
 
810
- } // anonymous namespace
917
+ } // namespace
811
918
 
812
919
  HNSWStats HNSW::search(
813
920
  DistanceComputer& qdis,
@@ -820,85 +927,47 @@ HNSWStats HNSW::search(
820
927
  }
821
928
  int k = extract_k_from_ResultHandler(res);
822
929
 
823
- if (upper_beam == 1) {
824
- // greedy search on upper levels
825
- storage_idx_t nearest = entry_point;
826
- float d_nearest = qdis(nearest);
827
-
828
- for (int level = max_level; level >= 1; level--) {
829
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
830
- }
831
-
832
- int ef = std::max(params ? params->efSearch : efSearch, k);
833
- if (search_bounded_queue) { // this is the most common branch
834
- MinimaxHeap candidates(ef);
930
+ bool bounded_queue =
931
+ params ? params->bounded_queue : this->search_bounded_queue;
835
932
 
836
- candidates.push(nearest, d_nearest);
933
+ // greedy search on upper levels
934
+ storage_idx_t nearest = entry_point;
935
+ float d_nearest = qdis(nearest);
837
936
 
838
- search_from_candidates(
839
- *this, qdis, res, candidates, vt, stats, 0, 0, params);
840
- } else {
841
- std::priority_queue<Node> top_candidates =
842
- search_from_candidate_unbounded(
843
- *this,
844
- Node(d_nearest, nearest),
845
- qdis,
846
- ef,
847
- &vt,
848
- stats);
849
-
850
- while (top_candidates.size() > k) {
851
- top_candidates.pop();
852
- }
937
+ for (int level = max_level; level >= 1; level--) {
938
+ HNSWStats local_stats =
939
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
940
+ stats.combine(local_stats);
941
+ }
853
942
 
854
- while (!top_candidates.empty()) {
855
- float d;
856
- storage_idx_t label;
857
- std::tie(d, label) = top_candidates.top();
858
- res.add_result(d, label);
859
- top_candidates.pop();
860
- }
861
- }
943
+ int ef = std::max(params ? params->efSearch : efSearch, k);
944
+ if (bounded_queue) { // this is the most common branch
945
+ MinimaxHeap candidates(ef);
862
946
 
863
- vt.advance();
947
+ candidates.push(nearest, d_nearest);
864
948
 
949
+ search_from_candidates(
950
+ *this, qdis, res, candidates, vt, stats, 0, 0, params);
865
951
  } else {
866
- int candidates_size = upper_beam;
867
- MinimaxHeap candidates(candidates_size);
868
-
869
- std::vector<idx_t> I_to_next(candidates_size);
870
- std::vector<float> D_to_next(candidates_size);
871
-
872
- HeapBlockResultHandler<C> block_resh(
873
- 1, D_to_next.data(), I_to_next.data(), candidates_size);
874
- HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
875
-
876
- int nres = 1;
877
- I_to_next[0] = entry_point;
878
- D_to_next[0] = qdis(entry_point);
879
-
880
- for (int level = max_level; level >= 0; level--) {
881
- // copy I, D -> candidates
952
+ std::priority_queue<Node> top_candidates =
953
+ search_from_candidate_unbounded(
954
+ *this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
882
955
 
883
- candidates.clear();
884
-
885
- for (int i = 0; i < nres; i++) {
886
- candidates.push(I_to_next[i], D_to_next[i]);
887
- }
956
+ while (top_candidates.size() > k) {
957
+ top_candidates.pop();
958
+ }
888
959
 
889
- if (level == 0) {
890
- nres = search_from_candidates(
891
- *this, qdis, res, candidates, vt, stats, 0);
892
- } else {
893
- resh.begin(0);
894
- nres = search_from_candidates(
895
- *this, qdis, resh, candidates, vt, stats, level);
896
- resh.end();
897
- }
898
- vt.advance();
960
+ while (!top_candidates.empty()) {
961
+ float d;
962
+ storage_idx_t label;
963
+ std::tie(d, label) = top_candidates.top();
964
+ res.add_result(d, label);
965
+ top_candidates.pop();
899
966
  }
900
967
  }
901
968
 
969
+ vt.advance();
970
+
902
971
  return stats;
903
972
  }
904
973
 
@@ -910,9 +979,12 @@ void HNSW::search_level_0(
910
979
  const float* nearest_d,
911
980
  int search_type,
912
981
  HNSWStats& search_stats,
913
- VisitedTable& vt) const {
982
+ VisitedTable& vt,
983
+ const SearchParametersHNSW* params) const {
914
984
  const HNSW& hnsw = *this;
985
+ auto efSearch = params ? params->efSearch : hnsw.efSearch;
915
986
  int k = extract_k_from_ResultHandler(res);
987
+
916
988
  if (search_type == 1) {
917
989
  int nres = 0;
918
990
 
@@ -925,16 +997,25 @@ void HNSW::search_level_0(
925
997
  if (vt.get(cj))
926
998
  continue;
927
999
 
928
- int candidates_size = std::max(hnsw.efSearch, k);
1000
+ int candidates_size = std::max(efSearch, k);
929
1001
  MinimaxHeap candidates(candidates_size);
930
1002
 
931
1003
  candidates.push(cj, nearest_d[j]);
932
1004
 
933
1005
  nres = search_from_candidates(
934
- hnsw, qdis, res, candidates, vt, search_stats, 0, nres);
1006
+ hnsw,
1007
+ qdis,
1008
+ res,
1009
+ candidates,
1010
+ vt,
1011
+ search_stats,
1012
+ 0,
1013
+ nres,
1014
+ params);
1015
+ nres = std::min(nres, candidates_size);
935
1016
  }
936
1017
  } else if (search_type == 2) {
937
- int candidates_size = std::max(hnsw.efSearch, int(k));
1018
+ int candidates_size = std::max(efSearch, int(k));
938
1019
  candidates_size = std::max(candidates_size, int(nprobe));
939
1020
 
940
1021
  MinimaxHeap candidates(candidates_size);
@@ -947,7 +1028,7 @@ void HNSW::search_level_0(
947
1028
  }
948
1029
 
949
1030
  search_from_candidates(
950
- hnsw, qdis, res, candidates, vt, search_stats, 0);
1031
+ hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
951
1032
  }
952
1033
  }
953
1034
 
@@ -1013,7 +1094,99 @@ void HNSW::MinimaxHeap::clear() {
1013
1094
  nvalid = k = 0;
1014
1095
  }
1015
1096
 
1016
- #ifdef __AVX2__
1097
+ #ifdef __AVX512F__
1098
+
1099
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1100
+ assert(k > 0);
1101
+ static_assert(
1102
+ std::is_same<storage_idx_t, int32_t>::value,
1103
+ "This code expects storage_idx_t to be int32_t");
1104
+
1105
+ int32_t min_idx = -1;
1106
+ float min_dis = std::numeric_limits<float>::infinity();
1107
+
1108
+ __m512i min_indices = _mm512_set1_epi32(-1);
1109
+ __m512 min_distances =
1110
+ _mm512_set1_ps(std::numeric_limits<float>::infinity());
1111
+ __m512i current_indices = _mm512_setr_epi32(
1112
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1113
+ __m512i offset = _mm512_set1_epi32(16);
1114
+
1115
+ // The following loop tracks the rightmost index with the min distance.
1116
+ // -1 index values are ignored.
1117
+ const int k16 = (k / 16) * 16;
1118
+ for (size_t iii = 0; iii < k16; iii += 16) {
1119
+ __m512i indices =
1120
+ _mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1121
+ __m512 distances = _mm512_loadu_ps(dis.data() + iii);
1122
+
1123
+ // This mask filters out -1 values among indices.
1124
+ __mmask16 m1mask =
1125
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1126
+
1127
+ __mmask16 dmask =
1128
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1129
+ __mmask16 finalmask = m1mask | dmask;
1130
+
1131
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1132
+ finalmask, current_indices, min_indices);
1133
+ const __m512 min_distances_new =
1134
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1135
+
1136
+ min_indices = min_indices_new;
1137
+ min_distances = min_distances_new;
1138
+
1139
+ current_indices = _mm512_add_epi32(current_indices, offset);
1140
+ }
1141
+
1142
+ // leftovers
1143
+ if (k16 != k) {
1144
+ const __mmask16 kmask = (1 << (k - k16)) - 1;
1145
+
1146
+ __m512i indices = _mm512_mask_loadu_epi32(
1147
+ _mm512_set1_epi32(-1), kmask, ids.data() + k16);
1148
+ __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1149
+
1150
+ // This mask filters out -1 values among indices.
1151
+ __mmask16 m1mask =
1152
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1153
+
1154
+ __mmask16 dmask =
1155
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1156
+ __mmask16 finalmask = m1mask | dmask;
1157
+
1158
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1159
+ finalmask, current_indices, min_indices);
1160
+ const __m512 min_distances_new =
1161
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1162
+
1163
+ min_indices = min_indices_new;
1164
+ min_distances = min_distances_new;
1165
+ }
1166
+
1167
+ // grab min distance
1168
+ min_dis = _mm512_reduce_min_ps(min_distances);
1169
+ // blend
1170
+ __mmask16 mindmask =
1171
+ _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1172
+ // pick the max one
1173
+ min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1174
+
1175
+ if (min_idx == -1) {
1176
+ return -1;
1177
+ }
1178
+
1179
+ if (vmin_out) {
1180
+ *vmin_out = min_dis;
1181
+ }
1182
+ int ret = ids[min_idx];
1183
+ ids[min_idx] = -1;
1184
+ --nvalid;
1185
+ return ret;
1186
+ }
1187
+
1188
+ #elif __AVX2__
1189
+
1017
1190
  int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1018
1191
  assert(k > 0);
1019
1192
  static_assert(