faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -10,8 +10,10 @@
10
10
  #include <faiss/impl/FaissAssert.h>
11
11
  #include <faiss/impl/RaBitQUtils.h>
12
12
  #include <faiss/impl/RaBitQuantizerMultiBit.h>
13
+ #include <faiss/impl/simd_dispatch.h>
13
14
  #include <faiss/utils/distances.h>
14
15
  #include <faiss/utils/rabitq_simd.h>
16
+
15
17
  #include <algorithm>
16
18
  #include <cmath>
17
19
  #include <cstring>
@@ -26,10 +28,13 @@ using rabitq_utils::QueryFactorsData;
26
28
  using rabitq_utils::SignBitFactors;
27
29
  using rabitq_utils::SignBitFactorsWithError;
28
30
 
29
- RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
30
- : Quantizer(d, 0), // code_size will be set below
31
+ RaBitQuantizer::RaBitQuantizer(
32
+ size_t d_in,
33
+ MetricType metric,
34
+ size_t nb_bits_in)
35
+ : Quantizer(d_in, 0), // code_size will be set below
31
36
  metric_type{metric},
32
- nb_bits{nb_bits} {
37
+ nb_bits{nb_bits_in} {
33
38
  // Validate nb_bits range
34
39
  FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
35
40
 
@@ -37,7 +42,7 @@ RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
37
42
  code_size = compute_code_size(d, nb_bits);
38
43
  }
39
44
 
40
- size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
45
+ size_t RaBitQuantizer::compute_code_size(size_t d_in, size_t num_bits) const {
41
46
  // Validate inputs
42
47
  FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
43
48
 
@@ -49,7 +54,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
49
54
  // Layout for multi-bit: [binary_code: (d+7)/8
50
55
  // bytes][SignBitFactorsWithError: 12 bytes]
51
56
  // factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
52
- size_t base_size = (d + 7) / 8 +
57
+ size_t base_size = (d_in + 7) / 8 +
53
58
  (ex_bits == 0 ? sizeof(SignBitFactors)
54
59
  : sizeof(SignBitFactorsWithError));
55
60
 
@@ -57,13 +62,13 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
57
62
  // Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
58
63
  size_t ex_size = 0;
59
64
  if (ex_bits > 0) {
60
- ex_size = (d * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
65
+ ex_size = (d_in * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
61
66
  }
62
67
 
63
68
  return base_size + ex_size;
64
69
  }
65
70
 
66
- void RaBitQuantizer::train(size_t n, const float* x) {
71
+ void RaBitQuantizer::train(size_t /*n*/, const float* /*x*/) {
67
72
  // does nothing
68
73
  }
69
74
 
@@ -91,7 +96,7 @@ void RaBitQuantizer::compute_codes_core(
91
96
 
92
97
  // Compute codes
93
98
  #pragma omp parallel for if (n > 1000)
94
- for (int64_t i = 0; i < n; i++) {
99
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
95
100
  // Pointer to this vector's code
96
101
  uint8_t* code = codes + i * code_size;
97
102
 
@@ -185,7 +190,7 @@ void RaBitQuantizer::decode_core(
185
190
  const size_t ex_bits = nb_bits - 1;
186
191
 
187
192
  #pragma omp parallel for if (n > 1000)
188
- for (int64_t i = 0; i < n; i++) {
193
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
189
194
  const uint8_t* code = codes + i * code_size;
190
195
 
191
196
  // split the code into parts
@@ -215,183 +220,161 @@ void RaBitQuantizer::decode_core(
215
220
  }
216
221
  }
217
222
 
218
- // Implementation of RaBitQDistanceComputer (declared in header)
219
-
220
- float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
221
- FAISS_ASSERT(code != nullptr);
222
-
223
- // Compute estimated distance using 1-bit codes
224
- float est_distance = distance_to_code_1bit(code);
225
-
226
- // Extract f_error from the code
227
- size_t size = (d + 7) / 8;
228
- const SignBitFactorsWithError* base_fac =
229
- reinterpret_cast<const SignBitFactorsWithError*>(code + size);
230
- float f_error = base_fac->f_error;
231
-
232
- // Compute proper lower bound using RaBitQ error formula:
233
- // lower_bound = est_distance - f_error * g_error
234
- // This guarantees: lower_bound ≤ true_distance
235
- float lower_bound = est_distance - (f_error * g_error);
236
-
237
- // Distance cannot be negative
238
- return std::max(0.0f, lower_bound);
239
- }
240
-
241
223
  namespace {
242
224
 
225
+ // Distance computers templatized on SIMDLevel to avoid per-call dynamic
226
+ // dispatch. The SIMDLevel is baked in at construction time via
227
+ // get_distance_computer, so virtual calls through the base class go
228
+ // directly to the SIMD-specialized code.
229
+
230
+ template <SIMDLevel SL>
243
231
  struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
244
232
  // the rotated query (qr - c)
245
233
  std::vector<float> rotated_q;
246
234
  // some additional numbers for the query
247
235
  QueryFactorsData query_fac;
248
236
 
249
- RaBitQDistanceComputerNotQ();
237
+ RaBitQDistanceComputerNotQ() = default;
250
238
 
251
239
  // Compute distance using only 1-bit codes (fast)
252
- float distance_to_code_1bit(const uint8_t* code) override;
240
+ float distance_to_code_1bit(const uint8_t* code) override {
241
+ FAISS_ASSERT(code != nullptr);
242
+ FAISS_ASSERT(
243
+ (metric_type == MetricType::METRIC_L2 ||
244
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
245
+ FAISS_ASSERT(rotated_q.size() == d);
253
246
 
254
- // Compute full distance using 1-bit + ex-bits (accurate)
255
- float distance_to_code_full(const uint8_t* code) override;
247
+ // split the code into parts
248
+ const uint8_t* binary_data = code;
256
249
 
257
- void set_query(const float* x) override;
258
- };
250
+ // Cast to appropriate type based on nb_bits
251
+ // For 1-bit: use SignBitFactors (8 bytes)
252
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
253
+ // f_error
254
+ size_t ex_bits = nb_bits - 1;
255
+ const SignBitFactors* base_fac = (ex_bits == 0)
256
+ ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
257
+ : reinterpret_cast<const SignBitFactorsWithError*>(
258
+ code + (d + 7) / 8);
259
259
 
260
- RaBitQDistanceComputerNotQ::RaBitQDistanceComputerNotQ() = default;
260
+ // this is the baseline code
261
+ //
262
+ // compute <q,o> using floats
263
+ float dot_qo = 0;
264
+ // It was a willful decision (after the discussion) to not to pre-cache
265
+ // the sum of all bits, just in order to reduce the overhead per
266
+ // vector.
267
+ uint64_t sum_q = 0;
268
+
269
+ for (size_t i = 0; i < d; i++) {
270
+ // Extract i-th bit
271
+ bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
272
+ // accumulate dp
273
+ dot_qo += bit ? rotated_q[i] : 0;
274
+ // accumulate sum-of-bits
275
+ sum_q += bit ? 1 : 0;
276
+ }
261
277
 
262
- float RaBitQDistanceComputerNotQ::distance_to_code_1bit(const uint8_t* code) {
263
- FAISS_ASSERT(code != nullptr);
264
- FAISS_ASSERT(
265
- (metric_type == MetricType::METRIC_L2 ||
266
- metric_type == MetricType::METRIC_INNER_PRODUCT));
267
- FAISS_ASSERT(rotated_q.size() == d);
268
-
269
- // split the code into parts
270
- const uint8_t* binary_data = code;
271
-
272
- // Cast to appropriate type based on nb_bits
273
- // For 1-bit: use SignBitFactors (8 bytes)
274
- // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
275
- // f_error
276
- size_t ex_bits = nb_bits - 1;
277
- const SignBitFactors* base_fac = (ex_bits == 0)
278
- ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
279
- : reinterpret_cast<const SignBitFactorsWithError*>(
280
- code + (d + 7) / 8);
281
-
282
- // this is the baseline code
283
- //
284
- // compute <q,o> using floats
285
- float dot_qo = 0;
286
- // It was a willful decision (after the discussion) to not to pre-cache
287
- // the sum of all bits, just in order to reduce the overhead per vector.
288
- uint64_t sum_q = 0;
289
-
290
- for (size_t i = 0; i < d; i++) {
291
- // Extract i-th bit
292
- bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
293
- // accumulate dp
294
- dot_qo += bit ? rotated_q[i] : 0;
295
- // accumulate sum-of-bits
296
- sum_q += bit ? 1 : 0;
297
- }
278
+ // Apply query factors
279
+ float final_dot =
280
+ query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
298
281
 
299
- // Apply query factors
300
- float final_dot =
301
- query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
302
-
303
- // pre_dist = ||or - c||^2 + ||qr - c||^2 -
304
- // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
305
- float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
306
- 2 * base_fac->dp_multiplier * final_dot;
307
-
308
- if (metric_type == MetricType::METRIC_L2) {
309
- // ||or - q||^ 2
310
- return pre_dist;
311
- } else {
312
- // metric == MetricType::METRIC_INNER_PRODUCT
313
- return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
282
+ // pre_dist = ||or - c||^2 + ||qr - c||^2 -
283
+ // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
284
+ float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
285
+ 2 * base_fac->dp_multiplier * final_dot;
286
+
287
+ if (metric_type == MetricType::METRIC_L2) {
288
+ // ||or - q||^ 2
289
+ return pre_dist;
290
+ } else {
291
+ // metric == MetricType::METRIC_INNER_PRODUCT
292
+ // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
293
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
294
+ }
314
295
  }
315
- }
316
296
 
317
- float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
318
- FAISS_ASSERT(code != nullptr);
319
- FAISS_ASSERT(
320
- (metric_type == MetricType::METRIC_L2 ||
321
- metric_type == MetricType::METRIC_INNER_PRODUCT));
322
- FAISS_ASSERT(rotated_q.size() == d);
297
+ // Compute full distance using 1-bit + ex-bits (accurate)
298
+ float distance_to_code_full(const uint8_t* code) override {
299
+ FAISS_ASSERT(code != nullptr);
300
+ FAISS_ASSERT(
301
+ (metric_type == MetricType::METRIC_L2 ||
302
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
303
+ FAISS_ASSERT(rotated_q.size() == d);
323
304
 
324
- size_t ex_bits = nb_bits - 1;
305
+ size_t ex_bits = nb_bits - 1;
325
306
 
326
- if (ex_bits == 0) {
327
- // No ex-bits, just return 1-bit distance
328
- return distance_to_code_1bit(code);
329
- }
307
+ if (ex_bits == 0) {
308
+ // No ex-bits, just return 1-bit distance
309
+ return distance_to_code_1bit(code);
310
+ }
330
311
 
331
- // Extract pointers to code sections
332
- const uint8_t* binary_data = code;
333
- size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
334
- const uint8_t* ex_code = code + offset;
335
- const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
336
- ex_code + (d * ex_bits + 7) / 8);
337
-
338
- // Call shared utility directly with rotated_q pointer
339
- return rabitq_utils::compute_full_multibit_distance(
340
- binary_data,
341
- ex_code,
342
- *ex_fac,
343
- rotated_q.data(),
344
- query_fac.qr_to_c_L2sqr,
345
- query_fac.qr_norm_L2sqr,
346
- d,
347
- ex_bits,
348
- metric_type);
349
- }
312
+ // Extract pointers to code sections
313
+ const uint8_t* binary_data = code;
314
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
315
+ const uint8_t* ex_code = code + offset;
316
+ const ExtraBitsFactors* ex_fac =
317
+ reinterpret_cast<const ExtraBitsFactors*>(
318
+ ex_code + (d * ex_bits + 7) / 8);
319
+
320
+ float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
321
+ ? query_fac.q_dot_c
322
+ : query_fac.qr_to_c_L2sqr;
323
+ return rabitq_utils::compute_full_multibit_distance<SL>(
324
+ binary_data,
325
+ ex_code,
326
+ *ex_fac,
327
+ rotated_q.data(),
328
+ qr_base,
329
+ d,
330
+ ex_bits,
331
+ metric_type);
332
+ }
350
333
 
351
- void RaBitQDistanceComputerNotQ::set_query(const float* x) {
352
- q = x;
353
- FAISS_ASSERT(x != nullptr);
354
- FAISS_ASSERT(
355
- (metric_type == MetricType::METRIC_L2 ||
356
- metric_type == MetricType::METRIC_INNER_PRODUCT));
334
+ void set_query(const float* x) override {
335
+ q = x;
336
+ FAISS_ASSERT(x != nullptr);
337
+ FAISS_ASSERT(
338
+ (metric_type == MetricType::METRIC_L2 ||
339
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
357
340
 
358
- // compute the distance from the query to the centroid
359
- if (centroid != nullptr) {
360
- query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
361
- } else {
362
- query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
363
- }
341
+ // compute the distance from the query to the centroid
342
+ if (centroid != nullptr) {
343
+ query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
344
+ } else {
345
+ query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
346
+ }
364
347
 
365
- // subtract c, obtain P^(-1)(qr - c)
366
- rotated_q.resize(d);
367
- for (size_t i = 0; i < d; i++) {
368
- rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
369
- }
348
+ // subtract c, obtain P^(-1)(qr - c)
349
+ rotated_q.resize(d);
350
+ for (size_t i = 0; i < d; i++) {
351
+ rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
352
+ }
370
353
 
371
- // Compute g_error (query norm for lower bound computation)
372
- // g_error = ||qr - c|| (L2 norm of rotated query)
373
- g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
354
+ // Compute g_error = ||qr - c|| (L2 norm of rotated query)
355
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
374
356
 
375
- // compute some numbers
376
- const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
357
+ // compute some numbers — do not quantize the query
358
+ const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
377
359
 
378
- // do not quantize the query
379
- float sum_q = 0;
380
- for (size_t i = 0; i < d; i++) {
381
- sum_q += rotated_q[i];
382
- }
360
+ float sum_q = 0;
361
+ for (size_t i = 0; i < d; i++) {
362
+ sum_q += rotated_q[i];
363
+ }
383
364
 
384
- query_fac.c1 = 2 * inv_d;
385
- query_fac.c2 = 0;
386
- query_fac.c34 = sum_q * inv_d;
365
+ query_fac.c1 = 2 * inv_d;
366
+ query_fac.c2 = 0;
367
+ query_fac.c34 = sum_q * inv_d;
387
368
 
388
- if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
389
- // precompute if needed
390
- query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
369
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
370
+ query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
371
+ query_fac.q_dot_c =
372
+ centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
373
+ }
391
374
  }
392
- }
375
+ };
393
376
 
394
- //
377
+ template <SIMDLevel SL>
395
378
  struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
396
379
  // the rotated and quantized query (qr - c)
397
380
  std::vector<float> rotated_q;
@@ -409,174 +392,188 @@ struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
409
392
  // the smallest value divisible by 8 that is not smaller than dim
410
393
  size_t popcount_aligned_dim = 0;
411
394
 
412
- RaBitQDistanceComputerQ();
395
+ RaBitQDistanceComputerQ() = default;
413
396
 
414
397
  // Compute distance using only 1-bit codes (fast)
415
- float distance_to_code_1bit(const uint8_t* code) override;
416
-
417
- // Compute full distance using 1-bit + ex-bits (accurate)
418
- float distance_to_code_full(const uint8_t* code) override;
398
+ float distance_to_code_1bit(const uint8_t* code) override {
399
+ FAISS_ASSERT(code != nullptr);
400
+ FAISS_ASSERT(
401
+ (metric_type == MetricType::METRIC_L2 ||
402
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
419
403
 
420
- void set_query(const float* x) override;
421
- };
404
+ // split the code into parts
405
+ size_t size = (d + 7) / 8;
406
+ const uint8_t* binary_data = code;
422
407
 
423
- RaBitQDistanceComputerQ::RaBitQDistanceComputerQ() = default;
408
+ // Cast to appropriate type based on nb_bits
409
+ // For 1-bit: use SignBitFactors (8 bytes)
410
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which
411
+ // includes f_error
412
+ size_t ex_bits = nb_bits - 1;
413
+ const SignBitFactors* base_fac = (ex_bits == 0)
414
+ ? reinterpret_cast<const SignBitFactors*>(code + size)
415
+ : reinterpret_cast<const SignBitFactorsWithError*>(code + size);
416
+
417
+ // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
418
+ float final_dot = 0;
419
+ if (centered) {
420
+ int64_t int_dot = ((1 << qb) - 1) * d;
421
+ // See RaBitDistanceComputerNotQ::distance_to_code() for
422
+ // baseline code.
423
+ int_dot -= 2 *
424
+ rabitq::bitwise_xor_dot_product<SL>(
425
+ rearranged_rotated_qq.data(),
426
+ binary_data,
427
+ size,
428
+ qb);
429
+ final_dot += int_dot * query_fac.int_dot_scale;
430
+ } else {
431
+ auto dot_qo = rabitq::bitwise_and_dot_product<SL>(
432
+ rearranged_rotated_qq.data(), binary_data, size, qb);
433
+ // It was a willful decision (after the discussion) to not to
434
+ // pre-cache the sum of all bits, just in order to reduce the
435
+ // overhead per vector.
436
+ // process 64-bit popcounts
437
+ auto sum_q = rabitq::popcount<SL>(binary_data, size);
438
+ // dot-product itself
439
+ final_dot += query_fac.c1 * dot_qo;
440
+ // normalizer coefficients
441
+ final_dot += query_fac.c2 * sum_q;
442
+ // normalizer coefficients
443
+ final_dot -= query_fac.c34;
444
+ }
424
445
 
425
- float RaBitQDistanceComputerQ::distance_to_code_1bit(const uint8_t* code) {
426
- FAISS_ASSERT(code != nullptr);
427
- FAISS_ASSERT(
428
- (metric_type == MetricType::METRIC_L2 ||
429
- metric_type == MetricType::METRIC_INNER_PRODUCT));
446
+ const float pre_dist = base_fac->or_minus_c_l2sqr +
447
+ query_fac.qr_to_c_L2sqr -
448
+ 2 * base_fac->dp_multiplier * final_dot;
430
449
 
431
- // split the code into parts
432
- size_t size = (d + 7) / 8;
433
- const uint8_t* binary_data = code;
434
-
435
- // Cast to appropriate type based on nb_bits
436
- // For 1-bit: use SignBitFactors (8 bytes)
437
- // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
438
- // f_error
439
- size_t ex_bits = nb_bits - 1;
440
- const SignBitFactors* base_fac = (ex_bits == 0)
441
- ? reinterpret_cast<const SignBitFactors*>(code + size)
442
- : reinterpret_cast<const SignBitFactorsWithError*>(code + size);
443
-
444
- // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
445
- float final_dot = 0;
446
- if (centered) {
447
- int64_t int_dot = ((1 << qb) - 1) * d;
448
- // See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
449
- int_dot -= 2 *
450
- rabitq::bitwise_xor_dot_product(
451
- rearranged_rotated_qq.data(), binary_data, size, qb);
452
- final_dot += int_dot * query_fac.int_dot_scale;
453
- } else {
454
- auto dot_qo = rabitq::bitwise_and_dot_product(
455
- rearranged_rotated_qq.data(), binary_data, size, qb);
456
- // It was a willful decision (after the discussion) to not to pre-cache
457
- // the sum of all bits, just in order to reduce the overhead per vector.
458
- // process 64-bit popcounts
459
- auto sum_q = rabitq::popcount(binary_data, size);
460
- // dot-product itself
461
- final_dot += query_fac.c1 * dot_qo;
462
- // normalizer coefficients
463
- final_dot += query_fac.c2 * sum_q;
464
- // normalizer coefficients
465
- final_dot -= query_fac.c34;
450
+ if (metric_type == MetricType::METRIC_L2) {
451
+ // ||or - q||^ 2
452
+ return pre_dist;
453
+ } else {
454
+ // metric == MetricType::METRIC_INNER_PRODUCT
455
+ // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
456
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
457
+ }
466
458
  }
467
459
 
468
- // pre_dist = ||or - c||^2 + ||qr - c||^2 -
469
- // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
470
- const float pre_dist = base_fac->or_minus_c_l2sqr +
471
- query_fac.qr_to_c_L2sqr - 2 * base_fac->dp_multiplier * final_dot;
472
-
473
- if (metric_type == MetricType::METRIC_L2) {
474
- // ||or - q||^ 2
475
- return pre_dist;
476
- } else {
477
- // metric == MetricType::METRIC_INNER_PRODUCT
478
- // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
479
- return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
480
- }
481
- }
460
+ // Compute full distance using 1-bit + ex-bits (accurate)
461
+ float distance_to_code_full(const uint8_t* code) override {
462
+ FAISS_ASSERT(code != nullptr);
463
+ FAISS_ASSERT(
464
+ (metric_type == MetricType::METRIC_L2 ||
465
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
466
+ FAISS_ASSERT(rotated_q.size() == d);
482
467
 
483
- float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
484
- FAISS_ASSERT(code != nullptr);
485
- FAISS_ASSERT(
486
- (metric_type == MetricType::METRIC_L2 ||
487
- metric_type == MetricType::METRIC_INNER_PRODUCT));
488
- FAISS_ASSERT(rotated_q.size() == d);
468
+ size_t ex_bits = nb_bits - 1;
489
469
 
490
- size_t ex_bits = nb_bits - 1;
470
+ if (ex_bits == 0) {
471
+ // No ex-bits, just return 1-bit distance
472
+ return distance_to_code_1bit(code);
473
+ }
491
474
 
492
- if (ex_bits == 0) {
493
- // No ex-bits, just return 1-bit distance
494
- return distance_to_code_1bit(code);
475
+ // Extract pointers to code sections
476
+ const uint8_t* binary_data = code;
477
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
478
+ const uint8_t* ex_code = code + offset;
479
+ const ExtraBitsFactors* ex_fac =
480
+ reinterpret_cast<const ExtraBitsFactors*>(
481
+ ex_code + (d * ex_bits + 7) / 8);
482
+
483
+ float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
484
+ ? query_fac.q_dot_c
485
+ : query_fac.qr_to_c_L2sqr;
486
+ return rabitq_utils::compute_full_multibit_distance<SL>(
487
+ binary_data,
488
+ ex_code,
489
+ *ex_fac,
490
+ rotated_q.data(),
491
+ qr_base,
492
+ d,
493
+ ex_bits,
494
+ metric_type);
495
495
  }
496
496
 
497
- // Extract pointers to code sections
498
- const uint8_t* binary_data = code;
499
- size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
500
- const uint8_t* ex_code = code + offset;
501
- const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
502
- ex_code + (d * ex_bits + 7) / 8);
503
-
504
- // Call shared utility directly with rotated_q pointer
505
- return rabitq_utils::compute_full_multibit_distance(
506
- binary_data,
507
- ex_code,
508
- *ex_fac,
509
- rotated_q.data(),
510
- query_fac.qr_to_c_L2sqr,
511
- query_fac.qr_norm_L2sqr,
512
- d,
513
- ex_bits,
514
- metric_type);
515
- }
497
+ void set_query(const float* x) override {
498
+ q = x;
499
+ FAISS_ASSERT(x != nullptr);
500
+ FAISS_ASSERT(
501
+ (metric_type == MetricType::METRIC_L2 ||
502
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
503
+ FAISS_THROW_IF_NOT(qb <= 8);
504
+ FAISS_THROW_IF_NOT(qb > 0);
505
+
506
+ // Use shared utilities for core query factor computation
507
+ // rotated_q is populated directly by compute_query_factors as an
508
+ // output parameter
509
+ query_fac = rabitq_utils::compute_query_factors(
510
+ x,
511
+ d,
512
+ centroid,
513
+ qb,
514
+ centered,
515
+ metric_type,
516
+ rotated_q,
517
+ rotated_qq);
518
+
519
+ // Compute g_error (query norm for lower bound computation)
520
+ // g_error = ||qr - c|| (L2 norm of rotated query)
521
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
522
+
523
+ // Rearrange the query vector for SIMD operations
524
+ // (RaBitQuantizer-specific)
525
+ popcount_aligned_dim = ((d + 7) / 8) * 8;
526
+ size_t offset = (d + 7) / 8;
527
+
528
+ rearranged_rotated_qq.resize(offset * qb);
529
+ std::fill(
530
+ rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
531
+
532
+ for (size_t idim = 0; idim < d; idim++) {
533
+ for (size_t iv = 0; iv < qb; iv++) {
534
+ const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
535
+ rearranged_rotated_qq[iv * offset + idim / 8] |=
536
+ bit ? (1 << (idim % 8)) : 0;
537
+ }
538
+ }
539
+ }
540
+ };
516
541
 
517
542
  // Use shared constant from RaBitQUtils
518
543
  using rabitq_utils::Z_MAX_BY_QB;
519
544
 
520
- void RaBitQDistanceComputerQ::set_query(const float* x) {
521
- q = x;
522
- FAISS_ASSERT(x != nullptr);
523
- FAISS_ASSERT(
524
- (metric_type == MetricType::METRIC_L2 ||
525
- metric_type == MetricType::METRIC_INNER_PRODUCT));
526
- FAISS_THROW_IF_NOT(qb <= 8);
527
- FAISS_THROW_IF_NOT(qb > 0);
528
-
529
- // Use shared utilities for core query factor computation
530
- // rotated_q is populated directly by compute_query_factors as an output
531
- // parameter
532
- query_fac = rabitq_utils::compute_query_factors(
533
- x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
534
-
535
- // Compute g_error (query norm for lower bound computation)
536
- // g_error = ||qr - c|| (L2 norm of rotated query)
537
- g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
538
-
539
- // Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
540
- popcount_aligned_dim = ((d + 7) / 8) * 8;
541
- size_t offset = (d + 7) / 8;
542
-
543
- rearranged_rotated_qq.resize(offset * qb);
544
- std::fill(rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
545
-
546
- for (size_t idim = 0; idim < d; idim++) {
547
- for (size_t iv = 0; iv < qb; iv++) {
548
- const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
549
- rearranged_rotated_qq[iv * offset + idim / 8] |=
550
- bit ? (1 << (idim % 8)) : 0;
551
- }
552
- }
553
- }
554
-
555
545
  } // anonymous namespace
556
546
 
557
547
  FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
558
548
  uint8_t qb,
559
549
  const float* centroid_in,
560
550
  bool centered) const {
561
- if (qb == 0) {
562
- auto dc = std::make_unique<RaBitQDistanceComputerNotQ>();
563
- dc->metric_type = metric_type;
564
- dc->d = d;
565
- dc->centroid = centroid_in;
566
- dc->nb_bits = nb_bits;
567
-
568
- return dc.release();
569
- } else {
570
- auto dc = std::make_unique<RaBitQDistanceComputerQ>();
571
- dc->metric_type = metric_type;
572
- dc->d = d;
573
- dc->centroid = centroid_in;
574
- dc->qb = qb;
575
- dc->centered = centered;
576
- dc->nb_bits = nb_bits;
577
-
578
- return dc.release();
579
- }
551
+ // Dispatch on SIMDLevel once here so the distance computer methods
552
+ // call the SIMD-specialized rabitq functions directly (no per-call
553
+ // with_simd_level overhead).
554
+ return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
555
+ [&]<SIMDLevel SL>() -> FlatCodesDistanceComputer* {
556
+ if (qb == 0) {
557
+ auto dc =
558
+ std::make_unique<RaBitQDistanceComputerNotQ<SL>>();
559
+ dc->metric_type = metric_type;
560
+ dc->d = d;
561
+ dc->centroid = centroid_in;
562
+ dc->nb_bits = nb_bits;
563
+
564
+ return dc.release();
565
+ } else {
566
+ auto dc = std::make_unique<RaBitQDistanceComputerQ<SL>>();
567
+ dc->metric_type = metric_type;
568
+ dc->d = d;
569
+ dc->centroid = centroid_in;
570
+ dc->qb = qb;
571
+ dc->centered = centered;
572
+ dc->nb_bits = nb_bits;
573
+
574
+ return dc.release();
575
+ }
576
+ });
580
577
  }
581
578
 
582
579
  } // namespace faiss