faiss 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (361) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +2 -1
  4. data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
  5. data/ext/faiss/index_binary.cpp +1 -1
  6. data/ext/faiss/kmeans.cpp +1 -1
  7. data/ext/faiss/pca_matrix.cpp +1 -1
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
  10. data/lib/faiss/version.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +93 -80
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -240
  13. data/vendor/faiss/faiss/Clustering.h +6 -0
  14. data/vendor/faiss/faiss/IVFlib.cpp +41 -21
  15. data/vendor/faiss/faiss/Index.cpp +6 -5
  16. data/vendor/faiss/faiss/Index.h +5 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  21. data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
  22. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
  32. data/vendor/faiss/faiss/IndexFastScan.h +25 -23
  33. data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
  34. data/vendor/faiss/faiss/IndexFlat.h +21 -18
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
  36. data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
  37. data/vendor/faiss/faiss/IndexHNSW.h +16 -2
  38. data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
  39. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  40. data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
  41. data/vendor/faiss/faiss/IndexIVF.h +33 -12
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
  45. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
  46. data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
  47. data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
  48. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  49. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
  50. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
  53. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  55. data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
  56. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
  57. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
  58. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
  59. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  62. data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
  63. data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  66. data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
  67. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
  68. data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
  69. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  73. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  74. data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
  75. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
  76. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
  77. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
  78. data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
  79. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  80. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  84. data/vendor/faiss/faiss/IndexShards.cpp +10 -9
  85. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  86. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  88. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  89. data/vendor/faiss/faiss/MetricType.h +14 -7
  90. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  91. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  92. data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
  93. data/vendor/faiss/faiss/VectorTransform.h +16 -16
  94. data/vendor/faiss/faiss/build.cpp +23 -0
  95. data/vendor/faiss/faiss/build.h +15 -0
  96. data/vendor/faiss/faiss/clone_index.cpp +48 -47
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  101. data/vendor/faiss/faiss/factory_tools.cpp +5 -0
  102. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  109. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  110. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  111. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  112. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  113. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  114. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  115. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  116. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  117. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  118. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  119. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  120. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  121. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  122. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  123. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  125. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
  127. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  128. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  129. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
  130. data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
  131. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
  132. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  133. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
  134. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  135. data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
  136. data/vendor/faiss/faiss/impl/HNSW.h +13 -34
  137. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  138. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  139. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
  141. data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
  142. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  143. data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
  144. data/vendor/faiss/faiss/impl/NSG.h +4 -4
  145. data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
  146. data/vendor/faiss/faiss/impl/Panorama.h +258 -87
  147. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  148. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  149. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
  150. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  151. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  152. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  153. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
  154. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  155. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
  156. data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
  157. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
  158. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
  159. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  160. data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
  161. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
  162. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
  163. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  164. data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
  165. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  166. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  167. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  168. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  169. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  170. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  171. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  172. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  173. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  174. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  175. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  176. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  177. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  178. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  179. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  180. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  181. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  182. data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
  183. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  184. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  185. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  186. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  187. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  188. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  189. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
  190. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  191. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  192. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  193. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  194. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  195. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  196. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
  197. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  198. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  199. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
  200. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  201. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  202. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  203. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  204. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  205. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  206. data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
  207. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
  208. data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
  209. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  210. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  211. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  212. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
  213. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  214. data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
  215. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  216. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  217. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  218. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  219. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  220. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  221. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  222. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  223. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
  224. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
  225. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
  226. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  227. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
  228. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
  229. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  230. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
  231. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  232. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
  233. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
  234. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
  235. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
  236. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
  237. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
  238. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
  239. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
  240. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
  241. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  242. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
  243. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
  244. data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
  245. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  246. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
  247. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  248. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  249. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  250. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
  251. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  252. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  253. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  254. data/vendor/faiss/faiss/index_factory.cpp +86 -18
  255. data/vendor/faiss/faiss/index_io.h +24 -0
  256. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
  257. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  258. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  259. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  260. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  261. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  262. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
  263. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  264. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
  265. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  266. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  267. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  268. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  269. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  270. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  271. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  272. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  273. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
  274. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  275. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
  276. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
  277. data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
  278. data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
  279. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  280. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  281. data/vendor/faiss/faiss/utils/distances.cpp +390 -560
  282. data/vendor/faiss/faiss/utils/distances.h +20 -1
  283. data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
  284. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  285. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  286. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  287. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  288. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  289. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  290. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
  291. data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
  292. data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
  293. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  294. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  295. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  296. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  297. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  298. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  299. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  300. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  301. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  302. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  303. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  304. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  305. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  306. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  307. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  308. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  309. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  310. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  311. data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
  312. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  313. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  314. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  315. data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
  316. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  317. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  318. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
  319. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
  320. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
  321. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
  322. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
  323. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  324. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  325. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
  326. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  327. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  328. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  329. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  330. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  331. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  332. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  333. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  339. data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
  340. data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
  341. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  342. data/vendor/faiss/faiss/utils/utils.cpp +5 -5
  343. data/vendor/faiss/faiss/utils/utils.h +3 -3
  344. metadata +119 -34
  345. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  346. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  347. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
  348. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
  349. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  350. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  351. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  352. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  353. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
  354. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  355. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  356. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
  357. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  358. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  359. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  360. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
  361. /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
@@ -13,10 +13,10 @@
13
13
  #include <faiss/impl/simd_dispatch.h>
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/distances.h>
16
- #include <faiss/utils/simdlib.h>
17
16
  #include <faiss/utils/utils.h>
18
17
 
19
- #include <faiss/utils/approx_topk/approx_topk.h>
18
+ #include <faiss/impl/approx_topk/approx_topk.h>
19
+ #include <faiss/impl/approx_topk/rq_beam_search_tab.h>
20
20
 
21
21
  extern "C" {
22
22
 
@@ -39,190 +39,6 @@ int sgemm_(
39
39
 
40
40
  namespace faiss {
41
41
 
42
- /********************************************************************
43
- * Basic routines
44
- ********************************************************************/
45
-
46
- namespace {
47
-
48
- template <size_t M, size_t NK>
49
- void accum_and_store_tab(
50
- const size_t m_offset,
51
- const float* const __restrict codebook_cross_norms,
52
- const uint64_t* const __restrict codebook_offsets,
53
- const int32_t* const __restrict codes_i,
54
- const size_t b,
55
- const size_t ldc,
56
- const size_t K,
57
- float* const __restrict output) {
58
- // load pointers into registers
59
- const float* cbs[M];
60
- for (size_t ij = 0; ij < M; ij++) {
61
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
62
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
63
- }
64
-
65
- // do accumulation in registers using SIMD.
66
- // It is possible that compiler may be smart enough so that
67
- // this manual SIMD unrolling might be unneeded.
68
- #if defined(__AVX2__) || defined(__aarch64__)
69
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
70
-
71
- // process in chunks of size (8 * NK) floats
72
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
73
- simd8float32 regs[NK];
74
- for (size_t ik = 0; ik < NK; ik++) {
75
- regs[ik].loadu(cbs[0] + kk + ik * 8);
76
- }
77
-
78
- for (size_t ij = 1; ij < M; ij++) {
79
- for (size_t ik = 0; ik < NK; ik++) {
80
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
81
- }
82
- }
83
-
84
- // write the result
85
- for (size_t ik = 0; ik < NK; ik++) {
86
- regs[ik].storeu(output + kk + ik * 8);
87
- }
88
- }
89
- #else
90
- const size_t K8 = 0;
91
- #endif
92
-
93
- // process leftovers
94
- for (size_t kk = K8; kk < K; kk++) {
95
- float reg = cbs[0][kk];
96
- for (size_t ij = 1; ij < M; ij++) {
97
- reg += cbs[ij][kk];
98
- }
99
- output[kk] = reg;
100
- }
101
- }
102
-
103
- template <size_t M, size_t NK>
104
- void accum_and_add_tab(
105
- const size_t m_offset,
106
- const float* const __restrict codebook_cross_norms,
107
- const uint64_t* const __restrict codebook_offsets,
108
- const int32_t* const __restrict codes_i,
109
- const size_t b,
110
- const size_t ldc,
111
- const size_t K,
112
- float* const __restrict output) {
113
- // load pointers into registers
114
- const float* cbs[M];
115
- for (size_t ij = 0; ij < M; ij++) {
116
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
117
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
118
- }
119
-
120
- // do accumulation in registers using SIMD.
121
- // It is possible that compiler may be smart enough so that
122
- // this manual SIMD unrolling might be unneeded.
123
- #if defined(__AVX2__) || defined(__aarch64__)
124
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
125
-
126
- // process in chunks of size (8 * NK) floats
127
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
128
- simd8float32 regs[NK];
129
- for (size_t ik = 0; ik < NK; ik++) {
130
- regs[ik].loadu(cbs[0] + kk + ik * 8);
131
- }
132
-
133
- for (size_t ij = 1; ij < M; ij++) {
134
- for (size_t ik = 0; ik < NK; ik++) {
135
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
136
- }
137
- }
138
-
139
- // write the result
140
- for (size_t ik = 0; ik < NK; ik++) {
141
- simd8float32 existing(output + kk + ik * 8);
142
- existing += regs[ik];
143
- existing.storeu(output + kk + ik * 8);
144
- }
145
- }
146
- #else
147
- const size_t K8 = 0;
148
- #endif
149
-
150
- // process leftovers
151
- for (size_t kk = K8; kk < K; kk++) {
152
- float reg = cbs[0][kk];
153
- for (size_t ij = 1; ij < M; ij++) {
154
- reg += cbs[ij][kk];
155
- }
156
- output[kk] += reg;
157
- }
158
- }
159
-
160
- template <size_t M, size_t NK>
161
- void accum_and_finalize_tab(
162
- const float* const __restrict codebook_cross_norms,
163
- const uint64_t* const __restrict codebook_offsets,
164
- const int32_t* const __restrict codes_i,
165
- const size_t b,
166
- const size_t ldc,
167
- const size_t K,
168
- const float* const __restrict distances_i,
169
- const float* const __restrict cd_common,
170
- float* const __restrict output) {
171
- // load pointers into registers
172
- const float* cbs[M];
173
- for (size_t ij = 0; ij < M; ij++) {
174
- const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
175
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
176
- }
177
-
178
- // do accumulation in registers using SIMD.
179
- // It is possible that compiler may be smart enough so that
180
- // this manual SIMD unrolling might be unneeded.
181
- #if defined(__AVX2__) || defined(__aarch64__)
182
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
183
-
184
- // process in chunks of size (8 * NK) floats
185
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
186
- simd8float32 regs[NK];
187
- for (size_t ik = 0; ik < NK; ik++) {
188
- regs[ik].loadu(cbs[0] + kk + ik * 8);
189
- }
190
-
191
- for (size_t ij = 1; ij < M; ij++) {
192
- for (size_t ik = 0; ik < NK; ik++) {
193
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
194
- }
195
- }
196
-
197
- simd8float32 two(2.0f);
198
- for (size_t ik = 0; ik < NK; ik++) {
199
- // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
200
- // + 2 * dp[k];
201
-
202
- simd8float32 common_v(cd_common + kk + ik * 8);
203
- common_v = fmadd(two, regs[ik], common_v);
204
-
205
- common_v += simd8float32(distances_i[b]);
206
- common_v.storeu(output + b * K + kk + ik * 8);
207
- }
208
- }
209
- #else
210
- const size_t K8 = 0;
211
- #endif
212
-
213
- // process leftovers
214
- for (size_t kk = K8; kk < K; kk++) {
215
- float reg = cbs[0][kk];
216
- for (size_t ij = 1; ij < M; ij++) {
217
- reg += cbs[ij][kk];
218
- }
219
-
220
- output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
221
- }
222
- }
223
-
224
- } // anonymous namespace
225
-
226
42
  /********************************************************************
227
43
  * Single encoding step
228
44
  ********************************************************************/
@@ -250,12 +66,12 @@ void beam_search_encode_step(
250
66
 
251
67
  if (assign_index) {
252
68
  // search beam_size distances per query
253
- FAISS_THROW_IF_NOT(assign_index->d == d);
69
+ FAISS_THROW_IF_NOT(assign_index->d == static_cast<int>(d));
254
70
  cent_distances.resize(n * beam_size * new_beam_size);
255
71
  cent_ids.resize(n * beam_size * new_beam_size);
256
72
  if (assign_index->ntotal != 0) {
257
73
  // then we assume the codebooks are already added to the index
258
- FAISS_THROW_IF_NOT(assign_index->ntotal == K);
74
+ FAISS_THROW_IF_NOT(assign_index->ntotal == static_cast<idx_t>(K));
259
75
  } else {
260
76
  assign_index->add(K, cent);
261
77
  }
@@ -276,110 +92,259 @@ void beam_search_encode_step(
276
92
  }
277
93
  InterruptCallback::check();
278
94
 
95
+ // Resolve SIMD level once, not per iteration of the n-parallel loop.
96
+ with_simd_level_256bit([&]<SIMDLevel SL>() {
279
97
  #pragma omp parallel for if (n > 100)
280
- for (int64_t i = 0; i < n; i++) {
281
- const int32_t* codes_i = codes + i * m * beam_size;
282
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
283
- const float* residuals_i = residuals + i * d * beam_size;
284
- float* new_residuals_i = new_residuals + i * d * new_beam_size;
285
-
286
- float* new_distances_i = new_distances + i * new_beam_size;
287
- using C = CMax<float, int>;
288
-
289
- if (assign_index) {
290
- const float* cent_distances_i =
291
- cent_distances.data() + i * beam_size * new_beam_size;
292
- const idx_t* cent_ids_i =
293
- cent_ids.data() + i * beam_size * new_beam_size;
294
-
295
- // here we could be a tad more efficient by merging sorted arrays
296
- for (int j = 0; j < new_beam_size; j++) {
297
- new_distances_i[j] = C::neutral();
298
- }
299
- std::vector<int> perm(new_beam_size, -1);
300
- heap_addn<C>(
301
- new_beam_size,
302
- new_distances_i,
303
- perm.data(),
304
- cent_distances_i,
305
- nullptr,
306
- beam_size * new_beam_size);
307
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
98
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
99
+ const int32_t* codes_i = codes + i * m * beam_size;
100
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
101
+ const float* residuals_i = residuals + i * d * beam_size;
102
+ float* new_residuals_i = new_residuals + i * d * new_beam_size;
103
+
104
+ float* new_distances_i = new_distances + i * new_beam_size;
105
+ using C = CMax<float, int>;
106
+
107
+ if (assign_index) {
108
+ const float* cent_distances_i =
109
+ cent_distances.data() + i * beam_size * new_beam_size;
110
+ const idx_t* cent_ids_i =
111
+ cent_ids.data() + i * beam_size * new_beam_size;
112
+
113
+ // here we could be a tad more efficient by merging sorted
114
+ // arrays
115
+ for (size_t j = 0; j < new_beam_size; j++) {
116
+ new_distances_i[j] = C::neutral();
117
+ }
118
+ std::vector<int> perm(new_beam_size, -1);
119
+ heap_addn<C>(
120
+ new_beam_size,
121
+ new_distances_i,
122
+ perm.data(),
123
+ cent_distances_i,
124
+ nullptr,
125
+ beam_size * new_beam_size);
126
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
127
+
128
+ for (size_t j = 0; j < new_beam_size; j++) {
129
+ int js = perm[j] / new_beam_size;
130
+ int ls = cent_ids_i[perm[j]];
131
+ if (m > 0) {
132
+ memcpy(new_codes_i,
133
+ codes_i + js * m,
134
+ sizeof(*codes) * m);
135
+ }
136
+ new_codes_i[m] = ls;
137
+ new_codes_i += m + 1;
138
+ fvec_sub(
139
+ d,
140
+ residuals_i + js * d,
141
+ cent + ls * d,
142
+ new_residuals_i);
143
+ new_residuals_i += d;
144
+ }
308
145
 
309
- for (int j = 0; j < new_beam_size; j++) {
310
- int js = perm[j] / new_beam_size;
311
- int ls = cent_ids_i[perm[j]];
312
- if (m > 0) {
313
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
146
+ } else {
147
+ const float* cent_distances_i =
148
+ cent_distances.data() + i * beam_size * K;
149
+ // then we have to select the best results
150
+ for (size_t j = 0; j < new_beam_size; j++) {
151
+ new_distances_i[j] = C::neutral();
314
152
  }
315
- new_codes_i[m] = ls;
316
- new_codes_i += m + 1;
317
- fvec_sub(
318
- d,
319
- residuals_i + js * d,
320
- cent + ls * d,
321
- new_residuals_i);
322
- new_residuals_i += d;
323
- }
153
+ std::vector<int> perm(new_beam_size, -1);
324
154
 
325
- } else {
326
- const float* cent_distances_i =
327
- cent_distances.data() + i * beam_size * K;
328
- // then we have to select the best results
329
- for (int j = 0; j < new_beam_size; j++) {
330
- new_distances_i[j] = C::neutral();
155
+ approx_topk_by_mode<SL>(
156
+ approx_topk_mode,
157
+ beam_size,
158
+ K,
159
+ cent_distances_i,
160
+ new_beam_size,
161
+ new_distances_i,
162
+ perm.data());
163
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
164
+
165
+ for (size_t j = 0; j < new_beam_size; j++) {
166
+ int js = perm[j] / K;
167
+ int ls = perm[j] % K;
168
+ if (m > 0) {
169
+ memcpy(new_codes_i,
170
+ codes_i + js * m,
171
+ sizeof(*codes) * m);
172
+ }
173
+ new_codes_i[m] = ls;
174
+ new_codes_i += m + 1;
175
+ fvec_sub(
176
+ d,
177
+ residuals_i + js * d,
178
+ cent + ls * d,
179
+ new_residuals_i);
180
+ new_residuals_i += d;
181
+ }
331
182
  }
332
- std::vector<int> perm(new_beam_size, -1);
183
+ }
184
+ });
185
+ }
186
+
187
+ // exposed in the faiss namespace
188
+
189
+ namespace {
190
+
191
+ // Baseline (scalar) implementation for computing cent_distances.
192
+ // Accumulates codebook cross-norms via fvec_add into a temporary buffer.
193
+ // Its primary flaw is that it writes too much to the temporary buffer dp.
194
+ // This code is kept because it is easy to understand what the optimized
195
+ // SIMD version (compute_cent_distances_simd) optimizes exactly.
196
+ void compute_cent_distances_baseline(
197
+ size_t K,
198
+ size_t beam_size,
199
+ const float* codebook_cross_norms,
200
+ size_t ldc,
201
+ const uint64_t* codebook_offsets,
202
+ size_t m,
203
+ const int32_t* codes_i,
204
+ const float* distances_i,
205
+ const float* cd_common,
206
+ float* cent_distances) {
207
+ for (size_t b = 0; b < beam_size; b++) {
208
+ std::vector<float> dp(K);
209
+ for (size_t m1 = 0; m1 < m; m1++) {
210
+ size_t c = codes_i[b * m + m1];
211
+ const float* cb =
212
+ &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
213
+ fvec_add(K, cb, dp.data(), dp.data());
214
+ }
215
+ for (size_t k = 0; k < K; k++) {
216
+ cent_distances[b * K + k] =
217
+ distances_i[b] + cd_common[k] + 2 * dp[k];
218
+ }
219
+ }
220
+ }
221
+
222
+ // SIMD-optimized implementation for computing cent_distances.
223
+ // Uses accum_and_finalize_tab / accum_and_store_tab / accum_and_add_tab
224
+ // to accumulate codebook cross-norms in SIMD registers.
225
+ template <SIMDLevel SL>
226
+ void compute_cent_distances_simd(
227
+ size_t K,
228
+ size_t beam_size,
229
+ const float* codebook_cross_norms,
230
+ size_t ldc,
231
+ const uint64_t* codebook_offsets,
232
+ size_t m,
233
+ const int32_t* codes_i,
234
+ const float* distances_i,
235
+ const float* cd_common,
236
+ float* cent_distances) {
237
+ auto do_finalize = [&]<size_t NK>() {
238
+ for (size_t b = 0; b < beam_size; b++) {
239
+ accum_and_finalize_tab<NK, 4, SL>(
240
+ codebook_cross_norms,
241
+ codebook_offsets,
242
+ codes_i,
243
+ b,
244
+ ldc,
245
+ K,
246
+ distances_i,
247
+ cd_common,
248
+ cent_distances);
249
+ }
250
+ };
333
251
 
334
- #define HANDLE_APPROX(NB, BD) \
335
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
336
- HeapWithBuckets<C, NB, BD>::bs_addn( \
337
- beam_size, \
338
- K, \
339
- cent_distances_i, \
340
- new_beam_size, \
341
- new_distances_i, \
342
- perm.data()); \
343
- break;
344
-
345
- switch (approx_topk_mode) {
346
- HANDLE_APPROX(8, 3)
347
- HANDLE_APPROX(8, 2)
348
- HANDLE_APPROX(16, 2)
349
- HANDLE_APPROX(32, 2)
350
- default:
351
- heap_addn<C>(
352
- new_beam_size,
353
- new_distances_i,
354
- perm.data(),
355
- cent_distances_i,
356
- nullptr,
357
- beam_size * K);
252
+ switch (m) {
253
+ case 0:
254
+ for (size_t b = 0; b < beam_size; b++) {
255
+ for (size_t k = 0; k < K; k++) {
256
+ cent_distances[b * K + k] = distances_i[b] + cd_common[k];
257
+ }
358
258
  }
359
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
259
+ break;
260
+ case 1:
261
+ do_finalize.template operator()<1>();
262
+ break;
263
+ case 2:
264
+ do_finalize.template operator()<2>();
265
+ break;
266
+ case 3:
267
+ do_finalize.template operator()<3>();
268
+ break;
269
+ case 4:
270
+ do_finalize.template operator()<4>();
271
+ break;
272
+ case 5:
273
+ do_finalize.template operator()<5>();
274
+ break;
275
+ case 6:
276
+ do_finalize.template operator()<6>();
277
+ break;
278
+ case 7:
279
+ do_finalize.template operator()<7>();
280
+ break;
281
+ default: {
282
+ // m >= 8: accumulate in chunks of 8 into a temporary buffer.
283
+ std::vector<float> dp(K);
360
284
 
361
- #undef HANDLE_APPROX
285
+ for (size_t b = 0; b < beam_size; b++) {
286
+ accum_and_store_tab<8, 4, SL>(
287
+ m,
288
+ codebook_cross_norms,
289
+ codebook_offsets,
290
+ codes_i,
291
+ b,
292
+ ldc,
293
+ K,
294
+ dp.data());
295
+
296
+ for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
297
+ size_t m_left = std::min(m - im, size_t(8));
298
+ auto do_add = [&]<size_t NK2>() {
299
+ accum_and_add_tab<NK2, 4, SL>(
300
+ m,
301
+ codebook_cross_norms,
302
+ codebook_offsets + im,
303
+ codes_i + im,
304
+ b,
305
+ ldc,
306
+ K,
307
+ dp.data());
308
+ };
309
+ switch (m_left) {
310
+ case 1:
311
+ do_add.template operator()<1>();
312
+ break;
313
+ case 2:
314
+ do_add.template operator()<2>();
315
+ break;
316
+ case 3:
317
+ do_add.template operator()<3>();
318
+ break;
319
+ case 4:
320
+ do_add.template operator()<4>();
321
+ break;
322
+ case 5:
323
+ do_add.template operator()<5>();
324
+ break;
325
+ case 6:
326
+ do_add.template operator()<6>();
327
+ break;
328
+ case 7:
329
+ do_add.template operator()<7>();
330
+ break;
331
+ case 8:
332
+ do_add.template operator()<8>();
333
+ break;
334
+ }
335
+ }
362
336
 
363
- for (int j = 0; j < new_beam_size; j++) {
364
- int js = perm[j] / K;
365
- int ls = perm[j] % K;
366
- if (m > 0) {
367
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
337
+ for (size_t k = 0; k < K; k++) {
338
+ cent_distances[b * K + k] =
339
+ distances_i[b] + cd_common[k] + 2 * dp[k];
368
340
  }
369
- new_codes_i[m] = ls;
370
- new_codes_i += m + 1;
371
- fvec_sub(
372
- d,
373
- residuals_i + js * d,
374
- cent + ls * d,
375
- new_residuals_i);
376
- new_residuals_i += d;
377
341
  }
378
342
  }
379
343
  }
380
344
  }
381
345
 
382
- // exposed in the faiss namespace
346
+ } // anonymous namespace
347
+
383
348
  void beam_search_encode_step_tab(
384
349
  size_t K,
385
350
  size_t n,
@@ -400,211 +365,80 @@ void beam_search_encode_step_tab(
400
365
  {
401
366
  FAISS_THROW_IF_NOT(ldc >= K);
402
367
 
368
+ // Resolve SIMD level once, not per iteration of the n-parallel loop.
369
+ with_simd_level_256bit([&]<SIMDLevel SL>() {
403
370
  #pragma omp parallel for if (n > 100) schedule(dynamic)
404
- for (int64_t i = 0; i < n; i++) {
405
- std::vector<float> cent_distances(beam_size * K);
406
- std::vector<float> cd_common(K);
371
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
372
+ std::vector<float> cent_distances(beam_size * K);
373
+ std::vector<float> cd_common(K);
407
374
 
408
- const int32_t* codes_i = codes + i * m * beam_size;
409
- const float* query_cp_i = query_cp + i * ldqc;
410
- const float* distances_i = distances + i * beam_size;
375
+ const int32_t* codes_i = codes + i * m * beam_size;
376
+ const float* query_cp_i = query_cp + i * ldqc;
377
+ const float* distances_i = distances + i * beam_size;
411
378
 
412
- for (size_t k = 0; k < K; k++) {
413
- cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
414
- }
415
-
416
- bool use_baseline_implementation = false;
417
-
418
- // This is the baseline implementation. Its primary flaw
419
- // that it writes way too many info to the temporary buffer
420
- // called dp.
421
- //
422
- // This baseline code is kept intentionally because it is easy to
423
- // understand what an optimized version optimizes exactly.
424
- //
425
- if (use_baseline_implementation) {
426
- for (size_t b = 0; b < beam_size; b++) {
427
- std::vector<float> dp(K);
428
-
429
- for (size_t m1 = 0; m1 < m; m1++) {
430
- size_t c = codes_i[b * m + m1];
431
- const float* cb =
432
- &codebook_cross_norms
433
- [(codebook_offsets[m1] + c) * ldc];
434
- fvec_add(K, cb, dp.data(), dp.data());
435
- }
436
-
437
- for (size_t k = 0; k < K; k++) {
438
- cent_distances[b * K + k] =
439
- distances_i[b] + cd_common[k] + 2 * dp[k];
440
- }
379
+ for (size_t k = 0; k < K; k++) {
380
+ cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
441
381
  }
442
382
 
443
- } else {
444
- // An optimized implementation that avoids using a temporary buffer
445
- // and does the accumulation in registers.
446
-
447
- // Compute a sum of NK AQ codes.
448
- #define ACCUM_AND_FINALIZE_TAB(NK) \
449
- case NK: \
450
- for (size_t b = 0; b < beam_size; b++) { \
451
- accum_and_finalize_tab<NK, 4>( \
452
- codebook_cross_norms, \
453
- codebook_offsets, \
454
- codes_i, \
455
- b, \
456
- ldc, \
457
- K, \
458
- distances_i, \
459
- cd_common.data(), \
460
- cent_distances.data()); \
461
- } \
462
- break;
463
-
464
- // this version contains many switch-case scenarios, but
465
- // they won't affect branch predictor.
466
- switch (m) {
467
- case 0:
468
- // trivial case
469
- for (size_t b = 0; b < beam_size; b++) {
470
- for (size_t k = 0; k < K; k++) {
471
- cent_distances[b * K + k] =
472
- distances_i[b] + cd_common[k];
473
- }
474
- }
475
- break;
476
-
477
- ACCUM_AND_FINALIZE_TAB(1)
478
- ACCUM_AND_FINALIZE_TAB(2)
479
- ACCUM_AND_FINALIZE_TAB(3)
480
- ACCUM_AND_FINALIZE_TAB(4)
481
- ACCUM_AND_FINALIZE_TAB(5)
482
- ACCUM_AND_FINALIZE_TAB(6)
483
- ACCUM_AND_FINALIZE_TAB(7)
484
-
485
- default: {
486
- // m >= 8 case.
487
-
488
- // A temporary buffer has to be used due to the lack of
489
- // registers. But we'll try to accumulate up to 8 AQ codes
490
- // in registers and issue a single write operation to the
491
- // buffer, while the baseline does no accumulation. So, the
492
- // number of write operations to the temporary buffer is
493
- // reduced 8x.
494
-
495
- // allocate a temporary buffer
496
- std::vector<float> dp(K);
497
-
498
- for (size_t b = 0; b < beam_size; b++) {
499
- // Initialize it. Compute a sum of first 8 AQ codes
500
- // because m >= 8 .
501
- accum_and_store_tab<8, 4>(
502
- m,
503
- codebook_cross_norms,
504
- codebook_offsets,
505
- codes_i,
506
- b,
507
- ldc,
508
- K,
509
- dp.data());
510
-
511
- #define ACCUM_AND_ADD_TAB(NK) \
512
- case NK: \
513
- accum_and_add_tab<NK, 4>( \
514
- m, \
515
- codebook_cross_norms, \
516
- codebook_offsets + im, \
517
- codes_i + im, \
518
- b, \
519
- ldc, \
520
- K, \
521
- dp.data()); \
522
- break;
523
-
524
- // accumulate up to 8 additional AQ codes into
525
- // a temporary buffer
526
- for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
527
- size_t m_left = m - im;
528
- if (m_left > 8) {
529
- m_left = 8;
530
- }
531
-
532
- switch (m_left) {
533
- ACCUM_AND_ADD_TAB(1)
534
- ACCUM_AND_ADD_TAB(2)
535
- ACCUM_AND_ADD_TAB(3)
536
- ACCUM_AND_ADD_TAB(4)
537
- ACCUM_AND_ADD_TAB(5)
538
- ACCUM_AND_ADD_TAB(6)
539
- ACCUM_AND_ADD_TAB(7)
540
- ACCUM_AND_ADD_TAB(8)
541
- }
542
- }
543
-
544
- // done. finalize the result
545
- for (size_t k = 0; k < K; k++) {
546
- cent_distances[b * K + k] =
547
- distances_i[b] + cd_common[k] + 2 * dp[k];
548
- }
549
- }
550
- }
383
+ if constexpr (SL == SIMDLevel::NONE) {
384
+ compute_cent_distances_baseline(
385
+ K,
386
+ beam_size,
387
+ codebook_cross_norms,
388
+ ldc,
389
+ codebook_offsets,
390
+ m,
391
+ codes_i,
392
+ distances_i,
393
+ cd_common.data(),
394
+ cent_distances.data());
395
+ } else {
396
+ compute_cent_distances_simd<SL>(
397
+ K,
398
+ beam_size,
399
+ codebook_cross_norms,
400
+ ldc,
401
+ codebook_offsets,
402
+ m,
403
+ codes_i,
404
+ distances_i,
405
+ cd_common.data(),
406
+ cent_distances.data());
551
407
  }
552
408
 
553
- // the optimized implementation ends here
554
- }
555
- using C = CMax<float, int>;
556
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
557
- float* new_distances_i = new_distances + i * new_beam_size;
409
+ using C = CMax<float, int>;
410
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
411
+ float* new_distances_i = new_distances + i * new_beam_size;
558
412
 
559
- const float* cent_distances_i = cent_distances.data();
413
+ const float* cent_distances_i = cent_distances.data();
560
414
 
561
- // then we have to select the best results
562
- for (int j = 0; j < new_beam_size; j++) {
563
- new_distances_i[j] = C::neutral();
564
- }
565
- std::vector<int> perm(new_beam_size, -1);
566
-
567
- #define HANDLE_APPROX(NB, BD) \
568
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
569
- HeapWithBuckets<C, NB, BD>::bs_addn( \
570
- beam_size, \
571
- K, \
572
- cent_distances_i, \
573
- new_beam_size, \
574
- new_distances_i, \
575
- perm.data()); \
576
- break;
577
-
578
- switch (approx_topk_mode) {
579
- HANDLE_APPROX(8, 3)
580
- HANDLE_APPROX(8, 2)
581
- HANDLE_APPROX(16, 2)
582
- HANDLE_APPROX(32, 2)
583
- default:
584
- heap_addn<C>(
585
- new_beam_size,
586
- new_distances_i,
587
- perm.data(),
588
- cent_distances_i,
589
- nullptr,
590
- beam_size * K);
591
- break;
592
- }
593
-
594
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
415
+ // then we have to select the best results
416
+ for (size_t j = 0; j < new_beam_size; j++) {
417
+ new_distances_i[j] = C::neutral();
418
+ }
419
+ std::vector<int> perm(new_beam_size, -1);
595
420
 
596
- #undef HANDLE_APPROX
421
+ approx_topk_by_mode<SL>(
422
+ approx_topk_mode,
423
+ beam_size,
424
+ K,
425
+ cent_distances_i,
426
+ new_beam_size,
427
+ new_distances_i,
428
+ perm.data());
429
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
597
430
 
598
- for (int j = 0; j < new_beam_size; j++) {
599
- int js = perm[j] / K;
600
- int ls = perm[j] % K;
601
- if (m > 0) {
602
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
431
+ for (size_t j = 0; j < new_beam_size; j++) {
432
+ int js = perm[j] / K;
433
+ int ls = perm[j] % K;
434
+ if (m > 0) {
435
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
436
+ }
437
+ new_codes_i[m] = ls;
438
+ new_codes_i += m + 1;
603
439
  }
604
- new_codes_i[m] = ls;
605
- new_codes_i += m + 1;
606
440
  }
607
- }
441
+ });
608
442
  }
609
443
 
610
444
  /********************************************************************
@@ -631,7 +465,7 @@ void refine_beam_mp(
631
465
  int max_beam_size = 0;
632
466
  {
633
467
  int tmp_beam_size = cur_beam_size;
634
- for (int m = 0; m < rq.M; m++) {
468
+ for (size_t m = 0; m < rq.M; m++) {
635
469
  int K = 1 << rq.nbits[m];
636
470
  int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
637
471
  tmp_beam_size = new_beam_size;
@@ -672,7 +506,7 @@ void refine_beam_mp(
672
506
  size_t distances_size = 0;
673
507
  size_t residuals_size = 0;
674
508
 
675
- for (int m = 0; m < rq.M; m++) {
509
+ for (size_t m = 0; m < rq.M; m++) {
676
510
  int K = 1 << rq.nbits[m];
677
511
 
678
512
  const float* __restrict codebooks_m =
@@ -711,14 +545,14 @@ void refine_beam_mp(
711
545
 
712
546
  if (rq.verbose) {
713
547
  float sum_distances = 0;
714
- for (int j = 0; j < distances_size; j++) {
548
+ for (size_t j = 0; j < distances_size; j++) {
715
549
  sum_distances += pool.distances[j];
716
550
  }
717
551
 
718
552
  printf("[%.3f s] encode stage %d, %d bits, "
719
553
  "total error %g, beam_size %d\n",
720
554
  (getmillisecs() - t0) / 1000,
721
- m,
555
+ int(m),
722
556
  int(rq.nbits[m]),
723
557
  sum_distances,
724
558
  cur_beam_size);
@@ -757,7 +591,7 @@ void refine_beam_LUT_mp(
757
591
  int max_beam_size = 0;
758
592
  {
759
593
  int tmp_beam_size = beam_size;
760
- for (int m = 0; m < rq.M; m++) {
594
+ for (size_t m = 0; m < rq.M; m++) {
761
595
  int K = 1 << rq.nbits[m];
762
596
  int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
763
597
  tmp_beam_size = new_beam_size;
@@ -790,7 +624,7 @@ void refine_beam_LUT_mp(
790
624
  size_t codes_size = 0;
791
625
  size_t distances_size = 0;
792
626
  size_t cross_ofs = 0;
793
- for (int m = 0; m < rq.M; m++) {
627
+ for (size_t m = 0; m < rq.M; m++) {
794
628
  int K = 1 << rq.nbits[m];
795
629
 
796
630
  // it is guaranteed that (new_beam_size <= max_beam_size)
@@ -826,13 +660,13 @@ void refine_beam_LUT_mp(
826
660
 
827
661
  if (rq.verbose) {
828
662
  float sum_distances = 0;
829
- for (int j = 0; j < distances_size; j++) {
663
+ for (size_t j = 0; j < distances_size; j++) {
830
664
  sum_distances += distances_ptr[j];
831
665
  }
832
666
  printf("[%.3f s] encode stage %d, %d bits, "
833
667
  "total error %g, beam_size %d\n",
834
668
  (getmillisecs() - t0) / 1000,
835
- m,
669
+ int(m),
836
670
  int(rq.nbits[m]),
837
671
  sum_distances,
838
672
  beam_size);