faiss 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -31,6 +31,8 @@ namespace {
31
31
  * writes results in a ResultHandler
32
32
  */
33
33
 
34
+ #ifndef __AVX512F__
35
+
34
36
  template <int NQ, class ResultHandler, class Scaler>
35
37
  void kernel_accumulate_block(
36
38
  int nsq,
@@ -111,6 +113,451 @@ void kernel_accumulate_block(
111
113
  }
112
114
  }
113
115
 
116
+ #else
117
+
118
+ // a special version for NQ=1.
119
+ // Despite the function being large in the text form, it compiles to a very
120
+ // compact assembler code.
121
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
122
+ template <class ResultHandler, class Scaler>
123
+ void kernel_accumulate_block_avx512_nq1(
124
+ int nsq,
125
+ const uint8_t* codes,
126
+ const uint8_t* LUT,
127
+ ResultHandler& res,
128
+ const Scaler& scaler) {
129
+ // NQ is kept in order to match the similarity to baseline function
130
+ constexpr int NQ = 1;
131
+ // distance accumulators. We can accept more for NQ=1
132
+ // layout: accu[q][b]: distance accumulator for vectors 32*b..32*b+15
133
+ simd32uint16 accu[NQ][4];
134
+ // layout: accu[q][b]: distance accumulator for vectors 32*b+16..32*b+31
135
+ simd32uint16 accu1[NQ][4];
136
+
137
+ for (int q = 0; q < NQ; q++) {
138
+ for (int b = 0; b < 4; b++) {
139
+ accu[q][b].clear();
140
+ accu1[q][b].clear();
141
+ }
142
+ }
143
+
144
+ // process "nsq - scaler.nscale" part
145
+ const int nsq_minus_nscale = nsq - scaler.nscale;
146
+ const int nsq_minus_nscale_8 = (nsq_minus_nscale / 8) * 8;
147
+ const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4;
148
+
149
+ // process in chunks of 8
150
+ for (int sq = 0; sq < nsq_minus_nscale_8; sq += 8) {
151
+ // prefetch
152
+ simd64uint8 c(codes);
153
+ codes += 64;
154
+
155
+ simd64uint8 c1(codes);
156
+ codes += 64;
157
+
158
+ simd64uint8 mask(0xf);
159
+ // shift op does not exist for int8...
160
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
161
+ simd64uint8 clo = c & mask;
162
+
163
+ simd64uint8 c1hi = simd64uint8(simd32uint16(c1) >> 4) & mask;
164
+ simd64uint8 c1lo = c1 & mask;
165
+
166
+ for (int q = 0; q < NQ; q++) {
167
+ // load LUTs for 4 quantizers
168
+ simd64uint8 lut(LUT);
169
+ LUT += 64;
170
+
171
+ {
172
+ simd64uint8 res0 = lut.lookup_4_lanes(clo);
173
+ simd64uint8 res1 = lut.lookup_4_lanes(chi);
174
+
175
+ accu[q][0] += simd32uint16(res0);
176
+ accu[q][1] += simd32uint16(res0) >> 8;
177
+
178
+ accu[q][2] += simd32uint16(res1);
179
+ accu[q][3] += simd32uint16(res1) >> 8;
180
+ }
181
+ }
182
+
183
+ for (int q = 0; q < NQ; q++) {
184
+ // load LUTs for 4 quantizers
185
+ simd64uint8 lut(LUT);
186
+ LUT += 64;
187
+
188
+ {
189
+ simd64uint8 res0 = lut.lookup_4_lanes(c1lo);
190
+ simd64uint8 res1 = lut.lookup_4_lanes(c1hi);
191
+
192
+ accu1[q][0] += simd32uint16(res0);
193
+ accu1[q][1] += simd32uint16(res0) >> 8;
194
+
195
+ accu1[q][2] += simd32uint16(res1);
196
+ accu1[q][3] += simd32uint16(res1) >> 8;
197
+ }
198
+ }
199
+ }
200
+
201
+ // process leftovers: a single chunk of size 4
202
+ if (nsq_minus_nscale_8 != nsq_minus_nscale_4) {
203
+ // prefetch
204
+ simd64uint8 c(codes);
205
+ codes += 64;
206
+
207
+ simd64uint8 mask(0xf);
208
+ // shift op does not exist for int8...
209
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
210
+ simd64uint8 clo = c & mask;
211
+
212
+ for (int q = 0; q < NQ; q++) {
213
+ // load LUTs for 4 quantizers
214
+ simd64uint8 lut(LUT);
215
+ LUT += 64;
216
+
217
+ simd64uint8 res0 = lut.lookup_4_lanes(clo);
218
+ simd64uint8 res1 = lut.lookup_4_lanes(chi);
219
+
220
+ accu[q][0] += simd32uint16(res0);
221
+ accu[q][1] += simd32uint16(res0) >> 8;
222
+
223
+ accu[q][2] += simd32uint16(res1);
224
+ accu[q][3] += simd32uint16(res1) >> 8;
225
+ }
226
+ }
227
+
228
+ // process leftovers: a single chunk of size 2
229
+ if (nsq_minus_nscale_4 != nsq_minus_nscale) {
230
+ // prefetch
231
+ simd32uint8 c(codes);
232
+ codes += 32;
233
+
234
+ simd32uint8 mask(0xf);
235
+ // shift op does not exist for int8...
236
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
237
+ simd32uint8 clo = c & mask;
238
+
239
+ for (int q = 0; q < NQ; q++) {
240
+ // load LUTs for 2 quantizers
241
+ simd32uint8 lut(LUT);
242
+ LUT += 32;
243
+
244
+ simd32uint8 res0 = lut.lookup_2_lanes(clo);
245
+ simd32uint8 res1 = lut.lookup_2_lanes(chi);
246
+
247
+ accu[q][0] += simd32uint16(simd16uint16(res0));
248
+ accu[q][1] += simd32uint16(simd16uint16(res0) >> 8);
249
+
250
+ accu[q][2] += simd32uint16(simd16uint16(res1));
251
+ accu[q][3] += simd32uint16(simd16uint16(res1) >> 8);
252
+ }
253
+ }
254
+
255
+ // process "sq" part
256
+ const int nscale = scaler.nscale;
257
+ const int nscale_8 = (nscale / 8) * 8;
258
+ const int nscale_4 = (nscale / 4) * 4;
259
+
260
+ // process in chunks of 8
261
+ for (int sq = 0; sq < nscale_8; sq += 8) {
262
+ // prefetch
263
+ simd64uint8 c(codes);
264
+ codes += 64;
265
+
266
+ simd64uint8 c1(codes);
267
+ codes += 64;
268
+
269
+ simd64uint8 mask(0xf);
270
+ // shift op does not exist for int8...
271
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
272
+ simd64uint8 clo = c & mask;
273
+
274
+ simd64uint8 c1hi = simd64uint8(simd32uint16(c1) >> 4) & mask;
275
+ simd64uint8 c1lo = c1 & mask;
276
+
277
+ for (int q = 0; q < NQ; q++) {
278
+ // load LUTs for 4 quantizers
279
+ simd64uint8 lut(LUT);
280
+ LUT += 64;
281
+
282
+ {
283
+ simd64uint8 res0 = scaler.lookup(lut, clo);
284
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..15
285
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 16..31
286
+
287
+ simd64uint8 res1 = scaler.lookup(lut, chi);
288
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 32..47
289
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 48..63
290
+ }
291
+ }
292
+
293
+ for (int q = 0; q < NQ; q++) {
294
+ // load LUTs for 4 quantizers
295
+ simd64uint8 lut(LUT);
296
+ LUT += 64;
297
+
298
+ {
299
+ simd64uint8 res0 = scaler.lookup(lut, c1lo);
300
+ accu1[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
301
+ accu1[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
302
+
303
+ simd64uint8 res1 = scaler.lookup(lut, c1hi);
304
+ accu1[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
305
+ accu1[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
306
+ }
307
+ }
308
+ }
309
+
310
+ // process leftovers: a single chunk of size 4
311
+ if (nscale_8 != nscale_4) {
312
+ // prefetch
313
+ simd64uint8 c(codes);
314
+ codes += 64;
315
+
316
+ simd64uint8 mask(0xf);
317
+ // shift op does not exist for int8...
318
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
319
+ simd64uint8 clo = c & mask;
320
+
321
+ for (int q = 0; q < NQ; q++) {
322
+ // load LUTs for 4 quantizers
323
+ simd64uint8 lut(LUT);
324
+ LUT += 64;
325
+
326
+ simd64uint8 res0 = scaler.lookup(lut, clo);
327
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..15
328
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 16..31
329
+
330
+ simd64uint8 res1 = scaler.lookup(lut, chi);
331
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 32..47
332
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 48..63
333
+ }
334
+ }
335
+
336
+ // process leftovers: a single chunk of size 2
337
+ if (nscale_4 != nscale) {
338
+ // prefetch
339
+ simd32uint8 c(codes);
340
+ codes += 32;
341
+
342
+ simd32uint8 mask(0xf);
343
+ // shift op does not exist for int8...
344
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
345
+ simd32uint8 clo = c & mask;
346
+
347
+ for (int q = 0; q < NQ; q++) {
348
+ // load LUTs for 2 quantizers
349
+ simd32uint8 lut(LUT);
350
+ LUT += 32;
351
+
352
+ simd32uint8 res0 = scaler.lookup(lut, clo);
353
+ accu[q][0] +=
354
+ simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
355
+ accu[q][1] +=
356
+ simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15
357
+
358
+ simd32uint8 res1 = scaler.lookup(lut, chi);
359
+ accu[q][2] += simd32uint16(
360
+ scaler.scale_lo(res1)); // handle vectors 16..23
361
+ accu[q][3] += simd32uint16(
362
+ scaler.scale_hi(res1)); // handle vectors 24..31
363
+ }
364
+ }
365
+
366
+ for (int q = 0; q < NQ; q++) {
367
+ for (int b = 0; b < 4; b++) {
368
+ accu[q][b] += accu1[q][b];
369
+ }
370
+ }
371
+
372
+ for (int q = 0; q < NQ; q++) {
373
+ accu[q][0] -= accu[q][1] << 8;
374
+ simd16uint16 dis0 = combine4x2(accu[q][0], accu[q][1]);
375
+ accu[q][2] -= accu[q][3] << 8;
376
+ simd16uint16 dis1 = combine4x2(accu[q][2], accu[q][3]);
377
+ res.handle(q, 0, dis0, dis1);
378
+ }
379
+ }
380
+
381
+ // general-purpose case
382
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
383
+ template <int NQ, class ResultHandler, class Scaler>
384
+ void kernel_accumulate_block_avx512_nqx(
385
+ int nsq,
386
+ const uint8_t* codes,
387
+ const uint8_t* LUT,
388
+ ResultHandler& res,
389
+ const Scaler& scaler) {
390
+ // dummy alloc to keep the windows compiler happy
391
+ constexpr int NQA = NQ > 0 ? NQ : 1;
392
+ // distance accumulators
393
+ // layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7
394
+ simd32uint16 accu[NQA][4];
395
+
396
+ for (int q = 0; q < NQ; q++) {
397
+ for (int b = 0; b < 4; b++) {
398
+ accu[q][b].clear();
399
+ }
400
+ }
401
+
402
+ // process "nsq - scaler.nscale" part
403
+ const int nsq_minus_nscale = nsq - scaler.nscale;
404
+ const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4;
405
+
406
+ // process in chunks of 8
407
+ for (int sq = 0; sq < nsq_minus_nscale_4; sq += 4) {
408
+ // prefetch
409
+ simd64uint8 c(codes);
410
+ codes += 64;
411
+
412
+ simd64uint8 mask(0xf);
413
+ // shift op does not exist for int8...
414
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
415
+ simd64uint8 clo = c & mask;
416
+
417
+ for (int q = 0; q < NQ; q++) {
418
+ // load LUTs for 4 quantizers
419
+ simd32uint8 lut_a(LUT);
420
+ simd32uint8 lut_b(LUT + NQ * 32);
421
+
422
+ simd64uint8 lut(lut_a, lut_b);
423
+ LUT += 32;
424
+
425
+ {
426
+ simd64uint8 res0 = lut.lookup_4_lanes(clo);
427
+ simd64uint8 res1 = lut.lookup_4_lanes(chi);
428
+
429
+ accu[q][0] += simd32uint16(res0);
430
+ accu[q][1] += simd32uint16(res0) >> 8;
431
+
432
+ accu[q][2] += simd32uint16(res1);
433
+ accu[q][3] += simd32uint16(res1) >> 8;
434
+ }
435
+ }
436
+
437
+ LUT += NQ * 32;
438
+ }
439
+
440
+ // process leftovers: a single chunk of size 2
441
+ if (nsq_minus_nscale_4 != nsq_minus_nscale) {
442
+ // prefetch
443
+ simd32uint8 c(codes);
444
+ codes += 32;
445
+
446
+ simd32uint8 mask(0xf);
447
+ // shift op does not exist for int8...
448
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
449
+ simd32uint8 clo = c & mask;
450
+
451
+ for (int q = 0; q < NQ; q++) {
452
+ // load LUTs for 2 quantizers
453
+ simd32uint8 lut(LUT);
454
+ LUT += 32;
455
+
456
+ simd32uint8 res0 = lut.lookup_2_lanes(clo);
457
+ simd32uint8 res1 = lut.lookup_2_lanes(chi);
458
+
459
+ accu[q][0] += simd32uint16(simd16uint16(res0));
460
+ accu[q][1] += simd32uint16(simd16uint16(res0) >> 8);
461
+
462
+ accu[q][2] += simd32uint16(simd16uint16(res1));
463
+ accu[q][3] += simd32uint16(simd16uint16(res1) >> 8);
464
+ }
465
+ }
466
+
467
+ // process "sq" part
468
+ const int nscale = scaler.nscale;
469
+ const int nscale_4 = (nscale / 4) * 4;
470
+
471
+ // process in chunks of 4
472
+ for (int sq = 0; sq < nscale_4; sq += 4) {
473
+ // prefetch
474
+ simd64uint8 c(codes);
475
+ codes += 64;
476
+
477
+ simd64uint8 mask(0xf);
478
+ // shift op does not exist for int8...
479
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
480
+ simd64uint8 clo = c & mask;
481
+
482
+ for (int q = 0; q < NQ; q++) {
483
+ // load LUTs for 4 quantizers
484
+ simd32uint8 lut_a(LUT);
485
+ simd32uint8 lut_b(LUT + NQ * 32);
486
+
487
+ simd64uint8 lut(lut_a, lut_b);
488
+ LUT += 32;
489
+
490
+ {
491
+ simd64uint8 res0 = scaler.lookup(lut, clo);
492
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
493
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
494
+
495
+ simd64uint8 res1 = scaler.lookup(lut, chi);
496
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
497
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
498
+ }
499
+ }
500
+
501
+ LUT += NQ * 32;
502
+ }
503
+
504
+ // process leftovers: a single chunk of size 2
505
+ if (nscale_4 != nscale) {
506
+ // prefetch
507
+ simd32uint8 c(codes);
508
+ codes += 32;
509
+
510
+ simd32uint8 mask(0xf);
511
+ // shift op does not exist for int8...
512
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
513
+ simd32uint8 clo = c & mask;
514
+
515
+ for (int q = 0; q < NQ; q++) {
516
+ // load LUTs for 2 quantizers
517
+ simd32uint8 lut(LUT);
518
+ LUT += 32;
519
+
520
+ simd32uint8 res0 = scaler.lookup(lut, clo);
521
+ accu[q][0] +=
522
+ simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
523
+ accu[q][1] +=
524
+ simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15
525
+
526
+ simd32uint8 res1 = scaler.lookup(lut, chi);
527
+ accu[q][2] += simd32uint16(
528
+ scaler.scale_lo(res1)); // handle vectors 16..23
529
+ accu[q][3] += simd32uint16(
530
+ scaler.scale_hi(res1)); // handle vectors 24..31
531
+ }
532
+ }
533
+
534
+ for (int q = 0; q < NQ; q++) {
535
+ accu[q][0] -= accu[q][1] << 8;
536
+ simd16uint16 dis0 = combine4x2(accu[q][0], accu[q][1]);
537
+ accu[q][2] -= accu[q][3] << 8;
538
+ simd16uint16 dis1 = combine4x2(accu[q][2], accu[q][3]);
539
+ res.handle(q, 0, dis0, dis1);
540
+ }
541
+ }
542
+
543
+ template <int NQ, class ResultHandler, class Scaler>
544
+ void kernel_accumulate_block(
545
+ int nsq,
546
+ const uint8_t* codes,
547
+ const uint8_t* LUT,
548
+ ResultHandler& res,
549
+ const Scaler& scaler) {
550
+ if constexpr (NQ == 1) {
551
+ kernel_accumulate_block_avx512_nq1<ResultHandler, Scaler>(
552
+ nsq, codes, LUT, res, scaler);
553
+ } else {
554
+ kernel_accumulate_block_avx512_nqx<NQ, ResultHandler, Scaler>(
555
+ nsq, codes, LUT, res, scaler);
556
+ }
557
+ }
558
+
559
+ #endif
560
+
114
561
  // handle at most 4 blocks of queries
115
562
  template <int QBS, class ResultHandler, class Scaler>
116
563
  void accumulate_q_4step(
@@ -304,7 +751,7 @@ void pq4_accumulate_loop_qbs(
304
751
  SIMDResultHandler& res,
305
752
  const NormTableScaler* scaler) {
306
753
  Run_pq4_accumulate_loop_qbs consumer;
307
- dispatch_SIMDResultHanlder(res, consumer, qbs, nb, nsq, codes, LUT, scaler);
754
+ dispatch_SIMDResultHandler(res, consumer, qbs, nb, nsq, codes, LUT, scaler);
308
755
  }
309
756
 
310
757
  /***************************************************************
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -95,7 +95,7 @@ void accum_and_store_tab(
95
95
  for (size_t ij = 1; ij < M; ij++) {
96
96
  reg += cbs[ij][kk];
97
97
  }
98
- output[b * K + kk] = reg;
98
+ output[kk] = reg;
99
99
  }
100
100
  }
101
101
 
@@ -152,7 +152,7 @@ void accum_and_add_tab(
152
152
  for (size_t ij = 1; ij < M; ij++) {
153
153
  reg += cbs[ij][kk];
154
154
  }
155
- output[b * K + kk] += reg;
155
+ output[kk] += reg;
156
156
  }
157
157
  }
158
158
 
@@ -664,8 +664,6 @@ void refine_beam_mp(
664
664
  std::unique_ptr<Index> assign_index;
665
665
  if (rq.assign_index_factory) {
666
666
  assign_index.reset((*rq.assign_index_factory)(rq.d));
667
- } else {
668
- assign_index.reset(new IndexFlatL2(rq.d));
669
667
  }
670
668
 
671
669
  // main loop
@@ -701,7 +699,9 @@ void refine_beam_mp(
701
699
  assign_index.get(),
702
700
  rq.approx_topk_mode);
703
701
 
704
- assign_index->reset();
702
+ if (assign_index != nullptr) {
703
+ assign_index->reset();
704
+ }
705
705
 
706
706
  std::swap(codes_ptr, new_codes_ptr);
707
707
  std::swap(residuals_ptr, new_residuals_ptr);
@@ -809,7 +809,7 @@ void refine_beam_LUT_mp(
809
809
  rq.codebook_offsets.data(),
810
810
  query_cp + rq.codebook_offsets[m],
811
811
  rq.total_codebook_size,
812
- rq.cent_norms.data() + rq.codebook_offsets[m],
812
+ rq.centroid_norms.data() + rq.codebook_offsets[m],
813
813
  m,
814
814
  codes_ptr,
815
815
  distances_ptr,
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -25,7 +25,7 @@ namespace faiss {
25
25
  * It allows low-level access to the encoding function, exposed mainly for unit
26
26
  * tests.
27
27
  *
28
- * @param n number of vectors to hanlde
28
+ * @param n number of vectors to handle
29
29
  * @param residuals vectors to encode, size (n, beam_size, d)
30
30
  * @param cent centroids, size (K, d)
31
31
  * @param beam_size input beam size