faiss 0.3.1 → 0.3.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -23,10 +23,16 @@
23
23
  #include <immintrin.h>
24
24
  #endif
25
25
 
26
- #ifdef __AVX2__
26
+ #if defined(__AVX512F__)
27
+ #include <faiss/utils/transpose/transpose-avx512-inl.h>
28
+ #elif defined(__AVX2__)
27
29
  #include <faiss/utils/transpose/transpose-avx2-inl.h>
28
30
  #endif
29
31
 
32
+ #ifdef __ARM_FEATURE_SVE
33
+ #include <arm_sve.h>
34
+ #endif
35
+
30
36
  #ifdef __aarch64__
31
37
  #include <arm_neon.h>
32
38
  #endif
@@ -346,6 +352,14 @@ inline float horizontal_sum(const __m256 v) {
346
352
  }
347
353
  #endif
348
354
 
355
+ #ifdef __AVX512F__
356
+ /// helper function for AVX512
357
+ inline float horizontal_sum(const __m512 v) {
358
+ // performs better than adding the high and low parts
359
+ return _mm512_reduce_add_ps(v);
360
+ }
361
+ #endif
362
+
349
363
  /// Function that does a component-wise operation between x and y
350
364
  /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
351
365
  /// functions below
@@ -366,6 +380,13 @@ struct ElementOpL2 {
366
380
  return _mm256_mul_ps(tmp, tmp);
367
381
  }
368
382
  #endif
383
+
384
+ #ifdef __AVX512F__
385
+ static __m512 op(__m512 x, __m512 y) {
386
+ __m512 tmp = _mm512_sub_ps(x, y);
387
+ return _mm512_mul_ps(tmp, tmp);
388
+ }
389
+ #endif
369
390
  };
370
391
 
371
392
  /// Function that does a component-wise operation between x and y
@@ -384,6 +405,12 @@ struct ElementOpIP {
384
405
  return _mm256_mul_ps(x, y);
385
406
  }
386
407
  #endif
408
+
409
+ #ifdef __AVX512F__
410
+ static __m512 op(__m512 x, __m512 y) {
411
+ return _mm512_mul_ps(x, y);
412
+ }
413
+ #endif
387
414
  };
388
415
 
389
416
  template <class ElementOp>
@@ -426,7 +453,130 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
426
453
  }
427
454
  }
428
455
 
429
- #ifdef __AVX2__
456
+ #if defined(__AVX512F__)
457
+
458
+ template <>
459
+ void fvec_op_ny_D2<ElementOpIP>(
460
+ float* dis,
461
+ const float* x,
462
+ const float* y,
463
+ size_t ny) {
464
+ const size_t ny16 = ny / 16;
465
+ size_t i = 0;
466
+
467
+ if (ny16 > 0) {
468
+ // process 16 D2-vectors per loop.
469
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
470
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
471
+
472
+ const __m512 m0 = _mm512_set1_ps(x[0]);
473
+ const __m512 m1 = _mm512_set1_ps(x[1]);
474
+
475
+ for (i = 0; i < ny16 * 16; i += 16) {
476
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
477
+
478
+ // load 16x2 matrix and transpose it in registers.
479
+ // the typical bottleneck is memory access, so
480
+ // let's trade instructions for the bandwidth.
481
+
482
+ __m512 v0;
483
+ __m512 v1;
484
+
485
+ transpose_16x2(
486
+ _mm512_loadu_ps(y + 0 * 16),
487
+ _mm512_loadu_ps(y + 1 * 16),
488
+ v0,
489
+ v1);
490
+
491
+ // compute distances (dot product)
492
+ __m512 distances = _mm512_mul_ps(m0, v0);
493
+ distances = _mm512_fmadd_ps(m1, v1, distances);
494
+
495
+ // store
496
+ _mm512_storeu_ps(dis + i, distances);
497
+
498
+ y += 32; // move to the next set of 16x2 elements
499
+ }
500
+ }
501
+
502
+ if (i < ny) {
503
+ // process leftovers
504
+ float x0 = x[0];
505
+ float x1 = x[1];
506
+
507
+ for (; i < ny; i++) {
508
+ float distance = x0 * y[0] + x1 * y[1];
509
+ y += 2;
510
+ dis[i] = distance;
511
+ }
512
+ }
513
+ }
514
+
515
+ template <>
516
+ void fvec_op_ny_D2<ElementOpL2>(
517
+ float* dis,
518
+ const float* x,
519
+ const float* y,
520
+ size_t ny) {
521
+ const size_t ny16 = ny / 16;
522
+ size_t i = 0;
523
+
524
+ if (ny16 > 0) {
525
+ // process 16 D2-vectors per loop.
526
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
527
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
528
+
529
+ const __m512 m0 = _mm512_set1_ps(x[0]);
530
+ const __m512 m1 = _mm512_set1_ps(x[1]);
531
+
532
+ for (i = 0; i < ny16 * 16; i += 16) {
533
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
534
+
535
+ // load 16x2 matrix and transpose it in registers.
536
+ // the typical bottleneck is memory access, so
537
+ // let's trade instructions for the bandwidth.
538
+
539
+ __m512 v0;
540
+ __m512 v1;
541
+
542
+ transpose_16x2(
543
+ _mm512_loadu_ps(y + 0 * 16),
544
+ _mm512_loadu_ps(y + 1 * 16),
545
+ v0,
546
+ v1);
547
+
548
+ // compute differences
549
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
550
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
551
+
552
+ // compute squares of differences
553
+ __m512 distances = _mm512_mul_ps(d0, d0);
554
+ distances = _mm512_fmadd_ps(d1, d1, distances);
555
+
556
+ // store
557
+ _mm512_storeu_ps(dis + i, distances);
558
+
559
+ y += 32; // move to the next set of 16x2 elements
560
+ }
561
+ }
562
+
563
+ if (i < ny) {
564
+ // process leftovers
565
+ float x0 = x[0];
566
+ float x1 = x[1];
567
+
568
+ for (; i < ny; i++) {
569
+ float sub0 = x0 - y[0];
570
+ float sub1 = x1 - y[1];
571
+ float distance = sub0 * sub0 + sub1 * sub1;
572
+
573
+ y += 2;
574
+ dis[i] = distance;
575
+ }
576
+ }
577
+ }
578
+
579
+ #elif defined(__AVX2__)
430
580
 
431
581
  template <>
432
582
  void fvec_op_ny_D2<ElementOpIP>(
@@ -562,7 +712,137 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
562
712
  }
563
713
  }
564
714
 
565
- #ifdef __AVX2__
715
+ #if defined(__AVX512F__)
716
+
717
+ template <>
718
+ void fvec_op_ny_D4<ElementOpIP>(
719
+ float* dis,
720
+ const float* x,
721
+ const float* y,
722
+ size_t ny) {
723
+ const size_t ny16 = ny / 16;
724
+ size_t i = 0;
725
+
726
+ if (ny16 > 0) {
727
+ // process 16 D4-vectors per loop.
728
+ const __m512 m0 = _mm512_set1_ps(x[0]);
729
+ const __m512 m1 = _mm512_set1_ps(x[1]);
730
+ const __m512 m2 = _mm512_set1_ps(x[2]);
731
+ const __m512 m3 = _mm512_set1_ps(x[3]);
732
+
733
+ for (i = 0; i < ny16 * 16; i += 16) {
734
+ // load 16x4 matrix and transpose it in registers.
735
+ // the typical bottleneck is memory access, so
736
+ // let's trade instructions for the bandwidth.
737
+
738
+ __m512 v0;
739
+ __m512 v1;
740
+ __m512 v2;
741
+ __m512 v3;
742
+
743
+ transpose_16x4(
744
+ _mm512_loadu_ps(y + 0 * 16),
745
+ _mm512_loadu_ps(y + 1 * 16),
746
+ _mm512_loadu_ps(y + 2 * 16),
747
+ _mm512_loadu_ps(y + 3 * 16),
748
+ v0,
749
+ v1,
750
+ v2,
751
+ v3);
752
+
753
+ // compute distances
754
+ __m512 distances = _mm512_mul_ps(m0, v0);
755
+ distances = _mm512_fmadd_ps(m1, v1, distances);
756
+ distances = _mm512_fmadd_ps(m2, v2, distances);
757
+ distances = _mm512_fmadd_ps(m3, v3, distances);
758
+
759
+ // store
760
+ _mm512_storeu_ps(dis + i, distances);
761
+
762
+ y += 64; // move to the next set of 16x4 elements
763
+ }
764
+ }
765
+
766
+ if (i < ny) {
767
+ // process leftovers
768
+ __m128 x0 = _mm_loadu_ps(x);
769
+
770
+ for (; i < ny; i++) {
771
+ __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
772
+ y += 4;
773
+ dis[i] = horizontal_sum(accu);
774
+ }
775
+ }
776
+ }
777
+
778
+ template <>
779
+ void fvec_op_ny_D4<ElementOpL2>(
780
+ float* dis,
781
+ const float* x,
782
+ const float* y,
783
+ size_t ny) {
784
+ const size_t ny16 = ny / 16;
785
+ size_t i = 0;
786
+
787
+ if (ny16 > 0) {
788
+ // process 16 D4-vectors per loop.
789
+ const __m512 m0 = _mm512_set1_ps(x[0]);
790
+ const __m512 m1 = _mm512_set1_ps(x[1]);
791
+ const __m512 m2 = _mm512_set1_ps(x[2]);
792
+ const __m512 m3 = _mm512_set1_ps(x[3]);
793
+
794
+ for (i = 0; i < ny16 * 16; i += 16) {
795
+ // load 16x4 matrix and transpose it in registers.
796
+ // the typical bottleneck is memory access, so
797
+ // let's trade instructions for the bandwidth.
798
+
799
+ __m512 v0;
800
+ __m512 v1;
801
+ __m512 v2;
802
+ __m512 v3;
803
+
804
+ transpose_16x4(
805
+ _mm512_loadu_ps(y + 0 * 16),
806
+ _mm512_loadu_ps(y + 1 * 16),
807
+ _mm512_loadu_ps(y + 2 * 16),
808
+ _mm512_loadu_ps(y + 3 * 16),
809
+ v0,
810
+ v1,
811
+ v2,
812
+ v3);
813
+
814
+ // compute differences
815
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
816
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
817
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
818
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
819
+
820
+ // compute squares of differences
821
+ __m512 distances = _mm512_mul_ps(d0, d0);
822
+ distances = _mm512_fmadd_ps(d1, d1, distances);
823
+ distances = _mm512_fmadd_ps(d2, d2, distances);
824
+ distances = _mm512_fmadd_ps(d3, d3, distances);
825
+
826
+ // store
827
+ _mm512_storeu_ps(dis + i, distances);
828
+
829
+ y += 64; // move to the next set of 16x4 elements
830
+ }
831
+ }
832
+
833
+ if (i < ny) {
834
+ // process leftovers
835
+ __m128 x0 = _mm_loadu_ps(x);
836
+
837
+ for (; i < ny; i++) {
838
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
839
+ y += 4;
840
+ dis[i] = horizontal_sum(accu);
841
+ }
842
+ }
843
+ }
844
+
845
+ #elif defined(__AVX2__)
566
846
 
567
847
  template <>
568
848
  void fvec_op_ny_D4<ElementOpIP>(
@@ -710,7 +990,181 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
710
990
  }
711
991
  }
712
992
 
713
- #ifdef __AVX2__
993
+ #if defined(__AVX512F__)
994
+
995
+ template <>
996
+ void fvec_op_ny_D8<ElementOpIP>(
997
+ float* dis,
998
+ const float* x,
999
+ const float* y,
1000
+ size_t ny) {
1001
+ const size_t ny16 = ny / 16;
1002
+ size_t i = 0;
1003
+
1004
+ if (ny16 > 0) {
1005
+ // process 16 D16-vectors per loop.
1006
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1007
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1008
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1009
+ const __m512 m3 = _mm512_set1_ps(x[3]);
1010
+ const __m512 m4 = _mm512_set1_ps(x[4]);
1011
+ const __m512 m5 = _mm512_set1_ps(x[5]);
1012
+ const __m512 m6 = _mm512_set1_ps(x[6]);
1013
+ const __m512 m7 = _mm512_set1_ps(x[7]);
1014
+
1015
+ for (i = 0; i < ny16 * 16; i += 16) {
1016
+ // load 16x8 matrix and transpose it in registers.
1017
+ // the typical bottleneck is memory access, so
1018
+ // let's trade instructions for the bandwidth.
1019
+
1020
+ __m512 v0;
1021
+ __m512 v1;
1022
+ __m512 v2;
1023
+ __m512 v3;
1024
+ __m512 v4;
1025
+ __m512 v5;
1026
+ __m512 v6;
1027
+ __m512 v7;
1028
+
1029
+ transpose_16x8(
1030
+ _mm512_loadu_ps(y + 0 * 16),
1031
+ _mm512_loadu_ps(y + 1 * 16),
1032
+ _mm512_loadu_ps(y + 2 * 16),
1033
+ _mm512_loadu_ps(y + 3 * 16),
1034
+ _mm512_loadu_ps(y + 4 * 16),
1035
+ _mm512_loadu_ps(y + 5 * 16),
1036
+ _mm512_loadu_ps(y + 6 * 16),
1037
+ _mm512_loadu_ps(y + 7 * 16),
1038
+ v0,
1039
+ v1,
1040
+ v2,
1041
+ v3,
1042
+ v4,
1043
+ v5,
1044
+ v6,
1045
+ v7);
1046
+
1047
+ // compute distances
1048
+ __m512 distances = _mm512_mul_ps(m0, v0);
1049
+ distances = _mm512_fmadd_ps(m1, v1, distances);
1050
+ distances = _mm512_fmadd_ps(m2, v2, distances);
1051
+ distances = _mm512_fmadd_ps(m3, v3, distances);
1052
+ distances = _mm512_fmadd_ps(m4, v4, distances);
1053
+ distances = _mm512_fmadd_ps(m5, v5, distances);
1054
+ distances = _mm512_fmadd_ps(m6, v6, distances);
1055
+ distances = _mm512_fmadd_ps(m7, v7, distances);
1056
+
1057
+ // store
1058
+ _mm512_storeu_ps(dis + i, distances);
1059
+
1060
+ y += 128; // 16 floats * 8 rows
1061
+ }
1062
+ }
1063
+
1064
+ if (i < ny) {
1065
+ // process leftovers
1066
+ __m256 x0 = _mm256_loadu_ps(x);
1067
+
1068
+ for (; i < ny; i++) {
1069
+ __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
1070
+ y += 8;
1071
+ dis[i] = horizontal_sum(accu);
1072
+ }
1073
+ }
1074
+ }
1075
+
1076
+ template <>
1077
+ void fvec_op_ny_D8<ElementOpL2>(
1078
+ float* dis,
1079
+ const float* x,
1080
+ const float* y,
1081
+ size_t ny) {
1082
+ const size_t ny16 = ny / 16;
1083
+ size_t i = 0;
1084
+
1085
+ if (ny16 > 0) {
1086
+ // process 16 D16-vectors per loop.
1087
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1088
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1089
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1090
+ const __m512 m3 = _mm512_set1_ps(x[3]);
1091
+ const __m512 m4 = _mm512_set1_ps(x[4]);
1092
+ const __m512 m5 = _mm512_set1_ps(x[5]);
1093
+ const __m512 m6 = _mm512_set1_ps(x[6]);
1094
+ const __m512 m7 = _mm512_set1_ps(x[7]);
1095
+
1096
+ for (i = 0; i < ny16 * 16; i += 16) {
1097
+ // load 16x8 matrix and transpose it in registers.
1098
+ // the typical bottleneck is memory access, so
1099
+ // let's trade instructions for the bandwidth.
1100
+
1101
+ __m512 v0;
1102
+ __m512 v1;
1103
+ __m512 v2;
1104
+ __m512 v3;
1105
+ __m512 v4;
1106
+ __m512 v5;
1107
+ __m512 v6;
1108
+ __m512 v7;
1109
+
1110
+ transpose_16x8(
1111
+ _mm512_loadu_ps(y + 0 * 16),
1112
+ _mm512_loadu_ps(y + 1 * 16),
1113
+ _mm512_loadu_ps(y + 2 * 16),
1114
+ _mm512_loadu_ps(y + 3 * 16),
1115
+ _mm512_loadu_ps(y + 4 * 16),
1116
+ _mm512_loadu_ps(y + 5 * 16),
1117
+ _mm512_loadu_ps(y + 6 * 16),
1118
+ _mm512_loadu_ps(y + 7 * 16),
1119
+ v0,
1120
+ v1,
1121
+ v2,
1122
+ v3,
1123
+ v4,
1124
+ v5,
1125
+ v6,
1126
+ v7);
1127
+
1128
+ // compute differences
1129
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1130
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1131
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
1132
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
1133
+ const __m512 d4 = _mm512_sub_ps(m4, v4);
1134
+ const __m512 d5 = _mm512_sub_ps(m5, v5);
1135
+ const __m512 d6 = _mm512_sub_ps(m6, v6);
1136
+ const __m512 d7 = _mm512_sub_ps(m7, v7);
1137
+
1138
+ // compute squares of differences
1139
+ __m512 distances = _mm512_mul_ps(d0, d0);
1140
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1141
+ distances = _mm512_fmadd_ps(d2, d2, distances);
1142
+ distances = _mm512_fmadd_ps(d3, d3, distances);
1143
+ distances = _mm512_fmadd_ps(d4, d4, distances);
1144
+ distances = _mm512_fmadd_ps(d5, d5, distances);
1145
+ distances = _mm512_fmadd_ps(d6, d6, distances);
1146
+ distances = _mm512_fmadd_ps(d7, d7, distances);
1147
+
1148
+ // store
1149
+ _mm512_storeu_ps(dis + i, distances);
1150
+
1151
+ y += 128; // 16 floats * 8 rows
1152
+ }
1153
+ }
1154
+
1155
+ if (i < ny) {
1156
+ // process leftovers
1157
+ __m256 x0 = _mm256_loadu_ps(x);
1158
+
1159
+ for (; i < ny; i++) {
1160
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1161
+ y += 8;
1162
+ dis[i] = horizontal_sum(accu);
1163
+ }
1164
+ }
1165
+ }
1166
+
1167
+ #elif defined(__AVX2__)
714
1168
 
715
1169
  template <>
716
1170
  void fvec_op_ny_D8<ElementOpIP>(
@@ -955,7 +1409,83 @@ void fvec_inner_products_ny(
955
1409
  #undef DISPATCH
956
1410
  }
957
1411
 
958
- #ifdef __AVX2__
1412
+ #if defined(__AVX512F__)
1413
+
1414
+ template <size_t DIM>
1415
+ void fvec_L2sqr_ny_y_transposed_D(
1416
+ float* distances,
1417
+ const float* x,
1418
+ const float* y,
1419
+ const float* y_sqlen,
1420
+ const size_t d_offset,
1421
+ size_t ny) {
1422
+ // current index being processed
1423
+ size_t i = 0;
1424
+
1425
+ // squared length of x
1426
+ float x_sqlen = 0;
1427
+ for (size_t j = 0; j < DIM; j++) {
1428
+ x_sqlen += x[j] * x[j];
1429
+ }
1430
+
1431
+ // process 16 vectors per loop
1432
+ const size_t ny16 = ny / 16;
1433
+
1434
+ if (ny16 > 0) {
1435
+ // m[i] = (2 * x[i], ... 2 * x[i])
1436
+ __m512 m[DIM];
1437
+ for (size_t j = 0; j < DIM; j++) {
1438
+ m[j] = _mm512_set1_ps(x[j]);
1439
+ m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
1440
+ }
1441
+
1442
+ __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
1443
+
1444
+ for (; i < ny16 * 16; i += 16) {
1445
+ // Load vectors for 16 dimensions
1446
+ __m512 v[DIM];
1447
+ for (size_t j = 0; j < DIM; j++) {
1448
+ v[j] = _mm512_loadu_ps(y + j * d_offset);
1449
+ }
1450
+
1451
+ // Compute dot products
1452
+ __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
1453
+ for (size_t j = 1; j < DIM; j++) {
1454
+ dp = _mm512_fnmadd_ps(m[j], v[j], dp);
1455
+ }
1456
+
1457
+ // Compute y^2 - (2 * x, y) + x^2
1458
+ __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
1459
+
1460
+ _mm512_storeu_ps(distances + i, distances_v);
1461
+
1462
+ // Scroll y and y_sqlen forward
1463
+ y += 16;
1464
+ y_sqlen += 16;
1465
+ }
1466
+ }
1467
+
1468
+ if (i < ny) {
1469
+ // Process leftovers
1470
+ for (; i < ny; i++) {
1471
+ float dp = 0;
1472
+ for (size_t j = 0; j < DIM; j++) {
1473
+ dp += x[j] * y[j * d_offset];
1474
+ }
1475
+
1476
+ // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
1477
+ // lowest distance.
1478
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
1479
+ distances[i] = distance;
1480
+
1481
+ y += 1;
1482
+ y_sqlen += 1;
1483
+ }
1484
+ }
1485
+ }
1486
+
1487
+ #elif defined(__AVX2__)
1488
+
959
1489
  template <size_t DIM>
960
1490
  void fvec_L2sqr_ny_y_transposed_D(
961
1491
  float* distances,
@@ -1014,58 +1544,368 @@ void fvec_L2sqr_ny_y_transposed_D(
1014
1544
  }
1015
1545
 
1016
1546
  if (i < ny) {
1017
- // process leftovers
1018
- for (; i < ny; i++) {
1019
- float dp = 0;
1020
- for (size_t j = 0; j < DIM; j++) {
1021
- dp += x[j] * y[j * d_offset];
1022
- }
1547
+ // process leftovers
1548
+ for (; i < ny; i++) {
1549
+ float dp = 0;
1550
+ for (size_t j = 0; j < DIM; j++) {
1551
+ dp += x[j] * y[j * d_offset];
1552
+ }
1553
+
1554
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
1555
+ // lowest distance.
1556
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
1557
+ distances[i] = distance;
1558
+
1559
+ y += 1;
1560
+ y_sqlen += 1;
1561
+ }
1562
+ }
1563
+ }
1564
+
1565
+ #endif
1566
+
1567
+ void fvec_L2sqr_ny_transposed(
1568
+ float* dis,
1569
+ const float* x,
1570
+ const float* y,
1571
+ const float* y_sqlen,
1572
+ size_t d,
1573
+ size_t d_offset,
1574
+ size_t ny) {
1575
+ // optimized for a few special cases
1576
+
1577
+ #ifdef __AVX2__
1578
+ #define DISPATCH(dval) \
1579
+ case dval: \
1580
+ return fvec_L2sqr_ny_y_transposed_D<dval>( \
1581
+ dis, x, y, y_sqlen, d_offset, ny);
1582
+
1583
+ switch (d) {
1584
+ DISPATCH(1)
1585
+ DISPATCH(2)
1586
+ DISPATCH(4)
1587
+ DISPATCH(8)
1588
+ default:
1589
+ return fvec_L2sqr_ny_y_transposed_ref(
1590
+ dis, x, y, y_sqlen, d, d_offset, ny);
1591
+ }
1592
+ #undef DISPATCH
1593
+ #else
1594
+ // non-AVX2 case
1595
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1596
+ #endif
1597
+ }
1598
+
1599
+ #if defined(__AVX512F__)
1600
+
1601
+ size_t fvec_L2sqr_ny_nearest_D2(
1602
+ float* distances_tmp_buffer,
1603
+ const float* x,
1604
+ const float* y,
1605
+ size_t ny) {
1606
+ // this implementation does not use distances_tmp_buffer.
1607
+
1608
+ size_t i = 0;
1609
+ float current_min_distance = HUGE_VALF;
1610
+ size_t current_min_index = 0;
1611
+
1612
+ const size_t ny16 = ny / 16;
1613
+ if (ny16 > 0) {
1614
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
1615
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
1616
+
1617
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1618
+ __m512i min_indices = _mm512_set1_epi32(0);
1619
+
1620
+ __m512i current_indices = _mm512_setr_epi32(
1621
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1622
+ const __m512i indices_increment = _mm512_set1_epi32(16);
1623
+
1624
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1625
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1626
+
1627
+ for (; i < ny16 * 16; i += 16) {
1628
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
1629
+
1630
+ __m512 v0;
1631
+ __m512 v1;
1632
+
1633
+ transpose_16x2(
1634
+ _mm512_loadu_ps(y + 0 * 16),
1635
+ _mm512_loadu_ps(y + 1 * 16),
1636
+ v0,
1637
+ v1);
1638
+
1639
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1640
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1641
+
1642
+ __m512 distances = _mm512_mul_ps(d0, d0);
1643
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1644
+
1645
+ __mmask16 comparison =
1646
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1647
+
1648
+ min_distances = _mm512_min_ps(distances, min_distances);
1649
+ min_indices = _mm512_mask_blend_epi32(
1650
+ comparison, min_indices, current_indices);
1651
+
1652
+ current_indices =
1653
+ _mm512_add_epi32(current_indices, indices_increment);
1654
+
1655
+ y += 32;
1656
+ }
1657
+
1658
+ alignas(64) float min_distances_scalar[16];
1659
+ alignas(64) uint32_t min_indices_scalar[16];
1660
+ _mm512_store_ps(min_distances_scalar, min_distances);
1661
+ _mm512_store_epi32(min_indices_scalar, min_indices);
1662
+
1663
+ for (size_t j = 0; j < 16; j++) {
1664
+ if (current_min_distance > min_distances_scalar[j]) {
1665
+ current_min_distance = min_distances_scalar[j];
1666
+ current_min_index = min_indices_scalar[j];
1667
+ }
1668
+ }
1669
+ }
1670
+
1671
+ if (i < ny) {
1672
+ float x0 = x[0];
1673
+ float x1 = x[1];
1674
+
1675
+ for (; i < ny; i++) {
1676
+ float sub0 = x0 - y[0];
1677
+ float sub1 = x1 - y[1];
1678
+ float distance = sub0 * sub0 + sub1 * sub1;
1679
+
1680
+ y += 2;
1681
+
1682
+ if (current_min_distance > distance) {
1683
+ current_min_distance = distance;
1684
+ current_min_index = i;
1685
+ }
1686
+ }
1687
+ }
1688
+
1689
+ return current_min_index;
1690
+ }
1691
+
1692
+ size_t fvec_L2sqr_ny_nearest_D4(
1693
+ float* distances_tmp_buffer,
1694
+ const float* x,
1695
+ const float* y,
1696
+ size_t ny) {
1697
+ // this implementation does not use distances_tmp_buffer.
1698
+
1699
+ size_t i = 0;
1700
+ float current_min_distance = HUGE_VALF;
1701
+ size_t current_min_index = 0;
1702
+
1703
+ const size_t ny16 = ny / 16;
1704
+
1705
+ if (ny16 > 0) {
1706
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1707
+ __m512i min_indices = _mm512_set1_epi32(0);
1708
+
1709
+ __m512i current_indices = _mm512_setr_epi32(
1710
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1711
+ const __m512i indices_increment = _mm512_set1_epi32(16);
1712
+
1713
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1714
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1715
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1716
+ const __m512 m3 = _mm512_set1_ps(x[3]);
1717
+
1718
+ for (; i < ny16 * 16; i += 16) {
1719
+ __m512 v0;
1720
+ __m512 v1;
1721
+ __m512 v2;
1722
+ __m512 v3;
1723
+
1724
+ transpose_16x4(
1725
+ _mm512_loadu_ps(y + 0 * 16),
1726
+ _mm512_loadu_ps(y + 1 * 16),
1727
+ _mm512_loadu_ps(y + 2 * 16),
1728
+ _mm512_loadu_ps(y + 3 * 16),
1729
+ v0,
1730
+ v1,
1731
+ v2,
1732
+ v3);
1733
+
1734
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1735
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1736
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
1737
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
1738
+
1739
+ __m512 distances = _mm512_mul_ps(d0, d0);
1740
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1741
+ distances = _mm512_fmadd_ps(d2, d2, distances);
1742
+ distances = _mm512_fmadd_ps(d3, d3, distances);
1743
+
1744
+ __mmask16 comparison =
1745
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1746
+
1747
+ min_distances = _mm512_min_ps(distances, min_distances);
1748
+ min_indices = _mm512_mask_blend_epi32(
1749
+ comparison, min_indices, current_indices);
1750
+
1751
+ current_indices =
1752
+ _mm512_add_epi32(current_indices, indices_increment);
1753
+
1754
+ y += 64;
1755
+ }
1756
+
1757
+ alignas(64) float min_distances_scalar[16];
1758
+ alignas(64) uint32_t min_indices_scalar[16];
1759
+ _mm512_store_ps(min_distances_scalar, min_distances);
1760
+ _mm512_store_epi32(min_indices_scalar, min_indices);
1761
+
1762
+ for (size_t j = 0; j < 16; j++) {
1763
+ if (current_min_distance > min_distances_scalar[j]) {
1764
+ current_min_distance = min_distances_scalar[j];
1765
+ current_min_index = min_indices_scalar[j];
1766
+ }
1767
+ }
1768
+ }
1769
+
1770
+ if (i < ny) {
1771
+ __m128 x0 = _mm_loadu_ps(x);
1772
+
1773
+ for (; i < ny; i++) {
1774
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
1775
+ y += 4;
1776
+ const float distance = horizontal_sum(accu);
1777
+
1778
+ if (current_min_distance > distance) {
1779
+ current_min_distance = distance;
1780
+ current_min_index = i;
1781
+ }
1782
+ }
1783
+ }
1784
+
1785
+ return current_min_index;
1786
+ }
1787
+
1788
+ size_t fvec_L2sqr_ny_nearest_D8(
1789
+ float* distances_tmp_buffer,
1790
+ const float* x,
1791
+ const float* y,
1792
+ size_t ny) {
1793
+ // this implementation does not use distances_tmp_buffer.
1794
+
1795
+ size_t i = 0;
1796
+ float current_min_distance = HUGE_VALF;
1797
+ size_t current_min_index = 0;
1798
+
1799
+ const size_t ny16 = ny / 16;
1800
+ if (ny16 > 0) {
1801
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1802
+ __m512i min_indices = _mm512_set1_epi32(0);
1803
+
1804
+ __m512i current_indices = _mm512_setr_epi32(
1805
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1806
+ const __m512i indices_increment = _mm512_set1_epi32(16);
1807
+
1808
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1809
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1810
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1811
+ const __m512 m3 = _mm512_set1_ps(x[3]);
1812
+
1813
+ const __m512 m4 = _mm512_set1_ps(x[4]);
1814
+ const __m512 m5 = _mm512_set1_ps(x[5]);
1815
+ const __m512 m6 = _mm512_set1_ps(x[6]);
1816
+ const __m512 m7 = _mm512_set1_ps(x[7]);
1817
+
1818
+ for (; i < ny16 * 16; i += 16) {
1819
+ __m512 v0;
1820
+ __m512 v1;
1821
+ __m512 v2;
1822
+ __m512 v3;
1823
+ __m512 v4;
1824
+ __m512 v5;
1825
+ __m512 v6;
1826
+ __m512 v7;
1827
+
1828
+ transpose_16x8(
1829
+ _mm512_loadu_ps(y + 0 * 16),
1830
+ _mm512_loadu_ps(y + 1 * 16),
1831
+ _mm512_loadu_ps(y + 2 * 16),
1832
+ _mm512_loadu_ps(y + 3 * 16),
1833
+ _mm512_loadu_ps(y + 4 * 16),
1834
+ _mm512_loadu_ps(y + 5 * 16),
1835
+ _mm512_loadu_ps(y + 6 * 16),
1836
+ _mm512_loadu_ps(y + 7 * 16),
1837
+ v0,
1838
+ v1,
1839
+ v2,
1840
+ v3,
1841
+ v4,
1842
+ v5,
1843
+ v6,
1844
+ v7);
1845
+
1846
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1847
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1848
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
1849
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
1850
+ const __m512 d4 = _mm512_sub_ps(m4, v4);
1851
+ const __m512 d5 = _mm512_sub_ps(m5, v5);
1852
+ const __m512 d6 = _mm512_sub_ps(m6, v6);
1853
+ const __m512 d7 = _mm512_sub_ps(m7, v7);
1854
+
1855
+ __m512 distances = _mm512_mul_ps(d0, d0);
1856
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1857
+ distances = _mm512_fmadd_ps(d2, d2, distances);
1858
+ distances = _mm512_fmadd_ps(d3, d3, distances);
1859
+ distances = _mm512_fmadd_ps(d4, d4, distances);
1860
+ distances = _mm512_fmadd_ps(d5, d5, distances);
1861
+ distances = _mm512_fmadd_ps(d6, d6, distances);
1862
+ distances = _mm512_fmadd_ps(d7, d7, distances);
1863
+
1864
+ __mmask16 comparison =
1865
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1866
+
1867
+ min_distances = _mm512_min_ps(distances, min_distances);
1868
+ min_indices = _mm512_mask_blend_epi32(
1869
+ comparison, min_indices, current_indices);
1870
+
1871
+ current_indices =
1872
+ _mm512_add_epi32(current_indices, indices_increment);
1873
+
1874
+ y += 128;
1875
+ }
1876
+
1877
+ alignas(64) float min_distances_scalar[16];
1878
+ alignas(64) uint32_t min_indices_scalar[16];
1879
+ _mm512_store_ps(min_distances_scalar, min_distances);
1880
+ _mm512_store_epi32(min_indices_scalar, min_indices);
1881
+
1882
+ for (size_t j = 0; j < 16; j++) {
1883
+ if (current_min_distance > min_distances_scalar[j]) {
1884
+ current_min_distance = min_distances_scalar[j];
1885
+ current_min_index = min_indices_scalar[j];
1886
+ }
1887
+ }
1888
+ }
1889
+
1890
+ if (i < ny) {
1891
+ __m256 x0 = _mm256_loadu_ps(x);
1023
1892
 
1024
- // compute y^2 - 2 * (x, y), which is sufficient for looking for the
1025
- // lowest distance.
1026
- const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
1027
- distances[i] = distance;
1893
+ for (; i < ny; i++) {
1894
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1895
+ y += 8;
1896
+ const float distance = horizontal_sum(accu);
1028
1897
 
1029
- y += 1;
1030
- y_sqlen += 1;
1898
+ if (current_min_distance > distance) {
1899
+ current_min_distance = distance;
1900
+ current_min_index = i;
1901
+ }
1031
1902
  }
1032
1903
  }
1033
- }
1034
- #endif
1035
-
1036
- void fvec_L2sqr_ny_transposed(
1037
- float* dis,
1038
- const float* x,
1039
- const float* y,
1040
- const float* y_sqlen,
1041
- size_t d,
1042
- size_t d_offset,
1043
- size_t ny) {
1044
- // optimized for a few special cases
1045
-
1046
- #ifdef __AVX2__
1047
- #define DISPATCH(dval) \
1048
- case dval: \
1049
- return fvec_L2sqr_ny_y_transposed_D<dval>( \
1050
- dis, x, y, y_sqlen, d_offset, ny);
1051
1904
 
1052
- switch (d) {
1053
- DISPATCH(1)
1054
- DISPATCH(2)
1055
- DISPATCH(4)
1056
- DISPATCH(8)
1057
- default:
1058
- return fvec_L2sqr_ny_y_transposed_ref(
1059
- dis, x, y, y_sqlen, d, d_offset, ny);
1060
- }
1061
- #undef DISPATCH
1062
- #else
1063
- // non-AVX2 case
1064
- return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1065
- #endif
1905
+ return current_min_index;
1066
1906
  }
1067
1907
 
1068
- #ifdef __AVX2__
1908
+ #elif defined(__AVX2__)
1069
1909
 
1070
1910
  size_t fvec_L2sqr_ny_nearest_D2(
1071
1911
  float* distances_tmp_buffer,
@@ -1476,7 +2316,123 @@ size_t fvec_L2sqr_ny_nearest(
1476
2316
  #undef DISPATCH
1477
2317
  }
1478
2318
 
1479
- #ifdef __AVX2__
2319
+ #if defined(__AVX512F__)
2320
+
2321
+ template <size_t DIM>
2322
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
2323
+ float* distances_tmp_buffer,
2324
+ const float* x,
2325
+ const float* y,
2326
+ const float* y_sqlen,
2327
+ const size_t d_offset,
2328
+ size_t ny) {
2329
+ // This implementation does not use distances_tmp_buffer.
2330
+
2331
+ // Current index being processed
2332
+ size_t i = 0;
2333
+
2334
+ // Min distance and the index of the closest vector so far
2335
+ float current_min_distance = HUGE_VALF;
2336
+ size_t current_min_index = 0;
2337
+
2338
+ // Process 16 vectors per loop
2339
+ const size_t ny16 = ny / 16;
2340
+
2341
+ if (ny16 > 0) {
2342
+ // Track min distance and the closest vector independently
2343
+ // for each of 16 AVX-512 components.
2344
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
2345
+ __m512i min_indices = _mm512_set1_epi32(0);
2346
+
2347
+ __m512i current_indices = _mm512_setr_epi32(
2348
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
2349
+ const __m512i indices_increment = _mm512_set1_epi32(16);
2350
+
2351
+ // m[i] = (2 * x[i], ... 2 * x[i])
2352
+ __m512 m[DIM];
2353
+ for (size_t j = 0; j < DIM; j++) {
2354
+ m[j] = _mm512_set1_ps(x[j]);
2355
+ m[j] = _mm512_add_ps(m[j], m[j]);
2356
+ }
2357
+
2358
+ for (; i < ny16 * 16; i += 16) {
2359
+ // Compute dot products
2360
+ const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
2361
+ __m512 dp = _mm512_mul_ps(m[0], v0);
2362
+ for (size_t j = 1; j < DIM; j++) {
2363
+ const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
2364
+ dp = _mm512_fmadd_ps(m[j], vj, dp);
2365
+ }
2366
+
2367
+ // Compute y^2 - (2 * x, y), which is sufficient for looking for the
2368
+ // lowest distance.
2369
+ // x^2 is the constant that can be avoided.
2370
+ const __m512 distances =
2371
+ _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
2372
+
2373
+ // Compare the new distances to the min distances
2374
+ __mmask16 comparison =
2375
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
2376
+
2377
+ // Update min distances and indices with closest vectors if needed
2378
+ min_distances =
2379
+ _mm512_mask_blend_ps(comparison, distances, min_distances);
2380
+ min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
2381
+ comparison,
2382
+ _mm512_castsi512_ps(current_indices),
2383
+ _mm512_castsi512_ps(min_indices)));
2384
+
2385
+ // Update current indices values. Basically, +16 to each of the 16
2386
+ // AVX-512 components.
2387
+ current_indices =
2388
+ _mm512_add_epi32(current_indices, indices_increment);
2389
+
2390
+ // Scroll y and y_sqlen forward.
2391
+ y += 16;
2392
+ y_sqlen += 16;
2393
+ }
2394
+
2395
+ // Dump values and find the minimum distance / minimum index
2396
+ float min_distances_scalar[16];
2397
+ uint32_t min_indices_scalar[16];
2398
+ _mm512_storeu_ps(min_distances_scalar, min_distances);
2399
+ _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
2400
+
2401
+ for (size_t j = 0; j < 16; j++) {
2402
+ if (current_min_distance > min_distances_scalar[j]) {
2403
+ current_min_distance = min_distances_scalar[j];
2404
+ current_min_index = min_indices_scalar[j];
2405
+ }
2406
+ }
2407
+ }
2408
+
2409
+ if (i < ny) {
2410
+ // Process leftovers
2411
+ for (; i < ny; i++) {
2412
+ float dp = 0;
2413
+ for (size_t j = 0; j < DIM; j++) {
2414
+ dp += x[j] * y[j * d_offset];
2415
+ }
2416
+
2417
+ // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
2418
+ // lowest distance.
2419
+ const float distance = y_sqlen[0] - 2 * dp;
2420
+
2421
+ if (current_min_distance > distance) {
2422
+ current_min_distance = distance;
2423
+ current_min_index = i;
2424
+ }
2425
+
2426
+ y += 1;
2427
+ y_sqlen += 1;
2428
+ }
2429
+ }
2430
+
2431
+ return current_min_index;
2432
+ }
2433
+
2434
+ #elif defined(__AVX2__)
2435
+
1480
2436
  template <size_t DIM>
1481
2437
  size_t fvec_L2sqr_ny_nearest_y_transposed_D(
1482
2438
  float* distances_tmp_buffer,
@@ -1592,6 +2548,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
1592
2548
 
1593
2549
  return current_min_index;
1594
2550
  }
2551
+
1595
2552
  #endif
1596
2553
 
1597
2554
  size_t fvec_L2sqr_ny_nearest_y_transposed(
@@ -1632,6 +2589,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
1632
2589
 
1633
2590
  float fvec_L1(const float* x, const float* y, size_t d) {
1634
2591
  __m256 msum1 = _mm256_setzero_ps();
2592
+ // signmask used for absolute value
1635
2593
  __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
1636
2594
 
1637
2595
  while (d >= 8) {
@@ -1639,7 +2597,9 @@ float fvec_L1(const float* x, const float* y, size_t d) {
1639
2597
  x += 8;
1640
2598
  __m256 my = _mm256_loadu_ps(y);
1641
2599
  y += 8;
2600
+ // subtract
1642
2601
  const __m256 a_m_b = _mm256_sub_ps(mx, my);
2602
+ // find sum of absolute value of distances (manhattan distance)
1643
2603
  msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
1644
2604
  d -= 8;
1645
2605
  }
@@ -1672,6 +2632,7 @@ float fvec_L1(const float* x, const float* y, size_t d) {
1672
2632
 
1673
2633
  float fvec_Linf(const float* x, const float* y, size_t d) {
1674
2634
  __m256 msum1 = _mm256_setzero_ps();
2635
+ // signmask used for absolute value
1675
2636
  __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
1676
2637
 
1677
2638
  while (d >= 8) {
@@ -1679,7 +2640,9 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
1679
2640
  x += 8;
1680
2641
  __m256 my = _mm256_loadu_ps(y);
1681
2642
  y += 8;
2643
+ // subtract
1682
2644
  const __m256 a_m_b = _mm256_sub_ps(mx, my);
2645
+ // find max of absolute value of distances (chebyshev distance)
1683
2646
  msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
1684
2647
  d -= 8;
1685
2648
  }
@@ -1720,6 +2683,441 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
1720
2683
  return fvec_Linf_ref(x, y, d);
1721
2684
  }
1722
2685
 
2686
+ #elif defined(__ARM_FEATURE_SVE)
2687
+
2688
+ struct ElementOpIP {
2689
+ static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) {
2690
+ return svmul_f32_x(pg, x, y);
2691
+ }
2692
+ static svfloat32_t merge(
2693
+ svbool_t pg,
2694
+ svfloat32_t z,
2695
+ svfloat32_t x,
2696
+ svfloat32_t y) {
2697
+ return svmla_f32_x(pg, z, x, y);
2698
+ }
2699
+ };
2700
+
2701
+ template <typename ElementOp>
2702
+ void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) {
2703
+ const size_t lanes = svcntw();
2704
+ const size_t lanes2 = lanes * 2;
2705
+ const size_t lanes3 = lanes * 3;
2706
+ const size_t lanes4 = lanes * 4;
2707
+ const svbool_t pg = svptrue_b32();
2708
+ const svfloat32_t x0 = svdup_n_f32(x[0]);
2709
+ size_t i = 0;
2710
+ for (; i + lanes4 < ny; i += lanes4) {
2711
+ svfloat32_t y0 = svld1_f32(pg, y);
2712
+ svfloat32_t y1 = svld1_f32(pg, y + lanes);
2713
+ svfloat32_t y2 = svld1_f32(pg, y + lanes2);
2714
+ svfloat32_t y3 = svld1_f32(pg, y + lanes3);
2715
+ y0 = ElementOp::op(pg, x0, y0);
2716
+ y1 = ElementOp::op(pg, x0, y1);
2717
+ y2 = ElementOp::op(pg, x0, y2);
2718
+ y3 = ElementOp::op(pg, x0, y3);
2719
+ svst1_f32(pg, dis, y0);
2720
+ svst1_f32(pg, dis + lanes, y1);
2721
+ svst1_f32(pg, dis + lanes2, y2);
2722
+ svst1_f32(pg, dis + lanes3, y3);
2723
+ y += lanes4;
2724
+ dis += lanes4;
2725
+ }
2726
+ const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2727
+ const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
2728
+ const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny);
2729
+ const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny);
2730
+ svfloat32_t y0 = svld1_f32(pg0, y);
2731
+ svfloat32_t y1 = svld1_f32(pg1, y + lanes);
2732
+ svfloat32_t y2 = svld1_f32(pg2, y + lanes2);
2733
+ svfloat32_t y3 = svld1_f32(pg3, y + lanes3);
2734
+ y0 = ElementOp::op(pg0, x0, y0);
2735
+ y1 = ElementOp::op(pg1, x0, y1);
2736
+ y2 = ElementOp::op(pg2, x0, y2);
2737
+ y3 = ElementOp::op(pg3, x0, y3);
2738
+ svst1_f32(pg0, dis, y0);
2739
+ svst1_f32(pg1, dis + lanes, y1);
2740
+ svst1_f32(pg2, dis + lanes2, y2);
2741
+ svst1_f32(pg3, dis + lanes3, y3);
2742
+ }
2743
+
2744
+ template <typename ElementOp>
2745
+ void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) {
2746
+ const size_t lanes = svcntw();
2747
+ const size_t lanes2 = lanes * 2;
2748
+ const size_t lanes4 = lanes * 4;
2749
+ const svbool_t pg = svptrue_b32();
2750
+ const svfloat32_t x0 = svdup_n_f32(x[0]);
2751
+ const svfloat32_t x1 = svdup_n_f32(x[1]);
2752
+ size_t i = 0;
2753
+ for (; i + lanes2 < ny; i += lanes2) {
2754
+ const svfloat32x2_t y0 = svld2_f32(pg, y);
2755
+ const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2);
2756
+ svfloat32_t y00 = svget2_f32(y0, 0);
2757
+ const svfloat32_t y01 = svget2_f32(y0, 1);
2758
+ svfloat32_t y10 = svget2_f32(y1, 0);
2759
+ const svfloat32_t y11 = svget2_f32(y1, 1);
2760
+ y00 = ElementOp::op(pg, x0, y00);
2761
+ y10 = ElementOp::op(pg, x0, y10);
2762
+ y00 = ElementOp::merge(pg, y00, x1, y01);
2763
+ y10 = ElementOp::merge(pg, y10, x1, y11);
2764
+ svst1_f32(pg, dis, y00);
2765
+ svst1_f32(pg, dis + lanes, y10);
2766
+ y += lanes4;
2767
+ dis += lanes2;
2768
+ }
2769
+ const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2770
+ const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny);
2771
+ const svfloat32x2_t y0 = svld2_f32(pg0, y);
2772
+ const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2);
2773
+ svfloat32_t y00 = svget2_f32(y0, 0);
2774
+ const svfloat32_t y01 = svget2_f32(y0, 1);
2775
+ svfloat32_t y10 = svget2_f32(y1, 0);
2776
+ const svfloat32_t y11 = svget2_f32(y1, 1);
2777
+ y00 = ElementOp::op(pg0, x0, y00);
2778
+ y10 = ElementOp::op(pg1, x0, y10);
2779
+ y00 = ElementOp::merge(pg0, y00, x1, y01);
2780
+ y10 = ElementOp::merge(pg1, y10, x1, y11);
2781
+ svst1_f32(pg0, dis, y00);
2782
+ svst1_f32(pg1, dis + lanes, y10);
2783
+ }
2784
+
2785
+ template <typename ElementOp>
2786
+ void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) {
2787
+ const size_t lanes = svcntw();
2788
+ const size_t lanes4 = lanes * 4;
2789
+ const svbool_t pg = svptrue_b32();
2790
+ const svfloat32_t x0 = svdup_n_f32(x[0]);
2791
+ const svfloat32_t x1 = svdup_n_f32(x[1]);
2792
+ const svfloat32_t x2 = svdup_n_f32(x[2]);
2793
+ const svfloat32_t x3 = svdup_n_f32(x[3]);
2794
+ size_t i = 0;
2795
+ for (; i + lanes < ny; i += lanes) {
2796
+ const svfloat32x4_t y0 = svld4_f32(pg, y);
2797
+ svfloat32_t y00 = svget4_f32(y0, 0);
2798
+ const svfloat32_t y01 = svget4_f32(y0, 1);
2799
+ svfloat32_t y02 = svget4_f32(y0, 2);
2800
+ const svfloat32_t y03 = svget4_f32(y0, 3);
2801
+ y00 = ElementOp::op(pg, x0, y00);
2802
+ y02 = ElementOp::op(pg, x2, y02);
2803
+ y00 = ElementOp::merge(pg, y00, x1, y01);
2804
+ y02 = ElementOp::merge(pg, y02, x3, y03);
2805
+ y00 = svadd_f32_x(pg, y00, y02);
2806
+ svst1_f32(pg, dis, y00);
2807
+ y += lanes4;
2808
+ dis += lanes;
2809
+ }
2810
+ const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2811
+ const svfloat32x4_t y0 = svld4_f32(pg0, y);
2812
+ svfloat32_t y00 = svget4_f32(y0, 0);
2813
+ const svfloat32_t y01 = svget4_f32(y0, 1);
2814
+ svfloat32_t y02 = svget4_f32(y0, 2);
2815
+ const svfloat32_t y03 = svget4_f32(y0, 3);
2816
+ y00 = ElementOp::op(pg0, x0, y00);
2817
+ y02 = ElementOp::op(pg0, x2, y02);
2818
+ y00 = ElementOp::merge(pg0, y00, x1, y01);
2819
+ y02 = ElementOp::merge(pg0, y02, x3, y03);
2820
+ y00 = svadd_f32_x(pg0, y00, y02);
2821
+ svst1_f32(pg0, dis, y00);
2822
+ }
2823
+
2824
+ template <typename ElementOp>
2825
+ void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) {
2826
+ const size_t lanes = svcntw();
2827
+ const size_t lanes4 = lanes * 4;
2828
+ const size_t lanes8 = lanes * 8;
2829
+ const svbool_t pg = svptrue_b32();
2830
+ const svfloat32_t x0 = svdup_n_f32(x[0]);
2831
+ const svfloat32_t x1 = svdup_n_f32(x[1]);
2832
+ const svfloat32_t x2 = svdup_n_f32(x[2]);
2833
+ const svfloat32_t x3 = svdup_n_f32(x[3]);
2834
+ const svfloat32_t x4 = svdup_n_f32(x[4]);
2835
+ const svfloat32_t x5 = svdup_n_f32(x[5]);
2836
+ const svfloat32_t x6 = svdup_n_f32(x[6]);
2837
+ const svfloat32_t x7 = svdup_n_f32(x[7]);
2838
+ size_t i = 0;
2839
+ for (; i + lanes < ny; i += lanes) {
2840
+ const svfloat32x4_t ya = svld4_f32(pg, y);
2841
+ const svfloat32x4_t yb = svld4_f32(pg, y + lanes4);
2842
+ const svfloat32_t ya0 = svget4_f32(ya, 0);
2843
+ const svfloat32_t ya1 = svget4_f32(ya, 1);
2844
+ const svfloat32_t ya2 = svget4_f32(ya, 2);
2845
+ const svfloat32_t ya3 = svget4_f32(ya, 3);
2846
+ const svfloat32_t yb0 = svget4_f32(yb, 0);
2847
+ const svfloat32_t yb1 = svget4_f32(yb, 1);
2848
+ const svfloat32_t yb2 = svget4_f32(yb, 2);
2849
+ const svfloat32_t yb3 = svget4_f32(yb, 3);
2850
+ svfloat32_t y0 = svuzp1(ya0, yb0);
2851
+ const svfloat32_t y1 = svuzp1(ya1, yb1);
2852
+ svfloat32_t y2 = svuzp1(ya2, yb2);
2853
+ const svfloat32_t y3 = svuzp1(ya3, yb3);
2854
+ svfloat32_t y4 = svuzp2(ya0, yb0);
2855
+ const svfloat32_t y5 = svuzp2(ya1, yb1);
2856
+ svfloat32_t y6 = svuzp2(ya2, yb2);
2857
+ const svfloat32_t y7 = svuzp2(ya3, yb3);
2858
+ y0 = ElementOp::op(pg, x0, y0);
2859
+ y2 = ElementOp::op(pg, x2, y2);
2860
+ y4 = ElementOp::op(pg, x4, y4);
2861
+ y6 = ElementOp::op(pg, x6, y6);
2862
+ y0 = ElementOp::merge(pg, y0, x1, y1);
2863
+ y2 = ElementOp::merge(pg, y2, x3, y3);
2864
+ y4 = ElementOp::merge(pg, y4, x5, y5);
2865
+ y6 = ElementOp::merge(pg, y6, x7, y7);
2866
+ y0 = svadd_f32_x(pg, y0, y2);
2867
+ y4 = svadd_f32_x(pg, y4, y6);
2868
+ y0 = svadd_f32_x(pg, y0, y4);
2869
+ svst1_f32(pg, dis, y0);
2870
+ y += lanes8;
2871
+ dis += lanes;
2872
+ }
2873
+ const svbool_t pg0 = svwhilelt_b32_u64(i, ny);
2874
+ const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2);
2875
+ const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2);
2876
+ const svfloat32x4_t ya = svld4_f32(pga, y);
2877
+ const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4);
2878
+ const svfloat32_t ya0 = svget4_f32(ya, 0);
2879
+ const svfloat32_t ya1 = svget4_f32(ya, 1);
2880
+ const svfloat32_t ya2 = svget4_f32(ya, 2);
2881
+ const svfloat32_t ya3 = svget4_f32(ya, 3);
2882
+ const svfloat32_t yb0 = svget4_f32(yb, 0);
2883
+ const svfloat32_t yb1 = svget4_f32(yb, 1);
2884
+ const svfloat32_t yb2 = svget4_f32(yb, 2);
2885
+ const svfloat32_t yb3 = svget4_f32(yb, 3);
2886
+ svfloat32_t y0 = svuzp1(ya0, yb0);
2887
+ const svfloat32_t y1 = svuzp1(ya1, yb1);
2888
+ svfloat32_t y2 = svuzp1(ya2, yb2);
2889
+ const svfloat32_t y3 = svuzp1(ya3, yb3);
2890
+ svfloat32_t y4 = svuzp2(ya0, yb0);
2891
+ const svfloat32_t y5 = svuzp2(ya1, yb1);
2892
+ svfloat32_t y6 = svuzp2(ya2, yb2);
2893
+ const svfloat32_t y7 = svuzp2(ya3, yb3);
2894
+ y0 = ElementOp::op(pg0, x0, y0);
2895
+ y2 = ElementOp::op(pg0, x2, y2);
2896
+ y4 = ElementOp::op(pg0, x4, y4);
2897
+ y6 = ElementOp::op(pg0, x6, y6);
2898
+ y0 = ElementOp::merge(pg0, y0, x1, y1);
2899
+ y2 = ElementOp::merge(pg0, y2, x3, y3);
2900
+ y4 = ElementOp::merge(pg0, y4, x5, y5);
2901
+ y6 = ElementOp::merge(pg0, y6, x7, y7);
2902
+ y0 = svadd_f32_x(pg0, y0, y2);
2903
+ y4 = svadd_f32_x(pg0, y4, y6);
2904
+ y0 = svadd_f32_x(pg0, y0, y4);
2905
+ svst1_f32(pg0, dis, y0);
2906
+ y += lanes8;
2907
+ dis += lanes;
2908
+ }
2909
+
2910
+ template <typename ElementOp>
2911
+ void fvec_op_ny_sve_lanes1(
2912
+ float* dis,
2913
+ const float* x,
2914
+ const float* y,
2915
+ size_t ny) {
2916
+ const size_t lanes = svcntw();
2917
+ const size_t lanes2 = lanes * 2;
2918
+ const size_t lanes3 = lanes * 3;
2919
+ const size_t lanes4 = lanes * 4;
2920
+ const svbool_t pg = svptrue_b32();
2921
+ const svfloat32_t x0 = svld1_f32(pg, x);
2922
+ size_t i = 0;
2923
+ for (; i + 3 < ny; i += 4) {
2924
+ svfloat32_t y0 = svld1_f32(pg, y);
2925
+ svfloat32_t y1 = svld1_f32(pg, y + lanes);
2926
+ svfloat32_t y2 = svld1_f32(pg, y + lanes2);
2927
+ svfloat32_t y3 = svld1_f32(pg, y + lanes3);
2928
+ y += lanes4;
2929
+ y0 = ElementOp::op(pg, x0, y0);
2930
+ y1 = ElementOp::op(pg, x0, y1);
2931
+ y2 = ElementOp::op(pg, x0, y2);
2932
+ y3 = ElementOp::op(pg, x0, y3);
2933
+ dis[i] = svaddv_f32(pg, y0);
2934
+ dis[i + 1] = svaddv_f32(pg, y1);
2935
+ dis[i + 2] = svaddv_f32(pg, y2);
2936
+ dis[i + 3] = svaddv_f32(pg, y3);
2937
+ }
2938
+ for (; i < ny; ++i) {
2939
+ svfloat32_t y0 = svld1_f32(pg, y);
2940
+ y += lanes;
2941
+ y0 = ElementOp::op(pg, x0, y0);
2942
+ dis[i] = svaddv_f32(pg, y0);
2943
+ }
2944
+ }
2945
+
2946
+ template <typename ElementOp>
2947
+ void fvec_op_ny_sve_lanes2(
2948
+ float* dis,
2949
+ const float* x,
2950
+ const float* y,
2951
+ size_t ny) {
2952
+ const size_t lanes = svcntw();
2953
+ const size_t lanes2 = lanes * 2;
2954
+ const size_t lanes3 = lanes * 3;
2955
+ const size_t lanes4 = lanes * 4;
2956
+ const svbool_t pg = svptrue_b32();
2957
+ const svfloat32_t x0 = svld1_f32(pg, x);
2958
+ const svfloat32_t x1 = svld1_f32(pg, x + lanes);
2959
+ size_t i = 0;
2960
+ for (; i + 1 < ny; i += 2) {
2961
+ svfloat32_t y00 = svld1_f32(pg, y);
2962
+ const svfloat32_t y01 = svld1_f32(pg, y + lanes);
2963
+ svfloat32_t y10 = svld1_f32(pg, y + lanes2);
2964
+ const svfloat32_t y11 = svld1_f32(pg, y + lanes3);
2965
+ y += lanes4;
2966
+ y00 = ElementOp::op(pg, x0, y00);
2967
+ y10 = ElementOp::op(pg, x0, y10);
2968
+ y00 = ElementOp::merge(pg, y00, x1, y01);
2969
+ y10 = ElementOp::merge(pg, y10, x1, y11);
2970
+ dis[i] = svaddv_f32(pg, y00);
2971
+ dis[i + 1] = svaddv_f32(pg, y10);
2972
+ }
2973
+ if (i < ny) {
2974
+ svfloat32_t y0 = svld1_f32(pg, y);
2975
+ const svfloat32_t y1 = svld1_f32(pg, y + lanes);
2976
+ y0 = ElementOp::op(pg, x0, y0);
2977
+ y0 = ElementOp::merge(pg, y0, x1, y1);
2978
+ dis[i] = svaddv_f32(pg, y0);
2979
+ }
2980
+ }
2981
+
2982
+ template <typename ElementOp>
2983
+ void fvec_op_ny_sve_lanes3(
2984
+ float* dis,
2985
+ const float* x,
2986
+ const float* y,
2987
+ size_t ny) {
2988
+ const size_t lanes = svcntw();
2989
+ const size_t lanes2 = lanes * 2;
2990
+ const size_t lanes3 = lanes * 3;
2991
+ const svbool_t pg = svptrue_b32();
2992
+ const svfloat32_t x0 = svld1_f32(pg, x);
2993
+ const svfloat32_t x1 = svld1_f32(pg, x + lanes);
2994
+ const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
2995
+ for (size_t i = 0; i < ny; ++i) {
2996
+ svfloat32_t y0 = svld1_f32(pg, y);
2997
+ const svfloat32_t y1 = svld1_f32(pg, y + lanes);
2998
+ svfloat32_t y2 = svld1_f32(pg, y + lanes2);
2999
+ y += lanes3;
3000
+ y0 = ElementOp::op(pg, x0, y0);
3001
+ y0 = ElementOp::merge(pg, y0, x1, y1);
3002
+ y0 = ElementOp::merge(pg, y0, x2, y2);
3003
+ dis[i] = svaddv_f32(pg, y0);
3004
+ }
3005
+ }
3006
+
3007
+ template <typename ElementOp>
3008
+ void fvec_op_ny_sve_lanes4(
3009
+ float* dis,
3010
+ const float* x,
3011
+ const float* y,
3012
+ size_t ny) {
3013
+ const size_t lanes = svcntw();
3014
+ const size_t lanes2 = lanes * 2;
3015
+ const size_t lanes3 = lanes * 3;
3016
+ const size_t lanes4 = lanes * 4;
3017
+ const svbool_t pg = svptrue_b32();
3018
+ const svfloat32_t x0 = svld1_f32(pg, x);
3019
+ const svfloat32_t x1 = svld1_f32(pg, x + lanes);
3020
+ const svfloat32_t x2 = svld1_f32(pg, x + lanes2);
3021
+ const svfloat32_t x3 = svld1_f32(pg, x + lanes3);
3022
+ for (size_t i = 0; i < ny; ++i) {
3023
+ svfloat32_t y0 = svld1_f32(pg, y);
3024
+ const svfloat32_t y1 = svld1_f32(pg, y + lanes);
3025
+ svfloat32_t y2 = svld1_f32(pg, y + lanes2);
3026
+ const svfloat32_t y3 = svld1_f32(pg, y + lanes3);
3027
+ y += lanes4;
3028
+ y0 = ElementOp::op(pg, x0, y0);
3029
+ y2 = ElementOp::op(pg, x2, y2);
3030
+ y0 = ElementOp::merge(pg, y0, x1, y1);
3031
+ y2 = ElementOp::merge(pg, y2, x3, y3);
3032
+ y0 = svadd_f32_x(pg, y0, y2);
3033
+ dis[i] = svaddv_f32(pg, y0);
3034
+ }
3035
+ }
3036
+
3037
+ void fvec_L2sqr_ny(
3038
+ float* dis,
3039
+ const float* x,
3040
+ const float* y,
3041
+ size_t d,
3042
+ size_t ny) {
3043
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
3044
+ }
3045
+
3046
+ void fvec_L2sqr_ny_transposed(
3047
+ float* dis,
3048
+ const float* x,
3049
+ const float* y,
3050
+ const float* y_sqlen,
3051
+ size_t d,
3052
+ size_t d_offset,
3053
+ size_t ny) {
3054
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
3055
+ }
3056
+
3057
+ size_t fvec_L2sqr_ny_nearest(
3058
+ float* distances_tmp_buffer,
3059
+ const float* x,
3060
+ const float* y,
3061
+ size_t d,
3062
+ size_t ny) {
3063
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
3064
+ }
3065
+
3066
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
3067
+ float* distances_tmp_buffer,
3068
+ const float* x,
3069
+ const float* y,
3070
+ const float* y_sqlen,
3071
+ size_t d,
3072
+ size_t d_offset,
3073
+ size_t ny) {
3074
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
3075
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
3076
+ }
3077
+
3078
+ float fvec_L1(const float* x, const float* y, size_t d) {
3079
+ return fvec_L1_ref(x, y, d);
3080
+ }
3081
+
3082
+ float fvec_Linf(const float* x, const float* y, size_t d) {
3083
+ return fvec_Linf_ref(x, y, d);
3084
+ }
3085
+
3086
+ void fvec_inner_products_ny(
3087
+ float* dis,
3088
+ const float* x,
3089
+ const float* y,
3090
+ size_t d,
3091
+ size_t ny) {
3092
+ const size_t lanes = svcntw();
3093
+ switch (d) {
3094
+ case 1:
3095
+ fvec_op_ny_sve_d1<ElementOpIP>(dis, x, y, ny);
3096
+ break;
3097
+ case 2:
3098
+ fvec_op_ny_sve_d2<ElementOpIP>(dis, x, y, ny);
3099
+ break;
3100
+ case 4:
3101
+ fvec_op_ny_sve_d4<ElementOpIP>(dis, x, y, ny);
3102
+ break;
3103
+ case 8:
3104
+ fvec_op_ny_sve_d8<ElementOpIP>(dis, x, y, ny);
3105
+ break;
3106
+ default:
3107
+ if (d == lanes)
3108
+ fvec_op_ny_sve_lanes1<ElementOpIP>(dis, x, y, ny);
3109
+ else if (d == lanes * 2)
3110
+ fvec_op_ny_sve_lanes2<ElementOpIP>(dis, x, y, ny);
3111
+ else if (d == lanes * 3)
3112
+ fvec_op_ny_sve_lanes3<ElementOpIP>(dis, x, y, ny);
3113
+ else if (d == lanes * 4)
3114
+ fvec_op_ny_sve_lanes4<ElementOpIP>(dis, x, y, ny);
3115
+ else
3116
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
3117
+ break;
3118
+ }
3119
+ }
3120
+
1723
3121
  #elif defined(__aarch64__)
1724
3122
 
1725
3123
  // not optimized for ARM
@@ -1858,7 +3256,39 @@ void fvec_inner_products_ny(
1858
3256
  c[i] = a[i] + bf * b[i];
1859
3257
  }
1860
3258
 
1861
- #ifdef __AVX2__
3259
+ #if defined(__AVX512F__)
3260
+
3261
+ static inline void fvec_madd_avx512(
3262
+ const size_t n,
3263
+ const float* __restrict a,
3264
+ const float bf,
3265
+ const float* __restrict b,
3266
+ float* __restrict c) {
3267
+ const size_t n16 = n / 16;
3268
+ const size_t n_for_masking = n % 16;
3269
+
3270
+ const __m512 bfmm = _mm512_set1_ps(bf);
3271
+
3272
+ size_t idx = 0;
3273
+ for (idx = 0; idx < n16 * 16; idx += 16) {
3274
+ const __m512 ax = _mm512_loadu_ps(a + idx);
3275
+ const __m512 bx = _mm512_loadu_ps(b + idx);
3276
+ const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
3277
+ _mm512_storeu_ps(c + idx, abmul);
3278
+ }
3279
+
3280
+ if (n_for_masking > 0) {
3281
+ const __mmask16 mask = (1 << n_for_masking) - 1;
3282
+
3283
+ const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
3284
+ const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
3285
+ const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
3286
+ _mm512_mask_storeu_ps(c + idx, mask, abmul);
3287
+ }
3288
+ }
3289
+
3290
+ #elif defined(__AVX2__)
3291
+
1862
3292
  static inline void fvec_madd_avx2(
1863
3293
  const size_t n,
1864
3294
  const float* __restrict a,
@@ -1911,6 +3341,7 @@ static inline void fvec_madd_avx2(
1911
3341
  _mm256_maskstore_ps(c + idx, mask, abmul);
1912
3342
  }
1913
3343
  }
3344
+
1914
3345
  #endif
1915
3346
 
1916
3347
  #ifdef __SSE3__
@@ -1936,7 +3367,9 @@ static inline void fvec_madd_avx2(
1936
3367
  }
1937
3368
 
1938
3369
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1939
- #ifdef __AVX2__
3370
+ #ifdef __AVX512F__
3371
+ fvec_madd_avx512(n, a, bf, b, c);
3372
+ #elif __AVX2__
1940
3373
  fvec_madd_avx2(n, a, bf, b, c);
1941
3374
  #else
1942
3375
  if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
@@ -1946,6 +3379,60 @@ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1946
3379
  #endif
1947
3380
  }
1948
3381
 
3382
+ #elif defined(__ARM_FEATURE_SVE)
3383
+
3384
+ void fvec_madd(
3385
+ const size_t n,
3386
+ const float* __restrict a,
3387
+ const float bf,
3388
+ const float* __restrict b,
3389
+ float* __restrict c) {
3390
+ const size_t lanes = static_cast<size_t>(svcntw());
3391
+ const size_t lanes2 = lanes * 2;
3392
+ const size_t lanes3 = lanes * 3;
3393
+ const size_t lanes4 = lanes * 4;
3394
+ size_t i = 0;
3395
+ for (; i + lanes4 < n; i += lanes4) {
3396
+ const auto mask = svptrue_b32();
3397
+ const auto ai0 = svld1_f32(mask, a + i);
3398
+ const auto ai1 = svld1_f32(mask, a + i + lanes);
3399
+ const auto ai2 = svld1_f32(mask, a + i + lanes2);
3400
+ const auto ai3 = svld1_f32(mask, a + i + lanes3);
3401
+ const auto bi0 = svld1_f32(mask, b + i);
3402
+ const auto bi1 = svld1_f32(mask, b + i + lanes);
3403
+ const auto bi2 = svld1_f32(mask, b + i + lanes2);
3404
+ const auto bi3 = svld1_f32(mask, b + i + lanes3);
3405
+ const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf);
3406
+ const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf);
3407
+ const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf);
3408
+ const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf);
3409
+ svst1_f32(mask, c + i, ci0);
3410
+ svst1_f32(mask, c + i + lanes, ci1);
3411
+ svst1_f32(mask, c + i + lanes2, ci2);
3412
+ svst1_f32(mask, c + i + lanes3, ci3);
3413
+ }
3414
+ const auto mask0 = svwhilelt_b32_u64(i, n);
3415
+ const auto mask1 = svwhilelt_b32_u64(i + lanes, n);
3416
+ const auto mask2 = svwhilelt_b32_u64(i + lanes2, n);
3417
+ const auto mask3 = svwhilelt_b32_u64(i + lanes3, n);
3418
+ const auto ai0 = svld1_f32(mask0, a + i);
3419
+ const auto ai1 = svld1_f32(mask1, a + i + lanes);
3420
+ const auto ai2 = svld1_f32(mask2, a + i + lanes2);
3421
+ const auto ai3 = svld1_f32(mask3, a + i + lanes3);
3422
+ const auto bi0 = svld1_f32(mask0, b + i);
3423
+ const auto bi1 = svld1_f32(mask1, b + i + lanes);
3424
+ const auto bi2 = svld1_f32(mask2, b + i + lanes2);
3425
+ const auto bi3 = svld1_f32(mask3, b + i + lanes3);
3426
+ const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf);
3427
+ const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf);
3428
+ const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf);
3429
+ const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf);
3430
+ svst1_f32(mask0, c + i, ci0);
3431
+ svst1_f32(mask1, c + i + lanes, ci1);
3432
+ svst1_f32(mask2, c + i + lanes2, ci2);
3433
+ svst1_f32(mask3, c + i + lanes3, ci3);
3434
+ }
3435
+
1949
3436
  #elif defined(__aarch64__)
1950
3437
 
1951
3438
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
@@ -2278,7 +3765,7 @@ void fvec_add(size_t d, const float* a, float b, float* c) {
2278
3765
  size_t i;
2279
3766
  simd8float32 bv(b);
2280
3767
  for (i = 0; i + 7 < d; i += 8) {
2281
- simd8float32 ci, ai, bi;
3768
+ simd8float32 ci, ai;
2282
3769
  ai.loadu(a + i);
2283
3770
  ci = ai + bv;
2284
3771
  ci.storeu(c + i);