faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -10,12 +10,13 @@
10
10
  #include <faiss/impl/AuxIndexStructures.h>
11
11
  #include <faiss/impl/FaissAssert.h>
12
12
  #include <faiss/impl/ResidualQuantizer.h>
13
+ #include <faiss/impl/simd_dispatch.h>
13
14
  #include <faiss/utils/Heap.h>
14
15
  #include <faiss/utils/distances.h>
15
- #include <faiss/utils/simdlib.h>
16
16
  #include <faiss/utils/utils.h>
17
17
 
18
- #include <faiss/utils/approx_topk/approx_topk.h>
18
+ #include <faiss/impl/approx_topk/approx_topk.h>
19
+ #include <faiss/impl/approx_topk/rq_beam_search_tab.h>
19
20
 
20
21
  extern "C" {
21
22
 
@@ -38,190 +39,6 @@ int sgemm_(
38
39
 
39
40
  namespace faiss {
40
41
 
41
- /********************************************************************
42
- * Basic routines
43
- ********************************************************************/
44
-
45
- namespace {
46
-
47
- template <size_t M, size_t NK>
48
- void accum_and_store_tab(
49
- const size_t m_offset,
50
- const float* const __restrict codebook_cross_norms,
51
- const uint64_t* const __restrict codebook_offsets,
52
- const int32_t* const __restrict codes_i,
53
- const size_t b,
54
- const size_t ldc,
55
- const size_t K,
56
- float* const __restrict output) {
57
- // load pointers into registers
58
- const float* cbs[M];
59
- for (size_t ij = 0; ij < M; ij++) {
60
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
61
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
62
- }
63
-
64
- // do accumulation in registers using SIMD.
65
- // It is possible that compiler may be smart enough so that
66
- // this manual SIMD unrolling might be unneeded.
67
- #if defined(__AVX2__) || defined(__aarch64__)
68
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
69
-
70
- // process in chunks of size (8 * NK) floats
71
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
72
- simd8float32 regs[NK];
73
- for (size_t ik = 0; ik < NK; ik++) {
74
- regs[ik].loadu(cbs[0] + kk + ik * 8);
75
- }
76
-
77
- for (size_t ij = 1; ij < M; ij++) {
78
- for (size_t ik = 0; ik < NK; ik++) {
79
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
80
- }
81
- }
82
-
83
- // write the result
84
- for (size_t ik = 0; ik < NK; ik++) {
85
- regs[ik].storeu(output + kk + ik * 8);
86
- }
87
- }
88
- #else
89
- const size_t K8 = 0;
90
- #endif
91
-
92
- // process leftovers
93
- for (size_t kk = K8; kk < K; kk++) {
94
- float reg = cbs[0][kk];
95
- for (size_t ij = 1; ij < M; ij++) {
96
- reg += cbs[ij][kk];
97
- }
98
- output[kk] = reg;
99
- }
100
- }
101
-
102
- template <size_t M, size_t NK>
103
- void accum_and_add_tab(
104
- const size_t m_offset,
105
- const float* const __restrict codebook_cross_norms,
106
- const uint64_t* const __restrict codebook_offsets,
107
- const int32_t* const __restrict codes_i,
108
- const size_t b,
109
- const size_t ldc,
110
- const size_t K,
111
- float* const __restrict output) {
112
- // load pointers into registers
113
- const float* cbs[M];
114
- for (size_t ij = 0; ij < M; ij++) {
115
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
116
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
117
- }
118
-
119
- // do accumulation in registers using SIMD.
120
- // It is possible that compiler may be smart enough so that
121
- // this manual SIMD unrolling might be unneeded.
122
- #if defined(__AVX2__) || defined(__aarch64__)
123
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
124
-
125
- // process in chunks of size (8 * NK) floats
126
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
127
- simd8float32 regs[NK];
128
- for (size_t ik = 0; ik < NK; ik++) {
129
- regs[ik].loadu(cbs[0] + kk + ik * 8);
130
- }
131
-
132
- for (size_t ij = 1; ij < M; ij++) {
133
- for (size_t ik = 0; ik < NK; ik++) {
134
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
135
- }
136
- }
137
-
138
- // write the result
139
- for (size_t ik = 0; ik < NK; ik++) {
140
- simd8float32 existing(output + kk + ik * 8);
141
- existing += regs[ik];
142
- existing.storeu(output + kk + ik * 8);
143
- }
144
- }
145
- #else
146
- const size_t K8 = 0;
147
- #endif
148
-
149
- // process leftovers
150
- for (size_t kk = K8; kk < K; kk++) {
151
- float reg = cbs[0][kk];
152
- for (size_t ij = 1; ij < M; ij++) {
153
- reg += cbs[ij][kk];
154
- }
155
- output[kk] += reg;
156
- }
157
- }
158
-
159
- template <size_t M, size_t NK>
160
- void accum_and_finalize_tab(
161
- const float* const __restrict codebook_cross_norms,
162
- const uint64_t* const __restrict codebook_offsets,
163
- const int32_t* const __restrict codes_i,
164
- const size_t b,
165
- const size_t ldc,
166
- const size_t K,
167
- const float* const __restrict distances_i,
168
- const float* const __restrict cd_common,
169
- float* const __restrict output) {
170
- // load pointers into registers
171
- const float* cbs[M];
172
- for (size_t ij = 0; ij < M; ij++) {
173
- const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
174
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
175
- }
176
-
177
- // do accumulation in registers using SIMD.
178
- // It is possible that compiler may be smart enough so that
179
- // this manual SIMD unrolling might be unneeded.
180
- #if defined(__AVX2__) || defined(__aarch64__)
181
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
182
-
183
- // process in chunks of size (8 * NK) floats
184
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
185
- simd8float32 regs[NK];
186
- for (size_t ik = 0; ik < NK; ik++) {
187
- regs[ik].loadu(cbs[0] + kk + ik * 8);
188
- }
189
-
190
- for (size_t ij = 1; ij < M; ij++) {
191
- for (size_t ik = 0; ik < NK; ik++) {
192
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
193
- }
194
- }
195
-
196
- simd8float32 two(2.0f);
197
- for (size_t ik = 0; ik < NK; ik++) {
198
- // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
199
- // + 2 * dp[k];
200
-
201
- simd8float32 common_v(cd_common + kk + ik * 8);
202
- common_v = fmadd(two, regs[ik], common_v);
203
-
204
- common_v += simd8float32(distances_i[b]);
205
- common_v.storeu(output + b * K + kk + ik * 8);
206
- }
207
- }
208
- #else
209
- const size_t K8 = 0;
210
- #endif
211
-
212
- // process leftovers
213
- for (size_t kk = K8; kk < K; kk++) {
214
- float reg = cbs[0][kk];
215
- for (size_t ij = 1; ij < M; ij++) {
216
- reg += cbs[ij][kk];
217
- }
218
-
219
- output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
220
- }
221
- }
222
-
223
- } // anonymous namespace
224
-
225
42
  /********************************************************************
226
43
  * Single encoding step
227
44
  ********************************************************************/
@@ -249,12 +66,12 @@ void beam_search_encode_step(
249
66
 
250
67
  if (assign_index) {
251
68
  // search beam_size distances per query
252
- FAISS_THROW_IF_NOT(assign_index->d == d);
69
+ FAISS_THROW_IF_NOT(assign_index->d == static_cast<int>(d));
253
70
  cent_distances.resize(n * beam_size * new_beam_size);
254
71
  cent_ids.resize(n * beam_size * new_beam_size);
255
72
  if (assign_index->ntotal != 0) {
256
73
  // then we assume the codebooks are already added to the index
257
- FAISS_THROW_IF_NOT(assign_index->ntotal == K);
74
+ FAISS_THROW_IF_NOT(assign_index->ntotal == static_cast<idx_t>(K));
258
75
  } else {
259
76
  assign_index->add(K, cent);
260
77
  }
@@ -275,110 +92,259 @@ void beam_search_encode_step(
275
92
  }
276
93
  InterruptCallback::check();
277
94
 
95
+ // Resolve SIMD level once, not per iteration of the n-parallel loop.
96
+ with_simd_level_256bit([&]<SIMDLevel SL>() {
278
97
  #pragma omp parallel for if (n > 100)
279
- for (int64_t i = 0; i < n; i++) {
280
- const int32_t* codes_i = codes + i * m * beam_size;
281
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
282
- const float* residuals_i = residuals + i * d * beam_size;
283
- float* new_residuals_i = new_residuals + i * d * new_beam_size;
284
-
285
- float* new_distances_i = new_distances + i * new_beam_size;
286
- using C = CMax<float, int>;
287
-
288
- if (assign_index) {
289
- const float* cent_distances_i =
290
- cent_distances.data() + i * beam_size * new_beam_size;
291
- const idx_t* cent_ids_i =
292
- cent_ids.data() + i * beam_size * new_beam_size;
293
-
294
- // here we could be a tad more efficient by merging sorted arrays
295
- for (int j = 0; j < new_beam_size; j++) {
296
- new_distances_i[j] = C::neutral();
297
- }
298
- std::vector<int> perm(new_beam_size, -1);
299
- heap_addn<C>(
300
- new_beam_size,
301
- new_distances_i,
302
- perm.data(),
303
- cent_distances_i,
304
- nullptr,
305
- beam_size * new_beam_size);
306
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
98
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
99
+ const int32_t* codes_i = codes + i * m * beam_size;
100
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
101
+ const float* residuals_i = residuals + i * d * beam_size;
102
+ float* new_residuals_i = new_residuals + i * d * new_beam_size;
103
+
104
+ float* new_distances_i = new_distances + i * new_beam_size;
105
+ using C = CMax<float, int>;
106
+
107
+ if (assign_index) {
108
+ const float* cent_distances_i =
109
+ cent_distances.data() + i * beam_size * new_beam_size;
110
+ const idx_t* cent_ids_i =
111
+ cent_ids.data() + i * beam_size * new_beam_size;
112
+
113
+ // here we could be a tad more efficient by merging sorted
114
+ // arrays
115
+ for (size_t j = 0; j < new_beam_size; j++) {
116
+ new_distances_i[j] = C::neutral();
117
+ }
118
+ std::vector<int> perm(new_beam_size, -1);
119
+ heap_addn<C>(
120
+ new_beam_size,
121
+ new_distances_i,
122
+ perm.data(),
123
+ cent_distances_i,
124
+ nullptr,
125
+ beam_size * new_beam_size);
126
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
127
+
128
+ for (size_t j = 0; j < new_beam_size; j++) {
129
+ int js = perm[j] / new_beam_size;
130
+ int ls = cent_ids_i[perm[j]];
131
+ if (m > 0) {
132
+ memcpy(new_codes_i,
133
+ codes_i + js * m,
134
+ sizeof(*codes) * m);
135
+ }
136
+ new_codes_i[m] = ls;
137
+ new_codes_i += m + 1;
138
+ fvec_sub(
139
+ d,
140
+ residuals_i + js * d,
141
+ cent + ls * d,
142
+ new_residuals_i);
143
+ new_residuals_i += d;
144
+ }
307
145
 
308
- for (int j = 0; j < new_beam_size; j++) {
309
- int js = perm[j] / new_beam_size;
310
- int ls = cent_ids_i[perm[j]];
311
- if (m > 0) {
312
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
146
+ } else {
147
+ const float* cent_distances_i =
148
+ cent_distances.data() + i * beam_size * K;
149
+ // then we have to select the best results
150
+ for (size_t j = 0; j < new_beam_size; j++) {
151
+ new_distances_i[j] = C::neutral();
313
152
  }
314
- new_codes_i[m] = ls;
315
- new_codes_i += m + 1;
316
- fvec_sub(
317
- d,
318
- residuals_i + js * d,
319
- cent + ls * d,
320
- new_residuals_i);
321
- new_residuals_i += d;
322
- }
153
+ std::vector<int> perm(new_beam_size, -1);
323
154
 
324
- } else {
325
- const float* cent_distances_i =
326
- cent_distances.data() + i * beam_size * K;
327
- // then we have to select the best results
328
- for (int j = 0; j < new_beam_size; j++) {
329
- new_distances_i[j] = C::neutral();
155
+ approx_topk_by_mode<SL>(
156
+ approx_topk_mode,
157
+ beam_size,
158
+ K,
159
+ cent_distances_i,
160
+ new_beam_size,
161
+ new_distances_i,
162
+ perm.data());
163
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
164
+
165
+ for (size_t j = 0; j < new_beam_size; j++) {
166
+ int js = perm[j] / K;
167
+ int ls = perm[j] % K;
168
+ if (m > 0) {
169
+ memcpy(new_codes_i,
170
+ codes_i + js * m,
171
+ sizeof(*codes) * m);
172
+ }
173
+ new_codes_i[m] = ls;
174
+ new_codes_i += m + 1;
175
+ fvec_sub(
176
+ d,
177
+ residuals_i + js * d,
178
+ cent + ls * d,
179
+ new_residuals_i);
180
+ new_residuals_i += d;
181
+ }
330
182
  }
331
- std::vector<int> perm(new_beam_size, -1);
183
+ }
184
+ });
185
+ }
186
+
187
+ // exposed in the faiss namespace
188
+
189
+ namespace {
190
+
191
+ // Baseline (scalar) implementation for computing cent_distances.
192
+ // Accumulates codebook cross-norms via fvec_add into a temporary buffer.
193
+ // Its primary flaw is that it writes too much to the temporary buffer dp.
194
+ // This code is kept because it is easy to understand what the optimized
195
+ // SIMD version (compute_cent_distances_simd) optimizes exactly.
196
+ void compute_cent_distances_baseline(
197
+ size_t K,
198
+ size_t beam_size,
199
+ const float* codebook_cross_norms,
200
+ size_t ldc,
201
+ const uint64_t* codebook_offsets,
202
+ size_t m,
203
+ const int32_t* codes_i,
204
+ const float* distances_i,
205
+ const float* cd_common,
206
+ float* cent_distances) {
207
+ for (size_t b = 0; b < beam_size; b++) {
208
+ std::vector<float> dp(K);
209
+ for (size_t m1 = 0; m1 < m; m1++) {
210
+ size_t c = codes_i[b * m + m1];
211
+ const float* cb =
212
+ &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
213
+ fvec_add(K, cb, dp.data(), dp.data());
214
+ }
215
+ for (size_t k = 0; k < K; k++) {
216
+ cent_distances[b * K + k] =
217
+ distances_i[b] + cd_common[k] + 2 * dp[k];
218
+ }
219
+ }
220
+ }
221
+
222
+ // SIMD-optimized implementation for computing cent_distances.
223
+ // Uses accum_and_finalize_tab / accum_and_store_tab / accum_and_add_tab
224
+ // to accumulate codebook cross-norms in SIMD registers.
225
+ template <SIMDLevel SL>
226
+ void compute_cent_distances_simd(
227
+ size_t K,
228
+ size_t beam_size,
229
+ const float* codebook_cross_norms,
230
+ size_t ldc,
231
+ const uint64_t* codebook_offsets,
232
+ size_t m,
233
+ const int32_t* codes_i,
234
+ const float* distances_i,
235
+ const float* cd_common,
236
+ float* cent_distances) {
237
+ auto do_finalize = [&]<size_t NK>() {
238
+ for (size_t b = 0; b < beam_size; b++) {
239
+ accum_and_finalize_tab<NK, 4, SL>(
240
+ codebook_cross_norms,
241
+ codebook_offsets,
242
+ codes_i,
243
+ b,
244
+ ldc,
245
+ K,
246
+ distances_i,
247
+ cd_common,
248
+ cent_distances);
249
+ }
250
+ };
332
251
 
333
- #define HANDLE_APPROX(NB, BD) \
334
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
335
- HeapWithBuckets<C, NB, BD>::bs_addn( \
336
- beam_size, \
337
- K, \
338
- cent_distances_i, \
339
- new_beam_size, \
340
- new_distances_i, \
341
- perm.data()); \
342
- break;
343
-
344
- switch (approx_topk_mode) {
345
- HANDLE_APPROX(8, 3)
346
- HANDLE_APPROX(8, 2)
347
- HANDLE_APPROX(16, 2)
348
- HANDLE_APPROX(32, 2)
349
- default:
350
- heap_addn<C>(
351
- new_beam_size,
352
- new_distances_i,
353
- perm.data(),
354
- cent_distances_i,
355
- nullptr,
356
- beam_size * K);
252
+ switch (m) {
253
+ case 0:
254
+ for (size_t b = 0; b < beam_size; b++) {
255
+ for (size_t k = 0; k < K; k++) {
256
+ cent_distances[b * K + k] = distances_i[b] + cd_common[k];
257
+ }
357
258
  }
358
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
259
+ break;
260
+ case 1:
261
+ do_finalize.template operator()<1>();
262
+ break;
263
+ case 2:
264
+ do_finalize.template operator()<2>();
265
+ break;
266
+ case 3:
267
+ do_finalize.template operator()<3>();
268
+ break;
269
+ case 4:
270
+ do_finalize.template operator()<4>();
271
+ break;
272
+ case 5:
273
+ do_finalize.template operator()<5>();
274
+ break;
275
+ case 6:
276
+ do_finalize.template operator()<6>();
277
+ break;
278
+ case 7:
279
+ do_finalize.template operator()<7>();
280
+ break;
281
+ default: {
282
+ // m >= 8: accumulate in chunks of 8 into a temporary buffer.
283
+ std::vector<float> dp(K);
359
284
 
360
- #undef HANDLE_APPROX
285
+ for (size_t b = 0; b < beam_size; b++) {
286
+ accum_and_store_tab<8, 4, SL>(
287
+ m,
288
+ codebook_cross_norms,
289
+ codebook_offsets,
290
+ codes_i,
291
+ b,
292
+ ldc,
293
+ K,
294
+ dp.data());
295
+
296
+ for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
297
+ size_t m_left = std::min(m - im, size_t(8));
298
+ auto do_add = [&]<size_t NK2>() {
299
+ accum_and_add_tab<NK2, 4, SL>(
300
+ m,
301
+ codebook_cross_norms,
302
+ codebook_offsets + im,
303
+ codes_i + im,
304
+ b,
305
+ ldc,
306
+ K,
307
+ dp.data());
308
+ };
309
+ switch (m_left) {
310
+ case 1:
311
+ do_add.template operator()<1>();
312
+ break;
313
+ case 2:
314
+ do_add.template operator()<2>();
315
+ break;
316
+ case 3:
317
+ do_add.template operator()<3>();
318
+ break;
319
+ case 4:
320
+ do_add.template operator()<4>();
321
+ break;
322
+ case 5:
323
+ do_add.template operator()<5>();
324
+ break;
325
+ case 6:
326
+ do_add.template operator()<6>();
327
+ break;
328
+ case 7:
329
+ do_add.template operator()<7>();
330
+ break;
331
+ case 8:
332
+ do_add.template operator()<8>();
333
+ break;
334
+ }
335
+ }
361
336
 
362
- for (int j = 0; j < new_beam_size; j++) {
363
- int js = perm[j] / K;
364
- int ls = perm[j] % K;
365
- if (m > 0) {
366
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
337
+ for (size_t k = 0; k < K; k++) {
338
+ cent_distances[b * K + k] =
339
+ distances_i[b] + cd_common[k] + 2 * dp[k];
367
340
  }
368
- new_codes_i[m] = ls;
369
- new_codes_i += m + 1;
370
- fvec_sub(
371
- d,
372
- residuals_i + js * d,
373
- cent + ls * d,
374
- new_residuals_i);
375
- new_residuals_i += d;
376
341
  }
377
342
  }
378
343
  }
379
344
  }
380
345
 
381
- // exposed in the faiss namespace
346
+ } // anonymous namespace
347
+
382
348
  void beam_search_encode_step_tab(
383
349
  size_t K,
384
350
  size_t n,
@@ -399,211 +365,80 @@ void beam_search_encode_step_tab(
399
365
  {
400
366
  FAISS_THROW_IF_NOT(ldc >= K);
401
367
 
368
+ // Resolve SIMD level once, not per iteration of the n-parallel loop.
369
+ with_simd_level_256bit([&]<SIMDLevel SL>() {
402
370
  #pragma omp parallel for if (n > 100) schedule(dynamic)
403
- for (int64_t i = 0; i < n; i++) {
404
- std::vector<float> cent_distances(beam_size * K);
405
- std::vector<float> cd_common(K);
406
-
407
- const int32_t* codes_i = codes + i * m * beam_size;
408
- const float* query_cp_i = query_cp + i * ldqc;
409
- const float* distances_i = distances + i * beam_size;
371
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
372
+ std::vector<float> cent_distances(beam_size * K);
373
+ std::vector<float> cd_common(K);
410
374
 
411
- for (size_t k = 0; k < K; k++) {
412
- cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
413
- }
414
-
415
- bool use_baseline_implementation = false;
416
-
417
- // This is the baseline implementation. Its primary flaw
418
- // that it writes way too many info to the temporary buffer
419
- // called dp.
420
- //
421
- // This baseline code is kept intentionally because it is easy to
422
- // understand what an optimized version optimizes exactly.
423
- //
424
- if (use_baseline_implementation) {
425
- for (size_t b = 0; b < beam_size; b++) {
426
- std::vector<float> dp(K);
427
-
428
- for (size_t m1 = 0; m1 < m; m1++) {
429
- size_t c = codes_i[b * m + m1];
430
- const float* cb =
431
- &codebook_cross_norms
432
- [(codebook_offsets[m1] + c) * ldc];
433
- fvec_add(K, cb, dp.data(), dp.data());
434
- }
375
+ const int32_t* codes_i = codes + i * m * beam_size;
376
+ const float* query_cp_i = query_cp + i * ldqc;
377
+ const float* distances_i = distances + i * beam_size;
435
378
 
436
- for (size_t k = 0; k < K; k++) {
437
- cent_distances[b * K + k] =
438
- distances_i[b] + cd_common[k] + 2 * dp[k];
439
- }
379
+ for (size_t k = 0; k < K; k++) {
380
+ cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
440
381
  }
441
382
 
442
- } else {
443
- // An optimized implementation that avoids using a temporary buffer
444
- // and does the accumulation in registers.
445
-
446
- // Compute a sum of NK AQ codes.
447
- #define ACCUM_AND_FINALIZE_TAB(NK) \
448
- case NK: \
449
- for (size_t b = 0; b < beam_size; b++) { \
450
- accum_and_finalize_tab<NK, 4>( \
451
- codebook_cross_norms, \
452
- codebook_offsets, \
453
- codes_i, \
454
- b, \
455
- ldc, \
456
- K, \
457
- distances_i, \
458
- cd_common.data(), \
459
- cent_distances.data()); \
460
- } \
461
- break;
462
-
463
- // this version contains many switch-case scenarios, but
464
- // they won't affect branch predictor.
465
- switch (m) {
466
- case 0:
467
- // trivial case
468
- for (size_t b = 0; b < beam_size; b++) {
469
- for (size_t k = 0; k < K; k++) {
470
- cent_distances[b * K + k] =
471
- distances_i[b] + cd_common[k];
472
- }
473
- }
474
- break;
475
-
476
- ACCUM_AND_FINALIZE_TAB(1)
477
- ACCUM_AND_FINALIZE_TAB(2)
478
- ACCUM_AND_FINALIZE_TAB(3)
479
- ACCUM_AND_FINALIZE_TAB(4)
480
- ACCUM_AND_FINALIZE_TAB(5)
481
- ACCUM_AND_FINALIZE_TAB(6)
482
- ACCUM_AND_FINALIZE_TAB(7)
483
-
484
- default: {
485
- // m >= 8 case.
486
-
487
- // A temporary buffer has to be used due to the lack of
488
- // registers. But we'll try to accumulate up to 8 AQ codes
489
- // in registers and issue a single write operation to the
490
- // buffer, while the baseline does no accumulation. So, the
491
- // number of write operations to the temporary buffer is
492
- // reduced 8x.
493
-
494
- // allocate a temporary buffer
495
- std::vector<float> dp(K);
496
-
497
- for (size_t b = 0; b < beam_size; b++) {
498
- // Initialize it. Compute a sum of first 8 AQ codes
499
- // because m >= 8 .
500
- accum_and_store_tab<8, 4>(
501
- m,
502
- codebook_cross_norms,
503
- codebook_offsets,
504
- codes_i,
505
- b,
506
- ldc,
507
- K,
508
- dp.data());
509
-
510
- #define ACCUM_AND_ADD_TAB(NK) \
511
- case NK: \
512
- accum_and_add_tab<NK, 4>( \
513
- m, \
514
- codebook_cross_norms, \
515
- codebook_offsets + im, \
516
- codes_i + im, \
517
- b, \
518
- ldc, \
519
- K, \
520
- dp.data()); \
521
- break;
522
-
523
- // accumulate up to 8 additional AQ codes into
524
- // a temporary buffer
525
- for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
526
- size_t m_left = m - im;
527
- if (m_left > 8) {
528
- m_left = 8;
529
- }
530
-
531
- switch (m_left) {
532
- ACCUM_AND_ADD_TAB(1)
533
- ACCUM_AND_ADD_TAB(2)
534
- ACCUM_AND_ADD_TAB(3)
535
- ACCUM_AND_ADD_TAB(4)
536
- ACCUM_AND_ADD_TAB(5)
537
- ACCUM_AND_ADD_TAB(6)
538
- ACCUM_AND_ADD_TAB(7)
539
- ACCUM_AND_ADD_TAB(8)
540
- }
541
- }
542
-
543
- // done. finalize the result
544
- for (size_t k = 0; k < K; k++) {
545
- cent_distances[b * K + k] =
546
- distances_i[b] + cd_common[k] + 2 * dp[k];
547
- }
548
- }
549
- }
383
+ if constexpr (SL == SIMDLevel::NONE) {
384
+ compute_cent_distances_baseline(
385
+ K,
386
+ beam_size,
387
+ codebook_cross_norms,
388
+ ldc,
389
+ codebook_offsets,
390
+ m,
391
+ codes_i,
392
+ distances_i,
393
+ cd_common.data(),
394
+ cent_distances.data());
395
+ } else {
396
+ compute_cent_distances_simd<SL>(
397
+ K,
398
+ beam_size,
399
+ codebook_cross_norms,
400
+ ldc,
401
+ codebook_offsets,
402
+ m,
403
+ codes_i,
404
+ distances_i,
405
+ cd_common.data(),
406
+ cent_distances.data());
550
407
  }
551
408
 
552
- // the optimized implementation ends here
553
- }
554
- using C = CMax<float, int>;
555
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
556
- float* new_distances_i = new_distances + i * new_beam_size;
409
+ using C = CMax<float, int>;
410
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
411
+ float* new_distances_i = new_distances + i * new_beam_size;
557
412
 
558
- const float* cent_distances_i = cent_distances.data();
559
-
560
- // then we have to select the best results
561
- for (int j = 0; j < new_beam_size; j++) {
562
- new_distances_i[j] = C::neutral();
563
- }
564
- std::vector<int> perm(new_beam_size, -1);
565
-
566
- #define HANDLE_APPROX(NB, BD) \
567
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
568
- HeapWithBuckets<C, NB, BD>::bs_addn( \
569
- beam_size, \
570
- K, \
571
- cent_distances_i, \
572
- new_beam_size, \
573
- new_distances_i, \
574
- perm.data()); \
575
- break;
576
-
577
- switch (approx_topk_mode) {
578
- HANDLE_APPROX(8, 3)
579
- HANDLE_APPROX(8, 2)
580
- HANDLE_APPROX(16, 2)
581
- HANDLE_APPROX(32, 2)
582
- default:
583
- heap_addn<C>(
584
- new_beam_size,
585
- new_distances_i,
586
- perm.data(),
587
- cent_distances_i,
588
- nullptr,
589
- beam_size * K);
590
- break;
591
- }
413
+ const float* cent_distances_i = cent_distances.data();
592
414
 
593
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
415
+ // then we have to select the best results
416
+ for (size_t j = 0; j < new_beam_size; j++) {
417
+ new_distances_i[j] = C::neutral();
418
+ }
419
+ std::vector<int> perm(new_beam_size, -1);
594
420
 
595
- #undef HANDLE_APPROX
421
+ approx_topk_by_mode<SL>(
422
+ approx_topk_mode,
423
+ beam_size,
424
+ K,
425
+ cent_distances_i,
426
+ new_beam_size,
427
+ new_distances_i,
428
+ perm.data());
429
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
596
430
 
597
- for (int j = 0; j < new_beam_size; j++) {
598
- int js = perm[j] / K;
599
- int ls = perm[j] % K;
600
- if (m > 0) {
601
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
431
+ for (size_t j = 0; j < new_beam_size; j++) {
432
+ int js = perm[j] / K;
433
+ int ls = perm[j] % K;
434
+ if (m > 0) {
435
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
436
+ }
437
+ new_codes_i[m] = ls;
438
+ new_codes_i += m + 1;
602
439
  }
603
- new_codes_i[m] = ls;
604
- new_codes_i += m + 1;
605
440
  }
606
- }
441
+ });
607
442
  }
608
443
 
609
444
  /********************************************************************
@@ -630,7 +465,7 @@ void refine_beam_mp(
630
465
  int max_beam_size = 0;
631
466
  {
632
467
  int tmp_beam_size = cur_beam_size;
633
- for (int m = 0; m < rq.M; m++) {
468
+ for (size_t m = 0; m < rq.M; m++) {
634
469
  int K = 1 << rq.nbits[m];
635
470
  int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
636
471
  tmp_beam_size = new_beam_size;
@@ -671,7 +506,7 @@ void refine_beam_mp(
671
506
  size_t distances_size = 0;
672
507
  size_t residuals_size = 0;
673
508
 
674
- for (int m = 0; m < rq.M; m++) {
509
+ for (size_t m = 0; m < rq.M; m++) {
675
510
  int K = 1 << rq.nbits[m];
676
511
 
677
512
  const float* __restrict codebooks_m =
@@ -710,14 +545,14 @@ void refine_beam_mp(
710
545
 
711
546
  if (rq.verbose) {
712
547
  float sum_distances = 0;
713
- for (int j = 0; j < distances_size; j++) {
548
+ for (size_t j = 0; j < distances_size; j++) {
714
549
  sum_distances += pool.distances[j];
715
550
  }
716
551
 
717
552
  printf("[%.3f s] encode stage %d, %d bits, "
718
553
  "total error %g, beam_size %d\n",
719
554
  (getmillisecs() - t0) / 1000,
720
- m,
555
+ int(m),
721
556
  int(rq.nbits[m]),
722
557
  sum_distances,
723
558
  cur_beam_size);
@@ -756,7 +591,7 @@ void refine_beam_LUT_mp(
756
591
  int max_beam_size = 0;
757
592
  {
758
593
  int tmp_beam_size = beam_size;
759
- for (int m = 0; m < rq.M; m++) {
594
+ for (size_t m = 0; m < rq.M; m++) {
760
595
  int K = 1 << rq.nbits[m];
761
596
  int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
762
597
  tmp_beam_size = new_beam_size;
@@ -789,7 +624,7 @@ void refine_beam_LUT_mp(
789
624
  size_t codes_size = 0;
790
625
  size_t distances_size = 0;
791
626
  size_t cross_ofs = 0;
792
- for (int m = 0; m < rq.M; m++) {
627
+ for (size_t m = 0; m < rq.M; m++) {
793
628
  int K = 1 << rq.nbits[m];
794
629
 
795
630
  // it is guaranteed that (new_beam_size <= max_beam_size)
@@ -825,13 +660,13 @@ void refine_beam_LUT_mp(
825
660
 
826
661
  if (rq.verbose) {
827
662
  float sum_distances = 0;
828
- for (int j = 0; j < distances_size; j++) {
663
+ for (size_t j = 0; j < distances_size; j++) {
829
664
  sum_distances += distances_ptr[j];
830
665
  }
831
666
  printf("[%.3f s] encode stage %d, %d bits, "
832
667
  "total error %g, beam_size %d\n",
833
668
  (getmillisecs() - t0) / 1000,
834
- m,
669
+ int(m),
835
670
  int(rq.nbits[m]),
836
671
  sum_distances,
837
672
  beam_size);
@@ -877,12 +712,14 @@ void compute_codes_add_centroids_mp_lut0(
877
712
  pool.norms.resize(n);
878
713
  // recover the norms of reconstruction as
879
714
  // || original_vector - residual ||^2
880
- for (size_t i = 0; i < n; i++) {
881
- pool.norms[i] = fvec_L2sqr(
882
- x + i * rq.d,
883
- pool.residuals.data() + i * rq.max_beam_size * rq.d,
884
- rq.d);
885
- }
715
+ with_simd_level([&]<SIMDLevel SL>() {
716
+ for (size_t i = 0; i < n; i++) {
717
+ pool.norms[i] = fvec_L2sqr<SL>(
718
+ x + i * rq.d,
719
+ pool.residuals.data() + i * rq.max_beam_size * rq.d,
720
+ rq.d);
721
+ }
722
+ });
886
723
  }
887
724
 
888
725
  // pack only the first code of the beam