faiss 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -23,6 +23,7 @@
23
23
  #include <faiss/impl/AuxIndexStructures.h>
24
24
  #include <faiss/impl/FaissAssert.h>
25
25
  #include <faiss/impl/IDSelector.h>
26
+ #include <faiss/utils/bf16.h>
26
27
  #include <faiss/utils/fp16.h>
27
28
  #include <faiss/utils/utils.h>
28
29
 
@@ -43,7 +44,9 @@ namespace faiss {
43
44
  * that hides the template mess.
44
45
  ********************************************************************/
45
46
 
46
- #ifdef __AVX2__
47
+ #if defined(__AVX512F__) && defined(__F16C__)
48
+ #define USE_AVX512_F16C
49
+ #elif defined(__AVX2__)
47
50
  #ifdef __F16C__
48
51
  #define USE_F16C
49
52
  #else
@@ -52,6 +55,15 @@ namespace faiss {
52
55
  #endif
53
56
  #endif
54
57
 
58
+ #if defined(__aarch64__)
59
+ #if defined(__GNUC__) && __GNUC__ < 8
60
+ #warning \
61
+ "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8"
62
+ #else
63
+ #define USE_NEON
64
+ #endif
65
+ #endif
66
+
55
67
  namespace {
56
68
 
57
69
  typedef ScalarQuantizer::QuantizerType QuantizerType;
@@ -78,7 +90,17 @@ struct Codec8bit {
78
90
  return (code[i] + 0.5f) / 255.0f;
79
91
  }
80
92
 
81
- #ifdef __AVX2__
93
+ #if defined(__AVX512F__)
94
+ static FAISS_ALWAYS_INLINE __m512
95
+ decode_16_components(const uint8_t* code, int i) {
96
+ const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
97
+ const __m512i i32 = _mm512_cvtepu8_epi32(c16);
98
+ const __m512 f16 = _mm512_cvtepi32_ps(i32);
99
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
100
+ const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
101
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
102
+ }
103
+ #elif defined(__AVX2__)
82
104
  static FAISS_ALWAYS_INLINE __m256
83
105
  decode_8_components(const uint8_t* code, int i) {
84
106
  const uint64_t c8 = *(uint64_t*)(code + i);
@@ -92,7 +114,7 @@ struct Codec8bit {
92
114
  }
93
115
  #endif
94
116
 
95
- #ifdef __aarch64__
117
+ #ifdef USE_NEON
96
118
  static FAISS_ALWAYS_INLINE float32x4x2_t
97
119
  decode_8_components(const uint8_t* code, int i) {
98
120
  float32_t result[8] = {};
@@ -101,8 +123,7 @@ struct Codec8bit {
101
123
  }
102
124
  float32x4_t res1 = vld1q_f32(result);
103
125
  float32x4_t res2 = vld1q_f32(result + 4);
104
- float32x4x2_t res = vzipq_f32(res1, res2);
105
- return vuzpq_f32(res.val[0], res.val[1]);
126
+ return {res1, res2};
106
127
  }
107
128
  #endif
108
129
  };
@@ -121,7 +142,26 @@ struct Codec4bit {
121
142
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
122
143
  }
123
144
 
124
- #ifdef __AVX2__
145
+ #if defined(__AVX512F__)
146
+ static FAISS_ALWAYS_INLINE __m512
147
+ decode_16_components(const uint8_t* code, int i) {
148
+ uint64_t c8 = *(uint64_t*)(code + (i >> 1));
149
+ uint64_t mask = 0x0f0f0f0f0f0f0f0f;
150
+ uint64_t c8ev = c8 & mask;
151
+ uint64_t c8od = (c8 >> 4) & mask;
152
+
153
+ __m128i c16 =
154
+ _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
155
+ __m256i c8lo = _mm256_cvtepu8_epi32(c16);
156
+ __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
157
+ __m512i i16 = _mm512_castsi256_si512(c8lo);
158
+ i16 = _mm512_inserti32x8(i16, c8hi, 1);
159
+ __m512 f16 = _mm512_cvtepi32_ps(i16);
160
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
161
+ const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
162
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
163
+ }
164
+ #elif defined(__AVX2__)
125
165
  static FAISS_ALWAYS_INLINE __m256
126
166
  decode_8_components(const uint8_t* code, int i) {
127
167
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
@@ -144,7 +184,7 @@ struct Codec4bit {
144
184
  }
145
185
  #endif
146
186
 
147
- #ifdef __aarch64__
187
+ #ifdef USE_NEON
148
188
  static FAISS_ALWAYS_INLINE float32x4x2_t
149
189
  decode_8_components(const uint8_t* code, int i) {
150
190
  float32_t result[8] = {};
@@ -153,8 +193,7 @@ struct Codec4bit {
153
193
  }
154
194
  float32x4_t res1 = vld1q_f32(result);
155
195
  float32x4_t res2 = vld1q_f32(result + 4);
156
- float32x4x2_t res = vzipq_f32(res1, res2);
157
- return vuzpq_f32(res.val[0], res.val[1]);
196
+ return {res1, res2};
158
197
  }
159
198
  #endif
160
199
  };
@@ -208,7 +247,56 @@ struct Codec6bit {
208
247
  return (bits + 0.5f) / 63.0f;
209
248
  }
210
249
 
211
- #ifdef __AVX2__
250
+ #if defined(__AVX512F__)
251
+
252
+ static FAISS_ALWAYS_INLINE __m512
253
+ decode_16_components(const uint8_t* code, int i) {
254
+ // pure AVX512 implementation (not necessarily the fastest).
255
+ // see:
256
+ // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
257
+
258
+ // clang-format off
259
+
260
+ // 16 components, 16x6 bit=12 bytes
261
+ const __m128i bit_6v =
262
+ _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
263
+ const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
264
+
265
+ // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
266
+ // 00 01 02 03
267
+ const __m256i shuffle_mask = _mm256_setr_epi16(
268
+ 0xFF00, 0x0100, 0x0201, 0xFF02,
269
+ 0xFF03, 0x0403, 0x0504, 0xFF05,
270
+ 0xFF06, 0x0706, 0x0807, 0xFF08,
271
+ 0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
272
+ const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
273
+
274
+ // 0: xxxxxxxx xx543210
275
+ // 1: xxxx5432 10xxxxxx
276
+ // 2: xxxxxx54 3210xxxx
277
+ // 3: xxxxxxxx 543210xx
278
+ const __m256i shift_right_v = _mm256_setr_epi16(
279
+ 0x0U, 0x6U, 0x4U, 0x2U,
280
+ 0x0U, 0x6U, 0x4U, 0x2U,
281
+ 0x0U, 0x6U, 0x4U, 0x2U,
282
+ 0x0U, 0x6U, 0x4U, 0x2U);
283
+ __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
284
+
285
+ // remove unneeded bits
286
+ shuffled_shifted =
287
+ _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
288
+
289
+ // scale
290
+ const __m512 f8 =
291
+ _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
292
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
293
+ const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
294
+ return _mm512_fmadd_ps(f8, one_255, half_one_255);
295
+
296
+ // clang-format on
297
+ }
298
+
299
+ #elif defined(__AVX2__)
212
300
 
213
301
  /* Load 6 bytes that represent 8 6-bit values, return them as a
214
302
  * 8*32 bit vector register */
@@ -257,7 +345,7 @@ struct Codec6bit {
257
345
 
258
346
  #endif
259
347
 
260
- #ifdef __aarch64__
348
+ #ifdef USE_NEON
261
349
  static FAISS_ALWAYS_INLINE float32x4x2_t
262
350
  decode_8_components(const uint8_t* code, int i) {
263
351
  float32_t result[8] = {};
@@ -266,8 +354,7 @@ struct Codec6bit {
266
354
  }
267
355
  float32x4_t res1 = vld1q_f32(result);
268
356
  float32x4_t res2 = vld1q_f32(result + 4);
269
- float32x4x2_t res = vzipq_f32(res1, res2);
270
- return vuzpq_f32(res.val[0], res.val[1]);
357
+ return {res1, res2};
271
358
  }
272
359
  #endif
273
360
  };
@@ -277,11 +364,14 @@ struct Codec6bit {
277
364
  * through a codec
278
365
  *******************************************************************/
279
366
 
280
- template <class Codec, bool uniform, int SIMD>
367
+ enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 };
368
+
369
+ template <class Codec, QuantizerTemplateScaling SCALING, int SIMD>
281
370
  struct QuantizerTemplate {};
282
371
 
283
372
  template <class Codec>
284
- struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
373
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>
374
+ : ScalarQuantizer::SQuantizer {
285
375
  const size_t d;
286
376
  const float vmin, vdiff;
287
377
 
@@ -318,12 +408,33 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
318
408
  }
319
409
  };
320
410
 
321
- #ifdef __AVX2__
411
+ #if defined(__AVX512F__)
412
+
413
+ template <class Codec>
414
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 16>
415
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
416
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
417
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
418
+ d,
419
+ trained) {}
420
+
421
+ FAISS_ALWAYS_INLINE __m512
422
+ reconstruct_16_components(const uint8_t* code, int i) const {
423
+ __m512 xi = Codec::decode_16_components(code, i);
424
+ return _mm512_fmadd_ps(
425
+ xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin));
426
+ }
427
+ };
428
+
429
+ #elif defined(__AVX2__)
322
430
 
323
431
  template <class Codec>
324
- struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
432
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
433
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
325
434
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
326
- : QuantizerTemplate<Codec, true, 1>(d, trained) {}
435
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
436
+ d,
437
+ trained) {}
327
438
 
328
439
  FAISS_ALWAYS_INLINE __m256
329
440
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -335,33 +446,35 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
335
446
 
336
447
  #endif
337
448
 
338
- #ifdef __aarch64__
449
+ #ifdef USE_NEON
339
450
 
340
451
  template <class Codec>
341
- struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
452
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
453
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
342
454
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
343
- : QuantizerTemplate<Codec, true, 1>(d, trained) {}
455
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
456
+ d,
457
+ trained) {}
344
458
 
345
459
  FAISS_ALWAYS_INLINE float32x4x2_t
346
460
  reconstruct_8_components(const uint8_t* code, int i) const {
347
461
  float32x4x2_t xi = Codec::decode_8_components(code, i);
348
- float32x4x2_t res = vzipq_f32(
349
- vfmaq_f32(
462
+ return {vfmaq_f32(
350
463
  vdupq_n_f32(this->vmin),
351
464
  xi.val[0],
352
465
  vdupq_n_f32(this->vdiff)),
353
466
  vfmaq_f32(
354
467
  vdupq_n_f32(this->vmin),
355
468
  xi.val[1],
356
- vdupq_n_f32(this->vdiff)));
357
- return vuzpq_f32(res.val[0], res.val[1]);
469
+ vdupq_n_f32(this->vdiff))};
358
470
  }
359
471
  };
360
472
 
361
473
  #endif
362
474
 
363
475
  template <class Codec>
364
- struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
476
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>
477
+ : ScalarQuantizer::SQuantizer {
365
478
  const size_t d;
366
479
  const float *vmin, *vdiff;
367
480
 
@@ -398,12 +511,37 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
398
511
  }
399
512
  };
400
513
 
401
- #ifdef __AVX2__
514
+ #if defined(__AVX512F__)
515
+
516
+ template <class Codec>
517
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 16>
518
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
519
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
520
+ : QuantizerTemplate<
521
+ Codec,
522
+ QuantizerTemplateScaling::NON_UNIFORM,
523
+ 1>(d, trained) {}
524
+
525
+ FAISS_ALWAYS_INLINE __m512
526
+ reconstruct_16_components(const uint8_t* code, int i) const {
527
+ __m512 xi = Codec::decode_16_components(code, i);
528
+ return _mm512_fmadd_ps(
529
+ xi,
530
+ _mm512_loadu_ps(this->vdiff + i),
531
+ _mm512_loadu_ps(this->vmin + i));
532
+ }
533
+ };
534
+
535
+ #elif defined(__AVX2__)
402
536
 
403
537
  template <class Codec>
404
- struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
538
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
539
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
405
540
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
406
- : QuantizerTemplate<Codec, false, 1>(d, trained) {}
541
+ : QuantizerTemplate<
542
+ Codec,
543
+ QuantizerTemplateScaling::NON_UNIFORM,
544
+ 1>(d, trained) {}
407
545
 
408
546
  FAISS_ALWAYS_INLINE __m256
409
547
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -417,12 +555,16 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
417
555
 
418
556
  #endif
419
557
 
420
- #ifdef __aarch64__
558
+ #ifdef USE_NEON
421
559
 
422
560
  template <class Codec>
423
- struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
561
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
562
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
424
563
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
425
- : QuantizerTemplate<Codec, false, 1>(d, trained) {}
564
+ : QuantizerTemplate<
565
+ Codec,
566
+ QuantizerTemplateScaling::NON_UNIFORM,
567
+ 1>(d, trained) {}
426
568
 
427
569
  FAISS_ALWAYS_INLINE float32x4x2_t
428
570
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -431,10 +573,8 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
431
573
  float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
432
574
  float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
433
575
 
434
- float32x4x2_t res = vzipq_f32(
435
- vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
436
- vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
437
- return vuzpq_f32(res.val[0], res.val[1]);
576
+ return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
577
+ vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])};
438
578
  }
439
579
  };
440
580
 
@@ -471,7 +611,23 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
471
611
  }
472
612
  };
473
613
 
474
- #ifdef USE_F16C
614
+ #if defined(USE_AVX512_F16C)
615
+
616
+ template <>
617
+ struct QuantizerFP16<16> : QuantizerFP16<1> {
618
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
619
+ : QuantizerFP16<1>(d, trained) {}
620
+
621
+ FAISS_ALWAYS_INLINE __m512
622
+ reconstruct_16_components(const uint8_t* code, int i) const {
623
+ __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
624
+ return _mm512_cvtph_ps(codei);
625
+ }
626
+ };
627
+
628
+ #endif
629
+
630
+ #if defined(USE_F16C)
475
631
 
476
632
  template <>
477
633
  struct QuantizerFP16<8> : QuantizerFP16<1> {
@@ -487,7 +643,7 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
487
643
 
488
644
  #endif
489
645
 
490
- #ifdef __aarch64__
646
+ #ifdef USE_NEON
491
647
 
492
648
  template <>
493
649
  struct QuantizerFP16<8> : QuantizerFP16<1> {
@@ -496,10 +652,90 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
496
652
 
497
653
  FAISS_ALWAYS_INLINE float32x4x2_t
498
654
  reconstruct_8_components(const uint8_t* code, int i) const {
499
- uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
500
- return vzipq_f32(
501
- vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
502
- vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
655
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
656
+ return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
657
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))};
658
+ }
659
+ };
660
+ #endif
661
+
662
+ /*******************************************************************
663
+ * BF16 quantizer
664
+ *******************************************************************/
665
+
666
+ template <int SIMDWIDTH>
667
+ struct QuantizerBF16 {};
668
+
669
+ template <>
670
+ struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
671
+ const size_t d;
672
+
673
+ QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
674
+
675
+ void encode_vector(const float* x, uint8_t* code) const final {
676
+ for (size_t i = 0; i < d; i++) {
677
+ ((uint16_t*)code)[i] = encode_bf16(x[i]);
678
+ }
679
+ }
680
+
681
+ void decode_vector(const uint8_t* code, float* x) const final {
682
+ for (size_t i = 0; i < d; i++) {
683
+ x[i] = decode_bf16(((uint16_t*)code)[i]);
684
+ }
685
+ }
686
+
687
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
688
+ const {
689
+ return decode_bf16(((uint16_t*)code)[i]);
690
+ }
691
+ };
692
+
693
+ #if defined(__AVX512F__)
694
+
695
+ template <>
696
+ struct QuantizerBF16<16> : QuantizerBF16<1> {
697
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
698
+ : QuantizerBF16<1>(d, trained) {}
699
+ FAISS_ALWAYS_INLINE __m512
700
+ reconstruct_16_components(const uint8_t* code, int i) const {
701
+ __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
702
+ __m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
703
+ code_512i = _mm512_slli_epi32(code_512i, 16);
704
+ return _mm512_castsi512_ps(code_512i);
705
+ }
706
+ };
707
+
708
+ #elif defined(__AVX2__)
709
+
710
+ template <>
711
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
712
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
713
+ : QuantizerBF16<1>(d, trained) {}
714
+
715
+ FAISS_ALWAYS_INLINE __m256
716
+ reconstruct_8_components(const uint8_t* code, int i) const {
717
+ __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
718
+ __m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
719
+ code_256i = _mm256_slli_epi32(code_256i, 16);
720
+ return _mm256_castsi256_ps(code_256i);
721
+ }
722
+ };
723
+
724
+ #endif
725
+
726
+ #ifdef USE_NEON
727
+
728
+ template <>
729
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
730
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
731
+ : QuantizerBF16<1>(d, trained) {}
732
+
733
+ FAISS_ALWAYS_INLINE float32x4x2_t
734
+ reconstruct_8_components(const uint8_t* code, int i) const {
735
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
736
+ return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
737
+ vreinterpretq_f32_u32(
738
+ vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
503
739
  }
504
740
  };
505
741
  #endif
@@ -536,7 +772,22 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
536
772
  }
537
773
  };
538
774
 
539
- #ifdef __AVX2__
775
+ #if defined(__AVX512F__)
776
+
777
+ template <>
778
+ struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> {
779
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
780
+ : Quantizer8bitDirect<1>(d, trained) {}
781
+
782
+ FAISS_ALWAYS_INLINE __m512
783
+ reconstruct_16_components(const uint8_t* code, int i) const {
784
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
785
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
786
+ return _mm512_cvtepi32_ps(y16); // 16 * float32
787
+ }
788
+ };
789
+
790
+ #elif defined(__AVX2__)
540
791
 
541
792
  template <>
542
793
  struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
@@ -553,7 +804,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
553
804
 
554
805
  #endif
555
806
 
556
- #ifdef __aarch64__
807
+ #ifdef USE_NEON
557
808
 
558
809
  template <>
559
810
  struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
@@ -562,14 +813,107 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
562
813
 
563
814
  FAISS_ALWAYS_INLINE float32x4x2_t
564
815
  reconstruct_8_components(const uint8_t* code, int i) const {
565
- float32_t result[8] = {};
566
- for (size_t j = 0; j < 8; j++) {
567
- result[j] = code[i + j];
816
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
817
+ uint16x8_t y8 = vmovl_u8(x8);
818
+ uint16x4_t y8_0 = vget_low_u16(y8);
819
+ uint16x4_t y8_1 = vget_high_u16(y8);
820
+
821
+ // convert uint16 -> uint32 -> fp32
822
+ return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))};
823
+ }
824
+ };
825
+
826
+ #endif
827
+
828
+ /*******************************************************************
829
+ * 8bit_direct_signed quantizer
830
+ *******************************************************************/
831
+
832
+ template <int SIMDWIDTH>
833
+ struct Quantizer8bitDirectSigned {};
834
+
835
+ template <>
836
+ struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer {
837
+ const size_t d;
838
+
839
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& /* unused */)
840
+ : d(d) {}
841
+
842
+ void encode_vector(const float* x, uint8_t* code) const final {
843
+ for (size_t i = 0; i < d; i++) {
844
+ code[i] = (uint8_t)(x[i] + 128);
568
845
  }
569
- float32x4_t res1 = vld1q_f32(result);
570
- float32x4_t res2 = vld1q_f32(result + 4);
571
- float32x4x2_t res = vzipq_f32(res1, res2);
572
- return vuzpq_f32(res.val[0], res.val[1]);
846
+ }
847
+
848
+ void decode_vector(const uint8_t* code, float* x) const final {
849
+ for (size_t i = 0; i < d; i++) {
850
+ x[i] = code[i] - 128;
851
+ }
852
+ }
853
+
854
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
855
+ const {
856
+ return code[i] - 128;
857
+ }
858
+ };
859
+
860
+ #if defined(__AVX512F__)
861
+
862
+ template <>
863
+ struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> {
864
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
865
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
866
+
867
+ FAISS_ALWAYS_INLINE __m512
868
+ reconstruct_16_components(const uint8_t* code, int i) const {
869
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
870
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
871
+ __m512i c16 = _mm512_set1_epi32(128);
872
+ __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
873
+ return _mm512_cvtepi32_ps(z16); // 16 * float32
874
+ }
875
+ };
876
+
877
+ #elif defined(__AVX2__)
878
+
879
+ template <>
880
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
881
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
882
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
883
+
884
+ FAISS_ALWAYS_INLINE __m256
885
+ reconstruct_8_components(const uint8_t* code, int i) const {
886
+ __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
887
+ __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
888
+ __m256i c8 = _mm256_set1_epi32(128);
889
+ __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes
890
+ return _mm256_cvtepi32_ps(z8); // 8 * float32
891
+ }
892
+ };
893
+
894
+ #endif
895
+
896
+ #ifdef USE_NEON
897
+
898
+ template <>
899
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
900
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
901
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
902
+
903
+ FAISS_ALWAYS_INLINE float32x4x2_t
904
+ reconstruct_8_components(const uint8_t* code, int i) const {
905
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
906
+ uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16
907
+ uint16x4_t y8_0 = vget_low_u16(y8);
908
+ uint16x4_t y8_1 = vget_high_u16(y8);
909
+
910
+ float32x4_t z8_0 = vcvtq_f32_u32(
911
+ vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32
912
+ float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1));
913
+
914
+ // subtract 128 to convert into signed numbers
915
+ return {vsubq_f32(z8_0, vmovq_n_f32(128.0)),
916
+ vsubq_f32(z8_1, vmovq_n_f32(128.0))};
573
917
  }
574
918
  };
575
919
 
@@ -582,24 +926,38 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
582
926
  const std::vector<float>& trained) {
583
927
  switch (qtype) {
584
928
  case ScalarQuantizer::QT_8bit:
585
- return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
586
- d, trained);
929
+ return new QuantizerTemplate<
930
+ Codec8bit,
931
+ QuantizerTemplateScaling::NON_UNIFORM,
932
+ SIMDWIDTH>(d, trained);
587
933
  case ScalarQuantizer::QT_6bit:
588
- return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
589
- d, trained);
934
+ return new QuantizerTemplate<
935
+ Codec6bit,
936
+ QuantizerTemplateScaling::NON_UNIFORM,
937
+ SIMDWIDTH>(d, trained);
590
938
  case ScalarQuantizer::QT_4bit:
591
- return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
592
- d, trained);
939
+ return new QuantizerTemplate<
940
+ Codec4bit,
941
+ QuantizerTemplateScaling::NON_UNIFORM,
942
+ SIMDWIDTH>(d, trained);
593
943
  case ScalarQuantizer::QT_8bit_uniform:
594
- return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
595
- d, trained);
944
+ return new QuantizerTemplate<
945
+ Codec8bit,
946
+ QuantizerTemplateScaling::UNIFORM,
947
+ SIMDWIDTH>(d, trained);
596
948
  case ScalarQuantizer::QT_4bit_uniform:
597
- return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
598
- d, trained);
949
+ return new QuantizerTemplate<
950
+ Codec4bit,
951
+ QuantizerTemplateScaling::UNIFORM,
952
+ SIMDWIDTH>(d, trained);
599
953
  case ScalarQuantizer::QT_fp16:
600
954
  return new QuantizerFP16<SIMDWIDTH>(d, trained);
955
+ case ScalarQuantizer::QT_bf16:
956
+ return new QuantizerBF16<SIMDWIDTH>(d, trained);
601
957
  case ScalarQuantizer::QT_8bit_direct:
602
958
  return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
959
+ case ScalarQuantizer::QT_8bit_direct_signed:
960
+ return new Quantizer8bitDirectSigned<SIMDWIDTH>(d, trained);
603
961
  }
604
962
  FAISS_THROW_MSG("unknown qtype");
605
963
  }
@@ -816,7 +1174,43 @@ struct SimilarityL2<1> {
816
1174
  }
817
1175
  };
818
1176
 
819
- #ifdef __AVX2__
1177
+ #if defined(__AVX512F__)
1178
+
1179
+ template <>
1180
+ struct SimilarityL2<16> {
1181
+ static constexpr int simdwidth = 16;
1182
+ static constexpr MetricType metric_type = METRIC_L2;
1183
+
1184
+ const float *y, *yi;
1185
+
1186
+ explicit SimilarityL2(const float* y) : y(y) {}
1187
+ __m512 accu16;
1188
+
1189
+ FAISS_ALWAYS_INLINE void begin_16() {
1190
+ accu16 = _mm512_setzero_ps();
1191
+ yi = y;
1192
+ }
1193
+
1194
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1195
+ __m512 yiv = _mm512_loadu_ps(yi);
1196
+ yi += 16;
1197
+ __m512 tmp = _mm512_sub_ps(yiv, x);
1198
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1199
+ }
1200
+
1201
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) {
1202
+ __m512 tmp = _mm512_sub_ps(y_2, x);
1203
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1204
+ }
1205
+
1206
+ FAISS_ALWAYS_INLINE float result_16() {
1207
+ // performs better than dividing into _mm256 and adding
1208
+ return _mm512_reduce_add_ps(accu16);
1209
+ }
1210
+ };
1211
+
1212
+ #elif defined(__AVX2__)
1213
+
820
1214
  template <>
821
1215
  struct SimilarityL2<8> {
822
1216
  static constexpr int simdwidth = 8;
@@ -857,7 +1251,7 @@ struct SimilarityL2<8> {
857
1251
 
858
1252
  #endif
859
1253
 
860
- #ifdef __aarch64__
1254
+ #ifdef USE_NEON
861
1255
  template <>
862
1256
  struct SimilarityL2<8> {
863
1257
  static constexpr int simdwidth = 8;
@@ -868,7 +1262,7 @@ struct SimilarityL2<8> {
868
1262
  float32x4x2_t accu8;
869
1263
 
870
1264
  FAISS_ALWAYS_INLINE void begin_8() {
871
- accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1265
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
872
1266
  yi = y;
873
1267
  }
874
1268
 
@@ -882,8 +1276,7 @@ struct SimilarityL2<8> {
882
1276
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
883
1277
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
884
1278
 
885
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
886
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1279
+ accu8 = {accu8_0, accu8_1};
887
1280
  }
888
1281
 
889
1282
  FAISS_ALWAYS_INLINE void add_8_components_2(
@@ -895,8 +1288,7 @@ struct SimilarityL2<8> {
895
1288
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
896
1289
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
897
1290
 
898
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
899
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1291
+ accu8 = {accu8_0, accu8_1};
900
1292
  }
901
1293
 
902
1294
  FAISS_ALWAYS_INLINE float result_8() {
@@ -941,7 +1333,43 @@ struct SimilarityIP<1> {
941
1333
  }
942
1334
  };
943
1335
 
944
- #ifdef __AVX2__
1336
+ #if defined(__AVX512F__)
1337
+
1338
+ template <>
1339
+ struct SimilarityIP<16> {
1340
+ static constexpr int simdwidth = 16;
1341
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1342
+
1343
+ const float *y, *yi;
1344
+
1345
+ float accu;
1346
+
1347
+ explicit SimilarityIP(const float* y) : y(y) {}
1348
+
1349
+ __m512 accu16;
1350
+
1351
+ FAISS_ALWAYS_INLINE void begin_16() {
1352
+ accu16 = _mm512_setzero_ps();
1353
+ yi = y;
1354
+ }
1355
+
1356
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1357
+ __m512 yiv = _mm512_loadu_ps(yi);
1358
+ yi += 16;
1359
+ accu16 = _mm512_fmadd_ps(yiv, x, accu16);
1360
+ }
1361
+
1362
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) {
1363
+ accu16 = _mm512_fmadd_ps(x1, x2, accu16);
1364
+ }
1365
+
1366
+ FAISS_ALWAYS_INLINE float result_16() {
1367
+ // performs better than dividing into _mm256 and adding
1368
+ return _mm512_reduce_add_ps(accu16);
1369
+ }
1370
+ };
1371
+
1372
+ #elif defined(__AVX2__)
945
1373
 
946
1374
  template <>
947
1375
  struct SimilarityIP<8> {
@@ -983,7 +1411,7 @@ struct SimilarityIP<8> {
983
1411
  };
984
1412
  #endif
985
1413
 
986
- #ifdef __aarch64__
1414
+ #ifdef USE_NEON
987
1415
 
988
1416
  template <>
989
1417
  struct SimilarityIP<8> {
@@ -996,7 +1424,7 @@ struct SimilarityIP<8> {
996
1424
  float32x4x2_t accu8;
997
1425
 
998
1426
  FAISS_ALWAYS_INLINE void begin_8() {
999
- accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1427
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1000
1428
  yi = y;
1001
1429
  }
1002
1430
 
@@ -1006,8 +1434,7 @@ struct SimilarityIP<8> {
1006
1434
 
1007
1435
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1008
1436
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1009
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1010
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1437
+ accu8 = {accu8_0, accu8_1};
1011
1438
  }
1012
1439
 
1013
1440
  FAISS_ALWAYS_INLINE void add_8_components_2(
@@ -1015,19 +1442,17 @@ struct SimilarityIP<8> {
1015
1442
  float32x4x2_t x2) {
1016
1443
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
1017
1444
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1018
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1019
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1445
+ accu8 = {accu8_0, accu8_1};
1020
1446
  }
1021
1447
 
1022
1448
  FAISS_ALWAYS_INLINE float result_8() {
1023
- float32x4x2_t sum_tmp = vzipq_f32(
1449
+ float32x4x2_t sum = {
1024
1450
  vpaddq_f32(accu8.val[0], accu8.val[0]),
1025
- vpaddq_f32(accu8.val[1], accu8.val[1]));
1026
- float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
1027
- float32x4x2_t sum2_tmp = vzipq_f32(
1451
+ vpaddq_f32(accu8.val[1], accu8.val[1])};
1452
+
1453
+ float32x4x2_t sum2 = {
1028
1454
  vpaddq_f32(sum.val[0], sum.val[0]),
1029
- vpaddq_f32(sum.val[1], sum.val[1]));
1030
- float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
1455
+ vpaddq_f32(sum.val[1], sum.val[1])};
1031
1456
  return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
1032
1457
  }
1033
1458
  };
@@ -1086,7 +1511,55 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
1086
1511
  }
1087
1512
  };
1088
1513
 
1089
- #ifdef USE_F16C
1514
+ #if defined(USE_AVX512_F16C)
1515
+
1516
+ template <class Quantizer, class Similarity>
1517
+ struct DCTemplate<Quantizer, Similarity, 16>
1518
+ : SQDistanceComputer { // Update to handle 16 lanes
1519
+ using Sim = Similarity;
1520
+
1521
+ Quantizer quant;
1522
+
1523
+ DCTemplate(size_t d, const std::vector<float>& trained)
1524
+ : quant(d, trained) {}
1525
+
1526
+ float compute_distance(const float* x, const uint8_t* code) const {
1527
+ Similarity sim(x);
1528
+ sim.begin_16();
1529
+ for (size_t i = 0; i < quant.d; i += 16) {
1530
+ __m512 xi = quant.reconstruct_16_components(code, i);
1531
+ sim.add_16_components(xi);
1532
+ }
1533
+ return sim.result_16();
1534
+ }
1535
+
1536
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1537
+ const {
1538
+ Similarity sim(nullptr);
1539
+ sim.begin_16();
1540
+ for (size_t i = 0; i < quant.d; i += 16) {
1541
+ __m512 x1 = quant.reconstruct_16_components(code1, i);
1542
+ __m512 x2 = quant.reconstruct_16_components(code2, i);
1543
+ sim.add_16_components_2(x1, x2);
1544
+ }
1545
+ return sim.result_16();
1546
+ }
1547
+
1548
+ void set_query(const float* x) final {
1549
+ q = x;
1550
+ }
1551
+
1552
+ float symmetric_dis(idx_t i, idx_t j) override {
1553
+ return compute_code_distance(
1554
+ codes + i * code_size, codes + j * code_size);
1555
+ }
1556
+
1557
+ float query_to_code(const uint8_t* code) const final {
1558
+ return compute_distance(q, code);
1559
+ }
1560
+ };
1561
+
1562
+ #elif defined(USE_F16C)
1090
1563
 
1091
1564
  template <class Quantizer, class Similarity>
1092
1565
  struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
@@ -1135,7 +1608,7 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1135
1608
 
1136
1609
  #endif
1137
1610
 
1138
- #ifdef __aarch64__
1611
+ #ifdef USE_NEON
1139
1612
 
1140
1613
  template <class Quantizer, class Similarity>
1141
1614
  struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
@@ -1233,7 +1706,60 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1233
1706
  }
1234
1707
  };
1235
1708
 
1236
- #ifdef __AVX2__
1709
+ #if defined(__AVX512F__)
1710
+
1711
+ template <class Similarity>
1712
+ struct DistanceComputerByte<Similarity, 16> : SQDistanceComputer {
1713
+ using Sim = Similarity;
1714
+
1715
+ int d;
1716
+ std::vector<uint8_t> tmp;
1717
+
1718
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1719
+
1720
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1721
+ const {
1722
+ __m512i accu = _mm512_setzero_si512();
1723
+ for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time
1724
+ __m512i c1 = _mm512_cvtepu8_epi16(
1725
+ _mm256_loadu_si256((__m256i*)(code1 + i)));
1726
+ __m512i c2 = _mm512_cvtepu8_epi16(
1727
+ _mm256_loadu_si256((__m256i*)(code2 + i)));
1728
+ __m512i prod32;
1729
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1730
+ prod32 = _mm512_madd_epi16(c1, c2);
1731
+ } else {
1732
+ __m512i diff = _mm512_sub_epi16(c1, c2);
1733
+ prod32 = _mm512_madd_epi16(diff, diff);
1734
+ }
1735
+ accu = _mm512_add_epi32(accu, prod32);
1736
+ }
1737
+ // Horizontally add elements of accu
1738
+ return _mm512_reduce_add_epi32(accu);
1739
+ }
1740
+
1741
+ void set_query(const float* x) final {
1742
+ for (int i = 0; i < d; i++) {
1743
+ tmp[i] = int(x[i]);
1744
+ }
1745
+ }
1746
+
1747
+ int compute_distance(const float* x, const uint8_t* code) {
1748
+ set_query(x);
1749
+ return compute_code_distance(tmp.data(), code);
1750
+ }
1751
+
1752
+ float symmetric_dis(idx_t i, idx_t j) override {
1753
+ return compute_code_distance(
1754
+ codes + i * code_size, codes + j * code_size);
1755
+ }
1756
+
1757
+ float query_to_code(const uint8_t* code) const final {
1758
+ return compute_code_distance(tmp.data(), code);
1759
+ }
1760
+ };
1761
+
1762
+ #elif defined(__AVX2__)
1237
1763
 
1238
1764
  template <class Similarity>
1239
1765
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
@@ -1298,7 +1824,7 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1298
1824
 
1299
1825
  #endif
1300
1826
 
1301
- #ifdef __aarch64__
1827
+ #ifdef USE_NEON
1302
1828
 
1303
1829
  template <class Similarity>
1304
1830
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
@@ -1360,31 +1886,46 @@ SQDistanceComputer* select_distance_computer(
1360
1886
  switch (qtype) {
1361
1887
  case ScalarQuantizer::QT_8bit_uniform:
1362
1888
  return new DCTemplate<
1363
- QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1889
+ QuantizerTemplate<
1890
+ Codec8bit,
1891
+ QuantizerTemplateScaling::UNIFORM,
1892
+ SIMDWIDTH>,
1364
1893
  Sim,
1365
1894
  SIMDWIDTH>(d, trained);
1366
1895
 
1367
1896
  case ScalarQuantizer::QT_4bit_uniform:
1368
1897
  return new DCTemplate<
1369
- QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1898
+ QuantizerTemplate<
1899
+ Codec4bit,
1900
+ QuantizerTemplateScaling::UNIFORM,
1901
+ SIMDWIDTH>,
1370
1902
  Sim,
1371
1903
  SIMDWIDTH>(d, trained);
1372
1904
 
1373
1905
  case ScalarQuantizer::QT_8bit:
1374
1906
  return new DCTemplate<
1375
- QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1907
+ QuantizerTemplate<
1908
+ Codec8bit,
1909
+ QuantizerTemplateScaling::NON_UNIFORM,
1910
+ SIMDWIDTH>,
1376
1911
  Sim,
1377
1912
  SIMDWIDTH>(d, trained);
1378
1913
 
1379
1914
  case ScalarQuantizer::QT_6bit:
1380
1915
  return new DCTemplate<
1381
- QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1916
+ QuantizerTemplate<
1917
+ Codec6bit,
1918
+ QuantizerTemplateScaling::NON_UNIFORM,
1919
+ SIMDWIDTH>,
1382
1920
  Sim,
1383
1921
  SIMDWIDTH>(d, trained);
1384
1922
 
1385
1923
  case ScalarQuantizer::QT_4bit:
1386
1924
  return new DCTemplate<
1387
- QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1925
+ QuantizerTemplate<
1926
+ Codec4bit,
1927
+ QuantizerTemplateScaling::NON_UNIFORM,
1928
+ SIMDWIDTH>,
1388
1929
  Sim,
1389
1930
  SIMDWIDTH>(d, trained);
1390
1931
 
@@ -1392,15 +1933,31 @@ SQDistanceComputer* select_distance_computer(
1392
1933
  return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1393
1934
  d, trained);
1394
1935
 
1936
+ case ScalarQuantizer::QT_bf16:
1937
+ return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1938
+ d, trained);
1939
+
1395
1940
  case ScalarQuantizer::QT_8bit_direct:
1941
+ #if defined(__AVX512F__)
1942
+ if (d % 32 == 0) {
1943
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1944
+ } else
1945
+ #elif defined(__AVX2__)
1396
1946
  if (d % 16 == 0) {
1397
1947
  return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1398
- } else {
1948
+ } else
1949
+ #endif
1950
+ {
1399
1951
  return new DCTemplate<
1400
1952
  Quantizer8bitDirect<SIMDWIDTH>,
1401
1953
  Sim,
1402
1954
  SIMDWIDTH>(d, trained);
1403
1955
  }
1956
+ case ScalarQuantizer::QT_8bit_direct_signed:
1957
+ return new DCTemplate<
1958
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
1959
+ Sim,
1960
+ SIMDWIDTH>(d, trained);
1404
1961
  }
1405
1962
  FAISS_THROW_MSG("unknown qtype");
1406
1963
  return nullptr;
@@ -1424,6 +1981,7 @@ void ScalarQuantizer::set_derived_sizes() {
1424
1981
  case QT_8bit:
1425
1982
  case QT_8bit_uniform:
1426
1983
  case QT_8bit_direct:
1984
+ case QT_8bit_direct_signed:
1427
1985
  code_size = d;
1428
1986
  bits = 8;
1429
1987
  break;
@@ -1440,6 +1998,10 @@ void ScalarQuantizer::set_derived_sizes() {
1440
1998
  code_size = d * 2;
1441
1999
  bits = 16;
1442
2000
  break;
2001
+ case QT_bf16:
2002
+ code_size = d * 2;
2003
+ bits = 16;
2004
+ break;
1443
2005
  }
1444
2006
  }
1445
2007
 
@@ -1476,13 +2038,19 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1476
2038
  break;
1477
2039
  case QT_fp16:
1478
2040
  case QT_8bit_direct:
2041
+ case QT_bf16:
2042
+ case QT_8bit_direct_signed:
1479
2043
  // no training necessary
1480
2044
  break;
1481
2045
  }
1482
2046
  }
1483
2047
 
1484
2048
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1485
- #if defined(USE_F16C) || defined(__aarch64__)
2049
+ #if defined(USE_AVX512_F16C)
2050
+ if (d % 16 == 0) {
2051
+ return select_quantizer_1<16>(qtype, d, trained);
2052
+ } else
2053
+ #elif defined(USE_F16C) || defined(USE_NEON)
1486
2054
  if (d % 8 == 0) {
1487
2055
  return select_quantizer_1<8>(qtype, d, trained);
1488
2056
  } else
@@ -1513,7 +2081,17 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1513
2081
  SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1514
2082
  MetricType metric) const {
1515
2083
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1516
- #if defined(USE_F16C) || defined(__aarch64__)
2084
+ #if defined(USE_AVX512_F16C)
2085
+ if (d % 16 == 0) {
2086
+ if (metric == METRIC_L2) {
2087
+ return select_distance_computer<SimilarityL2<16>>(
2088
+ qtype, d, trained);
2089
+ } else {
2090
+ return select_distance_computer<SimilarityIP<16>>(
2091
+ qtype, d, trained);
2092
+ }
2093
+ } else
2094
+ #elif defined(USE_F16C) || defined(USE_NEON)
1517
2095
  if (d % 8 == 0) {
1518
2096
  if (metric == METRIC_L2) {
1519
2097
  return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
@@ -1762,7 +2340,7 @@ InvertedListScanner* sel2_InvertedListScanner(
1762
2340
  }
1763
2341
  }
1764
2342
 
1765
- template <class Similarity, class Codec, bool uniform>
2343
+ template <class Similarity, class Codec, QuantizerTemplateScaling SCALING>
1766
2344
  InvertedListScanner* sel12_InvertedListScanner(
1767
2345
  const ScalarQuantizer* sq,
1768
2346
  const Index* quantizer,
@@ -1770,7 +2348,7 @@ InvertedListScanner* sel12_InvertedListScanner(
1770
2348
  const IDSelector* sel,
1771
2349
  bool r) {
1772
2350
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1773
- using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
2351
+ using QuantizerClass = QuantizerTemplate<Codec, SCALING, SIMDWIDTH>;
1774
2352
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1775
2353
  return sel2_InvertedListScanner<DCClass>(
1776
2354
  sq, quantizer, store_pairs, sel, r);
@@ -1786,36 +2364,70 @@ InvertedListScanner* sel1_InvertedListScanner(
1786
2364
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1787
2365
  switch (sq->qtype) {
1788
2366
  case ScalarQuantizer::QT_8bit_uniform:
1789
- return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
2367
+ return sel12_InvertedListScanner<
2368
+ Similarity,
2369
+ Codec8bit,
2370
+ QuantizerTemplateScaling::UNIFORM>(
1790
2371
  sq, quantizer, store_pairs, sel, r);
1791
2372
  case ScalarQuantizer::QT_4bit_uniform:
1792
- return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
2373
+ return sel12_InvertedListScanner<
2374
+ Similarity,
2375
+ Codec4bit,
2376
+ QuantizerTemplateScaling::UNIFORM>(
1793
2377
  sq, quantizer, store_pairs, sel, r);
1794
2378
  case ScalarQuantizer::QT_8bit:
1795
- return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
2379
+ return sel12_InvertedListScanner<
2380
+ Similarity,
2381
+ Codec8bit,
2382
+ QuantizerTemplateScaling::NON_UNIFORM>(
1796
2383
  sq, quantizer, store_pairs, sel, r);
1797
2384
  case ScalarQuantizer::QT_4bit:
1798
- return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
2385
+ return sel12_InvertedListScanner<
2386
+ Similarity,
2387
+ Codec4bit,
2388
+ QuantizerTemplateScaling::NON_UNIFORM>(
1799
2389
  sq, quantizer, store_pairs, sel, r);
1800
2390
  case ScalarQuantizer::QT_6bit:
1801
- return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
2391
+ return sel12_InvertedListScanner<
2392
+ Similarity,
2393
+ Codec6bit,
2394
+ QuantizerTemplateScaling::NON_UNIFORM>(
1802
2395
  sq, quantizer, store_pairs, sel, r);
1803
2396
  case ScalarQuantizer::QT_fp16:
1804
2397
  return sel2_InvertedListScanner<DCTemplate<
1805
2398
  QuantizerFP16<SIMDWIDTH>,
1806
2399
  Similarity,
1807
2400
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2401
+ case ScalarQuantizer::QT_bf16:
2402
+ return sel2_InvertedListScanner<DCTemplate<
2403
+ QuantizerBF16<SIMDWIDTH>,
2404
+ Similarity,
2405
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1808
2406
  case ScalarQuantizer::QT_8bit_direct:
2407
+ #if defined(__AVX512F__)
2408
+ if (sq->d % 32 == 0) {
2409
+ return sel2_InvertedListScanner<
2410
+ DistanceComputerByte<Similarity, SIMDWIDTH>>(
2411
+ sq, quantizer, store_pairs, sel, r);
2412
+ } else
2413
+ #elif defined(__AVX2__)
1809
2414
  if (sq->d % 16 == 0) {
1810
2415
  return sel2_InvertedListScanner<
1811
2416
  DistanceComputerByte<Similarity, SIMDWIDTH>>(
1812
2417
  sq, quantizer, store_pairs, sel, r);
1813
- } else {
2418
+ } else
2419
+ #endif
2420
+ {
1814
2421
  return sel2_InvertedListScanner<DCTemplate<
1815
2422
  Quantizer8bitDirect<SIMDWIDTH>,
1816
2423
  Similarity,
1817
2424
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1818
2425
  }
2426
+ case ScalarQuantizer::QT_8bit_direct_signed:
2427
+ return sel2_InvertedListScanner<DCTemplate<
2428
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
2429
+ Similarity,
2430
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1819
2431
  }
1820
2432
 
1821
2433
  FAISS_THROW_MSG("unknown qtype");
@@ -1849,7 +2461,12 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1849
2461
  bool store_pairs,
1850
2462
  const IDSelector* sel,
1851
2463
  bool by_residual) const {
1852
- #if defined(USE_F16C) || defined(__aarch64__)
2464
+ #if defined(USE_AVX512_F16C)
2465
+ if (d % 16 == 0) {
2466
+ return sel0_InvertedListScanner<16>(
2467
+ mt, this, quantizer, store_pairs, sel, by_residual);
2468
+ } else
2469
+ #elif defined(USE_F16C) || defined(USE_NEON)
1853
2470
  if (d % 8 == 0) {
1854
2471
  return sel0_InvertedListScanner<8>(
1855
2472
  mt, this, quantizer, store_pairs, sel, by_residual);