faiss 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -7,6 +7,7 @@
7
7
 
8
8
  #include <faiss/impl/HNSW.h>
9
9
 
10
+ #include <cstddef>
10
11
  #include <string>
11
12
 
12
13
  #include <faiss/impl/AuxIndexStructures.h>
@@ -110,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
110
111
  level,
111
112
  nb_neighbors(level));
112
113
  size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
113
- #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
114
- reduction(+: tot_reciprocal) reduction(+: n_node)
114
+ #pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
115
+ reduction(+ : tot_reciprocal) reduction(+ : n_node)
115
116
  for (int i = 0; i < levels.size(); i++) {
116
117
  if (levels[i] > level) {
117
118
  n_node++;
@@ -165,10 +166,10 @@ void HNSW::print_neighbor_stats(int level) const {
165
166
  }
166
167
 
167
168
  void HNSW::fill_with_random_links(size_t n) {
168
- int max_level = prepare_level_tab(n);
169
+ int max_level_2 = prepare_level_tab(n);
169
170
  RandomGenerator rng2(456);
170
171
 
171
- for (int level = max_level - 1; level >= 0; --level) {
172
+ for (int level = max_level_2 - 1; level >= 0; --level) {
172
173
  std::vector<int> elts;
173
174
  for (int i = 0; i < n; i++) {
174
175
  if (levels[i] > level) {
@@ -209,16 +210,16 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
209
210
  }
210
211
  }
211
212
 
212
- int max_level = 0;
213
+ int max_level_2 = 0;
213
214
  for (int i = 0; i < n; i++) {
214
215
  int pt_level = levels[i + n0] - 1;
215
- if (pt_level > max_level)
216
- max_level = pt_level;
216
+ if (pt_level > max_level_2)
217
+ max_level_2 = pt_level;
217
218
  offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
218
- neighbors.resize(offsets.back(), -1);
219
219
  }
220
+ neighbors.resize(offsets.back(), -1);
220
221
 
221
- return max_level;
222
+ return max_level_2;
222
223
  }
223
224
 
224
225
  /** Enumerate vertices from nearest to farthest from query, keep a
@@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list(
229
230
  DistanceComputer& qdis,
230
231
  std::priority_queue<NodeDistFarther>& input,
231
232
  std::vector<NodeDistFarther>& output,
232
- int max_size) {
233
+ int max_size,
234
+ bool keep_max_size_level0) {
235
+ // This prevents number of neighbors at
236
+ // level 0 from being shrunk to less than 2 * M.
237
+ // This is essential in making sure
238
+ // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
239
+ std::vector<NodeDistFarther> outsiders;
240
+
233
241
  while (input.size() > 0) {
234
242
  NodeDistFarther v1 = input.top();
235
243
  input.pop();
@@ -250,8 +258,15 @@ void HNSW::shrink_neighbor_list(
250
258
  if (output.size() >= max_size) {
251
259
  return;
252
260
  }
261
+ } else if (keep_max_size_level0) {
262
+ outsiders.push_back(v1);
253
263
  }
254
264
  }
265
+ size_t idx = 0;
266
+ while (keep_max_size_level0 && (output.size() < max_size) &&
267
+ (idx < outsiders.size())) {
268
+ output.push_back(outsiders[idx++]);
269
+ }
255
270
  }
256
271
 
257
272
  namespace {
@@ -268,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
268
283
  void shrink_neighbor_list(
269
284
  DistanceComputer& qdis,
270
285
  std::priority_queue<NodeDistCloser>& resultSet1,
271
- int max_size) {
286
+ int max_size,
287
+ bool keep_max_size_level0 = false) {
272
288
  if (resultSet1.size() < max_size) {
273
289
  return;
274
290
  }
@@ -280,7 +296,8 @@ void shrink_neighbor_list(
280
296
  resultSet1.pop();
281
297
  }
282
298
 
283
- HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size);
299
+ HNSW::shrink_neighbor_list(
300
+ qdis, resultSet, returnlist, max_size, keep_max_size_level0);
284
301
 
285
302
  for (NodeDistFarther curen2 : returnlist) {
286
303
  resultSet1.emplace(curen2.d, curen2.id);
@@ -294,7 +311,8 @@ void add_link(
294
311
  DistanceComputer& qdis,
295
312
  storage_idx_t src,
296
313
  storage_idx_t dest,
297
- int level) {
314
+ int level,
315
+ bool keep_max_size_level0 = false) {
298
316
  size_t begin, end;
299
317
  hnsw.neighbor_range(src, level, &begin, &end);
300
318
  if (hnsw.neighbors[end - 1] == -1) {
@@ -319,7 +337,7 @@ void add_link(
319
337
  resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
320
338
  }
321
339
 
322
- shrink_neighbor_list(qdis, resultSet, end - begin);
340
+ shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
323
341
 
324
342
  // ...and back
325
343
  size_t i = begin;
@@ -333,6 +351,8 @@ void add_link(
333
351
  }
334
352
  }
335
353
 
354
+ } // namespace
355
+
336
356
  /// search neighbors on a single level, starting from an entry point
337
357
  void search_neighbors_to_add(
338
358
  HNSW& hnsw,
@@ -341,7 +361,8 @@ void search_neighbors_to_add(
341
361
  int entry_point,
342
362
  float d_entry_point,
343
363
  int level,
344
- VisitedTable& vt) {
364
+ VisitedTable& vt,
365
+ bool reference_version) {
345
366
  // top is nearest candidate
346
367
  std::priority_queue<NodeDistFarther> candidates;
347
368
 
@@ -363,62 +384,98 @@ void search_neighbors_to_add(
363
384
  // loop over neighbors
364
385
  size_t begin, end;
365
386
  hnsw.neighbor_range(currNode, level, &begin, &end);
366
- for (size_t i = begin; i < end; i++) {
367
- storage_idx_t nodeId = hnsw.neighbors[i];
368
- if (nodeId < 0)
369
- break;
370
- if (vt.get(nodeId))
371
- continue;
372
- vt.set(nodeId);
373
387
 
374
- float dis = qdis(nodeId);
375
- NodeDistFarther evE1(dis, nodeId);
376
-
377
- if (results.size() < hnsw.efConstruction || results.top().d > dis) {
378
- results.emplace(dis, nodeId);
379
- candidates.emplace(dis, nodeId);
380
- if (results.size() > hnsw.efConstruction) {
381
- results.pop();
388
+ // The reference version is not used, but kept here because:
389
+ // 1. It is easier to switch back if the optimized version has a problem
390
+ // 2. It serves as a starting point for new optimizations
391
+ // 3. It helps understand the code
392
+ // 4. It ensures the reference version is still compilable if the
393
+ // optimized version changes
394
+ // The reference and the optimized versions' results are compared in
395
+ // test_hnsw.cpp
396
+ if (reference_version) {
397
+ // a reference version
398
+ for (size_t i = begin; i < end; i++) {
399
+ storage_idx_t nodeId = hnsw.neighbors[i];
400
+ if (nodeId < 0)
401
+ break;
402
+ if (vt.get(nodeId))
403
+ continue;
404
+ vt.set(nodeId);
405
+
406
+ float dis = qdis(nodeId);
407
+ NodeDistFarther evE1(dis, nodeId);
408
+
409
+ if (results.size() < hnsw.efConstruction ||
410
+ results.top().d > dis) {
411
+ results.emplace(dis, nodeId);
412
+ candidates.emplace(dis, nodeId);
413
+ if (results.size() > hnsw.efConstruction) {
414
+ results.pop();
415
+ }
382
416
  }
383
417
  }
384
- }
385
- }
386
- vt.advance();
387
- }
418
+ } else {
419
+ // a faster version
420
+
421
+ // the following version processes 4 neighbors at a time
422
+ auto update_with_candidate = [&](const storage_idx_t idx,
423
+ const float dis) {
424
+ if (results.size() < hnsw.efConstruction ||
425
+ results.top().d > dis) {
426
+ results.emplace(dis, idx);
427
+ candidates.emplace(dis, idx);
428
+ if (results.size() > hnsw.efConstruction) {
429
+ results.pop();
430
+ }
431
+ }
432
+ };
388
433
 
389
- /**************************************************************
390
- * Searching subroutines
391
- **************************************************************/
434
+ int n_buffered = 0;
435
+ storage_idx_t buffered_ids[4];
392
436
 
393
- /// greedily update a nearest vector at a given level
394
- void greedy_update_nearest(
395
- const HNSW& hnsw,
396
- DistanceComputer& qdis,
397
- int level,
398
- storage_idx_t& nearest,
399
- float& d_nearest) {
400
- for (;;) {
401
- storage_idx_t prev_nearest = nearest;
437
+ for (size_t j = begin; j < end; j++) {
438
+ storage_idx_t nodeId = hnsw.neighbors[j];
439
+ if (nodeId < 0)
440
+ break;
441
+ if (vt.get(nodeId)) {
442
+ continue;
443
+ }
444
+ vt.set(nodeId);
445
+
446
+ buffered_ids[n_buffered] = nodeId;
447
+ n_buffered += 1;
448
+
449
+ if (n_buffered == 4) {
450
+ float dis[4];
451
+ qdis.distances_batch_4(
452
+ buffered_ids[0],
453
+ buffered_ids[1],
454
+ buffered_ids[2],
455
+ buffered_ids[3],
456
+ dis[0],
457
+ dis[1],
458
+ dis[2],
459
+ dis[3]);
460
+
461
+ for (size_t id4 = 0; id4 < 4; id4++) {
462
+ update_with_candidate(buffered_ids[id4], dis[id4]);
463
+ }
402
464
 
403
- size_t begin, end;
404
- hnsw.neighbor_range(nearest, level, &begin, &end);
405
- for (size_t i = begin; i < end; i++) {
406
- storage_idx_t v = hnsw.neighbors[i];
407
- if (v < 0)
408
- break;
409
- float dis = qdis(v);
410
- if (dis < d_nearest) {
411
- nearest = v;
412
- d_nearest = dis;
465
+ n_buffered = 0;
466
+ }
467
+ }
468
+
469
+ // process leftovers
470
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
471
+ float dis = qdis(buffered_ids[icnt]);
472
+ update_with_candidate(buffered_ids[icnt], dis);
413
473
  }
414
- }
415
- if (nearest == prev_nearest) {
416
- return;
417
474
  }
418
475
  }
419
- }
420
476
 
421
- } // namespace
477
+ vt.advance();
478
+ }
422
479
 
423
480
  /// Finds neighbors and builds links with them, starting from an entry
424
481
  /// point. The own neighbor list is assumed to be locked.
@@ -429,7 +486,8 @@ void HNSW::add_links_starting_from(
429
486
  float d_nearest,
430
487
  int level,
431
488
  omp_lock_t* locks,
432
- VisitedTable& vt) {
489
+ VisitedTable& vt,
490
+ bool keep_max_size_level0) {
433
491
  std::priority_queue<NodeDistCloser> link_targets;
434
492
 
435
493
  search_neighbors_to_add(
@@ -438,21 +496,21 @@ void HNSW::add_links_starting_from(
438
496
  // but we can afford only this many neighbors
439
497
  int M = nb_neighbors(level);
440
498
 
441
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
499
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
442
500
 
443
- std::vector<storage_idx_t> neighbors;
444
- neighbors.reserve(link_targets.size());
501
+ std::vector<storage_idx_t> neighbors_to_add;
502
+ neighbors_to_add.reserve(link_targets.size());
445
503
  while (!link_targets.empty()) {
446
504
  storage_idx_t other_id = link_targets.top().id;
447
- add_link(*this, ptdis, pt_id, other_id, level);
448
- neighbors.push_back(other_id);
505
+ add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
506
+ neighbors_to_add.push_back(other_id);
449
507
  link_targets.pop();
450
508
  }
451
509
 
452
510
  omp_unset_lock(&locks[pt_id]);
453
- for (storage_idx_t other_id : neighbors) {
511
+ for (storage_idx_t other_id : neighbors_to_add) {
454
512
  omp_set_lock(&locks[other_id]);
455
- add_link(*this, ptdis, other_id, pt_id, level);
513
+ add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
456
514
  omp_unset_lock(&locks[other_id]);
457
515
  }
458
516
  omp_set_lock(&locks[pt_id]);
@@ -467,7 +525,8 @@ void HNSW::add_with_locks(
467
525
  int pt_level,
468
526
  int pt_id,
469
527
  std::vector<omp_lock_t>& locks,
470
- VisitedTable& vt) {
528
+ VisitedTable& vt,
529
+ bool keep_max_size_level0) {
471
530
  // greedy search on upper levels
472
531
 
473
532
  storage_idx_t nearest;
@@ -496,7 +555,14 @@ void HNSW::add_with_locks(
496
555
 
497
556
  for (; level >= 0; level--) {
498
557
  add_links_starting_from(
499
- ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
558
+ ptdis,
559
+ pt_id,
560
+ nearest,
561
+ d_nearest,
562
+ level,
563
+ locks.data(),
564
+ vt,
565
+ keep_max_size_level0);
500
566
  }
501
567
 
502
568
  omp_unset_lock(&locks[pt_id]);
@@ -511,12 +577,10 @@ void HNSW::add_with_locks(
511
577
  * Searching
512
578
  **************************************************************/
513
579
 
514
- namespace {
515
580
  using MinimaxHeap = HNSW::MinimaxHeap;
516
581
  using Node = HNSW::Node;
517
582
  using C = HNSW::C;
518
583
  /** Do a BFS on the candidates list */
519
-
520
584
  int search_from_candidates(
521
585
  const HNSW& hnsw,
522
586
  DistanceComputer& qdis,
@@ -525,8 +589,8 @@ int search_from_candidates(
525
589
  VisitedTable& vt,
526
590
  HNSWStats& stats,
527
591
  int level,
528
- int nres_in = 0,
529
- const SearchParametersHNSW* params = nullptr) {
592
+ int nres_in,
593
+ const SearchParametersHNSW* params) {
530
594
  int nres = nres_in;
531
595
  int ndis = 0;
532
596
 
@@ -571,27 +635,7 @@ int search_from_candidates(
571
635
  size_t begin, end;
572
636
  hnsw.neighbor_range(v0, level, &begin, &end);
573
637
 
574
- // // baseline version
575
- // for (size_t j = begin; j < end; j++) {
576
- // int v1 = hnsw.neighbors[j];
577
- // if (v1 < 0)
578
- // break;
579
- // if (vt.get(v1)) {
580
- // continue;
581
- // }
582
- // vt.set(v1);
583
- // ndis++;
584
- // float d = qdis(v1);
585
- // if (!sel || sel->is_member(v1)) {
586
- // if (nres < k) {
587
- // faiss::maxheap_push(++nres, D, I, d, v1);
588
- // } else if (d < D[0]) {
589
- // faiss::maxheap_replace_top(nres, D, I, d, v1);
590
- // }
591
- // }
592
- // candidates.push(v1, d);
593
- // }
594
-
638
+ // a faster version: reference version in unit test test_hnsw.cpp
595
639
  // the following version processes 4 neighbors at a time
596
640
  size_t jmax = begin;
597
641
  for (size_t j = begin; j < end; j++) {
@@ -606,7 +650,6 @@ int search_from_candidates(
606
650
  int counter = 0;
607
651
  size_t saved_j[4];
608
652
 
609
- ndis += jmax - begin;
610
653
  threshold = res.threshold;
611
654
 
612
655
  auto add_to_heap = [&](const size_t idx, const float dis) {
@@ -614,6 +657,7 @@ int search_from_candidates(
614
657
  if (dis < threshold) {
615
658
  if (res.add_result(dis, idx)) {
616
659
  threshold = res.threshold;
660
+ nres += 1;
617
661
  }
618
662
  }
619
663
  }
@@ -644,6 +688,8 @@ int search_from_candidates(
644
688
  add_to_heap(saved_j[id4], dis[id4]);
645
689
  }
646
690
 
691
+ ndis += 4;
692
+
647
693
  counter = 0;
648
694
  }
649
695
  }
@@ -651,6 +697,8 @@ int search_from_candidates(
651
697
  for (size_t icnt = 0; icnt < counter; icnt++) {
652
698
  float dis = qdis(saved_j[icnt]);
653
699
  add_to_heap(saved_j[icnt], dis);
700
+
701
+ ndis += 1;
654
702
  }
655
703
 
656
704
  nstep++;
@@ -664,7 +712,8 @@ int search_from_candidates(
664
712
  if (candidates.size() == 0) {
665
713
  stats.n2++;
666
714
  }
667
- stats.n3 += ndis;
715
+ stats.ndis += ndis;
716
+ stats.nhops += nstep;
668
717
  }
669
718
 
670
719
  return nres;
@@ -700,33 +749,7 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
700
749
  size_t begin, end;
701
750
  hnsw.neighbor_range(v0, 0, &begin, &end);
702
751
 
703
- // // baseline version
704
- // for (size_t j = begin; j < end; ++j) {
705
- // int v1 = hnsw.neighbors[j];
706
- //
707
- // if (v1 < 0) {
708
- // break;
709
- // }
710
- // if (vt->get(v1)) {
711
- // continue;
712
- // }
713
- //
714
- // vt->set(v1);
715
- //
716
- // float d1 = qdis(v1);
717
- // ++ndis;
718
- //
719
- // if (top_candidates.top().first > d1 ||
720
- // top_candidates.size() < ef) {
721
- // candidates.emplace(d1, v1);
722
- // top_candidates.emplace(d1, v1);
723
- //
724
- // if (top_candidates.size() > ef) {
725
- // top_candidates.pop();
726
- // }
727
- // }
728
- // }
729
-
752
+ // a faster version: reference version in unit test test_hnsw.cpp
730
753
  // the following version processes 4 neighbors at a time
731
754
  size_t jmax = begin;
732
755
  for (size_t j = begin; j < end; j++) {
@@ -741,8 +764,6 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
741
764
  int counter = 0;
742
765
  size_t saved_j[4];
743
766
 
744
- ndis += jmax - begin;
745
-
746
767
  auto add_to_heap = [&](const size_t idx, const float dis) {
747
768
  if (top_candidates.top().first > dis ||
748
769
  top_candidates.size() < ef) {
@@ -779,6 +800,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
779
800
  add_to_heap(saved_j[id4], dis[id4]);
780
801
  }
781
802
 
803
+ ndis += 4;
804
+
782
805
  counter = 0;
783
806
  }
784
807
  }
@@ -786,18 +809,102 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
786
809
  for (size_t icnt = 0; icnt < counter; icnt++) {
787
810
  float dis = qdis(saved_j[icnt]);
788
811
  add_to_heap(saved_j[icnt], dis);
812
+
813
+ ndis += 1;
789
814
  }
815
+
816
+ stats.nhops += 1;
790
817
  }
791
818
 
792
819
  ++stats.n1;
793
820
  if (candidates.size() == 0) {
794
821
  ++stats.n2;
795
822
  }
796
- stats.n3 += ndis;
823
+ stats.ndis += ndis;
797
824
 
798
825
  return top_candidates;
799
826
  }
800
827
 
828
+ /// greedily update a nearest vector at a given level
829
+ HNSWStats greedy_update_nearest(
830
+ const HNSW& hnsw,
831
+ DistanceComputer& qdis,
832
+ int level,
833
+ storage_idx_t& nearest,
834
+ float& d_nearest) {
835
+ HNSWStats stats;
836
+
837
+ for (;;) {
838
+ storage_idx_t prev_nearest = nearest;
839
+
840
+ size_t begin, end;
841
+ hnsw.neighbor_range(nearest, level, &begin, &end);
842
+
843
+ size_t ndis = 0;
844
+
845
+ // a faster version: reference version in unit test test_hnsw.cpp
846
+ // the following version processes 4 neighbors at a time
847
+ auto update_with_candidate = [&](const storage_idx_t idx,
848
+ const float dis) {
849
+ if (dis < d_nearest) {
850
+ nearest = idx;
851
+ d_nearest = dis;
852
+ }
853
+ };
854
+
855
+ int n_buffered = 0;
856
+ storage_idx_t buffered_ids[4];
857
+
858
+ for (size_t j = begin; j < end; j++) {
859
+ storage_idx_t v = hnsw.neighbors[j];
860
+ if (v < 0)
861
+ break;
862
+ ndis += 1;
863
+
864
+ buffered_ids[n_buffered] = v;
865
+ n_buffered += 1;
866
+
867
+ if (n_buffered == 4) {
868
+ float dis[4];
869
+ qdis.distances_batch_4(
870
+ buffered_ids[0],
871
+ buffered_ids[1],
872
+ buffered_ids[2],
873
+ buffered_ids[3],
874
+ dis[0],
875
+ dis[1],
876
+ dis[2],
877
+ dis[3]);
878
+
879
+ for (size_t id4 = 0; id4 < 4; id4++) {
880
+ update_with_candidate(buffered_ids[id4], dis[id4]);
881
+ }
882
+
883
+ n_buffered = 0;
884
+ }
885
+ }
886
+
887
+ // process leftovers
888
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
889
+ float dis = qdis(buffered_ids[icnt]);
890
+ update_with_candidate(buffered_ids[icnt], dis);
891
+ }
892
+
893
+ // update stats
894
+ stats.ndis += ndis;
895
+ stats.nhops += 1;
896
+
897
+ if (nearest == prev_nearest) {
898
+ return stats;
899
+ }
900
+ }
901
+ }
902
+
903
+ namespace {
904
+ using MinimaxHeap = HNSW::MinimaxHeap;
905
+ using Node = HNSW::Node;
906
+ using C = HNSW::C;
907
+
801
908
  // just used as a lower bound for the minmaxheap, but it is set for heap search
802
909
  int extract_k_from_ResultHandler(ResultHandler<C>& res) {
803
910
  using RH = HeapBlockResultHandler<C>;
@@ -807,7 +914,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
807
914
  return 1;
808
915
  }
809
916
 
810
- } // anonymous namespace
917
+ } // namespace
811
918
 
812
919
  HNSWStats HNSW::search(
813
920
  DistanceComputer& qdis,
@@ -820,85 +927,47 @@ HNSWStats HNSW::search(
820
927
  }
821
928
  int k = extract_k_from_ResultHandler(res);
822
929
 
823
- if (upper_beam == 1) {
824
- // greedy search on upper levels
825
- storage_idx_t nearest = entry_point;
826
- float d_nearest = qdis(nearest);
827
-
828
- for (int level = max_level; level >= 1; level--) {
829
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
830
- }
831
-
832
- int ef = std::max(params ? params->efSearch : efSearch, k);
833
- if (search_bounded_queue) { // this is the most common branch
834
- MinimaxHeap candidates(ef);
930
+ bool bounded_queue =
931
+ params ? params->bounded_queue : this->search_bounded_queue;
835
932
 
836
- candidates.push(nearest, d_nearest);
933
+ // greedy search on upper levels
934
+ storage_idx_t nearest = entry_point;
935
+ float d_nearest = qdis(nearest);
837
936
 
838
- search_from_candidates(
839
- *this, qdis, res, candidates, vt, stats, 0, 0, params);
840
- } else {
841
- std::priority_queue<Node> top_candidates =
842
- search_from_candidate_unbounded(
843
- *this,
844
- Node(d_nearest, nearest),
845
- qdis,
846
- ef,
847
- &vt,
848
- stats);
849
-
850
- while (top_candidates.size() > k) {
851
- top_candidates.pop();
852
- }
937
+ for (int level = max_level; level >= 1; level--) {
938
+ HNSWStats local_stats =
939
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
940
+ stats.combine(local_stats);
941
+ }
853
942
 
854
- while (!top_candidates.empty()) {
855
- float d;
856
- storage_idx_t label;
857
- std::tie(d, label) = top_candidates.top();
858
- res.add_result(d, label);
859
- top_candidates.pop();
860
- }
861
- }
943
+ int ef = std::max(params ? params->efSearch : efSearch, k);
944
+ if (bounded_queue) { // this is the most common branch
945
+ MinimaxHeap candidates(ef);
862
946
 
863
- vt.advance();
947
+ candidates.push(nearest, d_nearest);
864
948
 
949
+ search_from_candidates(
950
+ *this, qdis, res, candidates, vt, stats, 0, 0, params);
865
951
  } else {
866
- int candidates_size = upper_beam;
867
- MinimaxHeap candidates(candidates_size);
868
-
869
- std::vector<idx_t> I_to_next(candidates_size);
870
- std::vector<float> D_to_next(candidates_size);
871
-
872
- HeapBlockResultHandler<C> block_resh(
873
- 1, D_to_next.data(), I_to_next.data(), candidates_size);
874
- HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
875
-
876
- int nres = 1;
877
- I_to_next[0] = entry_point;
878
- D_to_next[0] = qdis(entry_point);
879
-
880
- for (int level = max_level; level >= 0; level--) {
881
- // copy I, D -> candidates
952
+ std::priority_queue<Node> top_candidates =
953
+ search_from_candidate_unbounded(
954
+ *this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
882
955
 
883
- candidates.clear();
884
-
885
- for (int i = 0; i < nres; i++) {
886
- candidates.push(I_to_next[i], D_to_next[i]);
887
- }
956
+ while (top_candidates.size() > k) {
957
+ top_candidates.pop();
958
+ }
888
959
 
889
- if (level == 0) {
890
- nres = search_from_candidates(
891
- *this, qdis, res, candidates, vt, stats, 0);
892
- } else {
893
- resh.begin(0);
894
- nres = search_from_candidates(
895
- *this, qdis, resh, candidates, vt, stats, level);
896
- resh.end();
897
- }
898
- vt.advance();
960
+ while (!top_candidates.empty()) {
961
+ float d;
962
+ storage_idx_t label;
963
+ std::tie(d, label) = top_candidates.top();
964
+ res.add_result(d, label);
965
+ top_candidates.pop();
899
966
  }
900
967
  }
901
968
 
969
+ vt.advance();
970
+
902
971
  return stats;
903
972
  }
904
973
 
@@ -910,9 +979,12 @@ void HNSW::search_level_0(
910
979
  const float* nearest_d,
911
980
  int search_type,
912
981
  HNSWStats& search_stats,
913
- VisitedTable& vt) const {
982
+ VisitedTable& vt,
983
+ const SearchParametersHNSW* params) const {
914
984
  const HNSW& hnsw = *this;
985
+ auto efSearch = params ? params->efSearch : hnsw.efSearch;
915
986
  int k = extract_k_from_ResultHandler(res);
987
+
916
988
  if (search_type == 1) {
917
989
  int nres = 0;
918
990
 
@@ -925,16 +997,25 @@ void HNSW::search_level_0(
925
997
  if (vt.get(cj))
926
998
  continue;
927
999
 
928
- int candidates_size = std::max(hnsw.efSearch, k);
1000
+ int candidates_size = std::max(efSearch, k);
929
1001
  MinimaxHeap candidates(candidates_size);
930
1002
 
931
1003
  candidates.push(cj, nearest_d[j]);
932
1004
 
933
1005
  nres = search_from_candidates(
934
- hnsw, qdis, res, candidates, vt, search_stats, 0, nres);
1006
+ hnsw,
1007
+ qdis,
1008
+ res,
1009
+ candidates,
1010
+ vt,
1011
+ search_stats,
1012
+ 0,
1013
+ nres,
1014
+ params);
1015
+ nres = std::min(nres, candidates_size);
935
1016
  }
936
1017
  } else if (search_type == 2) {
937
- int candidates_size = std::max(hnsw.efSearch, int(k));
1018
+ int candidates_size = std::max(efSearch, int(k));
938
1019
  candidates_size = std::max(candidates_size, int(nprobe));
939
1020
 
940
1021
  MinimaxHeap candidates(candidates_size);
@@ -947,7 +1028,7 @@ void HNSW::search_level_0(
947
1028
  }
948
1029
 
949
1030
  search_from_candidates(
950
- hnsw, qdis, res, candidates, vt, search_stats, 0);
1031
+ hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
951
1032
  }
952
1033
  }
953
1034
 
@@ -1013,7 +1094,99 @@ void HNSW::MinimaxHeap::clear() {
1013
1094
  nvalid = k = 0;
1014
1095
  }
1015
1096
 
1016
- #ifdef __AVX2__
1097
+ #ifdef __AVX512F__
1098
+
1099
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1100
+ assert(k > 0);
1101
+ static_assert(
1102
+ std::is_same<storage_idx_t, int32_t>::value,
1103
+ "This code expects storage_idx_t to be int32_t");
1104
+
1105
+ int32_t min_idx = -1;
1106
+ float min_dis = std::numeric_limits<float>::infinity();
1107
+
1108
+ __m512i min_indices = _mm512_set1_epi32(-1);
1109
+ __m512 min_distances =
1110
+ _mm512_set1_ps(std::numeric_limits<float>::infinity());
1111
+ __m512i current_indices = _mm512_setr_epi32(
1112
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1113
+ __m512i offset = _mm512_set1_epi32(16);
1114
+
1115
+ // The following loop tracks the rightmost index with the min distance.
1116
+ // -1 index values are ignored.
1117
+ const int k16 = (k / 16) * 16;
1118
+ for (size_t iii = 0; iii < k16; iii += 16) {
1119
+ __m512i indices =
1120
+ _mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1121
+ __m512 distances = _mm512_loadu_ps(dis.data() + iii);
1122
+
1123
+ // This mask filters out -1 values among indices.
1124
+ __mmask16 m1mask =
1125
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1126
+
1127
+ __mmask16 dmask =
1128
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1129
+ __mmask16 finalmask = m1mask | dmask;
1130
+
1131
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1132
+ finalmask, current_indices, min_indices);
1133
+ const __m512 min_distances_new =
1134
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1135
+
1136
+ min_indices = min_indices_new;
1137
+ min_distances = min_distances_new;
1138
+
1139
+ current_indices = _mm512_add_epi32(current_indices, offset);
1140
+ }
1141
+
1142
+ // leftovers
1143
+ if (k16 != k) {
1144
+ const __mmask16 kmask = (1 << (k - k16)) - 1;
1145
+
1146
+ __m512i indices = _mm512_mask_loadu_epi32(
1147
+ _mm512_set1_epi32(-1), kmask, ids.data() + k16);
1148
+ __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1149
+
1150
+ // This mask filters out -1 values among indices.
1151
+ __mmask16 m1mask =
1152
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1153
+
1154
+ __mmask16 dmask =
1155
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1156
+ __mmask16 finalmask = m1mask | dmask;
1157
+
1158
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1159
+ finalmask, current_indices, min_indices);
1160
+ const __m512 min_distances_new =
1161
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1162
+
1163
+ min_indices = min_indices_new;
1164
+ min_distances = min_distances_new;
1165
+ }
1166
+
1167
+ // grab min distance
1168
+ min_dis = _mm512_reduce_min_ps(min_distances);
1169
+ // blend
1170
+ __mmask16 mindmask =
1171
+ _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1172
+ // pick the max one
1173
+ min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1174
+
1175
+ if (min_idx == -1) {
1176
+ return -1;
1177
+ }
1178
+
1179
+ if (vmin_out) {
1180
+ *vmin_out = min_dis;
1181
+ }
1182
+ int ret = ids[min_idx];
1183
+ ids[min_idx] = -1;
1184
+ --nvalid;
1185
+ return ret;
1186
+ }
1187
+
1188
+ #elif __AVX2__
1189
+
1017
1190
  int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1018
1191
  assert(k > 0);
1019
1192
  static_assert(