faiss 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (361) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +2 -1
  4. data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
  5. data/ext/faiss/index_binary.cpp +1 -1
  6. data/ext/faiss/kmeans.cpp +1 -1
  7. data/ext/faiss/pca_matrix.cpp +1 -1
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
  10. data/lib/faiss/version.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +93 -80
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -240
  13. data/vendor/faiss/faiss/Clustering.h +6 -0
  14. data/vendor/faiss/faiss/IVFlib.cpp +41 -21
  15. data/vendor/faiss/faiss/Index.cpp +6 -5
  16. data/vendor/faiss/faiss/Index.h +5 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  21. data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
  22. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
  32. data/vendor/faiss/faiss/IndexFastScan.h +25 -23
  33. data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
  34. data/vendor/faiss/faiss/IndexFlat.h +21 -18
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
  36. data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
  37. data/vendor/faiss/faiss/IndexHNSW.h +16 -2
  38. data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
  39. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  40. data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
  41. data/vendor/faiss/faiss/IndexIVF.h +33 -12
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
  45. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
  46. data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
  47. data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
  48. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  49. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
  50. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
  53. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  55. data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
  56. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
  57. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
  58. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
  59. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  62. data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
  63. data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  66. data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
  67. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
  68. data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
  69. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  73. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  74. data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
  75. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
  76. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
  77. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
  78. data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
  79. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  80. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  84. data/vendor/faiss/faiss/IndexShards.cpp +10 -9
  85. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  86. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  88. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  89. data/vendor/faiss/faiss/MetricType.h +14 -7
  90. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  91. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  92. data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
  93. data/vendor/faiss/faiss/VectorTransform.h +16 -16
  94. data/vendor/faiss/faiss/build.cpp +23 -0
  95. data/vendor/faiss/faiss/build.h +15 -0
  96. data/vendor/faiss/faiss/clone_index.cpp +48 -47
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  101. data/vendor/faiss/faiss/factory_tools.cpp +5 -0
  102. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  109. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  110. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  111. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  112. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  113. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  114. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  115. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  116. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  117. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  118. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  119. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  120. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  121. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  122. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  123. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  125. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
  127. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  128. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  129. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
  130. data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
  131. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
  132. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  133. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
  134. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  135. data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
  136. data/vendor/faiss/faiss/impl/HNSW.h +13 -34
  137. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  138. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  139. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
  141. data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
  142. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  143. data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
  144. data/vendor/faiss/faiss/impl/NSG.h +4 -4
  145. data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
  146. data/vendor/faiss/faiss/impl/Panorama.h +258 -87
  147. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  148. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  149. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
  150. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  151. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  152. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  153. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
  154. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  155. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
  156. data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
  157. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
  158. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
  159. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  160. data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
  161. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
  162. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
  163. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  164. data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
  165. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  166. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  167. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  168. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  169. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  170. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  171. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  172. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  173. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  174. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  175. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  176. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  177. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  178. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  179. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  180. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  181. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  182. data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
  183. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  184. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  185. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  186. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  187. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  188. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  189. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
  190. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  191. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  192. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  193. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  194. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  195. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  196. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
  197. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  198. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  199. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
  200. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  201. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  202. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  203. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  204. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  205. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  206. data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
  207. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
  208. data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
  209. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  210. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  211. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  212. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
  213. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  214. data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
  215. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  216. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  217. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  218. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  219. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  220. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  221. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  222. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  223. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
  224. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
  225. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
  226. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  227. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
  228. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
  229. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  230. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
  231. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  232. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
  233. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
  234. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
  235. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
  236. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
  237. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
  238. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
  239. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
  240. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
  241. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  242. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
  243. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
  244. data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
  245. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  246. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
  247. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  248. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  249. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  250. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
  251. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  252. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  253. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  254. data/vendor/faiss/faiss/index_factory.cpp +86 -18
  255. data/vendor/faiss/faiss/index_io.h +24 -0
  256. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
  257. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  258. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  259. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  260. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  261. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  262. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
  263. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  264. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
  265. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  266. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  267. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  268. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  269. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  270. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  271. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  272. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  273. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
  274. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  275. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
  276. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
  277. data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
  278. data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
  279. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  280. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  281. data/vendor/faiss/faiss/utils/distances.cpp +390 -560
  282. data/vendor/faiss/faiss/utils/distances.h +20 -1
  283. data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
  284. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  285. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  286. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  287. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  288. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  289. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  290. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
  291. data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
  292. data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
  293. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  294. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  295. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  296. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  297. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  298. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  299. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  300. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  301. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  302. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  303. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  304. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  305. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  306. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  307. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  308. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  309. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  310. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  311. data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
  312. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  313. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  314. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  315. data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
  316. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  317. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  318. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
  319. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
  320. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
  321. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
  322. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
  323. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  324. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  325. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
  326. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  327. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  328. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  329. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  330. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  331. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  332. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  333. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  339. data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
  340. data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
  341. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  342. data/vendor/faiss/faiss/utils/utils.cpp +5 -5
  343. data/vendor/faiss/faiss/utils/utils.h +3 -3
  344. metadata +119 -34
  345. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  346. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  347. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
  348. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
  349. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  350. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  351. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  352. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  353. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
  354. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  355. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  356. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
  357. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  358. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  359. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  360. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
  361. /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
@@ -14,11 +14,11 @@
14
14
 
15
15
  #include <faiss/impl/CodePackerRaBitQ.h>
16
16
  #include <faiss/impl/FaissAssert.h>
17
- #include <faiss/impl/FastScanDistancePostProcessing.h>
18
17
  #include <faiss/impl/RaBitQUtils.h>
19
18
  #include <faiss/impl/RaBitQuantizerMultiBit.h>
20
- #include <faiss/impl/pq4_fast_scan.h>
21
- #include <faiss/impl/simd_result_handlers.h>
19
+ #include <faiss/impl/ResultHandler.h>
20
+ #include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
21
+ #include <faiss/impl/fast_scan/fast_scan.h>
22
22
  #include <faiss/invlists/BlockInvertedLists.h>
23
23
  #include <faiss/utils/distances.h>
24
24
  #include <faiss/utils/utils.h>
@@ -42,31 +42,38 @@ inline size_t roundup(size_t a, size_t b) {
42
42
  IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan() = default;
43
43
 
44
44
  IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
45
- Index* quantizer,
46
- size_t d,
47
- size_t nlist,
45
+ Index* quantizer_in,
46
+ size_t d_in,
47
+ size_t nlist_in,
48
48
  MetricType metric,
49
- int bbs,
50
- bool own_invlists,
49
+ int bbs_in,
50
+ bool own_invlists_in,
51
51
  uint8_t nb_bits)
52
- : IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists),
53
- rabitq(d, metric, nb_bits) {
54
- FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
52
+ : IndexIVFFastScan(
53
+ quantizer_in,
54
+ d_in,
55
+ nlist_in,
56
+ 0,
57
+ metric,
58
+ own_invlists_in),
59
+ rabitq(d_in, metric, nb_bits) {
60
+ FAISS_THROW_IF_NOT_MSG(d_in > 0, "Dimension must be positive");
55
61
  FAISS_THROW_IF_NOT_MSG(
56
62
  metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
57
63
  "RaBitQ only supports L2 and Inner Product metrics");
58
- FAISS_THROW_IF_NOT_MSG(bbs % 32 == 0, "Batch size must be multiple of 32");
59
- FAISS_THROW_IF_NOT_MSG(quantizer != nullptr, "Quantizer cannot be null");
64
+ FAISS_THROW_IF_NOT_MSG(
65
+ bbs_in % 32 == 0, "Batch size must be multiple of 32");
66
+ FAISS_THROW_IF_NOT_MSG(quantizer_in != nullptr, "Quantizer cannot be null");
60
67
 
61
68
  by_residual = true;
62
69
  qb = 8; // RaBitQ quantization bits
63
70
  centered = false;
64
71
 
65
72
  // FastScan-specific parameters: 4 bits per sub-quantizer
66
- const size_t M_fastscan = (d + 3) / 4;
73
+ const size_t M_fastscan = (d_in + 3) / 4;
67
74
  constexpr size_t nbits_fastscan = 4;
68
75
 
69
- this->bbs = bbs;
76
+ this->bbs = bbs_in;
70
77
  this->fine_quantizer = &rabitq;
71
78
  this->M = M_fastscan;
72
79
  this->nbits = nbits_fastscan;
@@ -101,6 +108,10 @@ size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
101
108
  return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
102
109
  }
103
110
 
111
+ size_t IndexIVFRaBitQFastScan::fast_scan_code_size() const {
112
+ return (d + 7) / 8;
113
+ }
114
+
104
115
  size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
105
116
  // Use code_size as stride to skip embedded factor data during packing
106
117
  return code_size;
@@ -195,7 +206,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
195
206
  const size_t bit_pattern_size = (d + 7) / 8;
196
207
 
197
208
  // Pack sign bits directly into FastScan format (inline)
198
- for (size_t j = 0; j < d; j++) {
209
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
199
210
  const float or_minus_c = xi[j] - centroid[j];
200
211
  if (or_minus_c > 0.0f) {
201
212
  rabitq_utils::set_bit_fastscan(fastscan_code, j);
@@ -224,7 +235,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
224
235
 
225
236
  // Compute residual (needed for quantize_ex_bits)
226
237
  std::vector<float> residual(d);
227
- for (size_t j = 0; j < d; j++) {
238
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
228
239
  residual[j] = xi[j] - centroid[j];
229
240
  }
230
241
 
@@ -261,84 +272,133 @@ bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
261
272
  return true;
262
273
  }
263
274
 
275
+ // out[code] = base + sum of v_i for each set bit in code.
276
+ inline void write_subset_sum_lut(
277
+ float* out,
278
+ float base,
279
+ float v0,
280
+ float v1,
281
+ float v2,
282
+ float v3) {
283
+ out[0] = base;
284
+ out[1] = base + v0;
285
+ out[2] = base + v1;
286
+ out[3] = base + v0 + v1;
287
+ out[4] = base + v2;
288
+ out[5] = base + v0 + v2;
289
+ out[6] = base + v1 + v2;
290
+ out[7] = base + v0 + v1 + v2;
291
+ out[8] = base + v3;
292
+ out[9] = base + v0 + v3;
293
+ out[10] = base + v1 + v3;
294
+ out[11] = base + v0 + v1 + v3;
295
+ out[12] = base + v2 + v3;
296
+ out[13] = base + v0 + v2 + v3;
297
+ out[14] = base + v1 + v2 + v3;
298
+ out[15] = base + v0 + v1 + v2 + v3;
299
+ }
300
+
264
301
  // Computes lookup table for residual vectors in RaBitQ FastScan format
265
302
  void IndexIVFRaBitQFastScan::compute_residual_LUT(
266
- const float* residual,
303
+ const float* query,
304
+ idx_t centroid_id,
267
305
  QueryFactorsData& query_factors,
268
306
  float* lut_out,
269
- const float* original_query) const {
270
- FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
271
-
272
- std::vector<float> rotated_q(d);
273
- std::vector<uint8_t> rotated_qq(d);
307
+ uint8_t qb_param,
308
+ bool centered_param,
309
+ std::vector<float>& rotated_q,
310
+ std::vector<float>& centroid_buf) const {
311
+ const size_t d_val = static_cast<size_t>(d);
312
+ FAISS_THROW_IF_NOT(d_val > 0);
313
+ rotated_q.resize(d_val);
314
+ centroid_buf.resize(d_val);
315
+ std::vector<uint8_t> rotated_qq(d_val);
316
+
317
+ // Compute residual
318
+ quantizer->reconstruct(centroid_id, centroid_buf.data());
319
+ for (size_t i = 0; i < d_val; i++) {
320
+ rotated_q[i] = query[i] - centroid_buf[i];
321
+ }
274
322
 
275
- // Use RaBitQUtils to compute query factors - eliminates code duplication
323
+ // Compute query factors using shared utility
276
324
  query_factors = rabitq_utils::compute_query_factors(
277
- residual,
278
- d,
325
+ rotated_q.data(),
326
+ d_val,
279
327
  nullptr,
280
- qb,
281
- centered,
328
+ qb_param,
329
+ centered_param,
282
330
  metric_type,
283
331
  rotated_q,
284
332
  rotated_qq);
285
333
 
286
- if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
287
- original_query != nullptr) {
288
- query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
289
- query_factors.q_dot_c = query_factors.qr_norm_L2sqr -
290
- fvec_inner_product(original_query, residual, d);
334
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
335
+ query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d_val);
336
+ query_factors.q_dot_c =
337
+ fvec_inner_product(query, centroid_buf.data(), d_val);
291
338
  }
292
339
 
293
- const size_t ex_bits = rabitq.nb_bits - 1;
294
- if (ex_bits > 0) {
340
+ if (rabitq.nb_bits > 1) {
295
341
  query_factors.rotated_q = rotated_q;
296
342
  }
297
343
 
298
- if (centered) {
299
- const float max_code_value = (1 << qb) - 1;
344
+ // Build LUT using branchless subset-sum construction
345
+ const size_t d_sz = d_val;
300
346
 
301
- for (size_t m = 0; m < M; m++) {
302
- const size_t dim_start = m * 4;
303
-
304
- for (int code_val = 0; code_val < 16; code_val++) {
305
- float xor_contribution = 0.0f;
306
-
307
- for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
308
- const size_t dim_idx = dim_start + dim_offset;
347
+ if (centered_param) {
348
+ const float mcv = static_cast<float>((1 << qb_param) - 1);
309
349
 
310
- if (dim_idx < d) {
311
- const bool db_bit = (code_val >> dim_offset) & 1;
312
- const float query_value = rotated_qq[dim_idx];
313
-
314
- xor_contribution += db_bit
315
- ? (max_code_value - query_value)
316
- : query_value;
317
- }
318
- }
319
-
320
- lut_out[m * 16 + code_val] = xor_contribution;
350
+ for (size_t m = 0; m < M; m++) {
351
+ const size_t ds = m * 4;
352
+ float* out = lut_out + m * 16;
353
+
354
+ float base = 0.0f;
355
+ float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
356
+ if (ds + 0 < d_sz) {
357
+ float q = rotated_qq[ds + 0];
358
+ base += q;
359
+ v0 = mcv - 2.0f * q;
360
+ }
361
+ if (ds + 1 < d_sz) {
362
+ float q = rotated_qq[ds + 1];
363
+ base += q;
364
+ v1 = mcv - 2.0f * q;
321
365
  }
366
+ if (ds + 2 < d_sz) {
367
+ float q = rotated_qq[ds + 2];
368
+ base += q;
369
+ v2 = mcv - 2.0f * q;
370
+ }
371
+ if (ds + 3 < d_sz) {
372
+ float q = rotated_qq[ds + 3];
373
+ base += q;
374
+ v3 = mcv - 2.0f * q;
375
+ }
376
+
377
+ write_subset_sum_lut(out, base, v0, v1, v2, v3);
322
378
  }
323
379
  } else {
324
- for (size_t m = 0; m < M; m++) {
325
- const size_t dim_start = m * 4;
326
-
327
- for (int code_val = 0; code_val < 16; code_val++) {
328
- float inner_product = 0.0f;
329
- int popcount = 0;
380
+ const float c1 = query_factors.c1;
381
+ const float c2 = query_factors.c2;
330
382
 
331
- for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
332
- const size_t dim_idx = dim_start + dim_offset;
383
+ for (size_t m = 0; m < M; m++) {
384
+ const size_t ds = m * 4;
385
+ float* out = lut_out + m * 16;
333
386
 
334
- if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
335
- inner_product += rotated_qq[dim_idx];
336
- popcount++;
337
- }
338
- }
339
- lut_out[m * 16 + code_val] = query_factors.c1 * inner_product +
340
- query_factors.c2 * popcount;
387
+ float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f;
388
+ if (ds + 0 < d_sz) {
389
+ v0 = c1 * rotated_qq[ds + 0] + c2;
390
+ }
391
+ if (ds + 1 < d_sz) {
392
+ v1 = c1 * rotated_qq[ds + 1] + c2;
341
393
  }
394
+ if (ds + 2 < d_sz) {
395
+ v2 = c1 * rotated_qq[ds + 2] + c2;
396
+ }
397
+ if (ds + 3 < d_sz) {
398
+ v3 = c1 * rotated_qq[ds + 3] + c2;
399
+ }
400
+
401
+ write_subset_sum_lut(out, 0.0f, v0, v1, v2, v3);
342
402
  }
343
403
  }
344
404
  }
@@ -360,18 +420,27 @@ void IndexIVFRaBitQFastScan::search_preassigned(
360
420
  !store_pairs, "store_pairs not supported for RaBitQFastScan");
361
421
  FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
362
422
 
363
- size_t nprobe = this->nprobe;
423
+ size_t cur_nprobe = this->nprobe;
424
+ uint8_t used_qb = qb;
425
+ bool used_centered = centered;
364
426
  if (params) {
365
427
  FAISS_THROW_IF_NOT(params->max_codes == 0);
366
- nprobe = params->nprobe;
428
+ cur_nprobe = params->nprobe;
429
+ if (auto rparams =
430
+ dynamic_cast<const IVFRaBitQSearchParameters*>(params)) {
431
+ used_qb = rparams->qb;
432
+ used_centered = rparams->centered;
433
+ }
367
434
  }
368
435
 
369
- std::vector<QueryFactorsData> query_factors_storage(n * nprobe);
436
+ std::vector<QueryFactorsData> query_factors_storage(n * cur_nprobe);
370
437
  FastScanDistancePostProcessing context;
371
438
  context.query_factors = query_factors_storage.data();
372
- context.nprobe = nprobe;
439
+ context.nprobe = cur_nprobe;
440
+ context.qb = used_qb;
441
+ context.centered = used_centered;
373
442
 
374
- const CoarseQuantized cq = {nprobe, centroid_dis, assign};
443
+ const CoarseQuantized cq = {cur_nprobe, centroid_dis, assign};
375
444
  search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
376
445
  }
377
446
 
@@ -385,44 +454,165 @@ void IndexIVFRaBitQFastScan::compute_LUT(
385
454
  FAISS_THROW_IF_NOT(is_trained);
386
455
  FAISS_THROW_IF_NOT(by_residual);
387
456
 
388
- size_t nprobe = cq.nprobe;
457
+ // Use overridden qb/centered from context if provided, else index defaults
458
+ const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
459
+ const bool used_centered = context.qb > 0 ? context.centered : centered;
460
+
461
+ size_t cq_nprobe = cq.nprobe;
389
462
 
390
463
  size_t dim12 = 16 * M;
391
464
 
392
- dis_tables.resize(n * nprobe * dim12);
393
- biases.resize(n * nprobe);
465
+ dis_tables.resize(n * cq_nprobe * dim12);
466
+ biases.resize(n * cq_nprobe);
394
467
 
395
- if (n * nprobe > 0) {
396
- memset(biases.get(), 0, sizeof(float) * n * nprobe);
468
+ if (n * cq_nprobe > 0) {
469
+ memset(biases.get(), 0, sizeof(float) * n * cq_nprobe);
397
470
  }
398
- std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
471
+ // Use per-thread buffers instead of one O(n * nprobe * d) allocation.
472
+ // rotated_q / centroid_buf keep their capacity across iterations so the
473
+ // allocator is only hit once per thread.
474
+ #pragma omp parallel if (n * cq_nprobe > 1000)
475
+ {
476
+ std::vector<float> rotated_q(d);
477
+ std::vector<float> centroid_buf(d);
478
+
479
+ #pragma omp for
480
+ for (idx_t ij = 0; ij < static_cast<idx_t>(n * cq_nprobe); ij++) {
481
+ idx_t i = ij / cq_nprobe;
482
+ idx_t cij = cq.ids[ij];
483
+
484
+ if (cij >= 0) {
485
+ QueryFactorsData query_factors_data;
486
+
487
+ compute_residual_LUT(
488
+ x + i * d,
489
+ cij,
490
+ query_factors_data,
491
+ dis_tables.get() + ij * dim12,
492
+ used_qb,
493
+ used_centered,
494
+ rotated_q,
495
+ centroid_buf);
496
+
497
+ if (context.query_factors != nullptr) {
498
+ context.query_factors[ij] = std::move(query_factors_data);
499
+ }
399
500
 
400
- #pragma omp parallel for if (n * nprobe > 1000)
401
- for (idx_t ij = 0; ij < n * nprobe; ij++) {
402
- idx_t i = ij / nprobe;
403
- float* xij = &xrel[ij * d];
404
- idx_t cij = cq.ids[ij];
501
+ } else {
502
+ memset(dis_tables.get() + ij * dim12, 0, sizeof(float) * dim12);
503
+ }
504
+ }
505
+ }
506
+ }
405
507
 
406
- if (cij >= 0) {
407
- quantizer->compute_residual(x + i * d, xij, cij);
508
+ void IndexIVFRaBitQFastScan::compute_LUT_uint8(
509
+ size_t n,
510
+ const float* x,
511
+ const CoarseQuantized& cq,
512
+ AlignedTable<uint8_t>& dis_tables,
513
+ AlignedTable<uint16_t>& biases,
514
+ float* normalizers,
515
+ const FastScanDistancePostProcessing& context) const {
516
+ FAISS_THROW_IF_NOT(is_trained);
517
+ FAISS_THROW_IF_NOT(by_residual);
408
518
 
409
- // Create QueryFactorsData for this query-list combination
410
- QueryFactorsData query_factors_data;
519
+ const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
520
+ const bool used_centered = context.qb > 0 ? context.centered : centered;
521
+ const size_t cur_nprobe = cq.nprobe;
522
+ const size_t dim12 = 16 * M;
523
+ const size_t dim12_2 = 16 * M2;
411
524
 
412
- compute_residual_LUT(
413
- xij,
414
- query_factors_data,
415
- dis_tables.get() + ij * dim12,
416
- x + i * d);
525
+ // Allocate only the uint8 output table (no full float table)
526
+ dis_tables.resize(n * cur_nprobe * dim12_2);
527
+ biases.resize(n * cur_nprobe);
417
528
 
418
- // Store query factors using compact indexing (ij directly)
419
- if (context.query_factors != nullptr) {
420
- context.query_factors[ij] = query_factors_data;
529
+ #pragma omp parallel if (n > 1)
530
+ {
531
+ // Per-thread buffers reused across queries
532
+ AlignedTable<float> lut_float(cur_nprobe * dim12);
533
+ std::vector<float> rotated_q(d);
534
+ std::vector<float> centroid_buf(d);
535
+ std::vector<float> all_mins(cur_nprobe * M);
536
+ std::vector<float> probe_b(cur_nprobe);
537
+
538
+ #pragma omp for schedule(dynamic)
539
+ for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
540
+ const float* xi = x + i * d;
541
+
542
+ // Compute float LUT for all probes using fused path
543
+ for (size_t j = 0; j < cur_nprobe; j++) {
544
+ const size_t ij = i * cur_nprobe + j;
545
+ idx_t cij = cq.ids[ij];
546
+
547
+ if (cij >= 0) {
548
+ QueryFactorsData qf;
549
+ compute_residual_LUT(
550
+ xi,
551
+ cij,
552
+ qf,
553
+ lut_float.get() + j * dim12,
554
+ used_qb,
555
+ used_centered,
556
+ rotated_q,
557
+ centroid_buf);
558
+
559
+ if (context.query_factors != nullptr) {
560
+ context.query_factors[ij] = qf;
561
+ }
562
+ } else {
563
+ memset(lut_float.get() + j * dim12,
564
+ 0,
565
+ sizeof(float) * dim12);
566
+ }
421
567
  }
422
568
 
423
- } else {
424
- memset(xij, -1, sizeof(float) * d);
425
- memset(dis_tables.get() + ij * dim12, -1, sizeof(float) * dim12);
569
+ // Quantize float LUT to uint8 inline.
570
+ // Mirrors quantize_LUT_and_bias 3D path with zero biases.
571
+ // Single pass: find per-sub-q mins, max span, and per-probe b.
572
+ float glob_max_span = -HUGE_VAL;
573
+ float glob_max_dis = -HUGE_VAL;
574
+ float glob_b = HUGE_VAL;
575
+ for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
576
+ float b_j = 0;
577
+ float span_j = 0;
578
+ for (size_t m = 0; m < M; m++) {
579
+ const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
580
+ float mn = tab[0], mx = tab[0];
581
+ for (size_t s = 1; s < ksub; s++) {
582
+ mn = std::min(mn, tab[s]);
583
+ mx = std::max(mx, tab[s]);
584
+ }
585
+ all_mins[j2 * M + m] = mn;
586
+ float span = mx - mn;
587
+ glob_max_span = std::max(glob_max_span, span);
588
+ b_j += mn;
589
+ span_j += span;
590
+ }
591
+ probe_b[j2] = b_j;
592
+ glob_max_dis = std::max(glob_max_dis, span_j);
593
+ glob_b = std::min(glob_b, b_j);
594
+ }
595
+ float a = std::min(255.0f / glob_max_span, 65535.0f / glob_max_dis);
596
+
597
+ // Second pass: quantize LUT and compute biasq
598
+ uint8_t* out_base = dis_tables.get() + i * cur_nprobe * dim12_2;
599
+ uint16_t* bq = biases.get() + i * cur_nprobe;
600
+ for (size_t j2 = 0; j2 < cur_nprobe; j2++) {
601
+ for (size_t m = 0; m < M; m++) {
602
+ const float* tab = lut_float.get() + j2 * dim12 + m * ksub;
603
+ float mn = all_mins[j2 * M + m];
604
+ uint8_t* out = out_base + j2 * dim12_2 + m * ksub;
605
+ for (size_t s = 0; s < ksub; s++) {
606
+ out[s] = static_cast<uint8_t>(
607
+ std::roundf(a * (tab[s] - mn)));
608
+ }
609
+ }
610
+ memset(out_base + j2 * dim12_2 + M * ksub, 0, (M2 - M) * ksub);
611
+ bq[j2] = static_cast<uint16_t>(
612
+ std::roundf(a * (probe_b[j2] - glob_b)));
613
+ }
614
+ normalizers[2 * i] = a;
615
+ normalizers[2 * i + 1] = glob_b;
426
616
  }
427
617
  }
428
618
  }
@@ -477,7 +667,7 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
477
667
  fastscan_code.data(), residual.data(), dp_multiplier);
478
668
 
479
669
  // Reconstruct: x = centroid + residual
480
- for (size_t j = 0; j < d; j++) {
670
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
481
671
  recons[j] = centroid[j] + residual[j];
482
672
  }
483
673
  }
@@ -502,7 +692,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
502
692
 
503
693
  idx_t list_no = decode_listno(code_i);
504
694
 
505
- if (list_no >= 0 && list_no < nlist) {
695
+ if (list_no >= 0 && list_no < static_cast<idx_t>(nlist)) {
506
696
  quantizer->reconstruct(list_no, centroid.data());
507
697
 
508
698
  const uint8_t* fastscan_code = code_i + coarse_size;
@@ -514,7 +704,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
514
704
  decode_fastscan_to_residual(
515
705
  fastscan_code, residual.data(), base_factors.dp_multiplier);
516
706
 
517
- for (size_t j = 0; j < d; j++) {
707
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
518
708
  x_i[j] = centroid[j] + residual[j];
519
709
  }
520
710
  } else {
@@ -531,7 +721,7 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
531
721
 
532
722
  const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
533
723
 
534
- for (size_t j = 0; j < d; j++) {
724
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
535
725
  bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
536
726
 
537
727
  float bit_as_float = bit_value ? 1.0f : 0.0f;
@@ -539,287 +729,18 @@ void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
539
729
  }
540
730
  }
541
731
 
542
- // Implementation of virtual make_knn_handler method
543
- SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
732
+ std::unique_ptr<FastScanCodeScanner> IndexIVFRaBitQFastScan::make_knn_scanner(
544
733
  bool is_max,
545
- int /* impl */,
546
734
  idx_t n,
547
735
  idx_t k,
548
736
  float* distances,
549
737
  idx_t* labels,
550
- const IDSelector* /* sel */,
551
- const FastScanDistancePostProcessing& context,
552
- const float* /* normalizers */) const {
553
- const size_t ex_bits = rabitq.nb_bits - 1;
554
- const bool is_multibit = ex_bits > 0;
555
-
556
- if (is_max) {
557
- return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
558
- this, n, k, distances, labels, &context, is_multibit);
559
- } else {
560
- return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
561
- this, n, k, distances, labels, &context, is_multibit);
562
- }
563
- }
564
-
565
- /*********************************************************
566
- * IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler implementation
567
- *********************************************************/
568
-
569
- template <class C>
570
- IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
571
- const IndexIVFRaBitQFastScan* idx,
572
- size_t nq_val,
573
- size_t k_val,
574
- float* distances,
575
- int64_t* labels,
576
- const FastScanDistancePostProcessing* ctx,
577
- bool multibit)
578
- : simd_result_handlers::ResultHandlerCompare<C, true>(
579
- nq_val,
580
- 0,
581
- nullptr),
582
- index(idx),
583
- heap_distances(distances),
584
- heap_labels(labels),
585
- nq(nq_val),
586
- k(k_val),
587
- context(ctx),
588
- is_multibit(multibit),
589
- storage_size(idx->compute_per_vector_storage_size()),
590
- packed_block_size(((idx->M2 + 1) / 2) * idx->bbs),
591
- full_block_size(idx->get_block_stride()),
592
- packer(idx->get_CodePacker()) {
593
- current_list_no = 0;
594
- probe_indices.clear();
595
-
596
- // Initialize heaps in constructor (standard pattern from HeapHandler)
597
- for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
598
- float* heap_dis = heap_distances + q * k;
599
- int64_t* heap_ids = heap_labels + q * k;
600
- heap_heapify<Cfloat>(k, heap_dis, heap_ids);
601
- }
602
- }
603
-
604
- template <class C>
605
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
606
- size_t q,
607
- size_t b,
608
- simd16uint16 d0,
609
- simd16uint16 d1) {
610
- // Store the original local query index before adjust_with_origin changes it
611
- size_t local_q = q;
612
- this->adjust_with_origin(q, d0, d1);
613
-
614
- ALIGNED(32) uint16_t d32tab[32];
615
- d0.store(d32tab);
616
- d1.store(d32tab + 16);
617
-
618
- float* const heap_dis = heap_distances + q * k;
619
- int64_t* const heap_ids = heap_labels + q * k;
620
-
621
- FAISS_THROW_IF_NOT_FMT(
622
- !probe_indices.empty() && local_q < probe_indices.size(),
623
- "set_list_context() must be called before handle() - probe_indices size: %zu, local_q: %zu, global_q: %zu",
624
- probe_indices.size(),
625
- local_q,
626
- q);
627
-
628
- // Access query factors directly from array via ProcessingContext
629
- if (!context || !context->query_factors) {
630
- FAISS_THROW_MSG(
631
- "Query factors not available: FastScanDistancePostProcessing with query_factors required");
632
- }
633
-
634
- // Use probe_rank from probe_indices for compact storage indexing
635
- size_t probe_rank = probe_indices[local_q];
636
- size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
637
- size_t storage_idx = q * nprobe + probe_rank;
638
-
639
- const auto& query_factors = context->query_factors[storage_idx];
640
-
641
- const float one_a =
642
- this->normalizers ? (1.0f / this->normalizers[2 * q]) : 1.0f;
643
- const float bias = this->normalizers ? this->normalizers[2 * q + 1] : 0.0f;
644
-
645
- uint64_t idx_base = this->j0 + b * 32;
646
- if (idx_base >= this->ntotal) {
647
- return;
648
- }
649
-
650
- size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
651
-
652
- // Stats tracking for two-stage search
653
- // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
654
- // n_multibit_evaluations: candidates requiring full multi-bit distance
655
- size_t local_1bit_evaluations = 0;
656
- size_t local_multibit_evaluations = 0;
657
-
658
- // Process each candidate vector in the SIMD batch
659
- for (size_t j = 0; j < max_positions; j++) {
660
- const int64_t result_id = this->adjust_id(b, j);
661
-
662
- if (result_id < 0) {
663
- continue;
664
- }
665
-
666
- const float normalized_distance = d32tab[j] * one_a + bias;
667
-
668
- const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
669
- list_codes_ptr,
670
- idx_base + j,
671
- index->bbs,
672
- packed_block_size,
673
- full_block_size,
674
- storage_size);
675
-
676
- if (is_multibit) {
677
- // Track candidates actually considered for two-stage filtering
678
- local_1bit_evaluations++;
679
-
680
- // Multi-bit: use SignBitFactorsWithError and two-stage search
681
- const SignBitFactorsWithError& full_factors =
682
- *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
683
-
684
- // Compute 1-bit adjusted distance using shared helper
685
- float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
686
- normalized_distance,
687
- full_factors,
688
- query_factors,
689
- index->centered,
690
- index->qb,
691
- index->d);
692
-
693
- // Adaptive filtering: decide whether to compute full distance
694
- const bool is_similarity =
695
- index->metric_type == MetricType::METRIC_INNER_PRODUCT;
696
-
697
- float g_error = query_factors.g_error;
698
-
699
- bool should_refine = rabitq_utils::should_refine_candidate(
700
- dist_1bit,
701
- full_factors.f_error,
702
- g_error,
703
- heap_dis[0],
704
- is_similarity);
705
- if (should_refine) {
706
- local_multibit_evaluations++;
707
-
708
- // Compute local_offset: position within current inverted list
709
- size_t local_offset = this->j0 + b * 32 + j;
710
-
711
- // Compute full multi-bit distance
712
- float dist_full = compute_full_multibit_distance(
713
- result_id, local_q, q, local_offset);
714
-
715
- // Update heap if this distance is better
716
- if (Cfloat::cmp(heap_dis[0], dist_full)) {
717
- heap_replace_top<Cfloat>(
718
- k, heap_dis, heap_ids, dist_full, result_id);
719
- nup++;
720
- }
721
- }
722
- } else {
723
- const auto& db_factors =
724
- *reinterpret_cast<const SignBitFactors*>(base_ptr);
725
-
726
- // Compute adjusted distance using shared helper
727
- float adjusted_distance =
728
- rabitq_utils::compute_1bit_adjusted_distance(
729
- normalized_distance,
730
- db_factors,
731
- query_factors,
732
- index->centered,
733
- index->qb,
734
- index->d);
735
-
736
- if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
737
- heap_replace_top<Cfloat>(
738
- k, heap_dis, heap_ids, adjusted_distance, result_id);
739
- nup++;
740
- }
741
- }
742
- }
743
-
744
- // Update global stats atomically
745
- #pragma omp atomic
746
- rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
747
- #pragma omp atomic
748
- rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
749
- }
750
-
751
- template <class C>
752
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
753
- size_t list_no,
754
- const std::vector<int>& probe_map) {
755
- current_list_no = list_no;
756
- probe_indices = probe_map;
757
- list_codes_ptr = index->invlists->get_codes(list_no);
758
- }
759
-
760
- template <class C>
761
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::begin(
762
- const float* norms) {
763
- this->normalizers = norms;
764
- }
765
-
766
- template <class C>
767
- void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
768
- #pragma omp parallel for
769
- for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
770
- float* heap_dis = heap_distances + q * k;
771
- int64_t* heap_ids = heap_labels + q * k;
772
- heap_reorder<Cfloat>(k, heap_dis, heap_ids);
773
- }
774
- }
775
-
776
- template <class C>
777
- float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
778
- compute_full_multibit_distance(
779
- size_t /*db_idx*/,
780
- size_t local_q,
781
- size_t global_q,
782
- size_t local_offset) const {
783
- const size_t ex_bits = index->rabitq.nb_bits - 1;
784
- const size_t dim = index->d;
785
-
786
- const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
787
- list_codes_ptr,
788
- local_offset,
789
- index->bbs,
790
- packed_block_size,
791
- full_block_size,
792
- storage_size);
793
-
794
- const size_t ex_code_size = (dim * ex_bits + 7) / 8;
795
- const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
796
- const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
797
- base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
798
-
799
- // Use local_q to access probe_indices (batch-local), global_q for storage
800
- size_t probe_rank = probe_indices[local_q];
801
- size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
802
- size_t storage_idx = global_q * nprobe + probe_rank;
803
- const auto& query_factors = context->query_factors[storage_idx];
804
-
805
- size_t list_no = current_list_no;
806
- InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
807
-
808
- std::vector<uint8_t> unpacked_code(index->code_size);
809
- packer->unpack_1(list_codes.get(), local_offset, unpacked_code.data());
810
- const uint8_t* sign_bits = unpacked_code.data();
811
-
812
- return rabitq_utils::compute_full_multibit_distance(
813
- sign_bits,
814
- ex_code,
815
- ex_fac,
816
- query_factors.rotated_q.data(),
817
- (index->metric_type == MetricType::METRIC_INNER_PRODUCT)
818
- ? query_factors.q_dot_c
819
- : query_factors.qr_to_c_L2sqr,
820
- dim,
821
- ex_bits,
822
- index->metric_type);
738
+ const IDSelector* sel,
739
+ int /*impl*/,
740
+ const FastScanDistancePostProcessing& context) const {
741
+ const bool is_multibit = (rabitq.nb_bits - 1) > 0;
742
+ return rabitq_ivf_make_knn_scanner(
743
+ is_max, this, n, k, distances, labels, sel, &context, is_multibit);
823
744
  }
824
745
 
825
746
  /*********************************************************
@@ -829,139 +750,209 @@ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
829
750
  namespace {
830
751
 
831
752
  /// Provides IVF scanner interface using FastScan's SIMD batch processing.
753
+ /// Buffers are allocated once and reused across set_list + scan_codes calls.
832
754
  struct IVFRaBitQFastScanScanner : InvertedListScanner {
833
- static constexpr int impl = 10;
755
+ using InvertedListScanner::scan_codes;
834
756
  static constexpr size_t nq = 1;
835
757
 
836
758
  const IndexIVFRaBitQFastScan& index;
759
+ const uint8_t qb;
760
+ const bool centered;
837
761
 
762
+ const float* xi = nullptr;
763
+
764
+ // Reusable buffers (allocated once in constructor)
838
765
  AlignedTable<uint8_t> dis_tables;
839
766
  AlignedTable<uint16_t> biases;
840
- /// [scale, offset] for converting uint16 to float
841
767
  std::array<float, 2> normalizers{};
842
-
843
- const float* xi = nullptr;
844
-
768
+ AlignedTable<float> lut_float;
769
+ std::vector<float> rotated_q;
770
+ std::vector<float> centroid_buf;
845
771
  QueryFactorsData query_factors;
846
772
  FastScanDistancePostProcessing context;
773
+ std::vector<int> probe_map;
774
+ std::vector<float> mins_buf;
847
775
 
776
+ // Distance computer for distance_to_code (created in set_list)
848
777
  std::unique_ptr<FlatCodesDistanceComputer> dc;
849
- std::vector<float> centroid;
850
778
 
851
779
  IVFRaBitQFastScanScanner(
852
- const IndexIVFRaBitQFastScan& index,
853
- bool store_pairs,
854
- const IDSelector* sel)
855
- : InvertedListScanner(store_pairs, sel), index(index) {
856
- this->keep_max = is_similarity_metric(index.metric_type);
780
+ const IndexIVFRaBitQFastScan& index_in,
781
+ bool store_pairs_in,
782
+ const IDSelector* sel_in,
783
+ uint8_t qb_in,
784
+ bool centered_in)
785
+ : InvertedListScanner(store_pairs_in, sel_in),
786
+ index(index_in),
787
+ qb(qb_in),
788
+ centered(centered_in),
789
+ lut_float(16 * index_in.M),
790
+ rotated_q(index_in.d),
791
+ centroid_buf(index_in.d),
792
+ probe_map({0}),
793
+ mins_buf(index_in.M) {
794
+ this->keep_max = is_similarity_metric(index_in.metric_type);
795
+ this->code_size = index_in.code_size;
796
+
797
+ // Pre-allocate output tables for single probe
798
+ dis_tables.resize(16 * index_in.M2);
799
+ biases.resize(1);
800
+
801
+ // Set up context once
802
+ context.query_factors = &query_factors;
803
+ context.nprobe = 1;
804
+ context.qb = qb;
805
+ context.centered = centered;
857
806
  }
858
807
 
859
808
  void set_query(const float* query) override {
860
809
  this->xi = query;
861
810
  }
862
811
 
863
- void set_list(idx_t list_no, float coarse_dis) override {
864
- this->list_no = list_no;
812
+ void set_list(idx_t list_no_in, float /*coarse_dis_in*/) override {
813
+ this->list_no = list_no_in;
814
+
815
+ index.compute_residual_LUT(
816
+ xi,
817
+ list_no_in,
818
+ query_factors,
819
+ lut_float.get(),
820
+ qb,
821
+ centered,
822
+ rotated_q,
823
+ centroid_buf);
824
+
825
+ // Single-probe quantization (simplified inline, no OMP, no 3D)
826
+ const size_t M = index.M;
827
+ const size_t M2 = index.M2;
828
+ const size_t ksub = index.ksub;
829
+
830
+ float max_span = -HUGE_VAL;
831
+ float max_dis = 0;
832
+ float b = 0;
833
+ float* mins = mins_buf.data();
865
834
 
866
- IndexIVFFastScan::CoarseQuantized cq{
867
- .nprobe = 1,
868
- .dis = &coarse_dis,
869
- .ids = &list_no,
870
- };
871
-
872
- // Set up context for use in scan_codes
873
- context = FastScanDistancePostProcessing{};
874
- context.query_factors = &query_factors;
875
- context.nprobe = 1;
835
+ for (size_t m = 0; m < M; m++) {
836
+ const float* tab = lut_float.get() + m * ksub;
837
+ float mn = tab[0], mx = tab[0];
838
+ for (size_t s = 1; s < ksub; s++) {
839
+ mn = std::min(mn, tab[s]);
840
+ mx = std::max(mx, tab[s]);
841
+ }
842
+ mins[m] = mn;
843
+ float span = mx - mn;
844
+ max_span = std::max(max_span, span);
845
+ max_dis += span;
846
+ b += mn;
847
+ }
876
848
 
877
- index.compute_LUT_uint8(
878
- 1, xi, cq, dis_tables, biases, &normalizers[0], context);
849
+ float a = std::min(255.0f / max_span, 65535.0f / max_dis);
850
+ uint8_t* out = dis_tables.get();
851
+ for (size_t m = 0; m < M; m++) {
852
+ const float* tab = lut_float.get() + m * ksub;
853
+ for (size_t s = 0; s < ksub; s++) {
854
+ out[m * ksub + s] = static_cast<uint8_t>(
855
+ std::roundf(a * (tab[s] - mins[m])));
856
+ }
857
+ }
858
+ memset(out + M * ksub, 0, (M2 - M) * ksub);
859
+ biases[0] = 0;
860
+ normalizers[0] = a;
861
+ normalizers[1] = b;
879
862
 
880
- // Set up distance computer for distance_to_code
881
- centroid.resize(index.d);
882
- index.quantizer->reconstruct(list_no, centroid.data());
863
+ // Create distance computer (reuses centroid_buf from
864
+ // compute_residual_LUT)
883
865
  dc.reset(index.rabitq.get_distance_computer(
884
- index.qb, centroid.data(), index.centered));
866
+ qb, centroid_buf.data(), centered));
885
867
  dc->set_query(xi);
886
868
  }
887
869
 
888
870
  float distance_to_code(const uint8_t* code) const override {
889
- FAISS_THROW_IF_NOT_MSG(
890
- dc,
891
- "set_query and set_list must be called before distance_to_code");
892
871
  return dc->distance_to_code(code);
893
872
  }
894
873
 
895
- public:
896
874
  size_t scan_codes(
897
875
  size_t ntotal,
898
876
  const uint8_t* codes,
899
877
  const idx_t* ids,
900
- float* distances,
901
- idx_t* labels,
902
- size_t k) const override {
903
- // initialize the current iteration heap to the worst possible value of
904
- // the prior loop
905
- std::vector<float> curr_dists(k, distances[0]);
906
- std::vector<idx_t> curr_labels(k, labels[0]);
907
-
908
- std::unique_ptr<SIMDResultHandlerToFloat> handler(
909
- index.make_knn_handler(
910
- !keep_max,
911
- impl,
912
- nq,
913
- k,
914
- curr_dists.data(),
915
- curr_labels.data(),
916
- sel,
917
- context,
918
- &normalizers[0]));
919
-
920
- int qmap1[1] = {0};
921
- handler->q_map = qmap1;
922
- handler->begin(&normalizers[0]);
923
-
924
- const uint8_t* LUT = dis_tables.get();
925
- handler->dbias = biases.get();
926
- handler->ntotal = ntotal;
927
- handler->id_map = ids;
928
-
929
- // RaBitQ needs list context for factor lookup
930
- std::vector<int> probe_map = {0};
931
- handler->set_list_context(list_no, probe_map);
932
-
933
- pq4_accumulate_loop(
934
- 1,
935
- roundup(ntotal, index.bbs),
936
- index.bbs,
937
- static_cast<int>(index.M2),
938
- codes,
939
- LUT,
940
- *handler,
941
- nullptr,
942
- index.get_block_stride());
943
-
944
- // Combine results across iterations
945
- handler->end();
946
- if (keep_max) {
947
- minheap_addn(
878
+ ResultHandler& result_handler) const override {
879
+ auto scan_with_heap = [&](auto* heap_handler) -> size_t {
880
+ const size_t k = heap_handler->k;
881
+ if (k == 0) {
882
+ return 0;
883
+ }
884
+
885
+ std::vector<float> curr_dists(k, result_handler.threshold);
886
+ std::vector<idx_t> curr_labels(k, -1);
887
+
888
+ auto scanner = index.make_knn_scanner(
889
+ !keep_max,
890
+ nq,
948
891
  k,
949
- distances,
950
- labels,
951
892
  curr_dists.data(),
952
893
  curr_labels.data(),
953
- k);
894
+ sel,
895
+ 0,
896
+ context);
897
+ auto* handler = scanner->handler();
898
+
899
+ int qmap1[1] = {0};
900
+ handler->q_map = qmap1;
901
+ handler->begin(&normalizers[0]);
902
+ handler->dbias = biases.get();
903
+ handler->ntotal = ntotal;
904
+ handler->id_map = ids;
905
+
906
+ handler->set_list_context(list_no, probe_map);
907
+ if (!handler->list_codes_ptr) {
908
+ handler->list_codes_ptr = codes;
909
+ }
910
+
911
+ scanner->accumulate_loop(
912
+ 1,
913
+ roundup(ntotal, index.bbs),
914
+ index.bbs,
915
+ static_cast<int>(index.M2),
916
+ codes,
917
+ dis_tables.get(),
918
+ 0,
919
+ index.get_block_stride());
920
+
921
+ const size_t scan_cnt = handler->count_scanned_rows();
922
+ handler->end();
923
+
924
+ result_handler.stats.scan_cnt += scan_cnt;
925
+ size_t nup = 0;
926
+ for (size_t j = 0; j < k; j++) {
927
+ if (curr_labels[j] < 0) {
928
+ continue;
929
+ }
930
+ if (result_handler.add_result(curr_dists[j], curr_labels[j])) {
931
+ result_handler.stats.nheap_updates++;
932
+ nup++;
933
+ }
934
+ }
935
+ return nup;
936
+ };
937
+
938
+ if (!keep_max) {
939
+ using C = CMax<float, idx_t>;
940
+ if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
941
+ &result_handler)) {
942
+ return scan_with_heap(heap_handler);
943
+ }
954
944
  } else {
955
- maxheap_addn(
956
- k,
957
- distances,
958
- labels,
959
- curr_dists.data(),
960
- curr_labels.data(),
961
- k);
945
+ using C = CMin<float, idx_t>;
946
+ if (auto* heap_handler = dynamic_cast<HeapResultHandler<C, false>*>(
947
+ &result_handler)) {
948
+ return scan_with_heap(heap_handler);
949
+ }
962
950
  }
963
951
 
964
- return handler->num_updates();
952
+ FAISS_THROW_MSG(
953
+ "IVFRaBitQFastScanScanner::scan_codes requires "
954
+ "HeapResultHandler; custom ResultHandler scan is not supported "
955
+ "by this optimized scanner");
965
956
  }
966
957
  };
967
958
 
@@ -970,8 +961,16 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner {
970
961
  InvertedListScanner* IndexIVFRaBitQFastScan::get_InvertedListScanner(
971
962
  bool store_pairs,
972
963
  const IDSelector* sel,
973
- const IVFSearchParameters*) const {
974
- return new IVFRaBitQFastScanScanner(*this, store_pairs, sel);
964
+ const IVFSearchParameters* search_params_in) const {
965
+ uint8_t used_qb = qb;
966
+ bool used_centered = centered;
967
+ if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
968
+ search_params_in)) {
969
+ used_qb = params->qb;
970
+ used_centered = params->centered;
971
+ }
972
+ return new IVFRaBitQFastScanScanner(
973
+ *this, store_pairs, sel, used_qb, used_centered);
975
974
  }
976
975
 
977
976
  } // namespace faiss