faiss 0.3.1 → 0.3.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -23,6 +23,7 @@
23
23
  #include <faiss/impl/AuxIndexStructures.h>
24
24
  #include <faiss/impl/FaissAssert.h>
25
25
  #include <faiss/impl/IDSelector.h>
26
+ #include <faiss/utils/bf16.h>
26
27
  #include <faiss/utils/fp16.h>
27
28
  #include <faiss/utils/utils.h>
28
29
 
@@ -43,7 +44,9 @@ namespace faiss {
43
44
  * that hides the template mess.
44
45
  ********************************************************************/
45
46
 
46
- #ifdef __AVX2__
47
+ #if defined(__AVX512F__) && defined(__F16C__)
48
+ #define USE_AVX512_F16C
49
+ #elif defined(__AVX2__)
47
50
  #ifdef __F16C__
48
51
  #define USE_F16C
49
52
  #else
@@ -52,6 +55,15 @@ namespace faiss {
52
55
  #endif
53
56
  #endif
54
57
 
58
+ #if defined(__aarch64__)
59
+ #if defined(__GNUC__) && __GNUC__ < 8
60
+ #warning \
61
+ "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8"
62
+ #else
63
+ #define USE_NEON
64
+ #endif
65
+ #endif
66
+
55
67
  namespace {
56
68
 
57
69
  typedef ScalarQuantizer::QuantizerType QuantizerType;
@@ -78,7 +90,17 @@ struct Codec8bit {
78
90
  return (code[i] + 0.5f) / 255.0f;
79
91
  }
80
92
 
81
- #ifdef __AVX2__
93
+ #if defined(__AVX512F__)
94
+ static FAISS_ALWAYS_INLINE __m512
95
+ decode_16_components(const uint8_t* code, int i) {
96
+ const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
97
+ const __m512i i32 = _mm512_cvtepu8_epi32(c16);
98
+ const __m512 f16 = _mm512_cvtepi32_ps(i32);
99
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
100
+ const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
101
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
102
+ }
103
+ #elif defined(__AVX2__)
82
104
  static FAISS_ALWAYS_INLINE __m256
83
105
  decode_8_components(const uint8_t* code, int i) {
84
106
  const uint64_t c8 = *(uint64_t*)(code + i);
@@ -92,7 +114,7 @@ struct Codec8bit {
92
114
  }
93
115
  #endif
94
116
 
95
- #ifdef __aarch64__
117
+ #ifdef USE_NEON
96
118
  static FAISS_ALWAYS_INLINE float32x4x2_t
97
119
  decode_8_components(const uint8_t* code, int i) {
98
120
  float32_t result[8] = {};
@@ -101,8 +123,7 @@ struct Codec8bit {
101
123
  }
102
124
  float32x4_t res1 = vld1q_f32(result);
103
125
  float32x4_t res2 = vld1q_f32(result + 4);
104
- float32x4x2_t res = vzipq_f32(res1, res2);
105
- return vuzpq_f32(res.val[0], res.val[1]);
126
+ return {res1, res2};
106
127
  }
107
128
  #endif
108
129
  };
@@ -121,7 +142,26 @@ struct Codec4bit {
121
142
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
122
143
  }
123
144
 
124
- #ifdef __AVX2__
145
+ #if defined(__AVX512F__)
146
+ static FAISS_ALWAYS_INLINE __m512
147
+ decode_16_components(const uint8_t* code, int i) {
148
+ uint64_t c8 = *(uint64_t*)(code + (i >> 1));
149
+ uint64_t mask = 0x0f0f0f0f0f0f0f0f;
150
+ uint64_t c8ev = c8 & mask;
151
+ uint64_t c8od = (c8 >> 4) & mask;
152
+
153
+ __m128i c16 =
154
+ _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
155
+ __m256i c8lo = _mm256_cvtepu8_epi32(c16);
156
+ __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
157
+ __m512i i16 = _mm512_castsi256_si512(c8lo);
158
+ i16 = _mm512_inserti32x8(i16, c8hi, 1);
159
+ __m512 f16 = _mm512_cvtepi32_ps(i16);
160
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
161
+ const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
162
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
163
+ }
164
+ #elif defined(__AVX2__)
125
165
  static FAISS_ALWAYS_INLINE __m256
126
166
  decode_8_components(const uint8_t* code, int i) {
127
167
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
@@ -144,7 +184,7 @@ struct Codec4bit {
144
184
  }
145
185
  #endif
146
186
 
147
- #ifdef __aarch64__
187
+ #ifdef USE_NEON
148
188
  static FAISS_ALWAYS_INLINE float32x4x2_t
149
189
  decode_8_components(const uint8_t* code, int i) {
150
190
  float32_t result[8] = {};
@@ -153,8 +193,7 @@ struct Codec4bit {
153
193
  }
154
194
  float32x4_t res1 = vld1q_f32(result);
155
195
  float32x4_t res2 = vld1q_f32(result + 4);
156
- float32x4x2_t res = vzipq_f32(res1, res2);
157
- return vuzpq_f32(res.val[0], res.val[1]);
196
+ return {res1, res2};
158
197
  }
159
198
  #endif
160
199
  };
@@ -208,7 +247,56 @@ struct Codec6bit {
208
247
  return (bits + 0.5f) / 63.0f;
209
248
  }
210
249
 
211
- #ifdef __AVX2__
250
+ #if defined(__AVX512F__)
251
+
252
+ static FAISS_ALWAYS_INLINE __m512
253
+ decode_16_components(const uint8_t* code, int i) {
254
+ // pure AVX512 implementation (not necessarily the fastest).
255
+ // see:
256
+ // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
257
+
258
+ // clang-format off
259
+
260
+ // 16 components, 16x6 bit=12 bytes
261
+ const __m128i bit_6v =
262
+ _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
263
+ const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
264
+
265
+ // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
266
+ // 00 01 02 03
267
+ const __m256i shuffle_mask = _mm256_setr_epi16(
268
+ 0xFF00, 0x0100, 0x0201, 0xFF02,
269
+ 0xFF03, 0x0403, 0x0504, 0xFF05,
270
+ 0xFF06, 0x0706, 0x0807, 0xFF08,
271
+ 0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
272
+ const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
273
+
274
+ // 0: xxxxxxxx xx543210
275
+ // 1: xxxx5432 10xxxxxx
276
+ // 2: xxxxxx54 3210xxxx
277
+ // 3: xxxxxxxx 543210xx
278
+ const __m256i shift_right_v = _mm256_setr_epi16(
279
+ 0x0U, 0x6U, 0x4U, 0x2U,
280
+ 0x0U, 0x6U, 0x4U, 0x2U,
281
+ 0x0U, 0x6U, 0x4U, 0x2U,
282
+ 0x0U, 0x6U, 0x4U, 0x2U);
283
+ __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
284
+
285
+ // remove unneeded bits
286
+ shuffled_shifted =
287
+ _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
288
+
289
+ // scale
290
+ const __m512 f8 =
291
+ _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
292
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
293
+ const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
294
+ return _mm512_fmadd_ps(f8, one_255, half_one_255);
295
+
296
+ // clang-format on
297
+ }
298
+
299
+ #elif defined(__AVX2__)
212
300
 
213
301
  /* Load 6 bytes that represent 8 6-bit values, return them as a
214
302
  * 8*32 bit vector register */
@@ -257,7 +345,7 @@ struct Codec6bit {
257
345
 
258
346
  #endif
259
347
 
260
- #ifdef __aarch64__
348
+ #ifdef USE_NEON
261
349
  static FAISS_ALWAYS_INLINE float32x4x2_t
262
350
  decode_8_components(const uint8_t* code, int i) {
263
351
  float32_t result[8] = {};
@@ -266,8 +354,7 @@ struct Codec6bit {
266
354
  }
267
355
  float32x4_t res1 = vld1q_f32(result);
268
356
  float32x4_t res2 = vld1q_f32(result + 4);
269
- float32x4x2_t res = vzipq_f32(res1, res2);
270
- return vuzpq_f32(res.val[0], res.val[1]);
357
+ return {res1, res2};
271
358
  }
272
359
  #endif
273
360
  };
@@ -277,11 +364,14 @@ struct Codec6bit {
277
364
  * through a codec
278
365
  *******************************************************************/
279
366
 
280
- template <class Codec, bool uniform, int SIMD>
367
+ enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 };
368
+
369
+ template <class Codec, QuantizerTemplateScaling SCALING, int SIMD>
281
370
  struct QuantizerTemplate {};
282
371
 
283
372
  template <class Codec>
284
- struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
373
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>
374
+ : ScalarQuantizer::SQuantizer {
285
375
  const size_t d;
286
376
  const float vmin, vdiff;
287
377
 
@@ -318,12 +408,33 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
318
408
  }
319
409
  };
320
410
 
321
- #ifdef __AVX2__
411
+ #if defined(__AVX512F__)
412
+
413
+ template <class Codec>
414
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 16>
415
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
416
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
417
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
418
+ d,
419
+ trained) {}
420
+
421
+ FAISS_ALWAYS_INLINE __m512
422
+ reconstruct_16_components(const uint8_t* code, int i) const {
423
+ __m512 xi = Codec::decode_16_components(code, i);
424
+ return _mm512_fmadd_ps(
425
+ xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin));
426
+ }
427
+ };
428
+
429
+ #elif defined(__AVX2__)
322
430
 
323
431
  template <class Codec>
324
- struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
432
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
433
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
325
434
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
326
- : QuantizerTemplate<Codec, true, 1>(d, trained) {}
435
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
436
+ d,
437
+ trained) {}
327
438
 
328
439
  FAISS_ALWAYS_INLINE __m256
329
440
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -335,33 +446,35 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
335
446
 
336
447
  #endif
337
448
 
338
- #ifdef __aarch64__
449
+ #ifdef USE_NEON
339
450
 
340
451
  template <class Codec>
341
- struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
452
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
453
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
342
454
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
343
- : QuantizerTemplate<Codec, true, 1>(d, trained) {}
455
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
456
+ d,
457
+ trained) {}
344
458
 
345
459
  FAISS_ALWAYS_INLINE float32x4x2_t
346
460
  reconstruct_8_components(const uint8_t* code, int i) const {
347
461
  float32x4x2_t xi = Codec::decode_8_components(code, i);
348
- float32x4x2_t res = vzipq_f32(
349
- vfmaq_f32(
462
+ return {vfmaq_f32(
350
463
  vdupq_n_f32(this->vmin),
351
464
  xi.val[0],
352
465
  vdupq_n_f32(this->vdiff)),
353
466
  vfmaq_f32(
354
467
  vdupq_n_f32(this->vmin),
355
468
  xi.val[1],
356
- vdupq_n_f32(this->vdiff)));
357
- return vuzpq_f32(res.val[0], res.val[1]);
469
+ vdupq_n_f32(this->vdiff))};
358
470
  }
359
471
  };
360
472
 
361
473
  #endif
362
474
 
363
475
  template <class Codec>
364
- struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
476
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>
477
+ : ScalarQuantizer::SQuantizer {
365
478
  const size_t d;
366
479
  const float *vmin, *vdiff;
367
480
 
@@ -398,12 +511,37 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
398
511
  }
399
512
  };
400
513
 
401
- #ifdef __AVX2__
514
+ #if defined(__AVX512F__)
515
+
516
+ template <class Codec>
517
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 16>
518
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
519
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
520
+ : QuantizerTemplate<
521
+ Codec,
522
+ QuantizerTemplateScaling::NON_UNIFORM,
523
+ 1>(d, trained) {}
524
+
525
+ FAISS_ALWAYS_INLINE __m512
526
+ reconstruct_16_components(const uint8_t* code, int i) const {
527
+ __m512 xi = Codec::decode_16_components(code, i);
528
+ return _mm512_fmadd_ps(
529
+ xi,
530
+ _mm512_loadu_ps(this->vdiff + i),
531
+ _mm512_loadu_ps(this->vmin + i));
532
+ }
533
+ };
534
+
535
+ #elif defined(__AVX2__)
402
536
 
403
537
  template <class Codec>
404
- struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
538
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
539
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
405
540
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
406
- : QuantizerTemplate<Codec, false, 1>(d, trained) {}
541
+ : QuantizerTemplate<
542
+ Codec,
543
+ QuantizerTemplateScaling::NON_UNIFORM,
544
+ 1>(d, trained) {}
407
545
 
408
546
  FAISS_ALWAYS_INLINE __m256
409
547
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -417,12 +555,16 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
417
555
 
418
556
  #endif
419
557
 
420
- #ifdef __aarch64__
558
+ #ifdef USE_NEON
421
559
 
422
560
  template <class Codec>
423
- struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
561
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
562
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
424
563
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
425
- : QuantizerTemplate<Codec, false, 1>(d, trained) {}
564
+ : QuantizerTemplate<
565
+ Codec,
566
+ QuantizerTemplateScaling::NON_UNIFORM,
567
+ 1>(d, trained) {}
426
568
 
427
569
  FAISS_ALWAYS_INLINE float32x4x2_t
428
570
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -431,10 +573,8 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
431
573
  float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
432
574
  float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
433
575
 
434
- float32x4x2_t res = vzipq_f32(
435
- vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
436
- vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
437
- return vuzpq_f32(res.val[0], res.val[1]);
576
+ return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
577
+ vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])};
438
578
  }
439
579
  };
440
580
 
@@ -471,7 +611,23 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
471
611
  }
472
612
  };
473
613
 
474
- #ifdef USE_F16C
614
+ #if defined(USE_AVX512_F16C)
615
+
616
+ template <>
617
+ struct QuantizerFP16<16> : QuantizerFP16<1> {
618
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
619
+ : QuantizerFP16<1>(d, trained) {}
620
+
621
+ FAISS_ALWAYS_INLINE __m512
622
+ reconstruct_16_components(const uint8_t* code, int i) const {
623
+ __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
624
+ return _mm512_cvtph_ps(codei);
625
+ }
626
+ };
627
+
628
+ #endif
629
+
630
+ #if defined(USE_F16C)
475
631
 
476
632
  template <>
477
633
  struct QuantizerFP16<8> : QuantizerFP16<1> {
@@ -487,7 +643,7 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
487
643
 
488
644
  #endif
489
645
 
490
- #ifdef __aarch64__
646
+ #ifdef USE_NEON
491
647
 
492
648
  template <>
493
649
  struct QuantizerFP16<8> : QuantizerFP16<1> {
@@ -496,10 +652,90 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
496
652
 
497
653
  FAISS_ALWAYS_INLINE float32x4x2_t
498
654
  reconstruct_8_components(const uint8_t* code, int i) const {
499
- uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
500
- return vzipq_f32(
501
- vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
502
- vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
655
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
656
+ return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
657
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))};
658
+ }
659
+ };
660
+ #endif
661
+
662
+ /*******************************************************************
663
+ * BF16 quantizer
664
+ *******************************************************************/
665
+
666
+ template <int SIMDWIDTH>
667
+ struct QuantizerBF16 {};
668
+
669
+ template <>
670
+ struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
671
+ const size_t d;
672
+
673
+ QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
674
+
675
+ void encode_vector(const float* x, uint8_t* code) const final {
676
+ for (size_t i = 0; i < d; i++) {
677
+ ((uint16_t*)code)[i] = encode_bf16(x[i]);
678
+ }
679
+ }
680
+
681
+ void decode_vector(const uint8_t* code, float* x) const final {
682
+ for (size_t i = 0; i < d; i++) {
683
+ x[i] = decode_bf16(((uint16_t*)code)[i]);
684
+ }
685
+ }
686
+
687
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
688
+ const {
689
+ return decode_bf16(((uint16_t*)code)[i]);
690
+ }
691
+ };
692
+
693
+ #if defined(__AVX512F__)
694
+
695
+ template <>
696
+ struct QuantizerBF16<16> : QuantizerBF16<1> {
697
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
698
+ : QuantizerBF16<1>(d, trained) {}
699
+ FAISS_ALWAYS_INLINE __m512
700
+ reconstruct_16_components(const uint8_t* code, int i) const {
701
+ __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
702
+ __m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
703
+ code_512i = _mm512_slli_epi32(code_512i, 16);
704
+ return _mm512_castsi512_ps(code_512i);
705
+ }
706
+ };
707
+
708
+ #elif defined(__AVX2__)
709
+
710
+ template <>
711
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
712
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
713
+ : QuantizerBF16<1>(d, trained) {}
714
+
715
+ FAISS_ALWAYS_INLINE __m256
716
+ reconstruct_8_components(const uint8_t* code, int i) const {
717
+ __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
718
+ __m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
719
+ code_256i = _mm256_slli_epi32(code_256i, 16);
720
+ return _mm256_castsi256_ps(code_256i);
721
+ }
722
+ };
723
+
724
+ #endif
725
+
726
+ #ifdef USE_NEON
727
+
728
+ template <>
729
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
730
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
731
+ : QuantizerBF16<1>(d, trained) {}
732
+
733
+ FAISS_ALWAYS_INLINE float32x4x2_t
734
+ reconstruct_8_components(const uint8_t* code, int i) const {
735
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
736
+ return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
737
+ vreinterpretq_f32_u32(
738
+ vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
503
739
  }
504
740
  };
505
741
  #endif
@@ -536,7 +772,22 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
536
772
  }
537
773
  };
538
774
 
539
- #ifdef __AVX2__
775
+ #if defined(__AVX512F__)
776
+
777
+ template <>
778
+ struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> {
779
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
780
+ : Quantizer8bitDirect<1>(d, trained) {}
781
+
782
+ FAISS_ALWAYS_INLINE __m512
783
+ reconstruct_16_components(const uint8_t* code, int i) const {
784
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
785
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
786
+ return _mm512_cvtepi32_ps(y16); // 16 * float32
787
+ }
788
+ };
789
+
790
+ #elif defined(__AVX2__)
540
791
 
541
792
  template <>
542
793
  struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
@@ -553,7 +804,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
553
804
 
554
805
  #endif
555
806
 
556
- #ifdef __aarch64__
807
+ #ifdef USE_NEON
557
808
 
558
809
  template <>
559
810
  struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
@@ -562,14 +813,107 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
562
813
 
563
814
  FAISS_ALWAYS_INLINE float32x4x2_t
564
815
  reconstruct_8_components(const uint8_t* code, int i) const {
565
- float32_t result[8] = {};
566
- for (size_t j = 0; j < 8; j++) {
567
- result[j] = code[i + j];
816
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
817
+ uint16x8_t y8 = vmovl_u8(x8);
818
+ uint16x4_t y8_0 = vget_low_u16(y8);
819
+ uint16x4_t y8_1 = vget_high_u16(y8);
820
+
821
+ // convert uint16 -> uint32 -> fp32
822
+ return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))};
823
+ }
824
+ };
825
+
826
+ #endif
827
+
828
+ /*******************************************************************
829
+ * 8bit_direct_signed quantizer
830
+ *******************************************************************/
831
+
832
+ template <int SIMDWIDTH>
833
+ struct Quantizer8bitDirectSigned {};
834
+
835
+ template <>
836
+ struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer {
837
+ const size_t d;
838
+
839
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& /* unused */)
840
+ : d(d) {}
841
+
842
+ void encode_vector(const float* x, uint8_t* code) const final {
843
+ for (size_t i = 0; i < d; i++) {
844
+ code[i] = (uint8_t)(x[i] + 128);
568
845
  }
569
- float32x4_t res1 = vld1q_f32(result);
570
- float32x4_t res2 = vld1q_f32(result + 4);
571
- float32x4x2_t res = vzipq_f32(res1, res2);
572
- return vuzpq_f32(res.val[0], res.val[1]);
846
+ }
847
+
848
+ void decode_vector(const uint8_t* code, float* x) const final {
849
+ for (size_t i = 0; i < d; i++) {
850
+ x[i] = code[i] - 128;
851
+ }
852
+ }
853
+
854
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
855
+ const {
856
+ return code[i] - 128;
857
+ }
858
+ };
859
+
860
+ #if defined(__AVX512F__)
861
+
862
+ template <>
863
+ struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> {
864
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
865
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
866
+
867
+ FAISS_ALWAYS_INLINE __m512
868
+ reconstruct_16_components(const uint8_t* code, int i) const {
869
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
870
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
871
+ __m512i c16 = _mm512_set1_epi32(128);
872
+ __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
873
+ return _mm512_cvtepi32_ps(z16); // 16 * float32
874
+ }
875
+ };
876
+
877
+ #elif defined(__AVX2__)
878
+
879
+ template <>
880
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
881
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
882
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
883
+
884
+ FAISS_ALWAYS_INLINE __m256
885
+ reconstruct_8_components(const uint8_t* code, int i) const {
886
+ __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
887
+ __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
888
+ __m256i c8 = _mm256_set1_epi32(128);
889
+ __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes
890
+ return _mm256_cvtepi32_ps(z8); // 8 * float32
891
+ }
892
+ };
893
+
894
+ #endif
895
+
896
+ #ifdef USE_NEON
897
+
898
+ template <>
899
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
900
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
901
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
902
+
903
+ FAISS_ALWAYS_INLINE float32x4x2_t
904
+ reconstruct_8_components(const uint8_t* code, int i) const {
905
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
906
+ uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16
907
+ uint16x4_t y8_0 = vget_low_u16(y8);
908
+ uint16x4_t y8_1 = vget_high_u16(y8);
909
+
910
+ float32x4_t z8_0 = vcvtq_f32_u32(
911
+ vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32
912
+ float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1));
913
+
914
+ // subtract 128 to convert into signed numbers
915
+ return {vsubq_f32(z8_0, vmovq_n_f32(128.0)),
916
+ vsubq_f32(z8_1, vmovq_n_f32(128.0))};
573
917
  }
574
918
  };
575
919
 
@@ -582,24 +926,38 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
582
926
  const std::vector<float>& trained) {
583
927
  switch (qtype) {
584
928
  case ScalarQuantizer::QT_8bit:
585
- return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
586
- d, trained);
929
+ return new QuantizerTemplate<
930
+ Codec8bit,
931
+ QuantizerTemplateScaling::NON_UNIFORM,
932
+ SIMDWIDTH>(d, trained);
587
933
  case ScalarQuantizer::QT_6bit:
588
- return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
589
- d, trained);
934
+ return new QuantizerTemplate<
935
+ Codec6bit,
936
+ QuantizerTemplateScaling::NON_UNIFORM,
937
+ SIMDWIDTH>(d, trained);
590
938
  case ScalarQuantizer::QT_4bit:
591
- return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
592
- d, trained);
939
+ return new QuantizerTemplate<
940
+ Codec4bit,
941
+ QuantizerTemplateScaling::NON_UNIFORM,
942
+ SIMDWIDTH>(d, trained);
593
943
  case ScalarQuantizer::QT_8bit_uniform:
594
- return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
595
- d, trained);
944
+ return new QuantizerTemplate<
945
+ Codec8bit,
946
+ QuantizerTemplateScaling::UNIFORM,
947
+ SIMDWIDTH>(d, trained);
596
948
  case ScalarQuantizer::QT_4bit_uniform:
597
- return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
598
- d, trained);
949
+ return new QuantizerTemplate<
950
+ Codec4bit,
951
+ QuantizerTemplateScaling::UNIFORM,
952
+ SIMDWIDTH>(d, trained);
599
953
  case ScalarQuantizer::QT_fp16:
600
954
  return new QuantizerFP16<SIMDWIDTH>(d, trained);
955
+ case ScalarQuantizer::QT_bf16:
956
+ return new QuantizerBF16<SIMDWIDTH>(d, trained);
601
957
  case ScalarQuantizer::QT_8bit_direct:
602
958
  return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
959
+ case ScalarQuantizer::QT_8bit_direct_signed:
960
+ return new Quantizer8bitDirectSigned<SIMDWIDTH>(d, trained);
603
961
  }
604
962
  FAISS_THROW_MSG("unknown qtype");
605
963
  }
@@ -816,7 +1174,43 @@ struct SimilarityL2<1> {
816
1174
  }
817
1175
  };
818
1176
 
819
- #ifdef __AVX2__
1177
+ #if defined(__AVX512F__)
1178
+
1179
+ template <>
1180
+ struct SimilarityL2<16> {
1181
+ static constexpr int simdwidth = 16;
1182
+ static constexpr MetricType metric_type = METRIC_L2;
1183
+
1184
+ const float *y, *yi;
1185
+
1186
+ explicit SimilarityL2(const float* y) : y(y) {}
1187
+ __m512 accu16;
1188
+
1189
+ FAISS_ALWAYS_INLINE void begin_16() {
1190
+ accu16 = _mm512_setzero_ps();
1191
+ yi = y;
1192
+ }
1193
+
1194
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1195
+ __m512 yiv = _mm512_loadu_ps(yi);
1196
+ yi += 16;
1197
+ __m512 tmp = _mm512_sub_ps(yiv, x);
1198
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1199
+ }
1200
+
1201
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) {
1202
+ __m512 tmp = _mm512_sub_ps(y_2, x);
1203
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1204
+ }
1205
+
1206
+ FAISS_ALWAYS_INLINE float result_16() {
1207
+ // performs better than dividing into _mm256 and adding
1208
+ return _mm512_reduce_add_ps(accu16);
1209
+ }
1210
+ };
1211
+
1212
+ #elif defined(__AVX2__)
1213
+
820
1214
  template <>
821
1215
  struct SimilarityL2<8> {
822
1216
  static constexpr int simdwidth = 8;
@@ -857,7 +1251,7 @@ struct SimilarityL2<8> {
857
1251
 
858
1252
  #endif
859
1253
 
860
- #ifdef __aarch64__
1254
+ #ifdef USE_NEON
861
1255
  template <>
862
1256
  struct SimilarityL2<8> {
863
1257
  static constexpr int simdwidth = 8;
@@ -868,7 +1262,7 @@ struct SimilarityL2<8> {
868
1262
  float32x4x2_t accu8;
869
1263
 
870
1264
  FAISS_ALWAYS_INLINE void begin_8() {
871
- accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1265
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
872
1266
  yi = y;
873
1267
  }
874
1268
 
@@ -882,8 +1276,7 @@ struct SimilarityL2<8> {
882
1276
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
883
1277
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
884
1278
 
885
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
886
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1279
+ accu8 = {accu8_0, accu8_1};
887
1280
  }
888
1281
 
889
1282
  FAISS_ALWAYS_INLINE void add_8_components_2(
@@ -895,8 +1288,7 @@ struct SimilarityL2<8> {
895
1288
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
896
1289
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
897
1290
 
898
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
899
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1291
+ accu8 = {accu8_0, accu8_1};
900
1292
  }
901
1293
 
902
1294
  FAISS_ALWAYS_INLINE float result_8() {
@@ -941,7 +1333,43 @@ struct SimilarityIP<1> {
941
1333
  }
942
1334
  };
943
1335
 
944
- #ifdef __AVX2__
1336
+ #if defined(__AVX512F__)
1337
+
1338
+ template <>
1339
+ struct SimilarityIP<16> {
1340
+ static constexpr int simdwidth = 16;
1341
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1342
+
1343
+ const float *y, *yi;
1344
+
1345
+ float accu;
1346
+
1347
+ explicit SimilarityIP(const float* y) : y(y) {}
1348
+
1349
+ __m512 accu16;
1350
+
1351
+ FAISS_ALWAYS_INLINE void begin_16() {
1352
+ accu16 = _mm512_setzero_ps();
1353
+ yi = y;
1354
+ }
1355
+
1356
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1357
+ __m512 yiv = _mm512_loadu_ps(yi);
1358
+ yi += 16;
1359
+ accu16 = _mm512_fmadd_ps(yiv, x, accu16);
1360
+ }
1361
+
1362
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) {
1363
+ accu16 = _mm512_fmadd_ps(x1, x2, accu16);
1364
+ }
1365
+
1366
+ FAISS_ALWAYS_INLINE float result_16() {
1367
+ // performs better than dividing into _mm256 and adding
1368
+ return _mm512_reduce_add_ps(accu16);
1369
+ }
1370
+ };
1371
+
1372
+ #elif defined(__AVX2__)
945
1373
 
946
1374
  template <>
947
1375
  struct SimilarityIP<8> {
@@ -983,7 +1411,7 @@ struct SimilarityIP<8> {
983
1411
  };
984
1412
  #endif
985
1413
 
986
- #ifdef __aarch64__
1414
+ #ifdef USE_NEON
987
1415
 
988
1416
  template <>
989
1417
  struct SimilarityIP<8> {
@@ -996,7 +1424,7 @@ struct SimilarityIP<8> {
996
1424
  float32x4x2_t accu8;
997
1425
 
998
1426
  FAISS_ALWAYS_INLINE void begin_8() {
999
- accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1427
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1000
1428
  yi = y;
1001
1429
  }
1002
1430
 
@@ -1006,8 +1434,7 @@ struct SimilarityIP<8> {
1006
1434
 
1007
1435
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1008
1436
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1009
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1010
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1437
+ accu8 = {accu8_0, accu8_1};
1011
1438
  }
1012
1439
 
1013
1440
  FAISS_ALWAYS_INLINE void add_8_components_2(
@@ -1015,19 +1442,17 @@ struct SimilarityIP<8> {
1015
1442
  float32x4x2_t x2) {
1016
1443
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
1017
1444
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1018
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1019
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1445
+ accu8 = {accu8_0, accu8_1};
1020
1446
  }
1021
1447
 
1022
1448
  FAISS_ALWAYS_INLINE float result_8() {
1023
- float32x4x2_t sum_tmp = vzipq_f32(
1449
+ float32x4x2_t sum = {
1024
1450
  vpaddq_f32(accu8.val[0], accu8.val[0]),
1025
- vpaddq_f32(accu8.val[1], accu8.val[1]));
1026
- float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
1027
- float32x4x2_t sum2_tmp = vzipq_f32(
1451
+ vpaddq_f32(accu8.val[1], accu8.val[1])};
1452
+
1453
+ float32x4x2_t sum2 = {
1028
1454
  vpaddq_f32(sum.val[0], sum.val[0]),
1029
- vpaddq_f32(sum.val[1], sum.val[1]));
1030
- float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
1455
+ vpaddq_f32(sum.val[1], sum.val[1])};
1031
1456
  return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
1032
1457
  }
1033
1458
  };
@@ -1086,7 +1511,55 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
1086
1511
  }
1087
1512
  };
1088
1513
 
1089
- #ifdef USE_F16C
1514
+ #if defined(USE_AVX512_F16C)
1515
+
1516
+ template <class Quantizer, class Similarity>
1517
+ struct DCTemplate<Quantizer, Similarity, 16>
1518
+ : SQDistanceComputer { // Update to handle 16 lanes
1519
+ using Sim = Similarity;
1520
+
1521
+ Quantizer quant;
1522
+
1523
+ DCTemplate(size_t d, const std::vector<float>& trained)
1524
+ : quant(d, trained) {}
1525
+
1526
+ float compute_distance(const float* x, const uint8_t* code) const {
1527
+ Similarity sim(x);
1528
+ sim.begin_16();
1529
+ for (size_t i = 0; i < quant.d; i += 16) {
1530
+ __m512 xi = quant.reconstruct_16_components(code, i);
1531
+ sim.add_16_components(xi);
1532
+ }
1533
+ return sim.result_16();
1534
+ }
1535
+
1536
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1537
+ const {
1538
+ Similarity sim(nullptr);
1539
+ sim.begin_16();
1540
+ for (size_t i = 0; i < quant.d; i += 16) {
1541
+ __m512 x1 = quant.reconstruct_16_components(code1, i);
1542
+ __m512 x2 = quant.reconstruct_16_components(code2, i);
1543
+ sim.add_16_components_2(x1, x2);
1544
+ }
1545
+ return sim.result_16();
1546
+ }
1547
+
1548
+ void set_query(const float* x) final {
1549
+ q = x;
1550
+ }
1551
+
1552
+ float symmetric_dis(idx_t i, idx_t j) override {
1553
+ return compute_code_distance(
1554
+ codes + i * code_size, codes + j * code_size);
1555
+ }
1556
+
1557
+ float query_to_code(const uint8_t* code) const final {
1558
+ return compute_distance(q, code);
1559
+ }
1560
+ };
1561
+
1562
+ #elif defined(USE_F16C)
1090
1563
 
1091
1564
  template <class Quantizer, class Similarity>
1092
1565
  struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
@@ -1135,7 +1608,7 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1135
1608
 
1136
1609
  #endif
1137
1610
 
1138
- #ifdef __aarch64__
1611
+ #ifdef USE_NEON
1139
1612
 
1140
1613
  template <class Quantizer, class Similarity>
1141
1614
  struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
@@ -1233,7 +1706,60 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1233
1706
  }
1234
1707
  };
1235
1708
 
1236
- #ifdef __AVX2__
1709
+ #if defined(__AVX512F__)
1710
+
1711
+ template <class Similarity>
1712
+ struct DistanceComputerByte<Similarity, 16> : SQDistanceComputer {
1713
+ using Sim = Similarity;
1714
+
1715
+ int d;
1716
+ std::vector<uint8_t> tmp;
1717
+
1718
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1719
+
1720
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1721
+ const {
1722
+ __m512i accu = _mm512_setzero_si512();
1723
+ for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time
1724
+ __m512i c1 = _mm512_cvtepu8_epi16(
1725
+ _mm256_loadu_si256((__m256i*)(code1 + i)));
1726
+ __m512i c2 = _mm512_cvtepu8_epi16(
1727
+ _mm256_loadu_si256((__m256i*)(code2 + i)));
1728
+ __m512i prod32;
1729
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1730
+ prod32 = _mm512_madd_epi16(c1, c2);
1731
+ } else {
1732
+ __m512i diff = _mm512_sub_epi16(c1, c2);
1733
+ prod32 = _mm512_madd_epi16(diff, diff);
1734
+ }
1735
+ accu = _mm512_add_epi32(accu, prod32);
1736
+ }
1737
+ // Horizontally add elements of accu
1738
+ return _mm512_reduce_add_epi32(accu);
1739
+ }
1740
+
1741
+ void set_query(const float* x) final {
1742
+ for (int i = 0; i < d; i++) {
1743
+ tmp[i] = int(x[i]);
1744
+ }
1745
+ }
1746
+
1747
+ int compute_distance(const float* x, const uint8_t* code) {
1748
+ set_query(x);
1749
+ return compute_code_distance(tmp.data(), code);
1750
+ }
1751
+
1752
+ float symmetric_dis(idx_t i, idx_t j) override {
1753
+ return compute_code_distance(
1754
+ codes + i * code_size, codes + j * code_size);
1755
+ }
1756
+
1757
+ float query_to_code(const uint8_t* code) const final {
1758
+ return compute_code_distance(tmp.data(), code);
1759
+ }
1760
+ };
1761
+
1762
+ #elif defined(__AVX2__)
1237
1763
 
1238
1764
  template <class Similarity>
1239
1765
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
@@ -1298,7 +1824,7 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1298
1824
 
1299
1825
  #endif
1300
1826
 
1301
- #ifdef __aarch64__
1827
+ #ifdef USE_NEON
1302
1828
 
1303
1829
  template <class Similarity>
1304
1830
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
@@ -1360,31 +1886,46 @@ SQDistanceComputer* select_distance_computer(
1360
1886
  switch (qtype) {
1361
1887
  case ScalarQuantizer::QT_8bit_uniform:
1362
1888
  return new DCTemplate<
1363
- QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1889
+ QuantizerTemplate<
1890
+ Codec8bit,
1891
+ QuantizerTemplateScaling::UNIFORM,
1892
+ SIMDWIDTH>,
1364
1893
  Sim,
1365
1894
  SIMDWIDTH>(d, trained);
1366
1895
 
1367
1896
  case ScalarQuantizer::QT_4bit_uniform:
1368
1897
  return new DCTemplate<
1369
- QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1898
+ QuantizerTemplate<
1899
+ Codec4bit,
1900
+ QuantizerTemplateScaling::UNIFORM,
1901
+ SIMDWIDTH>,
1370
1902
  Sim,
1371
1903
  SIMDWIDTH>(d, trained);
1372
1904
 
1373
1905
  case ScalarQuantizer::QT_8bit:
1374
1906
  return new DCTemplate<
1375
- QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1907
+ QuantizerTemplate<
1908
+ Codec8bit,
1909
+ QuantizerTemplateScaling::NON_UNIFORM,
1910
+ SIMDWIDTH>,
1376
1911
  Sim,
1377
1912
  SIMDWIDTH>(d, trained);
1378
1913
 
1379
1914
  case ScalarQuantizer::QT_6bit:
1380
1915
  return new DCTemplate<
1381
- QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1916
+ QuantizerTemplate<
1917
+ Codec6bit,
1918
+ QuantizerTemplateScaling::NON_UNIFORM,
1919
+ SIMDWIDTH>,
1382
1920
  Sim,
1383
1921
  SIMDWIDTH>(d, trained);
1384
1922
 
1385
1923
  case ScalarQuantizer::QT_4bit:
1386
1924
  return new DCTemplate<
1387
- QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1925
+ QuantizerTemplate<
1926
+ Codec4bit,
1927
+ QuantizerTemplateScaling::NON_UNIFORM,
1928
+ SIMDWIDTH>,
1388
1929
  Sim,
1389
1930
  SIMDWIDTH>(d, trained);
1390
1931
 
@@ -1392,15 +1933,31 @@ SQDistanceComputer* select_distance_computer(
1392
1933
  return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1393
1934
  d, trained);
1394
1935
 
1936
+ case ScalarQuantizer::QT_bf16:
1937
+ return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1938
+ d, trained);
1939
+
1395
1940
  case ScalarQuantizer::QT_8bit_direct:
1941
+ #if defined(__AVX512F__)
1942
+ if (d % 32 == 0) {
1943
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1944
+ } else
1945
+ #elif defined(__AVX2__)
1396
1946
  if (d % 16 == 0) {
1397
1947
  return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1398
- } else {
1948
+ } else
1949
+ #endif
1950
+ {
1399
1951
  return new DCTemplate<
1400
1952
  Quantizer8bitDirect<SIMDWIDTH>,
1401
1953
  Sim,
1402
1954
  SIMDWIDTH>(d, trained);
1403
1955
  }
1956
+ case ScalarQuantizer::QT_8bit_direct_signed:
1957
+ return new DCTemplate<
1958
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
1959
+ Sim,
1960
+ SIMDWIDTH>(d, trained);
1404
1961
  }
1405
1962
  FAISS_THROW_MSG("unknown qtype");
1406
1963
  return nullptr;
@@ -1424,6 +1981,7 @@ void ScalarQuantizer::set_derived_sizes() {
1424
1981
  case QT_8bit:
1425
1982
  case QT_8bit_uniform:
1426
1983
  case QT_8bit_direct:
1984
+ case QT_8bit_direct_signed:
1427
1985
  code_size = d;
1428
1986
  bits = 8;
1429
1987
  break;
@@ -1440,6 +1998,10 @@ void ScalarQuantizer::set_derived_sizes() {
1440
1998
  code_size = d * 2;
1441
1999
  bits = 16;
1442
2000
  break;
2001
+ case QT_bf16:
2002
+ code_size = d * 2;
2003
+ bits = 16;
2004
+ break;
1443
2005
  }
1444
2006
  }
1445
2007
 
@@ -1476,13 +2038,19 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1476
2038
  break;
1477
2039
  case QT_fp16:
1478
2040
  case QT_8bit_direct:
2041
+ case QT_bf16:
2042
+ case QT_8bit_direct_signed:
1479
2043
  // no training necessary
1480
2044
  break;
1481
2045
  }
1482
2046
  }
1483
2047
 
1484
2048
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1485
- #if defined(USE_F16C) || defined(__aarch64__)
2049
+ #if defined(USE_AVX512_F16C)
2050
+ if (d % 16 == 0) {
2051
+ return select_quantizer_1<16>(qtype, d, trained);
2052
+ } else
2053
+ #elif defined(USE_F16C) || defined(USE_NEON)
1486
2054
  if (d % 8 == 0) {
1487
2055
  return select_quantizer_1<8>(qtype, d, trained);
1488
2056
  } else
@@ -1513,7 +2081,17 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1513
2081
  SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1514
2082
  MetricType metric) const {
1515
2083
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1516
- #if defined(USE_F16C) || defined(__aarch64__)
2084
+ #if defined(USE_AVX512_F16C)
2085
+ if (d % 16 == 0) {
2086
+ if (metric == METRIC_L2) {
2087
+ return select_distance_computer<SimilarityL2<16>>(
2088
+ qtype, d, trained);
2089
+ } else {
2090
+ return select_distance_computer<SimilarityIP<16>>(
2091
+ qtype, d, trained);
2092
+ }
2093
+ } else
2094
+ #elif defined(USE_F16C) || defined(USE_NEON)
1517
2095
  if (d % 8 == 0) {
1518
2096
  if (metric == METRIC_L2) {
1519
2097
  return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
@@ -1762,7 +2340,7 @@ InvertedListScanner* sel2_InvertedListScanner(
1762
2340
  }
1763
2341
  }
1764
2342
 
1765
- template <class Similarity, class Codec, bool uniform>
2343
+ template <class Similarity, class Codec, QuantizerTemplateScaling SCALING>
1766
2344
  InvertedListScanner* sel12_InvertedListScanner(
1767
2345
  const ScalarQuantizer* sq,
1768
2346
  const Index* quantizer,
@@ -1770,7 +2348,7 @@ InvertedListScanner* sel12_InvertedListScanner(
1770
2348
  const IDSelector* sel,
1771
2349
  bool r) {
1772
2350
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1773
- using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
2351
+ using QuantizerClass = QuantizerTemplate<Codec, SCALING, SIMDWIDTH>;
1774
2352
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1775
2353
  return sel2_InvertedListScanner<DCClass>(
1776
2354
  sq, quantizer, store_pairs, sel, r);
@@ -1786,36 +2364,70 @@ InvertedListScanner* sel1_InvertedListScanner(
1786
2364
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1787
2365
  switch (sq->qtype) {
1788
2366
  case ScalarQuantizer::QT_8bit_uniform:
1789
- return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
2367
+ return sel12_InvertedListScanner<
2368
+ Similarity,
2369
+ Codec8bit,
2370
+ QuantizerTemplateScaling::UNIFORM>(
1790
2371
  sq, quantizer, store_pairs, sel, r);
1791
2372
  case ScalarQuantizer::QT_4bit_uniform:
1792
- return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
2373
+ return sel12_InvertedListScanner<
2374
+ Similarity,
2375
+ Codec4bit,
2376
+ QuantizerTemplateScaling::UNIFORM>(
1793
2377
  sq, quantizer, store_pairs, sel, r);
1794
2378
  case ScalarQuantizer::QT_8bit:
1795
- return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
2379
+ return sel12_InvertedListScanner<
2380
+ Similarity,
2381
+ Codec8bit,
2382
+ QuantizerTemplateScaling::NON_UNIFORM>(
1796
2383
  sq, quantizer, store_pairs, sel, r);
1797
2384
  case ScalarQuantizer::QT_4bit:
1798
- return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
2385
+ return sel12_InvertedListScanner<
2386
+ Similarity,
2387
+ Codec4bit,
2388
+ QuantizerTemplateScaling::NON_UNIFORM>(
1799
2389
  sq, quantizer, store_pairs, sel, r);
1800
2390
  case ScalarQuantizer::QT_6bit:
1801
- return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
2391
+ return sel12_InvertedListScanner<
2392
+ Similarity,
2393
+ Codec6bit,
2394
+ QuantizerTemplateScaling::NON_UNIFORM>(
1802
2395
  sq, quantizer, store_pairs, sel, r);
1803
2396
  case ScalarQuantizer::QT_fp16:
1804
2397
  return sel2_InvertedListScanner<DCTemplate<
1805
2398
  QuantizerFP16<SIMDWIDTH>,
1806
2399
  Similarity,
1807
2400
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2401
+ case ScalarQuantizer::QT_bf16:
2402
+ return sel2_InvertedListScanner<DCTemplate<
2403
+ QuantizerBF16<SIMDWIDTH>,
2404
+ Similarity,
2405
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1808
2406
  case ScalarQuantizer::QT_8bit_direct:
2407
+ #if defined(__AVX512F__)
2408
+ if (sq->d % 32 == 0) {
2409
+ return sel2_InvertedListScanner<
2410
+ DistanceComputerByte<Similarity, SIMDWIDTH>>(
2411
+ sq, quantizer, store_pairs, sel, r);
2412
+ } else
2413
+ #elif defined(__AVX2__)
1809
2414
  if (sq->d % 16 == 0) {
1810
2415
  return sel2_InvertedListScanner<
1811
2416
  DistanceComputerByte<Similarity, SIMDWIDTH>>(
1812
2417
  sq, quantizer, store_pairs, sel, r);
1813
- } else {
2418
+ } else
2419
+ #endif
2420
+ {
1814
2421
  return sel2_InvertedListScanner<DCTemplate<
1815
2422
  Quantizer8bitDirect<SIMDWIDTH>,
1816
2423
  Similarity,
1817
2424
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1818
2425
  }
2426
+ case ScalarQuantizer::QT_8bit_direct_signed:
2427
+ return sel2_InvertedListScanner<DCTemplate<
2428
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
2429
+ Similarity,
2430
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1819
2431
  }
1820
2432
 
1821
2433
  FAISS_THROW_MSG("unknown qtype");
@@ -1849,7 +2461,12 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1849
2461
  bool store_pairs,
1850
2462
  const IDSelector* sel,
1851
2463
  bool by_residual) const {
1852
- #if defined(USE_F16C) || defined(__aarch64__)
2464
+ #if defined(USE_AVX512_F16C)
2465
+ if (d % 16 == 0) {
2466
+ return sel0_InvertedListScanner<16>(
2467
+ mt, this, quantizer, store_pairs, sel, by_residual);
2468
+ } else
2469
+ #elif defined(USE_F16C) || defined(USE_NEON)
1853
2470
  if (d % 8 == 0) {
1854
2471
  return sel0_InvertedListScanner<8>(
1855
2472
  mt, this, quantizer, store_pairs, sel, by_residual);