faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -0,0 +1,1431 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/utils/distances.h>
9
+
10
+ #include <immintrin.h>
11
+
12
+ #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/ResultHandler.h>
14
+ #include <faiss/utils/distances_fused/distances_fused.h>
15
+ #include <faiss/utils/simd_impl/exhaustive_L2sqr_blas_cmax.h>
16
+
17
+ #ifndef FINTEGER
18
+ #define FINTEGER long
19
+ #endif
20
+
21
+ extern "C" {
22
+
23
+ int sgemm_(
24
+ const char* transa,
25
+ const char* transb,
26
+ FINTEGER* m,
27
+ FINTEGER* n,
28
+ FINTEGER* k,
29
+ const float* alpha,
30
+ const float* a,
31
+ FINTEGER* lda,
32
+ const float* b,
33
+ FINTEGER* ldb,
34
+ float* beta,
35
+ float* c,
36
+ FINTEGER* ldc);
37
+ }
38
+
39
+ #define THE_SIMD_LEVEL SIMDLevel::AVX2
40
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
41
+ #include <faiss/utils/simd_impl/distances_autovec-inl.h>
42
+
43
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
44
+ #include <faiss/utils/simd_impl/distances_simdlib256.h>
45
+
46
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
47
+ #include <faiss/utils/simd_impl/distances_sse-inl.h>
48
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
49
+ #include <faiss/utils/transpose/transpose-avx2-inl.h>
50
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
51
+ #include <faiss/utils/simd_impl/IVFFlatScanner-inl.h>
52
+
53
+ namespace faiss {
54
+
55
+ template <>
56
+ void fvec_madd<SIMDLevel::AVX2>(
57
+ const size_t n,
58
+ const float* a,
59
+ const float bf,
60
+ const float* b,
61
+ float* c) {
62
+ //
63
+ const size_t n8 = n / 8;
64
+ const size_t n_for_masking = n % 8;
65
+
66
+ const __m256 bfmm = _mm256_set1_ps(bf);
67
+
68
+ size_t idx = 0;
69
+ for (idx = 0; idx < n8 * 8; idx += 8) {
70
+ const __m256 ax = _mm256_loadu_ps(a + idx);
71
+ const __m256 bx = _mm256_loadu_ps(b + idx);
72
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
73
+ _mm256_storeu_ps(c + idx, abmul);
74
+ }
75
+
76
+ if (n_for_masking > 0) {
77
+ __m256i mask;
78
+ switch (n_for_masking) {
79
+ case 1:
80
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
81
+ break;
82
+ case 2:
83
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
84
+ break;
85
+ case 3:
86
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
87
+ break;
88
+ case 4:
89
+ mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
90
+ break;
91
+ case 5:
92
+ mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
93
+ break;
94
+ case 6:
95
+ mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
96
+ break;
97
+ case 7:
98
+ mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
99
+ break;
100
+ }
101
+
102
+ const __m256 ax = _mm256_maskload_ps(a + idx, mask);
103
+ const __m256 bx = _mm256_maskload_ps(b + idx, mask);
104
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
105
+ _mm256_maskstore_ps(c + idx, mask, abmul);
106
+ }
107
+ }
108
+
109
+ template <size_t DIM>
110
+ void fvec_L2sqr_ny_y_transposed_D(
111
+ float* distances,
112
+ const float* x,
113
+ const float* y,
114
+ const float* y_sqlen,
115
+ const size_t d_offset,
116
+ size_t ny) {
117
+ // current index being processed
118
+ size_t i = 0;
119
+
120
+ // squared length of x
121
+ float x_sqlen = 0;
122
+ for (size_t j = 0; j < DIM; j++) {
123
+ x_sqlen += x[j] * x[j];
124
+ }
125
+
126
+ // process 8 vectors per loop.
127
+ const size_t ny8 = ny / 8;
128
+
129
+ if (ny8 > 0) {
130
+ // m[i] = (2 * x[i], ... 2 * x[i])
131
+ __m256 m[DIM];
132
+ for (size_t j = 0; j < DIM; j++) {
133
+ m[j] = _mm256_set1_ps(x[j]);
134
+ m[j] = _mm256_add_ps(m[j], m[j]);
135
+ }
136
+
137
+ __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
138
+
139
+ for (; i < ny8 * 8; i += 8) {
140
+ // collect dim 0 for 8 D4-vectors.
141
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
142
+
143
+ // compute dot products
144
+ // this is x^2 - 2x[0]*y[0]
145
+ __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
146
+
147
+ for (size_t j = 1; j < DIM; j++) {
148
+ // collect dim j for 8 D4-vectors.
149
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
150
+ dp = _mm256_fnmadd_ps(m[j], vj, dp);
151
+ }
152
+
153
+ // we've got x^2 - (2x, y) at this point
154
+
155
+ // y^2 - (2x, y) + x^2
156
+ __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
157
+
158
+ _mm256_storeu_ps(distances + i, distances_v);
159
+
160
+ // scroll y and y_sqlen forward.
161
+ y += 8;
162
+ y_sqlen += 8;
163
+ }
164
+ }
165
+
166
+ if (i < ny) {
167
+ // process leftovers
168
+ for (; i < ny; i++) {
169
+ float dp = 0;
170
+ for (size_t j = 0; j < DIM; j++) {
171
+ dp += x[j] * y[j * d_offset];
172
+ }
173
+
174
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
175
+ // lowest distance.
176
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
177
+ distances[i] = distance;
178
+
179
+ y += 1;
180
+ y_sqlen += 1;
181
+ }
182
+ }
183
+ }
184
+
185
+ template <>
186
+ void fvec_L2sqr_ny_transposed<SIMDLevel::AVX2>(
187
+ float* dis,
188
+ const float* x,
189
+ const float* y,
190
+ const float* y_sqlen,
191
+ size_t d,
192
+ size_t d_offset,
193
+ size_t ny) {
194
+ // optimized for a few special cases
195
+ #define DISPATCH(dval) \
196
+ case dval: \
197
+ return fvec_L2sqr_ny_y_transposed_D<dval>( \
198
+ dis, x, y, y_sqlen, d_offset, ny);
199
+
200
+ switch (d) {
201
+ DISPATCH(1)
202
+ DISPATCH(2)
203
+ DISPATCH(4)
204
+ DISPATCH(8)
205
+ default:
206
+ return fvec_L2sqr_ny_transposed<SIMDLevel::NONE>(
207
+ dis, x, y, y_sqlen, d, d_offset, ny);
208
+ }
209
+ #undef DISPATCH
210
+ }
211
+
212
+ namespace {
213
+
214
+ struct AVX2ElementOpIP : public ElementOpIP {
215
+ using ElementOpIP::op;
216
+ static __m256 op(__m256 x, __m256 y) {
217
+ return _mm256_mul_ps(x, y);
218
+ }
219
+ };
220
+
221
+ struct AVX2ElementOpL2 : public ElementOpL2 {
222
+ using ElementOpL2::op;
223
+
224
+ static __m256 op(__m256 x, __m256 y) {
225
+ __m256 tmp = _mm256_sub_ps(x, y);
226
+ return _mm256_mul_ps(tmp, tmp);
227
+ }
228
+ };
229
+
230
+ } // namespace
231
+
232
+ /// helper function for AVX2
233
+ inline float horizontal_sum(const __m256 v) {
234
+ // add high and low parts
235
+ const __m128 v0 =
236
+ _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
237
+ // perform horizontal sum on v0
238
+ return horizontal_sum(v0);
239
+ }
240
+
241
+ template <>
242
+ void fvec_op_ny_D2<AVX2ElementOpIP>(
243
+ float* dis,
244
+ const float* x,
245
+ const float* y,
246
+ size_t ny) {
247
+ const size_t ny8 = ny / 8;
248
+ size_t i = 0;
249
+
250
+ if (ny8 > 0) {
251
+ // process 8 D2-vectors per loop.
252
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
253
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
254
+
255
+ const __m256 m0 = _mm256_set1_ps(x[0]);
256
+ const __m256 m1 = _mm256_set1_ps(x[1]);
257
+
258
+ for (i = 0; i < ny8 * 8; i += 8) {
259
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
260
+
261
+ // load 8x2 matrix and transpose it in registers.
262
+ // the typical bottleneck is memory access, so
263
+ // let's trade instructions for the bandwidth.
264
+
265
+ __m256 v0;
266
+ __m256 v1;
267
+
268
+ transpose_8x2(
269
+ _mm256_loadu_ps(y + 0 * 8),
270
+ _mm256_loadu_ps(y + 1 * 8),
271
+ v0,
272
+ v1);
273
+
274
+ // compute distances
275
+ __m256 distances = _mm256_mul_ps(m0, v0);
276
+ distances = _mm256_fmadd_ps(m1, v1, distances);
277
+
278
+ // store
279
+ _mm256_storeu_ps(dis + i, distances);
280
+
281
+ y += 16;
282
+ }
283
+ }
284
+
285
+ if (i < ny) {
286
+ // process leftovers
287
+ float x0 = x[0];
288
+ float x1 = x[1];
289
+
290
+ for (; i < ny; i++) {
291
+ float distance = x0 * y[0] + x1 * y[1];
292
+ y += 2;
293
+ dis[i] = distance;
294
+ }
295
+ }
296
+ }
297
+
298
+ template <>
299
+ void fvec_op_ny_D2<AVX2ElementOpL2>(
300
+ float* dis,
301
+ const float* x,
302
+ const float* y,
303
+ size_t ny) {
304
+ const size_t ny8 = ny / 8;
305
+ size_t i = 0;
306
+
307
+ if (ny8 > 0) {
308
+ // process 8 D2-vectors per loop.
309
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
310
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
311
+
312
+ const __m256 m0 = _mm256_set1_ps(x[0]);
313
+ const __m256 m1 = _mm256_set1_ps(x[1]);
314
+
315
+ for (i = 0; i < ny8 * 8; i += 8) {
316
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
317
+
318
+ // load 8x2 matrix and transpose it in registers.
319
+ // the typical bottleneck is memory access, so
320
+ // let's trade instructions for the bandwidth.
321
+
322
+ __m256 v0;
323
+ __m256 v1;
324
+
325
+ transpose_8x2(
326
+ _mm256_loadu_ps(y + 0 * 8),
327
+ _mm256_loadu_ps(y + 1 * 8),
328
+ v0,
329
+ v1);
330
+
331
+ // compute differences
332
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
333
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
334
+
335
+ // compute squares of differences
336
+ __m256 distances = _mm256_mul_ps(d0, d0);
337
+ distances = _mm256_fmadd_ps(d1, d1, distances);
338
+
339
+ // store
340
+ _mm256_storeu_ps(dis + i, distances);
341
+
342
+ y += 16;
343
+ }
344
+ }
345
+
346
+ if (i < ny) {
347
+ // process leftovers
348
+ float x0 = x[0];
349
+ float x1 = x[1];
350
+
351
+ for (; i < ny; i++) {
352
+ float sub0 = x0 - y[0];
353
+ float sub1 = x1 - y[1];
354
+ float distance = sub0 * sub0 + sub1 * sub1;
355
+
356
+ y += 2;
357
+ dis[i] = distance;
358
+ }
359
+ }
360
+ }
361
+
362
+ template <>
363
+ void fvec_op_ny_D4<AVX2ElementOpIP>(
364
+ float* dis,
365
+ const float* x,
366
+ const float* y,
367
+ size_t ny) {
368
+ const size_t ny8 = ny / 8;
369
+ size_t i = 0;
370
+
371
+ if (ny8 > 0) {
372
+ // process 8 D4-vectors per loop.
373
+ const __m256 m0 = _mm256_set1_ps(x[0]);
374
+ const __m256 m1 = _mm256_set1_ps(x[1]);
375
+ const __m256 m2 = _mm256_set1_ps(x[2]);
376
+ const __m256 m3 = _mm256_set1_ps(x[3]);
377
+
378
+ for (i = 0; i < ny8 * 8; i += 8) {
379
+ // load 8x4 matrix and transpose it in registers.
380
+ // the typical bottleneck is memory access, so
381
+ // let's trade instructions for the bandwidth.
382
+
383
+ __m256 v0;
384
+ __m256 v1;
385
+ __m256 v2;
386
+ __m256 v3;
387
+
388
+ transpose_8x4(
389
+ _mm256_loadu_ps(y + 0 * 8),
390
+ _mm256_loadu_ps(y + 1 * 8),
391
+ _mm256_loadu_ps(y + 2 * 8),
392
+ _mm256_loadu_ps(y + 3 * 8),
393
+ v0,
394
+ v1,
395
+ v2,
396
+ v3);
397
+
398
+ // compute distances
399
+ __m256 distances = _mm256_mul_ps(m0, v0);
400
+ distances = _mm256_fmadd_ps(m1, v1, distances);
401
+ distances = _mm256_fmadd_ps(m2, v2, distances);
402
+ distances = _mm256_fmadd_ps(m3, v3, distances);
403
+
404
+ // store
405
+ _mm256_storeu_ps(dis + i, distances);
406
+
407
+ y += 32;
408
+ }
409
+ }
410
+
411
+ if (i < ny) {
412
+ // process leftovers
413
+ __m128 x0 = _mm_loadu_ps(x);
414
+
415
+ for (; i < ny; i++) {
416
+ __m128 accu = AVX2ElementOpIP::op(x0, _mm_loadu_ps(y));
417
+ y += 4;
418
+ dis[i] = horizontal_sum(accu);
419
+ }
420
+ }
421
+ }
422
+
423
+ template <>
424
+ void fvec_op_ny_D4<AVX2ElementOpL2>(
425
+ float* dis,
426
+ const float* x,
427
+ const float* y,
428
+ size_t ny) {
429
+ const size_t ny8 = ny / 8;
430
+ size_t i = 0;
431
+
432
+ if (ny8 > 0) {
433
+ // process 8 D4-vectors per loop.
434
+ const __m256 m0 = _mm256_set1_ps(x[0]);
435
+ const __m256 m1 = _mm256_set1_ps(x[1]);
436
+ const __m256 m2 = _mm256_set1_ps(x[2]);
437
+ const __m256 m3 = _mm256_set1_ps(x[3]);
438
+
439
+ for (i = 0; i < ny8 * 8; i += 8) {
440
+ // load 8x4 matrix and transpose it in registers.
441
+ // the typical bottleneck is memory access, so
442
+ // let's trade instructions for the bandwidth.
443
+
444
+ __m256 v0;
445
+ __m256 v1;
446
+ __m256 v2;
447
+ __m256 v3;
448
+
449
+ transpose_8x4(
450
+ _mm256_loadu_ps(y + 0 * 8),
451
+ _mm256_loadu_ps(y + 1 * 8),
452
+ _mm256_loadu_ps(y + 2 * 8),
453
+ _mm256_loadu_ps(y + 3 * 8),
454
+ v0,
455
+ v1,
456
+ v2,
457
+ v3);
458
+
459
+ // compute differences
460
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
461
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
462
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
463
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
464
+
465
+ // compute squares of differences
466
+ __m256 distances = _mm256_mul_ps(d0, d0);
467
+ distances = _mm256_fmadd_ps(d1, d1, distances);
468
+ distances = _mm256_fmadd_ps(d2, d2, distances);
469
+ distances = _mm256_fmadd_ps(d3, d3, distances);
470
+
471
+ // store
472
+ _mm256_storeu_ps(dis + i, distances);
473
+
474
+ y += 32;
475
+ }
476
+ }
477
+
478
+ if (i < ny) {
479
+ // process leftovers
480
+ __m128 x0 = _mm_loadu_ps(x);
481
+
482
+ for (; i < ny; i++) {
483
+ __m128 accu = AVX2ElementOpL2::op(x0, _mm_loadu_ps(y));
484
+ y += 4;
485
+ dis[i] = horizontal_sum(accu);
486
+ }
487
+ }
488
+ }
489
+
490
+ template <>
491
+ void fvec_op_ny_D8<AVX2ElementOpIP>(
492
+ float* dis,
493
+ const float* x,
494
+ const float* y,
495
+ size_t ny) {
496
+ const size_t ny8 = ny / 8;
497
+ size_t i = 0;
498
+
499
+ if (ny8 > 0) {
500
+ // process 8 D8-vectors per loop.
501
+ const __m256 m0 = _mm256_set1_ps(x[0]);
502
+ const __m256 m1 = _mm256_set1_ps(x[1]);
503
+ const __m256 m2 = _mm256_set1_ps(x[2]);
504
+ const __m256 m3 = _mm256_set1_ps(x[3]);
505
+ const __m256 m4 = _mm256_set1_ps(x[4]);
506
+ const __m256 m5 = _mm256_set1_ps(x[5]);
507
+ const __m256 m6 = _mm256_set1_ps(x[6]);
508
+ const __m256 m7 = _mm256_set1_ps(x[7]);
509
+
510
+ for (i = 0; i < ny8 * 8; i += 8) {
511
+ // load 8x8 matrix and transpose it in registers.
512
+ // the typical bottleneck is memory access, so
513
+ // let's trade instructions for the bandwidth.
514
+
515
+ __m256 v0;
516
+ __m256 v1;
517
+ __m256 v2;
518
+ __m256 v3;
519
+ __m256 v4;
520
+ __m256 v5;
521
+ __m256 v6;
522
+ __m256 v7;
523
+
524
+ transpose_8x8(
525
+ _mm256_loadu_ps(y + 0 * 8),
526
+ _mm256_loadu_ps(y + 1 * 8),
527
+ _mm256_loadu_ps(y + 2 * 8),
528
+ _mm256_loadu_ps(y + 3 * 8),
529
+ _mm256_loadu_ps(y + 4 * 8),
530
+ _mm256_loadu_ps(y + 5 * 8),
531
+ _mm256_loadu_ps(y + 6 * 8),
532
+ _mm256_loadu_ps(y + 7 * 8),
533
+ v0,
534
+ v1,
535
+ v2,
536
+ v3,
537
+ v4,
538
+ v5,
539
+ v6,
540
+ v7);
541
+
542
+ // compute distances
543
+ __m256 distances = _mm256_mul_ps(m0, v0);
544
+ distances = _mm256_fmadd_ps(m1, v1, distances);
545
+ distances = _mm256_fmadd_ps(m2, v2, distances);
546
+ distances = _mm256_fmadd_ps(m3, v3, distances);
547
+ distances = _mm256_fmadd_ps(m4, v4, distances);
548
+ distances = _mm256_fmadd_ps(m5, v5, distances);
549
+ distances = _mm256_fmadd_ps(m6, v6, distances);
550
+ distances = _mm256_fmadd_ps(m7, v7, distances);
551
+
552
+ // store
553
+ _mm256_storeu_ps(dis + i, distances);
554
+
555
+ y += 64;
556
+ }
557
+ }
558
+
559
+ if (i < ny) {
560
+ // process leftovers
561
+ __m256 x0 = _mm256_loadu_ps(x);
562
+
563
+ for (; i < ny; i++) {
564
+ __m256 accu = AVX2ElementOpIP::op(x0, _mm256_loadu_ps(y));
565
+ y += 8;
566
+ dis[i] = horizontal_sum(accu);
567
+ }
568
+ }
569
+ }
570
+
571
+ template <>
572
+ void fvec_op_ny_D8<AVX2ElementOpL2>(
573
+ float* dis,
574
+ const float* x,
575
+ const float* y,
576
+ size_t ny) {
577
+ const size_t ny8 = ny / 8;
578
+ size_t i = 0;
579
+
580
+ if (ny8 > 0) {
581
+ // process 8 D8-vectors per loop.
582
+ const __m256 m0 = _mm256_set1_ps(x[0]);
583
+ const __m256 m1 = _mm256_set1_ps(x[1]);
584
+ const __m256 m2 = _mm256_set1_ps(x[2]);
585
+ const __m256 m3 = _mm256_set1_ps(x[3]);
586
+ const __m256 m4 = _mm256_set1_ps(x[4]);
587
+ const __m256 m5 = _mm256_set1_ps(x[5]);
588
+ const __m256 m6 = _mm256_set1_ps(x[6]);
589
+ const __m256 m7 = _mm256_set1_ps(x[7]);
590
+
591
+ for (i = 0; i < ny8 * 8; i += 8) {
592
+ // load 8x8 matrix and transpose it in registers.
593
+ // the typical bottleneck is memory access, so
594
+ // let's trade instructions for the bandwidth.
595
+
596
+ __m256 v0;
597
+ __m256 v1;
598
+ __m256 v2;
599
+ __m256 v3;
600
+ __m256 v4;
601
+ __m256 v5;
602
+ __m256 v6;
603
+ __m256 v7;
604
+
605
+ transpose_8x8(
606
+ _mm256_loadu_ps(y + 0 * 8),
607
+ _mm256_loadu_ps(y + 1 * 8),
608
+ _mm256_loadu_ps(y + 2 * 8),
609
+ _mm256_loadu_ps(y + 3 * 8),
610
+ _mm256_loadu_ps(y + 4 * 8),
611
+ _mm256_loadu_ps(y + 5 * 8),
612
+ _mm256_loadu_ps(y + 6 * 8),
613
+ _mm256_loadu_ps(y + 7 * 8),
614
+ v0,
615
+ v1,
616
+ v2,
617
+ v3,
618
+ v4,
619
+ v5,
620
+ v6,
621
+ v7);
622
+
623
+ // compute differences
624
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
625
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
626
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
627
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
628
+ const __m256 d4 = _mm256_sub_ps(m4, v4);
629
+ const __m256 d5 = _mm256_sub_ps(m5, v5);
630
+ const __m256 d6 = _mm256_sub_ps(m6, v6);
631
+ const __m256 d7 = _mm256_sub_ps(m7, v7);
632
+
633
+ // compute squares of differences
634
+ __m256 distances = _mm256_mul_ps(d0, d0);
635
+ distances = _mm256_fmadd_ps(d1, d1, distances);
636
+ distances = _mm256_fmadd_ps(d2, d2, distances);
637
+ distances = _mm256_fmadd_ps(d3, d3, distances);
638
+ distances = _mm256_fmadd_ps(d4, d4, distances);
639
+ distances = _mm256_fmadd_ps(d5, d5, distances);
640
+ distances = _mm256_fmadd_ps(d6, d6, distances);
641
+ distances = _mm256_fmadd_ps(d7, d7, distances);
642
+
643
+ // store
644
+ _mm256_storeu_ps(dis + i, distances);
645
+
646
+ y += 64;
647
+ }
648
+ }
649
+
650
+ if (i < ny) {
651
+ // process leftovers
652
+ __m256 x0 = _mm256_loadu_ps(x);
653
+
654
+ for (; i < ny; i++) {
655
+ __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y));
656
+ y += 8;
657
+ dis[i] = horizontal_sum(accu);
658
+ }
659
+ }
660
+ }
661
+
662
+ template <>
663
+ void fvec_inner_products_ny<SIMDLevel::AVX2>(
664
+ float* ip, /* output inner product */
665
+ const float* x,
666
+ const float* y,
667
+ size_t d,
668
+ size_t ny) {
669
+ fvec_inner_products_ny_ref<AVX2ElementOpIP>(ip, x, y, d, ny);
670
+ }
671
+
672
+ template <>
673
+ void fvec_L2sqr_ny<SIMDLevel::AVX2>(
674
+ float* dis,
675
+ const float* x,
676
+ const float* y,
677
+ size_t d,
678
+ size_t ny) {
679
+ fvec_L2sqr_ny_ref<AVX2ElementOpL2>(dis, x, y, d, ny);
680
+ }
681
+
682
+ template <>
683
+ size_t fvec_L2sqr_ny_nearest_D2<SIMDLevel::AVX2>(
684
+ float* /*distances_tmp_buffer*/,
685
+ const float* x,
686
+ const float* y,
687
+ size_t ny) {
688
+ // this implementation does not use distances_tmp_buffer.
689
+ // current index being processed
690
+ size_t i = 0;
691
+
692
+ // min distance and the index of the closest vector so far
693
+ float current_min_distance = HUGE_VALF;
694
+ size_t current_min_index = 0;
695
+
696
+ // process 8 D2-vectors per loop.
697
+ const size_t ny8 = ny / 8;
698
+ if (ny8 > 0) {
699
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
700
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
701
+
702
+ // track min distance and the closest vector independently
703
+ // for each of 8 AVX2 components.
704
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
705
+ __m256i min_indices = _mm256_set1_epi32(0);
706
+
707
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
708
+ const __m256i indices_increment = _mm256_set1_epi32(8);
709
+
710
+ // 1 value per register
711
+ const __m256 m0 = _mm256_set1_ps(x[0]);
712
+ const __m256 m1 = _mm256_set1_ps(x[1]);
713
+
714
+ for (; i < ny8 * 8; i += 8) {
715
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
716
+
717
+ __m256 v0;
718
+ __m256 v1;
719
+
720
+ transpose_8x2(
721
+ _mm256_loadu_ps(y + 0 * 8),
722
+ _mm256_loadu_ps(y + 1 * 8),
723
+ v0,
724
+ v1);
725
+
726
+ // compute differences
727
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
728
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
729
+
730
+ // compute squares of differences
731
+ __m256 distances = _mm256_mul_ps(d0, d0);
732
+ distances = _mm256_fmadd_ps(d1, d1, distances);
733
+
734
+ // compare the new distances to the min distances
735
+ // for each of 8 AVX2 components.
736
+ __m256 comparison =
737
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
738
+
739
+ // update min distances and indices with closest vectors if needed.
740
+ min_distances = _mm256_min_ps(distances, min_distances);
741
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
742
+ _mm256_castsi256_ps(current_indices),
743
+ _mm256_castsi256_ps(min_indices),
744
+ comparison));
745
+
746
+ // update current indices values. Basically, +8 to each of the
747
+ // 8 AVX2 components.
748
+ current_indices =
749
+ _mm256_add_epi32(current_indices, indices_increment);
750
+
751
+ // scroll y forward (8 vectors 2 DIM each).
752
+ y += 16;
753
+ }
754
+
755
+ // dump values and find the minimum distance / minimum index
756
+ float min_distances_scalar[8];
757
+ uint32_t min_indices_scalar[8];
758
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
759
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
760
+
761
+ for (size_t j = 0; j < 8; j++) {
762
+ if (current_min_distance > min_distances_scalar[j]) {
763
+ current_min_distance = min_distances_scalar[j];
764
+ current_min_index = min_indices_scalar[j];
765
+ }
766
+ }
767
+ }
768
+
769
+ if (i < ny) {
770
+ // process leftovers.
771
+ // the following code is not optimal, but it is rarely invoked.
772
+ float x0 = x[0];
773
+ float x1 = x[1];
774
+
775
+ for (; i < ny; i++) {
776
+ float sub0 = x0 - y[0];
777
+ float sub1 = x1 - y[1];
778
+ float distance = sub0 * sub0 + sub1 * sub1;
779
+
780
+ y += 2;
781
+
782
+ if (current_min_distance > distance) {
783
+ current_min_distance = distance;
784
+ current_min_index = i;
785
+ }
786
+ }
787
+ }
788
+
789
+ return current_min_index;
790
+ }
791
+
792
+ template <>
793
+ size_t fvec_L2sqr_ny_nearest_D4<SIMDLevel::AVX2>(
794
+ float* /*distances_tmp_buffer*/,
795
+ const float* x,
796
+ const float* y,
797
+ size_t ny) {
798
+ // this implementation does not use distances_tmp_buffer.
799
+
800
+ // current index being processed
801
+ size_t i = 0;
802
+
803
+ // min distance and the index of the closest vector so far
804
+ float current_min_distance = HUGE_VALF;
805
+ size_t current_min_index = 0;
806
+
807
+ // process 8 D4-vectors per loop.
808
+ const size_t ny8 = ny / 8;
809
+
810
+ if (ny8 > 0) {
811
+ // track min distance and the closest vector independently
812
+ // for each of 8 AVX2 components.
813
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
814
+ __m256i min_indices = _mm256_set1_epi32(0);
815
+
816
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
817
+ const __m256i indices_increment = _mm256_set1_epi32(8);
818
+
819
+ // 1 value per register
820
+ const __m256 m0 = _mm256_set1_ps(x[0]);
821
+ const __m256 m1 = _mm256_set1_ps(x[1]);
822
+ const __m256 m2 = _mm256_set1_ps(x[2]);
823
+ const __m256 m3 = _mm256_set1_ps(x[3]);
824
+
825
+ for (; i < ny8 * 8; i += 8) {
826
+ __m256 v0;
827
+ __m256 v1;
828
+ __m256 v2;
829
+ __m256 v3;
830
+
831
+ transpose_8x4(
832
+ _mm256_loadu_ps(y + 0 * 8),
833
+ _mm256_loadu_ps(y + 1 * 8),
834
+ _mm256_loadu_ps(y + 2 * 8),
835
+ _mm256_loadu_ps(y + 3 * 8),
836
+ v0,
837
+ v1,
838
+ v2,
839
+ v3);
840
+
841
+ // compute differences
842
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
843
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
844
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
845
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
846
+
847
+ // compute squares of differences
848
+ __m256 distances = _mm256_mul_ps(d0, d0);
849
+ distances = _mm256_fmadd_ps(d1, d1, distances);
850
+ distances = _mm256_fmadd_ps(d2, d2, distances);
851
+ distances = _mm256_fmadd_ps(d3, d3, distances);
852
+
853
+ // compare the new distances to the min distances
854
+ // for each of 8 AVX2 components.
855
+ __m256 comparison =
856
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
857
+
858
+ // update min distances and indices with closest vectors if needed.
859
+ min_distances = _mm256_min_ps(distances, min_distances);
860
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
861
+ _mm256_castsi256_ps(current_indices),
862
+ _mm256_castsi256_ps(min_indices),
863
+ comparison));
864
+
865
+ // update current indices values. Basically, +8 to each of the
866
+ // 8 AVX2 components.
867
+ current_indices =
868
+ _mm256_add_epi32(current_indices, indices_increment);
869
+
870
+ // scroll y forward (8 vectors 4 DIM each).
871
+ y += 32;
872
+ }
873
+
874
+ // dump values and find the minimum distance / minimum index
875
+ float min_distances_scalar[8];
876
+ uint32_t min_indices_scalar[8];
877
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
878
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
879
+
880
+ for (size_t j = 0; j < 8; j++) {
881
+ if (current_min_distance > min_distances_scalar[j]) {
882
+ current_min_distance = min_distances_scalar[j];
883
+ current_min_index = min_indices_scalar[j];
884
+ }
885
+ }
886
+ }
887
+
888
+ if (i < ny) {
889
+ // process leftovers
890
+ __m128 x0 = _mm_loadu_ps(x);
891
+
892
+ for (; i < ny; i++) {
893
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
894
+ y += 4;
895
+ const float distance = horizontal_sum(accu);
896
+
897
+ if (current_min_distance > distance) {
898
+ current_min_distance = distance;
899
+ current_min_index = i;
900
+ }
901
+ }
902
+ }
903
+
904
+ return current_min_index;
905
+ }
906
+
907
+ template <>
908
+ size_t fvec_L2sqr_ny_nearest_D8<SIMDLevel::AVX2>(
909
+ float* /*distances_tmp_buffer*/,
910
+ const float* x,
911
+ const float* y,
912
+ size_t ny) {
913
+ // this implementation does not use distances_tmp_buffer.
914
+
915
+ // current index being processed
916
+ size_t i = 0;
917
+
918
+ // min distance and the index of the closest vector so far
919
+ float current_min_distance = HUGE_VALF;
920
+ size_t current_min_index = 0;
921
+
922
+ // process 8 D8-vectors per loop.
923
+ const size_t ny8 = ny / 8;
924
+ if (ny8 > 0) {
925
+ // track min distance and the closest vector independently
926
+ // for each of 8 AVX2 components.
927
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
928
+ __m256i min_indices = _mm256_set1_epi32(0);
929
+
930
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
931
+ const __m256i indices_increment = _mm256_set1_epi32(8);
932
+
933
+ // 1 value per register
934
+ const __m256 m0 = _mm256_set1_ps(x[0]);
935
+ const __m256 m1 = _mm256_set1_ps(x[1]);
936
+ const __m256 m2 = _mm256_set1_ps(x[2]);
937
+ const __m256 m3 = _mm256_set1_ps(x[3]);
938
+
939
+ const __m256 m4 = _mm256_set1_ps(x[4]);
940
+ const __m256 m5 = _mm256_set1_ps(x[5]);
941
+ const __m256 m6 = _mm256_set1_ps(x[6]);
942
+ const __m256 m7 = _mm256_set1_ps(x[7]);
943
+
944
+ for (; i < ny8 * 8; i += 8) {
945
+ __m256 v0;
946
+ __m256 v1;
947
+ __m256 v2;
948
+ __m256 v3;
949
+ __m256 v4;
950
+ __m256 v5;
951
+ __m256 v6;
952
+ __m256 v7;
953
+
954
+ transpose_8x8(
955
+ _mm256_loadu_ps(y + 0 * 8),
956
+ _mm256_loadu_ps(y + 1 * 8),
957
+ _mm256_loadu_ps(y + 2 * 8),
958
+ _mm256_loadu_ps(y + 3 * 8),
959
+ _mm256_loadu_ps(y + 4 * 8),
960
+ _mm256_loadu_ps(y + 5 * 8),
961
+ _mm256_loadu_ps(y + 6 * 8),
962
+ _mm256_loadu_ps(y + 7 * 8),
963
+ v0,
964
+ v1,
965
+ v2,
966
+ v3,
967
+ v4,
968
+ v5,
969
+ v6,
970
+ v7);
971
+
972
+ // compute differences
973
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
974
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
975
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
976
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
977
+ const __m256 d4 = _mm256_sub_ps(m4, v4);
978
+ const __m256 d5 = _mm256_sub_ps(m5, v5);
979
+ const __m256 d6 = _mm256_sub_ps(m6, v6);
980
+ const __m256 d7 = _mm256_sub_ps(m7, v7);
981
+
982
+ // compute squares of differences
983
+ __m256 distances = _mm256_mul_ps(d0, d0);
984
+ distances = _mm256_fmadd_ps(d1, d1, distances);
985
+ distances = _mm256_fmadd_ps(d2, d2, distances);
986
+ distances = _mm256_fmadd_ps(d3, d3, distances);
987
+ distances = _mm256_fmadd_ps(d4, d4, distances);
988
+ distances = _mm256_fmadd_ps(d5, d5, distances);
989
+ distances = _mm256_fmadd_ps(d6, d6, distances);
990
+ distances = _mm256_fmadd_ps(d7, d7, distances);
991
+
992
+ // compare the new distances to the min distances
993
+ // for each of 8 AVX2 components.
994
+ __m256 comparison =
995
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
996
+
997
+ // update min distances and indices with closest vectors if needed.
998
+ min_distances = _mm256_min_ps(distances, min_distances);
999
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
1000
+ _mm256_castsi256_ps(current_indices),
1001
+ _mm256_castsi256_ps(min_indices),
1002
+ comparison));
1003
+
1004
+ // update current indices values. Basically, +8 to each of the
1005
+ // 8 AVX2 components.
1006
+ current_indices =
1007
+ _mm256_add_epi32(current_indices, indices_increment);
1008
+
1009
+ // scroll y forward (8 vectors 8 DIM each).
1010
+ y += 64;
1011
+ }
1012
+
1013
+ // dump values and find the minimum distance / minimum index
1014
+ float min_distances_scalar[8];
1015
+ uint32_t min_indices_scalar[8];
1016
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
1017
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
1018
+
1019
+ for (size_t j = 0; j < 8; j++) {
1020
+ if (current_min_distance > min_distances_scalar[j]) {
1021
+ current_min_distance = min_distances_scalar[j];
1022
+ current_min_index = min_indices_scalar[j];
1023
+ }
1024
+ }
1025
+ }
1026
+
1027
+ if (i < ny) {
1028
+ // process leftovers
1029
+ __m256 x0 = _mm256_loadu_ps(x);
1030
+
1031
+ for (; i < ny; i++) {
1032
+ __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y));
1033
+ y += 8;
1034
+ const float distance = horizontal_sum(accu);
1035
+
1036
+ if (current_min_distance > distance) {
1037
+ current_min_distance = distance;
1038
+ current_min_index = i;
1039
+ }
1040
+ }
1041
+ }
1042
+
1043
+ return current_min_index;
1044
+ }
1045
+
1046
+ template <>
1047
+ size_t fvec_L2sqr_ny_nearest<SIMDLevel::AVX2>(
1048
+ float* distances_tmp_buffer,
1049
+ const float* x,
1050
+ const float* y,
1051
+ size_t d,
1052
+ size_t ny) {
1053
+ return fvec_L2sqr_ny_nearest_x86<SIMDLevel::AVX2>(
1054
+ distances_tmp_buffer,
1055
+ x,
1056
+ y,
1057
+ d,
1058
+ ny,
1059
+ &fvec_L2sqr_ny_nearest_D2<SIMDLevel::AVX2>,
1060
+ &fvec_L2sqr_ny_nearest_D4<SIMDLevel::AVX2>,
1061
+ &fvec_L2sqr_ny_nearest_D8<SIMDLevel::AVX2>);
1062
+ }
1063
+
1064
+ template <size_t DIM>
1065
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
1066
+ float* /*distances_tmp_buffer*/,
1067
+ const float* x,
1068
+ const float* y,
1069
+ const float* y_sqlen,
1070
+ const size_t d_offset,
1071
+ size_t ny) {
1072
+ // this implementation does not use distances_tmp_buffer.
1073
+
1074
+ // current index being processed
1075
+ size_t i = 0;
1076
+
1077
+ // min distance and the index of the closest vector so far
1078
+ float current_min_distance = HUGE_VALF;
1079
+ size_t current_min_index = 0;
1080
+
1081
+ // process 8 vectors per loop.
1082
+ const size_t ny8 = ny / 8;
1083
+
1084
+ if (ny8 > 0) {
1085
+ // track min distance and the closest vector independently
1086
+ // for each of 8 AVX2 components.
1087
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
1088
+ __m256i min_indices = _mm256_set1_epi32(0);
1089
+
1090
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1091
+ const __m256i indices_increment = _mm256_set1_epi32(8);
1092
+
1093
+ // m[i] = (2 * x[i], ... 2 * x[i])
1094
+ __m256 m[DIM];
1095
+ for (size_t j = 0; j < DIM; j++) {
1096
+ m[j] = _mm256_set1_ps(x[j]);
1097
+ m[j] = _mm256_add_ps(m[j], m[j]);
1098
+ }
1099
+
1100
+ for (; i < ny8 * 8; i += 8) {
1101
+ // collect dim 0 for 8 D4-vectors.
1102
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
1103
+ // compute dot products
1104
+ __m256 dp = _mm256_mul_ps(m[0], v0);
1105
+
1106
+ for (size_t j = 1; j < DIM; j++) {
1107
+ // collect dim j for 8 D4-vectors.
1108
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
1109
+ dp = _mm256_fmadd_ps(m[j], vj, dp);
1110
+ }
1111
+
1112
+ // compute y^2 - (2 * x, y), which is sufficient for looking for the
1113
+ // lowest distance.
1114
+ // x^2 is the constant that can be avoided.
1115
+ const __m256 distances =
1116
+ _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
1117
+
1118
+ // compare the new distances to the min distances
1119
+ // for each of 8 AVX2 components.
1120
+ const __m256 comparison =
1121
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
1122
+
1123
+ // update min distances and indices with closest vectors if needed.
1124
+ min_distances =
1125
+ _mm256_blendv_ps(distances, min_distances, comparison);
1126
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
1127
+ _mm256_castsi256_ps(current_indices),
1128
+ _mm256_castsi256_ps(min_indices),
1129
+ comparison));
1130
+
1131
+ // update current indices values. Basically, +8 to each of the
1132
+ // 8 AVX2 components.
1133
+ current_indices =
1134
+ _mm256_add_epi32(current_indices, indices_increment);
1135
+
1136
+ // scroll y and y_sqlen forward.
1137
+ y += 8;
1138
+ y_sqlen += 8;
1139
+ }
1140
+
1141
+ // dump values and find the minimum distance / minimum index
1142
+ float min_distances_scalar[8];
1143
+ uint32_t min_indices_scalar[8];
1144
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
1145
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
1146
+
1147
+ for (size_t j = 0; j < 8; j++) {
1148
+ if (current_min_distance > min_distances_scalar[j]) {
1149
+ current_min_distance = min_distances_scalar[j];
1150
+ current_min_index = min_indices_scalar[j];
1151
+ }
1152
+ }
1153
+ }
1154
+
1155
+ if (i < ny) {
1156
+ // process leftovers
1157
+ for (; i < ny; i++) {
1158
+ float dp = 0;
1159
+ for (size_t j = 0; j < DIM; j++) {
1160
+ dp += x[j] * y[j * d_offset];
1161
+ }
1162
+
1163
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
1164
+ // lowest distance.
1165
+ const float distance = y_sqlen[0] - 2 * dp;
1166
+
1167
+ if (current_min_distance > distance) {
1168
+ current_min_distance = distance;
1169
+ current_min_index = i;
1170
+ }
1171
+
1172
+ y += 1;
1173
+ y_sqlen += 1;
1174
+ }
1175
+ }
1176
+
1177
+ return current_min_index;
1178
+ }
1179
+
1180
+ template <>
1181
+ size_t fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::AVX2>(
1182
+ float* distances_tmp_buffer,
1183
+ const float* x,
1184
+ const float* y,
1185
+ const float* y_sqlen,
1186
+ size_t d,
1187
+ size_t d_offset,
1188
+ size_t ny) {
1189
+ // optimized for a few special cases
1190
+ #define DISPATCH(dval) \
1191
+ case dval: \
1192
+ return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
1193
+ distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
1194
+
1195
+ switch (d) {
1196
+ DISPATCH(1)
1197
+ DISPATCH(2)
1198
+ DISPATCH(4)
1199
+ DISPATCH(8)
1200
+ default:
1201
+ return fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::NONE>(
1202
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1203
+ }
1204
+ #undef DISPATCH
1205
+ }
1206
+
1207
+ template <>
1208
+ int fvec_madd_and_argmin<SIMDLevel::AVX2>(
1209
+ size_t n,
1210
+ const float* a,
1211
+ float bf,
1212
+ const float* b,
1213
+ float* c) {
1214
+ return fvec_madd_and_argmin_sse(n, a, bf, b, c);
1215
+ }
1216
+
1217
+ template <>
1218
+ void exhaustive_L2sqr_blas_cmax<SIMDLevel::AVX2>(
1219
+ const float* x,
1220
+ const float* y,
1221
+ size_t d,
1222
+ size_t nx,
1223
+ size_t ny,
1224
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
1225
+ const float* y_norms) {
1226
+ // BLAS does not like empty matrices
1227
+ if (nx == 0 || ny == 0) {
1228
+ return;
1229
+ }
1230
+
1231
+ /* block sizes */
1232
+ const size_t bs_x = distance_compute_blas_query_bs;
1233
+ const size_t bs_y = distance_compute_blas_database_bs;
1234
+ // const size_t bs_x = 16, bs_y = 16;
1235
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
1236
+ std::unique_ptr<float[]> x_norms(new float[nx]);
1237
+ std::unique_ptr<float[]> del2;
1238
+
1239
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
1240
+
1241
+ if (!y_norms) {
1242
+ float* y_norms2 = new float[ny];
1243
+ del2.reset(y_norms2);
1244
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
1245
+ y_norms = y_norms2;
1246
+ }
1247
+
1248
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
1249
+ size_t i1 = i0 + bs_x;
1250
+ if (i1 > nx) {
1251
+ i1 = nx;
1252
+ }
1253
+
1254
+ res.begin_multiple(i0, i1);
1255
+
1256
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
1257
+ size_t j1 = j0 + bs_y;
1258
+ if (j1 > ny) {
1259
+ j1 = ny;
1260
+ }
1261
+ /* compute the actual dot products */
1262
+ {
1263
+ float one = 1, zero = 0;
1264
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
1265
+ sgemm_("Transpose",
1266
+ "Not transpose",
1267
+ &nyi,
1268
+ &nxi,
1269
+ &di,
1270
+ &one,
1271
+ y + j0 * d,
1272
+ &di,
1273
+ x + i0 * d,
1274
+ &di,
1275
+ &zero,
1276
+ ip_block.get(),
1277
+ &nyi);
1278
+ }
1279
+ #pragma omp parallel for schedule(static) if ((i1 - i0) >= 16)
1280
+ for (int64_t i = static_cast<int64_t>(i0);
1281
+ i < static_cast<int64_t>(i1);
1282
+ i++) {
1283
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
1284
+
1285
+ _mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
1286
+ _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
1287
+
1288
+ // constant
1289
+ const __m256 mul_minus2 = _mm256_set1_ps(-2);
1290
+
1291
+ // Track 8 min distances + 8 min indices.
1292
+ // All the distances tracked do not take x_norms[i]
1293
+ // into account in order to get rid of extra
1294
+ // _mm256_add_ps(x_norms[i], ...) instructions
1295
+ // is distance computations.
1296
+ __m256 min_distances =
1297
+ _mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
1298
+
1299
+ // these indices are local and are relative to j0.
1300
+ // so, value 0 means j0.
1301
+ __m256i min_indices = _mm256_set1_epi32(0);
1302
+
1303
+ __m256i current_indices =
1304
+ _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1305
+ const __m256i indices_delta = _mm256_set1_epi32(8);
1306
+
1307
+ // current j index
1308
+ size_t idx_j = 0;
1309
+ size_t count = j1 - j0;
1310
+
1311
+ // process 16 elements per loop
1312
+ for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
1313
+ _mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
1314
+ _mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
1315
+
1316
+ // load values for norms
1317
+ const __m256 y_norm_0 =
1318
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 0);
1319
+ const __m256 y_norm_1 =
1320
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 8);
1321
+
1322
+ // load values for dot products
1323
+ const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
1324
+ const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
1325
+
1326
+ // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
1327
+ // x_norm[i] was dropped off because it is a constant for a
1328
+ // given i. We'll deal with it later.
1329
+ __m256 distances_0 =
1330
+ _mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
1331
+ __m256 distances_1 =
1332
+ _mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
1333
+
1334
+ // compare the new distances to the min distances
1335
+ // for each of the first group of 8 AVX2 components.
1336
+ const __m256 comparison_0 = _mm256_cmp_ps(
1337
+ min_distances, distances_0, _CMP_LE_OS);
1338
+
1339
+ // update min distances and indices with closest vectors if
1340
+ // needed.
1341
+ min_distances = _mm256_blendv_ps(
1342
+ distances_0, min_distances, comparison_0);
1343
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
1344
+ _mm256_castsi256_ps(current_indices),
1345
+ _mm256_castsi256_ps(min_indices),
1346
+ comparison_0));
1347
+ current_indices =
1348
+ _mm256_add_epi32(current_indices, indices_delta);
1349
+
1350
+ // compare the new distances to the min distances
1351
+ // for each of the second group of 8 AVX2 components.
1352
+ const __m256 comparison_1 = _mm256_cmp_ps(
1353
+ min_distances, distances_1, _CMP_LE_OS);
1354
+
1355
+ // update min distances and indices with closest vectors if
1356
+ // needed.
1357
+ min_distances = _mm256_blendv_ps(
1358
+ distances_1, min_distances, comparison_1);
1359
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
1360
+ _mm256_castsi256_ps(current_indices),
1361
+ _mm256_castsi256_ps(min_indices),
1362
+ comparison_1));
1363
+ current_indices =
1364
+ _mm256_add_epi32(current_indices, indices_delta);
1365
+ }
1366
+
1367
+ // dump values and find the minimum distance / minimum index
1368
+ float min_distances_scalar[8];
1369
+ uint32_t min_indices_scalar[8];
1370
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
1371
+ _mm256_storeu_si256(
1372
+ (__m256i*)(min_indices_scalar), min_indices);
1373
+
1374
+ float current_min_distance = res.dis_tab[i];
1375
+ uint32_t current_min_index = res.ids_tab[i];
1376
+
1377
+ // This unusual comparison is needed to maintain the behavior
1378
+ // of the original implementation: if two indices are
1379
+ // represented with equal distance values, then
1380
+ // the index with the min value is returned.
1381
+ for (size_t jv = 0; jv < 8; jv++) {
1382
+ // add missing x_norms[i]
1383
+ float distance_candidate =
1384
+ min_distances_scalar[jv] + x_norms[i];
1385
+
1386
+ // negative values can occur for identical vectors
1387
+ // due to roundoff errors.
1388
+ if (distance_candidate < 0) {
1389
+ distance_candidate = 0;
1390
+ }
1391
+
1392
+ int64_t index_candidate = min_indices_scalar[jv] + j0;
1393
+
1394
+ if (current_min_distance > distance_candidate) {
1395
+ current_min_distance = distance_candidate;
1396
+ current_min_index = index_candidate;
1397
+ } else if (
1398
+ current_min_distance == distance_candidate &&
1399
+ current_min_index > index_candidate) {
1400
+ current_min_index = index_candidate;
1401
+ }
1402
+ }
1403
+
1404
+ // process leftovers
1405
+ for (; idx_j < count; idx_j++, ip_line++) {
1406
+ float ip = *ip_line;
1407
+ float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
1408
+ // negative values can occur for identical vectors
1409
+ // due to roundoff errors.
1410
+ if (dis < 0) {
1411
+ dis = 0;
1412
+ }
1413
+
1414
+ if (current_min_distance > dis) {
1415
+ current_min_distance = dis;
1416
+ current_min_index = idx_j + j0;
1417
+ }
1418
+ }
1419
+
1420
+ //
1421
+ res.add_result(i, current_min_distance, current_min_index);
1422
+ }
1423
+ }
1424
+ // Does nothing for SingleBestResultHandler, but
1425
+ // keeping the call for the consistency.
1426
+ res.end_multiple();
1427
+ InterruptCallback::check();
1428
+ }
1429
+ }
1430
+
1431
+ } // namespace faiss