faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -6,10 +6,11 @@
6
6
  */
7
7
 
8
8
  #include <faiss/IndexRaBitQFastScan.h>
9
- #include <faiss/impl/FastScanDistancePostProcessing.h>
9
+ #include <faiss/impl/CodePackerRaBitQ.h>
10
10
  #include <faiss/impl/RaBitQUtils.h>
11
11
  #include <faiss/impl/RaBitQuantizerMultiBit.h>
12
- #include <faiss/impl/pq4_fast_scan.h>
12
+ #include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
13
+ #include <faiss/impl/fast_scan/fast_scan.h>
13
14
  #include <faiss/utils/utils.h>
14
15
  #include <algorithm>
15
16
  #include <cmath>
@@ -21,29 +22,19 @@ static inline size_t roundup(size_t a, size_t b) {
21
22
  }
22
23
 
23
24
  size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
24
- const size_t ex_bits = rabitq.nb_bits - 1;
25
-
26
- if (ex_bits == 0) {
27
- // 1-bit: only SignBitFactors
28
- return sizeof(rabitq_utils::SignBitFactors);
29
- } else {
30
- // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors +
31
- // mag-codes
32
- return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
33
- (d * ex_bits + 7) / 8;
34
- }
25
+ return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
35
26
  }
36
27
 
37
28
  IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
38
29
 
39
30
  IndexRaBitQFastScan::IndexRaBitQFastScan(
40
- idx_t d,
31
+ idx_t d_in,
41
32
  MetricType metric,
42
- int bbs,
33
+ int bbs_in,
43
34
  uint8_t nb_bits)
44
- : rabitq(d, metric, nb_bits) {
35
+ : rabitq(d_in, metric, nb_bits) {
45
36
  // RaBitQ-specific validation
46
- FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
37
+ FAISS_THROW_IF_NOT_MSG(d_in > 0, "Dimension must be positive");
47
38
  FAISS_THROW_IF_NOT_MSG(
48
39
  metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
49
40
  "RaBitQ FastScan only supports L2 and Inner Product metrics");
@@ -52,24 +43,67 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(
52
43
 
53
44
  // RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
54
45
  // Each FastScan sub-quantizer handles 4 RaBitQ dimensions
55
- const size_t M_fastscan = (d + 3) / 4;
46
+ const size_t M_fastscan = (d_in + 3) / 4;
56
47
  constexpr size_t nbits_fastscan = 4;
57
48
 
58
49
  // init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
59
- init_fastscan(static_cast<int>(d), M_fastscan, nbits_fastscan, metric, bbs);
50
+ init_fastscan(
51
+ static_cast<int>(d_in), M_fastscan, nbits_fastscan, metric, bbs_in);
60
52
 
61
53
  // Compute code_size directly using RaBitQuantizer
62
- code_size = rabitq.compute_code_size(d, nb_bits);
54
+ code_size = rabitq.compute_code_size(d_in, nb_bits);
63
55
 
64
56
  // Set RaBitQ-specific parameters
65
57
  qb = 8;
66
- center.resize(d, 0.0f);
58
+ center.resize(d_in, 0.0f);
59
+ }
60
+
61
+ CodePacker* IndexRaBitQFastScan::get_CodePacker() const {
62
+ return new CodePackerRaBitQ(M2, bbs, compute_per_vector_storage_size());
63
+ }
64
+
65
+ size_t IndexRaBitQFastScan::remove_ids(const IDSelector& sel) {
66
+ const size_t block_stride = get_block_stride();
67
67
 
68
- // Initialize empty flat storage
69
- flat_storage.clear();
68
+ idx_t j = 0;
69
+ std::vector<uint8_t> buffer(code_size);
70
+ std::unique_ptr<CodePacker> packer(get_CodePacker());
71
+ for (idx_t i = 0; i < ntotal; i++) {
72
+ if (sel.is_member(i)) {
73
+ } else {
74
+ if (i > j) {
75
+ packer->unpack_1(codes.data(), i, buffer.data());
76
+ packer->pack_1(buffer.data(), j, codes.data());
77
+ }
78
+ j++;
79
+ }
80
+ }
81
+ size_t nremove = ntotal - j;
82
+ if (nremove > 0) {
83
+ ntotal = j;
84
+ ntotal2 = roundup(ntotal, bbs);
85
+ size_t new_size = ntotal2 / bbs * block_stride;
86
+
87
+ // Zero out stale data in the last block beyond the retained vectors.
88
+ // This is necessary because pq4_pack_codes_range uses |= to write
89
+ // new codes, so any stale non-zero nibbles would corrupt future adds.
90
+ // pack_1 with a zero buffer zeroes both PQ4 codes and aux data.
91
+ const size_t last_pos = ntotal % bbs;
92
+ if (last_pos > 0) {
93
+ const size_t last_block = ntotal / bbs;
94
+ std::vector<uint8_t> zero_code(code_size, 0);
95
+ for (size_t pos = last_pos; pos < bbs; pos++) {
96
+ packer->pack_1(
97
+ zero_code.data(), last_block * bbs + pos, codes.data());
98
+ }
99
+ }
100
+
101
+ codes.resize(new_size);
102
+ }
103
+ return nremove;
70
104
  }
71
105
 
72
- IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
106
+ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs_in)
73
107
  : rabitq(orig.rabitq) {
74
108
  // RaBitQ-specific validation
75
109
  FAISS_THROW_IF_NOT_MSG(orig.d > 0, "Dimension must be positive");
@@ -89,7 +123,7 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
89
123
  M_fastscan,
90
124
  nbits_fastscan,
91
125
  orig.metric_type,
92
- bbs);
126
+ bbs_in);
93
127
 
94
128
  code_size = rabitq.compute_code_size(d, rabitq.nb_bits);
95
129
 
@@ -104,58 +138,59 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
104
138
 
105
139
  // If the original index has data, extract factors and pack codes
106
140
  if (ntotal > 0) {
107
- // Compute per-vector storage size for flat storage
108
141
  const size_t storage_size = compute_per_vector_storage_size();
109
-
110
- // Allocate flat storage
111
- flat_storage.resize(ntotal * storage_size);
112
-
113
- // Copy factors directly from original codes
114
142
  const size_t bit_pattern_size = (d + 7) / 8;
115
- for (idx_t i = 0; i < ntotal; i++) {
116
- const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
117
- const uint8_t* source_factors_ptr = orig_code + bit_pattern_size;
118
- uint8_t* storage = flat_storage.data() + i * storage_size;
119
- memcpy(storage, source_factors_ptr, storage_size);
120
- }
121
143
 
122
144
  // Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
123
- // This follows the same pattern as IndexPQFastScan constructor
124
145
  AlignedTable<uint8_t> fastscan_codes(ntotal * code_size);
125
146
  memset(fastscan_codes.get(), 0, ntotal * code_size);
126
147
 
127
- // Convert from RaBitQ 1-bit-per-dimension to FastScan
128
- // 4-bit-per-sub-quantizer
129
148
  for (idx_t i = 0; i < ntotal; i++) {
130
149
  const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
131
150
  uint8_t* fs_code = fastscan_codes.get() + i * code_size;
132
151
 
133
- // Convert each dimension's bit (same logic as compute_codes)
134
- for (size_t j = 0; j < orig.d; j++) {
135
- // Extract bit from original RaBitQ format
152
+ for (size_t j = 0; j < static_cast<size_t>(orig.d); j++) {
136
153
  const size_t orig_byte_idx = j / 8;
137
154
  const size_t orig_bit_offset = j % 8;
138
155
  const bool bit_value =
139
156
  (orig_code[orig_byte_idx] >> orig_bit_offset) & 1;
140
157
 
141
- // Use RaBitQUtils for consistent bit setting
142
158
  if (bit_value) {
143
159
  rabitq_utils::set_bit_fastscan(fs_code, j);
144
160
  }
145
161
  }
146
162
  }
147
163
 
148
- // Pack the converted codes using pq4_pack_codes with custom stride
149
- codes.resize(ntotal2 * M2 / 2);
150
- pq4_pack_codes(
164
+ // Pack the converted codes using enlarged block layout
165
+ const size_t block_stride = get_block_stride();
166
+ const size_t n_blocks = ntotal2 / bbs;
167
+ codes.resize(n_blocks * block_stride);
168
+ memset(codes.get(), 0, n_blocks * block_stride);
169
+ pq4_pack_codes_range(
151
170
  fastscan_codes.get(),
152
- ntotal,
153
171
  M,
154
- ntotal2,
172
+ 0,
173
+ ntotal,
155
174
  bbs,
156
175
  M2,
157
176
  codes.get(),
158
- code_size);
177
+ code_size,
178
+ block_stride);
179
+
180
+ // Copy auxiliary data from original codes into block aux region
181
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
182
+ for (idx_t i = 0; i < ntotal; i++) {
183
+ const uint8_t* src =
184
+ orig.codes.data() + i * orig.code_size + bit_pattern_size;
185
+ uint8_t* dst = rabitq_utils::get_block_aux_ptr(
186
+ codes.get(),
187
+ i,
188
+ bbs,
189
+ packed_block_size,
190
+ block_stride,
191
+ storage_size);
192
+ memcpy(dst, src, storage_size);
193
+ }
159
194
  }
160
195
  }
161
196
 
@@ -163,13 +198,13 @@ void IndexRaBitQFastScan::train(idx_t n, const float* x) {
163
198
  // compute a centroid
164
199
  std::vector<float> centroid(d, 0);
165
200
  for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
166
- for (size_t j = 0; j < d; j++) {
201
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
167
202
  centroid[j] += x[i * d + j];
168
203
  }
169
204
  }
170
205
 
171
206
  if (n != 0) {
172
- for (size_t j = 0; j < d; j++) {
207
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
173
208
  centroid[j] /= (float)n;
174
209
  }
175
210
  }
@@ -204,23 +239,13 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
204
239
  compute_codes(tmp_codes.get(), n, x);
205
240
 
206
241
  const size_t storage_size = compute_per_vector_storage_size();
207
- flat_storage.resize((ntotal + n) * storage_size);
208
-
209
- // Populate flat storage (no sign bits copying needed!)
210
242
  const size_t bit_pattern_size = (d + 7) / 8;
211
- for (idx_t i = 0; i < n; i++) {
212
- const uint8_t* code = tmp_codes.get() + i * code_size;
213
- const idx_t vec_idx = ntotal + i;
214
-
215
- // Copy factors data directly to flat storage (no reordering needed)
216
- const uint8_t* source_factors_ptr = code + bit_pattern_size;
217
- uint8_t* storage = flat_storage.data() + vec_idx * storage_size;
218
- memcpy(storage, source_factors_ptr, storage_size);
219
- }
220
243
 
221
- // Resize main storage (same logic as parent)
244
+ // Resize main storage with enlarged block layout
222
245
  ntotal2 = roundup(ntotal + n, bbs);
223
- size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
246
+ const size_t block_stride = get_block_stride();
247
+ const size_t n_blocks = ntotal2 / bbs;
248
+ size_t new_size = n_blocks * block_stride;
224
249
  size_t old_size = codes.size();
225
250
  if (new_size > old_size) {
226
251
  codes.resize(new_size);
@@ -230,20 +255,36 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
230
255
  // Use our custom packing function with correct stride
231
256
  pq4_pack_codes_range(
232
257
  tmp_codes.get(),
233
- M, // Number of sub-quantizers (bit patterns only)
258
+ M,
234
259
  ntotal,
235
- ntotal + n, // Range to pack
260
+ ntotal + n,
236
261
  bbs,
237
- M2, // Block parameters
238
- codes.get(), // Output
239
- code_size); // CUSTOM STRIDE: includes factor space
262
+ M2,
263
+ codes.get(),
264
+ code_size,
265
+ block_stride);
266
+
267
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
268
+ for (idx_t i = 0; i < n; i++) {
269
+ const uint8_t* src = tmp_codes.get() + i * code_size + bit_pattern_size;
270
+ uint8_t* dst = rabitq_utils::get_block_aux_ptr(
271
+ codes.get(),
272
+ ntotal + i,
273
+ bbs,
274
+ packed_block_size,
275
+ block_stride,
276
+ storage_size);
277
+ memcpy(dst, src, storage_size);
278
+ }
240
279
 
241
280
  ntotal += n;
242
281
  }
243
282
 
244
- void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
245
- const {
246
- FAISS_ASSERT(codes != nullptr);
283
+ void IndexRaBitQFastScan::compute_codes(
284
+ uint8_t* out_codes,
285
+ idx_t n,
286
+ const float* x) const {
287
+ FAISS_ASSERT(out_codes != nullptr);
247
288
  FAISS_ASSERT(x != nullptr);
248
289
  FAISS_ASSERT(
249
290
  (metric_type == MetricType::METRIC_L2 ||
@@ -258,23 +299,23 @@ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
258
299
  const size_t ex_bits = rabitq.nb_bits - 1;
259
300
  const size_t ex_code_size = (d * ex_bits + 7) / 8;
260
301
 
261
- memset(codes, 0, n * code_size);
302
+ memset(out_codes, 0, n * code_size);
262
303
 
263
304
  #pragma omp parallel for if (n > 1000)
264
305
  for (int64_t i = 0; i < n; i++) {
265
- uint8_t* const code = codes + i * code_size;
306
+ uint8_t* const code = out_codes + i * code_size;
266
307
  const float* const x_row = x + i * d;
267
308
 
268
309
  // Compute residual once, reuse for both sign bits and ex-bits
269
310
  std::vector<float> residual(d);
270
- for (size_t j = 0; j < d; j++) {
311
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
271
312
  const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
272
313
  residual[j] = x_row[j] - centroid_val;
273
314
  }
274
315
 
275
316
  // Pack sign bits directly into FastScan format using precomputed
276
317
  // residual
277
- for (size_t j = 0; j < d; j++) {
318
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
278
319
  if (residual[j] > 0.0f) {
279
320
  rabitq_utils::set_bit_fastscan(code, j);
280
321
  }
@@ -374,7 +415,7 @@ void IndexRaBitQFastScan::compute_float_LUT(
374
415
  for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
375
416
  const size_t dim_idx = dim_start + dim_offset;
376
417
 
377
- if (dim_idx < d) {
418
+ if (dim_idx < static_cast<size_t>(d)) {
378
419
  const bool db_bit = (code_val >> dim_offset) & 1;
379
420
  const float query_value = rotated_qq[dim_idx];
380
421
 
@@ -409,7 +450,8 @@ void IndexRaBitQFastScan::compute_float_LUT(
409
450
  for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
410
451
  const size_t dim_idx = dim_start + dim_offset;
411
452
 
412
- if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
453
+ if (dim_idx < static_cast<size_t>(d) &&
454
+ ((code_val >> dim_offset) & 1)) {
413
455
  inner_product += rotated_qq[dim_idx];
414
456
  popcount++;
415
457
  }
@@ -425,12 +467,16 @@ void IndexRaBitQFastScan::compute_float_LUT(
425
467
  }
426
468
  }
427
469
 
470
+ size_t IndexRaBitQFastScan::fast_scan_code_size() const {
471
+ return (d + 7) / 8;
472
+ }
473
+
428
474
  void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
429
475
  const {
430
476
  const float* centroid_in =
431
477
  (center.data() == nullptr) ? nullptr : center.data();
432
- const uint8_t* codes = bytes;
433
- FAISS_ASSERT(codes != nullptr);
478
+ const uint8_t* input_codes = bytes;
479
+ FAISS_ASSERT(input_codes != nullptr);
434
480
  FAISS_ASSERT(x != nullptr);
435
481
 
436
482
  const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
@@ -439,7 +485,7 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
439
485
  #pragma omp parallel for if (n > 1000)
440
486
  for (int64_t i = 0; i < n; i++) {
441
487
  // Access code using correct FastScan format
442
- const uint8_t* code = codes + i * code_size;
488
+ const uint8_t* code = input_codes + i * code_size;
443
489
 
444
490
  // Extract factors directly from embedded codes
445
491
  const uint8_t* factors_ptr = code + bit_pattern_size;
@@ -447,7 +493,7 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
447
493
  reinterpret_cast<const rabitq_utils::SignBitFactors*>(
448
494
  factors_ptr);
449
495
 
450
- for (size_t j = 0; j < d; j++) {
496
+ for (size_t j = 0; j < static_cast<size_t>(d); j++) {
451
497
  // Use RaBitQUtils for consistent bit extraction
452
498
  bool bit_value = rabitq_utils::extract_bit_fastscan(code, j);
453
499
  float bit = bit_value ? 1.0f : 0.0f;
@@ -484,248 +530,20 @@ void IndexRaBitQFastScan::search(
484
530
  }
485
531
  }
486
532
 
487
- // Template implementations for RaBitQHeapHandler
488
- template <class C, bool with_id_map>
489
- RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
490
- const IndexRaBitQFastScan* index,
491
- size_t nq_val,
492
- size_t k_val,
493
- float* distances,
494
- int64_t* labels,
495
- const IDSelector* sel_in,
496
- const FastScanDistancePostProcessing& ctx,
497
- bool multi_bit)
498
- : RHC(nq_val, index->ntotal, sel_in),
499
- rabitq_index(index),
500
- heap_distances(distances),
501
- heap_labels(labels),
502
- nq(nq_val),
503
- k(k_val),
504
- context(ctx),
505
- is_multi_bit(multi_bit) {
506
- // Initialize heaps for all queries in constructor
507
- // This allows us to support direct normalizer assignment
508
- #pragma omp parallel for if (nq > 100)
509
- for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
510
- float* heap_dis = heap_distances + q * k;
511
- int64_t* heap_ids = heap_labels + q * k;
512
- heap_heapify<Cfloat>(k, heap_dis, heap_ids);
513
- }
514
- }
515
-
516
- template <class C, bool with_id_map>
517
- void RaBitQHeapHandler<C, with_id_map>::handle(
518
- size_t q,
519
- size_t b,
520
- simd16uint16 d0,
521
- simd16uint16 d1) {
522
- ALIGNED(32) uint16_t d32tab[32];
523
- d0.store(d32tab);
524
- d1.store(d32tab + 16);
525
-
526
- // Get heap pointers and query factors (computed once per batch)
527
- float* const heap_dis = heap_distances + q * k;
528
- int64_t* const heap_ids = heap_labels + q * k;
529
-
530
- // Access query factors from query_factors pointer
531
- rabitq_utils::QueryFactorsData query_factors_data = {};
532
- if (context.query_factors != nullptr) {
533
- query_factors_data = context.query_factors[q];
534
- }
535
-
536
- // Compute normalizers once per batch
537
- const float one_a = normalizers ? (1.0f / normalizers[2 * q]) : 1.0f;
538
- const float bias = normalizers ? normalizers[2 * q + 1] : 0.0f;
539
-
540
- // Compute loop bounds to avoid redundant bounds checking
541
- const size_t base_db_idx = this->j0 + b * 32;
542
- const size_t max_vectors = (base_db_idx < rabitq_index->ntotal)
543
- ? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
544
- : 0;
545
-
546
- // Get storage size once
547
- const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
548
-
549
- // Stats tracking for multi-bit two-stage search only
550
- // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
551
- // n_multibit_evaluations: candidates requiring full multi-bit distance
552
- size_t local_1bit_evaluations = 0;
553
- size_t local_multibit_evaluations = 0;
554
-
555
- // Process distances in batch
556
- for (size_t i = 0; i < max_vectors; i++) {
557
- const size_t db_idx = base_db_idx + i;
558
-
559
- // Normalize distance from LUT lookup
560
- const float normalized_distance = d32tab[i] * one_a + bias;
561
-
562
- // Access factors from flat storage
563
- const uint8_t* base_ptr =
564
- rabitq_index->flat_storage.data() + db_idx * storage_size;
565
-
566
- if (is_multi_bit) {
567
- // Track candidates actually considered for two-stage filtering
568
- local_1bit_evaluations++;
569
-
570
- const SignBitFactorsWithError& full_factors =
571
- *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
572
-
573
- float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
574
- normalized_distance,
575
- full_factors,
576
- query_factors_data,
577
- rabitq_index->centered,
578
- rabitq_index->qb,
579
- rabitq_index->d);
580
-
581
- float lower_bound = compute_lower_bound(dist_1bit, db_idx, q);
582
-
583
- // Adaptive filtering: decide whether to compute full distance
584
- const bool is_similarity = rabitq_index->metric_type ==
585
- MetricType::METRIC_INNER_PRODUCT;
586
- bool should_refine = is_similarity
587
- ? (lower_bound > heap_dis[0]) // IP: keep if better
588
- : (lower_bound < heap_dis[0]); // L2: keep if better
589
-
590
- if (should_refine) {
591
- local_multibit_evaluations++;
592
- float dist_full = compute_full_multibit_distance(db_idx, q);
593
-
594
- if (Cfloat::cmp(heap_dis[0], dist_full)) {
595
- heap_replace_top<Cfloat>(
596
- k, heap_dis, heap_ids, dist_full, db_idx);
597
- }
598
- }
599
- } else {
600
- const rabitq_utils::SignBitFactors& db_factors =
601
- *reinterpret_cast<const rabitq_utils::SignBitFactors*>(
602
- base_ptr);
603
-
604
- float adjusted_distance =
605
- rabitq_utils::compute_1bit_adjusted_distance(
606
- normalized_distance,
607
- db_factors,
608
- query_factors_data,
609
- rabitq_index->centered,
610
- rabitq_index->qb,
611
- rabitq_index->d);
612
-
613
- // Add to heap if better than current worst
614
- if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
615
- heap_replace_top<Cfloat>(
616
- k, heap_dis, heap_ids, adjusted_distance, db_idx);
617
- }
618
- }
619
- }
620
-
621
- // Update global stats atomically
622
- #pragma omp atomic
623
- rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
624
- #pragma omp atomic
625
- rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
626
- }
627
-
628
- template <class C, bool with_id_map>
629
- void RaBitQHeapHandler<C, with_id_map>::begin(const float* norms) {
630
- normalizers = norms;
631
- // Heap initialization is now done in constructor
632
- }
633
-
634
- template <class C, bool with_id_map>
635
- void RaBitQHeapHandler<C, with_id_map>::end() {
636
- // Reorder final results
637
- #pragma omp parallel for if (nq > 100)
638
- for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
639
- float* heap_dis = heap_distances + q * k;
640
- int64_t* heap_ids = heap_labels + q * k;
641
- heap_reorder<Cfloat>(k, heap_dis, heap_ids);
642
- }
643
- }
644
-
645
- template <class C, bool with_id_map>
646
- float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
647
- float dist_1bit,
648
- size_t db_idx,
649
- size_t q) const {
650
- // Access f_error directly from SignBitFactorsWithError in flat storage
651
- const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
652
- const uint8_t* base_ptr =
653
- rabitq_index->flat_storage.data() + db_idx * storage_size;
654
- const SignBitFactorsWithError& db_factors =
655
- *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
656
- float f_error = db_factors.f_error;
657
-
658
- // Get g_error from query factors (query-dependent error term)
659
- float g_error = 0.0f;
660
- if (context.query_factors != nullptr) {
661
- g_error = context.query_factors[q].g_error;
662
- }
663
-
664
- // Compute error adjustment: f_error * g_error
665
- float error_adjustment = f_error * g_error;
666
-
667
- return dist_1bit - error_adjustment;
668
- }
669
-
670
- template <class C, bool with_id_map>
671
- float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
672
- size_t db_idx,
673
- size_t q) const {
674
- const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
675
- const size_t dim = rabitq_index->d;
676
-
677
- const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
678
- const uint8_t* base_ptr =
679
- rabitq_index->flat_storage.data() + db_idx * storage_size;
680
-
681
- const size_t ex_code_size = (dim * ex_bits + 7) / 8;
682
- const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
683
- const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
684
- base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
685
-
686
- // Get query factors reference (avoid copying)
687
- const rabitq_utils::QueryFactorsData& query_factors =
688
- context.query_factors[q];
689
-
690
- // Get sign bits from FastScan packed format
691
- std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
692
- CodePackerPQ4 packer(rabitq_index->M2, rabitq_index->bbs);
693
- packer.unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
694
- const uint8_t* sign_bits = unpacked_code.data();
695
-
696
- return rabitq_utils::compute_full_multibit_distance(
697
- sign_bits,
698
- ex_code,
699
- ex_fac,
700
- query_factors.rotated_q.data(),
701
- query_factors.qr_to_c_L2sqr,
702
- query_factors.qr_norm_L2sqr,
703
- dim,
704
- ex_bits,
705
- rabitq_index->metric_type);
706
- }
533
+ std::unique_ptr<FastScanCodeScanner> IndexRaBitQFastScan::make_knn_scanner(
707
534
 
708
- // Implementation of virtual make_knn_handler method
709
- SIMDResultHandlerToFloat* IndexRaBitQFastScan::make_knn_handler(
710
535
  bool is_max,
711
- int /*impl*/,
712
536
  idx_t n,
713
537
  idx_t k,
714
538
  size_t /*ntotal*/,
715
539
  float* distances,
716
540
  idx_t* labels,
717
541
  const IDSelector* sel,
542
+ int /*impl*/,
718
543
  const FastScanDistancePostProcessing& context) const {
719
- // Use runtime boolean for multi-bit mode
720
- const bool multi_bit = rabitq.nb_bits > 1;
721
-
722
- if (is_max) {
723
- return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
724
- this, n, k, distances, labels, sel, context, multi_bit);
725
- } else {
726
- return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
727
- this, n, k, distances, labels, sel, context, multi_bit);
728
- }
544
+ const bool is_multi_bit = rabitq.nb_bits > 1;
545
+ return rabitq_make_knn_scanner(
546
+ this, is_max, n, k, distances, labels, sel, context, is_multi_bit);
729
547
  }
730
548
 
731
549
  } // namespace faiss