faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -7,14 +7,9 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #ifdef __AVX2__
11
-
12
10
  #include <immintrin.h>
13
11
 
14
- #include <type_traits>
15
-
16
- #include <faiss/impl/ProductQuantizer.h>
17
- #include <faiss/impl/code_distance/code_distance-generic.h>
12
+ #include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
18
13
 
19
14
  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782
20
15
  #if defined(__GNUC__) && __GNUC__ < 9
@@ -31,20 +26,17 @@ inline float horizontal_sum(const __m128 v) {
31
26
  return _mm_cvtss_f32(v3);
32
27
  }
33
28
 
34
- // Computes a horizontal sum over an __m256 register
29
+ // Computes a horizontal sum over an __m256 register.
35
30
  inline float horizontal_sum(const __m256 v) {
36
31
  const __m128 v0 =
37
32
  _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
38
33
  return horizontal_sum(v0);
39
34
  }
40
35
 
41
- // processes a single code for M=4, ksub=256, nbits=8
36
+ // Processes a single code for M=4, ksub=256, nbits=8.
42
37
  float inline distance_single_code_avx2_pqdecoder8_m4(
43
- // precomputed distances, layout (4, 256)
44
38
  const float* sim_table,
45
39
  const uint8_t* code) {
46
- float result = 0;
47
-
48
40
  const float* tab = sim_table;
49
41
  constexpr size_t ksub = 1 << 8;
50
42
 
@@ -52,39 +44,19 @@ float inline distance_single_code_avx2_pqdecoder8_m4(
52
44
  __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
53
45
  offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
54
46
 
55
- // accumulators of partial sums
56
- __m128 partialSum;
57
-
58
- // load 4 uint8 values
59
47
  const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code));
60
- {
61
- // convert uint8 values (low part of __m128i) to int32
62
- // values
63
- const __m128i idx1 = _mm_cvtepu8_epi32(mm1);
64
-
65
- // add offsets
66
- const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
48
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1);
49
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
50
+ __m128 collected =
51
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
67
52
 
68
- // gather 8 values, similar to 8 operations of tab[idx]
69
- __m128 collected =
70
- _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
71
-
72
- // collect partial sums
73
- partialSum = collected;
74
- }
75
-
76
- // horizontal sum for partialSum
77
- result = horizontal_sum(partialSum);
78
- return result;
53
+ return horizontal_sum(collected);
79
54
  }
80
55
 
81
- // processes a single code for M=8, ksub=256, nbits=8
56
+ // Processes a single code for M=8, ksub=256, nbits=8.
82
57
  float inline distance_single_code_avx2_pqdecoder8_m8(
83
- // precomputed distances, layout (8, 256)
84
58
  const float* sim_table,
85
59
  const uint8_t* code) {
86
- float result = 0;
87
-
88
60
  const float* tab = sim_table;
89
61
  constexpr size_t ksub = 1 << 8;
90
62
 
@@ -92,42 +64,21 @@ float inline distance_single_code_avx2_pqdecoder8_m8(
92
64
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
93
65
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
94
66
 
95
- // accumulators of partial sums
96
- __m256 partialSum;
97
-
98
- // load 8 uint8 values
99
67
  const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code);
100
- {
101
- // convert uint8 values (low part of __m128i) to int32
102
- // values
103
- const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
104
-
105
- // add offsets
106
- const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
107
-
108
- // gather 8 values, similar to 8 operations of tab[idx]
109
- __m256 collected =
110
- _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
68
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
69
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
70
+ __m256 collected =
71
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
111
72
 
112
- // collect partial sums
113
- partialSum = collected;
114
- }
115
-
116
- // horizontal sum for partialSum
117
- result = horizontal_sum(partialSum);
118
- return result;
73
+ return horizontal_sum(collected);
119
74
  }
120
75
 
121
- // processes four codes for M=4, ksub=256, nbits=8
122
76
  inline void distance_four_codes_avx2_pqdecoder8_m4(
123
- // precomputed distances, layout (4, 256)
124
77
  const float* sim_table,
125
- // codes
126
78
  const uint8_t* __restrict code0,
127
79
  const uint8_t* __restrict code1,
128
80
  const uint8_t* __restrict code2,
129
81
  const uint8_t* __restrict code3,
130
- // computed distances
131
82
  float& result0,
132
83
  float& result1,
133
84
  float& result2,
@@ -137,15 +88,12 @@ inline void distance_four_codes_avx2_pqdecoder8_m4(
137
88
  const float* tab = sim_table;
138
89
  constexpr size_t ksub = 1 << 8;
139
90
 
140
- // process 8 values
141
91
  const __m128i vksub = _mm_set1_epi32(ksub);
142
92
  __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
143
93
  offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
144
94
 
145
- // accumulators of partial sums
146
95
  __m128 partialSums[N];
147
96
 
148
- // load 4 uint8 values
149
97
  __m128i mm1[N];
150
98
  mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0));
151
99
  mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1));
@@ -153,38 +101,25 @@ inline void distance_four_codes_avx2_pqdecoder8_m4(
153
101
  mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3));
154
102
 
155
103
  for (intptr_t j = 0; j < N; j++) {
156
- // convert uint8 values (low part of __m128i) to int32
157
- // values
158
104
  const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]);
159
-
160
- // add offsets
161
105
  const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
162
-
163
- // gather 4 values, similar to 4 operations of tab[idx]
164
106
  __m128 collected =
165
107
  _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
166
-
167
- // collect partial sums
168
108
  partialSums[j] = collected;
169
109
  }
170
110
 
171
- // horizontal sum for partialSum
172
111
  result0 = horizontal_sum(partialSums[0]);
173
112
  result1 = horizontal_sum(partialSums[1]);
174
113
  result2 = horizontal_sum(partialSums[2]);
175
114
  result3 = horizontal_sum(partialSums[3]);
176
115
  }
177
116
 
178
- // processes four codes for M=8, ksub=256, nbits=8
179
117
  inline void distance_four_codes_avx2_pqdecoder8_m8(
180
- // precomputed distances, layout (8, 256)
181
118
  const float* sim_table,
182
- // codes
183
119
  const uint8_t* __restrict code0,
184
120
  const uint8_t* __restrict code1,
185
121
  const uint8_t* __restrict code2,
186
122
  const uint8_t* __restrict code3,
187
- // computed distances
188
123
  float& result0,
189
124
  float& result1,
190
125
  float& result2,
@@ -194,15 +129,12 @@ inline void distance_four_codes_avx2_pqdecoder8_m8(
194
129
  const float* tab = sim_table;
195
130
  constexpr size_t ksub = 1 << 8;
196
131
 
197
- // process 8 values
198
132
  const __m256i vksub = _mm256_set1_epi32(ksub);
199
133
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
200
134
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
201
135
 
202
- // accumulators of partial sums
203
136
  __m256 partialSums[N];
204
137
 
205
- // load 8 uint8 values
206
138
  __m128i mm1[N];
207
139
  mm1[0] = _mm_loadu_si64((const __m128i_u*)code0);
208
140
  mm1[1] = _mm_loadu_si64((const __m128i_u*)code1);
@@ -210,22 +142,13 @@ inline void distance_four_codes_avx2_pqdecoder8_m8(
210
142
  mm1[3] = _mm_loadu_si64((const __m128i_u*)code3);
211
143
 
212
144
  for (intptr_t j = 0; j < N; j++) {
213
- // convert uint8 values (low part of __m128i) to int32
214
- // values
215
145
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]);
216
-
217
- // add offsets
218
146
  const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
219
-
220
- // gather 8 values, similar to 8 operations of tab[idx]
221
147
  __m256 collected =
222
148
  _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
223
-
224
- // collect partial sums
225
149
  partialSums[j] = collected;
226
150
  }
227
151
 
228
- // horizontal sum for partialSum
229
152
  result0 = horizontal_sum(partialSums[0]);
230
153
  result1 = horizontal_sum(partialSums[1]);
231
154
  result2 = horizontal_sum(partialSums[2]);
@@ -235,31 +158,14 @@ inline void distance_four_codes_avx2_pqdecoder8_m8(
235
158
  } // namespace
236
159
 
237
160
  namespace faiss {
161
+ namespace pq_code_distance {
238
162
 
239
- template <typename PQDecoderT>
240
- typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, float>::
241
- type inline distance_single_code_avx2(
242
- // number of subquantizers
243
- const size_t M,
244
- // number of bits per quantization index
245
- const size_t nbits,
246
- // precomputed distances, layout (M, ksub)
247
- const float* sim_table,
248
- const uint8_t* code) {
249
- // default implementation
250
- return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
251
- }
252
-
253
- template <typename PQDecoderT>
254
- typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
255
- type inline distance_single_code_avx2(
256
- // number of subquantizers
257
- const size_t M,
258
- // number of bits per quantization index
259
- const size_t nbits,
260
- // precomputed distances, layout (M, ksub)
261
- const float* sim_table,
262
- const uint8_t* code) {
163
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
164
+ template <>
165
+ float pq_code_distance_8bit_single_impl<SIMDLevel::AVX2>(
166
+ size_t M,
167
+ const float* sim_table,
168
+ const uint8_t* code) {
263
169
  if (M == 4) {
264
170
  return distance_single_code_avx2_pqdecoder8_m4(sim_table, code);
265
171
  }
@@ -267,6 +173,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
267
173
  return distance_single_code_avx2_pqdecoder8_m8(sim_table, code);
268
174
  }
269
175
 
176
+ // Precomputed distances, layout (M, ksub).
270
177
  float result = 0;
271
178
  constexpr size_t ksub = 1 << 8;
272
179
 
@@ -276,67 +183,46 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
276
183
  const float* tab = sim_table;
277
184
 
278
185
  if (pqM16 > 0) {
279
- // process 16 values per loop
280
-
281
186
  const __m256i vksub = _mm256_set1_epi32(ksub);
282
187
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
283
188
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
284
189
 
285
- // accumulators of partial sums
286
190
  __m256 partialSum = _mm256_setzero_ps();
287
191
 
288
- // loop
192
+ // Process 16 values per loop iteration.
289
193
  for (m = 0; m < pqM16 * 16; m += 16) {
290
- // load 16 uint8 values
291
194
  const __m128i mm1 = _mm_loadu_si128((const __m128i_u*)(code + m));
195
+ // Process first 8 codes.
292
196
  {
293
- // convert uint8 values (low part of __m128i) to int32
294
- // values
295
197
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
296
-
297
- // add offsets
298
198
  const __m256i indices_to_read_from =
299
199
  _mm256_add_epi32(idx1, offsets_0);
300
-
301
- // gather 8 values, similar to 8 operations of tab[idx]
302
200
  __m256 collected = _mm256_i32gather_ps(
303
201
  tab, indices_to_read_from, sizeof(float));
304
202
  tab += ksub * 8;
305
-
306
- // collect partial sums
307
203
  partialSum = _mm256_add_ps(partialSum, collected);
308
204
  }
309
205
 
310
- // move high 8 uint8 to low ones
206
+ // Process next 8 codes.
311
207
  const __m128i mm2 = _mm_unpackhi_epi64(mm1, _mm_setzero_si128());
312
208
  {
313
- // convert uint8 values (low part of __m128i) to int32
314
- // values
315
209
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);
316
-
317
- // add offsets
318
210
  const __m256i indices_to_read_from =
319
211
  _mm256_add_epi32(idx1, offsets_0);
320
-
321
- // gather 8 values, similar to 8 operations of tab[idx]
322
212
  __m256 collected = _mm256_i32gather_ps(
323
213
  tab, indices_to_read_from, sizeof(float));
324
214
  tab += ksub * 8;
325
-
326
- // collect partial sums
327
215
  partialSum = _mm256_add_ps(partialSum, collected);
328
216
  }
329
217
  }
330
218
 
331
- // horizontal sum for partialSum
219
+ // Horizontal sum for partialSum.
332
220
  result += horizontal_sum(partialSum);
333
221
  }
334
222
 
335
- //
223
+ // Process leftovers.
336
224
  if (m < M) {
337
- // process leftovers
338
- PQDecoder8 decoder(code + m, nbits);
339
-
225
+ PQDecoder8 decoder(code + m, 8);
340
226
  for (; m < M; m++) {
341
227
  result += tab[decoder.decode()];
342
228
  tab += ksub;
@@ -346,56 +232,16 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
346
232
  return result;
347
233
  }
348
234
 
349
- template <typename PQDecoderT>
350
- typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
351
- type
352
- distance_four_codes_avx2(
353
- // number of subquantizers
354
- const size_t M,
355
- // number of bits per quantization index
356
- const size_t nbits,
357
- // precomputed distances, layout (M, ksub)
358
- const float* sim_table,
359
- // codes
360
- const uint8_t* __restrict code0,
361
- const uint8_t* __restrict code1,
362
- const uint8_t* __restrict code2,
363
- const uint8_t* __restrict code3,
364
- // computed distances
365
- float& result0,
366
- float& result1,
367
- float& result2,
368
- float& result3) {
369
- distance_four_codes_generic<PQDecoderT>(
370
- M,
371
- nbits,
372
- sim_table,
373
- code0,
374
- code1,
375
- code2,
376
- code3,
377
- result0,
378
- result1,
379
- result2,
380
- result3);
381
- }
382
-
383
- // Combines 4 operations of distance_single_code()
384
- template <typename PQDecoderT>
385
- typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, void>::type
386
- distance_four_codes_avx2(
387
- // number of subquantizers
388
- const size_t M,
389
- // number of bits per quantization index
390
- const size_t nbits,
391
- // precomputed distances, layout (M, ksub)
235
+ // Combines 4 operations of pq_code_distance_8bit_single_impl().
236
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
237
+ template <>
238
+ void pq_code_distance_8bit_four_impl<SIMDLevel::AVX2>(
239
+ size_t M,
392
240
  const float* sim_table,
393
- // codes
394
241
  const uint8_t* __restrict code0,
395
242
  const uint8_t* __restrict code1,
396
243
  const uint8_t* __restrict code2,
397
244
  const uint8_t* __restrict code3,
398
- // computed distances
399
245
  float& result0,
400
246
  float& result1,
401
247
  float& result2,
@@ -427,6 +273,7 @@ distance_four_codes_avx2(
427
273
  return;
428
274
  }
429
275
 
276
+ // Precomputed distances, layout (M, ksub).
430
277
  result0 = 0;
431
278
  result1 = 0;
432
279
  result2 = 0;
@@ -441,84 +288,61 @@ distance_four_codes_avx2(
441
288
  const float* tab = sim_table;
442
289
 
443
290
  if (pqM16 > 0) {
444
- // process 16 values per loop
445
291
  const __m256i vksub = _mm256_set1_epi32(ksub);
446
292
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
447
293
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
448
294
 
449
- // accumulators of partial sums
450
295
  __m256 partialSums[N];
451
296
  for (intptr_t j = 0; j < N; j++) {
452
297
  partialSums[j] = _mm256_setzero_ps();
453
298
  }
454
299
 
455
- // loop
300
+ // Process 16 values per loop iteration.
456
301
  for (m = 0; m < pqM16 * 16; m += 16) {
457
- // load 16 uint8 values
458
302
  __m128i mm1[N];
459
303
  mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m));
460
304
  mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m));
461
305
  mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m));
462
306
  mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m));
463
307
 
464
- // process first 8 codes
308
+ // Process first 8 codes.
465
309
  for (intptr_t j = 0; j < N; j++) {
466
- // convert uint8 values (low part of __m128i) to int32
467
- // values
468
310
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]);
469
-
470
- // add offsets
471
311
  const __m256i indices_to_read_from =
472
312
  _mm256_add_epi32(idx1, offsets_0);
473
-
474
- // gather 8 values, similar to 8 operations of tab[idx]
475
313
  __m256 collected = _mm256_i32gather_ps(
476
314
  tab, indices_to_read_from, sizeof(float));
477
-
478
- // collect partial sums
479
315
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
480
316
  }
481
317
  tab += ksub * 8;
482
318
 
483
- // process next 8 codes
319
+ // Process next 8 codes.
484
320
  for (intptr_t j = 0; j < N; j++) {
485
- // move high 8 uint8 to low ones
486
321
  const __m128i mm2 =
487
322
  _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128());
488
-
489
- // convert uint8 values (low part of __m128i) to int32
490
- // values
491
323
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);
492
-
493
- // add offsets
494
324
  const __m256i indices_to_read_from =
495
325
  _mm256_add_epi32(idx1, offsets_0);
496
-
497
- // gather 8 values, similar to 8 operations of tab[idx]
498
326
  __m256 collected = _mm256_i32gather_ps(
499
327
  tab, indices_to_read_from, sizeof(float));
500
-
501
- // collect partial sums
502
328
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
503
329
  }
504
330
 
505
331
  tab += ksub * 8;
506
332
  }
507
333
 
508
- // horizontal sum for partialSum
509
334
  result0 += horizontal_sum(partialSums[0]);
510
335
  result1 += horizontal_sum(partialSums[1]);
511
336
  result2 += horizontal_sum(partialSums[2]);
512
337
  result3 += horizontal_sum(partialSums[3]);
513
338
  }
514
339
 
515
- //
340
+ // Process leftovers.
516
341
  if (m < M) {
517
- // process leftovers
518
- PQDecoder8 decoder0(code0 + m, nbits);
519
- PQDecoder8 decoder1(code1 + m, nbits);
520
- PQDecoder8 decoder2(code2 + m, nbits);
521
- PQDecoder8 decoder3(code3 + m, nbits);
342
+ PQDecoder8 decoder0(code0 + m, 8);
343
+ PQDecoder8 decoder1(code1 + m, 8);
344
+ PQDecoder8 decoder2(code2 + m, 8);
345
+ PQDecoder8 decoder3(code3 + m, 8);
522
346
  for (; m < M; m++) {
523
347
  result0 += tab[decoder0.decode()];
524
348
  result1 += tab[decoder1.decode()];
@@ -529,6 +353,5 @@ distance_four_codes_avx2(
529
353
  }
530
354
  }
531
355
 
356
+ } // namespace pq_code_distance
532
357
  } // namespace faiss
533
-
534
- #endif