faiss 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (361) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +2 -1
  4. data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
  5. data/ext/faiss/index_binary.cpp +1 -1
  6. data/ext/faiss/kmeans.cpp +1 -1
  7. data/ext/faiss/pca_matrix.cpp +1 -1
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
  10. data/lib/faiss/version.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +93 -80
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -240
  13. data/vendor/faiss/faiss/Clustering.h +6 -0
  14. data/vendor/faiss/faiss/IVFlib.cpp +41 -21
  15. data/vendor/faiss/faiss/Index.cpp +6 -5
  16. data/vendor/faiss/faiss/Index.h +5 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  21. data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
  22. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
  32. data/vendor/faiss/faiss/IndexFastScan.h +25 -23
  33. data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
  34. data/vendor/faiss/faiss/IndexFlat.h +21 -18
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
  36. data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
  37. data/vendor/faiss/faiss/IndexHNSW.h +16 -2
  38. data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
  39. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  40. data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
  41. data/vendor/faiss/faiss/IndexIVF.h +33 -12
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
  45. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
  46. data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
  47. data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
  48. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  49. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
  50. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
  53. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  55. data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
  56. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
  57. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
  58. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
  59. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  62. data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
  63. data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  66. data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
  67. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
  68. data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
  69. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  73. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  74. data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
  75. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
  76. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
  77. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
  78. data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
  79. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  80. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  84. data/vendor/faiss/faiss/IndexShards.cpp +10 -9
  85. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  86. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  88. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  89. data/vendor/faiss/faiss/MetricType.h +14 -7
  90. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  91. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  92. data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
  93. data/vendor/faiss/faiss/VectorTransform.h +16 -16
  94. data/vendor/faiss/faiss/build.cpp +23 -0
  95. data/vendor/faiss/faiss/build.h +15 -0
  96. data/vendor/faiss/faiss/clone_index.cpp +48 -47
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  101. data/vendor/faiss/faiss/factory_tools.cpp +5 -0
  102. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  109. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  110. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  111. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  112. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  113. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  114. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  115. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  116. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  117. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  118. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  119. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  120. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  121. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  122. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  123. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  125. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
  127. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  128. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  129. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
  130. data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
  131. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
  132. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  133. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
  134. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  135. data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
  136. data/vendor/faiss/faiss/impl/HNSW.h +13 -34
  137. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  138. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  139. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
  141. data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
  142. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  143. data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
  144. data/vendor/faiss/faiss/impl/NSG.h +4 -4
  145. data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
  146. data/vendor/faiss/faiss/impl/Panorama.h +258 -87
  147. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  148. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  149. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
  150. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  151. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  152. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  153. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
  154. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  155. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
  156. data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
  157. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
  158. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
  159. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  160. data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
  161. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
  162. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
  163. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  164. data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
  165. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  166. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  167. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  168. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  169. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  170. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  171. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  172. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  173. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  174. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  175. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  176. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  177. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  178. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  179. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  180. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  181. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  182. data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
  183. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  184. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  185. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  186. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  187. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  188. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  189. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
  190. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  191. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  192. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  193. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  194. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  195. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  196. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
  197. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  198. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  199. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
  200. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  201. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  202. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  203. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  204. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  205. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  206. data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
  207. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
  208. data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
  209. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  210. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  211. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  212. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
  213. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  214. data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
  215. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  216. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  217. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  218. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  219. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  220. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  221. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  222. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  223. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
  224. data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
  225. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
  226. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  227. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
  228. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
  229. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  230. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
  231. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  232. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
  233. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
  234. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
  235. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
  236. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
  237. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
  238. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
  239. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
  240. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
  241. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  242. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
  243. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
  244. data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
  245. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  246. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
  247. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  248. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  249. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  250. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
  251. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  252. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  253. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  254. data/vendor/faiss/faiss/index_factory.cpp +86 -18
  255. data/vendor/faiss/faiss/index_io.h +24 -0
  256. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
  257. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  258. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  259. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  260. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  261. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  262. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
  263. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  264. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
  265. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  266. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  267. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  268. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  269. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  270. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  271. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  272. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  273. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
  274. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  275. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
  276. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
  277. data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
  278. data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
  279. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  280. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  281. data/vendor/faiss/faiss/utils/distances.cpp +390 -560
  282. data/vendor/faiss/faiss/utils/distances.h +20 -1
  283. data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
  284. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  285. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  286. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  287. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  288. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  289. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  290. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
  291. data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
  292. data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
  293. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  294. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  295. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  296. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  297. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  298. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  299. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  300. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  301. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  302. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  303. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  304. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  305. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  306. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  307. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  308. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  309. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  310. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  311. data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
  312. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  313. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  314. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  315. data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
  316. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  317. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  318. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
  319. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
  320. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
  321. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
  322. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
  323. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  324. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  325. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
  326. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  327. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  328. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  329. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  330. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  331. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  332. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  333. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  339. data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
  340. data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
  341. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  342. data/vendor/faiss/faiss/utils/utils.cpp +5 -5
  343. data/vendor/faiss/faiss/utils/utils.h +3 -3
  344. metadata +119 -34
  345. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  346. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  347. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
  348. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
  349. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  350. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  351. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  352. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  353. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
  354. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  355. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  356. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
  357. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  358. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  359. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  360. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
  361. /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
@@ -13,22 +13,19 @@
13
13
  #include <cstddef>
14
14
  #include <cstdio>
15
15
  #include <cstring>
16
+ #include <vector>
16
17
 
17
18
  #include <omp.h>
18
19
 
19
- #ifdef __AVX2__
20
- #include <immintrin.h>
21
- #elif defined(__ARM_FEATURE_SVE)
22
- #include <arm_sve.h>
23
- #endif
24
-
25
20
  #include <faiss/impl/AuxIndexStructures.h>
26
21
  #include <faiss/impl/FaissAssert.h>
27
22
  #include <faiss/impl/IDSelector.h>
28
23
  #include <faiss/impl/ResultHandler.h>
29
24
 
25
+ #include <faiss/impl/simd_dispatch.h>
30
26
  #include <faiss/utils/distances_dispatch.h>
31
27
  #include <faiss/utils/distances_fused/distances_fused.h>
28
+ #include <faiss/utils/simd_impl/exhaustive_L2sqr_blas_cmax.h>
32
29
 
33
30
  #ifndef FINTEGER
34
31
  #define FINTEGER long
@@ -172,6 +169,30 @@ int fvec_madd_and_argmin(
172
169
  return fvec_madd_and_argmin_dispatch(n, a, bf, b, c);
173
170
  }
174
171
 
172
+ void fvec_sub(size_t d, const float* a, const float* b, float* c) {
173
+ fvec_sub_dispatch(d, a, b, c);
174
+ }
175
+
176
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
177
+ fvec_add_dispatch(d, a, b, c);
178
+ }
179
+
180
+ void fvec_add(size_t d, const float* a, float b, float* c) {
181
+ fvec_add_scalar_dispatch(d, a, b, c);
182
+ }
183
+
184
+ void compute_PQ_dis_tables_dsub2(
185
+ size_t d,
186
+ size_t ksub,
187
+ const float* all_centroids,
188
+ size_t nx,
189
+ const float* x,
190
+ bool is_inner_product,
191
+ float* dis_tables) {
192
+ compute_PQ_dis_tables_dsub2_dispatch(
193
+ d, ksub, all_centroids, nx, x, is_inner_product, dis_tables);
194
+ }
195
+
175
196
  /***************************************************************************
176
197
  * Matrix/vector ops
177
198
  ***************************************************************************/
@@ -182,10 +203,12 @@ void fvec_norms_L2(
182
203
  const float* __restrict x,
183
204
  size_t d,
184
205
  size_t nx) {
206
+ with_simd_level([&]<SIMDLevel SL>() {
185
207
  #pragma omp parallel for if (nx > 10000)
186
- for (int64_t i = 0; i < nx; i++) {
187
- nr[i] = sqrtf(fvec_norm_L2sqr_dispatch(x + i * d, d));
188
- }
208
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
209
+ nr[i] = sqrtf(fvec_norm_L2sqr<SL>(x + i * d, d));
210
+ }
211
+ });
189
212
  }
190
213
 
191
214
  void fvec_norms_L2sqr(
@@ -193,10 +216,12 @@ void fvec_norms_L2sqr(
193
216
  const float* __restrict x,
194
217
  size_t d,
195
218
  size_t nx) {
219
+ with_simd_level([&]<SIMDLevel SL>() {
196
220
  #pragma omp parallel for if (nx > 10000)
197
- for (int64_t i = 0; i < nx; i++) {
198
- nr[i] = fvec_norm_L2sqr_dispatch(x + i * d, d);
199
- }
221
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
222
+ nr[i] = fvec_norm_L2sqr<SL>(x + i * d, d);
223
+ }
224
+ });
200
225
  }
201
226
 
202
227
  // The following is a workaround to a problem
@@ -210,29 +235,35 @@ void fvec_norms_L2sqr(
210
235
  // The workaround below is explicitly branching
211
236
  // off to a codepath without omp.
212
237
 
213
- #define FVEC_RENORM_L2_IMPL \
214
- float* __restrict xi = x + i * d; \
215
- \
216
- float nr = fvec_norm_L2sqr_dispatch(xi, d); \
217
- \
218
- if (nr > 0) { \
219
- size_t j; \
220
- const float inv_nr = 1.0 / sqrtf(nr); \
221
- for (j = 0; j < d; j++) \
222
- xi[j] *= inv_nr; \
223
- }
224
-
225
238
  void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
226
- for (int64_t i = 0; i < nx; i++) {
227
- FVEC_RENORM_L2_IMPL
228
- }
239
+ with_simd_level([&]<SIMDLevel SL>() {
240
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
241
+ float* __restrict xi = x + i * d;
242
+ float nr = fvec_norm_L2sqr<SL>(xi, d);
243
+ if (nr > 0) {
244
+ const float inv_nr = 1.0 / sqrtf(nr);
245
+ for (size_t j = 0; j < d; j++) {
246
+ xi[j] *= inv_nr;
247
+ }
248
+ }
249
+ }
250
+ });
229
251
  }
230
252
 
231
253
  void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
254
+ with_simd_level([&]<SIMDLevel SL>() {
232
255
  #pragma omp parallel for if (nx > 10000)
233
- for (int64_t i = 0; i < nx; i++) {
234
- FVEC_RENORM_L2_IMPL
235
- }
256
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
257
+ float* __restrict xi = x + i * d;
258
+ float nr = fvec_norm_L2sqr<SL>(xi, d);
259
+ if (nr > 0) {
260
+ const float inv_nr = 1.0 / sqrtf(nr);
261
+ for (size_t j = 0; j < d; j++) {
262
+ xi[j] *= inv_nr;
263
+ }
264
+ }
265
+ }
266
+ });
236
267
  }
237
268
 
238
269
  void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
@@ -265,22 +296,24 @@ void exhaustive_inner_product_seq(
265
296
  #pragma omp parallel num_threads(nt)
266
297
  {
267
298
  SingleResultHandler resi(res);
299
+ with_simd_level([&]<SIMDLevel SL>() {
268
300
  #pragma omp for
269
- for (int64_t i = 0; i < nx; i++) {
270
- const float* x_i = x + i * d;
271
- const float* y_j = y;
301
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
302
+ const float* x_i = x + i * d;
303
+ const float* y_j = y;
272
304
 
273
- resi.begin(i);
305
+ resi.begin(i);
274
306
 
275
- for (size_t j = 0; j < ny; j++, y_j += d) {
276
- if (!res.is_in_selection(j)) {
277
- continue;
307
+ for (size_t j = 0; j < ny; j++, y_j += d) {
308
+ if (!res.is_in_selection(j)) {
309
+ continue;
310
+ }
311
+ float ip = fvec_inner_product<SL>(x_i, y_j, d);
312
+ resi.add_result(ip, j);
278
313
  }
279
- float ip = fvec_inner_product_dispatch(x_i, y_j, d);
280
- resi.add_result(ip, j);
314
+ resi.end();
281
315
  }
282
- resi.end();
283
- }
316
+ });
284
317
  }
285
318
  }
286
319
 
@@ -299,20 +332,22 @@ void exhaustive_L2sqr_seq(
299
332
  #pragma omp parallel num_threads(nt)
300
333
  {
301
334
  SingleResultHandler resi(res);
335
+ with_simd_level([&]<SIMDLevel SL>() {
302
336
  #pragma omp for
303
- for (int64_t i = 0; i < nx; i++) {
304
- const float* x_i = x + i * d;
305
- const float* y_j = y;
306
- resi.begin(i);
307
- for (size_t j = 0; j < ny; j++, y_j += d) {
308
- if (!res.is_in_selection(j)) {
309
- continue;
337
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
338
+ const float* x_i = x + i * d;
339
+ const float* y_j = y;
340
+ resi.begin(i);
341
+ for (size_t j = 0; j < ny; j++, y_j += d) {
342
+ if (!res.is_in_selection(j)) {
343
+ continue;
344
+ }
345
+ float disij = fvec_L2sqr<SL>(x_i, y_j, d);
346
+ resi.add_result(disij, j);
310
347
  }
311
- float disij = fvec_L2sqr_dispatch(x_i, y_j, d);
312
- resi.add_result(disij, j);
348
+ resi.end();
313
349
  }
314
- resi.end();
315
- }
350
+ });
316
351
  }
317
352
  }
318
353
 
@@ -438,7 +473,7 @@ void exhaustive_L2sqr_blas_default_impl(
438
473
  ip_block.get(),
439
474
  &nyi);
440
475
  }
441
- for (int64_t i = i0; i < i1; i++) {
476
+ for (size_t i = i0; i < i1; i++) {
442
477
  float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
443
478
 
444
479
  for (size_t j = j0; j < j1; j++) {
@@ -474,396 +509,12 @@ void exhaustive_L2sqr_blas(
474
509
  size_t ny,
475
510
  BlockResultHandler& res,
476
511
  const float* y_norms = nullptr) {
477
- exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
512
+ exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res, y_norms);
478
513
  }
479
514
 
480
- #ifdef __AVX2__
481
- void exhaustive_L2sqr_blas_cmax_avx2(
482
- const float* x,
483
- const float* y,
484
- size_t d,
485
- size_t nx,
486
- size_t ny,
487
- Top1BlockResultHandler<CMax<float, int64_t>>& res,
488
- const float* y_norms) {
489
- // BLAS does not like empty matrices
490
- if (nx == 0 || ny == 0) {
491
- return;
492
- }
493
-
494
- /* block sizes */
495
- const size_t bs_x = distance_compute_blas_query_bs;
496
- const size_t bs_y = distance_compute_blas_database_bs;
497
- // const size_t bs_x = 16, bs_y = 16;
498
- std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
499
- std::unique_ptr<float[]> x_norms(new float[nx]);
500
- std::unique_ptr<float[]> del2;
501
-
502
- fvec_norms_L2sqr(x_norms.get(), x, d, nx);
503
-
504
- if (!y_norms) {
505
- float* y_norms2 = new float[ny];
506
- del2.reset(y_norms2);
507
- fvec_norms_L2sqr(y_norms2, y, d, ny);
508
- y_norms = y_norms2;
509
- }
510
-
511
- for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
512
- size_t i1 = i0 + bs_x;
513
- if (i1 > nx) {
514
- i1 = nx;
515
- }
516
-
517
- res.begin_multiple(i0, i1);
518
-
519
- for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
520
- size_t j1 = j0 + bs_y;
521
- if (j1 > ny) {
522
- j1 = ny;
523
- }
524
- /* compute the actual dot products */
525
- {
526
- float one = 1, zero = 0;
527
- FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
528
- sgemm_("Transpose",
529
- "Not transpose",
530
- &nyi,
531
- &nxi,
532
- &di,
533
- &one,
534
- y + j0 * d,
535
- &di,
536
- x + i0 * d,
537
- &di,
538
- &zero,
539
- ip_block.get(),
540
- &nyi);
541
- }
542
- for (int64_t i = i0; i < i1; i++) {
543
- float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
544
-
545
- _mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
546
- _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
547
-
548
- // constant
549
- const __m256 mul_minus2 = _mm256_set1_ps(-2);
550
-
551
- // Track 8 min distances + 8 min indices.
552
- // All the distances tracked do not take x_norms[i]
553
- // into account in order to get rid of extra
554
- // _mm256_add_ps(x_norms[i], ...) instructions
555
- // is distance computations.
556
- __m256 min_distances =
557
- _mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
558
-
559
- // these indices are local and are relative to j0.
560
- // so, value 0 means j0.
561
- __m256i min_indices = _mm256_set1_epi32(0);
562
-
563
- __m256i current_indices =
564
- _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
565
- const __m256i indices_delta = _mm256_set1_epi32(8);
566
-
567
- // current j index
568
- size_t idx_j = 0;
569
- size_t count = j1 - j0;
570
-
571
- // process 16 elements per loop
572
- for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
573
- _mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
574
- _mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
575
-
576
- // load values for norms
577
- const __m256 y_norm_0 =
578
- _mm256_loadu_ps(y_norms + idx_j + j0 + 0);
579
- const __m256 y_norm_1 =
580
- _mm256_loadu_ps(y_norms + idx_j + j0 + 8);
581
-
582
- // load values for dot products
583
- const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
584
- const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
585
-
586
- // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
587
- // x_norm[i] was dropped off because it is a constant for a
588
- // given i. We'll deal with it later.
589
- __m256 distances_0 =
590
- _mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
591
- __m256 distances_1 =
592
- _mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
593
-
594
- // compare the new distances to the min distances
595
- // for each of the first group of 8 AVX2 components.
596
- const __m256 comparison_0 = _mm256_cmp_ps(
597
- min_distances, distances_0, _CMP_LE_OS);
598
-
599
- // update min distances and indices with closest vectors if
600
- // needed.
601
- min_distances = _mm256_blendv_ps(
602
- distances_0, min_distances, comparison_0);
603
- min_indices = _mm256_castps_si256(_mm256_blendv_ps(
604
- _mm256_castsi256_ps(current_indices),
605
- _mm256_castsi256_ps(min_indices),
606
- comparison_0));
607
- current_indices =
608
- _mm256_add_epi32(current_indices, indices_delta);
609
-
610
- // compare the new distances to the min distances
611
- // for each of the second group of 8 AVX2 components.
612
- const __m256 comparison_1 = _mm256_cmp_ps(
613
- min_distances, distances_1, _CMP_LE_OS);
614
-
615
- // update min distances and indices with closest vectors if
616
- // needed.
617
- min_distances = _mm256_blendv_ps(
618
- distances_1, min_distances, comparison_1);
619
- min_indices = _mm256_castps_si256(_mm256_blendv_ps(
620
- _mm256_castsi256_ps(current_indices),
621
- _mm256_castsi256_ps(min_indices),
622
- comparison_1));
623
- current_indices =
624
- _mm256_add_epi32(current_indices, indices_delta);
625
- }
626
-
627
- // dump values and find the minimum distance / minimum index
628
- float min_distances_scalar[8];
629
- uint32_t min_indices_scalar[8];
630
- _mm256_storeu_ps(min_distances_scalar, min_distances);
631
- _mm256_storeu_si256(
632
- (__m256i*)(min_indices_scalar), min_indices);
633
-
634
- float current_min_distance = res.dis_tab[i];
635
- uint32_t current_min_index = res.ids_tab[i];
636
-
637
- // This unusual comparison is needed to maintain the behavior
638
- // of the original implementation: if two indices are
639
- // represented with equal distance values, then
640
- // the index with the min value is returned.
641
- for (size_t jv = 0; jv < 8; jv++) {
642
- // add missing x_norms[i]
643
- float distance_candidate =
644
- min_distances_scalar[jv] + x_norms[i];
645
-
646
- // negative values can occur for identical vectors
647
- // due to roundoff errors.
648
- if (distance_candidate < 0) {
649
- distance_candidate = 0;
650
- }
651
-
652
- int64_t index_candidate = min_indices_scalar[jv] + j0;
653
-
654
- if (current_min_distance > distance_candidate) {
655
- current_min_distance = distance_candidate;
656
- current_min_index = index_candidate;
657
- } else if (
658
- current_min_distance == distance_candidate &&
659
- current_min_index > index_candidate) {
660
- current_min_index = index_candidate;
661
- }
662
- }
663
-
664
- // process leftovers
665
- for (; idx_j < count; idx_j++, ip_line++) {
666
- float ip = *ip_line;
667
- float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
668
- // negative values can occur for identical vectors
669
- // due to roundoff errors.
670
- if (dis < 0) {
671
- dis = 0;
672
- }
673
-
674
- if (current_min_distance > dis) {
675
- current_min_distance = dis;
676
- current_min_index = idx_j + j0;
677
- }
678
- }
679
-
680
- //
681
- res.add_result(i, current_min_distance, current_min_index);
682
- }
683
- }
684
- // Does nothing for SingleBestResultHandler, but
685
- // keeping the call for the consistency.
686
- res.end_multiple();
687
- InterruptCallback::check();
688
- }
689
- }
690
- #elif defined(__ARM_FEATURE_SVE)
691
- void exhaustive_L2sqr_blas_cmax_sve(
692
- const float* x,
693
- const float* y,
694
- size_t d,
695
- size_t nx,
696
- size_t ny,
697
- Top1BlockResultHandler<CMax<float, int64_t>>& res,
698
- const float* y_norms) {
699
- // BLAS does not like empty matrices
700
- if (nx == 0 || ny == 0)
701
- return;
702
-
703
- /* block sizes */
704
- const size_t bs_x = distance_compute_blas_query_bs;
705
- const size_t bs_y = distance_compute_blas_database_bs;
706
- // const size_t bs_x = 16, bs_y = 16;
707
- std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
708
- std::unique_ptr<float[]> x_norms(new float[nx]);
709
- std::unique_ptr<float[]> del2;
710
-
711
- fvec_norms_L2sqr(x_norms.get(), x, d, nx);
712
-
713
- const size_t lanes = svcntw();
714
-
715
- if (!y_norms) {
716
- float* y_norms2 = new float[ny];
717
- del2.reset(y_norms2);
718
- fvec_norms_L2sqr(y_norms2, y, d, ny);
719
- y_norms = y_norms2;
720
- }
721
-
722
- for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
723
- size_t i1 = i0 + bs_x;
724
- if (i1 > nx)
725
- i1 = nx;
726
-
727
- res.begin_multiple(i0, i1);
728
-
729
- for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
730
- size_t j1 = j0 + bs_y;
731
- if (j1 > ny)
732
- j1 = ny;
733
- /* compute the actual dot products */
734
- {
735
- float one = 1, zero = 0;
736
- FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
737
- sgemm_("Transpose",
738
- "Not transpose",
739
- &nyi,
740
- &nxi,
741
- &di,
742
- &one,
743
- y + j0 * d,
744
- &di,
745
- x + i0 * d,
746
- &di,
747
- &zero,
748
- ip_block.get(),
749
- &nyi);
750
- }
751
- for (int64_t i = i0; i < i1; i++) {
752
- const size_t count = j1 - j0;
753
- float* ip_line = ip_block.get() + (i - i0) * count;
754
-
755
- svprfw(svwhilelt_b32_u64(0, count), ip_line, SV_PLDL1KEEP);
756
- svprfw(svwhilelt_b32_u64(lanes, count),
757
- ip_line + lanes,
758
- SV_PLDL1KEEP);
759
-
760
- // Track lanes min distances + lanes min indices.
761
- // All the distances tracked do not take x_norms[i]
762
- // into account in order to get rid of extra
763
- // vaddq_f32(x_norms[i], ...) instructions
764
- // is distance computations.
765
- auto min_distances = svdup_n_f32(res.dis_tab[i] - x_norms[i]);
766
-
767
- // these indices are local and are relative to j0.
768
- // so, value 0 means j0.
769
- auto min_indices = svdup_n_u32(0u);
770
-
771
- auto current_indices = svindex_u32(0u, 1u);
772
-
773
- // process lanes * 2 elements per loop
774
- for (size_t idx_j = 0; idx_j < count;
775
- idx_j += lanes * 2, ip_line += lanes * 2) {
776
- svprfw(svwhilelt_b32_u64(idx_j + lanes * 2, count),
777
- ip_line + lanes * 2,
778
- SV_PLDL1KEEP);
779
- svprfw(svwhilelt_b32_u64(idx_j + lanes * 3, count),
780
- ip_line + lanes * 3,
781
- SV_PLDL1KEEP);
782
-
783
- // mask
784
- const auto mask_0 = svwhilelt_b32_u64(idx_j, count);
785
- const auto mask_1 = svwhilelt_b32_u64(idx_j + lanes, count);
786
-
787
- // load values for norms
788
- const auto y_norm_0 =
789
- svld1_f32(mask_0, y_norms + idx_j + j0 + 0);
790
- const auto y_norm_1 =
791
- svld1_f32(mask_1, y_norms + idx_j + j0 + lanes);
792
-
793
- // load values for dot products
794
- const auto ip_0 = svld1_f32(mask_0, ip_line + 0);
795
- const auto ip_1 = svld1_f32(mask_1, ip_line + lanes);
796
-
797
- // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
798
- // x_norm[i] was dropped off because it is a constant for a
799
- // given i. We'll deal with it later.
800
- const auto distances_0 =
801
- svmla_n_f32_z(mask_0, y_norm_0, ip_0, -2.f);
802
- const auto distances_1 =
803
- svmla_n_f32_z(mask_1, y_norm_1, ip_1, -2.f);
804
-
805
- // compare the new distances to the min distances
806
- // for each of the first group of 4 ARM SIMD components.
807
- auto comparison =
808
- svcmpgt_f32(mask_0, min_distances, distances_0);
809
-
810
- // update min distances and indices with closest vectors if
811
- // needed.
812
- min_distances =
813
- svsel_f32(comparison, distances_0, min_distances);
814
- min_indices =
815
- svsel_u32(comparison, current_indices, min_indices);
816
- current_indices = svadd_n_u32_x(
817
- mask_0,
818
- current_indices,
819
- static_cast<uint32_t>(lanes));
820
-
821
- // compare the new distances to the min distances
822
- // for each of the second group of 4 ARM SIMD components.
823
- comparison =
824
- svcmpgt_f32(mask_1, min_distances, distances_1);
825
-
826
- // update min distances and indices with closest vectors if
827
- // needed.
828
- min_distances =
829
- svsel_f32(comparison, distances_1, min_distances);
830
- min_indices =
831
- svsel_u32(comparison, current_indices, min_indices);
832
- current_indices = svadd_n_u32_x(
833
- mask_1,
834
- current_indices,
835
- static_cast<uint32_t>(lanes));
836
- }
515
+ } // anonymous namespace
837
516
 
838
- // add missing x_norms[i]
839
- // negative values can occur for identical vectors
840
- // due to roundoff errors.
841
- auto mask = svwhilelt_b32_u64(0, count);
842
- min_distances = svadd_n_f32_z(
843
- svcmpge_n_f32(mask, min_distances, -x_norms[i]),
844
- min_distances,
845
- x_norms[i]);
846
- min_indices = svadd_n_u32_x(
847
- mask, min_indices, static_cast<uint32_t>(j0));
848
- mask = svcmple_n_f32(mask, min_distances, res.dis_tab[i]);
849
- if (svcntp_b32(svptrue_b32(), mask) == 0)
850
- res.add_result(i, res.dis_tab[i], res.ids_tab[i]);
851
- else {
852
- const auto min_distance = svminv_f32(mask, min_distances);
853
- const auto min_index = svminv_u32(
854
- svcmpeq_n_f32(mask, min_distances, min_distance),
855
- min_indices);
856
- res.add_result(i, min_distance, min_index);
857
- }
858
- }
859
- }
860
- // Does nothing for SingleBestResultHandler, but
861
- // keeping the call for the consistency.
862
- res.end_multiple();
863
- InterruptCallback::check();
864
- }
865
- }
866
- #endif
517
+ namespace {
867
518
 
868
519
  // an override if only a single closest point is needed
869
520
  template <>
@@ -875,43 +526,20 @@ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
875
526
  size_t ny,
876
527
  Top1BlockResultHandler<CMax<float, int64_t>>& res,
877
528
  const float* y_norms) {
878
- #if defined(__AVX2__)
879
529
  // use a faster fused kernel if available
880
530
  if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
881
- // the kernel is available and it is complete, we're done.
882
531
  return;
883
532
  }
884
533
 
885
- // run the specialized AVX2 implementation
886
- exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms);
887
-
888
- #elif defined(__ARM_FEATURE_SVE)
889
- // use a faster fused kernel if available
890
- if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
891
- // the kernel is available and it is complete, we're done.
892
- return;
893
- }
894
-
895
- // run the specialized SVE implementation
896
- exhaustive_L2sqr_blas_cmax_sve(x, y, d, nx, ny, res, y_norms);
897
-
898
- #elif defined(__aarch64__)
899
- // use a faster fused kernel if available
900
- if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
901
- // the kernel is available and it is complete, we're done.
902
- return;
903
- }
904
-
905
- // run the default implementation
906
- exhaustive_L2sqr_blas_default_impl<
907
- Top1BlockResultHandler<CMax<float, int64_t>>>(
908
- x, y, d, nx, ny, res, y_norms);
909
- #else
910
- // run the default implementation
911
- exhaustive_L2sqr_blas_default_impl<
912
- Top1BlockResultHandler<CMax<float, int64_t>>>(
913
- x, y, d, nx, ny, res, y_norms);
914
- #endif
534
+ with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A2>([&]<SIMDLevel SL>() {
535
+ if constexpr (SL == SIMDLevel::AVX2 || SL == SIMDLevel::ARM_SVE) {
536
+ exhaustive_L2sqr_blas_cmax<SL>(x, y, d, nx, ny, res, y_norms);
537
+ } else {
538
+ exhaustive_L2sqr_blas_default_impl<
539
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
540
+ x, y, d, nx, ny, res, y_norms);
541
+ }
542
+ });
915
543
  }
916
544
 
917
545
  struct Run_search_inner_product {
@@ -923,7 +551,8 @@ struct Run_search_inner_product {
923
551
  size_t d,
924
552
  size_t nx,
925
553
  size_t ny) {
926
- if (res.sel || nx < distance_compute_blas_threshold) {
554
+ if (res.sel ||
555
+ nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
927
556
  exhaustive_inner_product_seq(x, y, d, nx, ny, res);
928
557
  } else {
929
558
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
@@ -941,7 +570,8 @@ struct Run_search_L2sqr {
941
570
  size_t nx,
942
571
  size_t ny,
943
572
  const float* y_norm2) {
944
- if (res.sel || nx < distance_compute_blas_threshold) {
573
+ if (res.sel ||
574
+ nx * d < static_cast<size_t>(distance_compute_blas_threshold)) {
945
575
  exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
946
576
  } else {
947
577
  exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
@@ -955,11 +585,174 @@ struct Run_search_L2sqr {
955
585
  * KNN driver functions
956
586
  *******************************************************/
957
587
 
958
- int distance_compute_blas_threshold = 20;
588
+ int distance_compute_blas_threshold = 128000;
959
589
  int distance_compute_blas_query_bs = 4096;
960
590
  int distance_compute_blas_database_bs = 1024;
961
591
  int distance_compute_min_k_reservoir = 100;
962
592
 
593
+ // Database-parallel KNN: parallelizes over database segments instead of
594
+ // queries, for the case where nx < nthreads and the database is large.
595
+ static constexpr size_t kDbParallelMinVectors = 10000;
596
+
597
+ template <class C>
598
+ static void knn_db_parallel_impl(
599
+ const float* x,
600
+ const float* y,
601
+ size_t d,
602
+ size_t nx,
603
+ size_t ny,
604
+ size_t k,
605
+ float* vals,
606
+ int64_t* ids,
607
+ const float* y_norms) {
608
+ using T = typename C::T;
609
+ using TI = typename C::TI;
610
+
611
+ int nt = omp_get_max_threads();
612
+ const size_t bs_y = distance_compute_blas_database_bs;
613
+
614
+ // Per-thread result heaps: nt threads x nx queries x k results
615
+ std::vector<T> all_dis(static_cast<size_t>(nt) * nx * k);
616
+ std::vector<TI> all_ids(static_cast<size_t>(nt) * nx * k);
617
+
618
+ std::unique_ptr<float[]> x_norms_storage;
619
+ std::unique_ptr<float[]> y_norms_storage;
620
+ const float* x_norms = nullptr;
621
+ // C::is_max corresponds to L2 (CMax), not IP (CMin)
622
+ if constexpr (C::is_max) {
623
+ x_norms_storage.reset(new float[nx]);
624
+ fvec_norms_L2sqr(x_norms_storage.get(), x, d, nx);
625
+ x_norms = x_norms_storage.get();
626
+
627
+ if (!y_norms) {
628
+ y_norms_storage.reset(new float[ny]);
629
+ y_norms = y_norms_storage.get();
630
+ }
631
+ }
632
+
633
+ #pragma omp parallel num_threads(nt)
634
+ {
635
+ int tid = omp_get_thread_num();
636
+ size_t j_begin = static_cast<size_t>(tid) * ny / nt;
637
+ size_t j_end = static_cast<size_t>(tid + 1) * ny / nt;
638
+ size_t local_ny = j_end - j_begin;
639
+
640
+ // Compute y_norms for this thread's segment (cache locality)
641
+ if constexpr (C::is_max) {
642
+ if (y_norms_storage && local_ny > 0) {
643
+ fvec_norms_L2sqr(
644
+ y_norms_storage.get() + j_begin,
645
+ y + j_begin * d,
646
+ d,
647
+ local_ny);
648
+ }
649
+ }
650
+
651
+ T* my_dis = all_dis.data() + tid * nx * k;
652
+ TI* my_ids = all_ids.data() + tid * nx * k;
653
+
654
+ // Each thread initializes its own heaps
655
+ for (size_t i = 0; i < nx; i++) {
656
+ heap_heapify<C>(k, my_dis + i * k, my_ids + i * k);
657
+ }
658
+
659
+ if (local_ny > 0) {
660
+ size_t max_block = std::min(bs_y, local_ny);
661
+ std::unique_ptr<float[]> ip_block(new float[nx * max_block]);
662
+
663
+ for (size_t jj0 = 0; jj0 < local_ny; jj0 += bs_y) {
664
+ size_t jj1 = std::min(jj0 + bs_y, local_ny);
665
+ size_t block_ny = jj1 - jj0;
666
+
667
+ {
668
+ float one = 1, zero = 0;
669
+ FINTEGER nyi = static_cast<FINTEGER>(block_ny);
670
+ FINTEGER nxi = static_cast<FINTEGER>(nx);
671
+ FINTEGER di = static_cast<FINTEGER>(d);
672
+ sgemm_("Transpose",
673
+ "Not transpose",
674
+ &nyi,
675
+ &nxi,
676
+ &di,
677
+ &one,
678
+ y + (j_begin + jj0) * d,
679
+ &di,
680
+ x,
681
+ &di,
682
+ &zero,
683
+ ip_block.get(),
684
+ &nyi);
685
+ }
686
+
687
+ for (size_t i = 0; i < nx; i++) {
688
+ T* heap_dis = my_dis + i * k;
689
+ TI* heap_ids = my_ids + i * k;
690
+ const float* ip_line = ip_block.get() + i * block_ny;
691
+ T thresh = heap_dis[0];
692
+
693
+ for (size_t jj = 0; jj < block_ny; jj++) {
694
+ size_t global_j = j_begin + jj0 + jj;
695
+ float ip = ip_line[jj];
696
+ T dis;
697
+
698
+ if constexpr (C::is_max) {
699
+ dis = x_norms[i] + y_norms[global_j] - 2 * ip;
700
+ if (dis < 0) {
701
+ dis = 0;
702
+ }
703
+ } else {
704
+ dis = ip;
705
+ }
706
+
707
+ if (C::cmp(thresh, dis)) {
708
+ heap_replace_top<C>(
709
+ k, heap_dis, heap_ids, dis, global_j);
710
+ thresh = heap_dis[0];
711
+ }
712
+ }
713
+ }
714
+ }
715
+ }
716
+ }
717
+
718
+ // Merge per-thread heaps into output, parallelized over queries
719
+ #pragma omp parallel for
720
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
721
+ heap_heapify<C>(k, vals + i * k, ids + i * k);
722
+
723
+ for (int t = 0; t < nt; t++) {
724
+ T* t_dis = all_dis.data() + (t * nx + i) * k;
725
+ TI* t_ids = all_ids.data() + (t * nx + i) * k;
726
+ T* out_dis = vals + i * k;
727
+ TI* out_ids = ids + i * k;
728
+
729
+ for (size_t j = 0; j < k; j++) {
730
+ if (t_ids[j] >= 0 && C::cmp(out_dis[0], t_dis[j])) {
731
+ heap_replace_top<C>(
732
+ k, out_dis, out_ids, t_dis[j], t_ids[j]);
733
+ }
734
+ }
735
+ }
736
+
737
+ heap_reorder<C>(k, vals + i * k, ids + i * k);
738
+ }
739
+ }
740
+
741
+ static bool should_use_db_parallel(
742
+ size_t nx,
743
+ size_t ny,
744
+ const IDSelector* sel) {
745
+ if (sel) {
746
+ return false;
747
+ }
748
+ int nt = omp_get_max_threads();
749
+ size_t min_ny = std::max(
750
+ kDbParallelMinVectors,
751
+ static_cast<size_t>(nt) *
752
+ static_cast<size_t>(distance_compute_blas_database_bs));
753
+ return nt > 1 && nx < static_cast<size_t>(nt) && ny >= min_ny;
754
+ }
755
+
963
756
  void knn_inner_product(
964
757
  const float* x,
965
758
  const float* y,
@@ -984,9 +777,26 @@ void knn_inner_product(
984
777
  return;
985
778
  }
986
779
 
987
- Run_search_inner_product r;
988
- dispatch_knn_ResultHandler(
989
- nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
780
+ if (should_use_db_parallel(nx, ny, sel)) {
781
+ knn_db_parallel_impl<CMin<float, int64_t>>(
782
+ x, y, d, nx, ny, k, vals, ids, nullptr);
783
+ } else {
784
+ Run_search_inner_product r;
785
+ // @lint-ignore CLANGTIDY facebook-hte-NullableDereference
786
+ dispatch_knn_ResultHandler(
787
+ nx,
788
+ vals,
789
+ ids,
790
+ k,
791
+ METRIC_INNER_PRODUCT,
792
+ sel,
793
+ r,
794
+ x,
795
+ y,
796
+ d,
797
+ nx,
798
+ ny);
799
+ }
990
800
 
991
801
  if (imin != 0) {
992
802
  for (size_t i = 0; i < nx * k; i++) {
@@ -1033,9 +843,15 @@ void knn_L2sqr(
1033
843
  return;
1034
844
  }
1035
845
 
1036
- Run_search_L2sqr r;
1037
- dispatch_knn_ResultHandler(
1038
- nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
846
+ if (should_use_db_parallel(nx, ny, sel)) {
847
+ knn_db_parallel_impl<CMax<float, int64_t>>(
848
+ x, y, d, nx, ny, k, vals, ids, y_norm2);
849
+ } else {
850
+ Run_search_L2sqr r;
851
+ // @lint-ignore CLANGTIDY facebook-hte-NullableDereference
852
+ dispatch_knn_ResultHandler(
853
+ nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
854
+ }
1039
855
 
1040
856
  if (imin != 0) {
1041
857
  for (size_t i = 0; i < nx * k; i++) {
@@ -1106,19 +922,21 @@ void fvec_inner_products_by_idx(
1106
922
  size_t d,
1107
923
  size_t nx,
1108
924
  size_t ny) {
925
+ with_simd_level([&]<SIMDLevel SL>() {
1109
926
  #pragma omp parallel for
1110
- for (int64_t j = 0; j < nx; j++) {
1111
- const int64_t* __restrict idsj = ids + j * ny;
1112
- const float* xj = x + j * d;
1113
- float* __restrict ipj = ip + j * ny;
1114
- for (size_t i = 0; i < ny; i++) {
1115
- if (idsj[i] < 0) {
1116
- ipj[i] = -INFINITY;
1117
- } else {
1118
- ipj[i] = fvec_inner_product_dispatch(xj, y + d * idsj[i], d);
927
+ for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
928
+ const int64_t* __restrict idsj = ids + j * ny;
929
+ const float* xj = x + j * d;
930
+ float* __restrict ipj = ip + j * ny;
931
+ for (size_t i = 0; i < ny; i++) {
932
+ if (idsj[i] < 0) {
933
+ ipj[i] = -INFINITY;
934
+ } else {
935
+ ipj[i] = fvec_inner_product<SL>(xj, y + d * idsj[i], d);
936
+ }
1119
937
  }
1120
938
  }
1121
- }
939
+ });
1122
940
  }
1123
941
 
1124
942
  /* compute the inner product between x and a subset y of ny vectors,
@@ -1131,19 +949,21 @@ void fvec_L2sqr_by_idx(
1131
949
  size_t d,
1132
950
  size_t nx,
1133
951
  size_t ny) {
952
+ with_simd_level([&]<SIMDLevel SL>() {
1134
953
  #pragma omp parallel for
1135
- for (int64_t j = 0; j < nx; j++) {
1136
- const int64_t* __restrict idsj = ids + j * ny;
1137
- const float* xj = x + j * d;
1138
- float* __restrict disj = dis + j * ny;
1139
- for (size_t i = 0; i < ny; i++) {
1140
- if (idsj[i] < 0) {
1141
- disj[i] = INFINITY;
1142
- } else {
1143
- disj[i] = fvec_L2sqr_dispatch(xj, y + d * idsj[i], d);
954
+ for (int64_t j = 0; j < static_cast<int64_t>(nx); j++) {
955
+ const int64_t* __restrict idsj = ids + j * ny;
956
+ const float* xj = x + j * d;
957
+ float* __restrict disj = dis + j * ny;
958
+ for (size_t i = 0; i < ny; i++) {
959
+ if (idsj[i] < 0) {
960
+ disj[i] = INFINITY;
961
+ } else {
962
+ disj[i] = fvec_L2sqr<SL>(xj, y + d * idsj[i], d);
963
+ }
1144
964
  }
1145
965
  }
1146
- }
966
+ });
1147
967
  }
1148
968
 
1149
969
  void pairwise_indexed_L2sqr(
@@ -1154,14 +974,16 @@ void pairwise_indexed_L2sqr(
1154
974
  const float* y,
1155
975
  const int64_t* iy,
1156
976
  float* dis) {
977
+ with_simd_level([&]<SIMDLevel SL>() {
1157
978
  #pragma omp parallel for if (n > 1)
1158
- for (int64_t j = 0; j < n; j++) {
1159
- if (ix[j] >= 0 && iy[j] >= 0) {
1160
- dis[j] = fvec_L2sqr_dispatch(x + d * ix[j], y + d * iy[j], d);
1161
- } else {
1162
- dis[j] = INFINITY;
979
+ for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
980
+ if (ix[j] >= 0 && iy[j] >= 0) {
981
+ dis[j] = fvec_L2sqr<SL>(x + d * ix[j], y + d * iy[j], d);
982
+ } else {
983
+ dis[j] = INFINITY;
984
+ }
1163
985
  }
1164
- }
986
+ });
1165
987
  }
1166
988
 
1167
989
  void pairwise_indexed_inner_product(
@@ -1172,15 +994,17 @@ void pairwise_indexed_inner_product(
1172
994
  const float* y,
1173
995
  const int64_t* iy,
1174
996
  float* dis) {
997
+ with_simd_level([&]<SIMDLevel SL>() {
1175
998
  #pragma omp parallel for if (n > 1)
1176
- for (int64_t j = 0; j < n; j++) {
1177
- if (ix[j] >= 0 && iy[j] >= 0) {
1178
- dis[j] = fvec_inner_product_dispatch(
1179
- x + d * ix[j], y + d * iy[j], d);
1180
- } else {
1181
- dis[j] = -INFINITY;
999
+ for (int64_t j = 0; j < static_cast<int64_t>(n); j++) {
1000
+ if (ix[j] >= 0 && iy[j] >= 0) {
1001
+ dis[j] =
1002
+ fvec_inner_product<SL>(x + d * ix[j], y + d * iy[j], d);
1003
+ } else {
1004
+ dis[j] = -INFINITY;
1005
+ }
1182
1006
  }
1183
- }
1007
+ });
1184
1008
  }
1185
1009
 
1186
1010
  /* Find the nearest neighbors for nx queries in a set of ny vectors
@@ -1201,27 +1025,29 @@ void knn_inner_products_by_idx(
1201
1025
  ld_ids = ny;
1202
1026
  }
1203
1027
 
1028
+ with_simd_level([&]<SIMDLevel SL>() {
1204
1029
  #pragma omp parallel for if (nx > 100)
1205
- for (int64_t i = 0; i < nx; i++) {
1206
- const float* x_ = x + i * d;
1207
- const int64_t* idsi = ids + i * ld_ids;
1208
- size_t j;
1209
- float* __restrict simi = res_vals + i * k;
1210
- int64_t* __restrict idxi = res_ids + i * k;
1211
- minheap_heapify(k, simi, idxi);
1212
-
1213
- for (j = 0; j < nsubset; j++) {
1214
- if (idsi[j] < 0 || idsi[j] >= ny) {
1215
- break;
1216
- }
1217
- float ip = fvec_inner_product_dispatch(x_, y + d * idsi[j], d);
1030
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
1031
+ const float* x_ = x + i * d;
1032
+ const int64_t* idsi = ids + i * ld_ids;
1033
+ size_t j;
1034
+ float* __restrict simi = res_vals + i * k;
1035
+ int64_t* __restrict idxi = res_ids + i * k;
1036
+ minheap_heapify(k, simi, idxi);
1037
+
1038
+ for (j = 0; j < nsubset; j++) {
1039
+ if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
1040
+ break;
1041
+ }
1042
+ float ip = fvec_inner_product<SL>(x_, y + d * idsi[j], d);
1218
1043
 
1219
- if (ip > simi[0]) {
1220
- minheap_replace_top(k, simi, idxi, ip, idsi[j]);
1044
+ if (ip > simi[0]) {
1045
+ minheap_replace_top(k, simi, idxi, ip, idsi[j]);
1046
+ }
1221
1047
  }
1048
+ minheap_reorder(k, simi, idxi);
1222
1049
  }
1223
- minheap_reorder(k, simi, idxi);
1224
- }
1050
+ });
1225
1051
  }
1226
1052
 
1227
1053
  void knn_L2sqr_by_idx(
@@ -1239,25 +1065,27 @@ void knn_L2sqr_by_idx(
1239
1065
  if (ld_ids < 0) {
1240
1066
  ld_ids = ny;
1241
1067
  }
1068
+ with_simd_level([&]<SIMDLevel SL>() {
1242
1069
  #pragma omp parallel for if (nx > 100)
1243
- for (int64_t i = 0; i < nx; i++) {
1244
- const float* x_ = x + i * d;
1245
- const int64_t* __restrict idsi = ids + i * ld_ids;
1246
- float* __restrict simi = res_vals + i * k;
1247
- int64_t* __restrict idxi = res_ids + i * k;
1248
- maxheap_heapify(k, simi, idxi);
1249
- for (size_t j = 0; j < nsubset; j++) {
1250
- if (idsi[j] < 0 || idsi[j] >= ny) {
1251
- break;
1252
- }
1253
- float disij = fvec_L2sqr_dispatch(x_, y + d * idsi[j], d);
1070
+ for (int64_t i = 0; i < static_cast<int64_t>(nx); i++) {
1071
+ const float* x_ = x + i * d;
1072
+ const int64_t* __restrict idsi = ids + i * ld_ids;
1073
+ float* __restrict simi = res_vals + i * k;
1074
+ int64_t* __restrict idxi = res_ids + i * k;
1075
+ maxheap_heapify(k, simi, idxi);
1076
+ for (size_t j = 0; j < nsubset; j++) {
1077
+ if (idsi[j] < 0 || static_cast<size_t>(idsi[j]) >= ny) {
1078
+ break;
1079
+ }
1080
+ float disij = fvec_L2sqr<SL>(x_, y + d * idsi[j], d);
1254
1081
 
1255
- if (disij < simi[0]) {
1256
- maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
1082
+ if (disij < simi[0]) {
1083
+ maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
1084
+ }
1257
1085
  }
1086
+ maxheap_reorder(k, simi, idxi);
1258
1087
  }
1259
- maxheap_reorder(k, simi, idxi);
1260
- }
1088
+ });
1261
1089
  }
1262
1090
 
1263
1091
  void pairwise_L2sqr(
@@ -1286,25 +1114,27 @@ void pairwise_L2sqr(
1286
1114
  // store in beginning of distance matrix to avoid malloc
1287
1115
  float* b_norms = dis;
1288
1116
 
1117
+ with_simd_level([&]<SIMDLevel SL>() {
1289
1118
  #pragma omp parallel for if (nb > 1)
1290
- for (int64_t i = 0; i < nb; i++) {
1291
- b_norms[i] = fvec_norm_L2sqr_dispatch(xb + i * ldb, d);
1292
- }
1119
+ for (int64_t i = 0; i < nb; i++) {
1120
+ b_norms[i] = fvec_norm_L2sqr<SL>(xb + i * ldb, d);
1121
+ }
1293
1122
 
1294
1123
  #pragma omp parallel for
1295
- for (int64_t i = 1; i < nq; i++) {
1296
- float q_norm = fvec_norm_L2sqr_dispatch(xq + i * ldq, d);
1297
- for (int64_t j = 0; j < nb; j++) {
1298
- dis[i * ldd + j] = q_norm + b_norms[j];
1124
+ for (int64_t i = 1; i < nq; i++) {
1125
+ float q_norm = fvec_norm_L2sqr<SL>(xq + i * ldq, d);
1126
+ for (int64_t j = 0; j < nb; j++) {
1127
+ dis[i * ldd + j] = q_norm + b_norms[j];
1128
+ }
1299
1129
  }
1300
- }
1301
1130
 
1302
- {
1303
- float q_norm = fvec_norm_L2sqr_dispatch(xq, d);
1304
- for (int64_t j = 0; j < nb; j++) {
1305
- dis[j] += q_norm;
1131
+ {
1132
+ float q_norm = fvec_norm_L2sqr<SL>(xq, d);
1133
+ for (int64_t j = 0; j < nb; j++) {
1134
+ dis[j] += q_norm;
1135
+ }
1306
1136
  }
1307
- }
1137
+ });
1308
1138
 
1309
1139
  {
1310
1140
  FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
@@ -1333,7 +1163,7 @@ void inner_product_to_L2sqr(
1333
1163
  size_t n1,
1334
1164
  size_t n2) {
1335
1165
  #pragma omp parallel for
1336
- for (int64_t j = 0; j < n1; j++) {
1166
+ for (int64_t j = 0; j < static_cast<int64_t>(n1); j++) {
1337
1167
  float* disj = dis + j * n2;
1338
1168
  for (size_t i = 0; i < n2; i++) {
1339
1169
  disj[i] = nr1[j] + nr2[i] - 2 * disj[i];