faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -13,21 +13,19 @@
13
13
  #include <cstddef>
14
14
  #include <cstdio>
15
15
  #include <cstring>
16
+ #include <vector>
16
17
 
17
18
  #include <omp.h>
18
19
 
19
- #ifdef __AVX2__
20
- #include <immintrin.h>
21
- #elif defined(__ARM_FEATURE_SVE)
22
- #include <arm_sve.h>
23
- #endif
24
-
25
20
  #include <faiss/impl/AuxIndexStructures.h>
26
21
  #include <faiss/impl/FaissAssert.h>
27
22
  #include <faiss/impl/IDSelector.h>
28
23
  #include <faiss/impl/ResultHandler.h>
29
24
 
25
+ #include <faiss/impl/simd_dispatch.h>
26
+ #include <faiss/utils/distances_dispatch.h>
30
27
  #include <faiss/utils/distances_fused/distances_fused.h>
28
+ #include <faiss/utils/simd_impl/exhaustive_L2sqr_blas_cmax.h>
31
29
 
32
30
  #ifndef FINTEGER
33
31
  #define FINTEGER long
@@ -55,6 +53,146 @@ int sgemm_(
55
53
 
56
54
  namespace faiss {
57
55
 
56
+ /***************************************************************************
57
+ * Public API dispatch wrappers
58
+ ***************************************************************************/
59
+
60
+ float fvec_L1(const float* x, const float* y, size_t d) {
61
+ return fvec_L1_dispatch(x, y, d);
62
+ }
63
+
64
+ float fvec_Linf(const float* x, const float* y, size_t d) {
65
+ return fvec_Linf_dispatch(x, y, d);
66
+ }
67
+
68
+ float fvec_norm_L2sqr(const float* x, size_t d) {
69
+ return fvec_norm_L2sqr_dispatch(x, d);
70
+ }
71
+
72
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
73
+ return fvec_L2sqr_dispatch(x, y, d);
74
+ }
75
+
76
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
77
+ return fvec_inner_product_dispatch(x, y, d);
78
+ }
79
+
80
+ void fvec_inner_product_batch_4(
81
+ const float* x,
82
+ const float* y0,
83
+ const float* y1,
84
+ const float* y2,
85
+ const float* y3,
86
+ const size_t d,
87
+ float& dis0,
88
+ float& dis1,
89
+ float& dis2,
90
+ float& dis3) {
91
+ fvec_inner_product_batch_4_dispatch(
92
+ x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3);
93
+ }
94
+
95
+ void fvec_L2sqr_batch_4(
96
+ const float* x,
97
+ const float* y0,
98
+ const float* y1,
99
+ const float* y2,
100
+ const float* y3,
101
+ const size_t d,
102
+ float& dis0,
103
+ float& dis1,
104
+ float& dis2,
105
+ float& dis3) {
106
+ fvec_L2sqr_batch_4_dispatch(x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3);
107
+ }
108
+
109
+ void fvec_L2sqr_ny_transposed(
110
+ float* dis,
111
+ const float* x,
112
+ const float* y,
113
+ const float* y_sqlen,
114
+ size_t d,
115
+ size_t d_offset,
116
+ size_t ny) {
117
+ fvec_L2sqr_ny_transposed_dispatch(dis, x, y, y_sqlen, d, d_offset, ny);
118
+ }
119
+
120
+ void fvec_inner_products_ny(
121
+ float* ip,
122
+ const float* x,
123
+ const float* y,
124
+ size_t d,
125
+ size_t ny) {
126
+ fvec_inner_products_ny_dispatch(ip, x, y, d, ny);
127
+ }
128
+
129
+ void fvec_L2sqr_ny(
130
+ float* dis,
131
+ const float* x,
132
+ const float* y,
133
+ size_t d,
134
+ size_t ny) {
135
+ fvec_L2sqr_ny_dispatch(dis, x, y, d, ny);
136
+ }
137
+
138
+ size_t fvec_L2sqr_ny_nearest(
139
+ float* distances_tmp_buffer,
140
+ const float* x,
141
+ const float* y,
142
+ size_t d,
143
+ size_t ny) {
144
+ return fvec_L2sqr_ny_nearest_dispatch(distances_tmp_buffer, x, y, d, ny);
145
+ }
146
+
147
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
148
+ float* distances_tmp_buffer,
149
+ const float* x,
150
+ const float* y,
151
+ const float* y_sqlen,
152
+ size_t d,
153
+ size_t d_offset,
154
+ size_t ny) {
155
+ return fvec_L2sqr_ny_nearest_y_transposed_dispatch(
156
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
157
+ }
158
+
159
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
160
+ fvec_madd_dispatch(n, a, bf, b, c);
161
+ }
162
+
163
+ int fvec_madd_and_argmin(
164
+ size_t n,
165
+ const float* a,
166
+ float bf,
167
+ const float* b,
168
+ float* c) {
169
+ return fvec_madd_and_argmin_dispatch(n, a, bf, b, c);
170
+ }
171
+
172
+ void fvec_sub(size_t d, const float* a, const float* b, float* c) {
173
+ fvec_sub_dispatch(d, a, b, c);
174
+ }
175
+
176
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
177
+ fvec_add_dispatch(d, a, b, c);
178
+ }
179
+
180
+ void fvec_add(size_t d, const float* a, float b, float* c) {
181
+ fvec_add_scalar_dispatch(d, a, b, c);
182
+ }
183
+
184
+ void compute_PQ_dis_tables_dsub2(
185
+ size_t d,
186
+ size_t ksub,
187
+ const float* all_centroids,
188
+ size_t nx,
189
+ const float* x,
190
+ bool is_inner_product,
191
+ float* dis_tables) {
192
+ compute_PQ_dis_tables_dsub2_dispatch(
193
+ d, ksub, all_centroids, nx, x, is_inner_product, dis_tables);
194
+ }
195
+
58
196
  /***************************************************************************
59
197
  * Matrix/vector ops
60
198
  ***************************************************************************/
@@ -65,10 +203,12 @@ void fvec_norms_L2(
65
203
  const float* __restrict x,
66
204
  size_t d,
67
205
  size_t nx) {
206
+ with_simd_level([&]<SIMDLevel SL>() {
68
207
  #pragma omp parallel for if (nx > 10000)
69
- for (int64_t i = 0; i < nx; i++) {
70
- nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
71
- }
208
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
209
+ nr[i] = sqrtf(fvec_norm_L2sqr<SL>(x + i * d, d));
210
+ }
211
+ });
72
212
  }
73
213
 
74
214
  void fvec_norms_L2sqr(
@@ -76,10 +216,12 @@ void fvec_norms_L2sqr(
76
216
  const float* __restrict x,
77
217
  size_t d,
78
218
  size_t nx) {
219
+ with_simd_level([&]<SIMDLevel SL>() {
79
220
  #pragma omp parallel for if (nx > 10000)
80
- for (int64_t i = 0; i < nx; i++) {
81
- nr[i] = fvec_norm_L2sqr(x + i * d, d);
82
- }
221
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
222
+ nr[i] = fvec_norm_L2sqr<SL>(x + i * d, d);
223
+ }
224
+ });
83
225
  }
84
226
 
85
227
  // The following is a workaround to a problem
@@ -93,29 +235,35 @@ void fvec_norms_L2sqr(
93
235
  // The workaround below is explicitly branching
94
236
  // off to a codepath without omp.
95
237
 
96
- #define FVEC_RENORM_L2_IMPL \
97
- float* __restrict xi = x + i * d; \
98
- \
99
- float nr = fvec_norm_L2sqr(xi, d); \
100
- \
101
- if (nr > 0) { \
102
- size_t j; \
103
- const float inv_nr = 1.0 / sqrtf(nr); \
104
- for (j = 0; j < d; j++) \
105
- xi[j] *= inv_nr; \
106
- }
107
-
108
238
  void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
109
- for (int64_t i = 0; i < nx; i++) {
110
- FVEC_RENORM_L2_IMPL
111
- }
239
+ with_simd_level([&]<SIMDLevel SL>() {
240
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
241
+ float* __restrict xi = x + i * d;
242
+ float nr = fvec_norm_L2sqr<SL>(xi, d);
243
+ if (nr > 0) {
244
+ const float inv_nr = 1.0 / sqrtf(nr);
245
+ for (size_t j = 0; j < d; j++) {
246
+ xi[j] *= inv_nr;
247
+ }
248
+ }
249
+ }
250
+ });
112
251
  }
113
252
 
114
253
  void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
254
+ with_simd_level([&]<SIMDLevel SL>() {
115
255
  #pragma omp parallel for if (nx > 10000)
116
- for (int64_t i = 0; i < nx; i++) {
117
- FVEC_RENORM_L2_IMPL
118
- }
256
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
257
+ float* __restrict xi = x + i * d;
258
+ float nr = fvec_norm_L2sqr<SL>(xi, d);
259
+ if (nr > 0) {
260
+ const float inv_nr = 1.0 / sqrtf(nr);
261
+ for (size_t j = 0; j < d; j++) {
262
+ xi[j] *= inv_nr;
263
+ }
264
+ }
265
+ }
266
+ });
119
267
  }
120
268
 
121
269
  void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
@@ -148,22 +296,24 @@ void exhaustive_inner_product_seq(
148
296
  #pragma omp parallel num_threads(nt)
149
297
  {
150
298
  SingleResultHandler resi(res);
299
+ with_simd_level([&]<SIMDLevel SL>() {
151
300
  #pragma omp for
152
- for (int64_t i = 0; i < nx; i++) {
153
- const float* x_i = x + i * d;
154
- const float* y_j = y;
301
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
302
+ const float* x_i = x + i * d;
303
+ const float* y_j = y;
155
304
 
156
- resi.begin(i);
305
+ resi.begin(i);
157
306
 
158
- for (size_t j = 0; j < ny; j++, y_j += d) {
159
- if (!res.is_in_selection(j)) {
160
- continue;
307
+ for (size_t j = 0; j < ny; j++, y_j += d) {
308
+ if (!res.is_in_selection(j)) {
309
+ continue;
310
+ }
311
+ float ip = fvec_inner_product<SL>(x_i, y_j, d);
312
+ resi.add_result(ip, j);
161
313
  }
162
- float ip = fvec_inner_product(x_i, y_j, d);
163
- resi.add_result(ip, j);
314
+ resi.end();
164
315
  }
165
- resi.end();
166
- }
316
+ });
167
317
  }
168
318
  }
169
319
 
@@ -182,20 +332,22 @@ void exhaustive_L2sqr_seq(
182
332
  #pragma omp parallel num_threads(nt)
183
333
  {
184
334
  SingleResultHandler resi(res);
335
+ with_simd_level([&]<SIMDLevel SL>() {
185
336
  #pragma omp for
186
- for (int64_t i = 0; i < nx; i++) {
187
- const float* x_i = x + i * d;
188
- const float* y_j = y;
189
- resi.begin(i);
190
- for (size_t j = 0; j < ny; j++, y_j += d) {
191
- if (!res.is_in_selection(j)) {
192
- continue;
337
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
338
+ const float* x_i = x + i * d;
339
+ const float* y_j = y;
340
+ resi.begin(i);
341
+ for (size_t j = 0; j < ny; j++, y_j += d) {
342
+ if (!res.is_in_selection(j)) {
343
+ continue;
344
+ }
345
+ float disij = fvec_L2sqr<SL>(x_i, y_j, d);
346
+ resi.add_result(disij, j);
193
347
  }
194
- float disij = fvec_L2sqr(x_i, y_j, d);
195
- resi.add_result(disij, j);
348
+ resi.end();
196
349
  }
197
- resi.end();
198
- }
350
+ });
199
351
  }
200
352
  }
201
353
 
@@ -321,7 +473,7 @@ void exhaustive_L2sqr_blas_default_impl(
321
473
  ip_block.get(),
322
474
  &nyi);
323
475
  }
324
- for (int64_t i = i0; i < i1; i++) {
476
+ for (size_t i = i0; i < i1; i++) {
325
477
  float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
326
478
 
327
479
  for (size_t j = j0; j < j1; j++) {
@@ -357,396 +509,12 @@ void exhaustive_L2sqr_blas(
357
509
  size_t ny,
358
510
  BlockResultHandler& res,
359
511
  const float* y_norms = nullptr) {
360
- exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
361
- }
362
-
363
- #ifdef __AVX2__
364
- void exhaustive_L2sqr_blas_cmax_avx2(
365
- const float* x,
366
- const float* y,
367
- size_t d,
368
- size_t nx,
369
- size_t ny,
370
- Top1BlockResultHandler<CMax<float, int64_t>>& res,
371
- const float* y_norms) {
372
- // BLAS does not like empty matrices
373
- if (nx == 0 || ny == 0) {
374
- return;
375
- }
376
-
377
- /* block sizes */
378
- const size_t bs_x = distance_compute_blas_query_bs;
379
- const size_t bs_y = distance_compute_blas_database_bs;
380
- // const size_t bs_x = 16, bs_y = 16;
381
- std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
382
- std::unique_ptr<float[]> x_norms(new float[nx]);
383
- std::unique_ptr<float[]> del2;
384
-
385
- fvec_norms_L2sqr(x_norms.get(), x, d, nx);
386
-
387
- if (!y_norms) {
388
- float* y_norms2 = new float[ny];
389
- del2.reset(y_norms2);
390
- fvec_norms_L2sqr(y_norms2, y, d, ny);
391
- y_norms = y_norms2;
392
- }
393
-
394
- for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
395
- size_t i1 = i0 + bs_x;
396
- if (i1 > nx) {
397
- i1 = nx;
398
- }
399
-
400
- res.begin_multiple(i0, i1);
401
-
402
- for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
403
- size_t j1 = j0 + bs_y;
404
- if (j1 > ny) {
405
- j1 = ny;
406
- }
407
- /* compute the actual dot products */
408
- {
409
- float one = 1, zero = 0;
410
- FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
411
- sgemm_("Transpose",
412
- "Not transpose",
413
- &nyi,
414
- &nxi,
415
- &di,
416
- &one,
417
- y + j0 * d,
418
- &di,
419
- x + i0 * d,
420
- &di,
421
- &zero,
422
- ip_block.get(),
423
- &nyi);
424
- }
425
- for (int64_t i = i0; i < i1; i++) {
426
- float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
427
-
428
- _mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
429
- _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
430
-
431
- // constant
432
- const __m256 mul_minus2 = _mm256_set1_ps(-2);
433
-
434
- // Track 8 min distances + 8 min indices.
435
- // All the distances tracked do not take x_norms[i]
436
- // into account in order to get rid of extra
437
- // _mm256_add_ps(x_norms[i], ...) instructions
438
- // is distance computations.
439
- __m256 min_distances =
440
- _mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
441
-
442
- // these indices are local and are relative to j0.
443
- // so, value 0 means j0.
444
- __m256i min_indices = _mm256_set1_epi32(0);
445
-
446
- __m256i current_indices =
447
- _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
448
- const __m256i indices_delta = _mm256_set1_epi32(8);
449
-
450
- // current j index
451
- size_t idx_j = 0;
452
- size_t count = j1 - j0;
453
-
454
- // process 16 elements per loop
455
- for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
456
- _mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
457
- _mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
458
-
459
- // load values for norms
460
- const __m256 y_norm_0 =
461
- _mm256_loadu_ps(y_norms + idx_j + j0 + 0);
462
- const __m256 y_norm_1 =
463
- _mm256_loadu_ps(y_norms + idx_j + j0 + 8);
464
-
465
- // load values for dot products
466
- const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
467
- const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
468
-
469
- // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
470
- // x_norm[i] was dropped off because it is a constant for a
471
- // given i. We'll deal with it later.
472
- __m256 distances_0 =
473
- _mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
474
- __m256 distances_1 =
475
- _mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
476
-
477
- // compare the new distances to the min distances
478
- // for each of the first group of 8 AVX2 components.
479
- const __m256 comparison_0 = _mm256_cmp_ps(
480
- min_distances, distances_0, _CMP_LE_OS);
481
-
482
- // update min distances and indices with closest vectors if
483
- // needed.
484
- min_distances = _mm256_blendv_ps(
485
- distances_0, min_distances, comparison_0);
486
- min_indices = _mm256_castps_si256(_mm256_blendv_ps(
487
- _mm256_castsi256_ps(current_indices),
488
- _mm256_castsi256_ps(min_indices),
489
- comparison_0));
490
- current_indices =
491
- _mm256_add_epi32(current_indices, indices_delta);
492
-
493
- // compare the new distances to the min distances
494
- // for each of the second group of 8 AVX2 components.
495
- const __m256 comparison_1 = _mm256_cmp_ps(
496
- min_distances, distances_1, _CMP_LE_OS);
497
-
498
- // update min distances and indices with closest vectors if
499
- // needed.
500
- min_distances = _mm256_blendv_ps(
501
- distances_1, min_distances, comparison_1);
502
- min_indices = _mm256_castps_si256(_mm256_blendv_ps(
503
- _mm256_castsi256_ps(current_indices),
504
- _mm256_castsi256_ps(min_indices),
505
- comparison_1));
506
- current_indices =
507
- _mm256_add_epi32(current_indices, indices_delta);
508
- }
509
-
510
- // dump values and find the minimum distance / minimum index
511
- float min_distances_scalar[8];
512
- uint32_t min_indices_scalar[8];
513
- _mm256_storeu_ps(min_distances_scalar, min_distances);
514
- _mm256_storeu_si256(
515
- (__m256i*)(min_indices_scalar), min_indices);
516
-
517
- float current_min_distance = res.dis_tab[i];
518
- uint32_t current_min_index = res.ids_tab[i];
519
-
520
- // This unusual comparison is needed to maintain the behavior
521
- // of the original implementation: if two indices are
522
- // represented with equal distance values, then
523
- // the index with the min value is returned.
524
- for (size_t jv = 0; jv < 8; jv++) {
525
- // add missing x_norms[i]
526
- float distance_candidate =
527
- min_distances_scalar[jv] + x_norms[i];
528
-
529
- // negative values can occur for identical vectors
530
- // due to roundoff errors.
531
- if (distance_candidate < 0) {
532
- distance_candidate = 0;
533
- }
534
-
535
- int64_t index_candidate = min_indices_scalar[jv] + j0;
536
-
537
- if (current_min_distance > distance_candidate) {
538
- current_min_distance = distance_candidate;
539
- current_min_index = index_candidate;
540
- } else if (
541
- current_min_distance == distance_candidate &&
542
- current_min_index > index_candidate) {
543
- current_min_index = index_candidate;
544
- }
545
- }
546
-
547
- // process leftovers
548
- for (; idx_j < count; idx_j++, ip_line++) {
549
- float ip = *ip_line;
550
- float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
551
- // negative values can occur for identical vectors
552
- // due to roundoff errors.
553
- if (dis < 0) {
554
- dis = 0;
555
- }
556
-
557
- if (current_min_distance > dis) {
558
- current_min_distance = dis;
559
- current_min_index = idx_j + j0;
560
- }
561
- }
562
-
563
- //
564
- res.add_result(i, current_min_distance, current_min_index);
565
- }
566
- }
567
- // Does nothing for SingleBestResultHandler, but
568
- // keeping the call for the consistency.
569
- res.end_multiple();
570
- InterruptCallback::check();
571
- }
512
+ exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res, y_norms);
572
513
  }
573
- #elif defined(__ARM_FEATURE_SVE)
574
- void exhaustive_L2sqr_blas_cmax_sve(
575
- const float* x,
576
- const float* y,
577
- size_t d,
578
- size_t nx,
579
- size_t ny,
580
- Top1BlockResultHandler<CMax<float, int64_t>>& res,
581
- const float* y_norms) {
582
- // BLAS does not like empty matrices
583
- if (nx == 0 || ny == 0)
584
- return;
585
514
 
586
- /* block sizes */
587
- const size_t bs_x = distance_compute_blas_query_bs;
588
- const size_t bs_y = distance_compute_blas_database_bs;
589
- // const size_t bs_x = 16, bs_y = 16;
590
- std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
591
- std::unique_ptr<float[]> x_norms(new float[nx]);
592
- std::unique_ptr<float[]> del2;
593
-
594
- fvec_norms_L2sqr(x_norms.get(), x, d, nx);
595
-
596
- const size_t lanes = svcntw();
597
-
598
- if (!y_norms) {
599
- float* y_norms2 = new float[ny];
600
- del2.reset(y_norms2);
601
- fvec_norms_L2sqr(y_norms2, y, d, ny);
602
- y_norms = y_norms2;
603
- }
604
-
605
- for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
606
- size_t i1 = i0 + bs_x;
607
- if (i1 > nx)
608
- i1 = nx;
609
-
610
- res.begin_multiple(i0, i1);
611
-
612
- for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
613
- size_t j1 = j0 + bs_y;
614
- if (j1 > ny)
615
- j1 = ny;
616
- /* compute the actual dot products */
617
- {
618
- float one = 1, zero = 0;
619
- FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
620
- sgemm_("Transpose",
621
- "Not transpose",
622
- &nyi,
623
- &nxi,
624
- &di,
625
- &one,
626
- y + j0 * d,
627
- &di,
628
- x + i0 * d,
629
- &di,
630
- &zero,
631
- ip_block.get(),
632
- &nyi);
633
- }
634
- for (int64_t i = i0; i < i1; i++) {
635
- const size_t count = j1 - j0;
636
- float* ip_line = ip_block.get() + (i - i0) * count;
637
-
638
- svprfw(svwhilelt_b32_u64(0, count), ip_line, SV_PLDL1KEEP);
639
- svprfw(svwhilelt_b32_u64(lanes, count),
640
- ip_line + lanes,
641
- SV_PLDL1KEEP);
642
-
643
- // Track lanes min distances + lanes min indices.
644
- // All the distances tracked do not take x_norms[i]
645
- // into account in order to get rid of extra
646
- // vaddq_f32(x_norms[i], ...) instructions
647
- // is distance computations.
648
- auto min_distances = svdup_n_f32(res.dis_tab[i] - x_norms[i]);
649
-
650
- // these indices are local and are relative to j0.
651
- // so, value 0 means j0.
652
- auto min_indices = svdup_n_u32(0u);
653
-
654
- auto current_indices = svindex_u32(0u, 1u);
655
-
656
- // process lanes * 2 elements per loop
657
- for (size_t idx_j = 0; idx_j < count;
658
- idx_j += lanes * 2, ip_line += lanes * 2) {
659
- svprfw(svwhilelt_b32_u64(idx_j + lanes * 2, count),
660
- ip_line + lanes * 2,
661
- SV_PLDL1KEEP);
662
- svprfw(svwhilelt_b32_u64(idx_j + lanes * 3, count),
663
- ip_line + lanes * 3,
664
- SV_PLDL1KEEP);
665
-
666
- // mask
667
- const auto mask_0 = svwhilelt_b32_u64(idx_j, count);
668
- const auto mask_1 = svwhilelt_b32_u64(idx_j + lanes, count);
669
-
670
- // load values for norms
671
- const auto y_norm_0 =
672
- svld1_f32(mask_0, y_norms + idx_j + j0 + 0);
673
- const auto y_norm_1 =
674
- svld1_f32(mask_1, y_norms + idx_j + j0 + lanes);
675
-
676
- // load values for dot products
677
- const auto ip_0 = svld1_f32(mask_0, ip_line + 0);
678
- const auto ip_1 = svld1_f32(mask_1, ip_line + lanes);
679
-
680
- // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
681
- // x_norm[i] was dropped off because it is a constant for a
682
- // given i. We'll deal with it later.
683
- const auto distances_0 =
684
- svmla_n_f32_z(mask_0, y_norm_0, ip_0, -2.f);
685
- const auto distances_1 =
686
- svmla_n_f32_z(mask_1, y_norm_1, ip_1, -2.f);
687
-
688
- // compare the new distances to the min distances
689
- // for each of the first group of 4 ARM SIMD components.
690
- auto comparison =
691
- svcmpgt_f32(mask_0, min_distances, distances_0);
692
-
693
- // update min distances and indices with closest vectors if
694
- // needed.
695
- min_distances =
696
- svsel_f32(comparison, distances_0, min_distances);
697
- min_indices =
698
- svsel_u32(comparison, current_indices, min_indices);
699
- current_indices = svadd_n_u32_x(
700
- mask_0,
701
- current_indices,
702
- static_cast<uint32_t>(lanes));
703
-
704
- // compare the new distances to the min distances
705
- // for each of the second group of 4 ARM SIMD components.
706
- comparison =
707
- svcmpgt_f32(mask_1, min_distances, distances_1);
708
-
709
- // update min distances and indices with closest vectors if
710
- // needed.
711
- min_distances =
712
- svsel_f32(comparison, distances_1, min_distances);
713
- min_indices =
714
- svsel_u32(comparison, current_indices, min_indices);
715
- current_indices = svadd_n_u32_x(
716
- mask_1,
717
- current_indices,
718
- static_cast<uint32_t>(lanes));
719
- }
515
+ } // anonymous namespace
720
516
 
721
- // add missing x_norms[i]
722
- // negative values can occur for identical vectors
723
- // due to roundoff errors.
724
- auto mask = svwhilelt_b32_u64(0, count);
725
- min_distances = svadd_n_f32_z(
726
- svcmpge_n_f32(mask, min_distances, -x_norms[i]),
727
- min_distances,
728
- x_norms[i]);
729
- min_indices = svadd_n_u32_x(
730
- mask, min_indices, static_cast<uint32_t>(j0));
731
- mask = svcmple_n_f32(mask, min_distances, res.dis_tab[i]);
732
- if (svcntp_b32(svptrue_b32(), mask) == 0)
733
- res.add_result(i, res.dis_tab[i], res.ids_tab[i]);
734
- else {
735
- const auto min_distance = svminv_f32(mask, min_distances);
736
- const auto min_index = svminv_u32(
737
- svcmpeq_n_f32(mask, min_distances, min_distance),
738
- min_indices);
739
- res.add_result(i, min_distance, min_index);
740
- }
741
- }
742
- }
743
- // Does nothing for SingleBestResultHandler, but
744
- // keeping the call for the consistency.
745
- res.end_multiple();
746
- InterruptCallback::check();
747
- }
748
- }
749
- #endif
517
+ namespace {
750
518
 
751
519
  // an override if only a single closest point is needed
752
520
  template <>
@@ -758,43 +526,20 @@ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
758
526
  size_t ny,
759
527
  Top1BlockResultHandler<CMax<float, int64_t>>& res,
760
528
  const float* y_norms) {
761
- #if defined(__AVX2__)
762
- // use a faster fused kernel if available
763
- if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
764
- // the kernel is available and it is complete, we're done.
765
- return;
766
- }
767
-
768
- // run the specialized AVX2 implementation
769
- exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms);
770
-
771
- #elif defined(__ARM_FEATURE_SVE)
772
- // use a faster fused kernel if available
773
- if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
774
- // the kernel is available and it is complete, we're done.
775
- return;
776
- }
777
-
778
- // run the specialized SVE implementation
779
- exhaustive_L2sqr_blas_cmax_sve(x, y, d, nx, ny, res, y_norms);
780
-
781
- #elif defined(__aarch64__)
782
529
  // use a faster fused kernel if available
783
530
  if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
784
- // the kernel is available and it is complete, we're done.
785
531
  return;
786
532
  }
787
533
 
788
- // run the default implementation
789
- exhaustive_L2sqr_blas_default_impl<
790
- Top1BlockResultHandler<CMax<float, int64_t>>>(
791
- x, y, d, nx, ny, res, y_norms);
792
- #else
793
- // run the default implementation
794
- exhaustive_L2sqr_blas_default_impl<
795
- Top1BlockResultHandler<CMax<float, int64_t>>>(
796
- x, y, d, nx, ny, res, y_norms);
797
- #endif
534
+ with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A2>([&]<SIMDLevel SL>() {
535
+ if constexpr (SL == SIMDLevel::AVX2 || SL == SIMDLevel::ARM_SVE) {
536
+ exhaustive_L2sqr_blas_cmax<SL>(x, y, d, nx, ny, res, y_norms);
537
+ } else {
538
+ exhaustive_L2sqr_blas_default_impl<
539
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
540
+ x, y, d, nx, ny, res, y_norms);
541
+ }
542
+ });
798
543
  }
799
544
 
800
545
  struct Run_search_inner_product {
@@ -806,7 +551,8 @@ struct Run_search_inner_product {
806
551
  size_t d,
807
552
  size_t nx,
808
553
  size_t ny) {
809
- if (res.sel || nx < distance_compute_blas_threshold) {
554
+ if (res.sel ||
555
+ nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
810
556
  exhaustive_inner_product_seq(x, y, d, nx, ny, res);
811
557
  } else {
812
558
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
@@ -824,7 +570,8 @@ struct Run_search_L2sqr {
824
570
  size_t nx,
825
571
  size_t ny,
826
572
  const float* y_norm2) {
827
- if (res.sel || nx < distance_compute_blas_threshold) {
573
+ if (res.sel ||
574
+ nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
828
575
  exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
829
576
  } else {
830
577
  exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
@@ -838,11 +585,174 @@ struct Run_search_L2sqr {
838
585
  * KNN driver functions
839
586
  *******************************************************/
840
587
 
841
- int distance_compute_blas_threshold = 20;
588
+ int distance_compute_blas_threshold = 128000;
842
589
  int distance_compute_blas_query_bs = 4096;
843
590
  int distance_compute_blas_database_bs = 1024;
844
591
  int distance_compute_min_k_reservoir = 100;
845
592
 
593
+ // Database-parallel KNN: parallelizes over database segments instead of
594
+ // queries, for the case where nx < nthreads and the database is large.
595
+ static constexpr size_t kDbParallelMinVectors = 10000;
596
+
597
+ template <class C>
598
+ static void knn_db_parallel_impl(
599
+ const float* x,
600
+ const float* y,
601
+ size_t d,
602
+ size_t nx,
603
+ size_t ny,
604
+ size_t k,
605
+ float* vals,
606
+ int64_t* ids,
607
+ const float* y_norms) {
608
+ using T = typename C::T;
609
+ using TI = typename C::TI;
610
+
611
+ int nt = omp_get_max_threads();
612
+ const size_t bs_y = distance_compute_blas_database_bs;
613
+
614
+ // Per-thread result heaps: nt threads x nx queries x k results
615
+ std::vector<T> all_dis(static_cast<size_t>(nt) * nx * k);
616
+ std::vector<TI> all_ids(static_cast<size_t>(nt) * nx * k);
617
+
618
+ std::unique_ptr<float[]> x_norms_storage;
619
+ std::unique_ptr<float[]> y_norms_storage;
620
+ const float* x_norms = nullptr;
621
+ // C::is_max corresponds to L2 (CMax), not IP (CMin)
622
+ if constexpr (C::is_max) {
623
+ x_norms_storage.reset(new float[nx]);
624
+ fvec_norms_L2sqr(x_norms_storage.get(), x, d, nx);
625
+ x_norms = x_norms_storage.get();
626
+
627
+ if (!y_norms) {
628
+ y_norms_storage.reset(new float[ny]);
629
+ y_norms = y_norms_storage.get();
630
+ }
631
+ }
632
+
633
+ #pragma omp parallel num_threads(nt)
634
+ {
635
+ int tid = omp_get_thread_num();
636
+ size_t j_begin = static_cast<size_t>(tid) * ny / nt;
637
+ size_t j_end = static_cast<size_t>(tid + 1) * ny / nt;
638
+ size_t local_ny = j_end - j_begin;
639
+
640
+ // Compute y_norms for this thread's segment (cache locality)
641
+ if constexpr (C::is_max) {
642
+ if (y_norms_storage && local_ny > 0) {
643
+ fvec_norms_L2sqr(
644
+ y_norms_storage.get() + j_begin,
645
+ y + j_begin * d,
646
+ d,
647
+ local_ny);
648
+ }
649
+ }
650
+
651
+ T* my_dis = all_dis.data() + tid * nx * k;
652
+ TI* my_ids = all_ids.data() + tid * nx * k;
653
+
654
+ // Each thread initializes its own heaps
655
+ for (size_t i = 0; i < nx; i++) {
656
+ heap_heapify<C>(k, my_dis + i * k, my_ids + i * k);
657
+ }
658
+
659
+ if (local_ny > 0) {
660
+ size_t max_block = std::min(bs_y, local_ny);
661
+ std::unique_ptr<float[]> ip_block(new float[nx * max_block]);
662
+
663
+ for (size_t jj0 = 0; jj0 < local_ny; jj0 += bs_y) {
664
+ size_t jj1 = std::min(jj0 + bs_y, local_ny);
665
+ size_t block_ny = jj1 - jj0;
666
+
667
+ {
668
+ float one = 1, zero = 0;
669
+ FINTEGER nyi = static_cast<FINTEGER>(block_ny);
670
+ FINTEGER nxi = static_cast<FINTEGER>(nx);
671
+ FINTEGER di = static_cast<FINTEGER>(d);
672
+ sgemm_("Transpose",
673
+ "Not transpose",
674
+ &nyi,
675
+ &nxi,
676
+ &di,
677
+ &one,
678
+ y + (j_begin + jj0) * d,
679
+ &di,
680
+ x,
681
+ &di,
682
+ &zero,
683
+ ip_block.get(),
684
+ &nyi);
685
+ }
686
+
687
+ for (size_t i = 0; i < nx; i++) {
688
+ T* heap_dis = my_dis + i * k;
689
+ TI* heap_ids = my_ids + i * k;
690
+ const float* ip_line = ip_block.get() + i * block_ny;
691
+ T thresh = heap_dis[0];
692
+
693
+ for (size_t jj = 0; jj < block_ny; jj++) {
694
+ size_t global_j = j_begin + jj0 + jj;
695
+ float ip = ip_line[jj];
696
+ T dis;
697
+
698
+ if constexpr (C::is_max) {
699
+ dis = x_norms[i] + y_norms[global_j] - 2 * ip;
700
+ if (dis < 0) {
701
+ dis = 0;
702
+ }
703
+ } else {
704
+ dis = ip;
705
+ }
706
+
707
+ if (C::cmp(thresh, dis)) {
708
+ heap_replace_top<C>(
709
+ k, heap_dis, heap_ids, dis, global_j);
710
+ thresh = heap_dis[0];
711
+ }
712
+ }
713
+ }
714
+ }
715
+ }
716
+ }
717
+
718
+ // Merge per-thread heaps into output, parallelized over queries
719
+ #pragma omp parallel for
720
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
721
+ heap_heapify<C>(k, vals + i * k, ids + i * k);
722
+
723
+ for (int t = 0; t < nt; t++) {
724
+ T* t_dis = all_dis.data() + (t * nx + i) * k;
725
+ TI* t_ids = all_ids.data() + (t * nx + i) * k;
726
+ T* out_dis = vals + i * k;
727
+ TI* out_ids = ids + i * k;
728
+
729
+ for (size_t j = 0; j < k; j++) {
730
+ if (t_ids[j] >= 0 && C::cmp(out_dis[0], t_dis[j])) {
731
+ heap_replace_top<C>(
732
+ k, out_dis, out_ids, t_dis[j], t_ids[j]);
733
+ }
734
+ }
735
+ }
736
+
737
+ heap_reorder<C>(k, vals + i * k, ids + i * k);
738
+ }
739
+ }
740
+
741
+ static bool should_use_db_parallel(
742
+ size_t nx,
743
+ size_t ny,
744
+ const IDSelector* sel) {
745
+ if (sel) {
746
+ return false;
747
+ }
748
+ int nt = omp_get_max_threads();
749
+ size_t min_ny = std::max(
750
+ kDbParallelMinVectors,
751
+ static_cast<size_t>(nt) *
752
+ static_cast<size_t>(distance_compute_blas_database_bs));
753
+ return nt > 1 && nx < static_cast<size_t>(nt) && ny >= min_ny;
754
+ }
755
+
846
756
  void knn_inner_product(
847
757
  const float* x,
848
758
  const float* y,
@@ -867,9 +777,26 @@ void knn_inner_product(
867
777
  return;
868
778
  }
869
779
 
870
- Run_search_inner_product r;
871
- dispatch_knn_ResultHandler(
872
- nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
780
+ if (should_use_db_parallel(nx, ny, sel)) {
781
+ knn_db_parallel_impl<CMin<float, int64_t>>(
782
+ x, y, d, nx, ny, k, vals, ids, nullptr);
783
+ } else {
784
+ Run_search_inner_product r;
785
+ // @lint-ignore CLANGTIDY facebook-hte-NullableDereference
786
+ dispatch_knn_ResultHandler(
787
+ nx,
788
+ vals,
789
+ ids,
790
+ k,
791
+ METRIC_INNER_PRODUCT,
792
+ sel,
793
+ r,
794
+ x,
795
+ y,
796
+ d,
797
+ nx,
798
+ ny);
799
+ }
873
800
 
874
801
  if (imin != 0) {
875
802
  for (size_t i = 0; i < nx * k; i++) {
@@ -916,9 +843,15 @@ void knn_L2sqr(
916
843
  return;
917
844
  }
918
845
 
919
- Run_search_L2sqr r;
920
- dispatch_knn_ResultHandler(
921
- nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
846
+ if (should_use_db_parallel(nx, ny, sel)) {
847
+ knn_db_parallel_impl<CMax<float, int64_t>>(
848
+ x, y, d, nx, ny, k, vals, ids, y_norm2);
849
+ } else {
850
+ Run_search_L2sqr r;
851
+ // @lint-ignore CLANGTIDY facebook-hte-NullableDereference
852
+ dispatch_knn_ResultHandler(
853
+ nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
854
+ }
922
855
 
923
856
  if (imin != 0) {
924
857
  for (size_t i = 0; i < nx * k; i++) {
@@ -989,19 +922,21 @@ void fvec_inner_products_by_idx(
989
922
  size_t d,
990
923
  size_t nx,
991
924
  size_t ny) {
925
+ with_simd_level([&]<SIMDLevel SL>() {
992
926
  #pragma omp parallel for
993
- for (int64_t j = 0; j < nx; j++) {
994
- const int64_t* __restrict idsj = ids + j * ny;
995
- const float* xj = x + j * d;
996
- float* __restrict ipj = ip + j * ny;
997
- for (size_t i = 0; i < ny; i++) {
998
- if (idsj[i] < 0) {
999
- ipj[i] = -INFINITY;
1000
- } else {
1001
- ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
927
+ for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
928
+ const int64_t* __restrict idsj = ids + j * ny;
929
+ const float* xj = x + j * d;
930
+ float* __restrict ipj = ip + j * ny;
931
+ for (size_t i = 0; i < ny; i++) {
932
+ if (idsj[i] < 0) {
933
+ ipj[i] = -INFINITY;
934
+ } else {
935
+ ipj[i] = fvec_inner_product<SL>(xj, y + d * idsj[i], d);
936
+ }
1002
937
  }
1003
938
  }
1004
- }
939
+ });
1005
940
  }
1006
941
 
1007
942
  /* compute the inner product between x and a subset y of ny vectors,
@@ -1014,19 +949,21 @@ void fvec_L2sqr_by_idx(
1014
949
  size_t d,
1015
950
  size_t nx,
1016
951
  size_t ny) {
952
+ with_simd_level([&]<SIMDLevel SL>() {
1017
953
  #pragma omp parallel for
1018
- for (int64_t j = 0; j < nx; j++) {
1019
- const int64_t* __restrict idsj = ids + j * ny;
1020
- const float* xj = x + j * d;
1021
- float* __restrict disj = dis + j * ny;
1022
- for (size_t i = 0; i < ny; i++) {
1023
- if (idsj[i] < 0) {
1024
- disj[i] = INFINITY;
1025
- } else {
1026
- disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
954
+ for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
955
+ const int64_t* __restrict idsj = ids + j * ny;
956
+ const float* xj = x + j * d;
957
+ float* __restrict disj = dis + j * ny;
958
+ for (size_t i = 0; i < ny; i++) {
959
+ if (idsj[i] < 0) {
960
+ disj[i] = INFINITY;
961
+ } else {
962
+ disj[i] = fvec_L2sqr<SL>(xj, y + d * idsj[i], d);
963
+ }
1027
964
  }
1028
965
  }
1029
- }
966
+ });
1030
967
  }
1031
968
 
1032
969
  void pairwise_indexed_L2sqr(
@@ -1037,14 +974,16 @@ void pairwise_indexed_L2sqr(
1037
974
  const float* y,
1038
975
  const int64_t* iy,
1039
976
  float* dis) {
977
+ with_simd_level([&]<SIMDLevel SL>() {
1040
978
  #pragma omp parallel for if (n > 1)
1041
- for (int64_t j = 0; j < n; j++) {
1042
- if (ix[j] >= 0 && iy[j] >= 0) {
1043
- dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
1044
- } else {
1045
- dis[j] = INFINITY;
979
+ for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
980
+ if (ix[j] >= 0 && iy[j] >= 0) {
981
+ dis[j] = fvec_L2sqr<SL>(x + d * ix[j], y + d * iy[j], d);
982
+ } else {
983
+ dis[j] = INFINITY;
984
+ }
1046
985
  }
1047
- }
986
+ });
1048
987
  }
1049
988
 
1050
989
  void pairwise_indexed_inner_product(
@@ -1055,14 +994,17 @@ void pairwise_indexed_inner_product(
1055
994
  const float* y,
1056
995
  const int64_t* iy,
1057
996
  float* dis) {
997
+ with_simd_level([&]<SIMDLevel SL>() {
1058
998
  #pragma omp parallel for if (n > 1)
1059
- for (int64_t j = 0; j < n; j++) {
1060
- if (ix[j] >= 0 && iy[j] >= 0) {
1061
- dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
1062
- } else {
1063
- dis[j] = -INFINITY;
999
+ for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
1000
+ if (ix[j] >= 0 && iy[j] >= 0) {
1001
+ dis[j] =
1002
+ fvec_inner_product<SL>(x + d * ix[j], y + d * iy[j], d);
1003
+ } else {
1004
+ dis[j] = -INFINITY;
1005
+ }
1064
1006
  }
1065
- }
1007
+ });
1066
1008
  }
1067
1009
 
1068
1010
  /* Find the nearest neighbors for nx queries in a set of ny vectors
@@ -1083,27 +1025,29 @@ void knn_inner_products_by_idx(
1083
1025
  ld_ids = ny;
1084
1026
  }
1085
1027
 
1028
+ with_simd_level([&]<SIMDLevel SL>() {
1086
1029
  #pragma omp parallel for if (nx > 100)
1087
- for (int64_t i = 0; i < nx; i++) {
1088
- const float* x_ = x + i * d;
1089
- const int64_t* idsi = ids + i * ld_ids;
1090
- size_t j;
1091
- float* __restrict simi = res_vals + i * k;
1092
- int64_t* __restrict idxi = res_ids + i * k;
1093
- minheap_heapify(k, simi, idxi);
1094
-
1095
- for (j = 0; j < nsubset; j++) {
1096
- if (idsi[j] < 0 || idsi[j] >= ny) {
1097
- break;
1098
- }
1099
- float ip = fvec_inner_product(x_, y + d * idsi[j], d);
1030
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
1031
+ const float* x_ = x + i * d;
1032
+ const int64_t* idsi = ids + i * ld_ids;
1033
+ size_t j;
1034
+ float* __restrict simi = res_vals + i * k;
1035
+ int64_t* __restrict idxi = res_ids + i * k;
1036
+ minheap_heapify(k, simi, idxi);
1037
+
1038
+ for (j = 0; j < nsubset; j++) {
1039
+ if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
1040
+ break;
1041
+ }
1042
+ float ip = fvec_inner_product<SL>(x_, y + d * idsi[j], d);
1100
1043
 
1101
- if (ip > simi[0]) {
1102
- minheap_replace_top(k, simi, idxi, ip, idsi[j]);
1044
+ if (ip > simi[0]) {
1045
+ minheap_replace_top(k, simi, idxi, ip, idsi[j]);
1046
+ }
1103
1047
  }
1048
+ minheap_reorder(k, simi, idxi);
1104
1049
  }
1105
- minheap_reorder(k, simi, idxi);
1106
- }
1050
+ });
1107
1051
  }
1108
1052
 
1109
1053
  void knn_L2sqr_by_idx(
@@ -1121,25 +1065,27 @@ void knn_L2sqr_by_idx(
1121
1065
  if (ld_ids < 0) {
1122
1066
  ld_ids = ny;
1123
1067
  }
1068
+ with_simd_level([&]<SIMDLevel SL>() {
1124
1069
  #pragma omp parallel for if (nx > 100)
1125
- for (int64_t i = 0; i < nx; i++) {
1126
- const float* x_ = x + i * d;
1127
- const int64_t* __restrict idsi = ids + i * ld_ids;
1128
- float* __restrict simi = res_vals + i * k;
1129
- int64_t* __restrict idxi = res_ids + i * k;
1130
- maxheap_heapify(k, simi, idxi);
1131
- for (size_t j = 0; j < nsubset; j++) {
1132
- if (idsi[j] < 0 || idsi[j] >= ny) {
1133
- break;
1134
- }
1135
- float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
1070
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
1071
+ const float* x_ = x + i * d;
1072
+ const int64_t* __restrict idsi = ids + i * ld_ids;
1073
+ float* __restrict simi = res_vals + i * k;
1074
+ int64_t* __restrict idxi = res_ids + i * k;
1075
+ maxheap_heapify(k, simi, idxi);
1076
+ for (size_t j = 0; j < nsubset; j++) {
1077
+ if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
1078
+ break;
1079
+ }
1080
+ float disij = fvec_L2sqr<SL>(x_, y + d * idsi[j], d);
1136
1081
 
1137
- if (disij < simi[0]) {
1138
- maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
1082
+ if (disij < simi[0]) {
1083
+ maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
1084
+ }
1139
1085
  }
1086
+ maxheap_reorder(k, simi, idxi);
1140
1087
  }
1141
- maxheap_reorder(k, simi, idxi);
1142
- }
1088
+ });
1143
1089
  }
1144
1090
 
1145
1091
  void pairwise_L2sqr(
@@ -1168,25 +1114,27 @@ void pairwise_L2sqr(
1168
1114
  // store in beginning of distance matrix to avoid malloc
1169
1115
  float* b_norms = dis;
1170
1116
 
1117
+ with_simd_level([&]<SIMDLevel SL>() {
1171
1118
  #pragma omp parallel for if (nb > 1)
1172
- for (int64_t i = 0; i < nb; i++) {
1173
- b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
1174
- }
1119
+ for (int64_t i = 0; i < nb; i++) {
1120
+ b_norms[i] = fvec_norm_L2sqr<SL>(xb + i * ldb, d);
1121
+ }
1175
1122
 
1176
1123
  #pragma omp parallel for
1177
- for (int64_t i = 1; i < nq; i++) {
1178
- float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
1179
- for (int64_t j = 0; j < nb; j++) {
1180
- dis[i * ldd + j] = q_norm + b_norms[j];
1124
+ for (int64_t i = 1; i < nq; i++) {
1125
+ float q_norm = fvec_norm_L2sqr<SL>(xq + i * ldq, d);
1126
+ for (int64_t j = 0; j < nb; j++) {
1127
+ dis[i * ldd + j] = q_norm + b_norms[j];
1128
+ }
1181
1129
  }
1182
- }
1183
1130
 
1184
- {
1185
- float q_norm = fvec_norm_L2sqr(xq, d);
1186
- for (int64_t j = 0; j < nb; j++) {
1187
- dis[j] += q_norm;
1131
+ {
1132
+ float q_norm = fvec_norm_L2sqr<SL>(xq, d);
1133
+ for (int64_t j = 0; j < nb; j++) {
1134
+ dis[j] += q_norm;
1135
+ }
1188
1136
  }
1189
- }
1137
+ });
1190
1138
 
1191
1139
  {
1192
1140
  FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
@@ -1215,7 +1163,7 @@ void inner_product_to_L2sqr(
1215
1163
  size_t n1,
1216
1164
  size_t n2) {
1217
1165
  #pragma omp parallel for
1218
- for (int64_t j = 0; j < n1; j++) {
1166
+ for (int64_t j = 0; j < static_cast<int64_t>(n1); j++) {
1219
1167
  float* disj = dis + j * n2;
1220
1168
  for (size_t i = 0; i < n2; i++) {
1221
1169
  disj[i] = nr1[j] + nr2[i] - 2 * disj[i];