faiss 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (361) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +2 -1
  4. data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
  5. data/ext/faiss/index_binary.cpp +1 -1
  6. data/ext/faiss/kmeans.cpp +1 -1
  7. data/ext/faiss/pca_matrix.cpp +1 -1
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
  10. data/lib/faiss/version.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +93 -80
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -240
  13. data/vendor/faiss/faiss/Clustering.h +6 -0
  14. data/vendor/faiss/faiss/IVFlib.cpp +41 -21
  15. data/vendor/faiss/faiss/Index.cpp +6 -5
  16. data/vendor/faiss/faiss/Index.h +5 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  21. data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
  22. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
  32. data/vendor/faiss/faiss/IndexFastScan.h +25 -23
  33. data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
  34. data/vendor/faiss/faiss/IndexFlat.h +21 -18
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
  36. data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
  37. data/vendor/faiss/faiss/IndexHNSW.h +16 -2
  38. data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
  39. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  40. data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
  41. data/vendor/faiss/faiss/IndexIVF.h +33 -12
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
  45. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
  46. data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
  47. data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
  48. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  49. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
  50. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
  53. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  55. data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
  56. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
  57. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
  58. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
  59. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  62. data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
  63. data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  66. data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
  67. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
  68. data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
  69. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  73. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  74. data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
  75. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
  76. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
  77. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
  78. data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
  79. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  80. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  84. data/vendor/faiss/faiss/IndexShards.cpp +10 -9
  85. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  86. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  88. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  89. data/vendor/faiss/faiss/MetricType.h +14 -7
  90. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  91. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  92. data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
  93. data/vendor/faiss/faiss/VectorTransform.h +16 -16
  94. data/vendor/faiss/faiss/build.cpp +23 -0
  95. data/vendor/faiss/faiss/build.h +15 -0
  96. data/vendor/faiss/faiss/clone_index.cpp +48 -47
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  101. data/vendor/faiss/faiss/factory_tools.cpp +5 -0
  102. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  109. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  110. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  111. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  112. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  113. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  114. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  115. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  116. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  117. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  118. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  119. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  120. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  121. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  122. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  123. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  125. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
  127. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  128. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  129. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
  130. data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
  131. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
  132. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  133. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
  134. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  135. data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
  136. data/vendor/faiss/faiss/impl/HNSW.h +13 -34
  137. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  138. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  139. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
  141. data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
  142. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  143. data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
  144. data/vendor/faiss/faiss/impl/NSG.h +4 -4
  145. data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
  146. data/vendor/faiss/faiss/impl/Panorama.h +258 -87
  147. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  148. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  149. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
  150. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  151. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  152. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  153. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
  154. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  155. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
  156. data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
  157. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
  158. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
  159. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  160. data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
  161. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
  162. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
  163. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  164. data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
  165. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  166. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  167. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  168. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  169. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  170. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  171. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  172. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  173. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  174. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  175. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  176. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  177. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  178. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  179. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  180. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  181. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  182. data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
  183. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  184. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  185. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  186. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  187. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  188. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  189. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
  190. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  191. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  192. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  193. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  194. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  195. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  196. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
  197. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  198. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  199. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
  200. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  201. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  202. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  203. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  204. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  205. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  206. data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
  207. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
  208. data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
  209. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  210. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  211. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  212. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
  213. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  214. data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
  215. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  216. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  217. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  218. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  219. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  220. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  221. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  222. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  223. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
  224. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
  225. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
  226. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  227. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
  228. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
  229. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  230. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
  231. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  232. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
  233. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
  234. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
  235. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
  236. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
  237. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
  238. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
  239. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
  240. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
  241. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  242. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
  243. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
  244. data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
  245. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  246. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
  247. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  248. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  249. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  250. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
  251. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  252. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  253. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  254. data/vendor/faiss/faiss/index_factory.cpp +86 -18
  255. data/vendor/faiss/faiss/index_io.h +24 -0
  256. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
  257. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  258. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  259. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  260. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  261. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  262. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
  263. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  264. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
  265. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  266. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  267. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  268. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  269. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  270. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  271. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  272. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  273. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
  274. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  275. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
  276. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
  277. data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
  278. data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
  279. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  280. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  281. data/vendor/faiss/faiss/utils/distances.cpp +390 -560
  282. data/vendor/faiss/faiss/utils/distances.h +20 -1
  283. data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
  284. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  285. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  286. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  287. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  288. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  289. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  290. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
  291. data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
  292. data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
  293. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  294. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  295. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  296. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  297. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  298. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  299. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  300. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  301. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  302. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  303. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  304. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  305. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  306. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  307. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  308. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  309. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  310. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  311. data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
  312. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  313. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  314. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  315. data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
  316. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  317. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  318. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
  319. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
  320. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
  321. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
  322. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
  323. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  324. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  325. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
  326. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  327. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  328. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  329. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  330. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  331. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  332. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  333. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  339. data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
  340. data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
  341. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  342. data/vendor/faiss/faiss/utils/utils.cpp +5 -5
  343. data/vendor/faiss/faiss/utils/utils.h +3 -3
  344. metadata +119 -34
  345. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  346. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  347. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
  348. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
  349. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  350. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  351. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  352. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  353. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
  354. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  355. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  356. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
  357. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  358. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  359. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  360. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
  361. /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
@@ -10,6 +10,7 @@
10
10
  #include <faiss/impl/FaissAssert.h>
11
11
  #include <faiss/impl/RaBitQUtils.h>
12
12
  #include <faiss/impl/RaBitQuantizerMultiBit.h>
13
+ #include <faiss/impl/simd_dispatch.h>
13
14
  #include <faiss/utils/distances.h>
14
15
  #include <faiss/utils/rabitq_simd.h>
15
16
 
@@ -27,10 +28,13 @@ using rabitq_utils::QueryFactorsData;
27
28
  using rabitq_utils::SignBitFactors;
28
29
  using rabitq_utils::SignBitFactorsWithError;
29
30
 
30
- RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
31
- : Quantizer(d, 0), // code_size will be set below
31
+ RaBitQuantizer::RaBitQuantizer(
32
+ size_t d_in,
33
+ MetricType metric,
34
+ size_t nb_bits_in)
35
+ : Quantizer(d_in, 0), // code_size will be set below
32
36
  metric_type{metric},
33
- nb_bits{nb_bits} {
37
+ nb_bits{nb_bits_in} {
34
38
  // Validate nb_bits range
35
39
  FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
36
40
 
@@ -38,7 +42,7 @@ RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
38
42
  code_size = compute_code_size(d, nb_bits);
39
43
  }
40
44
 
41
- size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
45
+ size_t RaBitQuantizer::compute_code_size(size_t d_in, size_t num_bits) const {
42
46
  // Validate inputs
43
47
  FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
44
48
 
@@ -50,7 +54,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
50
54
  // Layout for multi-bit: [binary_code: (d+7)/8
51
55
  // bytes][SignBitFactorsWithError: 12 bytes]
52
56
  // factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
53
- size_t base_size = (d + 7) / 8 +
57
+ size_t base_size = (d_in + 7) / 8 +
54
58
  (ex_bits == 0 ? sizeof(SignBitFactors)
55
59
  : sizeof(SignBitFactorsWithError));
56
60
 
@@ -58,7 +62,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
58
62
  // Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
59
63
  size_t ex_size = 0;
60
64
  if (ex_bits > 0) {
61
- ex_size = (d * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
65
+ ex_size = (d_in * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
62
66
  }
63
67
 
64
68
  return base_size + ex_size;
@@ -92,7 +96,7 @@ void RaBitQuantizer::compute_codes_core(
92
96
 
93
97
  // Compute codes
94
98
  #pragma omp parallel for if (n > 1000)
95
- for (int64_t i = 0; i < n; i++) {
99
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
96
100
  // Pointer to this vector's code
97
101
  uint8_t* code = codes + i * code_size;
98
102
 
@@ -186,7 +190,7 @@ void RaBitQuantizer::decode_core(
186
190
  const size_t ex_bits = nb_bits - 1;
187
191
 
188
192
  #pragma omp parallel for if (n > 1000)
189
- for (int64_t i = 0; i < n; i++) {
193
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
190
194
  const uint8_t* code = codes + i * code_size;
191
195
 
192
196
  // split the code into parts
@@ -218,162 +222,159 @@ void RaBitQuantizer::decode_core(
218
222
 
219
223
  namespace {
220
224
 
225
+ // Distance computers templatized on SIMDLevel to avoid per-call dynamic
226
+ // dispatch. The SIMDLevel is baked in at construction time via
227
+ // get_distance_computer, so virtual calls through the base class go
228
+ // directly to the SIMD-specialized code.
229
+
230
+ template <SIMDLevel SL>
221
231
  struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
222
232
  // the rotated query (qr - c)
223
233
  std::vector<float> rotated_q;
224
234
  // some additional numbers for the query
225
235
  QueryFactorsData query_fac;
226
236
 
227
- RaBitQDistanceComputerNotQ();
237
+ RaBitQDistanceComputerNotQ() = default;
228
238
 
229
239
  // Compute distance using only 1-bit codes (fast)
230
- float distance_to_code_1bit(const uint8_t* code) override;
240
+ float distance_to_code_1bit(const uint8_t* code) override {
241
+ FAISS_ASSERT(code != nullptr);
242
+ FAISS_ASSERT(
243
+ (metric_type == MetricType::METRIC_L2 ||
244
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
245
+ FAISS_ASSERT(rotated_q.size() == d);
231
246
 
232
- // Compute full distance using 1-bit + ex-bits (accurate)
233
- float distance_to_code_full(const uint8_t* code) override;
247
+ // split the code into parts
248
+ const uint8_t* binary_data = code;
234
249
 
235
- void set_query(const float* x) override;
236
- };
250
+ // Cast to appropriate type based on nb_bits
251
+ // For 1-bit: use SignBitFactors (8 bytes)
252
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
253
+ // f_error
254
+ size_t ex_bits = nb_bits - 1;
255
+ const SignBitFactors* base_fac = (ex_bits == 0)
256
+ ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
257
+ : reinterpret_cast<const SignBitFactorsWithError*>(
258
+ code + (d + 7) / 8);
259
+
260
+ // this is the baseline code
261
+ //
262
+ // compute <q,o> using floats
263
+ float dot_qo = 0;
264
+ // It was a willful decision (after the discussion) to not to pre-cache
265
+ // the sum of all bits, just in order to reduce the overhead per
266
+ // vector.
267
+ uint64_t sum_q = 0;
268
+
269
+ for (size_t i = 0; i < d; i++) {
270
+ // Extract i-th bit
271
+ bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
272
+ // accumulate dp
273
+ dot_qo += bit ? rotated_q[i] : 0;
274
+ // accumulate sum-of-bits
275
+ sum_q += bit ? 1 : 0;
276
+ }
237
277
 
238
- RaBitQDistanceComputerNotQ::RaBitQDistanceComputerNotQ() = default;
278
+ // Apply query factors
279
+ float final_dot =
280
+ query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
239
281
 
240
- float RaBitQDistanceComputerNotQ::distance_to_code_1bit(const uint8_t* code) {
241
- FAISS_ASSERT(code != nullptr);
242
- FAISS_ASSERT(
243
- (metric_type == MetricType::METRIC_L2 ||
244
- metric_type == MetricType::METRIC_INNER_PRODUCT));
245
- FAISS_ASSERT(rotated_q.size() == d);
246
-
247
- // split the code into parts
248
- const uint8_t* binary_data = code;
249
-
250
- // Cast to appropriate type based on nb_bits
251
- // For 1-bit: use SignBitFactors (8 bytes)
252
- // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
253
- // f_error
254
- size_t ex_bits = nb_bits - 1;
255
- const SignBitFactors* base_fac = (ex_bits == 0)
256
- ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
257
- : reinterpret_cast<const SignBitFactorsWithError*>(
258
- code + (d + 7) / 8);
259
-
260
- // this is the baseline code
261
- //
262
- // compute <q,o> using floats
263
- float dot_qo = 0;
264
- // It was a willful decision (after the discussion) to not to pre-cache
265
- // the sum of all bits, just in order to reduce the overhead per vector.
266
- uint64_t sum_q = 0;
267
-
268
- for (size_t i = 0; i < d; i++) {
269
- // Extract i-th bit
270
- bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
271
- // accumulate dp
272
- dot_qo += bit ? rotated_q[i] : 0;
273
- // accumulate sum-of-bits
274
- sum_q += bit ? 1 : 0;
275
- }
282
+ // pre_dist = ||or - c||^2 + ||qr - c||^2 -
283
+ // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
284
+ float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
285
+ 2 * base_fac->dp_multiplier * final_dot;
276
286
 
277
- // Apply query factors
278
- float final_dot =
279
- query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
280
-
281
- // pre_dist = ||or - c||^2 + ||qr - c||^2 -
282
- // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
283
- float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
284
- 2 * base_fac->dp_multiplier * final_dot;
285
-
286
- if (metric_type == MetricType::METRIC_L2) {
287
- // ||or - q||^ 2
288
- return pre_dist;
289
- } else {
290
- // metric == MetricType::METRIC_INNER_PRODUCT
291
- return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
287
+ if (metric_type == MetricType::METRIC_L2) {
288
+ // ||or - q||^ 2
289
+ return pre_dist;
290
+ } else {
291
+ // metric == MetricType::METRIC_INNER_PRODUCT
292
+ // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
293
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
294
+ }
292
295
  }
293
- }
294
296
 
295
- float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
296
- FAISS_ASSERT(code != nullptr);
297
- FAISS_ASSERT(
298
- (metric_type == MetricType::METRIC_L2 ||
299
- metric_type == MetricType::METRIC_INNER_PRODUCT));
300
- FAISS_ASSERT(rotated_q.size() == d);
297
+ // Compute full distance using 1-bit + ex-bits (accurate)
298
+ float distance_to_code_full(const uint8_t* code) override {
299
+ FAISS_ASSERT(code != nullptr);
300
+ FAISS_ASSERT(
301
+ (metric_type == MetricType::METRIC_L2 ||
302
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
303
+ FAISS_ASSERT(rotated_q.size() == d);
301
304
 
302
- size_t ex_bits = nb_bits - 1;
305
+ size_t ex_bits = nb_bits - 1;
303
306
 
304
- if (ex_bits == 0) {
305
- // No ex-bits, just return 1-bit distance
306
- return distance_to_code_1bit(code);
307
+ if (ex_bits == 0) {
308
+ // No ex-bits, just return 1-bit distance
309
+ return distance_to_code_1bit(code);
310
+ }
311
+
312
+ // Extract pointers to code sections
313
+ const uint8_t* binary_data = code;
314
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
315
+ const uint8_t* ex_code = code + offset;
316
+ const ExtraBitsFactors* ex_fac =
317
+ reinterpret_cast<const ExtraBitsFactors*>(
318
+ ex_code + (d * ex_bits + 7) / 8);
319
+
320
+ float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
321
+ ? query_fac.q_dot_c
322
+ : query_fac.qr_to_c_L2sqr;
323
+ return rabitq_utils::compute_full_multibit_distance<SL>(
324
+ binary_data,
325
+ ex_code,
326
+ *ex_fac,
327
+ rotated_q.data(),
328
+ qr_base,
329
+ d,
330
+ ex_bits,
331
+ metric_type);
307
332
  }
308
333
 
309
- // Extract pointers to code sections
310
- const uint8_t* binary_data = code;
311
- size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
312
- const uint8_t* ex_code = code + offset;
313
- const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
314
- ex_code + (d * ex_bits + 7) / 8);
315
-
316
- // Call shared utility directly with rotated_q pointer
317
- float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
318
- ? query_fac.q_dot_c
319
- : query_fac.qr_to_c_L2sqr;
320
- return rabitq_utils::compute_full_multibit_distance(
321
- binary_data,
322
- ex_code,
323
- *ex_fac,
324
- rotated_q.data(),
325
- qr_base,
326
- d,
327
- ex_bits,
328
- metric_type);
329
- }
334
+ void set_query(const float* x) override {
335
+ q = x;
336
+ FAISS_ASSERT(x != nullptr);
337
+ FAISS_ASSERT(
338
+ (metric_type == MetricType::METRIC_L2 ||
339
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
330
340
 
331
- void RaBitQDistanceComputerNotQ::set_query(const float* x) {
332
- q = x;
333
- FAISS_ASSERT(x != nullptr);
334
- FAISS_ASSERT(
335
- (metric_type == MetricType::METRIC_L2 ||
336
- metric_type == MetricType::METRIC_INNER_PRODUCT));
337
-
338
- // compute the distance from the query to the centroid
339
- if (centroid != nullptr) {
340
- query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
341
- } else {
342
- query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
343
- }
341
+ // compute the distance from the query to the centroid
342
+ if (centroid != nullptr) {
343
+ query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
344
+ } else {
345
+ query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
346
+ }
344
347
 
345
- // subtract c, obtain P^(-1)(qr - c)
346
- rotated_q.resize(d);
347
- for (size_t i = 0; i < d; i++) {
348
- rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
349
- }
348
+ // subtract c, obtain P^(-1)(qr - c)
349
+ rotated_q.resize(d);
350
+ for (size_t i = 0; i < d; i++) {
351
+ rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
352
+ }
350
353
 
351
- // Compute g_error (query norm for lower bound computation)
352
- // g_error = ||qr - c|| (L2 norm of rotated query)
353
- g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
354
+ // Compute g_error = ||qr - c|| (L2 norm of rotated query)
355
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
354
356
 
355
- // compute some numbers
356
- const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
357
+ // compute some numbers — do not quantize the query
358
+ const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
357
359
 
358
- // do not quantize the query
359
- float sum_q = 0;
360
- for (size_t i = 0; i < d; i++) {
361
- sum_q += rotated_q[i];
362
- }
360
+ float sum_q = 0;
361
+ for (size_t i = 0; i < d; i++) {
362
+ sum_q += rotated_q[i];
363
+ }
363
364
 
364
- query_fac.c1 = 2 * inv_d;
365
- query_fac.c2 = 0;
366
- query_fac.c34 = sum_q * inv_d;
365
+ query_fac.c1 = 2 * inv_d;
366
+ query_fac.c2 = 0;
367
+ query_fac.c34 = sum_q * inv_d;
367
368
 
368
- if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
369
- // precompute if needed
370
- query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
371
- query_fac.q_dot_c =
372
- centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
369
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
370
+ query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
371
+ query_fac.q_dot_c =
372
+ centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
373
+ }
373
374
  }
374
- }
375
+ };
375
376
 
376
- //
377
+ template <SIMDLevel SL>
377
378
  struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
378
379
  // the rotated and quantized query (qr - c)
379
380
  std::vector<float> rotated_q;
@@ -391,176 +392,188 @@ struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
391
392
  // the smallest value divisible by 8 that is not smaller than dim
392
393
  size_t popcount_aligned_dim = 0;
393
394
 
394
- RaBitQDistanceComputerQ();
395
+ RaBitQDistanceComputerQ() = default;
395
396
 
396
397
  // Compute distance using only 1-bit codes (fast)
397
- float distance_to_code_1bit(const uint8_t* code) override;
398
-
399
- // Compute full distance using 1-bit + ex-bits (accurate)
400
- float distance_to_code_full(const uint8_t* code) override;
398
+ float distance_to_code_1bit(const uint8_t* code) override {
399
+ FAISS_ASSERT(code != nullptr);
400
+ FAISS_ASSERT(
401
+ (metric_type == MetricType::METRIC_L2 ||
402
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
401
403
 
402
- void set_query(const float* x) override;
403
- };
404
+ // split the code into parts
405
+ size_t size = (d + 7) / 8;
406
+ const uint8_t* binary_data = code;
404
407
 
405
- RaBitQDistanceComputerQ::RaBitQDistanceComputerQ() = default;
408
+ // Cast to appropriate type based on nb_bits
409
+ // For 1-bit: use SignBitFactors (8 bytes)
410
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which
411
+ // includes f_error
412
+ size_t ex_bits = nb_bits - 1;
413
+ const SignBitFactors* base_fac = (ex_bits == 0)
414
+ ? reinterpret_cast<const SignBitFactors*>(code + size)
415
+ : reinterpret_cast<const SignBitFactorsWithError*>(code + size);
416
+
417
+ // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
418
+ float final_dot = 0;
419
+ if (centered) {
420
+ int64_t int_dot = ((1 << qb) - 1) * d;
421
+ // See RaBitDistanceComputerNotQ::distance_to_code() for
422
+ // baseline code.
423
+ int_dot -= 2 *
424
+ rabitq::bitwise_xor_dot_product<SL>(
425
+ rearranged_rotated_qq.data(),
426
+ binary_data,
427
+ size,
428
+ qb);
429
+ final_dot += int_dot * query_fac.int_dot_scale;
430
+ } else {
431
+ auto dot_qo = rabitq::bitwise_and_dot_product<SL>(
432
+ rearranged_rotated_qq.data(), binary_data, size, qb);
433
+ // It was a willful decision (after the discussion) to not to
434
+ // pre-cache the sum of all bits, just in order to reduce the
435
+ // overhead per vector.
436
+ // process 64-bit popcounts
437
+ auto sum_q = rabitq::popcount<SL>(binary_data, size);
438
+ // dot-product itself
439
+ final_dot += query_fac.c1 * dot_qo;
440
+ // normalizer coefficients
441
+ final_dot += query_fac.c2 * sum_q;
442
+ // normalizer coefficients
443
+ final_dot -= query_fac.c34;
444
+ }
406
445
 
407
- float RaBitQDistanceComputerQ::distance_to_code_1bit(const uint8_t* code) {
408
- FAISS_ASSERT(code != nullptr);
409
- FAISS_ASSERT(
410
- (metric_type == MetricType::METRIC_L2 ||
411
- metric_type == MetricType::METRIC_INNER_PRODUCT));
446
+ const float pre_dist = base_fac->or_minus_c_l2sqr +
447
+ query_fac.qr_to_c_L2sqr -
448
+ 2 * base_fac->dp_multiplier * final_dot;
412
449
 
413
- // split the code into parts
414
- size_t size = (d + 7) / 8;
415
- const uint8_t* binary_data = code;
416
-
417
- // Cast to appropriate type based on nb_bits
418
- // For 1-bit: use SignBitFactors (8 bytes)
419
- // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
420
- // f_error
421
- size_t ex_bits = nb_bits - 1;
422
- const SignBitFactors* base_fac = (ex_bits == 0)
423
- ? reinterpret_cast<const SignBitFactors*>(code + size)
424
- : reinterpret_cast<const SignBitFactorsWithError*>(code + size);
425
-
426
- // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
427
- float final_dot = 0;
428
- if (centered) {
429
- int64_t int_dot = ((1 << qb) - 1) * d;
430
- // See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
431
- int_dot -= 2 *
432
- rabitq::bitwise_xor_dot_product(
433
- rearranged_rotated_qq.data(), binary_data, size, qb);
434
- final_dot += int_dot * query_fac.int_dot_scale;
435
- } else {
436
- auto dot_qo = rabitq::bitwise_and_dot_product(
437
- rearranged_rotated_qq.data(), binary_data, size, qb);
438
- // It was a willful decision (after the discussion) to not to pre-cache
439
- // the sum of all bits, just in order to reduce the overhead per vector.
440
- // process 64-bit popcounts
441
- auto sum_q = rabitq::popcount(binary_data, size);
442
- // dot-product itself
443
- final_dot += query_fac.c1 * dot_qo;
444
- // normalizer coefficients
445
- final_dot += query_fac.c2 * sum_q;
446
- // normalizer coefficients
447
- final_dot -= query_fac.c34;
450
+ if (metric_type == MetricType::METRIC_L2) {
451
+ // ||or - q||^ 2
452
+ return pre_dist;
453
+ } else {
454
+ // metric == MetricType::METRIC_INNER_PRODUCT
455
+ // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
456
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
457
+ }
448
458
  }
449
459
 
450
- // pre_dist = ||or - c||^2 + ||qr - c||^2 -
451
- // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
452
- const float pre_dist = base_fac->or_minus_c_l2sqr +
453
- query_fac.qr_to_c_L2sqr - 2 * base_fac->dp_multiplier * final_dot;
454
-
455
- if (metric_type == MetricType::METRIC_L2) {
456
- // ||or - q||^ 2
457
- return pre_dist;
458
- } else {
459
- // metric == MetricType::METRIC_INNER_PRODUCT
460
- // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
461
- return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
462
- }
463
- }
460
+ // Compute full distance using 1-bit + ex-bits (accurate)
461
+ float distance_to_code_full(const uint8_t* code) override {
462
+ FAISS_ASSERT(code != nullptr);
463
+ FAISS_ASSERT(
464
+ (metric_type == MetricType::METRIC_L2 ||
465
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
466
+ FAISS_ASSERT(rotated_q.size() == d);
464
467
 
465
- float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
466
- FAISS_ASSERT(code != nullptr);
467
- FAISS_ASSERT(
468
- (metric_type == MetricType::METRIC_L2 ||
469
- metric_type == MetricType::METRIC_INNER_PRODUCT));
470
- FAISS_ASSERT(rotated_q.size() == d);
468
+ size_t ex_bits = nb_bits - 1;
471
469
 
472
- size_t ex_bits = nb_bits - 1;
470
+ if (ex_bits == 0) {
471
+ // No ex-bits, just return 1-bit distance
472
+ return distance_to_code_1bit(code);
473
+ }
473
474
 
474
- if (ex_bits == 0) {
475
- // No ex-bits, just return 1-bit distance
476
- return distance_to_code_1bit(code);
475
+ // Extract pointers to code sections
476
+ const uint8_t* binary_data = code;
477
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
478
+ const uint8_t* ex_code = code + offset;
479
+ const ExtraBitsFactors* ex_fac =
480
+ reinterpret_cast<const ExtraBitsFactors*>(
481
+ ex_code + (d * ex_bits + 7) / 8);
482
+
483
+ float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
484
+ ? query_fac.q_dot_c
485
+ : query_fac.qr_to_c_L2sqr;
486
+ return rabitq_utils::compute_full_multibit_distance<SL>(
487
+ binary_data,
488
+ ex_code,
489
+ *ex_fac,
490
+ rotated_q.data(),
491
+ qr_base,
492
+ d,
493
+ ex_bits,
494
+ metric_type);
477
495
  }
478
496
 
479
- // Extract pointers to code sections
480
- const uint8_t* binary_data = code;
481
- size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
482
- const uint8_t* ex_code = code + offset;
483
- const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
484
- ex_code + (d * ex_bits + 7) / 8);
485
-
486
- // Call shared utility directly with rotated_q pointer
487
- float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
488
- ? query_fac.q_dot_c
489
- : query_fac.qr_to_c_L2sqr;
490
- return rabitq_utils::compute_full_multibit_distance(
491
- binary_data,
492
- ex_code,
493
- *ex_fac,
494
- rotated_q.data(),
495
- qr_base,
496
- d,
497
- ex_bits,
498
- metric_type);
499
- }
497
+ void set_query(const float* x) override {
498
+ q = x;
499
+ FAISS_ASSERT(x != nullptr);
500
+ FAISS_ASSERT(
501
+ (metric_type == MetricType::METRIC_L2 ||
502
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
503
+ FAISS_THROW_IF_NOT(qb <= 8);
504
+ FAISS_THROW_IF_NOT(qb > 0);
505
+
506
+ // Use shared utilities for core query factor computation
507
+ // rotated_q is populated directly by compute_query_factors as an
508
+ // output parameter
509
+ query_fac = rabitq_utils::compute_query_factors(
510
+ x,
511
+ d,
512
+ centroid,
513
+ qb,
514
+ centered,
515
+ metric_type,
516
+ rotated_q,
517
+ rotated_qq);
518
+
519
+ // Compute g_error (query norm for lower bound computation)
520
+ // g_error = ||qr - c|| (L2 norm of rotated query)
521
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
522
+
523
+ // Rearrange the query vector for SIMD operations
524
+ // (RaBitQuantizer-specific)
525
+ popcount_aligned_dim = ((d + 7) / 8) * 8;
526
+ size_t offset = (d + 7) / 8;
527
+
528
+ rearranged_rotated_qq.resize(offset * qb);
529
+ std::fill(
530
+ rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
531
+
532
+ for (size_t idim = 0; idim < d; idim++) {
533
+ for (size_t iv = 0; iv < qb; iv++) {
534
+ const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
535
+ rearranged_rotated_qq[iv * offset + idim / 8] |=
536
+ bit ? (1 << (idim % 8)) : 0;
537
+ }
538
+ }
539
+ }
540
+ };
500
541
 
501
542
  // Use shared constant from RaBitQUtils
502
543
  using rabitq_utils::Z_MAX_BY_QB;
503
544
 
504
- void RaBitQDistanceComputerQ::set_query(const float* x) {
505
- q = x;
506
- FAISS_ASSERT(x != nullptr);
507
- FAISS_ASSERT(
508
- (metric_type == MetricType::METRIC_L2 ||
509
- metric_type == MetricType::METRIC_INNER_PRODUCT));
510
- FAISS_THROW_IF_NOT(qb <= 8);
511
- FAISS_THROW_IF_NOT(qb > 0);
512
-
513
- // Use shared utilities for core query factor computation
514
- // rotated_q is populated directly by compute_query_factors as an output
515
- // parameter
516
- query_fac = rabitq_utils::compute_query_factors(
517
- x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
518
-
519
- // Compute g_error (query norm for lower bound computation)
520
- // g_error = ||qr - c|| (L2 norm of rotated query)
521
- g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
522
-
523
- // Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
524
- popcount_aligned_dim = ((d + 7) / 8) * 8;
525
- size_t offset = (d + 7) / 8;
526
-
527
- rearranged_rotated_qq.resize(offset * qb);
528
- std::fill(rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
529
-
530
- for (size_t idim = 0; idim < d; idim++) {
531
- for (size_t iv = 0; iv < qb; iv++) {
532
- const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
533
- rearranged_rotated_qq[iv * offset + idim / 8] |=
534
- bit ? (1 << (idim % 8)) : 0;
535
- }
536
- }
537
- }
538
-
539
545
  } // anonymous namespace
540
546
 
541
547
  FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
542
548
  uint8_t qb,
543
549
  const float* centroid_in,
544
550
  bool centered) const {
545
- if (qb == 0) {
546
- auto dc = std::make_unique<RaBitQDistanceComputerNotQ>();
547
- dc->metric_type = metric_type;
548
- dc->d = d;
549
- dc->centroid = centroid_in;
550
- dc->nb_bits = nb_bits;
551
-
552
- return dc.release();
553
- } else {
554
- auto dc = std::make_unique<RaBitQDistanceComputerQ>();
555
- dc->metric_type = metric_type;
556
- dc->d = d;
557
- dc->centroid = centroid_in;
558
- dc->qb = qb;
559
- dc->centered = centered;
560
- dc->nb_bits = nb_bits;
561
-
562
- return dc.release();
563
- }
551
+ // Dispatch on SIMDLevel once here so the distance computer methods
552
+ // call the SIMD-specialized rabitq functions directly (no per-call
553
+ // with_simd_level overhead).
554
+ return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
555
+ [&]<SIMDLevel SL>() -> FlatCodesDistanceComputer* {
556
+ if (qb == 0) {
557
+ auto dc =
558
+ std::make_unique<RaBitQDistanceComputerNotQ<SL>>();
559
+ dc->metric_type = metric_type;
560
+ dc->d = d;
561
+ dc->centroid = centroid_in;
562
+ dc->nb_bits = nb_bits;
563
+
564
+ return dc.release();
565
+ } else {
566
+ auto dc = std::make_unique<RaBitQDistanceComputerQ<SL>>();
567
+ dc->metric_type = metric_type;
568
+ dc->d = d;
569
+ dc->centroid = centroid_in;
570
+ dc->qb = qb;
571
+ dc->centered = centered;
572
+ dc->nb_bits = nb_bits;
573
+
574
+ return dc.release();
575
+ }
576
+ });
564
577
  }
565
578
 
566
579
  } // namespace faiss