faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -8,14 +8,17 @@
8
8
  #include <faiss/IndexIVFRaBitQFastScan.h>
9
9
 
10
10
  #include <algorithm>
11
+ #include <array>
11
12
  #include <cstdio>
13
+ #include <memory>
12
14
 
15
+ #include <faiss/impl/CodePackerRaBitQ.h>
13
16
  #include <faiss/impl/FaissAssert.h>
14
- #include <faiss/impl/FastScanDistancePostProcessing.h>
15
17
  #include <faiss/impl/RaBitQUtils.h>
16
18
  #include <faiss/impl/RaBitQuantizerMultiBit.h>
17
- #include <faiss/impl/pq4_fast_scan.h>
18
- #include <faiss/impl/simd_result_handlers.h>
19
+ #include <faiss/impl/ResultHandler.h>
20
+ #include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
21
+ #include <faiss/impl/fast_scan/fast_scan.h>
19
22
  #include <faiss/invlists/BlockInvertedLists.h>
20
23
  #include <faiss/utils/distances.h>
21
24
  #include <faiss/utils/utils.h>
@@ -39,31 +42,38 @@ inline size_t roundup(size_t a, size_t b) {
39
42
  IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan() = default;
40
43
 
41
44
  IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
42
- Index* quantizer,
43
- size_t d,
44
- size_t nlist,
45
+ Index* quantizer_in,
46
+ size_t d_in,
47
+ size_t nlist_in,
45
48
  MetricType metric,
46
- int bbs,
47
- bool own_invlists,
49
+ int bbs_in,
50
+ bool own_invlists_in,
48
51
  uint8_t nb_bits)
49
- : IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists),
50
- rabitq(d, metric, nb_bits) {
51
- FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
52
+ : IndexIVFFastScan(
53
+ quantizer_in,
54
+ d_in,
55
+ nlist_in,
56
+ 0,
57
+ metric,
58
+ own_invlists_in),
59
+ rabitq(d_in, metric, nb_bits) {
60
+ FAISS_THROW_IF_NOT_MSG(d_in > 0, "Dimension must be positive");
52
61
  FAISS_THROW_IF_NOT_MSG(
53
62
  metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
54
63
  "RaBitQ only supports L2 and Inner Product metrics");
55
- FAISS_THROW_IF_NOT_MSG(bbs % 32 == 0, "Batch size must be multiple of 32");
56
- FAISS_THROW_IF_NOT_MSG(quantizer != nullptr, "Quantizer cannot be null");
64
+ FAISS_THROW_IF_NOT_MSG(
65
+ bbs_in % 32 == 0, "Batch size must be multiple of 32");
66
+ FAISS_THROW_IF_NOT_MSG(quantizer_in != nullptr, "Quantizer cannot be null");
57
67
 
58
68
  by_residual = true;
59
69
  qb = 8; // RaBitQ quantization bits
60
70
  centered = false;
61
71
 
62
72
  // FastScan-specific parameters: 4 bits per sub-quantizer
63
- const size_t M_fastscan = (d + 3) / 4;
73
+ const size_t M_fastscan = (d_in + 3) / 4;
64
74
  constexpr size_t nbits_fastscan = 4;
65
75
 
66
- this->bbs = bbs;
76
+ this->bbs = bbs_in;
67
77
  this->fine_quantizer = &rabitq;
68
78
  this->M = M_fastscan;
69
79
  this->nbits = nbits_fastscan;
@@ -79,8 +89,6 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
79
89
  if (own_invlists) {
80
90
  replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
81
91
  }
82
-
83
- flat_storage.clear();
84
92
  }
85
93
 
86
94
  // Constructor that converts an existing IndexIVFRaBitQ to FastScan format
@@ -97,35 +105,11 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
97
105
  rabitq(orig.rabitq) {}
98
106
 
99
107
  size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
100
- const size_t ex_bits = rabitq.nb_bits - 1;
101
-
102
- if (ex_bits == 0) {
103
- // 1-bit: only SignBitFactors (8 bytes)
104
- return sizeof(SignBitFactors);
105
- } else {
106
- // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors + ex-codes
107
- return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
108
- (d * ex_bits + 7) / 8;
109
- }
108
+ return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
110
109
  }
111
110
 
112
- void IndexIVFRaBitQFastScan::preprocess_code_metadata(
113
- idx_t n,
114
- const uint8_t* flat_codes,
115
- idx_t start_global_idx) {
116
- // Unified approach: always use flat_storage for both 1-bit and multi-bit
117
- const size_t storage_size = compute_per_vector_storage_size();
118
- flat_storage.resize((start_global_idx + n) * storage_size);
119
-
120
- // Copy factors data directly to flat storage (no reordering needed)
121
- const size_t bit_pattern_size = (d + 7) / 8;
122
- for (idx_t i = 0; i < n; i++) {
123
- const uint8_t* code = flat_codes + i * code_size;
124
- const uint8_t* source_factors_ptr = code + bit_pattern_size;
125
- uint8_t* storage =
126
- flat_storage.data() + (start_global_idx + i) * storage_size;
127
- memcpy(storage, source_factors_ptr, storage_size);
128
- }
111
+ size_t IndexIVFRaBitQFastScan::fast_scan_code_size() const {
112
+ return (d + 7) / 8;
129
113
  }
130
114
 
131
115
  size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
@@ -133,6 +117,45 @@ size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
133
117
  return code_size;
134
118
  }
135
119
 
120
+ CodePacker* IndexIVFRaBitQFastScan::get_CodePacker() const {
121
+ return new CodePackerRaBitQ(M2, bbs, compute_per_vector_storage_size());
122
+ }
123
+
124
+ /*********************************************************
125
+ * postprocess_packed_codes: write auxiliary data into blocks
126
+ *********************************************************/
127
+
128
+ void IndexIVFRaBitQFastScan::postprocess_packed_codes(
129
+ idx_t list_no,
130
+ size_t list_offset,
131
+ size_t n_added,
132
+ const uint8_t* flat_codes) {
133
+ auto* bil = dynamic_cast<BlockInvertedLists*>(invlists);
134
+ FAISS_THROW_IF_NOT(bil);
135
+
136
+ uint8_t* block_data = bil->codes[list_no].data();
137
+ const size_t storage_size = compute_per_vector_storage_size();
138
+ const size_t bit_pattern_size = (d + 7) / 8;
139
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
140
+ const size_t full_block_size = get_block_stride();
141
+
142
+ for (size_t i = 0; i < n_added; i++) {
143
+ const uint8_t* src = flat_codes + i * code_size + bit_pattern_size;
144
+ uint8_t* dst = rabitq_utils::get_block_aux_ptr(
145
+ block_data,
146
+ list_offset + i,
147
+ bbs,
148
+ packed_block_size,
149
+ full_block_size,
150
+ storage_size);
151
+ memcpy(dst, src, storage_size);
152
+ }
153
+ }
154
+
155
+ /*********************************************************
156
+ * train_encoder
157
+ *********************************************************/
158
+
136
159
  void IndexIVFRaBitQFastScan::train_encoder(
137
160
  idx_t n,
138
161
  const float* x,
@@ -183,7 +206,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
183
206
  const size_t bit_pattern_size = (d + 7) / 8;
184
207
 
185
208
  // Pack sign bits directly into FastScan format (inline)
186
- for (size_t j = 0; j < d; j++) {
209
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
187
210
  const float or_minus_c = xi[j] - centroid[j];
188
211
  if (or_minus_c > 0.0f) {
189
212
  rabitq_utils::set_bit_fastscan(fastscan_code, j);
@@ -212,7 +235,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
212
235
 
213
236
  // Compute residual (needed for quantize_ex_bits)
214
237
  std::vector<float> residual(d);
215
- for (size_t j = 0; j < d; j++) {
238
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
216
239
  residual[j] = xi[j] - centroid[j];
217
240
  }
218
241
 
@@ -249,83 +272,133 @@ bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
249
272
  return true;
250
273
  }
251
274
 
275
+ // out[code] = base + sum of v_i for each set bit in code.
276
+ inline void write_subset_sum_lut(
277
+ float* out,
278
+ float base,
279
+ float v0,
280
+ float v1,
281
+ float v2,
282
+ float v3) {
283
+ out[0] = base;
284
+ out[1] = base + v0;
285
+ out[2] = base + v1;
286
+ out[3] = base + v0 + v1;
287
+ out[4] = base + v2;
288
+ out[5] = base + v0 + v2;
289
+ out[6] = base + v1 + v2;
290
+ out[7] = base + v0 + v1 + v2;
291
+ out[8] = base + v3;
292
+ out[9] = base + v0 + v3;
293
+ out[10] = base + v1 + v3;
294
+ out[11] = base + v0 + v1 + v3;
295
+ out[12] = base + v2 + v3;
296
+ out[13] = base + v0 + v2 + v3;
297
+ out[14] = base + v1 + v2 + v3;
298
+ out[15] = base + v0 + v1 + v2 + v3;
299
+ }
300
+
252
301
  // Computes lookup table for residual vectors in RaBitQ FastScan format
253
302
  void IndexIVFRaBitQFastScan::compute_residual_LUT(
254
- const float* residual,
303
+ const float* query,
304
+ idx_t centroid_id,
255
305
  QueryFactorsData& query_factors,
256
306
  float* lut_out,
257
- const float* original_query) const {
258
- FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
259
-
260
- std::vector<float> rotated_q(d);
261
- std::vector<uint8_t> rotated_qq(d);
307
+ uint8_t qb_param,
308
+ bool centered_param,
309
+ std::vector<float>& rotated_q,
310
+ std::vector<float>& centroid_buf) const {
311
+ const size_t d_val = static_cast<size_t>(d);
312
+ FAISS_THROW_IF_NOT(d_val > 0);
313
+ rotated_q.resize(d_val);
314
+ centroid_buf.resize(d_val);
315
+ std::vector<uint8_t> rotated_qq(d_val);
316
+
317
+ // Compute residual
318
+ quantizer->reconstruct(centroid_id, centroid_buf.data());
319
+ for (size_t i = 0; i < d_val; i++) {
320
+ rotated_q[i] = query[i] - centroid_buf[i];
321
+ }
262
322
 
263
- // Use RaBitQUtils to compute query factors - eliminates code duplication
323
+ // Compute query factors using shared utility
264
324
  query_factors = rabitq_utils::compute_query_factors(
265
- residual,
266
- d,
325
+ rotated_q.data(),
326
+ d_val,
267
327
  nullptr,
268
- qb,
269
- centered,
328
+ qb_param,
329
+ centered_param,
270
330
  metric_type,
271
331
  rotated_q,
272
332
  rotated_qq);
273
333
 
274
- // Override query norm for inner product if original query is provided
275
- if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
276
- original_query != nullptr) {
277
- query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
334
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
335
+ query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d_val);
336
+ query_factors.q_dot_c =
337
+ fvec_inner_product(query, centroid_buf.data(), d_val);
278
338
  }
279
339
 
280
- const size_t ex_bits = rabitq.nb_bits - 1;
281
- if (ex_bits > 0) {
340
+ if (rabitq.nb_bits > 1) {
282
341
  query_factors.rotated_q = rotated_q;
283
342
  }
284
343
 
285
- if (centered) {
286
- const float max_code_value = (1 << qb) - 1;
287
-
288
- for (size_t m = 0; m < M; m++) {
289
- const size_t dim_start = m * 4;
290
-
291
- for (int code_val = 0; code_val < 16; code_val++) {
292
- float xor_contribution = 0.0f;
344
+ // Build LUT using branchless subset-sum construction
345
+ const size_t d_sz = d_val;
293
346
 
294
- for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
295
- const size_t dim_idx = dim_start + dim_offset;
296
-
297
- if (dim_idx < d) {
298
- const bool db_bit = (code_val >> dim_offset) & 1;
299
- const float query_value = rotated_qq[dim_idx];
300
-
301
- xor_contribution += db_bit
302
- ? (max_code_value - query_value)
303
- : query_value;
304
- }
305
- }
347
+ if (centered_param) {
348
+ const float mcv = static_cast<float>((1 << qb_param) - 1);
306
349
 
307
- lut_out[m * 16 + code_val] = xor_contribution;
350
+ for (size_t m = 0; m < M; m++) {
351
+ const size_t ds = m * 4;
352
+ float* out = lut_out + m * 16;
353
+
354
+ float base = 0.0f;
355
+ float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
356
+ if (ds + 0 < d_sz) {
357
+ float q = rotated_qq[ds + 0];
358
+ base += q;
359
+ v0 = mcv - 2.0f * q;
360
+ }
361
+ if (ds + 1 < d_sz) {
362
+ float q = rotated_qq[ds + 1];
363
+ base += q;
364
+ v1 = mcv - 2.0f * q;
365
+ }
366
+ if (ds + 2 < d_sz) {
367
+ float q = rotated_qq[ds + 2];
368
+ base += q;
369
+ v2 = mcv - 2.0f * q;
370
+ }
371
+ if (ds + 3 < d_sz) {
372
+ float q = rotated_qq[ds + 3];
373
+ base += q;
374
+ v3 = mcv - 2.0f * q;
308
375
  }
376
+
377
+ write_subset_sum_lut(out, base, v0, v1, v2, v3);
309
378
  }
310
379
  } else {
311
- for (size_t m = 0; m < M; m++) {
312
- const size_t dim_start = m * 4;
380
+ const float c1 = query_factors.c1;
381
+ const float c2 = query_factors.c2;
313
382
 
314
- for (int code_val = 0; code_val < 16; code_val++) {
315
- float inner_product = 0.0f;
316
- int popcount = 0;
317
-
318
- for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
319
- const size_t dim_idx = dim_start + dim_offset;
383
+ for (size_t m = 0; m < M; m++) {
384
+ const size_t ds = m * 4;
385
+ float* out = lut_out + m * 16;
320
386
 
321
- if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
322
- inner_product += rotated_qq[dim_idx];
323
- popcount++;
324
- }
325
- }
326
- lut_out[m * 16 + code_val] = query_factors.c1 * inner_product +
327
- query_factors.c2 * popcount;
387
+ float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
388
+ if (ds + 0 < d_sz) {
389
+ v0 = c1 * rotated_qq[ds + 0] + c2;
390
+ }
391
+ if (ds + 1 < d_sz) {
392
+ v1 = c1 * rotated_qq[ds + 1] + c2;
393
+ }
394
+ if (ds + 2 < d_sz) {
395
+ v2 = c1 * rotated_qq[ds + 2] + c2;
328
396
  }
397
+ if (ds + 3 < d_sz) {
398
+ v3 = c1 * rotated_qq[ds + 3] + c2;
399
+ }
400
+
401
+ write_subset_sum_lut(out, 0.0f, v0, v1, v2, v3);
329
402
  }
330
403
  }
331
404
  }
@@ -347,18 +420,27 @@ void IndexIVFRaBitQFastScan::search_preassigned(
347
420
  !store_pairs, "store_pairs not supported for RaBitQFastScan");
348
421
  FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
349
422
 
350
- size_t nprobe = this->nprobe;
423
+ size_t cur_nprobe = this->nprobe;
424
+ uint8_t used_qb = qb;
425
+ bool used_centered = centered;
351
426
  if (params) {
352
427
  FAISS_THROW_IF_NOT(params->max_codes == 0);
353
- nprobe = params->nprobe;
428
+ cur_nprobe = params->nprobe;
429
+ if (auto rparams =
430
+ dynamic_cast<const IVFRaBitQSearchParameters*>(params)) {
431
+ used_qb = rparams->qb;
432
+ used_centered = rparams->centered;
433
+ }
354
434
  }
355
435
 
356
- std::vector<QueryFactorsData> query_factors_storage(n * nprobe);
436
+ std::vector<QueryFactorsData> query_factors_storage(n * cur_nprobe);
357
437
  FastScanDistancePostProcessing context;
358
438
  context.query_factors = query_factors_storage.data();
359
- context.nprobe = nprobe;
439
+ context.nprobe = cur_nprobe;
440
+ context.qb = used_qb;
441
+ context.centered = used_centered;
360
442
 
361
- const CoarseQuantized cq = {nprobe, centroid_dis, assign};
443
+ const CoarseQuantized cq = {cur_nprobe, centroid_dis, assign};
362
444
  search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
363
445
  }
364
446
 
@@ -372,44 +454,165 @@ void IndexIVFRaBitQFastScan::compute_LUT(
372
454
  FAISS_THROW_IF_NOT(is_trained);
373
455
  FAISS_THROW_IF_NOT(by_residual);
374
456
 
375
- size_t nprobe = cq.nprobe;
457
+ // Use overridden qb/centered from context if provided, else index defaults
458
+ const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
459
+ const bool used_centered = context.qb > 0 ? context.centered : centered;
460
+
461
+ size_t cq_nprobe = cq.nprobe;
376
462
 
377
463
  size_t dim12 = 16 * M;
378
464
 
379
- dis_tables.resize(n * nprobe * dim12);
380
- biases.resize(n * nprobe);
465
+ dis_tables.resize(n * cq_nprobe * dim12);
466
+ biases.resize(n * cq_nprobe);
381
467
 
382
- if (n * nprobe > 0) {
383
- memset(biases.get(), 0, sizeof(float) * n * nprobe);
468
+ if (n * cq_nprobe > 0) {
469
+ memset(biases.get(), 0, sizeof(float) * n * cq_nprobe);
384
470
  }
385
- std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
471
+ // Use per-thread buffers instead of one O(n * nprobe * d) allocation.
472
+ // rotated_q / centroid_buf keep their capacity across iterations so the
473
+ // allocator is only hit once per thread.
474
+ #pragma omp parallel if (n * cq_nprobe > 1000)
475
+ {
476
+ std::vector<float> rotated_q(d);
477
+ std::vector<float> centroid_buf(d);
478
+
479
+ #pragma omp for
480
+ for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
481
+ idx_t i = ij / cq_nprobe;
482
+ idx_t cij = cq.ids[ij];
483
+
484
+ if (cij >= 0) {
485
+ QueryFactorsData query_factors_data;
486
+
487
+ compute_residual_LUT(
488
+ x + i * d,
489
+ cij,
490
+ query_factors_data,
491
+ dis_tables.get() + ij * dim12,
492
+ used_qb,
493
+ used_centered,
494
+ rotated_q,
495
+ centroid_buf);
496
+
497
+ if (context.query_factors != nullptr) {
498
+ context.query_factors[ij] = std::move(query_factors_data);
499
+ }
386
500
 
387
- #pragma omp parallel for if (n * nprobe > 1000)
388
- for (idx_t ij = 0; ij < n * nprobe; ij++) {
389
- idx_t i = ij / nprobe;
390
- float* xij = &xrel[ij * d];
391
- idx_t cij = cq.ids[ij];
501
+ } else {
502
+ memset(dis_tables.get() + ij * dim12, 0, sizeof(float) * dim12);
503
+ }
504
+ }
505
+ }
506
+ }
392
507
 
393
- if (cij >= 0) {
394
- quantizer->compute_residual(x + i * d, xij, cij);
508
+ void IndexIVFRaBitQFastScan::compute_LUT_uint8(
509
+ size_t n,
510
+ const float* x,
511
+ const CoarseQuantized& cq,
512
+ AlignedTable<uint8_t>& dis_tables,
513
+ AlignedTable<uint16_t>& biases,
514
+ float* normalizers,
515
+ const FastScanDistancePostProcessing& context) const {
516
+ FAISS_THROW_IF_NOT(is_trained);
517
+ FAISS_THROW_IF_NOT(by_residual);
395
518
 
396
- // Create QueryFactorsData for this query-list combination
397
- QueryFactorsData query_factors_data;
519
+ const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
520
+ const bool used_centered = context.qb > 0 ? context.centered : centered;
521
+ const size_t cur_nprobe = cq.nprobe;
522
+ const size_t dim12 = 16 * M;
523
+ const size_t dim12_2 = 16 * M2;
398
524
 
399
- compute_residual_LUT(
400
- xij,
401
- query_factors_data,
402
- dis_tables.get() + ij * dim12,
403
- x + i * d);
525
+ // Allocate only the uint8 output table (no full float table)
526
+ dis_tables.resize(n * cur_nprobe * dim12_2);
527
+ biases.resize(n * cur_nprobe);
404
528
 
405
- // Store query factors using compact indexing (ij directly)
406
- if (context.query_factors != nullptr) {
407
- context.query_factors[ij] = query_factors_data;
529
+ #pragma omp parallel if (n > 1)
530
+ {
531
+ // Per-thread buffers reused across queries
532
+ AlignedTable<float> lut_float(cur_nprobe * dim12);
533
+ std::vector<float> rotated_q(d);
534
+ std::vector<float> centroid_buf(d);
535
+ std::vector<float> all_mins(cur_nprobe * M);
536
+ std::vector<float> probe_b(cur_nprobe);
537
+
538
+ #pragma omp for schedule(dynamic)
539
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
540
+ const float* xi = x + i * d;
541
+
542
+ // Compute float LUT for all probes using fused path
543
+ for (size_t j = 0; j < cur_nprobe; j++) {
544
+ const size_t ij = i * cur_nprobe + j;
545
+ idx_t cij = cq.ids[ij];
546
+
547
+ if (cij >= 0) {
548
+ QueryFactorsData qf;
549
+ compute_residual_LUT(
550
+ xi,
551
+ cij,
552
+ qf,
553
+ lut_float.get() + j * dim12,
554
+ used_qb,
555
+ used_centered,
556
+ rotated_q,
557
+ centroid_buf);
558
+
559
+ if (context.query_factors != nullptr) {
560
+ context.query_factors[ij] = qf;
561
+ }
562
+ } else {
563
+ memset(lut_float.get() + j * dim12,
564
+ 0,
565
+ sizeof(float) * dim12);
566
+ }
408
567
  }
409
568
 
410
- } else {
411
- memset(xij, -1, sizeof(float) * d);
412
- memset(dis_tables.get() + ij * dim12, -1, sizeof(float) * dim12);
569
+ // Quantize float LUT to uint8 inline.
570
+ // Mirrors quantize_LUT_and_bias 3D path with zero biases.
571
+ // Single pass: find per-sub-q mins, max span, and per-probe b.
572
+ float glob_max_span = -HUGE_VAL;
573
+ float glob_max_dis = -HUGE_VAL;
574
+ float glob_b = HUGE_VAL;
575
+ for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
576
+ float b_j = 0;
577
+ float span_j = 0;
578
+ for (size_t m = 0; m < M; m++) {
579
+ const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
580
+ float mn = tab[0], mx = tab[0];
581
+ for (size_t s = 1; s < ksub; s++) {
582
+ mn = std::min(mn, tab[s]);
583
+ mx = std::max(mx, tab[s]);
584
+ }
585
+ all_mins[j2 * M + m] = mn;
586
+ float span = mx - mn;
587
+ glob_max_span = std::max(glob_max_span, span);
588
+ b_j += mn;
589
+ span_j += span;
590
+ }
591
+ probe_b[j2] = b_j;
592
+ glob_max_dis = std::max(glob_max_dis, span_j);
593
+ glob_b = std::min(glob_b, b_j);
594
+ }
595
+ float a = std::min(255.0f / glob_max_span, 65535.0f / glob_max_dis);
596
+
597
+ // Second pass: quantize LUT and compute biasq
598
+ uint8_t* out_base = dis_tables.get() + i * cur_nprobe * dim12_2;
599
+ uint16_t* bq = biases.get() + i * cur_nprobe;
600
+ for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
601
+ for (size_t m = 0; m < M; m++) {
602
+ const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
603
+ float mn = all_mins[j2 * M + m];
604
+ uint8_t* out = out_base + j2 * dim12_2 + m * ksub;
605
+ for (size_t s = 0; s < ksub; s++) {
606
+ out[s] = static_cast<uint8_t>(
607
+ std::roundf(a * (tab[s] - mn)));
608
+ }
609
+ }
610
+ memset(out_base + j2 * dim12_2 + M * ksub, 0, (M2 - M) * ksub);
611
+ bq[j2] = static_cast<uint16_t>(
612
+ std::roundf(a * (probe_b[j2] - glob_b)));
613
+ }
614
+ normalizers[2 * i] = a;
615
+ normalizers[2 * i + 1] = glob_b;
413
616
  }
414
617
  }
415
618
  }
@@ -441,23 +644,22 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
441
644
  }
442
645
  }
443
646
 
444
- // Get dp_multiplier directly from flat_storage
445
- InvertedLists::ScopedIds list_ids(invlists, list_no);
446
- idx_t global_id = list_ids[offset];
447
-
448
- float dp_multiplier = 1.0f;
449
- if (global_id >= 0) {
450
- const size_t storage_size = compute_per_vector_storage_size();
451
- const size_t storage_capacity = flat_storage.size() / storage_size;
452
-
453
- if (static_cast<size_t>(global_id) < storage_capacity) {
454
- const uint8_t* base_ptr =
455
- flat_storage.data() + global_id * storage_size;
456
- const auto& base_factors =
457
- *reinterpret_cast<const SignBitFactors*>(base_ptr);
458
- dp_multiplier = base_factors.dp_multiplier;
459
- }
460
- }
647
+ const size_t storage_size = compute_per_vector_storage_size();
648
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
649
+ const size_t full_block_size = get_block_stride();
650
+
651
+ InvertedLists::ScopedCodes list_block_codes(invlists, list_no);
652
+ const uint8_t* aux_ptr = rabitq_utils::get_block_aux_ptr(
653
+ list_block_codes.get(),
654
+ offset,
655
+ bbs,
656
+ packed_block_size,
657
+ full_block_size,
658
+ storage_size);
659
+
660
+ const auto& base_factors =
661
+ *reinterpret_cast<const SignBitFactors*>(aux_ptr);
662
+ const float dp_multiplier = base_factors.dp_multiplier;
461
663
 
462
664
  // Decode residual directly using dp_multiplier
463
665
  std::vector<float> residual(d);
@@ -465,7 +667,7 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
465
667
  fastscan_code.data(), residual.data(), dp_multiplier);
466
668
 
467
669
  // Reconstruct: x = centroid + residual
468
- for (size_t j = 0; j < d; j++) {
670
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
469
671
  recons[j] = centroid[j] + residual[j];
470
672
  }
471
673
  }
@@ -490,7 +692,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
490
692
 
491
693
  idx_t list_no = decode_listno(code_i);
492
694
 
493
- if (list_no >= 0 && list_no < nlist) {
695
+ if (list_no >= 0 && list_no < static_cast<idx_t>(nlist)) {
494
696
  quantizer->reconstruct(list_no, centroid.data());
495
697
 
496
698
  const uint8_t* fastscan_code = code_i + coarse_size;
@@ -502,7 +704,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
502
704
  decode_fastscan_to_residual(
503
705
  fastscan_code, residual.data(), base_factors.dp_multiplier);
504
706
 
505
- for (size_t j = 0; j < d; j++) {
707
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
506
708
  x_i[j] = centroid[j] + residual[j];
507
709
  }
508
710
  } else {
@@ -519,7 +721,7 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
519
721
 
520
722
  const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
521
723
 
522
- for (size_t j = 0; j < d; j++) {
724
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
523
725
  bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
524
726
 
525
727
  float bit_as_float = bit_value ? 1.0f : 0.0f;
@@ -527,302 +729,248 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
527
729
  }
528
730
  }
529
731
 
530
- // Implementation of virtual make_knn_handler method
531
- SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
732
+ std::unique_ptr<FastScanCodeScanner> IndexIVFRaBitQFastScan::make_knn_scanner(
532
733
  bool is_max,
533
- int /* impl */,
534
734
  idx_t n,
535
735
  idx_t k,
536
736
  float* distances,
537
737
  idx_t* labels,
538
- const IDSelector* /* sel */,
539
- const FastScanDistancePostProcessing& context,
540
- const float* /* normalizers */) const {
541
- const size_t ex_bits = rabitq.nb_bits - 1;
542
- const bool is_multibit = ex_bits > 0;
543
-
544
- if (is_max) {
545
- return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
546
- this, n, k, distances, labels, &context, is_multibit);
547
- } else {
548
- return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
549
- this, n, k, distances, labels, &context, is_multibit);
550
- }
738
+ const IDSelector* sel,
739
+ int /*impl*/,
740
+ const FastScanDistancePostProcessing& context) const {
741
+ const bool is_multibit = (rabitq.nb_bits - 1) > 0;
742
+ return rabitq_ivf_make_knn_scanner(
743
+ is_max, this, n, k, distances, labels, sel, &context, is_multibit);
551
744
  }
552
745
 
553
746
  /*********************************************************
554
- * IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler implementation
747
+ * IVFRaBitQFastScanScanner implementation
555
748
  *********************************************************/
556
749
 
557
- template <class C>
558
- IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
559
- const IndexIVFRaBitQFastScan* idx,
560
- size_t nq_val,
561
- size_t k_val,
562
- float* distances,
563
- int64_t* labels,
564
- const FastScanDistancePostProcessing* ctx,
565
- bool multibit)
566
- : simd_result_handlers::ResultHandlerCompare<C, true>(
567
- nq_val,
568
- 0,
569
- nullptr),
570
- index(idx),
571
- heap_distances(distances),
572
- heap_labels(labels),
573
- nq(nq_val),
574
- k(k_val),
575
- context(ctx),
576
- is_multibit(multibit) {
577
- current_list_no = 0;
578
- probe_indices.clear();
579
-
580
- // Initialize heaps in constructor (standard pattern from HeapHandler)
581
- for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
582
- float* heap_dis = heap_distances + q * k;
583
- int64_t* heap_ids = heap_labels + q * k;
584
- heap_heapify<Cfloat>(k, heap_dis, heap_ids);
585
- }
586
- }
587
-
588
- template <class C>
589
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
590
- size_t q,
591
- size_t b,
592
- simd16uint16 d0,
593
- simd16uint16 d1) {
594
- // Store the original local query index before adjust_with_origin changes it
595
- size_t local_q = q;
596
- this->adjust_with_origin(q, d0, d1);
597
-
598
- ALIGNED(32) uint16_t d32tab[32];
599
- d0.store(d32tab);
600
- d1.store(d32tab + 16);
601
-
602
- float* const heap_dis = heap_distances + q * k;
603
- int64_t* const heap_ids = heap_labels + q * k;
604
-
605
- FAISS_THROW_IF_NOT_FMT(
606
- !probe_indices.empty() && local_q < probe_indices.size(),
607
- "set_list_context() must be called before handle() - probe_indices size: %zu, local_q: %zu, global_q: %zu",
608
- probe_indices.size(),
609
- local_q,
610
- q);
611
-
612
- // Access query factors directly from array via ProcessingContext
613
- if (!context || !context->query_factors) {
614
- FAISS_THROW_MSG(
615
- "Query factors not available: FastScanDistancePostProcessing with query_factors required");
616
- }
750
+ namespace {
617
751
 
618
- // Use probe_rank from probe_indices for compact storage indexing
619
- size_t probe_rank = probe_indices[local_q];
620
- size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
621
- size_t storage_idx = q * nprobe + probe_rank;
752
+ /// Provides IVF scanner interface using FastScan's SIMD batch processing.
753
+ /// Buffers are allocated once and reused across set_list + scan_codes calls.
754
+ struct IVFRaBitQFastScanScanner : InvertedListScanner {
755
+ using InvertedListScanner::scan_codes;
756
+ static constexpr size_t nq = 1;
622
757
 
623
- const auto& query_factors = context->query_factors[storage_idx];
758
+ const IndexIVFRaBitQFastScan& index;
759
+ const uint8_t qb;
760
+ const bool centered;
624
761
 
625
- const float one_a =
626
- this->normalizers ? (1.0f / this->normalizers[2 * q]) : 1.0f;
627
- const float bias = this->normalizers ? this->normalizers[2 * q + 1] : 0.0f;
762
+ const float* xi = nullptr;
628
763
 
629
- uint64_t idx_base = this->j0 + b * 32;
630
- if (idx_base >= this->ntotal) {
631
- return;
764
+ // Reusable buffers (allocated once in constructor)
765
+ AlignedTable<uint8_t> dis_tables;
766
+ AlignedTable<uint16_t> biases;
767
+ std::array<float, 2> normalizers{};
768
+ AlignedTable<float> lut_float;
769
+ std::vector<float> rotated_q;
770
+ std::vector<float> centroid_buf;
771
+ QueryFactorsData query_factors;
772
+ FastScanDistancePostProcessing context;
773
+ std::vector<int> probe_map;
774
+ std::vector<float> mins_buf;
775
+
776
+ // Distance computer for distance_to_code (created in set_list)
777
+ std::unique_ptr<FlatCodesDistanceComputer> dc;
778
+
779
+ IVFRaBitQFastScanScanner(
780
+ const IndexIVFRaBitQFastScan& index_in,
781
+ bool store_pairs_in,
782
+ const IDSelector* sel_in,
783
+ uint8_t qb_in,
784
+ bool centered_in)
785
+ : InvertedListScanner(store_pairs_in, sel_in),
786
+ index(index_in),
787
+ qb(qb_in),
788
+ centered(centered_in),
789
+ lut_float(16 * index_in.M),
790
+ rotated_q(index_in.d),
791
+ centroid_buf(index_in.d),
792
+ probe_map({0}),
793
+ mins_buf(index_in.M) {
794
+ this->keep_max = is_similarity_metric(index_in.metric_type);
795
+ this->code_size = index_in.code_size;
796
+
797
+ // Pre-allocate output tables for single probe
798
+ dis_tables.resize(16 * index_in.M2);
799
+ biases.resize(1);
800
+
801
+ // Set up context once
802
+ context.query_factors = &query_factors;
803
+ context.nprobe = 1;
804
+ context.qb = qb;
805
+ context.centered = centered;
632
806
  }
633
807
 
634
- size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
635
-
636
- // Stats tracking for two-stage search
637
- // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
638
- // n_multibit_evaluations: candidates requiring full multi-bit distance
639
- size_t local_1bit_evaluations = 0;
640
- size_t local_multibit_evaluations = 0;
808
+ void set_query(const float* query) override {
809
+ this->xi = query;
810
+ }
641
811
 
642
- // Process each candidate vector in the SIMD batch
643
- for (size_t j = 0; j < max_positions; j++) {
644
- const int64_t result_id = this->adjust_id(b, j);
812
+ void set_list(idx_t list_no_in, float /*coarse_dis_in*/) override {
813
+ this->list_no = list_no_in;
814
+
815
+ index.compute_residual_LUT(
816
+ xi,
817
+ list_no_in,
818
+ query_factors,
819
+ lut_float.get(),
820
+ qb,
821
+ centered,
822
+ rotated_q,
823
+ centroid_buf);
824
+
825
+ // Single-probe quantization (simplified inline, no OMP, no 3D)
826
+ const size_t M = index.M;
827
+ const size_t M2 = index.M2;
828
+ const size_t ksub = index.ksub;
829
+
830
+ float max_span = -HUGE_VAL;
831
+ float max_dis = 0;
832
+ float b = 0;
833
+ float* mins = mins_buf.data();
645
834
 
646
- if (result_id < 0) {
647
- continue;
835
+ for (size_t m = 0; m < M; m++) {
836
+ const float* tab = lut_float.get() + m * ksub;
837
+ float mn = tab[0], mx = tab[0];
838
+ for (size_t s = 1; s < ksub; s++) {
839
+ mn = std::min(mn, tab[s]);
840
+ mx = std::max(mx, tab[s]);
841
+ }
842
+ mins[m] = mn;
843
+ float span = mx - mn;
844
+ max_span = std::max(max_span, span);
845
+ max_dis += span;
846
+ b += mn;
648
847
  }
649
848
 
650
- const float normalized_distance = d32tab[j] * one_a + bias;
651
-
652
- // Get database factors from flat_storage
653
- const size_t storage_size = index->compute_per_vector_storage_size();
654
- const uint8_t* base_ptr =
655
- index->flat_storage.data() + result_id * storage_size;
656
-
657
- if (is_multibit) {
658
- // Track candidates actually considered for two-stage filtering
659
- local_1bit_evaluations++;
660
-
661
- // Multi-bit: use SignBitFactorsWithError and two-stage search
662
- const SignBitFactorsWithError& full_factors =
663
- *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
664
-
665
- // Compute 1-bit adjusted distance using shared helper
666
- float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
667
- normalized_distance,
668
- full_factors,
669
- query_factors,
670
- index->centered,
671
- index->qb,
672
- index->d);
673
-
674
- // Compute lower bound using error bound
675
- float lower_bound =
676
- compute_lower_bound(dist_1bit, result_id, local_q, q);
677
-
678
- // Adaptive filtering: decide whether to compute full distance
679
- const bool is_similarity =
680
- index->metric_type == MetricType::METRIC_INNER_PRODUCT;
681
- bool should_refine = is_similarity
682
- ? (lower_bound > heap_dis[0]) // IP: keep if better
683
- : (lower_bound < heap_dis[0]); // L2: keep if better
684
-
685
- if (should_refine) {
686
- local_multibit_evaluations++;
687
-
688
- // Compute local_offset: position within current inverted list
689
- size_t local_offset = this->j0 + b * 32 + j;
690
-
691
- // Compute full multi-bit distance
692
- float dist_full = compute_full_multibit_distance(
693
- result_id, local_q, q, local_offset);
694
-
695
- // Update heap if this distance is better
696
- if (Cfloat::cmp(heap_dis[0], dist_full)) {
697
- heap_replace_top<Cfloat>(
698
- k, heap_dis, heap_ids, dist_full, result_id);
699
- }
700
- }
701
- } else {
702
- const auto& db_factors =
703
- *reinterpret_cast<const SignBitFactors*>(base_ptr);
704
-
705
- // Compute adjusted distance using shared helper
706
- float adjusted_distance =
707
- rabitq_utils::compute_1bit_adjusted_distance(
708
- normalized_distance,
709
- db_factors,
710
- query_factors,
711
- index->centered,
712
- index->qb,
713
- index->d);
714
-
715
- if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
716
- heap_replace_top<Cfloat>(
717
- k, heap_dis, heap_ids, adjusted_distance, result_id);
849
+ float a = std::min(255.0f / max_span, 65535.0f / max_dis);
850
+ uint8_t* out = dis_tables.get();
851
+ for (size_t m = 0; m < M; m++) {
852
+ const float* tab = lut_float.get() + m * ksub;
853
+ for (size_t s = 0; s < ksub; s++) {
854
+ out[m * ksub + s] = static_cast<uint8_t>(
855
+ std::roundf(a * (tab[s] - mins[m])));
718
856
  }
719
857
  }
858
+ memset(out + M * ksub, 0, (M2 - M) * ksub);
859
+ biases[0] = 0;
860
+ normalizers[0] = a;
861
+ normalizers[1] = b;
862
+
863
+ // Create distance computer (reuses centroid_buf from
864
+ // compute_residual_LUT)
865
+ dc.reset(index.rabitq.get_distance_computer(
866
+ qb, centroid_buf.data(), centered));
867
+ dc->set_query(xi);
720
868
  }
721
869
 
722
- // Update global stats atomically
723
- #pragma omp atomic
724
- rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
725
- #pragma omp atomic
726
- rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
727
- }
870
+ float distance_to_code(const uint8_t* code) const override {
871
+ return dc->distance_to_code(code);
872
+ }
728
873
 
729
- template <class C>
730
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
731
- size_t list_no,
732
- const std::vector<int>& probe_map) {
733
- current_list_no = list_no;
734
- probe_indices = probe_map;
735
- }
874
+ size_t scan_codes(
875
+ size_t ntotal,
876
+ const uint8_t* codes,
877
+ const idx_t* ids,
878
+ ResultHandler& result_handler) const override {
879
+ auto scan_with_heap = [&](auto* heap_handler) -> size_t {
880
+ const size_t k = heap_handler->k;
881
+ if (k == 0) {
882
+ return 0;
883
+ }
736
884
 
737
- template <class C>
738
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::begin(
739
- const float* norms) {
740
- this->normalizers = norms;
741
- }
885
+ std::vector<float> curr_dists(k, result_handler.threshold);
886
+ std::vector<idx_t> curr_labels(k, -1);
887
+
888
+ auto scanner = index.make_knn_scanner(
889
+ !keep_max,
890
+ nq,
891
+ k,
892
+ curr_dists.data(),
893
+ curr_labels.data(),
894
+ sel,
895
+ 0,
896
+ context);
897
+ auto* handler = scanner->handler();
898
+
899
+ int qmap1[1] = {0};
900
+ handler->q_map = qmap1;
901
+ handler->begin(&normalizers[0]);
902
+ handler->dbias = biases.get();
903
+ handler->ntotal = ntotal;
904
+ handler->id_map = ids;
905
+
906
+ handler->set_list_context(list_no, probe_map);
907
+ if (!handler->list_codes_ptr) {
908
+ handler->list_codes_ptr = codes;
909
+ }
742
910
 
743
- template <class C>
744
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
745
- #pragma omp parallel for
746
- for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
747
- float* heap_dis = heap_distances + q * k;
748
- int64_t* heap_ids = heap_labels + q * k;
749
- heap_reorder<Cfloat>(k, heap_dis, heap_ids);
750
- }
751
- }
911
+ scanner->accumulate_loop(
912
+ 1,
913
+ roundup(ntotal, index.bbs),
914
+ index.bbs,
915
+ static_cast<int>(index.M2),
916
+ codes,
917
+ dis_tables.get(),
918
+ 0,
919
+ index.get_block_stride());
920
+
921
+ const size_t scan_cnt = handler->count_scanned_rows();
922
+ handler->end();
923
+
924
+ result_handler.stats.scan_cnt += scan_cnt;
925
+ size_t nup = 0;
926
+ for (size_t j = 0; j < k; j++) {
927
+ if (curr_labels[j] < 0) {
928
+ continue;
929
+ }
930
+ if (result_handler.add_result(curr_dists[j], curr_labels[j])) {
931
+ result_handler.stats.nheap_updates++;
932
+ nup++;
933
+ }
934
+ }
935
+ return nup;
936
+ };
937
+
938
+ if (!keep_max) {
939
+ using C = CMax<float, idx_t>;
940
+ if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
941
+ &result_handler)) {
942
+ return scan_with_heap(heap_handler);
943
+ }
944
+ } else {
945
+ using C = CMin<float, idx_t>;
946
+ if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
947
+ &result_handler)) {
948
+ return scan_with_heap(heap_handler);
949
+ }
950
+ }
752
951
 
753
- template <class C>
754
- float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::compute_lower_bound(
755
- float dist_1bit,
756
- size_t db_idx,
757
- size_t local_q,
758
- size_t global_q) const {
759
- // Access f_error from SignBitFactorsWithError in flat storage
760
- const size_t storage_size = index->compute_per_vector_storage_size();
761
- const uint8_t* base_ptr =
762
- index->flat_storage.data() + db_idx * storage_size;
763
- const SignBitFactorsWithError& db_factors =
764
- *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
765
- float f_error = db_factors.f_error;
766
-
767
- // Get g_error from query factors
768
- // Use local_q to access probe_indices (batch-local), global_q for storage
769
- float g_error = 0.0f;
770
- if (context && context->query_factors) {
771
- size_t probe_rank = probe_indices[local_q];
772
- size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
773
- size_t storage_idx = global_q * nprobe + probe_rank;
774
- g_error = context->query_factors[storage_idx].g_error;
952
+ FAISS_THROW_MSG(
953
+ "IVFRaBitQFastScanScanner::scan_codes requires "
954
+ "HeapResultHandler; custom ResultHandler scan is not supported "
955
+ "by this optimized scanner");
775
956
  }
957
+ };
776
958
 
777
- // Compute error adjustment: f_error * g_error
778
- float error_adjustment = f_error * g_error;
959
+ } // anonymous namespace
779
960
 
780
- return dist_1bit - error_adjustment;
781
- }
782
-
783
- template <class C>
784
- float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
785
- compute_full_multibit_distance(
786
- size_t db_idx,
787
- size_t local_q,
788
- size_t global_q,
789
- size_t local_offset) const {
790
- const size_t ex_bits = index->rabitq.nb_bits - 1;
791
- const size_t dim = index->d;
792
-
793
- const size_t storage_size = index->compute_per_vector_storage_size();
794
- const uint8_t* base_ptr =
795
- index->flat_storage.data() + db_idx * storage_size;
796
-
797
- const size_t ex_code_size = (dim * ex_bits + 7) / 8;
798
- const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
799
- const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
800
- base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
801
-
802
- // Use local_q to access probe_indices (batch-local), global_q for storage
803
- size_t probe_rank = probe_indices[local_q];
804
- size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
805
- size_t storage_idx = global_q * nprobe + probe_rank;
806
- const auto& query_factors = context->query_factors[storage_idx];
807
-
808
- size_t list_no = current_list_no;
809
- InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
810
-
811
- std::vector<uint8_t> unpacked_code(index->code_size);
812
- CodePackerPQ4 packer(index->M2, index->bbs);
813
- packer.unpack_1(list_codes.get(), local_offset, unpacked_code.data());
814
- const uint8_t* sign_bits = unpacked_code.data();
815
-
816
- return rabitq_utils::compute_full_multibit_distance(
817
- sign_bits,
818
- ex_code,
819
- ex_fac,
820
- query_factors.rotated_q.data(),
821
- query_factors.qr_to_c_L2sqr,
822
- query_factors.qr_norm_L2sqr,
823
- dim,
824
- ex_bits,
825
- index->metric_type);
961
+ InvertedListScanner* IndexIVFRaBitQFastScan::get_InvertedListScanner(
962
+ bool store_pairs,
963
+ const IDSelector* sel,
964
+ const IVFSearchParameters* search_params_in) const {
965
+ uint8_t used_qb = qb;
966
+ bool used_centered = centered;
967
+ if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
968
+ search_params_in)) {
969
+ used_qb = params->qb;
970
+ used_centered = params->centered;
971
+ }
972
+ return new IVFRaBitQFastScanScanner(
973
+ *this, store_pairs, sel, used_qb, used_centered);
826
974
  }
827
975
 
828
976
  } // namespace faiss