faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -9,7 +9,6 @@
9
9
 
10
10
  #include <faiss/IndexIVFPQ.h>
11
11
 
12
- #include <cassert>
13
12
  #include <cinttypes>
14
13
  #include <cmath>
15
14
  #include <cstdint>
@@ -17,22 +16,26 @@
17
16
 
18
17
  #include <algorithm>
19
18
 
20
- #include <faiss/utils/Heap.h>
21
- #include <faiss/utils/distances.h>
19
+ #include <faiss/utils/distances_dispatch.h>
22
20
  #include <faiss/utils/utils.h>
23
21
 
24
22
  #include <faiss/Clustering.h>
25
23
 
26
24
  #include <faiss/utils/hamming.h>
27
25
 
28
- #include <faiss/impl/FaissAssert.h>
29
-
30
26
  #include <faiss/impl/AuxIndexStructures.h>
27
+ #include <faiss/impl/FaissAssert.h>
31
28
  #include <faiss/impl/IDSelector.h>
32
-
33
29
  #include <faiss/impl/ProductQuantizer.h>
30
+ #include <faiss/impl/ResultHandler.h>
31
+ #include <faiss/impl/pq_code_distance/pq_code_distance-generic.h>
32
+ #include <faiss/impl/simd_dispatch.h>
34
33
 
35
- #include <faiss/impl/code_distance/code_distance.h>
34
+ // Scalar (NONE) fallback for dynamic dispatch
35
+ #define THE_SIMD_LEVEL SIMDLevel::NONE
36
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
37
+ #include <faiss/impl/pq_code_distance/IVFPQScanner_impl.h>
38
+ #undef THE_SIMD_LEVEL
36
39
 
37
40
  namespace faiss {
38
41
 
@@ -41,17 +44,17 @@ namespace faiss {
41
44
  ******************************************/
42
45
 
43
46
  IndexIVFPQ::IndexIVFPQ(
44
- Index* quantizer,
45
- size_t d,
46
- size_t nlist,
47
+ Index* quantizer_in,
48
+ size_t d_in,
49
+ size_t nlist_in,
47
50
  size_t M,
48
51
  size_t nbits_per_idx,
49
52
  MetricType metric,
50
- bool own_invlists)
51
- : IndexIVF(quantizer, d, nlist, 0, metric, own_invlists),
52
- pq(d, M, nbits_per_idx) {
53
+ bool own_invlists_in)
54
+ : IndexIVF(quantizer_in, d_in, nlist_in, 0, metric, own_invlists_in),
55
+ pq(d_in, M, nbits_per_idx) {
53
56
  code_size = pq.code_size;
54
- if (own_invlists) {
57
+ if (own_invlists_in) {
55
58
  invlists->code_size = code_size;
56
59
  }
57
60
  is_trained = false;
@@ -67,12 +70,16 @@ IndexIVFPQ::IndexIVFPQ(
67
70
  /****************************************************************
68
71
  * training */
69
72
 
70
- void IndexIVFPQ::train_encoder(idx_t n, const float* x, const idx_t* assign) {
73
+ void IndexIVFPQ::train_encoder(
74
+ idx_t n,
75
+ const float* x,
76
+ const idx_t* /*assign*/) {
71
77
  pq.train(n, x);
72
78
 
73
79
  if (do_polysemous_training) {
74
- if (verbose)
80
+ if (verbose) {
75
81
  printf("doing polysemous training for PQ\n");
82
+ }
76
83
  PolysemousTraining default_pt;
77
84
  PolysemousTraining* pt =
78
85
  polysemous_training ? polysemous_training : &default_pt;
@@ -97,8 +104,9 @@ void IndexIVFPQ::encode(idx_t key, const float* x, uint8_t* code) const {
97
104
  std::vector<float> residual_vec(d);
98
105
  quantizer->compute_residual(x, residual_vec.data(), key);
99
106
  pq.compute_code(residual_vec.data(), code);
100
- } else
107
+ } else {
101
108
  pq.compute_code(x, code);
109
+ }
102
110
  }
103
111
 
104
112
  void IndexIVFPQ::encode_multiple(
@@ -107,8 +115,9 @@ void IndexIVFPQ::encode_multiple(
107
115
  const float* x,
108
116
  uint8_t* xcodes,
109
117
  bool compute_keys) const {
110
- if (compute_keys)
118
+ if (compute_keys) {
111
119
  quantizer->assign(n, x, keys);
120
+ }
112
121
 
113
122
  encode_vectors(n, x, keys, xcodes);
114
123
  }
@@ -124,7 +133,7 @@ void IndexIVFPQ::decode_multiple(
124
133
  for (size_t i = 0; i < n; i++) {
125
134
  quantizer->reconstruct(keys[i], centroid.data());
126
135
  float* xi = x + i * d;
127
- for (size_t j = 0; j < d; j++) {
136
+ for (int j = 0; j < d; j++) {
128
137
  xi[j] += centroid[j];
129
138
  }
130
139
  }
@@ -150,13 +159,15 @@ static std::unique_ptr<float[]> compute_residuals(
150
159
  const idx_t* list_nos) {
151
160
  size_t d = quantizer->d;
152
161
  std::unique_ptr<float[]> residuals(new float[n * d]);
153
- // TODO: parallelize?
154
- for (size_t i = 0; i < n; i++) {
155
- if (list_nos[i] < 0)
162
+ // Parallelize with OpenMP (each iteration is independent)
163
+ #pragma omp parallel for if (n > 1000)
164
+ for (idx_t i = 0; i < n; i++) {
165
+ if (list_nos[i] < 0) {
156
166
  memset(residuals.get() + i * d, 0, sizeof(float) * d);
157
- else
167
+ } else {
158
168
  quantizer->compute_residual(
159
169
  x + i * d, residuals.get() + i * d, list_nos[i]);
170
+ }
160
171
  }
161
172
  return residuals;
162
173
  }
@@ -208,7 +219,7 @@ void IndexIVFPQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
208
219
  pq.decode(code + coarse_size, xi);
209
220
  if (by_residual) {
210
221
  quantizer->reconstruct(list_no, residual.data());
211
- for (size_t j = 0; j < d; j++) {
222
+ for (int j = 0; j < d; j++) {
212
223
  xi[j] += residual[j];
213
224
  }
214
225
  }
@@ -283,14 +294,15 @@ void IndexIVFPQ::add_core_o(
283
294
  double t2 = getmillisecs();
284
295
  // TODO: parallelize?
285
296
  size_t n_ignore = 0;
286
- for (size_t i = 0; i < n; i++) {
297
+ for (idx_t i = 0; i < n; i++) {
287
298
  idx_t key = idx[i];
288
299
  idx_t id = xids ? xids[i] : ntotal + i;
289
300
  if (key < 0) {
290
301
  direct_map.add_single_id(id, -1, 0);
291
302
  n_ignore++;
292
- if (residuals_2)
303
+ if (residuals_2) {
293
304
  memset(residuals_2, 0, sizeof(*residuals_2) * d);
305
+ }
294
306
  continue;
295
307
  }
296
308
 
@@ -302,8 +314,9 @@ void IndexIVFPQ::add_core_o(
302
314
  float* res2 = residuals_2 + i * d;
303
315
  const float* xi = to_encode + i * d;
304
316
  pq.decode(code, res2);
305
- for (int j = 0; j < d; j++)
317
+ for (int j = 0; j < d; j++) {
306
318
  res2[j] = xi[j] - res2[j];
319
+ }
307
320
  }
308
321
 
309
322
  direct_map.add_single_id(id, key, offset);
@@ -312,8 +325,9 @@ void IndexIVFPQ::add_core_o(
312
325
  double t3 = getmillisecs();
313
326
  if (verbose) {
314
327
  char comment[100] = {0};
315
- if (n_ignore > 0)
328
+ if (n_ignore > 0) {
316
329
  snprintf(comment, 100, "(%zd vectors ignored)", n_ignore);
330
+ }
317
331
  printf(" add_core times: %.3f %.3f %.3f %s\n",
318
332
  t1 - t0,
319
333
  t2 - t1,
@@ -380,6 +394,7 @@ void initialize_IVFPQ_precomputed_table(
380
394
  AlignedTable<float>& precomputed_table,
381
395
  bool by_residual,
382
396
  bool verbose) {
397
+ FAISS_THROW_IF_NOT_MSG(quantizer, "IVF quantizer must not be null");
383
398
  size_t nlist = quantizer->ntotal;
384
399
  size_t d = quantizer->d;
385
400
  FAISS_THROW_IF_NOT(d == pq.d);
@@ -389,6 +404,9 @@ void initialize_IVFPQ_precomputed_table(
389
404
  return;
390
405
  }
391
406
 
407
+ const size_t m_ksub =
408
+ mul_no_overflow(pq.M, pq.ksub, "IVFPQ precomputed_table");
409
+
392
410
  if (use_precomputed_table == 0) { // then choose the type of table
393
411
  if (!(quantizer->metric_type == METRIC_L2 && by_residual)) {
394
412
  if (verbose) {
@@ -400,10 +418,13 @@ void initialize_IVFPQ_precomputed_table(
400
418
  }
401
419
  const MultiIndexQuantizer* miq =
402
420
  dynamic_cast<const MultiIndexQuantizer*>(quantizer);
403
- if (miq && pq.M % miq->pq.M == 0)
421
+ if (miq && pq.M % miq->pq.M == 0) {
404
422
  use_precomputed_table = 2;
405
- else {
406
- size_t table_size = pq.M * pq.ksub * nlist * sizeof(float);
423
+ } else {
424
+ size_t table_size = mul_no_overflow(
425
+ mul_no_overflow(m_ksub, nlist, "IVFPQ precomputed_table"),
426
+ sizeof(float),
427
+ "IVFPQ precomputed_table");
407
428
  if (table_size > precomputed_table_max_bytes) {
408
429
  if (verbose) {
409
430
  printf("IndexIVFPQ::precompute_table: not precomputing table, "
@@ -423,22 +444,25 @@ void initialize_IVFPQ_precomputed_table(
423
444
  }
424
445
 
425
446
  // squared norms of the PQ centroids
426
- std::vector<float> r_norms(pq.M * pq.ksub, NAN);
427
- for (int m = 0; m < pq.M; m++)
428
- for (int j = 0; j < pq.ksub; j++)
447
+ std::vector<float> r_norms(m_ksub, NAN);
448
+ for (size_t m = 0; m < pq.M; m++) {
449
+ for (size_t j = 0; j < pq.ksub; j++) {
429
450
  r_norms[m * pq.ksub + j] =
430
- fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
451
+ fvec_norm_L2sqr_dispatch(pq.get_centroids(m, j), pq.dsub);
452
+ }
453
+ }
431
454
 
432
455
  if (use_precomputed_table == 1) {
433
- precomputed_table.resize(nlist * pq.M * pq.ksub);
456
+ precomputed_table.resize(
457
+ mul_no_overflow(nlist, m_ksub, "IVFPQ precomputed_table"));
434
458
  std::vector<float> centroid(d);
435
459
 
436
460
  for (size_t i = 0; i < nlist; i++) {
437
461
  quantizer->reconstruct(i, centroid.data());
438
462
 
439
- float* tab = &precomputed_table[i * pq.M * pq.ksub];
463
+ float* tab = &precomputed_table[i * m_ksub];
440
464
  pq.compute_inner_prod_table(centroid.data(), tab);
441
- fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
465
+ fvec_madd_dispatch(m_ksub, r_norms.data(), 2.0, tab, tab);
442
466
  }
443
467
  } else if (use_precomputed_table == 2) {
444
468
  const MultiIndexQuantizer* miq =
@@ -447,12 +471,13 @@ void initialize_IVFPQ_precomputed_table(
447
471
  const ProductQuantizer& cpq = miq->pq;
448
472
  FAISS_THROW_IF_NOT(pq.M % cpq.M == 0);
449
473
 
450
- precomputed_table.resize(cpq.ksub * pq.M * pq.ksub);
474
+ precomputed_table.resize(
475
+ mul_no_overflow(cpq.ksub, m_ksub, "IVFPQ precomputed_table"));
451
476
 
452
477
  // reorder PQ centroid table
453
478
  std::vector<float> centroids(d * cpq.ksub, NAN);
454
479
 
455
- for (int m = 0; m < cpq.M; m++) {
480
+ for (size_t m = 0; m < cpq.M; m++) {
456
481
  for (size_t i = 0; i < cpq.ksub; i++) {
457
482
  memcpy(centroids.data() + i * d + m * cpq.dsub,
458
483
  cpq.get_centroids(m, i),
@@ -464,8 +489,8 @@ void initialize_IVFPQ_precomputed_table(
464
489
  cpq.ksub, centroids.data(), precomputed_table.data());
465
490
 
466
491
  for (size_t i = 0; i < cpq.ksub; i++) {
467
- float* tab = &precomputed_table[i * pq.M * pq.ksub];
468
- fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
492
+ float* tab = &precomputed_table[i * m_ksub];
493
+ fvec_madd_dispatch(m_ksub, r_norms.data(), 2.0, tab, tab);
469
494
  }
470
495
  }
471
496
  }
@@ -480,870 +505,14 @@ void IndexIVFPQ::precompute_table() {
480
505
  verbose);
481
506
  }
482
507
 
483
- namespace {
484
-
485
- #define TIC t0 = get_cycles()
486
- #define TOC get_cycles() - t0
487
-
488
- /** QueryTables manages the various ways of searching an
489
- * IndexIVFPQ. The code contains a lot of branches, depending on:
490
- * - metric_type: are we computing L2 or Inner product similarity?
491
- * - by_residual: do we encode raw vectors or residuals?
492
- * - use_precomputed_table: are x_R|x_C tables precomputed?
493
- * - polysemous_ht: are we filtering with polysemous codes?
494
- */
495
- struct QueryTables {
496
- /*****************************************************
497
- * General data from the IVFPQ
498
- *****************************************************/
499
-
500
- const IndexIVFPQ& ivfpq;
501
- const IVFSearchParameters* params;
502
-
503
- // copied from IndexIVFPQ for easier access
504
- int d;
505
- const ProductQuantizer& pq;
506
- MetricType metric_type;
507
- bool by_residual;
508
- int use_precomputed_table;
509
- int polysemous_ht;
510
-
511
- // pre-allocated data buffers
512
- float *sim_table, *sim_table_2;
513
- float *residual_vec, *decoded_vec;
514
-
515
- // single data buffer
516
- std::vector<float> mem;
517
-
518
- // for table pointers
519
- std::vector<const float*> sim_table_ptrs;
520
-
521
- explicit QueryTables(
522
- const IndexIVFPQ& ivfpq,
523
- const IVFSearchParameters* params)
524
- : ivfpq(ivfpq),
525
- d(ivfpq.d),
526
- pq(ivfpq.pq),
527
- metric_type(ivfpq.metric_type),
528
- by_residual(ivfpq.by_residual),
529
- use_precomputed_table(ivfpq.use_precomputed_table) {
530
- mem.resize(pq.ksub * pq.M * 2 + d * 2);
531
- sim_table = mem.data();
532
- sim_table_2 = sim_table + pq.ksub * pq.M;
533
- residual_vec = sim_table_2 + pq.ksub * pq.M;
534
- decoded_vec = residual_vec + d;
535
-
536
- // for polysemous
537
- polysemous_ht = ivfpq.polysemous_ht;
538
- if (auto ivfpq_params =
539
- dynamic_cast<const IVFPQSearchParameters*>(params)) {
540
- polysemous_ht = ivfpq_params->polysemous_ht;
541
- }
542
- if (polysemous_ht != 0) {
543
- q_code.resize(pq.code_size);
544
- }
545
- init_list_cycles = 0;
546
- sim_table_ptrs.resize(pq.M);
547
- }
548
-
549
- /*****************************************************
550
- * What we do when query is known
551
- *****************************************************/
552
-
553
- // field specific to query
554
- const float* qi;
555
-
556
- // query-specific initialization
557
- void init_query(const float* qi) {
558
- this->qi = qi;
559
- if (metric_type == METRIC_INNER_PRODUCT)
560
- init_query_IP();
561
- else
562
- init_query_L2();
563
- if (!by_residual && polysemous_ht != 0)
564
- pq.compute_code(qi, q_code.data());
565
- }
566
-
567
- void init_query_IP() {
568
- // precompute some tables specific to the query qi
569
- pq.compute_inner_prod_table(qi, sim_table);
570
- }
571
-
572
- void init_query_L2() {
573
- if (!by_residual) {
574
- pq.compute_distance_table(qi, sim_table);
575
- } else if (use_precomputed_table) {
576
- pq.compute_inner_prod_table(qi, sim_table_2);
577
- }
578
- }
579
-
580
- /*****************************************************
581
- * When inverted list is known: prepare computations
582
- *****************************************************/
583
-
584
- // fields specific to list
585
- idx_t key;
586
- float coarse_dis;
587
- std::vector<uint8_t> q_code;
588
-
589
- uint64_t init_list_cycles;
590
-
591
- /// once we know the query and the centroid, we can prepare the
592
- /// sim_table that will be used for accumulation
593
- /// and dis0, the initial value
594
- float precompute_list_tables() {
595
- float dis0 = 0;
596
- uint64_t t0;
597
- TIC;
598
- if (by_residual) {
599
- if (metric_type == METRIC_INNER_PRODUCT)
600
- dis0 = precompute_list_tables_IP();
601
- else
602
- dis0 = precompute_list_tables_L2();
603
- }
604
- init_list_cycles += TOC;
605
- return dis0;
606
- }
607
-
608
- float precompute_list_table_pointers() {
609
- float dis0 = 0;
610
- uint64_t t0;
611
- TIC;
612
- if (by_residual) {
613
- if (metric_type == METRIC_INNER_PRODUCT)
614
- FAISS_THROW_MSG("not implemented");
615
- else
616
- dis0 = precompute_list_table_pointers_L2();
617
- }
618
- init_list_cycles += TOC;
619
- return dis0;
620
- }
621
-
622
- /*****************************************************
623
- * compute tables for inner prod
624
- *****************************************************/
625
-
626
- float precompute_list_tables_IP() {
627
- // prepare the sim_table that will be used for accumulation
628
- // and dis0, the initial value
629
- ivfpq.quantizer->reconstruct(key, decoded_vec);
630
- // decoded_vec = centroid
631
- float dis0 = fvec_inner_product(qi, decoded_vec, d);
632
-
633
- if (polysemous_ht) {
634
- for (int i = 0; i < d; i++) {
635
- residual_vec[i] = qi[i] - decoded_vec[i];
636
- }
637
- pq.compute_code(residual_vec, q_code.data());
638
- }
639
- return dis0;
640
- }
641
-
642
- /*****************************************************
643
- * compute tables for L2 distance
644
- *****************************************************/
645
-
646
- float precompute_list_tables_L2() {
647
- float dis0 = 0;
648
-
649
- if (use_precomputed_table == 0 || use_precomputed_table == -1) {
650
- ivfpq.quantizer->compute_residual(qi, residual_vec, key);
651
- pq.compute_distance_table(residual_vec, sim_table);
652
-
653
- if (polysemous_ht != 0) {
654
- pq.compute_code(residual_vec, q_code.data());
655
- }
656
-
657
- } else if (use_precomputed_table == 1) {
658
- dis0 = coarse_dis;
659
-
660
- fvec_madd(
661
- pq.M * pq.ksub,
662
- ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
663
- -2.0,
664
- sim_table_2,
665
- sim_table);
666
-
667
- if (polysemous_ht != 0) {
668
- ivfpq.quantizer->compute_residual(qi, residual_vec, key);
669
- pq.compute_code(residual_vec, q_code.data());
670
- }
671
-
672
- } else if (use_precomputed_table == 2) {
673
- dis0 = coarse_dis;
674
-
675
- const MultiIndexQuantizer* miq =
676
- dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
677
- FAISS_THROW_IF_NOT(miq);
678
- const ProductQuantizer& cpq = miq->pq;
679
- int Mf = pq.M / cpq.M;
680
-
681
- const float* qtab = sim_table_2; // query-specific table
682
- float* ltab = sim_table; // (output) list-specific table
683
-
684
- long k = key;
685
- for (int cm = 0; cm < cpq.M; cm++) {
686
- // compute PQ index
687
- int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
688
- k >>= cpq.nbits;
689
-
690
- // get corresponding table
691
- const float* pc = ivfpq.precomputed_table.data() +
692
- (ki * pq.M + cm * Mf) * pq.ksub;
693
-
694
- if (polysemous_ht == 0) {
695
- // sum up with query-specific table
696
- fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
697
- ltab += Mf * pq.ksub;
698
- qtab += Mf * pq.ksub;
699
- } else {
700
- for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
701
- q_code[m] = fvec_madd_and_argmin(
702
- pq.ksub, pc, -2, qtab, ltab);
703
- pc += pq.ksub;
704
- ltab += pq.ksub;
705
- qtab += pq.ksub;
706
- }
707
- }
708
- }
709
- }
710
-
711
- return dis0;
712
- }
713
-
714
- float precompute_list_table_pointers_L2() {
715
- float dis0 = 0;
716
-
717
- if (use_precomputed_table == 1) {
718
- dis0 = coarse_dis;
719
-
720
- const float* s =
721
- ivfpq.precomputed_table.data() + key * pq.ksub * pq.M;
722
- for (int m = 0; m < pq.M; m++) {
723
- sim_table_ptrs[m] = s;
724
- s += pq.ksub;
725
- }
726
- } else if (use_precomputed_table == 2) {
727
- dis0 = coarse_dis;
728
-
729
- const MultiIndexQuantizer* miq =
730
- dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
731
- FAISS_THROW_IF_NOT(miq);
732
- const ProductQuantizer& cpq = miq->pq;
733
- int Mf = pq.M / cpq.M;
734
-
735
- long k = key;
736
- int m0 = 0;
737
- for (int cm = 0; cm < cpq.M; cm++) {
738
- int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
739
- k >>= cpq.nbits;
740
-
741
- const float* pc = ivfpq.precomputed_table.data() +
742
- (ki * pq.M + cm * Mf) * pq.ksub;
743
-
744
- for (int m = m0; m < m0 + Mf; m++) {
745
- sim_table_ptrs[m] = pc;
746
- pc += pq.ksub;
747
- }
748
- m0 += Mf;
749
- }
750
- } else {
751
- FAISS_THROW_MSG("need precomputed tables");
752
- }
753
-
754
- if (polysemous_ht) {
755
- FAISS_THROW_MSG("not implemented");
756
- // Not clear that it makes sense to implemente this,
757
- // because it costs M * ksub, which is what we wanted to
758
- // avoid with the tables pointers.
759
- }
760
-
761
- return dis0;
762
- }
763
- };
764
-
765
- // This way of handling the selector is not optimal since all distances
766
- // are computed even if the id would filter it out.
767
- template <class C, bool use_sel>
768
- struct KnnSearchResults {
769
- idx_t key;
770
- const idx_t* ids;
771
- const IDSelector* sel;
772
-
773
- // heap params
774
- size_t k;
775
- float* heap_sim;
776
- idx_t* heap_ids;
777
-
778
- size_t nup;
779
-
780
- inline bool skip_entry(idx_t j) {
781
- return use_sel && !sel->is_member(ids[j]);
782
- }
783
-
784
- inline void add(idx_t j, float dis) {
785
- if (C::cmp(heap_sim[0], dis)) {
786
- idx_t id = ids ? ids[j] : lo_build(key, j);
787
- heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
788
- nup++;
789
- }
790
- }
791
- };
792
-
793
- template <class C, bool use_sel>
794
- struct RangeSearchResults {
795
- idx_t key;
796
- const idx_t* ids;
797
- const IDSelector* sel;
798
-
799
- // wrapped result structure
800
- float radius;
801
- RangeQueryResult& rres;
802
-
803
- inline bool skip_entry(idx_t j) {
804
- return use_sel && !sel->is_member(ids[j]);
805
- }
806
-
807
- inline void add(idx_t j, float dis) {
808
- if (C::cmp(radius, dis)) {
809
- idx_t id = ids ? ids[j] : lo_build(key, j);
810
- rres.add(dis, id);
811
- }
812
- }
813
- };
814
-
815
- /*****************************************************
816
- * Scaning the codes.
817
- * The scanning functions call their favorite precompute_*
818
- * function to precompute the tables they need.
819
- *****************************************************/
820
- template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
821
- struct IVFPQScannerT : QueryTables {
822
- const uint8_t* list_codes;
823
- const IDType* list_ids;
824
- size_t list_size;
825
-
826
- IVFPQScannerT(const IndexIVFPQ& ivfpq, const IVFSearchParameters* params)
827
- : QueryTables(ivfpq, params) {
828
- assert(METRIC_TYPE == metric_type);
829
- }
830
-
831
- float dis0;
832
-
833
- void init_list(idx_t list_no, float coarse_dis, int mode) {
834
- this->key = list_no;
835
- this->coarse_dis = coarse_dis;
836
-
837
- if (mode == 2) {
838
- dis0 = precompute_list_tables();
839
- } else if (mode == 1) {
840
- dis0 = precompute_list_table_pointers();
841
- }
842
- }
843
-
844
- /*****************************************************
845
- * Scaning the codes: simple PQ scan.
846
- *****************************************************/
847
-
848
- // This is the baseline version of scan_list_with_tables().
849
- // It demonstrates what this function actually does.
850
- //
851
- // /// version of the scan where we use precomputed tables.
852
- // template <class SearchResultType>
853
- // void scan_list_with_table(
854
- // size_t ncode,
855
- // const uint8_t* codes,
856
- // SearchResultType& res) const {
857
- //
858
- // for (size_t j = 0; j < ncode; j++, codes += pq.code_size) {
859
- // if (res.skip_entry(j)) {
860
- // continue;
861
- // }
862
- // float dis = dis0 + distance_single_code<PQDecoder>(
863
- // pq, sim_table, codes);
864
- // res.add(j, dis);
865
- // }
866
- // }
867
-
868
- // This is the modified version of scan_list_with_tables().
869
- // It was observed that doing manual unrolling of the loop that
870
- // utilizes distance_single_code() speeds up the computations.
871
-
872
- /// version of the scan where we use precomputed tables.
873
- template <class SearchResultType>
874
- void scan_list_with_table(
875
- size_t ncode,
876
- const uint8_t* codes,
877
- SearchResultType& res) const {
878
- int counter = 0;
879
-
880
- size_t saved_j[4] = {0, 0, 0, 0};
881
- for (size_t j = 0; j < ncode; j++) {
882
- if (res.skip_entry(j)) {
883
- continue;
884
- }
885
-
886
- saved_j[0] = (counter == 0) ? j : saved_j[0];
887
- saved_j[1] = (counter == 1) ? j : saved_j[1];
888
- saved_j[2] = (counter == 2) ? j : saved_j[2];
889
- saved_j[3] = (counter == 3) ? j : saved_j[3];
890
-
891
- counter += 1;
892
- if (counter == 4) {
893
- float distance_0 = 0;
894
- float distance_1 = 0;
895
- float distance_2 = 0;
896
- float distance_3 = 0;
897
- distance_four_codes<PQDecoder>(
898
- pq.M,
899
- pq.nbits,
900
- sim_table,
901
- codes + saved_j[0] * pq.code_size,
902
- codes + saved_j[1] * pq.code_size,
903
- codes + saved_j[2] * pq.code_size,
904
- codes + saved_j[3] * pq.code_size,
905
- distance_0,
906
- distance_1,
907
- distance_2,
908
- distance_3);
909
-
910
- res.add(saved_j[0], dis0 + distance_0);
911
- res.add(saved_j[1], dis0 + distance_1);
912
- res.add(saved_j[2], dis0 + distance_2);
913
- res.add(saved_j[3], dis0 + distance_3);
914
- counter = 0;
915
- }
916
- }
917
-
918
- if (counter >= 1) {
919
- float dis = dis0 +
920
- distance_single_code<PQDecoder>(
921
- pq.M,
922
- pq.nbits,
923
- sim_table,
924
- codes + saved_j[0] * pq.code_size);
925
- res.add(saved_j[0], dis);
926
- }
927
- if (counter >= 2) {
928
- float dis = dis0 +
929
- distance_single_code<PQDecoder>(
930
- pq.M,
931
- pq.nbits,
932
- sim_table,
933
- codes + saved_j[1] * pq.code_size);
934
- res.add(saved_j[1], dis);
935
- }
936
- if (counter >= 3) {
937
- float dis = dis0 +
938
- distance_single_code<PQDecoder>(
939
- pq.M,
940
- pq.nbits,
941
- sim_table,
942
- codes + saved_j[2] * pq.code_size);
943
- res.add(saved_j[2], dis);
944
- }
945
- }
946
-
947
- /// tables are not precomputed, but pointers are provided to the
948
- /// relevant X_c|x_r tables
949
- template <class SearchResultType>
950
- void scan_list_with_pointer(
951
- size_t ncode,
952
- const uint8_t* codes,
953
- SearchResultType& res) const {
954
- for (size_t j = 0; j < ncode; j++, codes += pq.code_size) {
955
- if (res.skip_entry(j)) {
956
- continue;
957
- }
958
- PQDecoder decoder(codes, pq.nbits);
959
- float dis = dis0;
960
- const float* tab = sim_table_2;
961
-
962
- for (size_t m = 0; m < pq.M; m++) {
963
- int ci = decoder.decode();
964
- dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
965
- tab += pq.ksub;
966
- }
967
- res.add(j, dis);
968
- }
969
- }
970
-
971
- /// nothing is precomputed: access residuals on-the-fly
972
- template <class SearchResultType>
973
- void scan_on_the_fly_dist(
974
- size_t ncode,
975
- const uint8_t* codes,
976
- SearchResultType& res) const {
977
- const float* dvec;
978
- float dis0 = 0;
979
- if (by_residual) {
980
- if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
981
- ivfpq.quantizer->reconstruct(key, residual_vec);
982
- dis0 = fvec_inner_product(residual_vec, qi, d);
983
- } else {
984
- ivfpq.quantizer->compute_residual(qi, residual_vec, key);
985
- }
986
- dvec = residual_vec;
987
- } else {
988
- dvec = qi;
989
- dis0 = 0;
990
- }
991
-
992
- for (size_t j = 0; j < ncode; j++, codes += pq.code_size) {
993
- if (res.skip_entry(j)) {
994
- continue;
995
- }
996
- pq.decode(codes, decoded_vec);
997
-
998
- float dis;
999
- if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
1000
- dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
1001
- } else {
1002
- dis = fvec_L2sqr(decoded_vec, dvec, d);
1003
- }
1004
- res.add(j, dis);
1005
- }
1006
- }
1007
-
1008
- /*****************************************************
1009
- * Scanning codes with polysemous filtering
1010
- *****************************************************/
1011
-
1012
- // This is the baseline version of scan_list_polysemous_hc().
1013
- // It demonstrates what this function actually does.
1014
-
1015
- // template <class HammingComputer, class SearchResultType>
1016
- // void scan_list_polysemous_hc(
1017
- // size_t ncode,
1018
- // const uint8_t* codes,
1019
- // SearchResultType& res) const {
1020
- // int ht = ivfpq.polysemous_ht;
1021
- // size_t n_hamming_pass = 0, nup = 0;
1022
- //
1023
- // int code_size = pq.code_size;
1024
- //
1025
- // HammingComputer hc(q_code.data(), code_size);
1026
- //
1027
- // for (size_t j = 0; j < ncode; j++, codes += code_size) {
1028
- // if (res.skip_entry(j)) {
1029
- // continue;
1030
- // }
1031
- // const uint8_t* b_code = codes;
1032
- // int hd = hc.hamming(b_code);
1033
- // if (hd < ht) {
1034
- // n_hamming_pass++;
1035
- //
1036
- // float dis =
1037
- // dis0 +
1038
- // distance_single_code<PQDecoder>(
1039
- // pq, sim_table, codes);
1040
- //
1041
- // res.add(j, dis);
1042
- // }
1043
- // }
1044
- // #pragma omp critical
1045
- // { indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
1046
- // }
1047
-
1048
- // This is the modified version of scan_list_with_tables().
1049
- // It was observed that doing manual unrolling of the loop that
1050
- // utilizes distance_single_code() speeds up the computations.
1051
-
1052
- template <class HammingComputer, class SearchResultType>
1053
- void scan_list_polysemous_hc(
1054
- size_t ncode,
1055
- const uint8_t* codes,
1056
- SearchResultType& res) const {
1057
- int ht = ivfpq.polysemous_ht;
1058
- size_t n_hamming_pass = 0;
1059
-
1060
- int code_size = pq.code_size;
1061
-
1062
- size_t saved_j[8];
1063
- int counter = 0;
1064
-
1065
- HammingComputer hc(q_code.data(), code_size);
1066
-
1067
- for (size_t j = 0; j < (ncode / 4) * 4; j += 4) {
1068
- const uint8_t* b_code = codes + j * code_size;
1069
-
1070
- // Unrolling is a key. Basically, doing multiple popcount
1071
- // operations one after another speeds things up.
1072
-
1073
- // 9999999 is just an arbitrary large number
1074
- int hd0 = (res.skip_entry(j + 0))
1075
- ? 99999999
1076
- : hc.hamming(b_code + 0 * code_size);
1077
- int hd1 = (res.skip_entry(j + 1))
1078
- ? 99999999
1079
- : hc.hamming(b_code + 1 * code_size);
1080
- int hd2 = (res.skip_entry(j + 2))
1081
- ? 99999999
1082
- : hc.hamming(b_code + 2 * code_size);
1083
- int hd3 = (res.skip_entry(j + 3))
1084
- ? 99999999
1085
- : hc.hamming(b_code + 3 * code_size);
1086
-
1087
- saved_j[counter] = j + 0;
1088
- counter = (hd0 < ht) ? (counter + 1) : counter;
1089
- saved_j[counter] = j + 1;
1090
- counter = (hd1 < ht) ? (counter + 1) : counter;
1091
- saved_j[counter] = j + 2;
1092
- counter = (hd2 < ht) ? (counter + 1) : counter;
1093
- saved_j[counter] = j + 3;
1094
- counter = (hd3 < ht) ? (counter + 1) : counter;
1095
-
1096
- if (counter >= 4) {
1097
- // process four codes at the same time
1098
- n_hamming_pass += 4;
1099
-
1100
- float distance_0 = dis0;
1101
- float distance_1 = dis0;
1102
- float distance_2 = dis0;
1103
- float distance_3 = dis0;
1104
- distance_four_codes<PQDecoder>(
1105
- pq.M,
1106
- pq.nbits,
1107
- sim_table,
1108
- codes + saved_j[0] * pq.code_size,
1109
- codes + saved_j[1] * pq.code_size,
1110
- codes + saved_j[2] * pq.code_size,
1111
- codes + saved_j[3] * pq.code_size,
1112
- distance_0,
1113
- distance_1,
1114
- distance_2,
1115
- distance_3);
1116
-
1117
- res.add(saved_j[0], dis0 + distance_0);
1118
- res.add(saved_j[1], dis0 + distance_1);
1119
- res.add(saved_j[2], dis0 + distance_2);
1120
- res.add(saved_j[3], dis0 + distance_3);
1121
-
1122
- //
1123
- counter -= 4;
1124
- saved_j[0] = saved_j[4];
1125
- saved_j[1] = saved_j[5];
1126
- saved_j[2] = saved_j[6];
1127
- saved_j[3] = saved_j[7];
1128
- }
1129
- }
1130
-
1131
- for (size_t kk = 0; kk < counter; kk++) {
1132
- n_hamming_pass++;
1133
-
1134
- float dis = dis0 +
1135
- distance_single_code<PQDecoder>(
1136
- pq.M,
1137
- pq.nbits,
1138
- sim_table,
1139
- codes + saved_j[kk] * pq.code_size);
1140
-
1141
- res.add(saved_j[kk], dis);
1142
- }
1143
-
1144
- // process leftovers
1145
- for (size_t j = (ncode / 4) * 4; j < ncode; j++) {
1146
- if (res.skip_entry(j)) {
1147
- continue;
1148
- }
1149
- const uint8_t* b_code = codes + j * code_size;
1150
- int hd = hc.hamming(b_code);
1151
- if (hd < ht) {
1152
- n_hamming_pass++;
1153
-
1154
- float dis = dis0 +
1155
- distance_single_code<PQDecoder>(
1156
- pq.M,
1157
- pq.nbits,
1158
- sim_table,
1159
- codes + j * code_size);
1160
-
1161
- res.add(j, dis);
1162
- }
1163
- }
1164
-
1165
- #pragma omp critical
1166
- {
1167
- indexIVFPQ_stats.n_hamming_pass += n_hamming_pass;
1168
- }
1169
- }
1170
-
1171
- template <class SearchResultType>
1172
- struct Run_scan_list_polysemous_hc {
1173
- using T = void;
1174
- template <class HammingComputer, class... Types>
1175
- void f(const IVFPQScannerT* scanner, Types... args) {
1176
- scanner->scan_list_polysemous_hc<HammingComputer, SearchResultType>(
1177
- args...);
1178
- }
1179
- };
1180
-
1181
- template <class SearchResultType>
1182
- void scan_list_polysemous(
1183
- size_t ncode,
1184
- const uint8_t* codes,
1185
- SearchResultType& res) const {
1186
- Run_scan_list_polysemous_hc<SearchResultType> r;
1187
- dispatch_HammingComputer(pq.code_size, r, this, ncode, codes, res);
1188
- }
1189
- };
1190
-
1191
- /* We put as many parameters as possible in template. Hopefully the
1192
- * gain in runtime is worth the code bloat.
1193
- *
1194
- * C is the comparator < or >, it is directly related to METRIC_TYPE.
1195
- *
1196
- * precompute_mode is how much we precompute (2 = precompute distance tables,
1197
- * 1 = precompute pointers to distances, 0 = compute distances one by one).
1198
- * Currently only 2 is supported
1199
- *
1200
- * use_sel: store or ignore the IDSelector
1201
- */
1202
- template <MetricType METRIC_TYPE, class C, class PQDecoder, bool use_sel>
1203
- struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1204
- InvertedListScanner {
1205
- int precompute_mode;
1206
- const IDSelector* sel;
1207
-
1208
- IVFPQScanner(
1209
- const IndexIVFPQ& ivfpq,
1210
- bool store_pairs,
1211
- int precompute_mode,
1212
- const IDSelector* sel)
1213
- : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
1214
- precompute_mode(precompute_mode),
1215
- sel(sel) {
1216
- this->store_pairs = store_pairs;
1217
- this->keep_max = is_similarity_metric(METRIC_TYPE);
1218
- this->code_size = this->pq.code_size;
1219
- }
1220
-
1221
- void set_query(const float* query) override {
1222
- this->init_query(query);
1223
- }
1224
-
1225
- void set_list(idx_t list_no, float coarse_dis) override {
1226
- this->list_no = list_no;
1227
- this->init_list(list_no, coarse_dis, precompute_mode);
1228
- }
1229
-
1230
- float distance_to_code(const uint8_t* code) const override {
1231
- assert(precompute_mode == 2);
1232
- float dis = this->dis0 +
1233
- distance_single_code<PQDecoder>(
1234
- this->pq.M, this->pq.nbits, this->sim_table, code);
1235
- return dis;
1236
- }
1237
-
1238
- size_t scan_codes(
1239
- size_t ncode,
1240
- const uint8_t* codes,
1241
- const idx_t* ids,
1242
- float* heap_sim,
1243
- idx_t* heap_ids,
1244
- size_t k) const override {
1245
- KnnSearchResults<C, use_sel> res = {
1246
- /* key */ this->key,
1247
- /* ids */ this->store_pairs ? nullptr : ids,
1248
- /* sel */ this->sel,
1249
- /* k */ k,
1250
- /* heap_sim */ heap_sim,
1251
- /* heap_ids */ heap_ids,
1252
- /* nup */ 0};
1253
-
1254
- if (this->polysemous_ht > 0) {
1255
- assert(precompute_mode == 2);
1256
- this->scan_list_polysemous(ncode, codes, res);
1257
- } else if (precompute_mode == 2) {
1258
- this->scan_list_with_table(ncode, codes, res);
1259
- } else if (precompute_mode == 1) {
1260
- this->scan_list_with_pointer(ncode, codes, res);
1261
- } else if (precompute_mode == 0) {
1262
- this->scan_on_the_fly_dist(ncode, codes, res);
1263
- } else {
1264
- FAISS_THROW_MSG("bad precomp mode");
1265
- }
1266
- return res.nup;
1267
- }
1268
-
1269
- void scan_codes_range(
1270
- size_t ncode,
1271
- const uint8_t* codes,
1272
- const idx_t* ids,
1273
- float radius,
1274
- RangeQueryResult& rres) const override {
1275
- RangeSearchResults<C, use_sel> res = {
1276
- /* key */ this->key,
1277
- /* ids */ this->store_pairs ? nullptr : ids,
1278
- /* sel */ this->sel,
1279
- /* radius */ radius,
1280
- /* rres */ rres};
1281
-
1282
- if (this->polysemous_ht > 0) {
1283
- assert(precompute_mode == 2);
1284
- this->scan_list_polysemous(ncode, codes, res);
1285
- } else if (precompute_mode == 2) {
1286
- this->scan_list_with_table(ncode, codes, res);
1287
- } else if (precompute_mode == 1) {
1288
- this->scan_list_with_pointer(ncode, codes, res);
1289
- } else if (precompute_mode == 0) {
1290
- this->scan_on_the_fly_dist(ncode, codes, res);
1291
- } else {
1292
- FAISS_THROW_MSG("bad precomp mode");
1293
- }
1294
- }
1295
- };
1296
-
1297
- template <class PQDecoder, bool use_sel>
1298
- InvertedListScanner* get_InvertedListScanner1(
1299
- const IndexIVFPQ& index,
1300
- bool store_pairs,
1301
- const IDSelector* sel) {
1302
- if (index.metric_type == METRIC_INNER_PRODUCT) {
1303
- return new IVFPQScanner<
1304
- METRIC_INNER_PRODUCT,
1305
- CMin<float, idx_t>,
1306
- PQDecoder,
1307
- use_sel>(index, store_pairs, 2, sel);
1308
- } else if (index.metric_type == METRIC_L2) {
1309
- return new IVFPQScanner<
1310
- METRIC_L2,
1311
- CMax<float, idx_t>,
1312
- PQDecoder,
1313
- use_sel>(index, store_pairs, 2, sel);
1314
- }
1315
- return nullptr;
1316
- }
1317
-
1318
- template <bool use_sel>
1319
- InvertedListScanner* get_InvertedListScanner2(
1320
- const IndexIVFPQ& index,
1321
- bool store_pairs,
1322
- const IDSelector* sel) {
1323
- if (index.pq.nbits == 8) {
1324
- return get_InvertedListScanner1<PQDecoder8, use_sel>(
1325
- index, store_pairs, sel);
1326
- } else if (index.pq.nbits == 16) {
1327
- return get_InvertedListScanner1<PQDecoder16, use_sel>(
1328
- index, store_pairs, sel);
1329
- } else {
1330
- return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
1331
- index, store_pairs, sel);
1332
- }
1333
- }
1334
-
1335
- } // anonymous namespace
1336
-
1337
508
  InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
1338
509
  bool store_pairs,
1339
510
  const IDSelector* sel,
1340
511
  const IVFSearchParameters*) const {
1341
- if (sel) {
1342
- return get_InvertedListScanner2<true>(*this, store_pairs, sel);
1343
- } else {
1344
- return get_InvertedListScanner2<false>(*this, store_pairs, sel);
1345
- }
1346
- return nullptr;
512
+ return with_simd_level([&]<SIMDLevel SL>() -> InvertedListScanner* {
513
+ return pq_code_distance::make_IVFPQInvertedListScanner<SL>(
514
+ *this, store_pairs, sel);
515
+ });
1347
516
  }
1348
517
 
1349
518
  IndexIVFPQStats indexIVFPQ_stats;
@@ -1378,25 +547,26 @@ size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
1378
547
  for (size_t list_no = 0; list_no < nlist; list_no++) {
1379
548
  size_t n = invlists->list_size(list_no);
1380
549
  std::vector<int> ord(n);
1381
- for (int i = 0; i < n; i++)
1382
- ord[i] = i;
550
+ for (size_t i = 0; i < n; i++) {
551
+ ord[i] = static_cast<int>(i);
552
+ }
1383
553
  InvertedLists::ScopedCodes codes(invlists, list_no);
1384
554
  CodeCmp cs = {codes.get(), code_size};
1385
555
  std::sort(ord.begin(), ord.end(), cs);
1386
556
 
1387
557
  InvertedLists::ScopedIds list_ids(invlists, list_no);
1388
558
  int prev = -1; // all elements from prev to i-1 are equal
1389
- for (int i = 0; i < n; i++) {
559
+ for (size_t i = 0; i < n; i++) {
1390
560
  if (prev >= 0 && cs.cmp(ord[prev], ord[i]) == 0) {
1391
561
  // same as previous => remember
1392
- if (prev + 1 == i) { // start new group
562
+ if (static_cast<size_t>(prev + 1) == i) { // start new group
1393
563
  ngroup++;
1394
564
  lims[ngroup] = lims[ngroup - 1];
1395
565
  dup_ids[lims[ngroup]++] = list_ids[ord[prev]];
1396
566
  }
1397
567
  dup_ids[lims[ngroup]++] = list_ids[ord[i]];
1398
568
  } else { // not same as previous.
1399
- prev = i;
569
+ prev = static_cast<int>(i);
1400
570
  }
1401
571
  }
1402
572
  }