faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -0,0 +1,1095 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/utils/distances.h>
9
+
10
+ #include <immintrin.h>
11
+
12
+ #define THE_SIMD_LEVEL SIMDLevel::AVX512
13
+ #include <faiss/utils/simd_impl/distances_autovec-inl.h>
14
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
15
+ #include <faiss/utils/simd_impl/IVFFlatScanner-inl.h>
16
+
17
+ #include <faiss/utils/simd_impl/distances_sse-inl.h>
18
+ #include <faiss/utils/transpose/transpose-avx512-inl.h>
19
+
20
+ namespace faiss {
21
+
22
+ template <>
23
+ void fvec_madd<SIMDLevel::AVX512>(
24
+ const size_t n,
25
+ const float* a,
26
+ const float bf,
27
+ const float* b,
28
+ float* c) {
29
+ const size_t n16 = n / 16;
30
+ const size_t n_for_masking = n % 16;
31
+
32
+ const __m512 bfmm = _mm512_set1_ps(bf);
33
+
34
+ size_t idx = 0;
35
+ for (idx = 0; idx < n16 * 16; idx += 16) {
36
+ const __m512 ax = _mm512_loadu_ps(a + idx);
37
+ const __m512 bx = _mm512_loadu_ps(b + idx);
38
+ const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
39
+ _mm512_storeu_ps(c + idx, abmul);
40
+ }
41
+
42
+ if (n_for_masking > 0) {
43
+ const __mmask16 mask = (1 << n_for_masking) - 1;
44
+
45
+ const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
46
+ const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
47
+ const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
48
+ _mm512_mask_storeu_ps(c + idx, mask, abmul);
49
+ }
50
+ }
51
+
52
+ template <size_t DIM>
53
+ void fvec_L2sqr_ny_y_transposed_D(
54
+ float* distances,
55
+ const float* x,
56
+ const float* y,
57
+ const float* y_sqlen,
58
+ const size_t d_offset,
59
+ size_t ny) {
60
+ // current index being processed
61
+ size_t i = 0;
62
+
63
+ // squared length of x
64
+ float x_sqlen = 0;
65
+ for (size_t j = 0; j < DIM; j++) {
66
+ x_sqlen += x[j] * x[j];
67
+ }
68
+
69
+ // process 16 vectors per loop
70
+ const size_t ny16 = ny / 16;
71
+
72
+ if (ny16 > 0) {
73
+ // m[i] = (2 * x[i], ... 2 * x[i])
74
+ __m512 m[DIM];
75
+ for (size_t j = 0; j < DIM; j++) {
76
+ m[j] = _mm512_set1_ps(x[j]);
77
+ m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
78
+ }
79
+
80
+ __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
81
+
82
+ for (; i < ny16 * 16; i += 16) {
83
+ // Load vectors for 16 dimensions
84
+ __m512 v[DIM];
85
+ for (size_t j = 0; j < DIM; j++) {
86
+ v[j] = _mm512_loadu_ps(y + j * d_offset);
87
+ }
88
+
89
+ // Compute dot products
90
+ __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
91
+ for (size_t j = 1; j < DIM; j++) {
92
+ dp = _mm512_fnmadd_ps(m[j], v[j], dp);
93
+ }
94
+
95
+ // Compute y^2 - (2 * x, y) + x^2
96
+ __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
97
+
98
+ _mm512_storeu_ps(distances + i, distances_v);
99
+
100
+ // Scroll y and y_sqlen forward
101
+ y += 16;
102
+ y_sqlen += 16;
103
+ }
104
+ }
105
+
106
+ if (i < ny) {
107
+ // Process leftovers
108
+ for (; i < ny; i++) {
109
+ float dp = 0;
110
+ for (size_t j = 0; j < DIM; j++) {
111
+ dp += x[j] * y[j * d_offset];
112
+ }
113
+
114
+ // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
115
+ // lowest distance.
116
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
117
+ distances[i] = distance;
118
+
119
+ y += 1;
120
+ y_sqlen += 1;
121
+ }
122
+ }
123
+ }
124
+
125
+ template <>
126
+ void fvec_L2sqr_ny_transposed<SIMDLevel::AVX512>(
127
+ float* dis,
128
+ const float* x,
129
+ const float* y,
130
+ const float* y_sqlen,
131
+ size_t d,
132
+ size_t d_offset,
133
+ size_t ny) {
134
+ // optimized for a few special cases
135
+ #define DISPATCH(dval) \
136
+ case dval: \
137
+ return fvec_L2sqr_ny_y_transposed_D<dval>( \
138
+ dis, x, y, y_sqlen, d_offset, ny);
139
+
140
+ switch (d) {
141
+ DISPATCH(1)
142
+ DISPATCH(2)
143
+ DISPATCH(4)
144
+ DISPATCH(8)
145
+ default:
146
+ return fvec_L2sqr_ny_transposed<SIMDLevel::NONE>(
147
+ dis, x, y, y_sqlen, d, d_offset, ny);
148
+ }
149
+ #undef DISPATCH
150
+ }
151
+
152
+ struct AVX512ElementOpIP : public ElementOpIP {
153
+ using ElementOpIP::op;
154
+ static __m512 op(__m512 x, __m512 y) {
155
+ return _mm512_mul_ps(x, y);
156
+ }
157
+ static __m256 op(__m256 x, __m256 y) {
158
+ return _mm256_mul_ps(x, y);
159
+ }
160
+ };
161
+
162
+ struct AVX512ElementOpL2 : public ElementOpL2 {
163
+ using ElementOpL2::op;
164
+ static __m512 op(__m512 x, __m512 y) {
165
+ __m512 tmp = _mm512_sub_ps(x, y);
166
+ return _mm512_mul_ps(tmp, tmp);
167
+ }
168
+ static __m256 op(__m256 x, __m256 y) {
169
+ __m256 tmp = _mm256_sub_ps(x, y);
170
+ return _mm256_mul_ps(tmp, tmp);
171
+ }
172
+ };
173
+
174
+ /// helper function for AVX512
175
+ inline float horizontal_sum(const __m512 v) {
176
+ // performs better than adding the high and low parts
177
+ return _mm512_reduce_add_ps(v);
178
+ }
179
+
180
+ inline float horizontal_sum(const __m256 v) {
181
+ // add high and low parts
182
+ const __m128 v0 =
183
+ _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
184
+ // perform horizontal sum on v0
185
+ return horizontal_sum(v0);
186
+ }
187
+
188
+ template <>
189
+ void fvec_op_ny_D2<AVX512ElementOpIP>(
190
+ float* dis,
191
+ const float* x,
192
+ const float* y,
193
+ size_t ny) {
194
+ const size_t ny16 = ny / 16;
195
+ size_t i = 0;
196
+
197
+ if (ny16 > 0) {
198
+ // process 16 D2-vectors per loop.
199
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
200
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
201
+
202
+ const __m512 m0 = _mm512_set1_ps(x[0]);
203
+ const __m512 m1 = _mm512_set1_ps(x[1]);
204
+
205
+ for (i = 0; i < ny16 * 16; i += 16) {
206
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
207
+
208
+ // load 16x2 matrix and transpose it in registers.
209
+ // the typical bottleneck is memory access, so
210
+ // let's trade instructions for the bandwidth.
211
+
212
+ __m512 v0;
213
+ __m512 v1;
214
+
215
+ transpose_16x2(
216
+ _mm512_loadu_ps(y + 0 * 16),
217
+ _mm512_loadu_ps(y + 1 * 16),
218
+ v0,
219
+ v1);
220
+
221
+ // compute distances (dot product)
222
+ __m512 distances = _mm512_mul_ps(m0, v0);
223
+ distances = _mm512_fmadd_ps(m1, v1, distances);
224
+
225
+ // store
226
+ _mm512_storeu_ps(dis + i, distances);
227
+
228
+ y += 32; // move to the next set of 16x2 elements
229
+ }
230
+ }
231
+
232
+ if (i < ny) {
233
+ // process leftovers
234
+ float x0 = x[0];
235
+ float x1 = x[1];
236
+
237
+ for (; i < ny; i++) {
238
+ float distance = x0 * y[0] + x1 * y[1];
239
+ y += 2;
240
+ dis[i] = distance;
241
+ }
242
+ }
243
+ }
244
+
245
+ template <>
246
+ void fvec_op_ny_D2<AVX512ElementOpL2>(
247
+ float* dis,
248
+ const float* x,
249
+ const float* y,
250
+ size_t ny) {
251
+ const size_t ny16 = ny / 16;
252
+ size_t i = 0;
253
+
254
+ if (ny16 > 0) {
255
+ // process 16 D2-vectors per loop.
256
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
257
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
258
+
259
+ const __m512 m0 = _mm512_set1_ps(x[0]);
260
+ const __m512 m1 = _mm512_set1_ps(x[1]);
261
+
262
+ for (i = 0; i < ny16 * 16; i += 16) {
263
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
264
+
265
+ // load 16x2 matrix and transpose it in registers.
266
+ // the typical bottleneck is memory access, so
267
+ // let's trade instructions for the bandwidth.
268
+
269
+ __m512 v0;
270
+ __m512 v1;
271
+
272
+ transpose_16x2(
273
+ _mm512_loadu_ps(y + 0 * 16),
274
+ _mm512_loadu_ps(y + 1 * 16),
275
+ v0,
276
+ v1);
277
+
278
+ // compute differences
279
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
280
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
281
+
282
+ // compute squares of differences
283
+ __m512 distances = _mm512_mul_ps(d0, d0);
284
+ distances = _mm512_fmadd_ps(d1, d1, distances);
285
+
286
+ // store
287
+ _mm512_storeu_ps(dis + i, distances);
288
+
289
+ y += 32; // move to the next set of 16x2 elements
290
+ }
291
+ }
292
+
293
+ if (i < ny) {
294
+ // process leftovers
295
+ float x0 = x[0];
296
+ float x1 = x[1];
297
+
298
+ for (; i < ny; i++) {
299
+ float sub0 = x0 - y[0];
300
+ float sub1 = x1 - y[1];
301
+ float distance = sub0 * sub0 + sub1 * sub1;
302
+
303
+ y += 2;
304
+ dis[i] = distance;
305
+ }
306
+ }
307
+ }
308
+
309
+ template <>
310
+ void fvec_op_ny_D4<AVX512ElementOpIP>(
311
+ float* dis,
312
+ const float* x,
313
+ const float* y,
314
+ size_t ny) {
315
+ const size_t ny16 = ny / 16;
316
+ size_t i = 0;
317
+
318
+ if (ny16 > 0) {
319
+ // process 16 D4-vectors per loop.
320
+ const __m512 m0 = _mm512_set1_ps(x[0]);
321
+ const __m512 m1 = _mm512_set1_ps(x[1]);
322
+ const __m512 m2 = _mm512_set1_ps(x[2]);
323
+ const __m512 m3 = _mm512_set1_ps(x[3]);
324
+
325
+ for (i = 0; i < ny16 * 16; i += 16) {
326
+ // load 16x4 matrix and transpose it in registers.
327
+ // the typical bottleneck is memory access, so
328
+ // let's trade instructions for the bandwidth.
329
+
330
+ __m512 v0;
331
+ __m512 v1;
332
+ __m512 v2;
333
+ __m512 v3;
334
+
335
+ transpose_16x4(
336
+ _mm512_loadu_ps(y + 0 * 16),
337
+ _mm512_loadu_ps(y + 1 * 16),
338
+ _mm512_loadu_ps(y + 2 * 16),
339
+ _mm512_loadu_ps(y + 3 * 16),
340
+ v0,
341
+ v1,
342
+ v2,
343
+ v3);
344
+
345
+ // compute distances
346
+ __m512 distances = _mm512_mul_ps(m0, v0);
347
+ distances = _mm512_fmadd_ps(m1, v1, distances);
348
+ distances = _mm512_fmadd_ps(m2, v2, distances);
349
+ distances = _mm512_fmadd_ps(m3, v3, distances);
350
+
351
+ // store
352
+ _mm512_storeu_ps(dis + i, distances);
353
+
354
+ y += 64; // move to the next set of 16x4 elements
355
+ }
356
+ }
357
+
358
+ if (i < ny) {
359
+ // process leftovers
360
+ __m128 x0 = _mm_loadu_ps(x);
361
+
362
+ for (; i < ny; i++) {
363
+ __m128 accu = AVX512ElementOpIP::op(x0, _mm_loadu_ps(y));
364
+ y += 4;
365
+ dis[i] = horizontal_sum(accu);
366
+ }
367
+ }
368
+ }
369
+
370
+ template <>
371
+ void fvec_op_ny_D4<AVX512ElementOpL2>(
372
+ float* dis,
373
+ const float* x,
374
+ const float* y,
375
+ size_t ny) {
376
+ const size_t ny16 = ny / 16;
377
+ size_t i = 0;
378
+
379
+ if (ny16 > 0) {
380
+ // process 16 D4-vectors per loop.
381
+ const __m512 m0 = _mm512_set1_ps(x[0]);
382
+ const __m512 m1 = _mm512_set1_ps(x[1]);
383
+ const __m512 m2 = _mm512_set1_ps(x[2]);
384
+ const __m512 m3 = _mm512_set1_ps(x[3]);
385
+
386
+ for (i = 0; i < ny16 * 16; i += 16) {
387
+ // load 16x4 matrix and transpose it in registers.
388
+ // the typical bottleneck is memory access, so
389
+ // let's trade instructions for the bandwidth.
390
+
391
+ __m512 v0;
392
+ __m512 v1;
393
+ __m512 v2;
394
+ __m512 v3;
395
+
396
+ transpose_16x4(
397
+ _mm512_loadu_ps(y + 0 * 16),
398
+ _mm512_loadu_ps(y + 1 * 16),
399
+ _mm512_loadu_ps(y + 2 * 16),
400
+ _mm512_loadu_ps(y + 3 * 16),
401
+ v0,
402
+ v1,
403
+ v2,
404
+ v3);
405
+
406
+ // compute differences
407
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
408
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
409
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
410
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
411
+
412
+ // compute squares of differences
413
+ __m512 distances = _mm512_mul_ps(d0, d0);
414
+ distances = _mm512_fmadd_ps(d1, d1, distances);
415
+ distances = _mm512_fmadd_ps(d2, d2, distances);
416
+ distances = _mm512_fmadd_ps(d3, d3, distances);
417
+
418
+ // store
419
+ _mm512_storeu_ps(dis + i, distances);
420
+
421
+ y += 64; // move to the next set of 16x4 elements
422
+ }
423
+ }
424
+
425
+ if (i < ny) {
426
+ // process leftovers
427
+ __m128 x0 = _mm_loadu_ps(x);
428
+
429
+ for (; i < ny; i++) {
430
+ __m128 accu = AVX512ElementOpL2::op(x0, _mm_loadu_ps(y));
431
+ y += 4;
432
+ dis[i] = horizontal_sum(accu);
433
+ }
434
+ }
435
+ }
436
+
437
+ template <>
438
+ void fvec_op_ny_D8<AVX512ElementOpIP>(
439
+ float* dis,
440
+ const float* x,
441
+ const float* y,
442
+ size_t ny) {
443
+ const size_t ny16 = ny / 16;
444
+ size_t i = 0;
445
+
446
+ if (ny16 > 0) {
447
+ // process 16 D16-vectors per loop.
448
+ const __m512 m0 = _mm512_set1_ps(x[0]);
449
+ const __m512 m1 = _mm512_set1_ps(x[1]);
450
+ const __m512 m2 = _mm512_set1_ps(x[2]);
451
+ const __m512 m3 = _mm512_set1_ps(x[3]);
452
+ const __m512 m4 = _mm512_set1_ps(x[4]);
453
+ const __m512 m5 = _mm512_set1_ps(x[5]);
454
+ const __m512 m6 = _mm512_set1_ps(x[6]);
455
+ const __m512 m7 = _mm512_set1_ps(x[7]);
456
+
457
+ for (i = 0; i < ny16 * 16; i += 16) {
458
+ // load 16x8 matrix and transpose it in registers.
459
+ // the typical bottleneck is memory access, so
460
+ // let's trade instructions for the bandwidth.
461
+
462
+ __m512 v0;
463
+ __m512 v1;
464
+ __m512 v2;
465
+ __m512 v3;
466
+ __m512 v4;
467
+ __m512 v5;
468
+ __m512 v6;
469
+ __m512 v7;
470
+
471
+ transpose_16x8(
472
+ _mm512_loadu_ps(y + 0 * 16),
473
+ _mm512_loadu_ps(y + 1 * 16),
474
+ _mm512_loadu_ps(y + 2 * 16),
475
+ _mm512_loadu_ps(y + 3 * 16),
476
+ _mm512_loadu_ps(y + 4 * 16),
477
+ _mm512_loadu_ps(y + 5 * 16),
478
+ _mm512_loadu_ps(y + 6 * 16),
479
+ _mm512_loadu_ps(y + 7 * 16),
480
+ v0,
481
+ v1,
482
+ v2,
483
+ v3,
484
+ v4,
485
+ v5,
486
+ v6,
487
+ v7);
488
+
489
+ // compute distances
490
+ __m512 distances = _mm512_mul_ps(m0, v0);
491
+ distances = _mm512_fmadd_ps(m1, v1, distances);
492
+ distances = _mm512_fmadd_ps(m2, v2, distances);
493
+ distances = _mm512_fmadd_ps(m3, v3, distances);
494
+ distances = _mm512_fmadd_ps(m4, v4, distances);
495
+ distances = _mm512_fmadd_ps(m5, v5, distances);
496
+ distances = _mm512_fmadd_ps(m6, v6, distances);
497
+ distances = _mm512_fmadd_ps(m7, v7, distances);
498
+
499
+ // store
500
+ _mm512_storeu_ps(dis + i, distances);
501
+
502
+ y += 128; // 16 floats * 8 rows
503
+ }
504
+ }
505
+
506
+ if (i < ny) {
507
+ // process leftovers
508
+ __m256 x0 = _mm256_loadu_ps(x);
509
+
510
+ for (; i < ny; i++) {
511
+ __m256 accu = AVX512ElementOpIP::op(x0, _mm256_loadu_ps(y));
512
+ y += 8;
513
+ dis[i] = horizontal_sum(accu);
514
+ }
515
+ }
516
+ }
517
+
518
+ template <>
519
+ void fvec_op_ny_D8<AVX512ElementOpL2>(
520
+ float* dis,
521
+ const float* x,
522
+ const float* y,
523
+ size_t ny) {
524
+ const size_t ny16 = ny / 16;
525
+ size_t i = 0;
526
+
527
+ if (ny16 > 0) {
528
+ // process 16 D16-vectors per loop.
529
+ const __m512 m0 = _mm512_set1_ps(x[0]);
530
+ const __m512 m1 = _mm512_set1_ps(x[1]);
531
+ const __m512 m2 = _mm512_set1_ps(x[2]);
532
+ const __m512 m3 = _mm512_set1_ps(x[3]);
533
+ const __m512 m4 = _mm512_set1_ps(x[4]);
534
+ const __m512 m5 = _mm512_set1_ps(x[5]);
535
+ const __m512 m6 = _mm512_set1_ps(x[6]);
536
+ const __m512 m7 = _mm512_set1_ps(x[7]);
537
+
538
+ for (i = 0; i < ny16 * 16; i += 16) {
539
+ // load 16x8 matrix and transpose it in registers.
540
+ // the typical bottleneck is memory access, so
541
+ // let's trade instructions for the bandwidth.
542
+
543
+ __m512 v0;
544
+ __m512 v1;
545
+ __m512 v2;
546
+ __m512 v3;
547
+ __m512 v4;
548
+ __m512 v5;
549
+ __m512 v6;
550
+ __m512 v7;
551
+
552
+ transpose_16x8(
553
+ _mm512_loadu_ps(y + 0 * 16),
554
+ _mm512_loadu_ps(y + 1 * 16),
555
+ _mm512_loadu_ps(y + 2 * 16),
556
+ _mm512_loadu_ps(y + 3 * 16),
557
+ _mm512_loadu_ps(y + 4 * 16),
558
+ _mm512_loadu_ps(y + 5 * 16),
559
+ _mm512_loadu_ps(y + 6 * 16),
560
+ _mm512_loadu_ps(y + 7 * 16),
561
+ v0,
562
+ v1,
563
+ v2,
564
+ v3,
565
+ v4,
566
+ v5,
567
+ v6,
568
+ v7);
569
+
570
+ // compute differences
571
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
572
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
573
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
574
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
575
+ const __m512 d4 = _mm512_sub_ps(m4, v4);
576
+ const __m512 d5 = _mm512_sub_ps(m5, v5);
577
+ const __m512 d6 = _mm512_sub_ps(m6, v6);
578
+ const __m512 d7 = _mm512_sub_ps(m7, v7);
579
+
580
+ // compute squares of differences
581
+ __m512 distances = _mm512_mul_ps(d0, d0);
582
+ distances = _mm512_fmadd_ps(d1, d1, distances);
583
+ distances = _mm512_fmadd_ps(d2, d2, distances);
584
+ distances = _mm512_fmadd_ps(d3, d3, distances);
585
+ distances = _mm512_fmadd_ps(d4, d4, distances);
586
+ distances = _mm512_fmadd_ps(d5, d5, distances);
587
+ distances = _mm512_fmadd_ps(d6, d6, distances);
588
+ distances = _mm512_fmadd_ps(d7, d7, distances);
589
+
590
+ // store
591
+ _mm512_storeu_ps(dis + i, distances);
592
+
593
+ y += 128; // 16 floats * 8 rows
594
+ }
595
+ }
596
+
597
+ if (i < ny) {
598
+ // process leftovers
599
+ __m256 x0 = _mm256_loadu_ps(x);
600
+
601
+ for (; i < ny; i++) {
602
+ __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y));
603
+ y += 8;
604
+ dis[i] = horizontal_sum(accu);
605
+ }
606
+ }
607
+ }
608
+
609
+ template <>
610
+ void fvec_inner_products_ny<SIMDLevel::AVX512>(
611
+ float* ip, /* output inner product */
612
+ const float* x,
613
+ const float* y,
614
+ size_t d,
615
+ size_t ny) {
616
+ fvec_inner_products_ny_ref<AVX512ElementOpIP>(ip, x, y, d, ny);
617
+ }
618
+
619
+ template <>
620
+ void fvec_L2sqr_ny<SIMDLevel::AVX512>(
621
+ float* dis,
622
+ const float* x,
623
+ const float* y,
624
+ size_t d,
625
+ size_t ny) {
626
+ fvec_L2sqr_ny_ref<AVX512ElementOpL2>(dis, x, y, d, ny);
627
+ }
628
+
629
+ template <>
630
+ size_t fvec_L2sqr_ny_nearest_D2<SIMDLevel::AVX512>(
631
+ float* distances_tmp_buffer,
632
+ const float* x,
633
+ const float* y,
634
+ size_t ny) {
635
+ // this implementation does not use distances_tmp_buffer.
636
+
637
+ size_t i = 0;
638
+ float current_min_distance = HUGE_VALF;
639
+ size_t current_min_index = 0;
640
+
641
+ const size_t ny16 = ny / 16;
642
+ if (ny16 > 0) {
643
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
644
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
645
+
646
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
647
+ __m512i min_indices = _mm512_set1_epi32(0);
648
+
649
+ __m512i current_indices = _mm512_setr_epi32(
650
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
651
+ const __m512i indices_increment = _mm512_set1_epi32(16);
652
+
653
+ const __m512 m0 = _mm512_set1_ps(x[0]);
654
+ const __m512 m1 = _mm512_set1_ps(x[1]);
655
+
656
+ for (; i < ny16 * 16; i += 16) {
657
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
658
+
659
+ __m512 v0;
660
+ __m512 v1;
661
+
662
+ transpose_16x2(
663
+ _mm512_loadu_ps(y + 0 * 16),
664
+ _mm512_loadu_ps(y + 1 * 16),
665
+ v0,
666
+ v1);
667
+
668
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
669
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
670
+
671
+ __m512 distances = _mm512_mul_ps(d0, d0);
672
+ distances = _mm512_fmadd_ps(d1, d1, distances);
673
+
674
+ __mmask16 comparison =
675
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
676
+
677
+ min_distances = _mm512_min_ps(distances, min_distances);
678
+ min_indices = _mm512_mask_blend_epi32(
679
+ comparison, min_indices, current_indices);
680
+
681
+ current_indices =
682
+ _mm512_add_epi32(current_indices, indices_increment);
683
+
684
+ y += 32;
685
+ }
686
+
687
+ alignas(64) float min_distances_scalar[16];
688
+ alignas(64) uint32_t min_indices_scalar[16];
689
+ _mm512_store_ps(min_distances_scalar, min_distances);
690
+ _mm512_store_epi32(min_indices_scalar, min_indices);
691
+
692
+ for (size_t j = 0; j < 16; j++) {
693
+ if (current_min_distance > min_distances_scalar[j]) {
694
+ current_min_distance = min_distances_scalar[j];
695
+ current_min_index = min_indices_scalar[j];
696
+ }
697
+ }
698
+ }
699
+
700
+ if (i < ny) {
701
+ float x0 = x[0];
702
+ float x1 = x[1];
703
+
704
+ for (; i < ny; i++) {
705
+ float sub0 = x0 - y[0];
706
+ float sub1 = x1 - y[1];
707
+ float distance = sub0 * sub0 + sub1 * sub1;
708
+
709
+ y += 2;
710
+
711
+ if (current_min_distance > distance) {
712
+ current_min_distance = distance;
713
+ current_min_index = i;
714
+ }
715
+ }
716
+ }
717
+
718
+ return current_min_index;
719
+ }
720
+
721
+ template <>
722
+ size_t fvec_L2sqr_ny_nearest_D4<SIMDLevel::AVX512>(
723
+ float* distances_tmp_buffer,
724
+ const float* x,
725
+ const float* y,
726
+ size_t ny) {
727
+ // this implementation does not use distances_tmp_buffer.
728
+
729
+ size_t i = 0;
730
+ float current_min_distance = HUGE_VALF;
731
+ size_t current_min_index = 0;
732
+
733
+ const size_t ny16 = ny / 16;
734
+
735
+ if (ny16 > 0) {
736
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
737
+ __m512i min_indices = _mm512_set1_epi32(0);
738
+
739
+ __m512i current_indices = _mm512_setr_epi32(
740
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
741
+ const __m512i indices_increment = _mm512_set1_epi32(16);
742
+
743
+ const __m512 m0 = _mm512_set1_ps(x[0]);
744
+ const __m512 m1 = _mm512_set1_ps(x[1]);
745
+ const __m512 m2 = _mm512_set1_ps(x[2]);
746
+ const __m512 m3 = _mm512_set1_ps(x[3]);
747
+
748
+ for (; i < ny16 * 16; i += 16) {
749
+ __m512 v0;
750
+ __m512 v1;
751
+ __m512 v2;
752
+ __m512 v3;
753
+
754
+ transpose_16x4(
755
+ _mm512_loadu_ps(y + 0 * 16),
756
+ _mm512_loadu_ps(y + 1 * 16),
757
+ _mm512_loadu_ps(y + 2 * 16),
758
+ _mm512_loadu_ps(y + 3 * 16),
759
+ v0,
760
+ v1,
761
+ v2,
762
+ v3);
763
+
764
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
765
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
766
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
767
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
768
+
769
+ __m512 distances = _mm512_mul_ps(d0, d0);
770
+ distances = _mm512_fmadd_ps(d1, d1, distances);
771
+ distances = _mm512_fmadd_ps(d2, d2, distances);
772
+ distances = _mm512_fmadd_ps(d3, d3, distances);
773
+
774
+ __mmask16 comparison =
775
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
776
+
777
+ min_distances = _mm512_min_ps(distances, min_distances);
778
+ min_indices = _mm512_mask_blend_epi32(
779
+ comparison, min_indices, current_indices);
780
+
781
+ current_indices =
782
+ _mm512_add_epi32(current_indices, indices_increment);
783
+
784
+ y += 64;
785
+ }
786
+
787
+ alignas(64) float min_distances_scalar[16];
788
+ alignas(64) uint32_t min_indices_scalar[16];
789
+ _mm512_store_ps(min_distances_scalar, min_distances);
790
+ _mm512_store_epi32(min_indices_scalar, min_indices);
791
+
792
+ for (size_t j = 0; j < 16; j++) {
793
+ if (current_min_distance > min_distances_scalar[j]) {
794
+ current_min_distance = min_distances_scalar[j];
795
+ current_min_index = min_indices_scalar[j];
796
+ }
797
+ }
798
+ }
799
+
800
+ if (i < ny) {
801
+ __m128 x0 = _mm_loadu_ps(x);
802
+
803
+ for (; i < ny; i++) {
804
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
805
+ y += 4;
806
+ const float distance = horizontal_sum(accu);
807
+
808
+ if (current_min_distance > distance) {
809
+ current_min_distance = distance;
810
+ current_min_index = i;
811
+ }
812
+ }
813
+ }
814
+
815
+ return current_min_index;
816
+ }
817
+
818
+ template <>
819
+ size_t fvec_L2sqr_ny_nearest_D8<SIMDLevel::AVX512>(
820
+ float* distances_tmp_buffer,
821
+ const float* x,
822
+ const float* y,
823
+ size_t ny) {
824
+ // this implementation does not use distances_tmp_buffer.
825
+
826
+ size_t i = 0;
827
+ float current_min_distance = HUGE_VALF;
828
+ size_t current_min_index = 0;
829
+
830
+ const size_t ny16 = ny / 16;
831
+ if (ny16 > 0) {
832
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
833
+ __m512i min_indices = _mm512_set1_epi32(0);
834
+
835
+ __m512i current_indices = _mm512_setr_epi32(
836
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
837
+ const __m512i indices_increment = _mm512_set1_epi32(16);
838
+
839
+ const __m512 m0 = _mm512_set1_ps(x[0]);
840
+ const __m512 m1 = _mm512_set1_ps(x[1]);
841
+ const __m512 m2 = _mm512_set1_ps(x[2]);
842
+ const __m512 m3 = _mm512_set1_ps(x[3]);
843
+
844
+ const __m512 m4 = _mm512_set1_ps(x[4]);
845
+ const __m512 m5 = _mm512_set1_ps(x[5]);
846
+ const __m512 m6 = _mm512_set1_ps(x[6]);
847
+ const __m512 m7 = _mm512_set1_ps(x[7]);
848
+
849
+ for (; i < ny16 * 16; i += 16) {
850
+ __m512 v0;
851
+ __m512 v1;
852
+ __m512 v2;
853
+ __m512 v3;
854
+ __m512 v4;
855
+ __m512 v5;
856
+ __m512 v6;
857
+ __m512 v7;
858
+
859
+ transpose_16x8(
860
+ _mm512_loadu_ps(y + 0 * 16),
861
+ _mm512_loadu_ps(y + 1 * 16),
862
+ _mm512_loadu_ps(y + 2 * 16),
863
+ _mm512_loadu_ps(y + 3 * 16),
864
+ _mm512_loadu_ps(y + 4 * 16),
865
+ _mm512_loadu_ps(y + 5 * 16),
866
+ _mm512_loadu_ps(y + 6 * 16),
867
+ _mm512_loadu_ps(y + 7 * 16),
868
+ v0,
869
+ v1,
870
+ v2,
871
+ v3,
872
+ v4,
873
+ v5,
874
+ v6,
875
+ v7);
876
+
877
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
878
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
879
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
880
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
881
+ const __m512 d4 = _mm512_sub_ps(m4, v4);
882
+ const __m512 d5 = _mm512_sub_ps(m5, v5);
883
+ const __m512 d6 = _mm512_sub_ps(m6, v6);
884
+ const __m512 d7 = _mm512_sub_ps(m7, v7);
885
+
886
+ __m512 distances = _mm512_mul_ps(d0, d0);
887
+ distances = _mm512_fmadd_ps(d1, d1, distances);
888
+ distances = _mm512_fmadd_ps(d2, d2, distances);
889
+ distances = _mm512_fmadd_ps(d3, d3, distances);
890
+ distances = _mm512_fmadd_ps(d4, d4, distances);
891
+ distances = _mm512_fmadd_ps(d5, d5, distances);
892
+ distances = _mm512_fmadd_ps(d6, d6, distances);
893
+ distances = _mm512_fmadd_ps(d7, d7, distances);
894
+
895
+ __mmask16 comparison =
896
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
897
+
898
+ min_distances = _mm512_min_ps(distances, min_distances);
899
+ min_indices = _mm512_mask_blend_epi32(
900
+ comparison, min_indices, current_indices);
901
+
902
+ current_indices =
903
+ _mm512_add_epi32(current_indices, indices_increment);
904
+
905
+ y += 128;
906
+ }
907
+
908
+ alignas(64) float min_distances_scalar[16];
909
+ alignas(64) uint32_t min_indices_scalar[16];
910
+ _mm512_store_ps(min_distances_scalar, min_distances);
911
+ _mm512_store_epi32(min_indices_scalar, min_indices);
912
+
913
+ for (size_t j = 0; j < 16; j++) {
914
+ if (current_min_distance > min_distances_scalar[j]) {
915
+ current_min_distance = min_distances_scalar[j];
916
+ current_min_index = min_indices_scalar[j];
917
+ }
918
+ }
919
+ }
920
+
921
+ if (i < ny) {
922
+ __m256 x0 = _mm256_loadu_ps(x);
923
+
924
+ for (; i < ny; i++) {
925
+ __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y));
926
+ y += 8;
927
+ const float distance = horizontal_sum(accu);
928
+
929
+ if (current_min_distance > distance) {
930
+ current_min_distance = distance;
931
+ current_min_index = i;
932
+ }
933
+ }
934
+ }
935
+
936
+ return current_min_index;
937
+ }
938
+
939
+ template <>
940
+ size_t fvec_L2sqr_ny_nearest<SIMDLevel::AVX512>(
941
+ float* distances_tmp_buffer,
942
+ const float* x,
943
+ const float* y,
944
+ size_t d,
945
+ size_t ny) {
946
+ return fvec_L2sqr_ny_nearest_x86<SIMDLevel::AVX512>(
947
+ distances_tmp_buffer,
948
+ x,
949
+ y,
950
+ d,
951
+ ny,
952
+ &fvec_L2sqr_ny_nearest_D2<SIMDLevel::AVX512>,
953
+ &fvec_L2sqr_ny_nearest_D4<SIMDLevel::AVX512>,
954
+ &fvec_L2sqr_ny_nearest_D8<SIMDLevel::AVX512>);
955
+ }
956
+
957
+ template <>
958
+ size_t fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::AVX512>(
959
+ float* distances_tmp_buffer,
960
+ const float* x,
961
+ const float* y,
962
+ const float* y_sqlen,
963
+ size_t d,
964
+ size_t d_offset,
965
+ size_t ny) {
966
+ return fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::NONE>(
967
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
968
+ }
969
+
970
+ // TODO: Following functions are not used in the current codebase. Check AVX2 ,
971
+ // respective implementation has been used
972
+ template <size_t DIM>
973
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
974
+ float* /* distances_tmp_buffer */,
975
+ const float* x,
976
+ const float* y,
977
+ const float* y_sqlen,
978
+ const size_t d_offset,
979
+ size_t ny) {
980
+ // This implementation does not use distances_tmp_buffer.
981
+
982
+ // Current index being processed
983
+ size_t i = 0;
984
+
985
+ // Min distance and the index of the closest vector so far
986
+ float current_min_distance = HUGE_VALF;
987
+ size_t current_min_index = 0;
988
+
989
+ // Process 16 vectors per loop
990
+ const size_t ny16 = ny / 16;
991
+
992
+ if (ny16 > 0) {
993
+ // Track min distance and the closest vector independently
994
+ // for each of 16 AVX-512 components.
995
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
996
+ __m512i min_indices = _mm512_set1_epi32(0);
997
+
998
+ __m512i current_indices = _mm512_setr_epi32(
999
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1000
+ const __m512i indices_increment = _mm512_set1_epi32(16);
1001
+
1002
+ // m[i] = (2 * x[i], ... 2 * x[i])
1003
+ __m512 m[DIM];
1004
+ for (size_t j = 0; j < DIM; j++) {
1005
+ m[j] = _mm512_set1_ps(x[j]);
1006
+ m[j] = _mm512_add_ps(m[j], m[j]);
1007
+ }
1008
+
1009
+ for (; i < ny16 * 16; i += 16) {
1010
+ // Compute dot products
1011
+ const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
1012
+ __m512 dp = _mm512_mul_ps(m[0], v0);
1013
+ for (size_t j = 1; j < DIM; j++) {
1014
+ const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
1015
+ dp = _mm512_fmadd_ps(m[j], vj, dp);
1016
+ }
1017
+
1018
+ // Compute y^2 - (2 * x, y), which is sufficient for looking for the
1019
+ // lowest distance.
1020
+ // x^2 is the constant that can be avoided.
1021
+ const __m512 distances =
1022
+ _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
1023
+
1024
+ // Compare the new distances to the min distances
1025
+ __mmask16 comparison =
1026
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1027
+
1028
+ // Update min distances and indices with closest vectors if needed
1029
+ min_distances =
1030
+ _mm512_mask_blend_ps(comparison, distances, min_distances);
1031
+ min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
1032
+ comparison,
1033
+ _mm512_castsi512_ps(current_indices),
1034
+ _mm512_castsi512_ps(min_indices)));
1035
+
1036
+ // Update current indices values. Basically, +16 to each of the 16
1037
+ // AVX-512 components.
1038
+ current_indices =
1039
+ _mm512_add_epi32(current_indices, indices_increment);
1040
+
1041
+ // Scroll y and y_sqlen forward.
1042
+ y += 16;
1043
+ y_sqlen += 16;
1044
+ }
1045
+
1046
+ // Dump values and find the minimum distance / minimum index
1047
+ float min_distances_scalar[16];
1048
+ uint32_t min_indices_scalar[16];
1049
+ _mm512_storeu_ps(min_distances_scalar, min_distances);
1050
+ _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
1051
+
1052
+ for (size_t j = 0; j < 16; j++) {
1053
+ if (current_min_distance > min_distances_scalar[j]) {
1054
+ current_min_distance = min_distances_scalar[j];
1055
+ current_min_index = min_indices_scalar[j];
1056
+ }
1057
+ }
1058
+ }
1059
+
1060
+ if (i < ny) {
1061
+ // Process leftovers
1062
+ for (; i < ny; i++) {
1063
+ float dp = 0;
1064
+ for (size_t j = 0; j < DIM; j++) {
1065
+ dp += x[j] * y[j * d_offset];
1066
+ }
1067
+
1068
+ // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
1069
+ // lowest distance.
1070
+ const float distance = y_sqlen[0] - 2 * dp;
1071
+
1072
+ if (current_min_distance > distance) {
1073
+ current_min_distance = distance;
1074
+ current_min_index = i;
1075
+ }
1076
+
1077
+ y += 1;
1078
+ y_sqlen += 1;
1079
+ }
1080
+ }
1081
+
1082
+ return current_min_index;
1083
+ }
1084
+
1085
+ template <>
1086
+ int fvec_madd_and_argmin<SIMDLevel::AVX512>(
1087
+ size_t n,
1088
+ const float* a,
1089
+ float bf,
1090
+ const float* b,
1091
+ float* c) {
1092
+ return fvec_madd_and_argmin_sse(n, a, bf, b, c);
1093
+ }
1094
+
1095
+ } // namespace faiss